en-ner-conll-2003/.ipynb_checkpoints/gotta_go-checkpoint.ipynb
2022-06-14 23:30:32 +02:00

102 KiB

def read_data(path):
    with open(path, 'r') as f:
        dataset = [line.strip().split() for line in f]
    return dataset
dataset = read_data('train/train.tsv')
train_x = [x[1] for x in dataset]
train_y = [y[0] for y in dataset]
train_y
['B-ORG',
 'O',
 'B-LOC',
 'B-LOC',
 'B-MISC',
 'B-MISC',
 'B-ORG',
 'B-ORG',
 'O',
 'B-LOC',
 'B-MISC',
 'O',
 'B-MISC',
 'B-LOC',
 'B-PER',
 'B-MISC',
 'B-LOC',
 'O',
 'O',
 'B-ORG',
 'O',
 'O',
 'B-ORG',
 'B-MISC',
 'O',
 'O',
 'B-MISC',
 'B-MISC',
 'B-MISC',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'O',
 'O',
 'B-LOC',
 'O',
 'B-LOC',
 'O',
 'B-ORG',
 'O',
 'B-LOC',
 'B-MISC',
 'B-LOC',
 'B-MISC',
 'B-MISC',
 'O',
 'O',
 'B-MISC',
 'B-MISC',
 'B-MISC',
 'B-PER',
 'B-PER',
 'O',
 'B-LOC',
 'O',
 'B-MISC',
 'B-LOC',
 'O',
 'O',
 'B-LOC',
 'B-MISC',
 'B-LOC',
 'B-LOC',
 'B-MISC',
 'B-MISC',
 'B-PER',
 'B-LOC',
 'O',
 'B-ORG',
 'O',
 'B-MISC',
 'B-ORG',
 'B-PER',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'O',
 'B-PER',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'O',
 'B-ORG',
 'B-ORG',
 'B-LOC',
 'B-MISC',
 'O',
 'B-LOC',
 'O',
 'B-LOC',
 'B-ORG',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-ORG',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-ORG',
 'B-MISC',
 'O',
 'O',
 'B-PER',
 'B-LOC',
 'O',
 'O',
 'B-ORG',
 'O',
 'O',
 'O',
 'B-LOC',
 'B-MISC',
 'O',
 'B-LOC',
 'B-LOC',
 'B-MISC',
 'O',
 'B-PER',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'O',
 'B-PER',
 'B-LOC',
 'B-ORG',
 'B-ORG',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'O',
 'B-ORG',
 'O',
 'B-LOC',
 'B-MISC',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'O',
 'B-LOC',
 'B-LOC',
 'O',
 'O',
 'O',
 'B-ORG',
 'B-ORG',
 'O',
 'B-ORG',
 'B-MISC',
 'B-MISC',
 'B-ORG',
 'O',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-MISC',
 'B-MISC',
 'B-ORG',
 'B-ORG',
 'B-LOC',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'O',
 'B-ORG',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-LOC',
 'O',
 'B-MISC',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-ORG',
 'B-ORG',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-MISC',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'O',
 'B-LOC',
 'B-MISC',
 'B-MISC',
 'B-LOC',
 'B-LOC',
 'O',
 'B-MISC',
 'B-LOC',
 'O',
 'B-PER',
 'O',
 'B-LOC',
 'B-MISC',
 'B-MISC',
 'B-MISC',
 'B-LOC',
 'O',
 'B-LOC',
 'B-MISC',
 'O',
 'O',
 'B-ORG',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-MISC',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-MISC',
 'B-MISC',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-LOC',
 'O',
 'O',
 'O',
 'O',
 'B-MISC',
 'O',
 'B-MISC',
 'B-ORG',
 'B-MISC',
 'O',
 'B-LOC',
 'B-MISC',
 'O',
 'B-MISC',
 'O',
 'O',
 'O',
 'O',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'O',
 'B-PER',
 'O',
 'O',
 'B-MISC',
 'O',
 'B-MISC',
 'B-MISC',
 'O',
 'B-MISC',
 'O',
 'O',
 'O',
 'B-ORG',
 'B-ORG',
 'B-MISC',
 'B-MISC',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-MISC',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-ORG',
 'O',
 'B-ORG',
 'O',
 'B-PER',
 'B-ORG',
 'B-MISC',
 'B-ORG',
 'B-PER',
 'B-ORG',
 'O',
 'B-ORG',
 'B-ORG',
 'B-MISC',
 'B-LOC',
 'B-MISC',
 'O',
 'B-ORG',
 'B-MISC',
 'B-ORG',
 'O',
 'O',
 'B-ORG',
 'B-ORG',
 'B-MISC',
 'B-ORG',
 'O',
 'B-MISC',
 'B-LOC',
 'B-MISC',
 'B-LOC',
 'B-LOC',
 'O',
 'B-PER',
 'B-LOC',
 'B-PER',
 'O',
 'B-LOC',
 'B-MISC',
 'O',
 'B-LOC',
 'B-ORG',
 'B-ORG',
 'O',
 'O',
 'O',
 'B-PER',
 'B-MISC',
 'B-PER',
 'B-ORG',
 'B-ORG',
 'O',
 'B-ORG',
 'B-MISC',
 'B-MISC',
 'B-ORG',
 'B-LOC',
 'O',
 'O',
 'B-ORG',
 'O',
 'O',
 'B-MISC',
 'O',
 'O',
 'O',
 'B-MISC',
 'B-ORG',
 'B-LOC',
 'O',
 'O',
 'O',
 'B-ORG',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-ORG',
 'O',
 'O',
 'B-ORG',
 'B-ORG',
 'B-PER',
 'B-MISC',
 'B-MISC',
 'B-ORG',
 'O',
 'B-LOC',
 'O',
 'B-MISC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-MISC',
 'O',
 'B-MISC',
 'B-MISC',
 'B-MISC',
 'B-LOC',
 'B-MISC',
 'B-LOC',
 'B-MISC',
 'B-LOC',
 'B-LOC',
 'O',
 'O',
 'B-ORG',
 'B-LOC',
 'B-LOC',
 'B-MISC',
 'O',
 'B-LOC',
 'O',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-PER',
 'O',
 'B-LOC',
 'B-LOC',
 'B-ORG',
 'B-LOC',
 'B-ORG',
 'B-MISC',
 'B-MISC',
 'O',
 'B-PER',
 'B-MISC',
 'B-MISC',
 'B-LOC',
 'B-LOC',
 'B-MISC',
 'B-LOC',
 'B-LOC',
 'B-MISC',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'O',
 'B-ORG',
 'B-ORG',
 'B-LOC',
 'B-LOC',
 'B-ORG',
 'B-ORG',
 'O',
 'O',
 'O',
 'B-ORG',
 'O',
 'O',
 'O',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'O',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-MISC',
 'B-ORG',
 'B-MISC',
 'O',
 'B-ORG',
 'O',
 'B-LOC',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-ORG',
 'O',
 'B-ORG',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-MISC',
 'B-MISC',
 'B-ORG',
 'B-LOC',
 'B-ORG',
 'B-ORG',
 'B-LOC',
 'B-ORG',
 'B-LOC',
 'B-ORG',
 'B-LOC',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-LOC',
 'B-MISC',
 'O',
 'O',
 'B-LOC',
 'B-LOC',
 'O',
 'B-LOC',
 'B-PER',
 'O',
 'B-MISC',
 'B-LOC',
 'B-MISC',
 'B-MISC',
 'O',
 'B-LOC',
 'B-PER',
 'B-LOC',
 'B-LOC',
 'O',
 'O',
 'O',
 'B-ORG',
 'B-MISC',
 'B-ORG',
 'B-MISC',
 'B-MISC',
 'B-LOC',
 'B-PER',
 'B-MISC',
 'B-ORG',
 'O',
 'B-ORG',
 'B-ORG',
 'B-LOC',
 'B-MISC',
 'O',
 'B-MISC',
 'B-MISC',
 'B-MISC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-MISC',
 'B-ORG',
 'B-PER',
 'B-MISC',
 'B-ORG',
 'B-PER',
 'B-LOC',
 'B-MISC',
 'B-LOC',
 'B-LOC',
 'B-ORG',
 'B-LOC',
 'B-PER',
 'O',
 'B-ORG',
 'B-LOC',
 'B-MISC',
 'O',
 'B-MISC',
 'B-LOC',
 'B-ORG',
 'B-PER',
 'O',
 'B-MISC',
 'O',
 'B-ORG',
 'B-ORG',
 'B-LOC',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'O',
 'B-ORG',
 'B-LOC',
 'B-LOC',
 'B-ORG',
 'O',
 'B-LOC',
 'B-PER',
 'B-LOC',
 'O',
 'B-ORG',
 'O',
 'B-LOC',
 'B-ORG',
 'B-ORG',
 'B-LOC',
 'B-ORG',
 'B-ORG',
 'B-MISC',
 'B-LOC',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-PER',
 'O',
 'O',
 'B-ORG',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-LOC',
 'B-LOC',
 'B-ORG',
 'O',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-MISC',
 'B-LOC',
 'B-LOC',
 'O',
 'B-ORG',
 'O',
 'B-PER',
 'B-ORG',
 'B-MISC',
 'O',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'O',
 'B-LOC',
 'O',
 'B-MISC',
 'B-MISC',
 'B-ORG',
 'B-MISC',
 'O',
 'B-LOC',
 'O',
 'O',
 'B-ORG',
 'O',
 'O',
 'B-PER',
 'B-PER',
 'B-LOC',
 'B-LOC',
 'B-ORG',
 'O',
 'B-LOC',
 'B-ORG',
 'B-MISC',
 'O',
 'B-LOC',
 'O',
 'O',
 'B-MISC',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-ORG',
 'B-LOC',
 'B-MISC',
 'B-LOC',
 'B-LOC',
 'B-LOC',
 'O',
 'B-MISC',
 'B-MISC',
 'B-LOC',
 'B-ORG',
 'O',
 'B-PER',
 'B-PER',
 'O',
 'O',
 'B-PER',
 'O',
 'B-LOC',
 'B-LOC',
 'B-MISC',
 'B-LOC',
 'O',
 'B-MISC',
 'O',
 'O',
 'B-ORG',
 'B-LOC',
 'B-MISC',
 'B-ORG',
 'B-ORG',
 'B-MISC',
 'O',
 'B-LOC',
 'B-LOC',
 'B-MISC',
 'B-ORG',
 'B-ORG',
 'B-LOC',
 'B-ORG',
 'B-LOC',
 'B-ORG',
 'B-LOC',
 'B-ORG',
 'B-ORG',
 'O',
 'O',
 'B-ORG',
 'B-LOC',
 'B-ORG',
 'B-LOC',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O']
import torchtext.vocab
from collections import Counter
def build_vocab(dataset):
    counter = Counter()
    for document in dataset:
        counter.update(document)
    
    vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
    vocab.set_default_index(0)
    return vocab
train_x = [x.split() for x in train_x]
vocab = build_vocab(train_x)
def data_process(dt):
    return [ torch.tensor([vocab['<bos>']] +[vocab[token]  for token in  document ] + [vocab['<eos>']], dtype = torch.long) for document in dt]

def labels_process(dt):
    labels = []
    for document in dt:
        temp = []
        temp.append(0)
        temp.append(document)
        temp.append(0)
        labels.append(torch.tensor(temp, dtype = torch.long))
    return labels
        
    
    #return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]
ner_tags = {'O': 0, 'B-ORG': 1, 'I-ORG': 2, 'B-PER': 3, 'I-PER': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
import torch

train_tokens_ids = data_process(train_x)
dev_x = read_data('dev-0/in.tsv')
dev_y = read_data('dev-0/expected.tsv')

test_x = read_data('test-A/in.tsv')

dev_x = [x[0].split() for x in dev_x]
dev_y = [y[0].split() for y in dev_y]
test_x = [x[0].split() for x in test_x]
train_y = [y[0] for y in dataset]
display(train_y[:5])
train_y = [ner_tags.get(tag) for tag in train_y]
train_y[:5]
['B-ORG', 'O', 'B-LOC', 'B-LOC', 'B-MISC']
[1, 0, 5, 5, 7]
dev_y = [ner_tags.get(tag) for y in dev_y for tag in y]
test_tokens_ids = data_process(dev_x)
train_labels = labels_process(train_y)
test_labels = labels_process(dev_y)
class NERModel(torch.nn.Module):

    def __init__(self,):
        super(NERModel, self).__init__()
        self.emb = torch.nn.Embedding(23627, 200)
        self.fc1 = torch.nn.Linear(2400, 9)
        #self.softmax = torch.nn.Softmax(dim=1)
        # nie trzeba, bo używamy https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
        # jako kryterium
        

    def forward(self, x):
        x = self.emb(x)
        x = x.reshape(2400) 
        x = self.fc1(x)
        #x = self.softmax(x)
        return x
ner_model = NERModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ner_model.parameters())
import string
def add_features(tens, tokens):
    array = [0, 0, 0, 0, 0, 0, 0, 0, 0]
    if len(tokens) >= 2:
        if len(tokens[1]) >= 1:
            word = tokens[1]
            if word[0].isupper():
                array[0] = 1
            if word.isalnum():
                array[1] = 1
            for i in word:
                # checking whether the char is punctuation.
                if i in string.punctuation:
                    # Printing the punctuation values
                    array[2] = 1
            if word.isnumeric():
                array[3] = 1
            if word.isupper():
                array[4] = 1
            if '-' in word:
                array[5] = 1
            if '/' in word:
                array[6] = 1
            if len(word) > 3:
                array[7] = 1
            if len(word) > 6:
                array[8] = 1
    x = torch.tensor(array)
    new_tensor = torch.cat((tens, x), 0)
    return new_tensor
for epoch in range(50):
    loss_score = 0
    acc_score = 0
    prec_score = 0
    selected_items = 0
    recall_score = 0
    relevant_items = 0
    items_total = 0
    ner_model.train()
    #for i in range(len(train_labels)):
    for i in range(100):
        for j in range(1, len(train_labels[i]) - 1):
    
            X_base = train_tokens_ids[i][j-1: j+2]
            X_add = train_x[i][j-1: j+2]
            X_final = add_features(X_base, X_add)
            
            Y = train_labels[i][j: j+1]

            Y_predictions = ner_model(X_final)
            
            
            acc_score += int(torch.argmax(Y_predictions) == Y)
            
            if torch.argmax(Y_predictions) != 0:
                selected_items +=1
            if  torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():
                prec_score += 1
            
            if  Y.item() != 0:
                relevant_items +=1
            if  Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():
                recall_score += 1
            
            items_total += 1

            
            optimizer.zero_grad()
            loss = criterion(Y_predictions.unsqueeze(0), Y)
            loss.backward()
            optimizer.step()


            loss_score += loss.item() 
    
    precision = prec_score / selected_items
    recall = recall_score / relevant_items
    #f1_score = (2*precision * recall) / (precision + recall)
    display('epoch: ', epoch)
    display('loss: ', loss_score / items_total)
    display('acc: ', acc_score / items_total)
    display('prec: ', precision)
    display('recall: : ', recall)
    #display('f1: ', f1_score)
'epoch: '
0
'loss: '
2.811322446731947
'acc: '
0.48
'prec: '
0.18604651162790697
'recall: : '
0.1702127659574468
'epoch: '
1
'loss: '
2.5642633876085097
'acc: '
0.43
'prec: '
0.1702127659574468
'recall: : '
0.1702127659574468
'epoch: '
2
'loss: '
2.2745147155076166
'acc: '
0.54
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
3
'loss: '
2.3077734905840908
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
4
'loss: '
2.2327055485211984
'acc: '
0.51
'prec: '
0.24489795918367346
'recall: : '
0.2553191489361702
'epoch: '
5
'loss: '
2.032022957816762
'acc: '
0.58
'prec: '
0.2978723404255319
'recall: : '
0.2978723404255319
'epoch: '
6
'loss: '
1.9094040171859614
'acc: '
0.57
'prec: '
0.2765957446808511
'recall: : '
0.2765957446808511
'epoch: '
7
'loss: '
1.8801336237322357
'acc: '
0.54
'prec: '
0.2391304347826087
'recall: : '
0.23404255319148937
'epoch: '
8
'loss: '
1.853852765722122
'acc: '
0.53
'prec: '
0.2222222222222222
'recall: : '
0.2127659574468085
'epoch: '
9
'loss: '
1.8288560365608282
'acc: '
0.53
'prec: '
0.2222222222222222
'recall: : '
0.2127659574468085
'epoch: '
10
'loss: '
1.8022360281114742
'acc: '
0.53
'prec: '
0.2222222222222222
'recall: : '
0.2127659574468085
'epoch: '
11
'loss: '
1.775143896874324
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
12
'loss: '
1.748205496848568
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
13
'loss: '
1.7217520299459284
'acc: '
0.54
'prec: '
0.24444444444444444
'recall: : '
0.23404255319148937
'epoch: '
14
'loss: '
1.6958234204566542
'acc: '
0.55
'prec: '
0.2608695652173913
'recall: : '
0.2553191489361702
'epoch: '
15
'loss: '
1.6712835076041666
'acc: '
0.55
'prec: '
0.2608695652173913
'recall: : '
0.2553191489361702
'epoch: '
16
'loss: '
1.6480589298788255
'acc: '
0.55
'prec: '
0.2608695652173913
'recall: : '
0.2553191489361702
'epoch: '
17
'loss: '
1.6257991834238783
'acc: '
0.55
'prec: '
0.2608695652173913
'recall: : '
0.2553191489361702
'epoch: '
18
'loss: '
1.604242644636688
'acc: '
0.55
'prec: '
0.26666666666666666
'recall: : '
0.2553191489361702
'epoch: '
19
'loss: '
1.5837320431148691
'acc: '
0.55
'prec: '
0.26666666666666666
'recall: : '
0.2553191489361702
'epoch: '
20
'loss: '
1.5641135577083332
'acc: '
0.54
'prec: '
0.24444444444444444
'recall: : '
0.23404255319148937
'epoch: '
21
'loss: '
1.5454202877922216
'acc: '
0.54
'prec: '
0.24444444444444444
'recall: : '
0.23404255319148937
'epoch: '
22
'loss: '
1.5275404900076683
'acc: '
0.54
'prec: '
0.24444444444444444
'recall: : '
0.23404255319148937
'epoch: '
23
'loss: '
1.5099791488426855
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
24
'loss: '
1.4932281806698302
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
25
'loss: '
1.4772486361171469
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
26
'loss: '
1.4617241937015206
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
27
'loss: '
1.447145789535134
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
28
'loss: '
1.433001351452549
'acc: '
0.53
'prec: '
0.22727272727272727
'recall: : '
0.2127659574468085
'epoch: '
29
'loss: '
1.4193171267636353
'acc: '
0.54
'prec: '
0.25
'recall: : '
0.23404255319148937
'epoch: '
30
'loss: '
1.4062099850556116
'acc: '
0.55
'prec: '
0.2727272727272727
'recall: : '
0.2553191489361702
'epoch: '
31
'loss: '
1.3941219550243114
'acc: '
0.55
'prec: '
0.2727272727272727
'recall: : '
0.2553191489361702
'epoch: '
32
'loss: '
1.3825843345944304
'acc: '
0.56
'prec: '
0.29545454545454547
'recall: : '
0.2765957446808511
'epoch: '
33
'loss: '
1.3714176365407185
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
34
'loss: '
1.3609581479639745
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
35
'loss: '
1.3509947879862738
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
36
'loss: '
1.3424826521927025
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
37
'loss: '
1.3336302372731734
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
38
'loss: '
1.3246490387670928
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
39
'loss: '
1.316349835752626
'acc: '
0.57
'prec: '
0.30952380952380953
'recall: : '
0.2765957446808511
'epoch: '
40
'loss: '
1.3090153592341813
'acc: '
0.57
'prec: '
0.30952380952380953
'recall: : '
0.2765957446808511
'epoch: '
41
'loss: '
1.3016801220795606
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
42
'loss: '
1.2947906140016858
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
43
'loss: '
1.2887717709777644
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
44
'loss: '
1.2825759449476026
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
45
'loss: '
1.2770079451325
'acc: '
0.56
'prec: '
0.3023255813953488
'recall: : '
0.2765957446808511
'epoch: '
46
'loss: '
1.2715366566940793
'acc: '
0.56
'prec: '
0.30952380952380953
'recall: : '
0.2765957446808511
'epoch: '
47
'loss: '
1.266929566776089
'acc: '
0.56
'prec: '
0.30952380952380953
'recall: : '
0.2765957446808511
'epoch: '
48
'loss: '
1.2626329537964194
'acc: '
0.56
'prec: '
0.30952380952380953
'recall: : '
0.2765957446808511
'epoch: '
49
'loss: '
1.2600517650827532
'acc: '
0.56
'prec: '
0.30952380952380953
'recall: : '
0.2765957446808511
(2*precision * recall) / (precision + recall)
0.29213483146067415
Y_predictions
tensor([ -0.5326,  -1.1218, -10.0297,  -0.1610, -10.9741,   1.3533, -11.9781,
          0.6097,  -9.1263], grad_fn=<AddBackward0>)
ner_tags_re = {
    0: 'O',
    1: 'B-PER',
    2: 'B-LOC',
    3: 'I-PER',
    4: 'B-MISC',
    5: 'I-MISC',
    6: 'I-LOC',
    7: 'B-ORG',
    8: 'I-ORG'
}

def generate_out(folder_path):
    ner_model.eval()
    ner_model.cpu()
    print('Generating out')
    X_dev = []
    with open(f"{folder_path}/in.tsv", 'r') as file:
        for line in file:
            line = line.strip()
            X_dev.append(line.split(' '))
    test_tokens_ids = data_process(X_dev)

    predicted_values = []
    # for i in range(100):
    for i in range(len(test_tokens_ids)):
        pred_string = ''
        for j in range(1, len(test_tokens_ids[i]) - 1):
            X = test_tokens_ids[i][j - 1: j + 2]
            X_raw_single = X_dev[i][j - 1: j + 2]
            X = add_features(X, X_raw_single)
            
            # X = X.to(device)
            # print('train is cuda?', X.is_cuda)

            try:
                Y_predictions = ner_model(X)
                id = torch.argmax(Y_predictions)
                val = ner_tags_re[int(id)]
                pred_string += val + ' '
            except Exception as e:
                print('Error', e)
        predicted_values.append(pred_string[:-1])
    lines = []
    for line in predicted_values:
        last_label = None
        line = line.split(' ')
        new_line = []
        for label in line:
            if (label != "O" and label[0:2] == "I-"):
                if last_label == None or last_label == "O":
                    label = label.replace('I-', 'B-')
                else:
                    label = "I-" + last_label[2:]
            last_label = label
            new_line.append(label)
        lines.append(" ".join(new_line))
    with open(f"{folder_path}/out.tsv", "w") as f:
        for line in lines:
            f.write(str(line) + "\n")

    f.close()
generate_out('dev-0')
generate_out('test-A')
Generating out
step 6
Generating out
step 6