GOATS/DialogueStateTracker.py

77 lines
2.4 KiB
Python
Raw Normal View History

import json
from convlab.dst.dst import DST
2024-04-21 10:09:03 +02:00
def default_state():
return {
'belief_state': {
'hotel': {
'info': {
'name': '',
'area': '',
'parking': '',
'price range': '',
'stars': '',
'internet': '',
'type': ''
},
'booking': {
'book stay': '',
'book day': '',
'book people': ''
}
}
},
'request_state': {},
'history': [],
'user_action': [],
'system_action': [],
'terminated': False,
'booked': []
}
class DialogueStateTracker(DST):
def __init__(self):
DST.__init__(self)
self.state = default_state()
with open('./hotels_data.json') as f:
self.value_dict = json.load(f)
def update(self, user_act=None):
for intent, domain, slot, value in user_act:
domain = domain.lower()
intent = intent.lower()
slot = slot.lower()
if domain not in self.state['belief_state']:
continue
if intent == 'inform':
if slot == 'none' or slot == '' or value == 'dontcare':
continue
domain_dic = self.state['belief_state'][domain]['info']
if slot in domain_dic:
nvalue = self.normalize_value(self.value_dict, domain, slot, value)
self.state['belief_state'][domain]['info'][slot] = nvalue
elif intent == 'request':
if domain not in self.state['request_state']:
self.state['request_state'][domain] = {}
if slot not in self.state['request_state'][domain]:
self.state['request_state'][domain][slot] = 0
return self.state
def normalize_value(self, value_dict, domain, slot, value):
normalized_value = value.lower().strip()
if domain in value_dict and slot in value_dict[domain]:
possible_values = value_dict[domain][slot]
if isinstance(possible_values, dict) and normalized_value in possible_values:
return possible_values[normalized_value]
return value
def init_session(self):
self.state = default_state()