102 KiB
102 KiB
def read_data(path):
with open(path, 'r') as f:
dataset = [line.strip().split() for line in f]
return dataset
dataset = read_data('train/train.tsv')
train_x = [x[1] for x in dataset]
train_y = [y[0] for y in dataset]
train_y
['B-ORG', 'O', 'B-LOC', 'B-LOC', 'B-MISC', 'B-MISC', 'B-ORG', 'B-ORG', 'O', 'B-LOC', 'B-MISC', 'O', 'B-MISC', 'B-LOC', 'B-PER', 'B-MISC', 'B-LOC', 'O', 'O', 'B-ORG', 'O', 'O', 'B-ORG', 'B-MISC', 'O', 'O', 'B-MISC', 'B-MISC', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'B-ORG', 'B-ORG', 'B-ORG', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'O', 'B-ORG', 'O', 'B-LOC', 'B-MISC', 'B-LOC', 'B-MISC', 'B-MISC', 'O', 'O', 'B-MISC', 'B-MISC', 'B-MISC', 'B-PER', 'B-PER', 'O', 'B-LOC', 'O', 'B-MISC', 'B-LOC', 'O', 'O', 'B-LOC', 'B-MISC', 'B-LOC', 'B-LOC', 'B-MISC', 'B-MISC', 'B-PER', 'B-LOC', 'O', 'B-ORG', 'O', 'B-MISC', 'B-ORG', 'B-PER', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'O', 'B-PER', 'B-ORG', 'B-ORG', 'B-ORG', 'B-ORG', 'O', 'B-ORG', 'B-ORG', 'B-LOC', 'B-MISC', 'O', 'B-LOC', 'O', 'B-LOC', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'B-MISC', 'O', 'O', 'B-PER', 'B-LOC', 'O', 'O', 'B-ORG', 'O', 'O', 'O', 'B-LOC', 'B-MISC', 'O', 'B-LOC', 'B-LOC', 'B-MISC', 'O', 'B-PER', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'O', 'B-PER', 'B-LOC', 'B-ORG', 'B-ORG', 'B-LOC', 'B-LOC', 'B-LOC', 'O', 'B-ORG', 'O', 'B-LOC', 'B-MISC', 'B-ORG', 'B-ORG', 'B-ORG', 'O', 'B-LOC', 'B-LOC', 'O', 'O', 'O', 'B-ORG', 'B-ORG', 'O', 'B-ORG', 'B-MISC', 'B-MISC', 'B-ORG', 'O', 'B-ORG', 'B-ORG', 'B-ORG', 'B-MISC', 'B-MISC', 'B-ORG', 'B-ORG', 'B-LOC', 'B-ORG', 'B-ORG', 'B-ORG', 'B-ORG', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'B-LOC', 'B-LOC', 'O', 'B-LOC', 'B-MISC', 'B-MISC', 'B-LOC', 'B-LOC', 'O', 'B-MISC', 'B-LOC', 'O', 'B-PER', 'O', 'B-LOC', 'B-MISC', 'B-MISC', 'B-MISC', 'B-LOC', 'O', 'B-LOC', 'B-MISC', 'O', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'B-MISC', 'B-ORG', 'B-MISC', 'O', 'B-LOC', 'B-MISC', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'B-LOC', 'B-LOC', 'B-LOC', 'O', 'B-PER', 'O', 'O', 'B-MISC', 'O', 'B-MISC', 'B-MISC', 'O', 'B-MISC', 'O', 'O', 'O', 'B-ORG', 'B-ORG', 'B-MISC', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'B-ORG', 'O', 'B-PER', 'B-ORG', 'B-MISC', 'B-ORG', 'B-PER', 'B-ORG', 'O', 'B-ORG', 'B-ORG', 'B-MISC', 'B-LOC', 'B-MISC', 'O', 'B-ORG', 'B-MISC', 'B-ORG', 'O', 'O', 'B-ORG', 'B-ORG', 'B-MISC', 'B-ORG', 'O', 'B-MISC', 'B-LOC', 'B-MISC', 'B-LOC', 'B-LOC', 'O', 'B-PER', 'B-LOC', 'B-PER', 'O', 'B-LOC', 'B-MISC', 'O', 'B-LOC', 'B-ORG', 'B-ORG', 'O', 'O', 'O', 'B-PER', 'B-MISC', 'B-PER', 'B-ORG', 'B-ORG', 'O', 'B-ORG', 'B-MISC', 'B-MISC', 'B-ORG', 'B-LOC', 'O', 'O', 'B-ORG', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'B-ORG', 'B-LOC', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'B-ORG', 'B-ORG', 'B-PER', 'B-MISC', 'B-MISC', 'B-ORG', 'O', 'B-LOC', 'O', 'B-MISC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-MISC', 'O', 'B-MISC', 'B-MISC', 'B-MISC', 'B-LOC', 'B-MISC', 'B-LOC', 'B-MISC', 'B-LOC', 'B-LOC', 'O', 'O', 'B-ORG', 'B-LOC', 'B-LOC', 'B-MISC', 'O', 'B-LOC', 'O', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-PER', 'O', 'B-LOC', 'B-LOC', 'B-ORG', 'B-LOC', 'B-ORG', 'B-MISC', 'B-MISC', 'O', 'B-PER', 'B-MISC', 'B-MISC', 'B-LOC', 'B-LOC', 'B-MISC', 'B-LOC', 'B-LOC', 'B-MISC', 'B-ORG', 'B-ORG', 'B-ORG', 'B-ORG', 'O', 'B-ORG', 'B-ORG', 'B-LOC', 'B-LOC', 'B-ORG', 'B-ORG', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'O', 'B-LOC', 'B-LOC', 'B-LOC', 'O', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-MISC', 'B-ORG', 'B-MISC', 'O', 'B-ORG', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'B-MISC', 'B-ORG', 'B-LOC', 'B-ORG', 'B-ORG', 'B-LOC', 'B-ORG', 'B-LOC', 'B-ORG', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'B-MISC', 'O', 'O', 'B-LOC', 'B-LOC', 'O', 'B-LOC', 'B-PER', 'O', 'B-MISC', 'B-LOC', 'B-MISC', 'B-MISC', 'O', 'B-LOC', 'B-PER', 'B-LOC', 'B-LOC', 'O', 'O', 'O', 'B-ORG', 'B-MISC', 'B-ORG', 'B-MISC', 'B-MISC', 'B-LOC', 'B-PER', 'B-MISC', 'B-ORG', 'O', 'B-ORG', 'B-ORG', 'B-LOC', 'B-MISC', 'O', 'B-MISC', 'B-MISC', 'B-MISC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'B-MISC', 'B-ORG', 'B-PER', 'B-LOC', 'B-MISC', 'B-LOC', 'B-LOC', 'B-ORG', 'B-LOC', 'B-PER', 'O', 'B-ORG', 'B-LOC', 'B-MISC', 'O', 'B-MISC', 'B-LOC', 'B-ORG', 'B-PER', 'O', 'B-MISC', 'O', 'B-ORG', 'B-ORG', 'B-LOC', 'B-ORG', 'B-ORG', 'B-ORG', 'B-ORG', 'B-ORG', 'O', 'B-ORG', 'B-LOC', 'B-LOC', 'B-ORG', 'O', 'B-LOC', 'B-PER', 'B-LOC', 'O', 'B-ORG', 'O', 'B-LOC', 'B-ORG', 'B-ORG', 'B-LOC', 'B-ORG', 'B-ORG', 'B-MISC', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'O', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'B-LOC', 'B-ORG', 'O', 'B-LOC', 'B-LOC', 'B-LOC', 'B-MISC', 'B-LOC', 'B-LOC', 'O', 'B-ORG', 'O', 'B-PER', 'B-ORG', 'B-MISC', 'O', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'O', 'B-LOC', 'O', 'B-MISC', 'B-MISC', 'B-ORG', 'B-MISC', 'O', 'B-LOC', 'O', 'O', 'B-ORG', 'O', 'O', 'B-PER', 'B-PER', 'B-LOC', 'B-LOC', 'B-ORG', 'O', 'B-LOC', 'B-ORG', 'B-MISC', 'O', 'B-LOC', 'O', 'O', 'B-MISC', 'B-ORG', 'B-ORG', 'B-ORG', 'B-ORG', 'B-ORG', 'B-ORG', 'B-ORG', 'B-ORG', 'B-LOC', 'B-MISC', 'B-LOC', 'B-LOC', 'B-LOC', 'O', 'B-MISC', 'B-MISC', 'B-LOC', 'B-ORG', 'O', 'B-PER', 'B-PER', 'O', 'O', 'B-PER', 'O', 'B-LOC', 'B-LOC', 'B-MISC', 'B-LOC', 'O', 'B-MISC', 'O', 'O', 'B-ORG', 'B-LOC', 'B-MISC', 'B-ORG', 'B-ORG', 'B-MISC', 'O', 'B-LOC', 'B-LOC', 'B-MISC', 'B-ORG', 'B-ORG', 'B-LOC', 'B-ORG', 'B-LOC', 'B-ORG', 'B-LOC', 'B-ORG', 'B-ORG', 'O', 'O', 'B-ORG', 'B-LOC', 'B-ORG', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
import torchtext.vocab
from collections import Counter
def build_vocab(dataset):
counter = Counter()
for document in dataset:
counter.update(document)
vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
vocab.set_default_index(0)
return vocab
train_x = [x.split() for x in train_x]
vocab = build_vocab(train_x)
def data_process(dt):
return [ torch.tensor([vocab['<bos>']] +[vocab[token] for token in document ] + [vocab['<eos>']], dtype = torch.long) for document in dt]
def labels_process(dt):
labels = []
for document in dt:
temp = []
temp.append(0)
temp.append(document)
temp.append(0)
labels.append(torch.tensor(temp, dtype = torch.long))
return labels
#return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]
ner_tags = {'O': 0, 'B-ORG': 1, 'I-ORG': 2, 'B-PER': 3, 'I-PER': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
import torch
train_tokens_ids = data_process(train_x)
dev_x = read_data('dev-0/in.tsv')
dev_y = read_data('dev-0/expected.tsv')
test_x = read_data('test-A/in.tsv')
dev_x = [x[0].split() for x in dev_x]
dev_y = [y[0].split() for y in dev_y]
test_x = [x[0].split() for x in test_x]
train_y = [y[0] for y in dataset]
display(train_y[:5])
train_y = [ner_tags.get(tag) for tag in train_y]
train_y[:5]
['B-ORG', 'O', 'B-LOC', 'B-LOC', 'B-MISC']
[1, 0, 5, 5, 7]
dev_y = [ner_tags.get(tag) for y in dev_y for tag in y]
test_tokens_ids = data_process(dev_x)
train_labels = labels_process(train_y)
test_labels = labels_process(dev_y)
class NERModel(torch.nn.Module):
def __init__(self,):
super(NERModel, self).__init__()
self.emb = torch.nn.Embedding(23627, 200)
self.fc1 = torch.nn.Linear(2400, 9)
#self.softmax = torch.nn.Softmax(dim=1)
# nie trzeba, bo używamy https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
# jako kryterium
def forward(self, x):
x = self.emb(x)
x = x.reshape(2400)
x = self.fc1(x)
#x = self.softmax(x)
return x
ner_model = NERModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ner_model.parameters())
import string
def add_features(tens, tokens):
array = [0, 0, 0, 0, 0, 0, 0, 0, 0]
if len(tokens) >= 2:
if len(tokens[1]) >= 1:
word = tokens[1]
if word[0].isupper():
array[0] = 1
if word.isalnum():
array[1] = 1
for i in word:
# checking whether the char is punctuation.
if i in string.punctuation:
# Printing the punctuation values
array[2] = 1
if word.isnumeric():
array[3] = 1
if word.isupper():
array[4] = 1
if '-' in word:
array[5] = 1
if '/' in word:
array[6] = 1
if len(word) > 3:
array[7] = 1
if len(word) > 6:
array[8] = 1
x = torch.tensor(array)
new_tensor = torch.cat((tens, x), 0)
return new_tensor
for epoch in range(50):
loss_score = 0
acc_score = 0
prec_score = 0
selected_items = 0
recall_score = 0
relevant_items = 0
items_total = 0
ner_model.train()
#for i in range(len(train_labels)):
for i in range(100):
for j in range(1, len(train_labels[i]) - 1):
X_base = train_tokens_ids[i][j-1: j+2]
X_add = train_x[i][j-1: j+2]
X_final = add_features(X_base, X_add)
Y = train_labels[i][j: j+1]
Y_predictions = ner_model(X_final)
acc_score += int(torch.argmax(Y_predictions) == Y)
if torch.argmax(Y_predictions) != 0:
selected_items +=1
if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():
prec_score += 1
if Y.item() != 0:
relevant_items +=1
if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():
recall_score += 1
items_total += 1
optimizer.zero_grad()
loss = criterion(Y_predictions.unsqueeze(0), Y)
loss.backward()
optimizer.step()
loss_score += loss.item()
precision = prec_score / selected_items
recall = recall_score / relevant_items
#f1_score = (2*precision * recall) / (precision + recall)
display('epoch: ', epoch)
display('loss: ', loss_score / items_total)
display('acc: ', acc_score / items_total)
display('prec: ', precision)
display('recall: : ', recall)
#display('f1: ', f1_score)
'epoch: '
0
'loss: '
2.811322446731947
'acc: '
0.48
'prec: '
0.18604651162790697
'recall: : '
0.1702127659574468
'epoch: '
1
'loss: '
2.5642633876085097
'acc: '
0.43
'prec: '
0.1702127659574468
'recall: : '
0.1702127659574468
'epoch: '
2
'loss: '
2.2745147155076166
'acc: '
0.54
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
3
'loss: '
2.3077734905840908
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
4
'loss: '
2.2327055485211984
'acc: '
0.51
'prec: '
0.24489795918367346
'recall: : '
0.2553191489361702
'epoch: '
5
'loss: '
2.032022957816762
'acc: '
0.58
'prec: '
0.2978723404255319
'recall: : '
0.2978723404255319
'epoch: '
6
'loss: '
1.9094040171859614
'acc: '
0.57
'prec: '
0.2765957446808511
'recall: : '
0.2765957446808511
'epoch: '
7
'loss: '
1.8801336237322357
'acc: '
0.54
'prec: '
0.2391304347826087
'recall: : '
0.23404255319148937
'epoch: '
8
'loss: '
1.853852765722122
'acc: '
0.53
'prec: '
0.2222222222222222
'recall: : '
0.2127659574468085
'epoch: '
9
'loss: '
1.8288560365608282
'acc: '
0.53
'prec: '
0.2222222222222222
'recall: : '
0.2127659574468085
'epoch: '
10
'loss: '
1.8022360281114742
'acc: '
0.53
'prec: '
0.2222222222222222
'recall: : '
0.2127659574468085
'epoch: '
11
'loss: '
1.775143896874324
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
12
'loss: '
1.748205496848568
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
13
'loss: '
1.7217520299459284
'acc: '
0.54
'prec: '
0.24444444444444444
'recall: : '
0.23404255319148937
'epoch: '
14
'loss: '
1.6958234204566542
'acc: '
0.55
'prec: '
0.2608695652173913
'recall: : '
0.2553191489361702
'epoch: '
15
'loss: '
1.6712835076041666
'acc: '
0.55
'prec: '
0.2608695652173913
'recall: : '
0.2553191489361702
'epoch: '
16
'loss: '
1.6480589298788255
'acc: '
0.55
'prec: '
0.2608695652173913
'recall: : '
0.2553191489361702
'epoch: '
17
'loss: '
1.6257991834238783
'acc: '
0.55
'prec: '
0.2608695652173913
'recall: : '
0.2553191489361702
'epoch: '
18
'loss: '
1.604242644636688
'acc: '
0.55
'prec: '
0.26666666666666666
'recall: : '
0.2553191489361702
'epoch: '
19
'loss: '
1.5837320431148691
'acc: '
0.55
'prec: '
0.26666666666666666
'recall: : '
0.2553191489361702
'epoch: '
20
'loss: '
1.5641135577083332
'acc: '
0.54
'prec: '
0.24444444444444444
'recall: : '
0.23404255319148937
'epoch: '
21
'loss: '
1.5454202877922216
'acc: '
0.54
'prec: '
0.24444444444444444
'recall: : '
0.23404255319148937
'epoch: '
22
'loss: '
1.5275404900076683
'acc: '
0.54
'prec: '
0.24444444444444444
'recall: : '
0.23404255319148937
'epoch: '
23
'loss: '
1.5099791488426855
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
24
'loss: '
1.4932281806698302
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
25
'loss: '
1.4772486361171469
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
26
'loss: '
1.4617241937015206
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
27
'loss: '
1.447145789535134
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
28
'loss: '
1.433001351452549
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
29
'loss: '
1.4193171267636353
'acc: '
0.54
'prec: '
0.25
'recall: : '
0.23404255319148937
'epoch: '
30
'loss: '
1.4062099850556116
'acc: '
0.55
'prec: '
0.2727272727272727
'recall: : '
0.2553191489361702
'epoch: '
31
'loss: '
1.3941219550243114
'acc: '
0.55
'prec: '
0.2727272727272727
'recall: : '
0.2553191489361702
'epoch: '
32
'loss: '
1.3825843345944304
'acc: '
0.56
'prec: '
0.29545454545454547
'recall: : '
0.2765957446808511
'epoch: '
33
'loss: '
1.3714176365407185
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
34
'loss: '
1.3609581479639745
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
35
'loss: '
1.3509947879862738
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
36
'loss: '
1.3424826521927025
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
37
'loss: '
1.3336302372731734
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
38
'loss: '
1.3246490387670928
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
39
'loss: '
1.316349835752626
'acc: '
0.57
'prec: '
0.30952380952380953
'recall: : '
0.2765957446808511
'epoch: '
40
'loss: '
1.3090153592341813
'acc: '
0.57
'prec: '
0.30952380952380953
'recall: : '
0.2765957446808511
'epoch: '
41
'loss: '
1.3016801220795606
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
42
'loss: '
1.2947906140016858
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
43
'loss: '
1.2887717709777644
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
44
'loss: '
1.2825759449476026
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
45
'loss: '
1.2770079451325
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
46
'loss: '
1.2715366566940793
'acc: '
0.56
'prec: '
0.30952380952380953
'recall: : '
0.2765957446808511
'epoch: '
47
'loss: '
1.266929566776089
'acc: '
0.56
'prec: '
0.30952380952380953
'recall: : '
0.2765957446808511
'epoch: '
48
'loss: '
1.2626329537964194
'acc: '
0.56
'prec: '
0.30952380952380953
'recall: : '
0.2765957446808511
'epoch: '
49
'loss: '
1.2600517650827532
'acc: '
0.56
'prec: '
0.30952380952380953
'recall: : '
0.2765957446808511
(2*precision * recall) / (precision + recall)
0.29213483146067415
Y_predictions
tensor([ -0.5326, -1.1218, -10.0297, -0.1610, -10.9741, 1.3533, -11.9781, 0.6097, -9.1263], grad_fn=<AddBackward0>)
ner_tags_re = {
0: 'O',
1: 'B-PER',
2: 'B-LOC',
3: 'I-PER',
4: 'B-MISC',
5: 'I-MISC',
6: 'I-LOC',
7: 'B-ORG',
8: 'I-ORG'
}
def generate_out(folder_path):
ner_model.eval()
ner_model.cpu()
print('Generating out')
X_dev = []
with open(f"{folder_path}/in.tsv", 'r') as file:
for line in file:
line = line.strip()
X_dev.append(line.split(' '))
test_tokens_ids = data_process(X_dev)
predicted_values = []
# for i in range(100):
for i in range(len(test_tokens_ids)):
pred_string = ''
for j in range(1, len(test_tokens_ids[i]) - 1):
X = test_tokens_ids[i][j - 1: j + 2]
X_raw_single = X_dev[i][j - 1: j + 2]
X = add_features(X, X_raw_single)
# X = X.to(device)
# print('train is cuda?', X.is_cuda)
try:
Y_predictions = ner_model(X)
id = torch.argmax(Y_predictions)
val = ner_tags_re[int(id)]
pred_string += val + ' '
except Exception as e:
print('Error', e)
predicted_values.append(pred_string[:-1])
lines = []
for line in predicted_values:
last_label = None
line = line.split(' ')
new_line = []
for label in line:
if (label != "O" and label[0:2] == "I-"):
if last_label == None or last_label == "O":
label = label.replace('I-', 'B-')
else:
label = "I-" + last_label[2:]
last_label = label
new_line.append(label)
lines.append(" ".join(new_line))
with open(f"{folder_path}/out.tsv", "w") as f:
for line in lines:
f.write(str(line) + "\n")
f.close()
generate_out('dev-0')
generate_out('test-A')
Generating out step 6 Generating out step 6