This commit is contained in:
Jakub Pokrywka 2022-06-07 15:50:27 +02:00
parent 3d85ca4084
commit 86915640a6
2 changed files with 223 additions and 215 deletions

View File

@ -28,19 +28,14 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import gensim\n",
"import torch\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from datasets import load_dataset\n",
"import torchtext\n",
"#from torchtext.vocab import vocab\n",
"from collections import Counter\n",
"\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"# https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html\n",
"\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.metrics import accuracy_score\n",
@ -53,8 +48,17 @@
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
"scrolled": false
},
"outputs": [
{
@ -67,7 +71,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7c9a8ca324914c40b7606ab8cd487df2",
"model_id": "5537459a83cc486e927e938f813a5794",
"version_major": 2,
"version_minor": 0
},
@ -85,7 +89,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -100,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -109,7 +113,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@ -118,7 +122,7 @@
"21"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -129,7 +133,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@ -139,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@ -149,7 +153,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@ -158,7 +162,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -167,7 +171,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@ -176,7 +180,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {
"scrolled": true
},
@ -187,7 +191,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@ -196,7 +200,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@ -205,7 +209,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"metadata": {},
"outputs": [
{
@ -214,7 +218,7 @@
"tensor([ 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 3])"
]
},
"execution_count": 14,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@ -225,7 +229,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"metadata": {},
"outputs": [
{
@ -246,7 +250,7 @@
" 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}"
]
},
"execution_count": 15,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@ -257,7 +261,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"metadata": {
"scrolled": true
},
@ -268,7 +272,7 @@
"tensor([0, 3, 0, 7, 0, 0, 0, 7, 0, 0, 0])"
]
},
"execution_count": 16,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
@ -279,7 +283,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@ -327,7 +331,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
@ -336,7 +340,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@ -358,22 +362,13 @@
" return out_weights"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"lstm = LSTM()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()"
"lstm = LSTM().to(device)"
]
},
{
@ -382,7 +377,7 @@
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(lstm.parameters())"
"criterion = torch.nn.CrossEntropyLoss().to(device)"
]
},
{
@ -391,21 +386,7 @@
"metadata": {},
"outputs": [],
"source": [
"def eval_model(dataset_tokens, dataset_labels, model):\n",
" Y_true = []\n",
" Y_pred = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(0)\n",
" tags = list(dataset_labels[i].numpy())\n",
" Y_true += tags\n",
" \n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights,1)\n",
" Y_pred += list(Y_batch_pred.numpy())\n",
" \n",
"\n",
" return get_scores(Y_true, Y_pred)\n",
" "
"optimizer = torch.optim.Adam(lstm.parameters())"
]
},
{
@ -414,12 +395,35 @@
"metadata": {},
"outputs": [],
"source": [
"NUM_EPOCHS = 5"
"def eval_model(dataset_tokens, dataset_labels, model):\n",
" Y_true = []\n",
" Y_pred = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(0).to(device)\n",
" tags = list(dataset_labels[i].numpy())\n",
" Y_true += tags\n",
" \n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights,1)\n",
" Y_pred += list(Y_batch_pred.cpu().numpy())\n",
" \n",
"\n",
" return get_scores(Y_true, Y_pred)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"NUM_EPOCHS = 5"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"scrolled": true
},
@ -427,12 +431,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "59e268fa2b29414fb6306ec4ee44d51f",
"model_id": "3b7cca5ee20b472d80f02c6d4fa54c4e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -441,7 +445,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "77f4b857b41143429af8391023430e23",
"model_id": "dfc1a78154bf4efda20bd62bdf9e6c99",
"version_major": 2,
"version_minor": 0
},
@ -456,18 +460,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.2310126582278481, 0.02545623619667558, 0.04585907234844519)\n"
"(0.516575591985428, 0.49447867023131464, 0.505285663380449)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "558d39ff9ab34f458e4d64f24028fe50",
"model_id": "4a94b241621943fd8cdd70bbda9c334b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -476,7 +480,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3094de37bef4484a87ed4789bfc85bdc",
"model_id": "9d9b7c4e48ac469cadfb90e79da70107",
"version_major": 2,
"version_minor": 0
},
@ -491,18 +495,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.22903453136011276, 0.15111007787980937, 0.1820855802227047)\n"
"(0.6624173748819642, 0.6523305823549924, 0.6573352855051245)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8baf610abdb04715924dba6109782efd",
"model_id": "c5d395d9553d47e4b96a3fa176ce05d5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -511,7 +515,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "204b0274b9ea42caa10d8d05838ed035",
"model_id": "061de4b1aac5429d8091ba07b5e8ba2f",
"version_major": 2,
"version_minor": 0
},
@ -526,18 +530,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.22289679098005205, 0.20911310008136696, 0.21578505457598657)\n"
"(0.7022361255937898, 0.7045216784842496, 0.7033770453754206)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fc6e663b99614e0e8c2382ef93a6402f",
"model_id": "14ade5ef81ab45d0832e4999ac62467a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -546,7 +550,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b963be2045a7494499c309693632e506",
"model_id": "9a7d4c3ffdd445fa9765bc2233bb2cf5",
"version_major": 2,
"version_minor": 0
},
@ -561,18 +565,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.2553244180287271, 0.23968383122166687, 0.2472570297979495)\n"
"(0.7282225874618455, 0.7210275485295827, 0.7246072075229251)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9e941358d44949c5a0f147f2287bf226",
"model_id": "ca77549e5d4248a4bdd51b66865505da",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -581,7 +585,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2363bcab950947b8bff899cd01f4ec0a",
"model_id": "bd2ec9174d50443db5aa98b9d8b50c66",
"version_major": 2,
"version_minor": 0
},
@ -596,17 +600,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.26687507236308905, 0.2679297919330466, 0.26740139211136893)\n"
"(0.7124554367201426, 0.7433453446472161, 0.7275726719381079)\n"
]
}
],
"source": [
"for i in range(NUM_EPOCHS):\n",
" lstm.train()\n",
" for i in tqdm(range(500)):\n",
" #for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0)\n",
" tags = train_labels[i].unsqueeze(1)\n",
" #for i in tqdm(range(5000)):\n",
" for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0).to(device)\n",
" tags = train_labels[i].unsqueeze(1).to(device)\n",
" \n",
" \n",
" predicted_tags = lstm(batch_tokens)\n",
@ -624,7 +628,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 27,
"metadata": {
"scrolled": true
},
@ -632,7 +636,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "11001c61092a4fd89efd1e155f6b0682",
"model_id": "f8fc75fb00954c3eb59aa2d40786fef7",
"version_major": 2,
"version_minor": 0
},
@ -646,10 +650,10 @@
{
"data": {
"text/plain": [
"(0.26687507236308905, 0.2679297919330466, 0.26740139211136893)"
"(0.7124554367201426, 0.7433453446472161, 0.7275726719381079)"
]
},
"execution_count": 26,
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
@ -660,13 +664,13 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "90336a538c2443608d45e094cc62e916",
"model_id": "0709c9b9be2446ea86e1ea0bc8b5ae3a",
"version_major": 2,
"version_minor": 0
},
@ -680,10 +684,10 @@
{
"data": {
"text/plain": [
"(0.2493934363427404, 0.24075443786982248, 0.24499780467916954)"
"(0.6445353594389246, 0.6797337278106509, 0.6616667666646667)"
]
},
"execution_count": 27,
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
@ -694,7 +698,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 29,
"metadata": {
"scrolled": true
},
@ -705,7 +709,7 @@
"14042"
]
},
"execution_count": 28,
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}

