systemy_dialogowe/notebooks/08-parsing-semantyczny-uczenie.ipynb

1439 lines
61 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "68bc3d74",
"metadata": {},
"source": [
"Parsing semantyczny z wykorzystaniem technik uczenia maszynowego\n",
"================================================================\n",
"\n",
"Wprowadzenie\n",
"------------\n",
"Problem wykrywania slotów i ich wartości w wypowiedziach użytkownika można sformułować jako zadanie\n",
"polegające na przewidywaniu dla poszczególnych słów etykiet wskazujących na to czy i do jakiego\n",
"slotu dane słowo należy.\n",
"\n",
"<pre>chciałbym zarezerwować stolik na jutro<b>/day</b> na godzinę dwunastą<b>/hour</b> czterdzieści<b>/hour</b> pięć<b>/hour</b> na pięć<b>/size</b> osób</pre>\n",
"\n",
"Granice slotów oznacza się korzystając z wybranego schematu etykietowania.\n",
"\n",
"### Schemat IOB\n",
"\n",
"| Prefix | Znaczenie |\n",
"|:------:|:---------------------------|\n",
"| I | wnętrze slotu (inside) |\n",
"| O | poza slotem (outside) |\n",
"| B | początek slotu (beginning) |\n",
"\n",
"<pre>chciałbym zarezerwować stolik na jutro<b>/B-day</b> na godzinę dwunastą<b>/B-hour</b> czterdzieści<b>/I-hour</b> pięć<b>/I-hour</b> na pięć<b>/B-size</b> osób</pre>\n",
"\n",
"### Schemat IOBES\n",
"\n",
"| Prefix | Znaczenie |\n",
"|:------:|:---------------------------|\n",
"| I | wnętrze slotu (inside) |\n",
"| O | poza slotem (outside) |\n",
"| B | początek slotu (beginning) |\n",
"| E | koniec slotu (ending) |\n",
"| S | pojedyncze słowo (single) |\n",
"\n",
"<pre>chciałbym zarezerwować stolik na jutro<b>/S-day</b> na godzinę dwunastą<b>/B-hour</b> czterdzieści<b>/I-hour</b> pięć<b>/E-hour</b> na pięć<b>/S-size</b> osób</pre>\n",
"\n",
"Jeżeli dla tak sformułowanego zadania przygotujemy zbiór danych\n",
"złożony z wypowiedzi użytkownika z oznaczonymi slotami (tzw. *zbiór uczący*),\n",
"to możemy zastosować techniki (nadzorowanego) uczenia maszynowego w celu zbudowania modelu\n",
"annotującego wypowiedzi użytkownika etykietami slotów.\n",
"\n",
"Do zbudowania takiego modelu można wykorzystać między innymi:\n",
"\n",
" 1. warunkowe pola losowe (Lafferty i in.; 2001),\n",
"\n",
" 2. rekurencyjne sieci neuronowe, np. sieci LSTM (Hochreiter i Schmidhuber; 1997),\n",
"\n",
" 3. transformery (Vaswani i in., 2017).\n",
"\n",
"Przykład\n",
"--------\n",
"Skorzystamy ze zbioru danych przygotowanego przez Schustera (2019)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8cca8cd1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"g:\\studia\\studia uam\\systemy dialogowe\\notebooks\\l07\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"\n",
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
"\n",
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
" 14 8714k 14 1307k 0 0 860k 0 0:00:10 0:00:01 0:00:09 1354k\n",
"100 8714k 100 8714k 0 0 4650k 0 0:00:01 0:00:01 --:--:-- 6606k\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"g:\\studia\\studia uam\\systemy dialogowe\\notebooks\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"'unzip' is not recognized as an internal or external command,\n",
"operable program or batch file.\n"
]
}
],
"source": [
"!mkdir -p l07\n",
"%cd l07\n",
"!curl -L -C - https://fb.me/multilingual_task_oriented_data -o data.zip\n",
"!unzip data.zip\n",
"%cd .."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"'Expand-Archive' is not recognized as an internal or external command,\n",
"operable program or batch file.\n"
]
}
],
"source": []
},
{
"cell_type": "markdown",
"id": "56d91f6c",
"metadata": {},
"source": [
"Zbiór ten gromadzi wypowiedzi w trzech językach opisane slotami dla dwunastu ram należących do trzech dziedzin `Alarm`, `Reminder` oraz `Weather`. Dane wczytamy korzystając z biblioteki [conllu](https://pypi.org/project/conllu/)."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "18b9a032",
"metadata": {},
"outputs": [],
"source": [
"from conllu import parse_incr\n",
"fields = ['id', 'form', 'frame', 'slot']\n",
"\n",
"def nolabel2o(line, i):\n",
" return 'O' if line[i] == 'NoLabel' else line[i]\n",
"\n",
"with open('l07/en/train-en.conllu') as trainfile:\n",
" trainset = list(parse_incr(trainfile, fields=fields, field_parsers={'slot': nolabel2o}))\n",
"with open('l07/en/test-en.conllu') as testfile:\n",
" testset = list(parse_incr(testfile, fields=fields, field_parsers={'slot': nolabel2o}))"
]
},
{
"cell_type": "markdown",
"id": "7477593e",
"metadata": {},
"source": [
"Zobaczmy kilka przykładowych wypowiedzi z tego zbioru."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b2799ad2",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tbody>\n",
"<tr><td style=\"text-align: right;\">1</td><td>tell </td><td>weather/find</td><td>O </td></tr>\n",
"<tr><td style=\"text-align: right;\">2</td><td>me </td><td>weather/find</td><td>O </td></tr>\n",
"<tr><td style=\"text-align: right;\">3</td><td>the </td><td>weather/find</td><td>O </td></tr>\n",
"<tr><td style=\"text-align: right;\">4</td><td>weather</td><td>weather/find</td><td>B-weather/noun</td></tr>\n",
"<tr><td style=\"text-align: right;\">5</td><td>report </td><td>weather/find</td><td>I-weather/noun</td></tr>\n",
"<tr><td style=\"text-align: right;\">6</td><td>for </td><td>weather/find</td><td>O </td></tr>\n",
"<tr><td style=\"text-align: right;\">7</td><td>half </td><td>weather/find</td><td>B-location </td></tr>\n",
"<tr><td style=\"text-align: right;\">8</td><td>moon </td><td>weather/find</td><td>I-location </td></tr>\n",
"<tr><td style=\"text-align: right;\">9</td><td>bay </td><td>weather/find</td><td>I-location </td></tr>\n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"'<table>\\n<tbody>\\n<tr><td style=\"text-align: right;\">1</td><td>tell </td><td>weather/find</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">2</td><td>me </td><td>weather/find</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">3</td><td>the </td><td>weather/find</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">4</td><td>weather</td><td>weather/find</td><td>B-weather/noun</td></tr>\\n<tr><td style=\"text-align: right;\">5</td><td>report </td><td>weather/find</td><td>I-weather/noun</td></tr>\\n<tr><td style=\"text-align: right;\">6</td><td>for </td><td>weather/find</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">7</td><td>half </td><td>weather/find</td><td>B-location </td></tr>\\n<tr><td style=\"text-align: right;\">8</td><td>moon </td><td>weather/find</td><td>I-location </td></tr>\\n<tr><td style=\"text-align: right;\">9</td><td>bay </td><td>weather/find</td><td>I-location </td></tr>\\n</tbody>\\n</table>'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from tabulate import tabulate\n",
"tabulate(trainset[0], tablefmt='html')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ba2c2706",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tbody>\n",
"<tr><td style=\"text-align: right;\">1</td><td>remind</td><td>reminder/set_reminder</td><td>O </td></tr>\n",
"<tr><td style=\"text-align: right;\">2</td><td>me </td><td>reminder/set_reminder</td><td>O </td></tr>\n",
"<tr><td style=\"text-align: right;\">3</td><td>about </td><td>reminder/set_reminder</td><td>O </td></tr>\n",
"<tr><td style=\"text-align: right;\">4</td><td>game </td><td>reminder/set_reminder</td><td>B-reminder/todo</td></tr>\n",
"<tr><td style=\"text-align: right;\">5</td><td>night </td><td>reminder/set_reminder</td><td>I-reminder/todo</td></tr>\n",
"<tr><td style=\"text-align: right;\">6</td><td>on </td><td>reminder/set_reminder</td><td>B-datetime </td></tr>\n",
"<tr><td style=\"text-align: right;\">7</td><td>friday</td><td>reminder/set_reminder</td><td>I-datetime </td></tr>\n",
"<tr><td style=\"text-align: right;\">8</td><td>at </td><td>reminder/set_reminder</td><td>I-datetime </td></tr>\n",
"<tr><td style=\"text-align: right;\">9</td><td>4pm </td><td>reminder/set_reminder</td><td>I-datetime </td></tr>\n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"'<table>\\n<tbody>\\n<tr><td style=\"text-align: right;\">1</td><td>remind</td><td>reminder/set_reminder</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">2</td><td>me </td><td>reminder/set_reminder</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">3</td><td>about </td><td>reminder/set_reminder</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">4</td><td>game </td><td>reminder/set_reminder</td><td>B-reminder/todo</td></tr>\\n<tr><td style=\"text-align: right;\">5</td><td>night </td><td>reminder/set_reminder</td><td>I-reminder/todo</td></tr>\\n<tr><td style=\"text-align: right;\">6</td><td>on </td><td>reminder/set_reminder</td><td>B-datetime </td></tr>\\n<tr><td style=\"text-align: right;\">7</td><td>friday</td><td>reminder/set_reminder</td><td>I-datetime </td></tr>\\n<tr><td style=\"text-align: right;\">8</td><td>at </td><td>reminder/set_reminder</td><td>I-datetime </td></tr>\\n<tr><td style=\"text-align: right;\">9</td><td>4pm </td><td>reminder/set_reminder</td><td>I-datetime </td></tr>\\n</tbody>\\n</table>'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tabulate(trainset[1000], tablefmt='html')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b5c9db18",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tbody>\n",
"<tr><td style=\"text-align: right;\">1</td><td>set </td><td>alarm/set_alarm</td><td>O </td></tr>\n",
"<tr><td style=\"text-align: right;\">2</td><td>alarm </td><td>alarm/set_alarm</td><td>O </td></tr>\n",
"<tr><td style=\"text-align: right;\">3</td><td>for </td><td>alarm/set_alarm</td><td>B-datetime</td></tr>\n",
"<tr><td style=\"text-align: right;\">4</td><td>20 </td><td>alarm/set_alarm</td><td>I-datetime</td></tr>\n",
"<tr><td style=\"text-align: right;\">5</td><td>minutes</td><td>alarm/set_alarm</td><td>I-datetime</td></tr>\n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"'<table>\\n<tbody>\\n<tr><td style=\"text-align: right;\">1</td><td>set </td><td>alarm/set_alarm</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">2</td><td>alarm </td><td>alarm/set_alarm</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">3</td><td>for </td><td>alarm/set_alarm</td><td>B-datetime</td></tr>\\n<tr><td style=\"text-align: right;\">4</td><td>20 </td><td>alarm/set_alarm</td><td>I-datetime</td></tr>\\n<tr><td style=\"text-align: right;\">5</td><td>minutes</td><td>alarm/set_alarm</td><td>I-datetime</td></tr>\\n</tbody>\\n</table>'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tabulate(trainset[2000], tablefmt='html')"
]
},
{
"cell_type": "markdown",
"id": "0f35074d",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"Na potrzeby prezentacji procesu uczenia w jupyterowym notatniku zawęzimy zbiór danych do początkowych przykładów."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f735ca85",
"metadata": {},
"outputs": [],
"source": [
"trainset = trainset[:300]\n",
"testset = testset[:300]"
]
},
{
"cell_type": "markdown",
"id": "66284486",
"metadata": {},
"source": [
"Budując model skorzystamy z architektury opartej o rekurencyjne sieci neuronowe\n",
"zaimplementowanej w bibliotece [flair](https://github.com/flairNLP/flair) (Akbik i in. 2018)."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f3e30f81",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from flair.data import Corpus, Sentence, Token\n",
"from flair.datasets import FlairDatapointDataset\n",
"from flair.embeddings import StackedEmbeddings\n",
"from flair.embeddings import WordEmbeddings\n",
"from flair.embeddings import CharacterEmbeddings\n",
"from flair.embeddings import FlairEmbeddings\n",
"from flair.models import SequenceTagger\n",
"from flair.trainers import ModelTrainer\n",
"\n",
"# determinizacja obliczeń\n",
"import random\n",
"import torch\n",
"random.seed(42)\n",
"torch.manual_seed(42)\n",
"\n",
"if torch.cuda.is_available():\n",
" torch.cuda.manual_seed(0)\n",
" torch.cuda.manual_seed_all(0)\n",
" torch.backends.cudnn.enabled = False\n",
" torch.backends.cudnn.benchmark = False\n",
" torch.backends.cudnn.deterministic = True"
]
},
{
"cell_type": "markdown",
"id": "c1a33987",
"metadata": {},
"source": [
"Dane skonwertujemy do formatu wykorzystywanego przez `flair`, korzystając z następującej funkcji."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f3c47593",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Corpus: 270 train + 30 dev + 300 test sentences\n",
"2023-04-16 23:32:52,141 Computing label dictionary. Progress:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"270it [00:00, 53998.76it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:32:52,151 Dictionary created for label 'slot' with 5 values: datetime (seen 174 times), weather/attribute (seen 77 times), weather/noun (seen 66 times), location (seen 44 times)\n",
"Dictionary with 5 tags: <unk>, datetime, weather/attribute, weather/noun, location\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"def conllu2flair(sentences, label=None):\n",
" fsentences = []\n",
"\n",
" for sentence in sentences:\n",
" fsentence = Sentence(' '.join(token['form'] for token in sentence), use_tokenizer=False)\n",
" start_idx = None\n",
" end_idx = None\n",
" tag = None\n",
"\n",
" if label:\n",
" for idx, (token, ftoken) in enumerate(zip(sentence, fsentence)):\n",
" if token[label].startswith('B-'):\n",
" start_idx = idx\n",
" end_idx = idx\n",
" tag = token[label][2:]\n",
" elif token[label].startswith('I-'):\n",
" end_idx = idx\n",
" elif token[label] == 'O':\n",
" if start_idx is not None:\n",
" fsentence[start_idx:end_idx+1].add_label(label, tag)\n",
" start_idx = None\n",
" end_idx = None\n",
" tag = None\n",
"\n",
" if start_idx is not None:\n",
" fsentence[start_idx:end_idx+1].add_label(label, tag)\n",
"\n",
" fsentences.append(fsentence)\n",
"\n",
" return FlairDatapointDataset(fsentences)\n",
"\n",
"corpus = Corpus(train=conllu2flair(trainset, 'slot'), test=conllu2flair(testset, 'slot'))\n",
"print(corpus)\n",
"tag_dictionary = corpus.make_label_dictionary(label_type='slot')\n",
"print(tag_dictionary)"
]
},
{
"cell_type": "markdown",
"id": "0ed59fb2",
"metadata": {},
"source": [
"Nasz model będzie wykorzystywał wektorowe reprezentacje słów (zob. [Word Embeddings](https://github.com/flairNLP/flair/blob/master/resources/docs/TUTORIAL_EMBEDDINGS_OVERVIEW.md))."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "408cf961",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:32:52,457 https://flair.informatik.hu-berlin.de/resources/embeddings/token/en-fasttext-news-300d-1M.vectors.npy not found in cache, downloading to C:\\Users\\macty\\AppData\\Local\\Temp\\tmpxx1l309v\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1.12G/1.12G [01:03<00:00, 18.8MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:33:56,277 copying C:\\Users\\macty\\AppData\\Local\\Temp\\tmpxx1l309v to cache at C:\\Users\\macty\\.flair\\embeddings\\en-fasttext-news-300d-1M.vectors.npy\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:33:57,359 removing temp file C:\\Users\\macty\\AppData\\Local\\Temp\\tmpxx1l309v\n",
"2023-04-16 23:33:58,068 https://flair.informatik.hu-berlin.de/resources/embeddings/token/en-fasttext-news-300d-1M not found in cache, downloading to C:\\Users\\macty\\AppData\\Local\\Temp\\tmpnqnow7h9\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 52.1M/52.1M [00:02<00:00, 18.9MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:01,079 copying C:\\Users\\macty\\AppData\\Local\\Temp\\tmpnqnow7h9 to cache at C:\\Users\\macty\\.flair\\embeddings\\en-fasttext-news-300d-1M\n",
"2023-04-16 23:34:01,116 removing temp file C:\\Users\\macty\\AppData\\Local\\Temp\\tmpnqnow7h9\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:09,948 https://flair.informatik.hu-berlin.de/resources/embeddings/flair/news-forward-0.4.1.pt not found in cache, downloading to C:\\Users\\macty\\AppData\\Local\\Temp\\tmpf2iq2swn\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 69.7M/69.7M [00:03<00:00, 21.2MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:13,532 copying C:\\Users\\macty\\AppData\\Local\\Temp\\tmpf2iq2swn to cache at C:\\Users\\macty\\.flair\\embeddings\\news-forward-0.4.1.pt\n",
"2023-04-16 23:34:13,579 removing temp file C:\\Users\\macty\\AppData\\Local\\Temp\\tmpf2iq2swn\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:14,221 https://flair.informatik.hu-berlin.de/resources/embeddings/flair/news-backward-0.4.1.pt not found in cache, downloading to C:\\Users\\macty\\AppData\\Local\\Temp\\tmp1i_pq3dr\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 69.7M/69.7M [00:03<00:00, 22.5MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:17,584 copying C:\\Users\\macty\\AppData\\Local\\Temp\\tmp1i_pq3dr to cache at C:\\Users\\macty\\.flair\\embeddings\\news-backward-0.4.1.pt\n",
"2023-04-16 23:34:17,631 removing temp file C:\\Users\\macty\\AppData\\Local\\Temp\\tmp1i_pq3dr\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:18,016 https://flair.informatik.hu-berlin.de/resources/characters/common_characters not found in cache, downloading to C:\\Users\\macty\\AppData\\Local\\Temp\\tmpqyh0003q\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 2.82k/2.82k [00:00<00:00, 1.44MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:18,139 copying C:\\Users\\macty\\AppData\\Local\\Temp\\tmpqyh0003q to cache at C:\\Users\\macty\\.flair\\datasets\\common_characters\n",
"2023-04-16 23:34:18,145 removing temp file C:\\Users\\macty\\AppData\\Local\\Temp\\tmpqyh0003q\n",
"2023-04-16 23:34:18,147 SequenceTagger predicts: Dictionary with 17 tags: O, S-datetime, B-datetime, E-datetime, I-datetime, S-weather/attribute, B-weather/attribute, E-weather/attribute, I-weather/attribute, S-weather/noun, B-weather/noun, E-weather/noun, I-weather/noun, S-location, B-location, E-location, I-location\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"embedding_types = [\n",
" WordEmbeddings('en'),\n",
" FlairEmbeddings('en-forward'),\n",
" FlairEmbeddings('en-backward'),\n",
" CharacterEmbeddings(),\n",
"]\n",
"\n",
"embeddings = StackedEmbeddings(embeddings=embedding_types)\n",
"tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,\n",
" tag_dictionary=tag_dictionary,\n",
" tag_type='slot', use_crf=True)"
]
},
{
"cell_type": "markdown",
"id": "ab634218",
"metadata": {},
"source": [
"Zobaczmy jak wygląda architektura sieci neuronowej, która będzie odpowiedzialna za przewidywanie\n",
"slotów w wypowiedziach."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "04d0bbf3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SequenceTagger(\n",
" (embeddings): StackedEmbeddings(\n",
" (list_embedding_0): WordEmbeddings(\n",
" 'en'\n",
" (embedding): Embedding(1000001, 300)\n",
" )\n",
" (list_embedding_1): FlairEmbeddings(\n",
" (lm): LanguageModel(\n",
" (drop): Dropout(p=0.05, inplace=False)\n",
" (encoder): Embedding(300, 100)\n",
" (rnn): LSTM(100, 2048)\n",
" )\n",
" )\n",
" (list_embedding_2): FlairEmbeddings(\n",
" (lm): LanguageModel(\n",
" (drop): Dropout(p=0.05, inplace=False)\n",
" (encoder): Embedding(300, 100)\n",
" (rnn): LSTM(100, 2048)\n",
" )\n",
" )\n",
" (list_embedding_3): CharacterEmbeddings(\n",
" (char_embedding): Embedding(275, 25)\n",
" (char_rnn): LSTM(25, 25, bidirectional=True)\n",
" )\n",
" )\n",
" (word_dropout): WordDropout(p=0.05)\n",
" (locked_dropout): LockedDropout(p=0.5)\n",
" (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True)\n",
" (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True)\n",
" (linear): Linear(in_features=512, out_features=19, bias=True)\n",
" (loss_function): ViterbiLoss()\n",
" (crf): CRF()\n",
")\n"
]
}
],
"source": [
"print(tagger)"
]
},
{
"cell_type": "markdown",
"id": "8e0da880",
"metadata": {},
"source": [
"Wykonamy dziesięć iteracji (epok) uczenia a wynikowy model zapiszemy w katalogu `slot-model`."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "0fd2b573",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:18,357 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:34:18,358 Model: \"SequenceTagger(\n",
" (embeddings): StackedEmbeddings(\n",
" (list_embedding_0): WordEmbeddings(\n",
" 'en'\n",
" (embedding): Embedding(1000001, 300)\n",
" )\n",
" (list_embedding_1): FlairEmbeddings(\n",
" (lm): LanguageModel(\n",
" (drop): Dropout(p=0.05, inplace=False)\n",
" (encoder): Embedding(300, 100)\n",
" (rnn): LSTM(100, 2048)\n",
" )\n",
" )\n",
" (list_embedding_2): FlairEmbeddings(\n",
" (lm): LanguageModel(\n",
" (drop): Dropout(p=0.05, inplace=False)\n",
" (encoder): Embedding(300, 100)\n",
" (rnn): LSTM(100, 2048)\n",
" )\n",
" )\n",
" (list_embedding_3): CharacterEmbeddings(\n",
" (char_embedding): Embedding(275, 25)\n",
" (char_rnn): LSTM(25, 25, bidirectional=True)\n",
" )\n",
" )\n",
" (word_dropout): WordDropout(p=0.05)\n",
" (locked_dropout): LockedDropout(p=0.5)\n",
" (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True)\n",
" (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True)\n",
" (linear): Linear(in_features=512, out_features=19, bias=True)\n",
" (loss_function): ViterbiLoss()\n",
" (crf): CRF()\n",
")\"\n",
"2023-04-16 23:34:18,359 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:34:18,360 Corpus: \"Corpus: 270 train + 30 dev + 300 test sentences\"\n",
"2023-04-16 23:34:18,360 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:34:18,361 Parameters:\n",
"2023-04-16 23:34:18,361 - learning_rate: \"0.100000\"\n",
"2023-04-16 23:34:18,361 - mini_batch_size: \"32\"\n",
"2023-04-16 23:34:18,362 - patience: \"3\"\n",
"2023-04-16 23:34:18,363 - anneal_factor: \"0.5\"\n",
"2023-04-16 23:34:18,363 - max_epochs: \"10\"\n",
"2023-04-16 23:34:18,363 - shuffle: \"True\"\n",
"2023-04-16 23:34:18,364 - train_with_dev: \"False\"\n",
"2023-04-16 23:34:18,364 - batch_growth_annealing: \"False\"\n",
"2023-04-16 23:34:18,365 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:34:18,365 Model training base path: \"slot-model\"\n",
"2023-04-16 23:34:18,366 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:34:18,367 Device: cpu\n",
"2023-04-16 23:34:18,367 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:34:18,368 Embeddings storage mode: cpu\n",
"2023-04-16 23:34:18,368 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:34:20,279 epoch 1 - iter 1/9 - loss 3.46586463 - time (sec): 1.91 - samples/sec: 152.88 - lr: 0.100000\n",
"2023-04-16 23:34:22,063 epoch 1 - iter 2/9 - loss 3.09300251 - time (sec): 3.69 - samples/sec: 153.49 - lr: 0.100000\n",
"2023-04-16 23:34:23,066 epoch 1 - iter 3/9 - loss 3.09930113 - time (sec): 4.70 - samples/sec: 154.78 - lr: 0.100000\n",
"2023-04-16 23:34:24,090 epoch 1 - iter 4/9 - loss 2.95447677 - time (sec): 5.72 - samples/sec: 155.39 - lr: 0.100000\n",
"2023-04-16 23:34:24,981 epoch 1 - iter 5/9 - loss 2.81596711 - time (sec): 6.61 - samples/sec: 157.92 - lr: 0.100000\n",
"2023-04-16 23:34:25,968 epoch 1 - iter 6/9 - loss 2.61288632 - time (sec): 7.60 - samples/sec: 156.47 - lr: 0.100000\n",
"2023-04-16 23:34:26,968 epoch 1 - iter 7/9 - loss 2.37758500 - time (sec): 8.60 - samples/sec: 153.85 - lr: 0.100000\n",
"2023-04-16 23:34:27,959 epoch 1 - iter 8/9 - loss 2.29585029 - time (sec): 9.59 - samples/sec: 153.81 - lr: 0.100000\n",
"2023-04-16 23:34:28,676 epoch 1 - iter 9/9 - loss 2.27926901 - time (sec): 10.31 - samples/sec: 150.87 - lr: 0.100000\n",
"2023-04-16 23:34:28,678 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:34:28,678 EPOCH 1 done: loss 2.2793 - lr 0.100000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:01<00:00, 1.12s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:29,797 Evaluating as a multi-label problem: False\n",
"2023-04-16 23:34:29,806 DEV : loss 1.2717913389205933 - f1-score (micro avg) 0.0357\n",
"2023-04-16 23:34:29,808 BAD EPOCHS (no improvement): 0\n",
"2023-04-16 23:34:29,809 saving best model\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:41,288 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:34:42,050 epoch 2 - iter 1/9 - loss 1.43326204 - time (sec): 0.76 - samples/sec: 241.47 - lr: 0.100000\n",
"2023-04-16 23:34:42,846 epoch 2 - iter 2/9 - loss 1.40312603 - time (sec): 1.56 - samples/sec: 234.92 - lr: 0.100000\n",
"2023-04-16 23:34:43,571 epoch 2 - iter 3/9 - loss 1.39219711 - time (sec): 2.28 - samples/sec: 235.65 - lr: 0.100000\n",
"2023-04-16 23:34:44,338 epoch 2 - iter 4/9 - loss 1.38006975 - time (sec): 3.05 - samples/sec: 232.13 - lr: 0.100000\n",
"2023-04-16 23:34:45,066 epoch 2 - iter 5/9 - loss 1.34341906 - time (sec): 3.78 - samples/sec: 232.93 - lr: 0.100000\n",
"2023-04-16 23:34:45,850 epoch 2 - iter 6/9 - loss 1.31603716 - time (sec): 4.56 - samples/sec: 234.55 - lr: 0.100000\n",
"2023-04-16 23:34:46,612 epoch 2 - iter 7/9 - loss 1.28879818 - time (sec): 5.32 - samples/sec: 237.23 - lr: 0.100000\n",
"2023-04-16 23:34:47,400 epoch 2 - iter 8/9 - loss 1.27353642 - time (sec): 6.11 - samples/sec: 238.71 - lr: 0.100000\n",
"2023-04-16 23:34:47,821 epoch 2 - iter 9/9 - loss 1.27439781 - time (sec): 6.53 - samples/sec: 238.02 - lr: 0.100000\n",
"2023-04-16 23:34:47,823 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:34:47,823 EPOCH 2 done: loss 1.2744 - lr 0.100000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 4.76it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:48,036 Evaluating as a multi-label problem: False\n",
"2023-04-16 23:34:48,044 DEV : loss 0.7916695475578308 - f1-score (micro avg) 0.0\n",
"2023-04-16 23:34:48,046 BAD EPOCHS (no improvement): 1\n",
"2023-04-16 23:34:48,047 ----------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:48,750 epoch 3 - iter 1/9 - loss 0.95985486 - time (sec): 0.70 - samples/sec: 266.38 - lr: 0.100000\n",
"2023-04-16 23:34:49,421 epoch 3 - iter 2/9 - loss 0.93650848 - time (sec): 1.37 - samples/sec: 258.56 - lr: 0.100000\n",
"2023-04-16 23:34:50,208 epoch 3 - iter 3/9 - loss 0.91011594 - time (sec): 2.16 - samples/sec: 249.54 - lr: 0.100000\n",
"2023-04-16 23:34:50,909 epoch 3 - iter 4/9 - loss 0.89886124 - time (sec): 2.86 - samples/sec: 249.21 - lr: 0.100000\n",
"2023-04-16 23:34:51,580 epoch 3 - iter 5/9 - loss 0.87307849 - time (sec): 3.53 - samples/sec: 251.48 - lr: 0.100000\n",
"2023-04-16 23:34:52,375 epoch 3 - iter 6/9 - loss 0.85130603 - time (sec): 4.33 - samples/sec: 251.67 - lr: 0.100000\n",
"2023-04-16 23:34:53,161 epoch 3 - iter 7/9 - loss 0.83100431 - time (sec): 5.11 - samples/sec: 249.75 - lr: 0.100000\n",
"2023-04-16 23:34:53,966 epoch 3 - iter 8/9 - loss 0.79737653 - time (sec): 5.92 - samples/sec: 248.39 - lr: 0.100000\n",
"2023-04-16 23:34:54,372 epoch 3 - iter 9/9 - loss 0.78005277 - time (sec): 6.32 - samples/sec: 245.89 - lr: 0.100000\n",
"2023-04-16 23:34:54,373 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:34:54,374 EPOCH 3 done: loss 0.7801 - lr 0.100000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 4.50it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:34:54,600 Evaluating as a multi-label problem: False\n",
"2023-04-16 23:34:54,608 DEV : loss 0.33898717164993286 - f1-score (micro avg) 0.7761\n",
"2023-04-16 23:34:54,610 BAD EPOCHS (no improvement): 0\n",
"2023-04-16 23:34:54,611 saving best model\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:35:08,824 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:35:09,623 epoch 4 - iter 1/9 - loss 0.41515362 - time (sec): 0.80 - samples/sec: 244.67 - lr: 0.100000\n",
"2023-04-16 23:35:10,428 epoch 4 - iter 2/9 - loss 0.51239108 - time (sec): 1.60 - samples/sec: 243.44 - lr: 0.100000\n",
"2023-04-16 23:35:11,205 epoch 4 - iter 3/9 - loss 0.45658054 - time (sec): 2.38 - samples/sec: 232.97 - lr: 0.100000\n",
"2023-04-16 23:35:11,983 epoch 4 - iter 4/9 - loss 0.45578642 - time (sec): 3.16 - samples/sec: 231.55 - lr: 0.100000\n",
"2023-04-16 23:35:12,730 epoch 4 - iter 5/9 - loss 0.44140354 - time (sec): 3.90 - samples/sec: 236.42 - lr: 0.100000\n",
"2023-04-16 23:35:13,509 epoch 4 - iter 6/9 - loss 0.42847639 - time (sec): 4.68 - samples/sec: 235.10 - lr: 0.100000\n",
"2023-04-16 23:35:14,254 epoch 4 - iter 7/9 - loss 0.42680964 - time (sec): 5.43 - samples/sec: 238.02 - lr: 0.100000\n",
"2023-04-16 23:35:15,046 epoch 4 - iter 8/9 - loss 0.41317872 - time (sec): 6.22 - samples/sec: 238.91 - lr: 0.100000\n",
"2023-04-16 23:35:15,362 epoch 4 - iter 9/9 - loss 0.40072542 - time (sec): 6.54 - samples/sec: 237.91 - lr: 0.100000\n",
"2023-04-16 23:35:15,364 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:35:15,364 EPOCH 4 done: loss 0.4007 - lr 0.100000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 4.63it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:35:15,583 Evaluating as a multi-label problem: False\n",
"2023-04-16 23:35:15,591 DEV : loss 0.16208426654338837 - f1-score (micro avg) 0.9231\n",
"2023-04-16 23:35:15,593 BAD EPOCHS (no improvement): 0\n",
"2023-04-16 23:35:15,595 saving best model\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:35:26,934 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:35:27,683 epoch 5 - iter 1/9 - loss 0.23677549 - time (sec): 0.75 - samples/sec: 243.64 - lr: 0.100000\n",
"2023-04-16 23:35:28,430 epoch 5 - iter 2/9 - loss 0.21526806 - time (sec): 1.49 - samples/sec: 236.28 - lr: 0.100000\n",
"2023-04-16 23:35:29,218 epoch 5 - iter 3/9 - loss 0.18914680 - time (sec): 2.28 - samples/sec: 232.25 - lr: 0.100000\n",
"2023-04-16 23:35:30,001 epoch 5 - iter 4/9 - loss 0.21946464 - time (sec): 3.06 - samples/sec: 235.24 - lr: 0.100000\n",
"2023-04-16 23:35:30,749 epoch 5 - iter 5/9 - loss 0.20726706 - time (sec): 3.81 - samples/sec: 238.39 - lr: 0.100000\n",
"2023-04-16 23:35:31,440 epoch 5 - iter 6/9 - loss 0.20088127 - time (sec): 4.50 - samples/sec: 242.67 - lr: 0.100000\n",
"2023-04-16 23:35:32,199 epoch 5 - iter 7/9 - loss 0.21393993 - time (sec): 5.26 - samples/sec: 242.64 - lr: 0.100000\n",
"2023-04-16 23:35:32,995 epoch 5 - iter 8/9 - loss 0.20896170 - time (sec): 6.06 - samples/sec: 241.62 - lr: 0.100000\n",
"2023-04-16 23:35:33,438 epoch 5 - iter 9/9 - loss 0.22090499 - time (sec): 6.50 - samples/sec: 239.16 - lr: 0.100000\n",
"2023-04-16 23:35:33,440 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:35:33,440 EPOCH 5 done: loss 0.2209 - lr 0.100000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 4.67it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:35:33,657 Evaluating as a multi-label problem: False\n",
"2023-04-16 23:35:33,664 DEV : loss 0.07899065315723419 - f1-score (micro avg) 0.9706\n",
"2023-04-16 23:35:33,667 BAD EPOCHS (no improvement): 0\n",
"2023-04-16 23:35:33,668 saving best model\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:35:47,995 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:35:48,716 epoch 6 - iter 1/9 - loss 0.08489894 - time (sec): 0.72 - samples/sec: 265.65 - lr: 0.100000\n",
"2023-04-16 23:35:49,474 epoch 6 - iter 2/9 - loss 0.16483877 - time (sec): 1.48 - samples/sec: 262.02 - lr: 0.100000\n",
"2023-04-16 23:35:50,229 epoch 6 - iter 3/9 - loss 0.17961087 - time (sec): 2.23 - samples/sec: 261.20 - lr: 0.100000\n",
"2023-04-16 23:35:50,883 epoch 6 - iter 4/9 - loss 0.17008273 - time (sec): 2.89 - samples/sec: 266.11 - lr: 0.100000\n",
"2023-04-16 23:35:51,580 epoch 6 - iter 5/9 - loss 0.16165644 - time (sec): 3.58 - samples/sec: 259.35 - lr: 0.100000\n",
"2023-04-16 23:35:52,303 epoch 6 - iter 6/9 - loss 0.16796368 - time (sec): 4.31 - samples/sec: 257.37 - lr: 0.100000\n",
"2023-04-16 23:35:52,971 epoch 6 - iter 7/9 - loss 0.15208281 - time (sec): 4.97 - samples/sec: 254.12 - lr: 0.100000\n",
"2023-04-16 23:35:53,698 epoch 6 - iter 8/9 - loss 0.14079077 - time (sec): 5.70 - samples/sec: 255.74 - lr: 0.100000\n",
"2023-04-16 23:35:54,094 epoch 6 - iter 9/9 - loss 0.14176936 - time (sec): 6.10 - samples/sec: 255.04 - lr: 0.100000\n",
"2023-04-16 23:35:54,095 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:35:54,096 EPOCH 6 done: loss 0.1418 - lr 0.100000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 5.41it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:35:54,284 Evaluating as a multi-label problem: False\n",
"2023-04-16 23:35:54,292 DEV : loss 0.05525948479771614 - f1-score (micro avg) 0.9706\n",
"2023-04-16 23:35:54,294 BAD EPOCHS (no improvement): 0\n",
"2023-04-16 23:35:54,295 ----------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:35:55,007 epoch 7 - iter 1/9 - loss 0.06206851 - time (sec): 0.71 - samples/sec: 262.64 - lr: 0.100000\n",
"2023-04-16 23:35:55,659 epoch 7 - iter 2/9 - loss 0.07954220 - time (sec): 1.36 - samples/sec: 257.33 - lr: 0.100000\n",
"2023-04-16 23:35:56,426 epoch 7 - iter 3/9 - loss 0.07803471 - time (sec): 2.13 - samples/sec: 248.71 - lr: 0.100000\n",
"2023-04-16 23:35:57,138 epoch 7 - iter 4/9 - loss 0.09479619 - time (sec): 2.84 - samples/sec: 247.27 - lr: 0.100000\n",
"2023-04-16 23:35:57,921 epoch 7 - iter 5/9 - loss 0.12208708 - time (sec): 3.63 - samples/sec: 251.03 - lr: 0.100000\n",
"2023-04-16 23:35:58,664 epoch 7 - iter 6/9 - loss 0.11140702 - time (sec): 4.37 - samples/sec: 251.09 - lr: 0.100000\n",
"2023-04-16 23:35:59,452 epoch 7 - iter 7/9 - loss 0.10265536 - time (sec): 5.16 - samples/sec: 250.34 - lr: 0.100000\n",
"2023-04-16 23:36:00,146 epoch 7 - iter 8/9 - loss 0.09764915 - time (sec): 5.85 - samples/sec: 252.61 - lr: 0.100000\n",
"2023-04-16 23:36:00,536 epoch 7 - iter 9/9 - loss 0.09412069 - time (sec): 6.24 - samples/sec: 249.16 - lr: 0.100000\n",
"2023-04-16 23:36:00,537 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:36:00,538 EPOCH 7 done: loss 0.0941 - lr 0.100000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 4.76it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:36:00,751 Evaluating as a multi-label problem: False\n",
"2023-04-16 23:36:00,759 DEV : loss 0.04532366245985031 - f1-score (micro avg) 0.9706\n",
"2023-04-16 23:36:00,761 BAD EPOCHS (no improvement): 0\n",
"2023-04-16 23:36:00,762 ----------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:36:01,551 epoch 8 - iter 1/9 - loss 0.03479994 - time (sec): 0.79 - samples/sec: 234.47 - lr: 0.100000\n",
"2023-04-16 23:36:02,299 epoch 8 - iter 2/9 - loss 0.03457490 - time (sec): 1.54 - samples/sec: 235.52 - lr: 0.100000\n",
"2023-04-16 23:36:03,135 epoch 8 - iter 3/9 - loss 0.02756396 - time (sec): 2.37 - samples/sec: 232.19 - lr: 0.100000\n",
"2023-04-16 23:36:03,831 epoch 8 - iter 4/9 - loss 0.03832822 - time (sec): 3.07 - samples/sec: 235.58 - lr: 0.100000\n",
"2023-04-16 23:36:04,675 epoch 8 - iter 5/9 - loss 0.05709582 - time (sec): 3.91 - samples/sec: 236.64 - lr: 0.100000\n",
"2023-04-16 23:36:05,470 epoch 8 - iter 6/9 - loss 0.07284620 - time (sec): 4.71 - samples/sec: 237.68 - lr: 0.100000\n",
"2023-04-16 23:36:06,242 epoch 8 - iter 7/9 - loss 0.09056851 - time (sec): 5.48 - samples/sec: 238.50 - lr: 0.100000\n",
"2023-04-16 23:36:06,934 epoch 8 - iter 8/9 - loss 0.08531148 - time (sec): 6.17 - samples/sec: 238.50 - lr: 0.100000\n",
"2023-04-16 23:36:07,344 epoch 8 - iter 9/9 - loss 0.08690437 - time (sec): 6.58 - samples/sec: 236.25 - lr: 0.100000\n",
"2023-04-16 23:36:07,345 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:36:07,346 EPOCH 8 done: loss 0.0869 - lr 0.100000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 5.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:36:07,548 Evaluating as a multi-label problem: False\n",
"2023-04-16 23:36:07,556 DEV : loss 0.024222934618592262 - f1-score (micro avg) 0.9706\n",
"2023-04-16 23:36:07,558 BAD EPOCHS (no improvement): 0\n",
"2023-04-16 23:36:07,559 ----------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:36:08,281 epoch 9 - iter 1/9 - loss 0.04802091 - time (sec): 0.72 - samples/sec: 239.94 - lr: 0.100000\n",
"2023-04-16 23:36:09,019 epoch 9 - iter 2/9 - loss 0.04962141 - time (sec): 1.46 - samples/sec: 241.95 - lr: 0.100000\n",
"2023-04-16 23:36:09,715 epoch 9 - iter 3/9 - loss 0.04733611 - time (sec): 2.16 - samples/sec: 258.47 - lr: 0.100000\n",
"2023-04-16 23:36:10,377 epoch 9 - iter 4/9 - loss 0.05639010 - time (sec): 2.82 - samples/sec: 258.08 - lr: 0.100000\n",
"2023-04-16 23:36:11,111 epoch 9 - iter 5/9 - loss 0.11318641 - time (sec): 3.55 - samples/sec: 263.87 - lr: 0.100000\n",
"2023-04-16 23:36:11,812 epoch 9 - iter 6/9 - loss 0.10583430 - time (sec): 4.25 - samples/sec: 260.58 - lr: 0.100000\n",
"2023-04-16 23:36:12,486 epoch 9 - iter 7/9 - loss 0.09734085 - time (sec): 4.93 - samples/sec: 262.89 - lr: 0.100000\n",
"2023-04-16 23:36:13,213 epoch 9 - iter 8/9 - loss 0.09001355 - time (sec): 5.65 - samples/sec: 259.68 - lr: 0.100000\n",
"2023-04-16 23:36:13,582 epoch 9 - iter 9/9 - loss 0.08649074 - time (sec): 6.02 - samples/sec: 258.22 - lr: 0.100000\n",
"2023-04-16 23:36:13,583 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:36:13,584 EPOCH 9 done: loss 0.0865 - lr 0.100000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 5.41it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:36:13,772 Evaluating as a multi-label problem: False\n",
"2023-04-16 23:36:13,779 DEV : loss 0.0387137271463871 - f1-score (micro avg) 0.9706\n",
"2023-04-16 23:36:13,781 BAD EPOCHS (no improvement): 1\n",
"2023-04-16 23:36:13,782 ----------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:36:14,486 epoch 10 - iter 1/9 - loss 0.01475051 - time (sec): 0.70 - samples/sec: 251.42 - lr: 0.100000\n",
"2023-04-16 23:36:15,220 epoch 10 - iter 2/9 - loss 0.02797203 - time (sec): 1.44 - samples/sec: 246.17 - lr: 0.100000\n",
"2023-04-16 23:36:15,881 epoch 10 - iter 3/9 - loss 0.03366118 - time (sec): 2.10 - samples/sec: 250.12 - lr: 0.100000\n",
"2023-04-16 23:36:16,583 epoch 10 - iter 4/9 - loss 0.02913842 - time (sec): 2.80 - samples/sec: 256.33 - lr: 0.100000\n",
"2023-04-16 23:36:17,329 epoch 10 - iter 5/9 - loss 0.02953250 - time (sec): 3.55 - samples/sec: 258.53 - lr: 0.100000\n",
"2023-04-16 23:36:18,035 epoch 10 - iter 6/9 - loss 0.04606061 - time (sec): 4.25 - samples/sec: 259.82 - lr: 0.100000\n",
"2023-04-16 23:36:18,729 epoch 10 - iter 7/9 - loss 0.04505843 - time (sec): 4.95 - samples/sec: 260.56 - lr: 0.100000\n",
"2023-04-16 23:36:19,462 epoch 10 - iter 8/9 - loss 0.05644751 - time (sec): 5.68 - samples/sec: 258.63 - lr: 0.100000\n",
"2023-04-16 23:36:19,893 epoch 10 - iter 9/9 - loss 0.06249184 - time (sec): 6.11 - samples/sec: 254.46 - lr: 0.100000\n",
"2023-04-16 23:36:19,895 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:36:19,895 EPOCH 10 done: loss 0.0625 - lr 0.100000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 5.43it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:36:20,082 Evaluating as a multi-label problem: False\n",
"2023-04-16 23:36:20,089 DEV : loss 0.03223632648587227 - f1-score (micro avg) 0.9275\n",
"2023-04-16 23:36:20,091 BAD EPOCHS (no improvement): 2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:36:31,813 ----------------------------------------------------------------------------------------------------\n",
"2023-04-16 23:36:35,403 SequenceTagger predicts: Dictionary with 19 tags: O, S-datetime, B-datetime, E-datetime, I-datetime, S-weather/attribute, B-weather/attribute, E-weather/attribute, I-weather/attribute, S-weather/noun, B-weather/noun, E-weather/noun, I-weather/noun, S-location, B-location, E-location, I-location, <START>, <STOP>\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10/10 [00:13<00:00, 1.40s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:36:50,601 Evaluating as a multi-label problem: False\n",
"2023-04-16 23:36:50,612 0.2111\t0.1989\t0.2048\t0.1176\n",
"2023-04-16 23:36:50,613 \n",
"Results:\n",
"- F-score (micro) 0.2048\n",
"- F-score (macro) 0.1106\n",
"- Accuracy 0.1176\n",
"\n",
"By class:\n",
" precision recall f1-score support\n",
"\n",
" datetime 0.2136 0.2327 0.2227 202\n",
" weather/noun 0.2184 0.5588 0.3140 34\n",
" reminder/todo 0.0000 0.0000 0.0000 46\n",
" reminder/noun 0.0000 0.0000 0.0000 42\n",
" weather/attribute 0.1765 0.1667 0.1714 18\n",
" location 0.1765 0.1765 0.1765 17\n",
"reminder/recurring_period 0.0000 0.0000 0.0000 2\n",
" negation 0.0000 0.0000 0.0000 1\n",
"\n",
" micro avg 0.2111 0.1989 0.2048 362\n",
" macro avg 0.0981 0.1418 0.1106 362\n",
" weighted avg 0.1568 0.1989 0.1706 362\n",
"\n",
"2023-04-16 23:36:50,613 ----------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"{'test_score': 0.20483641536273117,\n",
" 'dev_score_history': [0.03571428571428571,\n",
" 0.0,\n",
" 0.7761194029850745,\n",
" 0.923076923076923,\n",
" 0.9705882352941176,\n",
" 0.9705882352941176,\n",
" 0.9705882352941176,\n",
" 0.9705882352941176,\n",
" 0.9705882352941176,\n",
" 0.9275362318840579],\n",
" 'train_loss_history': [2.279269006857918,\n",
" 1.2743978126256028,\n",
" 0.7800527689157958,\n",
" 0.40072541558857516,\n",
" 0.22090499240102493,\n",
" 0.14176936011605706,\n",
" 0.09412068554059486,\n",
" 0.08690436752662781,\n",
" 0.08649074149668408,\n",
" 0.06249183581189711],\n",
" 'dev_loss_history': [1.2717913389205933,\n",
" 0.7916695475578308,\n",
" 0.33898717164993286,\n",
" 0.16208426654338837,\n",
" 0.07899065315723419,\n",
" 0.05525948479771614,\n",
" 0.04532366245985031,\n",
" 0.024222934618592262,\n",
" 0.0387137271463871,\n",
" 0.03223632648587227]}"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = ModelTrainer(tagger, corpus)\n",
"trainer.train('slot-model',\n",
" learning_rate=0.1,\n",
" mini_batch_size=32,\n",
" max_epochs=10,\n",
" train_with_dev=False)"
]
},
{
"cell_type": "markdown",
"id": "bcd0c303",
"metadata": {},
"source": [
"Jakość wyuczonego modelu możemy ocenić, korzystając z zaraportowanych powyżej metryk, tj.:\n",
"\n",
" - *tp (true positives)*\n",
"\n",
" > liczba słów oznaczonych w zbiorze testowym etykietą $e$, które model oznaczył tą etykietą\n",
"\n",
" - *fp (false positives)*\n",
"\n",
" > liczba słów nieoznaczonych w zbiorze testowym etykietą $e$, które model oznaczył tą etykietą\n",
"\n",
" - *fn (false negatives)*\n",
"\n",
" > liczba słów oznaczonych w zbiorze testowym etykietą $e$, którym model nie nadał etykiety $e$\n",
"\n",
" - *precision*\n",
"\n",
" > $$\\frac{tp}{tp + fp}$$\n",
"\n",
" - *recall*\n",
"\n",
" > $$\\frac{tp}{tp + fn}$$\n",
"\n",
" - $F_1$\n",
"\n",
" > $$\\frac{2 \\cdot precision \\cdot recall}{precision + recall}$$\n",
"\n",
" - *micro* $F_1$\n",
"\n",
" > $F_1$ w którym $tp$, $fp$ i $fn$ są liczone łącznie dla wszystkich etykiet, tj. $tp = \\sum_{e}{{tp}_e}$, $fn = \\sum_{e}{{fn}_e}$, $fp = \\sum_{e}{{fp}_e}$\n",
"\n",
" - *macro* $F_1$\n",
"\n",
" > średnia arytmetyczna z $F_1$ obliczonych dla poszczególnych etykiet z osobna.\n",
"\n",
"Wyuczony model możemy wczytać z pliku korzystając z metody `load`."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "d12596c1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-04-16 23:37:00,272 SequenceTagger predicts: Dictionary with 19 tags: O, S-datetime, B-datetime, E-datetime, I-datetime, S-weather/attribute, B-weather/attribute, E-weather/attribute, I-weather/attribute, S-weather/noun, B-weather/noun, E-weather/noun, I-weather/noun, S-location, B-location, E-location, I-location, <START>, <STOP>\n"
]
}
],
"source": [
"model = SequenceTagger.load('slot-model/final-model.pt')"
]
},
{
"cell_type": "markdown",
"id": "a97dd603",
"metadata": {},
"source": [
"Wczytany model możemy wykorzystać do przewidywania slotów w wypowiedziach użytkownika, korzystając\n",
"z przedstawionej poniżej funkcji `predict`."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "87c310cf",
"metadata": {},
"outputs": [],
"source": [
"def predict(model, sentence):\n",
" csentence = [{'form': word, 'slot': 'O'} for word in sentence]\n",
" fsentence = conllu2flair([csentence])[0]\n",
" model.predict(fsentence)\n",
"\n",
" for span in fsentence.get_spans('slot'):\n",
" tag = span.get_label('slot').value\n",
" csentence[span.tokens[0].idx - 1]['slot'] = f'B-{tag}'\n",
"\n",
" for token in span.tokens[1:]:\n",
" csentence[token.idx - 1]['slot'] = f'I-{tag}'\n",
"\n",
" return csentence\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "97043331",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tbody>\n",
"<tr><td>set </td><td>O </td></tr>\n",
"<tr><td>alarm </td><td>O </td></tr>\n",
"<tr><td>for </td><td>B-datetime</td></tr>\n",
"<tr><td>20 </td><td>B-datetime</td></tr>\n",
"<tr><td>minutes</td><td>I-datetime</td></tr>\n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"'<table>\\n<tbody>\\n<tr><td>set </td><td>O </td></tr>\\n<tr><td>alarm </td><td>O </td></tr>\\n<tr><td>for </td><td>B-datetime</td></tr>\\n<tr><td>20 </td><td>B-datetime</td></tr>\\n<tr><td>minutes</td><td>I-datetime</td></tr>\\n</tbody>\\n</table>'"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tabulate(predict(model, 'set alarm for 20 minutes'.split()), tablefmt='html')"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "29856a8a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tbody>\n",
"<tr><td>change</td><td>O </td></tr>\n",
"<tr><td>my </td><td>O </td></tr>\n",
"<tr><td>3 </td><td>O </td></tr>\n",
"<tr><td>pm </td><td>O </td></tr>\n",
"<tr><td>alarm </td><td>O </td></tr>\n",
"<tr><td>to </td><td>O </td></tr>\n",
"<tr><td>the </td><td>O </td></tr>\n",
"<tr><td>next </td><td>O </td></tr>\n",
"<tr><td>day </td><td>B-weather/noun</td></tr>\n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"'<table>\\n<tbody>\\n<tr><td>change</td><td>O </td></tr>\\n<tr><td>my </td><td>O </td></tr>\\n<tr><td>3 </td><td>O </td></tr>\\n<tr><td>pm </td><td>O </td></tr>\\n<tr><td>alarm </td><td>O </td></tr>\\n<tr><td>to </td><td>O </td></tr>\\n<tr><td>the </td><td>O </td></tr>\\n<tr><td>next </td><td>O </td></tr>\\n<tr><td>day </td><td>B-weather/noun</td></tr>\\n</tbody>\\n</table>'"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tabulate(predict(model, 'change my 3 pm alarm to the next day'.split()), tablefmt='html')"
]
},
{
"cell_type": "markdown",
"id": "21b00302",
"metadata": {},
"source": [
"Literatura\n",
"----------\n",
" 1. Sebastian Schuster, Sonal Gupta, Rushin Shah, Mike Lewis, Cross-lingual Transfer Learning for Multilingual Task Oriented Dialog. NAACL-HLT (1) 2019, pp. 3795-3805\n",
" 2. John D. Lafferty, Andrew McCallum, and Fernando C. N. Pereira. 2001. Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data. In Proceedings of the Eighteenth International Conference on Machine Learning (ICML '01). Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, 282289, https://repository.upenn.edu/cgi/viewcontent.cgi?article=1162&context=cis_papers\n",
" 3. Sepp Hochreiter and Jürgen Schmidhuber. 1997. Long Short-Term Memory. Neural Comput. 9, 8 (November 15, 1997), 17351780, https://doi.org/10.1162/neco.1997.9.8.1735\n",
" 4. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, Attention is All you Need, NIPS 2017, pp. 5998-6008, https://arxiv.org/abs/1706.03762\n",
" 5. Alan Akbik, Duncan Blythe, Roland Vollgraf, Contextual String Embeddings for Sequence Labeling, Proceedings of the 27th International Conference on Computational Linguistics, pp. 16381649, https://www.aclweb.org/anthology/C18-1139.pdf\n"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}