2021-06-21 21:56:24 +02:00
|
|
|
from numpy.lib.shape_base import split
|
2021-06-21 21:10:27 +02:00
|
|
|
import pandas as pd
|
|
|
|
import numpy as np
|
|
|
|
import gensim
|
|
|
|
import torch
|
|
|
|
import pandas as pd
|
|
|
|
from sklearn.model_selection import train_test_split
|
2021-06-21 21:56:24 +02:00
|
|
|
from collections import Counter
|
2021-06-22 03:40:29 +02:00
|
|
|
from torchtext.vocab import vocab
|
2021-06-22 01:19:45 +02:00
|
|
|
from TorchCRF import CRF
|
|
|
|
from tqdm import tqdm
|
2021-06-21 21:56:24 +02:00
|
|
|
|
2021-06-22 03:40:29 +02:00
|
|
|
EPOCHS = 1
|
2021-06-22 23:51:46 +02:00
|
|
|
BATCH = 5
|
2021-06-22 03:40:29 +02:00
|
|
|
SEQ_LEN = 5
|
2021-06-21 21:56:24 +02:00
|
|
|
|
|
|
|
# Functions from jupyter
|
2021-06-22 01:19:45 +02:00
|
|
|
|
|
|
|
|
2021-06-21 21:56:24 +02:00
|
|
|
def build_vocab(dataset):
|
|
|
|
counter = Counter()
|
|
|
|
for document in dataset:
|
|
|
|
counter.update(document)
|
2021-06-22 03:40:29 +02:00
|
|
|
v = vocab(counter)
|
|
|
|
v.set_default_index(0)
|
|
|
|
return v
|
2021-06-21 21:56:24 +02:00
|
|
|
|
|
|
|
|
|
|
|
def data_process(dt, vocab):
|
2021-06-22 03:40:29 +02:00
|
|
|
return [torch.tensor([vocab[token] for token in document], dtype=torch.long) for document in dt]
|
2021-06-21 21:56:24 +02:00
|
|
|
|
|
|
|
|
2021-06-22 01:19:45 +02:00
|
|
|
def get_scores(y_true, y_pred):
|
2021-06-22 23:51:46 +02:00
|
|
|
y_true = [item for sublist in y_true for item in sublist]
|
|
|
|
y_pred = [item for sublist in y_pred for item in sublist]
|
2021-06-22 01:19:45 +02:00
|
|
|
acc_score = 0
|
2021-06-22 23:51:46 +02:00
|
|
|
|
2021-06-22 01:19:45 +02:00
|
|
|
for p, t in zip(y_pred, y_true):
|
|
|
|
if p == t:
|
|
|
|
acc_score += 1
|
|
|
|
|
2021-06-22 23:51:46 +02:00
|
|
|
return acc_score / len(y_pred)
|
2021-06-22 01:19:45 +02:00
|
|
|
|
|
|
|
|
2021-06-22 02:24:08 +02:00
|
|
|
class GRU(torch.nn.Module):
|
2021-06-22 01:19:45 +02:00
|
|
|
def __init__(self, vocab_len):
|
2021-06-22 02:24:08 +02:00
|
|
|
super(GRU, self).__init__()
|
2021-06-22 01:19:45 +02:00
|
|
|
self.emb = torch.nn.Embedding(vocab_len, 100)
|
2021-06-22 03:40:29 +02:00
|
|
|
self.rec = torch.nn.GRU(100, 256, 1, batch_first=True, dropout=0.2)
|
2021-06-22 01:19:45 +02:00
|
|
|
self.fc1 = torch.nn.Linear(256, 9)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
emb = torch.relu(self.emb(x))
|
2021-06-22 02:24:08 +02:00
|
|
|
gru_output, h_n = self.rec(emb)
|
|
|
|
out_weights = self.fc1(gru_output)
|
2021-06-22 01:19:45 +02:00
|
|
|
return out_weights
|
2021-06-21 21:10:27 +02:00
|
|
|
|
2021-06-22 23:51:46 +02:00
|
|
|
# Helpers
|
|
|
|
|
|
|
|
|
|
|
|
def translate(dt, vocab):
|
|
|
|
translated = []
|
|
|
|
for d in dt:
|
|
|
|
translated.append([vocab.get_itos()[token] for token in d])
|
|
|
|
return translated
|
|
|
|
|
|
|
|
|
2021-06-23 00:27:54 +02:00
|
|
|
def FIX_OUTPUT_FOR_GEVAL(out):
|
|
|
|
result = []
|
|
|
|
for line in out:
|
|
|
|
last_label = None
|
|
|
|
new_line = []
|
|
|
|
for label in line:
|
|
|
|
if(label != "O" and label[0:2] == "I-"):
|
|
|
|
if last_label == None or last_label == "O":
|
|
|
|
label = label.replace('I-', 'B-')
|
|
|
|
else:
|
|
|
|
label = "I-" + last_label[2:]
|
|
|
|
last_label = label
|
|
|
|
new_line.append(label)
|
|
|
|
x = (" ".join(new_line))
|
|
|
|
result.append(" ".join(new_line))
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
2021-06-22 23:51:46 +02:00
|
|
|
def save_to_file(out, out_path):
|
2021-06-23 00:27:54 +02:00
|
|
|
lines = FIX_OUTPUT_FOR_GEVAL(out)
|
2021-06-22 23:51:46 +02:00
|
|
|
with open(out_path, 'w+') as f:
|
2021-06-23 00:27:54 +02:00
|
|
|
for line in lines:
|
|
|
|
f.write(str(line) + '\n')
|
|
|
|
|
2021-06-22 23:51:46 +02:00
|
|
|
|
2021-06-21 21:10:27 +02:00
|
|
|
# Load data
|
2021-06-22 01:19:45 +02:00
|
|
|
|
|
|
|
|
2021-06-21 21:56:24 +02:00
|
|
|
def load_data():
|
|
|
|
train = pd.read_csv('train/train.tsv', sep='\t',
|
|
|
|
names=['labels', 'document'])
|
|
|
|
|
|
|
|
Y_train = [y.split(sep=" ") for y in train['labels'].values]
|
|
|
|
X_train = [x.split(sep=" ") for x in train['document'].values]
|
|
|
|
|
|
|
|
dev = pd.read_csv('dev-0/in.tsv', sep='\t', names=['document'])
|
|
|
|
exp = pd.read_csv('dev-0/expected.tsv', sep='\t', names=['labels'])
|
|
|
|
X_dev = [x.split(sep=" ") for x in dev['document'].values]
|
|
|
|
Y_dev = [y.split(sep=" ") for y in exp['labels'].values]
|
|
|
|
|
|
|
|
test = pd.read_csv('test-A/in.tsv', sep='\t', names=['document'])
|
2021-06-22 23:57:22 +02:00
|
|
|
X_test = [x.split(sep=" ") for x in test['document'].values]
|
2021-06-21 21:56:24 +02:00
|
|
|
|
|
|
|
return X_train, Y_train, X_dev, Y_dev, X_test
|
2021-06-21 21:10:27 +02:00
|
|
|
|
2021-06-22 23:51:46 +02:00
|
|
|
# Train and save model
|
|
|
|
|
2021-06-21 21:10:27 +02:00
|
|
|
|
2021-06-22 01:19:45 +02:00
|
|
|
def train(model, crf, train_tokens, labels_tokens):
|
|
|
|
for i in range(EPOCHS):
|
|
|
|
crf.train()
|
|
|
|
model.train()
|
|
|
|
for i in tqdm(range(len(labels_tokens))):
|
|
|
|
batch_tokens = train_tokens[i].unsqueeze(0)
|
|
|
|
tags = labels_tokens[i].unsqueeze(1)
|
|
|
|
|
|
|
|
predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
2021-06-22 02:24:08 +02:00
|
|
|
loss = -crf(predicted_tags, tags)
|
2021-06-22 01:19:45 +02:00
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
2021-06-22 23:51:46 +02:00
|
|
|
torch.save(model.state_dict(), "model.torch")
|
2021-06-22 01:19:45 +02:00
|
|
|
|
2021-06-22 23:51:46 +02:00
|
|
|
# Eval dev set and save output
|
2021-06-22 03:40:29 +02:00
|
|
|
|
|
|
|
|
|
|
|
def dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab):
|
|
|
|
Y_true = []
|
|
|
|
Y_pred = []
|
|
|
|
model.eval()
|
|
|
|
crf.eval()
|
|
|
|
for i in tqdm(range(len(dev_labels_tokens))):
|
|
|
|
batch_tokens = dev_tokens[i].unsqueeze(0)
|
2021-06-22 23:51:46 +02:00
|
|
|
tags = labels_tokens[i].unsqueeze(1)
|
|
|
|
Y_true += [tags]
|
2021-06-22 03:40:29 +02:00
|
|
|
|
2021-06-22 23:51:46 +02:00
|
|
|
Y_batch_pred = model(batch_tokens).squeeze(0).unsqueeze(1)
|
2021-06-22 03:40:29 +02:00
|
|
|
Y_pred += [crf.decode(Y_batch_pred)[0]]
|
|
|
|
|
2021-06-22 23:51:46 +02:00
|
|
|
Y_pred_translate = translate(Y_pred, vocab)
|
|
|
|
Y_true_translate = translate(Y_true, vocab)
|
2021-06-22 03:40:29 +02:00
|
|
|
|
2021-06-22 23:51:46 +02:00
|
|
|
precision = get_scores(Y_pred_translate, Y_true_translate)
|
|
|
|
print(f'precision: {precision}'.format(precision))
|
|
|
|
return Y_pred_translate
|
2021-06-22 03:40:29 +02:00
|
|
|
|
|
|
|
|
2021-06-22 23:57:22 +02:00
|
|
|
def test_eval(model, crf, test_tokens, vocab):
|
|
|
|
Y_pred = []
|
|
|
|
model.eval()
|
|
|
|
crf.eval()
|
|
|
|
for i in tqdm(range(len(test_tokens))):
|
|
|
|
batch_tokens = test_tokens[i].unsqueeze(0)
|
|
|
|
|
|
|
|
Y_batch_pred = model(batch_tokens).squeeze(0).unsqueeze(1)
|
|
|
|
Y_pred += [crf.decode(Y_batch_pred)[0]]
|
|
|
|
|
|
|
|
Y_pred_translate = translate(Y_pred, vocab)
|
|
|
|
return Y_pred_translate
|
|
|
|
|
|
|
|
|
2021-06-21 21:56:24 +02:00
|
|
|
if __name__ == "__main__":
|
|
|
|
X_train, Y_train, X_dev, Y_dev, X_test = load_data()
|
2021-06-22 01:19:45 +02:00
|
|
|
vocab_x = build_vocab(X_train)
|
|
|
|
vocab_y = build_vocab(Y_train)
|
|
|
|
train_tokens = data_process(X_train, vocab_x)
|
|
|
|
labels_tokens = data_process(Y_train, vocab_y)
|
|
|
|
|
2021-06-22 03:40:29 +02:00
|
|
|
# train
|
|
|
|
model = GRU(len(vocab_x.get_itos()))
|
2021-06-22 01:19:45 +02:00
|
|
|
crf = CRF(9)
|
|
|
|
p = list(model.parameters()) + list(crf.parameters())
|
|
|
|
optimizer = torch.optim.Adam(p)
|
2021-06-22 03:40:29 +02:00
|
|
|
# # mask = torch.ByteTensor([1, 1]) # (batch_size. sequence_size)
|
2021-06-22 23:51:46 +02:00
|
|
|
# train(model, crf, train_tokens, labels_tokens)
|
2021-06-22 03:40:29 +02:00
|
|
|
|
|
|
|
# eval dev
|
2021-06-22 23:57:22 +02:00
|
|
|
# model.load_state_dict(torch.load("model.torch"))
|
|
|
|
# dev_tokens = data_process(X_dev, vocab_x)
|
|
|
|
# dev_labels_tokens = data_process(Y_dev, vocab_y)
|
|
|
|
# dev_pred = dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab_y)
|
|
|
|
# save_to_file(dev_pred, 'dev-0/out.tsv')
|
|
|
|
|
|
|
|
# predict test
|
2021-06-22 23:51:46 +02:00
|
|
|
model.load_state_dict(torch.load("model.torch"))
|
2021-06-22 23:57:22 +02:00
|
|
|
test_tokens = data_process(X_test, vocab_x)
|
|
|
|
test_pred = test_eval(model, crf, test_tokens, vocab_y)
|
|
|
|
save_to_file(test_pred, 'test-A/out.tsv')
|