{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Parsing semantyczny z wykorzystaniem technik uczenia maszynowego\n",
"================================================================\n",
"\n",
"Wprowadzenie\n",
"------------\n",
"Problem wykrywania slotów i ich wartości w wypowiedziach użytkownika można sformułować jako zadanie\n",
"polegające na przewidywaniu dla poszczególnych słów etykiet wskazujących na to czy i do jakiego\n",
"slotu dane słowo należy.\n",
"\n",
"> chciałbym zarezerwować stolik na jutro**/day** na godzinę dwunastą**/hour** czterdzieści**/hour** pięć**/hour** na pięć**/size** osób\n",
"\n",
"Granice slotów oznacza się korzystając z wybranego schematu etykietowania.\n",
"\n",
"### Schemat IOB\n",
"\n",
"| Prefix | Znaczenie |\n",
"|:------:|:---------------------------|\n",
"| I | wnętrze slotu (inside) |\n",
"| O | poza slotem (outside) |\n",
"| B | początek slotu (beginning) |\n",
"\n",
"> chciałbym zarezerwować stolik na jutro**/B-day** na godzinę dwunastą**/B-hour** czterdzieści**/I-hour** pięć**/I-hour** na pięć**/B-size** osób\n",
"\n",
"### Schemat IOBES\n",
"\n",
"| Prefix | Znaczenie |\n",
"|:------:|:---------------------------|\n",
"| I | wnętrze slotu (inside) |\n",
"| O | poza slotem (outside) |\n",
"| B | początek slotu (beginning) |\n",
"| E | koniec slotu (ending) |\n",
"| S | pojedyncze słowo (single) |\n",
"\n",
"> chciałbym zarezerwować stolik na jutro**/S-day** na godzinę dwunastą**/B-hour** czterdzieści**/I-hour** pięć**/E-hour** na pięć**/S-size** osób\n",
"\n",
"Jeżeli dla tak sformułowanego zadania przygotujemy zbiór danych\n",
"złożony z wypowiedzi użytkownika z oznaczonymi slotami (tzw. *zbiór uczący*),\n",
"to możemy zastosować techniki (nadzorowanego) uczenia maszynowego w celu zbudowania modelu\n",
"annotującego wypowiedzi użytkownika etykietami slotów.\n",
"\n",
"Do zbudowania takiego modelu można wykorzystać między innymi:\n",
"\n",
" 1. warunkowe pola losowe (Lafferty i in.; 2001),\n",
"\n",
" 2. rekurencyjne sieci neuronowe, np. sieci LSTM (Hochreiter i Schmidhuber; 1997),\n",
"\n",
" 3. transformery (Vaswani i in., 2017).\n",
"\n",
"Przykład\n",
"--------\n",
"Skorzystamy ze zbioru danych przygotowanego przez Schustera (2019)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"C:\\Users\\domstr2\\l07\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"\n",
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
"\n",
" 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n",
" 1 8714k 1 95352 0 0 66216 0 0:02:14 0:00:01 0:02:13 93666\n",
"100 8714k 100 8714k 0 0 4211k 0 0:00:02 0:00:02 --:--:-- 5290k\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"C:\\Users\\domstr2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"'unzip' is not recognized as an internal or external command,\n",
"operable program or batch file.\n"
]
}
],
"source": [
"!mkdir -p l07\n",
"%cd l07\n",
"!curl -L -C - https://fb.me/multilingual_task_oriented_data -o data.zip\n",
"%cd .."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Zbiór ten gromadzi wypowiedzi w trzech językach opisane slotami dla dwunastu ram należących do trzech dziedzin `Alarm`, `Reminder` oraz `Weather`. Dane wczytamy korzystając z biblioteki [conllu](https://pypi.org/project/conllu/)."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: conllu in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (4.4)\n"
]
}
],
"source": [
"!pip3 install conllu\n",
"import codecs\n",
"from conllu import parse_incr\n",
"fields = ['id', 'form', 'frame', 'slot']\n",
"\n",
"def nolabel2o(line, i):\n",
" return 'O' if line[i] == 'NoLabel' else line[i]\n",
"\n",
"with open('l07/Janet_test.conllu', encoding='utf-8') as trainfile:\n",
" trainset = list(parse_incr(trainfile, fields=fields, field_parsers={'slot': nolabel2o}))\n",
"with open('l07/Janet_test.conllu', encoding='utf-8') as testfile:\n",
" testset = list(parse_incr(testfile, fields=fields, field_parsers={'slot': nolabel2o}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Zobaczmy kilka przykładowych wypowiedzi z tego zbioru."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: tabulate in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (0.8.9)"
]
},
{
"data": {
"text/html": [
"
\n",
"\n",
"1 | hej | greeting | O |
\n",
"\n",
"
"
],
"text/plain": [
"'\\n\\n1 | hej | greeting | O |
\\n\\n
'"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"!pip3 install tabulate\n",
"from tabulate import tabulate\n",
"tabulate(trainset[0], tablefmt='html')"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"1 | chcialbym | prescription/collect | O |
\n",
"2 | odebrac | prescription/collect | O |
\n",
"3 | receptę | prescription/collect | O |
\n",
"\n",
"
"
],
"text/plain": [
"'\\n\\n1 | chcialbym | prescription/collect | O |
\\n2 | odebrac | prescription/collect | O |
\\n3 | receptę | prescription/collect | O |
\\n\\n
'"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tabulate(trainset[10], tablefmt='html')"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
" 1 | dzień | appoinment/create_appointment | O |
\n",
" 2 | dobry, | appoinment/create_appointment | O |
\n",
" 3 | chciałbym | appoinment/create_appointment | O |
\n",
" 4 | umówić | appoinment/create_appointment | O |
\n",
" 5 | się | appoinment/create_appointment | O |
\n",
" 6 | na | appoinment/create_appointment | O |
\n",
" 7 | wizytę | appoinment/create_appointment | O |
\n",
" 8 | do | appoinment/create_appointment | O |
\n",
" 9 | lekarza | appoinment/create_appointment | B-appoinment/doctor |
\n",
"10 | rodzinnego. | appoinment/create_appointment | I-appoinment/doctor |
\n",
"11 | najlepiej | appoinment/create_appointment | O |
\n",
"12 | dzisiaj | appoinment/create_appointment | B-datetime |
\n",
"13 | w | appoinment/create_appointment | I-datetime |
\n",
"14 | godzinach | appoinment/create_appointment | I-datetime |
\n",
"15 | popołudniowych. | appoinment/create_appointment | I-datetime |
\n",
"\n",
"
"
],
"text/plain": [
"'\\n\\n 1 | dzień | appoinment/create_appointment | O |
\\n 2 | dobry, | appoinment/create_appointment | O |
\\n 3 | chciałbym | appoinment/create_appointment | O |
\\n 4 | umówić | appoinment/create_appointment | O |
\\n 5 | się | appoinment/create_appointment | O |
\\n 6 | na | appoinment/create_appointment | O |
\\n 7 | wizytę | appoinment/create_appointment | O |
\\n 8 | do | appoinment/create_appointment | O |
\\n 9 | lekarza | appoinment/create_appointment | B-appoinment/doctor |
\\n10 | rodzinnego. | appoinment/create_appointment | I-appoinment/doctor |
\\n11 | najlepiej | appoinment/create_appointment | O |
\\n12 | dzisiaj | appoinment/create_appointment | B-datetime |
\\n13 | w | appoinment/create_appointment | I-datetime |
\\n14 | godzinach | appoinment/create_appointment | I-datetime |
\\n15 | popołudniowych. | appoinment/create_appointment | I-datetime |
\\n\\n
'"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tabulate(trainset[1], tablefmt='html')"
]
},
{
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"Na potrzeby prezentacji procesu uczenia w jupyterowym notatniku zawęzimy zbiór danych do początkowych przykładów."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"trainset = trainset[:100]\n",
"testset = testset[:100]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ąę\n"
]
}
],
"source": [
"print('ąę')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Budując model skorzystamy z architektury opartej o rekurencyjne sieci neuronowe\n",
"zaimplementowanej w bibliotece [flair](https://github.com/flairNLP/flair) (Akbik i in. 2018)."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: flair in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (0.8.0.post1)\n",
"Requirement already satisfied: tqdm>=4.26.0 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (4.50.2)\n",
"Requirement already satisfied: matplotlib>=2.2.3 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (3.3.2)\n",
"Requirement already satisfied: hyperopt>=0.1.1 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (0.2.5)\n",
"Requirement already satisfied: ftfy in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (6.0.1)\n",
"Requirement already satisfied: konoha<5.0.0,>=4.0.0 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (4.6.4)\n",
"Requirement already satisfied: bpemb>=0.3.2 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (0.3.3)\n",
"Requirement already satisfied: janome in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (0.4.1)\n",
"Requirement already satisfied: scikit-learn>=0.21.3 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (0.23.2)\n",
"Requirement already satisfied: transformers>=4.0.0 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (4.5.1)\n",
"Requirement already satisfied: gdown==3.12.2 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (3.12.2)\n",
"Requirement already satisfied: tabulate in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (0.8.9)\n",
"Requirement already satisfied: langdetect in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (1.0.9)\n",
"Requirement already satisfied: python-dateutil>=2.6.1 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (2.8.1)\n",
"Requirement already satisfied: deprecated>=1.2.4 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (1.2.12)\n",
"Requirement already satisfied: huggingface-hub in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (0.0.8)\n",
"Requirement already satisfied: regex in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (2020.10.15)\n",
"Requirement already satisfied: numpy<1.20.0 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (1.19.2)\n",
"Requirement already satisfied: gensim<=3.8.3,>=3.4.0 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (3.8.3)\n",
"Requirement already satisfied: sqlitedict>=1.6.0 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (1.7.0)\n",
"Requirement already satisfied: lxml in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (4.6.1)\n",
"Requirement already satisfied: sentencepiece==0.1.95 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (0.1.95)\n",
"Requirement already satisfied: segtok>=1.5.7 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (1.5.10)\n",
"Requirement already satisfied: torch<=1.7.1,>=1.5.0 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (1.7.1)\n",
"Requirement already satisfied: mpld3==0.3 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from flair) (0.3)\n",
"Requirement already satisfied: pillow>=6.2.0 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from matplotlib>=2.2.3->flair) (8.0.1)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from matplotlib>=2.2.3->flair) (2.4.7)\n",
"Requirement already satisfied: cycler>=0.10 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from matplotlib>=2.2.3->flair) (0.10.0)\n",
"Requirement already satisfied: certifi>=2020.06.20 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from matplotlib>=2.2.3->flair) (2020.6.20)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from matplotlib>=2.2.3->flair) (1.3.0)\n",
"Requirement already satisfied: six in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from hyperopt>=0.1.1->flair) (1.15.0)\n",
"Requirement already satisfied: future in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from hyperopt>=0.1.1->flair) (0.18.2)\n",
"Requirement already satisfied: scipy in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from hyperopt>=0.1.1->flair) (1.5.2)\n",
"Requirement already satisfied: cloudpickle in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from hyperopt>=0.1.1->flair) (1.6.0)\n",
"Requirement already satisfied: networkx>=2.2 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from hyperopt>=0.1.1->flair) (2.5)\n",
"Requirement already satisfied: wcwidth in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from ftfy->flair) (0.2.5)\n",
"Requirement already satisfied: importlib-metadata<4.0.0,>=3.7.0 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from konoha<5.0.0,>=4.0.0->flair) (3.10.1)\n",
"Collecting requests<3.0.0,>=2.25.1\n",
" Using cached requests-2.25.1-py2.py3-none-any.whl (61 kB)\n",
"Requirement already satisfied: overrides<4.0.0,>=3.0.0 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from konoha<5.0.0,>=4.0.0->flair) (3.1.0)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from scikit-learn>=0.21.3->flair) (2.1.0)\n",
"Requirement already satisfied: joblib>=0.11 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from scikit-learn>=0.21.3->flair) (0.17.0)\n",
"Requirement already satisfied: sacremoses in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from transformers>=4.0.0->flair) (0.0.45)\n",
"Requirement already satisfied: packaging in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from transformers>=4.0.0->flair) (20.4)\n",
"Requirement already satisfied: filelock in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from transformers>=4.0.0->flair) (3.0.12)\n",
"Requirement already satisfied: tokenizers<0.11,>=0.10.1 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from transformers>=4.0.0->flair) (0.10.2)\n",
"Requirement already satisfied: wrapt<2,>=1.10 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from deprecated>=1.2.4->flair) (1.12.1)\n",
"Requirement already satisfied: Cython==0.29.14 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from gensim<=3.8.3,>=3.4.0->flair) (0.29.14)\n",
"Requirement already satisfied: smart-open>=1.8.1 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from gensim<=3.8.3,>=3.4.0->flair) (5.0.0)\n",
"Requirement already satisfied: typing-extensions in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from torch<=1.7.1,>=1.5.0->flair) (3.7.4.3)\n",
"Requirement already satisfied: decorator>=4.3.0 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from networkx>=2.2->hyperopt>=0.1.1->flair) (4.4.2)\n",
"Requirement already satisfied: zipp>=0.5 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from importlib-metadata<4.0.0,>=3.7.0->konoha<5.0.0,>=4.0.0->flair) (3.4.0)\n",
"Requirement already satisfied: idna<3,>=2.5 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from requests<3.0.0,>=2.25.1->konoha<5.0.0,>=4.0.0->flair) (2.10)\n",
"Requirement already satisfied: chardet<5,>=3.0.2 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from requests<3.0.0,>=2.25.1->konoha<5.0.0,>=4.0.0->flair) (3.0.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from requests<3.0.0,>=2.25.1->konoha<5.0.0,>=4.0.0->flair) (1.25.11)\n",
"Requirement already satisfied: click in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from sacremoses->transformers>=4.0.0->flair) (7.1.2)\n",
"Installing collected packages: requests\n",
" Attempting uninstall: requests\n",
" Found existing installation: requests 2.24.0\n",
" Uninstalling requests-2.24.0:\n",
" Successfully uninstalled requests-2.24.0\n",
"Successfully installed requests-2.25.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"ERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.\n",
"\n",
"We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default.\n",
"\n",
"conda 4.10.1 requires ruamel_yaml_conda>=0.11.14, which is not installed.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: torch in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (1.7.1)\n",
"Requirement already satisfied: typing-extensions in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from torch) (3.7.4.3)\n",
"Requirement already satisfied: numpy in c:\\users\\domstr2\\anaconda3\\lib\\site-packages (from torch) (1.19.2)\n"
]
}
],
"source": [
"!pip3 install flair\n",
"from flair.data import Corpus, Sentence, Token\n",
"from flair.datasets import SentenceDataset\n",
"from flair.embeddings import StackedEmbeddings\n",
"from flair.embeddings import WordEmbeddings\n",
"from flair.embeddings import CharacterEmbeddings\n",
"from flair.embeddings import FlairEmbeddings\n",
"from flair.models import SequenceTagger\n",
"from flair.trainers import ModelTrainer\n",
"\n",
"!pip3 install torch\n",
"# determinizacja obliczeń\n",
"import random\n",
"import torch\n",
"random.seed(42)\n",
"torch.manual_seed(42)\n",
"\n",
"if torch.cuda.is_available():\n",
" torch.cuda.manual_seed(0)\n",
" torch.cuda.manual_seed_all(0)\n",
" torch.backends.cudnn.enabled = False\n",
" torch.backends.cudnn.benchmark = False\n",
" torch.backends.cudnn.deterministic = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dane skonwertujemy do formatu wykorzystywanego przez `flair`, korzystając z następującej funkcji."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Corpus: 37 train + 4 dev + 41 test sentences\n",
"Dictionary with 13 tags: , O, B-appoinment/doctor, I-appoinment/doctor, B-datetime, I-datetime, B-login/id, B-login/password, B-appointment/type, I-appointment/type, B-prescription/type, , \n"
]
}
],
"source": [
"def conllu2flair(sentences, label=None):\n",
" fsentences = []\n",
"\n",
" for sentence in sentences:\n",
" fsentence = Sentence()\n",
"\n",
" for token in sentence:\n",
" ftoken = Token(token['form'])\n",
"\n",
" if label:\n",
" ftoken.add_tag(label, token[label])\n",
"\n",
" fsentence.add_token(ftoken)\n",
"\n",
" fsentences.append(fsentence)\n",
"\n",
" return SentenceDataset(fsentences)\n",
"\n",
"corpus = Corpus(train=conllu2flair(trainset, 'slot'), test=conllu2flair(testset, 'slot'))\n",
"print(corpus)\n",
"tag_dictionary = corpus.make_tag_dictionary(tag_type='slot')\n",
"print(tag_dictionary)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nasz model będzie wykorzystywał wektorowe reprezentacje słów (zob. [Word Embeddings](https://github.com/flairNLP/flair/blob/master/resources/docs/TUTORIAL_3_WORD_EMBEDDING.md))."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-05-12 17:01:27,807 https://flair.informatik.hu-berlin.de/resources/embeddings/token/pl-wiki-fasttext-300d-1M.vectors.npy not found in cache, downloading to C:\\Users\\domstr2\\AppData\\Local\\Temp\\tmpq9mlzfps\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1199998928/1199998928 [00:52<00:00, 22832915.30B/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-05-12 17:02:20,552 copying C:\\Users\\domstr2\\AppData\\Local\\Temp\\tmpq9mlzfps to cache at C:\\Users\\domstr2\\.flair\\embeddings\\pl-wiki-fasttext-300d-1M.vectors.npy\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-05-12 17:02:32,864 removing temp file C:\\Users\\domstr2\\AppData\\Local\\Temp\\tmpq9mlzfps\n",
"2021-05-12 17:02:33,344 https://flair.informatik.hu-berlin.de/resources/embeddings/token/pl-wiki-fasttext-300d-1M not found in cache, downloading to C:\\Users\\domstr2\\AppData\\Local\\Temp\\tmpp2reld0s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 40874795/40874795 [00:01<00:00, 21969279.66B/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-05-12 17:02:35,412 copying C:\\Users\\domstr2\\AppData\\Local\\Temp\\tmpp2reld0s to cache at C:\\Users\\domstr2\\.flair\\embeddings\\pl-wiki-fasttext-300d-1M\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-05-12 17:02:36,260 removing temp file C:\\Users\\domstr2\\AppData\\Local\\Temp\\tmpp2reld0s\n",
"2021-05-12 17:02:39,489 https://flair.informatik.hu-berlin.de/resources/embeddings/flair/lm-polish-forward-v0.2.pt not found in cache, downloading to C:\\Users\\domstr2\\AppData\\Local\\Temp\\tmpin9zi6n_\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 84244196/84244196 [00:03<00:00, 27120526.13B/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-05-12 17:02:42,804 copying C:\\Users\\domstr2\\AppData\\Local\\Temp\\tmpin9zi6n_ to cache at C:\\Users\\domstr2\\.flair\\embeddings\\lm-polish-forward-v0.2.pt\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-05-12 17:02:42,861 removing temp file C:\\Users\\domstr2\\AppData\\Local\\Temp\\tmpin9zi6n_\n",
"2021-05-12 17:02:43,329 https://flair.informatik.hu-berlin.de/resources/embeddings/flair/lm-polish-backward-v0.2.pt not found in cache, downloading to C:\\Users\\domstr2\\AppData\\Local\\Temp\\tmp30skh32n\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 84244196/84244196 [00:03<00:00, 25790261.34B/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-05-12 17:02:46,769 copying C:\\Users\\domstr2\\AppData\\Local\\Temp\\tmp30skh32n to cache at C:\\Users\\domstr2\\.flair\\embeddings\\lm-polish-backward-v0.2.pt\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-05-12 17:02:46,828 removing temp file C:\\Users\\domstr2\\AppData\\Local\\Temp\\tmp30skh32n\n"
]
}
],
"source": [
"embedding_types = [\n",
" WordEmbeddings('pl'),\n",
" FlairEmbeddings('pl-forward'),\n",
" FlairEmbeddings('pl-backward'),\n",
" CharacterEmbeddings(),\n",
"]\n",
"\n",
"embeddings = StackedEmbeddings(embeddings=embedding_types)\n",
"tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,\n",
" tag_dictionary=tag_dictionary,\n",
" tag_type='slot', use_crf=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Zobaczmy jak wygląda architektura sieci neuronowej, która będzie odpowiedzialna za przewidywanie\n",
"slotów w wypowiedziach."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SequenceTagger(\n",
" (embeddings): StackedEmbeddings(\n",
" (list_embedding_0): WordEmbeddings('pl')\n",
" (list_embedding_1): FlairEmbeddings(\n",
" (lm): LanguageModel(\n",
" (drop): Dropout(p=0.25, inplace=False)\n",
" (encoder): Embedding(1602, 100)\n",
" (rnn): LSTM(100, 2048)\n",
" (decoder): Linear(in_features=2048, out_features=1602, bias=True)\n",
" )\n",
" )\n",
" (list_embedding_2): FlairEmbeddings(\n",
" (lm): LanguageModel(\n",
" (drop): Dropout(p=0.25, inplace=False)\n",
" (encoder): Embedding(1602, 100)\n",
" (rnn): LSTM(100, 2048)\n",
" (decoder): Linear(in_features=2048, out_features=1602, bias=True)\n",
" )\n",
" )\n",
" (list_embedding_3): CharacterEmbeddings(\n",
" (char_embedding): Embedding(275, 25)\n",
" (char_rnn): LSTM(25, 25, bidirectional=True)\n",
" )\n",
" )\n",
" (word_dropout): WordDropout(p=0.05)\n",
" (locked_dropout): LockedDropout(p=0.5)\n",
" (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True)\n",
" (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True)\n",
" (linear): Linear(in_features=512, out_features=13, bias=True)\n",
" (beta): 1.0\n",
" (weights): None\n",
" (weight_tensor) None\n",
")\n"
]
}
],
"source": [
"print(tagger)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Wykonamy dziesięć iteracji (epok) uczenia a wynikowy model zapiszemy w katalogu `slot-model`."
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-05-12 17:07:41,538 ----------------------------------------------------------------------------------------------------\n",
"2021-05-12 17:07:41,539 Model: \"SequenceTagger(\n",
" (embeddings): StackedEmbeddings(\n",
" (list_embedding_0): WordEmbeddings('pl')\n",
" (list_embedding_1): FlairEmbeddings(\n",
" (lm): LanguageModel(\n",
" (drop): Dropout(p=0.25, inplace=False)\n",
" (encoder): Embedding(1602, 100)\n",
" (rnn): LSTM(100, 2048)\n",
" (decoder): Linear(in_features=2048, out_features=1602, bias=True)\n",
" )\n",
" )\n",
" (list_embedding_2): FlairEmbeddings(\n",
" (lm): LanguageModel(\n",
" (drop): Dropout(p=0.25, inplace=False)\n",
" (encoder): Embedding(1602, 100)\n",
" (rnn): LSTM(100, 2048)\n",
" (decoder): Linear(in_features=2048, out_features=1602, bias=True)\n",
" )\n",
" )\n",
" (list_embedding_3): CharacterEmbeddings(\n",
" (char_embedding): Embedding(275, 25)\n",
" (char_rnn): LSTM(25, 25, bidirectional=True)\n",
" )\n",
" )\n",
" (word_dropout): WordDropout(p=0.05)\n",
" (locked_dropout): LockedDropout(p=0.5)\n",
" (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True)\n",
" (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True)\n",
" (linear): Linear(in_features=512, out_features=13, bias=True)\n",
" (beta): 1.0\n",
" (weights): None\n",
" (weight_tensor) None\n",
")\"\n",
"2021-05-12 17:07:41,540 ----------------------------------------------------------------------------------------------------\n",
"2021-05-12 17:07:41,541 Corpus: \"Corpus: 37 train + 4 dev + 41 test sentences\"\n",
"2021-05-12 17:07:41,541 ----------------------------------------------------------------------------------------------------\n",
"2021-05-12 17:07:41,542 Parameters:\n",
"2021-05-12 17:07:41,542 - learning_rate: \"0.1\"\n",
"2021-05-12 17:07:41,543 - mini_batch_size: \"32\"\n",
"2021-05-12 17:07:41,543 - patience: \"3\"\n",
"2021-05-12 17:07:41,544 - anneal_factor: \"0.5\"\n",
"2021-05-12 17:07:41,544 - max_epochs: \"10\"\n",
"2021-05-12 17:07:41,545 - shuffle: \"True\"\n",
"2021-05-12 17:07:41,546 - train_with_dev: \"False\"\n",
"2021-05-12 17:07:41,546 - batch_growth_annealing: \"False\"\n",
"2021-05-12 17:07:41,547 ----------------------------------------------------------------------------------------------------\n",
"2021-05-12 17:07:41,547 Model training base path: \"slot-model\"\n",
"2021-05-12 17:07:41,548 ----------------------------------------------------------------------------------------------------\n",
"2021-05-12 17:07:41,549 Device: cpu\n",
"2021-05-12 17:07:41,549 ----------------------------------------------------------------------------------------------------\n",
"2021-05-12 17:07:41,550 Embeddings storage mode: cpu\n",
"2021-05-12 17:07:41,552 ----------------------------------------------------------------------------------------------------\n",
"2021-05-12 17:07:46,139 epoch 1 - iter 1/2 - loss 9.51263237 - samples/sec: 6.98 - lr: 0.100000\n",
"2021-05-12 17:07:47,186 epoch 1 - iter 2/2 - loss 7.22621894 - samples/sec: 30.58 - lr: 0.100000\n",
"2021-05-12 17:07:47,188 ----------------------------------------------------------------------------------------------------\n",
"2021-05-12 17:07:47,189 EPOCH 1 done: loss 7.2262 - lr 0.1000000\n",
"2021-05-12 17:07:48,466 DEV : loss 5.046579837799072 - score 0.0\n",
"2021-05-12 17:07:48,468 BAD EPOCHS (no improvement): 0\n",
"saving best model\n"
]
},
{
"ename": "RuntimeError",
"evalue": "[enforce fail at ..\\caffe2\\serialize\\inline_container.cc:274] . unexpected pos 64 vs 0",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mOSError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36msave\u001b[1;34m(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)\u001b[0m\n\u001b[0;32m 371\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0m_open_zipfile_writer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mopened_zipfile\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 372\u001b[1;33m \u001b[0m_save\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mopened_zipfile\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_protocol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 373\u001b[0m \u001b[1;32mreturn\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36m_save\u001b[1;34m(obj, zip_file, pickle_module, pickle_protocol)\u001b[0m\n\u001b[0;32m 477\u001b[0m \u001b[0mdata_value\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata_buf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgetvalue\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 478\u001b[1;33m \u001b[0mzip_file\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwrite_record\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'data.pkl'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdata_value\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata_value\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 479\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mOSError\u001b[0m: [Errno 28] No space left on device",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[0mtrainer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mModelTrainer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtagger\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcorpus\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m trainer.train('slot-model',\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[0mlearning_rate\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mmini_batch_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m32\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mmax_epochs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\flair\\trainers\\trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(self, base_path, learning_rate, mini_batch_size, mini_batch_chunk_size, max_epochs, scheduler, cycle_momentum, anneal_factor, patience, initial_extra_patience, min_learning_rate, train_with_dev, train_with_test, monitor_train, monitor_test, embeddings_storage_mode, checkpoint, save_final_model, anneal_with_restarts, anneal_with_prestarts, batch_growth_annealing, shuffle, param_selection_mode, write_weights, num_workers, sampler, use_amp, amp_opt_level, eval_on_train_fraction, eval_on_train_shuffle, save_model_at_each_epoch, **kwargs)\u001b[0m\n\u001b[0;32m 592\u001b[0m ):\n\u001b[0;32m 593\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"saving best model\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 594\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbase_path\u001b[0m \u001b[1;33m/\u001b[0m \u001b[1;34m\"best-model.pt\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 595\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 596\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0manneal_with_prestarts\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\flair\\nn.py\u001b[0m in \u001b[0;36msave\u001b[1;34m(self, model_file)\u001b[0m\n\u001b[0;32m 70\u001b[0m \u001b[0mmodel_state\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_state_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 71\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 72\u001b[1;33m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_state\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_protocol\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 73\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 74\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mclassmethod\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36msave\u001b[1;34m(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)\u001b[0m\n\u001b[0;32m 371\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0m_open_zipfile_writer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mopened_zipfile\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 372\u001b[0m \u001b[0m_save\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mopened_zipfile\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_protocol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 373\u001b[1;33m \u001b[1;32mreturn\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 374\u001b[0m \u001b[0m_legacy_save\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mopened_file\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_protocol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 375\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36m__exit__\u001b[1;34m(self, *args)\u001b[0m\n\u001b[0;32m 257\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 258\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__exit__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 259\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfile_like\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwrite_end_of_file\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 260\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mflush\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 261\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mRuntimeError\u001b[0m: [enforce fail at ..\\caffe2\\serialize\\inline_container.cc:274] . unexpected pos 64 vs 0"
]
}
],
"source": [
"trainer = ModelTrainer(tagger, corpus)\n",
"trainer.train('slot-model',\n",
" learning_rate=0.1,\n",
" mini_batch_size=32,\n",
" max_epochs=10,\n",
" train_with_dev=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Jakość wyuczonego modelu możemy ocenić, korzystając z zaraportowanych powyżej metryk, tj.:\n",
"\n",
" - *tp (true positives)*\n",
"\n",
" > liczba słów oznaczonych w zbiorze testowym etykietą $e$, które model oznaczył tą etykietą\n",
"\n",
" - *fp (false positives)*\n",
"\n",
" > liczba słów nieoznaczonych w zbiorze testowym etykietą $e$, które model oznaczył tą etykietą\n",
"\n",
" - *fn (false negatives)*\n",
"\n",
" > liczba słów oznaczonych w zbiorze testowym etykietą $e$, którym model nie nadał etykiety $e$\n",
"\n",
" - *precision*\n",
"\n",
" > $$\\frac{tp}{tp + fp}$$\n",
"\n",
" - *recall*\n",
"\n",
" > $$\\frac{tp}{tp + fn}$$\n",
"\n",
" - $F_1$\n",
"\n",
" > $$\\frac{2 \\cdot precision \\cdot recall}{precision + recall}$$\n",
"\n",
" - *micro* $F_1$\n",
"\n",
" > $F_1$ w którym $tp$, $fp$ i $fn$ są liczone łącznie dla wszystkich etykiet, tj. $tp = \\sum_{e}{{tp}_e}$, $fn = \\sum_{e}{{fn}_e}$, $fp = \\sum_{e}{{fp}_e}$\n",
"\n",
" - *macro* $F_1$\n",
"\n",
" > średnia arytmetyczna z $F_1$ obliczonych dla poszczególnych etykiet z osobna.\n",
"\n",
"Wyuczony model możemy wczytać z pliku korzystając z metody `load`."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-05-12 16:58:59,033 loading file slot-model/final-model.pt\n"
]
}
],
"source": [
"model = SequenceTagger.load('slot-model/final-model.pt')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Wczytany model możemy wykorzystać do przewidywania slotów w wypowiedziach użytkownika, korzystając\n",
"z przedstawionej poniżej funkcji `predict`."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def predict(model, sentence):\n",
" csentence = [{'form': word} for word in sentence]\n",
" fsentence = conllu2flair([csentence])[0]\n",
" model.predict(fsentence)\n",
" return [(token, ftoken.get_tag('slot').value) for token, ftoken in zip(sentence, fsentence)]\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Jak pokazuje przykład poniżej model wyuczony tylko na 100 przykładach popełnia w dosyć prostej\n",
"wypowiedzi błąd etykietując słowo `alarm` tagiem `B-weather/noun`."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"doktor | O |
\n",
"lekarz | O |
\n",
"wizyta | O |
\n",
"kolano | O |
\n",
"na | O |
\n",
"godzine | O |
\n",
"jutro | O |
\n",
"dzisiaj | O |
\n",
"13:00 | O |
\n",
"\n",
"
"
],
"text/plain": [
"'\\n\\ndoktor | O |
\\nlekarz | O |
\\nwizyta | O |
\\nkolano | O |
\\nna | O |
\\ngodzine | O |
\\njutro | O |
\\ndzisiaj | O |
\\n13:00 | O |
\\n\\n
'"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tabulate(predict(model, 'doktor lekarz wizyta kolano na godzine jutro dzisiaj 13:00'.split()), tablefmt='html')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Literatura\n",
"----------\n",
" 1. Sebastian Schuster, Sonal Gupta, Rushin Shah, Mike Lewis, Cross-lingual Transfer Learning for Multilingual Task Oriented Dialog. NAACL-HLT (1) 2019, pp. 3795-3805\n",
" 2. John D. Lafferty, Andrew McCallum, and Fernando C. N. Pereira. 2001. Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data. In Proceedings of the Eighteenth International Conference on Machine Learning (ICML '01). Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, 282–289, https://repository.upenn.edu/cgi/viewcontent.cgi?article=1162&context=cis_papers\n",
" 3. Sepp Hochreiter and Jürgen Schmidhuber. 1997. Long Short-Term Memory. Neural Comput. 9, 8 (November 15, 1997), 1735–1780, https://doi.org/10.1162/neco.1997.9.8.1735\n",
" 4. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, Attention is All you Need, NIPS 2017, pp. 5998-6008, https://arxiv.org/abs/1706.03762\n",
" 5. Alan Akbik, Duncan Blythe, Roland Vollgraf, Contextual String Embeddings for Sequence Labeling, Proceedings of the 27th International Conference on Computational Linguistics, pp. 1638–1649, https://www.aclweb.org/anthology/C18-1139.pdf\n"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}