2024-05-27 21:50:25 +02:00
|
|
|
from collections import defaultdict
|
|
|
|
import json
|
2024-06-08 12:55:12 +02:00
|
|
|
import random
|
|
|
|
import string
|
2024-05-27 21:50:25 +02:00
|
|
|
from copy import deepcopy
|
|
|
|
from convlab.policy.policy import Policy
|
|
|
|
|
|
|
|
db_path = './hotels_data.json'
|
|
|
|
|
|
|
|
|
2024-06-08 12:55:12 +02:00
|
|
|
def generate_reference_number(length=8):
|
|
|
|
letters_and_digits = string.ascii_uppercase + string.digits
|
|
|
|
reference_number = ''.join(random.choice(letters_and_digits) for _ in range(length))
|
|
|
|
return reference_number
|
|
|
|
|
2024-05-27 21:50:25 +02:00
|
|
|
class DialoguePolicy(Policy):
|
2024-06-07 00:38:21 +02:00
|
|
|
info_dict = None
|
2024-05-27 21:50:25 +02:00
|
|
|
def __init__(self):
|
|
|
|
Policy.__init__(self)
|
|
|
|
self.db = self.load_database(db_path)
|
|
|
|
|
|
|
|
def load_database(self, db_path):
|
|
|
|
with open(db_path, 'r', encoding='utf-8') as f:
|
|
|
|
return json.load(f)
|
|
|
|
|
|
|
|
def query(self, domain, constraints):
|
|
|
|
if domain != 'hotel':
|
|
|
|
return []
|
|
|
|
|
|
|
|
results = []
|
|
|
|
for entry in self.db:
|
|
|
|
match = all(entry.get(key) == value for key, value in constraints)
|
|
|
|
if match:
|
|
|
|
results.append(entry)
|
|
|
|
return results
|
|
|
|
|
|
|
|
def predict(self, state):
|
|
|
|
self.results = []
|
|
|
|
system_action = defaultdict(list)
|
|
|
|
user_action = defaultdict(list)
|
|
|
|
|
|
|
|
for intent, domain, slot, value in state['user_action']:
|
|
|
|
user_action[(domain.lower(), intent.lower())].append((slot.lower(), value))
|
|
|
|
|
|
|
|
for user_act in user_action:
|
|
|
|
self.update_system_action(user_act, user_action, state, system_action)
|
|
|
|
|
|
|
|
if any(True for slots in user_action.values() for (slot, _) in slots if
|
|
|
|
slot in ['book stay', 'book day', 'book people']):
|
|
|
|
if self.results:
|
2024-06-08 12:55:12 +02:00
|
|
|
reference_number = generate_reference_number()
|
|
|
|
system_action = {('Booking', 'Book'): [["Ref", reference_number]]}
|
2024-05-27 21:50:25 +02:00
|
|
|
|
|
|
|
system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for
|
|
|
|
slot, value in slots]
|
|
|
|
state['system_action'] = system_acts
|
|
|
|
return system_acts
|
|
|
|
|
|
|
|
def update_system_action(self, user_act, user_action, state, system_action):
|
|
|
|
domain, intent = user_act
|
2024-06-08 14:11:17 +02:00
|
|
|
if domain in state['belief_state']:
|
|
|
|
constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != '']
|
|
|
|
# print(f"Constraints: {constraints}")
|
|
|
|
self.results = deepcopy(self.query(domain.lower(), constraints))
|
|
|
|
# print(f"Query results: {self.results}")
|
|
|
|
|
|
|
|
if intent == 'request':
|
|
|
|
if len(self.results) == 0:
|
|
|
|
system_action[(domain, 'NoOffer')] = []
|
|
|
|
else:
|
|
|
|
for slot in user_action[user_act]:
|
|
|
|
if slot[0] in self.results[0]:
|
|
|
|
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
|
|
|
|
|
|
|
|
elif intent == 'inform':
|
|
|
|
if len(self.results) == 0:
|
|
|
|
system_action[(domain, 'NoOffer')] = []
|
|
|
|
else:
|
|
|
|
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
|
|
|
|
choice = self.results[0]
|
|
|
|
|
|
|
|
if domain in ["hotel"]:
|
|
|
|
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
|
|
|
|
for slot in state['belief_state'][domain]['info']:
|
|
|
|
if choice.get(slot):
|
|
|
|
state['belief_state'][domain]['info'][slot] = choice[slot]
|