en-ner-conll-2003/main.py
2021-06-22 23:48:57 +02:00

167 lines
4.4 KiB
Python

import pandas as pd
import torch
import pandas as pd
from collections import Counter
from torchtext.vocab import vocab
from TorchCRF import CRF
from tqdm import tqdm
import sys
import os
EPOCHS = 1
BATCH = 5
SEQ_LEN = 5
PRETRAINED = True
MODEL = "model.torch"
OUT_FILE_NAME = "out.tsv"
IN_FILE_NAME = "in.tsv"
LABELS_COL = "labels"
DOCS_COL = "document"
VALUE_SEP = '\t'
TRAIN_PATH = "train"
TRAIN_FNAME = "train.tsv"
class GRU(torch.nn.Module):
def __init__(self, vocab_len):
super(GRU, self).__init__()
self.emb = torch.nn.Embedding(vocab_len, 100)
self.rec = torch.nn.GRU(100, 256, 1, batch_first=True, dropout=0.2)
self.fc1 = torch.nn.Linear(256, 9)
def forward(self, x):
emb = torch.relu(self.emb(x))
gru_output, _ = self.rec(emb)
out_weights = self.fc1(gru_output)
return out_weights
def main(dirnames):
tdf = pd.read_csv(os.path.join(TRAIN_PATH, TRAIN_FNAME), sep=VALUE_SEP,
names=[LABELS_COL, DOCS_COL])
Y_train = [y.split(sep=" ") for y in tdf[LABELS_COL].values]
X_train = [x.split(sep=" ") for x in tdf[DOCS_COL].values]
in_sets = []
for d in dirnames:
df = pd.read_csv(os.path.join(d, IN_FILE_NAME),
sep=VALUE_SEP, names=[DOCS_COL])
in_sets.append([x.split(sep=" ") for x in df[DOCS_COL].values])
vocab_x = build_vocab(X_train)
vocab_y = build_vocab(Y_train)
train_tokens = data_process(X_train, vocab_x)
train_labels = data_process(Y_train, vocab_y)
model = GRU(len(vocab_x.get_itos()))
if not PRETRAINED:
crf = CRF(9)
p = list(model.parameters()) + list(crf.parameters())
train(model, crf, train_tokens, train_labels, torch.optim.Adam(p))
if PRETRAINED:
model.load_state_dict(torch.load(MODEL))
for i in range(len(dirnames)):
d = dirnames[i]
print(f"Processing directory: {d}...")
t = data_process(in_sets[i], vocab_x)
p = predict(model, crf, t, vocab_y)
save_to_file(p, os.path.join(d, OUT_FILE_NAME))
def build_vocab(dataset):
counter = Counter()
for document in dataset:
counter.update(document)
v = vocab(counter)
v.set_default_index(0)
return v
def data_process(data, vocab):
return [torch.tensor([vocab[t] for t in doc], dtype=torch.long) for doc in data]
def get_scores(y_true, y_pred):
y_true = [item for sublist in y_true for item in sublist]
y_pred = [item for sublist in y_pred for item in sublist]
acc_score = 0
for p, t in zip(y_pred, y_true):
if p == t:
acc_score += 1
return acc_score / len(y_pred)
def translate(dt, vocab):
translated = []
for d in dt:
translated.append([vocab.get_itos()[token] for token in d])
return translated
def save_to_file(out, out_path):
lines = []
for line in out:
last_label = None
new_line = []
for label in line:
if(label != "O" and label[0:2] == "I-"):
if last_label == None or last_label == "O":
label = label.replace('I-', 'B-')
else:
label = "I-" + last_label[2:]
last_label = label
new_line.append(label)
lines.append(" ".join(new_line))
with open(out_path, 'w') as f:
for line in lines:
f.write(str(line) + "\n")
def train(model, crf, train_tokens, train_labels, optimizer):
for i in range(EPOCHS):
crf.train()
model.train()
for i in tqdm(range(len(train_labels))):
batch_tokens = train_tokens[i].unsqueeze(0)
tags = train_labels[i].unsqueeze(1)
predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)
optimizer.zero_grad()
loss = -crf(predicted_tags, tags)
loss.backward()
optimizer.step()
torch.save(model.state_dict(), MODEL)
def predict(model, crf, test_tokens, vocab):
Y_pred = []
model.eval()
crf.eval()
for i in tqdm(range(len(test_tokens))):
batch_tokens = test_tokens[i].unsqueeze(0)
Y_batch_pred = model(batch_tokens).squeeze(0).unsqueeze(1)
Y_pred += [crf.decode(Y_batch_pred)[0]]
Y_pred_translate = translate(Y_pred, vocab)
return Y_pred_translate
if __name__ == "__main__":
if len(sys.argv) < 2:
raise Exception("Name of working dir not specified!")
main(sys.argv[1:])