From 44588ed65b27e50deef0cc908552d248d0238739 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20J=C4=99dyk?= Date: Mon, 31 May 2021 13:52:01 +0200 Subject: [PATCH] DP and DST improvement --- ActFrame.py | 6 -- DialoguePolicy.py | 205 ++++++++++++++++++++++++++++++++-------- DialogueStateTracker.py | 44 +++++++-- SystemActType.py | 15 ++- main.py | 18 +++- 5 files changed, 224 insertions(+), 64 deletions(-) diff --git a/ActFrame.py b/ActFrame.py index 14e0ab5..3f1ecd8 100644 --- a/ActFrame.py +++ b/ActFrame.py @@ -9,9 +9,6 @@ class ActFrame(ABC): self.__actType = actType if actParams != None: - if type(actParams) is not list: - raise Exception( - 'actParams has wrong type: expected type \'list\', got \'{}\''.format(type(actParams))) self.__actParams = actParams def __repr__(self): @@ -22,9 +19,6 @@ class ActFrame(ABC): def setActParams(self, actParams): - if type(actParams) is not list: - raise Exception( - 'actParams has wrong type: expected type \'list\', got \'{}\''.format(type(actParams))) self.__actParams = actParams def getActParams(self): diff --git a/DialoguePolicy.py b/DialoguePolicy.py index a837e84..a2392c4 100644 --- a/DialoguePolicy.py +++ b/DialoguePolicy.py @@ -11,43 +11,172 @@ class DP: Wyjście: Akt systemu (rama) """ - def __init__(self): - self.results = [] + def __init__(self, dst): + self.DST = dst - def chooseTactic(self, current_frame) -> SystemAct: - #userAct = frameList[-1] - if current_frame.getActType() == UserActType.HELLO: - return SystemAct(SystemActType.WELCOME_MSG) - elif current_frame.getActType() == UserActType.BYE: - return SystemAct(SystemActType.BYE) - elif current_frame.getActType() == UserActType.CONFIRM: - # Czy napewno zawsze po Confirm jest Affirm? - return SystemAct(SystemActType.AFFIRM) - elif current_frame.getActType() == UserActType.NEGATE: - # TODO rozpoznanie czy ma się już komplet danych - # Affirm (gdy ma się wszystkie potrzebne zdanie) - # Request (gdy potrzeba się dopytać dalej) - # Bye (gdy to odp na REQMORE) - return SystemAct(SystemActType.AFFIRM) - elif current_frame.getActType() == UserActType.THANKYOU: - return SystemAct(SystemActType.REQMORE) - elif current_frame.getActType() == UserActType.INFORM: - # TODO najczęściej chyba AFFIRM, CONFIRM_DOMAIN i REQUEST - return SystemAct(SystemActType.REQUEST) - elif current_frame.getActType() == UserActType.CREATE_MEETING: - # TODO najczęściej chyba CONFIRM_DOMAIN i REQUEST - return SystemAct(SystemActType.REQUEST) - elif current_frame.getActType() == UserActType.UPDATE_MEETING: - # TODO rozpoznanie czy ma się już komplet danych jak nie to REQUEST jak tak to CONFIRM_DOMAIN - return SystemAct(SystemActType.REQUEST) - elif current_frame.getActType() == UserActType.CANCEL_MEETING: - # TODO rozpoznanie czy ma się już komplet danych jak nie to REQUEST jak tak to CONFIRM_DOMAIN - return SystemAct(SystemActType.REQUEST) - elif current_frame.getActType() == UserActType.MEETING_LIST: - return SystemAct(SystemActType.INFORM, ["meeting_list"]) - elif current_frame.getActType() == UserActType.FREE_TIME: - return SystemAct(SystemActType.INFORM, ["freetime"]) - elif current_frame.getActType() == UserActType.INVALID: - return SystemAct(SystemActType.NOT_UNDERSTOOD) + def chooseTactic(self) -> SystemAct: + dialogue_state, last_user_act, last_system_act = self.DST.get_dialogue_state() + slots = self.DST.get_dialogue_slots() + # stan dodawania spotkania + if dialogue_state == UserActType.CREATE_MEETING: + if not last_system_act: + if 'date' not in slots: + return SystemAct(SystemActType.REQUEST, ['date']) + elif 'time' not in slots: + return SystemAct(SystemActType.REQUEST, ['time']) + elif 'place' not in slots: + return SystemAct(SystemActType.REQUEST, ['place']) + elif 'description' not in slots: + return SystemAct(SystemActType.REQUEST, ['description']) + elif 'participants' not in slots: + return SystemAct(SystemActType.REQUEST, ['participants']) + else: + return SystemAct(SystemActType.CONFIRM_DOMAIN, slots) + elif last_system_act.getActType() == SystemActType.REQUEST: + if last_user_act == UserActType.NEGATE: + slot_type = last_system_act.getActParams()[0] + if slot_type not in ['date', 'time']: + self.DST.insert_empty_slot(slot_type) + if 'date' not in slots: + return SystemAct(SystemActType.REQUEST, ['date']) + elif 'time' not in slots: + return SystemAct(SystemActType.REQUEST, ['time']) + elif 'place' not in slots: + return SystemAct(SystemActType.REQUEST, ['place']) + elif 'description' not in slots: + return SystemAct(SystemActType.REQUEST, ['description']) + elif 'participants' not in slots: + return SystemAct(SystemActType.REQUEST, ['participants']) + else: + return SystemAct(SystemActType.CONFIRM_DOMAIN, slots) + elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN: + if last_user_act == UserActType.CONFIRM: + system_act = SystemAct(SystemActType.AFFIRM, ['create_meeting']) + # implementacja wpisywanie spotkania do bazy + self.DST.clear() + return system_act + elif last_user_act == UserActType.NEGATE: + self.DST.clear() + return SystemAct(SystemActType.REQMORE, []) + else: + return SystemAct(SystemActType.NOT_UNDERSTOOD, []) + # stan edycji spotkania + elif dialogue_state == UserActType.UPDATE_MEETING: + meeting_to_update = False + if not last_system_act: + if 'date' not in slots: + return SystemAct(SystemActType.REQUEST, ['date']) + elif 'time' not in slots: + return SystemAct(SystemActType.REQUEST, ['time']) + else: + # implementacja wyszukiwania odpowiedniego spotkania w bazie + return SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update']) + elif last_system_act.getActType() == SystemActType.REQUEST: + if not meeting_to_update: + if 'date' not in slots: + return SystemAct(SystemActType.REQUEST, ['date']) + elif 'time' not in slots: + return SystemAct(SystemActType.REQUEST, ['time']) + else: + # implementacja wyszukiwania odpowiedniego spotkania w bazie + return SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update']) + else: + if last_user_act == UserActType.NEGATE: + slot_type = last_system_act.getActParams()[0] + if slot_type not in ['date', 'time']: + self.DST.insert_empty_slot(slot_type) + if 'date' not in slots: + return SystemAct(SystemActType.REQUEST, ['date']) + elif 'time' not in slots: + return SystemAct(SystemActType.REQUEST, ['time']) + elif 'place' not in slots: + return SystemAct(SystemActType.REQUEST, ['place']) + elif 'description' not in slots: + return SystemAct(SystemActType.REQUEST, ['description']) + elif 'participants' not in slots: + return SystemAct(SystemActType.REQUEST, ['participants']) + else: + return SystemAct(SystemActType.CONFIRM_DOMAIN, slots) + elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN: + if meeting_to_update: + if last_user_act == UserActType.CONFIRM: + meeting_to_update = False + self.DST.clear() + return SystemAct(SystemActType.AFFIRM, ['update_meeting']) + elif last_user_act == UserActType.NEGATE: + self.DST.clear() + return SystemAct(SystemActType.REQMORE, []) + meeting_to_update = False + if not meeting_to_update: + if last_user_act == UserActType.CONFIRM: + meeting_to_update = True + self.DST.clear_slots() + return SystemAct(SystemActType.REQUEST, ['date']) + elif last_user_act == UserActType.NEGATE: + self.DST.clear() + return SystemAct(SystemActType.REQMORE, []) + else: + return SystemAct(SystemActType.NOT_UNDERSTOOD, []) + # stan anulowania spotkania + elif dialogue_state == UserActType.CANCEL_MEETING: + if not last_system_act: + if 'date' not in slots: + return SystemAct(SystemActType.REQUEST, ['date']) + elif 'time' not in slots: + return SystemAct(SystemActType.REQUEST, ['time']) + else: + # implementacja wyszukiwania odpowiedniego spotkania w bazie + return SystemAct(SystemActType.CONFIRM_DOMAIN, ['cancel_meeting']) + elif last_system_act.getActType() == SystemActType.REQUEST: + if 'date' not in slots: + return SystemAct(SystemActType.REQUEST, ['date']) + elif 'time' not in slots: + return SystemAct(SystemActType.REQUEST, ['time']) + else: + # implementacja wyszukiwania odpowiedniego spotkania w bazie + return SystemAct(SystemActType.CONFIRM_DOMAIN, ['cancel_meeting']) + elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN: + if last_user_act == UserActType.CONFIRM: + system_act = SystemAct(SystemActType.AFFIRM, ['cancel_meeting']) + # implementacja usuwania spotkania z bazy + self.DST.clear() + return system_act + elif last_user_act == UserActType.NEGATE: + self.DST.clear() + return SystemAct(SystemActType.REQMORE, []) + else: + return SystemAct(SystemActType.NOT_UNDERSTOOD, []) + # stan prośby o listę spotkań + elif dialogue_state == UserActType.MEETING_LIST: + if last_user_act == UserActType.NEGATE: + self.DST.clear() + return SystemAct(SystemActType.REQMORE, []) + else: + if 'date' in slots: + system_act = SystemAct(SystemActType.MEETING_LIST, slots) + self.DST.clear() + return system_act + else: + return SystemAct(SystemActType.REQUEST, ['date']) + # stan prośby o czas wolny + elif dialogue_state == UserActType.FREE_TIME: + if last_user_act == UserActType.NEGATE: + self.DST.clear() + return SystemAct(SystemActType.REQMORE, []) + else: + if 'date' in slots: + system_act = SystemAct(SystemActType.FREE_TIME, slots) + self.DST.clear() + return system_act + else: + return SystemAct(SystemActType.REQUEST, ['date']) + # brak określonego stanu else: - return SystemAct(SystemActType.INFORM,['name']) + if last_user_act == UserActType.HELLO: + return SystemAct(SystemActType.WELCOME_MSG, []) + elif last_user_act == UserActType.BYE: + return SystemAct(SystemActType.BYE, []) + elif last_user_act == UserActType.THANKYOU: + return SystemAct(SystemActType.REQMORE, []) + else: + return SystemAct(SystemActType.NOT_UNDERSTOOD, []) diff --git a/DialogueStateTracker.py b/DialogueStateTracker.py index f9d6ff1..48b9339 100644 --- a/DialogueStateTracker.py +++ b/DialogueStateTracker.py @@ -1,3 +1,4 @@ +from os import system from UserActType import UserActType from UserAct import UserAct @@ -10,16 +11,41 @@ class DST: """ def __init__(self): - self.frameList = [] self.state = None + self.last_user_act = None + self.last_system_act = None + self.slots = {} - def update(self, frame): - self.addFrame(frame) - self.state = frame - return self.state + def user_update(self, frame): + user_act = frame.getActType() + self.last_user_act = user_act + for slot in frame.getActParams(): + if slot[0] == 'participant': + if 'participants' not in self.slots: + self.slots['participants'] = [slot[1]] + else: + self.slots['participants'].append(slot[1]) + else: + self.slots[slot[0]] = slot[1] + if not self.state: + if user_act in [UserActType.CREATE_MEETING, UserActType.UPDATE_MEETING, UserActType.CANCEL_MEETING, UserActType.MEETING_LIST, UserActType.FREE_TIME]: + self.state = user_act - def addFrame(self, frame): - self.frameList.append(frame) + def system_update(self, system_act): + self.last_system_act = system_act - def getFrames(self): - return self.frameList + def insert_empty_slot(self, slot_name): + self.slots[slot_name] = None + + def clear(self): + self.state = None + self.slots = {} + + def clear_slots(self): + self.slots = {} + + def get_dialogue_state(self): + return self.state, self.last_user_act, self.last_system_act + + def get_dialogue_slots(self): + return self.slots diff --git a/SystemActType.py b/SystemActType.py index 3747b10..6ecceec 100644 --- a/SystemActType.py +++ b/SystemActType.py @@ -4,12 +4,11 @@ from enum import Enum, unique @unique class SystemActType(Enum): WELCOME_MSG = 0 - INFORM = 1 - BYE = 2 - REQUEST = 3 - INFORM = 4 - AFFIRM = 5 - CONFIRM_DOMAIN = 6 - OFFER = 7 - REQMORE = 8 + BYE = 1 + REQMORE = 2 + AFFIRM = 3 + CONFIRM_DOMAIN = 4 + MEETING_LIST = 5 + FREE_TIME = 6 + REQUEST = 7 NOT_UNDERSTOOD = -1 diff --git a/main.py b/main.py index 68ae437..e73c729 100644 --- a/main.py +++ b/main.py @@ -7,16 +7,28 @@ if __name__ == "__main__": nlu = NLU() dst = DST() - dp = DP() + dp = DP(dst) nlg = NLG() while(1): user_input = input("Wpisz tekst: ") user_frame = nlu.parse_user_input(user_input) + print('------ rozpoznany user frame ------') print(user_frame) - #dst.addFrame(user_frame) - #system_act = dp.chooseTactic(dst.getFrames()) + dst.user_update(user_frame) + state, last_user_act, last_system_act = dst.get_dialogue_state() + slots = dst.get_dialogue_slots() + system_act = dp.chooseTactic() + dst.system_update(system_act) + print('------ stan ------') + print(state, last_user_act, last_system_act) + print('------ przechowywane sloty ------') + print(slots) + print('------ wybrana akcja systemu ------') + print(system_act) + print('-----------------------------------') + print('-----------------------------------') #text = nlg.toText(system_act) #print(text)