#!/usr/bin/python3 # https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html import sys import torch from torch import nn, optim nb_of_char_codes = 128 + 2 SOS_token_id = 128 # start of sentence EOS_token_id = 129 # end of sentence hidden_size = 32 step = 200 device = torch.device('cpu') f = open('eng-fra.txt') def char_source(): for line in f: s, t = line.rstrip('\n').split('\t') s_list = [] t_list = [] for c in s: c_code = ord(c) if c_code < nb_of_char_codes: s_list.append(ord(c)) for c in t: c_code = ord(c) if c_code < nb_of_char_codes: t_list.append(ord(c)) yield s_list, t_list class EncoderRNN(nn.Module): def __init__(self, input_size, hidden_size): super(EncoderRNN, self).__init__() self.hidden_size = hidden_size self.embedding = nn.Embedding(input_size, hidden_size) self.gru = nn.GRU(hidden_size, hidden_size) def forward(self, input, hidden): embedded = self.embedding(input) output = embedded output, hidden = self.gru(output, hidden) return output, hidden def initHidden(self): return torch.zeros(1,1, self.hidden_size, device=device) class DecoderRNN(nn.Module): def __init__(self, hidden_size, output_size): super(DecoderRNN, self).__init__() self.hidden_size = hidden_size self.embedding = nn.Embedding(output_size, hidden_size) self.gru = nn.GRU(hidden_size, hidden_size) self.out = nn.Linear(hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden): output = self.embedding(input) output = torch.nn.functional.relu(output) output, hidden = self.gru(output, hidden) output = self.softmax(self.out(output[0])) return output, hidden encoder = EncoderRNN(nb_of_char_codes, hidden_size).to(device) decoder = DecoderRNN(hidden_size, nb_of_char_codes).to(device) criterion = nn.NLLLoss().to(device) optimizer = optim.Adam((list(encoder.parameters()) + list(decoder.parameters()))) counter = 0 losses = [] for s,t in char_source(): counter += 1 encoder.zero_grad() decoder.zero_grad() x = torch.tensor(s, dtype=torch.long, device=device) encoder_hidden = encoder.initHidden() encoder_output = torch.zeros(hidden_size, hidden_size, device=device) for i in range(x.shape[0]): output, encoder_hidden = encoder(x[i].unsqueeze(0).unsqueeze(0), encoder_hidden) encoder_output[i] = output[0,0] decoder_hidden = encoder_hidden decoder_input = torch.tensor([[SOS_token_id]], device=device) t.append(EOS_token_id) y = torch.tensor(t, dtype=torch.long, device=device) loss = 0 output_string = '' for di in range(y.shape[0]): decoder_output, decoder_hidden = decoder( decoder_input, decoder_hidden) topv, topi = decoder_output.topk(1) decoder_input = topi.detach() # detach from history as input output_string += chr(topi) loss += criterion(decoder_output, y[di].unsqueeze(0)) if chr(topi) == EOS_token_id: break losses.append(loss.item()) if counter % step == 0: # print(counter, end='\t') avg_loss = sum(losses)/len(losses) print(f"{counter}: {avg_loss}") losses = [] print('IN :\t', ''.join([chr(a) for a in s])) print('EXP:\t', ''.join([chr(a) for a in t])) print('OUT:\t', output_string) loss.backward() optimizer.step()