First draft of Machine translation
This commit is contained in:
parent
9ec62bc7bb
commit
ba956127bd
23
src/Decoder.py
Normal file
23
src/Decoder.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import torch.nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class Decoder:
|
||||||
|
def __init__(self, hidden_size, output_size, num_layers=2):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
self. embedding = nn.Embedding(output_size, hidden_size)
|
||||||
|
self.lstm = nn.LSTM(hidden_size, output_size, num_layers=num_layers)
|
||||||
|
self.out = nn.Linear(hidden_size, output_size)
|
||||||
|
self.softmax = nn.LogSoftmax(dim=1)
|
||||||
|
|
||||||
|
def forward(self, x, hidden):
|
||||||
|
embedded = self.embedding(x).view(1, 1, -1)
|
||||||
|
output = F.relu(embedded)
|
||||||
|
output, hidden = self.lstm(output, hidden)
|
||||||
|
output = self.softmax(self.out(output[0]))
|
||||||
|
return output, hidden
|
||||||
|
|
||||||
|
def init_hidden(self, device):
|
||||||
|
return torch.zeros(1, 1, self.hidden_size, device=device)
|
||||||
|
|
17
src/Encoder.py
Normal file
17
src/Encoder.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import torch.nn
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, num_layers=4):
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(input_size, hidden_size)
|
||||||
|
self.lstm = nn.LSTM(hidden_size, hidden_size. num_layers=num_layers)
|
||||||
|
|
||||||
|
def forward(self, x, hidden):
|
||||||
|
embedded = self.embedding(x).view(1,1,-1)
|
||||||
|
output, hidden = self.lstm(embedded, hidden)
|
||||||
|
return output, hidden
|
||||||
|
|
||||||
|
def init_hidden(self, device):
|
||||||
|
return torch.zeros(1, 1, self.hidden_size, device = device)
|
20
src/Vocab.py
Normal file
20
src/Vocab.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
class Vocab:
|
||||||
|
def __init__(self, lang):
|
||||||
|
self.lang = lang
|
||||||
|
self.word2index = {}
|
||||||
|
self.word2count = {}
|
||||||
|
self.index2word = {0 : "SOS", 1: "EOS"}
|
||||||
|
self.size = 2
|
||||||
|
|
||||||
|
def add_sentence(self, sentence):
|
||||||
|
for word in sentence.split(' '):
|
||||||
|
self.addWord(word)
|
||||||
|
|
||||||
|
def add_word(self, word):
|
||||||
|
if word not in self.word2index:
|
||||||
|
self.word2index[word] = self.size
|
||||||
|
self.word2count[word] = 1
|
||||||
|
self.index2word[self.size] = word
|
||||||
|
self.size += 1
|
||||||
|
else:
|
||||||
|
self.word2count[word] += 1
|
160
src/train.py
160
src/train.py
@ -3,35 +3,169 @@
|
|||||||
# an LSTM language model trained on sentence pairs
|
# an LSTM language model trained on sentence pairs
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from collection import Counter
|
import unicodedata
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from Vocab import Vocab
|
||||||
|
|
||||||
|
MAX_LEN = 25
|
||||||
|
SOS=0
|
||||||
|
EOS=1
|
||||||
|
teacher_forcing_ratio=0.5
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
def clear_line(string, target):
|
def clear_line(string, target):
|
||||||
return re.sub("[^a-z ]", "", string.lower()), re.sub("[^a-z ]", "", target.lower())
|
string = ''.join(
|
||||||
|
c for c in unicodedata.normalize('NFD', s)
|
||||||
|
if unicodedata.category(c) != 'Mn'
|
||||||
|
)
|
||||||
|
|
||||||
def read_clear_data(in_file_path, exptected_file_path):
|
target = ''.join(
|
||||||
|
c for c in unicodedata.normalize('NFD', s)
|
||||||
|
if unicodedata.category(c) != 'Mn'
|
||||||
|
)
|
||||||
|
|
||||||
|
string = re.sub("[^a-z ]", "", string.lower())
|
||||||
|
target = re.sub("[^a-z ]", "", target.lower())
|
||||||
|
return string, target
|
||||||
|
|
||||||
|
def read_clear_data(in_file_path, expected_file_path):
|
||||||
print("Reading data")
|
print("Reading data")
|
||||||
source_data = []
|
pairs = []
|
||||||
target_data = []
|
with open(in_file_path) as in_file, open(expected_file_path) as exp_file:
|
||||||
with open(in_file_path) as in_file, open(exptected_file_path) as exp_file:
|
|
||||||
for string, target in zip(in_file, exp_file):
|
for string, target in zip(in_file, exp_file):
|
||||||
string, target = clear_line(string, target)
|
string, target = clear_line(string, target)
|
||||||
source_data.appen(string)
|
if len(string.split(' ')) < MAX_LEN and len(target.split(' ')) < MAX_LEN:
|
||||||
target_data.appen(target)
|
pairs.append([string, target])
|
||||||
return source_data, target_data
|
input_vocab = Vocab("pl")
|
||||||
|
target_vocab = Vocab("en")
|
||||||
|
return pairs, input_vocab, target_vocab
|
||||||
|
|
||||||
def create_dict(data):
|
def prepare_data(in_file_path, expected_file_path):
|
||||||
counter = Counter()
|
pairs, input_vocab, target_vocab = read_clear_data(in_file_path, expected_file_path)
|
||||||
|
|
||||||
for line in data:
|
for pair in pairs:
|
||||||
|
input_lang.add_sentence(pair[0])
|
||||||
|
target_lang.add_sentence(pair[1])
|
||||||
|
|
||||||
|
return pairs, input_vocab, target_vocab
|
||||||
|
|
||||||
|
def indexes_from_sentence(vocab, sentence):
|
||||||
|
return [vocab.word2index[word] for word in sentence.split(' ')]
|
||||||
|
|
||||||
|
def tensor_from_sentece(vocab, sentence):
|
||||||
|
indexes = indexes_from_sentence(vocab, sentence)
|
||||||
|
indexes.append(EOS)
|
||||||
|
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1,1)
|
||||||
|
|
||||||
|
def tensors_from_pair(pair, input_vocab, target_vocab):
|
||||||
|
input_tensor = tensor_from_sentece(input_vocab, pair[0])
|
||||||
|
target_tensor = tensor_from_sentece(target_vocab, pair[1])
|
||||||
|
return (input_tensor, target_tensor)
|
||||||
|
|
||||||
|
def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_optim, criterion, max_length=MAX_LENGTH):
|
||||||
|
if not checkpoint:
|
||||||
|
encoder_hidden = encoder.init_hidden(device)
|
||||||
|
|
||||||
|
encoder_optim.zero_grad()
|
||||||
|
decoder_optim.zero_grad()
|
||||||
|
|
||||||
|
input_len = input_tensor.size(0)
|
||||||
|
target_len = target_tensor.size(0)
|
||||||
|
|
||||||
|
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
for e in range(input_len):
|
||||||
|
encoder_output, encoder_hidden = encoder(input_tensor[e], encoder_hidden)
|
||||||
|
encoder_outputs[i] = encoder_output[0, 0]
|
||||||
|
decoder_hidden = encoder_hidden
|
||||||
|
|
||||||
|
decoder_input = torch.tensor([[SOS]], device=device)
|
||||||
|
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
|
||||||
|
|
||||||
|
if use_teacher_forcing:
|
||||||
|
for d in range(target_len):
|
||||||
|
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
|
||||||
|
loss += criterion(decoder_output, target_tensor[d])
|
||||||
|
decoder_input = target_tensor[d]
|
||||||
|
else:
|
||||||
|
for d in range(target_len):
|
||||||
|
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
|
||||||
|
topv, topi = decoder_output.topk(1)
|
||||||
|
dcoder_input = topi.squeeze().detach()
|
||||||
|
|
||||||
|
loss += criterion(decoder_output, target_tensor[d])
|
||||||
|
if decoder_input.item() == EOS:
|
||||||
|
break
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
encoder_optim.step()
|
||||||
|
encoder_optim.step()
|
||||||
|
return loss.item()/ target_len
|
||||||
|
|
||||||
|
def train_iterate(pairs, encoder, decoder, n_iters, lr=0.01):
|
||||||
|
encoder_optim = torch.optim.SGD(encoder.parameters(), lr=lr)
|
||||||
|
decoder_optim = torch.optim.SGD(decoder.parameters(), lr=lr)
|
||||||
|
training_pairs = [tensors_from_pair(random.choice(pairs)) for i in range(n_iters)]
|
||||||
|
criterion = torch.nn.NLLLoss()
|
||||||
|
loss_total=0
|
||||||
|
|
||||||
|
for i in range(1, n_iters + 1):
|
||||||
|
training_pair = training_pairs[i - 1]
|
||||||
|
input_tensor = training_pair[0]
|
||||||
|
target_tensor = training_pair[1]
|
||||||
|
|
||||||
|
loss = train(input_tensor, target_tensor, encoder, de, encoder_optim, decoder_optim, criterion)
|
||||||
|
loss_total += loss
|
||||||
|
|
||||||
|
if i % 1000 == 0:
|
||||||
|
loss_avg = loss_total / 1000
|
||||||
|
print(f"lavg loss: {loss_avg}")
|
||||||
|
loss_total = 0
|
||||||
|
|
||||||
|
if i % 5000 == 0:
|
||||||
|
torch.save(encoder.state_dict(), f'models/encoder-{i}-{seed}')
|
||||||
|
torch.save(decoder.state_dict(), f'models/decoder-{i}-{seed}')
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--in_f')
|
parser.add_argument('--in_f')
|
||||||
parser.add_argument('--exp')
|
parser.add_argument('--exp')
|
||||||
parser.add_argument("--vocab")
|
parser.add_argument("--vocab")
|
||||||
|
parser.add_argument("--encoder")
|
||||||
|
parser.add_argument("--decoder")
|
||||||
|
parser.add_argument("--seed")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
source_data, target_data = read_clear_data(args.in_f, args.exp)
|
if args.seed:
|
||||||
|
seed = int(args.seed)
|
||||||
|
else:
|
||||||
|
seed = random.rand
|
||||||
|
|
||||||
|
global seed
|
||||||
|
|
||||||
|
if args.vocab:
|
||||||
|
with open(args.vocab, 'wb+') as p:
|
||||||
|
pairs, input_vocab, target_vocab = pickle.load(p)
|
||||||
|
else:
|
||||||
|
pairs, input_vocab, target_vocab = prepare_data(args.in_f, args.exp)
|
||||||
|
with open("vocabs.pckl", 'rb') as p:
|
||||||
|
pickle.dump([pairs, input_vocab, target_vocab], p)
|
||||||
|
|
||||||
|
hidden_size = 256
|
||||||
|
encoder = Encoder(input_vocab.size, hidden_size).to(device)
|
||||||
|
decoder = Decoder(hidden_size, target_vocab.size).to(device)
|
||||||
|
|
||||||
|
if args.encoder:
|
||||||
|
encoder.load_state_dict(torch.load(args.encoder))
|
||||||
|
if args.decoder:
|
||||||
|
decoder.load_state_dict(torch.load(args.decoder))
|
||||||
|
|
||||||
|
train_iterate(pairs, encoder, decoder, 50000)
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
Loading…
Reference in New Issue
Block a user