Slight changes

This commit is contained in:
SzamanFL 2021-01-20 09:48:43 +01:00
parent ba956127bd
commit 60775dec39
4 changed files with 42 additions and 30 deletions

View File

@ -1,15 +1,16 @@
import torch.nn import torch
import torch.nn.functional as F import torch.nn.functional as F
class Decoder: class Decoder(torch.nn.Module):
def __init__(self, hidden_size, output_size, num_layers=2): def __init__(self, hidden_size, output_size, num_layers=2):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self. embedding = nn.Embedding(output_size, hidden_size) self. embedding = torch.nn.Embedding(output_size, hidden_size)
self.lstm = nn.LSTM(hidden_size, output_size, num_layers=num_layers) #self.lstm = torch.nn.LSTM(hidden_size, output_size, num_layers=num_layers)
self.out = nn.Linear(hidden_size, output_size) self.lstm = torch.nn.LSTM(hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1) self.out = torch.nn.Linear(hidden_size, output_size)
self.softmax = torch.nn.LogSoftmax(dim=1)
def forward(self, x, hidden): def forward(self, x, hidden):
embedded = self.embedding(x).view(1, 1, -1) embedded = self.embedding(x).view(1, 1, -1)

View File

@ -1,12 +1,13 @@
import torch.nn import torch
class Encoder(nn.Module): class Encoder(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_layers=4): def __init__(self, input_size, hidden_size, num_layers=4):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size) self.embedding = torch.nn.Embedding(input_size, hidden_size)
self.lstm = nn.LSTM(hidden_size, hidden_size. num_layers=num_layers) #self.lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers=num_layers)
self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
def forward(self, x, hidden): def forward(self, x, hidden):
embedded = self.embedding(x).view(1,1,-1) embedded = self.embedding(x).view(1,1,-1)
@ -14,4 +15,4 @@ class Encoder(nn.Module):
return output, hidden return output, hidden
def init_hidden(self, device): def init_hidden(self, device):
return torch.zeros(1, 1, self.hidden_size, device = device) return (torch.zeros(1, 1, self.hidden_size, device = device), torch.zeros(1, 1, self.hidden_size, device = device))

View File

@ -8,7 +8,7 @@ class Vocab:
def add_sentence(self, sentence): def add_sentence(self, sentence):
for word in sentence.split(' '): for word in sentence.split(' '):
self.addWord(word) self.add_word(word)
def add_word(self, word): def add_word(self, word):
if word not in self.word2index: if word not in self.word2index:

View File

