From d1dbd09a15e35a923298b38f1d7accb0c274e648 Mon Sep 17 00:00:00 2001 From: Bartosz Date: Sat, 29 May 2021 11:22:47 +0200 Subject: [PATCH] Policy and DST --- .gitignore | 4 +- Modules.py | 263 +++++++++++++++++++++++++++++++++++++++++++--------- evaluate.py | 14 +-- 3 files changed, 229 insertions(+), 52 deletions(-) diff --git a/.gitignore b/.gitignore index ed8ebf5..3724127 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ -__pycache__ \ No newline at end of file +__pycache__ +*.log +*.ipynb \ No newline at end of file diff --git a/Modules.py b/Modules.py index 47f0b4d..fa55fa7 100644 --- a/Modules.py +++ b/Modules.py @@ -1,19 +1,29 @@ +from convlab2.dst.dst import DST +from convlab2.dst.rule.multiwoz.dst_util import normalize_value +from collections import defaultdict +from convlab2.policy.policy import Policy +from convlab2.util.multiwoz.dbquery import Database +import copy +from copy import deepcopy +import json import os import jsgf -#Natural Language Understanding +# Natural Language Understanding class NLU: - def __init__(self): - self.grammars = [jsgf.parse_grammar_file(f'JSGFs/{file_name}') for file_name in os.listdir('JSGFs')] + self.grammars = [ + jsgf.parse_grammar_file(f"JSGFs/{file_name}") + for file_name in os.listdir("JSGFs") + ] def get_dialog_act(self, rule): slots = [] self.get_slots(rule.expansion, slots) - return {'act': rule.grammar.name, 'slots': slots} + return {"act": rule.grammar.name, "slots": slots} def get_slots(self, expansion, slots): - if expansion.tag != '': + if expansion.tag != "": slots.append((expansion.tag, expansion.current_match)) return @@ -24,51 +34,225 @@ class NLU: self.get_slots(expansion.referenced_rule.expansion, slots) def match(self, utterance): - list_of_illegal_character = [',', '.', "'", '?', '!', ':', '-', '/'] + list_of_illegal_character = [",", ".", "'", "?", "!", ":", "-", "/"] for illegal_character in list_of_illegal_character[:-2]: - utterance = utterance.replace(f'{illegal_character}','') + utterance = utterance.replace(f"{illegal_character}", "") for illegal_character in list_of_illegal_character[-2:]: - utterance = utterance.replace(f'{illegal_character}',' ') + utterance = utterance.replace(f"{illegal_character}", " ") for grammar in self.grammars: matched = grammar.find_matching_rules(utterance) if matched: return self.get_dialog_act(matched[0]) - return {'act': 'null', 'slots': []} + return {"act": "null", "slots": []} -#Dialogue policy -class DP: - #Module decide what act takes next - def __init__(self, acts, arguments): - self.acts = acts - self.arguments = arguments - def tacticChoice(self, frame_list): - actVector = [0, 0] - return actVector +class DP(Policy): + def __init__(self): + Policy.__init__(self) + self.db = Database() -#Dialogue State Tracker -class DST: - #Contain informations about state of the dialogue and data taken from user - def __init__(self, acts, arguments): - self.acts = acts - self.arguments = arguments - self.frameList= [] + def predict(self, state): + self.results = [] + system_action = defaultdict(list) + user_action = defaultdict(list) - #store new act into frame - def store(self, frame): - self.frameList.append(frame) + for intent, domain, slot, value in state["user_action"]: + user_action[(domain, intent)].append((slot, value)) - def transfer(self): - return self.frameList -#Natural Language Generator + for user_act in user_action: + self.update_system_action(user_act, user_action, state, system_action) + + system_acts = [ + [intent, domain, slot, value] + for (domain, intent), slots in system_action.items() + for slot, value in slots + ] + state["system_action"] = system_acts + return system_acts + + def update_system_action(self, user_act, user_action, state, system_action): + domain, intent = user_act + constraints = [ + (slot, value) + for slot, value in state["belief_state"][domain.lower()]["semi"].items() + if value != "" + ] + self.db.dbs = { + "book": [ + { + "author": "autor", + "title": "krew", + "edition": "2020", + "lang": "polski", + }, + { + "author": "Marcin Bruczkowski", + "title": "Bezsenność w Tokio", + "genre": "reportaż", + "publisher": "Społeczny Instytut Wydawniczy Znak", + "edition": "2004", + "lang": "polski", + }, + { + "author": "Harari Yuval Noah", + "title": "Sapiens Od zwierząt do bogów", + "edition": "2011", + "lang": "polski", + }, + { + "author": "Haruki Murakami", + "title": "1Q84", + "edition": "2009", + "lang": "polski", + }, + { + "author": "Fiodor Dostojewski", + "title": "Zbrodnia i Kara", + "publisher": "Wydawnictwo Mg", + "edition": "2015", + "lang": "polski", + }, + ] + } + self.results = deepcopy(self.db.query(domain.lower(), constraints)) + + # Reguła 1 + if intent == "Request": + if len(self.results) == 0: + system_action[(domain, "NoOffer")] = [] + else: + for slot in user_action[user_act]: + kb_slot_name = ref[domain].get(slot[0], slot[0]) + + if kb_slot_name in self.results[0]: + system_action[(domain, "Inform")].append( + [slot[0], self.results[0].get(kb_slot_name, "unknown")] + ) + + # Reguła 2 + elif intent == "Inform": + if len(self.results) == 0: + system_action[(domain, "NoOffer")] = [] + else: + system_action[(domain, "Inform")].append( + ["Choice", str(len(self.results))] + ) + choice = self.results[0] + + if domain in ["Book"]: + system_action[(domain, "Recommend")].append( + ["Title", choice["title"]] + ) + + +# Dialogue State Tracker +class SDST(DST): + def __init__(self): + DST.__init__(self) + self.state = { + "user_action": [], + "system_action": [], + "belief_state": { + "books": { + "reserve": {"reservation": []}, + "semi": { + "title": "", + "author": "", + "genre": "", + "publisher": "", + "edition": "", + "lang": "", + }, + }, + "library": { + "semi": { + "location": "", + "status": "", + "events": "", + "days": "", + "phone number": "", + } + }, + "card": {"semi": {"lost": "", "destroyed": "", "new": ""}}, + "date": {"semi": {"day": "", "month": "", "year": ""}}, + }, + "request_state": {}, + "terminated": False, + "history": [], + } + self.ref = { + "Books": { + "Title": "title", + "Author": "author", + "Genre": "genre", + "Publisher": "publisher", + "Edition": "edition", + "Lang": "lang", + "None": "none", + }, + "Library": { + "Location": "location", + "Status": "status", + "Events": "events", + "Days": "days", + "Phone number": "phone number", + "None": "none", + }, + "Card": { + "Lost": "lost", + "Destroyed": "destroyed", + "New": "new", + "None": "none", + }, + "Date": {"Day": "day", "Month": "month", "Year": "year", "None": "none"}, + } + self.value_dict = json.load(open("value_dict.json")) + + def update(self, user_act=None): + for intent, domain, slot, value in user_act: + domain = domain.lower() + intent = intent.lower() + + if domain in ["unk", "general", "booking"]: + continue + + if intent == "inform": + k = self.ref[domain.capitalize()].get(slot, slot) + + if k is None: + continue + + domain_dic = self.state["belief_state"][domain] + + if k in domain_dic["semi"]: + nvalue = normalize_value(self.value_dict, domain, k, value) + self.state["belief_state"][domain]["semi"][k] = nvalue + elif k in domain_dic["book"]: + self.state["belief_state"][domain]["book"][k] = value + elif k.lower() in domain_dic["book"]: + self.state["belief_state"][domain]["book"][k.lower()] = value + elif intent == "request": + k = self.ref[domain.capitalize()].get(slot, slot) + if domain not in self.state["request_state"]: + self.state["request_state"][domain] = {} + if k not in self.state["request_state"][domain]: + self.state["request_state"][domain][k] = 0 + + return self.state + + def init_session(self): + self.state = self_state + + +# Natural Language Generator class NLG: def __init__(self, acts, arguments): self.acts = acts self.arguments = arguments def vectorToText(self, actVector): - if(actVector == [0, 0]): + if actVector == [0, 0]: return "Witaj, nazywam się Mateusz." else: return "Przykro mi, nie zrozumiałem Cię" @@ -76,13 +260,11 @@ class NLG: class Run: def __init__(self): - self.acts={ + self.acts = { 0: "hello", 1: "request", } - self.arguments={ - 0: "name" - } + self.arguments = {0: "name"} self.nlu = NLU() self.dp = DP(self.acts, self.arguments) @@ -98,15 +280,8 @@ class Run: return self.nlg.vectorToText(basic_act) + # run = Run() # while(1): # message = input("Napisz coś: ") # print(run.inputProcessing(message)) - - - - - - - - diff --git a/evaluate.py b/evaluate.py index 10f8709..7e00cc6 100644 --- a/evaluate.py +++ b/evaluate.py @@ -4,7 +4,7 @@ import pandas as pd import numpy as np from Modules import NLU -PATTERN = r'[^(]*' +PATTERN = r"[^(]*" # Algorytm sprawdzający @@ -13,19 +13,19 @@ hits = 0 nlu = NLU() -for file_name in os.listdir('data'): - df = pd.read_csv(f'data/{file_name}', sep='\t', names=['user', 'sentence', 'acts']) - df = df[df.user == 'user'] +for file_name in os.listdir("data"): + df = pd.read_csv(f"data/{file_name}", sep="\t", names=["user", "sentence", "acts"]) + df = df[df.user == "user"] data = np.array(df) for row in data: rows += 1 sentence = row[1] - user_acts = row[2].split('&') + user_acts = row[2].split("&") nlu_match = nlu.match(sentence) - if nlu_match['act'] in user_acts: + if nlu_match["act"] in user_acts: hits += 1 print(f"Accuracy: {(hits / rows)*100}") -# Dokładność 38.5% \ No newline at end of file +# Dokładność 38.5%