From ee1d7e45d444f39a717d79debca4cbcd729e1895 Mon Sep 17 00:00:00 2001 From: s495728 Date: Fri, 7 Jun 2024 14:45:13 +0200 Subject: [PATCH] dialog_policy --- src/service/convlab/policy/policy.py | 24 +++++++ src/service/convlab/util/module.py | 25 +++++++ .../convlab/util/restaurant/dbquery.py | 38 +++++++++++ .../data/restaurant/db/confirm_db.json | 4 ++ src/service/data/restaurant/db/dough_db.json | 5 ++ src/service/data/restaurant/db/drink_db.json | 11 ++++ src/service/data/restaurant/db/food_db.json | 3 + src/service/data/restaurant/db/meat_db.json | 5 ++ src/service/data/restaurant/db/menu_db.json | 7 ++ src/service/data/restaurant/db/pizza_db.json | 51 +++++++++++++++ src/service/data/restaurant/db/sauce_db.json | 4 ++ src/service/data/restaurant/db/size_db.json | 14 ++++ src/service/dialog_policy.py | 65 +++++++++++++++++-- 13 files changed, 249 insertions(+), 7 deletions(-) create mode 100644 src/service/convlab/policy/policy.py create mode 100644 src/service/convlab/util/module.py create mode 100644 src/service/convlab/util/restaurant/dbquery.py create mode 100644 src/service/data/restaurant/db/confirm_db.json create mode 100644 src/service/data/restaurant/db/dough_db.json create mode 100644 src/service/data/restaurant/db/drink_db.json create mode 100644 src/service/data/restaurant/db/food_db.json create mode 100644 src/service/data/restaurant/db/meat_db.json create mode 100644 src/service/data/restaurant/db/menu_db.json create mode 100644 src/service/data/restaurant/db/pizza_db.json create mode 100644 src/service/data/restaurant/db/sauce_db.json create mode 100644 src/service/data/restaurant/db/size_db.json diff --git a/src/service/convlab/policy/policy.py b/src/service/convlab/policy/policy.py new file mode 100644 index 0000000..5ac85a8 --- /dev/null +++ b/src/service/convlab/policy/policy.py @@ -0,0 +1,24 @@ +"""Policy Interface""" +from convlab.util.module import Module + + +class Policy(Module): + """Policy module interface.""" + + def predict(self, state): + """Predict the next agent action given dialog state. + + Args: + state (dict or list of list): + when the policy takes dialogue state as input, the type is dict. + else when the policy takes dialogue act as input, the type is list of list. + Returns: + action (list of list or str): + when the policy outputs dialogue act, the type is list of list. + else when the policy outputs utterance directly, the type is str. + """ + return [] + + def update_memory(self, utterance_list, state_list, action_list, reward_list): + pass + diff --git a/src/service/convlab/util/module.py b/src/service/convlab/util/module.py new file mode 100644 index 0000000..9d280ce --- /dev/null +++ b/src/service/convlab/util/module.py @@ -0,0 +1,25 @@ +"""module interface.""" +from abc import ABC + + +class Module(ABC): + + def train(self, *args, **kwargs): + """Model training entry point""" + pass + + def test(self, *args, **kwargs): + """Model testing entry point""" + pass + + def from_cache(self, *args, **kwargs): + """restore internal state for multi-turn dialog""" + return None + + def to_cache(self, *args, **kwargs): + """save internal state for multi-turn dialog""" + return None + + def init_session(self): + """Init the class variables for a new session.""" + pass diff --git a/src/service/convlab/util/restaurant/dbquery.py b/src/service/convlab/util/restaurant/dbquery.py new file mode 100644 index 0000000..735874d --- /dev/null +++ b/src/service/convlab/util/restaurant/dbquery.py @@ -0,0 +1,38 @@ +""" +""" +import json +import os +import random +from fuzzywuzzy import fuzz +from itertools import chain +from copy import deepcopy + + +class Database(object): + def __init__(self): + super(Database, self).__init__() + # loading databases + domains = ['menu', 'pizza', 'drink', 'size'] + self.dbs = {} + for domain in domains: + with open(os.path.join(os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), + 'data/restaurant/db/{}_db.json'.format(domain))) as f: + self.dbs[domain] = json.load(f) + + def query(self, domain): + """Returns the list of entities for a given domain + based on the annotation of the belief state""" + # query the db + if domain == 'pizza': + return [{'Name': random.choice(self.dbs[domain]['name'])}] + if domain == 'menu': + return deepcopy(self.dbs[domain]) + if domain == 'drink': + return [{'Name': random.choice(self.dbs[domain]['name'])}] + if domain == 'size': + return [{'Size': random.choice(self.dbs[domain]['size'])}] + + +if __name__ == '__main__': + db = Database() \ No newline at end of file diff --git a/src/service/data/restaurant/db/confirm_db.json b/src/service/data/restaurant/db/confirm_db.json new file mode 100644 index 0000000..3dff699 --- /dev/null +++ b/src/service/data/restaurant/db/confirm_db.json @@ -0,0 +1,4 @@ +[ + "true", + "false" +] diff --git a/src/service/data/restaurant/db/dough_db.json b/src/service/data/restaurant/db/dough_db.json new file mode 100644 index 0000000..42e1445 --- /dev/null +++ b/src/service/data/restaurant/db/dough_db.json @@ -0,0 +1,5 @@ +[ + "pepsi", + "cola", + "water" +] \ No newline at end of file diff --git a/src/service/data/restaurant/db/drink_db.json b/src/service/data/restaurant/db/drink_db.json new file mode 100644 index 0000000..e61e4fe --- /dev/null +++ b/src/service/data/restaurant/db/drink_db.json @@ -0,0 +1,11 @@ +[ + { + "name":"pepsi" + }, + { + "name":"cola" + }, + { + "name":"water" + } +] \ No newline at end of file diff --git a/src/service/data/restaurant/db/food_db.json b/src/service/data/restaurant/db/food_db.json new file mode 100644 index 0000000..717beb7 --- /dev/null +++ b/src/service/data/restaurant/db/food_db.json @@ -0,0 +1,3 @@ +[ + "pizza" +] \ No newline at end of file diff --git a/src/service/data/restaurant/db/meat_db.json b/src/service/data/restaurant/db/meat_db.json new file mode 100644 index 0000000..a3c80eb --- /dev/null +++ b/src/service/data/restaurant/db/meat_db.json @@ -0,0 +1,5 @@ +[ + "chicken", + "ham", + "tuna" +] \ No newline at end of file diff --git a/src/service/data/restaurant/db/menu_db.json b/src/service/data/restaurant/db/menu_db.json new file mode 100644 index 0000000..76ae4ed --- /dev/null +++ b/src/service/data/restaurant/db/menu_db.json @@ -0,0 +1,7 @@ +[ + "capri", + "margarita", + "hawajska", + "barcelona", + "tuna" +] diff --git a/src/service/data/restaurant/db/pizza_db.json b/src/service/data/restaurant/db/pizza_db.json new file mode 100644 index 0000000..b2d9294 --- /dev/null +++ b/src/service/data/restaurant/db/pizza_db.json @@ -0,0 +1,51 @@ +[ + { + "name": "capri", + "ingredient": [ + "tomato", + "ham", + "mushrooms", + "cheese" + ], + "price": 25 + }, + { + "name": "margarita", + "ingredient": [ + "tomato", + "cheese" + ], + "price": 20 + }, + { + "name": "hawajska", + "ingredient": [ + "tomato", + "pineapple", + "chicken", + "cheese" + ], + "price": 30 + }, + { + "name": "barcelona", + "ingredient": [ + "tomato", + "onion", + "ham", + "pepper", + "cheese" + ], + "price": 40 + }, + { + "name": "tuna", + "ingredient": [ + "tomato", + "tuna", + "onion", + "cheese" + ], + "price": 40 + } +] diff --git a/src/service/data/restaurant/db/sauce_db.json b/src/service/data/restaurant/db/sauce_db.json new file mode 100644 index 0000000..f2c8ea7 --- /dev/null +++ b/src/service/data/restaurant/db/sauce_db.json @@ -0,0 +1,4 @@ +[ + "garlic", + "1000w" +] diff --git a/src/service/data/restaurant/db/size_db.json b/src/service/data/restaurant/db/size_db.json new file mode 100644 index 0000000..faaa0d8 --- /dev/null +++ b/src/service/data/restaurant/db/size_db.json @@ -0,0 +1,14 @@ +[ + { + "size": "m", + "price_multiplier": 1 + }, + { + "size": "l", + "price_multiplier": 1.2 + }, + { + "size": "xl", + "price_multiplier": 1.4 + } +] diff --git a/src/service/dialog_policy.py b/src/service/dialog_policy.py index ee9d29c..cafa615 100644 --- a/src/service/dialog_policy.py +++ b/src/service/dialog_policy.py @@ -1,8 +1,59 @@ -from model.frame import Frame +from collections import defaultdict +import copy +import json +from copy import deepcopy -class DialogPolicy: - def next_dialogue_act(self, frames: list[Frame]) -> Frame: - if frames[-1].act == "welcomemsg": - return Frame("system", "welcomemsg", []) - else: - return Frame("system", "canthelp", []) +from convlab.policy.policy import Policy +from convlab.util.restaurant.dbquery import Database + +class SimpleRulePolicy(Policy): + def __init__(self): + Policy.__init__(self) + self.db = Database() + + def predict(self, state): + self.results = [] + system_action = defaultdict(list) + user_action = defaultdict(list) + + for intent, domain, slot, value in state['user_action']: + user_action[(domain.lower(), intent.lower())].append((slot.lower(), value)) + + for user_act in user_action: + self.update_system_action(user_act, user_action, state, system_action) + + # Reguła 3 + if any(True for slots in user_action.values() for (slot, _) in slots if slot in ['pizza', 'size', 'drink']): + if self.results: + system_action = {('Ordering', 'Order'): [["Ref", self.results[0].get('Ref', 'N/A')]]} + + system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for slot, value in slots] + state['system_action'] = system_acts + return system_acts + + def update_system_action(self, user_act, user_action, state, system_action): + domain, intent = user_act + constraints = [(slot, value) for slot, value in state['belief_state'][domain.lower()].items() if value != ''] + self.results = deepcopy(self.db.query(domain.lower(), constraints)) + + # Reguła 1 + if intent == 'request': + if len(self.results) == 0: + system_action[(domain, 'NoOffer')] = [] + else: + for slot in user_action[user_act]: + if slot[0] in self.results[0]: + system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')]) + + # Reguła 2 + elif intent == 'inform': + if len(self.results) == 0: + system_action[(domain, 'NoOffer')] = [] + else: + system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))]) + choice = self.results[0] + + if domain in ["pizza", "drink"]: + system_action[(domain, 'Recommend')].append(['Name', choice['name']]) + if domain in ["size"]: + system_action[(domain, 'Recommend')].append(['Size', choice['size']]) \ No newline at end of file