78 lines
3.1 KiB
Python
78 lines
3.1 KiB
Python
import torch
|
|
from torch import nn
|
|
device = 'cuda'
|
|
import torch.nn.functional as F
|
|
import torch.nn.init as init
|
|
from lang import SOS_token, EOS_token
|
|
|
|
class EncoderRNN(nn.Module):
|
|
def __init__(self, input_size, hidden_size):
|
|
super(EncoderRNN, self).__init__()
|
|
self.hidden_size = hidden_size
|
|
|
|
self.embedding = nn.Embedding(input_size, hidden_size)
|
|
self.lstm = nn.LSTM(hidden_size, hidden_size)
|
|
|
|
def forward(self, input, hidden):
|
|
embedded = self.embedding(input).view(1, 1, -1)
|
|
output = embedded
|
|
output, hidden = self.lstm(output, hidden)
|
|
return output, hidden
|
|
|
|
def initHidden(self):
|
|
return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device))
|
|
|
|
class DecoderRNN(nn.Module):
|
|
def __init__(self, hidden_size, output_size):
|
|
super(DecoderRNN, self).__init__()
|
|
self.hidden_size = hidden_size
|
|
|
|
self.embedding = nn.Embedding(output_size, hidden_size)
|
|
self.lstm = nn.LSTM(hidden_size, hidden_size)
|
|
self.out = nn.Linear(hidden_size, output_size)
|
|
self.softmax = nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, hidden):
|
|
output = self.embedding(input).view(1, 1, -1)
|
|
output = F.relu(output)
|
|
output, hidden = self.lstm(output, hidden)
|
|
output = self.softmax(self.out(output[0]))
|
|
return output, hidden
|
|
|
|
def initHidden(self):
|
|
return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device))
|
|
|
|
class AttnDecoderRNN(nn.Module):
|
|
def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=300):
|
|
super(AttnDecoderRNN, self).__init__()
|
|
self.hidden_size = hidden_size
|
|
self.output_size = output_size
|
|
self.dropout_p = dropout_p
|
|
self.max_length = max_length
|
|
|
|
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
|
|
self.attn = nn.Linear(self.hidden_size, self.max_length)
|
|
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
|
|
self.dropout = nn.Dropout(self.dropout_p)
|
|
self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)
|
|
self.out = nn.Linear(self.hidden_size, self.output_size)
|
|
|
|
def forward(self, input, hidden, encoder_outputs):
|
|
embedded = self.embedding(input).view(1, 1, -1)
|
|
embedded = self.dropout(embedded)
|
|
attn_weights = F.softmax(
|
|
self.attn(torch.cat((embedded, hidden[0]), 1)), dim=1)
|
|
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
|
|
encoder_outputs.unsqueeze(0))
|
|
output = torch.cat((embedded[0], attn_applied[0]), 1)
|
|
output = self.attn_combine(output).unsqueeze(0)
|
|
|
|
output = F.relu(output)
|
|
output, hidden = self.lstm(output, hidden)
|
|
|
|
output = F.log_softmax(self.out(output[0]), dim=1)
|
|
print(output.shape, hidden.shape)
|
|
return output, hidden, attn_weights
|
|
|
|
def initHidden(self):
|
|
return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device)) |