wmt-2020-pl-en/model_train.py
2021-01-31 16:54:20 +01:00

131 lines
4.4 KiB
Python

from lang import SOS_token
import torch
import random
import math
import time
from torch import nn, optim
from torch.nn.utils.rnn import pad_sequence
import torch
import pickle
MAX_LENGTH = 25
device = 'cuda'
teacher_forcing_ratio = 0.8
with open('data/pairs.pkl', 'rb') as input_file:
pairs = pickle.load(input_file)
with open('data/pl_lang.pkl', 'rb') as input_file:
input_lang = pickle.load(input_file)
with open('data/en_lang.pkl', 'rb') as out_file:
output_lang = pickle.load(out_file)
def indexesFromSentence(lang, sentence):
return [lang.word2index[word] if word in lang.word2index else 1 for word in sentence]
def tensorFromSentence(lang, sentence):
indexes = indexesFromSentence(lang, sentence)
indexes.append(0)
out = torch.tensor(indexes, device=device).view(-1, 1)
return out
def tensorsFromPair(pair):
input_tensor = tensorFromSentence(input_lang, pair[0])
target_tensor = tensorFromSentence(output_lang, pair[1])
return (input_tensor, target_tensor)
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
encoder_hidden = encoder.initHidden()
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
input_length = input_tensor.size(0)
target_length = target_tensor.size(0)
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
loss = 0
encoder_output, encoder_hidden = encoder(input_tensor, encoder_hidden)
decoder_input = torch.tensor([[SOS_token]], device=device)
decoder_hidden = encoder_hidden
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
if use_teacher_forcing:
for di in range(target_length):
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
loss += criterion(decoder_output, target_tensor[di])
decoder_input = target_tensor[di].unsqueeze(0)
else:
for di in range(target_length):
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
topv, topi = decoder_output.topk(1)
decoder_input = topi.transpose(0, 1).detach()
loss += criterion(decoder_output, target_tensor[di])
loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
return loss.item() / target_length
def trainIters(encoder, decoder, n_iters, print_every=10, plot_every=100, learning_rate=0.01):
start = time.time()
print_loss_total = 0 # Reset every print_every
plot_loss_total = 0 # Reset every plot_every
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
criterion = nn.NLLLoss()
pairs_in = pairs[:10000]
for iter in range(1, n_iters + 1):
try:
for idx, training_pair in enumerate(pairs_in):
input_ = training_pair[0]
target_ = training_pair[1]
input_ = input_.split()
input_ = input_[::-1]
target_ = target_.split()
if len(input_)>1 and len(target_)>1:
input_tensor = tensorFromSentence(input_lang, input_)
target_tensor = tensorFromSentence(output_lang, target_)
loss = train(input_tensor, target_tensor, encoder,
decoder, encoder_optimizer, decoder_optimizer, criterion)
print_loss_total += loss
plot_loss_total += loss
print(idx/len(pairs_in), end='\r')
if iter % print_every == 0:
print_loss_avg = print_loss_total / print_every
print_loss_total = 0
print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
iter, iter / n_iters * 100, print_loss_avg))
except KeyboardInterrupt:
torch.save(encoder.state_dict(), 'encoder.dict')
torch.save(decoder.state_dict(), 'decoder.dict')
torch.save(encoder.state_dict(), 'encoder.dict')
torch.save(decoder.state_dict(), 'decoder.dict')