diff --git a/src/main.py b/src/main.py index 0d6d7cd..3a505a5 100644 --- a/src/main.py +++ b/src/main.py @@ -32,7 +32,7 @@ while True: monitor.update(frame) # DP - # print(dialog_policy.next_dialogue_act(monitor.read()).act) + system_action = dialog_policy.predict(monitor) # NLG act, slots = parse_frame(frame) diff --git a/src/service/dialog_policy.py b/src/service/dialog_policy.py index ee9d29c..13812a1 100644 --- a/src/service/dialog_policy.py +++ b/src/service/dialog_policy.py @@ -1,8 +1,34 @@ +from collections import defaultdict from model.frame import Frame +from src.model.slot import Slot class DialogPolicy: - def next_dialogue_act(self, frames: list[Frame]) -> Frame: - if frames[-1].act == "welcomemsg": - return Frame("system", "welcomemsg", []) - else: - return Frame("system", "canthelp", []) + + def predict(self, dsm): + system_action = defaultdict(list) + last_frame = dsm.state['history'][-1] + if(dsm.was_previous_order_invalid==False): + match(last_frame.act.strip('/')): + case "inform" | "affirm": + current_active_status = dsm.get_current_active_stage() + match(current_active_status): + case "collect_food": + system_action["inform"].append([Slot("menu", dsm.constants['menu'])]) + case "collect_drinks": + system_action["inform"].append([Slot("drink"), dsm.constants['drink']]) + case "collect_address": + system_action["request"].append([Slot("address")]) + if(current_active_status == None): + system_action["inform"].append([Slot("order"), dsm.belief_state]) + case "request": + for slot in last_frame.slots: + system_action["inform"].append([Slot(slot), dsm.constants[slot]]) + case "welcomemsg": + system_action["inform"].append([Slot("menu"), dsm.constants['menu']]) + case "bye": + system_action["bye"].append([], None) + case "negate": + system_action["affirm"].append([], None) + + return system_action +