From 51d468f7aa421ec64436c2ed3f1eb23e411924c0 Mon Sep 17 00:00:00 2001 From: SzamanFL Date: Sat, 23 Jan 2021 19:42:56 +0100 Subject: [PATCH] Works --- src/Decoder.py | 2 +- src/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Decoder.py b/src/Decoder.py index 14948be..46ddb54 100644 --- a/src/Decoder.py +++ b/src/Decoder.py @@ -8,7 +8,7 @@ class Decoder(torch.nn.Module): self. embedding = torch.nn.Embedding(output_size, hidden_size) #self.lstm = torch.nn.LSTM(hidden_size, output_size, num_layers=num_layers) - self.lstm = torch.nn.LSTM(hidden_size, output_size) + self.lstm = torch.nn.LSTM(hidden_size, hidden_size) self.out = torch.nn.Linear(hidden_size, output_size) self.softmax = torch.nn.LogSoftmax(dim=1) diff --git a/src/train.py b/src/train.py index 07925db..6a40ddf 100644 --- a/src/train.py +++ b/src/train.py @@ -70,7 +70,6 @@ def tensors_from_pair(pair, input_vocab, target_vocab): return (input_tensor, target_tensor) def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion, max_length=MAX_LENGTH): - import ipdb; ipdb.set_trace() if not checkpoint: encoder_hidden = encoder.init_hidden(device) @@ -120,6 +119,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, l criterion = torch.nn.NLLLoss() loss_total=0 + print("Start training") for i in range(1, n_iters + 1): training_pair = training_pairs[i - 1] input_tensor = training_pair[0]