add basic main, dp, dst, nlg
This commit is contained in:
parent
22b61e0dd7
commit
de703d016c
28
dp.py
Normal file
28
dp.py
Normal file
@ -0,0 +1,28 @@
|
||||
from SystemActType import SystemActType
|
||||
from UserActType import UserActType
|
||||
|
||||
|
||||
class DP:
|
||||
def update_system_action(self, state, last_user_act, last_system_act, slots):
|
||||
if state == UserActType['order']:
|
||||
if 'kind' not in slots[state]:
|
||||
system_act = {'act': SystemActType['request'], 'slot': 'kind'}
|
||||
return system_act
|
||||
elif 'size' not in slots[state]:
|
||||
system_act = {'act': SystemActType['request'], 'slot': 'size'}
|
||||
return system_act
|
||||
elif 'plates' not in slots[state]:
|
||||
system_act = {'act': SystemActType['request'], 'slot': 'plates'}
|
||||
return system_act
|
||||
|
||||
else:
|
||||
if last_user_act == UserActType['hello']:
|
||||
return {'act': SystemActType['welcomemsg']}
|
||||
elif last_user_act == UserActType['bye']:
|
||||
return {'act': SystemActType['bye']}
|
||||
else:
|
||||
return {'act': SystemActType['canthelp']}
|
||||
|
||||
|
||||
|
||||
|
67
dst.py
67
dst.py
@ -1,34 +1,51 @@
|
||||
class SimpleRuleDST:
|
||||
from UserActType import UserActType
|
||||
|
||||
class DST:
|
||||
def __init__(self):
|
||||
self.state = None
|
||||
self.init_session()
|
||||
self.last_user_act = None
|
||||
self.last_system_act = None
|
||||
self.slots = self.init_slots()
|
||||
|
||||
def update(self, user_act=None):
|
||||
for intent, domain, slot, value in user_act:
|
||||
domain = domain.lower()
|
||||
intent = intent.lower()
|
||||
act = user_act['act']
|
||||
self.last_user_act = act
|
||||
if not self.state:
|
||||
if act in [UserActType['order'],
|
||||
UserActType['delivery'],
|
||||
UserActType['payment'],
|
||||
UserActType['price']]:
|
||||
self.state = act
|
||||
|
||||
for slot, value in user_act['slots']:
|
||||
slot = slot.lower()
|
||||
value = slot.lower()
|
||||
value = value.lower()
|
||||
|
||||
if domain not in self.state['belief_state']:
|
||||
continue
|
||||
|
||||
if intent == 'inform':
|
||||
if slot == 'none' or slot == '':
|
||||
continue
|
||||
|
||||
domain_dic = self.state['belief_state'][domain]
|
||||
|
||||
if slot in domain_dic:
|
||||
self.state['belief_state'][domain][slot] = value
|
||||
|
||||
elif intent == 'request':
|
||||
if domain not in self.state['request_state']:
|
||||
self.state['request_state'][domain] = {}
|
||||
if slot not in self.state['request_state'][domain]:
|
||||
self.state['request_state'][domain][slot] = 0
|
||||
self.slots[act][slot] = value
|
||||
|
||||
return self.state
|
||||
|
||||
def init_session(self):
|
||||
self.state = None
|
||||
def get_dialogue_state_tracker_state(self):
|
||||
return self.state, self.last_user_act, self.last_system_act, self.slots
|
||||
|
||||
def get_state(self):
|
||||
return self.state
|
||||
|
||||
def get_last_user_act(self):
|
||||
return self.last_user_act
|
||||
|
||||
def get_last_system_act(self):
|
||||
return self.last_system_act
|
||||
|
||||
def get_slots(self):
|
||||
return self.slots
|
||||
|
||||
def update_last_user_act(self, new_user_act):
|
||||
self.last_user_act = new_user_act
|
||||
|
||||
def update_last_system_act(self, new_system_act):
|
||||
self.last_system_act = new_system_act
|
||||
|
||||
def init_slots(self):
|
||||
return dict(order={}, delivery={}, payment={}, price={})
|
||||
|
||||
|
38
main.py
Normal file
38
main.py
Normal file
@ -0,0 +1,38 @@
|
||||
from UserActType import UserActType
|
||||
from nlu import nlu
|
||||
from dst import DST
|
||||
from dp import DP
|
||||
from nlg import NLG
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
dst = DST()
|
||||
dp = DP()
|
||||
nlg = NLG()
|
||||
|
||||
print("Witamy w restauracji πzza. W czym mogę pomóc?")
|
||||
|
||||
while True:
|
||||
user_input = input("$")
|
||||
# get user act frame from user input with Natural Language Understanding
|
||||
user_act_frame = nlu(user_input)
|
||||
# print('NLU', user_act_frame)
|
||||
# update Dialogue State Tracker with new user act frame
|
||||
dst.update(user_act_frame)
|
||||
state, last_user_act, last_system_act, slots = dst.get_dialogue_state_tracker_state()
|
||||
# print('state', state)
|
||||
# print('last_user_act', last_user_act)
|
||||
# print('last_system_act', last_system_act)
|
||||
# print('slots', slots)
|
||||
|
||||
# get system act frame which decides what's next from Dialogue Policy
|
||||
system_act_frame = dp.update_system_action(state, last_user_act, last_system_act, slots)
|
||||
dst.update_last_system_act(system_act_frame)
|
||||
# print('system_act_frame', system_act_frame)
|
||||
|
||||
# generate response based on system act frame
|
||||
system_response = nlg.generate_response(state, last_user_act, last_system_act, slots, system_act_frame)
|
||||
print('BOT:', system_response)
|
||||
|
||||
if user_act_frame['act'] == UserActType['bye']:
|
||||
break
|
18
nlg.py
Normal file
18
nlg.py
Normal file
@ -0,0 +1,18 @@
|
||||
from SystemActType import SystemActType
|
||||
from UserActType import UserActType
|
||||
|
||||
|
||||
class NLG:
|
||||
def generate_response(self, state, last_user_act, last_system_act, slots, system_act):
|
||||
if state == UserActType['order']:
|
||||
if system_act.act == SystemActType['request']:
|
||||
if system_act.slot == 'kind':
|
||||
return "Jaką pizzę chcesz zamówić?"
|
||||
elif system_act.slot == 'size':
|
||||
return 'Jakiego rozmiaru chcesz pizzę?'
|
||||
elif system_act.slot == 'plates':
|
||||
return 'Dla ilu osób ma to być?'
|
||||
elif last_user_act == UserActType['hello']:
|
||||
return "Dzień dobry, w czym mogę pomóc?"
|
||||
else:
|
||||
return "Przepraszam. Zdanie nie jest mi zrozumiałe. Spróbuj je sformułować w inny sposób."
|
Loading…
Reference in New Issue
Block a user