import pandas as pd from nltk import word_tokenize from utils import Vocabulary from transformers import GPT2Tokenizer, GPT2LMHeadModel from utils import POC from utils import NewsEntry import torch # def encode_dataset(data): # return data.apply(lambda x: vocabulary.sentence_to_numbers(x)) # mode = "train" # mode = "test" mode = "generate" prompt = """ Mr Johnson's bakerys is bad. My bakery is good. """ model_name = "model-4.pt" train = pd.read_csv("data/BBC_News_Train.csv") test = pd.read_csv("data/BBC_News_Test.csv") length_limit = 1024 train = train[train["Text"].apply(lambda x: len(word_tokenize(x)) < length_limit)] test = test[test["Text"].apply(lambda x: len(word_tokenize(x)) < length_limit)] tokenizer = GPT2Tokenizer.from_pretrained('gpt2') model = GPT2LMHeadModel.from_pretrained('gpt2') if mode == "train": dataset = NewsEntry(train["Text"], train["Text"], truncate=False, gpt2_type="gpt2") poc = POC() print("Starting training") poc.train(dataset, model, tokenizer, output_dir="./", save_model_on_epoch=True) elif mode == "generate": # dataset = NewsEntry(test["Text"], test["Text"], truncate=False, gpt2_type="gpt2") poc = POC() model.load_state_dict(torch.load(model_name)) x = poc.generate(model, tokenizer, prompt, entry_count=1, entry_length=300) print(x) elif mode == "test": dataset = NewsEntry(test["Text"], test["Text"], truncate=False, gpt2_type="gpt2") poc = POC() model.load_state_dict(torch.load(model_name))