add basic main, dp, dst, nlg
This commit is contained in:
parent
22b61e0dd7
commit
de703d016c
28
dp.py
Normal file
28
dp.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from SystemActType import SystemActType
|
||||||
|
from UserActType import UserActType
|
||||||
|
|
||||||
|
|
||||||
|
class DP:
|
||||||
|
def update_system_action(self, state, last_user_act, last_system_act, slots):
|
||||||
|
if state == UserActType['order']:
|
||||||
|
if 'kind' not in slots[state]:
|
||||||
|
system_act = {'act': SystemActType['request'], 'slot': 'kind'}
|
||||||
|
return system_act
|
||||||
|
elif 'size' not in slots[state]:
|
||||||
|
system_act = {'act': SystemActType['request'], 'slot': 'size'}
|
||||||
|
return system_act
|
||||||
|
elif 'plates' not in slots[state]:
|
||||||
|
system_act = {'act': SystemActType['request'], 'slot': 'plates'}
|
||||||
|
return system_act
|
||||||
|
|
||||||
|
else:
|
||||||
|
if last_user_act == UserActType['hello']:
|
||||||
|
return {'act': SystemActType['welcomemsg']}
|
||||||
|
elif last_user_act == UserActType['bye']:
|
||||||
|
return {'act': SystemActType['bye']}
|
||||||
|
else:
|
||||||
|
return {'act': SystemActType['canthelp']}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
67
dst.py
67
dst.py
@ -1,34 +1,51 @@
|
|||||||
class SimpleRuleDST:
|
from UserActType import UserActType
|
||||||
|
|
||||||
|
class DST:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.state = None
|
self.state = None
|
||||||
self.init_session()
|
self.last_user_act = None
|
||||||
|
self.last_system_act = None
|
||||||
|
self.slots = self.init_slots()
|
||||||
|
|
||||||
def update(self, user_act=None):
|
def update(self, user_act=None):
|
||||||
for intent, domain, slot, value in user_act:
|
act = user_act['act']
|
||||||
domain = domain.lower()
|
self.last_user_act = act
|
||||||
intent = intent.lower()
|
if not self.state:
|
||||||
|
if act in [UserActType['order'],
|
||||||
|
UserActType['delivery'],
|
||||||
|
UserActType['payment'],
|
||||||
|
UserActType['price']]:
|
||||||
|
self.state = act
|
||||||
|
|
||||||
|
for slot, value in user_act['slots']:
|
||||||
slot = slot.lower()
|
slot = slot.lower()
|
||||||
value = slot.lower()
|
value = value.lower()
|
||||||
|
|
||||||
if domain not in self.state['belief_state']:
|
self.slots[act][slot] = value
|
||||||
continue
|
|
||||||
|
|
||||||
if intent == 'inform':
|
|
||||||
if slot == 'none' or slot == '':
|
|
||||||
continue
|
|
||||||
|
|
||||||
domain_dic = self.state['belief_state'][domain]
|
|
||||||
|
|
||||||
if slot in domain_dic:
|
|
||||||
self.state['belief_state'][domain][slot] = value
|
|
||||||
|
|
||||||
elif intent == 'request':
|
|
||||||
if domain not in self.state['request_state']:
|
|
||||||
self.state['request_state'][domain] = {}
|
|
||||||
if slot not in self.state['request_state'][domain]:
|
|
||||||
self.state['request_state'][domain][slot] = 0
|
|
||||||
|
|
||||||
return self.state
|
return self.state
|
||||||
|
|
||||||
def init_session(self):
|
def get_dialogue_state_tracker_state(self):
|
||||||
self.state = None
|
return self.state, self.last_user_act, self.last_system_act, self.slots
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return self.state
|
||||||
|
|
||||||
|
def get_last_user_act(self):
|
||||||
|
return self.last_user_act
|
||||||
|
|
||||||
|
def get_last_system_act(self):
|
||||||
|
return self.last_system_act
|
||||||
|
|
||||||
|
def get_slots(self):
|
||||||
|
return self.slots
|
||||||
|
|
||||||
|
def update_last_user_act(self, new_user_act):
|
||||||
|
self.last_user_act = new_user_act
|
||||||
|
|
||||||
|
def update_last_system_act(self, new_system_act):
|
||||||
|
self.last_system_act = new_system_act
|
||||||
|
|
||||||
|
def init_slots(self):
|
||||||
|
return dict(order={}, delivery={}, payment={}, price={})
|
||||||
|
|
||||||
|
38
main.py
Normal file
38
main.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from UserActType import UserActType
|
||||||
|
from nlu import nlu
|
||||||
|
from dst import DST
|
||||||
|
from dp import DP
|
||||||
|
from nlg import NLG
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
dst = DST()
|
||||||
|
dp = DP()
|
||||||
|
nlg = NLG()
|
||||||
|
|
||||||
|
print("Witamy w restauracji πzza. W czym mogę pomóc?")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input("$")
|
||||||
|
# get user act frame from user input with Natural Language Understanding
|
||||||
|
user_act_frame = nlu(user_input)
|
||||||
|
# print('NLU', user_act_frame)
|
||||||
|
# update Dialogue State Tracker with new user act frame
|
||||||
|
dst.update(user_act_frame)
|
||||||
|
state, last_user_act, last_system_act, slots = dst.get_dialogue_state_tracker_state()
|
||||||
|
# print('state', state)
|
||||||
|
# print('last_user_act', last_user_act)
|
||||||
|
# print('last_system_act', last_system_act)
|
||||||
|
# print('slots', slots)
|
||||||
|
|
||||||
|
# get system act frame which decides what's next from Dialogue Policy
|
||||||
|
system_act_frame = dp.update_system_action(state, last_user_act, last_system_act, slots)
|
||||||
|
dst.update_last_system_act(system_act_frame)
|
||||||
|
# print('system_act_frame', system_act_frame)
|
||||||
|
|
||||||
|
# generate response based on system act frame
|
||||||
|
system_response = nlg.generate_response(state, last_user_act, last_system_act, slots, system_act_frame)
|
||||||
|
print('BOT:', system_response)
|
||||||
|
|
||||||
|
if user_act_frame['act'] == UserActType['bye']:
|
||||||
|
break
|
18
nlg.py
Normal file
18
nlg.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from SystemActType import SystemActType
|
||||||
|
from UserActType import UserActType
|
||||||
|
|
||||||
|
|
||||||
|
class NLG:
|
||||||
|
def generate_response(self, state, last_user_act, last_system_act, slots, system_act):
|
||||||
|
if state == UserActType['order']:
|
||||||
|
if system_act.act == SystemActType['request']:
|
||||||
|
if system_act.slot == 'kind':
|
||||||
|
return "Jaką pizzę chcesz zamówić?"
|
||||||
|
elif system_act.slot == 'size':
|
||||||
|
return 'Jakiego rozmiaru chcesz pizzę?'
|
||||||
|
elif system_act.slot == 'plates':
|
||||||
|
return 'Dla ilu osób ma to być?'
|
||||||
|
elif last_user_act == UserActType['hello']:
|
||||||
|
return "Dzień dobry, w czym mogę pomóc?"
|
||||||
|
else:
|
||||||
|
return "Przepraszam. Zdanie nie jest mi zrozumiałe. Spróbuj je sformułować w inny sposób."
|
Loading…
Reference in New Issue
Block a user