en-ner-conll-2003/predict.py

77 lines
2.1 KiB
Python

from collections import Counter
import torch as torch
import torchtext.vocab
from bidict import bidict
from string import punctuation
from train import add_extra_features, data_process
label2num = bidict({'O': 0, 'B-PER': 1, 'B-LOC': 2, 'I-PER': 3, 'B-MISC': 4, 'I-MISC': 5, 'I-LOC': 6, 'B-ORG': 7,
'I-ORG': 8})
num2label = label2num.inverse
class NERModel(torch.nn.Module):
def __init__(self, ):
super(NERModel, self).__init__()
self.emb = torch.nn.Embedding(23627, 200)
self.fc1 = torch.nn.Linear(6000, 9)
def forward(self, x):
x = self.emb(x)
x = x.reshape(6000)
x = self.fc1(x)
return x
ner_model = torch.load('model.pt')
ner_model.eval()
def predict(path):
X_base = []
X_strings = []
with open(f"{path}/in.tsv", 'r', encoding='utf-8') as f:
for l in f:
l = l.strip()
X_base.append(l.split(' '))
X_strings.append(l.split(' '))
train_tokens_ids = data_process(X_base)
predictions = []
for i in range(len(train_tokens_ids)):
labels_str = ''
for j in range(1, len(train_tokens_ids[i]) - 1):
X_base = train_tokens_ids[i][j - 1: j + 2]
X_string = X_strings[i][j - 1: j + 2]
X_extra = add_extra_features(X_base, X_string)
Y_pred = ner_model(X_base)
label = num2label[int(torch.argmax(Y_pred))]
labels_str += label + ' '
predictions.append(labels_str[:-1])
lines = []
for line in predictions:
prev_label = None
line_corr = []
for label in line.split():
if label != 'O' and label[0] == 'I':
if prev_label is None or prev_label == 'O':
label = label.replace('I', 'B')
else:
label = 'I' + prev_label[1:]
prev_label = label
line_corr.append(label)
lines.append(' '.join(line_corr))
with open(f'{path}/out.tsv', 'w', encoding='utf-8') as f:
for l in lines:
f.write(l + '\n')
predict('test-A')
predict('dev-0')