From 60775dec3973527d19718e13c30c6c97f5a40d95 Mon Sep 17 00:00:00 2001 From: SzamanFL Date: Wed, 20 Jan 2021 09:48:43 +0100 Subject: [PATCH] Slight changes --- src/Decoder.py | 13 +++++++------ src/Encoder.py | 11 ++++++----- src/Vocab.py | 2 +- src/train.py | 46 ++++++++++++++++++++++++++++------------------ 4 files changed, 42 insertions(+), 30 deletions(-) diff --git a/src/Decoder.py b/src/Decoder.py index 3e2a2ad..14948be 100644 --- a/src/Decoder.py +++ b/src/Decoder.py @@ -1,15 +1,16 @@ -import torch.nn +import torch import torch.nn.functional as F -class Decoder: +class Decoder(torch.nn.Module): def __init__(self, hidden_size, output_size, num_layers=2): super(Decoder, self).__init__() self.hidden_size = hidden_size - self. embedding = nn.Embedding(output_size, hidden_size) - self.lstm = nn.LSTM(hidden_size, output_size, num_layers=num_layers) - self.out = nn.Linear(hidden_size, output_size) - self.softmax = nn.LogSoftmax(dim=1) + self. embedding = torch.nn.Embedding(output_size, hidden_size) + #self.lstm = torch.nn.LSTM(hidden_size, output_size, num_layers=num_layers) + self.lstm = torch.nn.LSTM(hidden_size, output_size) + self.out = torch.nn.Linear(hidden_size, output_size) + self.softmax = torch.nn.LogSoftmax(dim=1) def forward(self, x, hidden): embedded = self.embedding(x).view(1, 1, -1) diff --git a/src/Encoder.py b/src/Encoder.py index 055d5d9..cdd8f39 100644 --- a/src/Encoder.py +++ b/src/Encoder.py @@ -1,12 +1,13 @@ -import torch.nn +import torch -class Encoder(nn.Module): +class Encoder(torch.nn.Module): def __init__(self, input_size, hidden_size, num_layers=4): super(Encoder, self).__init__() self.hidden_size = hidden_size - self.embedding = nn.Embedding(input_size, hidden_size) - self.lstm = nn.LSTM(hidden_size, hidden_size. num_layers=num_layers) + self.embedding = torch.nn.Embedding(input_size, hidden_size) + #self.lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers=num_layers) + self.lstm = torch.nn.LSTM(hidden_size, hidden_size) def forward(self, x, hidden): embedded = self.embedding(x).view(1,1,-1) @@ -14,4 +15,4 @@ class Encoder(nn.Module): return output, hidden def init_hidden(self, device): - return torch.zeros(1, 1, self.hidden_size, device = device) + return (torch.zeros(1, 1, self.hidden_size, device = device), torch.zeros(1, 1, self.hidden_size, device = device)) diff --git a/src/Vocab.py b/src/Vocab.py index 8b8589f..6a5f103 100644 --- a/src/Vocab.py +++ b/src/Vocab.py @@ -8,7 +8,7 @@ class Vocab: def add_sentence(self, sentence): for word in sentence.split(' '): - self.addWord(word) + self.add_word(word) def add_word(self, word): if word not in self.word2index: diff --git a/src/train.py b/src/train.py index 11320cc..07925db 100644 --- a/src/train.py +++ b/src/train.py @@ -7,10 +7,13 @@ import unicodedata import torch import random import pickle +import re from Vocab import Vocab +from Encoder import Encoder +from Decoder import Decoder -MAX_LEN = 25 +MAX_LENGTH = 25 SOS=0 EOS=1 teacher_forcing_ratio=0.5 @@ -19,12 +22,12 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def clear_line(string, target): string = ''.join( - c for c in unicodedata.normalize('NFD', s) + c for c in unicodedata.normalize('NFD', string) if unicodedata.category(c) != 'Mn' ) target = ''.join( - c for c in unicodedata.normalize('NFD', s) + c for c in unicodedata.normalize('NFD', string) if unicodedata.category(c) != 'Mn' ) @@ -38,7 +41,7 @@ def read_clear_data(in_file_path, expected_file_path): with open(in_file_path) as in_file, open(expected_file_path) as exp_file: for string, target in zip(in_file, exp_file): string, target = clear_line(string, target) - if len(string.split(' ')) < MAX_LEN and len(target.split(' ')) < MAX_LEN: + if len(string.split(' ')) < MAX_LENGTH and len(target.split(' ')) < MAX_LENGTH: pairs.append([string, target]) input_vocab = Vocab("pl") target_vocab = Vocab("en") @@ -48,8 +51,8 @@ def prepare_data(in_file_path, expected_file_path): pairs, input_vocab, target_vocab = read_clear_data(in_file_path, expected_file_path) for pair in pairs: - input_lang.add_sentence(pair[0]) - target_lang.add_sentence(pair[1]) + input_vocab.add_sentence(pair[0]) + target_vocab.add_sentence(pair[1]) return pairs, input_vocab, target_vocab @@ -67,6 +70,7 @@ def tensors_from_pair(pair, input_vocab, target_vocab): return (input_tensor, target_tensor) def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion, max_length=MAX_LENGTH): + import ipdb; ipdb.set_trace() if not checkpoint: encoder_hidden = encoder.init_hidden(device) @@ -81,7 +85,7 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_ loss = 0 for e in range(input_len): encoder_output, encoder_hidden = encoder(input_tensor[e], encoder_hidden) - encoder_outputs[i] = encoder_output[0, 0] + encoder_outputs[e] = encoder_output[0, 0] decoder_hidden = encoder_hidden decoder_input = torch.tensor([[SOS]], device=device) @@ -108,10 +112,11 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_ encoder_optim.step() return loss.item()/ target_len -def train_iterate(pairs, encoder, decoder, n_iters, lr=0.01): +def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, lr=0.01): encoder_optim = torch.optim.SGD(encoder.parameters(), lr=lr) decoder_optim = torch.optim.SGD(decoder.parameters(), lr=lr) - training_pairs = [tensors_from_pair(random.choice(pairs)) for i in range(n_iters)] + #import ipdb; ipdb.set_trace() + training_pairs = [tensors_from_pair(random.choice(pairs), input_vocab, target_vocab) for i in range(n_iters)] criterion = torch.nn.NLLLoss() loss_total=0 @@ -120,7 +125,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, lr=0.01): input_tensor = training_pair[0] target_tensor = training_pair[1] - loss = train(input_tensor, target_tensor, encoder, de, encoder_optim, decoder_optim, criterion) + loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion) loss_total += loss if i % 1000 == 0: @@ -142,30 +147,35 @@ def main(): parser.add_argument("--seed") args = parser.parse_args() + global seed if args.seed: seed = int(args.seed) else: - seed = random.rand - - global seed - + seed = random.randint(0,50) + print(seed) + + #global input_vocab + #global target_vocab if args.vocab: - with open(args.vocab, 'wb+') as p: + with open(args.vocab, 'rb') as p: pairs, input_vocab, target_vocab = pickle.load(p) else: pairs, input_vocab, target_vocab = prepare_data(args.in_f, args.exp) - with open("vocabs.pckl", 'rb') as p: + with open("vocabs.pckl", 'wb+') as p: pickle.dump([pairs, input_vocab, target_vocab], p) hidden_size = 256 encoder = Encoder(input_vocab.size, hidden_size).to(device) decoder = Decoder(hidden_size, target_vocab.size).to(device) - + global checkpoint + checkpoint = False if args.encoder: encoder.load_state_dict(torch.load(args.encoder)) + + checkpoint = True if args.decoder: decoder.load_state_dict(torch.load(args.decoder)) - train_iterate(pairs, encoder, decoder, 50000) + train_iterate(pairs, encoder, decoder, 50000, input_vocab, target_vocab) main()