2024-05-27 21:50:25 +02:00
|
|
|
import json
|
|
|
|
from convlab.dst.dst import DST
|
2024-04-21 10:09:03 +02:00
|
|
|
|
2024-05-27 21:50:25 +02:00
|
|
|
|
|
|
|
def default_state():
|
|
|
|
return {
|
|
|
|
'belief_state': {
|
|
|
|
'hotel': {
|
|
|
|
'info': {
|
|
|
|
'name': '',
|
|
|
|
'area': '',
|
|
|
|
'parking': '',
|
|
|
|
'price range': '',
|
|
|
|
'stars': '',
|
|
|
|
'internet': '',
|
|
|
|
'type': ''
|
|
|
|
},
|
|
|
|
'booking': {
|
|
|
|
'book stay': '',
|
|
|
|
'book day': '',
|
|
|
|
'book people': ''
|
|
|
|
}
|
|
|
|
}
|
|
|
|
},
|
|
|
|
'request_state': {},
|
|
|
|
'history': [],
|
|
|
|
'user_action': [],
|
|
|
|
'system_action': [],
|
|
|
|
'terminated': False,
|
|
|
|
'booked': []
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class DialogueStateTracker(DST):
|
|
|
|
def __init__(self):
|
|
|
|
DST.__init__(self)
|
|
|
|
self.state = default_state()
|
|
|
|
with open('./hotels_data.json') as f:
|
|
|
|
self.value_dict = json.load(f)
|
|
|
|
|
|
|
|
def update(self, user_act=None):
|
|
|
|
for intent, domain, slot, value in user_act:
|
|
|
|
domain = domain.lower()
|
|
|
|
intent = intent.lower()
|
|
|
|
slot = slot.lower()
|
|
|
|
|
|
|
|
if domain not in self.state['belief_state']:
|
|
|
|
continue
|
|
|
|
|
|
|
|
if intent == 'inform':
|
|
|
|
if slot == 'none' or slot == '' or value == 'dontcare':
|
|
|
|
continue
|
|
|
|
|
|
|
|
domain_dic = self.state['belief_state'][domain]['info']
|
|
|
|
|
|
|
|
if slot in domain_dic:
|
|
|
|
nvalue = self.normalize_value(self.value_dict, domain, slot, value)
|
|
|
|
self.state['belief_state'][domain]['info'][slot] = nvalue
|
|
|
|
|
|
|
|
elif intent == 'request':
|
|
|
|
if domain not in self.state['request_state']:
|
|
|
|
self.state['request_state'][domain] = {}
|
|
|
|
if slot not in self.state['request_state'][domain]:
|
|
|
|
self.state['request_state'][domain][slot] = 0
|
|
|
|
|
|
|
|
return self.state
|
|
|
|
|
|
|
|
def normalize_value(self, value_dict, domain, slot, value):
|
|
|
|
normalized_value = value.lower().strip()
|
|
|
|
if domain in value_dict and slot in value_dict[domain]:
|
|
|
|
possible_values = value_dict[domain][slot]
|
|
|
|
if isinstance(possible_values, dict) and normalized_value in possible_values:
|
|
|
|
return possible_values[normalized_value]
|
|
|
|
return value
|
|
|
|
|
|
|
|
def init_session(self):
|
|
|
|
self.state = default_state()
|