From 72e17d2106a119223c0758f6f66b756fb8f62142 Mon Sep 17 00:00:00 2001 From: Patryk Date: Sun, 9 Jun 2024 19:48:26 +0200 Subject: [PATCH] =?UTF-8?q?Zmiana=20metody=20zapisu=20pami=C4=99ci=20w=20D?= =?UTF-8?q?ST?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- attributes.json | 5 +-- src/service/dialog_state_monitor.py | 54 ++++++++++++++++------------- src/test/dialog_state_monitor.py | 16 ++++----- 3 files changed, 37 insertions(+), 38 deletions(-) diff --git a/attributes.json b/attributes.json index c92db86..969da03 100644 --- a/attributes.json +++ b/attributes.json @@ -1,6 +1,4 @@ { - "address": null, - "order-complete": false, "dough": ["thick"], "drink": ["pepsi", "cola", "water"], "food": ["pizza"], @@ -49,6 +47,5 @@ "xl": { "price_multiplier": 1.4 } - }, - "order": [] + } } \ No newline at end of file diff --git a/src/service/dialog_state_monitor.py b/src/service/dialog_state_monitor.py index b5bea59..dfeb970 100644 --- a/src/service/dialog_state_monitor.py +++ b/src/service/dialog_state_monitor.py @@ -1,6 +1,5 @@ from src.model.frame import Frame from convlab.dst.dst import DST -import json import copy @@ -10,19 +9,27 @@ def normalize(value): class DialogStateMonitor(DST): - def __init__(self, initial_state_file: str = '../attributes.json'): - DST.__init__(self) - with open(initial_state_file) as file: - self.__initial_state = json.load(file) - self.__memory = copy.deepcopy(self.__initial_state) + domain = 'restaurant' - # def __access_memory__(self, path: str) -> str | int | float | None: - # result = self.state['memory'] - # for segment in path.split('.'): - # if segment not in result: - # return None - # result = result[segment] - # return result + def __init__(self): + DST.__init__(self) + self.__initial_state = dict(user_action=[], + system_action=[], + belief_state={ + 'order': [], + 'address': {}, + 'order-complete': False, + 'phone': {}, + 'delivery': {}, + 'payment': {}, + 'time': {}, + 'name': {}, + }, + booked={}, + request_state={}, + terminated=False, + history=[]) + self.state = copy.deepcopy(self.__initial_state) def update(self, frame: Frame): if frame.source != 'user': @@ -31,30 +38,27 @@ class DialogStateMonitor(DST): new_order = dict() for slot in frame.slots: new_order[slot.name] = normalize(slot.value) - self.__memory['order'].append(new_order) + self.state['belief_state']['order'].append(new_order) elif frame.act == 'inform/address': for slot in frame.slots: - self.__memory['address'][slot.name] = normalize(slot.value) + self.state['belief_state']['address'][slot.name] = normalize(slot.value) elif frame.act == 'inform/phone': for slot in frame.slots: - self.__memory['phone'][slot.name] = normalize(slot.value) + self.state['belief_state']['phone'][slot.name] = normalize(slot.value) elif frame.act == 'inform/order-complete': - self.__memory['order-complete'] = True + self.state['belief_state']['order-complete'] = True elif frame.act == 'inform/delivery': for slot in frame.slots: - self.__memory['delivery'][slot.name] = normalize(slot.value) + self.state['belief_state']['delivery'][slot.name] = normalize(slot.value) elif frame.act == 'inform/payment': for slot in frame.slots: - self.__memory['payment'][slot.name] = normalize(slot.value) + self.state['belief_state']['payment'][slot.name] = normalize(slot.value) elif frame.act == 'inform/time': for slot in frame.slots: - self.__memory['time'][slot.name] = normalize(slot.value) + self.state['belief_state']['time'][slot.name] = normalize(slot.value) elif frame.act == 'inform/name': for slot in frame.slots: - self.__memory['name'][slot.name] = normalize(slot.value) - - def read(self) -> dict: - return self.__memory + self.state['belief_state']['name'][slot.name] = normalize(slot.value) def reset(self): - self.__memory = copy.deepcopy(self.__initial_state) + self.state = copy.deepcopy(self.__initial_state) diff --git a/src/test/dialog_state_monitor.py b/src/test/dialog_state_monitor.py index a76c1ee..922b26e 100644 --- a/src/test/dialog_state_monitor.py +++ b/src/test/dialog_state_monitor.py @@ -4,17 +4,15 @@ from src.model.slot import Slot dst = DialogStateMonitor() -assert dst.read()['pizza']['capri']['price'] == 25 - -dst.update(Frame('user', 'inform/order', [Slot('B-pizza', 'margaritta'), Slot('B-sauce', 'ketchup')])) -dst.update(Frame('user', 'inform/order', [Slot('B-pizza', 'carbonara')])) +dst.update(Frame('user', 'inform/order', [Slot('pizza', 'margaritta'), Slot('sauce', 'ketchup')])) +dst.update(Frame('user', 'inform/order', [Slot('pizza', 'carbonara')])) dst.update(Frame('user', 'inform/order-complete', [])) -assert dst.read()['order'][0]['B-pizza'] == 'margaritta' -assert dst.read()['order'][0]['B-sauce'] == 'ketchup' -assert dst.read()['order-complete'] == True +assert dst.state['belief_state']['order'][0]['pizza'] == 'margaritta' +assert dst.state['belief_state']['order'][0]['sauce'] == 'ketchup' +assert dst.state['belief_state']['order-complete'] == True dst.reset() -assert dst.read()['order'] == [] -assert dst.read()['order-complete'] == False +assert dst.state['belief_state']['order'] == [] +assert dst.state['belief_state']['order-complete'] == False