add solutions
This commit is contained in:
parent
f8bee81f0a
commit
433b2abced
76
predict.py
Normal file
76
predict.py
Normal file
@ -0,0 +1,76 @@
|
||||
from collections import Counter
|
||||
import torch as torch
|
||||
import torchtext.vocab
|
||||
from bidict import bidict
|
||||
from string import punctuation
|
||||
from train import add_extra_features, data_process
|
||||
|
||||
label2num = bidict({'O': 0, 'B-PER': 1, 'B-LOC': 2, 'I-PER': 3, 'B-MISC': 4, 'I-MISC': 5, 'I-LOC': 6, 'B-ORG': 7,
|
||||
'I-ORG': 8})
|
||||
num2label = label2num.inverse
|
||||
|
||||
|
||||
class NERModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, ):
|
||||
super(NERModel, self).__init__()
|
||||
self.emb = torch.nn.Embedding(23627, 200)
|
||||
self.fc1 = torch.nn.Linear(6000, 9)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.emb(x)
|
||||
x = x.reshape(6000)
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
|
||||
ner_model = torch.load('model.pt')
|
||||
ner_model.eval()
|
||||
|
||||
|
||||
def predict(path):
|
||||
X_base = []
|
||||
X_strings = []
|
||||
|
||||
with open(f"{path}/in.tsv", 'r', encoding='utf-8') as f:
|
||||
for l in f:
|
||||
l = l.strip()
|
||||
X_base.append(l.split(' '))
|
||||
X_strings.append(l.split(' '))
|
||||
|
||||
train_tokens_ids = data_process(X_base)
|
||||
|
||||
predictions = []
|
||||
for i in range(len(train_tokens_ids)):
|
||||
labels_str = ''
|
||||
for j in range(1, len(train_tokens_ids[i]) - 1):
|
||||
X_base = train_tokens_ids[i][j - 1: j + 2]
|
||||
X_string = X_strings[i][j - 1: j + 2]
|
||||
X_extra = add_extra_features(X_base, X_string)
|
||||
|
||||
Y_pred = ner_model(X_base)
|
||||
label = num2label[int(torch.argmax(Y_pred))]
|
||||
labels_str += label + ' '
|
||||
|
||||
predictions.append(labels_str[:-1])
|
||||
|
||||
lines = []
|
||||
for line in predictions:
|
||||
prev_label = None
|
||||
line_corr = []
|
||||
for label in line.split():
|
||||
if label != 'O' and label[0] == 'I':
|
||||
if prev_label is None or prev_label == 'O':
|
||||
label = label.replace('I', 'B')
|
||||
else:
|
||||
label = 'I' + prev_label[1:]
|
||||
prev_label = label
|
||||
line_corr.append(label)
|
||||
lines.append(' '.join(line_corr))
|
||||
with open(f'{path}/out.tsv', 'w', encoding='utf-8') as f:
|
||||
for l in lines:
|
||||
f.write(l + '\n')
|
||||
|
||||
|
||||
predict('test-A')
|
||||
predict('dev-0')
|
160
train.py
Normal file
160
train.py
Normal file
@ -0,0 +1,160 @@
|
||||
from collections import Counter
|
||||
import torch as torch
|
||||
import torchtext.vocab
|
||||
from bidict import bidict
|
||||
from string import punctuation
|
||||
|
||||
label2num = bidict({'O': 0, 'B-PER': 1, 'B-LOC': 2, 'I-PER': 3, 'B-MISC': 4, 'I-MISC': 5, 'I-LOC': 6, 'B-ORG': 7,
|
||||
'I-ORG': 8})
|
||||
num2label = label2num.inverse
|
||||
|
||||
|
||||
def build_vocab(dataset):
|
||||
counter = Counter()
|
||||
for document in dataset:
|
||||
counter.update(document)
|
||||
vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
|
||||
vocab.set_default_index(0)
|
||||
return vocab
|
||||
|
||||
|
||||
def data_process(dt):
|
||||
processed = [
|
||||
torch.tensor([vocab['<bos>']] + [vocab[token] for token in document] + [vocab['<eos>']], dtype=torch.long)
|
||||
for document in dt]
|
||||
return processed
|
||||
|
||||
|
||||
def labels_process(dt):
|
||||
dt_num = [[label2num[label] for label in labels] for labels in dt]
|
||||
return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt_num]
|
||||
|
||||
|
||||
def add_extra_features(x_base, x_str):
|
||||
extra_features = []
|
||||
for word in x_str:
|
||||
word_features = [0] * 9
|
||||
if word.islower():
|
||||
word_features[0] = 1
|
||||
if word.isupper():
|
||||
word_features[1] = 1
|
||||
if word.isalnum():
|
||||
word_features[2] = 1
|
||||
if word.isalpha():
|
||||
word_features[3] = 1
|
||||
if word.isdigit():
|
||||
word_features[4] = 1
|
||||
if word.istitle():
|
||||
word_features[5] = 1
|
||||
for char in word:
|
||||
if char in punctuation:
|
||||
word_features[6] = 1
|
||||
break
|
||||
if len(word) == 1:
|
||||
if word in punctuation:
|
||||
word_features[7] = 1
|
||||
if len(word) < 3:
|
||||
word_features[8] = 1
|
||||
extra_features += word_features
|
||||
while len(extra_features) != 27:
|
||||
extra_features += [0] * 9
|
||||
extra_features = torch.tensor(extra_features)
|
||||
x_extra = torch.cat((x_base, extra_features), 0)
|
||||
return x_extra
|
||||
|
||||
|
||||
class NERModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, ):
|
||||
super(NERModel, self).__init__()
|
||||
self.emb = torch.nn.Embedding(23627, 200)
|
||||
self.fc1 = torch.nn.Linear(6000, 9)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.emb(x)
|
||||
x = x.reshape(6000)
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
|
||||
X = []
|
||||
Y = []
|
||||
X_strings = []
|
||||
Y_strings = []
|
||||
|
||||
with open('train.tsv', encoding='utf-8') as f:
|
||||
for l in f:
|
||||
l = l.strip().split('\t')
|
||||
tags_list = l[0].split()
|
||||
text_list = l[1].split()
|
||||
X.append(text_list)
|
||||
X_strings.append(text_list)
|
||||
Y.append(tags_list)
|
||||
Y_strings.append(tags_list)
|
||||
|
||||
vocab = build_vocab(X)
|
||||
train_tokens_ids = data_process(X)
|
||||
|
||||
train_labels = labels_process(Y)
|
||||
|
||||
ner_model = NERModel()
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(ner_model.parameters())
|
||||
|
||||
# TRAIN
|
||||
print('-----TRAINING-----')
|
||||
for epoch in range(1):
|
||||
loss_score = 0
|
||||
acc_score = 0
|
||||
prec_score = 0
|
||||
selected_items = 0
|
||||
recall_score = 0
|
||||
relevant_items = 0
|
||||
items_total = 0
|
||||
ner_model.train()
|
||||
a = 0
|
||||
for i in range(len(train_labels)):
|
||||
a += 1
|
||||
print(a)
|
||||
for j in range(1, len(train_labels[i]) - 1):
|
||||
X_base = train_tokens_ids[i][j - 1: j + 2]
|
||||
X_string = X_strings[i][j - 1: j + 2]
|
||||
X_extra = add_extra_features(X_base, X_string)
|
||||
|
||||
Y = train_labels[i][j: j + 1]
|
||||
|
||||
Y_predictions = ner_model(X_extra)
|
||||
|
||||
acc_score += int(torch.argmax(Y_predictions) == Y)
|
||||
|
||||
if torch.argmax(Y_predictions) != 0:
|
||||
selected_items += 1
|
||||
if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():
|
||||
prec_score += 1
|
||||
|
||||
if Y.item() != 0:
|
||||
relevant_items += 1
|
||||
if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():
|
||||
recall_score += 1
|
||||
|
||||
items_total += 1
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(Y_predictions.unsqueeze(0), Y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
loss_score += loss.item()
|
||||
|
||||
precision = prec_score / selected_items
|
||||
recall = recall_score / relevant_items
|
||||
f1_score = (2 * precision * recall) / (precision + recall)
|
||||
print('epoch: ', epoch)
|
||||
print('loss: ', loss_score / items_total)
|
||||
print('acc: ', acc_score / items_total)
|
||||
print('prec: ', precision)
|
||||
print('recall: : ', recall)
|
||||
print('f1: ', f1_score)
|
||||
|
||||
PATH = 'model.pt'
|
||||
torch.save(ner_model, PATH)
|
Loading…
Reference in New Issue
Block a user