import torch from torch import nn, optim from torch.utils.data import DataLoader import numpy as np from collections import Counter import string import lzma import pdb import copy from torch.utils.data import IterableDataset import itertools import lzma import regex as re import pickle import string import pdb import utils import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" device = 'cuda' with open("vocab.pickle", 'rb') as handle: vocab = pickle.load( handle) vocab.set_default_index(vocab['']) def get_word_lines_from_file(file_name): counter=0 seq_len = 10 with lzma.open(file_name, 'r') as fh: for line in fh: counter+=1 # if counter == 100000: # break line = line.decode("utf-8") line_splitted = utils.get_words_from_line(line) vocab_line = [vocab[t] for t in line_splitted] for i in range(len(vocab_line) - seq_len): yield torch.tensor(vocab_line[i:i+seq_len]), torch.tensor(vocab_line[i+1 :i+seq_len+1]) class Grams_10(IterableDataset): def __init__(self, text_file, vocab): self.vocab = vocab self.vocab.set_default_index(self.vocab['']) self.text_file = text_file def __iter__(self): return get_word_lines_from_file(self.text_file) vocab_size = utils.vocab_size train_dataset = Grams_10('train/in.tsv.xz', vocab) BATCH_SIZE = 1024 class Model(nn.Module): def __init__(self, vocab_size): super(Model, self).__init__() self.lstm_size = 150 self.embedding_dim = 200 self.num_layers = 1 self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=self.embedding_dim, ) self.lstm = nn.LSTM( input_size=self.embedding_dim, hidden_size=self.lstm_size, num_layers=self.num_layers, batch_first=True, bidirectional=True, # dropout=0.2, ) self.fc = nn.Linear(self.lstm_size*2, vocab_size) def forward(self, x, prev_state = None): embed = self.embedding(x) output, state = self.lstm(embed, prev_state) logits = self.fc(output) return logits, state def init_state(self, sequence_length): return (torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device), torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device)) def train(dataloader, model, max_epochs): model.train() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.01) for epoch in range(max_epochs): step = 0 for batch_i, (x, y) in enumerate(dataloader): # pdb.set_trace() x = x.to(device) y = y.to(device) optimizer.zero_grad() y_pred, (state_h, state_c) = model(x) # pdb.set_trace() loss = criterion(y_pred.transpose(1, 2), y) loss.backward() optimizer.step() step+=1 if step % 500 == 0: print({ 'epoch': epoch,'step': step ,'loss': loss.item(), }) # torch.save(model.state_dict(), f'lstm_step_{step}.bin') if step % 5000 == 0: print({ 'epoch': epoch, 'step': step, 'loss': loss.item() }) torch.save(model.state_dict(), f'lstm_step_{step}.bin') torch.save(model.state_dict(), f'lstm_epoch_{epoch}.bin') # break print('Halko zaczynamy trenowanie') model = Model(vocab_size = vocab_size).to(device) dataset = DataLoader(train_dataset, batch_size=BATCH_SIZE) train(dataset, model, 1) torch.save(model.state_dict(), f'lstm.bin') # def predict(model, text_splitted): # model.eval() # words = text_splitted # x = torch.tensor([[vocab[w] for w in words]]).to(device) # state_h, state_c = model.init_state(x.size()[0]) # y_pred, (state_h, state_c) = model(x, (state_h, state_c)) # last_word_logits = y_pred[0][-1] # p = torch.nn.functional.softmax(last_word_logits, dim=0) # top = torch.topk(p, 64) # top_indices = top.indices.tolist() # top_probs = top.values.tolist() # top_words = vocab.lookup_tokens(top_indices) # return top_words, top_probs # print('Halko zaczynamy predykcje') # inference_result = [] # with lzma.open(f'dev-0/in.tsv.xz', 'r') as file: # for line in file: # line = line.decode("utf-8") # line = line.rstrip() # line = line.translate(str.maketrans('', '', string.punctuation)) # line_splitted_by_tab = line.split('\t') # left_context = line_splitted_by_tab[-2] # left_context_splitted = list(utils.get_words_from_line(left_context)) # top_words, top_probs = predict(model, left_context_splitted) # string_to_print = '' # sum_probs = 0 # for w, p in zip(top_words, top_probs): # # print(top_words) # if '' in w: # continue # string_to_print += f"{w}:{p} " # sum_probs += p # if string_to_print == '': # inference_result.append("the:0.2 a:0.3 :0.5") # continue # unknow_prob = 1 - sum_probs # string_to_print += f":{unknow_prob}" # inference_result.append(string_to_print) # with open('dev-0/out.tsv', 'w') as f: # for line in inference_result: # f.write(line+'\n') print('All done')