fix DST bugs

This commit is contained in:
Łukasz Jędyk 2021-05-31 14:53:20 +02:00
parent 44588ed65b
commit f343ebec98
3 changed files with 133 additions and 46 deletions

View File

@ -13,6 +13,7 @@ class DP:
def __init__(self, dst): def __init__(self, dst):
self.DST = dst self.DST = dst
self.meeting_to_update = False
def chooseTactic(self) -> SystemAct: def chooseTactic(self) -> SystemAct:
dialogue_state, last_user_act, last_system_act = self.DST.get_dialogue_state() dialogue_state, last_user_act, last_system_act = self.DST.get_dialogue_state()
@ -21,34 +22,58 @@ class DP:
if dialogue_state == UserActType.CREATE_MEETING: if dialogue_state == UserActType.CREATE_MEETING:
if not last_system_act: if not last_system_act:
if 'date' not in slots: if 'date' not in slots:
return SystemAct(SystemActType.REQUEST, ['date']) system_act = SystemAct(SystemActType.REQUEST, ['date'])
self.DST.system_update(system_act)
return system_act
elif 'time' not in slots: elif 'time' not in slots:
return SystemAct(SystemActType.REQUEST, ['time']) system_act = SystemAct(SystemActType.REQUEST, ['time'])
self.DST.system_update(system_act)
return system_act
elif 'place' not in slots: elif 'place' not in slots:
return SystemAct(SystemActType.REQUEST, ['place']) system_act = SystemAct(SystemActType.REQUEST, ['place'])
self.DST.system_update(system_act)
return system_act
elif 'description' not in slots: elif 'description' not in slots:
return SystemAct(SystemActType.REQUEST, ['description']) system_act = SystemAct(SystemActType.REQUEST, ['description'])
self.DST.system_update(system_act)
return system_act
elif 'participants' not in slots: elif 'participants' not in slots:
return SystemAct(SystemActType.REQUEST, ['participants']) system_act = SystemAct(SystemActType.REQUEST, ['participants'])
self.DST.system_update(system_act)
return system_act
else: else:
return SystemAct(SystemActType.CONFIRM_DOMAIN, slots) system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots)
self.DST.system_update(system_act)
return system_act
elif last_system_act.getActType() == SystemActType.REQUEST: elif last_system_act.getActType() == SystemActType.REQUEST:
if last_user_act == UserActType.NEGATE: if last_user_act == UserActType.NEGATE:
slot_type = last_system_act.getActParams()[0] slot_type = last_system_act.getActParams()[0]
if slot_type not in ['date', 'time']: if slot_type not in ['date', 'time']:
self.DST.insert_empty_slot(slot_type) self.DST.insert_empty_slot(slot_type)
if 'date' not in slots: if 'date' not in slots:
return SystemAct(SystemActType.REQUEST, ['date']) system_act = SystemAct(SystemActType.REQUEST, ['date'])
self.DST.system_update(system_act)
return system_act
elif 'time' not in slots: elif 'time' not in slots:
return SystemAct(SystemActType.REQUEST, ['time']) system_act = SystemAct(SystemActType.REQUEST, ['time'])
self.DST.system_update(system_act)
return system_act
elif 'place' not in slots: elif 'place' not in slots:
return SystemAct(SystemActType.REQUEST, ['place']) system_act = SystemAct(SystemActType.REQUEST, ['place'])
self.DST.system_update(system_act)
return system_act
elif 'description' not in slots: elif 'description' not in slots:
return SystemAct(SystemActType.REQUEST, ['description']) system_act = SystemAct(SystemActType.REQUEST, ['description'])
self.DST.system_update(system_act)
return system_act
elif 'participants' not in slots: elif 'participants' not in slots:
return SystemAct(SystemActType.REQUEST, ['participants']) system_act = SystemAct(SystemActType.REQUEST, ['participants'])
self.DST.system_update(system_act)
return system_act
else: else:
return SystemAct(SystemActType.CONFIRM_DOMAIN, slots) system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots)
self.DST.system_update(system_act)
return system_act
elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN: elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN:
if last_user_act == UserActType.CONFIRM: if last_user_act == UserActType.CONFIRM:
system_act = SystemAct(SystemActType.AFFIRM, ['create_meeting']) system_act = SystemAct(SystemActType.AFFIRM, ['create_meeting'])
@ -62,56 +87,102 @@ class DP:
return SystemAct(SystemActType.NOT_UNDERSTOOD, []) return SystemAct(SystemActType.NOT_UNDERSTOOD, [])
# stan edycji spotkania # stan edycji spotkania
elif dialogue_state == UserActType.UPDATE_MEETING: elif dialogue_state == UserActType.UPDATE_MEETING:
meeting_to_update = False
if not last_system_act: if not last_system_act:
if 'date' not in slots: if 'date' not in slots:
return SystemAct(SystemActType.REQUEST, ['date']) system_act = SystemAct(SystemActType.REQUEST, ['date'])
self.DST.system_update(system_act)
return system_act
elif 'time' not in slots: elif 'time' not in slots:
return SystemAct(SystemActType.REQUEST, ['time']) system_act = SystemAct(SystemActType.REQUEST, ['time'])
self.DST.system_update(system_act)
return system_act
else: else:
# implementacja wyszukiwania odpowiedniego spotkania w bazie # implementacja wyszukiwania odpowiedniego spotkania w bazie
return SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update']) return SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update'])
elif last_system_act.getActType() == SystemActType.REQUEST: elif last_system_act.getActType() == SystemActType.REQUEST:
if not meeting_to_update: if not self.meeting_to_update:
if 'date' not in slots: if 'date' not in slots:
return SystemAct(SystemActType.REQUEST, ['date']) system_act = SystemAct(SystemActType.REQUEST, ['date'])
self.DST.system_update(system_act)
return system_act
elif 'time' not in slots: elif 'time' not in slots:
return SystemAct(SystemActType.REQUEST, ['time']) system_act = SystemAct(SystemActType.REQUEST, ['time'])
self.DST.system_update(system_act)
return system_act
else: else:
# implementacja wyszukiwania odpowiedniego spotkania w bazie # implementacja wyszukiwania odpowiedniego spotkania w bazie
return SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update']) system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update'])
self.DST.system_update(system_act)
return system_act
else: else:
if last_user_act == UserActType.NEGATE: if last_user_act == UserActType.NEGATE:
slot_type = last_system_act.getActParams()[0] slot_type = last_system_act.getActParams()[0]
if slot_type not in ['date', 'time']: self.DST.insert_empty_slot(slot_type)
self.DST.insert_empty_slot(slot_type)
if 'date' not in slots: if 'date' not in slots:
return SystemAct(SystemActType.REQUEST, ['date']) system_act = SystemAct(SystemActType.REQUEST, ['date'])
self.DST.system_update(system_act)
return system_act
elif 'time' not in slots: elif 'time' not in slots:
return SystemAct(SystemActType.REQUEST, ['time']) system_act = SystemAct(SystemActType.REQUEST, ['time'])
self.DST.system_update(system_act)
return system_act
elif 'place' not in slots: elif 'place' not in slots:
return SystemAct(SystemActType.REQUEST, ['place']) system_act = SystemAct(SystemActType.REQUEST, ['place'])
self.DST.system_update(system_act)
return system_act
elif 'description' not in slots: elif 'description' not in slots:
return SystemAct(SystemActType.REQUEST, ['description']) system_act = SystemAct(SystemActType.REQUEST, ['description'])
self.DST.system_update(system_act)
return system_act
elif 'participants' not in slots: elif 'participants' not in slots:
return SystemAct(SystemActType.REQUEST, ['participants']) system_act = SystemAct(SystemActType.REQUEST, ['participants'])
self.DST.system_update(system_act)
return system_act
else: else:
return SystemAct(SystemActType.CONFIRM_DOMAIN, slots) system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots)
self.DST.system_update(system_act)
return system_act
else:
if 'date' not in slots:
system_act = SystemAct(SystemActType.REQUEST, ['date'])
self.DST.system_update(system_act)
return system_act
elif 'time' not in slots:
system_act = SystemAct(SystemActType.REQUEST, ['time'])
self.DST.system_update(system_act)
return system_act
elif 'place' not in slots:
system_act = SystemAct(SystemActType.REQUEST, ['place'])
self.DST.system_update(system_act)
return system_act
elif 'description' not in slots:
system_act = SystemAct(SystemActType.REQUEST, ['description'])
self.DST.system_update(system_act)
return system_act
elif 'participants' not in slots:
system_act = SystemAct(SystemActType.REQUEST, ['participants'])
self.DST.system_update(system_act)
return system_act
else:
system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots)
self.DST.system_update(system_act)
return system_act
elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN: elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN:
if meeting_to_update: if self.meeting_to_update:
if last_user_act == UserActType.CONFIRM: if last_user_act == UserActType.CONFIRM:
meeting_to_update = False
self.DST.clear() self.DST.clear()
return SystemAct(SystemActType.AFFIRM, ['update_meeting']) return SystemAct(SystemActType.AFFIRM, ['update_meeting'])
elif last_user_act == UserActType.NEGATE: elif last_user_act == UserActType.NEGATE:
self.DST.clear() self.DST.clear()
return SystemAct(SystemActType.REQMORE, []) return SystemAct(SystemActType.REQMORE, [])
meeting_to_update = False self.meeting_to_update = False
if not meeting_to_update: if not self.meeting_to_update:
if last_user_act == UserActType.CONFIRM: if last_user_act == UserActType.CONFIRM:
meeting_to_update = True self.meeting_to_update = True
self.DST.clear_slots() self.DST.clear_slots()
return SystemAct(SystemActType.REQUEST, ['date']) system_act = SystemAct(SystemActType.REQUEST, ['date'])
self.DST.system_update(system_act)
return system_act
elif last_user_act == UserActType.NEGATE: elif last_user_act == UserActType.NEGATE:
self.DST.clear() self.DST.clear()
return SystemAct(SystemActType.REQMORE, []) return SystemAct(SystemActType.REQMORE, [])
@ -121,20 +192,32 @@ class DP:
elif dialogue_state == UserActType.CANCEL_MEETING: elif dialogue_state == UserActType.CANCEL_MEETING:
if not last_system_act: if not last_system_act:
if 'date' not in slots: if 'date' not in slots:
return SystemAct(SystemActType.REQUEST, ['date']) system_act = SystemAct(SystemActType.REQUEST, ['date'])
self.DST.system_update(system_act)
return system_act
elif 'time' not in slots: elif 'time' not in slots:
return SystemAct(SystemActType.REQUEST, ['time']) system_act = SystemAct(SystemActType.REQUEST, ['time'])
self.DST.system_update(system_act)
return system_act
else: else:
# implementacja wyszukiwania odpowiedniego spotkania w bazie # implementacja wyszukiwania odpowiedniego spotkania w bazie
return SystemAct(SystemActType.CONFIRM_DOMAIN, ['cancel_meeting']) system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_cancel'])
self.DST.system_update(system_act)
return system_act
elif last_system_act.getActType() == SystemActType.REQUEST: elif last_system_act.getActType() == SystemActType.REQUEST:
if 'date' not in slots: if 'date' not in slots:
return SystemAct(SystemActType.REQUEST, ['date']) system_act = SystemAct(SystemActType.REQUEST, ['date'])
self.DST.system_update(system_act)
return system_act
elif 'time' not in slots: elif 'time' not in slots:
return SystemAct(SystemActType.REQUEST, ['time']) system_act = SystemAct(SystemActType.REQUEST, ['time'])
self.DST.system_update(system_act)
return system_act
else: else:
# implementacja wyszukiwania odpowiedniego spotkania w bazie # implementacja wyszukiwania odpowiedniego spotkania w bazie
return SystemAct(SystemActType.CONFIRM_DOMAIN, ['cancel_meeting']) system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_cancel'])
self.DST.system_update(system_act)
return system_act
elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN: elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN:
if last_user_act == UserActType.CONFIRM: if last_user_act == UserActType.CONFIRM:
system_act = SystemAct(SystemActType.AFFIRM, ['cancel_meeting']) system_act = SystemAct(SystemActType.AFFIRM, ['cancel_meeting'])
@ -157,7 +240,9 @@ class DP:
self.DST.clear() self.DST.clear()
return system_act return system_act
else: else:
return SystemAct(SystemActType.REQUEST, ['date']) system_act = SystemAct(SystemActType.REQUEST, ['date'])
self.DST.system_update(system_act)
return system_act
# stan prośby o czas wolny # stan prośby o czas wolny
elif dialogue_state == UserActType.FREE_TIME: elif dialogue_state == UserActType.FREE_TIME:
if last_user_act == UserActType.NEGATE: if last_user_act == UserActType.NEGATE:
@ -169,7 +254,9 @@ class DP:
self.DST.clear() self.DST.clear()
return system_act return system_act
else: else:
return SystemAct(SystemActType.REQUEST, ['date']) system_act = SystemAct(SystemActType.REQUEST, ['date'])
self.DST.system_update(system_act)
return system_act
# brak określonego stanu # brak określonego stanu
else: else:
if last_user_act == UserActType.HELLO: if last_user_act == UserActType.HELLO:

