ium_434749/train_model.py

242 lines
9.9 KiB
Python
Raw Normal View History

2021-04-25 20:55:45 +02:00
# https://arxiv.org/pdf/2001.11692.pdf
import numpy as np
import unicodedata
from time import sleep
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
2021-05-09 18:35:29 +02:00
from sacred.observers import FileStorageObserver, MongoObserver
2021-04-25 20:55:45 +02:00
from torch.utils.data import Dataset, DataLoader
import re
import random
import os
2021-04-26 17:36:15 +02:00
import sys
2021-04-25 20:55:45 +02:00
from tqdm import tqdm
from Levenshtein import distance as levenshtein_distance
2021-05-09 18:35:29 +02:00
from sacred import Experiment
2021-05-10 11:57:53 +02:00
import traceback
2021-05-23 17:00:43 +02:00
from mlflow import log_metric, log_param, log_artifacts
import mlflow
import logging
logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)
mlflow.set_tracking_uri("http://172.17.0.1:5000")
mlflow.set_experiment("s434749")
2021-04-25 20:55:45 +02:00
2021-05-09 18:35:29 +02:00
ex = Experiment("CNN")
ex.observers.append(FileStorageObserver('sacred_file_observer'))
2021-05-17 11:58:37 +02:00
try:
ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@172.17.0.1:27017',
db_name='sacred'))
except Exception as e:
traceback.print_exc()
2021-05-10 12:00:34 +02:00
2021-05-09 18:35:29 +02:00
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
2021-04-25 20:55:45 +02:00
2021-05-10 11:57:53 +02:00
2021-04-25 20:55:45 +02:00
class CNN(nn.Module):
2021-05-09 18:35:29 +02:00
def __init__(self, kernel_size, hidden_layers, channels, embedding_size, in_alphabet, max_len):
2021-04-25 20:55:45 +02:00
super(CNN, self).__init__()
2021-05-09 18:35:29 +02:00
self.input_conv = nn.Conv1d(in_channels=len(in_alphabet), out_channels=channels, kernel_size=kernel_size)
2021-04-25 20:55:45 +02:00
self.conv_hidden = nn.ModuleList(
[nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size) for _ in
range(hidden_layers)])
2021-05-09 18:35:29 +02:00
self.last_layer_size = (max_len - (kernel_size - 1) * (hidden_layers + 1)) * channels
2021-04-25 20:55:45 +02:00
self.lin = nn.Linear(self.last_layer_size, embedding_size)
def forward(self, x):
x = self.input_conv(x)
x = F.relu(x, inplace=True)
for c in self.conv_hidden:
x = c(x)
x = F.relu(x, inplace=True)
x = x.view(x.size()[0], self.last_layer_size)
x = self.lin(x)
return x
def dist(a: [str], b: [str]):
2021-05-09 18:35:29 +02:00
return torch.tensor([levenshtein_distance(a[i], b[i]) for i in range(len(a))], dtype=torch.float, device=device)
2021-04-25 20:55:45 +02:00
2021-05-23 17:00:43 +02:00
def encode(batch: [(torch.tensor, str)], in_alphabet, max_len):
batch_text = torch.zeros((len(batch), len(in_alphabet), max_len))
batch_phonemes = list(map(lambda x: x[1], batch))
for i, (sample, _) in enumerate(batch):
for chr_pos, index in enumerate(sample):
batch_text[i, index, chr_pos] = 1
return batch_text, batch_phonemes
def encode_str(batch: [(str, str)], in_alphabet, max_len):
2021-05-23 19:56:11 +02:00
batch = [(torch.tensor([in_alphabet[letter] for letter in in_str], dtype=torch.int), out_str) for in_str, out_str in
batch]
2021-05-23 19:40:20 +02:00
return encode(batch, in_alphabet, max_len)
2021-05-23 17:00:43 +02:00
2021-05-09 18:35:29 +02:00
def train_model(model, learning_rate, in_alphabet, max_len, data, epochs, batch_size):
2021-04-25 20:55:45 +02:00
optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()),
2021-05-09 18:35:29 +02:00
lr=learning_rate)
outer_bar = tqdm(total=epochs, position=0)
inner_bar = tqdm(total=len(data), position=1)
2021-04-25 20:55:45 +02:00
outer_bar.reset()
outer_bar.set_description("Epochs")
2021-05-09 18:35:29 +02:00
def collate(batch: [(torch.tensor, str)]):
2021-05-23 17:00:43 +02:00
return encode(batch, in_alphabet, max_len)
2021-05-09 18:35:29 +02:00
data_loader = DataLoader(dataset=data, drop_last=True,
batch_size=3 * batch_size,
2021-04-25 20:55:45 +02:00
collate_fn=collate,
shuffle=True)
2021-05-09 18:35:29 +02:00
for epoch in range(epochs):
2021-04-25 20:55:45 +02:00
total_loss = 0
inner_bar.reset()
for batch_text, batch_phonemes in data_loader:
optimizer.zero_grad()
2021-05-09 18:35:29 +02:00
anchor, positive, negative = batch_text.to(device).split(batch_size)
ph_anchor = batch_phonemes[:batch_size]
ph_positive = batch_phonemes[batch_size:2 * batch_size]
ph_negative = batch_phonemes[2 * batch_size:]
2021-04-25 20:55:45 +02:00
embedded_anchor = model(anchor)
embedded_positive = model(positive)
embedded_negative = model(negative)
estimated_pos_dist = torch.linalg.norm(embedded_anchor - embedded_positive, dim=1)
estimated_neg_dist = torch.linalg.norm(embedded_anchor - embedded_negative, dim=1)
estimated_pos_neg_dist = torch.linalg.norm(embedded_positive - embedded_negative, dim=1)
actual_pos_dist = dist(ph_anchor, ph_positive)
actual_neg_dist = dist(ph_anchor, ph_negative)
actual_pos_neg_dist = dist(ph_positive, ph_negative)
loss = sum(abs(estimated_neg_dist - actual_neg_dist)
+ abs(estimated_pos_dist - actual_pos_dist)
+ abs(estimated_pos_neg_dist - actual_pos_neg_dist)
+ (estimated_pos_dist - estimated_neg_dist - (actual_pos_dist - actual_neg_dist)).clip(min=0))
loss.backward()
optimizer.step()
2021-05-09 18:35:29 +02:00
inner_bar.update(3 * batch_size)
2021-04-25 20:55:45 +02:00
loss_scalar = loss.item()
total_loss += loss_scalar
inner_bar.set_description("loss %.2f" % loss_scalar)
2021-05-09 18:35:29 +02:00
ex.log_scalar("avg_loss", total_loss / len(data) * 3)
2021-05-23 17:00:43 +02:00
log_metric("avg_loss", total_loss / len(data) * 3)
2021-04-25 20:55:45 +02:00
# print()
# print("Total epoch loss:", total_loss)
# print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN)
# print("Training snapshots:", train_snapshots)
# print("Training snapshots(%):", train_snapshots_percentage)
# print("Evaluation snapshots:", eval_snapshots)
# print("Evaluation snapshots(%):", eval_snapshots_percentage)
outer_bar.set_description("Epochs")
outer_bar.update(1)
2021-05-09 18:35:29 +02:00
def evaluate_monte_carlo(model, repeats, data, batch_size, in_alphabet, max_len):
2021-04-25 20:55:45 +02:00
with torch.no_grad():
i = 0
diff = 0
2021-05-09 18:35:29 +02:00
outer_bar = tqdm(total=repeats, position=0)
inner_bar = tqdm(total=len(data), position=1)
2021-04-25 20:55:45 +02:00
outer_bar.set_description("Epochs")
2021-05-09 18:35:29 +02:00
def collate(batch: [(torch.tensor, str)]):
2021-05-23 17:00:43 +02:00
return encode(batch, in_alphabet, max_len)
2021-05-09 18:35:29 +02:00
2021-04-25 20:55:45 +02:00
for _ in range(repeats):
2021-05-09 18:35:29 +02:00
data_loader = DataLoader(dataset=data, drop_last=True,
batch_size=2 * batch_size,
2021-04-25 20:55:45 +02:00
collate_fn=collate,
shuffle=True)
inner_bar.reset()
for batch_text, batch_phonemes in data_loader:
2021-05-09 18:35:29 +02:00
positive, negative = batch_text.to(device).split(batch_size)
ph_positive = batch_phonemes[0:batch_size]
ph_negative = batch_phonemes[batch_size:]
2021-04-25 20:55:45 +02:00
embedded_positive = model(positive)
embedded_negative = model(negative)
estimated_dist = torch.linalg.norm(embedded_negative - embedded_positive, dim=1)
actual_dist = dist(ph_negative, ph_positive)
diff += sum(abs(estimated_dist - actual_dist))
2021-05-09 18:35:29 +02:00
i += batch_size
inner_bar.update(2 * batch_size)
2021-04-25 20:55:45 +02:00
outer_bar.update(1)
2021-04-26 10:45:58 +02:00
with open('results.txt', 'w+') as r:
2021-04-26 08:25:18 +02:00
print("Average estimation error " + str(diff.item() / i))
2021-04-26 17:36:15 +02:00
r.write("Average estimation error " + str(diff.item() / i) + "\n")
2021-05-09 18:35:29 +02:00
ex.log_scalar("avg_estim_error", diff.item() / i)
@ex.config
def cfg():
kernel_size = 3
hidden_layers = 14
data_file = 'preprocessed.tsv'
epochs = 14
mode = 'train'
teacher_forcing_probability = 0.4
learning_rate = 0.01
batch_size = 512
max_len = 32
total_out_len = 0
model_file = 'cnn.pth'
out_lookup = ['', 'b', 'a', 'ʊ', 't', 'k', 'ə', 'z', 'ɔ', 'ɹ', 's', 'j', 'u', 'm', 'f', 'ɪ', 'o', 'ɡ', 'ɛ', 'n',
'e', 'd',
'ɫ', 'w', 'i', 'p', 'ɑ', 'ɝ', 'θ', 'v', 'h', 'æ', 'ŋ', 'ʃ', 'ʒ', 'ð']
in_lookup = ['', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
'u', 'v', 'w', 'x', 'y', 'z']
2021-05-23 19:56:11 +02:00
def signature(model, in_alphabet, max_len):
2021-05-23 19:46:14 +02:00
mock_x = [('abc', 'xyz')]
2021-05-23 17:00:43 +02:00
mock_text, _ = encode_str(mock_x, in_alphabet, max_len)
mock_y = model(mock_text)
2021-05-23 20:31:41 +02:00
return mlflow.models.signature.infer_signature(mock_text.detach().numpy(), mock_y.detach().numpy())
2021-05-23 17:00:43 +02:00
2021-05-23 19:56:11 +02:00
2021-05-09 18:35:29 +02:00
@ex.automain
def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probability, learning_rate, batch_size, max_len,
total_out_len, model_file, out_lookup, in_lookup, mode):
2021-05-23 17:00:43 +02:00
with mlflow.start_run():
log_param("kernel_size", kernel_size)
log_param("hidden_layers", hidden_layers)
log_param("data_file", data_file)
log_param("epochs", epochs)
log_param("learning_rate", learning_rate)
log_param("batch_size", batch_size)
log_param("max_len", max_len)
in_alphabet = {letter: idx for idx, letter in enumerate(in_lookup)}
out_alphabet = {letter: idx for idx, letter in enumerate(out_lookup)}
data: [(torch.tensor, torch.tensor)] = []
texts: [str] = []
with open(data_file) as f:
for line in f:
text, phonemes = line.split("\t")
texts.append(text)
assert len(text) <= max_len, text
text = torch.tensor([in_alphabet[letter] for letter in text], dtype=torch.int)
data.append((text, phonemes))
cnn = CNN(kernel_size=kernel_size, hidden_layers=hidden_layers, channels=max_len, embedding_size=max_len,
in_alphabet=in_alphabet, max_len=max_len).to(device)
2021-05-23 19:56:11 +02:00
if mode == 'train':
train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size)
torch.save(cnn.state_dict(), model_file)
ex.add_artifact(model_file)
mlflow.pytorch.log_model(cnn, "cnn-model", registered_model_name="PhoneticEdDistEmbeddings",
signature=signature(cnn, in_alphabet, max_len))
2021-05-23 17:00:43 +02:00
if mode == 'eval':
cnn.eval()
evaluate_monte_carlo(cnn, 1, data, batch_size, in_alphabet, max_len)