aitech-eks-pub-22/cw/11_NER_RNN_ODPOWIEDZI.ipynb

1052 lines
24 KiB
Plaintext
Raw Normal View History

2022-06-07 14:56:08 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
"<div class=\"alert alert-block alert-info\">\n",
"<h1> Ekstrakcja informacji </h1>\n",
"<h2> 11. <i>NER RNN</i> [ćwiczenia]</h2> \n",
"<h3> Jakub Pokrywka (2021)</h3>\n",
"</div>\n",
"\n",
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Podejście softmax z embeddingami na przykładzie NER"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import pandas as pd\n",
"\n",
"from datasets import load_dataset\n",
"import torchtext\n",
"#from torchtext.vocab import vocab\n",
"from collections import Counter\n",
"\n",
"\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"from tqdm.notebook import tqdm\n",
"\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
2022-06-07 15:50:27 +02:00
"metadata": {},
"outputs": [],
"source": [
"device = 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 3,
2022-06-07 14:56:08 +02:00
"metadata": {
2022-06-07 15:50:27 +02:00
"scrolled": false
2022-06-07 14:56:08 +02:00
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset conll2003 (/home/kuba/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "5537459a83cc486e927e938f813a5794",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dataset = load_dataset(\"conll2003\")"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 4,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])\n",
" vocab.set_default_index(0)\n",
" return vocab"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 5,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"vocab = build_vocab(dataset['train']['tokens'])"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 6,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"21"
]
},
2022-06-07 15:50:27 +02:00
"execution_count": 6,
2022-06-07 14:56:08 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab['on']"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 7,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"def data_process(dt):\n",
" return [ torch.tensor([vocab['<bos>']] +[vocab[token] for token in document ] + [vocab['<eos>']], dtype = torch.long) for document in dt]"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 8,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"def labels_process(dt):\n",
" return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]\n"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 9,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"train_tokens_ids = data_process(dataset['train']['tokens'])"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 10,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"test_tokens_ids = data_process(dataset['test']['tokens'])"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 11,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"validation_tokens_ids = data_process(dataset['validation']['tokens'])"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 12,
2022-06-07 14:56:08 +02:00
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"train_labels = labels_process(dataset['train']['ner_tags'])"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 13,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"validation_labels = labels_process(dataset['validation']['ner_tags'])"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 14,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"test_labels = labels_process(dataset['test']['ner_tags'])"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 15,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 3])"
]
},
2022-06-07 15:50:27 +02:00
"execution_count": 15,
2022-06-07 14:56:08 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_tokens_ids[0]"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 16,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'id': '0',\n",
" 'tokens': ['EU',\n",
" 'rejects',\n",
" 'German',\n",
" 'call',\n",
" 'to',\n",
" 'boycott',\n",
" 'British',\n",
" 'lamb',\n",
" '.'],\n",
" 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],\n",
" 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],\n",
" 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}"
]
},
2022-06-07 15:50:27 +02:00
"execution_count": 16,
2022-06-07 14:56:08 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset['train'][0]"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 17,
2022-06-07 14:56:08 +02:00
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0, 3, 0, 7, 0, 0, 0, 7, 0, 0, 0])"
]
},
2022-06-07 15:50:27 +02:00
"execution_count": 17,
2022-06-07 14:56:08 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_labels[0]"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 18,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"def get_scores(y_true, y_pred):\n",
" acc_score = 0\n",
" tp = 0\n",
" fp = 0\n",
" selected_items = 0\n",
" relevant_items = 0 \n",
"\n",
" for p,t in zip(y_pred, y_true):\n",
" if p == t:\n",
" acc_score +=1\n",
"\n",
" if p > 0 and p == t:\n",
" tp +=1\n",
"\n",
" if p > 0:\n",
" selected_items += 1\n",
"\n",
" if t > 0 :\n",
" relevant_items +=1\n",
"\n",
" \n",
" \n",
" if selected_items == 0:\n",
" precision = 1.0\n",
" else:\n",
" precision = tp / selected_items\n",
" \n",
" \n",
" if relevant_items == 0:\n",
" recall = 1.0\n",
" else:\n",
" recall = tp / relevant_items\n",
" \n",
" \n",
" if precision + recall == 0.0 :\n",
" f1 = 0.0\n",
" else:\n",
" f1 = 2* precision * recall / (precision + recall)\n",
"\n",
" return precision, recall, f1"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 19,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"num_tags = max([max(x) for x in dataset['train']['ner_tags'] if x]) + 1 "
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 20,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"class LSTM(torch.nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(LSTM, self).__init__()\n",
" self.emb = torch.nn.Embedding(len(vocab.get_itos()),100)\n",
" self.rec = torch.nn.LSTM(100, 256, 1, batch_first = True)\n",
" self.fc1 = torch.nn.Linear( 256 , 9)\n",
"\n",
" def forward(self, x):\n",
" emb = torch.relu(self.emb(x))\n",
" \n",
" lstm_output, (h_n, c_n) = self.rec(emb)\n",
" \n",
" out_weights = self.fc1(lstm_output)\n",
"\n",
" return out_weights"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 21,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-06-07 15:50:27 +02:00
"lstm = LSTM().to(device)"
2022-06-07 14:56:08 +02:00
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 22,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-06-07 15:50:27 +02:00
"criterion = torch.nn.CrossEntropyLoss().to(device)"
2022-06-07 14:56:08 +02:00
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 23,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(lstm.parameters())"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 24,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"def eval_model(dataset_tokens, dataset_labels, model):\n",
" Y_true = []\n",
" Y_pred = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
2022-06-07 15:50:27 +02:00
" batch_tokens = dataset_tokens[i].unsqueeze(0).to(device)\n",
2022-06-07 14:56:08 +02:00
" tags = list(dataset_labels[i].numpy())\n",
" Y_true += tags\n",
" \n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights,1)\n",
2022-06-07 15:50:27 +02:00
" Y_pred += list(Y_batch_pred.cpu().numpy())\n",
2022-06-07 14:56:08 +02:00
" \n",
"\n",
" return get_scores(Y_true, Y_pred)\n",
" "
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 25,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"NUM_EPOCHS = 5"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 26,
2022-06-07 14:56:08 +02:00
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "3b7cca5ee20b472d80f02c6d4fa54c4e",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2022-06-07 15:50:27 +02:00
" 0%| | 0/14042 [00:00<?, ?it/s]"
2022-06-07 14:56:08 +02:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "dfc1a78154bf4efda20bd62bdf9e6c99",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3251 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-06-07 15:50:27 +02:00
"(0.516575591985428, 0.49447867023131464, 0.505285663380449)\n"
2022-06-07 14:56:08 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "4a94b241621943fd8cdd70bbda9c334b",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2022-06-07 15:50:27 +02:00
" 0%| | 0/14042 [00:00<?, ?it/s]"
2022-06-07 14:56:08 +02:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "9d9b7c4e48ac469cadfb90e79da70107",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3251 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-06-07 15:50:27 +02:00
"(0.6624173748819642, 0.6523305823549924, 0.6573352855051245)\n"
2022-06-07 14:56:08 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "c5d395d9553d47e4b96a3fa176ce05d5",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2022-06-07 15:50:27 +02:00
" 0%| | 0/14042 [00:00<?, ?it/s]"
2022-06-07 14:56:08 +02:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "061de4b1aac5429d8091ba07b5e8ba2f",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3251 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-06-07 15:50:27 +02:00
"(0.7022361255937898, 0.7045216784842496, 0.7033770453754206)\n"
2022-06-07 14:56:08 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "14ade5ef81ab45d0832e4999ac62467a",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2022-06-07 15:50:27 +02:00
" 0%| | 0/14042 [00:00<?, ?it/s]"
2022-06-07 14:56:08 +02:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "9a7d4c3ffdd445fa9765bc2233bb2cf5",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3251 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-06-07 15:50:27 +02:00
"(0.7282225874618455, 0.7210275485295827, 0.7246072075229251)\n"
2022-06-07 14:56:08 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "ca77549e5d4248a4bdd51b66865505da",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2022-06-07 15:50:27 +02:00
" 0%| | 0/14042 [00:00<?, ?it/s]"
2022-06-07 14:56:08 +02:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "bd2ec9174d50443db5aa98b9d8b50c66",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3251 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-06-07 15:50:27 +02:00
"(0.7124554367201426, 0.7433453446472161, 0.7275726719381079)\n"
2022-06-07 14:56:08 +02:00
]
}
],
"source": [
"for i in range(NUM_EPOCHS):\n",
" lstm.train()\n",
2022-06-07 15:50:27 +02:00
" #for i in tqdm(range(5000)):\n",
" for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0).to(device)\n",
" tags = train_labels[i].unsqueeze(1).to(device)\n",
2022-06-07 14:56:08 +02:00
" \n",
" \n",
" predicted_tags = lstm(batch_tokens)\n",
"\n",
" \n",
" optimizer.zero_grad()\n",
" loss = criterion(predicted_tags.squeeze(0),tags.squeeze(1))\n",
" \n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" lstm.eval()\n",
" print(eval_model(validation_tokens_ids, validation_labels, lstm))"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 27,
2022-06-07 14:56:08 +02:00
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "f8fc75fb00954c3eb59aa2d40786fef7",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3251 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
2022-06-07 15:50:27 +02:00
"(0.7124554367201426, 0.7433453446472161, 0.7275726719381079)"
2022-06-07 14:56:08 +02:00
]
},
2022-06-07 15:50:27 +02:00
"execution_count": 27,
2022-06-07 14:56:08 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eval_model(validation_tokens_ids, validation_labels, lstm)"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 28,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "0709c9b9be2446ea86e1ea0bc8b5ae3a",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3454 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
2022-06-07 15:50:27 +02:00
"(0.6445353594389246, 0.6797337278106509, 0.6616667666646667)"
2022-06-07 14:56:08 +02:00
]
},
2022-06-07 15:50:27 +02:00
"execution_count": 28,
2022-06-07 14:56:08 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eval_model(test_tokens_ids, test_labels, lstm)"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 29,
2022-06-07 14:56:08 +02:00
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"14042"
]
},
2022-06-07 15:50:27 +02:00
"execution_count": 29,
2022-06-07 14:56:08 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_tokens_ids)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## pytania\n",
"\n",
"- co zrobić z trenowaniem na batchach > 1 ?\n",
"- co zrobić, żeby sieć uwzględniała następne tokeny, a nie tylko poprzednie?\n",
"- w jaki sposób wykorzystać taką sieć do zadania zwykłej klasyfikacji?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Zadanie na zajęcia ( 20 minut)\n",
"\n",
"zmodyfikować sieć tak, żeby była używała dwuwarstwowej, dwukierunkowej warstwy GRU oraz dropoutu. Dropout ma nałożony na embeddingi.\n"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 30,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"class GRU(torch.nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(GRU, self).__init__()\n",
" self.emb = torch.nn.Embedding(len(vocab.get_itos()),100)\n",
" self.dropout = torch.nn.Dropout(0.2)\n",
" self.rec = torch.nn.GRU(100, 256, 2, batch_first = True, bidirectional = True)\n",
" self.fc1 = torch.nn.Linear(2* 256 , 9)\n",
" \n",
" def forward(self, x):\n",
" emb = torch.relu(self.emb(x))\n",
" emb = self.dropout(emb)\n",
" \n",
" gru_output, h_n = self.rec(emb)\n",
" \n",
" out_weights = self.fc1(gru_output)\n",
"\n",
" return out_weights"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 31,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-06-07 15:50:27 +02:00
"gru = GRU().to(device)"
2022-06-07 14:56:08 +02:00
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 32,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 33,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(gru.parameters())"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 34,
2022-06-07 14:56:08 +02:00
"metadata": {},
"outputs": [],
"source": [
"NUM_EPOCHS = 5"
]
},
{
"cell_type": "code",
2022-06-07 15:50:27 +02:00
"execution_count": 35,
2022-06-07 14:56:08 +02:00
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "fc4d756d3f9d45cea875ecdc268ed9f9",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2022-06-07 15:50:27 +02:00
" 0%| | 0/14042 [00:00<?, ?it/s]"
2022-06-07 14:56:08 +02:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "03b9fea03b8042f3bc143f0cc0ae70de",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3251 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-06-07 15:50:27 +02:00
"(0.6109818520241973, 0.4578635359758224, 0.5234551495016612)\n"
2022-06-07 14:56:08 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "9091f9231c7b4400b22360510a6dbca2",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2022-06-07 15:50:27 +02:00
" 0%| | 0/14042 [00:00<?, ?it/s]"
2022-06-07 14:56:08 +02:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "9c968d5eda614e7da357cf260deb2372",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3251 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-06-07 15:50:27 +02:00
"(0.6290377039954981, 0.6496570963617343, 0.639181152790485)\n"
2022-06-07 14:56:08 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "dc78edb6313b4439ad4099d0842ded9b",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2022-06-07 15:50:27 +02:00
" 0%| | 0/14042 [00:00<?, ?it/s]"
2022-06-07 14:56:08 +02:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "782be7d5c44a43bb8309a50ad85564d3",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3251 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-06-07 15:50:27 +02:00
"(0.6755871725383921, 0.6954550738114611, 0.6853771693682342)\n"
2022-06-07 14:56:08 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "1999a5193c7142039037dc567d6e56e5",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2022-06-07 15:50:27 +02:00
" 0%| | 0/14042 [00:00<?, ?it/s]"
2022-06-07 14:56:08 +02:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "9ad5d2d4387b40ecbfb493fd3385fb1b",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3251 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-06-07 15:50:27 +02:00
"(0.7477821586988664, 0.7054515866558178, 0.7260003588731384)\n"
2022-06-07 14:56:08 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "47411db4679941519585f5a89227fd8d",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2022-06-07 15:50:27 +02:00
" 0%| | 0/14042 [00:00<?, ?it/s]"
2022-06-07 14:56:08 +02:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-06-07 15:50:27 +02:00
"model_id": "a05b4e171efb4a5594e45c611f94aa18",
2022-06-07 14:56:08 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3251 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-06-07 15:50:27 +02:00
"(0.7669533169533169, 0.725677089387423, 0.745744490234725)\n"
2022-06-07 14:56:08 +02:00
]
}
],
"source": [
"for i in range(NUM_EPOCHS):\n",
" gru.train()\n",
2022-06-07 15:50:27 +02:00
" #for i in tqdm(range(500)):\n",
" for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0).to(device)\n",
" tags = train_labels[i].unsqueeze(1).to(device)\n",
2022-06-07 14:56:08 +02:00
" \n",
" \n",
" predicted_tags = gru(batch_tokens)\n",
"\n",
" \n",
" optimizer.zero_grad()\n",
" loss = criterion(predicted_tags.squeeze(0),tags.squeeze(1))\n",
" \n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" \n",
" gru.eval()\n",
" print(eval_model(validation_tokens_ids, validation_labels, gru))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Zadanie domowe\n",
"\n",
"\n",
"- stworzyć model seq labelling bazujący na sieci neuronowej opisanej w punkcie niżej (można bazować na tym jupyterze lub nie).\n",
"- model sieci to GRU (o dowolnych parametrach) + CRF w pytorchu korzystając z modułu CRF z poprzednich zajęć- - stworzyć predykcje w plikach dev-0/out.tsv oraz test-A/out.tsv\n",
"- wynik fscore sprawdzony za pomocą narzędzia geval (patrz poprzednie zadanie) powinien wynosić conajmniej 0.65\n",
"termin 22.06, 60 punktów, za najlepszy wynik- 100 punktów\n",
" "
]
}
],
"metadata": {
"author": "Jakub Pokrywka",
"email": "kubapok@wmi.amu.edu.pl",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"subtitle": "11.NER RNN[ćwiczenia]",
"title": "Ekstrakcja informacji",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}