Add solution

This commit is contained in:
SzyGra 2021-01-31 16:54:20 +01:00
parent 2e7c5b13c0
commit 21dad8fc72
8 changed files with 2073 additions and 914 deletions

File diff suppressed because it is too large Load Diff

11
lang.py
View File

@ -1,18 +1,17 @@
from nltk.tokenize import RegexpTokenizer from nltk.tokenize import RegexpTokenizer
SOS_token = 0 SOS_token = 2
EOS_token = 1 PAD_token = 0
tokenizer = RegexpTokenizer(r'\w+')
class Lang: class Lang:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.word2index = {} self.word2index = {}
self.word2count = {} self.word2count = {}
self.index2word = {0: "SOS", 1: "EOS"} self.index2word = {0: "PAD", 1: "UNK", 2: "SOS"}
self.n_words = 2 # Count SOS and EOS self.n_words = 2
def addSentence(self, sentence): def addSentence(self, sentence):
for word in tokenizer.tokenize(sentence): for word in sentence.split():
self.addWord(word) self.addWord(word)
def addWord(self, word): def addWord(self, word):

View File

@ -2,8 +2,6 @@ import torch
from torch import nn from torch import nn
device = 'cuda' device = 'cuda'
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.init as init
from lang import SOS_token, EOS_token
class EncoderRNN(nn.Module): class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size): def __init__(self, input_size, hidden_size):
@ -14,7 +12,7 @@ class EncoderRNN(nn.Module):
self.lstm = nn.LSTM(hidden_size, hidden_size) self.lstm = nn.LSTM(hidden_size, hidden_size)
def forward(self, input, hidden): def forward(self, input, hidden):
embedded = self.embedding(input).view(1, 1, -1) embedded = self.embedding(input)
output = embedded output = embedded
output, hidden = self.lstm(output, hidden) output, hidden = self.lstm(output, hidden)
return output, hidden return output, hidden
@ -33,46 +31,8 @@ class DecoderRNN(nn.Module):
self.softmax = nn.LogSoftmax(dim=1) self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden): def forward(self, input, hidden):
output = self.embedding(input).view(1, 1, -1) output = self.embedding(input)
output = F.relu(output) output = F.relu(output)
output, hidden = self.lstm(output, hidden) output, hidden = self.lstm(output, hidden)
output = self.softmax(self.out(output[0])) output = self.softmax(self.out(output[0]))
return output, hidden return output, hidden
def initHidden(self):
return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device))
class AttnDecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=300):
super(AttnDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.max_length = max_length
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size, self.max_length)
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.dropout = nn.Dropout(self.dropout_p)
self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)
self.out = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
attn_weights = F.softmax(
self.attn(torch.cat((embedded, hidden[0]), 1)), dim=1)
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
encoder_outputs.unsqueeze(0))
output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
output = F.relu(output)
output, hidden = self.lstm(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1)
print(output.shape, hidden.shape)
return output, hidden, attn_weights
def initHidden(self):
return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device))

View File

