48 lines
1.7 KiB
Python
48 lines
1.7 KiB
Python
import json
|
|
from convlab2.dst.dst import DST
|
|
from convlab2.dst.rule.multiwoz.dst_util import normalize_value
|
|
|
|
|
|
class Rules_DST(DST): #Dialogue State Tracker
|
|
"""
|
|
Moduł odpowiedzialny za śledzenie stanu dialogu. Przechowuje informacje o tym jakie dane zostały uzyskane od użytkownika w toku prowadzonej konwersacji.
|
|
|
|
Wejście: Akt użytkownika (rama)
|
|
|
|
Wyjście: Reprezentacja stanu dialogu (rama)
|
|
"""
|
|
def __init__(self):
|
|
DST.__init__(self)
|
|
self.state = json.load(open('default_state.json'))
|
|
self.value_dict = json.load(open('value_dict.json'))
|
|
|
|
def update(self, user_act=None):
|
|
slots = user_act["slots"]
|
|
intent = user_act["act"]
|
|
domain = user_act["act"].split('/')[0]
|
|
|
|
if domain in ['password', 'name', 'email', 'enter_email', 'enter_name']:
|
|
return
|
|
|
|
if 'appointment' in intent:
|
|
for full_slot in slots:
|
|
slot = full_slot[1]
|
|
value = full_slot[1]
|
|
k = self.value_dict[domain.lower()].get(slot, slot)
|
|
|
|
if k is None:
|
|
return
|
|
|
|
domain_dic = self.state['belief_state'][domain]
|
|
|
|
if k in domain_dic['semi']:
|
|
nvalue = normalize_value(self.value_dict, domain, k, value)
|
|
self.state['belief_state'][domain]['semi'][k] = nvalue
|
|
elif k in domain_dic['book']:
|
|
self.state['belief_state'][domain]['book'][k] = value
|
|
elif k.lower() in domain_dic['book']:
|
|
self.state['belief_state'][domain]['book'][k.lower()] = value
|
|
elif intent == 'end_conversation':
|
|
self.state = {}
|
|
|
|
return self.state |