wmt-2020-pl-en/lstm_model.py
2021-01-31 16:54:20 +01:00

39 lines
1.3 KiB
Python

import torch
from torch import nn
device = 'cuda'
import torch.nn.functional as F
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.lstm = nn.LSTM(hidden_size, hidden_size)
def forward(self, input, hidden):
embedded = self.embedding(input)
output = embedded
output, hidden = self.lstm(output, hidden)
return output, hidden
def initHidden(self):
return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device))
class DecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size):
super(DecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(output_size, hidden_size)
self.lstm = nn.LSTM(hidden_size, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
output = self.embedding(input)
output = F.relu(output)
output, hidden = self.lstm(output, hidden)
output = self.softmax(self.out(output[0]))
return output, hidden