This commit is contained in:
parent
2b7cfff7f4
commit
bd67201997
@ -8,7 +8,7 @@ pipeline {
|
|||||||
steps {
|
steps {
|
||||||
git 'https://git.wmi.amu.edu.pl/s434749/ium_434749.git'
|
git 'https://git.wmi.amu.edu.pl/s434749/ium_434749.git'
|
||||||
copyArtifacts fingerprintArtifacts: true, projectName: 's434749-training', selector: lastSuccessful()
|
copyArtifacts fingerprintArtifacts: true, projectName: 's434749-training', selector: lastSuccessful()
|
||||||
sh 'python3 train_model.py eval'
|
sh 'python3 train_model.py with "mode=eval"'
|
||||||
script{
|
script{
|
||||||
def results = readFile "${env.WORKSPACE}/results.txt"
|
def results = readFile "${env.WORKSPACE}/results.txt"
|
||||||
}
|
}
|
||||||
@ -17,7 +17,7 @@ pipeline {
|
|||||||
post {
|
post {
|
||||||
success {
|
success {
|
||||||
emailext body: 'Evaluation of CNN for english phonetic embeddings has finished successfully!\n'+results, subject: 's434749 evaluation finished', to: '26ab8f35.uam.onmicrosoft.com@emea.teams.ms'
|
emailext body: 'Evaluation of CNN for english phonetic embeddings has finished successfully!\n'+results, subject: 's434749 evaluation finished', to: '26ab8f35.uam.onmicrosoft.com@emea.teams.ms'
|
||||||
archiveArtifacts 'results.txt'
|
archiveArtifacts 'results.txt, sacred_file_observer'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,13 +8,13 @@ pipeline {
|
|||||||
steps {
|
steps {
|
||||||
git 'https://git.wmi.amu.edu.pl/s434749/ium_434749.git'
|
git 'https://git.wmi.amu.edu.pl/s434749/ium_434749.git'
|
||||||
copyArtifacts fingerprintArtifacts: true, projectName: 's434749-create-dataset', selector: lastSuccessful()
|
copyArtifacts fingerprintArtifacts: true, projectName: 's434749-create-dataset', selector: lastSuccessful()
|
||||||
sh 'python3 train_model.py train'
|
sh 'python3 train_model.py'
|
||||||
}
|
}
|
||||||
|
|
||||||
post {
|
post {
|
||||||
success {
|
success {
|
||||||
emailext body: 'Training of CNN for english phonetic embeddings has finished successfully', subject: 's434749 training finished', to: '26ab8f35.uam.onmicrosoft.com@emea.teams.ms'
|
emailext body: 'Training of CNN for english phonetic embeddings has finished successfully', subject: 's434749 training finished', to: '26ab8f35.uam.onmicrosoft.com@emea.teams.ms'
|
||||||
archiveArtifacts 'cnn.pth'
|
archiveArtifacts 'cnn.pth,sacred_file_observer'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
192
train_model.py
192
train_model.py
@ -9,6 +9,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
from sacred.observers import FileStorageObserver, MongoObserver
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
import re
|
import re
|
||||||
import random
|
import random
|
||||||
@ -16,60 +17,22 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from Levenshtein import distance as levenshtein_distance
|
from Levenshtein import distance as levenshtein_distance
|
||||||
|
from sacred import Experiment
|
||||||
|
|
||||||
DATA_FILE = 'preprocessed.tsv'
|
ex = Experiment("CNN")
|
||||||
EPOCHS = 14
|
ex.observers.append(FileStorageObserver('sacred_file_observer'))
|
||||||
TEACHER_FORCING_PROBABILITY = 0.4
|
ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017',
|
||||||
LEARNING_RATE = 0.01
|
db_name='sacred'))
|
||||||
BATCH_SIZE = 512
|
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
|
||||||
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
|
||||||
|
|
||||||
OUT_LOOKUP = ['', 'b', 'a', 'ʊ', 't', 'k', 'ə', 'z', 'ɔ', 'ɹ', 's', 'j', 'u', 'm', 'f', 'ɪ', 'o', 'ɡ', 'ɛ', 'n',
|
|
||||||
'e', 'd',
|
|
||||||
'ɫ', 'w', 'i', 'p', 'ɑ', 'ɝ', 'θ', 'v', 'h', 'æ', 'ŋ', 'ʃ', 'ʒ', 'ð']
|
|
||||||
|
|
||||||
IN_LOOKUP = ['', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
|
|
||||||
'u', 'v', 'w', 'x', 'y', 'z']
|
|
||||||
|
|
||||||
IN_ALPHABET = {letter: idx for idx, letter in enumerate(IN_LOOKUP)}
|
|
||||||
|
|
||||||
OUT_ALPHABET = {letter: idx for idx, letter in enumerate(OUT_LOOKUP)}
|
|
||||||
|
|
||||||
TOTAL_OUT_LEN = 0
|
|
||||||
|
|
||||||
DATA: [(torch.tensor, torch.tensor)] = []
|
|
||||||
|
|
||||||
TEXT: [str] = []
|
|
||||||
|
|
||||||
MAX_LEN = 32
|
|
||||||
|
|
||||||
with open(DATA_FILE) as f:
|
|
||||||
for line in f:
|
|
||||||
text, phonemes = line.split("\t")
|
|
||||||
TEXT.append(text)
|
|
||||||
assert len(text) <= MAX_LEN, text
|
|
||||||
text = torch.tensor([IN_ALPHABET[letter] for letter in text], dtype=torch.int)
|
|
||||||
DATA.append((text, phonemes))
|
|
||||||
|
|
||||||
|
|
||||||
def collate(batch: [(torch.tensor, str)]):
|
|
||||||
batch_text = torch.zeros((len(batch), len(IN_ALPHABET), MAX_LEN))
|
|
||||||
batch_phonemes = list(map(lambda x: x[1], batch))
|
|
||||||
for i, (sample, _) in enumerate(batch):
|
|
||||||
for chr_pos, index in enumerate(sample):
|
|
||||||
batch_text[i, index, chr_pos] = 1
|
|
||||||
return batch_text, batch_phonemes
|
|
||||||
|
|
||||||
|
|
||||||
class CNN(nn.Module):
|
class CNN(nn.Module):
|
||||||
def __init__(self, kernel_size, hidden_layers, channels, embedding_size):
|
def __init__(self, kernel_size, hidden_layers, channels, embedding_size, in_alphabet, max_len):
|
||||||
super(CNN, self).__init__()
|
super(CNN, self).__init__()
|
||||||
self.input_conv = nn.Conv1d(in_channels=len(IN_ALPHABET), out_channels=channels, kernel_size=kernel_size)
|
self.input_conv = nn.Conv1d(in_channels=len(in_alphabet), out_channels=channels, kernel_size=kernel_size)
|
||||||
self.conv_hidden = nn.ModuleList(
|
self.conv_hidden = nn.ModuleList(
|
||||||
[nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size) for _ in
|
[nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size) for _ in
|
||||||
range(hidden_layers)])
|
range(hidden_layers)])
|
||||||
self.last_layer_size = (MAX_LEN - (kernel_size - 1) * (hidden_layers + 1)) * channels
|
self.last_layer_size = (max_len - (kernel_size - 1) * (hidden_layers + 1)) * channels
|
||||||
self.lin = nn.Linear(self.last_layer_size, embedding_size)
|
self.lin = nn.Linear(self.last_layer_size, embedding_size)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -83,34 +46,40 @@ class CNN(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
outer_bar = tqdm(total=EPOCHS, position=0)
|
|
||||||
inner_bar = tqdm(total=len(DATA), position=1)
|
|
||||||
|
|
||||||
|
|
||||||
def dist(a: [str], b: [str]):
|
def dist(a: [str], b: [str]):
|
||||||
return torch.tensor([levenshtein_distance(a[i], b[i]) for i in range(len(a))], dtype=torch.float, device=DEVICE)
|
return torch.tensor([levenshtein_distance(a[i], b[i]) for i in range(len(a))], dtype=torch.float, device=device)
|
||||||
|
|
||||||
|
|
||||||
def train_model(model):
|
def train_model(model, learning_rate, in_alphabet, max_len, data, epochs, batch_size):
|
||||||
optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()),
|
optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()),
|
||||||
lr=LEARNING_RATE)
|
lr=learning_rate)
|
||||||
loss_snapshots = []
|
outer_bar = tqdm(total=epochs, position=0)
|
||||||
|
inner_bar = tqdm(total=len(data), position=1)
|
||||||
outer_bar.reset()
|
outer_bar.reset()
|
||||||
outer_bar.set_description("Epochs")
|
outer_bar.set_description("Epochs")
|
||||||
data_loader = DataLoader(dataset=DATA, drop_last=True,
|
|
||||||
batch_size=3 * BATCH_SIZE,
|
def collate(batch: [(torch.tensor, str)]):
|
||||||
|
batch_text = torch.zeros((len(batch), len(in_alphabet), max_len))
|
||||||
|
batch_phonemes = list(map(lambda x: x[1], batch))
|
||||||
|
for i, (sample, _) in enumerate(batch):
|
||||||
|
for chr_pos, index in enumerate(sample):
|
||||||
|
batch_text[i, index, chr_pos] = 1
|
||||||
|
return batch_text, batch_phonemes
|
||||||
|
|
||||||
|
data_loader = DataLoader(dataset=data, drop_last=True,
|
||||||
|
batch_size=3 * batch_size,
|
||||||
collate_fn=collate,
|
collate_fn=collate,
|
||||||
shuffle=True)
|
shuffle=True)
|
||||||
for epoch in range(EPOCHS):
|
for epoch in range(epochs):
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
inner_bar.reset()
|
inner_bar.reset()
|
||||||
|
|
||||||
for batch_text, batch_phonemes in data_loader:
|
for batch_text, batch_phonemes in data_loader:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
anchor, positive, negative = batch_text.to(DEVICE).split(BATCH_SIZE)
|
anchor, positive, negative = batch_text.to(device).split(batch_size)
|
||||||
ph_anchor = batch_phonemes[:BATCH_SIZE]
|
ph_anchor = batch_phonemes[:batch_size]
|
||||||
ph_positive = batch_phonemes[BATCH_SIZE:2 * BATCH_SIZE]
|
ph_positive = batch_phonemes[batch_size:2 * batch_size]
|
||||||
ph_negative = batch_phonemes[2 * BATCH_SIZE:]
|
ph_negative = batch_phonemes[2 * batch_size:]
|
||||||
embedded_anchor = model(anchor)
|
embedded_anchor = model(anchor)
|
||||||
embedded_positive = model(positive)
|
embedded_positive = model(positive)
|
||||||
embedded_negative = model(negative)
|
embedded_negative = model(negative)
|
||||||
@ -126,11 +95,11 @@ def train_model(model):
|
|||||||
+ (estimated_pos_dist - estimated_neg_dist - (actual_pos_dist - actual_neg_dist)).clip(min=0))
|
+ (estimated_pos_dist - estimated_neg_dist - (actual_pos_dist - actual_neg_dist)).clip(min=0))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
inner_bar.update(3 * BATCH_SIZE)
|
inner_bar.update(3 * batch_size)
|
||||||
loss_scalar = loss.item()
|
loss_scalar = loss.item()
|
||||||
total_loss += loss_scalar
|
total_loss += loss_scalar
|
||||||
inner_bar.set_description("loss %.2f" % loss_scalar)
|
inner_bar.set_description("loss %.2f" % loss_scalar)
|
||||||
loss_snapshots.append(total_loss / len(DATA) * 3)
|
ex.log_scalar("avg_loss", total_loss / len(data) * 3)
|
||||||
# print()
|
# print()
|
||||||
# print("Total epoch loss:", total_loss)
|
# print("Total epoch loss:", total_loss)
|
||||||
# print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN)
|
# print("Total epoch avg loss:", total_loss / TOTAL_TRAINING_OUT_LEN)
|
||||||
@ -142,46 +111,99 @@ def train_model(model):
|
|||||||
outer_bar.update(1)
|
outer_bar.update(1)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_monte_carlo(model, repeats):
|
def evaluate_monte_carlo(model, repeats, data, batch_size, in_alphabet, max_len):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
i = 0
|
i = 0
|
||||||
diff = 0
|
diff = 0
|
||||||
outer_bar.reset(total=repeats)
|
outer_bar = tqdm(total=repeats, position=0)
|
||||||
|
inner_bar = tqdm(total=len(data), position=1)
|
||||||
outer_bar.set_description("Epochs")
|
outer_bar.set_description("Epochs")
|
||||||
|
|
||||||
|
def collate(batch: [(torch.tensor, str)]):
|
||||||
|
batch_text = torch.zeros((len(batch), len(in_alphabet), max_len))
|
||||||
|
batch_phonemes = list(map(lambda x: x[1], batch))
|
||||||
|
for i, (sample, _) in enumerate(batch):
|
||||||
|
for chr_pos, index in enumerate(sample):
|
||||||
|
batch_text[i, index, chr_pos] = 1
|
||||||
|
return batch_text, batch_phonemes
|
||||||
|
|
||||||
for _ in range(repeats):
|
for _ in range(repeats):
|
||||||
data_loader = DataLoader(dataset=DATA, drop_last=True,
|
data_loader = DataLoader(dataset=data, drop_last=True,
|
||||||
batch_size=2 * BATCH_SIZE,
|
batch_size=2 * batch_size,
|
||||||
collate_fn=collate,
|
collate_fn=collate,
|
||||||
shuffle=True)
|
shuffle=True)
|
||||||
inner_bar.reset()
|
inner_bar.reset()
|
||||||
for batch_text, batch_phonemes in data_loader:
|
for batch_text, batch_phonemes in data_loader:
|
||||||
positive, negative = batch_text.to(DEVICE).split(BATCH_SIZE)
|
positive, negative = batch_text.to(device).split(batch_size)
|
||||||
ph_positive = batch_phonemes[0:BATCH_SIZE]
|
ph_positive = batch_phonemes[0:batch_size]
|
||||||
ph_negative = batch_phonemes[BATCH_SIZE:]
|
ph_negative = batch_phonemes[batch_size:]
|
||||||
embedded_positive = model(positive)
|
embedded_positive = model(positive)
|
||||||
embedded_negative = model(negative)
|
embedded_negative = model(negative)
|
||||||
estimated_dist = torch.linalg.norm(embedded_negative - embedded_positive, dim=1)
|
estimated_dist = torch.linalg.norm(embedded_negative - embedded_positive, dim=1)
|
||||||
actual_dist = dist(ph_negative, ph_positive)
|
actual_dist = dist(ph_negative, ph_positive)
|
||||||
diff += sum(abs(estimated_dist - actual_dist))
|
diff += sum(abs(estimated_dist - actual_dist))
|
||||||
i += BATCH_SIZE
|
i += batch_size
|
||||||
inner_bar.update(2 * BATCH_SIZE)
|
inner_bar.update(2 * batch_size)
|
||||||
outer_bar.update(1)
|
outer_bar.update(1)
|
||||||
with open('results.txt', 'w+') as r:
|
with open('results.txt', 'w+') as r:
|
||||||
print("Average estimation error " + str(diff.item() / i))
|
print("Average estimation error " + str(diff.item() / i))
|
||||||
r.write("Average estimation error " + str(diff.item() / i) + "\n")
|
r.write("Average estimation error " + str(diff.item() / i) + "\n")
|
||||||
|
ex.log_scalar("avg_estim_error", diff.item() / i)
|
||||||
|
|
||||||
|
|
||||||
cnn = CNN(kernel_size=3, hidden_layers=14, channels=MAX_LEN, embedding_size=MAX_LEN).to(DEVICE)
|
@ex.config
|
||||||
if os.path.isfile('cnn.pth'):
|
def cfg():
|
||||||
cnn.load_state_dict(torch.load('cnn.pth', map_location=torch.device('cpu')))
|
kernel_size = 3
|
||||||
else:
|
hidden_layers = 14
|
||||||
if len(sys.argv) > 1 and sys.argv[1] == 'train':
|
data_file = 'preprocessed.tsv'
|
||||||
train_model(cnn)
|
epochs = 14
|
||||||
torch.save(cnn.state_dict(), 'cnn.pth')
|
mode = 'train'
|
||||||
|
teacher_forcing_probability = 0.4
|
||||||
|
learning_rate = 0.01
|
||||||
|
batch_size = 512
|
||||||
|
max_len = 32
|
||||||
|
total_out_len = 0
|
||||||
|
model_file = 'cnn.pth'
|
||||||
|
out_lookup = ['', 'b', 'a', 'ʊ', 't', 'k', 'ə', 'z', 'ɔ', 'ɹ', 's', 'j', 'u', 'm', 'f', 'ɪ', 'o', 'ɡ', 'ɛ', 'n',
|
||||||
|
'e', 'd',
|
||||||
|
'ɫ', 'w', 'i', 'p', 'ɑ', 'ɝ', 'θ', 'v', 'h', 'æ', 'ŋ', 'ʃ', 'ʒ', 'ð']
|
||||||
|
in_lookup = ['', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
|
||||||
|
'u', 'v', 'w', 'x', 'y', 'z']
|
||||||
|
|
||||||
|
|
||||||
|
@ex.automain
|
||||||
|
def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probability, learning_rate, batch_size, max_len,
|
||||||
|
total_out_len, model_file, out_lookup, in_lookup, mode):
|
||||||
|
in_alphabet = {letter: idx for idx, letter in enumerate(in_lookup)}
|
||||||
|
|
||||||
|
out_alphabet = {letter: idx for idx, letter in enumerate(out_lookup)}
|
||||||
|
|
||||||
|
data: [(torch.tensor, torch.tensor)] = []
|
||||||
|
|
||||||
|
texts: [str] = []
|
||||||
|
|
||||||
|
with open(data_file) as f:
|
||||||
|
for line in f:
|
||||||
|
text, phonemes = line.split("\t")
|
||||||
|
texts.append(text)
|
||||||
|
assert len(text) <= max_len, text
|
||||||
|
text = torch.tensor([in_alphabet[letter] for letter in text], dtype=torch.int)
|
||||||
|
data.append((text, phonemes))
|
||||||
|
|
||||||
|
cnn = CNN(kernel_size=kernel_size, hidden_layers=hidden_layers, channels=max_len, embedding_size=max_len,
|
||||||
|
in_alphabet=in_alphabet, max_len=max_len).to(device)
|
||||||
|
if os.path.isfile(model_file):
|
||||||
|
cnn.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))
|
||||||
else:
|
else:
|
||||||
print("cnn.pth missing!")
|
if mode == 'train':
|
||||||
exit(2)
|
train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size)
|
||||||
|
torch.save(cnn.state_dict(), model_file)
|
||||||
|
ex.add_artifact(model_file)
|
||||||
|
else:
|
||||||
|
print(model_file + " missing!")
|
||||||
|
exit(2)
|
||||||
|
|
||||||
|
if mode == 'eval':
|
||||||
|
cnn.eval()
|
||||||
|
evaluate_monte_carlo(cnn, 1, data, batch_size, in_alphabet, max_len)
|
||||||
|
|
||||||
if len(sys.argv) > 1 and sys.argv[1] == 'eval':
|
|
||||||
cnn.eval()
|
|
||||||
evaluate_monte_carlo(cnn, 1)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user