@ -1,16 +1,16 @@
from lang import SOS_token, EOS_token from lang import SOS_token
import torch import torch
import random import random
import math import math
import time import time
from torch import nn, optim from torch import nn, optim
from torch.nn.utils.rnn import pad_sequence
import torch import torch
from lang import EOS_token, tokenizer
import pickle import pickle
MAX_LENGTH = 300 MAX_LENGTH = 25
device = 'cuda' device = 'cuda'
teacher_forcing_ratio = 0.5 teacher_forcing_ratio = 0.8
with open('data/pairs.pkl', 'rb') as input_file: with open('data/pairs.pkl', 'rb') as input_file:
pairs = pickle.load(input_file) pairs = pickle.load(input_file)
@ -21,14 +21,16 @@ with open('data/pl_lang.pkl', 'rb') as input_file:
with open('data/en_lang.pkl', 'rb') as out_file: with open('data/en_lang.pkl', 'rb') as out_file:
output_lang = pickle.load(out_file) output_lang = pickle.load(out_file)
def indexesFromSentence(lang, sentence): def indexesFromSentence(lang, sentence):
return [lang.word2index[word] for word in tokenizer.tokenize(sentence) if word in lang.word2index] return [lang.word2index[word] if word in lang.word2index else 1 for word in sentence]
def tensorFromSentence(lang, sentence): def tensorFromSentence(lang, sentence):
indexes = indexesFromSentence(lang, sentence) indexes = indexesFromSentence(lang, sentence)
indexes.append(EOS_token) indexes.append(0)
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1) out = torch.tensor(indexes, device=device).view(-1, 1)
return out
def tensorsFromPair(pair): def tensorsFromPair(pair):
@ -57,14 +59,12 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, deco
input_length = input_tensor.size(0) input_length = input_tensor.size(0)
target_length = target_tensor.size(0) target_length = target_tensor.size(0)
encoder_outputs = torch.zeros(max_length, max_length, encoder.hidden_size, device=device) encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
loss = 0 loss = 0
encoder_output, encoder_hidden = encoder(input_tensor, encoder_hidden)
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(
input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0, 0]
decoder_input = torch.tensor([[SOS_token]], device=device) decoder_input = torch.tensor([[SOS_token]], device=device)
@ -72,25 +72,17 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, deco
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
if use_teacher_forcing: if use_teacher_forcing:
# Teacher forcing: Feed the target as the next input
for di in range(target_length): for di in range(target_length):
decoder_output, decoder_hidden, decoder_attention = decoder( decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
decoder_input, decoder_hidden, encoder_outputs)
loss += criterion(decoder_output, target_tensor[di]) loss += criterion(decoder_output, target_tensor[di])
decoder_input = target_tensor[di] decoder_input = target_tensor[di].unsqueeze(0)
else: else:
# Without teacher forcing: use its own predictions as the next input
for di in range(target_length): for di in range(target_length):
decoder_output, decoder_hidden, decoder_attention = decoder( decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
decoder_input, decoder_hidden, encoder_outputs)
topv, topi = decoder_output.topk(1) topv, topi = decoder_output.topk(1)
decoder_input = topi.squeeze().detach() decoder_input = topi.transpose(0, 1).detach()
loss += criterion(decoder_output, target_tensor[di]) loss += criterion(decoder_output, target_tensor[di])
if decoder_input.item() == EOS_token:
break
loss.backward() loss.backward()
encoder_optimizer.step() encoder_optimizer.step()
@ -105,24 +97,34 @@ def trainIters(encoder, decoder, n_iters, print_every=10, plot_every=100, learni
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate) encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate) decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
training_pairs = [tensorsFromPair(random.choice(pairs))
for i in range(n_iters)]
criterion = nn.NLLLoss() criterion = nn.NLLLoss()
pairs_in = pairs[:10000]
for iter in range(1, n_iters + 1): for iter in range(1, n_iters + 1):
training_pair = training_pairs[iter - 1] try:
input_tensor = training_pair[0] for idx, training_pair in enumerate(pairs_in):
target_tensor = training_pair[1] input_ = training_pair[0]
target_ = training_pair[1]
input_ = input_.split()
input_ = input_[::-1]
target_ = target_.split()
loss = train(input_tensor, target_tensor, encoder, if len(input_)>1 and len(target_)>1:
decoder, encoder_optimizer, decoder_optimizer, criterion) input_tensor = tensorFromSentence(input_lang, input_)
print_loss_total += loss target_tensor = tensorFromSentence(output_lang, target_)
plot_loss_total += loss
if iter % print_every == 0: loss = train(input_tensor, target_tensor, encoder,
print_loss_avg = print_loss_total / print_every decoder, encoder_optimizer, decoder_optimizer, criterion)
print_loss_total = 0 print_loss_total += loss
print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters), plot_loss_total += loss
print(idx/len(pairs_in), end='\r')
if iter % print_every == 0:
print_loss_avg = print_loss_total / print_every
print_loss_total = 0
print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
iter, iter / n_iters * 100, print_loss_avg)) iter, iter / n_iters * 100, print_loss_avg))
except KeyboardInterrupt:
torch.save(encoder.state_dict(), 'encoder.dict')
torch.save(decoder.state_dict(), 'decoder.dict')
torch.save(encoder.state_dict(), 'encoder.dict') torch.save(encoder.state_dict(), 'encoder.dict')
torch.save(decoder.state_dict(), 'decoder.dict') torch.save(decoder.state_dict(), 'decoder.dict')