View File

@ -39,6 +39,7 @@ class DST:
def clear(self): def clear(self):
self.state = None self.state = None
self.last_system_act = None
self.slots = {} self.slots = {}
def clear_slots(self): def clear_slots(self):

View File

@ -1,3 +1,4 @@
from SystemActType import SystemActType
from NaturalLanguageUnderstanding import NLU from NaturalLanguageUnderstanding import NLU
from NaturalLanguageGeneration import NLG from NaturalLanguageGeneration import NLG
from DialogueStateTracker import DST from DialogueStateTracker import DST
@ -20,7 +21,6 @@ if __name__ == "__main__":
state, last_user_act, last_system_act = dst.get_dialogue_state() state, last_user_act, last_system_act = dst.get_dialogue_state()
slots = dst.get_dialogue_slots() slots = dst.get_dialogue_slots()
system_act = dp.chooseTactic() system_act = dp.chooseTactic()
dst.system_update(system_act)
print('------ stan ------') print('------ stan ------')
print(state, last_user_act, last_system_act) print(state, last_user_act, last_system_act)
print('------ przechowywane sloty ------') print('------ przechowywane sloty ------')
@ -30,7 +30,6 @@ if __name__ == "__main__":
print('-----------------------------------') print('-----------------------------------')
print('-----------------------------------') print('-----------------------------------')
#text = nlg.toText(system_act) #text = nlg.toText(system_act)
#print(text) if system_act.getActType() == SystemActType.BYE:
#if system_act.isDialogFinished(): break
# break