from collections import defaultdict from copy import deepcopy from convlab2.policy.policy import Policy from convlab2.util.multiwoz.dbquery import Database # from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA from convlab2.dialog_agent import PipelineAgent from DST import DST REF_SYS_DA = { 'Cinema': { 'Type': 'type','Price': 'price','Stars': 'stars', 'Name': 'name','Day': 'day','People': 'people','Movie': 'movie', 'E-mail': 'e-mail', 'none': None }, } # Taktyka prowadzenia dialogu class DP(Policy): def __init__(self): Policy.__init__(self) self.db = Database() def predict(self, state): self.results = [] system_action = defaultdict(list) user_action = defaultdict(list) for intent, domain, slot, value in state['user_action']: user_action[(domain, intent)].append((slot, value)) for user_act in user_action: self.update_system_action(user_act, user_action, state, system_action) # Reguła 3 if any(True for slots in user_action.values() for (slot, _) in slots if slot in ['Stay', 'Day', 'People']): if self.results: system_action = {('Booking', 'Book'): [["Ref", self.results[0].get('Ref', 'N/A')]]} system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for slot, value in slots] state['system_action'] = system_acts return system_acts def update_system_action(self, user_act, user_action, state, system_action): domain, intent = user_act constraints = [(slot, value) for slot, value in state['belief_state'][domain.lower()]['semi'].items() if value != ''] self.results = deepcopy(self.db.query(domain.lower(), constraints)) # Reguła 1 if intent == 'Request': if len(self.results) == 0: system_action[(domain, 'NoOffer')] = [] else: for slot in user_action[user_act]: kb_slot_name = REF_SYS_DA[domain].get(slot[0], slot[0]) if kb_slot_name in self.results[0]: system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(kb_slot_name, 'unknown')]) # Reguła 2 elif intent == 'Inform': if len(self.results) == 0: system_action[(domain, 'NoOffer')] = [] else: system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))]) choice = self.results[0] if domain in ["Cinema", "Hotel", "Attraction", "Police", "Restaurant"]: system_action[(domain, 'Recommend')].append(['Name', choice['name']]) # Przykładowe uruchomienie dla kodu jeszcze bez zmian pod rezerwację biletów kinowych """ dst = DST() dp = DP() agent = PipelineAgent(nlu=None, dst=dst, policy=dp, nlg=None, name='sys') print(agent.response([['Inform', 'Cinema', 'Price', '15 zł'], ['Inform', 'Cinema', 'Movie', 'Batman']])) """