Improve libraries import

This commit is contained in:
unclecode
2024-05-13 02:46:35 +08:00
parent 11393183f7
commit 5fea6c064b
5 changed files with 231 additions and 125 deletions

View File

@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
import re
import spacy
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, TextTilingTokenizer
# spacy = lazy_import.lazy_module('spacy')
# nl = lazy_import.lazy_module('nltk')
# from nltk.corpus import stopwords
# from nltk.tokenize import word_tokenize, TextTilingTokenizer
from collections import Counter
import string
@@ -34,8 +34,10 @@ class RegexChunking(ChunkingStrategy):
return paragraphs
# NLP-based sentence chunking using spaCy
class NlpSentenceChunking(ChunkingStrategy):
def __init__(self, model='en_core_web_sm'):
import spacy
self.nlp = spacy.load(model)
def chunk(self, text: str) -> list:
@@ -44,8 +46,10 @@ class NlpSentenceChunking(ChunkingStrategy):
# Topic-based segmentation using TextTiling
class TopicSegmentationChunking(ChunkingStrategy):
def __init__(self, num_keywords=3):
self.tokenizer = TextTilingTokenizer()
import nltk as nl
self.tokenizer = nl.toknize.TextTilingTokenizer()
self.num_keywords = num_keywords
def chunk(self, text: str) -> list:
@@ -55,8 +59,9 @@ class TopicSegmentationChunking(ChunkingStrategy):
def extract_keywords(self, text: str) -> list:
# Tokenize and remove stopwords and punctuation
tokens = word_tokenize(text)
tokens = [token.lower() for token in tokens if token not in stopwords.words('english') and token not in string.punctuation]
import nltk as nl
tokens = nl.toknize.word_tokenize(text)
tokens = [token.lower() for token in tokens if token not in nl.corpus.stopwords.words('english') and token not in string.punctuation]
# Calculate frequency distribution
freq_dist = Counter(tokens)

View File

