Add solution
This commit is contained in:
parent
2e7c5b13c0
commit
21dad8fc72
1600
dev-0/out.tsv
1600
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
11
lang.py
11
lang.py
@ -1,18 +1,17 @@
|
|||||||
from nltk.tokenize import RegexpTokenizer
|
from nltk.tokenize import RegexpTokenizer
|
||||||
SOS_token = 0
|
SOS_token = 2
|
||||||
EOS_token = 1
|
PAD_token = 0
|
||||||
tokenizer = RegexpTokenizer(r'\w+')
|
|
||||||
|
|
||||||
class Lang:
|
class Lang:
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.word2index = {}
|
self.word2index = {}
|
||||||
self.word2count = {}
|
self.word2count = {}
|
||||||
self.index2word = {0: "SOS", 1: "EOS"}
|
self.index2word = {0: "PAD", 1: "UNK", 2: "SOS"}
|
||||||
self.n_words = 2 # Count SOS and EOS
|
self.n_words = 2
|
||||||
|
|
||||||
def addSentence(self, sentence):
|
def addSentence(self, sentence):
|
||||||
for word in tokenizer.tokenize(sentence):
|
for word in sentence.split():
|
||||||
self.addWord(word)
|
self.addWord(word)
|
||||||
|
|
||||||
def addWord(self, word):
|
def addWord(self, word):
|
||||||
|
@ -2,8 +2,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
device = 'cuda'
|
device = 'cuda'
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.nn.init as init
|
|
||||||
from lang import SOS_token, EOS_token
|
|
||||||
|
|
||||||
class EncoderRNN(nn.Module):
|
class EncoderRNN(nn.Module):
|
||||||
def __init__(self, input_size, hidden_size):
|
def __init__(self, input_size, hidden_size):
|
||||||
@ -14,7 +12,7 @@ class EncoderRNN(nn.Module):
|
|||||||
self.lstm = nn.LSTM(hidden_size, hidden_size)
|
self.lstm = nn.LSTM(hidden_size, hidden_size)
|
||||||
|
|
||||||
def forward(self, input, hidden):
|
def forward(self, input, hidden):
|
||||||
embedded = self.embedding(input).view(1, 1, -1)
|
embedded = self.embedding(input)
|
||||||
output = embedded
|
output = embedded
|
||||||
output, hidden = self.lstm(output, hidden)
|
output, hidden = self.lstm(output, hidden)
|
||||||
return output, hidden
|
return output, hidden
|
||||||
@ -33,46 +31,8 @@ class DecoderRNN(nn.Module):
|
|||||||
self.softmax = nn.LogSoftmax(dim=1)
|
self.softmax = nn.LogSoftmax(dim=1)
|
||||||
|
|
||||||
def forward(self, input, hidden):
|
def forward(self, input, hidden):
|
||||||
output = self.embedding(input).view(1, 1, -1)
|
output = self.embedding(input)
|
||||||
output = F.relu(output)
|
output = F.relu(output)
|
||||||
output, hidden = self.lstm(output, hidden)
|
output, hidden = self.lstm(output, hidden)
|
||||||
output = self.softmax(self.out(output[0]))
|
output = self.softmax(self.out(output[0]))
|
||||||
return output, hidden
|
return output, hidden
|
||||||
|
|
||||||
def initHidden(self):
|
|
||||||
return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device))
|
|
||||||
|
|
||||||
class AttnDecoderRNN(nn.Module):
|
|
||||||
def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=300):
|
|
||||||
super(AttnDecoderRNN, self).__init__()
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.output_size = output_size
|
|
||||||
self.dropout_p = dropout_p
|
|
||||||
self.max_length = max_length
|
|
||||||
|
|
||||||
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
|
|
||||||
self.attn = nn.Linear(self.hidden_size, self.max_length)
|
|
||||||
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
|
|
||||||
self.dropout = nn.Dropout(self.dropout_p)
|
|
||||||
self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)
|
|
||||||
self.out = nn.Linear(self.hidden_size, self.output_size)
|
|
||||||
|
|
||||||
def forward(self, input, hidden, encoder_outputs):
|
|
||||||
embedded = self.embedding(input).view(1, 1, -1)
|
|
||||||
embedded = self.dropout(embedded)
|
|
||||||
attn_weights = F.softmax(
|
|
||||||
self.attn(torch.cat((embedded, hidden[0]), 1)), dim=1)
|
|
||||||
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
|
|
||||||
encoder_outputs.unsqueeze(0))
|
|
||||||
output = torch.cat((embedded[0], attn_applied[0]), 1)
|
|
||||||
output = self.attn_combine(output).unsqueeze(0)
|
|
||||||
|
|
||||||
output = F.relu(output)
|
|
||||||
output, hidden = self.lstm(output, hidden)
|
|
||||||
|
|
||||||
output = F.log_softmax(self.out(output[0]), dim=1)
|
|
||||||
print(output.shape, hidden.shape)
|
|
||||||
return output, hidden, attn_weights
|
|
||||||
|
|
||||||
def initHidden(self):
|
|
||||||
return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device))
|
|
@ -1,16 +1,16 @@
|
|||||||
from lang import SOS_token, EOS_token
|
from lang import SOS_token
|
||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
import torch
|
import torch
|
||||||
from lang import EOS_token, tokenizer
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
MAX_LENGTH = 300
|
MAX_LENGTH = 25
|
||||||
device = 'cuda'
|
device = 'cuda'
|
||||||
teacher_forcing_ratio = 0.5
|
teacher_forcing_ratio = 0.8
|
||||||
|
|
||||||
with open('data/pairs.pkl', 'rb') as input_file:
|
with open('data/pairs.pkl', 'rb') as input_file:
|
||||||
pairs = pickle.load(input_file)
|
pairs = pickle.load(input_file)
|
||||||
@ -21,14 +21,16 @@ with open('data/pl_lang.pkl', 'rb') as input_file:
|
|||||||
with open('data/en_lang.pkl', 'rb') as out_file:
|
with open('data/en_lang.pkl', 'rb') as out_file:
|
||||||
output_lang = pickle.load(out_file)
|
output_lang = pickle.load(out_file)
|
||||||
|
|
||||||
|
|
||||||
def indexesFromSentence(lang, sentence):
|
def indexesFromSentence(lang, sentence):
|
||||||
return [lang.word2index[word] for word in tokenizer.tokenize(sentence) if word in lang.word2index]
|
return [lang.word2index[word] if word in lang.word2index else 1 for word in sentence]
|
||||||
|
|
||||||
|
|
||||||
def tensorFromSentence(lang, sentence):
|
def tensorFromSentence(lang, sentence):
|
||||||
indexes = indexesFromSentence(lang, sentence)
|
indexes = indexesFromSentence(lang, sentence)
|
||||||
indexes.append(EOS_token)
|
indexes.append(0)
|
||||||
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
|
out = torch.tensor(indexes, device=device).view(-1, 1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def tensorsFromPair(pair):
|
def tensorsFromPair(pair):
|
||||||
@ -57,14 +59,12 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, deco
|
|||||||
input_length = input_tensor.size(0)
|
input_length = input_tensor.size(0)
|
||||||
target_length = target_tensor.size(0)
|
target_length = target_tensor.size(0)
|
||||||
|
|
||||||
encoder_outputs = torch.zeros(max_length, max_length, encoder.hidden_size, device=device)
|
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
|
||||||
|
|
||||||
loss = 0
|
loss = 0
|
||||||
|
encoder_output, encoder_hidden = encoder(input_tensor, encoder_hidden)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for ei in range(input_length):
|
|
||||||
encoder_output, encoder_hidden = encoder(
|
|
||||||
input_tensor[ei], encoder_hidden)
|
|
||||||
encoder_outputs[ei] = encoder_output[0, 0, 0]
|
|
||||||
|
|
||||||
|
|
||||||
decoder_input = torch.tensor([[SOS_token]], device=device)
|
decoder_input = torch.tensor([[SOS_token]], device=device)
|
||||||
@ -72,25 +72,17 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, deco
|
|||||||
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
|
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
|
||||||
|
|
||||||
if use_teacher_forcing:
|
if use_teacher_forcing:
|
||||||
# Teacher forcing: Feed the target as the next input
|
|
||||||
for di in range(target_length):
|
for di in range(target_length):
|
||||||
decoder_output, decoder_hidden, decoder_attention = decoder(
|
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
|
||||||
decoder_input, decoder_hidden, encoder_outputs)
|
|
||||||
loss += criterion(decoder_output, target_tensor[di])
|
loss += criterion(decoder_output, target_tensor[di])
|
||||||
decoder_input = target_tensor[di]
|
decoder_input = target_tensor[di].unsqueeze(0)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Without teacher forcing: use its own predictions as the next input
|
|
||||||
for di in range(target_length):
|
for di in range(target_length):
|
||||||
decoder_output, decoder_hidden, decoder_attention = decoder(
|
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
|
||||||
decoder_input, decoder_hidden, encoder_outputs)
|
|
||||||
topv, topi = decoder_output.topk(1)
|
topv, topi = decoder_output.topk(1)
|
||||||
decoder_input = topi.squeeze().detach()
|
decoder_input = topi.transpose(0, 1).detach()
|
||||||
|
|
||||||
loss += criterion(decoder_output, target_tensor[di])
|
loss += criterion(decoder_output, target_tensor[di])
|
||||||
if decoder_input.item() == EOS_token:
|
|
||||||
break
|
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
encoder_optimizer.step()
|
encoder_optimizer.step()
|
||||||
@ -105,24 +97,34 @@ def trainIters(encoder, decoder, n_iters, print_every=10, plot_every=100, learni
|
|||||||
|
|
||||||
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
|
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
|
||||||
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
|
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
|
||||||
training_pairs = [tensorsFromPair(random.choice(pairs))
|
|
||||||
for i in range(n_iters)]
|
|
||||||
criterion = nn.NLLLoss()
|
criterion = nn.NLLLoss()
|
||||||
|
pairs_in = pairs[:10000]
|
||||||
for iter in range(1, n_iters + 1):
|
for iter in range(1, n_iters + 1):
|
||||||
training_pair = training_pairs[iter - 1]
|
try:
|
||||||
input_tensor = training_pair[0]
|
for idx, training_pair in enumerate(pairs_in):
|
||||||
target_tensor = training_pair[1]
|
input_ = training_pair[0]
|
||||||
|
target_ = training_pair[1]
|
||||||
|
input_ = input_.split()
|
||||||
|
input_ = input_[::-1]
|
||||||
|
target_ = target_.split()
|
||||||
|
|
||||||
loss = train(input_tensor, target_tensor, encoder,
|
if len(input_)>1 and len(target_)>1:
|
||||||
decoder, encoder_optimizer, decoder_optimizer, criterion)
|
input_tensor = tensorFromSentence(input_lang, input_)
|
||||||
print_loss_total += loss
|
target_tensor = tensorFromSentence(output_lang, target_)
|
||||||
plot_loss_total += loss
|
|
||||||
|
|
||||||
if iter % print_every == 0:
|
loss = train(input_tensor, target_tensor, encoder,
|
||||||
print_loss_avg = print_loss_total / print_every
|
decoder, encoder_optimizer, decoder_optimizer, criterion)
|
||||||
print_loss_total = 0
|
print_loss_total += loss
|
||||||
print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
|
plot_loss_total += loss
|
||||||
|
print(idx/len(pairs_in), end='\r')
|
||||||
|
|
||||||
|
if iter % print_every == 0:
|
||||||
|
print_loss_avg = print_loss_total / print_every
|
||||||
|
print_loss_total = 0
|
||||||
|
print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
|
||||||
iter, iter / n_iters * 100, print_loss_avg))
|
iter, iter / n_iters * 100, print_loss_avg))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
torch.save(encoder.state_dict(), 'encoder.dict')
|
||||||
|
torch.save(decoder.state_dict(), 'decoder.dict')
|
||||||
torch.save(encoder.state_dict(), 'encoder.dict')
|
torch.save(encoder.state_dict(), 'encoder.dict')
|
||||||
torch.save(decoder.state_dict(), 'decoder.dict')
|
torch.save(decoder.state_dict(), 'decoder.dict')
|
||||||
|
42
predict.py
42
predict.py
@ -1,4 +1,4 @@
|
|||||||
from model_train import tensorFromSentence, SOS_token, MAX_LENGTH, device, EOS_token
|
from model_train import tensorFromSentence, SOS_token, MAX_LENGTH, device
|
||||||
import pickle
|
import pickle
|
||||||
from lstm_model import EncoderRNN, DecoderRNN
|
from lstm_model import EncoderRNN, DecoderRNN
|
||||||
import sys
|
import sys
|
||||||
@ -14,45 +14,43 @@ with open('data/en_lang.pkl', 'rb') as out_file:
|
|||||||
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
|
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
input_tensor = tensorFromSentence(input_lang, sentence)
|
input_tensor = tensorFromSentence(input_lang, sentence)
|
||||||
input_length = input_tensor.size()[0]
|
|
||||||
encoder_hidden = encoder.initHidden()
|
encoder_hidden = encoder.initHidden()
|
||||||
|
|
||||||
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
|
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
|
||||||
|
|
||||||
loss = 0
|
encoder_output, encoder_hidden = encoder(input_tensor,encoder_hidden)
|
||||||
|
|
||||||
for ei in range(input_length):
|
encoder_outputs = encoder_output
|
||||||
encoder_output, encoder_hidden = encoder(input_tensor[ei],
|
|
||||||
encoder_hidden)
|
|
||||||
encoder_outputs[ei] = encoder_output[0, 0]
|
|
||||||
|
|
||||||
decoder_input = torch.tensor([[SOS_token]], dtype=torch.long, device=device).view(-1, 1) # SOS
|
decoder_input = torch.tensor([[SOS_token]], device=device)
|
||||||
|
|
||||||
decoder_hidden = encoder_hidden
|
decoder_hidden = encoder_hidden
|
||||||
|
|
||||||
decoded_words = []
|
decoded_words = []
|
||||||
decoder_attentions = torch.zeros(max_length, max_length)
|
|
||||||
|
|
||||||
for di in range(max_length):
|
for di in range(max_length):
|
||||||
decoder_output, decoder_hidden = decoder(
|
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
|
||||||
decoder_input, decoder_hidden)
|
topv, topi = decoder_output.topk(1)
|
||||||
topv, topi = decoder_output.data.topk(1)
|
decoded_words.append(topi)
|
||||||
if topi.item() == EOS_token:
|
decoder_input = topi.transpose(0, 1)
|
||||||
break
|
out = torch.stack(decoded_words)
|
||||||
else:
|
return out
|
||||||
decoded_words.append(output_lang.index2word[topi.item()])
|
|
||||||
|
|
||||||
decoder_input = topi.squeeze().detach()
|
|
||||||
|
|
||||||
return decoded_words
|
|
||||||
|
|
||||||
hidden_size = 256
|
hidden_size = 256
|
||||||
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
|
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
|
||||||
decoder = DecoderRNN(hidden_size, output_lang.n_words).to(device)
|
decoder = DecoderRNN(hidden_size, output_lang.n_words).to(device)
|
||||||
encoder.load_state_dict(torch.load('encoder.dict'))
|
encoder.load_state_dict(torch.load('encoder.dict'))
|
||||||
decoder.load_state_dict(torch.load('decoder.dict'))
|
decoder.load_state_dict(torch.load('decoder.dict'))
|
||||||
|
encoder.eval()
|
||||||
|
decoder.eval()
|
||||||
for line in sys.stdin:
|
for line in sys.stdin:
|
||||||
line = line.rstrip()
|
line = line.rstrip()
|
||||||
dec_words = evaluate(encoder, decoder, line, MAX_LENGTH)
|
dec_words = evaluate(encoder, decoder, line, MAX_LENGTH)
|
||||||
print(' '.join(dec_words))
|
dec_words = dec_words.transpose(0, 1)
|
||||||
|
for sen in dec_words:
|
||||||
|
out = []
|
||||||
|
for idx in sen:
|
||||||
|
if idx == 0:
|
||||||
|
break
|
||||||
|
out.append(output_lang.index2word[idx.item()])
|
||||||
|
print(' '.join(out))
|
||||||
|
@ -11,7 +11,7 @@ from torch import optim
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from lang import *
|
from lang import *
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
MAX_LENGTH = 300
|
MAX_LENGTH = 25
|
||||||
|
|
||||||
# Turn a Unicode string to plain ASCII, thanks to
|
# Turn a Unicode string to plain ASCII, thanks to
|
||||||
# https://stackoverflow.com/a/518232/2809427
|
# https://stackoverflow.com/a/518232/2809427
|
||||||
|
1200
test-A/out.tsv
1200
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
6
train.py
6
train.py
@ -1,7 +1,7 @@
|
|||||||
from lstm_model import EncoderRNN, DecoderRNN, AttnDecoderRNN
|
from lstm_model import EncoderRNN, DecoderRNN
|
||||||
from model_train import *
|
from model_train import *
|
||||||
hidden_size = 256
|
hidden_size = 256
|
||||||
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
|
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
|
||||||
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)
|
attn_decoder1 = DecoderRNN(hidden_size, output_lang.n_words).to(device)
|
||||||
|
|
||||||
trainIters(encoder1, attn_decoder1, 10000, print_every=100)
|
trainIters(encoder1, attn_decoder1, 5, print_every=1)
|
||||||
|
Loading…
Reference in New Issue
Block a user