ium_434749/train_model.py
Alagris 638e8e92ea
All checks were successful
s434749-training/pipeline/head This commit looks good
no mongo
2021-05-10 12:00:34 +02:00

215 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# https://arxiv.org/pdf/2001.11692.pdf
import numpy as np
import unicodedata
from time import sleep
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sacred.observers import FileStorageObserver, MongoObserver
from torch.utils.data import Dataset, DataLoader
import re
import random
import os
import sys
from tqdm import tqdm
from Levenshtein import distance as levenshtein_distance
from sacred import Experiment
import traceback
ex = Experiment("CNN")
ex.observers.append(FileStorageObserver('sacred_file_observer'))
# try:
# ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017',
# db_name='sacred'))
# except Exception as e:
# traceback.print_exc()
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
class CNN(nn.Module):
def __init__(self, kernel_size, hidden_layers, channels, embedding_size, in_alphabet, max_len):
super(CNN, self).__init__()
self.input_conv = nn.Conv1d(in_channels=len(in_alphabet), out_channels=channels, kernel_size=kernel_size)
self.conv_hidden = nn.ModuleList(
[nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size) for _ in
range(hidden_layers)])
self.last_layer_size = (max_len - (kernel_size - 1) * (hidden_layers + 1)) * channels
self.lin = nn.Linear(self.last_layer_size, embedding_size)
def forward(self, x):
x = self.input_conv(x)
x = F.relu(x, inplace=True)
for c in self.conv_hidden:
x = c(x)
x = F.relu(x, inplace=True)
x = x.view(x.size()[0], self.last_layer_size)
x = self.lin(x)
return x
def dist(a: [str], b: [str]):
return torch.tensor([levenshtein_distance(a[i], b[i]) for i in range(len(a))], dtype=torch.float, device=device)
def train_model(model, learning_rate, in_alphabet, max_len, data, epochs, batch_size):
optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()),
lr=learning_rate)
outer_bar = tqdm(total=epochs, position=0)
inner_bar = tqdm(total=len(data), position=1)
outer_bar.reset()
outer_bar.set_description("Epochs")
def collate(batch: [(torch.tensor, str)]):
batch_text = torch.zeros((len(batch), len(in_alphabet), max_len))
batch_phonemes = list(map(lambda x: x[1], batch))
for i, (sample, _) in enumerate(batch):
for chr_pos, index in enumerate(sample):
batch_text[i, index, chr_pos] = 1
return batch_text, batch_phonemes
data_loader = DataLoader(dataset=data, drop_last=True,
batch_size=3 * batch_size,
collate_fn=collate,
shuffle=True)
for epoch in range(epochs):
total_loss = 0
inner_bar.reset()
for batch_text, batch_phonemes in data_loader:
optimizer.zero_grad()
anchor, positive, negative = batch_text.to(device).split(batch_size)
ph_anchor = batch_phonemes[:batch_size]
ph_positive = batch_phonemes[batch_size:2 * batch_size]
ph_negative = batch_phonemes[2 * batch_size:]
embedded_anchor = model(anchor)
embedded_positive = model(positive)
embedded_negative = model(negative)
estimated_pos_dist = torch.linalg.norm(embedded_anchor - embedded_positive, dim=1)
estimated_neg_dist = torch.linalg.norm(embedded_anchor - embedded_negative, dim=1)
estimated_pos_neg_dist = torch.linalg.norm(embedded_positive - embedded_negative, dim=1)
actual_pos_dist = dist(ph_anchor, ph_positive)
actual_neg_dist = dist(ph_anchor, ph_negative)
actual_pos_neg_dist = dist(ph_positive, ph_negative)
loss = sum(abs(estimated_neg_dist - actual_neg_dist)
+ abs(estimated_pos_dist - actual_pos_dist)
+ abs(estimated_pos_neg_dist - actual_pos_neg_dist)
+ (estimated_pos_dist - estimated_neg_dist - (actual_pos_dist - actual_neg_dist)).clip(min=0))
loss.backward()
optimizer.step()
inner_bar.update(3 * batch_size)
loss_scalar = loss.item()
total_loss += loss_scalar
inner_bar.set_description("loss %.2f" % loss_scalar)
ex.log_scalar("avg_loss", total_loss / len(data) * 3)
# print()
# print("Total epoch loss:", total_loss)
# print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN)
# print("Training snapshots:", train_snapshots)
# print("Training snapshots(%):", train_snapshots_percentage)
# print("Evaluation snapshots:", eval_snapshots)
# print("Evaluation snapshots(%):", eval_snapshots_percentage)
outer_bar.set_description("Epochs")
outer_bar.update(1)
def evaluate_monte_carlo(model, repeats, data, batch_size, in_alphabet, max_len):
with torch.no_grad():
i = 0
diff = 0
outer_bar = tqdm(total=repeats, position=0)
inner_bar = tqdm(total=len(data), position=1)
outer_bar.set_description("Epochs")
def collate(batch: [(torch.tensor, str)]):
batch_text = torch.zeros((len(batch), len(in_alphabet), max_len))
batch_phonemes = list(map(lambda x: x[1], batch))
for i, (sample, _) in enumerate(batch):
for chr_pos, index in enumerate(sample):
batch_text[i, index, chr_pos] = 1
return batch_text, batch_phonemes
for _ in range(repeats):
data_loader = DataLoader(dataset=data, drop_last=True,
batch_size=2 * batch_size,
collate_fn=collate,
shuffle=True)
inner_bar.reset()
for batch_text, batch_phonemes in data_loader:
positive, negative = batch_text.to(device).split(batch_size)
ph_positive = batch_phonemes[0:batch_size]
ph_negative = batch_phonemes[batch_size:]
embedded_positive = model(positive)
embedded_negative = model(negative)
estimated_dist = torch.linalg.norm(embedded_negative - embedded_positive, dim=1)
actual_dist = dist(ph_negative, ph_positive)
diff += sum(abs(estimated_dist - actual_dist))
i += batch_size
inner_bar.update(2 * batch_size)
outer_bar.update(1)
with open('results.txt', 'w+') as r:
print("Average estimation error " + str(diff.item() / i))
r.write("Average estimation error " + str(diff.item() / i) + "\n")
ex.log_scalar("avg_estim_error", diff.item() / i)
@ex.config
def cfg():
kernel_size = 3
hidden_layers = 14
data_file = 'preprocessed.tsv'
epochs = 14
mode = 'train'
teacher_forcing_probability = 0.4
learning_rate = 0.01
batch_size = 512
max_len = 32
total_out_len = 0
model_file = 'cnn.pth'
out_lookup = ['', 'b', 'a', 'ʊ', 't', 'k', 'ə', 'z', 'ɔ', 'ɹ', 's', 'j', 'u', 'm', 'f', 'ɪ', 'o', 'ɡ', 'ɛ', 'n',
'e', 'd',
'ɫ', 'w', 'i', 'p', 'ɑ', 'ɝ', 'θ', 'v', 'h', 'æ', 'ŋ', 'ʃ', 'ʒ', 'ð']
in_lookup = ['', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
'u', 'v', 'w', 'x', 'y', 'z']
@ex.automain
def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probability, learning_rate, batch_size, max_len,
total_out_len, model_file, out_lookup, in_lookup, mode):
in_alphabet = {letter: idx for idx, letter in enumerate(in_lookup)}
out_alphabet = {letter: idx for idx, letter in enumerate(out_lookup)}
data: [(torch.tensor, torch.tensor)] = []
texts: [str] = []
with open(data_file) as f:
for line in f:
text, phonemes = line.split("\t")
texts.append(text)
assert len(text) <= max_len, text
text = torch.tensor([in_alphabet[letter] for letter in text], dtype=torch.int)
data.append((text, phonemes))
cnn = CNN(kernel_size=kernel_size, hidden_layers=hidden_layers, channels=max_len, embedding_size=max_len,
in_alphabet=in_alphabet, max_len=max_len).to(device)
if os.path.isfile(model_file):
cnn.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))
else:
if mode == 'train':
train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size)
torch.save(cnn.state_dict(), model_file)
ex.add_artifact(model_file)
else:
print(model_file + " missing!")
exit(2)
if mode == 'eval':
cnn.eval()
evaluate_monte_carlo(cnn, 1, data, batch_size, in_alphabet, max_len)