From f343ebec98e954f18c326126ab2c768d8057b6b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20J=C4=99dyk?= Date: Mon, 31 May 2021 14:53:20 +0200 Subject: [PATCH] fix DST bugs --- DialoguePolicy.py | 169 ++++++++++++++++++++++++++++++---------- DialogueStateTracker.py | 1 + main.py | 9 +-- 3 files changed, 133 insertions(+), 46 deletions(-) diff --git a/DialoguePolicy.py b/DialoguePolicy.py index a2392c4..abd7595 100644 --- a/DialoguePolicy.py +++ b/DialoguePolicy.py @@ -13,6 +13,7 @@ class DP: def __init__(self, dst): self.DST = dst + self.meeting_to_update = False def chooseTactic(self) -> SystemAct: dialogue_state, last_user_act, last_system_act = self.DST.get_dialogue_state() @@ -21,34 +22,58 @@ class DP: if dialogue_state == UserActType.CREATE_MEETING: if not last_system_act: if 'date' not in slots: - return SystemAct(SystemActType.REQUEST, ['date']) + system_act = SystemAct(SystemActType.REQUEST, ['date']) + self.DST.system_update(system_act) + return system_act elif 'time' not in slots: - return SystemAct(SystemActType.REQUEST, ['time']) + system_act = SystemAct(SystemActType.REQUEST, ['time']) + self.DST.system_update(system_act) + return system_act elif 'place' not in slots: - return SystemAct(SystemActType.REQUEST, ['place']) + system_act = SystemAct(SystemActType.REQUEST, ['place']) + self.DST.system_update(system_act) + return system_act elif 'description' not in slots: - return SystemAct(SystemActType.REQUEST, ['description']) + system_act = SystemAct(SystemActType.REQUEST, ['description']) + self.DST.system_update(system_act) + return system_act elif 'participants' not in slots: - return SystemAct(SystemActType.REQUEST, ['participants']) + system_act = SystemAct(SystemActType.REQUEST, ['participants']) + self.DST.system_update(system_act) + return system_act else: - return SystemAct(SystemActType.CONFIRM_DOMAIN, slots) + system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots) + self.DST.system_update(system_act) + return system_act elif last_system_act.getActType() == SystemActType.REQUEST: if last_user_act == UserActType.NEGATE: slot_type = last_system_act.getActParams()[0] if slot_type not in ['date', 'time']: self.DST.insert_empty_slot(slot_type) if 'date' not in slots: - return SystemAct(SystemActType.REQUEST, ['date']) + system_act = SystemAct(SystemActType.REQUEST, ['date']) + self.DST.system_update(system_act) + return system_act elif 'time' not in slots: - return SystemAct(SystemActType.REQUEST, ['time']) + system_act = SystemAct(SystemActType.REQUEST, ['time']) + self.DST.system_update(system_act) + return system_act elif 'place' not in slots: - return SystemAct(SystemActType.REQUEST, ['place']) + system_act = SystemAct(SystemActType.REQUEST, ['place']) + self.DST.system_update(system_act) + return system_act elif 'description' not in slots: - return SystemAct(SystemActType.REQUEST, ['description']) + system_act = SystemAct(SystemActType.REQUEST, ['description']) + self.DST.system_update(system_act) + return system_act elif 'participants' not in slots: - return SystemAct(SystemActType.REQUEST, ['participants']) + system_act = SystemAct(SystemActType.REQUEST, ['participants']) + self.DST.system_update(system_act) + return system_act else: - return SystemAct(SystemActType.CONFIRM_DOMAIN, slots) + system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots) + self.DST.system_update(system_act) + return system_act elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN: if last_user_act == UserActType.CONFIRM: system_act = SystemAct(SystemActType.AFFIRM, ['create_meeting']) @@ -62,56 +87,102 @@ class DP: return SystemAct(SystemActType.NOT_UNDERSTOOD, []) # stan edycji spotkania elif dialogue_state == UserActType.UPDATE_MEETING: - meeting_to_update = False if not last_system_act: if 'date' not in slots: - return SystemAct(SystemActType.REQUEST, ['date']) + system_act = SystemAct(SystemActType.REQUEST, ['date']) + self.DST.system_update(system_act) + return system_act elif 'time' not in slots: - return SystemAct(SystemActType.REQUEST, ['time']) + system_act = SystemAct(SystemActType.REQUEST, ['time']) + self.DST.system_update(system_act) + return system_act else: # implementacja wyszukiwania odpowiedniego spotkania w bazie return SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update']) elif last_system_act.getActType() == SystemActType.REQUEST: - if not meeting_to_update: + if not self.meeting_to_update: if 'date' not in slots: - return SystemAct(SystemActType.REQUEST, ['date']) + system_act = SystemAct(SystemActType.REQUEST, ['date']) + self.DST.system_update(system_act) + return system_act elif 'time' not in slots: - return SystemAct(SystemActType.REQUEST, ['time']) + system_act = SystemAct(SystemActType.REQUEST, ['time']) + self.DST.system_update(system_act) + return system_act else: # implementacja wyszukiwania odpowiedniego spotkania w bazie - return SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update']) + system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update']) + self.DST.system_update(system_act) + return system_act else: if last_user_act == UserActType.NEGATE: slot_type = last_system_act.getActParams()[0] - if slot_type not in ['date', 'time']: - self.DST.insert_empty_slot(slot_type) + self.DST.insert_empty_slot(slot_type) if 'date' not in slots: - return SystemAct(SystemActType.REQUEST, ['date']) + system_act = SystemAct(SystemActType.REQUEST, ['date']) + self.DST.system_update(system_act) + return system_act elif 'time' not in slots: - return SystemAct(SystemActType.REQUEST, ['time']) + system_act = SystemAct(SystemActType.REQUEST, ['time']) + self.DST.system_update(system_act) + return system_act elif 'place' not in slots: - return SystemAct(SystemActType.REQUEST, ['place']) + system_act = SystemAct(SystemActType.REQUEST, ['place']) + self.DST.system_update(system_act) + return system_act elif 'description' not in slots: - return SystemAct(SystemActType.REQUEST, ['description']) + system_act = SystemAct(SystemActType.REQUEST, ['description']) + self.DST.system_update(system_act) + return system_act elif 'participants' not in slots: - return SystemAct(SystemActType.REQUEST, ['participants']) + system_act = SystemAct(SystemActType.REQUEST, ['participants']) + self.DST.system_update(system_act) + return system_act else: - return SystemAct(SystemActType.CONFIRM_DOMAIN, slots) + system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots) + self.DST.system_update(system_act) + return system_act + else: + if 'date' not in slots: + system_act = SystemAct(SystemActType.REQUEST, ['date']) + self.DST.system_update(system_act) + return system_act + elif 'time' not in slots: + system_act = SystemAct(SystemActType.REQUEST, ['time']) + self.DST.system_update(system_act) + return system_act + elif 'place' not in slots: + system_act = SystemAct(SystemActType.REQUEST, ['place']) + self.DST.system_update(system_act) + return system_act + elif 'description' not in slots: + system_act = SystemAct(SystemActType.REQUEST, ['description']) + self.DST.system_update(system_act) + return system_act + elif 'participants' not in slots: + system_act = SystemAct(SystemActType.REQUEST, ['participants']) + self.DST.system_update(system_act) + return system_act + else: + system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots) + self.DST.system_update(system_act) + return system_act elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN: - if meeting_to_update: + if self.meeting_to_update: if last_user_act == UserActType.CONFIRM: - meeting_to_update = False self.DST.clear() return SystemAct(SystemActType.AFFIRM, ['update_meeting']) elif last_user_act == UserActType.NEGATE: self.DST.clear() return SystemAct(SystemActType.REQMORE, []) - meeting_to_update = False - if not meeting_to_update: + self.meeting_to_update = False + if not self.meeting_to_update: if last_user_act == UserActType.CONFIRM: - meeting_to_update = True + self.meeting_to_update = True self.DST.clear_slots() - return SystemAct(SystemActType.REQUEST, ['date']) + system_act = SystemAct(SystemActType.REQUEST, ['date']) + self.DST.system_update(system_act) + return system_act elif last_user_act == UserActType.NEGATE: self.DST.clear() return SystemAct(SystemActType.REQMORE, []) @@ -121,20 +192,32 @@ class DP: elif dialogue_state == UserActType.CANCEL_MEETING: if not last_system_act: if 'date' not in slots: - return SystemAct(SystemActType.REQUEST, ['date']) + system_act = SystemAct(SystemActType.REQUEST, ['date']) + self.DST.system_update(system_act) + return system_act elif 'time' not in slots: - return SystemAct(SystemActType.REQUEST, ['time']) + system_act = SystemAct(SystemActType.REQUEST, ['time']) + self.DST.system_update(system_act) + return system_act else: # implementacja wyszukiwania odpowiedniego spotkania w bazie - return SystemAct(SystemActType.CONFIRM_DOMAIN, ['cancel_meeting']) + system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_cancel']) + self.DST.system_update(system_act) + return system_act elif last_system_act.getActType() == SystemActType.REQUEST: if 'date' not in slots: - return SystemAct(SystemActType.REQUEST, ['date']) + system_act = SystemAct(SystemActType.REQUEST, ['date']) + self.DST.system_update(system_act) + return system_act elif 'time' not in slots: - return SystemAct(SystemActType.REQUEST, ['time']) + system_act = SystemAct(SystemActType.REQUEST, ['time']) + self.DST.system_update(system_act) + return system_act else: # implementacja wyszukiwania odpowiedniego spotkania w bazie - return SystemAct(SystemActType.CONFIRM_DOMAIN, ['cancel_meeting']) + system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_cancel']) + self.DST.system_update(system_act) + return system_act elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN: if last_user_act == UserActType.CONFIRM: system_act = SystemAct(SystemActType.AFFIRM, ['cancel_meeting']) @@ -157,7 +240,9 @@ class DP: self.DST.clear() return system_act else: - return SystemAct(SystemActType.REQUEST, ['date']) + system_act = SystemAct(SystemActType.REQUEST, ['date']) + self.DST.system_update(system_act) + return system_act # stan prośby o czas wolny elif dialogue_state == UserActType.FREE_TIME: if last_user_act == UserActType.NEGATE: @@ -169,7 +254,9 @@ class DP: self.DST.clear() return system_act else: - return SystemAct(SystemActType.REQUEST, ['date']) + system_act = SystemAct(SystemActType.REQUEST, ['date']) + self.DST.system_update(system_act) + return system_act # brak określonego stanu else: if last_user_act == UserActType.HELLO: diff --git a/DialogueStateTracker.py b/DialogueStateTracker.py index 48b9339..0eddd49 100644 --- a/DialogueStateTracker.py +++ b/DialogueStateTracker.py @@ -39,6 +39,7 @@ class DST: def clear(self): self.state = None + self.last_system_act = None self.slots = {} def clear_slots(self): diff --git a/main.py b/main.py index e73c729..30b0867 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +from SystemActType import SystemActType from NaturalLanguageUnderstanding import NLU from NaturalLanguageGeneration import NLG from DialogueStateTracker import DST @@ -20,7 +21,6 @@ if __name__ == "__main__": state, last_user_act, last_system_act = dst.get_dialogue_state() slots = dst.get_dialogue_slots() system_act = dp.chooseTactic() - dst.system_update(system_act) print('------ stan ------') print(state, last_user_act, last_system_act) print('------ przechowywane sloty ------') @@ -30,7 +30,6 @@ if __name__ == "__main__": print('-----------------------------------') print('-----------------------------------') #text = nlg.toText(system_act) - - #print(text) - #if system_act.isDialogFinished(): - # break + + if system_act.getActType() == SystemActType.BYE: + break