diff --git a/src/service/dialog_policy.py b/src/service/dialog_policy.py index cb8a1de..9a405d0 100644 --- a/src/service/dialog_policy.py +++ b/src/service/dialog_policy.py @@ -5,29 +5,49 @@ from model.slot import Slot class DialogPolicy: def predict(self, dsm): - system_action = defaultdict(list) last_frame = dsm.state['history'][-1] - if(dsm.state['was_previous_order_invalid']==False): - match(last_frame.act.strip('/')): + if(dsm.state['was_previous_order_invalid'] == False): + if("inform" in last_frame.act): + act = last_frame.act.split('/')[0] + else: + act = last_frame.act + match(act): case "inform" | "affirm": current_active_status = dsm.get_current_active_stage() match(current_active_status): case "collect_food": - system_action["inform"].append(["menu", dsm.state['constants']['menu']]) + return Frame(source="system", act = "inform", slots = [Slot("menu", dsm.state['constants']['menu'])]) case "collect_drinks": - system_action["inform"].append(["drink", dsm.state['constants']['drink']]) + return Frame(source="system", act = "inform", slots = [Slot("drink", dsm.state['constants']['drink'])]) case "collect_address": - system_action["request"].append(["address"]) + return Frame(source="system", act = "request", slots = [Slot("address")]) if(current_active_status == None): - system_action["inform"].append(["order", dsm.belief_state]) + dsm.state['was_previous_order_invalid'] = False + return Frame(source="system", act = "end") + case "request/menu": + return Frame(source="system", act = "inform", slots = [Slot("menu", dsm.state['constants']['menu'])]) + case "request/price": + return Frame(source="system", act = "inform", slots = [Slot("price", dsm.state['total_cost'])]) + case "request/ingredients": + return Frame(source="system", act = "inform", slots = [Slot("ingredients", dsm.state['constants']['ingredients'])]) + case "request/sauce": + return Frame(source="system", act = "inform", slots = [Slot("sauce", dsm.state['constants']['sauce'])]) + case "request/time": + return Frame(source="system", act = "inform", slots = [Slot("time", dsm.state['belief_state']['time'])]) + case "request/size": + return Frame(source="system", act = "inform", slots = [Slot("size", dsm.state['constants']['size'])]) + case "request/delivery-price": + return Frame(source="system", act = "inform", slots = [Slot("delivery-price", "10")]) + case "request/drinks": + return Frame(source="system", act = "inform", slots = [Slot("drink", dsm.state['constants']['drink'])]) case "request": + slots = [] for slot in last_frame.slots: - system_action["inform"].append([slot, dsm.state['constants'][slot]]) + slots.append(Slot(slot.name, dsm.state['constants'][slot.name])) + return Frame(source="system", act = "inform", slots = slots) case "welcomemsg": - system_action["inform"].append(["menu", dsm.state['constants']['menu']]) + return Frame(source="system", act = "inform", slots = [Slot("menu", dsm.state['constants']['menu'])]) case "bye": - system_action["bye"].append([[], None]) + return Frame(source="system", act = "bye") case "negate": - system_action["affirm"].append([[], None]) - - return system_action \ No newline at end of file + return Frame(source="system", act = "affirm") \ No newline at end of file