137 lines
4.1 KiB
Python
137 lines
4.1 KiB
Python
|
#!/usr/bin/python3
|
||
|
|
||
|
# https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
|
||
|
|
||
|
import sys
|
||
|
import torch
|
||
|
from torch import nn, optim
|
||
|
|
||
|
nb_of_char_codes = 128 + 2
|
||
|
SOS_token_id = 128 # start of sentence
|
||
|
EOS_token_id = 129 # end of sentence
|
||
|
MAX_LENGTH = 20
|
||
|
|
||
|
hidden_size = 32
|
||
|
step = 200
|
||
|
|
||
|
device = torch.device('cpu')
|
||
|
|
||
|
f = open('eng-fra.txt')
|
||
|
def char_source():
|
||
|
for line in f:
|
||
|
s, t = line.rstrip('\n').split('\t')
|
||
|
s_list = []
|
||
|
t_list = []
|
||
|
|
||
|
for c in s:
|
||
|
c_code = ord(c)
|
||
|
if c_code < nb_of_char_codes:
|
||
|
s_list.append(ord(c))
|
||
|
|
||
|
for c in t:
|
||
|
c_code = ord(c)
|
||
|
if c_code < nb_of_char_codes:
|
||
|
t_list.append(ord(c))
|
||
|
|
||
|
yield s_list, t_list
|
||
|
|
||
|
class EncoderRNN(nn.Module):
|
||
|
def __init__(self, input_size, hidden_size):
|
||
|
super(EncoderRNN, self).__init__()
|
||
|
self.hidden_size = hidden_size
|
||
|
|
||
|
self.embedding = nn.Embedding(input_size, hidden_size)
|
||
|
self.gru = nn.GRU(hidden_size, hidden_size)
|
||
|
|
||
|
def forward(self, input, hidden):
|
||
|
embedded = self.embedding(input)
|
||
|
output = embedded
|
||
|
output, hidden = self.gru(output, hidden)
|
||
|
return output, hidden
|
||
|
|
||
|
def initHidden(self):
|
||
|
return torch.zeros(1,1, self.hidden_size, device=device)
|
||
|
|
||
|
class DecoderRNN(nn.Module):
|
||
|
def __init__(self, hidden_size, output_size, max_length=MAX_LENGTH):
|
||
|
super(DecoderRNN, self).__init__()
|
||
|
self.hidden_size = hidden_size
|
||
|
|
||
|
self.embedding = nn.Embedding(output_size, hidden_size)
|
||
|
self.gru = nn.GRU(hidden_size, hidden_size)
|
||
|
self.out = nn.Linear(hidden_size, output_size)
|
||
|
self.softmax = nn.LogSoftmax(dim=1)
|
||
|
|
||
|
|
||
|
self.attn = nn.Linear(self.hidden_size * 2, max_length)
|
||
|
self.attn_combine = nn.Linear(hidden_size * 2, hidden_size)
|
||
|
|
||
|
def forward(self, input, hidden, encoder_output):
|
||
|
output = self.embedding(input)
|
||
|
|
||
|
|
||
|
attn_weights = torch.nn.functional.softmax(self.attn(torch.cat((output[0], hidden[0]), 1)), dim=1)
|
||
|
attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_output.unsqueeze(0))
|
||
|
output = torch.cat((output[0], attn_applied[0]), 1)
|
||
|
output = self.attn_combine(output).unsqueeze(0)
|
||
|
|
||
|
|
||
|
|
||
|
output = torch.nn.functional.relu(output)
|
||
|
output, hidden = self.gru(output, hidden)
|
||
|
output = self.softmax(self.out(output[0]))
|
||
|
return output, hidden
|
||
|
|
||
|
|
||
|
encoder = EncoderRNN(nb_of_char_codes, hidden_size).to(device)
|
||
|
decoder = DecoderRNN(hidden_size, nb_of_char_codes).to(device)
|
||
|
criterion = nn.NLLLoss().to(device)
|
||
|
optimizer = optim.Adam((list(encoder.parameters()) + list(decoder.parameters())))
|
||
|
|
||
|
counter = 0
|
||
|
losses = []
|
||
|
|
||
|
for s,t in char_source():
|
||
|
counter += 1
|
||
|
encoder.zero_grad()
|
||
|
decoder.zero_grad()
|
||
|
x = torch.tensor(s, dtype=torch.long, device=device)
|
||
|
encoder_hidden = encoder.initHidden()
|
||
|
encoder_output = torch.zeros(MAX_LENGTH, hidden_size, device=device)
|
||
|
for i in range(x.shape[0]):
|
||
|
output, encoder_hidden = encoder(x[i].unsqueeze(0).unsqueeze(0), encoder_hidden)
|
||
|
encoder_output[i] = output[0,0]
|
||
|
|
||
|
decoder_hidden = encoder_hidden
|
||
|
|
||
|
decoder_input = torch.tensor([[SOS_token_id]], device=device)
|
||
|
|
||
|
t.append(EOS_token_id)
|
||
|
y = torch.tensor(t, dtype=torch.long, device=device)
|
||
|
loss = 0
|
||
|
output_string = ''
|
||
|
for di in range(y.shape[0]):
|
||
|
decoder_output, decoder_hidden = decoder(
|
||
|
decoder_input, decoder_hidden, encoder_output)
|
||
|
topv, topi = decoder_output.topk(1)
|
||
|
decoder_input = topi.detach() # detach from history as input
|
||
|
|
||
|
output_string += chr(topi)
|
||
|
loss += criterion(decoder_output, y[di].unsqueeze(0))
|
||
|
if chr(topi) == EOS_token_id:
|
||
|
break
|
||
|
|
||
|
losses.append(loss.item())
|
||
|
if counter % step == 0:
|
||
|
# print(counter, end='\t')
|
||
|
avg_loss = sum(losses)/len(losses)
|
||
|
print(f"{counter}: {avg_loss}")
|
||
|
losses = []
|
||
|
print('IN :\t', ''.join([chr(a) for a in s]))
|
||
|
print('EXP:\t', ''.join([chr(a) for a in t]))
|
||
|
print('OUT:\t', output_string)
|
||
|
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|