1439 lines
61 KiB
Plaintext
1439 lines
61 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "68bc3d74",
|
||
"metadata": {},
|
||
"source": [
|
||
"Parsing semantyczny z wykorzystaniem technik uczenia maszynowego\n",
|
||
"================================================================\n",
|
||
"\n",
|
||
"Wprowadzenie\n",
|
||
"------------\n",
|
||
"Problem wykrywania slotów i ich wartości w wypowiedziach użytkownika można sformułować jako zadanie\n",
|
||
"polegające na przewidywaniu dla poszczególnych słów etykiet wskazujących na to czy i do jakiego\n",
|
||
"slotu dane słowo należy.\n",
|
||
"\n",
|
||
"<pre>chciałbym zarezerwować stolik na jutro<b>/day</b> na godzinę dwunastą<b>/hour</b> czterdzieści<b>/hour</b> pięć<b>/hour</b> na pięć<b>/size</b> osób</pre>\n",
|
||
"\n",
|
||
"Granice slotów oznacza się korzystając z wybranego schematu etykietowania.\n",
|
||
"\n",
|
||
"### Schemat IOB\n",
|
||
"\n",
|
||
"| Prefix | Znaczenie |\n",
|
||
"|:------:|:---------------------------|\n",
|
||
"| I | wnętrze slotu (inside) |\n",
|
||
"| O | poza slotem (outside) |\n",
|
||
"| B | początek slotu (beginning) |\n",
|
||
"\n",
|
||
"<pre>chciałbym zarezerwować stolik na jutro<b>/B-day</b> na godzinę dwunastą<b>/B-hour</b> czterdzieści<b>/I-hour</b> pięć<b>/I-hour</b> na pięć<b>/B-size</b> osób</pre>\n",
|
||
"\n",
|
||
"### Schemat IOBES\n",
|
||
"\n",
|
||
"| Prefix | Znaczenie |\n",
|
||
"|:------:|:---------------------------|\n",
|
||
"| I | wnętrze slotu (inside) |\n",
|
||
"| O | poza slotem (outside) |\n",
|
||
"| B | początek slotu (beginning) |\n",
|
||
"| E | koniec slotu (ending) |\n",
|
||
"| S | pojedyncze słowo (single) |\n",
|
||
"\n",
|
||
"<pre>chciałbym zarezerwować stolik na jutro<b>/S-day</b> na godzinę dwunastą<b>/B-hour</b> czterdzieści<b>/I-hour</b> pięć<b>/E-hour</b> na pięć<b>/S-size</b> osób</pre>\n",
|
||
"\n",
|
||
"Jeżeli dla tak sformułowanego zadania przygotujemy zbiór danych\n",
|
||
"złożony z wypowiedzi użytkownika z oznaczonymi slotami (tzw. *zbiór uczący*),\n",
|
||
"to możemy zastosować techniki (nadzorowanego) uczenia maszynowego w celu zbudowania modelu\n",
|
||
"annotującego wypowiedzi użytkownika etykietami slotów.\n",
|
||
"\n",
|
||
"Do zbudowania takiego modelu można wykorzystać między innymi:\n",
|
||
"\n",
|
||
" 1. warunkowe pola losowe (Lafferty i in.; 2001),\n",
|
||
"\n",
|
||
" 2. rekurencyjne sieci neuronowe, np. sieci LSTM (Hochreiter i Schmidhuber; 1997),\n",
|
||
"\n",
|
||
" 3. transformery (Vaswani i in., 2017).\n",
|
||
"\n",
|
||
"Przykład\n",
|
||
"--------\n",
|
||
"Skorzystamy ze zbioru danych przygotowanego przez Schustera (2019)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "8cca8cd1",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"g:\\studia\\studia uam\\systemy dialogowe\\notebooks\\l07\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
|
||
" Dload Upload Total Spent Left Speed\n",
|
||
"\n",
|
||
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
|
||
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
|
||
"\n",
|
||
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
|
||
" 14 8714k 14 1307k 0 0 860k 0 0:00:10 0:00:01 0:00:09 1354k\n",
|
||
"100 8714k 100 8714k 0 0 4650k 0 0:00:01 0:00:01 --:--:-- 6606k\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"g:\\studia\\studia uam\\systemy dialogowe\\notebooks\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"'unzip' is not recognized as an internal or external command,\n",
|
||
"operable program or batch file.\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"!mkdir -p l07\n",
|
||
"%cd l07\n",
|
||
"!curl -L -C - https://fb.me/multilingual_task_oriented_data -o data.zip\n",
|
||
"!unzip data.zip\n",
|
||
"%cd .."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"'Expand-Archive' is not recognized as an internal or external command,\n",
|
||
"operable program or batch file.\n"
|
||
]
|
||
}
|
||
],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "56d91f6c",
|
||
"metadata": {},
|
||
"source": [
|
||
"Zbiór ten gromadzi wypowiedzi w trzech językach opisane slotami dla dwunastu ram należących do trzech dziedzin `Alarm`, `Reminder` oraz `Weather`. Dane wczytamy korzystając z biblioteki [conllu](https://pypi.org/project/conllu/)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "18b9a032",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from conllu import parse_incr\n",
|
||
"fields = ['id', 'form', 'frame', 'slot']\n",
|
||
"\n",
|
||
"def nolabel2o(line, i):\n",
|
||
" return 'O' if line[i] == 'NoLabel' else line[i]\n",
|
||
"\n",
|
||
"with open('l07/en/train-en.conllu') as trainfile:\n",
|
||
" trainset = list(parse_incr(trainfile, fields=fields, field_parsers={'slot': nolabel2o}))\n",
|
||
"with open('l07/en/test-en.conllu') as testfile:\n",
|
||
" testset = list(parse_incr(testfile, fields=fields, field_parsers={'slot': nolabel2o}))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "7477593e",
|
||
"metadata": {},
|
||
"source": [
|
||
"Zobaczmy kilka przykładowych wypowiedzi z tego zbioru."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"id": "b2799ad2",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<table>\n",
|
||
"<tbody>\n",
|
||
"<tr><td style=\"text-align: right;\">1</td><td>tell </td><td>weather/find</td><td>O </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">2</td><td>me </td><td>weather/find</td><td>O </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">3</td><td>the </td><td>weather/find</td><td>O </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">4</td><td>weather</td><td>weather/find</td><td>B-weather/noun</td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">5</td><td>report </td><td>weather/find</td><td>I-weather/noun</td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">6</td><td>for </td><td>weather/find</td><td>O </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">7</td><td>half </td><td>weather/find</td><td>B-location </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">8</td><td>moon </td><td>weather/find</td><td>I-location </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">9</td><td>bay </td><td>weather/find</td><td>I-location </td></tr>\n",
|
||
"</tbody>\n",
|
||
"</table>"
|
||
],
|
||
"text/plain": [
|
||
"'<table>\\n<tbody>\\n<tr><td style=\"text-align: right;\">1</td><td>tell </td><td>weather/find</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">2</td><td>me </td><td>weather/find</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">3</td><td>the </td><td>weather/find</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">4</td><td>weather</td><td>weather/find</td><td>B-weather/noun</td></tr>\\n<tr><td style=\"text-align: right;\">5</td><td>report </td><td>weather/find</td><td>I-weather/noun</td></tr>\\n<tr><td style=\"text-align: right;\">6</td><td>for </td><td>weather/find</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">7</td><td>half </td><td>weather/find</td><td>B-location </td></tr>\\n<tr><td style=\"text-align: right;\">8</td><td>moon </td><td>weather/find</td><td>I-location </td></tr>\\n<tr><td style=\"text-align: right;\">9</td><td>bay </td><td>weather/find</td><td>I-location </td></tr>\\n</tbody>\\n</table>'"
|
||
]
|
||
},
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from tabulate import tabulate\n",
|
||
"tabulate(trainset[0], tablefmt='html')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "ba2c2706",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<table>\n",
|
||
"<tbody>\n",
|
||
"<tr><td style=\"text-align: right;\">1</td><td>remind</td><td>reminder/set_reminder</td><td>O </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">2</td><td>me </td><td>reminder/set_reminder</td><td>O </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">3</td><td>about </td><td>reminder/set_reminder</td><td>O </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">4</td><td>game </td><td>reminder/set_reminder</td><td>B-reminder/todo</td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">5</td><td>night </td><td>reminder/set_reminder</td><td>I-reminder/todo</td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">6</td><td>on </td><td>reminder/set_reminder</td><td>B-datetime </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">7</td><td>friday</td><td>reminder/set_reminder</td><td>I-datetime </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">8</td><td>at </td><td>reminder/set_reminder</td><td>I-datetime </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">9</td><td>4pm </td><td>reminder/set_reminder</td><td>I-datetime </td></tr>\n",
|
||
"</tbody>\n",
|
||
"</table>"
|
||
],
|
||
"text/plain": [
|
||
"'<table>\\n<tbody>\\n<tr><td style=\"text-align: right;\">1</td><td>remind</td><td>reminder/set_reminder</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">2</td><td>me </td><td>reminder/set_reminder</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">3</td><td>about </td><td>reminder/set_reminder</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">4</td><td>game </td><td>reminder/set_reminder</td><td>B-reminder/todo</td></tr>\\n<tr><td style=\"text-align: right;\">5</td><td>night </td><td>reminder/set_reminder</td><td>I-reminder/todo</td></tr>\\n<tr><td style=\"text-align: right;\">6</td><td>on </td><td>reminder/set_reminder</td><td>B-datetime </td></tr>\\n<tr><td style=\"text-align: right;\">7</td><td>friday</td><td>reminder/set_reminder</td><td>I-datetime </td></tr>\\n<tr><td style=\"text-align: right;\">8</td><td>at </td><td>reminder/set_reminder</td><td>I-datetime </td></tr>\\n<tr><td style=\"text-align: right;\">9</td><td>4pm </td><td>reminder/set_reminder</td><td>I-datetime </td></tr>\\n</tbody>\\n</table>'"
|
||
]
|
||
},
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"tabulate(trainset[1000], tablefmt='html')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"id": "b5c9db18",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<table>\n",
|
||
"<tbody>\n",
|
||
"<tr><td style=\"text-align: right;\">1</td><td>set </td><td>alarm/set_alarm</td><td>O </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">2</td><td>alarm </td><td>alarm/set_alarm</td><td>O </td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">3</td><td>for </td><td>alarm/set_alarm</td><td>B-datetime</td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">4</td><td>20 </td><td>alarm/set_alarm</td><td>I-datetime</td></tr>\n",
|
||
"<tr><td style=\"text-align: right;\">5</td><td>minutes</td><td>alarm/set_alarm</td><td>I-datetime</td></tr>\n",
|
||
"</tbody>\n",
|
||
"</table>"
|
||
],
|
||
"text/plain": [
|
||
"'<table>\\n<tbody>\\n<tr><td style=\"text-align: right;\">1</td><td>set </td><td>alarm/set_alarm</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">2</td><td>alarm </td><td>alarm/set_alarm</td><td>O </td></tr>\\n<tr><td style=\"text-align: right;\">3</td><td>for </td><td>alarm/set_alarm</td><td>B-datetime</td></tr>\\n<tr><td style=\"text-align: right;\">4</td><td>20 </td><td>alarm/set_alarm</td><td>I-datetime</td></tr>\\n<tr><td style=\"text-align: right;\">5</td><td>minutes</td><td>alarm/set_alarm</td><td>I-datetime</td></tr>\\n</tbody>\\n</table>'"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"tabulate(trainset[2000], tablefmt='html')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "0f35074d",
|
||
"metadata": {
|
||
"lines_to_next_cell": 0
|
||
},
|
||
"source": [
|
||
"Na potrzeby prezentacji procesu uczenia w jupyterowym notatniku zawęzimy zbiór danych do początkowych przykładów."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"id": "f735ca85",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"trainset = trainset[:300]\n",
|
||
"testset = testset[:300]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "66284486",
|
||
"metadata": {},
|
||
"source": [
|
||
"Budując model skorzystamy z architektury opartej o rekurencyjne sieci neuronowe\n",
|
||
"zaimplementowanej w bibliotece [flair](https://github.com/flairNLP/flair) (Akbik i in. 2018)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"id": "f3e30f81",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from flair.data import Corpus, Sentence, Token\n",
|
||
"from flair.datasets import FlairDatapointDataset\n",
|
||
"from flair.embeddings import StackedEmbeddings\n",
|
||
"from flair.embeddings import WordEmbeddings\n",
|
||
"from flair.embeddings import CharacterEmbeddings\n",
|
||
"from flair.embeddings import FlairEmbeddings\n",
|
||
"from flair.models import SequenceTagger\n",
|
||
"from flair.trainers import ModelTrainer\n",
|
||
"\n",
|
||
"# determinizacja obliczeń\n",
|
||
"import random\n",
|
||
"import torch\n",
|
||
"random.seed(42)\n",
|
||
"torch.manual_seed(42)\n",
|
||
"\n",
|
||
"if torch.cuda.is_available():\n",
|
||
" torch.cuda.manual_seed(0)\n",
|
||
" torch.cuda.manual_seed_all(0)\n",
|
||
" torch.backends.cudnn.enabled = False\n",
|
||
" torch.backends.cudnn.benchmark = False\n",
|
||
" torch.backends.cudnn.deterministic = True"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c1a33987",
|
||
"metadata": {},
|
||
"source": [
|
||
"Dane skonwertujemy do formatu wykorzystywanego przez `flair`, korzystając z następującej funkcji."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"id": "f3c47593",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Corpus: 270 train + 30 dev + 300 test sentences\n",
|
||
"2023-04-16 23:32:52,141 Computing label dictionary. Progress:\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"270it [00:00, 53998.76it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:32:52,151 Dictionary created for label 'slot' with 5 values: datetime (seen 174 times), weather/attribute (seen 77 times), weather/noun (seen 66 times), location (seen 44 times)\n",
|
||
"Dictionary with 5 tags: <unk>, datetime, weather/attribute, weather/noun, location\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"def conllu2flair(sentences, label=None):\n",
|
||
" fsentences = []\n",
|
||
"\n",
|
||
" for sentence in sentences:\n",
|
||
" fsentence = Sentence(' '.join(token['form'] for token in sentence), use_tokenizer=False)\n",
|
||
" start_idx = None\n",
|
||
" end_idx = None\n",
|
||
" tag = None\n",
|
||
"\n",
|
||
" if label:\n",
|
||
" for idx, (token, ftoken) in enumerate(zip(sentence, fsentence)):\n",
|
||
" if token[label].startswith('B-'):\n",
|
||
" start_idx = idx\n",
|
||
" end_idx = idx\n",
|
||
" tag = token[label][2:]\n",
|
||
" elif token[label].startswith('I-'):\n",
|
||
" end_idx = idx\n",
|
||
" elif token[label] == 'O':\n",
|
||
" if start_idx is not None:\n",
|
||
" fsentence[start_idx:end_idx+1].add_label(label, tag)\n",
|
||
" start_idx = None\n",
|
||
" end_idx = None\n",
|
||
" tag = None\n",
|
||
"\n",
|
||
" if start_idx is not None:\n",
|
||
" fsentence[start_idx:end_idx+1].add_label(label, tag)\n",
|
||
"\n",
|
||
" fsentences.append(fsentence)\n",
|
||
"\n",
|
||
" return FlairDatapointDataset(fsentences)\n",
|
||
"\n",
|
||
"corpus = Corpus(train=conllu2flair(trainset, 'slot'), test=conllu2flair(testset, 'slot'))\n",
|
||
"print(corpus)\n",
|
||
"tag_dictionary = corpus.make_label_dictionary(label_type='slot')\n",
|
||
"print(tag_dictionary)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "0ed59fb2",
|
||
"metadata": {},
|
||
"source": [
|
||
"Nasz model będzie wykorzystywał wektorowe reprezentacje słów (zob. [Word Embeddings](https://github.com/flairNLP/flair/blob/master/resources/docs/TUTORIAL_EMBEDDINGS_OVERVIEW.md))."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"id": "408cf961",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:32:52,457 https://flair.informatik.hu-berlin.de/resources/embeddings/token/en-fasttext-news-300d-1M.vectors.npy not found in cache, downloading to C:\\Users\\macty\\AppData\\Local\\Temp\\tmpxx1l309v\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1.12G/1.12G [01:03<00:00, 18.8MB/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:33:56,277 copying C:\\Users\\macty\\AppData\\Local\\Temp\\tmpxx1l309v to cache at C:\\Users\\macty\\.flair\\embeddings\\en-fasttext-news-300d-1M.vectors.npy\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:33:57,359 removing temp file C:\\Users\\macty\\AppData\\Local\\Temp\\tmpxx1l309v\n",
|
||
"2023-04-16 23:33:58,068 https://flair.informatik.hu-berlin.de/resources/embeddings/token/en-fasttext-news-300d-1M not found in cache, downloading to C:\\Users\\macty\\AppData\\Local\\Temp\\tmpnqnow7h9\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 52.1M/52.1M [00:02<00:00, 18.9MB/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:01,079 copying C:\\Users\\macty\\AppData\\Local\\Temp\\tmpnqnow7h9 to cache at C:\\Users\\macty\\.flair\\embeddings\\en-fasttext-news-300d-1M\n",
|
||
"2023-04-16 23:34:01,116 removing temp file C:\\Users\\macty\\AppData\\Local\\Temp\\tmpnqnow7h9\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:09,948 https://flair.informatik.hu-berlin.de/resources/embeddings/flair/news-forward-0.4.1.pt not found in cache, downloading to C:\\Users\\macty\\AppData\\Local\\Temp\\tmpf2iq2swn\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 69.7M/69.7M [00:03<00:00, 21.2MB/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:13,532 copying C:\\Users\\macty\\AppData\\Local\\Temp\\tmpf2iq2swn to cache at C:\\Users\\macty\\.flair\\embeddings\\news-forward-0.4.1.pt\n",
|
||
"2023-04-16 23:34:13,579 removing temp file C:\\Users\\macty\\AppData\\Local\\Temp\\tmpf2iq2swn\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:14,221 https://flair.informatik.hu-berlin.de/resources/embeddings/flair/news-backward-0.4.1.pt not found in cache, downloading to C:\\Users\\macty\\AppData\\Local\\Temp\\tmp1i_pq3dr\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 69.7M/69.7M [00:03<00:00, 22.5MB/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:17,584 copying C:\\Users\\macty\\AppData\\Local\\Temp\\tmp1i_pq3dr to cache at C:\\Users\\macty\\.flair\\embeddings\\news-backward-0.4.1.pt\n",
|
||
"2023-04-16 23:34:17,631 removing temp file C:\\Users\\macty\\AppData\\Local\\Temp\\tmp1i_pq3dr\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:18,016 https://flair.informatik.hu-berlin.de/resources/characters/common_characters not found in cache, downloading to C:\\Users\\macty\\AppData\\Local\\Temp\\tmpqyh0003q\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 2.82k/2.82k [00:00<00:00, 1.44MB/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:18,139 copying C:\\Users\\macty\\AppData\\Local\\Temp\\tmpqyh0003q to cache at C:\\Users\\macty\\.flair\\datasets\\common_characters\n",
|
||
"2023-04-16 23:34:18,145 removing temp file C:\\Users\\macty\\AppData\\Local\\Temp\\tmpqyh0003q\n",
|
||
"2023-04-16 23:34:18,147 SequenceTagger predicts: Dictionary with 17 tags: O, S-datetime, B-datetime, E-datetime, I-datetime, S-weather/attribute, B-weather/attribute, E-weather/attribute, I-weather/attribute, S-weather/noun, B-weather/noun, E-weather/noun, I-weather/noun, S-location, B-location, E-location, I-location\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"embedding_types = [\n",
|
||
" WordEmbeddings('en'),\n",
|
||
" FlairEmbeddings('en-forward'),\n",
|
||
" FlairEmbeddings('en-backward'),\n",
|
||
" CharacterEmbeddings(),\n",
|
||
"]\n",
|
||
"\n",
|
||
"embeddings = StackedEmbeddings(embeddings=embedding_types)\n",
|
||
"tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,\n",
|
||
" tag_dictionary=tag_dictionary,\n",
|
||
" tag_type='slot', use_crf=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ab634218",
|
||
"metadata": {},
|
||
"source": [
|
||
"Zobaczmy jak wygląda architektura sieci neuronowej, która będzie odpowiedzialna za przewidywanie\n",
|
||
"slotów w wypowiedziach."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"id": "04d0bbf3",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"SequenceTagger(\n",
|
||
" (embeddings): StackedEmbeddings(\n",
|
||
" (list_embedding_0): WordEmbeddings(\n",
|
||
" 'en'\n",
|
||
" (embedding): Embedding(1000001, 300)\n",
|
||
" )\n",
|
||
" (list_embedding_1): FlairEmbeddings(\n",
|
||
" (lm): LanguageModel(\n",
|
||
" (drop): Dropout(p=0.05, inplace=False)\n",
|
||
" (encoder): Embedding(300, 100)\n",
|
||
" (rnn): LSTM(100, 2048)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (list_embedding_2): FlairEmbeddings(\n",
|
||
" (lm): LanguageModel(\n",
|
||
" (drop): Dropout(p=0.05, inplace=False)\n",
|
||
" (encoder): Embedding(300, 100)\n",
|
||
" (rnn): LSTM(100, 2048)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (list_embedding_3): CharacterEmbeddings(\n",
|
||
" (char_embedding): Embedding(275, 25)\n",
|
||
" (char_rnn): LSTM(25, 25, bidirectional=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (word_dropout): WordDropout(p=0.05)\n",
|
||
" (locked_dropout): LockedDropout(p=0.5)\n",
|
||
" (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True)\n",
|
||
" (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True)\n",
|
||
" (linear): Linear(in_features=512, out_features=19, bias=True)\n",
|
||
" (loss_function): ViterbiLoss()\n",
|
||
" (crf): CRF()\n",
|
||
")\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(tagger)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8e0da880",
|
||
"metadata": {},
|
||
"source": [
|
||
"Wykonamy dziesięć iteracji (epok) uczenia a wynikowy model zapiszemy w katalogu `slot-model`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"id": "0fd2b573",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:18,357 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:34:18,358 Model: \"SequenceTagger(\n",
|
||
" (embeddings): StackedEmbeddings(\n",
|
||
" (list_embedding_0): WordEmbeddings(\n",
|
||
" 'en'\n",
|
||
" (embedding): Embedding(1000001, 300)\n",
|
||
" )\n",
|
||
" (list_embedding_1): FlairEmbeddings(\n",
|
||
" (lm): LanguageModel(\n",
|
||
" (drop): Dropout(p=0.05, inplace=False)\n",
|
||
" (encoder): Embedding(300, 100)\n",
|
||
" (rnn): LSTM(100, 2048)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (list_embedding_2): FlairEmbeddings(\n",
|
||
" (lm): LanguageModel(\n",
|
||
" (drop): Dropout(p=0.05, inplace=False)\n",
|
||
" (encoder): Embedding(300, 100)\n",
|
||
" (rnn): LSTM(100, 2048)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (list_embedding_3): CharacterEmbeddings(\n",
|
||
" (char_embedding): Embedding(275, 25)\n",
|
||
" (char_rnn): LSTM(25, 25, bidirectional=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (word_dropout): WordDropout(p=0.05)\n",
|
||
" (locked_dropout): LockedDropout(p=0.5)\n",
|
||
" (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True)\n",
|
||
" (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True)\n",
|
||
" (linear): Linear(in_features=512, out_features=19, bias=True)\n",
|
||
" (loss_function): ViterbiLoss()\n",
|
||
" (crf): CRF()\n",
|
||
")\"\n",
|
||
"2023-04-16 23:34:18,359 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:34:18,360 Corpus: \"Corpus: 270 train + 30 dev + 300 test sentences\"\n",
|
||
"2023-04-16 23:34:18,360 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:34:18,361 Parameters:\n",
|
||
"2023-04-16 23:34:18,361 - learning_rate: \"0.100000\"\n",
|
||
"2023-04-16 23:34:18,361 - mini_batch_size: \"32\"\n",
|
||
"2023-04-16 23:34:18,362 - patience: \"3\"\n",
|
||
"2023-04-16 23:34:18,363 - anneal_factor: \"0.5\"\n",
|
||
"2023-04-16 23:34:18,363 - max_epochs: \"10\"\n",
|
||
"2023-04-16 23:34:18,363 - shuffle: \"True\"\n",
|
||
"2023-04-16 23:34:18,364 - train_with_dev: \"False\"\n",
|
||
"2023-04-16 23:34:18,364 - batch_growth_annealing: \"False\"\n",
|
||
"2023-04-16 23:34:18,365 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:34:18,365 Model training base path: \"slot-model\"\n",
|
||
"2023-04-16 23:34:18,366 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:34:18,367 Device: cpu\n",
|
||
"2023-04-16 23:34:18,367 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:34:18,368 Embeddings storage mode: cpu\n",
|
||
"2023-04-16 23:34:18,368 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:34:20,279 epoch 1 - iter 1/9 - loss 3.46586463 - time (sec): 1.91 - samples/sec: 152.88 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:22,063 epoch 1 - iter 2/9 - loss 3.09300251 - time (sec): 3.69 - samples/sec: 153.49 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:23,066 epoch 1 - iter 3/9 - loss 3.09930113 - time (sec): 4.70 - samples/sec: 154.78 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:24,090 epoch 1 - iter 4/9 - loss 2.95447677 - time (sec): 5.72 - samples/sec: 155.39 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:24,981 epoch 1 - iter 5/9 - loss 2.81596711 - time (sec): 6.61 - samples/sec: 157.92 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:25,968 epoch 1 - iter 6/9 - loss 2.61288632 - time (sec): 7.60 - samples/sec: 156.47 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:26,968 epoch 1 - iter 7/9 - loss 2.37758500 - time (sec): 8.60 - samples/sec: 153.85 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:27,959 epoch 1 - iter 8/9 - loss 2.29585029 - time (sec): 9.59 - samples/sec: 153.81 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:28,676 epoch 1 - iter 9/9 - loss 2.27926901 - time (sec): 10.31 - samples/sec: 150.87 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:28,678 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:34:28,678 EPOCH 1 done: loss 2.2793 - lr 0.100000\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1/1 [00:01<00:00, 1.12s/it]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:29,797 Evaluating as a multi-label problem: False\n",
|
||
"2023-04-16 23:34:29,806 DEV : loss 1.2717913389205933 - f1-score (micro avg) 0.0357\n",
|
||
"2023-04-16 23:34:29,808 BAD EPOCHS (no improvement): 0\n",
|
||
"2023-04-16 23:34:29,809 saving best model\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:41,288 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:34:42,050 epoch 2 - iter 1/9 - loss 1.43326204 - time (sec): 0.76 - samples/sec: 241.47 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:42,846 epoch 2 - iter 2/9 - loss 1.40312603 - time (sec): 1.56 - samples/sec: 234.92 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:43,571 epoch 2 - iter 3/9 - loss 1.39219711 - time (sec): 2.28 - samples/sec: 235.65 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:44,338 epoch 2 - iter 4/9 - loss 1.38006975 - time (sec): 3.05 - samples/sec: 232.13 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:45,066 epoch 2 - iter 5/9 - loss 1.34341906 - time (sec): 3.78 - samples/sec: 232.93 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:45,850 epoch 2 - iter 6/9 - loss 1.31603716 - time (sec): 4.56 - samples/sec: 234.55 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:46,612 epoch 2 - iter 7/9 - loss 1.28879818 - time (sec): 5.32 - samples/sec: 237.23 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:47,400 epoch 2 - iter 8/9 - loss 1.27353642 - time (sec): 6.11 - samples/sec: 238.71 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:47,821 epoch 2 - iter 9/9 - loss 1.27439781 - time (sec): 6.53 - samples/sec: 238.02 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:47,823 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:34:47,823 EPOCH 2 done: loss 1.2744 - lr 0.100000\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1/1 [00:00<00:00, 4.76it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:48,036 Evaluating as a multi-label problem: False\n",
|
||
"2023-04-16 23:34:48,044 DEV : loss 0.7916695475578308 - f1-score (micro avg) 0.0\n",
|
||
"2023-04-16 23:34:48,046 BAD EPOCHS (no improvement): 1\n",
|
||
"2023-04-16 23:34:48,047 ----------------------------------------------------------------------------------------------------\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:48,750 epoch 3 - iter 1/9 - loss 0.95985486 - time (sec): 0.70 - samples/sec: 266.38 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:49,421 epoch 3 - iter 2/9 - loss 0.93650848 - time (sec): 1.37 - samples/sec: 258.56 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:50,208 epoch 3 - iter 3/9 - loss 0.91011594 - time (sec): 2.16 - samples/sec: 249.54 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:50,909 epoch 3 - iter 4/9 - loss 0.89886124 - time (sec): 2.86 - samples/sec: 249.21 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:51,580 epoch 3 - iter 5/9 - loss 0.87307849 - time (sec): 3.53 - samples/sec: 251.48 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:52,375 epoch 3 - iter 6/9 - loss 0.85130603 - time (sec): 4.33 - samples/sec: 251.67 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:53,161 epoch 3 - iter 7/9 - loss 0.83100431 - time (sec): 5.11 - samples/sec: 249.75 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:53,966 epoch 3 - iter 8/9 - loss 0.79737653 - time (sec): 5.92 - samples/sec: 248.39 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:54,372 epoch 3 - iter 9/9 - loss 0.78005277 - time (sec): 6.32 - samples/sec: 245.89 - lr: 0.100000\n",
|
||
"2023-04-16 23:34:54,373 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:34:54,374 EPOCH 3 done: loss 0.7801 - lr 0.100000\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1/1 [00:00<00:00, 4.50it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:34:54,600 Evaluating as a multi-label problem: False\n",
|
||
"2023-04-16 23:34:54,608 DEV : loss 0.33898717164993286 - f1-score (micro avg) 0.7761\n",
|
||
"2023-04-16 23:34:54,610 BAD EPOCHS (no improvement): 0\n",
|
||
"2023-04-16 23:34:54,611 saving best model\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:35:08,824 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:35:09,623 epoch 4 - iter 1/9 - loss 0.41515362 - time (sec): 0.80 - samples/sec: 244.67 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:10,428 epoch 4 - iter 2/9 - loss 0.51239108 - time (sec): 1.60 - samples/sec: 243.44 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:11,205 epoch 4 - iter 3/9 - loss 0.45658054 - time (sec): 2.38 - samples/sec: 232.97 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:11,983 epoch 4 - iter 4/9 - loss 0.45578642 - time (sec): 3.16 - samples/sec: 231.55 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:12,730 epoch 4 - iter 5/9 - loss 0.44140354 - time (sec): 3.90 - samples/sec: 236.42 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:13,509 epoch 4 - iter 6/9 - loss 0.42847639 - time (sec): 4.68 - samples/sec: 235.10 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:14,254 epoch 4 - iter 7/9 - loss 0.42680964 - time (sec): 5.43 - samples/sec: 238.02 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:15,046 epoch 4 - iter 8/9 - loss 0.41317872 - time (sec): 6.22 - samples/sec: 238.91 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:15,362 epoch 4 - iter 9/9 - loss 0.40072542 - time (sec): 6.54 - samples/sec: 237.91 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:15,364 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:35:15,364 EPOCH 4 done: loss 0.4007 - lr 0.100000\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1/1 [00:00<00:00, 4.63it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:35:15,583 Evaluating as a multi-label problem: False\n",
|
||
"2023-04-16 23:35:15,591 DEV : loss 0.16208426654338837 - f1-score (micro avg) 0.9231\n",
|
||
"2023-04-16 23:35:15,593 BAD EPOCHS (no improvement): 0\n",
|
||
"2023-04-16 23:35:15,595 saving best model\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:35:26,934 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:35:27,683 epoch 5 - iter 1/9 - loss 0.23677549 - time (sec): 0.75 - samples/sec: 243.64 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:28,430 epoch 5 - iter 2/9 - loss 0.21526806 - time (sec): 1.49 - samples/sec: 236.28 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:29,218 epoch 5 - iter 3/9 - loss 0.18914680 - time (sec): 2.28 - samples/sec: 232.25 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:30,001 epoch 5 - iter 4/9 - loss 0.21946464 - time (sec): 3.06 - samples/sec: 235.24 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:30,749 epoch 5 - iter 5/9 - loss 0.20726706 - time (sec): 3.81 - samples/sec: 238.39 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:31,440 epoch 5 - iter 6/9 - loss 0.20088127 - time (sec): 4.50 - samples/sec: 242.67 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:32,199 epoch 5 - iter 7/9 - loss 0.21393993 - time (sec): 5.26 - samples/sec: 242.64 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:32,995 epoch 5 - iter 8/9 - loss 0.20896170 - time (sec): 6.06 - samples/sec: 241.62 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:33,438 epoch 5 - iter 9/9 - loss 0.22090499 - time (sec): 6.50 - samples/sec: 239.16 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:33,440 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:35:33,440 EPOCH 5 done: loss 0.2209 - lr 0.100000\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1/1 [00:00<00:00, 4.67it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:35:33,657 Evaluating as a multi-label problem: False\n",
|
||
"2023-04-16 23:35:33,664 DEV : loss 0.07899065315723419 - f1-score (micro avg) 0.9706\n",
|
||
"2023-04-16 23:35:33,667 BAD EPOCHS (no improvement): 0\n",
|
||
"2023-04-16 23:35:33,668 saving best model\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:35:47,995 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:35:48,716 epoch 6 - iter 1/9 - loss 0.08489894 - time (sec): 0.72 - samples/sec: 265.65 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:49,474 epoch 6 - iter 2/9 - loss 0.16483877 - time (sec): 1.48 - samples/sec: 262.02 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:50,229 epoch 6 - iter 3/9 - loss 0.17961087 - time (sec): 2.23 - samples/sec: 261.20 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:50,883 epoch 6 - iter 4/9 - loss 0.17008273 - time (sec): 2.89 - samples/sec: 266.11 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:51,580 epoch 6 - iter 5/9 - loss 0.16165644 - time (sec): 3.58 - samples/sec: 259.35 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:52,303 epoch 6 - iter 6/9 - loss 0.16796368 - time (sec): 4.31 - samples/sec: 257.37 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:52,971 epoch 6 - iter 7/9 - loss 0.15208281 - time (sec): 4.97 - samples/sec: 254.12 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:53,698 epoch 6 - iter 8/9 - loss 0.14079077 - time (sec): 5.70 - samples/sec: 255.74 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:54,094 epoch 6 - iter 9/9 - loss 0.14176936 - time (sec): 6.10 - samples/sec: 255.04 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:54,095 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:35:54,096 EPOCH 6 done: loss 0.1418 - lr 0.100000\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1/1 [00:00<00:00, 5.41it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:35:54,284 Evaluating as a multi-label problem: False\n",
|
||
"2023-04-16 23:35:54,292 DEV : loss 0.05525948479771614 - f1-score (micro avg) 0.9706\n",
|
||
"2023-04-16 23:35:54,294 BAD EPOCHS (no improvement): 0\n",
|
||
"2023-04-16 23:35:54,295 ----------------------------------------------------------------------------------------------------\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:35:55,007 epoch 7 - iter 1/9 - loss 0.06206851 - time (sec): 0.71 - samples/sec: 262.64 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:55,659 epoch 7 - iter 2/9 - loss 0.07954220 - time (sec): 1.36 - samples/sec: 257.33 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:56,426 epoch 7 - iter 3/9 - loss 0.07803471 - time (sec): 2.13 - samples/sec: 248.71 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:57,138 epoch 7 - iter 4/9 - loss 0.09479619 - time (sec): 2.84 - samples/sec: 247.27 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:57,921 epoch 7 - iter 5/9 - loss 0.12208708 - time (sec): 3.63 - samples/sec: 251.03 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:58,664 epoch 7 - iter 6/9 - loss 0.11140702 - time (sec): 4.37 - samples/sec: 251.09 - lr: 0.100000\n",
|
||
"2023-04-16 23:35:59,452 epoch 7 - iter 7/9 - loss 0.10265536 - time (sec): 5.16 - samples/sec: 250.34 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:00,146 epoch 7 - iter 8/9 - loss 0.09764915 - time (sec): 5.85 - samples/sec: 252.61 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:00,536 epoch 7 - iter 9/9 - loss 0.09412069 - time (sec): 6.24 - samples/sec: 249.16 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:00,537 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:36:00,538 EPOCH 7 done: loss 0.0941 - lr 0.100000\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1/1 [00:00<00:00, 4.76it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:36:00,751 Evaluating as a multi-label problem: False\n",
|
||
"2023-04-16 23:36:00,759 DEV : loss 0.04532366245985031 - f1-score (micro avg) 0.9706\n",
|
||
"2023-04-16 23:36:00,761 BAD EPOCHS (no improvement): 0\n",
|
||
"2023-04-16 23:36:00,762 ----------------------------------------------------------------------------------------------------\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:36:01,551 epoch 8 - iter 1/9 - loss 0.03479994 - time (sec): 0.79 - samples/sec: 234.47 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:02,299 epoch 8 - iter 2/9 - loss 0.03457490 - time (sec): 1.54 - samples/sec: 235.52 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:03,135 epoch 8 - iter 3/9 - loss 0.02756396 - time (sec): 2.37 - samples/sec: 232.19 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:03,831 epoch 8 - iter 4/9 - loss 0.03832822 - time (sec): 3.07 - samples/sec: 235.58 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:04,675 epoch 8 - iter 5/9 - loss 0.05709582 - time (sec): 3.91 - samples/sec: 236.64 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:05,470 epoch 8 - iter 6/9 - loss 0.07284620 - time (sec): 4.71 - samples/sec: 237.68 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:06,242 epoch 8 - iter 7/9 - loss 0.09056851 - time (sec): 5.48 - samples/sec: 238.50 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:06,934 epoch 8 - iter 8/9 - loss 0.08531148 - time (sec): 6.17 - samples/sec: 238.50 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:07,344 epoch 8 - iter 9/9 - loss 0.08690437 - time (sec): 6.58 - samples/sec: 236.25 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:07,345 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:36:07,346 EPOCH 8 done: loss 0.0869 - lr 0.100000\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1/1 [00:00<00:00, 5.03it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:36:07,548 Evaluating as a multi-label problem: False\n",
|
||
"2023-04-16 23:36:07,556 DEV : loss 0.024222934618592262 - f1-score (micro avg) 0.9706\n",
|
||
"2023-04-16 23:36:07,558 BAD EPOCHS (no improvement): 0\n",
|
||
"2023-04-16 23:36:07,559 ----------------------------------------------------------------------------------------------------\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:36:08,281 epoch 9 - iter 1/9 - loss 0.04802091 - time (sec): 0.72 - samples/sec: 239.94 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:09,019 epoch 9 - iter 2/9 - loss 0.04962141 - time (sec): 1.46 - samples/sec: 241.95 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:09,715 epoch 9 - iter 3/9 - loss 0.04733611 - time (sec): 2.16 - samples/sec: 258.47 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:10,377 epoch 9 - iter 4/9 - loss 0.05639010 - time (sec): 2.82 - samples/sec: 258.08 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:11,111 epoch 9 - iter 5/9 - loss 0.11318641 - time (sec): 3.55 - samples/sec: 263.87 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:11,812 epoch 9 - iter 6/9 - loss 0.10583430 - time (sec): 4.25 - samples/sec: 260.58 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:12,486 epoch 9 - iter 7/9 - loss 0.09734085 - time (sec): 4.93 - samples/sec: 262.89 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:13,213 epoch 9 - iter 8/9 - loss 0.09001355 - time (sec): 5.65 - samples/sec: 259.68 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:13,582 epoch 9 - iter 9/9 - loss 0.08649074 - time (sec): 6.02 - samples/sec: 258.22 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:13,583 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:36:13,584 EPOCH 9 done: loss 0.0865 - lr 0.100000\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1/1 [00:00<00:00, 5.41it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:36:13,772 Evaluating as a multi-label problem: False\n",
|
||
"2023-04-16 23:36:13,779 DEV : loss 0.0387137271463871 - f1-score (micro avg) 0.9706\n",
|
||
"2023-04-16 23:36:13,781 BAD EPOCHS (no improvement): 1\n",
|
||
"2023-04-16 23:36:13,782 ----------------------------------------------------------------------------------------------------\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:36:14,486 epoch 10 - iter 1/9 - loss 0.01475051 - time (sec): 0.70 - samples/sec: 251.42 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:15,220 epoch 10 - iter 2/9 - loss 0.02797203 - time (sec): 1.44 - samples/sec: 246.17 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:15,881 epoch 10 - iter 3/9 - loss 0.03366118 - time (sec): 2.10 - samples/sec: 250.12 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:16,583 epoch 10 - iter 4/9 - loss 0.02913842 - time (sec): 2.80 - samples/sec: 256.33 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:17,329 epoch 10 - iter 5/9 - loss 0.02953250 - time (sec): 3.55 - samples/sec: 258.53 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:18,035 epoch 10 - iter 6/9 - loss 0.04606061 - time (sec): 4.25 - samples/sec: 259.82 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:18,729 epoch 10 - iter 7/9 - loss 0.04505843 - time (sec): 4.95 - samples/sec: 260.56 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:19,462 epoch 10 - iter 8/9 - loss 0.05644751 - time (sec): 5.68 - samples/sec: 258.63 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:19,893 epoch 10 - iter 9/9 - loss 0.06249184 - time (sec): 6.11 - samples/sec: 254.46 - lr: 0.100000\n",
|
||
"2023-04-16 23:36:19,895 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:36:19,895 EPOCH 10 done: loss 0.0625 - lr 0.100000\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1/1 [00:00<00:00, 5.43it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:36:20,082 Evaluating as a multi-label problem: False\n",
|
||
"2023-04-16 23:36:20,089 DEV : loss 0.03223632648587227 - f1-score (micro avg) 0.9275\n",
|
||
"2023-04-16 23:36:20,091 BAD EPOCHS (no improvement): 2\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:36:31,813 ----------------------------------------------------------------------------------------------------\n",
|
||
"2023-04-16 23:36:35,403 SequenceTagger predicts: Dictionary with 19 tags: O, S-datetime, B-datetime, E-datetime, I-datetime, S-weather/attribute, B-weather/attribute, E-weather/attribute, I-weather/attribute, S-weather/noun, B-weather/noun, E-weather/noun, I-weather/noun, S-location, B-location, E-location, I-location, <START>, <STOP>\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 10/10 [00:13<00:00, 1.40s/it]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:36:50,601 Evaluating as a multi-label problem: False\n",
|
||
"2023-04-16 23:36:50,612 0.2111\t0.1989\t0.2048\t0.1176\n",
|
||
"2023-04-16 23:36:50,613 \n",
|
||
"Results:\n",
|
||
"- F-score (micro) 0.2048\n",
|
||
"- F-score (macro) 0.1106\n",
|
||
"- Accuracy 0.1176\n",
|
||
"\n",
|
||
"By class:\n",
|
||
" precision recall f1-score support\n",
|
||
"\n",
|
||
" datetime 0.2136 0.2327 0.2227 202\n",
|
||
" weather/noun 0.2184 0.5588 0.3140 34\n",
|
||
" reminder/todo 0.0000 0.0000 0.0000 46\n",
|
||
" reminder/noun 0.0000 0.0000 0.0000 42\n",
|
||
" weather/attribute 0.1765 0.1667 0.1714 18\n",
|
||
" location 0.1765 0.1765 0.1765 17\n",
|
||
"reminder/recurring_period 0.0000 0.0000 0.0000 2\n",
|
||
" negation 0.0000 0.0000 0.0000 1\n",
|
||
"\n",
|
||
" micro avg 0.2111 0.1989 0.2048 362\n",
|
||
" macro avg 0.0981 0.1418 0.1106 362\n",
|
||
" weighted avg 0.1568 0.1989 0.1706 362\n",
|
||
"\n",
|
||
"2023-04-16 23:36:50,613 ----------------------------------------------------------------------------------------------------\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'test_score': 0.20483641536273117,\n",
|
||
" 'dev_score_history': [0.03571428571428571,\n",
|
||
" 0.0,\n",
|
||
" 0.7761194029850745,\n",
|
||
" 0.923076923076923,\n",
|
||
" 0.9705882352941176,\n",
|
||
" 0.9705882352941176,\n",
|
||
" 0.9705882352941176,\n",
|
||
" 0.9705882352941176,\n",
|
||
" 0.9705882352941176,\n",
|
||
" 0.9275362318840579],\n",
|
||
" 'train_loss_history': [2.279269006857918,\n",
|
||
" 1.2743978126256028,\n",
|
||
" 0.7800527689157958,\n",
|
||
" 0.40072541558857516,\n",
|
||
" 0.22090499240102493,\n",
|
||
" 0.14176936011605706,\n",
|
||
" 0.09412068554059486,\n",
|
||
" 0.08690436752662781,\n",
|
||
" 0.08649074149668408,\n",
|
||
" 0.06249183581189711],\n",
|
||
" 'dev_loss_history': [1.2717913389205933,\n",
|
||
" 0.7916695475578308,\n",
|
||
" 0.33898717164993286,\n",
|
||
" 0.16208426654338837,\n",
|
||
" 0.07899065315723419,\n",
|
||
" 0.05525948479771614,\n",
|
||
" 0.04532366245985031,\n",
|
||
" 0.024222934618592262,\n",
|
||
" 0.0387137271463871,\n",
|
||
" 0.03223632648587227]}"
|
||
]
|
||
},
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"trainer = ModelTrainer(tagger, corpus)\n",
|
||
"trainer.train('slot-model',\n",
|
||
" learning_rate=0.1,\n",
|
||
" mini_batch_size=32,\n",
|
||
" max_epochs=10,\n",
|
||
" train_with_dev=False)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "bcd0c303",
|
||
"metadata": {},
|
||
"source": [
|
||
"Jakość wyuczonego modelu możemy ocenić, korzystając z zaraportowanych powyżej metryk, tj.:\n",
|
||
"\n",
|
||
" - *tp (true positives)*\n",
|
||
"\n",
|
||
" > liczba słów oznaczonych w zbiorze testowym etykietą $e$, które model oznaczył tą etykietą\n",
|
||
"\n",
|
||
" - *fp (false positives)*\n",
|
||
"\n",
|
||
" > liczba słów nieoznaczonych w zbiorze testowym etykietą $e$, które model oznaczył tą etykietą\n",
|
||
"\n",
|
||
" - *fn (false negatives)*\n",
|
||
"\n",
|
||
" > liczba słów oznaczonych w zbiorze testowym etykietą $e$, którym model nie nadał etykiety $e$\n",
|
||
"\n",
|
||
" - *precision*\n",
|
||
"\n",
|
||
" > $$\\frac{tp}{tp + fp}$$\n",
|
||
"\n",
|
||
" - *recall*\n",
|
||
"\n",
|
||
" > $$\\frac{tp}{tp + fn}$$\n",
|
||
"\n",
|
||
" - $F_1$\n",
|
||
"\n",
|
||
" > $$\\frac{2 \\cdot precision \\cdot recall}{precision + recall}$$\n",
|
||
"\n",
|
||
" - *micro* $F_1$\n",
|
||
"\n",
|
||
" > $F_1$ w którym $tp$, $fp$ i $fn$ są liczone łącznie dla wszystkich etykiet, tj. $tp = \\sum_{e}{{tp}_e}$, $fn = \\sum_{e}{{fn}_e}$, $fp = \\sum_{e}{{fp}_e}$\n",
|
||
"\n",
|
||
" - *macro* $F_1$\n",
|
||
"\n",
|
||
" > średnia arytmetyczna z $F_1$ obliczonych dla poszczególnych etykiet z osobna.\n",
|
||
"\n",
|
||
"Wyuczony model możemy wczytać z pliku korzystając z metody `load`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"id": "d12596c1",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2023-04-16 23:37:00,272 SequenceTagger predicts: Dictionary with 19 tags: O, S-datetime, B-datetime, E-datetime, I-datetime, S-weather/attribute, B-weather/attribute, E-weather/attribute, I-weather/attribute, S-weather/noun, B-weather/noun, E-weather/noun, I-weather/noun, S-location, B-location, E-location, I-location, <START>, <STOP>\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"model = SequenceTagger.load('slot-model/final-model.pt')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a97dd603",
|
||
"metadata": {},
|
||
"source": [
|
||
"Wczytany model możemy wykorzystać do przewidywania slotów w wypowiedziach użytkownika, korzystając\n",
|
||
"z przedstawionej poniżej funkcji `predict`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"id": "87c310cf",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def predict(model, sentence):\n",
|
||
" csentence = [{'form': word, 'slot': 'O'} for word in sentence]\n",
|
||
" fsentence = conllu2flair([csentence])[0]\n",
|
||
" model.predict(fsentence)\n",
|
||
"\n",
|
||
" for span in fsentence.get_spans('slot'):\n",
|
||
" tag = span.get_label('slot').value\n",
|
||
" csentence[span.tokens[0].idx - 1]['slot'] = f'B-{tag}'\n",
|
||
"\n",
|
||
" for token in span.tokens[1:]:\n",
|
||
" csentence[token.idx - 1]['slot'] = f'I-{tag}'\n",
|
||
"\n",
|
||
" return csentence\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"id": "97043331",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<table>\n",
|
||
"<tbody>\n",
|
||
"<tr><td>set </td><td>O </td></tr>\n",
|
||
"<tr><td>alarm </td><td>O </td></tr>\n",
|
||
"<tr><td>for </td><td>B-datetime</td></tr>\n",
|
||
"<tr><td>20 </td><td>B-datetime</td></tr>\n",
|
||
"<tr><td>minutes</td><td>I-datetime</td></tr>\n",
|
||
"</tbody>\n",
|
||
"</table>"
|
||
],
|
||
"text/plain": [
|
||
"'<table>\\n<tbody>\\n<tr><td>set </td><td>O </td></tr>\\n<tr><td>alarm </td><td>O </td></tr>\\n<tr><td>for </td><td>B-datetime</td></tr>\\n<tr><td>20 </td><td>B-datetime</td></tr>\\n<tr><td>minutes</td><td>I-datetime</td></tr>\\n</tbody>\\n</table>'"
|
||
]
|
||
},
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"tabulate(predict(model, 'set alarm for 20 minutes'.split()), tablefmt='html')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"id": "29856a8a",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<table>\n",
|
||
"<tbody>\n",
|
||
"<tr><td>change</td><td>O </td></tr>\n",
|
||
"<tr><td>my </td><td>O </td></tr>\n",
|
||
"<tr><td>3 </td><td>O </td></tr>\n",
|
||
"<tr><td>pm </td><td>O </td></tr>\n",
|
||
"<tr><td>alarm </td><td>O </td></tr>\n",
|
||
"<tr><td>to </td><td>O </td></tr>\n",
|
||
"<tr><td>the </td><td>O </td></tr>\n",
|
||
"<tr><td>next </td><td>O </td></tr>\n",
|
||
"<tr><td>day </td><td>B-weather/noun</td></tr>\n",
|
||
"</tbody>\n",
|
||
"</table>"
|
||
],
|
||
"text/plain": [
|
||
"'<table>\\n<tbody>\\n<tr><td>change</td><td>O </td></tr>\\n<tr><td>my </td><td>O </td></tr>\\n<tr><td>3 </td><td>O </td></tr>\\n<tr><td>pm </td><td>O </td></tr>\\n<tr><td>alarm </td><td>O </td></tr>\\n<tr><td>to </td><td>O </td></tr>\\n<tr><td>the </td><td>O </td></tr>\\n<tr><td>next </td><td>O </td></tr>\\n<tr><td>day </td><td>B-weather/noun</td></tr>\\n</tbody>\\n</table>'"
|
||
]
|
||
},
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"tabulate(predict(model, 'change my 3 pm alarm to the next day'.split()), tablefmt='html')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "21b00302",
|
||
"metadata": {},
|
||
"source": [
|
||
"Literatura\n",
|
||
"----------\n",
|
||
" 1. Sebastian Schuster, Sonal Gupta, Rushin Shah, Mike Lewis, Cross-lingual Transfer Learning for Multilingual Task Oriented Dialog. NAACL-HLT (1) 2019, pp. 3795-3805\n",
|
||
" 2. John D. Lafferty, Andrew McCallum, and Fernando C. N. Pereira. 2001. Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data. In Proceedings of the Eighteenth International Conference on Machine Learning (ICML '01). Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, 282–289, https://repository.upenn.edu/cgi/viewcontent.cgi?article=1162&context=cis_papers\n",
|
||
" 3. Sepp Hochreiter and Jürgen Schmidhuber. 1997. Long Short-Term Memory. Neural Comput. 9, 8 (November 15, 1997), 1735–1780, https://doi.org/10.1162/neco.1997.9.8.1735\n",
|
||
" 4. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, Attention is All you Need, NIPS 2017, pp. 5998-6008, https://arxiv.org/abs/1706.03762\n",
|
||
" 5. Alan Akbik, Duncan Blythe, Roland Vollgraf, Contextual String Embeddings for Sequence Labeling, Proceedings of the 27th International Conference on Computational Linguistics, pp. 1638–1649, https://www.aclweb.org/anthology/C18-1139.pdf\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"jupytext": {
|
||
"cell_metadata_filter": "-all",
|
||
"main_language": "python",
|
||
"notebook_metadata_filter": "-all"
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.2"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|