from collections import Counter import torch as torch import torchtext.vocab from bidict import bidict from string import punctuation from train import add_extra_features, data_process label2num = bidict({'O': 0, 'B-PER': 1, 'B-LOC': 2, 'I-PER': 3, 'B-MISC': 4, 'I-MISC': 5, 'I-LOC': 6, 'B-ORG': 7, 'I-ORG': 8}) num2label = label2num.inverse class NERModel(torch.nn.Module): def __init__(self, ): super(NERModel, self).__init__() self.emb = torch.nn.Embedding(23627, 200) self.fc1 = torch.nn.Linear(6000, 9) def forward(self, x): x = self.emb(x) x = x.reshape(6000) x = self.fc1(x) return x ner_model = torch.load('model.pt') ner_model.eval() def predict(path): X_base = [] X_strings = [] with open(f"{path}/in.tsv", 'r', encoding='utf-8') as f: for l in f: l = l.strip() X_base.append(l.split(' ')) X_strings.append(l.split(' ')) train_tokens_ids = data_process(X_base) predictions = [] for i in range(len(train_tokens_ids)): labels_str = '' for j in range(1, len(train_tokens_ids[i]) - 1): X_base = train_tokens_ids[i][j - 1: j + 2] X_string = X_strings[i][j - 1: j + 2] X_extra = add_extra_features(X_base, X_string) Y_pred = ner_model(X_base) label = num2label[int(torch.argmax(Y_pred))] labels_str += label + ' ' predictions.append(labels_str[:-1]) lines = [] for line in predictions: prev_label = None line_corr = [] for label in line.split(): if label != 'O' and label[0] == 'I': if prev_label is None or prev_label == 'O': label = label.replace('I', 'B') else: label = 'I' + prev_label[1:] prev_label = label line_corr.append(label) lines.append(' '.join(line_corr)) with open(f'{path}/out.tsv', 'w', encoding='utf-8') as f: for l in lines: f.write(l + '\n') predict('test-A') predict('dev-0')