From 0f409ae23b4857a42196e574de029feed11156c3 Mon Sep 17 00:00:00 2001 From: Jakub Pokrywka Date: Mon, 18 Jan 2021 20:45:46 +0100 Subject: [PATCH] add rnn attention --- pytorch12.py | 5 +- pytorch13.py | 136 +++++++++++++++++++++++++++++++++++++++++++++++++++ rnnatt.py | 0 rnnreg.py | 0 4 files changed, 139 insertions(+), 2 deletions(-) mode change 100644 => 100755 pytorch12.py create mode 100755 pytorch13.py create mode 100644 rnnatt.py create mode 100644 rnnreg.py diff --git a/pytorch12.py b/pytorch12.py old mode 100644 new mode 100755 index d71f614..8d125ee --- a/pytorch12.py +++ b/pytorch12.py @@ -9,6 +9,7 @@ from torch import nn, optim nb_of_char_codes = 128 + 2 SOS_token_id = 128 # start of sentence EOS_token_id = 129 # end of sentence +MAX_LENGTH = 20 hidden_size = 32 step = 200 @@ -72,7 +73,7 @@ class DecoderRNN(nn.Module): encoder = EncoderRNN(nb_of_char_codes, hidden_size).to(device) decoder = DecoderRNN(hidden_size, nb_of_char_codes).to(device) criterion = nn.NLLLoss().to(device) -optimizer = optim.Adam((list(encoder.parameters()) + list(decoder.parameters()))) +optimizer = optim.Adam((list(encoder.parameters()) + list(decoder.parameters()))) counter = 0 losses = [] @@ -83,7 +84,7 @@ for s,t in char_source(): decoder.zero_grad() x = torch.tensor(s, dtype=torch.long, device=device) encoder_hidden = encoder.initHidden() - encoder_output = torch.zeros(hidden_size, hidden_size, device=device) + encoder_output = torch.zeros(MAX_LENGTH, hidden_size, device=device) for i in range(x.shape[0]): output, encoder_hidden = encoder(x[i].unsqueeze(0).unsqueeze(0), encoder_hidden) encoder_output[i] = output[0,0] diff --git a/pytorch13.py b/pytorch13.py new file mode 100755 index 0000000..ede597f --- /dev/null +++ b/pytorch13.py @@ -0,0 +1,136 @@ +#!/usr/bin/python3 + +# https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html + +import sys +import torch +from torch import nn, optim + +nb_of_char_codes = 128 + 2 +SOS_token_id = 128 # start of sentence +EOS_token_id = 129 # end of sentence +MAX_LENGTH = 20 + +hidden_size = 32 +step = 200 + +device = torch.device('cpu') + +f = open('eng-fra.txt') +def char_source(): + for line in f: + s, t = line.rstrip('\n').split('\t') + s_list = [] + t_list = [] + + for c in s: + c_code = ord(c) + if c_code < nb_of_char_codes: + s_list.append(ord(c)) + + for c in t: + c_code = ord(c) + if c_code < nb_of_char_codes: + t_list.append(ord(c)) + + yield s_list, t_list + +class EncoderRNN(nn.Module): + def __init__(self, input_size, hidden_size): + super(EncoderRNN, self).__init__() + self.hidden_size = hidden_size + + self.embedding = nn.Embedding(input_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size) + + def forward(self, input, hidden): + embedded = self.embedding(input) + output = embedded + output, hidden = self.gru(output, hidden) + return output, hidden + + def initHidden(self): + return torch.zeros(1,1, self.hidden_size, device=device) + +class DecoderRNN(nn.Module): + def __init__(self, hidden_size, output_size, max_length=MAX_LENGTH): + super(DecoderRNN, self).__init__() + self.hidden_size = hidden_size + + self.embedding = nn.Embedding(output_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size) + self.out = nn.Linear(hidden_size, output_size) + self.softmax = nn.LogSoftmax(dim=1) + + + self.attn = nn.Linear(self.hidden_size * 2, max_length) + self.attn_combine = nn.Linear(hidden_size * 2, hidden_size) + + def forward(self, input, hidden, encoder_output): + output = self.embedding(input) + + + attn_weights = torch.nn.functional.softmax(self.attn(torch.cat((output[0], hidden[0]), 1)), dim=1) + attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_output.unsqueeze(0)) + output = torch.cat((output[0], attn_applied[0]), 1) + output = self.attn_combine(output).unsqueeze(0) + + + + output = torch.nn.functional.relu(output) + output, hidden = self.gru(output, hidden) + output = self.softmax(self.out(output[0])) + return output, hidden + + +encoder = EncoderRNN(nb_of_char_codes, hidden_size).to(device) +decoder = DecoderRNN(hidden_size, nb_of_char_codes).to(device) +criterion = nn.NLLLoss().to(device) +optimizer = optim.Adam((list(encoder.parameters()) + list(decoder.parameters()))) + +counter = 0 +losses = [] + +for s,t in char_source(): + counter += 1 + encoder.zero_grad() + decoder.zero_grad() + x = torch.tensor(s, dtype=torch.long, device=device) + encoder_hidden = encoder.initHidden() + encoder_output = torch.zeros(MAX_LENGTH, hidden_size, device=device) + for i in range(x.shape[0]): + output, encoder_hidden = encoder(x[i].unsqueeze(0).unsqueeze(0), encoder_hidden) + encoder_output[i] = output[0,0] + + decoder_hidden = encoder_hidden + + decoder_input = torch.tensor([[SOS_token_id]], device=device) + + t.append(EOS_token_id) + y = torch.tensor(t, dtype=torch.long, device=device) + loss = 0 + output_string = '' + for di in range(y.shape[0]): + decoder_output, decoder_hidden = decoder( + decoder_input, decoder_hidden, encoder_output) + topv, topi = decoder_output.topk(1) + decoder_input = topi.detach() # detach from history as input + + output_string += chr(topi) + loss += criterion(decoder_output, y[di].unsqueeze(0)) + if chr(topi) == EOS_token_id: + break + + losses.append(loss.item()) + if counter % step == 0: + # print(counter, end='\t') + avg_loss = sum(losses)/len(losses) + print(f"{counter}: {avg_loss}") + losses = [] + print('IN :\t', ''.join([chr(a) for a in s])) + print('EXP:\t', ''.join([chr(a) for a in t])) + print('OUT:\t', output_string) + + loss.backward() + optimizer.step() + diff --git a/rnnatt.py b/rnnatt.py new file mode 100644 index 0000000..e69de29 diff --git a/rnnreg.py b/rnnreg.py new file mode 100644 index 0000000..e69de29