146 lines
4.9 KiB
Python
146 lines
4.9 KiB
Python
import spacy
|
|
from spacy.training import Example
|
|
import random
|
|
import nltk
|
|
from nltk.corpus import reuters
|
|
import torch
|
|
|
|
def save_spacy_model_as_torch(nlp, model_dir="models/reuters"):
|
|
# Extract the TextCategorizer component
|
|
textcat = nlp.get_pipe("textcat_multilabel")
|
|
|
|
# Convert the weights to a PyTorch state dictionary
|
|
state_dict = {name: torch.tensor(param.data) for name, param in textcat.model.named_parameters()}
|
|
|
|
# Save the state dictionary
|
|
torch.save(state_dict, f"{model_dir}/model_weights.pth")
|
|
|
|
# Extract and save the vocabulary
|
|
vocab = extract_vocab(nlp)
|
|
with open(f"{model_dir}/vocab.txt", "w") as vocab_file:
|
|
for word, idx in vocab.items():
|
|
vocab_file.write(f"{word}\t{idx}\n")
|
|
|
|
print(f"Model weights and vocabulary saved to: {model_dir}")
|
|
|
|
def extract_vocab(nlp):
|
|
# Extract vocabulary from the SpaCy model
|
|
vocab = {word: i for i, word in enumerate(nlp.vocab.strings)}
|
|
return vocab
|
|
|
|
nlp = spacy.load("models/reuters")
|
|
save_spacy_model_as_torch(nlp, model_dir="models")
|
|
|
|
def train_and_save_reuters_model(model_dir="models/reuters"):
|
|
# Ensure the Reuters corpus is downloaded
|
|
nltk.download('reuters')
|
|
nltk.download('punkt')
|
|
if not reuters.fileids():
|
|
print("Reuters corpus not found.")
|
|
return
|
|
|
|
# Load a blank English spaCy model
|
|
nlp = spacy.blank("en")
|
|
|
|
# Create a TextCategorizer with the ensemble model for multi-label classification
|
|
textcat = nlp.add_pipe("textcat_multilabel")
|
|
|
|
# Add labels to text classifier
|
|
for label in reuters.categories():
|
|
textcat.add_label(label)
|
|
|
|
# Prepare training data
|
|
train_examples = []
|
|
for fileid in reuters.fileids():
|
|
categories = reuters.categories(fileid)
|
|
text = reuters.raw(fileid)
|
|
cats = {label: label in categories for label in reuters.categories()}
|
|
# Prepare spacy Example objects
|
|
doc = nlp.make_doc(text)
|
|
example = Example.from_dict(doc, {'cats': cats})
|
|
train_examples.append(example)
|
|
|
|
# Initialize the text categorizer with the example objects
|
|
nlp.initialize(lambda: train_examples)
|
|
|
|
# Train the model
|
|
random.seed(1)
|
|
spacy.util.fix_random_seed(1)
|
|
for i in range(5): # Adjust iterations for better accuracy
|
|
random.shuffle(train_examples)
|
|
losses = {}
|
|
# Create batches of data
|
|
batches = spacy.util.minibatch(train_examples, size=8)
|
|
for batch in batches:
|
|
nlp.update(batch, drop=0.2, losses=losses)
|
|
print(f"Losses at iteration {i}: {losses}")
|
|
|
|
# Save the trained model
|
|
nlp.to_disk(model_dir)
|
|
print(f"Model saved to: {model_dir}")
|
|
|
|
def train_model(model_dir, additional_epochs=0):
|
|
# Load the model if it exists, otherwise start with a blank model
|
|
try:
|
|
nlp = spacy.load(model_dir)
|
|
print("Model loaded from disk.")
|
|
except IOError:
|
|
print("No existing model found. Starting with a new model.")
|
|
nlp = spacy.blank("en")
|
|
textcat = nlp.add_pipe("textcat_multilabel")
|
|
for label in reuters.categories():
|
|
textcat.add_label(label)
|
|
|
|
# Prepare training data
|
|
train_examples = []
|
|
for fileid in reuters.fileids():
|
|
categories = reuters.categories(fileid)
|
|
text = reuters.raw(fileid)
|
|
cats = {label: label in categories for label in reuters.categories()}
|
|
doc = nlp.make_doc(text)
|
|
example = Example.from_dict(doc, {'cats': cats})
|
|
train_examples.append(example)
|
|
|
|
# Initialize the model if it was newly created
|
|
if 'textcat_multilabel' not in nlp.pipe_names:
|
|
nlp.initialize(lambda: train_examples)
|
|
else:
|
|
print("Continuing training with existing model.")
|
|
|
|
# Train the model
|
|
random.seed(1)
|
|
spacy.util.fix_random_seed(1)
|
|
num_epochs = 5 + additional_epochs
|
|
for i in range(num_epochs):
|
|
random.shuffle(train_examples)
|
|
losses = {}
|
|
batches = spacy.util.minibatch(train_examples, size=8)
|
|
for batch in batches:
|
|
nlp.update(batch, drop=0.2, losses=losses)
|
|
print(f"Losses at iteration {i}: {losses}")
|
|
|
|
# Save the trained model
|
|
nlp.to_disk(model_dir)
|
|
print(f"Model saved to: {model_dir}")
|
|
|
|
def load_model_and_predict(model_dir, text, tok_k = 3):
|
|
# Load the trained model from the specified directory
|
|
nlp = spacy.load(model_dir)
|
|
|
|
# Process the text with the loaded model
|
|
doc = nlp(text)
|
|
|
|
# gee top 3 categories
|
|
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
|
print(f"Top {tok_k} categories:")
|
|
|
|
return top_categories
|
|
|
|
if __name__ == "__main__":
|
|
train_and_save_reuters_model()
|
|
train_model("models/reuters", additional_epochs=5)
|
|
model_directory = "reuters_model_10"
|
|
print(reuters.categories())
|
|
example_text = "Apple Inc. is reportedly buying a startup for $1 billion"
|
|
r =load_model_and_predict(model_directory, example_text)
|
|
print(r) |