fix DST bugs
This commit is contained in:
parent
44588ed65b
commit
f343ebec98
@ -13,6 +13,7 @@ class DP:
|
|||||||
|
|
||||||
def __init__(self, dst):
|
def __init__(self, dst):
|
||||||
self.DST = dst
|
self.DST = dst
|
||||||
|
self.meeting_to_update = False
|
||||||
|
|
||||||
def chooseTactic(self) -> SystemAct:
|
def chooseTactic(self) -> SystemAct:
|
||||||
dialogue_state, last_user_act, last_system_act = self.DST.get_dialogue_state()
|
dialogue_state, last_user_act, last_system_act = self.DST.get_dialogue_state()
|
||||||
@ -21,34 +22,58 @@ class DP:
|
|||||||
if dialogue_state == UserActType.CREATE_MEETING:
|
if dialogue_state == UserActType.CREATE_MEETING:
|
||||||
if not last_system_act:
|
if not last_system_act:
|
||||||
if 'date' not in slots:
|
if 'date' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['date'])
|
system_act = SystemAct(SystemActType.REQUEST, ['date'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'time' not in slots:
|
elif 'time' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['time'])
|
system_act = SystemAct(SystemActType.REQUEST, ['time'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'place' not in slots:
|
elif 'place' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['place'])
|
system_act = SystemAct(SystemActType.REQUEST, ['place'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'description' not in slots:
|
elif 'description' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['description'])
|
system_act = SystemAct(SystemActType.REQUEST, ['description'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'participants' not in slots:
|
elif 'participants' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['participants'])
|
system_act = SystemAct(SystemActType.REQUEST, ['participants'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
else:
|
else:
|
||||||
return SystemAct(SystemActType.CONFIRM_DOMAIN, slots)
|
system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots)
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif last_system_act.getActType() == SystemActType.REQUEST:
|
elif last_system_act.getActType() == SystemActType.REQUEST:
|
||||||
if last_user_act == UserActType.NEGATE:
|
if last_user_act == UserActType.NEGATE:
|
||||||
slot_type = last_system_act.getActParams()[0]
|
slot_type = last_system_act.getActParams()[0]
|
||||||
if slot_type not in ['date', 'time']:
|
if slot_type not in ['date', 'time']:
|
||||||
self.DST.insert_empty_slot(slot_type)
|
self.DST.insert_empty_slot(slot_type)
|
||||||
if 'date' not in slots:
|
if 'date' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['date'])
|
system_act = SystemAct(SystemActType.REQUEST, ['date'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'time' not in slots:
|
elif 'time' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['time'])
|
system_act = SystemAct(SystemActType.REQUEST, ['time'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'place' not in slots:
|
elif 'place' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['place'])
|
system_act = SystemAct(SystemActType.REQUEST, ['place'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'description' not in slots:
|
elif 'description' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['description'])
|
system_act = SystemAct(SystemActType.REQUEST, ['description'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'participants' not in slots:
|
elif 'participants' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['participants'])
|
system_act = SystemAct(SystemActType.REQUEST, ['participants'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
else:
|
else:
|
||||||
return SystemAct(SystemActType.CONFIRM_DOMAIN, slots)
|
system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots)
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN:
|
elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN:
|
||||||
if last_user_act == UserActType.CONFIRM:
|
if last_user_act == UserActType.CONFIRM:
|
||||||
system_act = SystemAct(SystemActType.AFFIRM, ['create_meeting'])
|
system_act = SystemAct(SystemActType.AFFIRM, ['create_meeting'])
|
||||||
@ -62,56 +87,102 @@ class DP:
|
|||||||
return SystemAct(SystemActType.NOT_UNDERSTOOD, [])
|
return SystemAct(SystemActType.NOT_UNDERSTOOD, [])
|
||||||
# stan edycji spotkania
|
# stan edycji spotkania
|
||||||
elif dialogue_state == UserActType.UPDATE_MEETING:
|
elif dialogue_state == UserActType.UPDATE_MEETING:
|
||||||
meeting_to_update = False
|
|
||||||
if not last_system_act:
|
if not last_system_act:
|
||||||
if 'date' not in slots:
|
if 'date' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['date'])
|
system_act = SystemAct(SystemActType.REQUEST, ['date'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'time' not in slots:
|
elif 'time' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['time'])
|
system_act = SystemAct(SystemActType.REQUEST, ['time'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
else:
|
else:
|
||||||
# implementacja wyszukiwania odpowiedniego spotkania w bazie
|
# implementacja wyszukiwania odpowiedniego spotkania w bazie
|
||||||
return SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update'])
|
return SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update'])
|
||||||
elif last_system_act.getActType() == SystemActType.REQUEST:
|
elif last_system_act.getActType() == SystemActType.REQUEST:
|
||||||
if not meeting_to_update:
|
if not self.meeting_to_update:
|
||||||
if 'date' not in slots:
|
if 'date' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['date'])
|
system_act = SystemAct(SystemActType.REQUEST, ['date'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'time' not in slots:
|
elif 'time' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['time'])
|
system_act = SystemAct(SystemActType.REQUEST, ['time'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
else:
|
else:
|
||||||
# implementacja wyszukiwania odpowiedniego spotkania w bazie
|
# implementacja wyszukiwania odpowiedniego spotkania w bazie
|
||||||
return SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update'])
|
system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_update'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
else:
|
else:
|
||||||
if last_user_act == UserActType.NEGATE:
|
if last_user_act == UserActType.NEGATE:
|
||||||
slot_type = last_system_act.getActParams()[0]
|
slot_type = last_system_act.getActParams()[0]
|
||||||
if slot_type not in ['date', 'time']:
|
|
||||||
self.DST.insert_empty_slot(slot_type)
|
self.DST.insert_empty_slot(slot_type)
|
||||||
if 'date' not in slots:
|
if 'date' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['date'])
|
system_act = SystemAct(SystemActType.REQUEST, ['date'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'time' not in slots:
|
elif 'time' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['time'])
|
system_act = SystemAct(SystemActType.REQUEST, ['time'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'place' not in slots:
|
elif 'place' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['place'])
|
system_act = SystemAct(SystemActType.REQUEST, ['place'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'description' not in slots:
|
elif 'description' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['description'])
|
system_act = SystemAct(SystemActType.REQUEST, ['description'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'participants' not in slots:
|
elif 'participants' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['participants'])
|
system_act = SystemAct(SystemActType.REQUEST, ['participants'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
else:
|
else:
|
||||||
return SystemAct(SystemActType.CONFIRM_DOMAIN, slots)
|
system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots)
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
|
else:
|
||||||
|
if 'date' not in slots:
|
||||||
|
system_act = SystemAct(SystemActType.REQUEST, ['date'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
|
elif 'time' not in slots:
|
||||||
|
system_act = SystemAct(SystemActType.REQUEST, ['time'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
|
elif 'place' not in slots:
|
||||||
|
system_act = SystemAct(SystemActType.REQUEST, ['place'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
|
elif 'description' not in slots:
|
||||||
|
system_act = SystemAct(SystemActType.REQUEST, ['description'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
|
elif 'participants' not in slots:
|
||||||
|
system_act = SystemAct(SystemActType.REQUEST, ['participants'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
|
else:
|
||||||
|
system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots)
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN:
|
elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN:
|
||||||
if meeting_to_update:
|
if self.meeting_to_update:
|
||||||
if last_user_act == UserActType.CONFIRM:
|
if last_user_act == UserActType.CONFIRM:
|
||||||
meeting_to_update = False
|
|
||||||
self.DST.clear()
|
self.DST.clear()
|
||||||
return SystemAct(SystemActType.AFFIRM, ['update_meeting'])
|
return SystemAct(SystemActType.AFFIRM, ['update_meeting'])
|
||||||
elif last_user_act == UserActType.NEGATE:
|
elif last_user_act == UserActType.NEGATE:
|
||||||
self.DST.clear()
|
self.DST.clear()
|
||||||
return SystemAct(SystemActType.REQMORE, [])
|
return SystemAct(SystemActType.REQMORE, [])
|
||||||
meeting_to_update = False
|
self.meeting_to_update = False
|
||||||
if not meeting_to_update:
|
if not self.meeting_to_update:
|
||||||
if last_user_act == UserActType.CONFIRM:
|
if last_user_act == UserActType.CONFIRM:
|
||||||
meeting_to_update = True
|
self.meeting_to_update = True
|
||||||
self.DST.clear_slots()
|
self.DST.clear_slots()
|
||||||
return SystemAct(SystemActType.REQUEST, ['date'])
|
system_act = SystemAct(SystemActType.REQUEST, ['date'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif last_user_act == UserActType.NEGATE:
|
elif last_user_act == UserActType.NEGATE:
|
||||||
self.DST.clear()
|
self.DST.clear()
|
||||||
return SystemAct(SystemActType.REQMORE, [])
|
return SystemAct(SystemActType.REQMORE, [])
|
||||||
@ -121,20 +192,32 @@ class DP:
|
|||||||
elif dialogue_state == UserActType.CANCEL_MEETING:
|
elif dialogue_state == UserActType.CANCEL_MEETING:
|
||||||
if not last_system_act:
|
if not last_system_act:
|
||||||
if 'date' not in slots:
|
if 'date' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['date'])
|
system_act = SystemAct(SystemActType.REQUEST, ['date'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'time' not in slots:
|
elif 'time' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['time'])
|
system_act = SystemAct(SystemActType.REQUEST, ['time'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
else:
|
else:
|
||||||
# implementacja wyszukiwania odpowiedniego spotkania w bazie
|
# implementacja wyszukiwania odpowiedniego spotkania w bazie
|
||||||
return SystemAct(SystemActType.CONFIRM_DOMAIN, ['cancel_meeting'])
|
system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_cancel'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif last_system_act.getActType() == SystemActType.REQUEST:
|
elif last_system_act.getActType() == SystemActType.REQUEST:
|
||||||
if 'date' not in slots:
|
if 'date' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['date'])
|
system_act = SystemAct(SystemActType.REQUEST, ['date'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif 'time' not in slots:
|
elif 'time' not in slots:
|
||||||
return SystemAct(SystemActType.REQUEST, ['time'])
|
system_act = SystemAct(SystemActType.REQUEST, ['time'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
else:
|
else:
|
||||||
# implementacja wyszukiwania odpowiedniego spotkania w bazie
|
# implementacja wyszukiwania odpowiedniego spotkania w bazie
|
||||||
return SystemAct(SystemActType.CONFIRM_DOMAIN, ['cancel_meeting'])
|
system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_cancel'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN:
|
elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN:
|
||||||
if last_user_act == UserActType.CONFIRM:
|
if last_user_act == UserActType.CONFIRM:
|
||||||
system_act = SystemAct(SystemActType.AFFIRM, ['cancel_meeting'])
|
system_act = SystemAct(SystemActType.AFFIRM, ['cancel_meeting'])
|
||||||
@ -157,7 +240,9 @@ class DP:
|
|||||||
self.DST.clear()
|
self.DST.clear()
|
||||||
return system_act
|
return system_act
|
||||||
else:
|
else:
|
||||||
return SystemAct(SystemActType.REQUEST, ['date'])
|
system_act = SystemAct(SystemActType.REQUEST, ['date'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
# stan prośby o czas wolny
|
# stan prośby o czas wolny
|
||||||
elif dialogue_state == UserActType.FREE_TIME:
|
elif dialogue_state == UserActType.FREE_TIME:
|
||||||
if last_user_act == UserActType.NEGATE:
|
if last_user_act == UserActType.NEGATE:
|
||||||
@ -169,7 +254,9 @@ class DP:
|
|||||||
self.DST.clear()
|
self.DST.clear()
|
||||||
return system_act
|
return system_act
|
||||||
else:
|
else:
|
||||||
return SystemAct(SystemActType.REQUEST, ['date'])
|
system_act = SystemAct(SystemActType.REQUEST, ['date'])
|
||||||
|
self.DST.system_update(system_act)
|
||||||
|
return system_act
|
||||||
# brak określonego stanu
|
# brak określonego stanu
|
||||||
else:
|
else:
|
||||||
if last_user_act == UserActType.HELLO:
|
if last_user_act == UserActType.HELLO:
|
||||||
|
@ -39,6 +39,7 @@ class DST:
|
|||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.state = None
|
self.state = None
|
||||||
|
self.last_system_act = None
|
||||||
self.slots = {}
|
self.slots = {}
|
||||||
|
|
||||||
def clear_slots(self):
|
def clear_slots(self):
|
||||||
|
7
main.py
7
main.py
@ -1,3 +1,4 @@
|
|||||||
|
from SystemActType import SystemActType
|
||||||
from NaturalLanguageUnderstanding import NLU
|
from NaturalLanguageUnderstanding import NLU
|
||||||
from NaturalLanguageGeneration import NLG
|
from NaturalLanguageGeneration import NLG
|
||||||
from DialogueStateTracker import DST
|
from DialogueStateTracker import DST
|
||||||
@ -20,7 +21,6 @@ if __name__ == "__main__":
|
|||||||
state, last_user_act, last_system_act = dst.get_dialogue_state()
|
state, last_user_act, last_system_act = dst.get_dialogue_state()
|
||||||
slots = dst.get_dialogue_slots()
|
slots = dst.get_dialogue_slots()
|
||||||
system_act = dp.chooseTactic()
|
system_act = dp.chooseTactic()
|
||||||
dst.system_update(system_act)
|
|
||||||
print('------ stan ------')
|
print('------ stan ------')
|
||||||
print(state, last_user_act, last_system_act)
|
print(state, last_user_act, last_system_act)
|
||||||
print('------ przechowywane sloty ------')
|
print('------ przechowywane sloty ------')
|
||||||
@ -31,6 +31,5 @@ if __name__ == "__main__":
|
|||||||
print('-----------------------------------')
|
print('-----------------------------------')
|
||||||
#text = nlg.toText(system_act)
|
#text = nlg.toText(system_act)
|
||||||
|
|
||||||
#print(text)
|
if system_act.getActType() == SystemActType.BYE:
|
||||||
#if system_act.isDialogFinished():
|
break
|
||||||
# break
|
|
||||||
|
Loading…
Reference in New Issue
Block a user