47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
|
import pandas as pd
|
||
|
from nltk import word_tokenize
|
||
|
from utils import Vocabulary
|
||
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||
|
from utils import POC
|
||
|
from utils import NewsEntry
|
||
|
import torch
|
||
|
|
||
|
# def encode_dataset(data):
|
||
|
# return data.apply(lambda x: vocabulary.sentence_to_numbers(x))
|
||
|
|
||
|
# mode = "train"
|
||
|
# mode = "test"
|
||
|
mode = "generate"
|
||
|
prompt = """
|
||
|
Mr Johnson's bakerys is bad. My bakery is good.
|
||
|
"""
|
||
|
model_name = "model-4.pt"
|
||
|
|
||
|
train = pd.read_csv("data/BBC_News_Train.csv")
|
||
|
test = pd.read_csv("data/BBC_News_Test.csv")
|
||
|
length_limit = 1024
|
||
|
|
||
|
train = train[train["Text"].apply(lambda x: len(word_tokenize(x)) < length_limit)]
|
||
|
test = test[test["Text"].apply(lambda x: len(word_tokenize(x)) < length_limit)]
|
||
|
|
||
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||
|
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
||
|
|
||
|
|
||
|
if mode == "train":
|
||
|
dataset = NewsEntry(train["Text"], train["Text"], truncate=False, gpt2_type="gpt2")
|
||
|
poc = POC()
|
||
|
print("Starting training")
|
||
|
poc.train(dataset, model, tokenizer, output_dir="./", save_model_on_epoch=True)
|
||
|
|
||
|
elif mode == "generate":
|
||
|
# dataset = NewsEntry(test["Text"], test["Text"], truncate=False, gpt2_type="gpt2")
|
||
|
poc = POC()
|
||
|
model.load_state_dict(torch.load(model_name))
|
||
|
x = poc.generate(model, tokenizer, prompt, entry_count=1, entry_length=300)
|
||
|
print(x)
|
||
|
|
||
|
elif mode == "test":
|
||
|
dataset = NewsEntry(test["Text"], test["Text"], truncate=False, gpt2_type="gpt2")
|
||
|
poc = POC()
|
||
|
model.load_state_dict(torch.load(model_name))
|