Zmiany dla dialog policy

This commit is contained in:
s495728 2024-06-10 19:31:38 +02:00
parent a928c925fc
commit fda06f5240

View File

@ -5,29 +5,49 @@ from model.slot import Slot
class DialogPolicy: class DialogPolicy:
def predict(self, dsm): def predict(self, dsm):
system_action = defaultdict(list)
last_frame = dsm.state['history'][-1] last_frame = dsm.state['history'][-1]
if(dsm.state['was_previous_order_invalid'] == False): if(dsm.state['was_previous_order_invalid'] == False):
match(last_frame.act.strip('/')): if("inform" in last_frame.act):
act = last_frame.act.split('/')[0]
else:
act = last_frame.act
match(act):
case "inform" | "affirm": case "inform" | "affirm":
current_active_status = dsm.get_current_active_stage() current_active_status = dsm.get_current_active_stage()
match(current_active_status): match(current_active_status):
case "collect_food": case "collect_food":
system_action["inform"].append(["menu", dsm.state['constants']['menu']]) return Frame(source="system", act = "inform", slots = [Slot("menu", dsm.state['constants']['menu'])])
case "collect_drinks": case "collect_drinks":
system_action["inform"].append(["drink", dsm.state['constants']['drink']]) return Frame(source="system", act = "inform", slots = [Slot("drink", dsm.state['constants']['drink'])])
case "collect_address": case "collect_address":
system_action["request"].append(["address"]) return Frame(source="system", act = "request", slots = [Slot("address")])
if(current_active_status == None): if(current_active_status == None):
system_action["inform"].append(["order", dsm.belief_state]) dsm.state['was_previous_order_invalid'] = False
return Frame(source="system", act = "end")
case "request/menu":
return Frame(source="system", act = "inform", slots = [Slot("menu", dsm.state['constants']['menu'])])
case "request/price":
return Frame(source="system", act = "inform", slots = [Slot("price", dsm.state['total_cost'])])
case "request/ingredients":
return Frame(source="system", act = "inform", slots = [Slot("ingredients", dsm.state['constants']['ingredients'])])
case "request/sauce":
return Frame(source="system", act = "inform", slots = [Slot("sauce", dsm.state['constants']['sauce'])])
case "request/time":
return Frame(source="system", act = "inform", slots = [Slot("time", dsm.state['belief_state']['time'])])
case "request/size":
return Frame(source="system", act = "inform", slots = [Slot("size", dsm.state['constants']['size'])])
case "request/delivery-price":
return Frame(source="system", act = "inform", slots = [Slot("delivery-price", "10")])
case "request/drinks":
return Frame(source="system", act = "inform", slots = [Slot("drink", dsm.state['constants']['drink'])])
case "request": case "request":
slots = []
for slot in last_frame.slots: for slot in last_frame.slots:
system_action["inform"].append([slot, dsm.state['constants'][slot]]) slots.append(Slot(slot.name, dsm.state['constants'][slot.name]))
return Frame(source="system", act = "inform", slots = slots)
case "welcomemsg": case "welcomemsg":
system_action["inform"].append(["menu", dsm.state['constants']['menu']]) return Frame(source="system", act = "inform", slots = [Slot("menu", dsm.state['constants']['menu'])])
case "bye": case "bye":
system_action["bye"].append([[], None]) return Frame(source="system", act = "bye")
case "negate": case "negate":
system_action["affirm"].append([[], None]) return Frame(source="system", act = "affirm")
return system_action