@@ -1,20 +1,12 @@
from abc import ABC, abstractmethod
from typing import Any, List, Dict, Optional, Union
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist
from transformers import BertTokenizer, BertModel, pipeline
from transformers import AutoTokenizer, AutoModel
from concurrent.futures import ThreadPoolExecutor, as_completed
import nltk
from nltk.tokenize import TextTilingTokenizer
import json, time
import torch
import spacy
# from optimum.intel import IPEXModel
from .prompts import PROMPT_EXTRACT_BLOCKS
from .config import *
from .utils import *
from functools import partial
class ExtractionStrategy(ABC):
"""
@@ -50,23 +42,32 @@ class ExtractionStrategy(ABC):
parsed_json.extend(future.result())
return parsed_json
class NoExtractionStrategy(ExtractionStrategy):
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
return [{"index": 0, "content": html}]
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
return [{"index": i, "tags": [], "content": section} for i, section in enumerate(sections)]
class LLMExtractionStrategy(ExtractionStrategy):
def __init__(self, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None):
"""
Initialize the strategy with clustering parameters.
"""
Initialize the strategy with clustering parameters.
:param word_count_threshold: Minimum number of words per cluster.
:param max_dist: The maximum cophenetic distance on the dendrogram to form clusters.
:param linkage_method: The linkage method for hierarchical clustering.
"""
super().__init__()
self.provider = provider
self.api_token = api_token
def extract(self, url: str, html: str, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None) -> List[Dict[str, Any]]:
api_token = PROVIDER_MODELS.get(provider, None) if not api_token else api_token
:param word_count_threshold: Minimum number of words per cluster.
:param max_dist: The maximum cophenetic distance on the dendrogram to form clusters.
:param linkage_method: The linkage method for hierarchical clustering.
"""
super().__init__()
self.provider = provider
self.api_token = api_token or PROVIDER_MODELS.get(provider, None) or os.getenv("OPENAI_API_KEY")
if not self.api_token:
raise ValueError("API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable.")
def extract(self, url: str, html: str) -> List[Dict[str, Any]]:
print("Extracting blocks ...")
variable_values = {
"URL": url,
"HTML": escape_json_string(sanitize_html(html)),
@@ -78,7 +79,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
"{" + variable + "}", variable_values[variable]
)
response = perform_completion_with_backoff(provider, prompt_with_variables, api_token)
response = perform_completion_with_backoff(self.provider, prompt_with_variables, self.api_token)
try:
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
@@ -96,28 +97,54 @@ class LLMExtractionStrategy(ExtractionStrategy):
"tags": ["error"],
"content": unparsed
})
print("Extracted", len(blocks), "blocks.")
return blocks
def run(self, url: str, sections: List[str], provider: str, api_token: Optional[str]) -> List[Dict[str, Any]]:
def _merge(self, documents):
chunks = []
sections = []
total_token_so_far = 0
for document in documents:
if total_token_so_far < CHUNK_TOKEN_THRESHOLD:
chunk = document.split(' ')
total_token_so_far += len(chunk) * 1.3
chunks.append(document)
else:
sections.append('\n\n'.join(chunks))
chunks = [document]
total_token_so_far = len(document.split(' ')) * 1.3
if chunks:
sections.append('\n\n'.join(chunks))
return sections
def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
"""
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
"""
merged_sections = self._merge(sections)
parsed_json = []
if provider.startswith("groq/"):
if self.provider.startswith("groq/"):
# Sequential processing with a delay
for section in sections:
parsed_json.extend(self.extract(url, section, provider, api_token))
for section in merged_sections:
parsed_json.extend(self.extract(url, section))
time.sleep(0.5) # 500 ms delay between each processing
else:
# Parallel processing using ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
futures = [executor.submit(self.extract, url, section, provider, api_token) for section in sections]
with ThreadPoolExecutor(max_workers=4) as executor:
extract_func = partial(self.extract, url)
futures = [executor.submit(extract_func, section) for section in merged_sections]
for future in as_completed(futures):
parsed_json.extend(future.result())
return parsed_json
class HierarchicalClusteringStrategy(ExtractionStrategy):
class CosinegStrategy(ExtractionStrategy):
def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5'):
"""
Initialize the strategy with clustering parameters.
@@ -128,6 +155,10 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
:param top_k: Number of top categories to extract.
"""
super().__init__()
from transformers import BertTokenizer, BertModel, pipeline
from transformers import AutoTokenizer, AutoModel
import spacy
self.word_count_threshold = word_count_threshold
self.max_dist = max_dist
self.linkage_method = linkage_method
@@ -156,6 +187,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
:param sentences: List of text chunks (sentences).
:return: NumPy array of embeddings.
"""
import torch
# Tokenize sentences and convert to tensor
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
@@ -174,9 +206,11 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
:return: NumPy array of cluster labels.
"""
# Get embeddings
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist
self.timer = time.time()
embeddings = self.get_embeddings(sentences)
print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
# print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
# Compute pairwise cosine distances
distance_matrix = pdist(embeddings, 'cosine')
# Perform agglomerative clustering respecting order
@@ -219,7 +253,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
# Perform clustering
labels = self.hierarchical_clustering(text_chunks)
print(f"[LOG] 🚀 Clustering done in {time.time() - t:.2f} seconds")
# print(f"[LOG] 🚀 Clustering done in {time.time() - t:.2f} seconds")
# Organize texts by their cluster labels, retaining order
t = time.time()
@@ -240,7 +274,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
cluster['tags'] = [cat for cat, _ in top_categories]
print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds")
# print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds")
return cluster_list
@@ -265,9 +299,10 @@ class TopicExtractionStrategy(ExtractionStrategy):
:param num_keywords: Number of keywords to represent each topic segment.
"""
import nltk
super().__init__()
self.num_keywords = num_keywords
self.tokenizer = TextTilingTokenizer()
self.tokenizer = nltk.TextTilingTokenizer()
def extract_keywords(self, text: str) -> List[str]:
"""
@@ -276,6 +311,7 @@ class TopicExtractionStrategy(ExtractionStrategy):
:param text: The text segment from which to extract keywords.
:return: A list of keyword strings.
"""
import nltk
# Tokenize the text and compute word frequency
words = nltk.word_tokenize(text)
freq_dist = nltk.FreqDist(words)

View File

