From ecd88b2531ca771b17dc6b03b1b85d3090ca74ef Mon Sep 17 00:00:00 2001 From: SzamanFL Date: Sun, 24 Jan 2021 17:07:05 +0100 Subject: [PATCH] Now using multiple lstm layers --- src/Decoder.py | 7 ++++--- src/Encoder.py | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/Decoder.py b/src/Decoder.py index 46ddb54..10bad90 100644 --- a/src/Decoder.py +++ b/src/Decoder.py @@ -2,13 +2,14 @@ import torch import torch.nn.functional as F class Decoder(torch.nn.Module): - def __init__(self, hidden_size, output_size, num_layers=2): + def __init__(self, hidden_size, output_size, num_layers=3): super(Decoder, self).__init__() self.hidden_size = hidden_size + self.num_layers = num_layers self. embedding = torch.nn.Embedding(output_size, hidden_size) - #self.lstm = torch.nn.LSTM(hidden_size, output_size, num_layers=num_layers) - self.lstm = torch.nn.LSTM(hidden_size, hidden_size) + self.lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers=num_layers) + #self.lstm = torch.nn.LSTM(hidden_size, hidden_size) self.out = torch.nn.Linear(hidden_size, output_size) self.softmax = torch.nn.LogSoftmax(dim=1) diff --git a/src/Encoder.py b/src/Encoder.py index cdd8f39..a03b594 100644 --- a/src/Encoder.py +++ b/src/Encoder.py @@ -1,13 +1,14 @@ import torch class Encoder(torch.nn.Module): - def __init__(self, input_size, hidden_size, num_layers=4): + def __init__(self, input_size, hidden_size, num_layers=3): super(Encoder, self).__init__() self.hidden_size = hidden_size + self.num_layers = num_layers self.embedding = torch.nn.Embedding(input_size, hidden_size) - #self.lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers=num_layers) - self.lstm = torch.nn.LSTM(hidden_size, hidden_size) + self.lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers=num_layers) + #self.lstm = torch.nn.LSTM(hidden_size, hidden_size) def forward(self, x, hidden): embedded = self.embedding(x).view(1,1,-1) @@ -15,4 +16,4 @@ class Encoder(torch.nn.Module): return output, hidden def init_hidden(self, device): - return (torch.zeros(1, 1, self.hidden_size, device = device), torch.zeros(1, 1, self.hidden_size, device = device)) + return (torch.zeros(self.num_layers, 1, self.hidden_size, device = device), torch.zeros(self.num_layers, 1, self.hidden_size, device = device))