en-ner-conll-2003/sequence_labeling_fras.ipynb

1035 lines
26 KiB
Plaintext
Raw Permalink Normal View History

2021-06-08 15:55:00 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Klasyfikacja wieloklasowa i sequence labelling"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import gensim\n",
"import torch\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from datasets import load_dataset\n",
"from torchtext.vocab import Vocab\n",
"from collections import Counter\n",
"\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.metrics import accuracy_score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Zadanie domowe\n",
"\n",
"- sklonować repozytorium https://git.wmi.amu.edu.pl/kubapok/en-ner-conll-2003\n",
"- stworzyć klasyfikator bazujący na sieci neuronowej feed forward w pytorchu (można bazować na tym jupyterze lub nie).\n",
"- klasyfikator powinien obejmować dodatkowe cechy (np. długość wyrazu, czy wyraz zaczyna się od wielkiej litery, stemmming słowa, czy zawiera cyfrę)\n",
"- stworzyć predykcje w plikach dev-0/out.tsv oraz test-A/out.tsv\n",
"- wynik fscore sprawdzony za pomocą narzędzia geval (patrz poprzednie zadanie) powinien wynosić conajmniej 0.60\n",
"- proszę umieścić predykcję oraz skrypty generujące (w postaci tekstowej a nie jupyter) w repo, a w MS TEAMS umieścić link do swojego repo\n",
"termin 08.06, 80 punktów\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# train"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import lzma\n",
"import re\n",
"import itertools\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def read_data(filename):\n",
" all_data = lzma.open(filename).read().decode('UTF-8').split('\\n')\n",
" return [line.split('\\t') for line in all_data][:-1]\n",
"\n",
"train_data = read_data('train/train.tsv.xz')\n",
"\n",
"tokens, ner_tags = [], []\n",
"for i in train_data:\n",
" ner_tags.append(i[0].split())\n",
" tokens.append(i[1].split())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['B-PER', 'B-LOC', 'I-LOC', 'B-ORG', 'I-ORG', 'I-MISC', 'O', 'B-MISC', 'I-PER']\n"
]
}
],
"source": [
"ner_tags_set = list(set(itertools.chain(*ner_tags)))\n",
"print(ner_tags_set)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'B-PER': 0, 'B-LOC': 1, 'I-LOC': 2, 'B-ORG': 3, 'I-ORG': 4, 'I-MISC': 5, 'O': 6, 'B-MISC': 7, 'I-PER': 8}\n"
]
}
],
"source": [
"ner_tags_dic = {}\n",
"for i in range(len(ner_tags_set)):\n",
" ner_tags_dic[ner_tags_set[i]] = i\n",
"print(ner_tags_dic)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"for i in range(len(ner_tags)):\n",
" for j in range(len(ner_tags[i])):\n",
" ner_tags[i][j] = ner_tags_dic[ner_tags[i][j]]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def data_process(dt):\n",
" return [ torch.tensor([vocab['<bos>']] +[vocab[token] for token in document ] + [vocab['<eos>']], dtype = torch.long) for document in dt]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def labels_process(dt):\n",
" return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"vocab = build_vocab(tokens)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"train_tokens_ids = data_process(tokens)\n",
"train_labels = labels_process(ner_tags)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 2, 967, 22410, 239, 774, 10, 4588, 213, 7687, 5,\n",
" 4, 740, 2091, 4, 1388, 138, 4, 22, 231, 460,\n",
" 17, 16, 70, 39, 10855, 28, 239, 4552, 10, 2621,\n",
" 10, 22766, 213, 7687, 425, 4100, 2178, 514, 1897, 2010,\n",
" 663, 295, 43, 11848, 10, 2056, 5, 4, 118, 18,\n",
" 3489, 10, 7, 231, 494, 18, 3107, 1089, 10434, 10494,\n",
" 17, 16, 75, 2621, 264, 893, 11638, 30, 547, 128,\n",
" 116, 126, 425, 7, 2717, 4552, 23, 19846, 5, 4,\n",
" 15, 121, 172, 202, 348, 217, 584, 7880, 159, 103,\n",
" 172, 202, 847, 217, 3987, 19, 39, 6, 15, 7,\n",
" 460, 18, 451, 179, 17516, 1380, 2632, 17769, 91, 11,\n",
" 241, 3909, 5, 4, 86, 17, 724, 2717, 2464, 23,\n",
" 3071, 14, 201, 39, 23, 340, 29, 804, 23, 991,\n",
" 39, 264, 43, 566, 31, 7, 231, 494, 5, 4,\n",
" 86, 17, 11, 2444, 72, 224, 31, 967, 6654, 3178,\n",
" 5219, 3683, 10, 639, 2056, 10634, 6, 11710, 14, 4861,\n",
" 10782, 30, 7, 814, 14, 2949, 1146, 3915, 23, 11,\n",
" 3993, 3508, 14, 22123, 1358, 10, 5997, 814, 944, 5,\n",
" 4, 3683, 1651, 15772, 1549, 46, 730, 30, 126, 14,\n",
" 134, 29, 107, 7686, 938, 2056, 119, 807, 8919, 10229,\n",
" 9189, 12, 2088, 13, 55, 1897, 2010, 663, 5, 4,\n",
" 111, 3683, 415, 10, 3494, 40, 2444, 46, 7, 967,\n",
" 18, 2731, 3107, 1089, 6, 21529, 2949, 944, 142, 6,\n",
" 2047, 201, 584, 804, 23, 5890, 34, 145, 23, 139,\n",
" 11, 4112, 1285, 10, 814, 944, 5, 4, 1846, 6654,\n",
" 148, 17056, 484, 17738, 37, 249, 600, 3683, 27, 44,\n",
" 967, 1445, 1759, 115, 236, 8, 5706, 23399, 7280, 184,\n",
" 15, 1870, 20842, 5, 15, 4, 5, 4, 4444, 134,\n",
" 14, 126, 3338, 3683, 18, 2444, 5, 4, 22, 967,\n",
" 18, 2717, 3107, 14, 21666, 10734, 57, 283, 10, 11507,\n",
" 7, 391, 274, 166, 224, 14, 382, 11515, 10, 7,\n",
" 909, 3107, 142, 5, 4, 10166, 45, 666, 53, 757,\n",
" 10, 807, 11615, 6, 11, 7350, 663, 1055, 10, 2088,\n",
" 61, 32, 836, 10, 45, 53, 8050, 10, 2006, 184,\n",
" 1351, 4615, 2949, 3541, 5, 4, 213, 1269, 980, 16,\n",
" 70, 145, 23, 217, 2394, 10, 814, 944, 30, 58,\n",
" 2056, 6, 50, 2184, 1438, 29, 239, 78, 4552, 10,\n",
" 2621, 10, 1612, 213, 7687, 649, 5874, 2621, 684, 587,\n",
" 5, 4, 15, 1990, 103, 45, 10, 43, 2991, 19735,\n",
" 8, 32, 843, 128, 547, 57, 432, 10, 259, 118,\n",
" 18, 276, 6, 15, 10431, 265, 9239, 115, 494, 12,\n",
" 17439, 13, 860, 448, 1129, 1401, 17, 16, 8822, 994,\n",
" 5, 4, 2798, 38, 628, 1623, 10, 5997, 711, 944,\n",
" 46, 1618, 2387, 7394, 9, 637, 46, 11, 213, 409,\n",
" 6109, 7636, 119, 807, 44, 3425, 1055, 10, 1897, 2010,\n",
" 663, 31, 2983, 10768, 2369, 5, 4, 118, 4693, 8565,\n",
" 2056, 30, 126, 72, 68, 6, 866, 245, 8, 609,\n",
" 1886, 5, 4, 87, 746, 9, 8525, 253, 8, 213,\n",
" 7751, 6, 108, 92, 67, 8, 1210, 1886, 5, 4,\n",
" 3])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_tokens_ids[0]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"class NeuralNetworkModel(torch.nn.Module):\n",
"\n",
" def __init__(self, output_size):\n",
" super(NeuralNetworkModel, self).__init__()\n",
" self.fc1 = torch.nn.Linear(10_000,len(train_tokens_ids))\n",
" self.softmax = torch.nn.Softmax(dim=0)\n",
" \n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = self.softmax(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"class NERModel(torch.nn.Module):\n",
"\n",
" def __init__(self,):\n",
" super(NERModel, self).__init__()\n",
" self.emb = torch.nn.Embedding(23627,200)\n",
" self.fc1 = torch.nn.Linear(600,9)\n",
"\n",
" def forward(self, x):\n",
" x = self.emb(x)\n",
" x = x.reshape(600) \n",
" x = self.fc1(x)\n",
" #x = self.softmax(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"nn_model = NeuralNetworkModel(len(train_tokens_ids))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 967, 22410, 239])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_tokens_ids[0][1:4]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"ner_model = NERModel()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0.7428, 1.0342, -0.5970, 0.1479, 0.4966, 0.8864, 0.0432, -0.0845,\n",
" 0.2145], grad_fn=<AddBackward0>)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ner_model(train_tokens_ids[0][1:4])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(ner_model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"945"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_labels)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'epoch: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'loss: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.5410224926585327"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'acc: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.856768558951965"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'prec: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.8666126186274977"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'recall: : '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.868891651525294"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'f1: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.8677506386839527"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'epoch: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"1"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'loss: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.28820573237663566"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'acc: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.923373937025971"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'prec: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.9287656853857531"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'recall: : '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.9307640814765229"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'f1: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.9297638096147876"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for epoch in range(2):\n",
" loss_score = 0\n",
" acc_score = 0\n",
" prec_score = 0\n",
" selected_items = 0\n",
" recall_score = 0\n",
" relevant_items = 0\n",
" items_total = 0\n",
" nn_model.train()\n",
" for i in range(100):\n",
" for j in range(1, len(train_labels[i]) - 1):\n",
" \n",
" X = train_tokens_ids[i][j-1: j+2]\n",
" Y = train_labels[i][j: j+1]\n",
"\n",
" Y_predictions = ner_model(X)\n",
" \n",
" \n",
" acc_score += int(torch.argmax(Y_predictions) == Y)\n",
" \n",
" if torch.argmax(Y_predictions) != 0:\n",
" selected_items +=1\n",
" if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():\n",
" prec_score += 1\n",
" \n",
" if Y.item() != 0:\n",
" relevant_items +=1\n",
" if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():\n",
" recall_score += 1\n",
" \n",
" items_total += 1\n",
"\n",
" \n",
" optimizer.zero_grad()\n",
" loss = criterion(Y_predictions.unsqueeze(0), Y)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"\n",
" loss_score += loss.item() \n",
" \n",
" precision = prec_score / selected_items\n",
" recall = recall_score / relevant_items\n",
" f1_score = (2*precision * recall) / (precision + recall)\n",
" display('epoch: ', epoch)\n",
" display('loss: ', loss_score / items_total)\n",
" display('acc: ', acc_score / items_total)\n",
" display('prec: ', precision)\n",
" display('recall: : ', recall)\n",
" display('f1: ', f1_score)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# dev-0"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"with open('dev-0/in.tsv', \"r\", encoding=\"utf-8\") as f:\n",
" dev_0_data = [line.rstrip() for line in f]\n",
" \n",
"dev_0_data = [i.split() for i in dev_0_data]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"with open('dev-0/expected.tsv', \"r\", encoding=\"utf-8\") as f:\n",
" dev_0_tags = [line.rstrip() for line in f]\n",
" \n",
"dev_0_tags = [i.split() for i in dev_0_tags]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"for i in range(len(dev_0_tags)):\n",
" for j in range(len(dev_0_tags[i])):\n",
" dev_0_tags[i][j] = ner_tags_dic[dev_0_tags[i][j]]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"test_tokens_ids = data_process(dev_0_data)\n",
"test_labels = labels_process(dev_0_tags)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'loss: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.7757424341984906"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'acc: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.8510501460833134"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'prec: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.8772459727385378"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'recall: : '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.8616800745516441"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'f1: '"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.8693933550163583"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"result = []\n",
"\n",
"loss_score = 0\n",
"acc_score = 0\n",
"prec_score = 0\n",
"selected_items = 0\n",
"recall_score = 0\n",
"relevant_items = 0\n",
"items_total = 0\n",
"nn_model.eval()\n",
"for i in range(len(test_tokens_ids)):\n",
" result.append([])\n",
" for j in range(1, len(test_labels[i]) - 1):\n",
"\n",
" X = test_tokens_ids[i][j-1: j+2]\n",
" Y = test_labels[i][j: j+1]\n",
"\n",
" Y_predictions = ner_model(X)\n",
"\n",
"\n",
" acc_score += int(torch.argmax(Y_predictions) == Y)\n",
"\n",
" if torch.argmax(Y_predictions) != 0:\n",
" selected_items +=1\n",
" if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():\n",
" prec_score += 1\n",
"\n",
" if Y.item() != 0:\n",
" relevant_items +=1\n",
" if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():\n",
" recall_score += 1\n",
"\n",
" items_total += 1\n",
" loss = criterion(Y_predictions.unsqueeze(0), Y)\n",
" loss_score += loss.item() \n",
" \n",
" result[i].append(int(torch.argmax(Y_predictions)))\n",
"\n",
"precision = prec_score / selected_items\n",
"recall = recall_score / relevant_items\n",
"f1_score = (2*precision * recall) / (precision + recall)\n",
"display('loss: ', loss_score / items_total)\n",
"display('acc: ', acc_score / items_total)\n",
"display('prec: ', precision)\n",
"display('recall: : ', recall)\n",
"display('f1: ', f1_score)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"tags = []\n",
"tmp = []\n",
"for i in ner_tags_dic:\n",
" tmp.append(i)\n",
"\n",
"for i in range(len(result)):\n",
" tags.append([])\n",
" for j in range(len(result[i])):\n",
" tags[i].append(tmp[result[i][j]])"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"f = open(\"dev-0/out.tsv\", \"a\")\n",
"for i in tags:\n",
" f.write(' '.join(i) + '\\n')\n",
"f.close()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"with open('dev-0/expected.tsv', \"r\", encoding=\"utf-8\") as f:\n",
" dev_0_tags = [line.rstrip() for line in f]\n",
" \n",
"dev_0_tags = [i.split() for i in dev_0_tags]"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.8510501460833134\n"
]
}
],
"source": [
"import math\n",
"t = 0\n",
"for i in range(len(tags)):\n",
" for j in range(len(tags[i])):\n",
" if tags[i][j] == dev_0_tags[i][j]:\n",
" t += 1\n",
"print(t/len(list((itertools.chain(*tags)))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# test"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"with open('test-A/in.tsv', \"r\", encoding=\"utf-8\") as f:\n",
" test_data = [line.rstrip() for line in f]\n",
" \n",
"test_data = [i.split() for i in test_data]"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"test_tokens_ids = data_process(test_data)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"result = []\n",
"\n",
"loss_score = 0\n",
"acc_score = 0\n",
"prec_score = 0\n",
"selected_items = 0\n",
"recall_score = 0\n",
"relevant_items = 0\n",
"items_total = 0\n",
"nn_model.eval()\n",
"for i in range(len(test_tokens_ids)):\n",
" result.append([])\n",
" for j in range(1, len(test_tokens_ids[i]) - 1):\n",
"\n",
" X = test_tokens_ids[i][j-1: j+2]\n",
"\n",
" Y_predictions = ner_model(X)\n",
" result[i].append(int(torch.argmax(Y_predictions)))"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"tags = []\n",
"tmp = []\n",
"for i in ner_tags_dic:\n",
" tmp.append(i)\n",
"\n",
"for i in range(len(result)):\n",
" tags.append([])\n",
" for j in range(len(result[i])):\n",
" tags[i].append(tmp[result[i][j]])"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"f = open(\"test-A/out.tsv\", \"a\")\n",
"for i in tags:\n",
" f.write(' '.join(i) + '\\n')\n",
"f.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}