Now using multiple lstm layers

This commit is contained in:
SzamanFL 2021-01-24 17:07:05 +01:00
parent e45b0118bd
commit ecd88b2531
2 changed files with 9 additions and 7 deletions

View File

@ -2,13 +2,14 @@ import torch
import torch.nn.functional as F
class Decoder(torch.nn.Module):
def __init__(self, hidden_size, output_size, num_layers=2):
def __init__(self, hidden_size, output_size, num_layers=3):
super(Decoder, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self. embedding = torch.nn.Embedding(output_size, hidden_size)
#self.lstm = torch.nn.LSTM(hidden_size, output_size, num_layers=num_layers)
self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
self.lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers=num_layers)
#self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
self.out = torch.nn.Linear(hidden_size, output_size)
self.softmax = torch.nn.LogSoftmax(dim=1)

View File

@ -1,13 +1,14 @@
import torch
class Encoder(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_layers=4):
def __init__(self, input_size, hidden_size, num_layers=3):
super(Encoder, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.embedding = torch.nn.Embedding(input_size, hidden_size)
#self.lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers=num_layers)
self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
self.lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers=num_layers)
#self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
def forward(self, x, hidden):
embedded = self.embedding(x).view(1,1,-1)
@ -15,4 +16,4 @@ class Encoder(torch.nn.Module):
return output, hidden
def init_hidden(self, device):
return (torch.zeros(1, 1, self.hidden_size, device = device), torch.zeros(1, 1, self.hidden_size, device = device))
return (torch.zeros(self.num_layers, 1, self.hidden_size, device = device), torch.zeros(self.num_layers, 1, self.hidden_size, device = device))