ium_434749/train_model.py

242 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# https://arxiv.org/pdf/2001.11692.pdf
import numpy as np
import unicodedata
from time import sleep
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sacred.observers import FileStorageObserver, MongoObserver
from torch.utils.data import Dataset, DataLoader
import re
import random
import os
import sys
from tqdm import tqdm
from Levenshtein import distance as levenshtein_distance
from sacred import Experiment
import traceback
from mlflow import log_metric, log_param, log_artifacts
import mlflow
import logging
logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)
mlflow.set_tracking_uri("http://172.17.0.1:5000")
mlflow.set_experiment("s434749")
ex = Experiment("CNN")
ex.observers.append(FileStorageObserver('sacred_file_observer'))
try:
ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@172.17.0.1:27017',
db_name='sacred'))
except Exception as e:
traceback.print_exc()
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
class CNN(nn.Module):
def __init__(self, kernel_size, hidden_layers, channels, embedding_size, in_alphabet, max_len):
super(CNN, self).__init__()
self.input_conv = nn.Conv1d(in_channels=len(in_alphabet), out_channels=channels, kernel_size=kernel_size)
self.conv_hidden = nn.ModuleList(
[nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size) for _ in
range(hidden_layers)])
self.last_layer_size = (max_len - (kernel_size - 1) * (hidden_layers + 1)) * channels
self.lin = nn.Linear(self.last_layer_size, embedding_size)
def forward(self, x):
x = self.input_conv(x)
x = F.relu(x, inplace=True)
for c in self.conv_hidden:
x = c(x)
x = F.relu(x, inplace=True)
x = x.view(x.size()[0], self.last_layer_size)
x = self.lin(x)
return x
def dist(a: [str], b: [str]):
return torch.tensor([levenshtein_distance(a[i], b[i]) for i in range(len(a))], dtype=torch.float, device=device)
def encode(batch: [(torch.tensor, str)], in_alphabet, max_len):
batch_text = torch.zeros((len(batch), len(in_alphabet), max_len))
batch_phonemes = list(map(lambda x: x[1], batch))
for i, (sample, _) in enumerate(batch):
for chr_pos, index in enumerate(sample):
batch_text[i, index, chr_pos] = 1
return batch_text, batch_phonemes
def encode_str(batch: [(str, str)], in_alphabet, max_len):
batch = [(torch.tensor([in_alphabet[letter] for letter in in_str], dtype=torch.int), out_str) for in_str, out_str in
batch]
return encode(batch, in_alphabet, max_len)
def train_model(model, learning_rate, in_alphabet, max_len, data, epochs, batch_size):
optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()),
lr=learning_rate)
outer_bar = tqdm(total=epochs, position=0)
inner_bar = tqdm(total=len(data), position=1)
outer_bar.reset()
outer_bar.set_description("Epochs")
def collate(batch: [(torch.tensor, str)]):
return encode(batch, in_alphabet, max_len)
data_loader = DataLoader(dataset=data, drop_last=True,
batch_size=3 * batch_size,
collate_fn=collate,
shuffle=True)
for epoch in range(epochs):
total_loss = 0
inner_bar.reset()
for batch_text, batch_phonemes in data_loader:
optimizer.zero_grad()
anchor, positive, negative = batch_text.to(device).split(batch_size)
ph_anchor = batch_phonemes[:batch_size]
ph_positive = batch_phonemes[batch_size:2 * batch_size]
ph_negative = batch_phonemes[2 * batch_size:]
embedded_anchor = model(anchor)
embedded_positive = model(positive)
embedded_negative = model(negative)
estimated_pos_dist = torch.linalg.norm(embedded_anchor - embedded_positive, dim=1)
estimated_neg_dist = torch.linalg.norm(embedded_anchor - embedded_negative, dim=1)
estimated_pos_neg_dist = torch.linalg.norm(embedded_positive - embedded_negative, dim=1)
actual_pos_dist = dist(ph_anchor, ph_positive)
actual_neg_dist = dist(ph_anchor, ph_negative)
actual_pos_neg_dist = dist(ph_positive, ph_negative)
loss = sum(abs(estimated_neg_dist - actual_neg_dist)
+ abs(estimated_pos_dist - actual_pos_dist)
+ abs(estimated_pos_neg_dist - actual_pos_neg_dist)
+ (estimated_pos_dist - estimated_neg_dist - (actual_pos_dist - actual_neg_dist)).clip(min=0))
loss.backward()
optimizer.step()
inner_bar.update(3 * batch_size)
loss_scalar = loss.item()
total_loss += loss_scalar
inner_bar.set_description("loss %.2f" % loss_scalar)
ex.log_scalar("avg_loss", total_loss / len(data) * 3)
log_metric("avg_loss", total_loss / len(data) * 3)
# print()
# print("Total epoch loss:", total_loss)
# print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN)
# print("Training snapshots:", train_snapshots)
# print("Training snapshots(%):", train_snapshots_percentage)
# print("Evaluation snapshots:", eval_snapshots)
# print("Evaluation snapshots(%):", eval_snapshots_percentage)
outer_bar.set_description("Epochs")
outer_bar.update(1)
def evaluate_monte_carlo(model, repeats, data, batch_size, in_alphabet, max_len):
with torch.no_grad():
i = 0
diff = 0
outer_bar = tqdm(total=repeats, position=0)
inner_bar = tqdm(total=len(data), position=1)
outer_bar.set_description("Epochs")
def collate(batch: [(torch.tensor, str)]):
return encode(batch, in_alphabet, max_len)
for _ in range(repeats):
data_loader = DataLoader(dataset=data, drop_last=True,
batch_size=2 * batch_size,
collate_fn=collate,
shuffle=True)
inner_bar.reset()
for batch_text, batch_phonemes in data_loader:
positive, negative = batch_text.to(device).split(batch_size)
ph_positive = batch_phonemes[0:batch_size]
ph_negative = batch_phonemes[batch_size:]
embedded_positive = model(positive)
embedded_negative = model(negative)
estimated_dist = torch.linalg.norm(embedded_negative - embedded_positive, dim=1)
actual_dist = dist(ph_negative, ph_positive)
diff += sum(abs(estimated_dist - actual_dist))
i += batch_size
inner_bar.update(2 * batch_size)
outer_bar.update(1)
with open('results.txt', 'w+') as r:
print("Average estimation error " + str(diff.item() / i))
r.write("Average estimation error " + str(diff.item() / i) + "\n")
ex.log_scalar("avg_estim_error", diff.item() / i)
@ex.config
def cfg():
kernel_size = 3
hidden_layers = 14
data_file = 'preprocessed.tsv'
epochs = 14
mode = 'train'
teacher_forcing_probability = 0.4
learning_rate = 0.01
batch_size = 512
max_len = 32
total_out_len = 0
model_file = 'cnn.pth'
out_lookup = ['', 'b', 'a', 'ʊ', 't', 'k', 'ə', 'z', 'ɔ', 'ɹ', 's', 'j', 'u', 'm', 'f', 'ɪ', 'o', 'ɡ', 'ɛ', 'n',
'e', 'd',
'ɫ', 'w', 'i', 'p', 'ɑ', 'ɝ', 'θ', 'v', 'h', 'æ', 'ŋ', 'ʃ', 'ʒ', 'ð']
in_lookup = ['', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
'u', 'v', 'w', 'x', 'y', 'z']
def signature(model, in_alphabet, max_len):
mock_x = [('abc', 'xyz')]
mock_text, _ = encode_str(mock_x, in_alphabet, max_len)
mock_y = model(mock_text)
return mlflow.models.signature.infer_signature(mock_text.detach().numpy(), mock_y.detach().numpy())
@ex.automain
def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probability, learning_rate, batch_size, max_len,
total_out_len, model_file, out_lookup, in_lookup, mode):
with mlflow.start_run():
log_param("kernel_size", kernel_size)
log_param("hidden_layers", hidden_layers)
log_param("data_file", data_file)
log_param("epochs", epochs)
log_param("learning_rate", learning_rate)
log_param("batch_size", batch_size)
log_param("max_len", max_len)
in_alphabet = {letter: idx for idx, letter in enumerate(in_lookup)}
out_alphabet = {letter: idx for idx, letter in enumerate(out_lookup)}
data: [(torch.tensor, torch.tensor)] = []
texts: [str] = []
with open(data_file) as f:
for line in f:
text, phonemes = line.split("\t")
texts.append(text)
assert len(text) <= max_len, text
text = torch.tensor([in_alphabet[letter] for letter in text], dtype=torch.int)
data.append((text, phonemes))
cnn = CNN(kernel_size=kernel_size, hidden_layers=hidden_layers, channels=max_len, embedding_size=max_len,
in_alphabet=in_alphabet, max_len=max_len).to(device)
if mode == 'train':
train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size)
torch.save(cnn.state_dict(), model_file)
ex.add_artifact(model_file)
mlflow.pytorch.log_model(cnn, "cnn-model", registered_model_name="PhoneticEdDistEmbeddings",
signature=signature(cnn, in_alphabet, max_len))
if mode == 'eval':
cnn.eval()
evaluate_monte_carlo(cnn, 1, data, batch_size, in_alphabet, max_len)