77 lines
2.1 KiB
Python
77 lines
2.1 KiB
Python
from collections import Counter
|
|
import torch as torch
|
|
import torchtext.vocab
|
|
from bidict import bidict
|
|
from string import punctuation
|
|
from train import add_extra_features, data_process
|
|
|
|
label2num = bidict({'O': 0, 'B-PER': 1, 'B-LOC': 2, 'I-PER': 3, 'B-MISC': 4, 'I-MISC': 5, 'I-LOC': 6, 'B-ORG': 7,
|
|
'I-ORG': 8})
|
|
num2label = label2num.inverse
|
|
|
|
|
|
class NERModel(torch.nn.Module):
|
|
|
|
def __init__(self, ):
|
|
super(NERModel, self).__init__()
|
|
self.emb = torch.nn.Embedding(23627, 200)
|
|
self.fc1 = torch.nn.Linear(6000, 9)
|
|
|
|
def forward(self, x):
|
|
x = self.emb(x)
|
|
x = x.reshape(6000)
|
|
x = self.fc1(x)
|
|
return x
|
|
|
|
|
|
ner_model = torch.load('model.pt')
|
|
ner_model.eval()
|
|
|
|
|
|
def predict(path):
|
|
X_base = []
|
|
X_strings = []
|
|
|
|
with open(f"{path}/in.tsv", 'r', encoding='utf-8') as f:
|
|
for l in f:
|
|
l = l.strip()
|
|
X_base.append(l.split(' '))
|
|
X_strings.append(l.split(' '))
|
|
|
|
train_tokens_ids = data_process(X_base)
|
|
|
|
predictions = []
|
|
for i in range(len(train_tokens_ids)):
|
|
labels_str = ''
|
|
for j in range(1, len(train_tokens_ids[i]) - 1):
|
|
X_base = train_tokens_ids[i][j - 1: j + 2]
|
|
X_string = X_strings[i][j - 1: j + 2]
|
|
X_extra = add_extra_features(X_base, X_string)
|
|
|
|
Y_pred = ner_model(X_base)
|
|
label = num2label[int(torch.argmax(Y_pred))]
|
|
labels_str += label + ' '
|
|
|
|
predictions.append(labels_str[:-1])
|
|
|
|
lines = []
|
|
for line in predictions:
|
|
prev_label = None
|
|
line_corr = []
|
|
for label in line.split():
|
|
if label != 'O' and label[0] == 'I':
|
|
if prev_label is None or prev_label == 'O':
|
|
label = label.replace('I', 'B')
|
|
else:
|
|
label = 'I' + prev_label[1:]
|
|
prev_label = label
|
|
line_corr.append(label)
|
|
lines.append(' '.join(line_corr))
|
|
with open(f'{path}/out.tsv', 'w', encoding='utf-8') as f:
|
|
for l in lines:
|
|
f.write(l + '\n')
|
|
|
|
|
|
predict('test-A')
|
|
predict('dev-0')
|