ium_434749/PH.py
2021-03-22 13:03:07 +01:00

285 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import unicodedata
from time import sleep
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import re
import random
import os
from tqdm import tqdm
DATA_FILE = 'en_US.txt'
EPOCHS = 4000
TEACHER_FORCING_PROBABILITY = 0.4
LEARNING_RATE = 0.01
BATCH_SIZE = 1024
plt.ion()
if not os.path.isfile(DATA_FILE):
import requests
open(DATA_FILE, 'wb').write(
requests.get('https://raw.githubusercontent.com/open-dict-data/ipa-dict/master/data/' + DATA_FILE).content)
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
OUT_LOOKUP = ['', 'b', 'a', 'ʊ', 't', 'k', 'ə', 'z', 'ɔ', 'ɹ', 's', 'j', 'u', 'm', 'f', 'ɪ', 'o', 'ɡ', 'ɛ', 'n',
'e', 'd',
'ɫ', 'w', 'i', 'p', 'ɑ', 'ɝ', 'θ', 'v', 'h', 'æ', 'ŋ', 'ʃ', 'ʒ', 'ð', '^', '$']
IN_LOOKUP = ['', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
'u', 'v', 'w', 'x', 'y', 'z', '$', '^']
IN_ALPHABET = {letter: idx for idx, letter in enumerate(IN_LOOKUP)}
OUT_ALPHABET = {letter: idx for idx, letter in enumerate(OUT_LOOKUP)}
TOTAL_OUT_LEN = 0
DATA: [(torch.tensor, torch.tensor)] = []
with open(DATA_FILE) as f:
for line in f:
text, phonemes = line.split("\t")
phonemes = phonemes.strip().split(",")[0]
phonemes = '^' + re.sub(r'[/\'ˈˌ]', '', phonemes) + '$'
text = '^' + re.sub(r'[^a-z]', '', text.strip()) + '$'
text = torch.tensor([IN_ALPHABET[letter] for letter in text], dtype=torch.int)
phonemes = torch.tensor([OUT_ALPHABET[letter] for letter in phonemes], dtype=torch.int)
DATA.append((text, phonemes))
random.shuffle(DATA)
# DATA = DATA[:2000]
print("Number of samples ", len(DATA))
TRAINING_SET_SIZE = int(len(DATA) * 0.5)
TRAINING_SET_SIZE -= TRAINING_SET_SIZE % BATCH_SIZE
EVAL = DATA[TRAINING_SET_SIZE:]
DATA = DATA[:TRAINING_SET_SIZE]
assert len(DATA) % BATCH_SIZE == 0
print("Training samples ", len(DATA))
print("Evaluation samples ", len(EVAL))
print("Input alphabet ", IN_LOOKUP)
print("Output alphabet ", OUT_LOOKUP)
TOTAL_TRAINING_OUT_LEN = 0
TOTAL_EVALUATION_OUT_LEN = 0
for text, phonemes in DATA:
TOTAL_TRAINING_OUT_LEN += len(phonemes)
for text, phonemes in EVAL:
TOTAL_EVALUATION_OUT_LEN += len(phonemes)
TOTAL_EVALUATION_OUT_LEN -= len(EVAL) # do not count the beginning of line ^ character
TOTAL_TRAINING_OUT_LEN -= len(DATA)
print("Total output length in training set", TOTAL_TRAINING_OUT_LEN)
print("Total output length in evaluation set", TOTAL_EVALUATION_OUT_LEN)
def shuffle_but_keep_sorted_by_output_lengths(data: [(torch.tensor, torch.tensor)]):
random.shuffle(data)
data.sort(reverse=True, key=lambda x: len(x[1])) # sort with respect to output lengths
def collate(batch: [(torch.tensor, torch.tensor)]):
batch.sort(reverse=True, key=lambda x: len(x[0])) # sort with respect to input lengths
in_lengths = [len(entry[0]) for entry in batch]
max_in_len = max(in_lengths)
out_lengths = [len(entry[1]) for entry in batch]
max_out_len = max(out_lengths)
padded_in = torch.zeros((len(batch), max_in_len), dtype=torch.int)
padded_out = torch.zeros((len(batch), max_out_len), dtype=torch.long)
for i in range(0, len(batch)):
padded_in[i, :len(batch[i][0])] = batch[i][0]
padded_out[i, :len(batch[i][1])] = batch[i][1]
return padded_in, in_lengths, padded_out, out_lengths
class EncoderRNN(nn.Module):
def __init__(self, hidden_size, layers):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.hidden_layers = layers
self.embedding = nn.Embedding(num_embeddings=len(IN_ALPHABET),
embedding_dim=hidden_size,
padding_idx=IN_ALPHABET[''])
self.gru = nn.LSTM(input_size=hidden_size,
hidden_size=hidden_size,
num_layers=self.hidden_layers,
batch_first=True)
# self.lin = nn.Linear(hidden_size, hidden_size)
def forward(self, padded_in, in_lengths):
batch_size = len(in_lengths)
# hidden_state, cell_state = hidden
# assert hidden_state.size() == (self.hidden_layers, batch_size, self.hidden_size)
# assert cell_state.size() == (self.hidden_layers, batch_size, self.hidden_size)
embedded = self.embedding(padded_in)
assert embedded.size() == (batch_size, max(in_lengths), self.hidden_size)
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, in_lengths, batch_first=True)
gru_out, hidden = self.gru(packed)
unpacked, _ = torch.nn.utils.rnn.pad_packed_sequence(gru_out, batch_first=True)
assert unpacked.size() == (batch_size, max(in_lengths), self.hidden_size)
# assert hidden.size() == (self.hidden_layers, batch_size, self.hidden_size)
# h, cell_state = hidden
# final_hidden = self.lin(h)
return unpacked, hidden
def init_hidden(self, batch_size, device):
hidden_state = torch.zeros(self.hidden_layers, batch_size, self.hidden_size, device=device)
cell_state = torch.zeros(self.hidden_layers, batch_size, self.hidden_size, device=device)
return hidden_state, cell_state
class DecoderRNN(nn.Module):
def __init__(self, hidden_size, layers):
super(DecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.hidden_layers = layers
self.embedding = nn.Embedding(num_embeddings=len(OUT_ALPHABET),
embedding_dim=hidden_size,
padding_idx=OUT_ALPHABET[''])
self.gru = nn.LSTM(input_size=hidden_size,
hidden_size=hidden_size,
num_layers=self.hidden_layers,
batch_first=True)
self.out = nn.Linear(hidden_size, len(OUT_ALPHABET))
self.softmax = nn.LogSoftmax(dim=2)
def forward(self, padded_out, hidden):
batch_size = len(padded_out)
padded_out = padded_out.unsqueeze(1)
seq_length = 1
hidden_state, cell_state = hidden
assert hidden_state.size() == (self.hidden_layers, batch_size, self.hidden_size)
assert cell_state.size() == (self.hidden_layers, batch_size, self.hidden_size)
embedded = self.embedding(padded_out)
assert embedded.size() == (batch_size, seq_length, self.hidden_size)
gru_out, hidden = self.gru(embedded, hidden)
# assert hidden.size() == (self.hidden_layers, batch_size, self.hidden_size)
assert gru_out.size() == (batch_size, seq_length, self.hidden_size)
lin = self.out(gru_out)
assert lin.size() == (batch_size, seq_length, len(OUT_ALPHABET))
probabilities = self.softmax(lin)
assert probabilities.size() == (batch_size, seq_length, len(OUT_ALPHABET))
return probabilities, hidden
def run(encoder, decoder, batch_in, i_lengths, batch_out, o_lengths, teacher_forcing_prob, criterion):
batch_in = batch_in.to(DEVICE)
batch_out = batch_out.to(DEVICE)
out_seq_len = batch_out.size()[1]
in_seq_len = batch_in.size()[1]
assert batch_in.size() == (BATCH_SIZE, in_seq_len)
assert batch_out.size() == (BATCH_SIZE, out_seq_len)
loss = 0
total_correct_predictions = 0
encoder_output, hidden = encoder(batch_in, i_lengths)
output = batch_out[:, 0]
for i in range(out_seq_len - 1):
if random.random() < teacher_forcing_prob:
out = batch_out[:, i]
else:
out = output
output, hidden = decoder(out, hidden)
output = output.squeeze(1)
expected_output = batch_out[:, i + 1]
if criterion is not None:
loss += criterion(output, expected_output)
argmax_output = torch.argmax(output, 1)
with torch.no_grad():
total_correct_predictions += (argmax_output == expected_output).sum().item()
output = argmax_output
return loss, total_correct_predictions
eval_bar = tqdm(total=len(EVAL), position=2)
def eval_model(encoder, decoder):
eval_bar.reset()
eval_bar.set_description("Evaluation")
with torch.no_grad():
total_correct_predictions = 0
for batch_in, i_lengths, batch_out, o_lengths in DataLoader(dataset=EVAL, drop_last=True,
batch_size=BATCH_SIZE,
collate_fn=collate):
loss, correct_predictions = run(encoder=encoder,
decoder=decoder,
criterion=None,
i_lengths=i_lengths,
o_lengths=o_lengths,
batch_in=batch_in,
batch_out=batch_out,
teacher_forcing_prob=0)
total_correct_predictions += correct_predictions
eval_bar.update(BATCH_SIZE)
return total_correct_predictions
outer_bar = tqdm(total=EPOCHS, position=0)
inner_bar = tqdm(total=len(DATA), position=1)
def train_model(encoder, decoder):
encoder_optimizer = optim.Adam(filter(lambda x: x.requires_grad, encoder.parameters()),
lr=LEARNING_RATE)
decoder_optimizer = optim.Adam(filter(lambda x: x.requires_grad, decoder.parameters()),
lr=LEARNING_RATE)
criterion = nn.NLLLoss(ignore_index=OUT_ALPHABET[''])
train_snapshots_percentage = [0]
train_snapshots = [0]
eval_snapshots = [eval_model(encoder, decoder)]
eval_snapshots_percentage = [eval_snapshots[0] / TOTAL_EVALUATION_OUT_LEN]
outer_bar.reset()
outer_bar.set_description("Epochs")
for epoch in range(EPOCHS):
shuffle_but_keep_sorted_by_output_lengths(DATA)
total_loss = 0
total_correct_predictions = 0
inner_bar.reset()
for batch_in, i_lengths, batch_out, o_lengths in DataLoader(dataset=DATA, drop_last=True,
batch_size=BATCH_SIZE,
collate_fn=collate):
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
loss, correct_predictions = run(encoder=encoder,
decoder=decoder,
criterion=criterion,
i_lengths=i_lengths,
o_lengths=o_lengths,
batch_in=batch_in,
batch_out=batch_out,
teacher_forcing_prob=TEACHER_FORCING_PROBABILITY)
total_correct_predictions += correct_predictions
loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
inner_bar.update(BATCH_SIZE)
loss_scalar = loss.item()
total_loss += loss_scalar
inner_bar.set_description("Avg loss %.2f" % (loss_scalar / batch_out.size()[1]))
train_snapshots.append(total_correct_predictions)
train_snapshots_percentage.append(total_correct_predictions / TOTAL_TRAINING_OUT_LEN)
eval_snapshots.append(eval_model(encoder, decoder))
eval_snapshots_percentage.append(eval_snapshots[-1] / TOTAL_EVALUATION_OUT_LEN)
plt.clf()
plt.plot(train_snapshots_percentage, label="Training %")
plt.plot(eval_snapshots_percentage, label="Evaluation %")
plt.legend(loc="upper left")
plt.pause(interval=0.01)
# print()
# print("Total epoch loss:", total_loss)
# print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN)
# print("Training snapshots:", train_snapshots)
# print("Training snapshots(%):", train_snapshots_percentage)
# print("Evaluation snapshots:", eval_snapshots)
# print("Evaluation snapshots(%):", eval_snapshots_percentage)
outer_bar.set_description("Epochs")
outer_bar.update(1)
train_model(EncoderRNN(32, 4).to(DEVICE), DecoderRNN(32, 4).to(DEVICE))