en-ner-conll-2003/run.ipynb
2022-06-20 20:04:19 +02:00

2604 lines
52 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "coral-camera",
"metadata": {},
"outputs": [],
"source": [
"def read_data(path):\n",
" with open(path, 'r') as f:\n",
" dataset = [line.strip().split('\\t') for line in f]\n",
" return dataset"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cc201a16",
"metadata": {},
"outputs": [],
"source": [
"dataset = read_data('train/train.tsv')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "sharing-employment",
"metadata": {},
"outputs": [],
"source": [
"train_x = [x[1] for x in dataset]\n",
"train_y = [y[0] for y in dataset]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "elder-trauma",
"metadata": {},
"outputs": [],
"source": [
"import torchtext.vocab\n",
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "material-timothy",
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" \n",
" vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])\n",
" vocab.set_default_index(0)\n",
" return vocab"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "provincial-reader",
"metadata": {},
"outputs": [],
"source": [
"train_x = [x.split() for x in train_x]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "invalid-nursing",
"metadata": {},
"outputs": [],
"source": [
"vocab = build_vocab(train_x)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "accredited-observation",
"metadata": {},
"outputs": [],
"source": [
"def data_process(dt):\n",
" return [ torch.tensor([vocab['<bos>']] +[vocab[token] for token in document ] + [vocab['<eos>']], dtype = torch.long) for document in dt]\n",
"\n",
"def labels_process(dt):\n",
" labels = []\n",
" for document in dt:\n",
" temp = []\n",
" temp.append(0)\n",
" temp.append(document)\n",
" # print(document)\n",
" temp.append(0)\n",
" labels.append(torch.tensor(temp, dtype = torch.long))\n",
" return labels\n",
" \n",
" \n",
" #return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]\n"
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "united-local",
"metadata": {},
"outputs": [],
"source": [
"ner_tags = {'O': 0, 'B-ORG': 1, 'I-ORG': 2, 'B-PER': 3, 'I-PER': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "reported-afghanistan",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"train_tokens_ids = data_process(train_x)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"id": "southern-nirvana",
"metadata": {},
"outputs": [],
"source": [
"dev_x = read_data('dev-0/in.tsv')\n",
"dev_y = read_data('dev-0/expected.tsv')\n",
"\n",
"test_x = read_data('test-A/in.tsv')\n",
"\n",
"dev_x = [x[0].split() for x in dev_x]\n",
"dev_y = [y[0].split() for y in dev_y]\n",
"test_x = [x[0].split() for x in test_x]"
]
},
{
"cell_type": "code",
"execution_count": 80,
"id": "played-transparency",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"B-ORG O B-MISC O O O B-MISC O O O B-PER I-PER O B-LOC O O O B-ORG I-ORG O O O O O O B-MISC O O O O O B-MISC O O O O O O O O O O O O O O O B-LOC O O O O B-ORG I-ORG O O O B-PER I-PER O O O O O O O O O O O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O B-PER I-PER I-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG I-ORG O O O O O O O O O B-ORG O O B-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-PER O B-MISC O O O O B-LOC O B-LOC O O O O O O O B-MISC I-MISC I-MISC O B-MISC O O O O O O O O B-PER O O O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O O B-PER I-PER I-PER O O O B-PER O O B-ORG O O O O O O O O O O O O O O O O O O B-LOC O B-LOC O B-PER O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O B-MISC O O O O O B-LOC O O O O O O O O O O O O O O O O O O O B-LOC O O O O B-ORG I-ORG I-ORG I-ORG I-ORG O B-ORG O O B-PER I-PER I-PER O O B-ORG I-ORG O O B-LOC O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O O O B-LOC O O O O B-LOC O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O\n"
]
},
{
"data": {
"text/plain": [
"[1,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 4,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 8,\n",
" 8,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 2,\n",
" 2,\n",
" 2,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 8,\n",
" 8,\n",
" 8,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 6,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 2,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" ...]"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_y = [y[0] for y in dataset]\n",
"print(train_y[0])\n",
"train_y = [[ner_tags.get(tag) for y in train_y for tag in y.split()]]\n",
"train_y[0]"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "assured-colonial",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 8,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 6,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 8,\n",
" 8,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 6,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 5,\n",
" 6,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 3,\n",
" 4,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 2,\n",
" 2,\n",
" 2,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 7,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 6,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 2,\n",
" 2,\n",
" 2,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 6,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 6,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 6,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 2,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 0,\n",
" 5,\n",
" ...]"
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dev_y = [ner_tags.get(tag) for y in dev_y for tag in y]\n",
"dev_y"
]
},
{
"cell_type": "code",
"execution_count": 84,
"id": "identical-subsection",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"test_tokens_ids = data_process(dev_x)\n",
"train_labels = labels_process(train_y[0])\n",
"test_labels = labels_process(dev_y)"
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "demanding-bonus",
"metadata": {},
"outputs": [],
"source": [
"class NERModel(torch.nn.Module):\n",
"\n",
" def __init__(self,):\n",
" super(NERModel, self).__init__()\n",
" self.emb = torch.nn.Embedding(23627, 200)\n",
" self.fc1 = torch.nn.Linear(2400, 9)\n",
" #self.softmax = torch.nn.Softmax(dim=1)\n",
" # nie trzeba, bo używamy https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html\n",
" # jako kryterium\n",
" \n",
"\n",
" def forward(self, x):\n",
" x = self.emb(x)\n",
" x = x.reshape(2400) \n",
" x = self.fc1(x)\n",
" #x = self.softmax(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "statistical-barbados",
"metadata": {},
"outputs": [],
"source": [
"ner_model = NERModel()"
]
},
{
"cell_type": "code",
"execution_count": 87,
"id": "impressive-insert",
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "speaking-seeking",
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(ner_model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 89,
"id": "8161d438",
"metadata": {},
"outputs": [],
"source": [
"import string\n",
"def add_features(tens, tokens):\n",
" array = [0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
" if len(tokens) >= 2:\n",
" if len(tokens[1]) >= 1:\n",
" word = tokens[1]\n",
" if word[0].isupper():\n",
" array[0] = 1\n",
" if word.isalnum():\n",
" array[1] = 1\n",
" for i in word:\n",
" # checking whether the char is punctuation.\n",
" if i in string.punctuation:\n",
" # Printing the punctuation values\n",
" array[2] = 1\n",
" if word.isnumeric():\n",
" array[3] = 1\n",
" if word.isupper():\n",
" array[4] = 1\n",
" if '-' in word:\n",
" array[5] = 1\n",
" if '/' in word:\n",
" array[6] = 1\n",
" if len(word) > 3:\n",
" array[7] = 1\n",
" if len(word) > 6:\n",
" array[8] = 1\n",
" x = torch.tensor(array)\n",
" new_tensor = torch.cat((tens, x), 0)\n",
" return new_tensor"
]
},
{
"cell_type": "code",
"execution_count": 102,
"id": "sized-mobile",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0\tloss: 0.4997586368946457\tacc: 0.89\tprec: 0.6875\trecall: : 0.6470588235294118\n",
"epoch: 1\tloss: 0.518995374645849\tacc: 0.9\tprec: 0.6666666666666666\trecall: : 0.7058823529411765\n",
"epoch: 2\tloss: 0.6272036719378185\tacc: 0.87\tprec: 0.625\trecall: : 0.5882352941176471\n",
"epoch: 3\tloss: 0.5379423921279067\tacc: 0.9\tprec: 0.7058823529411765\trecall: : 0.7058823529411765\n",
"epoch: 4\tloss: 0.6458101376151467\tacc: 0.88\tprec: 0.6470588235294118\trecall: : 0.6470588235294118\n",
"epoch: 5\tloss: 0.5032455809084422\tacc: 0.9\tprec: 0.7692307692307693\trecall: : 0.5882352941176471\n",
"epoch: 6\tloss: 0.5464647812097837\tacc: 0.87\tprec: 0.6111111111111112\trecall: : 0.6470588235294118\n",
"epoch: 7\tloss: 0.5818069439918144\tacc: 0.89\tprec: 0.6875\trecall: : 0.6470588235294118\n",
"epoch: 8\tloss: 0.6688040642930787\tacc: 0.86\tprec: 0.5789473684210527\trecall: : 0.6470588235294118\n",
"epoch: 9\tloss: 0.47163703395596485\tacc: 0.9\tprec: 0.7058823529411765\trecall: : 0.7058823529411765\n",
"epoch: 10\tloss: 0.6080643151845471\tacc: 0.87\tprec: 0.625\trecall: : 0.5882352941176471\n",
"epoch: 11\tloss: 0.6119919324012835\tacc: 0.86\tprec: 0.5789473684210527\trecall: : 0.6470588235294118\n",
"epoch: 12\tloss: 0.5809223624372385\tacc: 0.9\tprec: 0.7333333333333333\trecall: : 0.6470588235294118\n",
"epoch: 13\tloss: 0.5410229888884214\tacc: 0.89\tprec: 0.6875\trecall: : 0.6470588235294118\n",
"epoch: 14\tloss: 0.5213326926458853\tacc: 0.88\tprec: 0.631578947368421\trecall: : 0.7058823529411765\n",
"epoch: 15\tloss: 0.5297116384661035\tacc: 0.89\tprec: 0.7142857142857143\trecall: : 0.5882352941176471\n",
"epoch: 16\tloss: 0.5681106116262435\tacc: 0.87\tprec: 0.6111111111111112\trecall: : 0.6470588235294118\n",
"epoch: 17\tloss: 0.49915451315861675\tacc: 0.9\tprec: 0.7333333333333333\trecall: : 0.6470588235294118\n",
"epoch: 18\tloss: 0.5361382347030667\tacc: 0.88\tprec: 0.6470588235294118\trecall: : 0.6470588235294118\n",
"epoch: 19\tloss: 0.4398948981850981\tacc: 0.89\tprec: 0.6875\trecall: : 0.6470588235294118\n",
"epoch: 20\tloss: 0.587098065932239\tacc: 0.86\tprec: 0.5789473684210527\trecall: : 0.6470588235294118\n",
"epoch: 21\tloss: 0.4703573033369526\tacc: 0.9\tprec: 0.7333333333333333\trecall: : 0.6470588235294118\n",
"epoch: 22\tloss: 0.33882861601850434\tacc: 0.9\tprec: 0.6666666666666666\trecall: : 0.7058823529411765\n",
"epoch: 23\tloss: 0.6288586365318634\tacc: 0.86\tprec: 0.5789473684210527\trecall: : 0.6470588235294118\n",
"epoch: 24\tloss: 0.4446407145373905\tacc: 0.9\tprec: 0.7692307692307693\trecall: : 0.5882352941176471\n",
"epoch: 25\tloss: 0.47516126279861737\tacc: 0.89\tprec: 0.6875\trecall: : 0.6470588235294118\n",
"epoch: 26\tloss: 0.47878878450462253\tacc: 0.87\tprec: 0.6\trecall: : 0.7058823529411765\n",
"epoch: 27\tloss: 0.406448530066898\tacc: 0.9\tprec: 0.7333333333333333\trecall: : 0.6470588235294118\n",
"epoch: 28\tloss: 0.5326147545382947\tacc: 0.88\tprec: 0.6470588235294118\trecall: : 0.6470588235294118\n",
"epoch: 29\tloss: 0.35017394204057384\tacc: 0.89\tprec: 0.6875\trecall: : 0.6470588235294118\n",
"epoch: 30\tloss: 0.5492841665070227\tacc: 0.85\tprec: 0.5555555555555556\trecall: : 0.5882352941176471\n",
"epoch: 31\tloss: 0.45283244484153784\tacc: 0.9\tprec: 0.7058823529411765\trecall: : 0.7058823529411765\n",
"epoch: 32\tloss: 0.40580460080429476\tacc: 0.9\tprec: 0.7333333333333333\trecall: : 0.6470588235294118\n",
"epoch: 33\tloss: 0.5504078443901653\tacc: 0.86\tprec: 0.5789473684210527\trecall: : 0.6470588235294118\n",
"epoch: 34\tloss: 0.45548378403755124\tacc: 0.9\tprec: 0.7333333333333333\trecall: : 0.6470588235294118\n",
"epoch: 35\tloss: 0.4666948410707255\tacc: 0.89\tprec: 0.6875\trecall: : 0.6470588235294118\n",
"epoch: 36\tloss: 0.3942578120598796\tacc: 0.87\tprec: 0.6111111111111112\trecall: : 0.6470588235294118\n",
"epoch: 37\tloss: 0.395962362795658\tacc: 0.9\tprec: 0.7692307692307693\trecall: : 0.5882352941176471\n",
"epoch: 38\tloss: 0.44939344771950573\tacc: 0.87\tprec: 0.6111111111111112\trecall: : 0.6470588235294118\n",
"epoch: 39\tloss: 0.38211571767803887\tacc: 0.9\tprec: 0.7333333333333333\trecall: : 0.6470588235294118\n",
"epoch: 40\tloss: 0.48910969563327855\tacc: 0.87\tprec: 0.6111111111111112\trecall: : 0.6470588235294118\n",
"epoch: 41\tloss: 0.3446516449950968\tacc: 0.9\tprec: 0.7333333333333333\trecall: : 0.6470588235294118\n",
"epoch: 42\tloss: 0.4679804835932646\tacc: 0.87\tprec: 0.6111111111111112\trecall: : 0.6470588235294118\n",
"epoch: 43\tloss: 0.33552404487287274\tacc: 0.9\tprec: 0.7058823529411765\trecall: : 0.7058823529411765\n",
"epoch: 44\tloss: 0.4151001131474459\tacc: 0.87\tprec: 0.625\trecall: : 0.5882352941176471\n",
"epoch: 45\tloss: 0.36344730574960066\tacc: 0.9\tprec: 0.7333333333333333\trecall: : 0.6470588235294118\n",
"epoch: 46\tloss: 0.36800105327266464\tacc: 0.88\tprec: 0.6470588235294118\trecall: : 0.6470588235294118\n",
"epoch: 47\tloss: 0.3511931332464837\tacc: 0.89\tprec: 0.7142857142857143\trecall: : 0.5882352941176471\n",
"epoch: 48\tloss: 0.4371468522066334\tacc: 0.87\tprec: 0.6111111111111112\trecall: : 0.6470588235294118\n",
"epoch: 49\tloss: 0.3572919689995433\tacc: 0.9\tprec: 0.7333333333333333\trecall: : 0.6470588235294118\n"
]
}
],
"source": [
"for epoch in range(50):\n",
" loss_score = 0\n",
" acc_score = 0\n",
" prec_score = 0\n",
" selected_items = 0\n",
" recall_score = 0\n",
" relevant_items = 0\n",
" items_total = 0\n",
" ner_model.train()\n",
" #for i in range(len(train_labels)):\n",
" for i in range(100):\n",
" for j in range(1, len(train_labels[i]) - 1):\n",
" \n",
" X_base = train_tokens_ids[i][j-1: j+2]\n",
" X_add = train_x[i][j-1: j+2]\n",
" X_final = add_features(X_base, X_add)\n",
" \n",
" Y = train_labels[i][j: j+1]\n",
"\n",
" Y_predictions = ner_model(X_final)\n",
" \n",
" \n",
" acc_score += int(torch.argmax(Y_predictions) == Y)\n",
" \n",
" if torch.argmax(Y_predictions) != 0:\n",
" selected_items +=1\n",
" if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():\n",
" prec_score += 1\n",
" \n",
" if Y.item() != 0:\n",
" relevant_items +=1\n",
" if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():\n",
" recall_score += 1\n",
" \n",
" items_total += 1\n",
"\n",
" \n",
" optimizer.zero_grad()\n",
" loss = criterion(Y_predictions.unsqueeze(0), Y)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"\n",
" loss_score += loss.item() \n",
" \n",
" precision = prec_score / selected_items\n",
" recall = recall_score / relevant_items\n",
" #f1_score = (2*precision * recall) / (precision + recall)\n",
" print('epoch: ', epoch, end='\\t')\n",
" print('loss: ', loss_score / items_total, end='\\t')\n",
" print('acc: ', acc_score / items_total, end='\\t')\n",
" print('prec: ', precision, end='\\t')\n",
" print('recall: : ', recall)\n",
" #display('f1: ', f1_score)"
]
},
{
"cell_type": "code",
"execution_count": 103,
"id": "defensive-discretion",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.6875"
]
},
"execution_count": 103,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(2*precision * recall) / (precision + recall)"
]
},
{
"cell_type": "code",
"execution_count": 104,
"id": "common-national",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-2.1121, 6.8412, -5.5392, -2.9573, -2.2702, -4.7000, -8.2250, -4.4908,\n",
" -8.5875], grad_fn=<AddBackward0>)"
]
},
"execution_count": 104,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y_predictions"
]
},
{
"cell_type": "code",
"execution_count": 105,
"id": "isolated-excess",
"metadata": {},
"outputs": [],
"source": [
"ner_tags_re = {\n",
" 0: 'O',\n",
" 1: 'B-PER',\n",
" 2: 'B-LOC',\n",
" 3: 'I-PER',\n",
" 4: 'B-MISC',\n",
" 5: 'I-MISC',\n",
" 6: 'I-LOC',\n",
" 7: 'B-ORG',\n",
" 8: 'I-ORG'\n",
"}\n",
"\n",
"def generate_out(folder_path):\n",
" ner_model.eval()\n",
" ner_model.cpu()\n",
" print('Generating out')\n",
" X_dev = []\n",
" with open(f\"{folder_path}/in.tsv\", 'r') as file:\n",
" for line in file:\n",
" line = line.strip()\n",
" X_dev.append(line.split(' '))\n",
" test_tokens_ids = data_process(X_dev)\n",
"\n",
" predicted_values = []\n",
" # for i in range(100):\n",
" for i in range(len(test_tokens_ids)):\n",
" pred_string = ''\n",
" for j in range(1, len(test_tokens_ids[i]) - 1):\n",
" X = test_tokens_ids[i][j - 1: j + 2]\n",
" X_raw_single = X_dev[i][j - 1: j + 2]\n",
" X = add_features(X, X_raw_single)\n",
" \n",
" # X = X.to(device)\n",
" # print('train is cuda?', X.is_cuda)\n",
"\n",
" try:\n",
" Y_predictions = ner_model(X)\n",
" id = torch.argmax(Y_predictions)\n",
" val = ner_tags_re[int(id)]\n",
" pred_string += val + ' '\n",
" except Exception as e:\n",
" print('Error', e)\n",
" predicted_values.append(pred_string[:-1])\n",
" lines = []\n",
" for line in predicted_values:\n",
" last_label = None\n",
" line = line.split(' ')\n",
" new_line = []\n",
" for label in line:\n",
" if (label != \"O\" and label[0:2] == \"I-\"):\n",
" if last_label == None or last_label == \"O\":\n",
" label = label.replace('I-', 'B-')\n",
" else:\n",
" label = \"I-\" + last_label[2:]\n",
" last_label = label\n",
" new_line.append(label)\n",
" lines.append(\" \".join(new_line))\n",
" with open(f\"{folder_path}/out.tsv\", \"w\") as f:\n",
" for line in lines:\n",
" f.write(str(line) + \"\\n\")\n",
"\n",
" f.close()"
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "d362007a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generating out\n",
"Error index out of range in self\n",
"Error index out of range in self\n",
"Error index out of range in self\n",
"Generating out\n"
]
}
],
"source": [
"generate_out('dev-0')\n",
"generate_out('test-A')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c7bdb256",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.4 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"vscode": {
"interpreter": {
"hash": "369f2c481f4da34e4445cda3fffd2e751bd1c4d706f27375911949ba6bb62e1c"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}