GOATS/DialoguePolicy.py

76 lines
2.9 KiB
Python

from collections import defaultdict
import json
from copy import deepcopy
from convlab.policy.policy import Policy
db_path = './hotels_data.json'
class DialoguePolicy(Policy):
def __init__(self):
Policy.__init__(self)
self.db = self.load_database(db_path)
def load_database(self, db_path):
with open(db_path, 'r', encoding='utf-8') as f:
return json.load(f)
def query(self, domain, constraints):
if domain != 'hotel':
return []
results = []
for entry in self.db:
match = all(entry.get(key) == value for key, value in constraints)
if match:
results.append(entry)
return results
def predict(self, state):
self.results = []
system_action = defaultdict(list)
user_action = defaultdict(list)
for intent, domain, slot, value in state['user_action']:
user_action[(domain.lower(), intent.lower())].append((slot.lower(), value))
for user_act in user_action:
self.update_system_action(user_act, user_action, state, system_action)
if any(True for slots in user_action.values() for (slot, _) in slots if
slot in ['book stay', 'book day', 'book people']):
if self.results:
system_action = {('Booking', 'Book'): [["Ref", self.results[0].get('Ref', 'N/A')]]}
system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for
slot, value in slots]
state['system_action'] = system_acts
return system_acts
def update_system_action(self, user_act, user_action, state, system_action):
domain, intent = user_act
constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != '']
# print(f"Constraints: {constraints}")
self.results = deepcopy(self.query(domain.lower(), constraints))
# print(f"Query results: {self.results}")
if intent == 'request':
if len(self.results) == 0:
system_action[(domain, 'NoOffer')] = []
else:
for slot in user_action[user_act]:
if slot[0] in self.results[0]:
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
elif intent == 'inform':
if len(self.results) == 0:
system_action[(domain, 'NoOffer')] = []
else:
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
choice = self.results[0]
if domain in ["hotel"]:
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
for slot in state['belief_state'][domain]['info']:
if choice.get(slot):
state['belief_state'][domain]['info'][slot] = choice[slot]