forked from kubapok/en-ner-conll-2003
167 lines
4.4 KiB
Python
167 lines
4.4 KiB
Python
import pandas as pd
|
|
import torch
|
|
import pandas as pd
|
|
from collections import Counter
|
|
from torchtext.vocab import vocab
|
|
from TorchCRF import CRF
|
|
from tqdm import tqdm
|
|
import sys
|
|
import os
|
|
|
|
EPOCHS = 1
|
|
BATCH = 5
|
|
SEQ_LEN = 5
|
|
PRETRAINED = True
|
|
MODEL = "model.torch"
|
|
|
|
OUT_FILE_NAME = "out.tsv"
|
|
IN_FILE_NAME = "in.tsv"
|
|
|
|
LABELS_COL = "labels"
|
|
DOCS_COL = "document"
|
|
VALUE_SEP = '\t'
|
|
|
|
TRAIN_PATH = "train"
|
|
TRAIN_FNAME = "train.tsv"
|
|
|
|
|
|
class GRU(torch.nn.Module):
|
|
def __init__(self, vocab_len):
|
|
super(GRU, self).__init__()
|
|
self.emb = torch.nn.Embedding(vocab_len, 100)
|
|
self.rec = torch.nn.GRU(100, 256, 1, batch_first=True, dropout=0.2)
|
|
self.fc1 = torch.nn.Linear(256, 9)
|
|
|
|
def forward(self, x):
|
|
emb = torch.relu(self.emb(x))
|
|
gru_output, _ = self.rec(emb)
|
|
out_weights = self.fc1(gru_output)
|
|
return out_weights
|
|
|
|
|
|
def main(dirnames):
|
|
tdf = pd.read_csv(os.path.join(TRAIN_PATH, TRAIN_FNAME), sep=VALUE_SEP,
|
|
names=[LABELS_COL, DOCS_COL])
|
|
|
|
Y_train = [y.split(sep=" ") for y in tdf[LABELS_COL].values]
|
|
X_train = [x.split(sep=" ") for x in tdf[DOCS_COL].values]
|
|
|
|
in_sets = []
|
|
|
|
for d in dirnames:
|
|
df = pd.read_csv(os.path.join(d, IN_FILE_NAME),
|
|
sep=VALUE_SEP, names=[DOCS_COL])
|
|
in_sets.append([x.split(sep=" ") for x in df[DOCS_COL].values])
|
|
|
|
vocab_x = build_vocab(X_train)
|
|
vocab_y = build_vocab(Y_train)
|
|
train_tokens = data_process(X_train, vocab_x)
|
|
train_labels = data_process(Y_train, vocab_y)
|
|
|
|
model = GRU(len(vocab_x.get_itos()))
|
|
|
|
if not PRETRAINED:
|
|
crf = CRF(9)
|
|
p = list(model.parameters()) + list(crf.parameters())
|
|
train(model, crf, train_tokens, train_labels, torch.optim.Adam(p))
|
|
|
|
if PRETRAINED:
|
|
model.load_state_dict(torch.load(MODEL))
|
|
|
|
for i in range(len(dirnames)):
|
|
d = dirnames[i]
|
|
print(f"Processing directory: {d}...")
|
|
t = data_process(in_sets[i], vocab_x)
|
|
p = predict(model, crf, t, vocab_y)
|
|
save_to_file(p, os.path.join(d, OUT_FILE_NAME))
|
|
|
|
|
|
def build_vocab(dataset):
|
|
counter = Counter()
|
|
for document in dataset:
|
|
counter.update(document)
|
|
v = vocab(counter)
|
|
v.set_default_index(0)
|
|
return v
|
|
|
|
|
|
def data_process(data, vocab):
|
|
return [torch.tensor([vocab[t] for t in doc], dtype=torch.long) for doc in data]
|
|
|
|
|
|
def get_scores(y_true, y_pred):
|
|
y_true = [item for sublist in y_true for item in sublist]
|
|
y_pred = [item for sublist in y_pred for item in sublist]
|
|
|
|
acc_score = 0
|
|
for p, t in zip(y_pred, y_true):
|
|
if p == t:
|
|
acc_score += 1
|
|
|
|
return acc_score / len(y_pred)
|
|
|
|
|
|
def translate(dt, vocab):
|
|
translated = []
|
|
for d in dt:
|
|
translated.append([vocab.get_itos()[token] for token in d])
|
|
return translated
|
|
|
|
|
|
def save_to_file(out, out_path):
|
|
lines = []
|
|
for line in out:
|
|
last_label = None
|
|
new_line = []
|
|
for label in line:
|
|
if(label != "O" and label[0:2] == "I-"):
|
|
if last_label == None or last_label == "O":
|
|
label = label.replace('I-', 'B-')
|
|
else:
|
|
label = "I-" + last_label[2:]
|
|
last_label = label
|
|
new_line.append(label)
|
|
lines.append(" ".join(new_line))
|
|
with open(out_path, 'w') as f:
|
|
for line in lines:
|
|
f.write(str(line) + "\n")
|
|
|
|
|
|
def train(model, crf, train_tokens, train_labels, optimizer):
|
|
for i in range(EPOCHS):
|
|
crf.train()
|
|
model.train()
|
|
for i in tqdm(range(len(train_labels))):
|
|
batch_tokens = train_tokens[i].unsqueeze(0)
|
|
tags = train_labels[i].unsqueeze(1)
|
|
|
|
predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)
|
|
|
|
optimizer.zero_grad()
|
|
loss = -crf(predicted_tags, tags)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
torch.save(model.state_dict(), MODEL)
|
|
|
|
|
|
def predict(model, crf, test_tokens, vocab):
|
|
Y_pred = []
|
|
model.eval()
|
|
crf.eval()
|
|
for i in tqdm(range(len(test_tokens))):
|
|
batch_tokens = test_tokens[i].unsqueeze(0)
|
|
|
|
Y_batch_pred = model(batch_tokens).squeeze(0).unsqueeze(1)
|
|
Y_pred += [crf.decode(Y_batch_pred)[0]]
|
|
|
|
Y_pred_translate = translate(Y_pred, vocab)
|
|
return Y_pred_translate
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) < 2:
|
|
raise Exception("Name of working dir not specified!")
|
|
main(sys.argv[1:])
|