This commit is contained in:
SzamanFL 2021-01-23 19:42:56 +01:00
parent 60775dec39
commit 51d468f7aa
2 changed files with 2 additions and 2 deletions

View File

@ -8,7 +8,7 @@ class Decoder(torch.nn.Module):
self. embedding = torch.nn.Embedding(output_size, hidden_size) self. embedding = torch.nn.Embedding(output_size, hidden_size)
#self.lstm = torch.nn.LSTM(hidden_size, output_size, num_layers=num_layers) #self.lstm = torch.nn.LSTM(hidden_size, output_size, num_layers=num_layers)
self.lstm = torch.nn.LSTM(hidden_size, output_size) self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
self.out = torch.nn.Linear(hidden_size, output_size) self.out = torch.nn.Linear(hidden_size, output_size)
self.softmax = torch.nn.LogSoftmax(dim=1) self.softmax = torch.nn.LogSoftmax(dim=1)

View File

@ -70,7 +70,6 @@ def tensors_from_pair(pair, input_vocab, target_vocab):
return (input_tensor, target_tensor) return (input_tensor, target_tensor)
def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion, max_length=MAX_LENGTH): def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion, max_length=MAX_LENGTH):
import ipdb; ipdb.set_trace()
if not checkpoint: if not checkpoint:
encoder_hidden = encoder.init_hidden(device) encoder_hidden = encoder.init_hidden(device)
@ -120,6 +119,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, l
criterion = torch.nn.NLLLoss() criterion = torch.nn.NLLLoss()
loss_total=0 loss_total=0
print("Start training")
for i in range(1, n_iters + 1): for i in range(1, n_iters + 1):
training_pair = training_pairs[i - 1] training_pair = training_pairs[i - 1]
input_tensor = training_pair[0] input_tensor = training_pair[0]