128 lines
4.3 KiB
Python
128 lines
4.3 KiB
Python
|
from lang import SOS_token, EOS_token
|
||
|
import torch
|
||
|
import random
|
||
|
import math
|
||
|
import time
|
||
|
from torch import nn, optim
|
||
|
import torch
|
||
|
from lang import EOS_token, tokenizer
|
||
|
import pickle
|
||
|
|
||
|
MAX_LENGTH = 300
|
||
|
device = 'cuda'
|
||
|
teacher_forcing_ratio = 0.5
|
||
|
|
||
|
with open('data/pairs.pkl', 'rb') as input_file:
|
||
|
pairs = pickle.load(input_file)
|
||
|
|
||
|
with open('data/pl_lang.pkl', 'rb') as input_file:
|
||
|
input_lang = pickle.load(input_file)
|
||
|
|
||
|
with open('data/en_lang.pkl', 'rb') as out_file:
|
||
|
output_lang = pickle.load(out_file)
|
||
|
|
||
|
def indexesFromSentence(lang, sentence):
|
||
|
return [lang.word2index[word] for word in tokenizer.tokenize(sentence) if word in lang.word2index]
|
||
|
|
||
|
|
||
|
def tensorFromSentence(lang, sentence):
|
||
|
indexes = indexesFromSentence(lang, sentence)
|
||
|
indexes.append(EOS_token)
|
||
|
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
|
||
|
|
||
|
|
||
|
def tensorsFromPair(pair):
|
||
|
input_tensor = tensorFromSentence(input_lang, pair[0])
|
||
|
target_tensor = tensorFromSentence(output_lang, pair[1])
|
||
|
return (input_tensor, target_tensor)
|
||
|
|
||
|
|
||
|
def asMinutes(s):
|
||
|
m = math.floor(s / 60)
|
||
|
s -= m * 60
|
||
|
return '%dm %ds' % (m, s)
|
||
|
|
||
|
def timeSince(since, percent):
|
||
|
now = time.time()
|
||
|
s = now - since
|
||
|
es = s / (percent)
|
||
|
rs = es - s
|
||
|
return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
|
||
|
|
||
|
def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
|
||
|
encoder_hidden = encoder.initHidden()
|
||
|
encoder_optimizer.zero_grad()
|
||
|
decoder_optimizer.zero_grad()
|
||
|
|
||
|
input_length = input_tensor.size(0)
|
||
|
target_length = target_tensor.size(0)
|
||
|
|
||
|
encoder_outputs = torch.zeros(max_length, max_length, encoder.hidden_size, device=device)
|
||
|
|
||
|
loss = 0
|
||
|
|
||
|
for ei in range(input_length):
|
||
|
encoder_output, encoder_hidden = encoder(
|
||
|
input_tensor[ei], encoder_hidden)
|
||
|
encoder_outputs[ei] = encoder_output[0, 0, 0]
|
||
|
|
||
|
|
||
|
decoder_input = torch.tensor([[SOS_token]], device=device)
|
||
|
decoder_hidden = encoder_hidden
|
||
|
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
|
||
|
|
||
|
if use_teacher_forcing:
|
||
|
# Teacher forcing: Feed the target as the next input
|
||
|
for di in range(target_length):
|
||
|
decoder_output, decoder_hidden, decoder_attention = decoder(
|
||
|
decoder_input, decoder_hidden, encoder_outputs)
|
||
|
loss += criterion(decoder_output, target_tensor[di])
|
||
|
decoder_input = target_tensor[di]
|
||
|
|
||
|
else:
|
||
|
# Without teacher forcing: use its own predictions as the next input
|
||
|
for di in range(target_length):
|
||
|
decoder_output, decoder_hidden, decoder_attention = decoder(
|
||
|
decoder_input, decoder_hidden, encoder_outputs)
|
||
|
topv, topi = decoder_output.topk(1)
|
||
|
decoder_input = topi.squeeze().detach()
|
||
|
|
||
|
loss += criterion(decoder_output, target_tensor[di])
|
||
|
if decoder_input.item() == EOS_token:
|
||
|
break
|
||
|
|
||
|
loss.backward()
|
||
|
|
||
|
encoder_optimizer.step()
|
||
|
decoder_optimizer.step()
|
||
|
|
||
|
return loss.item() / target_length
|
||
|
|
||
|
def trainIters(encoder, decoder, n_iters, print_every=10, plot_every=100, learning_rate=0.01):
|
||
|
start = time.time()
|
||
|
print_loss_total = 0 # Reset every print_every
|
||
|
plot_loss_total = 0 # Reset every plot_every
|
||
|
|
||
|
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
|
||
|
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
|
||
|
training_pairs = [tensorsFromPair(random.choice(pairs))
|
||
|
for i in range(n_iters)]
|
||
|
criterion = nn.NLLLoss()
|
||
|
|
||
|
for iter in range(1, n_iters + 1):
|
||
|
training_pair = training_pairs[iter - 1]
|
||
|
input_tensor = training_pair[0]
|
||
|
target_tensor = training_pair[1]
|
||
|
|
||
|
loss = train(input_tensor, target_tensor, encoder,
|
||
|
decoder, encoder_optimizer, decoder_optimizer, criterion)
|
||
|
print_loss_total += loss
|
||
|
plot_loss_total += loss
|
||
|
|
||
|
if iter % print_every == 0:
|
||
|
print_loss_avg = print_loss_total / print_every
|
||
|
print_loss_total = 0
|
||
|
print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
|
||
|
iter, iter / n_iters * 100, print_loss_avg))
|
||
|
torch.save(encoder.state_dict(), 'encoder.dict')
|
||
|
torch.save(decoder.state_dict(), 'decoder.dict')
|