systemy_dialogowe/DST_DP.ipynb
Wojciech Lidwin 2d0ff0561e DST and DP
2023-05-16 23:31:26 +02:00

10 KiB

import json

class Rules_DST(): 

    def __init__(self):
        self.state = json.load(open('data.json'))

    def update_user(self, user_acts=None):
        for intent, domain, slot, value in user_acts:
            domain = domain.lower()
            intent = intent.lower()
            slot = slot.lower()
            if intent == 'start_conversation':
                continue

            elif intent == 'end_conversation':
                self.state = json.load(open('data.json'))
            elif domain not in self.state['belief_state']:
                continue
                
            
            elif 'inform' in intent:
                if (slot == 'inform'):
                    continue
                
                if(domain in slot):
                    slot.replace(domain + "/", '')

                domain_dic = self.state['belief_state'][domain]
                if slot in domain_dic:
                    self.state['belief_state'][domain][slot] = value
  
                    
            elif intent == 'request':
                if domain not in self.state['request_state']:
                    self.state['request_state'][domain] = {}
                if slot not in self.state['request_state'][domain]:
                    self.state['request_state'][domain][slot] = 0
                else:
                    self.state['request_state'][domain][slot] = value
                    
            elif intent == 'start_conversation':
                self.state["user_action"].append([intent, domain, slot, value])
                continue

            elif intent == 'end_conversation':
                self.state = json.load(open('data.json'))
                
            self.state["user_action"].append([intent, domain, slot, value])
        return self.state
dst = Rules_DST()
dst.state
{'user_action': [],
 'system_action': [],
 'belief_state': {'food': {'name': '',
   'type': '',
   'price range': '',
   'size': '',
   'ingredients': ''},
  'drink': {'name': '', 'price range': '', 'size': ''},
  'sauce': {'name': '', 'price range': '', 'size': ''},
  'order': {'type': '',
   'price range': '',
   'restaurant_name': '',
   'area': '',
   'book time': '',
   'book day': ''},
  'booking': {'restaurant_name': '',
   'area': '',
   'book time': '',
   'book day': '',
   'book people': ''},
  'payment': {'type': '', 'amount': '', 'vat': ''}},
 'request_state': {},
 'terminated': False,
 'history': []}
dst.state['user_action']
[]
dst.update_user([['star_conversation',"","",""], ['inform', 'drink', 'size', 'duża']])
dst.state['belief_state']['food']
{'name': '', 'type': '', 'price range': '', 'size': '', 'ingredients': ''}
dst.state['user_action']
[['inform', 'drink', 'size', 'duża']]
dst.update_user([['request', 'drink', 'price range', '?']])
dst.state['request_state']
{'drink': {'price range': 0}}
dst.update_user([['inform', 'food', 'type', 'pizza'], ['inform', 'food', 'size', 'duża']])
dst.state['belief_state']['food']
{'name': '',
 'type': 'pizza',
 'price range': '',
 'size': 'duża',
 'ingredients': ''}
from collections import defaultdict
import jmespath

class DP():
    def __init__(self):
        with open('database.json', encoding='utf-8-sig') as json_file:
            self.db =  json.load(json_file)
    

    def predict(self, state):
        self.results = []
        system_action = defaultdict(list)
        user_action = defaultdict(list)
        system_acts = []
        for idx in range(len(state['user_action'])):
            intent, domain, slot, value = state['user_action'][idx]
            user_action[(domain, intent)].append((slot, value))

        for user_act in user_action:
            system_acts.append(self.update_system_action(user_act, user_action, state, system_action))
        state['system_action'] = system_acts
        return system_acts[-1]


    def update_system_action(self, user_act, user_action, state, system_action):
        
        domain, intent = user_act    
        
        #Reguła 3
        if intent == 'end_conversation':
            return None
        
        constraints = [(slot, value) for slot, value in state['belief_state'][domain].items() if value != '']
        
        # Reguła 1
        if intent == 'request':
            if len(self.results) == 0:
                system_action[(domain, 'NoOffer')] = []
            else:
                for slot in user_action[user_act]:                  
                    if slot[0] in self.results[0]:
                        system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])

        # Reguła 2
        elif intent == 'inform':
            if len(constraints)>1:
                arg=f"{constraints[0]}".replace(f"\'{constraints[0][0]}\'",f"{constraints[0][0]}")
                arg = arg.replace("[","").replace("]","")
                for cons in constraints[1:]:
                    arg+=f" && contains{cons}".replace(f"\'{cons[0]}\'",f"{cons[0]}").replace("[","").replace("]","")
            else:
                arg=f"{constraints}".replace(f"\'{constraints[0]}\'",f"{constraints[0]}").replace("[","").replace("]","").replace("(\'","(").replace("\',",",")  
            self.results = jmespath.search(f"database.{domain}[?contains{arg} == `true` ]", self.db) 
            if len(self.results) == 0:
                system_action[(domain, 'NoOffer')] = []
            else:
                system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
                choice = self.results[0]

                if domain in ["food", "drink", "police", "sauce", "order", "booking", "payment"]:
                    system_action[(domain, 'Recommend')].append(['Name', choice['name']])
        return system_action
                    
dp= DP()
dp.predict(dst.state)
defaultdict(list,
            {('drink', 'Inform'): [['Choice', '1'],
              ['price range', 'średnia']],
             ('drink', 'Recommend'): [['Name', 'lemoniada']],
             ('food', 'Inform'): [['Choice', '4']],
             ('food', 'Recommend'): [['Name', 'pizza margherita']]})