# https://arxiv.org/pdf/2001.11692.pdf import numpy as np import unicodedata from time import sleep import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from sacred.observers import FileStorageObserver, MongoObserver from torch.utils.data import Dataset, DataLoader import re import random import os import sys from tqdm import tqdm from Levenshtein import distance as levenshtein_distance from sacred import Experiment ex = Experiment("CNN") ex.observers.append(FileStorageObserver('sacred_file_observer')) ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017', db_name='sacred')) device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') class CNN(nn.Module): def __init__(self, kernel_size, hidden_layers, channels, embedding_size, in_alphabet, max_len): super(CNN, self).__init__() self.input_conv = nn.Conv1d(in_channels=len(in_alphabet), out_channels=channels, kernel_size=kernel_size) self.conv_hidden = nn.ModuleList( [nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size) for _ in range(hidden_layers)]) self.last_layer_size = (max_len - (kernel_size - 1) * (hidden_layers + 1)) * channels self.lin = nn.Linear(self.last_layer_size, embedding_size) def forward(self, x): x = self.input_conv(x) x = F.relu(x, inplace=True) for c in self.conv_hidden: x = c(x) x = F.relu(x, inplace=True) x = x.view(x.size()[0], self.last_layer_size) x = self.lin(x) return x def dist(a: [str], b: [str]): return torch.tensor([levenshtein_distance(a[i], b[i]) for i in range(len(a))], dtype=torch.float, device=device) def train_model(model, learning_rate, in_alphabet, max_len, data, epochs, batch_size): optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()), lr=learning_rate) outer_bar = tqdm(total=epochs, position=0) inner_bar = tqdm(total=len(data), position=1) outer_bar.reset() outer_bar.set_description("Epochs") def collate(batch: [(torch.tensor, str)]): batch_text = torch.zeros((len(batch), len(in_alphabet), max_len)) batch_phonemes = list(map(lambda x: x[1], batch)) for i, (sample, _) in enumerate(batch): for chr_pos, index in enumerate(sample): batch_text[i, index, chr_pos] = 1 return batch_text, batch_phonemes data_loader = DataLoader(dataset=data, drop_last=True, batch_size=3 * batch_size, collate_fn=collate, shuffle=True) for epoch in range(epochs): total_loss = 0 inner_bar.reset() for batch_text, batch_phonemes in data_loader: optimizer.zero_grad() anchor, positive, negative = batch_text.to(device).split(batch_size) ph_anchor = batch_phonemes[:batch_size] ph_positive = batch_phonemes[batch_size:2 * batch_size] ph_negative = batch_phonemes[2 * batch_size:] embedded_anchor = model(anchor) embedded_positive = model(positive) embedded_negative = model(negative) estimated_pos_dist = torch.linalg.norm(embedded_anchor - embedded_positive, dim=1) estimated_neg_dist = torch.linalg.norm(embedded_anchor - embedded_negative, dim=1) estimated_pos_neg_dist = torch.linalg.norm(embedded_positive - embedded_negative, dim=1) actual_pos_dist = dist(ph_anchor, ph_positive) actual_neg_dist = dist(ph_anchor, ph_negative) actual_pos_neg_dist = dist(ph_positive, ph_negative) loss = sum(abs(estimated_neg_dist - actual_neg_dist) + abs(estimated_pos_dist - actual_pos_dist) + abs(estimated_pos_neg_dist - actual_pos_neg_dist) + (estimated_pos_dist - estimated_neg_dist - (actual_pos_dist - actual_neg_dist)).clip(min=0)) loss.backward() optimizer.step() inner_bar.update(3 * batch_size) loss_scalar = loss.item() total_loss += loss_scalar inner_bar.set_description("loss %.2f" % loss_scalar) ex.log_scalar("avg_loss", total_loss / len(data) * 3) # print() # print("Total epoch loss:", total_loss) # print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN) # print("Training snapshots:", train_snapshots) # print("Training snapshots(%):", train_snapshots_percentage) # print("Evaluation snapshots:", eval_snapshots) # print("Evaluation snapshots(%):", eval_snapshots_percentage) outer_bar.set_description("Epochs") outer_bar.update(1) def evaluate_monte_carlo(model, repeats, data, batch_size, in_alphabet, max_len): with torch.no_grad(): i = 0 diff = 0 outer_bar = tqdm(total=repeats, position=0) inner_bar = tqdm(total=len(data), position=1) outer_bar.set_description("Epochs") def collate(batch: [(torch.tensor, str)]): batch_text = torch.zeros((len(batch), len(in_alphabet), max_len)) batch_phonemes = list(map(lambda x: x[1], batch)) for i, (sample, _) in enumerate(batch): for chr_pos, index in enumerate(sample): batch_text[i, index, chr_pos] = 1 return batch_text, batch_phonemes for _ in range(repeats): data_loader = DataLoader(dataset=data, drop_last=True, batch_size=2 * batch_size, collate_fn=collate, shuffle=True) inner_bar.reset() for batch_text, batch_phonemes in data_loader: positive, negative = batch_text.to(device).split(batch_size) ph_positive = batch_phonemes[0:batch_size] ph_negative = batch_phonemes[batch_size:] embedded_positive = model(positive) embedded_negative = model(negative) estimated_dist = torch.linalg.norm(embedded_negative - embedded_positive, dim=1) actual_dist = dist(ph_negative, ph_positive) diff += sum(abs(estimated_dist - actual_dist)) i += batch_size inner_bar.update(2 * batch_size) outer_bar.update(1) with open('results.txt', 'w+') as r: print("Average estimation error " + str(diff.item() / i)) r.write("Average estimation error " + str(diff.item() / i) + "\n") ex.log_scalar("avg_estim_error", diff.item() / i) @ex.config def cfg(): kernel_size = 3 hidden_layers = 14 data_file = 'preprocessed.tsv' epochs = 14 mode = 'train' teacher_forcing_probability = 0.4 learning_rate = 0.01 batch_size = 512 max_len = 32 total_out_len = 0 model_file = 'cnn.pth' out_lookup = ['', 'b', 'a', 'ʊ', 't', 'k', 'ə', 'z', 'ɔ', 'ɹ', 's', 'j', 'u', 'm', 'f', 'ɪ', 'o', 'ɡ', 'ɛ', 'n', 'e', 'd', 'ɫ', 'w', 'i', 'p', 'ɑ', 'ɝ', 'θ', 'v', 'h', 'æ', 'ŋ', 'ʃ', 'ʒ', 'ð'] in_lookup = ['', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] @ex.automain def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probability, learning_rate, batch_size, max_len, total_out_len, model_file, out_lookup, in_lookup, mode): in_alphabet = {letter: idx for idx, letter in enumerate(in_lookup)} out_alphabet = {letter: idx for idx, letter in enumerate(out_lookup)} data: [(torch.tensor, torch.tensor)] = [] texts: [str] = [] with open(data_file) as f: for line in f: text, phonemes = line.split("\t") texts.append(text) assert len(text) <= max_len, text text = torch.tensor([in_alphabet[letter] for letter in text], dtype=torch.int) data.append((text, phonemes)) cnn = CNN(kernel_size=kernel_size, hidden_layers=hidden_layers, channels=max_len, embedding_size=max_len, in_alphabet=in_alphabet, max_len=max_len).to(device) if os.path.isfile(model_file): cnn.load_state_dict(torch.load(model_file, map_location=torch.device('cpu'))) else: if mode == 'train': train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size) torch.save(cnn.state_dict(), model_file) ex.add_artifact(model_file) else: print(model_file + " missing!") exit(2) if mode == 'eval': cnn.eval() evaluate_monte_carlo(cnn, 1, data, batch_size, in_alphabet, max_len)