en-ner-conll-2003/main.py

191 lines
6.0 KiB
Python
Raw Normal View History

2021-06-20 19:04:16 +02:00
from os import sep
from nltk import word_tokenize
import pandas as pd
import torch
2021-06-20 22:03:34 +02:00
from torch._C import device
2021-06-20 19:04:16 +02:00
from tqdm import tqdm
from torchtext.vocab import vocab
from collections import Counter, OrderedDict
import spacy
from torchcrf import CRF
from torch.utils.data import DataLoader
2021-06-21 00:43:43 +02:00
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, classification_report
2021-06-21 01:38:09 +02:00
import csv
import pickle
2021-06-28 11:48:51 +02:00
from model import Model
nlp = spacy.load("en_core_web_sm")
2021-06-20 19:04:16 +02:00
2021-06-21 00:43:43 +02:00
2021-06-21 19:19:09 +02:00
def process_output(lines):
result = []
for line in lines:
last_label = None
new_line = []
for label in line:
if(label != "O" and label[0:2] == "I-"):
if last_label == None or last_label == "O":
label = label.replace('I-', 'B-')
else:
label = "I-" + last_label[2:]
last_label = label
new_line.append(label)
x = (" ".join(new_line))
result.append(" ".join(new_line))
return result
2021-06-20 19:04:16 +02:00
def process_document(document):
2021-06-28 11:48:51 +02:00
return [process_token(x) for x in document.split(" ")]
def save_file(path, obj):
with open(path, "w") as file:
file.write(obj)
def process_token(token):
return token.lower()
2021-06-20 19:04:16 +02:00
def build_vocab(dataset):
counter = Counter()
for document in dataset:
counter.update(process_document(document))
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)
v = vocab(counter)
2021-06-21 01:38:09 +02:00
default_index = 0
2021-06-20 19:04:16 +02:00
v.set_default_index(default_index)
return v
def data_process(dt):
2021-06-28 11:48:51 +02:00
return [ torch.tensor([vocab[process_token(token)] for token in document.split(" ") ], dtype = torch.long) for document in dt]
2021-06-20 19:04:16 +02:00
def labels_process(dt):
return [ torch.tensor([labels_vocab[token] for token in document.split(" ") ], dtype = torch.long) for document in dt]
2021-06-21 01:38:09 +02:00
# mode = "train"
2021-06-21 22:48:48 +02:00
# mode = "eval"
mode = "generate"
2021-06-21 01:38:09 +02:00
2021-06-21 21:29:11 +02:00
save_path = "dev-0/out.tsv"
2021-06-21 01:38:09 +02:00
2021-06-28 11:48:51 +02:00
data = pd.read_csv("dev-0/in.tsv", sep="\t", names=['data'])
2021-06-21 01:38:09 +02:00
# data.columns = ["labels", "text"]
2021-06-21 21:29:11 +02:00
# train_target = pd.read_csv("train/train.tsv", sep = '\t', names = ['labels', 'data'])
2021-06-21 01:38:09 +02:00
2021-06-21 21:29:11 +02:00
ex_data = pd.read_csv("dev-0/expected.tsv", sep="\t", names=['labels'])
2021-06-21 01:38:09 +02:00
2021-06-28 11:48:51 +02:00
in_data = data['data']
target = ex_data['labels']
# in_data = data["0"]
# target = ex_data["labels"]
2021-06-21 18:57:04 +02:00
2021-06-21 21:29:11 +02:00
# test_data = pd.read_csv("test-A/in.tsv", sep = '\t', names=['0'])
2021-06-21 01:38:09 +02:00
# test_data.columns = ['0']
2021-06-21 21:29:11 +02:00
# data = test_data['0']
# in_data = test_data['0']
2021-06-21 01:38:09 +02:00
# target = list(np.zeros(len(in_data)))
2021-06-21 21:29:11 +02:00
# target = train_target['labels']
2021-06-21 01:38:09 +02:00
2021-06-20 19:04:16 +02:00
# labels_vocab = build_vocab(data['labels'])
2021-06-21 01:38:09 +02:00
if mode == "train":
vocab = build_vocab(in_data)
with open("vocab.pickle", "wb") as file:
pickle.dump(vocab, file)
print("Vocab saved")
else:
with open("vocab.pickle", "rb") as file:
vocab = pickle.load(file)
2021-06-20 19:04:16 +02:00
labels_vocab = {
'O': 0,
'B-PER': 1,
'B-LOC': 2,
'I-PER': 3,
'B-MISC': 4,
'I-MISC': 5,
'I-LOC': 6,
'B-ORG': 7,
'I-ORG': 8
}
2021-06-21 00:43:43 +02:00
inv_labels_vocab = {v: k for k, v in labels_vocab.items()}
2021-06-21 01:38:09 +02:00
train_tokens_ids = data_process(in_data)
train_labels = labels_process(target)
2021-06-20 22:03:34 +02:00
2021-06-20 19:04:16 +02:00
num_tags = 9
NUM_EPOCHS = 5
2021-06-21 00:43:43 +02:00
seq_length = 5
2021-06-20 19:04:16 +02:00
2021-06-28 11:48:51 +02:00
model = Model(num_tags, seq_length, vocab)
device = torch.device("cuda")
2021-06-20 22:03:34 +02:00
model.to(device)
2021-06-28 11:48:51 +02:00
model.cuda(0)
2021-06-20 22:03:34 +02:00
2021-06-21 01:38:09 +02:00
if mode == "train":
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
2021-06-20 22:03:34 +02:00
for i in range(NUM_EPOCHS):
model.train()
2021-06-21 00:43:43 +02:00
model.train_mode()
2021-06-20 22:03:34 +02:00
#for i in tqdm(range(500)):
for i in tqdm(range(len(train_labels))):
for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length):
batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0)
tags = train_labels[i][k: k + seq_length].unsqueeze(1)
predicted_tags = model(batch_tokens.to(device), tags.to(device))
predicted_tags.backward()
optimizer.step()
model.zero_grad()
2021-06-21 00:43:43 +02:00
model.crf.zero_grad()
optimizer.zero_grad()
2021-06-20 22:03:34 +02:00
torch.save(model.state_dict(), "model.torch")
2021-06-21 01:38:09 +02:00
if mode == "eval" or mode == "generate":
2021-06-20 22:03:34 +02:00
model.eval()
2021-06-21 00:43:43 +02:00
model.eval_mode()
predicted = []
correct = []
model.load_state_dict(torch.load("model.torch"))
2021-06-21 01:38:09 +02:00
for i in tqdm(range(0, len(train_tokens_ids))):
2021-06-21 11:41:45 +02:00
last_idx = 0
2021-06-21 21:12:35 +02:00
for k in range(0, len(train_tokens_ids[i]) - seq_length + 1, seq_length):
2021-06-20 19:04:16 +02:00
batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0)
tags = train_labels[i][k: k + seq_length].unsqueeze(1)
2021-06-20 22:03:34 +02:00
predicted_tags = model.decode(batch_tokens.to(device))
2021-06-21 00:43:43 +02:00
predicted += predicted_tags[0]
correct += [x[0] for x in tags.numpy().tolist()]
2021-06-21 11:41:45 +02:00
last_idx = k
l = len(train_tokens_ids[i])
rest = l - int(l/seq_length) * seq_length
if rest != 0:
batch_tokens = train_tokens_ids[i][last_idx: last_idx + rest].unsqueeze(0)
tags = train_labels[i][last_idx: last_idx + rest].unsqueeze(1)
predicted_tags = model.decode(batch_tokens.to(device))
predicted += predicted_tags[0]
correct += [x[0] for x in tags.numpy().tolist()]
2021-06-21 01:38:09 +02:00
if mode == "eval":
print(classification_report(correct, predicted))
print(accuracy_score(correct, predicted))
2021-06-28 11:48:51 +02:00
print(f1_score(correct, predicted, average="micro"))
save_file("correct.txt", '\n'.join([str(x) for x in correct]))
save_file("predicted.txt", '\n'.join([str(x) for x in predicted]))
2021-06-21 00:43:43 +02:00
predicted = list(map(lambda x: inv_labels_vocab[x], predicted))
2021-06-21 21:26:00 +02:00
slices = [len(x.split(" ")) for x in in_data]
2021-06-21 01:38:09 +02:00
with open(save_path, "w") as save:
writer = csv.writer(save, delimiter='\t', lineterminator='\n')
2021-06-21 00:43:43 +02:00
accumulator = 0
2021-06-21 19:19:09 +02:00
output = []
2021-06-21 00:43:43 +02:00
for slice in slices:
2021-06-21 19:19:09 +02:00
output.append(predicted[accumulator: accumulator + slice])
2021-06-21 21:12:35 +02:00
accumulator += slice
2021-06-21 21:29:11 +02:00
for line in process_output(output):
writer.writerow([line])
2021-06-21 00:43:43 +02:00