From 513f09e06b1cd0d743cf35b067050207af7d8d3c Mon Sep 17 00:00:00 2001 From: 464962 Date: Mon, 27 May 2024 21:50:25 +0200 Subject: [PATCH] move code from jupyter to python classes --- DialoguePolicy.py | 82 ++++++++++++++++++++++++++++++++++---- DialogueStateTracker.py | 82 +++++++++++++++++++++++++++++++++++--- Main.py | 30 ++++++-------- NaturalLanguageAnalyzer.py | 6 ++- 4 files changed, 169 insertions(+), 31 deletions(-) diff --git a/DialoguePolicy.py b/DialoguePolicy.py index 544b7fe..5983d03 100644 --- a/DialoguePolicy.py +++ b/DialoguePolicy.py @@ -1,8 +1,76 @@ -class DialoguePolicy: +from collections import defaultdict +import json +from copy import deepcopy +from convlab.policy.policy import Policy - def policy(self, state): - system_act = None - name = "James" - if state == "what name": - system_act = f"inform(name={name})" - return system_act +db_path = './hotels_data.json' + + +class DialoguePolicy(Policy): + def __init__(self): + Policy.__init__(self) + self.db = self.load_database(db_path) + + def load_database(self, db_path): + with open(db_path, 'r', encoding='utf-8') as f: + return json.load(f) + + def query(self, domain, constraints): + if domain != 'hotel': + return [] + + results = [] + for entry in self.db: + match = all(entry.get(key) == value for key, value in constraints) + if match: + results.append(entry) + return results + + def predict(self, state): + self.results = [] + system_action = defaultdict(list) + user_action = defaultdict(list) + + for intent, domain, slot, value in state['user_action']: + user_action[(domain.lower(), intent.lower())].append((slot.lower(), value)) + + for user_act in user_action: + self.update_system_action(user_act, user_action, state, system_action) + + if any(True for slots in user_action.values() for (slot, _) in slots if + slot in ['book stay', 'book day', 'book people']): + if self.results: + system_action = {('Booking', 'Book'): [["Ref", self.results[0].get('Ref', 'N/A')]]} + + system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for + slot, value in slots] + state['system_action'] = system_acts + return system_acts + + def update_system_action(self, user_act, user_action, state, system_action): + domain, intent = user_act + constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != ''] + # print(f"Constraints: {constraints}") + self.results = deepcopy(self.query(domain.lower(), constraints)) + # print(f"Query results: {self.results}") + + if intent == 'request': + if len(self.results) == 0: + system_action[(domain, 'NoOffer')] = [] + else: + for slot in user_action[user_act]: + if slot[0] in self.results[0]: + system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')]) + + elif intent == 'inform': + if len(self.results) == 0: + system_action[(domain, 'NoOffer')] = [] + else: + system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))]) + choice = self.results[0] + + if domain in ["hotel"]: + system_action[(domain, 'Recommend')].append(['Name', choice['name']]) + for slot in state['belief_state'][domain]['info']: + if choice.get(slot): + state['belief_state'][domain]['info'][slot] = choice[slot] \ No newline at end of file diff --git a/DialogueStateTracker.py b/DialogueStateTracker.py index 1f15dbc..d95b8e9 100644 --- a/DialogueStateTracker.py +++ b/DialogueStateTracker.py @@ -1,7 +1,77 @@ -class DialogueStateTracker: +import json +from convlab.dst.dst import DST - def dst(self, user_act): - state = None - if user_act == "request(firstname)": - state = "what name" - return state + +def default_state(): + return { + 'belief_state': { + 'hotel': { + 'info': { + 'name': '', + 'area': '', + 'parking': '', + 'price range': '', + 'stars': '', + 'internet': '', + 'type': '' + }, + 'booking': { + 'book stay': '', + 'book day': '', + 'book people': '' + } + } + }, + 'request_state': {}, + 'history': [], + 'user_action': [], + 'system_action': [], + 'terminated': False, + 'booked': [] + } + + +class DialogueStateTracker(DST): + def __init__(self): + DST.__init__(self) + self.state = default_state() + with open('./hotels_data.json') as f: + self.value_dict = json.load(f) + + def update(self, user_act=None): + for intent, domain, slot, value in user_act: + domain = domain.lower() + intent = intent.lower() + slot = slot.lower() + + if domain not in self.state['belief_state']: + continue + + if intent == 'inform': + if slot == 'none' or slot == '' or value == 'dontcare': + continue + + domain_dic = self.state['belief_state'][domain]['info'] + + if slot in domain_dic: + nvalue = self.normalize_value(self.value_dict, domain, slot, value) + self.state['belief_state'][domain]['info'][slot] = nvalue + + elif intent == 'request': + if domain not in self.state['request_state']: + self.state['request_state'][domain] = {} + if slot not in self.state['request_state'][domain]: + self.state['request_state'][domain][slot] = 0 + + return self.state + + def normalize_value(self, value_dict, domain, slot, value): + normalized_value = value.lower().strip() + if domain in value_dict and slot in value_dict[domain]: + possible_values = value_dict[domain][slot] + if isinstance(possible_values, dict) and normalized_value in possible_values: + return possible_values[normalized_value] + return value + + def init_session(self): + self.state = default_state() \ No newline at end of file diff --git a/Main.py b/Main.py index dd61ebc..4717208 100644 --- a/Main.py +++ b/Main.py @@ -1,24 +1,20 @@ from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer -from DialogueStateTracker import DialogueStateTracker from DialoguePolicy import DialoguePolicy -from NaturalLanguageGeneration import NaturalLanguageGeneration +from DialogueStateTracker import DialogueStateTracker +from convlab.dialog_agent import PipelineAgent +from convlab.nlg.template.multiwoz import TemplateNLG + + if __name__ == "__main__": - text = "chciałbym zarezerwować pokój z balkonem 1 stycznia w Warszawie" - nla = NaturalLanguageAnalyzer() - user_act = nla.process(text) - print(user_act) + text = "chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum" + nlu = NaturalLanguageAnalyzer() + dst = DialogueStateTracker() + policy = DialoguePolicy() + nlg = TemplateNLG(is_user=False) - # dst = DialogueStateTracker() - # state = dst.dst(user_act) - # print(state) - # - # dp = DialoguePolicy() - # system_act = dp.policy(state) - # print(system_act) - # - # nlg = NaturalLanguageGeneration() - # response = nlg.nlg(system_act) - # print(response) + agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys') + response = agent.response(text) + print(response) diff --git a/NaturalLanguageAnalyzer.py b/NaturalLanguageAnalyzer.py index 8ed6a5d..c58a526 100644 --- a/NaturalLanguageAnalyzer.py +++ b/NaturalLanguageAnalyzer.py @@ -14,7 +14,7 @@ def translate_text(text, target_language='en'): class NaturalLanguageAnalyzer: - def process(self, text): + def predict(self, text, context=None): # Inicjalizacja modelu NLU model_name = "ConvLab/t5-small-nlu-multiwoz21" nlu_model = T5NLU(speaker='user', context_window_size=0, model_name_or_path=model_name) @@ -26,3 +26,7 @@ class NaturalLanguageAnalyzer: nlu_output = nlu_model.predict(translated_input) return nlu_output + + def init_session(self): + # Inicjalizacja sesji (jeśli konieczne) + pass \ No newline at end of file