en-ner-conll-2003/run.ipynb
2022-06-20 20:04:19 +02:00

52 KiB

def read_data(path):
    with open(path, 'r') as f:
        dataset = [line.strip().split('\t') for line in f]
    return dataset
dataset = read_data('train/train.tsv')
train_x = [x[1] for x in dataset]
train_y = [y[0] for y in dataset]
import torchtext.vocab
from collections import Counter
def build_vocab(dataset):
    counter = Counter()
    for document in dataset:
        counter.update(document)
    
    vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
    vocab.set_default_index(0)
    return vocab
train_x = [x.split() for x in train_x]
vocab = build_vocab(train_x)
def data_process(dt):
    return [ torch.tensor([vocab['<bos>']] +[vocab[token]  for token in  document ] + [vocab['<eos>']], dtype = torch.long) for document in dt]

def labels_process(dt):
    labels = []
    for document in dt:
        temp = []
        temp.append(0)
        temp.append(document)
        # print(document)
        temp.append(0)
        labels.append(torch.tensor(temp, dtype = torch.long))
    return labels
        
    
    #return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]
ner_tags = {'O': 0, 'B-ORG': 1, 'I-ORG': 2, 'B-PER': 3, 'I-PER': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
import torch

train_tokens_ids = data_process(train_x)
dev_x = read_data('dev-0/in.tsv')
dev_y = read_data('dev-0/expected.tsv')

test_x = read_data('test-A/in.tsv')

dev_x = [x[0].split() for x in dev_x]
dev_y = [y[0].split() for y in dev_y]
test_x = [x[0].split() for x in test_x]
train_y = [y[0] for y in dataset]
print(train_y[0])
train_y = [[ner_tags.get(tag) for y in train_y for tag in y.split()]]
train_y[0]
B-ORG O B-MISC O O O B-MISC O O O B-PER I-PER O B-LOC O O O B-ORG I-ORG O O O O O O B-MISC O O O O O B-MISC O O O O O O O O O O O O O O O B-LOC O O O O B-ORG I-ORG O O O B-PER I-PER O O O O O O O O O O O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O B-PER I-PER I-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG I-ORG O O O O O O O O O B-ORG O O B-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-PER O B-MISC O O O O B-LOC O B-LOC O O O O O O O B-MISC I-MISC I-MISC O B-MISC O O O O O O O O B-PER O O O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O O B-PER I-PER I-PER O O O B-PER O O B-ORG O O O O O O O O O O O O O O O O O O B-LOC O B-LOC O B-PER O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O B-MISC O O O O O B-LOC O O O O O O O O O O O O O O O O O O O B-LOC O O O O B-ORG I-ORG I-ORG I-ORG I-ORG O B-ORG O O B-PER I-PER I-PER O O B-ORG I-ORG O O B-LOC O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O O O B-LOC O O O O B-LOC O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O
[1,
 0,
 7,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 3,
 4,
 0,
 5,
 0,
 0,
 0,
 1,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 1,
 2,
 0,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 3,
 4,
 4,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 0,
 7,
 0,
 0,
 0,
 0,
 5,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 8,
 8,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 3,
 4,
 4,
 0,
 0,
 0,
 3,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 5,
 0,
 3,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 1,
 2,
 2,
 2,
 2,
 0,
 1,
 0,
 0,
 3,
 4,
 4,
 0,
 0,
 1,
 2,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 8,
 8,
 8,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 5,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 6,
 0,
 0,
 0,
 0,
 5,
 0,
 7,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 1,
 2,
 0,
 3,
 4,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 5,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 2,
 2,
 0,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 7,
 0,
 0,
 3,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 5,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 ...]
dev_y = [ner_tags.get(tag) for y in dev_y for tag in y]
dev_y
[0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 7,
 8,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 6,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 4,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 5,
 0,
 0,
 3,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 1,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 4,
 0,
 7,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 8,
 8,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 5,
 0,
 1,
 0,
 0,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 1,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 6,
 0,
 1,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 5,
 0,
 5,
 6,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 5,
 0,
 1,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 1,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 5,
 0,
 1,
 0,
 0,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 0,
 0,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 5,
 0,
 1,
 0,
 0,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 3,
 4,
 0,
 0,
 3,
 4,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 5,
 0,
 0,
 5,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 2,
 2,
 2,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 7,
 0,
 0,
 0,
 0,
 0,
 1,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 2,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 5,
 6,
 0,
 0,
 0,
 0,
 1,
 2,
 2,
 2,
 2,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 6,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 6,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 6,
 0,
 0,
 0,
 0,
 0,
 1,
 2,
 0,
 0,
 5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 5,
 ...]
test_tokens_ids = data_process(dev_x)
train_labels = labels_process(train_y[0])
test_labels = labels_process(dev_y)
class NERModel(torch.nn.Module):

    def __init__(self,):
        super(NERModel, self).__init__()
        self.emb = torch.nn.Embedding(23627, 200)
        self.fc1 = torch.nn.Linear(2400, 9)
        #self.softmax = torch.nn.Softmax(dim=1)
        # nie trzeba, bo używamy https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
        # jako kryterium
        

    def forward(self, x):
        x = self.emb(x)
        x = x.reshape(2400) 
        x = self.fc1(x)
        #x = self.softmax(x)
        return x
ner_model = NERModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ner_model.parameters())
import string
def add_features(tens, tokens):
    array = [0, 0, 0, 0, 0, 0, 0, 0, 0]
    if len(tokens) >= 2:
        if len(tokens[1]) >= 1:
            word = tokens[1]
            if word[0].isupper():
                array[0] = 1
            if word.isalnum():
                array[1] = 1
            for i in word:
                # checking whether the char is punctuation.
                if i in string.punctuation:
                    # Printing the punctuation values
                    array[2] = 1
            if word.isnumeric():
                array[3] = 1
            if word.isupper():
                array[4] = 1
            if '-' in word:
                array[5] = 1
            if '/' in word:
                array[6] = 1
            if len(word) > 3:
                array[7] = 1
            if len(word) > 6:
                array[8] = 1
    x = torch.tensor(array)
    new_tensor = torch.cat((tens, x), 0)
    return new_tensor
