Now using multiple lstm layers
This commit is contained in:
parent
e45b0118bd
commit
ecd88b2531
@ -2,13 +2,14 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Decoder(torch.nn.Module):
|
||||
def __init__(self, hidden_size, output_size, num_layers=2):
|
||||
def __init__(self, hidden_size, output_size, num_layers=3):
|
||||
super(Decoder, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
|
||||
self. embedding = torch.nn.Embedding(output_size, hidden_size)
|
||||
#self.lstm = torch.nn.LSTM(hidden_size, output_size, num_layers=num_layers)
|
||||
self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
|
||||
self.lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers=num_layers)
|
||||
#self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
|
||||
self.out = torch.nn.Linear(hidden_size, output_size)
|
||||
self.softmax = torch.nn.LogSoftmax(dim=1)
|
||||
|
||||
|
@ -1,13 +1,14 @@
|
||||
import torch
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers=4):
|
||||
def __init__(self, input_size, hidden_size, num_layers=3):
|
||||
super(Encoder, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.embedding = torch.nn.Embedding(input_size, hidden_size)
|
||||
#self.lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers=num_layers)
|
||||
self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
|
||||
self.lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers=num_layers)
|
||||
#self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, x, hidden):
|
||||
embedded = self.embedding(x).view(1,1,-1)
|
||||
@ -15,4 +16,4 @@ class Encoder(torch.nn.Module):
|
||||
return output, hidden
|
||||
|
||||
def init_hidden(self, device):
|
||||
return (torch.zeros(1, 1, self.hidden_size, device = device), torch.zeros(1, 1, self.hidden_size, device = device))
|
||||
return (torch.zeros(self.num_layers, 1, self.hidden_size, device = device), torch.zeros(self.num_layers, 1, self.hidden_size, device = device))
|
||||
|
Loading…
Reference in New Issue
Block a user