add notebooks
This commit is contained in:
parent
d9c9c06603
commit
5cde582e5e
494
notebooks/09-zarzadzanie-dialogiem-reguly.ipynb
Normal file
494
notebooks/09-zarzadzanie-dialogiem-reguly.ipynb
Normal file
@ -0,0 +1,494 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "90c05009",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Zarządzanie dialogiem z wykorzystaniem reguł\n",
|
||||||
|
"============================================\n",
|
||||||
|
"\n",
|
||||||
|
"Agent dialogowy wykorzystuje do zarządzanie dialogiem dwa moduły:\n",
|
||||||
|
"\n",
|
||||||
|
" - monitor stanu dialogu (dialogue state tracker, DST) — moduł odpowiedzialny za śledzenie stanu dialogu.\n",
|
||||||
|
"\n",
|
||||||
|
" - taktykę prowadzenia dialogu (dialogue policy) — moduł, który na podstawie stanu dialogu\n",
|
||||||
|
" podejmuje decyzję o tym jaką akcję (akt systemu) agent ma podjąć w kolejnej turze.\n",
|
||||||
|
"\n",
|
||||||
|
"Oba moduły mogą być realizowane zarówno z wykorzystaniem reguł jak i uczenia maszynowego.\n",
|
||||||
|
"Mogą one zostać również połączone w pojedynczy moduł zwany wówczas *menedżerem dialogu*.\n",
|
||||||
|
"\n",
|
||||||
|
"Przykład\n",
|
||||||
|
"--------\n",
|
||||||
|
"\n",
|
||||||
|
"Zaimplementujemy regułowe moduły monitora stanu dialogu oraz taktyki dialogowej a następnie\n",
|
||||||
|
"osadzimy je w środowisku *[ConvLab](https://github.com/ConvLab/ConvLab-3)*,\n",
|
||||||
|
"które służy do ewaluacji systemów dialogowych.\n",
|
||||||
|
"\n",
|
||||||
|
"**Uwaga:** Niektóre moduły środowiska *ConvLab* nie są zgodne z najnowszymi wersjami Pythona,\n",
|
||||||
|
"dlatego przed uruchomieniem poniższych przykładów należy się upewnić, że mają Państwo interpreter\n",
|
||||||
|
"Pythona w wersji 3.8.\n",
|
||||||
|
"Odpowiednią wersję Pythona można zainstalować korzystając m.in. z narzędzia [pyenv](https://github.com/pyenv/pyenv) oraz środowiska [conda](https://conda.io).\n",
|
||||||
|
"\n",
|
||||||
|
"Środowisko *ConvLab* można zainstalować korzystając z poniższych poleceń."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4205706b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!mkdir -p l09\n",
|
||||||
|
"%cd l09\n",
|
||||||
|
"!git clone --depth 1 https://github.com/ConvLab/ConvLab-3\n",
|
||||||
|
"%cd ConvLab-3\n",
|
||||||
|
"!pip install -e .\n",
|
||||||
|
"%cd ../.."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c14555bd",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Po zainstalowaniu środowiska `ConvLab` należy zrestartować interpreter Pythona (opcja *Kernel -> Restart* w Jupyter).\n",
|
||||||
|
"\n",
|
||||||
|
"Działanie zaimplementowanych modułów zilustrujemy, korzystając ze zbioru danych\n",
|
||||||
|
"[MultiWOZ](https://github.com/budzianowski/multiwoz) (Budzianowski i in., 2018), który zawiera\n",
|
||||||
|
"wypowiedzi dotyczące m.in. rezerwacji pokoi hotelowych, zamawiania biletów kolejowych oraz\n",
|
||||||
|
"rezerwacji stolików w restauracji.\n",
|
||||||
|
"\n",
|
||||||
|
"### Monitor Stanu Dialogu\n",
|
||||||
|
"\n",
|
||||||
|
"Do reprezentowania stanu dialogu użyjemy struktury danych wykorzystywanej w *ConvLab*."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "38c4de37",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from convlab.util.multiwoz.state import default_state\n",
|
||||||
|
"default_state()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "09fecf16",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Metoda `update` naszego monitora stanu dialogu będzie przyjmować akty użytkownika i odpowiednio\n",
|
||||||
|
"modyfikować stan dialogu.\n",
|
||||||
|
"W przypadku aktów typu `inform` wartości slotów zostaną zapamiętane w słownikach odpowiadających\n",
|
||||||
|
"poszczególnym dziedzinom pod kluczem `belief_state`.\n",
|
||||||
|
"W przypadku aktów typu `request` sloty, o które pyta użytkownik zostaną zapisane pod kluczem\n",
|
||||||
|
"`request_state`.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "172a883f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import json\n",
|
||||||
|
"import os\n",
|
||||||
|
"from convlab.dst.dst import DST\n",
|
||||||
|
"from convlab.dst.rule.multiwoz.dst_util import normalize_value\n",
|
||||||
|
"\n",
|
||||||
|
"class SimpleRuleDST(DST):\n",
|
||||||
|
" def __init__(self):\n",
|
||||||
|
" DST.__init__(self)\n",
|
||||||
|
" self.state = default_state()\n",
|
||||||
|
" self.value_dict = json.load(open('l09/ConvLab-3/data/multiwoz/value_dict.json'))\n",
|
||||||
|
"\n",
|
||||||
|
" def update(self, user_act=None):\n",
|
||||||
|
" for intent, domain, slot, value in user_act:\n",
|
||||||
|
" domain = domain.lower()\n",
|
||||||
|
" intent = intent.lower()\n",
|
||||||
|
" slot = slot.lower()\n",
|
||||||
|
" \n",
|
||||||
|
" if domain not in self.state['belief_state']:\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" if intent == 'inform':\n",
|
||||||
|
" if slot == 'none' or slot == '':\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" domain_dic = self.state['belief_state'][domain]\n",
|
||||||
|
"\n",
|
||||||
|
" if slot in domain_dic:\n",
|
||||||
|
" nvalue = normalize_value(self.value_dict, domain, slot, value)\n",
|
||||||
|
" self.state['belief_state'][domain][slot] = nvalue\n",
|
||||||
|
"\n",
|
||||||
|
" elif intent == 'request':\n",
|
||||||
|
" if domain not in self.state['request_state']:\n",
|
||||||
|
" self.state['request_state'][domain] = {}\n",
|
||||||
|
" if slot not in self.state['request_state'][domain]:\n",
|
||||||
|
" self.state['request_state'][domain][slot] = 0\n",
|
||||||
|
"\n",
|
||||||
|
" return self.state\n",
|
||||||
|
"\n",
|
||||||
|
" def init_session(self):\n",
|
||||||
|
" self.state = default_state()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a411a1ca",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"W definicji metody `update` zakładamy, że akty dialogowe przekazywane do monitora stanu dialogu z\n",
|
||||||
|
"modułu NLU są czteroelementowymi listami złożonymi z:\n",
|
||||||
|
"\n",
|
||||||
|
" - nazwy aktu użytkownika,\n",
|
||||||
|
" - nazwy dziedziny, której dotyczy wypowiedź,\n",
|
||||||
|
" - nazwy slotu,\n",
|
||||||
|
" - wartości slotu.\n",
|
||||||
|
"\n",
|
||||||
|
"Zobaczmy na kilku prostych przykładach jak stan dialogu zmienia się pod wpływem przekazanych aktów\n",
|
||||||
|
"użytkownika."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2abb1707",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dst = SimpleRuleDST()\n",
|
||||||
|
"dst.state"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "31312e0f",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dst.update([['Inform', 'Hotel', 'Price Range', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']])\n",
|
||||||
|
"dst.state['belief_state']['hotel']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "38a02c80",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dst.update([['Inform', 'Hotel', 'Area', 'north']])\n",
|
||||||
|
"dst.state['belief_state']['hotel']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "23a47f33",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dst.update([['Request', 'Hotel', 'Area', '?']])\n",
|
||||||
|
"dst.state['request_state']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9a0698f2",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dst.update([['Inform', 'Hotel', 'Book Day', 'tuesday'], ['Inform', 'Hotel', 'Book People', '2'], ['Inform', 'Hotel', 'Book Stay', '4']])\n",
|
||||||
|
"dst.state['belief_state']['hotel']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "1c5c8093",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dst.state"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b0814105",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Taktyka Prowadzenia Dialogu\n",
|
||||||
|
"\n",
|
||||||
|
"Prosta taktyka prowadzenia dialogu dla systemu rezerwacji pokoi hotelowych może składać się z następujących reguł:\n",
|
||||||
|
"\n",
|
||||||
|
" 1. Jeżeli użytkownik przekazał w ostatniej turze akt typu `Request`, to udziel odpowiedzi na jego\n",
|
||||||
|
" pytanie.\n",
|
||||||
|
"\n",
|
||||||
|
" 2. Jeżeli użytkownik przekazał w ostatniej turze akt typu `Inform`, to zaproponuj mu hotel\n",
|
||||||
|
" spełniający zdefiniowane przez niego kryteria.\n",
|
||||||
|
"\n",
|
||||||
|
" 3. Jeżeli użytkownik przekazał w ostatniej turze akt typu `Inform` zawierający szczegóły\n",
|
||||||
|
" rezerwacji, to zarezerwuj pokój.\n",
|
||||||
|
"\n",
|
||||||
|
"Metoda `predict` taktyki `SimpleRulePolicy` realizuje reguły przedstawione powyżej."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "14412255",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from collections import defaultdict\n",
|
||||||
|
"import copy\n",
|
||||||
|
"import json\n",
|
||||||
|
"from copy import deepcopy\n",
|
||||||
|
"\n",
|
||||||
|
"from convlab.policy.policy import Policy\n",
|
||||||
|
"from convlab.util.multiwoz.dbquery import Database\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class SimpleRulePolicy(Policy):\n",
|
||||||
|
" def __init__(self):\n",
|
||||||
|
" Policy.__init__(self)\n",
|
||||||
|
" self.db = Database()\n",
|
||||||
|
"\n",
|
||||||
|
" def predict(self, state):\n",
|
||||||
|
" self.results = []\n",
|
||||||
|
" system_action = defaultdict(list)\n",
|
||||||
|
" user_action = defaultdict(list)\n",
|
||||||
|
"\n",
|
||||||
|
" for intent, domain, slot, value in state['user_action']:\n",
|
||||||
|
" user_action[(domain.lower(), intent.lower())].append((slot.lower(), value))\n",
|
||||||
|
"\n",
|
||||||
|
" for user_act in user_action:\n",
|
||||||
|
" self.update_system_action(user_act, user_action, state, system_action)\n",
|
||||||
|
"\n",
|
||||||
|
" # Reguła 3\n",
|
||||||
|
" if any(True for slots in user_action.values() for (slot, _) in slots if slot in ['book stay', 'book day', 'book people']):\n",
|
||||||
|
" if self.results:\n",
|
||||||
|
" system_action = {('Booking', 'Book'): [[\"Ref\", self.results[0].get('Ref', 'N/A')]]}\n",
|
||||||
|
"\n",
|
||||||
|
" system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for slot, value in slots]\n",
|
||||||
|
" state['system_action'] = system_acts\n",
|
||||||
|
" return system_acts\n",
|
||||||
|
"\n",
|
||||||
|
" def update_system_action(self, user_act, user_action, state, system_action):\n",
|
||||||
|
" domain, intent = user_act\n",
|
||||||
|
" constraints = [(slot, value) for slot, value in state['belief_state'][domain.lower()].items() if value != '']\n",
|
||||||
|
" self.results = deepcopy(self.db.query(domain.lower(), constraints))\n",
|
||||||
|
"\n",
|
||||||
|
" # Reguła 1\n",
|
||||||
|
" if intent == 'request':\n",
|
||||||
|
" if len(self.results) == 0:\n",
|
||||||
|
" system_action[(domain, 'NoOffer')] = []\n",
|
||||||
|
" else:\n",
|
||||||
|
" for slot in user_action[user_act]: \n",
|
||||||
|
" if slot[0] in self.results[0]:\n",
|
||||||
|
" system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])\n",
|
||||||
|
"\n",
|
||||||
|
" # Reguła 2\n",
|
||||||
|
" elif intent == 'inform':\n",
|
||||||
|
" if len(self.results) == 0:\n",
|
||||||
|
" system_action[(domain, 'NoOffer')] = []\n",
|
||||||
|
" else:\n",
|
||||||
|
" system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])\n",
|
||||||
|
" choice = self.results[0]\n",
|
||||||
|
"\n",
|
||||||
|
" if domain in [\"hotel\", \"attraction\", \"police\", \"restaurant\"]:\n",
|
||||||
|
" system_action[(domain, 'Recommend')].append(['Name', choice['name']])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bff4572c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Podobnie jak w przypadku aktów użytkownika akty systemowe przekazywane do modułu NLG są czteroelementowymi listami złożonymi z:\n",
|
||||||
|
"\n",
|
||||||
|
" - nazwy aktu systemowe,\n",
|
||||||
|
" - nazwy dziedziny, której dotyczy wypowiedź,\n",
|
||||||
|
" - nazwy slotu,\n",
|
||||||
|
" - wartości slotu.\n",
|
||||||
|
"\n",
|
||||||
|
"Sprawdźmy jakie akty systemowe zwraca taktyka `SimpleRulePolicy` w odpowiedzi na zmieniający się stan dialogu."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "1b50240f",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from convlab.dialog_agent import PipelineAgent\n",
|
||||||
|
"dst.init_session()\n",
|
||||||
|
"policy = SimpleRulePolicy()\n",
|
||||||
|
"agent = PipelineAgent(nlu=None, dst=dst, policy=policy, nlg=None, name='sys')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "116d54d7",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent.response([['Inform', 'Hotel', 'Price Range', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7d1c5be8",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent.response([['Request', 'Hotel', 'Area', '?']])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "f296e283",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent.response([['Inform', 'Hotel', 'Area', 'centre']])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "08d6bfaa",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent.response([['Inform', 'Hotel', 'Book Day', 'tuesday'], ['Inform', 'Hotel', 'Book People', '2'], ['Inform', 'Hotel', 'Book Stay', '4']])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ecf28c41",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Testy End-to-End\n",
|
||||||
|
"\n",
|
||||||
|
"Na koniec przeprowadźmy dialog łącząc w potok nasze moduły\n",
|
||||||
|
"z modułami NLU i NLG dostępnymi dla MultiWOZ w środowisku `ConvLab`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d5d95c0c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from convlab.base_models.t5.nlu import T5NLU\n",
|
||||||
|
"from convlab.nlg.template.multiwoz import TemplateNLG\n",
|
||||||
|
"\n",
|
||||||
|
"nlu = T5NLU(speaker='user', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlu-multiwoz21')\n",
|
||||||
|
"nlg = TemplateNLG(is_user=False)\n",
|
||||||
|
"agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "200e6941",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent.response(\"I need a cheap hotel with free parking .\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "da78fca0",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent.response(\"Yeah , could you book me a room for 2 people for 4 nights starting Tuesday ?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "fbb4e0cf",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent.response(\"what is the hotel phone number ?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "6116cf5c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Zauważmy, ze nasza prosta taktyka dialogowa zawiera wiele luk, do których należą m.in.:\n",
|
||||||
|
"\n",
|
||||||
|
" 1. Niezdolność do udzielenia odpowiedzi na przywitanie, prośbę o pomoc lub restart.\n",
|
||||||
|
"\n",
|
||||||
|
" 2. Brak reguł dopytujących użytkownika o szczegóły niezbędne do dokonania rezerwacji takie, jak długość pobytu czy liczba osób.\n",
|
||||||
|
"\n",
|
||||||
|
"Bardziej zaawansowane moduły zarządzania dialogiem zbudowane z wykorzystaniem reguł można znaleźć w\n",
|
||||||
|
"środowisku `ConvLab`. Należą do nich m.in. monitor [RuleDST](https://github.com/ConvLab/ConvLab-3/blob/master/convlab/dst/rule/multiwoz/dst.py) oraz taktyka [RuleBasedMultiwozBot](https://github.com/ConvLab/ConvLab-3/blob/master/convlab/policy/rule/multiwoz/rule_based_multiwoz_bot.py).\n",
|
||||||
|
"\n",
|
||||||
|
"Zadania\n",
|
||||||
|
"-------\n",
|
||||||
|
" 1. Zaimplementować w projekcie monitor stanu dialogu.\n",
|
||||||
|
"\n",
|
||||||
|
" 2. Zaimplementować w projekcie taktykę prowadzenia dialogu.\n",
|
||||||
|
"\n",
|
||||||
|
"Literatura\n",
|
||||||
|
"----------\n",
|
||||||
|
" 1. Pawel Budzianowski, Tsung-Hsien Wen, Bo-Hsiang Tseng, Iñigo Casanueva, Stefan Ultes, Osman Ramadan, Milica Gasic, MultiWOZ - A Large-Scale Multi-Domain Wizard-of-Oz Dataset for Task-Oriented Dialogue Modelling. EMNLP 2018, pp. 5016-5026\n",
|
||||||
|
" 2. Cathy Pearl, Basic principles for designing voice user interfaces, https://www.oreilly.com/content/basic-principles-for-designing-voice-user-interfaces/ data dostępu: 21 marca 2021\n",
|
||||||
|
" 3. Cathy Pearl, Designing Voice User Interfaces, Excerpts from Chapter 5: Advanced Voice User Interface Design, https://www.uxmatters.com/mt/archives/2018/01/designing-voice-user-interfaces.php data dostępu: 21 marca 2021"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"jupytext": {
|
||||||
|
"cell_metadata_filter": "-all",
|
||||||
|
"main_language": "python",
|
||||||
|
"notebook_metadata_filter": "-all"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
517
notebooks/10-zarzadzanie-dialogiem-uczenie.ipynb
Normal file
517
notebooks/10-zarzadzanie-dialogiem-uczenie.ipynb
Normal file
@ -0,0 +1,517 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d8790b55",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Zarządzanie dialogiem z wykorzystaniem technik uczenia maszynowego\n",
|
||||||
|
"==================================================================\n",
|
||||||
|
"\n",
|
||||||
|
"Uczenie przez wzmacnianie\n",
|
||||||
|
"-------------------------\n",
|
||||||
|
"\n",
|
||||||
|
"Zamiast ręcznie implementować zbiór reguł odpowiedzialnych za wyznaczenie akcji, którą powinien podjąć agent będąc w danym stanie, odpowiednią taktykę prowadzenia dialogu można zbudować, wykorzystując techniki uczenia maszynowego.\n",
|
||||||
|
"\n",
|
||||||
|
"Obok metod uczenia nadzorowanego, które wykorzystaliśmy do zbudowania modelu NLU, do konstruowania taktyk\n",
|
||||||
|
"prowadzenia dialogu wykorzystuje się również *uczenie przez wzmacnianie* (ang. *reinforcement learning*).\n",
|
||||||
|
"\n",
|
||||||
|
"W tym ujęciu szukać będziemy funkcji $Q*: S \\times A \\to R$, która dla stanu dialogu $s \\in S$ oraz aktu\n",
|
||||||
|
"dialogowego $a \\in A$ zwraca nagrodę (ang. *reward*) $r \\in R$, tj. wartość rzeczywistą pozwalającą ocenić na ile\n",
|
||||||
|
"podjęcie akcji $a$ w stanie $s$ jest korzystne.\n",
|
||||||
|
"\n",
|
||||||
|
"Założymy również, że poszukiwana funkcja powinna maksymalizować *zwrot* (ang. *return*), tj.\n",
|
||||||
|
"skumulowaną nagrodę w toku prowadzonego dialogu, czyli dla tury $t_0$ cel uczenia powinien mieć postać:\n",
|
||||||
|
"\n",
|
||||||
|
"$$ \\sum_{t=t_0}^{\\infty}{\\gamma^{t-1}r_t} $$\n",
|
||||||
|
"\n",
|
||||||
|
"gdzie:\n",
|
||||||
|
"\n",
|
||||||
|
" - $t$: tura agenta,\n",
|
||||||
|
"\n",
|
||||||
|
" - $r_t$: nagroda w turze $t$,\n",
|
||||||
|
"\n",
|
||||||
|
" - $\\gamma \\in [0, 1]$: współczynnik dyskontowy (w przypadku agentów dialogowych bliżej $1$ niż $0$, por. np. Rieser i Lemon (2011))."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "18c3a6c0",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Agent dialogowy w procesie uczenia przez wzmacnianie wchodzi w interakcję ze *środowiskiem*, które\n",
|
||||||
|
"dla akcji podejmowanej przez taktykę prowadzenia dialogu zwraca kolejny stan oraz nagrodę powiązaną z\n",
|
||||||
|
"wykonaniem tej akcji w bieżącym stanie.\n",
|
||||||
|
"\n",
|
||||||
|
"Sposób w jaki informacje pochodzące ze środowiska są wykorzystywane do znalezienia funkcji $Q*$\n",
|
||||||
|
"zależy od wybranej metody uczenia.\n",
|
||||||
|
"W przykładzie przestawionym poniżej skorzystamy z algorytmu $DQN$ (Mnih i in., 2013) co oznacza, że:\n",
|
||||||
|
"\n",
|
||||||
|
" 1. będziemy aproksymować funkcję $Q*$ siecią neuronową,\n",
|
||||||
|
"\n",
|
||||||
|
" 2. wagi sieci będziemy wyznaczać korzystając z metody spadku gradientu.\n",
|
||||||
|
"\n",
|
||||||
|
"Przykład\n",
|
||||||
|
"--------\n",
|
||||||
|
"\n",
|
||||||
|
"Ze względu na to, że implementacja algorytmu $DQN$ nie została jeszcze przystosowana do wersji 3.0 środowiska `ConvLab` skorzystamy z wersji 2.0."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "50675b87",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!mkdir -p l10\n",
|
||||||
|
"%cd l10\n",
|
||||||
|
"!git clone --depth 1 https://github.com/thu-coai/ConvLab-2.git\n",
|
||||||
|
"%cd ConvLab-2\n",
|
||||||
|
"!pip install -e .\n",
|
||||||
|
"%cd ../.."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "407a0ebe",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Po zainstalowaniu środowiska `ConvLab-2` należy zrestartować interpreter Pythona (opcja *Kernel -> Restart* w Jupyter)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "3b4cbd96",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from convlab2.dialog_agent.agent import PipelineAgent\n",
|
||||||
|
"from convlab2.dialog_agent.env import Environment\n",
|
||||||
|
"from convlab2.dst.rule.multiwoz import RuleDST\n",
|
||||||
|
"from convlab2.policy.rule.multiwoz import RulePolicy\n",
|
||||||
|
"from convlab2.policy.dqn import DQN\n",
|
||||||
|
"from convlab2.policy.rlmodule import Memory\n",
|
||||||
|
"from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator\n",
|
||||||
|
"import logging\n",
|
||||||
|
"\n",
|
||||||
|
"logging.disable(logging.DEBUG)\n",
|
||||||
|
"\n",
|
||||||
|
"# determinizacja obliczeń\n",
|
||||||
|
"import random\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"\n",
|
||||||
|
"np.random.seed(123)\n",
|
||||||
|
"random.seed(123)\n",
|
||||||
|
"torch.manual_seed(123)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "cd7eeeab",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Środowisko, z którym agent będzie wchodził w interakcje zawierać będzie\n",
|
||||||
|
"symulator użytkownika wykorzystujący taktykę prowadzenia dialogu zbudowaną z wykorzystaniem reguł."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c1a91522",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"usr_policy = RulePolicy(character='usr')\n",
|
||||||
|
"usr_simulator = PipelineAgent(None, None, usr_policy, None, 'user') # type: ignore\n",
|
||||||
|
"\n",
|
||||||
|
"dst = RuleDST()\n",
|
||||||
|
"evaluator = MultiWozEvaluator()\n",
|
||||||
|
"env = Environment(None, usr_simulator, None, dst, evaluator)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "331c14c9",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Zobaczmy jak w *ConvLab-2* zdefiniowana jest nagroda w klasie `Environment`"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "031bd292",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%%script false --no-raise-error\n",
|
||||||
|
"#\n",
|
||||||
|
"# plik convlab2/dialog_agent/env.py\n",
|
||||||
|
"#\n",
|
||||||
|
"class Environment():\n",
|
||||||
|
"\n",
|
||||||
|
" # (...)\n",
|
||||||
|
"\n",
|
||||||
|
" def step(self, action):\n",
|
||||||
|
"\n",
|
||||||
|
" # (...)\n",
|
||||||
|
"\n",
|
||||||
|
" if self.evaluator:\n",
|
||||||
|
" if self.evaluator.task_success():\n",
|
||||||
|
" reward = 40\n",
|
||||||
|
" elif self.evaluator.cur_domain and self.evaluator.domain_success(self.evaluator.cur_domain):\n",
|
||||||
|
" reward = 5\n",
|
||||||
|
" else:\n",
|
||||||
|
" reward = -1\n",
|
||||||
|
" else:\n",
|
||||||
|
" reward = self.usr.get_reward()\n",
|
||||||
|
" terminated = self.usr.is_terminated()\n",
|
||||||
|
"\n",
|
||||||
|
" return state, reward, terminated\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2ef9da55",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Jak można zauważyć powyżej akcja, która prowadzi do pomyślnego zakończenia zadania uzyskuje nagrodę $40$,\n",
|
||||||
|
"akcja która prowadzi do prawidłowego rozpoznania dziedziny uzyskuje nagrodę $5$,\n",
|
||||||
|
"natomiast każda inna akcja uzyskuje \"karę\" $-1$. Taka definicja zwrotu premiuje krótkie dialogi\n",
|
||||||
|
"prowadzące do pomyślnego wykonania zadania.\n",
|
||||||
|
"\n",
|
||||||
|
"Sieć neuronowa, którą wykorzystamy do aproksymacji funkcji $Q*$ ma następującą architekturę"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "821a46a6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%%script false --no-raise-error\n",
|
||||||
|
"#\n",
|
||||||
|
"# plik convlab2/policy/rlmodule.py\n",
|
||||||
|
"# klasa EpsilonGreedyPolicy wykorzystywana w DQN\n",
|
||||||
|
"#\n",
|
||||||
|
"class EpsilonGreedyPolicy(nn.Module):\n",
|
||||||
|
" def __init__(self, s_dim, h_dim, a_dim, epsilon_spec={'start': 0.1, 'end': 0.0, 'end_epoch': 200}):\n",
|
||||||
|
" super(EpsilonGreedyPolicy, self).__init__()\n",
|
||||||
|
"\n",
|
||||||
|
" self.net = nn.Sequential(nn.Linear(s_dim, h_dim),\n",
|
||||||
|
" nn.ReLU(),\n",
|
||||||
|
" nn.Linear(h_dim, h_dim),\n",
|
||||||
|
" nn.ReLU(),\n",
|
||||||
|
" nn.Linear(h_dim, a_dim))\n",
|
||||||
|
" # (...)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5d47ae91",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"policy = DQN(is_train=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "de41de6f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Każdy krok procedury uczenia składa się z dwóch etapów:\n",
|
||||||
|
"\n",
|
||||||
|
" 1. Wygenerowania przy użyciu taktyki (metoda `policy.predict`) oraz środowiska (metoda `env.step`) *trajektorii*, tj. sekwencji przejść pomiędzy stanami złożonych z krotek postaci:\n",
|
||||||
|
" - stanu źródłowego,\n",
|
||||||
|
" - podjętej akcji (aktu systemowego),\n",
|
||||||
|
" - nagrody,\n",
|
||||||
|
" - stanu docelowego,\n",
|
||||||
|
" - znacznika końca dialogu."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9063d436",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# por. ConvLab-2/convlab2/policy/dqn/train.py\n",
|
||||||
|
"def sample(env, policy, batch_size, warm_up):\n",
|
||||||
|
" buff = Memory()\n",
|
||||||
|
" sampled_num = 0\n",
|
||||||
|
" max_trajectory_len = 50\n",
|
||||||
|
"\n",
|
||||||
|
" while sampled_num < batch_size:\n",
|
||||||
|
" # rozpoczęcie nowego dialogu\n",
|
||||||
|
" s = env.reset()\n",
|
||||||
|
"\n",
|
||||||
|
" for t in range(max_trajectory_len):\n",
|
||||||
|
" try:\n",
|
||||||
|
" # podjęcie akcji przez agenta dialogowego\n",
|
||||||
|
" a = policy.predict(s, warm_up=warm_up)\n",
|
||||||
|
"\n",
|
||||||
|
" # odpowiedź środowiska na podjętą akcje\n",
|
||||||
|
" next_s, r, done = env.step(a)\n",
|
||||||
|
"\n",
|
||||||
|
" # dodanie krotki do zbioru danych\n",
|
||||||
|
" buff.push(torch.Tensor(policy.vector.state_vectorize(s)).numpy(), # stan źródłowy\n",
|
||||||
|
" policy.vector.action_vectorize(a), # akcja\n",
|
||||||
|
" r, # nagroda\n",
|
||||||
|
" torch.Tensor(policy.vector.state_vectorize(next_s)).numpy(), # stan docelowy\n",
|
||||||
|
" 0 if done else 1) # znacznik końca\n",
|
||||||
|
"\n",
|
||||||
|
" s = next_s\n",
|
||||||
|
"\n",
|
||||||
|
" if done:\n",
|
||||||
|
" break\n",
|
||||||
|
" except:\n",
|
||||||
|
" break\n",
|
||||||
|
"\n",
|
||||||
|
" sampled_num += t\n",
|
||||||
|
"\n",
|
||||||
|
" return buff"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "281d76a1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
" 2. Wykorzystania wygenerowanych krotek do aktualizacji taktyki.\n",
|
||||||
|
"\n",
|
||||||
|
"Funkcja `train` realizująca pojedynczy krok uczenia przez wzmacnianie ma następującą postać"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a42a21da",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def train(env, policy, batch_size, epoch, warm_up):\n",
|
||||||
|
" print(f'epoch: {epoch}')\n",
|
||||||
|
" buff = sample(env, policy, batch_size, warm_up)\n",
|
||||||
|
" policy.update_memory(buff)\n",
|
||||||
|
" policy.update(epoch)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "074508c2",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Metoda `update` klasy `DQN` wykorzystywana do aktualizacji wag ma następującą postać"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6ce332bd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%%script false --no-raise-error\n",
|
||||||
|
"#\n",
|
||||||
|
"# plik convlab2/policy/dqn/dqn.py\n",
|
||||||
|
"# klasa DQN\n",
|
||||||
|
"#\n",
|
||||||
|
"class DQN(Policy):\n",
|
||||||
|
" # (...)\n",
|
||||||
|
" def update(self, epoch):\n",
|
||||||
|
" total_loss = 0.\n",
|
||||||
|
" for i in range(self.training_iter):\n",
|
||||||
|
" round_loss = 0.\n",
|
||||||
|
" # 1. batch a sample from memory\n",
|
||||||
|
" batch = self.memory.get_batch(batch_size=self.batch_size)\n",
|
||||||
|
"\n",
|
||||||
|
" for _ in range(self.training_batch_iter):\n",
|
||||||
|
" # 2. calculate the Q loss\n",
|
||||||
|
" loss = self.calc_q_loss(batch)\n",
|
||||||
|
"\n",
|
||||||
|
" # 3. make a optimization step\n",
|
||||||
|
" self.net_optim.zero_grad()\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" self.net_optim.step()\n",
|
||||||
|
"\n",
|
||||||
|
" round_loss += loss.item()\n",
|
||||||
|
"\n",
|
||||||
|
" # (...)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "60e3db5d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Przebieg procesu uczenia zilustrujemy wykonując 10 iteracji. W każdej iteracji ograniczymy się do 100 przykładów."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6d50da62",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"epoch = 10\n",
|
||||||
|
"batch_size = 100\n",
|
||||||
|
"\n",
|
||||||
|
"train(env, policy, batch_size, 0, warm_up=True)\n",
|
||||||
|
"\n",
|
||||||
|
"for i in range(1, epoch):\n",
|
||||||
|
" train(env, policy, batch_size, i, warm_up=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ecf3ee2b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Sprawdźmy jakie akty systemowe zwraca taktyka `DQN` w odpowiedzi na zmieniający się stan dialogu."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "575c0211",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from convlab2.dialog_agent import PipelineAgent\n",
|
||||||
|
"dst.init_session()\n",
|
||||||
|
"agent = PipelineAgent(nlu=None, dst=dst, policy=policy, nlg=None, name='sys')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ef846914",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent.response([['Inform', 'Hotel', 'Price', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6f0aaa0d",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent.response([['Inform', 'Hotel', 'Area', 'north']])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "64cb7ad5",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent.response([['Request', 'Hotel', 'Area', '?']])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "157f3e0c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent.response([['Inform', 'Hotel', 'Day', 'tuesday'], ['Inform', 'Hotel', 'People', '2'], ['Inform', 'Hotel', 'Stay', '4']])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "9bc4474b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Jakość wyuczonego modelu możemy ocenić mierząc tzw. wskaźnik sukcesu (ang. *task success rate*),\n",
|
||||||
|
"tj. stosunek liczby dialogów zakończonych powodzeniem do liczby wszystkich dialogów."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "58c875d7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from convlab2.dialog_agent.session import BiSession\n",
|
||||||
|
"\n",
|
||||||
|
"sess = BiSession(agent, usr_simulator, None, evaluator)\n",
|
||||||
|
"dialog_num = 100\n",
|
||||||
|
"task_success_num = 0\n",
|
||||||
|
"max_turn_num = 50\n",
|
||||||
|
"\n",
|
||||||
|
"# por. ConvLab-2/convlab2/policy/evaluate.py\n",
|
||||||
|
"for dialog in range(dialog_num):\n",
|
||||||
|
" random.seed(dialog)\n",
|
||||||
|
" np.random.seed(dialog)\n",
|
||||||
|
" torch.manual_seed(dialog)\n",
|
||||||
|
" sess.init_session()\n",
|
||||||
|
" sys_act = []\n",
|
||||||
|
" task_success = 0\n",
|
||||||
|
"\n",
|
||||||
|
" for _ in range(max_turn_num):\n",
|
||||||
|
" sys_act, _, finished, _ = sess.next_turn(sys_act)\n",
|
||||||
|
"\n",
|
||||||
|
" if finished is True:\n",
|
||||||
|
" task_success = sess.evaluator.task_success()\n",
|
||||||
|
" break\n",
|
||||||
|
"\n",
|
||||||
|
" print(f'dialog: {dialog:02} success: {task_success}')\n",
|
||||||
|
" task_success_num += task_success\n",
|
||||||
|
"\n",
|
||||||
|
"print('')\n",
|
||||||
|
"print(f'task success rate: {task_success_num/dialog_num:.2f}')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5313885b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"**Uwaga**: Chcąc uzyskać taktykę o skuteczności porównywalnej z wynikami przedstawionymi na stronie\n",
|
||||||
|
"[ConvLab-2](https://github.com/thu-coai/ConvLab-2/blob/master/README.md) trzeba odpowiednio\n",
|
||||||
|
"zwiększyć zarówno liczbę iteracji jak i liczbę przykładów generowanych w każdym przyroście.\n",
|
||||||
|
"W celu przyśpieszenia procesu uczenia warto zrównoleglić obliczenia, jak pokazano w\n",
|
||||||
|
"skrypcie [train.py](https://github.com/thu-coai/ConvLab-2/blob/master/convlab2/policy/dqn/train.py).\n",
|
||||||
|
"\n",
|
||||||
|
"Literatura\n",
|
||||||
|
"----------\n",
|
||||||
|
" 1. Rieser, V., Lemon, O., (2011). Reinforcement learning for adaptive dialogue systems: a data-driven methodology for dialogue management and natural language generation. (Theory and Applications of Natural Language Processing). Springer. https://doi.org/10.1007/978-3-642-24942-6\n",
|
||||||
|
"\n",
|
||||||
|
" 2. Richard S. Sutton and Andrew G. Barto, (2018). Reinforcement Learning: An Introduction, Second Edition, MIT Press, Cambridge, MA http://incompleteideas.net/book/RLbook2020.pdf\n",
|
||||||
|
"\n",
|
||||||
|
" 3. Volodymyr Mnih and Koray Kavukcuoglu and David Silver and Alex Graves and Ioannis Antonoglou and Daan Wierstra and Martin Riedmiller, (2013). Playing Atari with Deep Reinforcement Learning, NIPS Deep Learning Workshop, https://arxiv.org/pdf/1312.5602.pdf"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"jupytext": {
|
||||||
|
"cell_metadata_filter": "-all",
|
||||||
|
"main_language": "python",
|
||||||
|
"notebook_metadata_filter": "-all"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
287
notebooks/11-generowanie-odpowiedzi.ipynb
Normal file
287
notebooks/11-generowanie-odpowiedzi.ipynb
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "1a9e48a0",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Generowanie odpowiedzi\n",
|
||||||
|
"======================\n",
|
||||||
|
"\n",
|
||||||
|
"W systemie dialogowym taktyka prowadzenia dialogu odpowiada za wyznaczanie aktów systemowych, czyli wskazanie tego **co ma zostać przez system wypowiedziane** i/lub wykonane.\n",
|
||||||
|
"Zadaniem modułu generowania odpowiedzi jest zamiana aktów dialogowych na wypowiedzi w języku\n",
|
||||||
|
"naturalnym, czyli wskazanie tego **w jaki sposób** ma zostać wypowiedziane to co ma zostać\n",
|
||||||
|
"wypowiedziane.\n",
|
||||||
|
"\n",
|
||||||
|
"Generowanie odpowiedzi przy użyciu szablonów\n",
|
||||||
|
"--------------------------------------------\n",
|
||||||
|
"Podstawowe narzędzie wykorzystywane w modułach generowania odpowiedzi stanowią szablony tekstowe\n",
|
||||||
|
"interpolujące zmienne. W Pythonie mechanizm ten jest dostępny za pośrednictwem\n",
|
||||||
|
"[f-stringów](https://docs.python.org/3/reference/lexical_analysis.html#f-strings), metody\n",
|
||||||
|
"[format](https://docs.python.org/3/library/string.html#formatstrings) oraz zewnętrznych bibliotek takich, jak [Jinja2](https://jinja.palletsprojects.com/).\n",
|
||||||
|
"\n",
|
||||||
|
"O ile podejście wykorzystujące wbudowane mechanizmy języka Python sprawdza się w prostych\n",
|
||||||
|
"przypadkach..."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "250a4248",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def nlg(system_act):\n",
|
||||||
|
" domain, intent, slot, value = system_act\n",
|
||||||
|
"\n",
|
||||||
|
" if intent == 'Inform' and slot == 'Phone':\n",
|
||||||
|
" return f'Numer telefonu to {value}'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "54e4076a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"nlg(['Hotel', 'Inform', 'Phone', '1234567890'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "38dfc0e6",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"... to trzeba mieć świadomość, że w toku prac nad agentem dialogowym może być konieczne\n",
|
||||||
|
"uwzględnienie m.in.:\n",
|
||||||
|
"\n",
|
||||||
|
" 1. szablonów zależnych od wartości slotów"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5f0930e1",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def nlg(system_act):\n",
|
||||||
|
" domain, intent, slot, value = system_act\n",
|
||||||
|
"\n",
|
||||||
|
" if domain == 'Restaurant' and intent == 'Inform' and slot == 'Count':\n",
|
||||||
|
" if value == 0:\n",
|
||||||
|
" return f'Nie znalazłem restauracji spełniających podane kryteria.'\n",
|
||||||
|
" elif value == 1:\n",
|
||||||
|
" return f'Znalazłem jedną restaurację spełniającą podane kryteria.'\n",
|
||||||
|
" elif value <= 4:\n",
|
||||||
|
" return f'Znalazłem {value} restauracje spełniające podane kryteria.'\n",
|
||||||
|
" elif value <= 9:\n",
|
||||||
|
" return f'Znalazłem {value} restauracji spełniających podane kryteria.'\n",
|
||||||
|
" else:\n",
|
||||||
|
" return f'Znalazłem wiele restauracji spełniających podane kryteria.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "1245bde1",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"nlg(['Restaurant', 'Inform', 'Count', 0])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7c09def2",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"nlg(['Restaurant', 'Inform', 'Count', 1])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "bcf67aa4",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"nlg(['Restaurant', 'Inform', 'Count', 2])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6390d6c5",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"nlg(['Restaurant', 'Inform', 'Count', 6])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "3b269a47",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"nlg(['Restaurant', 'Inform', 'Count', 100])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "acec991b",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
" 2. wielu wariantów tej samej wypowiedzi"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d35b82f9",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import random\n",
|
||||||
|
"\n",
|
||||||
|
"def nlg(system_act):\n",
|
||||||
|
" domain, intent, slot, value = system_act\n",
|
||||||
|
"\n",
|
||||||
|
" if intent == 'Affirm':\n",
|
||||||
|
" r = random.randint(1, 3)\n",
|
||||||
|
"\n",
|
||||||
|
" if r == 1:\n",
|
||||||
|
" return 'Tak'\n",
|
||||||
|
" elif r == 2:\n",
|
||||||
|
" return 'Zgadza się'\n",
|
||||||
|
" else:\n",
|
||||||
|
" return 'Potwierdzam'\n",
|
||||||
|
"\n",
|
||||||
|
"nlg(['Hotel', 'Affirm', '', ''])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "8e0244f4",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
" 3. wielojęzycznego interfejsu użytkownika"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "bed51e01",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def nlg_en(system_act):\n",
|
||||||
|
" domain, intent, slot, value = system_act\n",
|
||||||
|
"\n",
|
||||||
|
" if domain == 'Hotel' and intent == 'Request' and slot == 'CreditCardNo':\n",
|
||||||
|
" return 'What is your credit card number?'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6762fc98",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"nlg_en(['Hotel', 'Request', 'CreditCardNo', '?'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "23323149",
|
||||||
|
"metadata": {
|
||||||
|
"lines_to_next_cell": 0
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"Generowanie odpowiedzi z wykorzystaniem uczenia maszynowego\n",
|
||||||
|
"-----------------------------------------------------------\n",
|
||||||
|
"Obok mechanizmu szablonów do generowania odpowiedzi można również\n",
|
||||||
|
"stosować techniki uczenia maszynowego.\n",
|
||||||
|
"Zagadnienie to stanowiło\n",
|
||||||
|
"przedmiot konkursu [E2E NLG Challenge](http://www.macs.hw.ac.uk/InteractionLab/E2E/) (Novikova i in., 2017).\n",
|
||||||
|
"Przyjrzyjmy się danym, jakie udostępnili organizatorzy."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0e616020",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!mkdir -p l11\n",
|
||||||
|
"!curl -L -C - https://github.com/tuetschek/e2e-dataset/releases/download/v1.0.0/e2e-dataset.zip -o l11/e2e-dataset.zip\n",
|
||||||
|
"!unzip l11/e2e-dataset.zip -d l11"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d9a1032e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"\n",
|
||||||
|
"trainset = pd.read_csv('l11/e2e-dataset/trainset.csv')\n",
|
||||||
|
"trainset"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d3d5e0d6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Zadanie\n",
|
||||||
|
"-------\n",
|
||||||
|
"Zaimplementować moduł generowania odpowiedzi obejmujący akty systemowe występujące w zgromadzonym korpusie.\n",
|
||||||
|
"\n",
|
||||||
|
"Literatura\n",
|
||||||
|
"----------\n",
|
||||||
|
" 1. Jekaterina Novikova, Ondřej Dušek, Verena Rieser, The E2E Dataset: New Challenges For End-to-End Generation, Proceedings of the SIGDIAL 2017 Conference, pages 201-206, Saarbrücken, Germany https://arxiv.org/pdf/1706.09254.pdf"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"jupytext": {
|
||||||
|
"cell_metadata_filter": "-all",
|
||||||
|
"main_language": "python",
|
||||||
|
"notebook_metadata_filter": "-all"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user