forked from kubapok/en-ner-conll-2003
841 lines
22 KiB
Plaintext
841 lines
22 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"name": "main.ipynb",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
},
|
|
"accelerator": "GPU"
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "OY5VomOSCBez"
|
|
},
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import gensim\n",
|
|
"import torch\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"from torchtext.vocab import Vocab\n",
|
|
"from collections import Counter\n",
|
|
"\n",
|
|
"import lzma\n",
|
|
"import re\n",
|
|
"import itertools"
|
|
],
|
|
"execution_count": 1,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "VXcowLY6HlNC"
|
|
},
|
|
"source": [
|
|
"class NeuralNetworkModel(torch.nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self, output_size):\n",
|
|
" super(NeuralNetworkModel, self).__init__()\n",
|
|
" self.fc1 = torch.nn.Linear(10_000,len(train_tokens_ids))\n",
|
|
" self.softmax = torch.nn.Softmax(dim=0)\n",
|
|
" \n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x = self.fc1(x)\n",
|
|
" x = self.softmax(x)\n",
|
|
" return x"
|
|
],
|
|
"execution_count": 2,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "OXX_vPpTHhOq"
|
|
},
|
|
"source": [
|
|
"class NERModel(torch.nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self,):\n",
|
|
" super(NERModel, self).__init__()\n",
|
|
" self.emb = torch.nn.Embedding(23627,200)\n",
|
|
" self.fc1 = torch.nn.Linear(600,9)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x = self.emb(x)\n",
|
|
" x = x.reshape(600) \n",
|
|
" x = self.fc1(x)\n",
|
|
" return x"
|
|
],
|
|
"execution_count": 3,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "NNpGPta9C4TI"
|
|
},
|
|
"source": [
|
|
"def get_dataset(path):\n",
|
|
" data = lzma.open(path).read().decode('UTF-8').split('\\n')\n",
|
|
" return [line.split('\\t') for line in data][:-1]\n",
|
|
"\n",
|
|
"train_data = get_dataset('train.tsv.xz')\n",
|
|
"\n",
|
|
"tokens = []\n",
|
|
"ner_tags = []\n",
|
|
"\n",
|
|
"for i in train_data:\n",
|
|
" ner_tags.append(i[0].split())\n",
|
|
" tokens.append(i[1].split())\n",
|
|
"\n",
|
|
"ner_tags_set = list(set(itertools.chain(*ner_tags)))\n",
|
|
"\n",
|
|
"ner_tags_dictionary = {}\n",
|
|
"\n",
|
|
"for i in range(len(ner_tags_set)):\n",
|
|
" ner_tags_dictionary[ner_tags_set[i]] = i"
|
|
],
|
|
"execution_count": 4,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "vvOF0opUGEMN"
|
|
},
|
|
"source": [
|
|
"for i in range(len(ner_tags)):\n",
|
|
" for j in range(len(ner_tags[i])):\n",
|
|
" ner_tags[i][j] = ner_tags_dictionary[ner_tags[i][j]]\n",
|
|
"\n",
|
|
"def data_preprocessing(data):\n",
|
|
" return [ torch.tensor([vocab['<bos>']] +[vocab[token] for token in document ] + [vocab['<eos>']], dtype = torch.long) for document in data ]\n",
|
|
"\n",
|
|
"def labels_preprocessing(data):\n",
|
|
" return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in data ]\n",
|
|
"\n",
|
|
"def build_vocab(dataset):\n",
|
|
" counter = Counter()\n",
|
|
" for document in dataset:\n",
|
|
" counter.update(document)\n",
|
|
" return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])\n",
|
|
"\n",
|
|
"\n",
|
|
"vocab = build_vocab(tokens)\n",
|
|
"train_tokens_ids = data_preprocessing(tokens)\n",
|
|
"train_labels = labels_preprocessing(ner_tags)"
|
|
],
|
|
"execution_count": 5,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 437
|
|
},
|
|
"id": "yoCYSZNeHJeT",
|
|
"outputId": "8d9dba9d-2bc1-4579-bd9e-98ba4629e642"
|
|
},
|
|
"source": [
|
|
"nn_model = NeuralNetworkModel(len(train_tokens_ids))\n",
|
|
"train_tokens_ids[0][1:4]\n",
|
|
"\n",
|
|
"ner_model = NERModel()\n",
|
|
"ner_model(train_tokens_ids[0][1:4])\n",
|
|
"\n",
|
|
"criterion = torch.nn.CrossEntropyLoss()\n",
|
|
"optimizer = torch.optim.Adam(ner_model.parameters())\n",
|
|
"\n",
|
|
"for epoch in range(2):\n",
|
|
" loss_score = 0\n",
|
|
" acc_score = 0\n",
|
|
" prec_score = 0\n",
|
|
" selected_items = 0\n",
|
|
" recall_score = 0\n",
|
|
" relevant_items = 0\n",
|
|
" items_total = 0\n",
|
|
" nn_model.train()\n",
|
|
" for i in range(100):\n",
|
|
" for j in range(1, len(train_labels[i]) - 1):\n",
|
|
" \n",
|
|
" X = train_tokens_ids[i][j-1: j+2]\n",
|
|
" Y = train_labels[i][j: j+1]\n",
|
|
"\n",
|
|
" Y_predictions = ner_model(X)\n",
|
|
" \n",
|
|
" \n",
|
|
" acc_score += int(torch.argmax(Y_predictions) == Y)\n",
|
|
" \n",
|
|
" if torch.argmax(Y_predictions) != 0:\n",
|
|
" selected_items +=1\n",
|
|
" if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():\n",
|
|
" prec_score += 1\n",
|
|
" \n",
|
|
" if Y.item() != 0:\n",
|
|
" relevant_items +=1\n",
|
|
" if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():\n",
|
|
" recall_score += 1\n",
|
|
" \n",
|
|
" items_total += 1\n",
|
|
"\n",
|
|
" \n",
|
|
" optimizer.zero_grad()\n",
|
|
" loss = criterion(Y_predictions.unsqueeze(0), Y)\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
"\n",
|
|
" loss_score += loss.item() \n",
|
|
" \n",
|
|
" precision = prec_score / selected_items\n",
|
|
" recall = recall_score / relevant_items\n",
|
|
" f1_score = (2*precision * recall) / (precision + recall)\n",
|
|
" display('epoch: ', epoch)\n",
|
|
" display('loss: ', loss_score / items_total)\n",
|
|
" display('acc: ', acc_score / items_total)\n",
|
|
" display('prec: ', precision)\n",
|
|
" display('recall: : ', recall)\n",
|
|
" display('f1: ', f1_score)"
|
|
],
|
|
"execution_count": 6,
|
|
"outputs": [
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'epoch: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'loss: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.5382220030078203"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'acc: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.8581935187313261"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'prec: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.8677398098465594"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'recall: : '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.8674948240165632"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'f1: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.8676172996376301"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'epoch: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"1"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'loss: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.2793121223593968"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'acc: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.9241553665823948"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'prec: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.9306665413180408"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'recall: : '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.9316299642386598"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'f1: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.931148003574284"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "KOVSTjGWVuq9"
|
|
},
|
|
"source": [
|
|
"with open('dev-0/in.tsv', \"r\", encoding=\"utf-8\") as f:\n",
|
|
" dev_0_data = [line.rstrip() for line in f]\n",
|
|
" \n",
|
|
"dev_0_data = [i.split() for i in dev_0_data]\n",
|
|
"\n",
|
|
"with open('dev-0/expected.tsv', \"r\", encoding=\"utf-8\") as f:\n",
|
|
" dev_0_tags = [line.rstrip() for line in f]\n",
|
|
" \n",
|
|
"dev_0_tags = [i.split() for i in dev_0_tags]\n",
|
|
"\n",
|
|
"for i in range(len(dev_0_tags)):\n",
|
|
" for j in range(len(dev_0_tags[i])):\n",
|
|
" dev_0_tags[i][j] = ner_tags_dictionary[dev_0_tags[i][j]]\n",
|
|
" \n",
|
|
"test_tokens_ids = data_preprocessing(dev_0_data)\n",
|
|
"test_labels = labels_preprocessing(dev_0_tags)\n"
|
|
],
|
|
"execution_count": 8,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 192
|
|
},
|
|
"id": "Pt7sVRdhWCqC",
|
|
"outputId": "f4e9ee61-f7a8-47a3-cc90-6d222215024e"
|
|
},
|
|
"source": [
|
|
"result = []\n",
|
|
"\n",
|
|
"loss_score = 0\n",
|
|
"acc_score = 0\n",
|
|
"prec_score = 0\n",
|
|
"selected_items = 0\n",
|
|
"recall_score = 0\n",
|
|
"relevant_items = 0\n",
|
|
"items_total = 0\n",
|
|
"nn_model.eval()\n",
|
|
"\n",
|
|
"for i in range(len(test_tokens_ids)):\n",
|
|
" result.append([])\n",
|
|
" for j in range(1, len(test_labels[i]) - 1):\n",
|
|
"\n",
|
|
" X = test_tokens_ids[i][j-1: j+2]\n",
|
|
" Y = test_labels[i][j: j+1]\n",
|
|
"\n",
|
|
" Y_predictions = ner_model(X)\n",
|
|
"\n",
|
|
"\n",
|
|
" acc_score += int(torch.argmax(Y_predictions) == Y)\n",
|
|
"\n",
|
|
" if torch.argmax(Y_predictions) != 0:\n",
|
|
" selected_items +=1\n",
|
|
" if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():\n",
|
|
" prec_score += 1\n",
|
|
"\n",
|
|
" if Y.item() != 0:\n",
|
|
" relevant_items +=1\n",
|
|
" if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():\n",
|
|
" recall_score += 1\n",
|
|
"\n",
|
|
" items_total += 1\n",
|
|
" loss = criterion(Y_predictions.unsqueeze(0), Y)\n",
|
|
" loss_score += loss.item() \n",
|
|
" \n",
|
|
" result[i].append(int(torch.argmax(Y_predictions)))\n",
|
|
"\n",
|
|
"precision = prec_score / selected_items\n",
|
|
"recall = recall_score / relevant_items\n",
|
|
"f1_score = (2*precision * recall) / (precision + recall)\n",
|
|
"display('loss: ', loss_score / items_total)\n",
|
|
"display('acc: ', acc_score / items_total)\n",
|
|
"display('prec: ', precision)\n",
|
|
"display('recall: : ', recall)\n",
|
|
"display('f1: ', f1_score)"
|
|
],
|
|
"execution_count": 9,
|
|
"outputs": [
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'loss: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.7380534848964866"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'acc: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.846621708531633"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'prec: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.8595547727017202"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'recall: : '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.8640559071729957"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.google.colaboratory.intrinsic+json": {
|
|
"type": "string"
|
|
},
|
|
"text/plain": [
|
|
"'f1: '"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/plain": [
|
|
"0.8617994626787158"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "Bx-pxGsjlVLJ"
|
|
},
|
|
"source": [
|
|
"def save_file(path, data):\n",
|
|
" f = open(path, \"a\")\n",
|
|
"\n",
|
|
" for i in data:\n",
|
|
" f.write(' '.join(i) + '\\n')\n",
|
|
"\n",
|
|
" f.close()"
|
|
],
|
|
"execution_count": 17,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "20zgcUx4sMCe"
|
|
},
|
|
"source": [
|
|
"tags = []\n",
|
|
"tmp = []\n",
|
|
"for i in ner_tags_dictionary:\n",
|
|
" tmp.append(i)\n",
|
|
"\n",
|
|
"for i in range(len(result)):\n",
|
|
" tags.append([])\n",
|
|
" for j in range(len(result[i])):\n",
|
|
" tags[i].append(tmp[result[i][j]])\n",
|
|
"\n",
|
|
"save_file(\"dev-0/out.tsv\", tags)\n",
|
|
"\n",
|
|
"with open('dev-0/expected.tsv', \"r\", encoding=\"utf-8\") as f:\n",
|
|
" dev_0_tags = [line.rstrip() for line in f]\n",
|
|
" \n",
|
|
"dev_0_tags = [i.split() for i in dev_0_tags]\n",
|
|
"\n",
|
|
"import math\n",
|
|
"t = 0\n",
|
|
"for i in range(len(tags)):\n",
|
|
" for j in range(len(tags[i])):\n",
|
|
" if tags[i][j] == dev_0_tags[i][j]:\n",
|
|
" t += 1\n"
|
|
],
|
|
"execution_count": 18,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "d-QCHMrycKwH"
|
|
},
|
|
"source": [
|
|
"with open('test-A/in.tsv', \"r\", encoding=\"utf-8\") as file:\n",
|
|
" test_data = [line.rstrip() for line in file]\n",
|
|
" \n",
|
|
"test_data = [i.split() for i in test_data]\n",
|
|
"test_tokens_ids = data_preprocessing(test_data)\n",
|
|
"result = []\n",
|
|
"\n",
|
|
"loss_score = 0\n",
|
|
"acc_score = 0\n",
|
|
"prec_score = 0\n",
|
|
"selected_items = 0\n",
|
|
"recall_score = 0\n",
|
|
"relevant_items = 0\n",
|
|
"items_total = 0\n",
|
|
"nn_model.eval()\n",
|
|
"\n",
|
|
"test_tokens_length = len(test_tokens_ids)\n",
|
|
"\n",
|
|
"for i in range(test_tokens_length):\n",
|
|
" result.append([])\n",
|
|
" for j in range(1, len(test_tokens_ids[i]) - 1):\n",
|
|
" X = test_tokens_ids[i][j-1: j + 2]\n",
|
|
" Y_predictions = ner_model(X)\n",
|
|
" result[i].append(int(torch.argmax(Y_predictions)))"
|
|
],
|
|
"execution_count": 19,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "862mKzAMkbx_"
|
|
},
|
|
"source": [
|
|
"tags = []\n",
|
|
"tmp = []\n",
|
|
"\n",
|
|
"for i in ner_tags_dictionary:\n",
|
|
" tmp.append(i)\n",
|
|
"\n",
|
|
"result_length = len(result)\n",
|
|
"\n",
|
|
"for i in range(result_length):\n",
|
|
" tags.append([])\n",
|
|
" for j in range(len(result[i])):\n",
|
|
" tags[i].append(tmp[result[i][j]])\n",
|
|
"\n",
|
|
"save_file(\"test-A/out.tsv\", tags)"
|
|
],
|
|
"execution_count": 20,
|
|
"outputs": []
|
|
}
|
|
]
|
|
} |