SystemyDialogowe-ProjektMag.../lab/08-parsing-semantyczny-uczenie.ipynb
2022-04-29 01:06:09 +02:00

1048 lines
61 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
"<div class=\"alert alert-block alert-info\">\n",
"<h1> Systemy Dialogowe </h1>\n",
"<h2> 8. <i>Parsing semantyczny z wykorzystaniem technik uczenia maszynowego</i> [laboratoria]</h2> \n",
"<h3> Marek Kubis (2021)</h3>\n",
"</div>\n",
"\n",
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Parsing semantyczny z wykorzystaniem technik uczenia maszynowego\n",
"================================================================\n",
"\n",
"Wprowadzenie\n",
"------------\n",
"Problem wykrywania slotów i ich wartości w wypowiedziach użytkownika można sformułować jako zadanie\n",
"polegające na przewidywaniu dla poszczególnych słów etykiet wskazujących na to czy i do jakiego\n",
"slotu dane słowo należy.\n",
"\n",
"> chciałbym zarezerwować stolik na jutro**/day** na godzinę dwunastą**/hour** czterdzieści**/hour** pięć**/hour** na pięć**/size** osób\n",
"\n",
"Granice slotów oznacza się korzystając z wybranego schematu etykietowania.\n",
"\n",
"### Schemat IOB\n",
"\n",
"| Prefix | Znaczenie |\n",
"|:------:|:---------------------------|\n",
"| I | wnętrze slotu (inside) |\n",
"| O | poza slotem (outside) |\n",
"| B | początek slotu (beginning) |\n",
"\n",
"> chciałbym zarezerwować stolik na jutro**/B-day** na godzinę dwunastą**/B-hour** czterdzieści**/I-hour** pięć**/I-hour** na pięć**/B-size** osób\n",
"\n",
"### Schemat IOBES\n",
"\n",
"| Prefix | Znaczenie |\n",
"|:------:|:---------------------------|\n",
"| I | wnętrze slotu (inside) |\n",
"| O | poza slotem (outside) |\n",
"| B | początek slotu (beginning) |\n",
"| E | koniec slotu (ending) |\n",
"| S | pojedyncze słowo (single) |\n",
"\n",
"> chciałbym zarezerwować stolik na jutro**/S-day** na godzinę dwunastą**/B-hour** czterdzieści**/I-hour** pięć**/E-hour** na pięć**/S-size** osób\n",
"\n",
"Jeżeli dla tak sformułowanego zadania przygotujemy zbiór danych\n",
"złożony z wypowiedzi użytkownika z oznaczonymi slotami (tzw. *zbiór uczący*),\n",
"to możemy zastosować techniki (nadzorowanego) uczenia maszynowego w celu zbudowania modelu\n",
"annotującego wypowiedzi użytkownika etykietami slotów.\n",
"\n",
"Do zbudowania takiego modelu można wykorzystać między innymi:\n",
"\n",
" 1. warunkowe pola losowe (Lafferty i in.; 2001),\n",
"\n",
" 2. rekurencyjne sieci neuronowe, np. sieci LSTM (Hochreiter i Schmidhuber; 1997),\n",
"\n",
" 3. transformery (Vaswani i in., 2017).\n",
"\n",
"Przykład\n",
"--------\n",
"Skorzystamy ze zbioru danych przygotowanego przez Schustera (2019)."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"c:\\Develop\\wmi\\AITECH\\sem1\\Systemy dialogowe\\lab\\l07\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"A subdirectory or file -p already exists.\n",
"Error occurred while processing: -p.\n",
"A subdirectory or file l07 already exists.\n",
"Error occurred while processing: l07.\n",
"** Resuming transfer from byte position 8923190\n",
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"\n",
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
"\n",
"100 49 100 49 0 0 118 0 --:--:-- --:--:-- --:--:-- 118\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"c:\\Develop\\wmi\\AITECH\\sem1\\Systemy dialogowe\\lab\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"'unzip' is not recognized as an internal or external command,\n",
"operable program or batch file.\n"
]
}
],
"source": [
"!mkdir -p l07\n",
"%cd l07\n",
"!curl -L -C - https://fb.me/multilingual_task_oriented_data -o data.zip\n",
"!unzip data.zip\n",
"%cd .."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Zbiór ten gromadzi wypowiedzi w trzech językach opisane slotami dla dwunastu ram należących do trzech dziedzin `Alarm`, `Reminder` oraz `Weather`. Dane wczytamy korzystając z biblioteki [conllu](https://pypi.org/project/conllu/)."
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# text: halo\n",
"\n",
"# intent: hello\n",
"\n",
"# slots: \n",
"\n",
"1\thalo\thello\tNoLabel\n",
"\n",
"\n",
"\n",
"# text: chaciałbym pójść na premierę filmu jakie premiery są w tym tygodniu\n",
"\n",
"# intent: reqmore\n",
"\n",
"# slots: \n",
"\n",
"1\tchaciałbym\treqmore\tNoLabel\n",
"\n",
"2\tpójść\treqmore\tNoLabel\n",
"\n",
"3\tna\treqmore\tNoLabel\n",
"\n",
"4\tpremierę\treqmore\tNoLabel\n",
"\n",
"5\tfilmu\treqmore\tNoLabel\n",
"\n",
"6\tjakie\treqmore\tNoLabel\n",
"\n",
"7\tpremiery\treqmore\tNoLabel\n",
"\n"
]
}
],
"source": [
"from conllu import parse_incr\n",
"fields = ['id', 'form', 'frame', 'slot']\n",
"\n",
"def nolabel2o(line, i):\n",
" return 'O' if line[i] == 'NoLabel' else line[i]\n",
"# pathTrain = '../tasks/zad8/en/train-en.conllu'\n",
"# pathTest = '../tasks/zad8/en/test-en.conllu'\n",
"\n",
"pathTrain = '../tasks/zad8/pl/train.conllu'\n",
"pathTest = '../tasks/zad8/pl/test.conllu'\n",
"\n",
"with open(pathTrain, encoding=\"UTF-8\") as trainfile:\n",
" i=0\n",
" for line in trainfile:\n",
" print(line)\n",
" i+=1\n",
" if i==15: break \n",
" trainset = list(parse_incr(trainfile, fields=fields, field_parsers={'slot': nolabel2o}))\n",
"with open(pathTest, encoding=\"UTF-8\") as testfile:\n",
" testset = list(parse_incr(testfile, fields=fields, field_parsers={'slot': nolabel2o}))\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Zobaczmy kilka przykładowych wypowiedzi z tego zbioru."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tbody>\n",
"<tr><td style=\"text-align: right;\">1</td><td>wybieram</td><td>inform</td><td>O </td></tr>\n",
"<tr><td style=\"text-align: right;\">2</td><td>batmana </td><td>inform</td><td>B-title</td></tr>\n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"'<table>\\n<tbody>\\n<tr><td style=\"text-align: right;\">1</td><td>wybieram</td><td>inform</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">2</td><td>batmana </td><td>inform</td><td>B-title</td></tr>\\n</tbody>\\n</table>'"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from tabulate import tabulate\n",
"tabulate(trainset[1], tablefmt='html')"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tbody>\n",
"<tr><td style=\"text-align: right;\">1</td><td>chcę </td><td>inform</td><td>O</td></tr>\n",
"<tr><td style=\"text-align: right;\">2</td><td>zarezerwować</td><td>inform</td><td>O</td></tr>\n",
"<tr><td style=\"text-align: right;\">3</td><td>bilety </td><td>inform</td><td>O</td></tr>\n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"'<table>\\n<tbody>\\n<tr><td style=\"text-align: right;\">1</td><td>chcę </td><td>inform</td><td>O</td></tr>\\n<tr><td style=\"text-align: right;\">2</td><td>zarezerwować</td><td>inform</td><td>O</td></tr>\\n<tr><td style=\"text-align: right;\">3</td><td>bilety </td><td>inform</td><td>O</td></tr>\\n</tbody>\\n</table>'"
]
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tabulate(trainset[16], tablefmt='html')"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tbody>\n",
"<tr><td style=\"text-align: right;\">1</td><td>chciałbym </td><td>inform</td><td>O</td></tr>\n",
"<tr><td style=\"text-align: right;\">2</td><td>anulować </td><td>inform</td><td>O</td></tr>\n",
"<tr><td style=\"text-align: right;\">3</td><td>rezerwację</td><td>inform</td><td>O</td></tr>\n",
"<tr><td style=\"text-align: right;\">4</td><td>biletu </td><td>inform</td><td>O</td></tr>\n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"'<table>\\n<tbody>\\n<tr><td style=\"text-align: right;\">1</td><td>chciałbym </td><td>inform</td><td>O</td></tr>\\n<tr><td style=\"text-align: right;\">2</td><td>anulować </td><td>inform</td><td>O</td></tr>\\n<tr><td style=\"text-align: right;\">3</td><td>rezerwację</td><td>inform</td><td>O</td></tr>\\n<tr><td style=\"text-align: right;\">4</td><td>biletu </td><td>inform</td><td>O</td></tr>\\n</tbody>\\n</table>'"
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tabulate(trainset[20], tablefmt='html')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Budując model skorzystamy z architektury opartej o rekurencyjne sieci neuronowe\n",
"zaimplementowanej w bibliotece [flair](https://github.com/flairNLP/flair) (Akbik i in. 2018)."
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"from flair.data import Corpus, Sentence, Token\n",
"from flair.datasets import SentenceDataset\n",
"from flair.embeddings import StackedEmbeddings\n",
"from flair.embeddings import WordEmbeddings\n",
"from flair.embeddings import CharacterEmbeddings\n",
"from flair.embeddings import FlairEmbeddings\n",
"from flair.models import SequenceTagger\n",
"from flair.trainers import ModelTrainer\n",
"\n",
"# determinizacja obliczeń\n",
"import random\n",
"import torch\n",
"random.seed(42)\n",
"torch.manual_seed(42)\n",
"\n",
"if torch.cuda.is_available():\n",
" torch.cuda.manual_seed(0)\n",
" torch.cuda.manual_seed_all(0)\n",
" torch.backends.cudnn.enabled = False\n",
" torch.backends.cudnn.benchmark = False\n",
" torch.backends.cudnn.deterministic = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dane skonwertujemy do formatu wykorzystywanego przez `flair`, korzystając z następującej funkcji."
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Corpus: 297 train + 33 dev + 33 test sentences\n",
"Dictionary with 14 tags: <unk>, O, B-date, I-date, B-time, I-time, B-area, I-area, B-title, B-quantity, I-title, I-quantity, <START>, <STOP>\n"
]
}
],
"source": [
"def conllu2flair(sentences, label=None):\n",
" fsentences = []\n",
"\n",
" for sentence in sentences:\n",
" fsentence = Sentence()\n",
"\n",
" for token in sentence:\n",
" ftoken = Token(token['form'])\n",
"\n",
" if label:\n",
" ftoken.add_tag(label, token[label])\n",
"\n",
" fsentence.add_token(ftoken)\n",
"\n",
" fsentences.append(fsentence)\n",
"\n",
" return SentenceDataset(fsentences)\n",
"\n",
"corpus = Corpus(train=conllu2flair(trainset, 'slot'), test=conllu2flair(testset, 'slot'))\n",
"print(corpus)\n",
"tag_dictionary = corpus.make_tag_dictionary(tag_type='slot')\n",
"print(tag_dictionary)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nasz model będzie wykorzystywał wektorowe reprezentacje słów (zob. [Word Embeddings](https://github.com/flairNLP/flair/blob/master/resources/docs/TUTORIAL_3_WORD_EMBEDDING.md))."
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-04-28 22:14:01,525 https://flair.informatik.hu-berlin.de/resources/embeddings/token/pl-wiki-fasttext-300d-1M.vectors.npy not found in cache, downloading to C:\\Users\\48516\\AppData\\Local\\Temp\\tmp8ekygs88\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1199998928/1199998928 [01:00<00:00, 19734932.13B/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-04-28 22:15:02,505 copying C:\\Users\\48516\\AppData\\Local\\Temp\\tmp8ekygs88 to cache at C:\\Users\\48516\\.flair\\embeddings\\pl-wiki-fasttext-300d-1M.vectors.npy\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-04-28 22:15:03,136 removing temp file C:\\Users\\48516\\AppData\\Local\\Temp\\tmp8ekygs88\n",
"2022-04-28 22:15:03,420 https://flair.informatik.hu-berlin.de/resources/embeddings/token/pl-wiki-fasttext-300d-1M not found in cache, downloading to C:\\Users\\48516\\AppData\\Local\\Temp\\tmp612sxdgl\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 40874795/40874795 [00:02<00:00, 18943852.55B/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-04-28 22:15:05,807 copying C:\\Users\\48516\\AppData\\Local\\Temp\\tmp612sxdgl to cache at C:\\Users\\48516\\.flair\\embeddings\\pl-wiki-fasttext-300d-1M\n",
"2022-04-28 22:15:05,830 removing temp file C:\\Users\\48516\\AppData\\Local\\Temp\\tmp612sxdgl\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-04-28 22:15:13,095 https://flair.informatik.hu-berlin.de/resources/embeddings/flair/lm-polish-forward-v0.2.pt not found in cache, downloading to C:\\Users\\48516\\AppData\\Local\\Temp\\tmp05k_xff8\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 84244196/84244196 [00:04<00:00, 19653900.77B/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-04-28 22:15:17,599 copying C:\\Users\\48516\\AppData\\Local\\Temp\\tmp05k_xff8 to cache at C:\\Users\\48516\\.flair\\embeddings\\lm-polish-forward-v0.2.pt\n",
"2022-04-28 22:15:17,640 removing temp file C:\\Users\\48516\\AppData\\Local\\Temp\\tmp05k_xff8\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-04-28 22:15:18,034 https://flair.informatik.hu-berlin.de/resources/embeddings/flair/lm-polish-backward-v0.2.pt not found in cache, downloading to C:\\Users\\48516\\AppData\\Local\\Temp\\tmpbjevekqx\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 84244196/84244196 [00:04<00:00, 19850177.72B/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-04-28 22:15:22,467 copying C:\\Users\\48516\\AppData\\Local\\Temp\\tmpbjevekqx to cache at C:\\Users\\48516\\.flair\\embeddings\\lm-polish-backward-v0.2.pt\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-04-28 22:15:22,518 removing temp file C:\\Users\\48516\\AppData\\Local\\Temp\\tmpbjevekqx\n"
]
}
],
"source": [
"embedding_types = [\n",
" WordEmbeddings('pl'),\n",
" FlairEmbeddings('polish-forward'),\n",
" FlairEmbeddings('polish-backward'),\n",
" CharacterEmbeddings(),\n",
"]\n",
"\n",
"embeddings = StackedEmbeddings(embeddings=embedding_types)\n",
"tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,\n",
" tag_dictionary=tag_dictionary,\n",
" tag_type='slot', use_crf=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Zobaczmy jak wygląda architektura sieci neuronowej, która będzie odpowiedzialna za przewidywanie\n",
"slotów w wypowiedziach."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SequenceTagger(\n",
" (embeddings): StackedEmbeddings(\n",
" (list_embedding_0): WordEmbeddings('pl')\n",
" (list_embedding_1): FlairEmbeddings(\n",
" (lm): LanguageModel(\n",
" (drop): Dropout(p=0.25, inplace=False)\n",
" (encoder): Embedding(1602, 100)\n",
" (rnn): LSTM(100, 2048)\n",
" (decoder): Linear(in_features=2048, out_features=1602, bias=True)\n",
" )\n",
" )\n",
" (list_embedding_2): FlairEmbeddings(\n",
" (lm): LanguageModel(\n",
" (drop): Dropout(p=0.25, inplace=False)\n",
" (encoder): Embedding(1602, 100)\n",
" (rnn): LSTM(100, 2048)\n",
" (decoder): Linear(in_features=2048, out_features=1602, bias=True)\n",
" )\n",
" )\n",
" (list_embedding_3): CharacterEmbeddings(\n",
" (char_embedding): Embedding(275, 25)\n",
" (char_rnn): LSTM(25, 25, bidirectional=True)\n",
" )\n",
" )\n",
" (word_dropout): WordDropout(p=0.05)\n",
" (locked_dropout): LockedDropout(p=0.5)\n",
" (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True)\n",
" (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True)\n",
" (linear): Linear(in_features=512, out_features=14, bias=True)\n",
" (beta): 1.0\n",
" (weights): None\n",
" (weight_tensor) None\n",
")\n"
]
}
],
"source": [
"print(tagger)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Wykonamy dziesięć iteracji (epok) uczenia a wynikowy model zapiszemy w katalogu `slot-model`."
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-04-28 22:15:23,085 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:15:23,086 Model: \"SequenceTagger(\n",
" (embeddings): StackedEmbeddings(\n",
" (list_embedding_0): WordEmbeddings('pl')\n",
" (list_embedding_1): FlairEmbeddings(\n",
" (lm): LanguageModel(\n",
" (drop): Dropout(p=0.25, inplace=False)\n",
" (encoder): Embedding(1602, 100)\n",
" (rnn): LSTM(100, 2048)\n",
" (decoder): Linear(in_features=2048, out_features=1602, bias=True)\n",
" )\n",
" )\n",
" (list_embedding_2): FlairEmbeddings(\n",
" (lm): LanguageModel(\n",
" (drop): Dropout(p=0.25, inplace=False)\n",
" (encoder): Embedding(1602, 100)\n",
" (rnn): LSTM(100, 2048)\n",
" (decoder): Linear(in_features=2048, out_features=1602, bias=True)\n",
" )\n",
" )\n",
" (list_embedding_3): CharacterEmbeddings(\n",
" (char_embedding): Embedding(275, 25)\n",
" (char_rnn): LSTM(25, 25, bidirectional=True)\n",
" )\n",
" )\n",
" (word_dropout): WordDropout(p=0.05)\n",
" (locked_dropout): LockedDropout(p=0.5)\n",
" (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True)\n",
" (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True)\n",
" (linear): Linear(in_features=512, out_features=14, bias=True)\n",
" (beta): 1.0\n",
" (weights): None\n",
" (weight_tensor) None\n",
")\"\n",
"2022-04-28 22:15:23,087 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:15:23,088 Corpus: \"Corpus: 297 train + 33 dev + 33 test sentences\"\n",
"2022-04-28 22:15:23,088 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:15:23,089 Parameters:\n",
"2022-04-28 22:15:23,089 - learning_rate: \"0.1\"\n",
"2022-04-28 22:15:23,090 - mini_batch_size: \"32\"\n",
"2022-04-28 22:15:23,090 - patience: \"3\"\n",
"2022-04-28 22:15:23,091 - anneal_factor: \"0.5\"\n",
"2022-04-28 22:15:23,092 - max_epochs: \"10\"\n",
"2022-04-28 22:15:23,093 - shuffle: \"True\"\n",
"2022-04-28 22:15:23,093 - train_with_dev: \"False\"\n",
"2022-04-28 22:15:23,094 - batch_growth_annealing: \"False\"\n",
"2022-04-28 22:15:23,094 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:15:23,095 Model training base path: \"slot-model\"\n",
"2022-04-28 22:15:23,095 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:15:23,096 Device: cpu\n",
"2022-04-28 22:15:23,096 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:15:23,097 Embeddings storage mode: cpu\n",
"2022-04-28 22:15:23,100 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:15:25,051 epoch 1 - iter 1/10 - loss 15.67058754 - samples/sec: 16.40 - lr: 0.100000\n",
"2022-04-28 22:15:27,334 epoch 1 - iter 2/10 - loss 13.01803017 - samples/sec: 14.02 - lr: 0.100000\n",
"2022-04-28 22:15:29,132 epoch 1 - iter 3/10 - loss 11.16305335 - samples/sec: 17.81 - lr: 0.100000\n",
"2022-04-28 22:15:30,629 epoch 1 - iter 4/10 - loss 9.23769999 - samples/sec: 21.39 - lr: 0.100000\n",
"2022-04-28 22:15:32,614 epoch 1 - iter 5/10 - loss 7.94914236 - samples/sec: 16.13 - lr: 0.100000\n",
"2022-04-28 22:15:34,081 epoch 1 - iter 6/10 - loss 7.05464562 - samples/sec: 21.83 - lr: 0.100000\n",
"2022-04-28 22:15:35,257 epoch 1 - iter 7/10 - loss 6.28502292 - samples/sec: 27.26 - lr: 0.100000\n",
"2022-04-28 22:15:37,386 epoch 1 - iter 8/10 - loss 5.74554797 - samples/sec: 15.04 - lr: 0.100000\n",
"2022-04-28 22:15:39,009 epoch 1 - iter 9/10 - loss 5.48559354 - samples/sec: 19.73 - lr: 0.100000\n",
"2022-04-28 22:15:39,892 epoch 1 - iter 10/10 - loss 5.10890775 - samples/sec: 36.28 - lr: 0.100000\n",
"2022-04-28 22:15:39,893 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:15:39,894 EPOCH 1 done: loss 5.1089 - lr 0.1000000\n",
"2022-04-28 22:15:41,651 DEV : loss 1.1116931438446045 - score 0.0\n",
"2022-04-28 22:15:41,654 BAD EPOCHS (no improvement): 0\n",
"saving best model\n",
"2022-04-28 22:15:54,970 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:15:55,703 epoch 2 - iter 1/10 - loss 2.39535546 - samples/sec: 48.71 - lr: 0.100000\n",
"2022-04-28 22:15:56,276 epoch 2 - iter 2/10 - loss 3.14594960 - samples/sec: 55.94 - lr: 0.100000\n",
"2022-04-28 22:15:56,849 epoch 2 - iter 3/10 - loss 2.96723008 - samples/sec: 55.94 - lr: 0.100000\n",
"2022-04-28 22:15:57,326 epoch 2 - iter 4/10 - loss 2.72414619 - samples/sec: 67.23 - lr: 0.100000\n",
"2022-04-28 22:15:57,799 epoch 2 - iter 5/10 - loss 2.52746274 - samples/sec: 67.80 - lr: 0.100000\n",
"2022-04-28 22:15:58,255 epoch 2 - iter 6/10 - loss 2.41920217 - samples/sec: 70.33 - lr: 0.100000\n",
"2022-04-28 22:15:58,770 epoch 2 - iter 7/10 - loss 2.48535442 - samples/sec: 62.26 - lr: 0.100000\n",
"2022-04-28 22:15:59,324 epoch 2 - iter 8/10 - loss 2.40343314 - samples/sec: 57.87 - lr: 0.100000\n",
"2022-04-28 22:15:59,827 epoch 2 - iter 9/10 - loss 2.41345758 - samples/sec: 63.74 - lr: 0.100000\n",
"2022-04-28 22:16:00,052 epoch 2 - iter 10/10 - loss 2.63766205 - samples/sec: 142.86 - lr: 0.100000\n",
"2022-04-28 22:16:00,053 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:00,054 EPOCH 2 done: loss 2.6377 - lr 0.1000000\n",
"2022-04-28 22:16:00,234 DEV : loss 1.2027416229248047 - score 0.0\n",
"2022-04-28 22:16:00,238 BAD EPOCHS (no improvement): 1\n",
"2022-04-28 22:16:00,241 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:00,771 epoch 3 - iter 1/10 - loss 2.07519531 - samples/sec: 60.61 - lr: 0.100000\n",
"2022-04-28 22:16:01,297 epoch 3 - iter 2/10 - loss 2.21946335 - samples/sec: 60.95 - lr: 0.100000\n",
"2022-04-28 22:16:01,826 epoch 3 - iter 3/10 - loss 2.32372427 - samples/sec: 60.61 - lr: 0.100000\n",
"2022-04-28 22:16:02,304 epoch 3 - iter 4/10 - loss 2.18133342 - samples/sec: 67.23 - lr: 0.100000\n",
"2022-04-28 22:16:02,727 epoch 3 - iter 5/10 - loss 2.10553741 - samples/sec: 75.83 - lr: 0.100000\n",
"2022-04-28 22:16:03,215 epoch 3 - iter 6/10 - loss 1.99518015 - samples/sec: 65.84 - lr: 0.100000\n",
"2022-04-28 22:16:03,670 epoch 3 - iter 7/10 - loss 2.03174150 - samples/sec: 70.64 - lr: 0.100000\n",
"2022-04-28 22:16:04,239 epoch 3 - iter 8/10 - loss 2.19520997 - samples/sec: 56.34 - lr: 0.100000\n",
"2022-04-28 22:16:04,686 epoch 3 - iter 9/10 - loss 2.15986861 - samples/sec: 71.75 - lr: 0.100000\n",
"2022-04-28 22:16:04,919 epoch 3 - iter 10/10 - loss 2.02860461 - samples/sec: 137.93 - lr: 0.100000\n",
"2022-04-28 22:16:04,920 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:04,921 EPOCH 3 done: loss 2.0286 - lr 0.1000000\n",
"2022-04-28 22:16:05,067 DEV : loss 0.9265440702438354 - score 0.0\n",
"2022-04-28 22:16:05,069 BAD EPOCHS (no improvement): 0\n",
"saving best model\n",
"2022-04-28 22:16:10,882 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:11,339 epoch 4 - iter 1/10 - loss 2.63443780 - samples/sec: 70.33 - lr: 0.100000\n",
"2022-04-28 22:16:11,858 epoch 4 - iter 2/10 - loss 2.35905457 - samples/sec: 61.78 - lr: 0.100000\n",
"2022-04-28 22:16:12,523 epoch 4 - iter 3/10 - loss 2.23206981 - samples/sec: 48.19 - lr: 0.100000\n",
"2022-04-28 22:16:13,026 epoch 4 - iter 4/10 - loss 2.28027773 - samples/sec: 63.75 - lr: 0.100000\n",
"2022-04-28 22:16:13,610 epoch 4 - iter 5/10 - loss 2.22129200 - samples/sec: 54.98 - lr: 0.100000\n",
"2022-04-28 22:16:14,074 epoch 4 - iter 6/10 - loss 2.10545621 - samples/sec: 69.11 - lr: 0.100000\n",
"2022-04-28 22:16:14,646 epoch 4 - iter 7/10 - loss 2.10457425 - samples/sec: 56.04 - lr: 0.100000\n",
"2022-04-28 22:16:15,144 epoch 4 - iter 8/10 - loss 2.04774940 - samples/sec: 64.38 - lr: 0.100000\n",
"2022-04-28 22:16:15,698 epoch 4 - iter 9/10 - loss 1.99643935 - samples/sec: 57.97 - lr: 0.100000\n",
"2022-04-28 22:16:15,935 epoch 4 - iter 10/10 - loss 1.81641705 - samples/sec: 136.14 - lr: 0.100000\n",
"2022-04-28 22:16:15,936 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:15,937 EPOCH 4 done: loss 1.8164 - lr 0.1000000\n",
"2022-04-28 22:16:16,092 DEV : loss 0.8311207890510559 - score 0.0\n",
"2022-04-28 22:16:16,094 BAD EPOCHS (no improvement): 0\n",
"saving best model\n",
"2022-04-28 22:16:21,938 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:22,424 epoch 5 - iter 1/10 - loss 1.31467295 - samples/sec: 66.12 - lr: 0.100000\n",
"2022-04-28 22:16:22,852 epoch 5 - iter 2/10 - loss 1.87177873 - samples/sec: 74.94 - lr: 0.100000\n",
"2022-04-28 22:16:23,440 epoch 5 - iter 3/10 - loss 1.83717314 - samples/sec: 54.51 - lr: 0.100000\n",
"2022-04-28 22:16:23,991 epoch 5 - iter 4/10 - loss 2.06565040 - samples/sec: 58.18 - lr: 0.100000\n",
"2022-04-28 22:16:24,364 epoch 5 - iter 5/10 - loss 1.95749507 - samples/sec: 86.25 - lr: 0.100000\n",
"2022-04-28 22:16:24,832 epoch 5 - iter 6/10 - loss 1.84727591 - samples/sec: 68.67 - lr: 0.100000\n",
"2022-04-28 22:16:25,238 epoch 5 - iter 7/10 - loss 1.79978011 - samples/sec: 79.21 - lr: 0.100000\n",
"2022-04-28 22:16:25,679 epoch 5 - iter 8/10 - loss 1.69797329 - samples/sec: 72.73 - lr: 0.100000\n",
"2022-04-28 22:16:26,173 epoch 5 - iter 9/10 - loss 1.70765987 - samples/sec: 64.84 - lr: 0.100000\n",
"2022-04-28 22:16:26,364 epoch 5 - iter 10/10 - loss 1.76581790 - samples/sec: 169.31 - lr: 0.100000\n",
"2022-04-28 22:16:26,366 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:26,367 EPOCH 5 done: loss 1.7658 - lr 0.1000000\n",
"2022-04-28 22:16:26,509 DEV : loss 0.7797471880912781 - score 0.2222\n",
"2022-04-28 22:16:26,510 BAD EPOCHS (no improvement): 0\n",
"saving best model\n",
"2022-04-28 22:16:32,211 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:32,666 epoch 6 - iter 1/10 - loss 2.04772544 - samples/sec: 70.64 - lr: 0.100000\n",
"2022-04-28 22:16:33,172 epoch 6 - iter 2/10 - loss 1.61218661 - samples/sec: 63.37 - lr: 0.100000\n",
"2022-04-28 22:16:33,673 epoch 6 - iter 3/10 - loss 1.55716117 - samples/sec: 64.00 - lr: 0.100000\n",
"2022-04-28 22:16:34,183 epoch 6 - iter 4/10 - loss 1.54974008 - samples/sec: 62.87 - lr: 0.100000\n",
"2022-04-28 22:16:34,687 epoch 6 - iter 5/10 - loss 1.50827932 - samples/sec: 63.62 - lr: 0.100000\n",
"2022-04-28 22:16:35,155 epoch 6 - iter 6/10 - loss 1.46459270 - samples/sec: 68.52 - lr: 0.100000\n",
"2022-04-28 22:16:35,658 epoch 6 - iter 7/10 - loss 1.50249643 - samples/sec: 63.87 - lr: 0.100000\n",
"2022-04-28 22:16:36,094 epoch 6 - iter 8/10 - loss 1.51979375 - samples/sec: 73.56 - lr: 0.100000\n",
"2022-04-28 22:16:36,548 epoch 6 - iter 9/10 - loss 1.56509953 - samples/sec: 70.64 - lr: 0.100000\n",
"2022-04-28 22:16:36,744 epoch 6 - iter 10/10 - loss 1.55241492 - samples/sec: 164.10 - lr: 0.100000\n",
"2022-04-28 22:16:36,746 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:36,746 EPOCH 6 done: loss 1.5524 - lr 0.1000000\n",
"2022-04-28 22:16:36,884 DEV : loss 0.9345423579216003 - score 0.3333\n",
"2022-04-28 22:16:36,885 BAD EPOCHS (no improvement): 0\n",
"saving best model\n",
"2022-04-28 22:16:42,377 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:42,856 epoch 7 - iter 1/10 - loss 2.15539050 - samples/sec: 67.09 - lr: 0.100000\n",
"2022-04-28 22:16:43,336 epoch 7 - iter 2/10 - loss 1.68949413 - samples/sec: 66.95 - lr: 0.100000\n",
"2022-04-28 22:16:43,781 epoch 7 - iter 3/10 - loss 1.81478349 - samples/sec: 72.07 - lr: 0.100000\n",
"2022-04-28 22:16:44,241 epoch 7 - iter 4/10 - loss 1.68033907 - samples/sec: 69.87 - lr: 0.100000\n",
"2022-04-28 22:16:44,730 epoch 7 - iter 5/10 - loss 1.64062953 - samples/sec: 65.57 - lr: 0.100000\n",
"2022-04-28 22:16:45,227 epoch 7 - iter 6/10 - loss 1.59568199 - samples/sec: 64.78 - lr: 0.100000\n",
"2022-04-28 22:16:45,663 epoch 7 - iter 7/10 - loss 1.46137918 - samples/sec: 73.39 - lr: 0.100000\n",
"2022-04-28 22:16:46,169 epoch 7 - iter 8/10 - loss 1.41721664 - samples/sec: 63.36 - lr: 0.100000\n",
"2022-04-28 22:16:46,734 epoch 7 - iter 9/10 - loss 1.39811980 - samples/sec: 56.74 - lr: 0.100000\n",
"2022-04-28 22:16:46,937 epoch 7 - iter 10/10 - loss 1.38412433 - samples/sec: 159.20 - lr: 0.100000\n",
"2022-04-28 22:16:46,938 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:46,939 EPOCH 7 done: loss 1.3841 - lr 0.1000000\n",
"2022-04-28 22:16:47,081 DEV : loss 0.6798948049545288 - score 0.5\n",
"2022-04-28 22:16:47,083 BAD EPOCHS (no improvement): 0\n",
"saving best model\n",
"2022-04-28 22:16:52,628 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:53,137 epoch 8 - iter 1/10 - loss 1.08732188 - samples/sec: 63.12 - lr: 0.100000\n",
"2022-04-28 22:16:53,606 epoch 8 - iter 2/10 - loss 1.29048711 - samples/sec: 68.38 - lr: 0.100000\n",
"2022-04-28 22:16:54,039 epoch 8 - iter 3/10 - loss 1.04415214 - samples/sec: 74.07 - lr: 0.100000\n",
"2022-04-28 22:16:54,568 epoch 8 - iter 4/10 - loss 1.02857886 - samples/sec: 60.60 - lr: 0.100000\n",
"2022-04-28 22:16:55,148 epoch 8 - iter 5/10 - loss 1.26690668 - samples/sec: 55.27 - lr: 0.100000\n",
"2022-04-28 22:16:55,602 epoch 8 - iter 6/10 - loss 1.30797880 - samples/sec: 70.80 - lr: 0.100000\n",
"2022-04-28 22:16:56,075 epoch 8 - iter 7/10 - loss 1.22035806 - samples/sec: 67.72 - lr: 0.100000\n",
"2022-04-28 22:16:56,494 epoch 8 - iter 8/10 - loss 1.23306625 - samples/sec: 76.51 - lr: 0.100000\n",
"2022-04-28 22:16:56,933 epoch 8 - iter 9/10 - loss 1.18903442 - samples/sec: 73.15 - lr: 0.100000\n",
"2022-04-28 22:16:57,147 epoch 8 - iter 10/10 - loss 1.31105986 - samples/sec: 150.24 - lr: 0.100000\n",
"2022-04-28 22:16:57,148 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:16:57,149 EPOCH 8 done: loss 1.3111 - lr 0.1000000\n",
"2022-04-28 22:16:57,289 DEV : loss 0.5563207864761353 - score 0.5\n",
"2022-04-28 22:16:57,290 BAD EPOCHS (no improvement): 0\n",
"saving best model\n",
"2022-04-28 22:17:02,550 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:17:03,134 epoch 9 - iter 1/10 - loss 1.32691610 - samples/sec: 54.89 - lr: 0.100000\n",
"2022-04-28 22:17:03,595 epoch 9 - iter 2/10 - loss 1.16159409 - samples/sec: 69.57 - lr: 0.100000\n",
"2022-04-28 22:17:04,014 epoch 9 - iter 3/10 - loss 1.10929267 - samples/sec: 76.56 - lr: 0.100000\n",
"2022-04-28 22:17:04,518 epoch 9 - iter 4/10 - loss 1.05318102 - samples/sec: 63.62 - lr: 0.100000\n",
"2022-04-28 22:17:04,966 epoch 9 - iter 5/10 - loss 1.07275693 - samples/sec: 71.75 - lr: 0.100000\n",
"2022-04-28 22:17:05,432 epoch 9 - iter 6/10 - loss 1.02824855 - samples/sec: 68.82 - lr: 0.100000\n",
"2022-04-28 22:17:05,909 epoch 9 - iter 7/10 - loss 1.04051120 - samples/sec: 67.23 - lr: 0.100000\n",
"2022-04-28 22:17:06,404 epoch 9 - iter 8/10 - loss 1.00513531 - samples/sec: 64.78 - lr: 0.100000\n",
"2022-04-28 22:17:06,831 epoch 9 - iter 9/10 - loss 1.03960636 - samples/sec: 75.29 - lr: 0.100000\n",
"2022-04-28 22:17:07,019 epoch 9 - iter 10/10 - loss 1.07805606 - samples/sec: 171.12 - lr: 0.100000\n",
"2022-04-28 22:17:07,020 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:17:07,021 EPOCH 9 done: loss 1.0781 - lr 0.1000000\n",
"2022-04-28 22:17:07,151 DEV : loss 0.909138560295105 - score 0.7143\n",
"2022-04-28 22:17:07,153 BAD EPOCHS (no improvement): 0\n",
"saving best model\n",
"2022-04-28 22:17:12,454 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:17:12,906 epoch 10 - iter 1/10 - loss 1.49117911 - samples/sec: 70.96 - lr: 0.100000\n",
"2022-04-28 22:17:13,334 epoch 10 - iter 2/10 - loss 1.23203236 - samples/sec: 74.94 - lr: 0.100000\n",
"2022-04-28 22:17:13,789 epoch 10 - iter 3/10 - loss 1.12988973 - samples/sec: 70.48 - lr: 0.100000\n",
"2022-04-28 22:17:14,275 epoch 10 - iter 4/10 - loss 1.07148103 - samples/sec: 65.98 - lr: 0.100000\n",
"2022-04-28 22:17:14,795 epoch 10 - iter 5/10 - loss 1.08848752 - samples/sec: 61.66 - lr: 0.100000\n",
"2022-04-28 22:17:15,328 epoch 10 - iter 6/10 - loss 1.05938606 - samples/sec: 60.26 - lr: 0.100000\n",
"2022-04-28 22:17:15,730 epoch 10 - iter 7/10 - loss 1.00324091 - samples/sec: 79.80 - lr: 0.100000\n",
"2022-04-28 22:17:16,245 epoch 10 - iter 8/10 - loss 0.93657552 - samples/sec: 62.26 - lr: 0.100000\n",
"2022-04-28 22:17:16,681 epoch 10 - iter 9/10 - loss 0.95801387 - samples/sec: 73.56 - lr: 0.100000\n",
"2022-04-28 22:17:16,901 epoch 10 - iter 10/10 - loss 0.87346228 - samples/sec: 146.77 - lr: 0.100000\n",
"2022-04-28 22:17:16,902 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:17:16,903 EPOCH 10 done: loss 0.8735 - lr 0.1000000\n",
"2022-04-28 22:17:17,047 DEV : loss 0.5443210601806641 - score 0.7143\n",
"2022-04-28 22:17:17,050 BAD EPOCHS (no improvement): 0\n",
"saving best model\n",
"2022-04-28 22:17:27,557 ----------------------------------------------------------------------------------------------------\n",
"2022-04-28 22:17:27,557 Testing using best model ...\n",
"2022-04-28 22:17:27,566 loading file slot-model\\best-model.pt\n",
"2022-04-28 22:17:33,102 0.6429\t0.4500\t0.5294\n",
"2022-04-28 22:17:33,103 \n",
"Results:\n",
"- F1-score (micro) 0.5294\n",
"- F1-score (macro) 0.4533\n",
"\n",
"By class:\n",
"area tp: 0 - fp: 0 - fn: 1 - precision: 0.0000 - recall: 0.0000 - f1-score: 0.0000\n",
"date tp: 1 - fp: 1 - fn: 0 - precision: 0.5000 - recall: 1.0000 - f1-score: 0.6667\n",
"quantity tp: 3 - fp: 1 - fn: 3 - precision: 0.7500 - recall: 0.5000 - f1-score: 0.6000\n",
"time tp: 2 - fp: 2 - fn: 4 - precision: 0.5000 - recall: 0.3333 - f1-score: 0.4000\n",
"title tp: 3 - fp: 1 - fn: 3 - precision: 0.7500 - recall: 0.5000 - f1-score: 0.6000\n",
"2022-04-28 22:17:33,104 ----------------------------------------------------------------------------------------------------\n"
]
},
{
"data": {
"text/plain": [
"{'test_score': 0.5294117647058824,\n",
" 'dev_score_history': [0.0,\n",
" 0.0,\n",
" 0.0,\n",
" 0.0,\n",
" 0.2222222222222222,\n",
" 0.3333333333333333,\n",
" 0.5,\n",
" 0.5,\n",
" 0.7142857142857143,\n",
" 0.7142857142857143],\n",
" 'train_loss_history': [5.108907747268677,\n",
" 2.6376620531082153,\n",
" 2.0286046147346495,\n",
" 1.816417047381401,\n",
" 1.7658178985118866,\n",
" 1.5524149179458617,\n",
" 1.384124332666397,\n",
" 1.3110598623752594,\n",
" 1.0780560612678527,\n",
" 0.8734622806310653],\n",
" 'dev_loss_history': [1.1116931438446045,\n",
" 1.2027416229248047,\n",
" 0.9265440702438354,\n",
" 0.8311207890510559,\n",
" 0.7797471880912781,\n",
" 0.9345423579216003,\n",
" 0.6798948049545288,\n",
" 0.5563207864761353,\n",
" 0.909138560295105,\n",
" 0.5443210601806641]}"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = ModelTrainer(tagger, corpus)\n",
"trainer.train('slot-model',\n",
" learning_rate=0.1,\n",
" mini_batch_size=32,\n",
" max_epochs=10,\n",
" train_with_dev=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Jakość wyuczonego modelu możemy ocenić, korzystając z zaraportowanych powyżej metryk, tj.:\n",
"\n",
" - *tp (true positives)*\n",
"\n",
" > liczba słów oznaczonych w zbiorze testowym etykietą $e$, które model oznaczył tą etykietą\n",
"\n",
" - *fp (false positives)*\n",
"\n",
" > liczba słów nieoznaczonych w zbiorze testowym etykietą $e$, które model oznaczył tą etykietą\n",
"\n",
" - *fn (false negatives)*\n",
"\n",
" > liczba słów oznaczonych w zbiorze testowym etykietą $e$, którym model nie nadał etykiety $e$\n",
"\n",
" - *precision*\n",
"\n",
" > $$\\frac{tp}{tp + fp}$$\n",
"\n",
" - *recall*\n",
"\n",
" > $$\\frac{tp}{tp + fn}$$\n",
"\n",
" - $F_1$\n",
"\n",
" > $$\\frac{2 \\cdot precision \\cdot recall}{precision + recall}$$\n",
"\n",
" - *micro* $F_1$\n",
"\n",
" > $F_1$ w którym $tp$, $fp$ i $fn$ są liczone łącznie dla wszystkich etykiet, tj. $tp = \\sum_{e}{{tp}_e}$, $fn = \\sum_{e}{{fn}_e}$, $fp = \\sum_{e}{{fp}_e}$\n",
"\n",
" - *macro* $F_1$\n",
"\n",
" > średnia arytmetyczna z $F_1$ obliczonych dla poszczególnych etykiet z osobna.\n",
"\n",
"Wyuczony model możemy wczytać z pliku korzystając z metody `load`."
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-04-28 22:17:33,278 loading file slot-model/final-model.pt\n"
]
}
],
"source": [
"model = SequenceTagger.load('slot-model/final-model.pt')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Wczytany model możemy wykorzystać do przewidywania slotów w wypowiedziach użytkownika, korzystając\n",
"z przedstawionej poniżej funkcji `predict`."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"def predict(model, sentence):\n",
" csentence = [{'form': word} for word in sentence]\n",
" fsentence = conllu2flair([csentence])[0]\n",
" model.predict(fsentence)\n",
" return [(token, ftoken.get_tag('slot').value) for token, ftoken in zip(sentence, fsentence)]\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Jak pokazuje przykład poniżej model wyuczony tylko na 100 przykładach popełnia w dosyć prostej\n",
"wypowiedzi błąd etykietując słowo `alarm` tagiem `B-weather/noun`."
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tbody>\n",
"<tr><td>co </td><td>O</td></tr>\n",
"<tr><td>gracie </td><td>O</td></tr>\n",
"<tr><td>popołudniu</td><td>O</td></tr>\n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"'<table>\\n<tbody>\\n<tr><td>co </td><td>O</td></tr>\\n<tr><td>gracie </td><td>O</td></tr>\\n<tr><td>popołudniu</td><td>O</td></tr>\\n</tbody>\\n</table>'"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tabulate(predict(model, 'batman'.split()), tablefmt='html')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Literatura\n",
"----------\n",
" 1. Sebastian Schuster, Sonal Gupta, Rushin Shah, Mike Lewis, Cross-lingual Transfer Learning for Multilingual Task Oriented Dialog. NAACL-HLT (1) 2019, pp. 3795-3805\n",
" 2. John D. Lafferty, Andrew McCallum, and Fernando C. N. Pereira. 2001. Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data. In Proceedings of the Eighteenth International Conference on Machine Learning (ICML '01). Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, 282289, https://repository.upenn.edu/cgi/viewcontent.cgi?article=1162&context=cis_papers\n",
" 3. Sepp Hochreiter and Jürgen Schmidhuber. 1997. Long Short-Term Memory. Neural Comput. 9, 8 (November 15, 1997), 17351780, https://doi.org/10.1162/neco.1997.9.8.1735\n",
" 4. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, Attention is All you Need, NIPS 2017, pp. 5998-6008, https://arxiv.org/abs/1706.03762\n",
" 5. Alan Akbik, Duncan Blythe, Roland Vollgraf, Contextual String Embeddings for Sequence Labeling, Proceedings of the 27th International Conference on Computational Linguistics, pp. 16381649, https://www.aclweb.org/anthology/C18-1139.pdf\n"
]
}
],
"metadata": {
"author": "Marek Kubis",
"email": "mkubis@amu.edu.pl",
"interpreter": {
"hash": "2be5faf79681da6f2a61fdfdd5405d65d042280f7fba6178067603e3a2925119"
},
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "Python 3.10.4 64-bit",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
"subtitle": "8.Parsing semantyczny z wykorzystaniem technik uczenia maszynowego[laboratoria]",
"title": "Systemy Dialogowe",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}