for epoch in range(50):
    loss_score = 0
    acc_score = 0
    prec_score = 0
    selected_items = 0
    recall_score = 0
    relevant_items = 0
    items_total = 0
    ner_model.train()
    #for i in range(len(train_labels)):
    for i in range(100):
        for j in range(1, len(train_labels[i]) - 1):
    
            X_base = train_tokens_ids[i][j-1: j+2]
            X_add = train_x[i][j-1: j+2]
            X_final = add_features(X_base, X_add)
            
            Y = train_labels[i][j: j+1]

            Y_predictions = ner_model(X_final)
            
            
            acc_score += int(torch.argmax(Y_predictions) == Y)
            
            if torch.argmax(Y_predictions) != 0:
                selected_items +=1
            if  torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():
                prec_score += 1
            
            if  Y.item() != 0:
                relevant_items +=1
            if  Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():
                recall_score += 1
            
            items_total += 1

            
            optimizer.zero_grad()
            loss = criterion(Y_predictions.unsqueeze(0), Y)
            loss.backward()
            optimizer.step()


            loss_score += loss.item() 
    
    precision = prec_score / selected_items
    recall = recall_score / relevant_items
    #f1_score = (2*precision * recall) / (precision + recall)
    print('epoch: ', epoch, end='\t')
    print('loss: ', loss_score / items_total, end='\t')
    print('acc: ', acc_score / items_total, end='\t')
    print('prec: ', precision, end='\t')
    print('recall: : ', recall)
    #display('f1: ', f1_score)
