12 KiB
12 KiB
import json
class Rules_DST():
def __init__(self):
self.state = json.load(open('data.json'))
def update_user(self, user_acts=None):
for intent, domain, slot, value in user_acts:
domain = domain.lower()
intent = intent.lower()
slot = slot.lower()
if intent == 'start_conversation':
continue
elif intent == 'end_conversation':
self.state = json.load(open('data.json'))
elif domain not in self.state['belief_state']:
continue
elif 'inform' in intent:
if (slot == 'inform'):
continue
if(domain in slot):
slot.replace(domain + "/", '')
domain_dic = self.state['belief_state'][domain]
if slot in domain_dic:
self.state['belief_state'][domain][slot] = value
elif intent == 'request':
if domain not in self.state['request_state']:
self.state['request_state'][domain] = {}
if slot not in self.state['request_state'][domain]:
self.state['request_state'][domain][slot] = 0
else:
self.state['request_state'][domain][slot] = value
elif intent == 'start_conversation':
self.state["user_action"].append([intent, domain, slot, value])
continue
elif intent == 'end_conversation':
self.state = json.load(open('data.json'))
self.state["user_action"].append([intent, domain, slot, value])
return self.state
dst = Rules_DST()
dst.state
{'user_action': [], 'system_action': [], 'belief_state': {'food': {'name': '', 'type': '', 'price_range': '', 'size': '', 'ingredients': ''}, 'drink': {'name': '', 'price_range': '', 'size': ''}, 'sauce': {'name': '', 'price_range': '', 'size': ''}, 'order': {'type': '', 'price_range': '', 'restaurant_name': '', 'area': '', 'book_time': '', 'book_day': ''}, 'booking': {'restaurant_name': '', 'area': '', 'book_time': '', 'book_day': '', 'book_people': ''}, 'payment': {'type': '', 'amount': '', 'vat': ''}}, 'request_state': {}, 'terminated': False, 'history': []}
dst.state['user_action']
[]
dst.update_user([['star_conversation',"","",""], ['inform', 'drink', 'size', 'duża']])
dst.state['belief_state']['food']
{'name': '', 'type': '', 'price_range': '', 'size': '', 'ingredients': ''}
dst.state['belief_state']['drink']
{'name': '', 'price_range': '', 'size': 'duża'}
dst.state['user_action']
[['inform', 'drink', 'size', 'duża']]
dst.update_user([['request', 'drink', 'price range', '?']])
dst.state['request_state']
{'drink': {'price range': 0}}
dst.update_user([['inform', 'food', 'type', 'pizza'], ['inform', 'food', 'size', 'duża']])
dst.state['belief_state']['food']
{'name': '', 'type': 'pizza', 'price_range': '', 'size': 'duża', 'ingredients': ''}
dst.update_user([['inform', 'sauce', 'size', 'standardowa'], ['inform', 'sauce', 'price_range', 'średnia']])
dst.state['belief_state']['sauce']
{'name': '', 'price_range': 'średnia', 'size': 'standardowa'}
from collections import defaultdict
import jmespath
class DP():
def __init__(self):
with open('database.json', encoding='utf-8-sig') as json_file:
self.db = json.load(json_file)
def predict(self, state):
self.results = []
system_action = defaultdict(list)
user_action = defaultdict(list)
system_acts = []
for idx in range(len(state['user_action'])):
intent, domain, slot, value = state['user_action'][idx]
user_action[(domain, intent)].append((slot, value))
for user_act in user_action:
system_acts.append(self.update_system_action(user_act, user_action, state, system_action))
state['system_action'] = system_acts
return system_acts[-1]
def update_system_action(self, user_act, user_action, state, system_action):
domain, intent = user_act
#Reguła 3
if intent == 'end_conversation':
return None
constraints = [(slot, value) for slot, value in state['belief_state'][domain].items() if value != '']
# Reguła 1
if intent == 'request':
if len(self.results) == 0:
system_action[(domain, 'NoOffer')] = []
else:
for slot in user_action[user_act]:
if slot[0] in self.results[0]:
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
# Reguła 2
elif intent == 'inform':
if len(constraints)>1:
arg=f"{constraints[0]}".replace(f"\'{constraints[0][0]}\'",f"{constraints[0][0]}")
arg = arg.replace("[","").replace("]","")
for cons in constraints[1:]:
arg+=f" && contains{cons}".replace(f"\'{cons[0]}\'",f"{cons[0]}").replace("[","").replace("]","")
else:
arg=f"{constraints}".replace(f"\'{constraints[0]}\'",f"{constraints[0]}").replace("[","").replace("]","").replace("(\'","(").replace("\',",",")
print(arg)
print(jmespath.search(f"database.{domain}[?contains{arg} == `true` ]", self.db) )
self.results = jmespath.search(f"database.{domain}[?contains{arg} == `true` ]", self.db)
if len(self.results) == 0:
system_action[(domain, 'NoOffer')] = []
else:
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
choice = self.results[0]
if domain in ["food", "drink", "sauce"]:
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
elif domain in ["order", "booking", "payment"]:
system_action[(domain, 'Recommend')].append(['Type', choice['type']])
return system_action
dp= DP()
dp.predict(dst.state)
(size, 'duża') [{'name': 'lemoniada', 'price_range': 'średnia', 'size': 'duża'}] (type, 'pizza') && contains(size, 'duża') [{'name': 'pizza margherita', 'type': 'pizza', 'price_range': 'średnia', 'size': 'duża', 'ingredients': 'sos pomidorowy, ser mozzarella, bazylia'}, {'name': 'pizza vegetariana', 'type': 'pizza', 'price_range': 'wysoka', 'size': 'duża', 'ingredients': 'sos pomidorowy, mozzarella, warzywa (papryka, cebula, pomidory, pieczarki), oregano'}, {'name': 'pizza hawajska', 'type': 'pizza', 'price_range': 'wysoka', 'size': 'duża', 'ingredients': 'sos pomidorowy, mozzarella, szynka, ananas, oregano'}, {'name': 'pizza capricciosa', 'type': 'pizza', 'price_range': 'wysoka', 'size': 'duża', 'ingredients': 'sos pomidorowy, mozzarella, szynka, pieczarki, oregano'}] (price_range, 'średnia') && contains(size, 'standardowa') []
defaultdict(list, {('drink', 'Inform'): [['Choice', '1']], ('drink', 'Recommend'): [['Name', 'lemoniada']], ('food', 'Inform'): [['Choice', '4']], ('food', 'Recommend'): [['Name', 'pizza margherita']], ('sauce', 'NoOffer'): []})