From 48d472eb45e7a5c7280dece17c6af6195fbc8047 Mon Sep 17 00:00:00 2001 From: Maciej Sobkowiak Date: Tue, 22 Jun 2021 03:40:29 +0200 Subject: [PATCH] working on eval --- seq.py | 81 ++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 53 insertions(+), 28 deletions(-) diff --git a/seq.py b/seq.py index f47633e..7427361 100644 --- a/seq.py +++ b/seq.py @@ -6,12 +6,13 @@ import torch import pandas as pd from sklearn.model_selection import train_test_split from collections import Counter -from torchtext.vocab import Vocab +from torchtext.vocab import vocab from TorchCRF import CRF from tqdm import tqdm -EPOCHS = 5 +EPOCHS = 1 BATCH = 1 +SEQ_LEN = 5 # Functions from jupyter @@ -20,11 +21,13 @@ def build_vocab(dataset): counter = Counter() for document in dataset: counter.update(document) - return Vocab(counter) + v = vocab(counter) + v.set_default_index(0) + return v def data_process(dt, vocab): - return [torch.tensor([vocab['']] + [vocab[token] for token in document] + [vocab['']], dtype=torch.long) for document in dt] + return [torch.tensor([vocab[token] for token in document], dtype=torch.long) for document in dt] def get_scores(y_true, y_pred): @@ -33,17 +36,13 @@ def get_scores(y_true, y_pred): fp = 0 selected_items = 0 relevant_items = 0 - for p, t in zip(y_pred, y_true): if p == t: acc_score += 1 - if p > 0 and p == t: tp += 1 - if p > 0: selected_items += 1 - if t > 0: relevant_items += 1 @@ -65,26 +64,11 @@ def get_scores(y_true, y_pred): return precision, recall, f1 -def eval_model(dataset_tokens, dataset_labels, model): - Y_true = [] - Y_pred = [] - for i in tqdm(range(len(dataset_labels))): - batch_tokens = dataset_tokens[i].unsqueeze(0) - tags = list(dataset_labels[i].numpy()) - Y_true += tags - - Y_batch_pred_weights = model(batch_tokens).squeeze(0) - Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1) - Y_pred += list(Y_batch_pred.numpy()) - - return get_scores(Y_true, Y_pred) - - class GRU(torch.nn.Module): def __init__(self, vocab_len): super(GRU, self).__init__() self.emb = torch.nn.Embedding(vocab_len, 100) - self.rec = torch.nn.GRU(100, 256, 2, batch_first=True, dropout=0.2) + self.rec = torch.nn.GRU(100, 256, 1, batch_first=True, dropout=0.2) self.fc1 = torch.nn.Linear(256, 9) def forward(self, x): @@ -131,6 +115,42 @@ def train(model, crf, train_tokens, labels_tokens): optimizer.step() +def data_translate(dt, vocab): + return [[vocab.itos[token] for token in document] for document in dt] + + +def dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab): + Y_true = [] + Y_pred = [] + model.eval() + crf.eval() + for i in tqdm(range(len(dev_labels_tokens))): + batch_tokens = dev_tokens[i].unsqueeze(0) + tags = list(dev_labels_tokens[i].numpy()) + Y_true += tags + + Y_batch_pred_weights = model(batch_tokens).squeeze(0) + Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1) + # Y_pred += list(Y_batch_pred.numpy()) + Y_pred += [crf.decode(Y_batch_pred)[0]] + + # print(Y_pred) + # Y_pred_translated = data_translate(Y_pred, vocab) + # with open('dev-0/out.tsv', "w+") as file: + # temp_str = "" + # for i in Y_pred_translated: + # for j in i: + # temp_str += str(j) + # temp_str += " " + # temp_str = temp_str[:-1] + # temp_str += "\n" + # temp_str = temp_str[:-1] + # file.write(temp_str) + + precision, recall, f1 = get_scores(Y_true, Y_pred) + print(f'precision: {0}, recall: {1}, f1: {2}', precision, recall, f1) + + if __name__ == "__main__": X_train, Y_train, X_dev, Y_dev, X_test = load_data() vocab_x = build_vocab(X_train) @@ -138,11 +158,16 @@ if __name__ == "__main__": train_tokens = data_process(X_train, vocab_x) labels_tokens = data_process(Y_train, vocab_y) - # model - model = GRU(len(vocab_x)) - print(model) + # train + print(len(vocab_x.get_itos())) + model = GRU(len(vocab_x.get_itos())) crf = CRF(9) p = list(model.parameters()) + list(crf.parameters()) optimizer = torch.optim.Adam(p) - # mask = torch.ByteTensor([1, 1]) # (batch_size. sequence_size) + # # mask = torch.ByteTensor([1, 1]) # (batch_size. sequence_size) train(model, crf, train_tokens, labels_tokens) + + # eval dev + dev_tokens = data_process(X_dev, vocab_x) + dev_labels_tokens = data_process(Y_dev, vocab_y) + dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab_x)