diff --git a/src/service/dialog_state_monitor.py b/src/service/dialog_state_monitor.py index abebbea..7827de7 100644 --- a/src/service/dialog_state_monitor.py +++ b/src/service/dialog_state_monitor.py @@ -1,5 +1,6 @@ from src.model.frame import Frame import copy +import json def normalize(value): @@ -8,7 +9,9 @@ def normalize(value): class DialogStateMonitor: - def __init__(self): + def __init__(self, initial_state_file: str = '../attributes.json'): + with open(initial_state_file) as file: + constants = json.load(file) self.__initial_state = dict(belief_state={ 'order': [], 'address': {}, @@ -19,17 +22,20 @@ class DialogStateMonitor: 'time': {}, 'name': {}, }, + total_cost=0, stages=[ {'completed': False, 'name': 'collect_food'}, {'completed': False, 'name': 'collect_drinks'}, {'completed': False, 'name': 'collect_address'}, ], + was_previous_order_invalid=False, constraints={ 'order': [ 'sauce', 'pizza', ], }, + constants=constants, history=[]) self.state = copy.deepcopy(self.__initial_state) @@ -55,6 +61,12 @@ class DialogStateMonitor: stage['completed'] = True return + def pizza_exists(self, name: str) -> bool: + return normalize(name) in self.state['constants']['pizza'] + + def get_total_cost(self) -> int: + return self.state['total_cost'] + def update(self, frame: Frame) -> None: self.state['history'].append(frame) if frame.source != 'user': @@ -62,7 +74,14 @@ class DialogStateMonitor: if frame.act == 'inform/order': new_order = dict() for slot in frame.slots: - new_order[slot.name] = normalize(slot.value) + value = normalize(slot.value) + if slot.name == 'pizza': + if self.pizza_exists(value) is False: + self.state['was_previous_order_invalid'] = True + return + self.state['was_previous_order_invalid'] = False + self.state['total_cost'] += self.state['constants']['pizza'][value]['price'] + new_order[slot.name] = value self.state['belief_state']['order'].append(new_order) elif frame.act == 'inform/address': for slot in frame.slots: diff --git a/src/test/dialog_state_monitor.py b/src/test/dialog_state_monitor.py index f6509d4..363e302 100644 --- a/src/test/dialog_state_monitor.py +++ b/src/test/dialog_state_monitor.py @@ -4,22 +4,28 @@ from src.model.slot import Slot dsm = DialogStateMonitor() -frame1 = Frame('user', 'inform/order', [Slot('pizza', 'margaritta'), Slot('sauce', 'ketchup')]) +assert dsm.pizza_exists('capri') is True +assert dsm.state['was_previous_order_invalid'] is False + +frame1 = Frame('user', 'inform/order', [Slot('pizza', 'margarita'), Slot('sauce', 'ketchup')]) dsm.update(frame1) assert dsm.get_last_order_missing_fields() == [] -frame2 = Frame('user', 'inform/order', [Slot('pizza', 'carbonara')]) +assert dsm.get_total_cost() == 20 +frame2 = Frame('user', 'inform/order', [Slot('pizza', 'tuna')]) dsm.update(frame2) assert dsm.get_last_order_missing_fields() == ['sauce'] +assert dsm.get_total_cost() == 60 frame3 = Frame('user', 'inform/order-complete', []) dsm.update(frame3) -assert dsm.state['belief_state']['order'][0]['pizza'] == 'margaritta' +assert dsm.state['belief_state']['order'][0]['pizza'] == 'margarita' assert dsm.state['belief_state']['order'][0]['sauce'] == 'ketchup' assert dsm.state['belief_state']['order-complete'] is True assert dsm.state['history'][0] == frame1 assert dsm.state['history'][1] == frame2 assert dsm.state['history'][2] == frame3 + assert dsm.get_current_active_stage() == 'collect_food' dsm.mark_current_stage_completed() assert dsm.get_current_active_stage() == 'collect_drinks' @@ -30,7 +36,16 @@ assert dsm.get_current_active_stage() is None dsm.reset() +assert dsm.get_total_cost() == 0 assert dsm.get_current_active_stage() == 'collect_food' assert dsm.state['belief_state']['order'] == [] assert dsm.state['belief_state']['order-complete'] is False assert len(dsm.state['history']) == 0 + +dsm.reset() + +frame1 = Frame('user', 'inform/order', [Slot('pizza', 'buraczna')]) +dsm.update(frame1) +assert dsm.state['was_previous_order_invalid'] is True +assert dsm.state['belief_state']['order'] == [] +assert dsm.get_total_cost() == 0