forked from kubapok/en-ner-conll-2003
working on eval
This commit is contained in:
parent
81fa0ec07f
commit
48d472eb45
81
seq.py
81
seq.py
@ -6,12 +6,13 @@ import torch
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from torchtext.vocab import Vocab
|
from torchtext.vocab import vocab
|
||||||
from TorchCRF import CRF
|
from TorchCRF import CRF
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
EPOCHS = 5
|
EPOCHS = 1
|
||||||
BATCH = 1
|
BATCH = 1
|
||||||
|
SEQ_LEN = 5
|
||||||
|
|
||||||
# Functions from jupyter
|
# Functions from jupyter
|
||||||
|
|
||||||
@ -20,11 +21,13 @@ def build_vocab(dataset):
|
|||||||
counter = Counter()
|
counter = Counter()
|
||||||
for document in dataset:
|
for document in dataset:
|
||||||
counter.update(document)
|
counter.update(document)
|
||||||
return Vocab(counter)
|
v = vocab(counter)
|
||||||
|
v.set_default_index(0)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
def data_process(dt, vocab):
|
def data_process(dt, vocab):
|
||||||
return [torch.tensor([vocab['<bos>']] + [vocab[token] for token in document] + [vocab['<eos>']], dtype=torch.long) for document in dt]
|
return [torch.tensor([vocab[token] for token in document], dtype=torch.long) for document in dt]
|
||||||
|
|
||||||
|
|
||||||
def get_scores(y_true, y_pred):
|
def get_scores(y_true, y_pred):
|
||||||
@ -33,17 +36,13 @@ def get_scores(y_true, y_pred):
|
|||||||
fp = 0
|
fp = 0
|
||||||
selected_items = 0
|
selected_items = 0
|
||||||
relevant_items = 0
|
relevant_items = 0
|
||||||
|
|
||||||
for p, t in zip(y_pred, y_true):
|
for p, t in zip(y_pred, y_true):
|
||||||
if p == t:
|
if p == t:
|
||||||
acc_score += 1
|
acc_score += 1
|
||||||
|
|
||||||
if p > 0 and p == t:
|
if p > 0 and p == t:
|
||||||
tp += 1
|
tp += 1
|
||||||
|
|
||||||
if p > 0:
|
if p > 0:
|
||||||
selected_items += 1
|
selected_items += 1
|
||||||
|
|
||||||
if t > 0:
|
if t > 0:
|
||||||
relevant_items += 1
|
relevant_items += 1
|
||||||
|
|
||||||
@ -65,26 +64,11 @@ def get_scores(y_true, y_pred):
|
|||||||
return precision, recall, f1
|
return precision, recall, f1
|
||||||
|
|
||||||
|
|
||||||
def eval_model(dataset_tokens, dataset_labels, model):
|
|
||||||
Y_true = []
|
|
||||||
Y_pred = []
|
|
||||||
for i in tqdm(range(len(dataset_labels))):
|
|
||||||
batch_tokens = dataset_tokens[i].unsqueeze(0)
|
|
||||||
tags = list(dataset_labels[i].numpy())
|
|
||||||
Y_true += tags
|
|
||||||
|
|
||||||
Y_batch_pred_weights = model(batch_tokens).squeeze(0)
|
|
||||||
Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
|
|
||||||
Y_pred += list(Y_batch_pred.numpy())
|
|
||||||
|
|
||||||
return get_scores(Y_true, Y_pred)
|
|
||||||
|
|
||||||
|
|
||||||
class GRU(torch.nn.Module):
|
class GRU(torch.nn.Module):
|
||||||
def __init__(self, vocab_len):
|
def __init__(self, vocab_len):
|
||||||
super(GRU, self).__init__()
|
super(GRU, self).__init__()
|
||||||
self.emb = torch.nn.Embedding(vocab_len, 100)
|
self.emb = torch.nn.Embedding(vocab_len, 100)
|
||||||
self.rec = torch.nn.GRU(100, 256, 2, batch_first=True, dropout=0.2)
|
self.rec = torch.nn.GRU(100, 256, 1, batch_first=True, dropout=0.2)
|
||||||
self.fc1 = torch.nn.Linear(256, 9)
|
self.fc1 = torch.nn.Linear(256, 9)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -131,6 +115,42 @@ def train(model, crf, train_tokens, labels_tokens):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
def data_translate(dt, vocab):
|
||||||
|
return [[vocab.itos[token] for token in document] for document in dt]
|
||||||
|
|
||||||
|
|
||||||
|
def dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab):
|
||||||
|
Y_true = []
|
||||||
|
Y_pred = []
|
||||||
|
model.eval()
|
||||||
|
crf.eval()
|
||||||
|
for i in tqdm(range(len(dev_labels_tokens))):
|
||||||
|
batch_tokens = dev_tokens[i].unsqueeze(0)
|
||||||
|
tags = list(dev_labels_tokens[i].numpy())
|
||||||
|
Y_true += tags
|
||||||
|
|
||||||
|
Y_batch_pred_weights = model(batch_tokens).squeeze(0)
|
||||||
|
Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
|
||||||
|
# Y_pred += list(Y_batch_pred.numpy())
|
||||||
|
Y_pred += [crf.decode(Y_batch_pred)[0]]
|
||||||
|
|
||||||
|
# print(Y_pred)
|
||||||
|
# Y_pred_translated = data_translate(Y_pred, vocab)
|
||||||
|
# with open('dev-0/out.tsv', "w+") as file:
|
||||||
|
# temp_str = ""
|
||||||
|
# for i in Y_pred_translated:
|
||||||
|
# for j in i:
|
||||||
|
# temp_str += str(j)
|
||||||
|
# temp_str += " "
|
||||||
|
# temp_str = temp_str[:-1]
|
||||||
|
# temp_str += "\n"
|
||||||
|
# temp_str = temp_str[:-1]
|
||||||
|
# file.write(temp_str)
|
||||||
|
|
||||||
|
precision, recall, f1 = get_scores(Y_true, Y_pred)
|
||||||
|
print(f'precision: {0}, recall: {1}, f1: {2}', precision, recall, f1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
X_train, Y_train, X_dev, Y_dev, X_test = load_data()
|
X_train, Y_train, X_dev, Y_dev, X_test = load_data()
|
||||||
vocab_x = build_vocab(X_train)
|
vocab_x = build_vocab(X_train)
|
||||||
@ -138,11 +158,16 @@ if __name__ == "__main__":
|
|||||||
train_tokens = data_process(X_train, vocab_x)
|
train_tokens = data_process(X_train, vocab_x)
|
||||||
labels_tokens = data_process(Y_train, vocab_y)
|
labels_tokens = data_process(Y_train, vocab_y)
|
||||||
|
|
||||||
# model
|
# train
|
||||||
model = GRU(len(vocab_x))
|
print(len(vocab_x.get_itos()))
|
||||||
print(model)
|
model = GRU(len(vocab_x.get_itos()))
|
||||||
crf = CRF(9)
|
crf = CRF(9)
|
||||||
p = list(model.parameters()) + list(crf.parameters())
|
p = list(model.parameters()) + list(crf.parameters())
|
||||||
optimizer = torch.optim.Adam(p)
|
optimizer = torch.optim.Adam(p)
|
||||||
# mask = torch.ByteTensor([1, 1]) # (batch_size. sequence_size)
|
# # mask = torch.ByteTensor([1, 1]) # (batch_size. sequence_size)
|
||||||
train(model, crf, train_tokens, labels_tokens)
|
train(model, crf, train_tokens, labels_tokens)
|
||||||
|
|
||||||
|
# eval dev
|
||||||
|
dev_tokens = data_process(X_dev, vocab_x)
|
||||||
|
dev_labels_tokens = data_process(Y_dev, vocab_y)
|
||||||
|
dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab_x)
|
||||||
|
Loading…
Reference in New Issue
Block a user