@@ -3,15 +3,12 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from bs4 import BeautifulSoup, Comment, element, Tag, NavigableString
import html2text
import json
import html
import re
import os
import litellm
from litellm import completion, batch_completion
from html2text import HTML2Text
from .prompts import PROMPT_EXTRACT_BLOCKS
from .config import *
import re
import html
from html2text import HTML2Text
def beautify_html(escaped_html):
@@ -303,17 +300,16 @@ def extract_xml_data(tags, string):
return data
import time
import litellm
# Function to perform the completion with exponential backoff
def perform_completion_with_backoff(provider, prompt_with_variables, api_token):
from litellm import completion
from litellm.exceptions import RateLimitError
max_attempts = 3
base_delay = 2 # Base delay in seconds, you can adjust this based on your needs
for attempt in range(max_attempts):
try:
response = completion(
response =completion(
model=provider,
messages=[
{"role": "user", "content": prompt_with_variables}
@@ -322,7 +318,7 @@ def perform_completion_with_backoff(provider, prompt_with_variables, api_token):
api_key=api_token
)
return response # Return the successful response
except litellm.exceptions.RateLimitError as e:
except RateLimitError as e:
print("Rate limit error:", str(e))
# Check if we have exhausted our max attempts
@@ -378,7 +374,7 @@ def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None):
def extract_blocks_batch(batch_data, provider = "groq/llama3-70b-8192", api_token = None):
api_token = os.getenv('GROQ_API_KEY', None) if not api_token else api_token
from litellm import batch_completion
messages = []
for url, html in batch_data:

View File

@@ -9,49 +9,93 @@ from .extraction_strategy import *
from .crawler_strategy import *
from typing import List
from concurrent.futures import ThreadPoolExecutor
from .config import *
from .config import *
class WebCrawler:
def __init__(self, db_path: str, crawler_strategy: CrawlerStrategy = LocalSeleniumCrawlerStrategy()):
def __init__(
self,
db_path: str = None,
crawler_strategy: CrawlerStrategy = LocalSeleniumCrawlerStrategy(),
):
self.db_path = db_path
init_db(self.db_path)
self.crawler_strategy = crawler_strategy
# Create the .crawl4ai folder in the user's home directory if it doesn't exist
self.crawl4ai_folder = os.path.join(Path.home(), ".crawl4ai")
os.makedirs(self.crawl4ai_folder, exist_ok=True)
os.makedirs(self.crawl4ai_folder, exist_ok=True)
os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True)
def fetch_page(self,
url_model: UrlModel,
provider: str = DEFAULT_PROVIDER,
api_token: str = None,
extract_blocks_flag: bool = True,
word_count_threshold = MIN_WORD_THRESHOLD,
use_cached_html: bool = False,
extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(),
chunking_strategy: ChunkingStrategy = RegexChunking(),
**kwargs
) -> CrawlResult:
# If db_path is not provided, use the default path
if not db_path:
self.db_path = f"{self.crawl4ai_folder}/crawl4ai.db"
init_db(self.db_path)
self.ready = False
def warmup(self):
print("[LOG] 🌤️ Warming up the WebCrawler")
single_url = UrlModel(url='https://crawl4ai.uccode.io/', forced=False)
result = self.run(
single_url,
word_count_threshold=5,
extraction_strategy= CosinegStrategy(),
verbose = False
)
self.ready = True
print("[LOG] 🌞 WebCrawler is ready to crawl")
def fetch_page(
self,
url_model: UrlModel,
provider: str = DEFAULT_PROVIDER,
api_token: str = None,
extract_blocks_flag: bool = True,
word_count_threshold=MIN_WORD_THRESHOLD,
use_cached_html: bool = False,
extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(),
chunking_strategy: ChunkingStrategy = RegexChunking(),
**kwargs,
) -> CrawlResult:
return self.run(
url_model,
word_count_threshold,
extraction_strategy,
chunking_strategy,
**kwargs,
)
pass
def run(
self,
url_model: UrlModel,
word_count_threshold=MIN_WORD_THRESHOLD,
extraction_strategy: ExtractionStrategy = NoExtractionStrategy(),
chunking_strategy: ChunkingStrategy = RegexChunking(),
verbose=True,
**kwargs,
) -> CrawlResult:
# make sure word_count_threshold is not lesser than MIN_WORD_THRESHOLD
if word_count_threshold < MIN_WORD_THRESHOLD:
word_count_threshold = MIN_WORD_THRESHOLD
# Check cache first
cached = get_cached_url(self.db_path, str(url_model.url))
if cached and not url_model.forced:
return CrawlResult(**{
"url": cached[0],
"html": cached[1],
"cleaned_html": cached[2],
"markdown": cached[3],
"parsed_json": cached[4],
"success": cached[5],
"error_message": ""
})
return CrawlResult(
**{
"url": cached[0],
"html": cached[1],
"cleaned_html": cached[2],
"markdown": cached[3],
"parsed_json": cached[4],
"success": cached[5],
"error_message": "",
}
)
# Initialize WebDriver for crawling
t = time.time()
@@ -62,65 +106,89 @@ class WebCrawler:
except Exception as e:
html = ""
success = False
error_message = str(e)
error_message = str(e)
# Extract content from HTML
result = get_content_of_website(html, word_count_threshold)
cleaned_html = result.get('cleaned_html', html)
markdown = result.get('markdown', "")
cleaned_html = result.get("cleaned_html", html)
markdown = result.get("markdown", "")
# Print a profession LOG style message, show time taken and say crawling is done
print(f"[LOG] 🚀 Crawling done for {url_model.url}, success: {success}, time taken: {time.time() - t} seconds")
if verbose:
print(
f"[LOG] 🚀 Crawling done for {url_model.url}, success: {success}, time taken: {time.time() - t} seconds"
)
parsed_json = []
if extract_blocks_flag:
if verbose:
print(f"[LOG] 🔥 Extracting semantic blocks for {url_model.url}")
t = time.time()
# Split markdown into sections
sections = chunking_strategy.chunk(markdown)
# sections = merge_chunks_based_on_token_threshold(sections, CHUNK_TOKEN_THRESHOLD)
t = time.time()
# Split markdown into sections
sections = chunking_strategy.chunk(markdown)
# sections = merge_chunks_based_on_token_threshold(sections, CHUNK_TOKEN_THRESHOLD)
parsed_json = extraction_strategy.run(str(url_model.url), sections, provider, api_token)
parsed_json = json.dumps(parsed_json)
print(f"[LOG] 🚀 Extraction done for {url_model.url}, time taken: {time.time() - t} seconds.")
else:
parsed_json = "{}"
print(f"[LOG] 🚀 Skipping extraction for {url_model.url}")
parsed_json = extraction_strategy.run(
str(url_model.url), sections,
)
parsed_json = json.dumps(parsed_json)
if verbose:
print(
f"[LOG] 🚀 Extraction done for {url_model.url}, time taken: {time.time() - t} seconds."
)
# Cache the result
cleaned_html = beautify_html(cleaned_html)
cache_url(self.db_path, str(url_model.url), html, cleaned_html, markdown, parsed_json, success)
return CrawlResult(
url=str(url_model.url),
html=html,
cleaned_html=cleaned_html,
markdown=markdown,
parsed_json=parsed_json,
success=success,
error_message=error_message
cache_url(
self.db_path,
str(url_model.url),
html,
cleaned_html,
markdown,
parsed_json,
success,
)
def fetch_pages(self, url_models: List[UrlModel], provider: str = DEFAULT_PROVIDER, api_token: str = None,
extract_blocks_flag: bool = True, word_count_threshold=MIN_WORD_THRESHOLD,
use_cached_html: bool = False, extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(),
chunking_strategy: ChunkingStrategy = RegexChunking(), **kwargs) -> List[CrawlResult]:
return CrawlResult(
url=str(url_model.url),
html=html,
cleaned_html=cleaned_html,
markdown=markdown,
parsed_json=parsed_json,
success=success,
error_message=error_message,
)
def fetch_pages(
self,
url_models: List[UrlModel],
provider: str = DEFAULT_PROVIDER,
api_token: str = None,
extract_blocks_flag: bool = True,
word_count_threshold=MIN_WORD_THRESHOLD,
use_cached_html: bool = False,
extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(),
chunking_strategy: ChunkingStrategy = RegexChunking(),
**kwargs,
) -> List[CrawlResult]:
def fetch_page_wrapper(url_model, *args, **kwargs):
return self.fetch_page(url_model, *args, **kwargs)
with ThreadPoolExecutor() as executor:
results = list(executor.map(fetch_page_wrapper, url_models,
[provider] * len(url_models),
[api_token] * len(url_models),
[extract_blocks_flag] * len(url_models),
[word_count_threshold] * len(url_models),
[use_cached_html] * len(url_models),
[extraction_strategy] * len(url_models),
[chunking_strategy] * len(url_models),
*[kwargs] * len(url_models)))
results = list(
executor.map(
fetch_page_wrapper,
url_models,
[provider] * len(url_models),
[api_token] * len(url_models),
[extract_blocks_flag] * len(url_models),
[word_count_threshold] * len(url_models),
[use_cached_html] * len(url_models),
[extraction_strategy] * len(url_models),
[chunking_strategy] * len(url_models),
*[kwargs] * len(url_models),
)
)
return results
return results

View File

@@ -12,4 +12,5 @@ html2text
litellm
python-dotenv
nltk
lazy_import
# spacy