Slight changes
This commit is contained in:
parent
ba956127bd
commit
60775dec39
@ -1,15 +1,16 @@
|
||||
import torch.nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Decoder:
|
||||
class Decoder(torch.nn.Module):
|
||||
def __init__(self, hidden_size, output_size, num_layers=2):
|
||||
super(Decoder, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self. embedding = nn.Embedding(output_size, hidden_size)
|
||||
self.lstm = nn.LSTM(hidden_size, output_size, num_layers=num_layers)
|
||||
self.out = nn.Linear(hidden_size, output_size)
|
||||
self.softmax = nn.LogSoftmax(dim=1)
|
||||
self. embedding = torch.nn.Embedding(output_size, hidden_size)
|
||||
#self.lstm = torch.nn.LSTM(hidden_size, output_size, num_layers=num_layers)
|
||||
self.lstm = torch.nn.LSTM(hidden_size, output_size)
|
||||
self.out = torch.nn.Linear(hidden_size, output_size)
|
||||
self.softmax = torch.nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, x, hidden):
|
||||
embedded = self.embedding(x).view(1, 1, -1)
|
||||
|
@ -1,12 +1,13 @@
|
||||
import torch.nn
|
||||
import torch
|
||||
|
||||
class Encoder(nn.Module):
|
||||
class Encoder(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers=4):
|
||||
super(Encoder, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.embedding = nn.Embedding(input_size, hidden_size)
|
||||
self.lstm = nn.LSTM(hidden_size, hidden_size. num_layers=num_layers)
|
||||
self.embedding = torch.nn.Embedding(input_size, hidden_size)
|
||||
#self.lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers=num_layers)
|
||||
self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, x, hidden):
|
||||
embedded = self.embedding(x).view(1,1,-1)
|
||||
@ -14,4 +15,4 @@ class Encoder(nn.Module):
|
||||
return output, hidden
|
||||
|
||||
def init_hidden(self, device):
|
||||
return torch.zeros(1, 1, self.hidden_size, device = device)
|
||||
return (torch.zeros(1, 1, self.hidden_size, device = device), torch.zeros(1, 1, self.hidden_size, device = device))
|
||||
|
@ -8,7 +8,7 @@ class Vocab:
|
||||
|
||||
def add_sentence(self, sentence):
|
||||
for word in sentence.split(' '):
|
||||
self.addWord(word)
|
||||
self.add_word(word)
|
||||
|
||||
def add_word(self, word):
|
||||
if word not in self.word2index:
|
||||
|
46
src/train.py
46
src/train.py
@ -7,10 +7,13 @@ import unicodedata
|
||||
import torch
|
||||
import random
|
||||
import pickle
|
||||
import re
|
||||
|
||||
from Vocab import Vocab
|
||||
from Encoder import Encoder
|
||||
from Decoder import Decoder
|
||||
|
||||
MAX_LEN = 25
|
||||
MAX_LENGTH = 25
|
||||
SOS=0
|
||||
EOS=1
|
||||
teacher_forcing_ratio=0.5
|
||||
@ -19,12 +22,12 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def clear_line(string, target):
|
||||
string = ''.join(
|
||||
c for c in unicodedata.normalize('NFD', s)
|
||||
c for c in unicodedata.normalize('NFD', string)
|
||||
if unicodedata.category(c) != 'Mn'
|
||||
)
|
||||
|
||||
target = ''.join(
|
||||
c for c in unicodedata.normalize('NFD', s)
|
||||
c for c in unicodedata.normalize('NFD', string)
|
||||
if unicodedata.category(c) != 'Mn'
|
||||
)
|
||||
|
||||
@ -38,7 +41,7 @@ def read_clear_data(in_file_path, expected_file_path):
|
||||
with open(in_file_path) as in_file, open(expected_file_path) as exp_file:
|
||||
for string, target in zip(in_file, exp_file):
|
||||
string, target = clear_line(string, target)
|
||||
if len(string.split(' ')) < MAX_LEN and len(target.split(' ')) < MAX_LEN:
|
||||
if len(string.split(' ')) < MAX_LENGTH and len(target.split(' ')) < MAX_LENGTH:
|
||||
pairs.append([string, target])
|
||||
input_vocab = Vocab("pl")
|
||||
target_vocab = Vocab("en")
|
||||
@ -48,8 +51,8 @@ def prepare_data(in_file_path, expected_file_path):
|
||||
pairs, input_vocab, target_vocab = read_clear_data(in_file_path, expected_file_path)
|
||||
|
||||
for pair in pairs:
|
||||
input_lang.add_sentence(pair[0])
|
||||
target_lang.add_sentence(pair[1])
|
||||
input_vocab.add_sentence(pair[0])
|
||||
target_vocab.add_sentence(pair[1])
|
||||
|
||||
return pairs, input_vocab, target_vocab
|
||||
|
||||
@ -67,6 +70,7 @@ def tensors_from_pair(pair, input_vocab, target_vocab):
|
||||
return (input_tensor, target_tensor)
|
||||
|
||||
def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion, max_length=MAX_LENGTH):
|
||||
import ipdb; ipdb.set_trace()
|
||||
if not checkpoint:
|
||||
encoder_hidden = encoder.init_hidden(device)
|
||||
|
||||
@ -81,7 +85,7 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_
|
||||
loss = 0
|
||||
for e in range(input_len):
|
||||
encoder_output, encoder_hidden = encoder(input_tensor[e], encoder_hidden)
|
||||
encoder_outputs[i] = encoder_output[0, 0]
|
||||
encoder_outputs[e] = encoder_output[0, 0]
|
||||
decoder_hidden = encoder_hidden
|
||||
|
||||
decoder_input = torch.tensor([[SOS]], device=device)
|
||||
@ -108,10 +112,11 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_
|
||||
encoder_optim.step()
|
||||
return loss.item()/ target_len
|
||||
|
||||
def train_iterate(pairs, encoder, decoder, n_iters, lr=0.01):
|
||||
def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, lr=0.01):
|
||||
encoder_optim = torch.optim.SGD(encoder.parameters(), lr=lr)
|
||||
decoder_optim = torch.optim.SGD(decoder.parameters(), lr=lr)
|
||||
training_pairs = [tensors_from_pair(random.choice(pairs)) for i in range(n_iters)]
|
||||
#import ipdb; ipdb.set_trace()
|
||||
training_pairs = [tensors_from_pair(random.choice(pairs), input_vocab, target_vocab) for i in range(n_iters)]
|
||||
criterion = torch.nn.NLLLoss()
|
||||
loss_total=0
|
||||
|
||||
@ -120,7 +125,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, lr=0.01):
|
||||
input_tensor = training_pair[0]
|
||||
target_tensor = training_pair[1]
|
||||
|
||||
loss = train(input_tensor, target_tensor, encoder, de, encoder_optim, decoder_optim, criterion)
|
||||
loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion)
|
||||
loss_total += loss
|
||||
|
||||
if i % 1000 == 0:
|
||||
@ -142,30 +147,35 @@ def main():
|
||||
parser.add_argument("--seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
global seed
|
||||
if args.seed:
|
||||
seed = int(args.seed)
|
||||
else:
|
||||
seed = random.rand
|
||||
|
||||
global seed
|
||||
|
||||
seed = random.randint(0,50)
|
||||
print(seed)
|
||||
|
||||
#global input_vocab
|
||||
#global target_vocab
|
||||
if args.vocab:
|
||||
with open(args.vocab, 'wb+') as p:
|
||||
with open(args.vocab, 'rb') as p:
|
||||
pairs, input_vocab, target_vocab = pickle.load(p)
|
||||
else:
|
||||
pairs, input_vocab, target_vocab = prepare_data(args.in_f, args.exp)
|
||||
with open("vocabs.pckl", 'rb') as p:
|
||||
with open("vocabs.pckl", 'wb+') as p:
|
||||
pickle.dump([pairs, input_vocab, target_vocab], p)
|
||||
|
||||
hidden_size = 256
|
||||
encoder = Encoder(input_vocab.size, hidden_size).to(device)
|
||||
decoder = Decoder(hidden_size, target_vocab.size).to(device)
|
||||
|
||||
global checkpoint
|
||||
checkpoint = False
|
||||
if args.encoder:
|
||||
encoder.load_state_dict(torch.load(args.encoder))
|
||||
|
||||
checkpoint = True
|
||||
if args.decoder:
|
||||
decoder.load_state_dict(torch.load(args.decoder))
|
||||
|
||||
train_iterate(pairs, encoder, decoder, 50000)
|
||||
train_iterate(pairs, encoder, decoder, 50000, input_vocab, target_vocab)
|
||||
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user