pbr-ayct-core/utils.py
2022-01-24 19:22:04 +01:00

206 lines
6.6 KiB
Python

import spacy
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
import torch.nn.functional as F
import csv
import os
nlp = spacy.load("en_core_web_sm")
class NewsEntry(Dataset):
def __init__(self, control_code, data, truncate=False, gpt2_type="gpt2", max_length=1024):
self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
self.data = []
for row in data:
self.data.append(torch.tensor(
self.tokenizer.encode(f"<|{control_code}|>{row[:max_length]}<|endoftext|>")
))
if truncate:
self.data = self.data[:20000]
self.data_count = len(self.data)
def __len__(self):
return self.data_count
def __getitem__(self, item):
return self.data[item]
class Vocabulary():
def __init__(self) -> None:
self.__UKNOWN__ = -1
self.vocab = {}
def create_vocab(self, data):
counter = 0
vocab = {}
for row in data:
for word in nlp(row):
ex = word.lemma_
if ex in vocab:
pass
else:
vocab[ex] = counter
counter += 1
self.vocab = vocab
def word_to_number(self, word):
word = nlp(word)
for token in word:
ex = token.lemma_
if ex in self.vocab:
return self.vocab[ex]
else:
return self.__UKNOWN__
def sentence_to_numbers(self, seq):
result = []
for word in nlp(seq):
ex = word.lemma_
if ex in self.vocab:
result.append(self.vocab[ex])
else:
result.append(self.__UKNOWN__)
return result
def sequence_to_numbers(self, seq):
return [self.word_to_number(x) for x in seq]
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
if packed_tensor is None:
return new_tensor, True, None
if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
return packed_tensor, False, new_tensor
else:
packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
return packed_tensor, True, None
class POC():
def __init__(self) -> None:
pass
def train(
self, dataset, model, tokenizer,
batch_size=16, epochs=5, lr=2e-5,
max_seq_len=400, warmup_steps=200,
gpt2_type="gpt2", output_dir=".", output_prefix="model",
test_mode=False,save_model_on_epoch=False,
):
acc_steps = 100
device=torch.device("cuda")
model = model.cuda()
model.train()
optimizer = AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1
)
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
loss=0
accumulating_batch_count = 0
input_tensor = None
for epoch in range(epochs):
print(f"Training epoch {epoch}")
print(loss)
for idx, entry in tqdm(enumerate(train_dataloader)):
(input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, 768)
if carry_on and idx != len(train_dataloader) - 1:
continue
input_tensor = input_tensor.to(device)
outputs = model(input_tensor, labels=input_tensor)
loss = outputs[0]
loss.backward()
if (accumulating_batch_count % batch_size) == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
model.zero_grad()
accumulating_batch_count += 1
input_tensor = None
if save_model_on_epoch:
torch.save(
model.state_dict(),
os.path.join(output_dir, f"{output_prefix}-{epoch}.pt"),
)
return model
def generate(
self,
model,
tokenizer,
prompt,
entry_count=10,
entry_length=30, #maximum number of words
top_p=0.8,
temperature=1.,
):
model.eval()
generated_num = 0
generated_list = []
filter_value = -float("Inf")
with torch.no_grad():
for entry_idx in trange(entry_count):
entry_finished = False
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
if entry_idx % 100 == 0:
print(entry_idx)
for i in range(entry_length):
outputs = model(generated, labels=generated)
loss, logits = outputs[:2]
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = filter_value
next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token), dim=1)
if next_token in tokenizer.encode("<|endoftext|>"):
entry_finished = True
if entry_finished:
generated_num = generated_num + 1
output_list = list(generated.squeeze().numpy())
output_text = tokenizer.decode(output_list)
generated_list.append(output_text)
break
if not entry_finished:
output_list = list(generated.squeeze().numpy())
output_text = f"{tokenizer.decode(output_list)}<|endoftext|>"
generated_list.append(output_text)
return generated_list