epoch:  0	loss:  0.4997586368946457	acc:  0.89	prec:  0.6875	recall: :  0.6470588235294118
epoch:  1	loss:  0.518995374645849	acc:  0.9	prec:  0.6666666666666666	recall: :  0.7058823529411765
epoch:  2	loss:  0.6272036719378185	acc:  0.87	prec:  0.625	recall: :  0.5882352941176471
epoch:  3	loss:  0.5379423921279067	acc:  0.9	prec:  0.7058823529411765	recall: :  0.7058823529411765
epoch:  4	loss:  0.6458101376151467	acc:  0.88	prec:  0.6470588235294118	recall: :  0.6470588235294118
epoch:  5	loss:  0.5032455809084422	acc:  0.9	prec:  0.7692307692307693	recall: :  0.5882352941176471
epoch:  6	loss:  0.5464647812097837	acc:  0.87	prec:  0.6111111111111112	recall: :  0.6470588235294118
epoch:  7	loss:  0.5818069439918144	acc:  0.89	prec:  0.6875	recall: :  0.6470588235294118
epoch:  8	loss:  0.6688040642930787	acc:  0.86	prec:  0.5789473684210527	recall: :  0.6470588235294118
epoch:  9	loss:  0.47163703395596485	acc:  0.9	prec:  0.7058823529411765	recall: :  0.7058823529411765
epoch:  10	loss:  0.6080643151845471	acc:  0.87	prec:  0.625	recall: :  0.5882352941176471
epoch:  11	loss:  0.6119919324012835	acc:  0.86	prec:  0.5789473684210527	recall: :  0.6470588235294118
epoch:  12	loss:  0.5809223624372385	acc:  0.9	prec:  0.7333333333333333	recall: :  0.6470588235294118
epoch:  13	loss:  0.5410229888884214	acc:  0.89	prec:  0.6875	recall: :  0.6470588235294118
epoch:  14	loss:  0.5213326926458853	acc:  0.88	prec:  0.631578947368421	recall: :  0.7058823529411765
epoch:  15	loss:  0.5297116384661035	acc:  0.89	prec:  0.7142857142857143	recall: :  0.5882352941176471
epoch:  16	loss:  0.5681106116262435	acc:  0.87	prec:  0.6111111111111112	recall: :  0.6470588235294118
epoch:  17	loss:  0.49915451315861675	acc:  0.9	prec:  0.7333333333333333	recall: :  0.6470588235294118
epoch:  18	loss:  0.5361382347030667	acc:  0.88	prec:  0.6470588235294118	recall: :  0.6470588235294118
epoch:  19	loss:  0.4398948981850981	acc:  0.89	prec:  0.6875	recall: :  0.6470588235294118
epoch:  20	loss:  0.587098065932239	acc:  0.86	prec:  0.5789473684210527	recall: :  0.6470588235294118
epoch:  21	loss:  0.4703573033369526	acc:  0.9	prec:  0.7333333333333333	recall: :  0.6470588235294118
epoch:  22	loss:  0.33882861601850434	acc:  0.9	prec:  0.6666666666666666	recall: :  0.7058823529411765
epoch:  23	loss:  0.6288586365318634	acc:  0.86	prec:  0.5789473684210527	recall: :  0.6470588235294118
epoch:  24	loss:  0.4446407145373905	acc:  0.9	prec:  0.7692307692307693	recall: :  0.5882352941176471
epoch:  25	loss:  0.47516126279861737	acc:  0.89	prec:  0.6875	recall: :  0.6470588235294118
epoch:  26	loss:  0.47878878450462253	acc:  0.87	prec:  0.6	recall: :  0.7058823529411765
epoch:  27	loss:  0.406448530066898	acc:  0.9	prec:  0.7333333333333333	recall: :  0.6470588235294118
epoch:  28	loss:  0.5326147545382947	acc:  0.88	prec:  0.6470588235294118	recall: :  0.6470588235294118
epoch:  29	loss:  0.35017394204057384	acc:  0.89	prec:  0.6875	recall: :  0.6470588235294118
epoch:  30	loss:  0.5492841665070227	acc:  0.85	prec:  0.5555555555555556	recall: :  0.5882352941176471
epoch:  31	loss:  0.45283244484153784	acc:  0.9	prec:  0.7058823529411765	recall: :  0.7058823529411765
epoch:  32	loss:  0.40580460080429476	acc:  0.9	prec:  0.7333333333333333	recall: :  0.6470588235294118
epoch:  33	loss:  0.5504078443901653	acc:  0.86	prec:  0.5789473684210527	recall: :  0.6470588235294118
epoch:  34	loss:  0.45548378403755124	acc:  0.9	prec:  0.7333333333333333	recall: :  0.6470588235294118
epoch:  35	loss:  0.4666948410707255	acc:  0.89	prec:  0.6875	recall: :  0.6470588235294118
epoch:  36	loss:  0.3942578120598796	acc:  0.87	prec:  0.6111111111111112	recall: :  0.6470588235294118
epoch:  37	loss:  0.395962362795658	acc:  0.9	prec:  0.7692307692307693	recall: :  0.5882352941176471
epoch:  38	loss:  0.44939344771950573	acc:  0.87	prec:  0.6111111111111112	recall: :  0.6470588235294118
epoch:  39	loss:  0.38211571767803887	acc:  0.9	prec:  0.7333333333333333	recall: :  0.6470588235294118
epoch:  40	loss:  0.48910969563327855	acc:  0.87	prec:  0.6111111111111112	recall: :  0.6470588235294118
epoch:  41	loss:  0.3446516449950968	acc:  0.9	prec:  0.7333333333333333	recall: :  0.6470588235294118
epoch:  42	loss:  0.4679804835932646	acc:  0.87	prec:  0.6111111111111112	recall: :  0.6470588235294118
epoch:  43	loss:  0.33552404487287274	acc:  0.9	prec:  0.7058823529411765	recall: :  0.7058823529411765
epoch:  44	loss:  0.4151001131474459	acc:  0.87	prec:  0.625	recall: :  0.5882352941176471
epoch:  45	loss:  0.36344730574960066	acc:  0.9	prec:  0.7333333333333333	recall: :  0.6470588235294118
epoch:  46	loss:  0.36800105327266464	acc:  0.88	prec:  0.6470588235294118	recall: :  0.6470588235294118
epoch:  47	loss:  0.3511931332464837	acc:  0.89	prec:  0.7142857142857143	recall: :  0.5882352941176471
epoch:  48	loss:  0.4371468522066334	acc:  0.87	prec:  0.6111111111111112	recall: :  0.6470588235294118
epoch:  49	loss:  0.3572919689995433	acc:  0.9	prec:  0.7333333333333333	recall: :  0.6470588235294118
(2*precision * recall) / (precision + recall)
0.6875
Y_predictions
tensor([-2.1121,  6.8412, -5.5392, -2.9573, -2.2702, -4.7000, -8.2250, -4.4908,
        -8.5875], grad_fn=<AddBackward0>)
