518 lines
17 KiB
Plaintext
518 lines
17 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d8790b55",
|
|
"metadata": {},
|
|
"source": [
|
|
"Zarządzanie dialogiem z wykorzystaniem technik uczenia maszynowego\n",
|
|
"==================================================================\n",
|
|
"\n",
|
|
"Uczenie przez wzmacnianie\n",
|
|
"-------------------------\n",
|
|
"\n",
|
|
"Zamiast ręcznie implementować zbiór reguł odpowiedzialnych za wyznaczenie akcji, którą powinien podjąć agent będąc w danym stanie, odpowiednią taktykę prowadzenia dialogu można zbudować, wykorzystując techniki uczenia maszynowego.\n",
|
|
"\n",
|
|
"Obok metod uczenia nadzorowanego, które wykorzystaliśmy do zbudowania modelu NLU, do konstruowania taktyk\n",
|
|
"prowadzenia dialogu wykorzystuje się również *uczenie przez wzmacnianie* (ang. *reinforcement learning*).\n",
|
|
"\n",
|
|
"W tym ujęciu szukać będziemy funkcji $Q*: S \\times A \\to R$, która dla stanu dialogu $s \\in S$ oraz aktu\n",
|
|
"dialogowego $a \\in A$ zwraca nagrodę (ang. *reward*) $r \\in R$, tj. wartość rzeczywistą pozwalającą ocenić na ile\n",
|
|
"podjęcie akcji $a$ w stanie $s$ jest korzystne.\n",
|
|
"\n",
|
|
"Założymy również, że poszukiwana funkcja powinna maksymalizować *zwrot* (ang. *return*), tj.\n",
|
|
"skumulowaną nagrodę w toku prowadzonego dialogu, czyli dla tury $t_0$ cel uczenia powinien mieć postać:\n",
|
|
"\n",
|
|
"$$ \\sum_{t=t_0}^{\\infty}{\\gamma^{t-1}r_t} $$\n",
|
|
"\n",
|
|
"gdzie:\n",
|
|
"\n",
|
|
" - $t$: tura agenta,\n",
|
|
"\n",
|
|
" - $r_t$: nagroda w turze $t$,\n",
|
|
"\n",
|
|
" - $\\gamma \\in [0, 1]$: współczynnik dyskontowy (w przypadku agentów dialogowych bliżej $1$ niż $0$, por. np. Rieser i Lemon (2011))."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "18c3a6c0",
|
|
"metadata": {},
|
|
"source": [
|
|
"Agent dialogowy w procesie uczenia przez wzmacnianie wchodzi w interakcję ze *środowiskiem*, które\n",
|
|
"dla akcji podejmowanej przez taktykę prowadzenia dialogu zwraca kolejny stan oraz nagrodę powiązaną z\n",
|
|
"wykonaniem tej akcji w bieżącym stanie.\n",
|
|
"\n",
|
|
"Sposób w jaki informacje pochodzące ze środowiska są wykorzystywane do znalezienia funkcji $Q*$\n",
|
|
"zależy od wybranej metody uczenia.\n",
|
|
"W przykładzie przestawionym poniżej skorzystamy z algorytmu $DQN$ (Mnih i in., 2013) co oznacza, że:\n",
|
|
"\n",
|
|
" 1. będziemy aproksymować funkcję $Q*$ siecią neuronową,\n",
|
|
"\n",
|
|
" 2. wagi sieci będziemy wyznaczać korzystając z metody spadku gradientu.\n",
|
|
"\n",
|
|
"Przykład\n",
|
|
"--------\n",
|
|
"\n",
|
|
"Ze względu na to, że implementacja algorytmu $DQN$ nie została jeszcze przystosowana do wersji 3.0 środowiska `ConvLab` skorzystamy z wersji 2.0."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "50675b87",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!mkdir -p l10\n",
|
|
"%cd l10\n",
|
|
"!git clone --depth 1 https://github.com/thu-coai/ConvLab-2.git\n",
|
|
"%cd ConvLab-2\n",
|
|
"!pip install -e .\n",
|
|
"%cd ../.."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "407a0ebe",
|
|
"metadata": {},
|
|
"source": [
|
|
"Po zainstalowaniu środowiska `ConvLab-2` należy zrestartować interpreter Pythona (opcja *Kernel -> Restart* w Jupyter)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3b4cbd96",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from convlab2.dialog_agent.agent import PipelineAgent\n",
|
|
"from convlab2.dialog_agent.env import Environment\n",
|
|
"from convlab2.dst.rule.multiwoz import RuleDST\n",
|
|
"from convlab2.policy.rule.multiwoz import RulePolicy\n",
|
|
"from convlab2.policy.dqn import DQN\n",
|
|
"from convlab2.policy.rlmodule import Memory\n",
|
|
"from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator\n",
|
|
"import logging\n",
|
|
"\n",
|
|
"logging.disable(logging.DEBUG)\n",
|
|
"\n",
|
|
"# determinizacja obliczeń\n",
|
|
"import random\n",
|
|
"import torch\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"np.random.seed(123)\n",
|
|
"random.seed(123)\n",
|
|
"torch.manual_seed(123)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cd7eeeab",
|
|
"metadata": {},
|
|
"source": [
|
|
"Środowisko, z którym agent będzie wchodził w interakcje zawierać będzie\n",
|
|
"symulator użytkownika wykorzystujący taktykę prowadzenia dialogu zbudowaną z wykorzystaniem reguł."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c1a91522",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"usr_policy = RulePolicy(character='usr')\n",
|
|
"usr_simulator = PipelineAgent(None, None, usr_policy, None, 'user') # type: ignore\n",
|
|
"\n",
|
|
"dst = RuleDST()\n",
|
|
"evaluator = MultiWozEvaluator()\n",
|
|
"env = Environment(None, usr_simulator, None, dst, evaluator)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "331c14c9",
|
|
"metadata": {},
|
|
"source": [
|
|
"Zobaczmy jak w *ConvLab-2* zdefiniowana jest nagroda w klasie `Environment`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "031bd292",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%%script false --no-raise-error\n",
|
|
"#\n",
|
|
"# plik convlab2/dialog_agent/env.py\n",
|
|
"#\n",
|
|
"class Environment():\n",
|
|
"\n",
|
|
" # (...)\n",
|
|
"\n",
|
|
" def step(self, action):\n",
|
|
"\n",
|
|
" # (...)\n",
|
|
"\n",
|
|
" if self.evaluator:\n",
|
|
" if self.evaluator.task_success():\n",
|
|
" reward = 40\n",
|
|
" elif self.evaluator.cur_domain and self.evaluator.domain_success(self.evaluator.cur_domain):\n",
|
|
" reward = 5\n",
|
|
" else:\n",
|
|
" reward = -1\n",
|
|
" else:\n",
|
|
" reward = self.usr.get_reward()\n",
|
|
" terminated = self.usr.is_terminated()\n",
|
|
"\n",
|
|
" return state, reward, terminated\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2ef9da55",
|
|
"metadata": {},
|
|
"source": [
|
|
"Jak można zauważyć powyżej akcja, która prowadzi do pomyślnego zakończenia zadania uzyskuje nagrodę $40$,\n",
|
|
"akcja która prowadzi do prawidłowego rozpoznania dziedziny uzyskuje nagrodę $5$,\n",
|
|
"natomiast każda inna akcja uzyskuje \"karę\" $-1$. Taka definicja zwrotu premiuje krótkie dialogi\n",
|
|
"prowadzące do pomyślnego wykonania zadania.\n",
|
|
"\n",
|
|
"Sieć neuronowa, którą wykorzystamy do aproksymacji funkcji $Q*$ ma następującą architekturę"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "821a46a6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%%script false --no-raise-error\n",
|
|
"#\n",
|
|
"# plik convlab2/policy/rlmodule.py\n",
|
|
"# klasa EpsilonGreedyPolicy wykorzystywana w DQN\n",
|
|
"#\n",
|
|
"class EpsilonGreedyPolicy(nn.Module):\n",
|
|
" def __init__(self, s_dim, h_dim, a_dim, epsilon_spec={'start': 0.1, 'end': 0.0, 'end_epoch': 200}):\n",
|
|
" super(EpsilonGreedyPolicy, self).__init__()\n",
|
|
"\n",
|
|
" self.net = nn.Sequential(nn.Linear(s_dim, h_dim),\n",
|
|
" nn.ReLU(),\n",
|
|
" nn.Linear(h_dim, h_dim),\n",
|
|
" nn.ReLU(),\n",
|
|
" nn.Linear(h_dim, a_dim))\n",
|
|
" # (...)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5d47ae91",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"policy = DQN(is_train=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "de41de6f",
|
|
"metadata": {},
|
|
"source": [
|
|
"Każdy krok procedury uczenia składa się z dwóch etapów:\n",
|
|
"\n",
|
|
" 1. Wygenerowania przy użyciu taktyki (metoda `policy.predict`) oraz środowiska (metoda `env.step`) *trajektorii*, tj. sekwencji przejść pomiędzy stanami złożonych z krotek postaci:\n",
|
|
" - stanu źródłowego,\n",
|
|
" - podjętej akcji (aktu systemowego),\n",
|
|
" - nagrody,\n",
|
|
" - stanu docelowego,\n",
|
|
" - znacznika końca dialogu."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9063d436",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# por. ConvLab-2/convlab2/policy/dqn/train.py\n",
|
|
"def sample(env, policy, batch_size, warm_up):\n",
|
|
" buff = Memory()\n",
|
|
" sampled_num = 0\n",
|
|
" max_trajectory_len = 50\n",
|
|
"\n",
|
|
" while sampled_num < batch_size:\n",
|
|
" # rozpoczęcie nowego dialogu\n",
|
|
" s = env.reset()\n",
|
|
"\n",
|
|
" for t in range(max_trajectory_len):\n",
|
|
" try:\n",
|
|
" # podjęcie akcji przez agenta dialogowego\n",
|
|
" a = policy.predict(s, warm_up=warm_up)\n",
|
|
"\n",
|
|
" # odpowiedź środowiska na podjętą akcje\n",
|
|
" next_s, r, done = env.step(a)\n",
|
|
"\n",
|
|
" # dodanie krotki do zbioru danych\n",
|
|
" buff.push(torch.Tensor(policy.vector.state_vectorize(s)).numpy(), # stan źródłowy\n",
|
|
" policy.vector.action_vectorize(a), # akcja\n",
|
|
" r, # nagroda\n",
|
|
" torch.Tensor(policy.vector.state_vectorize(next_s)).numpy(), # stan docelowy\n",
|
|
" 0 if done else 1) # znacznik końca\n",
|
|
"\n",
|
|
" s = next_s\n",
|
|
"\n",
|
|
" if done:\n",
|
|
" break\n",
|
|
" except:\n",
|
|
" break\n",
|
|
"\n",
|
|
" sampled_num += t\n",
|
|
"\n",
|
|
" return buff"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "281d76a1",
|
|
"metadata": {},
|
|
"source": [
|
|
" 2. Wykorzystania wygenerowanych krotek do aktualizacji taktyki.\n",
|
|
"\n",
|
|
"Funkcja `train` realizująca pojedynczy krok uczenia przez wzmacnianie ma następującą postać"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a42a21da",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train(env, policy, batch_size, epoch, warm_up):\n",
|
|
" print(f'epoch: {epoch}')\n",
|
|
" buff = sample(env, policy, batch_size, warm_up)\n",
|
|
" policy.update_memory(buff)\n",
|
|
" policy.update(epoch)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "074508c2",
|
|
"metadata": {},
|
|
"source": [
|
|
"Metoda `update` klasy `DQN` wykorzystywana do aktualizacji wag ma następującą postać"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6ce332bd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%%script false --no-raise-error\n",
|
|
"#\n",
|
|
"# plik convlab2/policy/dqn/dqn.py\n",
|
|
"# klasa DQN\n",
|
|
"#\n",
|
|
"class DQN(Policy):\n",
|
|
" # (...)\n",
|
|
" def update(self, epoch):\n",
|
|
" total_loss = 0.\n",
|
|
" for i in range(self.training_iter):\n",
|
|
" round_loss = 0.\n",
|
|
" # 1. batch a sample from memory\n",
|
|
" batch = self.memory.get_batch(batch_size=self.batch_size)\n",
|
|
"\n",
|
|
" for _ in range(self.training_batch_iter):\n",
|
|
" # 2. calculate the Q loss\n",
|
|
" loss = self.calc_q_loss(batch)\n",
|
|
"\n",
|
|
" # 3. make a optimization step\n",
|
|
" self.net_optim.zero_grad()\n",
|
|
" loss.backward()\n",
|
|
" self.net_optim.step()\n",
|
|
"\n",
|
|
" round_loss += loss.item()\n",
|
|
"\n",
|
|
" # (...)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "60e3db5d",
|
|
"metadata": {},
|
|
"source": [
|
|
"Przebieg procesu uczenia zilustrujemy wykonując 10 iteracji. W każdej iteracji ograniczymy się do 100 przykładów."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6d50da62",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"epoch = 10\n",
|
|
"batch_size = 100\n",
|
|
"\n",
|
|
"train(env, policy, batch_size, 0, warm_up=True)\n",
|
|
"\n",
|
|
"for i in range(1, epoch):\n",
|
|
" train(env, policy, batch_size, i, warm_up=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ecf3ee2b",
|
|
"metadata": {},
|
|
"source": [
|
|
"Sprawdźmy jakie akty systemowe zwraca taktyka `DQN` w odpowiedzi na zmieniający się stan dialogu."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "575c0211",
|
|
"metadata": {
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from convlab2.dialog_agent import PipelineAgent\n",
|
|
"dst.init_session()\n",
|
|
"agent = PipelineAgent(nlu=None, dst=dst, policy=policy, nlg=None, name='sys')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ef846914",
|
|
"metadata": {
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"agent.response([['Inform', 'Hotel', 'Price', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6f0aaa0d",
|
|
"metadata": {
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"agent.response([['Inform', 'Hotel', 'Area', 'north']])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "64cb7ad5",
|
|
"metadata": {
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"agent.response([['Request', 'Hotel', 'Area', '?']])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "157f3e0c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"agent.response([['Inform', 'Hotel', 'Day', 'tuesday'], ['Inform', 'Hotel', 'People', '2'], ['Inform', 'Hotel', 'Stay', '4']])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9bc4474b",
|
|
"metadata": {},
|
|
"source": [
|
|
"Jakość wyuczonego modelu możemy ocenić mierząc tzw. wskaźnik sukcesu (ang. *task success rate*),\n",
|
|
"tj. stosunek liczby dialogów zakończonych powodzeniem do liczby wszystkich dialogów."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "58c875d7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from convlab2.dialog_agent.session import BiSession\n",
|
|
"\n",
|
|
"sess = BiSession(agent, usr_simulator, None, evaluator)\n",
|
|
"dialog_num = 100\n",
|
|
"task_success_num = 0\n",
|
|
"max_turn_num = 50\n",
|
|
"\n",
|
|
"# por. ConvLab-2/convlab2/policy/evaluate.py\n",
|
|
"for dialog in range(dialog_num):\n",
|
|
" random.seed(dialog)\n",
|
|
" np.random.seed(dialog)\n",
|
|
" torch.manual_seed(dialog)\n",
|
|
" sess.init_session()\n",
|
|
" sys_act = []\n",
|
|
" task_success = 0\n",
|
|
"\n",
|
|
" for _ in range(max_turn_num):\n",
|
|
" sys_act, _, finished, _ = sess.next_turn(sys_act)\n",
|
|
"\n",
|
|
" if finished is True:\n",
|
|
" task_success = sess.evaluator.task_success()\n",
|
|
" break\n",
|
|
"\n",
|
|
" print(f'dialog: {dialog:02} success: {task_success}')\n",
|
|
" task_success_num += task_success\n",
|
|
"\n",
|
|
"print('')\n",
|
|
"print(f'task success rate: {task_success_num/dialog_num:.2f}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5313885b",
|
|
"metadata": {},
|
|
"source": [
|
|
"**Uwaga**: Chcąc uzyskać taktykę o skuteczności porównywalnej z wynikami przedstawionymi na stronie\n",
|
|
"[ConvLab-2](https://github.com/thu-coai/ConvLab-2/blob/master/README.md) trzeba odpowiednio\n",
|
|
"zwiększyć zarówno liczbę iteracji jak i liczbę przykładów generowanych w każdym przyroście.\n",
|
|
"W celu przyśpieszenia procesu uczenia warto zrównoleglić obliczenia, jak pokazano w\n",
|
|
"skrypcie [train.py](https://github.com/thu-coai/ConvLab-2/blob/master/convlab2/policy/dqn/train.py).\n",
|
|
"\n",
|
|
"Literatura\n",
|
|
"----------\n",
|
|
" 1. Rieser, V., Lemon, O., (2011). Reinforcement learning for adaptive dialogue systems: a data-driven methodology for dialogue management and natural language generation. (Theory and Applications of Natural Language Processing). Springer. https://doi.org/10.1007/978-3-642-24942-6\n",
|
|
"\n",
|
|
" 2. Richard S. Sutton and Andrew G. Barto, (2018). Reinforcement Learning: An Introduction, Second Edition, MIT Press, Cambridge, MA http://incompleteideas.net/book/RLbook2020.pdf\n",
|
|
"\n",
|
|
" 3. Volodymyr Mnih and Koray Kavukcuoglu and David Silver and Alex Graves and Ioannis Antonoglou and Daan Wierstra and Martin Riedmiller, (2013). Playing Atari with Deep Reinforcement Learning, NIPS Deep Learning Workshop, https://arxiv.org/pdf/1312.5602.pdf"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"cell_metadata_filter": "-all",
|
|
"main_language": "python",
|
|
"notebook_metadata_filter": "-all"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|