Update:
- Text Categorization - Crawler, Extraction, and Chunking strategies - Clustering for semantic segmentation
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -166,3 +166,5 @@ Crawl4AI.egg-info/*
|
||||
crawler_data.db
|
||||
.vscode/
|
||||
test_pad.py
|
||||
.data/
|
||||
Crawl4AI.egg-info/
|
||||
95
crawl4ai/chunking_strategy.py
Normal file
95
crawl4ai/chunking_strategy.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import re
|
||||
import spacy
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.tokenize import word_tokenize, TextTilingTokenizer
|
||||
from collections import Counter
|
||||
import string
|
||||
|
||||
# Define the abstract base class for chunking strategies
|
||||
class ChunkingStrategy(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def chunk(self, text: str) -> list:
|
||||
"""
|
||||
Abstract method to chunk the given text.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Regex-based chunking
|
||||
class RegexChunking(ChunkingStrategy):
|
||||
def __init__(self, patterns=None):
|
||||
if patterns is None:
|
||||
patterns = [r'\n\n'] # Default split pattern
|
||||
self.patterns = patterns
|
||||
|
||||
def chunk(self, text: str) -> list:
|
||||
paragraphs = [text]
|
||||
for pattern in self.patterns:
|
||||
new_paragraphs = []
|
||||
for paragraph in paragraphs:
|
||||
new_paragraphs.extend(re.split(pattern, paragraph))
|
||||
paragraphs = new_paragraphs
|
||||
return paragraphs
|
||||
|
||||
# NLP-based sentence chunking using spaCy
|
||||
class NlpSentenceChunking(ChunkingStrategy):
|
||||
def __init__(self, model='en_core_web_sm'):
|
||||
self.nlp = spacy.load(model)
|
||||
|
||||
def chunk(self, text: str) -> list:
|
||||
doc = self.nlp(text)
|
||||
return [sent.text.strip() for sent in doc.sents]
|
||||
|
||||
# Topic-based segmentation using TextTiling
|
||||
class TopicSegmentationChunking(ChunkingStrategy):
|
||||
def __init__(self, num_keywords=3):
|
||||
self.tokenizer = TextTilingTokenizer()
|
||||
self.num_keywords = num_keywords
|
||||
|
||||
def chunk(self, text: str) -> list:
|
||||
# Use the TextTilingTokenizer to segment the text
|
||||
segmented_topics = self.tokenizer.tokenize(text)
|
||||
return segmented_topics
|
||||
|
||||
def extract_keywords(self, text: str) -> list:
|
||||
# Tokenize and remove stopwords and punctuation
|
||||
tokens = word_tokenize(text)
|
||||
tokens = [token.lower() for token in tokens if token not in stopwords.words('english') and token not in string.punctuation]
|
||||
|
||||
# Calculate frequency distribution
|
||||
freq_dist = Counter(tokens)
|
||||
keywords = [word for word, freq in freq_dist.most_common(self.num_keywords)]
|
||||
return keywords
|
||||
|
||||
def chunk_with_topics(self, text: str) -> list:
|
||||
# Segment the text into topics
|
||||
segments = self.chunk(text)
|
||||
# Extract keywords for each topic segment
|
||||
segments_with_topics = [(segment, self.extract_keywords(segment)) for segment in segments]
|
||||
return segments_with_topics
|
||||
|
||||
# Fixed-length word chunks
|
||||
class FixedLengthWordChunking(ChunkingStrategy):
|
||||
def __init__(self, chunk_size=100):
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def chunk(self, text: str) -> list:
|
||||
words = text.split()
|
||||
return [' '.join(words[i:i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)]
|
||||
|
||||
# Sliding window chunking
|
||||
class SlidingWindowChunking(ChunkingStrategy):
|
||||
def __init__(self, window_size=100, step=50):
|
||||
self.window_size = window_size
|
||||
self.step = step
|
||||
|
||||
def chunk(self, text: str) -> list:
|
||||
words = text.split()
|
||||
chunks = []
|
||||
for i in range(0, len(words), self.step):
|
||||
chunks.append(' '.join(words[i:i + self.window_size]))
|
||||
return chunks
|
||||
|
||||
|
||||
58
crawl4ai/crawler_strategy.py
Normal file
58
crawl4ai/crawler_strategy.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.service import Service
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
import chromedriver_autoinstaller
|
||||
from typing import List
|
||||
import requests
|
||||
|
||||
|
||||
class CrawlerStrategy(ABC):
|
||||
@abstractmethod
|
||||
def crawl(self, url: str) -> str:
|
||||
pass
|
||||
|
||||
class CloudCrawlerStrategy(CrawlerStrategy):
|
||||
def crawl(self, url: str) -> str:
|
||||
data = {
|
||||
"urls": [url],
|
||||
"provider_model": "",
|
||||
"api_token": "token",
|
||||
"include_raw_html": True,
|
||||
"forced": True,
|
||||
"extract_blocks": False,
|
||||
"word_count_threshold": 10
|
||||
}
|
||||
|
||||
response = requests.post("http://crawl4ai.uccode.io/crawl", json=data)
|
||||
response = response.json()
|
||||
html = response["results"][0]["html"]
|
||||
return html
|
||||
|
||||
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
def __init__(self):
|
||||
self.options = Options()
|
||||
self.options.headless = True
|
||||
self.options.add_argument("--no-sandbox")
|
||||
self.options.add_argument("--disable-dev-shm-usage")
|
||||
self.options.add_argument("--headless")
|
||||
|
||||
chromedriver_autoinstaller.install()
|
||||
self.service = Service(chromedriver_autoinstaller.install())
|
||||
self.driver = webdriver.Chrome(service=self.service, options=self.options)
|
||||
|
||||
def crawl(self, url: str, use_cached_html = False) -> str:
|
||||
if use_cached_html:
|
||||
return get_content_of_website(url)
|
||||
self.driver.get(url)
|
||||
WebDriverWait(self.driver, 10).until(
|
||||
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
|
||||
)
|
||||
html = self.driver.page_source
|
||||
return html
|
||||
|
||||
def quit(self):
|
||||
self.driver.quit()
|
||||
358
crawl4ai/extraction_strategy.py
Normal file
358
crawl4ai/extraction_strategy.py
Normal file
@@ -0,0 +1,358 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Dict, Optional, Union
|
||||
from scipy.cluster.hierarchy import linkage, fcluster
|
||||
from scipy.spatial.distance import pdist
|
||||
from transformers import BertTokenizer, BertModel, pipeline
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import nltk
|
||||
from nltk.tokenize import TextTilingTokenizer
|
||||
import json, time
|
||||
import torch
|
||||
import spacy
|
||||
|
||||
from .prompts import PROMPT_EXTRACT_BLOCKS
|
||||
from .config import *
|
||||
from .utils import *
|
||||
|
||||
class ExtractionStrategy(ABC):
|
||||
"""
|
||||
Abstract base class for all extraction strategies.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.DEL = "<|DEL|>"
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract meaningful blocks or chunks from the given HTML.
|
||||
|
||||
:param url: The URL of the webpage.
|
||||
:param html: The HTML content of the webpage.
|
||||
:return: A list of extracted blocks or chunks.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process sections of text in parallel by default.
|
||||
|
||||
:param url: The URL of the webpage.
|
||||
:param sections: List of sections (strings) to process.
|
||||
:return: A list of processed JSON blocks.
|
||||
"""
|
||||
parsed_json = []
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(self.extract, url, section, **kwargs) for section in sections]
|
||||
for future in as_completed(futures):
|
||||
parsed_json.extend(future.result())
|
||||
return parsed_json
|
||||
|
||||
class LLMExtractionStrategy(ExtractionStrategy):
|
||||
def __init__(self, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None):
|
||||
"""
|
||||
Initialize the strategy with clustering parameters.
|
||||
|
||||
:param word_count_threshold: Minimum number of words per cluster.
|
||||
:param max_dist: The maximum cophenetic distance on the dendrogram to form clusters.
|
||||
:param linkage_method: The linkage method for hierarchical clustering.
|
||||
"""
|
||||
super().__init__()
|
||||
self.provider = provider
|
||||
self.api_token = api_token
|
||||
|
||||
|
||||
def extract(self, url: str, html: str, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
api_token = PROVIDER_MODELS.get(provider, None) if not api_token else api_token
|
||||
|
||||
variable_values = {
|
||||
"URL": url,
|
||||
"HTML": escape_json_string(sanitize_html(html)),
|
||||
}
|
||||
|
||||
prompt_with_variables = PROMPT_EXTRACT_BLOCKS
|
||||
for variable in variable_values:
|
||||
prompt_with_variables = prompt_with_variables.replace(
|
||||
"{" + variable + "}", variable_values[variable]
|
||||
)
|
||||
|
||||
response = perform_completion_with_backoff(provider, prompt_with_variables, api_token)
|
||||
|
||||
try:
|
||||
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
|
||||
blocks = json.loads(blocks)
|
||||
for block in blocks:
|
||||
block['error'] = False
|
||||
except Exception as e:
|
||||
print("Error extracting blocks:", str(e))
|
||||
parsed, unparsed = split_and_parse_json_objects(response.choices[0].message.content)
|
||||
blocks = parsed
|
||||
if unparsed:
|
||||
blocks.append({
|
||||
"index": 0,
|
||||
"error": True,
|
||||
"tags": ["error"],
|
||||
"content": unparsed
|
||||
})
|
||||
return blocks
|
||||
|
||||
def run(self, url: str, sections: List[str], provider: str, api_token: Optional[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
|
||||
"""
|
||||
parsed_json = []
|
||||
if provider.startswith("groq/"):
|
||||
# Sequential processing with a delay
|
||||
for section in sections:
|
||||
parsed_json.extend(self.extract(url, section, provider, api_token))
|
||||
time.sleep(0.5) # 500 ms delay between each processing
|
||||
else:
|
||||
# Parallel processing using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(self.extract, url, section, provider, api_token) for section in sections]
|
||||
for future in as_completed(futures):
|
||||
parsed_json.extend(future.result())
|
||||
|
||||
return parsed_json
|
||||
|
||||
class HierarchicalClusteringStrategy(ExtractionStrategy):
|
||||
def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3):
|
||||
"""
|
||||
Initialize the strategy with clustering parameters.
|
||||
|
||||
:param word_count_threshold: Minimum number of words per cluster.
|
||||
:param max_dist: The maximum cophenetic distance on the dendrogram to form clusters.
|
||||
:param linkage_method: The linkage method for hierarchical clustering.
|
||||
:param top_k: Number of top categories to extract.
|
||||
"""
|
||||
super().__init__()
|
||||
self.word_count_threshold = word_count_threshold
|
||||
self.max_dist = max_dist
|
||||
self.linkage_method = linkage_method
|
||||
self.top_k = top_k
|
||||
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
self.nlp = spacy.load("models/reuters")
|
||||
|
||||
|
||||
|
||||
def get_embeddings(self, sentences: List[str]):
|
||||
"""
|
||||
Get BERT embeddings for a list of sentences.
|
||||
|
||||
:param sentences: List of text chunks (sentences).
|
||||
:return: NumPy array of embeddings.
|
||||
"""
|
||||
# Tokenize sentences and convert to tensor
|
||||
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
||||
# Compute token embeddings
|
||||
with torch.no_grad():
|
||||
model_output = self.model(**encoded_input)
|
||||
# Get embeddings from the last hidden state (mean pooling)
|
||||
embeddings = model_output.last_hidden_state.mean(1)
|
||||
return embeddings.numpy()
|
||||
|
||||
def hierarchical_clustering(self, sentences: List[str]):
|
||||
"""
|
||||
Perform hierarchical clustering on sentences and return cluster labels.
|
||||
|
||||
:param sentences: List of text chunks (sentences).
|
||||
:return: NumPy array of cluster labels.
|
||||
"""
|
||||
# Get embeddings
|
||||
embeddings = self.get_embeddings(sentences)
|
||||
# Compute pairwise cosine distances
|
||||
distance_matrix = pdist(embeddings, 'cosine')
|
||||
# Perform agglomerative clustering respecting order
|
||||
linked = linkage(distance_matrix, method=self.linkage_method)
|
||||
# Form flat clusters
|
||||
labels = fcluster(linked, self.max_dist, criterion='distance')
|
||||
return labels
|
||||
|
||||
def filter_clusters_by_word_count(self, clusters: Dict[int, List[str]]):
|
||||
"""
|
||||
Filter clusters to remove those with a word count below the threshold.
|
||||
|
||||
:param clusters: Dictionary of clusters.
|
||||
:return: Filtered dictionary of clusters.
|
||||
"""
|
||||
filtered_clusters = {}
|
||||
for cluster_id, texts in clusters.items():
|
||||
# Concatenate texts for analysis
|
||||
full_text = " ".join(texts)
|
||||
# Count words
|
||||
word_count = len(full_text.split())
|
||||
|
||||
# Keep clusters with word count above the threshold
|
||||
if word_count >= self.word_count_threshold:
|
||||
filtered_clusters[cluster_id] = texts
|
||||
|
||||
return filtered_clusters
|
||||
|
||||
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract clusters from HTML content using hierarchical clustering.
|
||||
|
||||
:param url: The URL of the webpage.
|
||||
:param html: The HTML content of the webpage.
|
||||
:return: A list of dictionaries representing the clusters.
|
||||
"""
|
||||
# Assume `html` is a list of text chunks for this strategy
|
||||
text_chunks = html.split(self.DEL) # Split by lines or paragraphs as needed
|
||||
|
||||
# Perform clustering
|
||||
labels = self.hierarchical_clustering(text_chunks)
|
||||
|
||||
# Organize texts by their cluster labels, retaining order
|
||||
clusters = {}
|
||||
for index, label in enumerate(labels):
|
||||
clusters.setdefault(label, []).append(text_chunks[index])
|
||||
|
||||
# Filter clusters by word count
|
||||
filtered_clusters = self.filter_clusters_by_word_count(clusters)
|
||||
|
||||
# Convert filtered clusters to a sorted list of dictionaries
|
||||
cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)]
|
||||
|
||||
# Process the text with the loaded model
|
||||
for cluster in cluster_list:
|
||||
doc = self.nlp(cluster['content'])
|
||||
tok_k = self.top_k
|
||||
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
||||
cluster['tags'] = [cat for cat, _ in top_categories]
|
||||
|
||||
return cluster_list
|
||||
|
||||
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process sections using hierarchical clustering.
|
||||
|
||||
:param url: The URL of the webpage.
|
||||
:param sections: List of sections (strings) to process.
|
||||
:param provider: The provider to be used for extraction (not used here).
|
||||
:param api_token: Optional API token for the provider (not used here).
|
||||
:return: A list of processed JSON blocks.
|
||||
"""
|
||||
# This strategy processes all sections together
|
||||
|
||||
return self.extract(url, self.DEL.join(sections), **kwargs)
|
||||
|
||||
class TopicExtractionStrategy(ExtractionStrategy):
|
||||
def __init__(self, num_keywords: int = 3):
|
||||
"""
|
||||
Initialize the topic extraction strategy with parameters for topic segmentation.
|
||||
|
||||
:param num_keywords: Number of keywords to represent each topic segment.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_keywords = num_keywords
|
||||
self.tokenizer = TextTilingTokenizer()
|
||||
|
||||
def extract_keywords(self, text: str) -> List[str]:
|
||||
"""
|
||||
Extract keywords from a given text segment using simple frequency analysis.
|
||||
|
||||
:param text: The text segment from which to extract keywords.
|
||||
:return: A list of keyword strings.
|
||||
"""
|
||||
# Tokenize the text and compute word frequency
|
||||
words = nltk.word_tokenize(text)
|
||||
freq_dist = nltk.FreqDist(words)
|
||||
# Get the most common words as keywords
|
||||
keywords = [word for (word, _) in freq_dist.most_common(self.num_keywords)]
|
||||
return keywords
|
||||
|
||||
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract topics from HTML content using TextTiling for segmentation and keyword extraction.
|
||||
|
||||
:param url: The URL of the webpage.
|
||||
:param html: The HTML content of the webpage.
|
||||
:param provider: The provider to be used for extraction (not used here).
|
||||
:param api_token: Optional API token for the provider (not used here).
|
||||
:return: A list of dictionaries representing the topics.
|
||||
"""
|
||||
# Use TextTiling to segment the text into topics
|
||||
segmented_topics = html.split(self.DEL) # Split by lines or paragraphs as needed
|
||||
|
||||
# Prepare the output as a list of dictionaries
|
||||
topic_list = []
|
||||
for i, segment in enumerate(segmented_topics):
|
||||
# Extract keywords for each segment
|
||||
keywords = self.extract_keywords(segment)
|
||||
topic_list.append({
|
||||
"index": i,
|
||||
"content": segment,
|
||||
"keywords": keywords
|
||||
})
|
||||
|
||||
return topic_list
|
||||
|
||||
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process sections using topic segmentation and keyword extraction.
|
||||
|
||||
:param url: The URL of the webpage.
|
||||
:param sections: List of sections (strings) to process.
|
||||
:param provider: The provider to be used for extraction (not used here).
|
||||
:param api_token: Optional API token for the provider (not used here).
|
||||
:return: A list of processed JSON blocks.
|
||||
"""
|
||||
# Concatenate sections into a single text for coherent topic segmentation
|
||||
|
||||
|
||||
return self.extract(url, self.DEL.join(sections), **kwargs)
|
||||
|
||||
class ContentSummarizationStrategy(ExtractionStrategy):
|
||||
def __init__(self, model_name: str = "sshleifer/distilbart-cnn-12-6"):
|
||||
"""
|
||||
Initialize the content summarization strategy with a specific model.
|
||||
|
||||
:param model_name: The model to use for summarization.
|
||||
"""
|
||||
self.summarizer = pipeline("summarization", model=model_name)
|
||||
|
||||
def extract(self, url: str, text: str, provider: str = None, api_token: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Summarize a single section of text.
|
||||
|
||||
:param url: The URL of the webpage.
|
||||
:param text: A section of text to summarize.
|
||||
:param provider: The provider to be used for extraction (not used here).
|
||||
:param api_token: Optional API token for the provider (not used here).
|
||||
:return: A dictionary with the summary.
|
||||
"""
|
||||
try:
|
||||
summary = self.summarizer(text, max_length=130, min_length=30, do_sample=False)
|
||||
return {"summary": summary[0]['summary_text']}
|
||||
except Exception as e:
|
||||
print(f"Error summarizing text: {e}")
|
||||
return {"summary": text} # Fallback to original text if summarization fails
|
||||
|
||||
def run(self, url: str, sections: List[str], provider: str = None, api_token: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process each section in parallel to produce summaries.
|
||||
|
||||
:param url: The URL of the webpage.
|
||||
:param sections: List of sections (strings) to summarize.
|
||||
:param provider: The provider to be used for extraction (not used here).
|
||||
:param api_token: Optional API token for the provider (not used here).
|
||||
:return: A list of dictionaries with summaries for each section.
|
||||
"""
|
||||
# Use a ThreadPoolExecutor to summarize in parallel
|
||||
summaries = []
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# Create a future for each section's summarization
|
||||
future_to_section = {executor.submit(self.extract, url, section, provider, api_token): i for i, section in enumerate(sections)}
|
||||
for future in as_completed(future_to_section):
|
||||
section_index = future_to_section[future]
|
||||
try:
|
||||
summary_result = future.result()
|
||||
summaries.append((section_index, summary_result))
|
||||
except Exception as e:
|
||||
print(f"Error processing section {section_index}: {e}")
|
||||
summaries.append((section_index, {"summary": sections[section_index]})) # Fallback to original text
|
||||
|
||||
# Sort summaries by the original section index to maintain order
|
||||
summaries.sort(key=lambda x: x[0])
|
||||
return [summary for _, summary in summaries]
|
||||
122
crawl4ai/train.py
Normal file
122
crawl4ai/train.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import spacy
|
||||
from spacy.training import Example
|
||||
import random
|
||||
import nltk
|
||||
from nltk.corpus import reuters
|
||||
|
||||
def train_and_save_reuters_model(model_dir="models/reuters"):
|
||||
# Ensure the Reuters corpus is downloaded
|
||||
nltk.download('reuters')
|
||||
nltk.download('punkt')
|
||||
if not reuters.fileids():
|
||||
print("Reuters corpus not found.")
|
||||
return
|
||||
|
||||
# Load a blank English spaCy model
|
||||
nlp = spacy.blank("en")
|
||||
|
||||
# Create a TextCategorizer with the ensemble model for multi-label classification
|
||||
textcat = nlp.add_pipe("textcat_multilabel")
|
||||
|
||||
# Add labels to text classifier
|
||||
for label in reuters.categories():
|
||||
textcat.add_label(label)
|
||||
|
||||
# Prepare training data
|
||||
train_examples = []
|
||||
for fileid in reuters.fileids():
|
||||
categories = reuters.categories(fileid)
|
||||
text = reuters.raw(fileid)
|
||||
cats = {label: label in categories for label in reuters.categories()}
|
||||
# Prepare spacy Example objects
|
||||
doc = nlp.make_doc(text)
|
||||
example = Example.from_dict(doc, {'cats': cats})
|
||||
train_examples.append(example)
|
||||
|
||||
# Initialize the text categorizer with the example objects
|
||||
nlp.initialize(lambda: train_examples)
|
||||
|
||||
# Train the model
|
||||
random.seed(1)
|
||||
spacy.util.fix_random_seed(1)
|
||||
for i in range(5): # Adjust iterations for better accuracy
|
||||
random.shuffle(train_examples)
|
||||
losses = {}
|
||||
# Create batches of data
|
||||
batches = spacy.util.minibatch(train_examples, size=8)
|
||||
for batch in batches:
|
||||
nlp.update(batch, drop=0.2, losses=losses)
|
||||
print(f"Losses at iteration {i}: {losses}")
|
||||
|
||||
# Save the trained model
|
||||
nlp.to_disk(model_dir)
|
||||
print(f"Model saved to: {model_dir}")
|
||||
|
||||
def train_model(model_dir, additional_epochs=0):
|
||||
# Load the model if it exists, otherwise start with a blank model
|
||||
try:
|
||||
nlp = spacy.load(model_dir)
|
||||
print("Model loaded from disk.")
|
||||
except IOError:
|
||||
print("No existing model found. Starting with a new model.")
|
||||
nlp = spacy.blank("en")
|
||||
textcat = nlp.add_pipe("textcat_multilabel")
|
||||
for label in reuters.categories():
|
||||
textcat.add_label(label)
|
||||
|
||||
# Prepare training data
|
||||
train_examples = []
|
||||
for fileid in reuters.fileids():
|
||||
categories = reuters.categories(fileid)
|
||||
text = reuters.raw(fileid)
|
||||
cats = {label: label in categories for label in reuters.categories()}
|
||||
doc = nlp.make_doc(text)
|
||||
example = Example.from_dict(doc, {'cats': cats})
|
||||
train_examples.append(example)
|
||||
|
||||
# Initialize the model if it was newly created
|
||||
if 'textcat_multilabel' not in nlp.pipe_names:
|
||||
nlp.initialize(lambda: train_examples)
|
||||
else:
|
||||
print("Continuing training with existing model.")
|
||||
|
||||
# Train the model
|
||||
random.seed(1)
|
||||
spacy.util.fix_random_seed(1)
|
||||
num_epochs = 5 + additional_epochs
|
||||
for i in range(num_epochs):
|
||||
random.shuffle(train_examples)
|
||||
losses = {}
|
||||
batches = spacy.util.minibatch(train_examples, size=8)
|
||||
for batch in batches:
|
||||
nlp.update(batch, drop=0.2, losses=losses)
|
||||
print(f"Losses at iteration {i}: {losses}")
|
||||
|
||||
# Save the trained model
|
||||
nlp.to_disk(model_dir)
|
||||
print(f"Model saved to: {model_dir}")
|
||||
|
||||
|
||||
|
||||
def load_model_and_predict(model_dir, text, tok_k = 3):
|
||||
# Load the trained model from the specified directory
|
||||
nlp = spacy.load(model_dir)
|
||||
|
||||
# Process the text with the loaded model
|
||||
doc = nlp(text)
|
||||
|
||||
# gee top 3 categories
|
||||
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
||||
print(f"Top {tok_k} categories:")
|
||||
|
||||
return top_categories
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_and_save_reuters_model()
|
||||
train_model("models/reuters", additional_epochs=5)
|
||||
model_directory = "reuters_model_10"
|
||||
print(reuters.categories())
|
||||
example_text = "Apple Inc. is reportedly buying a startup for $1 billion"
|
||||
r =load_model_and_predict(model_directory, example_text)
|
||||
print(r)
|
||||
@@ -1,4 +1,5 @@
|
||||
import requests
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from bs4 import BeautifulSoup, Comment, element, Tag, NavigableString
|
||||
import html2text
|
||||
import json
|
||||
@@ -272,6 +273,7 @@ def get_content_of_website(html, word_count_threshold = MIN_WORD_THRESHOLD):
|
||||
h = CustomHTML2Text()
|
||||
h.ignore_links = True
|
||||
markdown = h.handle(cleaned_html)
|
||||
markdown = markdown.replace(' ```', '```')
|
||||
|
||||
# Return the Markdown content
|
||||
return{
|
||||
@@ -417,3 +419,49 @@ def extract_blocks_batch(batch_data, provider = "groq/llama3-70b-8192", api_toke
|
||||
all_blocks.append(blocks)
|
||||
|
||||
return sum(all_blocks, [])
|
||||
|
||||
|
||||
def merge_chunks_based_on_token_threshold(chunks, token_threshold):
|
||||
"""
|
||||
Merges small chunks into larger ones based on the total token threshold.
|
||||
|
||||
:param chunks: List of text chunks to be merged based on token count.
|
||||
:param token_threshold: Max number of tokens for each merged chunk.
|
||||
:return: List of merged text chunks.
|
||||
"""
|
||||
merged_sections = []
|
||||
current_chunk = []
|
||||
total_token_so_far = 0
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_token_count = len(chunk.split()) * 1.3 # Estimate token count with a factor
|
||||
if total_token_so_far + chunk_token_count < token_threshold:
|
||||
current_chunk.append(chunk)
|
||||
total_token_so_far += chunk_token_count
|
||||
else:
|
||||
if current_chunk:
|
||||
merged_sections.append('\n\n'.join(current_chunk))
|
||||
current_chunk = [chunk]
|
||||
total_token_so_far = chunk_token_count
|
||||
|
||||
# Add the last chunk if it exists
|
||||
if current_chunk:
|
||||
merged_sections.append('\n\n'.join(current_chunk))
|
||||
|
||||
return merged_sections
|
||||
|
||||
def process_sections(url: str, sections: list, provider: str, api_token: str) -> list:
|
||||
parsed_json = []
|
||||
if provider.startswith("groq/"):
|
||||
# Sequential processing with a delay
|
||||
for section in sections:
|
||||
parsed_json.extend(extract_blocks(url, section, provider, api_token))
|
||||
time.sleep(0.5) # 500 ms delay between each processing
|
||||
else:
|
||||
# Parallel processing using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(extract_blocks, url, section, provider, api_token) for section in sections]
|
||||
for future in as_completed(futures):
|
||||
parsed_json.extend(future.result())
|
||||
|
||||
return parsed_json
|
||||
@@ -1,39 +1,21 @@
|
||||
import asyncio
|
||||
import os, time
|
||||
import json
|
||||
from pathlib import Path
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.service import Service
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
import chromedriver_autoinstaller
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
from .models import UrlModel, CrawlResult
|
||||
from .database import init_db, get_cached_url, cache_url
|
||||
from .utils import *
|
||||
from .chunking_strategy import *
|
||||
from .extraction_strategy import *
|
||||
from .crawler_strategy import *
|
||||
from typing import List
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from .config import *
|
||||
|
||||
class WebCrawler:
|
||||
def __init__(self, db_path: str):
|
||||
def __init__(self, db_path: str, crawler_strategy: CrawlerStrategy = LocalSeleniumCrawlerStrategy()):
|
||||
self.db_path = db_path
|
||||
init_db(self.db_path)
|
||||
self.options = Options()
|
||||
self.options.headless = True
|
||||
self.options.add_argument("--no-sandbox")
|
||||
self.options.add_argument("--disable-dev-shm-usage")
|
||||
# make it headless
|
||||
self.options.add_argument("--headless")
|
||||
|
||||
# Automatically install or update chromedriver
|
||||
chromedriver_autoinstaller.install()
|
||||
|
||||
# Initialize WebDriver for crawling
|
||||
self.service = Service(chromedriver_autoinstaller.install())
|
||||
self.driver = webdriver.Chrome(service=self.service, options=self.options)
|
||||
self.crawler_strategy = crawler_strategy
|
||||
|
||||
# Create the .crawl4ai folder in the user's home directory if it doesn't exist
|
||||
self.crawl4ai_folder = os.path.join(Path.home(), ".crawl4ai")
|
||||
@@ -47,10 +29,15 @@ class WebCrawler:
|
||||
api_token: str = None,
|
||||
extract_blocks_flag: bool = True,
|
||||
word_count_threshold = MIN_WORD_THRESHOLD,
|
||||
use_cached_html: bool = False) -> CrawlResult:
|
||||
use_cached_html: bool = False,
|
||||
extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(),
|
||||
chunking_strategy: ChunkingStrategy = RegexChunking(),
|
||||
**kwargs
|
||||
) -> CrawlResult:
|
||||
|
||||
# make sure word_count_threshold is not lesser than MIN_WORD_THRESHOLD
|
||||
# if word_count_threshold < MIN_WORD_THRESHOLD:
|
||||
# word_count_threshold = MIN_WORD_THRESHOLD
|
||||
if word_count_threshold < MIN_WORD_THRESHOLD:
|
||||
word_count_threshold = MIN_WORD_THRESHOLD
|
||||
|
||||
# Check cache first
|
||||
cached = get_cached_url(self.db_path, str(url_model.url))
|
||||
@@ -67,87 +54,38 @@ class WebCrawler:
|
||||
|
||||
|
||||
# Initialize WebDriver for crawling
|
||||
if use_cached_html:
|
||||
# load html from crawl4ai_folder/cache
|
||||
valid_file_name = str(url_model.url).replace("/", "_").replace(":", "_")
|
||||
if os.path.exists(os.path.join(self.crawl4ai_folder, "cache", valid_file_name)):
|
||||
with open(os.path.join(self.crawl4ai_folder, "cache", valid_file_name), "r") as f:
|
||||
html = f.read()
|
||||
else:
|
||||
raise Exception("Cached HTML file not found")
|
||||
|
||||
t = time.time()
|
||||
try:
|
||||
html = self.crawler_strategy.crawl(str(url_model.url))
|
||||
success = True
|
||||
error_message = ""
|
||||
else:
|
||||
service = self.service
|
||||
driver = self.driver
|
||||
|
||||
try:
|
||||
driver.get(str(url_model.url))
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
|
||||
)
|
||||
html = driver.page_source
|
||||
success = True
|
||||
error_message = ""
|
||||
|
||||
# Save html in crawl4ai_folder/cache
|
||||
valid_file_name = str(url_model.url).replace("/", "_").replace(":", "_")
|
||||
with open(os.path.join(self.crawl4ai_folder, "cache", valid_file_name), "w") as f:
|
||||
f.write(html)
|
||||
|
||||
except Exception as e:
|
||||
html = ""
|
||||
success = False
|
||||
error_message = str(e)
|
||||
finally:
|
||||
driver.quit()
|
||||
except Exception as e:
|
||||
html = ""
|
||||
success = False
|
||||
error_message = str(e)
|
||||
|
||||
# Extract content from HTML
|
||||
result = get_content_of_website(html, word_count_threshold)
|
||||
cleaned_html = result.get('cleaned_html', html)
|
||||
markdown = result.get('markdown', "")
|
||||
|
||||
print("Crawling is done 🚀")
|
||||
# Print a profession LOG style message, show time taken and say crawling is done
|
||||
print(f"[LOG] 🚀 Crawling done for {url_model.url}, success: {success}, time taken: {time.time() - t} seconds")
|
||||
|
||||
|
||||
parsed_json = []
|
||||
if extract_blocks_flag:
|
||||
print(f"[LOG] 🚀 Extracting semantic blocks for {url_model.url}")
|
||||
# Split markdown into sections
|
||||
paragraphs = markdown.split('\n\n')
|
||||
sections = []
|
||||
chunks = []
|
||||
total_token_so_far = 0
|
||||
|
||||
for paragraph in paragraphs:
|
||||
if total_token_so_far < CHUNK_TOKEN_THRESHOLD:
|
||||
chunk = paragraph.split(' ')
|
||||
total_token_so_far += len(chunk) * 1.3
|
||||
chunks.append(paragraph)
|
||||
else:
|
||||
sections.append('\n\n'.join(chunks))
|
||||
chunks = [paragraph]
|
||||
total_token_so_far = len(paragraph.split(' ')) * 1.3
|
||||
|
||||
if chunks:
|
||||
sections.append('\n\n'.join(chunks))
|
||||
|
||||
# Process sections to extract blocks
|
||||
parsed_json = []
|
||||
if provider.startswith("groq/"):
|
||||
# Sequential processing with a delay
|
||||
for section in sections:
|
||||
parsed_json.extend(extract_blocks(str(url_model.url), section, provider, api_token))
|
||||
time.sleep(0.5) # 500 ms delay between each processing
|
||||
else:
|
||||
# Parallel processing using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(extract_blocks, str(url_model.url), section, provider, api_token) for section in sections]
|
||||
for future in as_completed(futures):
|
||||
parsed_json.extend(future.result())
|
||||
sections = chunking_strategy.chunk(markdown)
|
||||
# sections = merge_chunks_based_on_token_threshold(sections, CHUNK_TOKEN_THRESHOLD)
|
||||
|
||||
parsed_json = extraction_strategy.run(str(url_model.url), sections, provider, api_token)
|
||||
parsed_json = json.dumps(parsed_json)
|
||||
print(f"[LOG] 🚀 Extraction done for {url_model.url}")
|
||||
else:
|
||||
parsed_json = "{}"
|
||||
print(f"[LOG] 🚀 Skipping extraction for {url_model.url}")
|
||||
|
||||
# Cache the result
|
||||
cleaned_html = beautify_html(cleaned_html)
|
||||
@@ -163,7 +101,23 @@ class WebCrawler:
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
def fetch_pages(self, url_models: List[UrlModel], provider: str = DEFAULT_PROVIDER, api_token: str = None) -> List[CrawlResult]:
|
||||
def fetch_pages(self, url_models: List[UrlModel], provider: str = DEFAULT_PROVIDER, api_token: str = None,
|
||||
extract_blocks_flag: bool = True, word_count_threshold=MIN_WORD_THRESHOLD,
|
||||
use_cached_html: bool = False, extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(),
|
||||
chunking_strategy: ChunkingStrategy = RegexChunking(), **kwargs) -> List[CrawlResult]:
|
||||
|
||||
def fetch_page_wrapper(url_model, *args, **kwargs):
|
||||
return self.fetch_page(url_model, *args, **kwargs)
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(executor.map(self.fetch_page, url_models, [provider] * len(url_models), [api_token] * len(url_models)))
|
||||
results = list(executor.map(fetch_page_wrapper, url_models,
|
||||
[provider] * len(url_models),
|
||||
[api_token] * len(url_models),
|
||||
[extract_blocks_flag] * len(url_models),
|
||||
[word_count_threshold] * len(url_models),
|
||||
[use_cached_html] * len(url_models),
|
||||
[extraction_strategy] * len(url_models),
|
||||
[chunking_strategy] * len(url_models),
|
||||
*[kwargs] * len(url_models)))
|
||||
|
||||
return results
|
||||
@@ -8,11 +8,12 @@ def main():
|
||||
crawler = WebCrawler(db_path='crawler_data.db')
|
||||
|
||||
# Fetch a single page
|
||||
single_url = UrlModel(url='https://www.nbcnews.com/business', forced=False)
|
||||
single_url = UrlModel(url='https://www.nbcnews.com/business', forced=True)
|
||||
result = crawler.fetch_page(
|
||||
single_url,
|
||||
provider= "openai/gpt-3.5-turbo",
|
||||
api_token = os.getenv('OPENAI_API_KEY'),
|
||||
use_cached_html = True,
|
||||
extract_blocks_flag=True,
|
||||
word_count_threshold=10
|
||||
)
|
||||
|
||||
144
models/reuters/config.cfg
Normal file
144
models/reuters/config.cfg
Normal file
@@ -0,0 +1,144 @@
|
||||
[paths]
|
||||
train = null
|
||||
dev = null
|
||||
vectors = null
|
||||
init_tok2vec = null
|
||||
|
||||
[system]
|
||||
seed = 0
|
||||
gpu_allocator = null
|
||||
|
||||
[nlp]
|
||||
lang = "en"
|
||||
pipeline = ["textcat_multilabel"]
|
||||
disabled = []
|
||||
before_creation = null
|
||||
after_creation = null
|
||||
after_pipeline_creation = null
|
||||
batch_size = 1000
|
||||
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
|
||||
vectors = {"@vectors":"spacy.Vectors.v1"}
|
||||
|
||||
[components]
|
||||
|
||||
[components.textcat_multilabel]
|
||||
factory = "textcat_multilabel"
|
||||
scorer = {"@scorers":"spacy.textcat_multilabel_scorer.v2"}
|
||||
threshold = 0.5
|
||||
|
||||
[components.textcat_multilabel.model]
|
||||
@architectures = "spacy.TextCatEnsemble.v2"
|
||||
nO = null
|
||||
|
||||
[components.textcat_multilabel.model.linear_model]
|
||||
@architectures = "spacy.TextCatBOW.v3"
|
||||
exclusive_classes = false
|
||||
length = 262144
|
||||
ngram_size = 1
|
||||
no_output_layer = false
|
||||
nO = null
|
||||
|
||||
[components.textcat_multilabel.model.tok2vec]
|
||||
@architectures = "spacy.Tok2Vec.v2"
|
||||
|
||||
[components.textcat_multilabel.model.tok2vec.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v2"
|
||||
width = 64
|
||||
rows = [2000,2000,500,1000,500]
|
||||
attrs = ["NORM","LOWER","PREFIX","SUFFIX","SHAPE"]
|
||||
include_static_vectors = false
|
||||
|
||||
[components.textcat_multilabel.model.tok2vec.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||
width = 64
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
depth = 2
|
||||
|
||||
[corpora]
|
||||
|
||||
[corpora.dev]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths.dev}
|
||||
gold_preproc = false
|
||||
max_length = 0
|
||||
limit = 0
|
||||
augmenter = null
|
||||
|
||||
[corpora.train]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths.train}
|
||||
gold_preproc = false
|
||||
max_length = 0
|
||||
limit = 0
|
||||
augmenter = null
|
||||
|
||||
[training]
|
||||
seed = ${system.seed}
|
||||
gpu_allocator = ${system.gpu_allocator}
|
||||
dropout = 0.1
|
||||
accumulate_gradient = 1
|
||||
patience = 1600
|
||||
max_epochs = 0
|
||||
max_steps = 20000
|
||||
eval_frequency = 200
|
||||
frozen_components = []
|
||||
annotating_components = []
|
||||
dev_corpus = "corpora.dev"
|
||||
train_corpus = "corpora.train"
|
||||
before_to_disk = null
|
||||
before_update = null
|
||||
|
||||
[training.batcher]
|
||||
@batchers = "spacy.batch_by_words.v1"
|
||||
discard_oversize = false
|
||||
tolerance = 0.2
|
||||
get_length = null
|
||||
|
||||
[training.batcher.size]
|
||||
@schedules = "compounding.v1"
|
||||
start = 100
|
||||
stop = 1000
|
||||
compound = 1.001
|
||||
t = 0.0
|
||||
|
||||
[training.logger]
|
||||
@loggers = "spacy.ConsoleLogger.v1"
|
||||
progress_bar = false
|
||||
|
||||
[training.optimizer]
|
||||
@optimizers = "Adam.v1"
|
||||
beta1 = 0.9
|
||||
beta2 = 0.999
|
||||
L2_is_weight_decay = true
|
||||
L2 = 0.01
|
||||
grad_clip = 1.0
|
||||
use_averages = false
|
||||
eps = 0.00000001
|
||||
learn_rate = 0.001
|
||||
|
||||
[training.score_weights]
|
||||
cats_score = 1.0
|
||||
cats_score_desc = null
|
||||
cats_micro_p = null
|
||||
cats_micro_r = null
|
||||
cats_micro_f = null
|
||||
cats_macro_p = null
|
||||
cats_macro_r = null
|
||||
cats_macro_f = null
|
||||
cats_macro_auc = null
|
||||
cats_f_per_type = null
|
||||
|
||||
[pretraining]
|
||||
|
||||
[initialize]
|
||||
vectors = ${paths.vectors}
|
||||
init_tok2vec = ${paths.init_tok2vec}
|
||||
vocab_data = null
|
||||
lookups = null
|
||||
before_init = null
|
||||
after_init = null
|
||||
|
||||
[initialize.components]
|
||||
|
||||
[initialize.tokenizer]
|
||||
122
models/reuters/meta.json
Normal file
122
models/reuters/meta.json
Normal file
@@ -0,0 +1,122 @@
|
||||
{
|
||||
"lang":"en",
|
||||
"name":"pipeline",
|
||||
"version":"0.0.0",
|
||||
"spacy_version":">=3.7.4,<3.8.0",
|
||||
"description":"",
|
||||
"author":"",
|
||||
"email":"",
|
||||
"url":"",
|
||||
"license":"",
|
||||
"spacy_git_version":"bff8725f4",
|
||||
"vectors":{
|
||||
"width":0,
|
||||
"vectors":0,
|
||||
"keys":0,
|
||||
"name":null,
|
||||
"mode":"default"
|
||||
},
|
||||
"labels":{
|
||||
"textcat_multilabel":[
|
||||
"acq",
|
||||
"alum",
|
||||
"barley",
|
||||
"bop",
|
||||
"carcass",
|
||||
"castor-oil",
|
||||
"cocoa",
|
||||
"coconut",
|
||||
"coconut-oil",
|
||||
"coffee",
|
||||
"copper",
|
||||
"copra-cake",
|
||||
"corn",
|
||||
"cotton",
|
||||
"cotton-oil",
|
||||
"cpi",
|
||||
"cpu",
|
||||
"crude",
|
||||
"dfl",
|
||||
"dlr",
|
||||
"dmk",
|
||||
"earn",
|
||||
"fuel",
|
||||
"gas",
|
||||
"gnp",
|
||||
"gold",
|
||||
"grain",
|
||||
"groundnut",
|
||||
"groundnut-oil",
|
||||
"heat",
|
||||
"hog",
|
||||
"housing",
|
||||
"income",
|
||||
"instal-debt",
|
||||
"interest",
|
||||
"ipi",
|
||||
"iron-steel",
|
||||
"jet",
|
||||
"jobs",
|
||||
"l-cattle",
|
||||
"lead",
|
||||
"lei",
|
||||
"lin-oil",
|
||||
"livestock",
|
||||
"lumber",
|
||||
"meal-feed",
|
||||
"money-fx",
|
||||
"money-supply",
|
||||
"naphtha",
|
||||
"nat-gas",
|
||||
"nickel",
|
||||
"nkr",
|
||||
"nzdlr",
|
||||
"oat",
|
||||
"oilseed",
|
||||
"orange",
|
||||
"palladium",
|
||||
"palm-oil",
|
||||
"palmkernel",
|
||||
"pet-chem",
|
||||
"platinum",
|
||||
"potato",
|
||||
"propane",
|
||||
"rand",
|
||||
"rape-oil",
|
||||
"rapeseed",
|
||||
"reserves",
|
||||
"retail",
|
||||
"rice",
|
||||
"rubber",
|
||||
"rye",
|
||||
"ship",
|
||||
"silver",
|
||||
"sorghum",
|
||||
"soy-meal",
|
||||
"soy-oil",
|
||||
"soybean",
|
||||
"strategic-metal",
|
||||
"sugar",
|
||||
"sun-meal",
|
||||
"sun-oil",
|
||||
"sunseed",
|
||||
"tea",
|
||||
"tin",
|
||||
"trade",
|
||||
"veg-oil",
|
||||
"wheat",
|
||||
"wpi",
|
||||
"yen",
|
||||
"zinc"
|
||||
]
|
||||
},
|
||||
"pipeline":[
|
||||
"textcat_multilabel"
|
||||
],
|
||||
"components":[
|
||||
"textcat_multilabel"
|
||||
],
|
||||
"disabled":[
|
||||
|
||||
]
|
||||
}
|
||||
95
models/reuters/textcat_multilabel/cfg
Normal file
95
models/reuters/textcat_multilabel/cfg
Normal file
@@ -0,0 +1,95 @@
|
||||
{
|
||||
"labels":[
|
||||
"acq",
|
||||
"alum",
|
||||
"barley",
|
||||
"bop",
|
||||
"carcass",
|
||||
"castor-oil",
|
||||
"cocoa",
|
||||
"coconut",
|
||||
"coconut-oil",
|
||||
"coffee",
|
||||
"copper",
|
||||
"copra-cake",
|
||||
"corn",
|
||||
"cotton",
|
||||
"cotton-oil",
|
||||
"cpi",
|
||||
"cpu",
|
||||
"crude",
|
||||
"dfl",
|
||||
"dlr",
|
||||
"dmk",
|
||||
"earn",
|
||||
"fuel",
|
||||
"gas",
|
||||
"gnp",
|
||||
"gold",
|
||||
"grain",
|
||||
"groundnut",
|
||||
"groundnut-oil",
|
||||
"heat",
|
||||
"hog",
|
||||
"housing",
|
||||
"income",
|
||||
"instal-debt",
|
||||
"interest",
|
||||
"ipi",
|
||||
"iron-steel",
|
||||
"jet",
|
||||
"jobs",
|
||||
"l-cattle",
|
||||
"lead",
|
||||
"lei",
|
||||
"lin-oil",
|
||||
"livestock",
|
||||
"lumber",
|
||||
"meal-feed",
|
||||
"money-fx",
|
||||
"money-supply",
|
||||
"naphtha",
|
||||
"nat-gas",
|
||||
"nickel",
|
||||
"nkr",
|
||||
"nzdlr",
|
||||
"oat",
|
||||
"oilseed",
|
||||
"orange",
|
||||
"palladium",
|
||||
"palm-oil",
|
||||
"palmkernel",
|
||||
"pet-chem",
|
||||
"platinum",
|
||||
"potato",
|
||||
"propane",
|
||||
"rand",
|
||||
"rape-oil",
|
||||
"rapeseed",
|
||||
"reserves",
|
||||
"retail",
|
||||
"rice",
|
||||
"rubber",
|
||||
"rye",
|
||||
"ship",
|
||||
"silver",
|
||||
"sorghum",
|
||||
"soy-meal",
|
||||
"soy-oil",
|
||||
"soybean",
|
||||
"strategic-metal",
|
||||
"sugar",
|
||||
"sun-meal",
|
||||
"sun-oil",
|
||||
"sunseed",
|
||||
"tea",
|
||||
"tin",
|
||||
"trade",
|
||||
"veg-oil",
|
||||
"wheat",
|
||||
"wpi",
|
||||
"yen",
|
||||
"zinc"
|
||||
],
|
||||
"threshold":0.5
|
||||
}
|
||||
BIN
models/reuters/textcat_multilabel/model
Normal file
BIN
models/reuters/textcat_multilabel/model
Normal file
Binary file not shown.
3
models/reuters/tokenizer
Normal file
3
models/reuters/tokenizer
Normal file
File diff suppressed because one or more lines are too long
1
models/reuters/vocab/key2row
Normal file
1
models/reuters/vocab/key2row
Normal file
@@ -0,0 +1 @@
|
||||
<EFBFBD>
|
||||
1
models/reuters/vocab/lookups.bin
Normal file
1
models/reuters/vocab/lookups.bin
Normal file
@@ -0,0 +1 @@
|
||||
<EFBFBD>
|
||||
83457
models/reuters/vocab/strings.json
Normal file
83457
models/reuters/vocab/strings.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
models/reuters/vocab/vectors
Normal file
BIN
models/reuters/vocab/vectors
Normal file
Binary file not shown.
3
models/reuters/vocab/vectors.cfg
Normal file
3
models/reuters/vocab/vectors.cfg
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"mode":"default"
|
||||
}
|
||||
@@ -11,3 +11,5 @@ bs4
|
||||
html2text
|
||||
litellm
|
||||
python-dotenv
|
||||
nltk
|
||||
spacy
|
||||
Reference in New Issue
Block a user