diff --git a/hotels_data.json b/hotels_data.json new file mode 100644 index 0000000..52c4ecc --- /dev/null +++ b/hotels_data.json @@ -0,0 +1,254 @@ +[ + { + "name": "Hotel Marriott", + "area": "centre", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Hotel Cambridge", + "area": "centre", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Hotel Belfry", + "area": "suburbs", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Aylesbray Guest House", + "area": "centre", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "guesthouse" + }, + { + "name": "University Arms Hotel", + "area": "centre", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Lensfield Hotel", + "area": "north", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Autumn House Hotel", + "area": "east", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "guesthouse" + }, + { + "name": "Finches Bed and Breakfast", + "area": "west", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "bed and breakfast" + }, + { + "name": "Arbury Lodge Guest House", + "area": "north", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "guesthouse" + }, + { + "name": "Royal Cambridge Hotel", + "area": "centre", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Hilton Hotel", + "area": "suburbs", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Holiday Inn Hotel", + "area": "centre", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Radisson Hotel", + "area": "centre", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Regent Guest House", + "area": "north", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "guesthouse" + }, + { + "name": "Travelodge Hotel", + "area": "centre", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Premier Inn Hotel", + "area": "centre", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Ibis Hotel", + "area": "suburbs", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Novotel Hotel", + "area": "centre", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Mercure Hotel", + "area": "suburbs", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Crowne Plaza Hotel", + "area": "centre", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Best Western Hotel", + "area": "north", + "parking": "yes", + "price range": "cheap", + "stars": "4", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Marriott Hotel", + "area": "west", + "parking": "yes", + "price range": "expensive", + "stars": "5", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Hyatt Regency Hotel", + "area": "south", + "parking": "yes", + "price range": "expensive", + "stars": "5", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Four Seasons Hotel", + "area": "centre", + "parking": "yes", + "price range": "expensive", + "stars": "5", + "internet": "yes", + "type": "hotel" + }, + { + "name": "The Ritz Hotel", + "area": "centre", + "parking": "yes", + "price range": "expensive", + "stars": "5", + "internet": "yes", + "type": "hotel" + }, + { + "name": "The Savoy Hotel", + "area": "centre", + "parking": "yes", + "price range": "expensive", + "stars": "5", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Mandarin Oriental Hotel", + "area": "suburbs", + "parking": "yes", + "price range": "expensive", + "stars": "5", + "internet": "yes", + "type": "hotel" + }, + { + "name": "Shangri-La Hotel", + "area": "centre", + "parking": "yes", + "price range": "expensive", + "stars": "5", + "internet": "yes", + "type": "hotel" + } +] diff --git a/reguly.ipynb b/reguly.ipynb index 4ddb3a9..7224dda 100644 --- a/reguly.ipynb +++ b/reguly.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 29, + "execution_count": 212, "id": "706dd5e1-57ee-416b-a77c-5d15df8dbdc8", "metadata": {}, "outputs": [], @@ -43,63 +43,7 @@ }, { "cell_type": "code", - "execution_count": 27, - "id": "423f0821-000a-4aaa-b400-2e7554866175", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'user_action': [],\n", - " 'system_action': [],\n", - " 'belief_state': {'attraction': {'type': '', 'name': '', 'area': ''},\n", - " 'hotel': {'name': '',\n", - " 'area': '',\n", - " 'parking': '',\n", - " 'price range': '',\n", - " 'stars': '4',\n", - " 'internet': 'yes',\n", - " 'type': 'hotel',\n", - " 'book stay': '',\n", - " 'book day': '',\n", - " 'book people': ''},\n", - " 'restaurant': {'food': '',\n", - " 'price range': '',\n", - " 'name': '',\n", - " 'area': '',\n", - " 'book time': '',\n", - " 'book day': '',\n", - " 'book people': ''},\n", - " 'taxi': {'leave at': '',\n", - " 'destination': '',\n", - " 'departure': '',\n", - " 'arrive by': ''},\n", - " 'train': {'leave at': '',\n", - " 'destination': '',\n", - " 'day': '',\n", - " 'arrive by': '',\n", - " 'departure': '',\n", - " 'book people': ''},\n", - " 'hospital': {'department': ''}},\n", - " 'booked': {},\n", - " 'request_state': {},\n", - " 'terminated': False,\n", - " 'history': []}" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from convlab.util.multiwoz.state import default_state\n", - "default_state()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, + "execution_count": 213, "id": "06926543-cab1-48e7-8e82-0560fc0fa16a", "metadata": {}, "outputs": [], @@ -109,11 +53,42 @@ "from convlab.dst.dst import DST\n", "from convlab.dst.rule.multiwoz.dst_util import normalize_value\n", "\n", - "class SimpleRuleDST(DST):\n", + "\n", + "def default_state():\n", + " return {\n", + " 'belief_state': {\n", + " 'hotel': {\n", + " 'info': {\n", + " 'name': '',\n", + " 'area': '',\n", + " 'parking': '',\n", + " 'price range': '',\n", + " 'stars': '',\n", + " 'internet': '',\n", + " 'type': ''\n", + " },\n", + " 'booking': {\n", + " 'book stay': '',\n", + " 'book day': '',\n", + " 'book people': ''\n", + " }\n", + " }\n", + " },\n", + " 'request_state': {},\n", + " 'history': [],\n", + " 'user_action': [],\n", + " 'system_action': [],\n", + " 'terminated': False,\n", + " 'booked': []\n", + " }\n", + "\n", + "\n", + "class DialogueStateTracker(DST):\n", " def __init__(self):\n", " DST.__init__(self)\n", " self.state = default_state()\n", - " self.value_dict = json.load(open('value_dict.json'))\n", + " with open('./hotels_data.json') as f:\n", + " self.value_dict = json.load(f)\n", "\n", " def update(self, user_act=None):\n", " for intent, domain, slot, value in user_act:\n", @@ -125,14 +100,14 @@ " continue\n", "\n", " if intent == 'inform':\n", - " if slot == 'none' or slot == '':\n", + " if slot == 'none' or slot == '' or value == 'dontcare':\n", " continue\n", "\n", - " domain_dic = self.state['belief_state'][domain]\n", + " domain_dic = self.state['belief_state'][domain]['info']\n", "\n", " if slot in domain_dic:\n", - " nvalue = normalize_value(self.value_dict, domain, slot, value)\n", - " self.state['belief_state'][domain][slot] = nvalue\n", + " nvalue = self.normalize_value(self.value_dict, domain, slot, value)\n", + " self.state['belief_state'][domain]['info'][slot] = nvalue\n", "\n", " elif intent == 'request':\n", " if domain not in self.state['request_state']:\n", @@ -142,55 +117,21 @@ "\n", " return self.state\n", "\n", + " def normalize_value(self, value_dict, domain, slot, value):\n", + " normalized_value = value.lower().strip()\n", + " if domain in value_dict and slot in value_dict[domain]:\n", + " possible_values = value_dict[domain][slot]\n", + " if isinstance(possible_values, dict) and normalized_value in possible_values:\n", + " return possible_values[normalized_value]\n", + " return value\n", + "\n", " def init_session(self):\n", - " self.state = default_state()" + " self.state = default_state()\n" ] }, { "cell_type": "code", - "execution_count": null, - "id": "b1d42d5f-e923-4c46-a930-48da9b72d77b", - "metadata": {}, - "outputs": [], - "source": [ - "dst = SimpleRuleDST()\n", - "dst.state" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "749e3a90-17c3-4a3e-acd7-856560445eaf", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'name': '',\n", - " 'area': '',\n", - " 'parking': 'yes',\n", - " 'price range': 'cheap',\n", - " 'stars': '4',\n", - " 'internet': 'yes',\n", - " 'type': 'hotel',\n", - " 'book stay': '',\n", - " 'book day': '',\n", - " 'book people': ''}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dst.update([['Inform', 'Hotel', 'Price Range', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']])\n", - "dst.state['belief_state']['hotel']" - ] - }, - { - "cell_type": "code", - "execution_count": 10, + "execution_count": 214, "id": "a7f3d067-3a95-4ef5-b216-be5840bc8831", "metadata": {}, "outputs": [], @@ -203,11 +144,27 @@ "from convlab.policy.policy import Policy\n", "from convlab.util.multiwoz.dbquery import Database\n", "\n", + "db_path = './hotels_data.json'\n", "\n", - "class SimpleRulePolicy(Policy):\n", + "class DialoguePolicy(Policy):\n", " def __init__(self):\n", " Policy.__init__(self)\n", - " self.db = Database()\n", + " self.db = self.load_database(db_path)\n", + "\n", + " def load_database(self, db_path):\n", + " with open(db_path, 'r', encoding='utf-8') as f:\n", + " return json.load(f)\n", + "\n", + " def query(self, domain, constraints):\n", + " if domain != 'hotel':\n", + " return []\n", + " \n", + " results = []\n", + " for entry in self.db:\n", + " match = all(entry.get(key) == value for key, value in constraints)\n", + " if match:\n", + " results.append(entry)\n", + " return results\n", "\n", " def predict(self, state):\n", " self.results = []\n", @@ -220,7 +177,6 @@ " for user_act in user_action:\n", " self.update_system_action(user_act, user_action, state, system_action)\n", "\n", - " # Reguła 3\n", " if any(True for slots in user_action.values() for (slot, _) in slots if slot in ['book stay', 'book day', 'book people']):\n", " if self.results:\n", " system_action = {('Booking', 'Book'): [[\"Ref\", self.results[0].get('Ref', 'N/A')]]}\n", @@ -231,10 +187,11 @@ "\n", " def update_system_action(self, user_act, user_action, state, system_action):\n", " domain, intent = user_act\n", - " constraints = [(slot, value) for slot, value in state['belief_state'][domain.lower()].items() if value != '']\n", - " self.results = deepcopy(self.db.query(domain.lower(), constraints))\n", + " constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != '']\n", + " print(f\"Constraints: {constraints}\")\n", + " self.results = deepcopy(self.query(domain.lower(), constraints))\n", + " print(f\"Query results: {self.results}\")\n", "\n", - " # Reguła 1\n", " if intent == 'request':\n", " if len(self.results) == 0:\n", " system_action[(domain, 'NoOffer')] = []\n", @@ -243,7 +200,6 @@ " if slot[0] in self.results[0]:\n", " system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])\n", "\n", - " # Reguła 2\n", " elif intent == 'inform':\n", " if len(self.results) == 0:\n", " system_action[(domain, 'NoOffer')] = []\n", @@ -251,60 +207,17 @@ " system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])\n", " choice = self.results[0]\n", "\n", - " if domain in [\"hotel\", \"attraction\", \"police\", \"restaurant\"]:\n", - " system_action[(domain, 'Recommend')].append(['Name', choice['name']])" + " if domain in [\"hotel\"]:\n", + " system_action[(domain, 'Recommend')].append(['Name', choice['name']])\n", + " for slot in state['belief_state'][domain]['info']:\n", + " if choice.get(slot):\n", + " state['belief_state'][domain]['info'][slot] = choice[slot]" ] }, { "cell_type": "code", - "execution_count": 11, - "id": "089dbfa8-d34a-457c-9084-ef335372ea05", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:nlu info_dict is not initialized\n", - "WARNING:root:dst info_dict is not initialized\n", - "WARNING:root:policy info_dict is not initialized\n", - "WARNING:root:nlg info_dict is not initialized\n" - ] - } - ], - "source": [ - "from convlab.dialog_agent import PipelineAgent\n", - "dst.init_session()\n", - "policy = SimpleRulePolicy()\n", - "agent = PipelineAgent(nlu=None, dst=dst, policy=policy, nlg=None, name='sys')" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "5ac57cc8-6650-4a1b-a87e-2cda67d9b0f3", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[['Inform', 'hotel', 'Choice', '3'],\n", - " ['Recommend', 'hotel', 'Name', 'huntingdon marriott hotel']]" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "agent.response([['Inform', 'Hotel', 'Price Range', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']])" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "eaeca7b0-08d5-4db0-9eb3-3aceda24f987", + "execution_count": 218, + "id": "11f34b20-c5b0-4752-8610-21f5eef4b569", "metadata": {}, "outputs": [ { @@ -326,34 +239,50 @@ } ], "source": [ - "from convlab.base_models.t5.nlu import T5NLU\n", "from convlab.nlg.template.multiwoz import TemplateNLG\n", + "from convlab.dialog_agent import PipelineAgent\n", "\n", - "# nlu = T5NLU(speaker='user', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlu-multiwoz21')\n", "nlu = NaturalLanguageAnalyzer()\n", + "dst = DialogueStateTracker()\n", + "policy = DialoguePolicy()\n", "nlg = TemplateNLG(is_user=False)\n", + "\n", "agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')" ] }, { "cell_type": "code", - "execution_count": 37, - "id": "b559fcd3-861b-49d7-ac2b-d3160d4c5a1d", + "execution_count": 219, + "id": "faf05778-2bca-4044-97a7-d6facf853e10", + "metadata": {}, + "outputs": [], + "source": [ + "# nla = NaturalLanguageAnalyzer()\n", + "# nla_response = nla.predict(\"chciałbym zarezerwować drogi hotel bez parkingu 1 stycznia w Warszawie w centrum\")\n", + "# print(nla_response)\n", + "# response = agent.response(nla_response)\n", + "# print(response)" + ] + }, + { + "cell_type": "code", + "execution_count": 220, + "id": "6c837788-e7d5-483e-b873-00061f118619", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "'We have 3 such places . Would huntingdon marriott hotel work for you ?'" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "Constraints: [('area', 'centre'), ('parking', 'yes'), ('price range', 'expensive'), ('type', 'hotel')]\n", + "Query results: [{'name': 'Four Seasons Hotel', 'area': 'centre', 'parking': 'yes', 'price range': 'expensive', 'stars': '5', 'internet': 'yes', 'type': 'hotel'}, {'name': 'The Ritz Hotel', 'area': 'centre', 'parking': 'yes', 'price range': 'expensive', 'stars': '5', 'internet': 'yes', 'type': 'hotel'}, {'name': 'The Savoy Hotel', 'area': 'centre', 'parking': 'yes', 'price range': 'expensive', 'stars': '5', 'internet': 'yes', 'type': 'hotel'}, {'name': 'Shangri-La Hotel', 'area': 'centre', 'parking': 'yes', 'price range': 'expensive', 'stars': '5', 'internet': 'yes', 'type': 'hotel'}]\n", + "We have 4 such places . Four Seasons Hotel looks like it would be a good choice .\n" + ] } ], "source": [ - "agent.response(\"I need a cheap hotel with free parking .\")" + "response = agent.response(\"chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum\")\n", + "print(response)" ] }, {