en-ner-conll-2003/seq.py

174 lines
4.8 KiB
Python
Raw Normal View History

2021-06-21 21:56:24 +02:00
from numpy.lib.shape_base import split
2021-06-21 21:10:27 +02:00
import pandas as pd
import numpy as np
import gensim
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
2021-06-21 21:56:24 +02:00
from collections import Counter
2021-06-22 03:40:29 +02:00
from torchtext.vocab import vocab
2021-06-22 01:19:45 +02:00
from TorchCRF import CRF
from tqdm import tqdm
2021-06-21 21:56:24 +02:00
2021-06-22 03:40:29 +02:00
EPOCHS = 1
2021-06-22 01:19:45 +02:00
BATCH = 1
2021-06-22 03:40:29 +02:00
SEQ_LEN = 5
2021-06-21 21:56:24 +02:00
# Functions from jupyter
2021-06-22 01:19:45 +02:00
2021-06-21 21:56:24 +02:00
def build_vocab(dataset):
counter = Counter()
for document in dataset:
counter.update(document)
2021-06-22 03:40:29 +02:00
v = vocab(counter)
v.set_default_index(0)
return v
2021-06-21 21:56:24 +02:00
def data_process(dt, vocab):
2021-06-22 03:40:29 +02:00
return [torch.tensor([vocab[token] for token in document], dtype=torch.long) for document in dt]
2021-06-21 21:56:24 +02:00
2021-06-22 01:19:45 +02:00
def get_scores(y_true, y_pred):
acc_score = 0
tp = 0
fp = 0
selected_items = 0
relevant_items = 0
for p, t in zip(y_pred, y_true):
if p == t:
acc_score += 1
if p > 0 and p == t:
tp += 1
if p > 0:
selected_items += 1
if t > 0:
relevant_items += 1
if selected_items == 0:
precision = 1.0
else:
precision = tp / selected_items
if relevant_items == 0:
recall = 1.0
else:
recall = tp / relevant_items
if precision + recall == 0.0:
f1 = 0.0
else:
f1 = 2 * precision * recall / (precision + recall)
return precision, recall, f1
2021-06-22 02:24:08 +02:00
class GRU(torch.nn.Module):
2021-06-22 01:19:45 +02:00
def __init__(self, vocab_len):
2021-06-22 02:24:08 +02:00
super(GRU, self).__init__()
2021-06-22 01:19:45 +02:00
self.emb = torch.nn.Embedding(vocab_len, 100)
2021-06-22 03:40:29 +02:00
self.rec = torch.nn.GRU(100, 256, 1, batch_first=True, dropout=0.2)
2021-06-22 01:19:45 +02:00
self.fc1 = torch.nn.Linear(256, 9)
def forward(self, x):
emb = torch.relu(self.emb(x))
2021-06-22 02:24:08 +02:00
gru_output, h_n = self.rec(emb)
out_weights = self.fc1(gru_output)
2021-06-22 01:19:45 +02:00
return out_weights
2021-06-21 21:10:27 +02:00
# Load data
2021-06-22 01:19:45 +02:00
2021-06-21 21:56:24 +02:00
def load_data():
train = pd.read_csv('train/train.tsv', sep='\t',
names=['labels', 'document'])
Y_train = [y.split(sep=" ") for y in train['labels'].values]
X_train = [x.split(sep=" ") for x in train['document'].values]
dev = pd.read_csv('dev-0/in.tsv', sep='\t', names=['document'])
exp = pd.read_csv('dev-0/expected.tsv', sep='\t', names=['labels'])
X_dev = [x.split(sep=" ") for x in dev['document'].values]
Y_dev = [y.split(sep=" ") for y in exp['labels'].values]
test = pd.read_csv('test-A/in.tsv', sep='\t', names=['document'])
X_test = test['document'].values
return X_train, Y_train, X_dev, Y_dev, X_test
2021-06-21 21:10:27 +02:00
2021-06-22 01:19:45 +02:00
def train(model, crf, train_tokens, labels_tokens):
for i in range(EPOCHS):
crf.train()
model.train()
for i in tqdm(range(len(labels_tokens))):
batch_tokens = train_tokens[i].unsqueeze(0)
tags = labels_tokens[i].unsqueeze(1)
predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)
optimizer.zero_grad()
2021-06-22 02:24:08 +02:00
loss = -crf(predicted_tags, tags)
2021-06-22 01:19:45 +02:00
loss.backward()
optimizer.step()
2021-06-22 03:40:29 +02:00
def data_translate(dt, vocab):
return [[vocab.itos[token] for token in document] for document in dt]
def dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab):
Y_true = []
Y_pred = []
model.eval()
crf.eval()
for i in tqdm(range(len(dev_labels_tokens))):
batch_tokens = dev_tokens[i].unsqueeze(0)
tags = list(dev_labels_tokens[i].numpy())
Y_true += tags
Y_batch_pred_weights = model(batch_tokens).squeeze(0)
Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
# Y_pred += list(Y_batch_pred.numpy())
Y_pred += [crf.decode(Y_batch_pred)[0]]
# print(Y_pred)
# Y_pred_translated = data_translate(Y_pred, vocab)
# with open('dev-0/out.tsv', "w+") as file:
# temp_str = ""
# for i in Y_pred_translated:
# for j in i:
# temp_str += str(j)
# temp_str += " "
# temp_str = temp_str[:-1]
# temp_str += "\n"
# temp_str = temp_str[:-1]
# file.write(temp_str)
precision, recall, f1 = get_scores(Y_true, Y_pred)
print(f'precision: {0}, recall: {1}, f1: {2}', precision, recall, f1)
2021-06-21 21:56:24 +02:00
if __name__ == "__main__":
X_train, Y_train, X_dev, Y_dev, X_test = load_data()
2021-06-22 01:19:45 +02:00
vocab_x = build_vocab(X_train)
vocab_y = build_vocab(Y_train)
train_tokens = data_process(X_train, vocab_x)
labels_tokens = data_process(Y_train, vocab_y)
2021-06-22 03:40:29 +02:00
# train
print(len(vocab_x.get_itos()))
model = GRU(len(vocab_x.get_itos()))
2021-06-22 01:19:45 +02:00
crf = CRF(9)
p = list(model.parameters()) + list(crf.parameters())
optimizer = torch.optim.Adam(p)
2021-06-22 03:40:29 +02:00
# # mask = torch.ByteTensor([1, 1]) # (batch_size. sequence_size)
2021-06-22 01:19:45 +02:00
train(model, crf, train_tokens, labels_tokens)
2021-06-22 03:40:29 +02:00
# eval dev
dev_tokens = data_process(X_dev, vocab_x)
dev_labels_tokens = data_process(Y_dev, vocab_y)
dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab_x)