Update:
- Text Categorization - Crawler, Extraction, and Chunking strategies - Clustering for semantic segmentation
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -165,4 +165,6 @@ Crawl4AI.egg-info/
|
|||||||
Crawl4AI.egg-info/*
|
Crawl4AI.egg-info/*
|
||||||
crawler_data.db
|
crawler_data.db
|
||||||
.vscode/
|
.vscode/
|
||||||
test_pad.py
|
test_pad.py
|
||||||
|
.data/
|
||||||
|
Crawl4AI.egg-info/
|
||||||
95
crawl4ai/chunking_strategy.py
Normal file
95
crawl4ai/chunking_strategy.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import re
|
||||||
|
import spacy
|
||||||
|
import nltk
|
||||||
|
from nltk.corpus import stopwords
|
||||||
|
from nltk.tokenize import word_tokenize, TextTilingTokenizer
|
||||||
|
from collections import Counter
|
||||||
|
import string
|
||||||
|
|
||||||
|
# Define the abstract base class for chunking strategies
|
||||||
|
class ChunkingStrategy(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def chunk(self, text: str) -> list:
|
||||||
|
"""
|
||||||
|
Abstract method to chunk the given text.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Regex-based chunking
|
||||||
|
class RegexChunking(ChunkingStrategy):
|
||||||
|
def __init__(self, patterns=None):
|
||||||
|
if patterns is None:
|
||||||
|
patterns = [r'\n\n'] # Default split pattern
|
||||||
|
self.patterns = patterns
|
||||||
|
|
||||||
|
def chunk(self, text: str) -> list:
|
||||||
|
paragraphs = [text]
|
||||||
|
for pattern in self.patterns:
|
||||||
|
new_paragraphs = []
|
||||||
|
for paragraph in paragraphs:
|
||||||
|
new_paragraphs.extend(re.split(pattern, paragraph))
|
||||||
|
paragraphs = new_paragraphs
|
||||||
|
return paragraphs
|
||||||
|
|
||||||
|
# NLP-based sentence chunking using spaCy
|
||||||
|
class NlpSentenceChunking(ChunkingStrategy):
|
||||||
|
def __init__(self, model='en_core_web_sm'):
|
||||||
|
self.nlp = spacy.load(model)
|
||||||
|
|
||||||
|
def chunk(self, text: str) -> list:
|
||||||
|
doc = self.nlp(text)
|
||||||
|
return [sent.text.strip() for sent in doc.sents]
|
||||||
|
|
||||||
|
# Topic-based segmentation using TextTiling
|
||||||
|
class TopicSegmentationChunking(ChunkingStrategy):
|
||||||
|
def __init__(self, num_keywords=3):
|
||||||
|
self.tokenizer = TextTilingTokenizer()
|
||||||
|
self.num_keywords = num_keywords
|
||||||
|
|
||||||
|
def chunk(self, text: str) -> list:
|
||||||
|
# Use the TextTilingTokenizer to segment the text
|
||||||
|
segmented_topics = self.tokenizer.tokenize(text)
|
||||||
|
return segmented_topics
|
||||||
|
|
||||||
|
def extract_keywords(self, text: str) -> list:
|
||||||
|
# Tokenize and remove stopwords and punctuation
|
||||||
|
tokens = word_tokenize(text)
|
||||||
|
tokens = [token.lower() for token in tokens if token not in stopwords.words('english') and token not in string.punctuation]
|
||||||
|
|
||||||
|
# Calculate frequency distribution
|
||||||
|
freq_dist = Counter(tokens)
|
||||||
|
keywords = [word for word, freq in freq_dist.most_common(self.num_keywords)]
|
||||||
|
return keywords
|
||||||
|
|
||||||
|
def chunk_with_topics(self, text: str) -> list:
|
||||||
|
# Segment the text into topics
|
||||||
|
segments = self.chunk(text)
|
||||||
|
# Extract keywords for each topic segment
|
||||||
|
segments_with_topics = [(segment, self.extract_keywords(segment)) for segment in segments]
|
||||||
|
return segments_with_topics
|
||||||
|
|
||||||
|
# Fixed-length word chunks
|
||||||
|
class FixedLengthWordChunking(ChunkingStrategy):
|
||||||
|
def __init__(self, chunk_size=100):
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
|
def chunk(self, text: str) -> list:
|
||||||
|
words = text.split()
|
||||||
|
return [' '.join(words[i:i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)]
|
||||||
|
|
||||||
|
# Sliding window chunking
|
||||||
|
class SlidingWindowChunking(ChunkingStrategy):
|
||||||
|
def __init__(self, window_size=100, step=50):
|
||||||
|
self.window_size = window_size
|
||||||
|
self.step = step
|
||||||
|
|
||||||
|
def chunk(self, text: str) -> list:
|
||||||
|
words = text.split()
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(words), self.step):
|
||||||
|
chunks.append(' '.join(words[i:i + self.window_size]))
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
58
crawl4ai/crawler_strategy.py
Normal file
58
crawl4ai/crawler_strategy.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from selenium import webdriver
|
||||||
|
from selenium.webdriver.chrome.service import Service
|
||||||
|
from selenium.webdriver.common.by import By
|
||||||
|
from selenium.webdriver.support.ui import WebDriverWait
|
||||||
|
from selenium.webdriver.support import expected_conditions as EC
|
||||||
|
from selenium.webdriver.chrome.options import Options
|
||||||
|
import chromedriver_autoinstaller
|
||||||
|
from typing import List
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class CrawlerStrategy(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def crawl(self, url: str) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class CloudCrawlerStrategy(CrawlerStrategy):
|
||||||
|
def crawl(self, url: str) -> str:
|
||||||
|
data = {
|
||||||
|
"urls": [url],
|
||||||
|
"provider_model": "",
|
||||||
|
"api_token": "token",
|
||||||
|
"include_raw_html": True,
|
||||||
|
"forced": True,
|
||||||
|
"extract_blocks": False,
|
||||||
|
"word_count_threshold": 10
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post("http://crawl4ai.uccode.io/crawl", json=data)
|
||||||
|
response = response.json()
|
||||||
|
html = response["results"][0]["html"]
|
||||||
|
return html
|
||||||
|
|
||||||
|
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||||
|
def __init__(self):
|
||||||
|
self.options = Options()
|
||||||
|
self.options.headless = True
|
||||||
|
self.options.add_argument("--no-sandbox")
|
||||||
|
self.options.add_argument("--disable-dev-shm-usage")
|
||||||
|
self.options.add_argument("--headless")
|
||||||
|
|
||||||
|
chromedriver_autoinstaller.install()
|
||||||
|
self.service = Service(chromedriver_autoinstaller.install())
|
||||||
|
self.driver = webdriver.Chrome(service=self.service, options=self.options)
|
||||||
|
|
||||||
|
def crawl(self, url: str, use_cached_html = False) -> str:
|
||||||
|
if use_cached_html:
|
||||||
|
return get_content_of_website(url)
|
||||||
|
self.driver.get(url)
|
||||||
|
WebDriverWait(self.driver, 10).until(
|
||||||
|
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
|
||||||
|
)
|
||||||
|
html = self.driver.page_source
|
||||||
|
return html
|
||||||
|
|
||||||
|
def quit(self):
|
||||||
|
self.driver.quit()
|
||||||
358
crawl4ai/extraction_strategy.py
Normal file
358
crawl4ai/extraction_strategy.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, List, Dict, Optional, Union
|
||||||
|
from scipy.cluster.hierarchy import linkage, fcluster
|
||||||
|
from scipy.spatial.distance import pdist
|
||||||
|
from transformers import BertTokenizer, BertModel, pipeline
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
import nltk
|
||||||
|
from nltk.tokenize import TextTilingTokenizer
|
||||||
|
import json, time
|
||||||
|
import torch
|
||||||
|
import spacy
|
||||||
|
|
||||||
|
from .prompts import PROMPT_EXTRACT_BLOCKS
|
||||||
|
from .config import *
|
||||||
|
from .utils import *
|
||||||
|
|
||||||
|
class ExtractionStrategy(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for all extraction strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.DEL = "<|DEL|>"
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Extract meaningful blocks or chunks from the given HTML.
|
||||||
|
|
||||||
|
:param url: The URL of the webpage.
|
||||||
|
:param html: The HTML content of the webpage.
|
||||||
|
:return: A list of extracted blocks or chunks.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Process sections of text in parallel by default.
|
||||||
|
|
||||||
|
:param url: The URL of the webpage.
|
||||||
|
:param sections: List of sections (strings) to process.
|
||||||
|
:return: A list of processed JSON blocks.
|
||||||
|
"""
|
||||||
|
parsed_json = []
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
futures = [executor.submit(self.extract, url, section, **kwargs) for section in sections]
|
||||||
|
for future in as_completed(futures):
|
||||||
|
parsed_json.extend(future.result())
|
||||||
|
return parsed_json
|
||||||
|
|
||||||
|
class LLMExtractionStrategy(ExtractionStrategy):
|
||||||
|
def __init__(self, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the strategy with clustering parameters.
|
||||||
|
|
||||||
|
:param word_count_threshold: Minimum number of words per cluster.
|
||||||
|
:param max_dist: The maximum cophenetic distance on the dendrogram to form clusters.
|
||||||
|
:param linkage_method: The linkage method for hierarchical clustering.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.provider = provider
|
||||||
|
self.api_token = api_token
|
||||||
|
|
||||||
|
|
||||||
|
def extract(self, url: str, html: str, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
|
api_token = PROVIDER_MODELS.get(provider, None) if not api_token else api_token
|
||||||
|
|
||||||
|
variable_values = {
|
||||||
|
"URL": url,
|
||||||
|
"HTML": escape_json_string(sanitize_html(html)),
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_with_variables = PROMPT_EXTRACT_BLOCKS
|
||||||
|
for variable in variable_values:
|
||||||
|
prompt_with_variables = prompt_with_variables.replace(
|
||||||
|
"{" + variable + "}", variable_values[variable]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = perform_completion_with_backoff(provider, prompt_with_variables, api_token)
|
||||||
|
|
||||||
|
try:
|
||||||
|
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
|
||||||
|
blocks = json.loads(blocks)
|
||||||
|
for block in blocks:
|
||||||
|
block['error'] = False
|
||||||
|
except Exception as e:
|
||||||
|
print("Error extracting blocks:", str(e))
|
||||||
|
parsed, unparsed = split_and_parse_json_objects(response.choices[0].message.content)
|
||||||
|
blocks = parsed
|
||||||
|
if unparsed:
|
||||||
|
blocks.append({
|
||||||
|
"index": 0,
|
||||||
|
"error": True,
|
||||||
|
"tags": ["error"],
|
||||||
|
"content": unparsed
|
||||||
|
})
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
def run(self, url: str, sections: List[str], provider: str, api_token: Optional[str]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
|
||||||
|
"""
|
||||||
|
parsed_json = []
|
||||||
|
if provider.startswith("groq/"):
|
||||||
|
# Sequential processing with a delay
|
||||||
|
for section in sections:
|
||||||
|
parsed_json.extend(self.extract(url, section, provider, api_token))
|
||||||
|
time.sleep(0.5) # 500 ms delay between each processing
|
||||||
|
else:
|
||||||
|
# Parallel processing using ThreadPoolExecutor
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
futures = [executor.submit(self.extract, url, section, provider, api_token) for section in sections]
|
||||||
|
for future in as_completed(futures):
|
||||||
|
parsed_json.extend(future.result())
|
||||||
|
|
||||||
|
return parsed_json
|
||||||
|
|
||||||
|
class HierarchicalClusteringStrategy(ExtractionStrategy):
|
||||||
|
def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3):
|
||||||
|
"""
|
||||||
|
Initialize the strategy with clustering parameters.
|
||||||
|
|
||||||
|
:param word_count_threshold: Minimum number of words per cluster.
|
||||||
|
:param max_dist: The maximum cophenetic distance on the dendrogram to form clusters.
|
||||||
|
:param linkage_method: The linkage method for hierarchical clustering.
|
||||||
|
:param top_k: Number of top categories to extract.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.word_count_threshold = word_count_threshold
|
||||||
|
self.max_dist = max_dist
|
||||||
|
self.linkage_method = linkage_method
|
||||||
|
self.top_k = top_k
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
||||||
|
self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
||||||
|
self.nlp = spacy.load("models/reuters")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_embeddings(self, sentences: List[str]):
|
||||||
|
"""
|
||||||
|
Get BERT embeddings for a list of sentences.
|
||||||
|
|
||||||
|
:param sentences: List of text chunks (sentences).
|
||||||
|
:return: NumPy array of embeddings.
|
||||||
|
"""
|
||||||
|
# Tokenize sentences and convert to tensor
|
||||||
|
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
||||||
|
# Compute token embeddings
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output = self.model(**encoded_input)
|
||||||
|
# Get embeddings from the last hidden state (mean pooling)
|
||||||
|
embeddings = model_output.last_hidden_state.mean(1)
|
||||||
|
return embeddings.numpy()
|
||||||
|
|
||||||
|
def hierarchical_clustering(self, sentences: List[str]):
|
||||||
|
"""
|
||||||
|
Perform hierarchical clustering on sentences and return cluster labels.
|
||||||
|
|
||||||
|
:param sentences: List of text chunks (sentences).
|
||||||
|
:return: NumPy array of cluster labels.
|
||||||
|
"""
|
||||||
|
# Get embeddings
|
||||||
|
embeddings = self.get_embeddings(sentences)
|
||||||
|
# Compute pairwise cosine distances
|
||||||
|
distance_matrix = pdist(embeddings, 'cosine')
|
||||||
|
# Perform agglomerative clustering respecting order
|
||||||
|
linked = linkage(distance_matrix, method=self.linkage_method)
|
||||||
|
# Form flat clusters
|
||||||
|
labels = fcluster(linked, self.max_dist, criterion='distance')
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def filter_clusters_by_word_count(self, clusters: Dict[int, List[str]]):
|
||||||
|
"""
|
||||||
|
Filter clusters to remove those with a word count below the threshold.
|
||||||
|
|
||||||
|
:param clusters: Dictionary of clusters.
|
||||||
|
:return: Filtered dictionary of clusters.
|
||||||
|
"""
|
||||||
|
filtered_clusters = {}
|
||||||
|
for cluster_id, texts in clusters.items():
|
||||||
|
# Concatenate texts for analysis
|
||||||
|
full_text = " ".join(texts)
|
||||||
|
# Count words
|
||||||
|
word_count = len(full_text.split())
|
||||||
|
|
||||||
|
# Keep clusters with word count above the threshold
|
||||||
|
if word_count >= self.word_count_threshold:
|
||||||
|
filtered_clusters[cluster_id] = texts
|
||||||
|
|
||||||
|
return filtered_clusters
|
||||||
|
|
||||||
|
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Extract clusters from HTML content using hierarchical clustering.
|
||||||
|
|
||||||
|
:param url: The URL of the webpage.
|
||||||
|
:param html: The HTML content of the webpage.
|
||||||
|
:return: A list of dictionaries representing the clusters.
|
||||||
|
"""
|
||||||
|
# Assume `html` is a list of text chunks for this strategy
|
||||||
|
text_chunks = html.split(self.DEL) # Split by lines or paragraphs as needed
|
||||||
|
|
||||||
|
# Perform clustering
|
||||||
|
labels = self.hierarchical_clustering(text_chunks)
|
||||||
|
|
||||||
|
# Organize texts by their cluster labels, retaining order
|
||||||
|
clusters = {}
|
||||||
|
for index, label in enumerate(labels):
|
||||||
|
clusters.setdefault(label, []).append(text_chunks[index])
|
||||||
|
|
||||||
|
# Filter clusters by word count
|
||||||
|
filtered_clusters = self.filter_clusters_by_word_count(clusters)
|
||||||
|
|
||||||
|
# Convert filtered clusters to a sorted list of dictionaries
|
||||||
|
cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)]
|
||||||
|
|
||||||
|
# Process the text with the loaded model
|
||||||
|
for cluster in cluster_list:
|
||||||
|
doc = self.nlp(cluster['content'])
|
||||||
|
tok_k = self.top_k
|
||||||
|
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
||||||
|
cluster['tags'] = [cat for cat, _ in top_categories]
|
||||||
|
|
||||||
|
return cluster_list
|
||||||
|
|
||||||
|
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Process sections using hierarchical clustering.
|
||||||
|
|
||||||
|
:param url: The URL of the webpage.
|
||||||
|
:param sections: List of sections (strings) to process.
|
||||||
|
:param provider: The provider to be used for extraction (not used here).
|
||||||
|
:param api_token: Optional API token for the provider (not used here).
|
||||||
|
:return: A list of processed JSON blocks.
|
||||||
|
"""
|
||||||
|
# This strategy processes all sections together
|
||||||
|
|
||||||
|
return self.extract(url, self.DEL.join(sections), **kwargs)
|
||||||
|
|
||||||
|
class TopicExtractionStrategy(ExtractionStrategy):
|
||||||
|
def __init__(self, num_keywords: int = 3):
|
||||||
|
"""
|
||||||
|
Initialize the topic extraction strategy with parameters for topic segmentation.
|
||||||
|
|
||||||
|
:param num_keywords: Number of keywords to represent each topic segment.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_keywords = num_keywords
|
||||||
|
self.tokenizer = TextTilingTokenizer()
|
||||||
|
|
||||||
|
def extract_keywords(self, text: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Extract keywords from a given text segment using simple frequency analysis.
|
||||||
|
|
||||||
|
:param text: The text segment from which to extract keywords.
|
||||||
|
:return: A list of keyword strings.
|
||||||
|
"""
|
||||||
|
# Tokenize the text and compute word frequency
|
||||||
|
words = nltk.word_tokenize(text)
|
||||||
|
freq_dist = nltk.FreqDist(words)
|
||||||
|
# Get the most common words as keywords
|
||||||
|
keywords = [word for (word, _) in freq_dist.most_common(self.num_keywords)]
|
||||||
|
return keywords
|
||||||
|
|
||||||
|
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Extract topics from HTML content using TextTiling for segmentation and keyword extraction.
|
||||||
|
|
||||||
|
:param url: The URL of the webpage.
|
||||||
|
:param html: The HTML content of the webpage.
|
||||||
|
:param provider: The provider to be used for extraction (not used here).
|
||||||
|
:param api_token: Optional API token for the provider (not used here).
|
||||||
|
:return: A list of dictionaries representing the topics.
|
||||||
|
"""
|
||||||
|
# Use TextTiling to segment the text into topics
|
||||||
|
segmented_topics = html.split(self.DEL) # Split by lines or paragraphs as needed
|
||||||
|
|
||||||
|
# Prepare the output as a list of dictionaries
|
||||||
|
topic_list = []
|
||||||
|
for i, segment in enumerate(segmented_topics):
|
||||||
|
# Extract keywords for each segment
|
||||||
|
keywords = self.extract_keywords(segment)
|
||||||
|
topic_list.append({
|
||||||
|
"index": i,
|
||||||
|
"content": segment,
|
||||||
|
"keywords": keywords
|
||||||
|
})
|
||||||
|
|
||||||
|
return topic_list
|
||||||
|
|
||||||
|
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Process sections using topic segmentation and keyword extraction.
|
||||||
|
|
||||||
|
:param url: The URL of the webpage.
|
||||||
|
:param sections: List of sections (strings) to process.
|
||||||
|
:param provider: The provider to be used for extraction (not used here).
|
||||||
|
:param api_token: Optional API token for the provider (not used here).
|
||||||
|
:return: A list of processed JSON blocks.
|
||||||
|
"""
|
||||||
|
# Concatenate sections into a single text for coherent topic segmentation
|
||||||
|
|
||||||
|
|
||||||
|
return self.extract(url, self.DEL.join(sections), **kwargs)
|
||||||
|
|
||||||
|
class ContentSummarizationStrategy(ExtractionStrategy):
|
||||||
|
def __init__(self, model_name: str = "sshleifer/distilbart-cnn-12-6"):
|
||||||
|
"""
|
||||||
|
Initialize the content summarization strategy with a specific model.
|
||||||
|
|
||||||
|
:param model_name: The model to use for summarization.
|
||||||
|
"""
|
||||||
|
self.summarizer = pipeline("summarization", model=model_name)
|
||||||
|
|
||||||
|
def extract(self, url: str, text: str, provider: str = None, api_token: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Summarize a single section of text.
|
||||||
|
|
||||||
|
:param url: The URL of the webpage.
|
||||||
|
:param text: A section of text to summarize.
|
||||||
|
:param provider: The provider to be used for extraction (not used here).
|
||||||
|
:param api_token: Optional API token for the provider (not used here).
|
||||||
|
:return: A dictionary with the summary.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
summary = self.summarizer(text, max_length=130, min_length=30, do_sample=False)
|
||||||
|
return {"summary": summary[0]['summary_text']}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error summarizing text: {e}")
|
||||||
|
return {"summary": text} # Fallback to original text if summarization fails
|
||||||
|
|
||||||
|
def run(self, url: str, sections: List[str], provider: str = None, api_token: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Process each section in parallel to produce summaries.
|
||||||
|
|
||||||
|
:param url: The URL of the webpage.
|
||||||
|
:param sections: List of sections (strings) to summarize.
|
||||||
|
:param provider: The provider to be used for extraction (not used here).
|
||||||
|
:param api_token: Optional API token for the provider (not used here).
|
||||||
|
:return: A list of dictionaries with summaries for each section.
|
||||||
|
"""
|
||||||
|
# Use a ThreadPoolExecutor to summarize in parallel
|
||||||
|
summaries = []
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
# Create a future for each section's summarization
|
||||||
|
future_to_section = {executor.submit(self.extract, url, section, provider, api_token): i for i, section in enumerate(sections)}
|
||||||
|
for future in as_completed(future_to_section):
|
||||||
|
section_index = future_to_section[future]
|
||||||
|
try:
|
||||||
|
summary_result = future.result()
|
||||||
|
summaries.append((section_index, summary_result))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing section {section_index}: {e}")
|
||||||
|
summaries.append((section_index, {"summary": sections[section_index]})) # Fallback to original text
|
||||||
|
|
||||||
|
# Sort summaries by the original section index to maintain order
|
||||||
|
summaries.sort(key=lambda x: x[0])
|
||||||
|
return [summary for _, summary in summaries]
|
||||||
122
crawl4ai/train.py
Normal file
122
crawl4ai/train.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import spacy
|
||||||
|
from spacy.training import Example
|
||||||
|
import random
|
||||||
|
import nltk
|
||||||
|
from nltk.corpus import reuters
|
||||||
|
|
||||||
|
def train_and_save_reuters_model(model_dir="models/reuters"):
|
||||||
|
# Ensure the Reuters corpus is downloaded
|
||||||
|
nltk.download('reuters')
|
||||||
|
nltk.download('punkt')
|
||||||
|
if not reuters.fileids():
|
||||||
|
print("Reuters corpus not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load a blank English spaCy model
|
||||||
|
nlp = spacy.blank("en")
|
||||||
|
|
||||||
|
# Create a TextCategorizer with the ensemble model for multi-label classification
|
||||||
|
textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
|
||||||
|
# Add labels to text classifier
|
||||||
|
for label in reuters.categories():
|
||||||
|
textcat.add_label(label)
|
||||||
|
|
||||||
|
# Prepare training data
|
||||||
|
train_examples = []
|
||||||
|
for fileid in reuters.fileids():
|
||||||
|
categories = reuters.categories(fileid)
|
||||||
|
text = reuters.raw(fileid)
|
||||||
|
cats = {label: label in categories for label in reuters.categories()}
|
||||||
|
# Prepare spacy Example objects
|
||||||
|
doc = nlp.make_doc(text)
|
||||||
|
example = Example.from_dict(doc, {'cats': cats})
|
||||||
|
train_examples.append(example)
|
||||||
|
|
||||||
|
# Initialize the text categorizer with the example objects
|
||||||
|
nlp.initialize(lambda: train_examples)
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
random.seed(1)
|
||||||
|
spacy.util.fix_random_seed(1)
|
||||||
|
for i in range(5): # Adjust iterations for better accuracy
|
||||||
|
random.shuffle(train_examples)
|
||||||
|
losses = {}
|
||||||
|
# Create batches of data
|
||||||
|
batches = spacy.util.minibatch(train_examples, size=8)
|
||||||
|
for batch in batches:
|
||||||
|
nlp.update(batch, drop=0.2, losses=losses)
|
||||||
|
print(f"Losses at iteration {i}: {losses}")
|
||||||
|
|
||||||
|
# Save the trained model
|
||||||
|
nlp.to_disk(model_dir)
|
||||||
|
print(f"Model saved to: {model_dir}")
|
||||||
|
|
||||||
|
def train_model(model_dir, additional_epochs=0):
|
||||||
|
# Load the model if it exists, otherwise start with a blank model
|
||||||
|
try:
|
||||||
|
nlp = spacy.load(model_dir)
|
||||||
|
print("Model loaded from disk.")
|
||||||
|
except IOError:
|
||||||
|
print("No existing model found. Starting with a new model.")
|
||||||
|
nlp = spacy.blank("en")
|
||||||
|
textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
for label in reuters.categories():
|
||||||
|
textcat.add_label(label)
|
||||||
|
|
||||||
|
# Prepare training data
|
||||||
|
train_examples = []
|
||||||
|
for fileid in reuters.fileids():
|
||||||
|
categories = reuters.categories(fileid)
|
||||||
|
text = reuters.raw(fileid)
|
||||||
|
cats = {label: label in categories for label in reuters.categories()}
|
||||||
|
doc = nlp.make_doc(text)
|
||||||
|
example = Example.from_dict(doc, {'cats': cats})
|
||||||
|
train_examples.append(example)
|
||||||
|
|
||||||
|
# Initialize the model if it was newly created
|
||||||
|
if 'textcat_multilabel' not in nlp.pipe_names:
|
||||||
|
nlp.initialize(lambda: train_examples)
|
||||||
|
else:
|
||||||
|
print("Continuing training with existing model.")
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
random.seed(1)
|
||||||
|
spacy.util.fix_random_seed(1)
|
||||||
|
num_epochs = 5 + additional_epochs
|
||||||
|
for i in range(num_epochs):
|
||||||
|
random.shuffle(train_examples)
|
||||||
|
losses = {}
|
||||||
|
batches = spacy.util.minibatch(train_examples, size=8)
|
||||||
|
for batch in batches:
|
||||||
|
nlp.update(batch, drop=0.2, losses=losses)
|
||||||
|
print(f"Losses at iteration {i}: {losses}")
|
||||||
|
|
||||||
|
# Save the trained model
|
||||||
|
nlp.to_disk(model_dir)
|
||||||
|
print(f"Model saved to: {model_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_and_predict(model_dir, text, tok_k = 3):
|
||||||
|
# Load the trained model from the specified directory
|
||||||
|
nlp = spacy.load(model_dir)
|
||||||
|
|
||||||
|
# Process the text with the loaded model
|
||||||
|
doc = nlp(text)
|
||||||
|
|
||||||
|
# gee top 3 categories
|
||||||
|
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
||||||
|
print(f"Top {tok_k} categories:")
|
||||||
|
|
||||||
|
return top_categories
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_and_save_reuters_model()
|
||||||
|
train_model("models/reuters", additional_epochs=5)
|
||||||
|
model_directory = "reuters_model_10"
|
||||||
|
print(reuters.categories())
|
||||||
|
example_text = "Apple Inc. is reportedly buying a startup for $1 billion"
|
||||||
|
r =load_model_and_predict(model_directory, example_text)
|
||||||
|
print(r)
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import requests
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from bs4 import BeautifulSoup, Comment, element, Tag, NavigableString
|
from bs4 import BeautifulSoup, Comment, element, Tag, NavigableString
|
||||||
import html2text
|
import html2text
|
||||||
import json
|
import json
|
||||||
@@ -272,7 +273,8 @@ def get_content_of_website(html, word_count_threshold = MIN_WORD_THRESHOLD):
|
|||||||
h = CustomHTML2Text()
|
h = CustomHTML2Text()
|
||||||
h.ignore_links = True
|
h.ignore_links = True
|
||||||
markdown = h.handle(cleaned_html)
|
markdown = h.handle(cleaned_html)
|
||||||
|
markdown = markdown.replace(' ```', '```')
|
||||||
|
|
||||||
# Return the Markdown content
|
# Return the Markdown content
|
||||||
return{
|
return{
|
||||||
'markdown': markdown,
|
'markdown': markdown,
|
||||||
@@ -416,4 +418,50 @@ def extract_blocks_batch(batch_data, provider = "groq/llama3-70b-8192", api_toke
|
|||||||
}]
|
}]
|
||||||
all_blocks.append(blocks)
|
all_blocks.append(blocks)
|
||||||
|
|
||||||
return sum(all_blocks, [])
|
return sum(all_blocks, [])
|
||||||
|
|
||||||
|
|
||||||
|
def merge_chunks_based_on_token_threshold(chunks, token_threshold):
|
||||||
|
"""
|
||||||
|
Merges small chunks into larger ones based on the total token threshold.
|
||||||
|
|
||||||
|
:param chunks: List of text chunks to be merged based on token count.
|
||||||
|
:param token_threshold: Max number of tokens for each merged chunk.
|
||||||
|
:return: List of merged text chunks.
|
||||||
|
"""
|
||||||
|
merged_sections = []
|
||||||
|
current_chunk = []
|
||||||
|
total_token_so_far = 0
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_token_count = len(chunk.split()) * 1.3 # Estimate token count with a factor
|
||||||
|
if total_token_so_far + chunk_token_count < token_threshold:
|
||||||
|
current_chunk.append(chunk)
|
||||||
|
total_token_so_far += chunk_token_count
|
||||||
|
else:
|
||||||
|
if current_chunk:
|
||||||
|
merged_sections.append('\n\n'.join(current_chunk))
|
||||||
|
current_chunk = [chunk]
|
||||||
|
total_token_so_far = chunk_token_count
|
||||||
|
|
||||||
|
# Add the last chunk if it exists
|
||||||
|
if current_chunk:
|
||||||
|
merged_sections.append('\n\n'.join(current_chunk))
|
||||||
|
|
||||||
|
return merged_sections
|
||||||
|
|
||||||
|
def process_sections(url: str, sections: list, provider: str, api_token: str) -> list:
|
||||||
|
parsed_json = []
|
||||||
|
if provider.startswith("groq/"):
|
||||||
|
# Sequential processing with a delay
|
||||||
|
for section in sections:
|
||||||
|
parsed_json.extend(extract_blocks(url, section, provider, api_token))
|
||||||
|
time.sleep(0.5) # 500 ms delay between each processing
|
||||||
|
else:
|
||||||
|
# Parallel processing using ThreadPoolExecutor
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
futures = [executor.submit(extract_blocks, url, section, provider, api_token) for section in sections]
|
||||||
|
for future in as_completed(futures):
|
||||||
|
parsed_json.extend(future.result())
|
||||||
|
|
||||||
|
return parsed_json
|
||||||
@@ -1,39 +1,21 @@
|
|||||||
import asyncio
|
|
||||||
import os, time
|
import os, time
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from selenium import webdriver
|
|
||||||
from selenium.webdriver.chrome.service import Service
|
|
||||||
from selenium.webdriver.common.by import By
|
|
||||||
from selenium.webdriver.support.ui import WebDriverWait
|
|
||||||
from selenium.webdriver.support import expected_conditions as EC
|
|
||||||
from selenium.webdriver.chrome.options import Options
|
|
||||||
import chromedriver_autoinstaller
|
|
||||||
from pydantic import parse_obj_as
|
|
||||||
from .models import UrlModel, CrawlResult
|
from .models import UrlModel, CrawlResult
|
||||||
from .database import init_db, get_cached_url, cache_url
|
from .database import init_db, get_cached_url, cache_url
|
||||||
from .utils import *
|
from .utils import *
|
||||||
|
from .chunking_strategy import *
|
||||||
|
from .extraction_strategy import *
|
||||||
|
from .crawler_strategy import *
|
||||||
from typing import List
|
from typing import List
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from .config import *
|
from .config import *
|
||||||
|
|
||||||
class WebCrawler:
|
class WebCrawler:
|
||||||
def __init__(self, db_path: str):
|
def __init__(self, db_path: str, crawler_strategy: CrawlerStrategy = LocalSeleniumCrawlerStrategy()):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
init_db(self.db_path)
|
init_db(self.db_path)
|
||||||
self.options = Options()
|
self.crawler_strategy = crawler_strategy
|
||||||
self.options.headless = True
|
|
||||||
self.options.add_argument("--no-sandbox")
|
|
||||||
self.options.add_argument("--disable-dev-shm-usage")
|
|
||||||
# make it headless
|
|
||||||
self.options.add_argument("--headless")
|
|
||||||
|
|
||||||
# Automatically install or update chromedriver
|
|
||||||
chromedriver_autoinstaller.install()
|
|
||||||
|
|
||||||
# Initialize WebDriver for crawling
|
|
||||||
self.service = Service(chromedriver_autoinstaller.install())
|
|
||||||
self.driver = webdriver.Chrome(service=self.service, options=self.options)
|
|
||||||
|
|
||||||
# Create the .crawl4ai folder in the user's home directory if it doesn't exist
|
# Create the .crawl4ai folder in the user's home directory if it doesn't exist
|
||||||
self.crawl4ai_folder = os.path.join(Path.home(), ".crawl4ai")
|
self.crawl4ai_folder = os.path.join(Path.home(), ".crawl4ai")
|
||||||
@@ -47,10 +29,15 @@ class WebCrawler:
|
|||||||
api_token: str = None,
|
api_token: str = None,
|
||||||
extract_blocks_flag: bool = True,
|
extract_blocks_flag: bool = True,
|
||||||
word_count_threshold = MIN_WORD_THRESHOLD,
|
word_count_threshold = MIN_WORD_THRESHOLD,
|
||||||
use_cached_html: bool = False) -> CrawlResult:
|
use_cached_html: bool = False,
|
||||||
|
extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(),
|
||||||
|
chunking_strategy: ChunkingStrategy = RegexChunking(),
|
||||||
|
**kwargs
|
||||||
|
) -> CrawlResult:
|
||||||
|
|
||||||
# make sure word_count_threshold is not lesser than MIN_WORD_THRESHOLD
|
# make sure word_count_threshold is not lesser than MIN_WORD_THRESHOLD
|
||||||
# if word_count_threshold < MIN_WORD_THRESHOLD:
|
if word_count_threshold < MIN_WORD_THRESHOLD:
|
||||||
# word_count_threshold = MIN_WORD_THRESHOLD
|
word_count_threshold = MIN_WORD_THRESHOLD
|
||||||
|
|
||||||
# Check cache first
|
# Check cache first
|
||||||
cached = get_cached_url(self.db_path, str(url_model.url))
|
cached = get_cached_url(self.db_path, str(url_model.url))
|
||||||
@@ -67,87 +54,38 @@ class WebCrawler:
|
|||||||
|
|
||||||
|
|
||||||
# Initialize WebDriver for crawling
|
# Initialize WebDriver for crawling
|
||||||
if use_cached_html:
|
t = time.time()
|
||||||
# load html from crawl4ai_folder/cache
|
try:
|
||||||
valid_file_name = str(url_model.url).replace("/", "_").replace(":", "_")
|
html = self.crawler_strategy.crawl(str(url_model.url))
|
||||||
if os.path.exists(os.path.join(self.crawl4ai_folder, "cache", valid_file_name)):
|
|
||||||
with open(os.path.join(self.crawl4ai_folder, "cache", valid_file_name), "r") as f:
|
|
||||||
html = f.read()
|
|
||||||
else:
|
|
||||||
raise Exception("Cached HTML file not found")
|
|
||||||
|
|
||||||
success = True
|
success = True
|
||||||
error_message = ""
|
error_message = ""
|
||||||
else:
|
except Exception as e:
|
||||||
service = self.service
|
html = ""
|
||||||
driver = self.driver
|
success = False
|
||||||
|
error_message = str(e)
|
||||||
try:
|
|
||||||
driver.get(str(url_model.url))
|
|
||||||
WebDriverWait(driver, 10).until(
|
|
||||||
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
|
|
||||||
)
|
|
||||||
html = driver.page_source
|
|
||||||
success = True
|
|
||||||
error_message = ""
|
|
||||||
|
|
||||||
# Save html in crawl4ai_folder/cache
|
|
||||||
valid_file_name = str(url_model.url).replace("/", "_").replace(":", "_")
|
|
||||||
with open(os.path.join(self.crawl4ai_folder, "cache", valid_file_name), "w") as f:
|
|
||||||
f.write(html)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
html = ""
|
|
||||||
success = False
|
|
||||||
error_message = str(e)
|
|
||||||
finally:
|
|
||||||
driver.quit()
|
|
||||||
|
|
||||||
# Extract content from HTML
|
# Extract content from HTML
|
||||||
result = get_content_of_website(html, word_count_threshold)
|
result = get_content_of_website(html, word_count_threshold)
|
||||||
cleaned_html = result.get('cleaned_html', html)
|
cleaned_html = result.get('cleaned_html', html)
|
||||||
markdown = result.get('markdown', "")
|
markdown = result.get('markdown', "")
|
||||||
|
|
||||||
print("Crawling is done 🚀")
|
# Print a profession LOG style message, show time taken and say crawling is done
|
||||||
|
print(f"[LOG] 🚀 Crawling done for {url_model.url}, success: {success}, time taken: {time.time() - t} seconds")
|
||||||
|
|
||||||
|
|
||||||
parsed_json = []
|
parsed_json = []
|
||||||
if extract_blocks_flag:
|
if extract_blocks_flag:
|
||||||
|
print(f"[LOG] 🚀 Extracting semantic blocks for {url_model.url}")
|
||||||
# Split markdown into sections
|
# Split markdown into sections
|
||||||
paragraphs = markdown.split('\n\n')
|
sections = chunking_strategy.chunk(markdown)
|
||||||
sections = []
|
# sections = merge_chunks_based_on_token_threshold(sections, CHUNK_TOKEN_THRESHOLD)
|
||||||
chunks = []
|
|
||||||
total_token_so_far = 0
|
|
||||||
|
|
||||||
for paragraph in paragraphs:
|
|
||||||
if total_token_so_far < CHUNK_TOKEN_THRESHOLD:
|
|
||||||
chunk = paragraph.split(' ')
|
|
||||||
total_token_so_far += len(chunk) * 1.3
|
|
||||||
chunks.append(paragraph)
|
|
||||||
else:
|
|
||||||
sections.append('\n\n'.join(chunks))
|
|
||||||
chunks = [paragraph]
|
|
||||||
total_token_so_far = len(paragraph.split(' ')) * 1.3
|
|
||||||
|
|
||||||
if chunks:
|
|
||||||
sections.append('\n\n'.join(chunks))
|
|
||||||
|
|
||||||
# Process sections to extract blocks
|
|
||||||
parsed_json = []
|
|
||||||
if provider.startswith("groq/"):
|
|
||||||
# Sequential processing with a delay
|
|
||||||
for section in sections:
|
|
||||||
parsed_json.extend(extract_blocks(str(url_model.url), section, provider, api_token))
|
|
||||||
time.sleep(0.5) # 500 ms delay between each processing
|
|
||||||
else:
|
|
||||||
# Parallel processing using ThreadPoolExecutor
|
|
||||||
with ThreadPoolExecutor() as executor:
|
|
||||||
futures = [executor.submit(extract_blocks, str(url_model.url), section, provider, api_token) for section in sections]
|
|
||||||
for future in as_completed(futures):
|
|
||||||
parsed_json.extend(future.result())
|
|
||||||
|
|
||||||
|
parsed_json = extraction_strategy.run(str(url_model.url), sections, provider, api_token)
|
||||||
parsed_json = json.dumps(parsed_json)
|
parsed_json = json.dumps(parsed_json)
|
||||||
|
print(f"[LOG] 🚀 Extraction done for {url_model.url}")
|
||||||
else:
|
else:
|
||||||
parsed_json = "{}"
|
parsed_json = "{}"
|
||||||
|
print(f"[LOG] 🚀 Skipping extraction for {url_model.url}")
|
||||||
|
|
||||||
# Cache the result
|
# Cache the result
|
||||||
cleaned_html = beautify_html(cleaned_html)
|
cleaned_html = beautify_html(cleaned_html)
|
||||||
@@ -163,7 +101,23 @@ class WebCrawler:
|
|||||||
error_message=error_message
|
error_message=error_message
|
||||||
)
|
)
|
||||||
|
|
||||||
def fetch_pages(self, url_models: List[UrlModel], provider: str = DEFAULT_PROVIDER, api_token: str = None) -> List[CrawlResult]:
|
def fetch_pages(self, url_models: List[UrlModel], provider: str = DEFAULT_PROVIDER, api_token: str = None,
|
||||||
|
extract_blocks_flag: bool = True, word_count_threshold=MIN_WORD_THRESHOLD,
|
||||||
|
use_cached_html: bool = False, extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(),
|
||||||
|
chunking_strategy: ChunkingStrategy = RegexChunking(), **kwargs) -> List[CrawlResult]:
|
||||||
|
|
||||||
|
def fetch_page_wrapper(url_model, *args, **kwargs):
|
||||||
|
return self.fetch_page(url_model, *args, **kwargs)
|
||||||
|
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
results = list(executor.map(self.fetch_page, url_models, [provider] * len(url_models), [api_token] * len(url_models)))
|
results = list(executor.map(fetch_page_wrapper, url_models,
|
||||||
|
[provider] * len(url_models),
|
||||||
|
[api_token] * len(url_models),
|
||||||
|
[extract_blocks_flag] * len(url_models),
|
||||||
|
[word_count_threshold] * len(url_models),
|
||||||
|
[use_cached_html] * len(url_models),
|
||||||
|
[extraction_strategy] * len(url_models),
|
||||||
|
[chunking_strategy] * len(url_models),
|
||||||
|
*[kwargs] * len(url_models)))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@@ -8,11 +8,12 @@ def main():
|
|||||||
crawler = WebCrawler(db_path='crawler_data.db')
|
crawler = WebCrawler(db_path='crawler_data.db')
|
||||||
|
|
||||||
# Fetch a single page
|
# Fetch a single page
|
||||||
single_url = UrlModel(url='https://www.nbcnews.com/business', forced=False)
|
single_url = UrlModel(url='https://www.nbcnews.com/business', forced=True)
|
||||||
result = crawler.fetch_page(
|
result = crawler.fetch_page(
|
||||||
single_url,
|
single_url,
|
||||||
provider= "openai/gpt-3.5-turbo",
|
provider= "openai/gpt-3.5-turbo",
|
||||||
api_token = os.getenv('OPENAI_API_KEY'),
|
api_token = os.getenv('OPENAI_API_KEY'),
|
||||||
|
use_cached_html = True,
|
||||||
extract_blocks_flag=True,
|
extract_blocks_flag=True,
|
||||||
word_count_threshold=10
|
word_count_threshold=10
|
||||||
)
|
)
|
||||||
|
|||||||
144
models/reuters/config.cfg
Normal file
144
models/reuters/config.cfg
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
[paths]
|
||||||
|
train = null
|
||||||
|
dev = null
|
||||||
|
vectors = null
|
||||||
|
init_tok2vec = null
|
||||||
|
|
||||||
|
[system]
|
||||||
|
seed = 0
|
||||||
|
gpu_allocator = null
|
||||||
|
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["textcat_multilabel"]
|
||||||
|
disabled = []
|
||||||
|
before_creation = null
|
||||||
|
after_creation = null
|
||||||
|
after_pipeline_creation = null
|
||||||
|
batch_size = 1000
|
||||||
|
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
|
||||||
|
vectors = {"@vectors":"spacy.Vectors.v1"}
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.textcat_multilabel]
|
||||||
|
factory = "textcat_multilabel"
|
||||||
|
scorer = {"@scorers":"spacy.textcat_multilabel_scorer.v2"}
|
||||||
|
threshold = 0.5
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model]
|
||||||
|
@architectures = "spacy.TextCatEnsemble.v2"
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.linear_model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v3"
|
||||||
|
exclusive_classes = false
|
||||||
|
length = 262144
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.tok2vec.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v2"
|
||||||
|
width = 64
|
||||||
|
rows = [2000,2000,500,1000,500]
|
||||||
|
attrs = ["NORM","LOWER","PREFIX","SUFFIX","SHAPE"]
|
||||||
|
include_static_vectors = false
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.tok2vec.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
width = 64
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
depth = 2
|
||||||
|
|
||||||
|
[corpora]
|
||||||
|
|
||||||
|
[corpora.dev]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths.dev}
|
||||||
|
gold_preproc = false
|
||||||
|
max_length = 0
|
||||||
|
limit = 0
|
||||||
|
augmenter = null
|
||||||
|
|
||||||
|
[corpora.train]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths.train}
|
||||||
|
gold_preproc = false
|
||||||
|
max_length = 0
|
||||||
|
limit = 0
|
||||||
|
augmenter = null
|
||||||
|
|
||||||
|
[training]
|
||||||
|
seed = ${system.seed}
|
||||||
|
gpu_allocator = ${system.gpu_allocator}
|
||||||
|
dropout = 0.1
|
||||||
|
accumulate_gradient = 1
|
||||||
|
patience = 1600
|
||||||
|
max_epochs = 0
|
||||||
|
max_steps = 20000
|
||||||
|
eval_frequency = 200
|
||||||
|
frozen_components = []
|
||||||
|
annotating_components = []
|
||||||
|
dev_corpus = "corpora.dev"
|
||||||
|
train_corpus = "corpora.train"
|
||||||
|
before_to_disk = null
|
||||||
|
before_update = null
|
||||||
|
|
||||||
|
[training.batcher]
|
||||||
|
@batchers = "spacy.batch_by_words.v1"
|
||||||
|
discard_oversize = false
|
||||||
|
tolerance = 0.2
|
||||||
|
get_length = null
|
||||||
|
|
||||||
|
[training.batcher.size]
|
||||||
|
@schedules = "compounding.v1"
|
||||||
|
start = 100
|
||||||
|
stop = 1000
|
||||||
|
compound = 1.001
|
||||||
|
t = 0.0
|
||||||
|
|
||||||
|
[training.logger]
|
||||||
|
@loggers = "spacy.ConsoleLogger.v1"
|
||||||
|
progress_bar = false
|
||||||
|
|
||||||
|
[training.optimizer]
|
||||||
|
@optimizers = "Adam.v1"
|
||||||
|
beta1 = 0.9
|
||||||
|
beta2 = 0.999
|
||||||
|
L2_is_weight_decay = true
|
||||||
|
L2 = 0.01
|
||||||
|
grad_clip = 1.0
|
||||||
|
use_averages = false
|
||||||
|
eps = 0.00000001
|
||||||
|
learn_rate = 0.001
|
||||||
|
|
||||||
|
[training.score_weights]
|
||||||
|
cats_score = 1.0
|
||||||
|
cats_score_desc = null
|
||||||
|
cats_micro_p = null
|
||||||
|
cats_micro_r = null
|
||||||
|
cats_micro_f = null
|
||||||
|
cats_macro_p = null
|
||||||
|
cats_macro_r = null
|
||||||
|
cats_macro_f = null
|
||||||
|
cats_macro_auc = null
|
||||||
|
cats_f_per_type = null
|
||||||
|
|
||||||
|
[pretraining]
|
||||||
|
|
||||||
|
[initialize]
|
||||||
|
vectors = ${paths.vectors}
|
||||||
|
init_tok2vec = ${paths.init_tok2vec}
|
||||||
|
vocab_data = null
|
||||||
|
lookups = null
|
||||||
|
before_init = null
|
||||||
|
after_init = null
|
||||||
|
|
||||||
|
[initialize.components]
|
||||||
|
|
||||||
|
[initialize.tokenizer]
|
||||||
122
models/reuters/meta.json
Normal file
122
models/reuters/meta.json
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
{
|
||||||
|
"lang":"en",
|
||||||
|
"name":"pipeline",
|
||||||
|
"version":"0.0.0",
|
||||||
|
"spacy_version":">=3.7.4,<3.8.0",
|
||||||
|
"description":"",
|
||||||
|
"author":"",
|
||||||
|
"email":"",
|
||||||
|
"url":"",
|
||||||
|
"license":"",
|
||||||
|
"spacy_git_version":"bff8725f4",
|
||||||
|
"vectors":{
|
||||||
|
"width":0,
|
||||||
|
"vectors":0,
|
||||||
|
"keys":0,
|
||||||
|
"name":null,
|
||||||
|
"mode":"default"
|
||||||
|
},
|
||||||
|
"labels":{
|
||||||
|
"textcat_multilabel":[
|
||||||
|
"acq",
|
||||||
|
"alum",
|
||||||
|
"barley",
|
||||||
|
"bop",
|
||||||
|
"carcass",
|
||||||
|
"castor-oil",
|
||||||
|
"cocoa",
|
||||||
|
"coconut",
|
||||||
|
"coconut-oil",
|
||||||
|
"coffee",
|
||||||
|
"copper",
|
||||||
|
"copra-cake",
|
||||||
|
"corn",
|
||||||
|
"cotton",
|
||||||
|
"cotton-oil",
|
||||||
|
"cpi",
|
||||||
|
"cpu",
|
||||||
|
"crude",
|
||||||
|
"dfl",
|
||||||
|
"dlr",
|
||||||
|
"dmk",
|
||||||
|
"earn",
|
||||||
|
"fuel",
|
||||||
|
"gas",
|
||||||
|
"gnp",
|
||||||
|
"gold",
|
||||||
|
"grain",
|
||||||
|
"groundnut",
|
||||||
|
"groundnut-oil",
|
||||||
|
"heat",
|
||||||
|
"hog",
|
||||||
|
"housing",
|
||||||
|
"income",
|
||||||
|
"instal-debt",
|
||||||
|
"interest",
|
||||||
|
"ipi",
|
||||||
|
"iron-steel",
|
||||||
|
"jet",
|
||||||
|
"jobs",
|
||||||
|
"l-cattle",
|
||||||
|
"lead",
|
||||||
|
"lei",
|
||||||
|
"lin-oil",
|
||||||
|
"livestock",
|
||||||
|
"lumber",
|
||||||
|
"meal-feed",
|
||||||
|
"money-fx",
|
||||||
|
"money-supply",
|
||||||
|
"naphtha",
|
||||||
|
"nat-gas",
|
||||||
|
"nickel",
|
||||||
|
"nkr",
|
||||||
|
"nzdlr",
|
||||||
|
"oat",
|
||||||
|
"oilseed",
|
||||||
|
"orange",
|
||||||
|
"palladium",
|
||||||
|
"palm-oil",
|
||||||
|
"palmkernel",
|
||||||
|
"pet-chem",
|
||||||
|
"platinum",
|
||||||
|
"potato",
|
||||||
|
"propane",
|
||||||
|
"rand",
|
||||||
|
"rape-oil",
|
||||||
|
"rapeseed",
|
||||||
|
"reserves",
|
||||||
|
"retail",
|
||||||
|
"rice",
|
||||||
|
"rubber",
|
||||||
|
"rye",
|
||||||
|
"ship",
|
||||||
|
"silver",
|
||||||
|
"sorghum",
|
||||||
|
"soy-meal",
|
||||||
|
"soy-oil",
|
||||||
|
"soybean",
|
||||||
|
"strategic-metal",
|
||||||
|
"sugar",
|
||||||
|
"sun-meal",
|
||||||
|
"sun-oil",
|
||||||
|
"sunseed",
|
||||||
|
"tea",
|
||||||
|
"tin",
|
||||||
|
"trade",
|
||||||
|
"veg-oil",
|
||||||
|
"wheat",
|
||||||
|
"wpi",
|
||||||
|
"yen",
|
||||||
|
"zinc"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"pipeline":[
|
||||||
|
"textcat_multilabel"
|
||||||
|
],
|
||||||
|
"components":[
|
||||||
|
"textcat_multilabel"
|
||||||
|
],
|
||||||
|
"disabled":[
|
||||||
|
|
||||||
|
]
|
||||||
|
}
|
||||||
95
models/reuters/textcat_multilabel/cfg
Normal file
95
models/reuters/textcat_multilabel/cfg
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
{
|
||||||
|
"labels":[
|
||||||
|
"acq",
|
||||||
|
"alum",
|
||||||
|
"barley",
|
||||||
|
"bop",
|
||||||
|
"carcass",
|
||||||
|
"castor-oil",
|
||||||
|
"cocoa",
|
||||||
|
"coconut",
|
||||||
|
"coconut-oil",
|
||||||
|
"coffee",
|
||||||
|
"copper",
|
||||||
|
"copra-cake",
|
||||||
|
"corn",
|
||||||
|
"cotton",
|
||||||
|
"cotton-oil",
|
||||||
|
"cpi",
|
||||||
|
"cpu",
|
||||||
|
"crude",
|
||||||
|
"dfl",
|
||||||
|
"dlr",
|
||||||
|
"dmk",
|
||||||
|
"earn",
|
||||||
|
"fuel",
|
||||||
|
"gas",
|
||||||
|
"gnp",
|
||||||
|
"gold",
|
||||||
|
"grain",
|
||||||
|
"groundnut",
|
||||||
|
"groundnut-oil",
|
||||||
|
"heat",
|
||||||
|
"hog",
|
||||||
|
"housing",
|
||||||
|
"income",
|
||||||
|
"instal-debt",
|
||||||
|
"interest",
|
||||||
|
"ipi",
|
||||||
|
"iron-steel",
|
||||||
|
"jet",
|
||||||
|
"jobs",
|
||||||
|
"l-cattle",
|
||||||
|
"lead",
|
||||||
|
"lei",
|
||||||
|
"lin-oil",
|
||||||
|
"livestock",
|
||||||
|
"lumber",
|
||||||
|
"meal-feed",
|
||||||
|
"money-fx",
|
||||||
|
"money-supply",
|
||||||
|
"naphtha",
|
||||||
|
"nat-gas",
|
||||||
|
"nickel",
|
||||||
|
"nkr",
|
||||||
|
"nzdlr",
|
||||||
|
"oat",
|
||||||
|
"oilseed",
|
||||||
|
"orange",
|
||||||
|
"palladium",
|
||||||
|
"palm-oil",
|
||||||
|
"palmkernel",
|
||||||
|
"pet-chem",
|
||||||
|
"platinum",
|
||||||
|
"potato",
|
||||||
|
"propane",
|
||||||
|
"rand",
|
||||||
|
"rape-oil",
|
||||||
|
"rapeseed",
|
||||||
|
"reserves",
|
||||||
|
"retail",
|
||||||
|
"rice",
|
||||||
|
"rubber",
|
||||||
|
"rye",
|
||||||
|
"ship",
|
||||||
|
"silver",
|
||||||
|
"sorghum",
|
||||||
|
"soy-meal",
|
||||||
|
"soy-oil",
|
||||||
|
"soybean",
|
||||||
|
"strategic-metal",
|
||||||
|
"sugar",
|
||||||
|
"sun-meal",
|
||||||
|
"sun-oil",
|
||||||
|
"sunseed",
|
||||||
|
"tea",
|
||||||
|
"tin",
|
||||||
|
"trade",
|
||||||
|
"veg-oil",
|
||||||
|
"wheat",
|
||||||
|
"wpi",
|
||||||
|
"yen",
|
||||||
|
"zinc"
|
||||||
|
],
|
||||||
|
"threshold":0.5
|
||||||
|
}
|
||||||
BIN
models/reuters/textcat_multilabel/model
Normal file
BIN
models/reuters/textcat_multilabel/model
Normal file
Binary file not shown.
3
models/reuters/tokenizer
Normal file
3
models/reuters/tokenizer
Normal file
File diff suppressed because one or more lines are too long
1
models/reuters/vocab/key2row
Normal file
1
models/reuters/vocab/key2row
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<EFBFBD>
|
||||||
1
models/reuters/vocab/lookups.bin
Normal file
1
models/reuters/vocab/lookups.bin
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<EFBFBD>
|
||||||
83457
models/reuters/vocab/strings.json
Normal file
83457
models/reuters/vocab/strings.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
models/reuters/vocab/vectors
Normal file
BIN
models/reuters/vocab/vectors
Normal file
Binary file not shown.
3
models/reuters/vocab/vectors.cfg
Normal file
3
models/reuters/vocab/vectors.cfg
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"mode":"default"
|
||||||
|
}
|
||||||
@@ -10,4 +10,6 @@ requests
|
|||||||
bs4
|
bs4
|
||||||
html2text
|
html2text
|
||||||
litellm
|
litellm
|
||||||
python-dotenv
|
python-dotenv
|
||||||
|
nltk
|
||||||
|
spacy
|
||||||
Reference in New Issue
Block a user