diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..5699dde --- /dev/null +++ b/predict.py @@ -0,0 +1,76 @@ +from collections import Counter +import torch as torch +import torchtext.vocab +from bidict import bidict +from string import punctuation +from train import add_extra_features, data_process + +label2num = bidict({'O': 0, 'B-PER': 1, 'B-LOC': 2, 'I-PER': 3, 'B-MISC': 4, 'I-MISC': 5, 'I-LOC': 6, 'B-ORG': 7, + 'I-ORG': 8}) +num2label = label2num.inverse + + +class NERModel(torch.nn.Module): + + def __init__(self, ): + super(NERModel, self).__init__() + self.emb = torch.nn.Embedding(23627, 200) + self.fc1 = torch.nn.Linear(6000, 9) + + def forward(self, x): + x = self.emb(x) + x = x.reshape(6000) + x = self.fc1(x) + return x + + +ner_model = torch.load('model.pt') +ner_model.eval() + + +def predict(path): + X_base = [] + X_strings = [] + + with open(f"{path}/in.tsv", 'r', encoding='utf-8') as f: + for l in f: + l = l.strip() + X_base.append(l.split(' ')) + X_strings.append(l.split(' ')) + + train_tokens_ids = data_process(X_base) + + predictions = [] + for i in range(len(train_tokens_ids)): + labels_str = '' + for j in range(1, len(train_tokens_ids[i]) - 1): + X_base = train_tokens_ids[i][j - 1: j + 2] + X_string = X_strings[i][j - 1: j + 2] + X_extra = add_extra_features(X_base, X_string) + + Y_pred = ner_model(X_base) + label = num2label[int(torch.argmax(Y_pred))] + labels_str += label + ' ' + + predictions.append(labels_str[:-1]) + + lines = [] + for line in predictions: + prev_label = None + line_corr = [] + for label in line.split(): + if label != 'O' and label[0] == 'I': + if prev_label is None or prev_label == 'O': + label = label.replace('I', 'B') + else: + label = 'I' + prev_label[1:] + prev_label = label + line_corr.append(label) + lines.append(' '.join(line_corr)) + with open(f'{path}/out.tsv', 'w', encoding='utf-8') as f: + for l in lines: + f.write(l + '\n') + + +predict('test-A') +predict('dev-0') diff --git a/train.py b/train.py new file mode 100644 index 0000000..b269889 --- /dev/null +++ b/train.py @@ -0,0 +1,160 @@ +from collections import Counter +import torch as torch +import torchtext.vocab +from bidict import bidict +from string import punctuation + +label2num = bidict({'O': 0, 'B-PER': 1, 'B-LOC': 2, 'I-PER': 3, 'B-MISC': 4, 'I-MISC': 5, 'I-LOC': 6, 'B-ORG': 7, + 'I-ORG': 8}) +num2label = label2num.inverse + + +def build_vocab(dataset): + counter = Counter() + for document in dataset: + counter.update(document) + vocab = torchtext.vocab.vocab(counter, specials=['', '', '', '']) + vocab.set_default_index(0) + return vocab + + +def data_process(dt): + processed = [ + torch.tensor([vocab['']] + [vocab[token] for token in document] + [vocab['']], dtype=torch.long) + for document in dt] + return processed + + +def labels_process(dt): + dt_num = [[label2num[label] for label in labels] for labels in dt] + return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt_num] + + +def add_extra_features(x_base, x_str): + extra_features = [] + for word in x_str: + word_features = [0] * 9 + if word.islower(): + word_features[0] = 1 + if word.isupper(): + word_features[1] = 1 + if word.isalnum(): + word_features[2] = 1 + if word.isalpha(): + word_features[3] = 1 + if word.isdigit(): + word_features[4] = 1 + if word.istitle(): + word_features[5] = 1 + for char in word: + if char in punctuation: + word_features[6] = 1 + break + if len(word) == 1: + if word in punctuation: + word_features[7] = 1 + if len(word) < 3: + word_features[8] = 1 + extra_features += word_features + while len(extra_features) != 27: + extra_features += [0] * 9 + extra_features = torch.tensor(extra_features) + x_extra = torch.cat((x_base, extra_features), 0) + return x_extra + + +class NERModel(torch.nn.Module): + + def __init__(self, ): + super(NERModel, self).__init__() + self.emb = torch.nn.Embedding(23627, 200) + self.fc1 = torch.nn.Linear(6000, 9) + + def forward(self, x): + x = self.emb(x) + x = x.reshape(6000) + x = self.fc1(x) + return x + + +X = [] +Y = [] +X_strings = [] +Y_strings = [] + +with open('train.tsv', encoding='utf-8') as f: + for l in f: + l = l.strip().split('\t') + tags_list = l[0].split() + text_list = l[1].split() + X.append(text_list) + X_strings.append(text_list) + Y.append(tags_list) + Y_strings.append(tags_list) + +vocab = build_vocab(X) +train_tokens_ids = data_process(X) + +train_labels = labels_process(Y) + +ner_model = NERModel() +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(ner_model.parameters()) + +# TRAIN +print('-----TRAINING-----') +for epoch in range(1): + loss_score = 0 + acc_score = 0 + prec_score = 0 + selected_items = 0 + recall_score = 0 + relevant_items = 0 + items_total = 0 + ner_model.train() + a = 0 + for i in range(len(train_labels)): + a += 1 + print(a) + for j in range(1, len(train_labels[i]) - 1): + X_base = train_tokens_ids[i][j - 1: j + 2] + X_string = X_strings[i][j - 1: j + 2] + X_extra = add_extra_features(X_base, X_string) + + Y = train_labels[i][j: j + 1] + + Y_predictions = ner_model(X_extra) + + acc_score += int(torch.argmax(Y_predictions) == Y) + + if torch.argmax(Y_predictions) != 0: + selected_items += 1 + if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item(): + prec_score += 1 + + if Y.item() != 0: + relevant_items += 1 + if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item(): + recall_score += 1 + + items_total += 1 + + optimizer.zero_grad() + loss = criterion(Y_predictions.unsqueeze(0), Y) + loss.backward() + optimizer.step() + + loss_score += loss.item() + + precision = prec_score / selected_items + recall = recall_score / relevant_items + f1_score = (2 * precision * recall) / (precision + recall) + print('epoch: ', epoch) + print('loss: ', loss_score / items_total) + print('acc: ', acc_score / items_total) + print('prec: ', precision) + print('recall: : ', recall) + print('f1: ', f1_score) + +PATH = 'model.pt' +torch.save(ner_model, PATH)