@ -7,10 +7,13 @@ import unicodedata
import torch import torch
import random import random
import pickle import pickle
import re
from Vocab import Vocab from Vocab import Vocab
from Encoder import Encoder
from Decoder import Decoder
MAX_LEN = 25 MAX_LENGTH = 25
SOS=0 SOS=0
EOS=1 EOS=1
teacher_forcing_ratio=0.5 teacher_forcing_ratio=0.5
@ -19,12 +22,12 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def clear_line(string, target): def clear_line(string, target):
string = ''.join( string = ''.join(
c for c in unicodedata.normalize('NFD', s) c for c in unicodedata.normalize('NFD', string)
if unicodedata.category(c) != 'Mn' if unicodedata.category(c) != 'Mn'
) )
target = ''.join( target = ''.join(
c for c in unicodedata.normalize('NFD', s) c for c in unicodedata.normalize('NFD', string)
if unicodedata.category(c) != 'Mn' if unicodedata.category(c) != 'Mn'
) )
@ -38,7 +41,7 @@ def read_clear_data(in_file_path, expected_file_path):
with open(in_file_path) as in_file, open(expected_file_path) as exp_file: with open(in_file_path) as in_file, open(expected_file_path) as exp_file:
for string, target in zip(in_file, exp_file): for string, target in zip(in_file, exp_file):
string, target = clear_line(string, target) string, target = clear_line(string, target)
if len(string.split(' ')) < MAX_LEN and len(target.split(' ')) < MAX_LEN: if len(string.split(' ')) < MAX_LENGTH and len(target.split(' ')) < MAX_LENGTH:
pairs.append([string, target]) pairs.append([string, target])
input_vocab = Vocab("pl") input_vocab = Vocab("pl")
target_vocab = Vocab("en") target_vocab = Vocab("en")
@ -48,8 +51,8 @@ def prepare_data(in_file_path, expected_file_path):
pairs, input_vocab, target_vocab = read_clear_data(in_file_path, expected_file_path) pairs, input_vocab, target_vocab = read_clear_data(in_file_path, expected_file_path)
for pair in pairs: for pair in pairs:
input_lang.add_sentence(pair[0]) input_vocab.add_sentence(pair[0])
target_lang.add_sentence(pair[1]) target_vocab.add_sentence(pair[1])
return pairs, input_vocab, target_vocab return pairs, input_vocab, target_vocab
@ -67,6 +70,7 @@ def tensors_from_pair(pair, input_vocab, target_vocab):
return (input_tensor, target_tensor) return (input_tensor, target_tensor)
def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion, max_length=MAX_LENGTH): def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion, max_length=MAX_LENGTH):
import ipdb; ipdb.set_trace()
if not checkpoint: if not checkpoint:
encoder_hidden = encoder.init_hidden(device) encoder_hidden = encoder.init_hidden(device)
@ -81,7 +85,7 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_
loss = 0 loss = 0
for e in range(input_len): for e in range(input_len):
encoder_output, encoder_hidden = encoder(input_tensor[e], encoder_hidden) encoder_output, encoder_hidden = encoder(input_tensor[e], encoder_hidden)
encoder_outputs[i] = encoder_output[0, 0] encoder_outputs[e] = encoder_output[0, 0]
decoder_hidden = encoder_hidden decoder_hidden = encoder_hidden
decoder_input = torch.tensor([[SOS]], device=device) decoder_input = torch.tensor([[SOS]], device=device)
@ -108,10 +112,11 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_
encoder_optim.step() encoder_optim.step()
return loss.item()/ target_len return loss.item()/ target_len
def train_iterate(pairs, encoder, decoder, n_iters, lr=0.01): def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, lr=0.01):
encoder_optim = torch.optim.SGD(encoder.parameters(), lr=lr) encoder_optim = torch.optim.SGD(encoder.parameters(), lr=lr)
decoder_optim = torch.optim.SGD(decoder.parameters(), lr=lr) decoder_optim = torch.optim.SGD(decoder.parameters(), lr=lr)
training_pairs = [tensors_from_pair(random.choice(pairs)) for i in range(n_iters)] #import ipdb; ipdb.set_trace()
training_pairs = [tensors_from_pair(random.choice(pairs), input_vocab, target_vocab) for i in range(n_iters)]
criterion = torch.nn.NLLLoss() criterion = torch.nn.NLLLoss()
loss_total=0 loss_total=0
@ -120,7 +125,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, lr=0.01):
input_tensor = training_pair[0] input_tensor = training_pair[0]
target_tensor = training_pair[1] target_tensor = training_pair[1]
loss = train(input_tensor, target_tensor, encoder, de, encoder_optim, decoder_optim, criterion) loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion)
loss_total += loss loss_total += loss
if i % 1000 == 0: if i % 1000 == 0:
@ -142,30 +147,35 @@ def main():
parser.add_argument("--seed") parser.add_argument("--seed")
args = parser.parse_args() args = parser.parse_args()
global seed
if args.seed: if args.seed:
seed = int(args.seed) seed = int(args.seed)
else: else:
seed = random.rand seed = random.randint(0,50)
print(seed)
global seed
#global input_vocab
#global target_vocab
if args.vocab: if args.vocab:
with open(args.vocab, 'wb+') as p: with open(args.vocab, 'rb') as p:
pairs, input_vocab, target_vocab = pickle.load(p) pairs, input_vocab, target_vocab = pickle.load(p)
else: else:
pairs, input_vocab, target_vocab = prepare_data(args.in_f, args.exp) pairs, input_vocab, target_vocab = prepare_data(args.in_f, args.exp)
with open("vocabs.pckl", 'rb') as p: with open("vocabs.pckl", 'wb+') as p:
pickle.dump([pairs, input_vocab, target_vocab], p) pickle.dump([pairs, input_vocab, target_vocab], p)
hidden_size = 256 hidden_size = 256
encoder = Encoder(input_vocab.size, hidden_size).to(device) encoder = Encoder(input_vocab.size, hidden_size).to(device)
decoder = Decoder(hidden_size, target_vocab.size).to(device) decoder = Decoder(hidden_size, target_vocab.size).to(device)
global checkpoint
checkpoint = False
if args.encoder: if args.encoder:
encoder.load_state_dict(torch.load(args.encoder)) encoder.load_state_dict(torch.load(args.encoder))
checkpoint = True
if args.decoder: if args.decoder:
decoder.load_state_dict(torch.load(args.decoder)) decoder.load_state_dict(torch.load(args.decoder))
train_iterate(pairs, encoder, decoder, 50000) train_iterate(pairs, encoder, decoder, 50000, input_vocab, target_vocab)
main() main()