From adc78b066c0323b819679061ffdd05062789c74a Mon Sep 17 00:00:00 2001 From: Maciej Sobkowiak Date: Tue, 22 Jun 2021 01:19:45 +0200 Subject: [PATCH] working on training --- seq.py | 107 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 104 insertions(+), 3 deletions(-) diff --git a/seq.py b/seq.py index a7f447b..847ea74 100644 --- a/seq.py +++ b/seq.py @@ -7,25 +7,96 @@ import pandas as pd from sklearn.model_selection import train_test_split from collections import Counter from torchtext.vocab import Vocab +from TorchCRF import CRF +from tqdm import tqdm +EPOCHS = 5 +BATCH = 1 # Functions from jupyter + + def build_vocab(dataset): counter = Counter() for document in dataset: counter.update(document) - return Vocab(counter, specials=['', '', '', '']) + return Vocab(counter) def data_process(dt, vocab): return [torch.tensor([vocab['']] + [vocab[token] for token in document] + [vocab['']], dtype=torch.long) for document in dt] -def labels_process(dt, vocab): - return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt] +def get_scores(y_true, y_pred): + acc_score = 0 + tp = 0 + fp = 0 + selected_items = 0 + relevant_items = 0 + for p, t in zip(y_pred, y_true): + if p == t: + acc_score += 1 + + if p > 0 and p == t: + tp += 1 + + if p > 0: + selected_items += 1 + + if t > 0: + relevant_items += 1 + + if selected_items == 0: + precision = 1.0 + else: + precision = tp / selected_items + + if relevant_items == 0: + recall = 1.0 + else: + recall = tp / relevant_items + + if precision + recall == 0.0: + f1 = 0.0 + else: + f1 = 2 * precision * recall / (precision + recall) + + return precision, recall, f1 + + +def eval_model(dataset_tokens, dataset_labels, model): + Y_true = [] + Y_pred = [] + for i in tqdm(range(len(dataset_labels))): + batch_tokens = dataset_tokens[i].unsqueeze(0) + tags = list(dataset_labels[i].numpy()) + Y_true += tags + + Y_batch_pred_weights = model(batch_tokens).squeeze(0) + Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1) + Y_pred += list(Y_batch_pred.numpy()) + + return get_scores(Y_true, Y_pred) + + +class LSTM(torch.nn.Module): + def __init__(self, vocab_len): + super(LSTM, self).__init__() + self.emb = torch.nn.Embedding(vocab_len, 100) + self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True) + self.fc1 = torch.nn.Linear(256, 9) + + def forward(self, x): + emb = torch.relu(self.emb(x)) + lstm_output, (h_n, c_n) = self.rec(emb) + out_weights = self.fc1(lstm_output) + + return out_weights # Load data + + def load_data(): train = pd.read_csv('train/train.tsv', sep='\t', names=['labels', 'document']) @@ -44,5 +115,35 @@ def load_data(): return X_train, Y_train, X_dev, Y_dev, X_test +def train(model, crf, train_tokens, labels_tokens): + for i in range(EPOCHS): + crf.train() + model.train() + for i in tqdm(range(len(labels_tokens))): + batch_tokens = train_tokens[i].unsqueeze(0) + tags = labels_tokens[i].unsqueeze(1) + + predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1) + + optimizer.zero_grad() + loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1)) + + loss.backward() + optimizer.step() + + if __name__ == "__main__": X_train, Y_train, X_dev, Y_dev, X_test = load_data() + vocab_x = build_vocab(X_train) + vocab_y = build_vocab(Y_train) + train_tokens = data_process(X_train, vocab_x) + labels_tokens = data_process(Y_train, vocab_y) + print(train_tokens[0]) + + # model + model = LSTM(len(vocab_x)) + crf = CRF(9) + p = list(model.parameters()) + list(crf.parameters()) + optimizer = torch.optim.Adam(p) + criterion = torch.nn.CrossEntropyLoss() + train(model, crf, train_tokens, labels_tokens)