ner_tags_re = {
    0: 'O',
    1: 'B-PER',
    2: 'B-LOC',
    3: 'I-PER',
    4: 'B-MISC',
    5: 'I-MISC',
    6: 'I-LOC',
    7: 'B-ORG',
    8: 'I-ORG'
}

def generate_out(folder_path):
    ner_model.eval()
    ner_model.cpu()
    print('Generating out')
    X_dev = []
    with open(f"{folder_path}/in.tsv", 'r') as file:
        for line in file:
            line = line.strip()
            X_dev.append(line.split(' '))
    test_tokens_ids = data_process(X_dev)

    predicted_values = []
    # for i in range(100):
    for i in range(len(test_tokens_ids)):
        pred_string = ''
        for j in range(1, len(test_tokens_ids[i]) - 1):
            X = test_tokens_ids[i][j - 1: j + 2]
            X_raw_single = X_dev[i][j - 1: j + 2]
            X = add_features(X, X_raw_single)
            
            # X = X.to(device)
            # print('train is cuda?', X.is_cuda)

            try:
                Y_predictions = ner_model(X)
                id = torch.argmax(Y_predictions)
                val = ner_tags_re[int(id)]
                pred_string += val + ' '
            except Exception as e:
                print('Error', e)
        predicted_values.append(pred_string[:-1])
    lines = []
    for line in predicted_values:
        last_label = None
        line = line.split(' ')
        new_line = []
        for label in line:
            if (label != "O" and label[0:2] == "I-"):
                if last_label == None or last_label == "O":
                    label = label.replace('I-', 'B-')
                else:
                    label = "I-" + last_label[2:]
            last_label = label
            new_line.append(label)
        lines.append(" ".join(new_line))
    with open(f"{folder_path}/out.tsv", "w") as f:
        for line in lines:
            f.write(str(line) + "\n")

    f.close()
generate_out('dev-0')
generate_out('test-A')
Generating out
Error index out of range in self
Error index out of range in self
Error index out of range in self
Generating out