From b64eb49ba0c1ee61b70bb924fd7e300e1421fc79 Mon Sep 17 00:00:00 2001 From: Patryk Date: Fri, 7 Jun 2024 12:49:41 +0200 Subject: [PATCH] Implementacja DSM wraz z testami --- attributes.json | 7 +--- requirements.txt | 3 +- src/main.py | 4 +- src/service/dialog_state_monitor.py | 57 +++++++++++++++++++++++++---- src/test/__init__.py | 0 src/test/dialog_state_monitor.py | 20 ++++++++++ 6 files changed, 75 insertions(+), 16 deletions(-) create mode 100644 src/test/__init__.py create mode 100644 src/test/dialog_state_monitor.py diff --git a/attributes.json b/attributes.json index 7783078..c92db86 100644 --- a/attributes.json +++ b/attributes.json @@ -1,9 +1,6 @@ { - "addr": null, - "confirm": [ - "true", - "false" - ], + "address": null, + "order-complete": false, "dough": ["thick"], "drink": ["pepsi", "cola", "water"], "food": ["pizza"], diff --git a/requirements.txt b/requirements.txt index 565e922..75af938 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ flair==0.13.1 conllu==4.5.3 pandas==1.5.3 numpy==1.26.4 -torch==2.3.0 \ No newline at end of file +torch==2.3.0 +convlab==3.0.2a0 \ No newline at end of file diff --git a/src/main.py b/src/main.py index 01a0f04..aa0abb9 100644 --- a/src/main.py +++ b/src/main.py @@ -18,10 +18,10 @@ while True: # print(frame) # DSM - # monitor.append(frame) + monitor.update(frame) # DP - # print(dialog_policy.next_dialogue_act(monitor.get_all()).act) + # print(dialog_policy.next_dialogue_act(monitor.read()).act) # NLG act, slots = parse_frame(frame) diff --git a/src/service/dialog_state_monitor.py b/src/service/dialog_state_monitor.py index 0c62631..3edcf45 100644 --- a/src/service/dialog_state_monitor.py +++ b/src/service/dialog_state_monitor.py @@ -1,14 +1,55 @@ -from model.frame import Frame +from src.model.frame import Frame +from convlab.dst.dst import DST +import json +import copy class DialogStateMonitor: - dialog = [] + def __init__(self, initial_state_file: str = '../attributes.json'): + DST.__init__(self) + with open(initial_state_file) as file: + self.__initial_state = json.load(file) + self.__memory = copy.deepcopy(self.__initial_state) - def append(self, frame: Frame): - self.dialog.append(frame) + # def __access_memory__(self, path: str) -> str | int | float | None: + # result = self.state['memory'] + # for segment in path.split('.'): + # if segment not in result: + # return None + # result = result[segment] + # return result - def get_all(self) -> [Frame]: - return self.dialog + def update(self, frame: Frame): + if frame.source != 'user': + return + if frame.act == 'inform/order': + new_order = dict() + for slot in frame.slots: + new_order[slot.name] = slot.value + self.__memory['order'].append(new_order) + elif frame.act == 'inform/address': + for slot in frame.slots: + self.__memory['address'][slot.name] = slot.value + elif frame.act == 'inform/phone': + for slot in frame.slots: + self.__memory['phone'][slot.name] = slot.value + elif frame.act == 'inform/order-complete': + self.__memory['order-complete'] = True + elif frame.act == 'inform/delivery': + for slot in frame.slots: + self.__memory['delivery'][slot.name] = slot.value + elif frame.act == 'inform/payment': + for slot in frame.slots: + self.__memory['payment'][slot.name] = slot.value + elif frame.act == 'inform/time': + for slot in frame.slots: + self.__memory['time'][slot.name] = slot.value + elif frame.act == 'inform/name': + for slot in frame.slots: + self.__memory['name'][slot.name] = slot.value - def get_last(self) -> Frame: - return self.dialog[len(self.dialog) - 1] + def read(self) -> dict: + return self.__memory + + def reset(self): + self.__memory = copy.deepcopy(self.__initial_state) diff --git a/src/test/__init__.py b/src/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/test/dialog_state_monitor.py b/src/test/dialog_state_monitor.py new file mode 100644 index 0000000..a76c1ee --- /dev/null +++ b/src/test/dialog_state_monitor.py @@ -0,0 +1,20 @@ +from src.service.dialog_state_monitor import DialogStateMonitor +from src.model.frame import Frame +from src.model.slot import Slot + +dst = DialogStateMonitor() + +assert dst.read()['pizza']['capri']['price'] == 25 + +dst.update(Frame('user', 'inform/order', [Slot('B-pizza', 'margaritta'), Slot('B-sauce', 'ketchup')])) +dst.update(Frame('user', 'inform/order', [Slot('B-pizza', 'carbonara')])) +dst.update(Frame('user', 'inform/order-complete', [])) + +assert dst.read()['order'][0]['B-pizza'] == 'margaritta' +assert dst.read()['order'][0]['B-sauce'] == 'ketchup' +assert dst.read()['order-complete'] == True + +dst.reset() + +assert dst.read()['order'] == [] +assert dst.read()['order-complete'] == False