en-ner-conll-2003/train.py
2022-06-07 21:36:55 +02:00

161 lines
4.4 KiB
Python

from collections import Counter
import torch as torch
import torchtext.vocab
from bidict import bidict
from string import punctuation
label2num = bidict({'O': 0, 'B-PER': 1, 'B-LOC': 2, 'I-PER': 3, 'B-MISC': 4, 'I-MISC': 5, 'I-LOC': 6, 'B-ORG': 7,
'I-ORG': 8})
num2label = label2num.inverse
def build_vocab(dataset):
counter = Counter()
for document in dataset:
counter.update(document)
vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
vocab.set_default_index(0)
return vocab
def data_process(dt):
processed = [
torch.tensor([vocab['<bos>']] + [vocab[token] for token in document] + [vocab['<eos>']], dtype=torch.long)
for document in dt]
return processed
def labels_process(dt):
dt_num = [[label2num[label] for label in labels] for labels in dt]
return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt_num]
def add_extra_features(x_base, x_str):
extra_features = []
for word in x_str:
word_features = [0] * 9
if word.islower():
word_features[0] = 1
if word.isupper():
word_features[1] = 1
if word.isalnum():
word_features[2] = 1
if word.isalpha():
word_features[3] = 1
if word.isdigit():
word_features[4] = 1
if word.istitle():
word_features[5] = 1
for char in word:
if char in punctuation:
word_features[6] = 1
break
if len(word) == 1:
if word in punctuation:
word_features[7] = 1
if len(word) < 3:
word_features[8] = 1
extra_features += word_features
while len(extra_features) != 27:
extra_features += [0] * 9
extra_features = torch.tensor(extra_features)
x_extra = torch.cat((x_base, extra_features), 0)
return x_extra
class NERModel(torch.nn.Module):
def __init__(self, ):
super(NERModel, self).__init__()
self.emb = torch.nn.Embedding(23627, 200)
self.fc1 = torch.nn.Linear(6000, 9)
def forward(self, x):
x = self.emb(x)
x = x.reshape(6000)
x = self.fc1(x)
return x
X = []
Y = []
X_strings = []
Y_strings = []
with open('train.tsv', encoding='utf-8') as f:
for l in f:
l = l.strip().split('\t')
tags_list = l[0].split()
text_list = l[1].split()
X.append(text_list)
X_strings.append(text_list)
Y.append(tags_list)
Y_strings.append(tags_list)
vocab = build_vocab(X)
train_tokens_ids = data_process(X)
train_labels = labels_process(Y)
ner_model = NERModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ner_model.parameters())
# TRAIN
print('-----TRAINING-----')
for epoch in range(5):
loss_score = 0
acc_score = 0
prec_score = 0
selected_items = 0
recall_score = 0
relevant_items = 0
items_total = 0
ner_model.train()
a = 0
for i in range(len(train_labels)):
a += 1
print(a)
for j in range(1, len(train_labels[i]) - 1):
X_base = train_tokens_ids[i][j - 1: j + 2]
X_string = X_strings[i][j - 1: j + 2]
X_extra = add_extra_features(X_base, X_string)
Y = train_labels[i][j: j + 1]
Y_predictions = ner_model(X_extra)
acc_score += int(torch.argmax(Y_predictions) == Y)
if torch.argmax(Y_predictions) != 0:
selected_items += 1
if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():
prec_score += 1
if Y.item() != 0:
relevant_items += 1
if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():
recall_score += 1
items_total += 1
optimizer.zero_grad()
loss = criterion(Y_predictions.unsqueeze(0), Y)
loss.backward()
optimizer.step()
loss_score += loss.item()
precision = prec_score / selected_items
recall = recall_score / relevant_items
f1_score = (2 * precision * recall) / (precision + recall)
print('epoch: ', epoch)
print('loss: ', loss_score / items_total)
print('acc: ', acc_score / items_total)
print('prec: ', precision)
print('recall: : ', recall)
print('f1: ', f1_score)
PATH = 'model.pt'
torch.save(ner_model, PATH)