From 9001dfae126e31e05e5e3a1693366937f2546636 Mon Sep 17 00:00:00 2001 From: eugene Date: Thu, 25 May 2023 13:32:35 +0200 Subject: [PATCH] edit domain in dialogmanager --- DialogManager.ipynb | 66 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/DialogManager.ipynb b/DialogManager.ipynb index 361c4c5..e20ff14 100644 --- a/DialogManager.ipynb +++ b/DialogManager.ipynb @@ -70,6 +70,72 @@ "print(state['belief_state'])\n", "print(state['request_state'])" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "import copy\n", + "import json\n", + "from copy import deepcopy\n", + "\n", + "from convlab.policy.policy import Policy\n", + "from convlab.util.multiwoz.dbquery import Database\n", + "\n", + "\n", + "class SimpleRulePolicy(Policy):\n", + " def __init__(self):\n", + " Policy.__init__(self)\n", + " self.db = Database()\n", + "\n", + " def predict(self, state):\n", + " self.results = []\n", + " system_action = defaultdict(list)\n", + " user_action = defaultdict(list)\n", + "\n", + " for intent, domain, slot, value in state['user_action']:\n", + " user_action[(domain.lower(), intent.lower())].append((slot.lower(), value))\n", + "\n", + " for user_act in user_action:\n", + " self.update_system_action(user_act, user_action, state, system_action)\n", + "\n", + " # Reguła 3\n", + " if any(True for slots in user_action.values() for (slot, _) in slots if slot in ['book stay', 'book day', 'book people']):\n", + " if self.results:\n", + " system_action = {('Booking', 'Book'): [[\"Ref\", self.results[0].get('Ref', 'N/A')]]}\n", + "\n", + " system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for slot, value in slots]\n", + " state['system_action'] = system_acts\n", + " return system_acts\n", + "\n", + " def update_system_action(self, user_act, user_action, state, system_action):\n", + " domain, intent = user_act\n", + " constraints = [(slot, value) for slot, value in state['belief_state'][domain.lower()].items() if value != '']\n", + " self.results = deepcopy(self.db.query(domain.lower(), constraints))\n", + "\n", + " # Reguła 1\n", + " if intent == 'request':\n", + " if len(self.results) == 0:\n", + " system_action[(domain, 'NoOffer')] = []\n", + " else:\n", + " for slot in user_action[user_act]: \n", + " if slot[0] in self.results[0]:\n", + " system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])\n", + "\n", + " # Reguła 2\n", + " elif intent == 'inform':\n", + " if len(self.results) == 0:\n", + " system_action[(domain, 'NoOffer')] = []\n", + " else:\n", + " system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])\n", + " choice = self.results[0]\n", + "\n", + " if domain in [\"product\"]:\n", + " system_action[(domain, 'Recommend')].append(['Name', choice['name']])" + ] } ], "metadata": {