import spacy import pandas as pd from transformers import GPT2LMHeadModel, GPT2Tokenizer import numpy as np import random import torch from torch.utils.data import Dataset, DataLoader from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup from tqdm import tqdm, trange import torch.nn.functional as F import csv import os nlp = spacy.load("en_core_web_sm") class NewsEntry(Dataset): def __init__(self, control_code, data, truncate=False, gpt2_type="gpt2", max_length=1024): self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type) self.data = [] for row in data: self.data.append(torch.tensor( self.tokenizer.encode(f"<|{control_code}|>{row[:max_length]}<|endoftext|>") )) if truncate: self.data = self.data[:20000] self.data_count = len(self.data) def __len__(self): return self.data_count def __getitem__(self, item): return self.data[item] class Vocabulary(): def __init__(self) -> None: self.__UKNOWN__ = -1 self.vocab = {} def create_vocab(self, data): counter = 0 vocab = {} for row in data: for word in nlp(row): ex = word.lemma_ if ex in vocab: pass else: vocab[ex] = counter counter += 1 self.vocab = vocab def word_to_number(self, word): word = nlp(word) for token in word: ex = token.lemma_ if ex in self.vocab: return self.vocab[ex] else: return self.__UKNOWN__ def sentence_to_numbers(self, seq): result = [] for word in nlp(seq): ex = word.lemma_ if ex in self.vocab: result.append(self.vocab[ex]) else: result.append(self.__UKNOWN__) return result def sequence_to_numbers(self, seq): return [self.word_to_number(x) for x in seq] def pack_tensor(new_tensor, packed_tensor, max_seq_len): if packed_tensor is None: return new_tensor, True, None if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len: return packed_tensor, False, new_tensor else: packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1) return packed_tensor, True, None class POC(): def __init__(self) -> None: pass def train( self, dataset, model, tokenizer, batch_size=16, epochs=5, lr=2e-5, max_seq_len=400, warmup_steps=200, gpt2_type="gpt2", output_dir=".", output_prefix="model", test_mode=False,save_model_on_epoch=False, ): acc_steps = 100 device=torch.device("cuda") model = model.cuda() model.train() optimizer = AdamW(model.parameters(), lr=lr) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1 ) train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True) loss=0 accumulating_batch_count = 0 input_tensor = None for epoch in range(epochs): print(f"Training epoch {epoch}") print(loss) for idx, entry in tqdm(enumerate(train_dataloader)): (input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, 768) if carry_on and idx != len(train_dataloader) - 1: continue input_tensor = input_tensor.to(device) outputs = model(input_tensor, labels=input_tensor) loss = outputs[0] loss.backward() if (accumulating_batch_count % batch_size) == 0: optimizer.step() scheduler.step() optimizer.zero_grad() model.zero_grad() accumulating_batch_count += 1 input_tensor = None if save_model_on_epoch: torch.save( model.state_dict(), os.path.join(output_dir, f"{output_prefix}-{epoch}.pt"), ) return model def generate( self, model, tokenizer, prompt, entry_count=10, entry_length=30, #maximum number of words top_p=0.8, temperature=1., ): model.eval() generated_num = 0 generated_list = [] filter_value = -float("Inf") with torch.no_grad(): for entry_idx in trange(entry_count): entry_finished = False generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) if entry_idx % 100 == 0: print(entry_idx) for i in range(entry_length): outputs = model(generated, labels=generated) loss, logits = outputs[:2] logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ ..., :-1 ].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[:, indices_to_remove] = filter_value next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) generated = torch.cat((generated, next_token), dim=1) if next_token in tokenizer.encode("<|endoftext|>"): entry_finished = True if entry_finished: generated_num = generated_num + 1 output_list = list(generated.squeeze().numpy()) output_text = tokenizer.decode(output_list) generated_list.append(output_text) break if not entry_finished: output_list = list(generated.squeeze().numpy()) output_text = f"{tokenizer.decode(output_list)}<|endoftext|>" generated_list.append(output_text) return generated_list