GOATS/DialoguePolicy.py

87 lines
3.4 KiB
Python
Raw Permalink Normal View History

from collections import defaultdict
import json
2024-06-08 12:55:12 +02:00
import random
import string
from copy import deepcopy
from convlab.policy.policy import Policy
db_path = './hotels_data.json'
2024-06-08 12:55:12 +02:00
def generate_reference_number(length=8):
letters_and_digits = string.ascii_uppercase + string.digits
reference_number = ''.join(random.choice(letters_and_digits) for _ in range(length))
return reference_number
class DialoguePolicy(Policy):
2024-06-07 00:38:21 +02:00
info_dict = None
def __init__(self):
Policy.__init__(self)
self.db = self.load_database(db_path)
def load_database(self, db_path):
with open(db_path, 'r', encoding='utf-8') as f:
return json.load(f)
def query(self, domain, constraints):
if domain != 'hotel':
return []
results = []
for entry in self.db:
match = all(entry.get(key) == value for key, value in constraints)
if match:
results.append(entry)
return results
def predict(self, state):
self.results = []
system_action = defaultdict(list)
user_action = defaultdict(list)
for intent, domain, slot, value in state['user_action']:
user_action[(domain.lower(), intent.lower())].append((slot.lower(), value))
for user_act in user_action:
self.update_system_action(user_act, user_action, state, system_action)
if any(True for slots in user_action.values() for (slot, _) in slots if
slot in ['book stay', 'book day', 'book people']):
if self.results:
2024-06-08 12:55:12 +02:00
reference_number = generate_reference_number()
system_action = {('Booking', 'Book'): [["Ref", reference_number]]}
system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for
slot, value in slots]
state['system_action'] = system_acts
return system_acts
def update_system_action(self, user_act, user_action, state, system_action):
domain, intent = user_act
2024-06-08 14:11:17 +02:00
if domain in state['belief_state']:
constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != '']
# print(f"Constraints: {constraints}")
self.results = deepcopy(self.query(domain.lower(), constraints))
# print(f"Query results: {self.results}")
if intent == 'request':
if len(self.results) == 0:
system_action[(domain, 'NoOffer')] = []
else:
for slot in user_action[user_act]:
if slot[0] in self.results[0]:
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
elif intent == 'inform':
if len(self.results) == 0:
system_action[(domain, 'NoOffer')] = []
else:
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
choice = self.results[0]
if domain in ["hotel"]:
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
for slot in state['belief_state'][domain]['info']:
if choice.get(slot):
state['belief_state'][domain]['info'][slot] = choice[slot]