This commit is contained in:
s495727 2024-06-13 16:11:22 +02:00
parent 118099bb37
commit 02312cc3b9

View File

@ -70,16 +70,12 @@ class DialogStateMonitor:
if stage['name'] == "collect_food": if stage['name'] == "collect_food":
for order in self.state['belief_state']['order']: for order in self.state['belief_state']['order']:
for o in order: if order.get("pizza"):
k, v = o
if k == 'pizza':
stage['completed'] = True stage['completed'] = True
return return
elif stage["name"] == "collect_drinks": elif stage["name"] == "collect_drinks":
for order in self.state['belief_state']['order']: for order in self.state['belief_state']['order']:
for o in order: if order.get("drink"):
k, v = o
if k == 'drink':
stage['completed'] = True stage['completed'] = True
return return
elif stage["name"] == "collect_address": elif stage["name"] == "collect_address":
@ -129,7 +125,7 @@ class DialogStateMonitor:
if frame.source != 'user': if frame.source != 'user':
return return
if frame.act == 'inform/order': if frame.act == 'inform/order':
new_order = list() new_order = dict()
for slot in frame.slots: for slot in frame.slots:
value = normalize(slot.value) value = normalize(slot.value)
slot = self.slot_augmentation(slot, value) slot = self.slot_augmentation(slot, value)
@ -146,7 +142,11 @@ class DialogStateMonitor:
return return
self.state['was_previous_order_invalid'] = False self.state['was_previous_order_invalid'] = False
self.state['total_cost'] += self.state['constants'][slot.name][value]['price'] self.state['total_cost'] += self.state['constants'][slot.name][value]['price']
new_order.append((slot.name, value)) if slot.name == "pizza":
if new_order.get("pizza"):
new_order[slot.name].append(value)
else:
new_order[slot.name] = [value]
if len(new_order) > 0: if len(new_order) > 0:
self.state['belief_state']['order'].append(new_order) self.state['belief_state']['order'].append(new_order)
self.complete_stage_if_valid('collect_food') self.complete_stage_if_valid('collect_food')