From 3dfb0c42b997f4c7c63c66d705422adc02aa62ef Mon Sep 17 00:00:00 2001 From: s495728 Date: Mon, 10 Jun 2024 02:39:46 +0200 Subject: [PATCH] Add dialog policy --- src/service/dialog_policy.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/src/service/dialog_policy.py b/src/service/dialog_policy.py index ee9d29c..96cc523 100644 --- a/src/service/dialog_policy.py +++ b/src/service/dialog_policy.py @@ -1,8 +1,33 @@ +from collections import defaultdict from model.frame import Frame +from src.model import slot class DialogPolicy: - def next_dialogue_act(self, frames: list[Frame]) -> Frame: - if frames[-1].act == "welcomemsg": - return Frame("system", "welcomemsg", []) - else: - return Frame("system", "canthelp", []) + + def predict(self, dsm): + system_action = defaultdict(list) + last_frame = dsm.state['history'][-1] + if(dsm.was_previous_order_invalid==False): + match(last_frame.act.strip('/')): + case "inform" | "affirm": + current_active_status = dsm.get_current_active_stage() + match(current_active_status): + case "collect_food": + system_action["inform"].append([slot("menu")], dsm.constants['menu']) + case "collect_drinks": + system_action["inform"].append([slot("drink")], dsm.constants['drink']) + case "collect_address": + system_action["request"].append([slot("address")], None) + return Frame(source = 'system', act = "request", slots = last_frame.slots) + case "request": + for slot in last_frame.slots: + system_action["inform"].append([slot(slot)], dsm.constants[slot]) + case "welcomemsg": + system_action["inform"].append([slot("menu")], dsm.constants['menu']) + case "bye": + system_action["bye"].append([], None) + case "negate": + system_action["affirm"].append([], None) + + return system_action +