pbr-ayct-core/test.py

47 lines
1.5 KiB
Python
Raw Permalink Normal View History

2022-01-24 19:22:04 +01:00
import pandas as pd
from nltk import word_tokenize
from utils import Vocabulary
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from utils import POC
from utils import NewsEntry
import torch
# def encode_dataset(data):
# return data.apply(lambda x: vocabulary.sentence_to_numbers(x))
# mode = "train"
# mode = "test"
mode = "generate"
prompt = """
Mr Johnson's bakerys is bad. My bakery is good.
"""
model_name = "model-4.pt"
train = pd.read_csv("data/BBC_News_Train.csv")
test = pd.read_csv("data/BBC_News_Test.csv")
length_limit = 1024
train = train[train["Text"].apply(lambda x: len(word_tokenize(x)) < length_limit)]
test = test[test["Text"].apply(lambda x: len(word_tokenize(x)) < length_limit)]
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
if mode == "train":
dataset = NewsEntry(train["Text"], train["Text"], truncate=False, gpt2_type="gpt2")
poc = POC()
print("Starting training")
poc.train(dataset, model, tokenizer, output_dir="./", save_model_on_epoch=True)
elif mode == "generate":
# dataset = NewsEntry(test["Text"], test["Text"], truncate=False, gpt2_type="gpt2")
poc = POC()
model.load_state_dict(torch.load(model_name))
x = poc.generate(model, tokenizer, prompt, entry_count=1, entry_length=300)
print(x)
elif mode == "test":
dataset = NewsEntry(test["Text"], test["Text"], truncate=False, gpt2_type="gpt2")
poc = POC()
model.load_state_dict(torch.load(model_name))