Compare commits
5 Commits
18858dddc8
...
de703d016c
Author | SHA1 | Date | |
---|---|---|---|
![]() |
de703d016c | ||
![]() |
22b61e0dd7 | ||
![]() |
d63f926b49 | ||
![]() |
5bf25d0b2b | ||
![]() |
cdd9ce17e9 |
3
.gitignore
vendored
3
.gitignore
vendored
@ -1 +1,2 @@
|
||||
.idea/
|
||||
.idea/
|
||||
__pycache__/
|
18
SystemActType.py
Normal file
18
SystemActType.py
Normal file
@ -0,0 +1,18 @@
|
||||
SystemActType = dict(
|
||||
affirm='affirm',
|
||||
bye='bye',
|
||||
canthear='canthear',
|
||||
confirm_domain='confirm-domain',
|
||||
negate='negate',
|
||||
repeat='repeat',
|
||||
reqmore='reqmore',
|
||||
welcomemsg='welcomemsg',
|
||||
canthelp='canthelp',
|
||||
canthelp_missing_slot_value='canthelp.missing_slot_value',
|
||||
expl_conf='expl-conf',
|
||||
impl_conf='inform-conf',
|
||||
inform='infomr',
|
||||
offer='offer',
|
||||
request='request',
|
||||
select='select'
|
||||
)
|
23
UserActType.py
Normal file
23
UserActType.py
Normal file
@ -0,0 +1,23 @@
|
||||
UserActType = dict(
|
||||
ack='ack',
|
||||
affirm='affirm',
|
||||
bye='bye',
|
||||
hello='hello',
|
||||
help='help',
|
||||
negate='negate',
|
||||
null='null',
|
||||
repeat='repeat',
|
||||
requalts='requalts',
|
||||
reqmore='reqmore',
|
||||
restart='restart',
|
||||
silence='silence',
|
||||
thankyou='thankyou',
|
||||
confirm='confirm',
|
||||
deny='deny',
|
||||
inform='inform',
|
||||
request='request',
|
||||
order='order',
|
||||
delivery='delivery',
|
||||
payment='payment',
|
||||
price='price'
|
||||
)
|
28
dp.py
Normal file
28
dp.py
Normal file
@ -0,0 +1,28 @@
|
||||
from SystemActType import SystemActType
|
||||
from UserActType import UserActType
|
||||
|
||||
|
||||
class DP:
|
||||
def update_system_action(self, state, last_user_act, last_system_act, slots):
|
||||
if state == UserActType['order']:
|
||||
if 'kind' not in slots[state]:
|
||||
system_act = {'act': SystemActType['request'], 'slot': 'kind'}
|
||||
return system_act
|
||||
elif 'size' not in slots[state]:
|
||||
system_act = {'act': SystemActType['request'], 'slot': 'size'}
|
||||
return system_act
|
||||
elif 'plates' not in slots[state]:
|
||||
system_act = {'act': SystemActType['request'], 'slot': 'plates'}
|
||||
return system_act
|
||||
|
||||
else:
|
||||
if last_user_act == UserActType['hello']:
|
||||
return {'act': SystemActType['welcomemsg']}
|
||||
elif last_user_act == UserActType['bye']:
|
||||
return {'act': SystemActType['bye']}
|
||||
else:
|
||||
return {'act': SystemActType['canthelp']}
|
||||
|
||||
|
||||
|
||||
|
67
dst.py
67
dst.py
@ -1,34 +1,51 @@
|
||||
class SimpleRuleDST:
|
||||
from UserActType import UserActType
|
||||
|
||||
class DST:
|
||||
def __init__(self):
|
||||
self.state = None
|
||||
self.init_session()
|
||||
self.last_user_act = None
|
||||
self.last_system_act = None
|
||||
self.slots = self.init_slots()
|
||||
|
||||
def update(self, user_act=None):
|
||||
for intent, domain, slot, value in user_act:
|
||||
domain = domain.lower()
|
||||
intent = intent.lower()
|
||||
act = user_act['act']
|
||||
self.last_user_act = act
|
||||
if not self.state:
|
||||
if act in [UserActType['order'],
|
||||
UserActType['delivery'],
|
||||
UserActType['payment'],
|
||||
UserActType['price']]:
|
||||
self.state = act
|
||||
|
||||
for slot, value in user_act['slots']:
|
||||
slot = slot.lower()
|
||||
value = slot.lower()
|
||||
value = value.lower()
|
||||
|
||||
if domain not in self.state['belief_state']:
|
||||
continue
|
||||
|
||||
if intent == 'inform':
|
||||
if slot == 'none' or slot == '':
|
||||
continue
|
||||
|
||||
domain_dic = self.state['belief_state'][domain]
|
||||
|
||||
if slot in domain_dic:
|
||||
self.state['belief_state'][domain][slot] = value
|
||||
|
||||
elif intent == 'request':
|
||||
if domain not in self.state['request_state']:
|
||||
self.state['request_state'][domain] = {}
|
||||
if slot not in self.state['request_state'][domain]:
|
||||
self.state['request_state'][domain][slot] = 0
|
||||
self.slots[act][slot] = value
|
||||
|
||||
return self.state
|
||||
|
||||
def init_session(self):
|
||||
self.state = None
|
||||
def get_dialogue_state_tracker_state(self):
|
||||
return self.state, self.last_user_act, self.last_system_act, self.slots
|
||||
|
||||
def get_state(self):
|
||||
return self.state
|
||||
|
||||
def get_last_user_act(self):
|
||||
return self.last_user_act
|
||||
|
||||
def get_last_system_act(self):
|
||||
return self.last_system_act
|
||||
|
||||
def get_slots(self):
|
||||
return self.slots
|
||||
|
||||
def update_last_user_act(self, new_user_act):
|
||||
self.last_user_act = new_user_act
|
||||
|
||||
def update_last_system_act(self, new_system_act):
|
||||
self.last_system_act = new_system_act
|
||||
|
||||
def init_slots(self):
|
||||
return dict(order={}, delivery={}, payment={}, price={})
|
||||
|
||||
|
38
main.py
Normal file
38
main.py
Normal file
@ -0,0 +1,38 @@
|
||||
from UserActType import UserActType
|
||||
from nlu import nlu
|
||||
from dst import DST
|
||||
from dp import DP
|
||||
from nlg import NLG
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
dst = DST()
|
||||
dp = DP()
|
||||
nlg = NLG()
|
||||
|
||||
print("Witamy w restauracji πzza. W czym mogę pomóc?")
|
||||
|
||||
while True:
|
||||
user_input = input("$")
|
||||
# get user act frame from user input with Natural Language Understanding
|
||||
user_act_frame = nlu(user_input)
|
||||
# print('NLU', user_act_frame)
|
||||
# update Dialogue State Tracker with new user act frame
|
||||
dst.update(user_act_frame)
|
||||
state, last_user_act, last_system_act, slots = dst.get_dialogue_state_tracker_state()
|
||||
# print('state', state)
|
||||
# print('last_user_act', last_user_act)
|
||||
# print('last_system_act', last_system_act)
|
||||
# print('slots', slots)
|
||||
|
||||
# get system act frame which decides what's next from Dialogue Policy
|
||||
system_act_frame = dp.update_system_action(state, last_user_act, last_system_act, slots)
|
||||
dst.update_last_system_act(system_act_frame)
|
||||
# print('system_act_frame', system_act_frame)
|
||||
|
||||
# generate response based on system act frame
|
||||
system_response = nlg.generate_response(state, last_user_act, last_system_act, slots, system_act_frame)
|
||||
print('BOT:', system_response)
|
||||
|
||||
if user_act_frame['act'] == UserActType['bye']:
|
||||
break
|
18
nlg.py
Normal file
18
nlg.py
Normal file
@ -0,0 +1,18 @@
|
||||
from SystemActType import SystemActType
|
||||
from UserActType import UserActType
|
||||
|
||||
|
||||
class NLG:
|
||||
def generate_response(self, state, last_user_act, last_system_act, slots, system_act):
|
||||
if state == UserActType['order']:
|
||||
if system_act.act == SystemActType['request']:
|
||||
if system_act.slot == 'kind':
|
||||
return "Jaką pizzę chcesz zamówić?"
|
||||
elif system_act.slot == 'size':
|
||||
return 'Jakiego rozmiaru chcesz pizzę?'
|
||||
elif system_act.slot == 'plates':
|
||||
return 'Dla ilu osób ma to być?'
|
||||
elif last_user_act == UserActType['hello']:
|
||||
return "Dzień dobry, w czym mogę pomóc?"
|
||||
else:
|
||||
return "Przepraszam. Zdanie nie jest mi zrozumiałe. Spróbuj je sformułować w inny sposób."
|
6
nlu.py
6
nlu.py
@ -44,6 +44,6 @@ def nlu(utterance):
|
||||
return {'act': 'null', 'slots': []}
|
||||
|
||||
|
||||
print(nlu('chciałbym zamowic pizze vesuvio XXL na dwie osoby'))
|
||||
print(nlu('na dowoz'))
|
||||
print(nlu('dowoz'))
|
||||
# print(nlu('chciałbym zamowic pizze vesuvio XXL na dwie osoby'))
|
||||
# print(nlu('na dowoz'))
|
||||
# print(nlu('dowoz'))
|
||||
|
@ -1,20 +0,0 @@
|
||||
from enum import Enum, unique
|
||||
|
||||
|
||||
@unique
|
||||
class SystemActType(Enum):
|
||||
affirm = 0,
|
||||
bye = 1,
|
||||
confirm_domain = 2,
|
||||
negate = 3,
|
||||
repeat = 4,
|
||||
reqmore = 5,
|
||||
welcomemsg = 6,
|
||||
canthelp = 7,
|
||||
canthelp_missing_slot_value = 8,
|
||||
expl_conf = 9,
|
||||
impl_conf = 10,
|
||||
inform = 11,
|
||||
offer = 12,
|
||||
request = 13,
|
||||
select = 14
|
Loading…
Reference in New Issue
Block a user