From 53d498fb2cc2f0fbb3b5516cc0210d068e53fa98 Mon Sep 17 00:00:00 2001 From: s459312 Date: Thu, 25 May 2023 14:49:44 +0200 Subject: [PATCH] change dialog policy --- DialogManager.ipynb | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/DialogManager.ipynb b/DialogManager.ipynb index 5a5e99a..7f0ed5e 100644 --- a/DialogManager.ipynb +++ b/DialogManager.ipynb @@ -102,20 +102,16 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 132, "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "import json\n", "\n", - "\n", - "\n", - "\n", "class SimpleRulePolicy():\n", " def __init__(self):\n", - " with open('product_db.json', encoding='utf-8') as json_file:\n", - " self.db = json.load(json_file)\n", + " self.db = json.load(open('product_db.json'))\n", "\n", " def predict(self, state):\n", " self.results = []\n", @@ -135,14 +131,15 @@ "\n", " def update_system_action(self, user_act, user_action, state, system_action):\n", " domain, intent = user_act\n", + " self.results = self.db['database'][domain]\n", "\n", " # Reguła 1\n", " if intent == 'request':\n", " if len(self.results) == 0:\n", " system_action[(domain, 'NoOffer')] = []\n", " else:\n", - " for slot in user_action[user_act]: \n", - " if slot[0] in self.results[0]:\n", + " for slot in user_action[user_act]:\n", + " if self.results and slot[0] in self.results[0]:\n", " system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])\n", "\n", " # Reguła 2\n", @@ -151,15 +148,17 @@ " system_action[(domain, 'NoOffer')] = []\n", " else:\n", " system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])\n", - " choice = self.results[0]\n", "\n", - " if domain in [\"payment\", \"delivery\", \"product\"]:\n", - " system_action[(domain, 'Recommend')].append(['Name', choice['name']])" + " if self.results and 'name' in self.results[0]:\n", + " choice = self.results[0]\n", + "\n", + " if domain in [\"payment\", \"delivery\", \"product\"]:\n", + " system_action[(domain, 'Recommend')].append(['Name', choice['name']])" ] }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 133, "metadata": {}, "outputs": [ { @@ -176,10 +175,19 @@ { "data": { "text/plain": [ - "[]" + "[['Inform', 'payment', 'Choice', '1'],\n", + " ['Inform', 'payment', 'Choice', '1'],\n", + " ['Inform', 'delivery', 'Choice', '1'],\n", + " ['Inform', 'delivery', 'Choice', '1'],\n", + " ['Inform', 'product', 'Choice', '3'],\n", + " ['Inform', 'product', 'Choice', '3'],\n", + " ['Inform', 'product', 'name', 'banan'],\n", + " ['Inform', 'product', 'name', 'banan'],\n", + " ['Recommend', 'product', 'Name', 'banan'],\n", + " ['Recommend', 'product', 'Name', 'banan']]" ] }, - "execution_count": 123, + "execution_count": 133, "metadata": {}, "output_type": "execute_result" }