View File

@ -28,19 +28,14 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import gensim\n",
"import torch\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from datasets import load_dataset\n",
"import torchtext\n",
"#from torchtext.vocab import vocab\n",
"from collections import Counter\n",
"\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"# https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html\n",
"\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.metrics import accuracy_score\n",
@ -53,8 +48,17 @@
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
"scrolled": false
},
"outputs": [
{
@ -67,7 +71,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7c9a8ca324914c40b7606ab8cd487df2",
"model_id": "5537459a83cc486e927e938f813a5794",
"version_major": 2,
"version_minor": 0
},
@ -85,7 +89,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -100,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -109,7 +113,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@ -118,7 +122,7 @@
"21"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -129,7 +133,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@ -139,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@ -149,7 +153,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@ -158,7 +162,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -167,7 +171,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@ -176,7 +180,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {
"scrolled": true
},
@ -187,7 +191,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@ -196,7 +200,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@ -205,7 +209,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"metadata": {},
"outputs": [
{
@ -214,7 +218,7 @@
"tensor([ 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 3])"
]
},
"execution_count": 14,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@ -225,7 +229,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"metadata": {},
"outputs": [
{
@ -246,7 +250,7 @@
" 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}"
]
},
"execution_count": 15,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@ -257,7 +261,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"metadata": {
"scrolled": true
},
@ -268,7 +272,7 @@
"tensor([0, 3, 0, 7, 0, 0, 0, 7, 0, 0, 0])"
]
},
"execution_count": 16,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
@ -279,7 +283,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@ -327,7 +331,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
@ -336,7 +340,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@ -358,22 +362,13 @@
" return out_weights"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"lstm = LSTM()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()"
"lstm = LSTM().to(device)"
]
},
{
@ -382,7 +377,7 @@
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(lstm.parameters())"
"criterion = torch.nn.CrossEntropyLoss().to(device)"
]
},
{
@ -391,21 +386,7 @@
"metadata": {},
"outputs": [],
"source": [
"def eval_model(dataset_tokens, dataset_labels, model):\n",
" Y_true = []\n",
" Y_pred = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(0)\n",
" tags = list(dataset_labels[i].numpy())\n",
" Y_true += tags\n",
" \n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights,1)\n",
" Y_pred += list(Y_batch_pred.numpy())\n",
" \n",
"\n",
" return get_scores(Y_true, Y_pred)\n",
" "
"optimizer = torch.optim.Adam(lstm.parameters())"
]
},
{
@ -414,12 +395,35 @@
"metadata": {},
"outputs": [],
"source": [
"NUM_EPOCHS = 5"
"def eval_model(dataset_tokens, dataset_labels, model):\n",
" Y_true = []\n",
" Y_pred = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(0).to(device)\n",
" tags = list(dataset_labels[i].numpy())\n",
" Y_true += tags\n",
" \n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights,1)\n",
" Y_pred += list(Y_batch_pred.cpu().numpy())\n",
" \n",
"\n",
" return get_scores(Y_true, Y_pred)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"NUM_EPOCHS = 5"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"scrolled": true
},
@ -427,12 +431,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "59e268fa2b29414fb6306ec4ee44d51f",
"model_id": "3b7cca5ee20b472d80f02c6d4fa54c4e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -441,7 +445,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "77f4b857b41143429af8391023430e23",
"model_id": "dfc1a78154bf4efda20bd62bdf9e6c99",
"version_major": 2,
"version_minor": 0
},
@ -456,18 +460,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.2310126582278481, 0.02545623619667558, 0.04585907234844519)\n"
"(0.516575591985428, 0.49447867023131464, 0.505285663380449)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "558d39ff9ab34f458e4d64f24028fe50",
"model_id": "4a94b241621943fd8cdd70bbda9c334b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -476,7 +480,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3094de37bef4484a87ed4789bfc85bdc",
"model_id": "9d9b7c4e48ac469cadfb90e79da70107",
"version_major": 2,
"version_minor": 0
},
@ -491,18 +495,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.22903453136011276, 0.15111007787980937, 0.1820855802227047)\n"
"(0.6624173748819642, 0.6523305823549924, 0.6573352855051245)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8baf610abdb04715924dba6109782efd",
"model_id": "c5d395d9553d47e4b96a3fa176ce05d5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -511,7 +515,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "204b0274b9ea42caa10d8d05838ed035",
"model_id": "061de4b1aac5429d8091ba07b5e8ba2f",
"version_major": 2,
"version_minor": 0
},
@ -526,18 +530,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.22289679098005205, 0.20911310008136696, 0.21578505457598657)\n"
"(0.7022361255937898, 0.7045216784842496, 0.7033770453754206)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fc6e663b99614e0e8c2382ef93a6402f",
"model_id": "14ade5ef81ab45d0832e4999ac62467a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -546,7 +550,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b963be2045a7494499c309693632e506",
"model_id": "9a7d4c3ffdd445fa9765bc2233bb2cf5",
"version_major": 2,
"version_minor": 0
},
@ -561,18 +565,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.2553244180287271, 0.23968383122166687, 0.2472570297979495)\n"
"(0.7282225874618455, 0.7210275485295827, 0.7246072075229251)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9e941358d44949c5a0f147f2287bf226",
"model_id": "ca77549e5d4248a4bdd51b66865505da",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -581,7 +585,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2363bcab950947b8bff899cd01f4ec0a",
"model_id": "bd2ec9174d50443db5aa98b9d8b50c66",
"version_major": 2,
"version_minor": 0
},
@ -596,17 +600,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.26687507236308905, 0.2679297919330466, 0.26740139211136893)\n"
"(0.7124554367201426, 0.7433453446472161, 0.7275726719381079)\n"
]
}
],
"source": [
"for i in range(NUM_EPOCHS):\n",
" lstm.train()\n",
" for i in tqdm(range(500)):\n",
" #for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0)\n",
" tags = train_labels[i].unsqueeze(1)\n",
" #for i in tqdm(range(5000)):\n",
" for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0).to(device)\n",
" tags = train_labels[i].unsqueeze(1).to(device)\n",
" \n",
" \n",
" predicted_tags = lstm(batch_tokens)\n",
@ -624,7 +628,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 27,
"metadata": {
"scrolled": true
},
@ -632,7 +636,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "11001c61092a4fd89efd1e155f6b0682",
"model_id": "f8fc75fb00954c3eb59aa2d40786fef7",
"version_major": 2,
"version_minor": 0
},
@ -646,10 +650,10 @@
{
"data": {
"text/plain": [
"(0.26687507236308905, 0.2679297919330466, 0.26740139211136893)"
"(0.7124554367201426, 0.7433453446472161, 0.7275726719381079)"
]
},
"execution_count": 26,
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
@ -660,13 +664,13 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "90336a538c2443608d45e094cc62e916",
"model_id": "0709c9b9be2446ea86e1ea0bc8b5ae3a",
"version_major": 2,
"version_minor": 0
},
@ -680,10 +684,10 @@
{
"data": {
"text/plain": [
"(0.2493934363427404, 0.24075443786982248, 0.24499780467916954)"
"(0.6445353594389246, 0.6797337278106509, 0.6616667666646667)"
]
},
"execution_count": 27,
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
@ -694,7 +698,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 29,
"metadata": {
"scrolled": true
},
@ -705,7 +709,7 @@
"14042"
]
},
"execution_count": 28,
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
@ -736,7 +740,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
@ -760,22 +764,13 @@
" return out_weights"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"gru = GRU()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()"
"gru = GRU().to(device)"
]
},
{
@ -784,7 +779,7 @@
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(gru.parameters())"
"criterion = torch.nn.CrossEntropyLoss()"
]
},
{
@ -793,12 +788,21 @@
"metadata": {},
"outputs": [],
"source": [
"NUM_EPOCHS = 5"
"optimizer = torch.optim.Adam(gru.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"NUM_EPOCHS = 5"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"scrolled": true
},
@ -806,12 +810,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "109e0891142545ee8315c040cb231fb2",
"model_id": "fc4d756d3f9d45cea875ecdc268ed9f9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -820,7 +824,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "46b4b71a6b4b4c20bf1d093abb25b8c6",
"model_id": "03b9fea03b8042f3bc143f0cc0ae70de",
"version_major": 2,
"version_minor": 0
},
@ -835,18 +839,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.38776758409785933, 0.14739044519353714, 0.2135938684410006)\n"
"(0.6109818520241973, 0.4578635359758224, 0.5234551495016612)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a5ea7182cef84130ac18adc4b47c4ea4",
"model_id": "9091f9231c7b4400b22360510a6dbca2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -855,7 +859,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "151e6a36e4cd48f89a6bd7d68eff9e2b",
"model_id": "9c968d5eda614e7da357cf260deb2372",
"version_major": 2,
"version_minor": 0
},
@ -870,18 +874,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.27651183172655563, 0.22003952109729163, 0.24506440546313676)\n"
"(0.6290377039954981, 0.6496570963617343, 0.639181152790485)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "44b30d1dcb914c5fa5281e2ba0264a4c",
"model_id": "dc78edb6313b4439ad4099d0842ded9b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -890,7 +894,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "48cf8a32a9b3464e8dc467ade38d6b64",
"model_id": "782be7d5c44a43bb8309a50ad85564d3",
"version_major": 2,
"version_minor": 0
},
@ -905,18 +909,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.31285223367697595, 0.2645588748111124, 0.28668598060209094)\n"
"(0.6755871725383921, 0.6954550738114611, 0.6853771693682342)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "24d10c567c6444faa26dd543ec1405d6",
"model_id": "1999a5193c7142039037dc567d6e56e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -925,7 +929,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e3061f7c3278474f8fcdc286cf3fda5b",
"model_id": "9ad5d2d4387b40ecbfb493fd3385fb1b",
"version_major": 2,
"version_minor": 0
},
@ -940,18 +944,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.2596728376922323, 0.3081483203533651, 0.2818413778439294)\n"
"(0.7477821586988664, 0.7054515866558178, 0.7260003588731384)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "527e776a01634bf9bd8d6c1dab5dafaf",
"model_id": "47411db4679941519585f5a89227fd8d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
" 0%| | 0/14042 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -960,7 +964,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "25bc5029685d41a7b1498ea2b54bb33c",
"model_id": "a05b4e171efb4a5594e45c611f94aa18",
"version_major": 2,
"version_minor": 0
},
@ -975,17 +979,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0.29086115992970124, 0.30779960478902707, 0.29909075506861693)\n"
"(0.7669533169533169, 0.725677089387423, 0.745744490234725)\n"
]
}
],
"source": [
"for i in range(NUM_EPOCHS):\n",
" gru.train()\n",
" for i in tqdm(range(500)):\n",
" #for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0)\n",
" tags = train_labels[i].unsqueeze(1)\n",
" #for i in tqdm(range(500)):\n",
" for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0).to(device)\n",
" tags = train_labels[i].unsqueeze(1).to(device)\n",
" \n",
" \n",
" predicted_tags = gru(batch_tokens)\n",