import torch from torch import nn device = 'cuda' import torch.nn.functional as F import torch.nn.init as init from lang import SOS_token, EOS_token class EncoderRNN(nn.Module): def __init__(self, input_size, hidden_size): super(EncoderRNN, self).__init__() self.hidden_size = hidden_size self.embedding = nn.Embedding(input_size, hidden_size) self.lstm = nn.LSTM(hidden_size, hidden_size) def forward(self, input, hidden): embedded = self.embedding(input).view(1, 1, -1) output = embedded output, hidden = self.lstm(output, hidden) return output, hidden def initHidden(self): return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device)) class DecoderRNN(nn.Module): def __init__(self, hidden_size, output_size): super(DecoderRNN, self).__init__() self.hidden_size = hidden_size self.embedding = nn.Embedding(output_size, hidden_size) self.lstm = nn.LSTM(hidden_size, hidden_size) self.out = nn.Linear(hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden): output = self.embedding(input).view(1, 1, -1) output = F.relu(output) output, hidden = self.lstm(output, hidden) output = self.softmax(self.out(output[0])) return output, hidden def initHidden(self): return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device)) class AttnDecoderRNN(nn.Module): def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=300): super(AttnDecoderRNN, self).__init__() self.hidden_size = hidden_size self.output_size = output_size self.dropout_p = dropout_p self.max_length = max_length self.embedding = nn.Embedding(self.output_size, self.hidden_size) self.attn = nn.Linear(self.hidden_size, self.max_length) self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size) self.dropout = nn.Dropout(self.dropout_p) self.lstm = nn.LSTM(self.hidden_size, self.hidden_size) self.out = nn.Linear(self.hidden_size, self.output_size) def forward(self, input, hidden, encoder_outputs): embedded = self.embedding(input).view(1, 1, -1) embedded = self.dropout(embedded) attn_weights = F.softmax( self.attn(torch.cat((embedded, hidden[0]), 1)), dim=1) attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0)) output = torch.cat((embedded[0], attn_applied[0]), 1) output = self.attn_combine(output).unsqueeze(0) output = F.relu(output) output, hidden = self.lstm(output, hidden) output = F.log_softmax(self.out(output[0]), dim=1) print(output.shape, hidden.shape) return output, hidden, attn_weights def initHidden(self): return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device))