Zmiany dla dialog policy
This commit is contained in:
parent
a928c925fc
commit
0f2041b481
@ -5,29 +5,51 @@ from model.slot import Slot
|
|||||||
class DialogPolicy:
|
class DialogPolicy:
|
||||||
|
|
||||||
def predict(self, dsm):
|
def predict(self, dsm):
|
||||||
system_action = defaultdict(list)
|
|
||||||
last_frame = dsm.state['history'][-1]
|
last_frame = dsm.state['history'][-1]
|
||||||
if(dsm.state['was_previous_order_invalid'] == False):
|
if(dsm.state['was_previous_order_invalid'] == False):
|
||||||
match(last_frame.act.strip('/')):
|
if("inform" in last_frame.act):
|
||||||
|
act = last_frame.act.split('/')[0]
|
||||||
|
else:
|
||||||
|
act = last_frame.act
|
||||||
|
match(act):
|
||||||
case "inform" | "affirm":
|
case "inform" | "affirm":
|
||||||
current_active_status = dsm.get_current_active_stage()
|
current_active_status = dsm.get_current_active_stage()
|
||||||
match(current_active_status):
|
match(current_active_status):
|
||||||
case "collect_food":
|
case "collect_food":
|
||||||
system_action["inform"].append(["menu", dsm.state['constants']['menu']])
|
return Frame(source="system", act = "inform", slots = [Slot("menu", dsm.state['constants']['menu'])])
|
||||||
case "collect_drinks":
|
case "collect_drinks":
|
||||||
system_action["inform"].append(["drink", dsm.state['constants']['drink']])
|
return Frame(source="system", act = "inform", slots = [Slot("drink", dsm.state['constants']['drink'])])
|
||||||
case "collect_address":
|
case "collect_address":
|
||||||
system_action["request"].append(["address"])
|
return Frame(source="system", act = "request", slots = [Slot("address")])
|
||||||
if(current_active_status == None):
|
if(current_active_status == None):
|
||||||
system_action["inform"].append(["order", dsm.belief_state])
|
dsm.state['was_previous_order_invalid'] = False
|
||||||
|
return Frame(source="system", act = "end")
|
||||||
|
case "request/menu":
|
||||||
|
return Frame(source="system", act = "inform", slots = [Slot("menu", dsm.state['constants']['menu'])])
|
||||||
|
case "request/price":
|
||||||
|
return Frame(source="system", act = "inform", slots = [Slot("price", dsm.state['total_cost'])])
|
||||||
|
case "request/ingredients":
|
||||||
|
return Frame(source="system", act = "inform", slots = [Slot("ingredients", dsm.state['constants']['ingredients'])])
|
||||||
|
case "request/sauce":
|
||||||
|
return Frame(source="system", act = "inform", slots = [Slot("sauce", dsm.state['constants']['sauce'])])
|
||||||
|
case "request/time":
|
||||||
|
return Frame(source="system", act = "inform", slots = [Slot("time", dsm.state['belief_state']['time'])])
|
||||||
|
case "request/size":
|
||||||
|
return Frame(source="system", act = "inform", slots = [Slot("size", dsm.state['constants']['size'])])
|
||||||
|
case "request/delivery-price":
|
||||||
|
return Frame(source="system", act = "inform", slots = [Slot("delivery-price", "10")])
|
||||||
|
case "request/drinks":
|
||||||
|
return Frame(source="system", act = "inform", slots = [Slot("drink", dsm.state['constants']['drink'])])
|
||||||
case "request":
|
case "request":
|
||||||
|
slots = []
|
||||||
for slot in last_frame.slots:
|
for slot in last_frame.slots:
|
||||||
system_action["inform"].append([slot, dsm.state['constants'][slot]])
|
slots.append(Slot(slot.name, dsm.state['constants'][slot.name]))
|
||||||
|
return Frame(source="system", act = "inform", slots = slots)
|
||||||
case "welcomemsg":
|
case "welcomemsg":
|
||||||
system_action["inform"].append(["menu", dsm.state['constants']['menu']])
|
return Frame(source="system", act = "inform", slots = [Slot("menu", dsm.state['constants']['menu'])])
|
||||||
case "bye":
|
case "bye":
|
||||||
system_action["bye"].append([[], None])
|
return Frame(source="system", act = "bye")
|
||||||
case "negate":
|
case "negate":
|
||||||
system_action["affirm"].append([[], None])
|
return Frame(source="system", act = "affirm")
|
||||||
|
|
||||||
return system_action
|
return Frame(source="system", act = [])
|
Loading…
Reference in New Issue
Block a user