From 5f128e7e885c6cae3d68f432b87b885b95d05b8f Mon Sep 17 00:00:00 2001 From: Patryk Date: Mon, 10 Jun 2024 23:35:45 +0200 Subject: [PATCH] =?UTF-8?q?Zmiana=20staga=20po=20dodaniu=20okre=C5=9Bloneg?= =?UTF-8?q?o=20wpisu=20do=20DSM?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/service/dialog_state_monitor.py | 8 +++++--- src/test/dialog_state_monitor.py | 16 ++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/service/dialog_state_monitor.py b/src/service/dialog_state_monitor.py index 1057d1d..4d8e91c 100644 --- a/src/service/dialog_state_monitor.py +++ b/src/service/dialog_state_monitor.py @@ -1,4 +1,4 @@ -from model.frame import Frame +from src.model.frame import Frame import copy import json @@ -62,17 +62,19 @@ class DialogStateMonitor: new_order = dict() for slot in frame.slots: value = normalize(slot.value) - if slot.name == 'pizza' or slot.name == 'drink': + if (slot.name == 'pizza' and self.get_current_active_stage() == 'collect_food') or (slot.name == 'drink' and self.get_current_active_stage() == 'collect_drinks'): if self.item_exists(slot.name, value) is False: self.state['was_previous_order_invalid'] = True return self.state['was_previous_order_invalid'] = False self.state['total_cost'] += self.state['constants'][slot.name][value]['price'] + self.mark_current_stage_completed() new_order[slot.name] = value self.state['belief_state']['order'].append(new_order) - elif frame.act == 'inform/address': + elif frame.act == 'inform/address' and self.get_current_active_stage() == 'collect_address': for slot in frame.slots: self.state['belief_state']['address'][slot.name] = normalize(slot.value) + self.mark_current_stage_completed() elif frame.act == 'inform/phone': for slot in frame.slots: self.state['belief_state']['phone'][slot.name] = normalize(slot.value) diff --git a/src/test/dialog_state_monitor.py b/src/test/dialog_state_monitor.py index 5d8087b..cd4090e 100644 --- a/src/test/dialog_state_monitor.py +++ b/src/test/dialog_state_monitor.py @@ -9,17 +9,21 @@ assert dsm.item_exists('pizza', 'buraczana') is False assert dsm.item_exists('drink', 'cola') is True assert dsm.state['was_previous_order_invalid'] is False +assert dsm.get_current_active_stage() == 'collect_food' frame1 = Frame('user', 'inform/order', [Slot('pizza', 'margarita'), Slot('sauce', 'ketchup')]) dsm.update(frame1) +assert dsm.get_current_active_stage() == 'collect_drinks' assert dsm.get_total_cost() == 20 frame2 = Frame('user', 'inform/order', [Slot('pizza', 'tuna')]) dsm.update(frame2) -assert dsm.get_total_cost() == 60 +assert dsm.get_current_active_stage() == 'collect_drinks' +assert dsm.get_total_cost() == 20 # Pizza is not added, as previous stage is closed already frame3 = Frame('user', 'inform/order-complete', []) dsm.update(frame3) frame4 = Frame('user', 'inform/order', [Slot('drink', 'cola')]) dsm.update(frame4) -assert dsm.get_total_cost() == 70 +assert dsm.get_current_active_stage() == 'collect_address' +assert dsm.get_total_cost() == 30 assert dsm.state['belief_state']['order'][0]['pizza'] == 'margarita' assert dsm.state['belief_state']['order'][0]['sauce'] == 'ketchup' @@ -29,14 +33,6 @@ assert dsm.state['history'][1] == frame2 assert dsm.state['history'][2] == frame3 assert dsm.state['history'][3] == frame4 -assert dsm.get_current_active_stage() == 'collect_food' -dsm.mark_current_stage_completed() -assert dsm.get_current_active_stage() == 'collect_drinks' -dsm.mark_current_stage_completed() -assert dsm.get_current_active_stage() == 'collect_address' -dsm.mark_current_stage_completed() -assert dsm.get_current_active_stage() is None - dsm.reset() assert dsm.get_total_cost() == 0