diff --git a/cw/11_NER_RNN.ipynb b/cw/11_NER_RNN.ipynb index 23c498f..e5c91d2 100644 --- a/cw/11_NER_RNN.ipynb +++ b/cw/11_NER_RNN.ipynb @@ -28,19 +28,14 @@ "outputs": [], "source": [ "import numpy as np\n", - "import gensim\n", "import torch\n", "import pandas as pd\n", - "import seaborn as sns\n", - "from sklearn.model_selection import train_test_split\n", "\n", "from datasets import load_dataset\n", "import torchtext\n", "#from torchtext.vocab import vocab\n", "from collections import Counter\n", "\n", - "from sklearn.datasets import fetch_20newsgroups\n", - "# https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html\n", "\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.metrics import accuracy_score\n", @@ -53,8 +48,17 @@ { "cell_type": "code", "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "device = 'cpu'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "metadata": { - "scrolled": true + "scrolled": false }, "outputs": [ { @@ -67,7 +71,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7c9a8ca324914c40b7606ab8cd487df2", + "model_id": "5537459a83cc486e927e938f813a5794", "version_major": 2, "version_minor": 0 }, @@ -85,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -100,7 +104,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -109,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -118,7 +122,7 @@ "21" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -129,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -139,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -149,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -158,7 +162,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -167,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -176,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": { "scrolled": true }, @@ -187,7 +191,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -196,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -205,7 +209,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -214,7 +218,7 @@ "tensor([ 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 3])" ] }, - "execution_count": 14, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -225,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -246,7 +250,7 @@ " 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}" ] }, - "execution_count": 15, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -257,7 +261,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": { "scrolled": true }, @@ -268,7 +272,7 @@ "tensor([0, 3, 0, 7, 0, 0, 0, 7, 0, 0, 0])" ] }, - "execution_count": 16, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -279,7 +283,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -327,7 +331,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -336,7 +340,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -358,22 +362,13 @@ " return out_weights" ] }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "lstm = LSTM()" - ] - }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ - "criterion = torch.nn.CrossEntropyLoss()" + "lstm = LSTM().to(device)" ] }, { @@ -382,7 +377,7 @@ "metadata": {}, "outputs": [], "source": [ - "optimizer = torch.optim.Adam(lstm.parameters())" + "criterion = torch.nn.CrossEntropyLoss().to(device)" ] }, { @@ -391,21 +386,7 @@ "metadata": {}, "outputs": [], "source": [ - "def eval_model(dataset_tokens, dataset_labels, model):\n", - " Y_true = []\n", - " Y_pred = []\n", - " for i in tqdm(range(len(dataset_labels))):\n", - " batch_tokens = dataset_tokens[i].unsqueeze(0)\n", - " tags = list(dataset_labels[i].numpy())\n", - " Y_true += tags\n", - " \n", - " Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n", - " Y_batch_pred = torch.argmax(Y_batch_pred_weights,1)\n", - " Y_pred += list(Y_batch_pred.numpy())\n", - " \n", - "\n", - " return get_scores(Y_true, Y_pred)\n", - " " + "optimizer = torch.optim.Adam(lstm.parameters())" ] }, { @@ -414,12 +395,35 @@ "metadata": {}, "outputs": [], "source": [ - "NUM_EPOCHS = 5" + "def eval_model(dataset_tokens, dataset_labels, model):\n", + " Y_true = []\n", + " Y_pred = []\n", + " for i in tqdm(range(len(dataset_labels))):\n", + " batch_tokens = dataset_tokens[i].unsqueeze(0).to(device)\n", + " tags = list(dataset_labels[i].numpy())\n", + " Y_true += tags\n", + " \n", + " Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n", + " Y_batch_pred = torch.argmax(Y_batch_pred_weights,1)\n", + " Y_pred += list(Y_batch_pred.cpu().numpy())\n", + " \n", + "\n", + " return get_scores(Y_true, Y_pred)\n", + " " ] }, { "cell_type": "code", "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "NUM_EPOCHS = 5" + ] + }, + { + "cell_type": "code", + "execution_count": 26, "metadata": { "scrolled": true }, @@ -427,12 +431,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "59e268fa2b29414fb6306ec4ee44d51f", + "model_id": "3b7cca5ee20b472d80f02c6d4fa54c4e", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/500 [00:00