en-ner-conll-2003/rnn_fras.ipynb

1404 lines
34 KiB
Plaintext
Raw Permalink Normal View History

2021-06-22 20:21:17 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Zadanie domowe\n",
"\n",
"\n",
"- sklonować repozytorium https://git.wmi.amu.edu.pl/kubapok/en-ner-conll-2003\n",
"- stworzyć model seq labelling bazujący na sieci neuronowej opisanej w punkcie niżej (można bazować na tym jupyterze lub nie).\n",
"- model sieci to GRU (o dowolnych parametrach) + CRF w pytorchu korzystając z modułu CRF z poprzednich zajęć- - stworzyć predykcje w plikach dev-0/out.tsv oraz test-A/out.tsv\n",
"- wynik fscore sprawdzony za pomocą narzędzia geval (patrz poprzednie zadanie) powinien wynosić conajmniej 0.65\n",
"- proszę umieścić predykcję oraz skrypty generujące (w postaci tekstowej a nie jupyter) w repo, a w MS TEAMS umieścić link do swojego repo\n",
"termin 22.06, 60 punktów, za najlepszy wynik- 100 punktów\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from torchtext.vocab import Vocab\n",
"from collections import Counter\n",
"from tqdm.notebook import tqdm\n",
"import lzma\n",
"import itertools\n",
"from torchcrf import CRF"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def read_data(filename):\n",
" all_data = lzma.open(filename).read().decode('UTF-8').split('\\n')\n",
" return [line.split('\\t') for line in all_data][:-1]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def data_process(dt):\n",
" return [torch.tensor([vocab['<bos>']] + [vocab[token] for token in document] + [vocab['<eos>']], dtype = torch.long) for document in dt]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def labels_process(dt):\n",
" return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"train_data = read_data('train/train.tsv.xz')\n",
"\n",
"tokens, ner_tags = [], []\n",
"for i in train_data:\n",
" ner_tags.append(i[0].split())\n",
" tokens.append(i[1].split())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"vocab = build_vocab(tokens)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"train_tokens_ids = data_process(tokens)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'I-LOC', 'I-MISC', 'I-ORG', 'I-PER', 'O']\n"
]
}
],
"source": [
"ner_tags_set = list(set(itertools.chain(*ner_tags)))\n",
"ner_tags_set.sort()\n",
"print(ner_tags_set)\n",
"train_labels = labels_process([[ner_tags_set.index(token) for token in doc] for doc in ner_tags])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"num_tags = max([max(x) for x in train_labels]) + 1 "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class GRU(torch.nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(GRU, self).__init__()\n",
" self.emb = torch.nn.Embedding(len(vocab.itos),100)\n",
" self.dropout = torch.nn.Dropout(0.2)\n",
" self.rec = torch.nn.GRU(100, 256, 2, batch_first = True, bidirectional = True)\n",
" self.fc1 = torch.nn.Linear(2* 256 , 9)\n",
" \n",
" def forward(self, x):\n",
" emb = torch.relu(self.emb(x))\n",
" emb = self.dropout(emb)\n",
" gru_output, h_n = self.rec(emb)\n",
" out_weights = self.fc1(gru_output)\n",
" return out_weights"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def get_scores(y_true, y_pred):\n",
" acc_score = 0\n",
" tp = 0\n",
" fp = 0\n",
" selected_items = 0\n",
" relevant_items = 0 \n",
"\n",
" for p,t in zip(y_pred, y_true):\n",
" if p == t:\n",
" acc_score +=1\n",
"\n",
" if p > 0 and p == t:\n",
" tp +=1\n",
"\n",
" if p > 0:\n",
" selected_items += 1\n",
"\n",
" if t > 0 :\n",
" relevant_items +=1\n",
" \n",
" if selected_items == 0:\n",
" precision = 1.0\n",
" else:\n",
" precision = tp / selected_items\n",
" \n",
" if relevant_items == 0:\n",
" recall = 1.0\n",
" else:\n",
" recall = tp / relevant_items\n",
" \n",
" if precision + recall == 0.0 :\n",
" f1 = 0.0\n",
" else:\n",
" f1 = 2* precision * recall / (precision + recall)\n",
"\n",
" return precision, recall, f1"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def eval_model(dataset_tokens, dataset_labels, model):\n",
" Y_true = []\n",
" Y_pred = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(1)\n",
" tags = list(dataset_labels[i].numpy())\n",
" emissions = gru(batch_tokens).squeeze(0)\n",
" Y_pred += crf.decode(emissions)[0]\n",
" Y_true += tags\n",
" return get_scores(Y_true, Y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"gru = GRU()\n",
"crf = CRF(num_tags)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"params = list(gru.parameters()) + list(crf.parameters())\n",
"optimizer = torch.optim.Adam(params)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"NUM_EPOCHS = 20"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c5da680182d74dbe8a6e6e515f39c304",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/zosia/.local/lib/python3.8/site-packages/torchcrf/__init__.py:249: UserWarning: where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead. (Triggered internally at /pytorch/aten/src/ATen/native/TensorCompare.cpp:255.)\n",
" score = torch.where(mask[i].unsqueeze(1), next_score, score)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3ca55e4b508d4fc9b2d0720e1def2a58",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.8601941656899232, 0.8751514345303986, 0.8676083403589915)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "31afa4456a9240789283af09788a3ed9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a6b3ba1f3b474cf092a826c87a0345be",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.8815602436292092, 0.8897984198549079, 0.8856601748234387)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0d18da57114b4e0ab646fcb52860dabd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9cfb2facab9c4b56924c27e287ba05d4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9144309250302297, 0.919752763828645, 0.9170841238373373)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a0b5a4064324446c8a1a0c70d07cda59",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dea1cbf55a0c43fa84167e376b309125",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9361905528132853, 0.9398110097060626, 0.9379972877369673)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0524c9827f294852a5cd271bfbdbfcd2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "04b306aaa0604677aa251727feacd2ba",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9519541852390448, 0.9547763044748607, 0.9533631563717097)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1404f55a3ce546c2b99ddc12679b5d97",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c848b31135b14155b98aeaea7b8ac2be",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.960722713444972, 0.9632376346282668, 0.961978530336279)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8713c93530b94398a96354138783e326",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "57d75f0d65be4ca4a5401cb7ed3d5fe0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9697570414352719, 0.9714709221947199, 0.9706132252353172)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "47614d60ab4b4f0abd56c395420103a6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2b88316f13234cdeb1486f66cf08d5b6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9760554565110192, 0.9779891394717963, 0.9770213412246582)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9df02fe711de4ccbac63cbcc77b9185b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d45a98359f90443b9de4017908729121",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9811127302761178, 0.9819703829690195, 0.9815413692723396)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "85e862f466b049089dd1ae4d4b3b25b4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3c38cdbf365448a8bec573e7f9bf3831",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.984655071665091, 0.9846831395763159, 0.9846691054206851)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fc3a0e186cb94c649f6d489d8afe0b28",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "aa122bd240074ffb9ad1a8e8df787497",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9871442343767067, 0.9875194192515452, 0.9873317911716786)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bc95fc4c8eb84aa99bc9c09e449edc53",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "36dc7b6450a24bac82ced0efb0d9c4a0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9893908786272786, 0.9889114292094049, 0.9891510958201069)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ec3fe5ad630e42b181096b84f985428f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3b1bcac29f3f4f5d949ee0860c345329",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9911312527046112, 0.9901989196482444, 0.9906648668174991)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4dd0d5d81d5943f1ad7d7451014727d9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "98ef88bd0dcf45fd9ebeaf2ab77b2dbf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9924332083291745, 0.9919900041332719, 0.9922115567382627)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "99089ca7a97a4168a1cc46d1eba49a62",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "294ec84096cd474a80ebff9c36ea0644",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9930640069977942, 0.9924270857582653, 0.9927454442197611)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "130e89d10ba54246a65645381c69538a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "11ae3075d74443a79bbd4789a5bdd9b7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9739162872556146, 0.9674801769230403, 0.9706875636048171)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cfaa4824a18046d898818cbed675450e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ba9b8454320c4d53b041f6f424b773ab",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9848088502477955, 0.9837187094689933, 0.9842634780066597)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a5c8d9a82a9b4b4f8a58b83dcf50dfc8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "286d0b6ad83146e4911b967e9cbde195",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9808100926458495, 0.9802695653413275, 0.9805397545015183)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b6342f18da15402e9c391550235a7ded",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5fc5c4feee604ddea5ead7ff203f07e8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9668917478143436, 0.9694090371376854, 0.968148756174055)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "927b69c7183442aeaf3dc08ae3e20cbe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "15f90bd1520b4773bcbb699f985ea031",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9793555195345366, 0.9788157938495013, 0.979085582310423)\n"
]
}
],
"source": [
"for i in range(NUM_EPOCHS):\n",
" gru.train()\n",
" crf.train()\n",
" for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(1)\n",
" tags = train_labels[i].unsqueeze(1)\n",
" emissions = gru(batch_tokens).squeeze(0)\n",
" optimizer.zero_grad()\n",
" loss = -crf(emissions,tags.squeeze(0))\n",
" loss.backward()\n",
" optimizer.step()\n",
" gru.eval()\n",
" crf.eval()\n",
" print(eval_model(train_tokens_ids, train_labels, gru))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## dev-0 i test-A"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def predict_labels(dataset_tokens, dataset_labels, model):\n",
" print(len(dataset_tokens[0]), len(dataset_labels[0]))\n",
" Y_true = []\n",
" Y_pred = []\n",
" result = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(1)\n",
" tags = list(dataset_labels[i].numpy())\n",
" emissions = gru(batch_tokens).squeeze(0)\n",
" tmp = crf.decode(emissions)[0]\n",
" Y_pred += tmp\n",
" result += [tmp]\n",
" Y_true += tags\n",
" print(get_scores(Y_true, Y_pred))\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"with open('dev-0/in.tsv', \"r\", encoding=\"utf-8\") as f:\n",
" dev_0_data = [line.rstrip() for line in f]\n",
" \n",
"dev_0_data = [i.split() for i in dev_0_data]\n",
"dev_0_tokens_ids = data_process(dev_0_data)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"with open('dev-0/expected.tsv', \"r\", encoding=\"utf-8\") as f:\n",
" dev_0_labels = [line.rstrip() for line in f]\n",
" \n",
"dev_0_labels = [i.split() for i in dev_0_labels]\n",
"dev_0_labels = labels_process([[ner_tags_set.index(token) for token in doc] for doc in dev_0_labels])\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"458 458\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e48f16faacc043ac8237af22f32b0af1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=215.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(0.9501477944520237, 0.9535808009736432, 0.9518612023310112)\n"
]
}
],
"source": [
"tmp = predict_labels(dev_0_tokens_ids, dev_0_labels, gru)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"r = [[ner_tags_set[i] for i in tmp2] for tmp2 in tmp]\n",
"r = [i[1:-1] for i in r]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"for doc in r:\n",
" if doc[0] != 'O':\n",
" doc[0] = 'B' + doc[0][1:]\n",
" for i in range(len(doc))[:-1]:\n",
" if doc[i] == 'O':\n",
" if doc[i + 1] != 'O':\n",
" doc[i + 1] = 'B' + doc[i + 1][1:]\n",
" elif doc[i + 1] != 'O':\n",
" if doc[i][1:] == doc[i + 1][1:]:\n",
" doc[i + 1] = 'I' + doc[i + 1][1:]\n",
" else:\n",
" doc[i + 1] = 'B' + doc[i + 1][1:]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"f = open(\"dev-0/out.tsv\", \"a\")\n",
"for i in r:\n",
" f.write(' '.join(i) + '\\n')\n",
"f.close()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9cce1860765e420f9b0bfaa23b651f58",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=215.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "42e2565e95db4efb9343d93f195212d5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=230.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"def predict(path, model):\n",
" with open(path + '/in.tsv', \"r\", encoding=\"utf-8\") as f:\n",
" data = [line.rstrip() for line in f]\n",
" data = [i.split() for i in data]\n",
" tokens_ids = data_process(data)\n",
" \n",
" Y_true = []\n",
" Y_pred = []\n",
" result = []\n",
" for i in tqdm(range(len(tokens_ids))):\n",
" batch_tokens = tokens_ids[i].unsqueeze(1)\n",
" emissions = gru(batch_tokens).squeeze(0)\n",
" tmp = crf.decode(emissions)[0]\n",
" Y_pred += tmp\n",
" result += [tmp]\n",
" r = [[ner_tags_set[i] for i in tmp] for tmp in result]\n",
" r = [i[1:-1] for i in r]\n",
" for doc in r:\n",
" if doc[0] != 'O':\n",
" doc[0] = 'B' + doc[0][1:]\n",
" for i in range(len(doc))[:-1]:\n",
" if doc[i] == 'O':\n",
" if doc[i + 1] != 'O':\n",
" doc[i + 1] = 'B' + doc[i + 1][1:]\n",
" elif doc[i + 1] != 'O':\n",
" if doc[i][1:] == doc[i + 1][1:]:\n",
" doc[i + 1] = 'I' + doc[i + 1][1:]\n",
" else:\n",
" doc[i + 1] = 'B' + doc[i + 1][1:]\n",
" f = open(path + \"/out.tsv\", \"a\")\n",
" for i in r:\n",
" f.write(' '.join(i) + '\\n')\n",
" f.close()\n",
" return result\n",
"\n",
"result = predict('dev-0', gru)\n",
"result = predict('test-A', gru)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}