Works
This commit is contained in:
parent
60775dec39
commit
51d468f7aa
@ -8,7 +8,7 @@ class Decoder(torch.nn.Module):
|
||||
|
||||
self. embedding = torch.nn.Embedding(output_size, hidden_size)
|
||||
#self.lstm = torch.nn.LSTM(hidden_size, output_size, num_layers=num_layers)
|
||||
self.lstm = torch.nn.LSTM(hidden_size, output_size)
|
||||
self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
|
||||
self.out = torch.nn.Linear(hidden_size, output_size)
|
||||
self.softmax = torch.nn.LogSoftmax(dim=1)
|
||||
|
||||
|
@ -70,7 +70,6 @@ def tensors_from_pair(pair, input_vocab, target_vocab):
|
||||
return (input_tensor, target_tensor)
|
||||
|
||||
def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion, max_length=MAX_LENGTH):
|
||||
import ipdb; ipdb.set_trace()
|
||||
if not checkpoint:
|
||||
encoder_hidden = encoder.init_hidden(device)
|
||||
|
||||
@ -120,6 +119,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, l
|
||||
criterion = torch.nn.NLLLoss()
|
||||
loss_total=0
|
||||
|
||||
print("Start training")
|
||||
for i in range(1, n_iters + 1):
|
||||
training_pair = training_pairs[i - 1]
|
||||
input_tensor = training_pair[0]
|
||||
|
Loading…
Reference in New Issue
Block a user