diff --git a/README.md b/README.md
new file mode 100644
index 0000000..acbe0ae
--- /dev/null
+++ b/README.md
@@ -0,0 +1,13 @@
+## Zadanie 7
+
+- rozwiązanie zadania znajduje się w pliku **lab/08-parsing-semantyczny-uczenie(zmodyfikowany).ipynb**, ostatnia komórka zawiera skrypt ewaluujący model metrykami precision, recall i f1
+
+- uczenie modelu realizowane jest w zmodyfikowanym pliku z zajęć **lab/08-parsing-semantyczny-uczenie(zmodyfikowany).ipynb**
+
+- dane uczące, wygenerowane są automatycznie, na podstawie zebranych wcześniej dialogów, przez regułowy skrypt **tasks/zad8/pl/annotate.py**, a następnie poprawione ręcznie. Dane znajdują sie w dwóch plikach **tasks/zad8/pl/test.conllu** oraz **tasks/zad8/pl/train.conllu**
+
+- model wykorzystywany jest w klasie z pliku **src/components/NLU.py**
+
+- plik **src/dialogue_system.py** będzie łączył wszystkie moduły systemu dialogowego, narazie wykorzystuje tylko tagger NLU
+
+- aby porozmawiać z systemem należy uruchomić wszystkie komórki pliku **lab/08-parsing-semantyczny-uczenie(zmodyfikowany).ipynb**, w celu anuczenia modelu, po ich wykonaniu należy uruchomić pythonowy skrypt **src/dialogue_system.py**
diff --git a/lab/08-parsing-semantyczny-uczenie(zmodyfikowany).ipynb b/lab/08-parsing-semantyczny-uczenie(zmodyfikowany).ipynb
index d186b33..068fd49 100644
--- a/lab/08-parsing-semantyczny-uczenie(zmodyfikowany).ipynb
+++ b/lab/08-parsing-semantyczny-uczenie(zmodyfikowany).ipynb
@@ -82,7 +82,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
@@ -155,7 +155,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
@@ -172,7 +172,7 @@
"'
\\n\\n1 | wybieram | inform | O |
\\n2 | batmana | inform | B-title |
\\n\\n
'"
]
},
- "execution_count": 24,
+ "execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@@ -184,7 +184,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
@@ -202,7 +202,7 @@
"'\\n\\n1 | chcę | inform | O |
\\n2 | zarezerwować | inform | B-goal |
\\n3 | bilety | inform | O |
\\n\\n
'"
]
},
- "execution_count": 25,
+ "execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@@ -213,7 +213,7 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
@@ -232,7 +232,7 @@
"'\\n\\n1 | chciałbym | inform | O |
\\n2 | anulować | inform | O |
\\n3 | rezerwację | inform | O |
\\n4 | biletu | inform | O |
\\n\\n
'"
]
},
- "execution_count": 26,
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -251,7 +251,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@@ -263,6 +263,7 @@
"from flair.embeddings import FlairEmbeddings\n",
"from flair.models import SequenceTagger\n",
"from flair.trainers import ModelTrainer\n",
+ "from flair.datasets import DataLoader\n",
"\n",
"# determinizacja obliczeń\n",
"import random\n",
@@ -287,7 +288,7 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
@@ -333,7 +334,7 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
@@ -360,7 +361,7 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 20,
"metadata": {},
"outputs": [
{
@@ -416,301 +417,23 @@
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": 21,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "2022-05-01 12:13:39,609 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:13:39,610 Model: \"SequenceTagger(\n",
- " (embeddings): StackedEmbeddings(\n",
- " (list_embedding_0): WordEmbeddings('pl')\n",
- " (list_embedding_1): FlairEmbeddings(\n",
- " (lm): LanguageModel(\n",
- " (drop): Dropout(p=0.25, inplace=False)\n",
- " (encoder): Embedding(1602, 100)\n",
- " (rnn): LSTM(100, 2048)\n",
- " (decoder): Linear(in_features=2048, out_features=1602, bias=True)\n",
- " )\n",
- " )\n",
- " (list_embedding_2): FlairEmbeddings(\n",
- " (lm): LanguageModel(\n",
- " (drop): Dropout(p=0.25, inplace=False)\n",
- " (encoder): Embedding(1602, 100)\n",
- " (rnn): LSTM(100, 2048)\n",
- " (decoder): Linear(in_features=2048, out_features=1602, bias=True)\n",
- " )\n",
- " )\n",
- " (list_embedding_3): CharacterEmbeddings(\n",
- " (char_embedding): Embedding(275, 25)\n",
- " (char_rnn): LSTM(25, 25, bidirectional=True)\n",
- " )\n",
- " )\n",
- " (word_dropout): WordDropout(p=0.05)\n",
- " (locked_dropout): LockedDropout(p=0.5)\n",
- " (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True)\n",
- " (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True)\n",
- " (linear): Linear(in_features=512, out_features=20, bias=True)\n",
- " (beta): 1.0\n",
- " (weights): None\n",
- " (weight_tensor) None\n",
- ")\"\n",
- "2022-05-01 12:13:39,611 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:13:39,611 Corpus: \"Corpus: 345 train + 38 dev + 32 test sentences\"\n",
- "2022-05-01 12:13:39,612 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:13:39,613 Parameters:\n",
- "2022-05-01 12:13:39,614 - learning_rate: \"0.1\"\n",
- "2022-05-01 12:13:39,614 - mini_batch_size: \"32\"\n",
- "2022-05-01 12:13:39,615 - patience: \"3\"\n",
- "2022-05-01 12:13:39,616 - anneal_factor: \"0.5\"\n",
- "2022-05-01 12:13:39,616 - max_epochs: \"10\"\n",
- "2022-05-01 12:13:39,616 - shuffle: \"True\"\n",
- "2022-05-01 12:13:39,617 - train_with_dev: \"False\"\n",
- "2022-05-01 12:13:39,618 - batch_growth_annealing: \"False\"\n",
- "2022-05-01 12:13:39,618 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:13:39,619 Model training base path: \"slot-model\"\n",
- "2022-05-01 12:13:39,620 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:13:39,620 Device: cpu\n",
- "2022-05-01 12:13:39,621 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:13:39,621 Embeddings storage mode: cpu\n",
- "2022-05-01 12:13:39,623 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:13:42,490 epoch 1 - iter 1/11 - loss 9.59000492 - samples/sec: 11.17 - lr: 0.100000\n",
- "2022-05-01 12:13:44,150 epoch 1 - iter 2/11 - loss 9.31767702 - samples/sec: 19.29 - lr: 0.100000\n",
- "2022-05-01 12:13:45,968 epoch 1 - iter 3/11 - loss 8.70617644 - samples/sec: 17.61 - lr: 0.100000\n",
- "2022-05-01 12:13:47,791 epoch 1 - iter 4/11 - loss 8.11678410 - samples/sec: 17.57 - lr: 0.100000\n",
- "2022-05-01 12:13:49,815 epoch 1 - iter 5/11 - loss 7.65581417 - samples/sec: 15.82 - lr: 0.100000\n",
- "2022-05-01 12:13:52,296 epoch 1 - iter 6/11 - loss 7.27475810 - samples/sec: 12.90 - lr: 0.100000\n",
- "2022-05-01 12:13:54,454 epoch 1 - iter 7/11 - loss 6.95693064 - samples/sec: 14.84 - lr: 0.100000\n",
- "2022-05-01 12:13:56,845 epoch 1 - iter 8/11 - loss 6.61199290 - samples/sec: 13.39 - lr: 0.100000\n",
- "2022-05-01 12:13:59,195 epoch 1 - iter 9/11 - loss 6.58955601 - samples/sec: 13.63 - lr: 0.100000\n",
- "2022-05-01 12:14:01,065 epoch 1 - iter 10/11 - loss 6.63135071 - samples/sec: 17.11 - lr: 0.100000\n",
- "2022-05-01 12:14:02,415 epoch 1 - iter 11/11 - loss 6.52558366 - samples/sec: 23.72 - lr: 0.100000\n",
- "2022-05-01 12:14:02,416 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:14:02,417 EPOCH 1 done: loss 6.5256 - lr 0.1000000\n",
- "2022-05-01 12:14:05,139 DEV : loss 8.419286727905273 - score 0.0\n",
- "2022-05-01 12:14:05,141 BAD EPOCHS (no improvement): 0\n",
- "saving best model\n",
- "2022-05-01 12:14:15,906 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:14:16,782 epoch 2 - iter 1/11 - loss 7.61237478 - samples/sec: 40.25 - lr: 0.100000\n",
- "2022-05-01 12:14:17,253 epoch 2 - iter 2/11 - loss 7.02023911 - samples/sec: 68.09 - lr: 0.100000\n",
- "2022-05-01 12:14:17,744 epoch 2 - iter 3/11 - loss 6.25125138 - samples/sec: 65.31 - lr: 0.100000\n",
- "2022-05-01 12:14:18,282 epoch 2 - iter 4/11 - loss 5.91574061 - samples/sec: 59.59 - lr: 0.100000\n",
- "2022-05-01 12:14:18,742 epoch 2 - iter 5/11 - loss 5.80905600 - samples/sec: 69.87 - lr: 0.100000\n",
- "2022-05-01 12:14:19,262 epoch 2 - iter 6/11 - loss 5.51969266 - samples/sec: 61.66 - lr: 0.100000\n",
- "2022-05-01 12:14:19,753 epoch 2 - iter 7/11 - loss 5.34836953 - samples/sec: 65.31 - lr: 0.100000\n",
- "2022-05-01 12:14:20,267 epoch 2 - iter 8/11 - loss 5.33710295 - samples/sec: 62.38 - lr: 0.100000\n",
- "2022-05-01 12:14:20,750 epoch 2 - iter 9/11 - loss 5.28061861 - samples/sec: 66.32 - lr: 0.100000\n",
- "2022-05-01 12:14:21,379 epoch 2 - iter 10/11 - loss 5.20552692 - samples/sec: 50.95 - lr: 0.100000\n",
- "2022-05-01 12:14:21,922 epoch 2 - iter 11/11 - loss 5.26294283 - samples/sec: 59.03 - lr: 0.100000\n",
- "2022-05-01 12:14:21,923 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:14:21,924 EPOCH 2 done: loss 5.2629 - lr 0.1000000\n",
- "2022-05-01 12:14:22,145 DEV : loss 7.168168544769287 - score 0.0645\n",
- "2022-05-01 12:14:22,149 BAD EPOCHS (no improvement): 0\n",
- "saving best model\n",
- "2022-05-01 12:14:27,939 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:14:28,495 epoch 3 - iter 1/11 - loss 3.70659065 - samples/sec: 57.56 - lr: 0.100000\n",
- "2022-05-01 12:14:29,038 epoch 3 - iter 2/11 - loss 4.21530080 - samples/sec: 59.04 - lr: 0.100000\n",
- "2022-05-01 12:14:29,607 epoch 3 - iter 3/11 - loss 4.40864404 - samples/sec: 56.37 - lr: 0.100000\n",
- "2022-05-01 12:14:30,171 epoch 3 - iter 4/11 - loss 4.69527233 - samples/sec: 56.93 - lr: 0.100000\n",
- "2022-05-01 12:14:30,587 epoch 3 - iter 5/11 - loss 4.43719640 - samples/sec: 77.11 - lr: 0.100000\n",
- "2022-05-01 12:14:31,075 epoch 3 - iter 6/11 - loss 4.55344125 - samples/sec: 65.71 - lr: 0.100000\n",
- "2022-05-01 12:14:31,625 epoch 3 - iter 7/11 - loss 4.77397609 - samples/sec: 58.34 - lr: 0.100000\n",
- "2022-05-01 12:14:32,143 epoch 3 - iter 8/11 - loss 4.61572361 - samples/sec: 61.89 - lr: 0.100000\n",
- "2022-05-01 12:14:32,703 epoch 3 - iter 9/11 - loss 4.60090372 - samples/sec: 57.24 - lr: 0.100000\n",
- "2022-05-01 12:14:33,404 epoch 3 - iter 10/11 - loss 4.70502276 - samples/sec: 45.69 - lr: 0.100000\n",
- "2022-05-01 12:14:33,839 epoch 3 - iter 11/11 - loss 4.76321775 - samples/sec: 73.73 - lr: 0.100000\n",
- "2022-05-01 12:14:33,840 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:14:33,840 EPOCH 3 done: loss 4.7632 - lr 0.1000000\n",
- "2022-05-01 12:14:33,992 DEV : loss 7.209894180297852 - score 0.0\n",
- "2022-05-01 12:14:33,993 BAD EPOCHS (no improvement): 1\n",
- "2022-05-01 12:14:33,994 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:14:34,556 epoch 4 - iter 1/11 - loss 5.55247641 - samples/sec: 57.04 - lr: 0.100000\n",
- "2022-05-01 12:14:35,078 epoch 4 - iter 2/11 - loss 5.08158088 - samples/sec: 61.42 - lr: 0.100000\n",
- "2022-05-01 12:14:35,643 epoch 4 - iter 3/11 - loss 4.69475476 - samples/sec: 56.73 - lr: 0.100000\n",
- "2022-05-01 12:14:36,270 epoch 4 - iter 4/11 - loss 4.78649628 - samples/sec: 51.16 - lr: 0.100000\n",
- "2022-05-01 12:14:36,806 epoch 4 - iter 5/11 - loss 4.62873497 - samples/sec: 59.93 - lr: 0.100000\n",
- "2022-05-01 12:14:37,419 epoch 4 - iter 6/11 - loss 4.70938087 - samples/sec: 52.29 - lr: 0.100000\n",
- "2022-05-01 12:14:38,068 epoch 4 - iter 7/11 - loss 4.50588363 - samples/sec: 49.46 - lr: 0.100000\n",
- "2022-05-01 12:14:38,581 epoch 4 - iter 8/11 - loss 4.36334288 - samples/sec: 62.50 - lr: 0.100000\n",
- "2022-05-01 12:14:39,140 epoch 4 - iter 9/11 - loss 4.36617618 - samples/sec: 57.45 - lr: 0.100000\n",
- "2022-05-01 12:14:39,780 epoch 4 - iter 10/11 - loss 4.37847199 - samples/sec: 50.16 - lr: 0.100000\n",
- "2022-05-01 12:14:40,321 epoch 4 - iter 11/11 - loss 4.26116128 - samples/sec: 59.18 - lr: 0.100000\n",
- "2022-05-01 12:14:40,323 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:14:40,324 EPOCH 4 done: loss 4.2612 - lr 0.1000000\n",
- "2022-05-01 12:14:40,544 DEV : loss 5.882441997528076 - score 0.1714\n",
- "2022-05-01 12:14:40,546 BAD EPOCHS (no improvement): 0\n",
- "saving best model\n",
- "2022-05-01 12:14:46,159 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:14:46,709 epoch 5 - iter 1/11 - loss 3.86370564 - samples/sec: 58.29 - lr: 0.100000\n",
- "2022-05-01 12:14:47,349 epoch 5 - iter 2/11 - loss 3.80554891 - samples/sec: 50.08 - lr: 0.100000\n",
- "2022-05-01 12:14:47,857 epoch 5 - iter 3/11 - loss 3.34506067 - samples/sec: 63.11 - lr: 0.100000\n",
- "2022-05-01 12:14:48,579 epoch 5 - iter 4/11 - loss 3.88535106 - samples/sec: 44.38 - lr: 0.100000\n",
- "2022-05-01 12:14:49,170 epoch 5 - iter 5/11 - loss 3.81894360 - samples/sec: 54.28 - lr: 0.100000\n",
- "2022-05-01 12:14:49,708 epoch 5 - iter 6/11 - loss 4.18858314 - samples/sec: 59.53 - lr: 0.100000\n",
- "2022-05-01 12:14:50,171 epoch 5 - iter 7/11 - loss 4.13974752 - samples/sec: 69.26 - lr: 0.100000\n",
- "2022-05-01 12:14:50,593 epoch 5 - iter 8/11 - loss 4.01002905 - samples/sec: 75.98 - lr: 0.100000\n",
- "2022-05-01 12:14:51,062 epoch 5 - iter 9/11 - loss 3.97078644 - samples/sec: 68.52 - lr: 0.100000\n",
- "2022-05-01 12:14:51,508 epoch 5 - iter 10/11 - loss 3.94409857 - samples/sec: 71.91 - lr: 0.100000\n",
- "2022-05-01 12:14:51,960 epoch 5 - iter 11/11 - loss 3.80738796 - samples/sec: 70.95 - lr: 0.100000\n",
- "2022-05-01 12:14:51,961 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:14:51,963 EPOCH 5 done: loss 3.8074 - lr 0.1000000\n",
- "2022-05-01 12:14:52,103 DEV : loss 5.224854469299316 - score 0.1667\n",
- "2022-05-01 12:14:52,105 BAD EPOCHS (no improvement): 1\n",
- "2022-05-01 12:14:52,106 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:14:52,616 epoch 6 - iter 1/11 - loss 3.51282573 - samples/sec: 62.91 - lr: 0.100000\n",
- "2022-05-01 12:14:53,100 epoch 6 - iter 2/11 - loss 3.41601551 - samples/sec: 66.25 - lr: 0.100000\n",
- "2022-05-01 12:14:53,513 epoch 6 - iter 3/11 - loss 3.08380787 - samples/sec: 77.76 - lr: 0.100000\n",
- "2022-05-01 12:14:55,121 epoch 6 - iter 4/11 - loss 3.21056002 - samples/sec: 64.71 - lr: 0.100000\n",
- "2022-05-01 12:14:55,665 epoch 6 - iter 5/11 - loss 3.30184879 - samples/sec: 58.88 - lr: 0.100000\n",
- "2022-05-01 12:14:56,160 epoch 6 - iter 6/11 - loss 3.20993070 - samples/sec: 64.91 - lr: 0.100000\n",
- "2022-05-01 12:14:56,670 epoch 6 - iter 7/11 - loss 3.14396119 - samples/sec: 62.91 - lr: 0.100000\n",
- "2022-05-01 12:14:57,329 epoch 6 - iter 8/11 - loss 3.24591878 - samples/sec: 48.63 - lr: 0.100000\n",
- "2022-05-01 12:14:57,958 epoch 6 - iter 9/11 - loss 3.31877112 - samples/sec: 51.03 - lr: 0.100000\n",
- "2022-05-01 12:14:58,527 epoch 6 - iter 10/11 - loss 3.33475649 - samples/sec: 56.34 - lr: 0.100000\n",
- "2022-05-01 12:14:58,989 epoch 6 - iter 11/11 - loss 3.23232636 - samples/sec: 69.41 - lr: 0.100000\n",
- "2022-05-01 12:14:58,991 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:14:58,991 EPOCH 6 done: loss 3.2323 - lr 0.1000000\n",
- "2022-05-01 12:14:59,178 DEV : loss 4.557621002197266 - score 0.2381\n",
- "2022-05-01 12:14:59,180 BAD EPOCHS (no improvement): 0\n",
- "saving best model\n",
- "2022-05-01 12:15:25,844 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:15:26,423 epoch 7 - iter 1/11 - loss 2.71161938 - samples/sec: 55.36 - lr: 0.100000\n",
- "2022-05-01 12:15:26,886 epoch 7 - iter 2/11 - loss 2.50157821 - samples/sec: 69.26 - lr: 0.100000\n",
- "2022-05-01 12:15:27,347 epoch 7 - iter 3/11 - loss 2.78014056 - samples/sec: 69.56 - lr: 0.100000\n",
- "2022-05-01 12:15:27,853 epoch 7 - iter 4/11 - loss 2.82983196 - samples/sec: 63.36 - lr: 0.100000\n",
- "2022-05-01 12:15:28,393 epoch 7 - iter 5/11 - loss 2.84246483 - samples/sec: 59.37 - lr: 0.100000\n",
- "2022-05-01 12:15:28,847 epoch 7 - iter 6/11 - loss 2.89787177 - samples/sec: 70.64 - lr: 0.100000\n",
- "2022-05-01 12:15:29,338 epoch 7 - iter 7/11 - loss 2.74564961 - samples/sec: 65.30 - lr: 0.100000\n",
- "2022-05-01 12:15:29,813 epoch 7 - iter 8/11 - loss 2.79853699 - samples/sec: 67.58 - lr: 0.100000\n",
- "2022-05-01 12:15:30,364 epoch 7 - iter 9/11 - loss 2.89167126 - samples/sec: 58.18 - lr: 0.100000\n",
- "2022-05-01 12:15:30,834 epoch 7 - iter 10/11 - loss 2.86527851 - samples/sec: 68.22 - lr: 0.100000\n",
- "2022-05-01 12:15:31,296 epoch 7 - iter 11/11 - loss 2.82858575 - samples/sec: 69.41 - lr: 0.100000\n",
- "2022-05-01 12:15:31,297 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:15:31,298 EPOCH 7 done: loss 2.8286 - lr 0.1000000\n",
- "2022-05-01 12:15:31,462 DEV : loss 4.020608901977539 - score 0.3182\n",
- "2022-05-01 12:15:31,463 BAD EPOCHS (no improvement): 0\n",
- "saving best model\n",
- "2022-05-01 12:15:38,431 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:15:38,979 epoch 8 - iter 1/11 - loss 3.28806710 - samples/sec: 58.61 - lr: 0.100000\n",
- "2022-05-01 12:15:39,534 epoch 8 - iter 2/11 - loss 2.72140074 - samples/sec: 57.76 - lr: 0.100000\n",
- "2022-05-01 12:15:40,061 epoch 8 - iter 3/11 - loss 2.77740423 - samples/sec: 60.89 - lr: 0.100000\n",
- "2022-05-01 12:15:40,541 epoch 8 - iter 4/11 - loss 2.51573136 - samples/sec: 66.72 - lr: 0.100000\n",
- "2022-05-01 12:15:41,109 epoch 8 - iter 5/11 - loss 2.54271443 - samples/sec: 56.53 - lr: 0.100000\n",
- "2022-05-01 12:15:41,537 epoch 8 - iter 6/11 - loss 2.47530021 - samples/sec: 75.12 - lr: 0.100000\n",
- "2022-05-01 12:15:42,078 epoch 8 - iter 7/11 - loss 2.62978831 - samples/sec: 59.26 - lr: 0.100000\n",
- "2022-05-01 12:15:42,506 epoch 8 - iter 8/11 - loss 2.62844713 - samples/sec: 74.84 - lr: 0.100000\n",
- "2022-05-01 12:15:42,988 epoch 8 - iter 9/11 - loss 2.61604464 - samples/sec: 66.59 - lr: 0.100000\n",
- "2022-05-01 12:15:43,471 epoch 8 - iter 10/11 - loss 2.62512223 - samples/sec: 66.39 - lr: 0.100000\n",
- "2022-05-01 12:15:43,895 epoch 8 - iter 11/11 - loss 2.64045010 - samples/sec: 75.65 - lr: 0.100000\n",
- "2022-05-01 12:15:43,896 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:15:43,897 EPOCH 8 done: loss 2.6405 - lr 0.1000000\n",
- "2022-05-01 12:15:44,036 DEV : loss 3.542769432067871 - score 0.3846\n",
- "2022-05-01 12:15:44,038 BAD EPOCHS (no improvement): 0\n",
- "saving best model\n",
- "2022-05-01 12:15:51,672 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:15:52,235 epoch 9 - iter 1/11 - loss 1.73337626 - samples/sec: 56.99 - lr: 0.100000\n",
- "2022-05-01 12:15:52,801 epoch 9 - iter 2/11 - loss 2.09788013 - samples/sec: 56.74 - lr: 0.100000\n",
- "2022-05-01 12:15:53,288 epoch 9 - iter 3/11 - loss 2.24861153 - samples/sec: 65.84 - lr: 0.100000\n",
- "2022-05-01 12:15:53,735 epoch 9 - iter 4/11 - loss 2.42630130 - samples/sec: 71.75 - lr: 0.100000\n",
- "2022-05-01 12:15:54,189 epoch 9 - iter 5/11 - loss 2.42454610 - samples/sec: 70.64 - lr: 0.100000\n",
- "2022-05-01 12:15:54,720 epoch 9 - iter 6/11 - loss 2.39987107 - samples/sec: 60.38 - lr: 0.100000\n",
- "2022-05-01 12:15:55,192 epoch 9 - iter 7/11 - loss 2.29154910 - samples/sec: 67.94 - lr: 0.100000\n",
- "2022-05-01 12:15:55,632 epoch 9 - iter 8/11 - loss 2.22984707 - samples/sec: 73.06 - lr: 0.100000\n",
- "2022-05-01 12:15:56,162 epoch 9 - iter 9/11 - loss 2.32317919 - samples/sec: 60.49 - lr: 0.100000\n",
- "2022-05-01 12:15:56,559 epoch 9 - iter 10/11 - loss 2.24865967 - samples/sec: 80.81 - lr: 0.100000\n",
- "2022-05-01 12:15:56,986 epoch 9 - iter 11/11 - loss 2.27327953 - samples/sec: 75.12 - lr: 0.100000\n",
- "2022-05-01 12:15:56,988 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:15:56,988 EPOCH 9 done: loss 2.2733 - lr 0.1000000\n",
- "2022-05-01 12:15:57,130 DEV : loss 3.4634602069854736 - score 0.5517\n",
- "2022-05-01 12:15:57,132 BAD EPOCHS (no improvement): 0\n",
- "saving best model\n",
- "2022-05-01 12:16:04,067 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:16:04,643 epoch 10 - iter 1/11 - loss 2.22972107 - samples/sec: 55.65 - lr: 0.100000\n",
- "2022-05-01 12:16:05,144 epoch 10 - iter 2/11 - loss 2.20346498 - samples/sec: 64.00 - lr: 0.100000\n",
- "2022-05-01 12:16:05,576 epoch 10 - iter 3/11 - loss 2.07501336 - samples/sec: 74.24 - lr: 0.100000\n",
- "2022-05-01 12:16:06,036 epoch 10 - iter 4/11 - loss 2.09982607 - samples/sec: 69.72 - lr: 0.100000\n",
- "2022-05-01 12:16:06,508 epoch 10 - iter 5/11 - loss 2.08048103 - samples/sec: 67.94 - lr: 0.100000\n",
- "2022-05-01 12:16:07,062 epoch 10 - iter 6/11 - loss 2.08074635 - samples/sec: 57.87 - lr: 0.100000\n",
- "2022-05-01 12:16:07,590 epoch 10 - iter 7/11 - loss 2.07187140 - samples/sec: 60.84 - lr: 0.100000\n",
- "2022-05-01 12:16:08,116 epoch 10 - iter 8/11 - loss 2.10148455 - samples/sec: 60.95 - lr: 0.100000\n",
- "2022-05-01 12:16:08,563 epoch 10 - iter 9/11 - loss 2.06198527 - samples/sec: 71.74 - lr: 0.100000\n",
- "2022-05-01 12:16:09,066 epoch 10 - iter 10/11 - loss 2.00194792 - samples/sec: 63.75 - lr: 0.100000\n",
- "2022-05-01 12:16:09,486 epoch 10 - iter 11/11 - loss 2.00801701 - samples/sec: 76.37 - lr: 0.100000\n",
- "2022-05-01 12:16:09,487 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:16:09,488 EPOCH 10 done: loss 2.0080 - lr 0.1000000\n",
- "2022-05-01 12:16:09,624 DEV : loss 3.1866908073425293 - score 0.4706\n",
- "2022-05-01 12:16:09,625 BAD EPOCHS (no improvement): 1\n",
- "2022-05-01 12:16:16,655 ----------------------------------------------------------------------------------------------------\n",
- "2022-05-01 12:16:16,656 Testing using best model ...\n",
- "2022-05-01 12:16:16,676 loading file slot-model\\best-model.pt\n",
- "2022-05-01 12:16:22,739 0.4231\t0.3056\t0.3548\n",
- "2022-05-01 12:16:22,740 \n",
- "Results:\n",
- "- F1-score (micro) 0.3548\n",
- "- F1-score (macro) 0.2570\n",
- "\n",
- "By class:\n",
- "area tp: 1 - fp: 1 - fn: 2 - precision: 0.5000 - recall: 0.3333 - f1-score: 0.4000\n",
- "date tp: 0 - fp: 3 - fn: 3 - precision: 0.0000 - recall: 0.0000 - f1-score: 0.0000\n",
- "goal tp: 2 - fp: 2 - fn: 8 - precision: 0.5000 - recall: 0.2000 - f1-score: 0.2857\n",
- "interval tp: 0 - fp: 0 - fn: 1 - precision: 0.0000 - recall: 0.0000 - f1-score: 0.0000\n",
- "quantity tp: 4 - fp: 1 - fn: 2 - precision: 0.8000 - recall: 0.6667 - f1-score: 0.7273\n",
- "seats tp: 0 - fp: 1 - fn: 0 - precision: 0.0000 - recall: 0.0000 - f1-score: 0.0000\n",
- "time tp: 1 - fp: 4 - fn: 5 - precision: 0.2000 - recall: 0.1667 - f1-score: 0.1818\n",
- "title tp: 3 - fp: 3 - fn: 4 - precision: 0.5000 - recall: 0.4286 - f1-score: 0.4615\n",
- "2022-05-01 12:16:22,740 ----------------------------------------------------------------------------------------------------\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'test_score': 0.3548387096774194,\n",
- " 'dev_score_history': [0.0,\n",
- " 0.06451612903225806,\n",
- " 0.0,\n",
- " 0.17142857142857143,\n",
- " 0.16666666666666663,\n",
- " 0.23809523809523808,\n",
- " 0.3181818181818182,\n",
- " 0.38461538461538464,\n",
- " 0.5517241379310345,\n",
- " 0.47058823529411764],\n",
- " 'train_loss_history': [6.525583657351407,\n",
- " 5.26294283433394,\n",
- " 4.7632177526300605,\n",
- " 4.261161284013228,\n",
- " 3.807387958873402,\n",
- " 3.2323263558474453,\n",
- " 2.828585754741322,\n",
- " 2.6404500982978125,\n",
- " 2.2732795260169287,\n",
- " 2.0080170089548286],\n",
- " 'dev_loss_history': [8.419286727905273,\n",
- " 7.168168544769287,\n",
- " 7.209894180297852,\n",
- " 5.882441997528076,\n",
- " 5.224854469299316,\n",
- " 4.557621002197266,\n",
- " 4.020608901977539,\n",
- " 3.542769432067871,\n",
- " 3.4634602069854736,\n",
- " 3.1866908073425293]}"
- ]
- },
- "execution_count": 31,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "trainer = ModelTrainer(tagger, corpus)\n",
- "trainer.train('slot-model',\n",
- " learning_rate=0.1,\n",
- " mini_batch_size=32,\n",
- " max_epochs=10,\n",
- " train_with_dev=False)"
+ "modelPath = 'slot-model/final-model.pt'\n",
+ "\n",
+ "from os.path import exists\n",
+ "\n",
+ "fileExists = exists(modelPath)\n",
+ "\n",
+ "if(not fileExists):\n",
+ " trainer = ModelTrainer(tagger, corpus)\n",
+ " trainer.train('slot-model',\n",
+ " learning_rate=0.1,\n",
+ " mini_batch_size=32,\n",
+ " max_epochs=10,\n",
+ " train_with_dev=False)"
]
},
{
@@ -756,19 +479,19 @@
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "2022-05-01 12:16:22,953 loading file slot-model/final-model.pt\n"
+ "2022-05-05 17:34:34,767 loading file slot-model/final-model.pt\n"
]
}
],
"source": [
- "model = SequenceTagger.load('slot-model/final-model.pt')"
+ "model = SequenceTagger.load(modelPath)"
]
},
{
@@ -781,7 +504,7 @@
},
{
"cell_type": "code",
- "execution_count": 69,
+ "execution_count": 42,
"metadata": {},
"outputs": [
{
@@ -790,7 +513,7 @@
"[('kiedy', 'O'), ('gracie', 'O'), ('film', 'O'), ('zorro', 'B-title')]"
]
},
- "execution_count": 69,
+ "execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
@@ -815,7 +538,7 @@
},
{
"cell_type": "code",
- "execution_count": 68,
+ "execution_count": 24,
"metadata": {},
"outputs": [
{
@@ -834,7 +557,7 @@
"'\\n\\nkiedy | O |
\\ngracie | O |
\\nfilm | O |
\\nzorro | B-title |
\\n\\n
'"
]
},
- "execution_count": 68,
+ "execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
@@ -843,6 +566,67 @@
"tabulate(predict(model, 'kiedy gracie film zorro'.split()), tablefmt='html')"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 82,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "stats: \n",
+ "precision: 0.8076923076923077\n",
+ "recall: 0.4117647058823529\n",
+ "f1: 0.5454545454545454\n"
+ ]
+ }
+ ],
+ "source": [
+ "# evaluation\n",
+ "\n",
+ "def precision(tpScore, fpScore):\n",
+ " return float(tpScore) / (tpScore + fpScore)\n",
+ "\n",
+ "def recall(tpScore, fnScore):\n",
+ " return float(tpScore) / (tpScore + fnScore)\n",
+ "\n",
+ "def f1(precision, recall):\n",
+ " return 2 * precision * recall/(precision + recall)\n",
+ "\n",
+ "def eval():\n",
+ " tp = 0\n",
+ " fp = 0\n",
+ " fn = 0\n",
+ " sentences = [sentence for sentence in testset]\n",
+ " for sentence in sentences:\n",
+ " # get sentence as terms list\n",
+ " termsList = [w[\"form\"] for w in sentence]\n",
+ " # predict tags\n",
+ " predTags = [tag[1] for tag in predict(model, termsList)]\n",
+ " \n",
+ " expTags = [token[\"slot\"] for token in sentence]\n",
+ " for i in range(len(predTags)):\n",
+ " if (expTags[i] == \"O\" and expTags[i] != predTags[i]):\n",
+ " fp += 1\n",
+ " elif ((expTags[i] != \"O\") & (predTags[i] == \"O\")):\n",
+ " fn += 1\n",
+ " elif ((expTags[i] != \"O\") & (predTags[i] == expTags[i])):\n",
+ " tp += 1\n",
+ "\n",
+ " precisionScore = precision(tp, fp)\n",
+ " recallScore = recall(tp, fn)\n",
+ " f1Score = f1(precisionScore, recallScore)\n",
+ " print(\"stats: \")\n",
+ " print(\"precision: \", precisionScore)\n",
+ " print(\"recall: \", recallScore)\n",
+ " print(\"f1: \", f1Score)\n",
+ "\n",
+ "eval()\n",
+ "\n",
+ " "
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
diff --git a/lab/09-zarzadzanie-dialogiem-reguly.ipynb b/lab/09-zarzadzanie-dialogiem-reguly.ipynb
new file mode 100644
index 0000000..6d59ed6
--- /dev/null
+++ b/lab/09-zarzadzanie-dialogiem-reguly.ipynb
@@ -0,0 +1,528 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "collapsed": false
+ },
+ "source": [
+ "![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
+ "\n",
+ "
Systemy Dialogowe
\n",
+ " 9. Zarz\u0105dzanie dialogiem z wykorzystaniem regu\u0142 [laboratoria]
\n",
+ " Marek Kubis (2021)
\n",
+ "\n",
+ "\n",
+ "![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Zarz\u0105dzanie dialogiem z wykorzystaniem regu\u0142\n",
+ "============================================\n",
+ "\n",
+ "Agent dialogowy wykorzystuje do zarz\u0105dzanie dialogiem dwa modu\u0142y:\n",
+ "\n",
+ " - monitor stanu dialogu (dialogue state tracker, DST) \u2014 modu\u0142 odpowiedzialny za \u015bledzenie stanu dialogu.\n",
+ "\n",
+ " - taktyk\u0119 prowadzenia dialogu (dialogue policy) \u2014 modu\u0142, kt\u00f3ry na podstawie stanu dialogu\n",
+ " podejmuje decyzj\u0119 o tym jak\u0105 akcj\u0119 (akt systemu) agent ma podj\u0105\u0107 w kolejnej turze.\n",
+ "\n",
+ "Oba modu\u0142y mog\u0105 by\u0107 realizowane zar\u00f3wno z wykorzystaniem regu\u0142 jak i uczenia maszynowego.\n",
+ "Mog\u0105 one zosta\u0107 r\u00f3wnie\u017c po\u0142\u0105czone w pojedynczy modu\u0142 zwany w\u00f3wczas *mened\u017cerem dialogu*.\n",
+ "\n",
+ "Przyk\u0142ad\n",
+ "--------\n",
+ "\n",
+ "Zaimplementujemy regu\u0142owe modu\u0142y monitora stanu dialogu oraz taktyki dialogowej a nast\u0119pnie\n",
+ "osadzimy je w \u015brodowisku *[ConvLab-2](https://github.com/thu-coai/ConvLab-2)*,\n",
+ "kt\u00f3re s\u0142u\u017cy do ewaluacji system\u00f3w dialogowych.\n",
+ "\n",
+ "**Uwaga:** Niekt\u00f3re modu\u0142y \u015brodowiska *ConvLab-2* nie s\u0105 zgodne z najnowszymi wersjami Pythona,\n",
+ "dlatego przed uruchomieniem poni\u017cszych przyk\u0142ad\u00f3w nale\u017cy si\u0119 upewni\u0107, \u017ce maj\u0105 Pa\u0144stwo interpreter\n",
+ "Pythona w wersji 3.7. W przypadku nowszych wersji Ubuntu Pythona 3.7 mo\u017cna zainstalowa\u0107 z\n",
+ "repozytorium `deadsnakes`, wykonuj\u0105c polecenia przedstawione poni\u017cej.\n",
+ "\n",
+ "```\n",
+ "sudo add-apt-repository ppa:deadsnakes/ppa\n",
+ "sudo apt update\n",
+ "sudo apt install python3.7 python3.7-dev python3.7-venv\n",
+ "```\n",
+ "\n",
+ "W przypadku innych system\u00f3w mo\u017cna skorzysta\u0107 np. z narz\u0119dzia [pyenv](https://github.com/pyenv/pyenv) lub \u015brodowiska [conda](https://conda.io).\n",
+ "\n",
+ "Ze wzgl\u0119du na to, \u017ce *ConvLab-2* ma wiele zale\u017cno\u015bci zach\u0119cam r\u00f3wnie\u017c do skorzystania ze \u015brodowiska\n",
+ "wirtualnego `venv`, w kt\u00f3rym modu\u0142y zale\u017cne mog\u0105 zosta\u0107 zainstalowane.\n",
+ "W tym celu nale\u017cy wykona\u0107 nast\u0119puj\u0105ce polecenia\n",
+ "\n",
+ "```\n",
+ "python3.7 -m venv convenv # utworzenie nowego \u015brodowiska o nazwie convenv\n",
+ "source convenv/bin/activate # aktywacja \u015brodowiska w bie\u017c\u0105cej pow\u0142oce\n",
+ "pip install --ignore-installed jupyter # instalacja jupytera w \u015brodowisku convenv\n",
+ "```\n",
+ "\n",
+ "Po skonfigurowaniu \u015brodowiska mo\u017cna przyst\u0105pi\u0107 do instalacji *ConvLab-2*, korzystaj\u0105c z\n",
+ "nast\u0119puj\u0105cych polece\u0144\n",
+ "\n",
+ "```\n",
+ "mkdir -p l08\n",
+ "cd l08\n",
+ "git clone https://github.com/thu-coai/ConvLab-2.git\n",
+ "cd ConvLab-2\n",
+ "pip install -e .\n",
+ "python -m spacy download en_core_web_sm\n",
+ "cd ../..\n",
+ "```\n",
+ "\n",
+ "Po zako\u0144czeniu instalacji nale\u017cy ponownie uruchomi\u0107 notatnik w pow\u0142oce, w kt\u00f3rej aktywne jest\n",
+ "\u015brodowisko wirtualne *convenv*.\n",
+ "\n",
+ "```\n",
+ "jupyter notebook 08-zarzadzanie-dialogiem-reguly.ipynb\n",
+ "```\n",
+ "\n",
+ "Dzia\u0142anie zaimplementowanych modu\u0142\u00f3w zilustrujemy, korzystaj\u0105c ze zbioru danych\n",
+ "[MultiWOZ](https://github.com/budzianowski/multiwoz) (Budzianowski i in., 2018), kt\u00f3ry zawiera\n",
+ "wypowiedzi dotycz\u0105ce m.in. rezerwacji pokoi hotelowych, zamawiania bilet\u00f3w kolejowych oraz\n",
+ "rezerwacji stolik\u00f3w w restauracji.\n",
+ "\n",
+ "### Monitor Stanu Dialogu\n",
+ "\n",
+ "Do reprezentowania stanu dialogu u\u017cyjemy struktury danych wykorzystywanej w *ConvLab-2*."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from convlab2.util.multiwoz.state import default_state\n",
+ "default_state()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Metoda `update` naszego monitora stanu dialogu b\u0119dzie przyjmowa\u0107 akty u\u017cytkownika i odpowiednio\n",
+ "modyfikowa\u0107 stan dialogu.\n",
+ "W przypadku akt\u00f3w typu `inform` warto\u015bci slot\u00f3w zostan\u0105 zapami\u0119tane w s\u0142ownikach odpowiadaj\u0105cych\n",
+ "poszczeg\u00f3lnym dziedzinom pod kluczem `belief_state`.\n",
+ "W przypadku akt\u00f3w typu `request` sloty, o kt\u00f3re pyta u\u017cytkownik zostan\u0105 zapisane pod kluczem\n",
+ "`request_state`.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "import os\n",
+ "from convlab2.dst.dst import DST\n",
+ "from convlab2.dst.rule.multiwoz.dst_util import normalize_value\n",
+ "from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA\n",
+ "\n",
+ "\n",
+ "class SimpleRuleDST(DST):\n",
+ " def __init__(self):\n",
+ " DST.__init__(self)\n",
+ " self.state = default_state()\n",
+ " self.value_dict = json.load(open('l08/ConvLab-2/data/multiwoz/value_dict.json'))\n",
+ "\n",
+ " def update(self, user_act=None):\n",
+ " for intent, domain, slot, value in user_act:\n",
+ " domain = domain.lower()\n",
+ " intent = intent.lower()\n",
+ "\n",
+ " if domain in ['unk', 'general', 'booking']:\n",
+ " continue\n",
+ "\n",
+ " if intent == 'inform':\n",
+ " k = REF_SYS_DA[domain.capitalize()].get(slot, slot)\n",
+ "\n",
+ " if k is None:\n",
+ " continue\n",
+ "\n",
+ " domain_dic = self.state['belief_state'][domain]\n",
+ "\n",
+ " if k in domain_dic['semi']:\n",
+ " nvalue = normalize_value(self.value_dict, domain, k, value)\n",
+ " self.state['belief_state'][domain]['semi'][k] = nvalue\n",
+ " elif k in domain_dic['book']:\n",
+ " self.state['belief_state'][domain]['book'][k] = value\n",
+ " elif k.lower() in domain_dic['book']:\n",
+ " self.state['belief_state'][domain]['book'][k.lower()] = value\n",
+ " elif intent == 'request':\n",
+ " k = REF_SYS_DA[domain.capitalize()].get(slot, slot)\n",
+ "\n",
+ " if domain not in self.state['request_state']:\n",
+ " self.state['request_state'][domain] = {}\n",
+ " if k not in self.state['request_state'][domain]:\n",
+ " self.state['request_state'][domain][k] = 0\n",
+ "\n",
+ " return self.state\n",
+ "\n",
+ " def init_session(self):\n",
+ " self.state = default_state()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "W definicji metody `update` zak\u0142adamy, \u017ce akty dialogowe przekazywane do monitora stanu dialogu z\n",
+ "modu\u0142u NLU s\u0105 czteroelementowymi listami z\u0142o\u017conymi z:\n",
+ "\n",
+ " - nazwy aktu u\u017cytkownika,\n",
+ " - nazwy dziedziny, kt\u00f3rej dotyczy wypowied\u017a,\n",
+ " - nazwy slotu,\n",
+ " - warto\u015bci slotu.\n",
+ "\n",
+ "Zobaczmy na kilku prostych przyk\u0142adach jak stan dialogu zmienia si\u0119 pod wp\u0142ywem przekazanych akt\u00f3w\n",
+ "u\u017cytkownika."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "dst = SimpleRuleDST()\n",
+ "dst.state"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "dst.update([['Inform', 'Hotel', 'Price', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']])\n",
+ "dst.state['belief_state']['hotel']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "dst.update([['Inform', 'Hotel', 'Area', 'north']])\n",
+ "dst.state['belief_state']['hotel']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "dst.update([['Request', 'Hotel', 'Area', '?']])\n",
+ "dst.state['request_state']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "dst.update([['Inform', 'Hotel', 'Day', 'tuesday'], ['Inform', 'Hotel', 'People', '2'], ['Inform', 'Hotel', 'Stay', '4']])\n",
+ "dst.state['belief_state']['hotel']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dst.state"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Taktyka Prowadzenia Dialogu\n",
+ "\n",
+ "Prosta taktyka prowadzenia dialogu dla systemu rezerwacji pokoi hotelowych mo\u017ce sk\u0142ada\u0107 si\u0119 z nast\u0119puj\u0105cych regu\u0142:\n",
+ "\n",
+ " 1. Je\u017celi u\u017cytkownik przekaza\u0142 w ostatniej turze akt typu `Request`, to udziel odpowiedzi na jego\n",
+ " pytanie.\n",
+ "\n",
+ " 2. Je\u017celi u\u017cytkownik przekaza\u0142 w ostatniej turze akt typu `Inform`, to zaproponuj mu hotel\n",
+ " spe\u0142niaj\u0105cy zdefiniowane przez niego kryteria.\n",
+ "\n",
+ " 3. Je\u017celi u\u017cytkownik przekaza\u0142 w ostatniej turze akt typu `Inform` zawieraj\u0105cy szczeg\u00f3\u0142y\n",
+ " rezerwacji, to zarezerwuj pok\u00f3j.\n",
+ "\n",
+ "Metoda `predict` taktyki `SimpleRulePolicy` realizuje regu\u0142y przedstawione powy\u017cej."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from collections import defaultdict\n",
+ "import copy\n",
+ "import json\n",
+ "from copy import deepcopy\n",
+ "\n",
+ "from convlab2.policy.policy import Policy\n",
+ "from convlab2.util.multiwoz.dbquery import Database\n",
+ "from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA\n",
+ "\n",
+ "\n",
+ "class SimpleRulePolicy(Policy):\n",
+ " def __init__(self):\n",
+ " Policy.__init__(self)\n",
+ " self.db = Database()\n",
+ "\n",
+ " def predict(self, state):\n",
+ " self.results = []\n",
+ " system_action = defaultdict(list)\n",
+ " user_action = defaultdict(list)\n",
+ "\n",
+ " for intent, domain, slot, value in state['user_action']:\n",
+ " user_action[(domain, intent)].append((slot, value))\n",
+ "\n",
+ " for user_act in user_action:\n",
+ " self.update_system_action(user_act, user_action, state, system_action)\n",
+ "\n",
+ " # Regu\u0142a 3\n",
+ " if any(True for slots in user_action.values() for (slot, _) in slots if slot in ['Stay', 'Day', 'People']):\n",
+ " if self.results:\n",
+ " system_action = {('Booking', 'Book'): [[\"Ref\", self.results[0].get('Ref', 'N/A')]]}\n",
+ "\n",
+ " system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for slot, value in slots]\n",
+ " state['system_action'] = system_acts\n",
+ " return system_acts\n",
+ "\n",
+ " def update_system_action(self, user_act, user_action, state, system_action):\n",
+ " domain, intent = user_act\n",
+ " constraints = [(slot, value) for slot, value in state['belief_state'][domain.lower()]['semi'].items() if value != '']\n",
+ " self.results = deepcopy(self.db.query(domain.lower(), constraints))\n",
+ "\n",
+ " # Regu\u0142a 1\n",
+ " if intent == 'Request':\n",
+ " if len(self.results) == 0:\n",
+ " system_action[(domain, 'NoOffer')] = []\n",
+ " else:\n",
+ " for slot in user_action[user_act]:\n",
+ " kb_slot_name = REF_SYS_DA[domain].get(slot[0], slot[0])\n",
+ "\n",
+ " if kb_slot_name in self.results[0]:\n",
+ " system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(kb_slot_name, 'unknown')])\n",
+ "\n",
+ " # Regu\u0142a 2\n",
+ " elif intent == 'Inform':\n",
+ " if len(self.results) == 0:\n",
+ " system_action[(domain, 'NoOffer')] = []\n",
+ " else:\n",
+ " system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])\n",
+ " choice = self.results[0]\n",
+ "\n",
+ " if domain in [\"Hotel\", \"Attraction\", \"Police\", \"Restaurant\"]:\n",
+ " system_action[(domain, 'Recommend')].append(['Name', choice['name']])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Podobnie jak w przypadku akt\u00f3w u\u017cytkownika akty systemowe przekazywane do modu\u0142u NLG s\u0105 czteroelementowymi listami z\u0142o\u017conymi z:\n",
+ "\n",
+ " - nazwy aktu systemowe,\n",
+ " - nazwy dziedziny, kt\u00f3rej dotyczy wypowied\u017a,\n",
+ " - nazwy slotu,\n",
+ " - warto\u015bci slotu.\n",
+ "\n",
+ "Sprawd\u017amy jakie akty systemowe zwraca taktyka `SimpleRulePolicy` w odpowiedzi na zmieniaj\u0105cy si\u0119 stan dialogu."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "from convlab2.dialog_agent import PipelineAgent\n",
+ "dst.init_session()\n",
+ "policy = SimpleRulePolicy()\n",
+ "agent = PipelineAgent(nlu=None, dst=dst, policy=policy, nlg=None, name='sys')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "agent.response([['Inform', 'Hotel', 'Price', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "agent.response([['Inform', 'Hotel', 'Area', 'north']])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "agent.response([['Request', 'Hotel', 'Area', '?']])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "agent.response([['Inform', 'Hotel', 'Day', 'tuesday'], ['Inform', 'Hotel', 'People', '2'], ['Inform', 'Hotel', 'Stay', '4']])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Testy End-to-End\n",
+ "\n",
+ "Na koniec przeprowad\u017amy dialog \u0142\u0105cz\u0105c w potok nasze modu\u0142y\n",
+ "z modu\u0142ami NLU i NLG dost\u0119pnymi dla MultiWOZ w \u015brodowisku `ConvLab-2`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from convlab2.nlu.svm.multiwoz import SVMNLU\n",
+ "from convlab2.nlg.template.multiwoz import TemplateNLG\n",
+ "\n",
+ "nlu = SVMNLU()\n",
+ "nlg = TemplateNLG(is_user=False)\n",
+ "agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "agent.response(\"I need a cheap hotel with free parking .\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "agent.response(\"Where it is located ?\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "agent.response(\"I would prefer the hotel be in the north part of town .\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "agent.response(\"Yeah , could you book me a room for 2 people for 4 nights starting Tuesday ?\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Zauwa\u017cmy, ze nasza prosta taktyka dialogowa zawiera wiele luk, do kt\u00f3rych nale\u017c\u0105 m.in.:\n",
+ "\n",
+ " 1. Niezdolno\u015b\u0107 do udzielenia odpowiedzi na przywitanie, pro\u015bb\u0119 o pomoc lub restart.\n",
+ "\n",
+ " 2. Brak regu\u0142 dopytuj\u0105cych u\u017cytkownika o szczeg\u00f3\u0142y niezb\u0119dne do dokonania rezerwacji takie, jak d\u0142ugo\u015b\u0107 pobytu czy liczba os\u00f3b.\n",
+ "\n",
+ "Bardziej zaawansowane modu\u0142y zarz\u0105dzania dialogiem zbudowane z wykorzystaniem regu\u0142 mo\u017cna znale\u017a\u0107 w\n",
+ "\u015brodowisku `ConvLab-2`. Nale\u017c\u0105 do nich m.in. monitor [RuleDST](https://github.com/thu-coai/ConvLab-2/blob/master/convlab2/dst/rule/multiwoz/dst.py) oraz taktyka [RuleBasedMultiwozBot](https://github.com/thu-coai/ConvLab-2/blob/master/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py).\n",
+ "\n",
+ "Zadania\n",
+ "-------\n",
+ " 1. Zaimplementowa\u0107 w projekcie monitor stanu dialogu.\n",
+ "\n",
+ " 2. Zaimplementowa\u0107 w projekcie taktyk\u0119 prowadzenia dialogu.\n",
+ "\n",
+ "Termin: 24.05.2021, godz. 23:59.\n",
+ "\n",
+ "Literatura\n",
+ "----------\n",
+ " 1. Pawel Budzianowski, Tsung-Hsien Wen, Bo-Hsiang Tseng, I\u00f1igo Casanueva, Stefan Ultes, Osman Ramadan, Milica Gasic, MultiWOZ - A Large-Scale Multi-Domain Wizard-of-Oz Dataset for Task-Oriented Dialogue Modelling. EMNLP 2018, pp. 5016-5026\n",
+ " 2. Cathy Pearl, Basic principles for designing voice user interfaces, https://www.oreilly.com/content/basic-principles-for-designing-voice-user-interfaces/ data dost\u0119pu: 21 marca 2021\n",
+ " 3. Cathy Pearl, Designing Voice User Interfaces, Excerpts from Chapter 5: Advanced Voice User Interface Design, https://www.uxmatters.com/mt/archives/2018/01/designing-voice-user-interfaces.php data dost\u0119pu: 21 marca 2021"
+ ]
+ }
+ ],
+ "metadata": {
+ "jupytext": {
+ "cell_metadata_filter": "-all",
+ "main_language": "python",
+ "notebook_metadata_filter": "-all"
+ },
+ "author": "Marek Kubis",
+ "email": "mkubis@amu.edu.pl",
+ "lang": "pl",
+ "subtitle": "9.Zarz\u0105dzanie dialogiem z wykorzystaniem regu\u0142[laboratoria]",
+ "title": "Systemy Dialogowe",
+ "year": "2021"
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
\ No newline at end of file