161 lines
4.4 KiB
Python
161 lines
4.4 KiB
Python
from collections import Counter
|
|
import torch as torch
|
|
import torchtext.vocab
|
|
from bidict import bidict
|
|
from string import punctuation
|
|
|
|
label2num = bidict({'O': 0, 'B-PER': 1, 'B-LOC': 2, 'I-PER': 3, 'B-MISC': 4, 'I-MISC': 5, 'I-LOC': 6, 'B-ORG': 7,
|
|
'I-ORG': 8})
|
|
num2label = label2num.inverse
|
|
|
|
|
|
def build_vocab(dataset):
|
|
counter = Counter()
|
|
for document in dataset:
|
|
counter.update(document)
|
|
vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
|
|
vocab.set_default_index(0)
|
|
return vocab
|
|
|
|
|
|
def data_process(dt):
|
|
processed = [
|
|
torch.tensor([vocab['<bos>']] + [vocab[token] for token in document] + [vocab['<eos>']], dtype=torch.long)
|
|
for document in dt]
|
|
return processed
|
|
|
|
|
|
def labels_process(dt):
|
|
dt_num = [[label2num[label] for label in labels] for labels in dt]
|
|
return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt_num]
|
|
|
|
|
|
def add_extra_features(x_base, x_str):
|
|
extra_features = []
|
|
for word in x_str:
|
|
word_features = [0] * 9
|
|
if word.islower():
|
|
word_features[0] = 1
|
|
if word.isupper():
|
|
word_features[1] = 1
|
|
if word.isalnum():
|
|
word_features[2] = 1
|
|
if word.isalpha():
|
|
word_features[3] = 1
|
|
if word.isdigit():
|
|
word_features[4] = 1
|
|
if word.istitle():
|
|
word_features[5] = 1
|
|
for char in word:
|
|
if char in punctuation:
|
|
word_features[6] = 1
|
|
break
|
|
if len(word) == 1:
|
|
if word in punctuation:
|
|
word_features[7] = 1
|
|
if len(word) < 3:
|
|
word_features[8] = 1
|
|
extra_features += word_features
|
|
while len(extra_features) != 27:
|
|
extra_features += [0] * 9
|
|
extra_features = torch.tensor(extra_features)
|
|
x_extra = torch.cat((x_base, extra_features), 0)
|
|
return x_extra
|
|
|
|
|
|
class NERModel(torch.nn.Module):
|
|
|
|
def __init__(self, ):
|
|
super(NERModel, self).__init__()
|
|
self.emb = torch.nn.Embedding(23627, 200)
|
|
self.fc1 = torch.nn.Linear(6000, 9)
|
|
|
|
def forward(self, x):
|
|
x = self.emb(x)
|
|
x = x.reshape(6000)
|
|
x = self.fc1(x)
|
|
return x
|
|
|
|
|
|
X = []
|
|
Y = []
|
|
X_strings = []
|
|
Y_strings = []
|
|
|
|
with open('train.tsv', encoding='utf-8') as f:
|
|
for l in f:
|
|
l = l.strip().split('\t')
|
|
tags_list = l[0].split()
|
|
text_list = l[1].split()
|
|
X.append(text_list)
|
|
X_strings.append(text_list)
|
|
Y.append(tags_list)
|
|
Y_strings.append(tags_list)
|
|
|
|
vocab = build_vocab(X)
|
|
train_tokens_ids = data_process(X)
|
|
|
|
train_labels = labels_process(Y)
|
|
|
|
ner_model = NERModel()
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.Adam(ner_model.parameters())
|
|
|
|
# TRAIN
|
|
print('-----TRAINING-----')
|
|
for epoch in range(5):
|
|
loss_score = 0
|
|
acc_score = 0
|
|
prec_score = 0
|
|
selected_items = 0
|
|
recall_score = 0
|
|
relevant_items = 0
|
|
items_total = 0
|
|
ner_model.train()
|
|
a = 0
|
|
for i in range(len(train_labels)):
|
|
a += 1
|
|
print(a)
|
|
for j in range(1, len(train_labels[i]) - 1):
|
|
X_base = train_tokens_ids[i][j - 1: j + 2]
|
|
X_string = X_strings[i][j - 1: j + 2]
|
|
X_extra = add_extra_features(X_base, X_string)
|
|
|
|
Y = train_labels[i][j: j + 1]
|
|
|
|
Y_predictions = ner_model(X_extra)
|
|
|
|
acc_score += int(torch.argmax(Y_predictions) == Y)
|
|
|
|
if torch.argmax(Y_predictions) != 0:
|
|
selected_items += 1
|
|
if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():
|
|
prec_score += 1
|
|
|
|
if Y.item() != 0:
|
|
relevant_items += 1
|
|
if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():
|
|
recall_score += 1
|
|
|
|
items_total += 1
|
|
|
|
optimizer.zero_grad()
|
|
loss = criterion(Y_predictions.unsqueeze(0), Y)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
loss_score += loss.item()
|
|
|
|
precision = prec_score / selected_items
|
|
recall = recall_score / relevant_items
|
|
f1_score = (2 * precision * recall) / (precision + recall)
|
|
print('epoch: ', epoch)
|
|
print('loss: ', loss_score / items_total)
|
|
print('acc: ', acc_score / items_total)
|
|
print('prec: ', precision)
|
|
print('recall: : ', recall)
|
|
print('f1: ', f1_score)
|
|
|
|
PATH = 'model.pt'
|
|
torch.save(ner_model, PATH)
|