View File

@ -1,4 +1,4 @@
from model_train import tensorFromSentence, SOS_token, MAX_LENGTH, device, EOS_token from model_train import tensorFromSentence, SOS_token, MAX_LENGTH, device
import pickle import pickle
from lstm_model import EncoderRNN, DecoderRNN from lstm_model import EncoderRNN, DecoderRNN
import sys import sys
@ -14,45 +14,43 @@ with open('data/en_lang.pkl', 'rb') as out_file:
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH): def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
with torch.no_grad(): with torch.no_grad():
input_tensor = tensorFromSentence(input_lang, sentence) input_tensor = tensorFromSentence(input_lang, sentence)
input_length = input_tensor.size()[0]
encoder_hidden = encoder.initHidden() encoder_hidden = encoder.initHidden()
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device) encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
loss = 0 encoder_output, encoder_hidden = encoder(input_tensor,encoder_hidden)
for ei in range(input_length): encoder_outputs = encoder_output
encoder_output, encoder_hidden = encoder(input_tensor[ei],
encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
decoder_input = torch.tensor([[SOS_token]], dtype=torch.long, device=device).view(-1, 1) # SOS decoder_input = torch.tensor([[SOS_token]], device=device)
decoder_hidden = encoder_hidden decoder_hidden = encoder_hidden
decoded_words = [] decoded_words = []
decoder_attentions = torch.zeros(max_length, max_length)
for di in range(max_length): for di in range(max_length):
decoder_output, decoder_hidden = decoder( decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
decoder_input, decoder_hidden) topv, topi = decoder_output.topk(1)
topv, topi = decoder_output.data.topk(1) decoded_words.append(topi)
if topi.item() == EOS_token: decoder_input = topi.transpose(0, 1)
break out = torch.stack(decoded_words)
else: return out
decoded_words.append(output_lang.index2word[topi.item()])
decoder_input = topi.squeeze().detach()
return decoded_words
hidden_size = 256 hidden_size = 256
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device) encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
decoder = DecoderRNN(hidden_size, output_lang.n_words).to(device) decoder = DecoderRNN(hidden_size, output_lang.n_words).to(device)
encoder.load_state_dict(torch.load('encoder.dict')) encoder.load_state_dict(torch.load('encoder.dict'))
decoder.load_state_dict(torch.load('decoder.dict')) decoder.load_state_dict(torch.load('decoder.dict'))
encoder.eval()
decoder.eval()
for line in sys.stdin: for line in sys.stdin:
line = line.rstrip() line = line.rstrip()
dec_words = evaluate(encoder, decoder, line, MAX_LENGTH) dec_words = evaluate(encoder, decoder, line, MAX_LENGTH)
print(' '.join(dec_words)) dec_words = dec_words.transpose(0, 1)
for sen in dec_words:
out = []
for idx in sen:
if idx == 0:
break
out.append(output_lang.index2word[idx.item()])
print(' '.join(out))

View File

@ -11,7 +11,7 @@ from torch import optim
import torch.nn.functional as F import torch.nn.functional as F
from lang import * from lang import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LENGTH = 300 MAX_LENGTH = 25
# Turn a Unicode string to plain ASCII, thanks to # Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427 # https://stackoverflow.com/a/518232/2809427

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
from lstm_model import EncoderRNN, DecoderRNN, AttnDecoderRNN from lstm_model import EncoderRNN, DecoderRNN
from model_train import * from model_train import *
hidden_size = 256 hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device) encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device) attn_decoder1 = DecoderRNN(hidden_size, output_lang.n_words).to(device)
trainIters(encoder1, attn_decoder1, 10000, print_every=100) trainIters(encoder1, attn_decoder1, 5, print_every=1)