10 KiB
10 KiB
import json
class Rules_DST():
def __init__(self):
self.state = json.load(open('data.json'))
def update_user(self, user_acts=None):
for intent, domain, slot, value in user_acts:
domain = domain.lower()
intent = intent.lower()
slot = slot.lower()
if intent == 'start_conversation':
continue
elif intent == 'end_conversation':
self.state = json.load(open('data.json'))
elif domain not in self.state['belief_state']:
continue
elif 'inform' in intent:
if (slot == 'inform'):
continue
if(domain in slot):
slot.replace(domain + "/", '')
domain_dic = self.state['belief_state'][domain]
if slot in domain_dic:
self.state['belief_state'][domain][slot] = value
elif intent == 'request':
if domain not in self.state['request_state']:
self.state['request_state'][domain] = {}
if slot not in self.state['request_state'][domain]:
self.state['request_state'][domain][slot] = 0
else:
self.state['request_state'][domain][slot] = value
elif intent == 'start_conversation':
self.state["user_action"].append([intent, domain, slot, value])
continue
elif intent == 'end_conversation':
self.state = json.load(open('data.json'))
self.state["user_action"].append([intent, domain, slot, value])
return self.state
dst = Rules_DST()
dst.state
{'user_action': [], 'system_action': [], 'belief_state': {'food': {'name': '', 'type': '', 'price range': '', 'size': '', 'ingredients': ''}, 'drink': {'name': '', 'price range': '', 'size': ''}, 'sauce': {'name': '', 'price range': '', 'size': ''}, 'order': {'type': '', 'price range': '', 'restaurant_name': '', 'area': '', 'book time': '', 'book day': ''}, 'booking': {'restaurant_name': '', 'area': '', 'book time': '', 'book day': '', 'book people': ''}, 'payment': {'type': '', 'amount': '', 'vat': ''}}, 'request_state': {}, 'terminated': False, 'history': []}
dst.state['user_action']
[]
dst.update_user([['star_conversation',"","",""], ['inform', 'drink', 'size', 'duża']])
dst.state['belief_state']['food']
{'name': '', 'type': '', 'price range': '', 'size': '', 'ingredients': ''}
dst.state['user_action']
[['inform', 'drink', 'size', 'duża']]
dst.update_user([['request', 'drink', 'price range', '?']])
dst.state['request_state']
{'drink': {'price range': 0}}
dst.update_user([['inform', 'food', 'type', 'pizza'], ['inform', 'food', 'size', 'duża']])
dst.state['belief_state']['food']
{'name': '', 'type': 'pizza', 'price range': '', 'size': 'duża', 'ingredients': ''}
from collections import defaultdict
import jmespath
class DP():
def __init__(self):
with open('database.json', encoding='utf-8-sig') as json_file:
self.db = json.load(json_file)
def predict(self, state):
self.results = []
system_action = defaultdict(list)
user_action = defaultdict(list)
system_acts = []
for idx in range(len(state['user_action'])):
intent, domain, slot, value = state['user_action'][idx]
user_action[(domain, intent)].append((slot, value))
for user_act in user_action:
system_acts.append(self.update_system_action(user_act, user_action, state, system_action))
state['system_action'] = system_acts
return system_acts[-1]
def update_system_action(self, user_act, user_action, state, system_action):
domain, intent = user_act
#Reguła 3
if intent == 'end_conversation':
return None
constraints = [(slot, value) for slot, value in state['belief_state'][domain].items() if value != '']
# Reguła 1
if intent == 'request':
if len(self.results) == 0:
system_action[(domain, 'NoOffer')] = []
else:
for slot in user_action[user_act]:
if slot[0] in self.results[0]:
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
# Reguła 2
elif intent == 'inform':
if len(constraints)>1:
arg=f"{constraints[0]}".replace(f"\'{constraints[0][0]}\'",f"{constraints[0][0]}")
arg = arg.replace("[","").replace("]","")
for cons in constraints[1:]:
arg+=f" && contains{cons}".replace(f"\'{cons[0]}\'",f"{cons[0]}").replace("[","").replace("]","")
else:
arg=f"{constraints}".replace(f"\'{constraints[0]}\'",f"{constraints[0]}").replace("[","").replace("]","").replace("(\'","(").replace("\',",",")
self.results = jmespath.search(f"database.{domain}[?contains{arg} == `true` ]", self.db)
if len(self.results) == 0:
system_action[(domain, 'NoOffer')] = []
else:
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
choice = self.results[0]
if domain in ["food", "drink", "police", "sauce", "order", "booking", "payment"]:
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
return system_action
dp= DP()
dp.predict(dst.state)
defaultdict(list, {('drink', 'Inform'): [['Choice', '1'], ['price range', 'średnia']], ('drink', 'Recommend'): [['Name', 'lemoniada']], ('food', 'Inform'): [['Choice', '4']], ('food', 'Recommend'): [['Name', 'pizza margherita']]})