Works
This commit is contained in:
parent
60775dec39
commit
51d468f7aa
@ -8,7 +8,7 @@ class Decoder(torch.nn.Module):
|
|||||||
|
|
||||||
self. embedding = torch.nn.Embedding(output_size, hidden_size)
|
self. embedding = torch.nn.Embedding(output_size, hidden_size)
|
||||||
#self.lstm = torch.nn.LSTM(hidden_size, output_size, num_layers=num_layers)
|
#self.lstm = torch.nn.LSTM(hidden_size, output_size, num_layers=num_layers)
|
||||||
self.lstm = torch.nn.LSTM(hidden_size, output_size)
|
self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
|
||||||
self.out = torch.nn.Linear(hidden_size, output_size)
|
self.out = torch.nn.Linear(hidden_size, output_size)
|
||||||
self.softmax = torch.nn.LogSoftmax(dim=1)
|
self.softmax = torch.nn.LogSoftmax(dim=1)
|
||||||
|
|
||||||
|
@ -70,7 +70,6 @@ def tensors_from_pair(pair, input_vocab, target_vocab):
|
|||||||
return (input_tensor, target_tensor)
|
return (input_tensor, target_tensor)
|
||||||
|
|
||||||
def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion, max_length=MAX_LENGTH):
|
def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion, max_length=MAX_LENGTH):
|
||||||
import ipdb; ipdb.set_trace()
|
|
||||||
if not checkpoint:
|
if not checkpoint:
|
||||||
encoder_hidden = encoder.init_hidden(device)
|
encoder_hidden = encoder.init_hidden(device)
|
||||||
|
|
||||||
@ -120,6 +119,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, l
|
|||||||
criterion = torch.nn.NLLLoss()
|
criterion = torch.nn.NLLLoss()
|
||||||
loss_total=0
|
loss_total=0
|
||||||
|
|
||||||
|
print("Start training")
|
||||||
for i in range(1, n_iters + 1):
|
for i in range(1, n_iters + 1):
|
||||||
training_pair = training_pairs[i - 1]
|
training_pair = training_pairs[i - 1]
|
||||||
input_tensor = training_pair[0]
|
input_tensor = training_pair[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user