Policy and DST
This commit is contained in:
parent
6eb2ecfeed
commit
d1dbd09a15
2
.gitignore
vendored
2
.gitignore
vendored
@ -1 +1,3 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
|
*.log
|
||||||
|
*.ipynb
|
263
Modules.py
263
Modules.py
@ -1,19 +1,29 @@
|
|||||||
|
from convlab2.dst.dst import DST
|
||||||
|
from convlab2.dst.rule.multiwoz.dst_util import normalize_value
|
||||||
|
from collections import defaultdict
|
||||||
|
from convlab2.policy.policy import Policy
|
||||||
|
from convlab2.util.multiwoz.dbquery import Database
|
||||||
|
import copy
|
||||||
|
from copy import deepcopy
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import jsgf
|
import jsgf
|
||||||
|
|
||||||
#Natural Language Understanding
|
# Natural Language Understanding
|
||||||
class NLU:
|
class NLU:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.grammars = [jsgf.parse_grammar_file(f'JSGFs/{file_name}') for file_name in os.listdir('JSGFs')]
|
self.grammars = [
|
||||||
|
jsgf.parse_grammar_file(f"JSGFs/{file_name}")
|
||||||
|
for file_name in os.listdir("JSGFs")
|
||||||
|
]
|
||||||
|
|
||||||
def get_dialog_act(self, rule):
|
def get_dialog_act(self, rule):
|
||||||
slots = []
|
slots = []
|
||||||
self.get_slots(rule.expansion, slots)
|
self.get_slots(rule.expansion, slots)
|
||||||
return {'act': rule.grammar.name, 'slots': slots}
|
return {"act": rule.grammar.name, "slots": slots}
|
||||||
|
|
||||||
def get_slots(self, expansion, slots):
|
def get_slots(self, expansion, slots):
|
||||||
if expansion.tag != '':
|
if expansion.tag != "":
|
||||||
slots.append((expansion.tag, expansion.current_match))
|
slots.append((expansion.tag, expansion.current_match))
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -24,51 +34,225 @@ class NLU:
|
|||||||
self.get_slots(expansion.referenced_rule.expansion, slots)
|
self.get_slots(expansion.referenced_rule.expansion, slots)
|
||||||
|
|
||||||
def match(self, utterance):
|
def match(self, utterance):
|
||||||
list_of_illegal_character = [',', '.', "'", '?', '!', ':', '-', '/']
|
list_of_illegal_character = [",", ".", "'", "?", "!", ":", "-", "/"]
|
||||||
for illegal_character in list_of_illegal_character[:-2]:
|
for illegal_character in list_of_illegal_character[:-2]:
|
||||||
utterance = utterance.replace(f'{illegal_character}','')
|
utterance = utterance.replace(f"{illegal_character}", "")
|
||||||
for illegal_character in list_of_illegal_character[-2:]:
|
for illegal_character in list_of_illegal_character[-2:]:
|
||||||
utterance = utterance.replace(f'{illegal_character}',' ')
|
utterance = utterance.replace(f"{illegal_character}", " ")
|
||||||
|
|
||||||
for grammar in self.grammars:
|
for grammar in self.grammars:
|
||||||
matched = grammar.find_matching_rules(utterance)
|
matched = grammar.find_matching_rules(utterance)
|
||||||
if matched:
|
if matched:
|
||||||
return self.get_dialog_act(matched[0])
|
return self.get_dialog_act(matched[0])
|
||||||
return {'act': 'null', 'slots': []}
|
return {"act": "null", "slots": []}
|
||||||
|
|
||||||
#Dialogue policy
|
|
||||||
class DP:
|
|
||||||
#Module decide what act takes next
|
|
||||||
def __init__(self, acts, arguments):
|
|
||||||
self.acts = acts
|
|
||||||
self.arguments = arguments
|
|
||||||
|
|
||||||
def tacticChoice(self, frame_list):
|
class DP(Policy):
|
||||||
actVector = [0, 0]
|
def __init__(self):
|
||||||
return actVector
|
Policy.__init__(self)
|
||||||
|
self.db = Database()
|
||||||
|
|
||||||
#Dialogue State Tracker
|
def predict(self, state):
|
||||||
class DST:
|
self.results = []
|
||||||
#Contain informations about state of the dialogue and data taken from user
|
system_action = defaultdict(list)
|
||||||
def __init__(self, acts, arguments):
|
user_action = defaultdict(list)
|
||||||
self.acts = acts
|
|
||||||
self.arguments = arguments
|
|
||||||
self.frameList= []
|
|
||||||
|
|
||||||
#store new act into frame
|
for intent, domain, slot, value in state["user_action"]:
|
||||||
def store(self, frame):
|
user_action[(domain, intent)].append((slot, value))
|
||||||
self.frameList.append(frame)
|
|
||||||
|
|
||||||
def transfer(self):
|
for user_act in user_action:
|
||||||
return self.frameList
|
self.update_system_action(user_act, user_action, state, system_action)
|
||||||
#Natural Language Generator
|
|
||||||
|
system_acts = [
|
||||||
|
[intent, domain, slot, value]
|
||||||
|
for (domain, intent), slots in system_action.items()
|
||||||
|
for slot, value in slots
|
||||||
|
]
|
||||||
|
state["system_action"] = system_acts
|
||||||
|
return system_acts
|
||||||
|
|
||||||
|
def update_system_action(self, user_act, user_action, state, system_action):
|
||||||
|
domain, intent = user_act
|
||||||
|
constraints = [
|
||||||
|
(slot, value)
|
||||||
|
for slot, value in state["belief_state"][domain.lower()]["semi"].items()
|
||||||
|
if value != ""
|
||||||
|
]
|
||||||
|
self.db.dbs = {
|
||||||
|
"book": [
|
||||||
|
{
|
||||||
|
"author": "autor",
|
||||||
|
"title": "krew",
|
||||||
|
"edition": "2020",
|
||||||
|
"lang": "polski",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"author": "Marcin Bruczkowski",
|
||||||
|
"title": "Bezsenność w Tokio",
|
||||||
|
"genre": "reportaż",
|
||||||
|
"publisher": "Społeczny Instytut Wydawniczy Znak",
|
||||||
|
"edition": "2004",
|
||||||
|
"lang": "polski",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"author": "Harari Yuval Noah",
|
||||||
|
"title": "Sapiens Od zwierząt do bogów",
|
||||||
|
"edition": "2011",
|
||||||
|
"lang": "polski",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"author": "Haruki Murakami",
|
||||||
|
"title": "1Q84",
|
||||||
|
"edition": "2009",
|
||||||
|
"lang": "polski",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"author": "Fiodor Dostojewski",
|
||||||
|
"title": "Zbrodnia i Kara",
|
||||||
|
"publisher": "Wydawnictwo Mg",
|
||||||
|
"edition": "2015",
|
||||||
|
"lang": "polski",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
self.results = deepcopy(self.db.query(domain.lower(), constraints))
|
||||||
|
|
||||||
|
# Reguła 1
|
||||||
|
if intent == "Request":
|
||||||
|
if len(self.results) == 0:
|
||||||
|
system_action[(domain, "NoOffer")] = []
|
||||||
|
else:
|
||||||
|
for slot in user_action[user_act]:
|
||||||
|
kb_slot_name = ref[domain].get(slot[0], slot[0])
|
||||||
|
|
||||||
|
if kb_slot_name in self.results[0]:
|
||||||
|
system_action[(domain, "Inform")].append(
|
||||||
|
[slot[0], self.results[0].get(kb_slot_name, "unknown")]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reguła 2
|
||||||
|
elif intent == "Inform":
|
||||||
|
if len(self.results) == 0:
|
||||||
|
system_action[(domain, "NoOffer")] = []
|
||||||
|
else:
|
||||||
|
system_action[(domain, "Inform")].append(
|
||||||
|
["Choice", str(len(self.results))]
|
||||||
|
)
|
||||||
|
choice = self.results[0]
|
||||||
|
|
||||||
|
if domain in ["Book"]:
|
||||||
|
system_action[(domain, "Recommend")].append(
|
||||||
|
["Title", choice["title"]]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Dialogue State Tracker
|
||||||
|
class SDST(DST):
|
||||||
|
def __init__(self):
|
||||||
|
DST.__init__(self)
|
||||||
|
self.state = {
|
||||||
|
"user_action": [],
|
||||||
|
"system_action": [],
|
||||||
|
"belief_state": {
|
||||||
|
"books": {
|
||||||
|
"reserve": {"reservation": []},
|
||||||
|
"semi": {
|
||||||
|
"title": "",
|
||||||
|
"author": "",
|
||||||
|
"genre": "",
|
||||||
|
"publisher": "",
|
||||||
|
"edition": "",
|
||||||
|
"lang": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"library": {
|
||||||
|
"semi": {
|
||||||
|
"location": "",
|
||||||
|
"status": "",
|
||||||
|
"events": "",
|
||||||
|
"days": "",
|
||||||
|
"phone number": "",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"card": {"semi": {"lost": "", "destroyed": "", "new": ""}},
|
||||||
|
"date": {"semi": {"day": "", "month": "", "year": ""}},
|
||||||
|
},
|
||||||
|
"request_state": {},
|
||||||
|
"terminated": False,
|
||||||
|
"history": [],
|
||||||
|
}
|
||||||
|
self.ref = {
|
||||||
|
"Books": {
|
||||||
|
"Title": "title",
|
||||||
|
"Author": "author",
|
||||||
|
"Genre": "genre",
|
||||||
|
"Publisher": "publisher",
|
||||||
|
"Edition": "edition",
|
||||||
|
"Lang": "lang",
|
||||||
|
"None": "none",
|
||||||
|
},
|
||||||
|
"Library": {
|
||||||
|
"Location": "location",
|
||||||
|
"Status": "status",
|
||||||
|
"Events": "events",
|
||||||
|
"Days": "days",
|
||||||
|
"Phone number": "phone number",
|
||||||
|
"None": "none",
|
||||||
|
},
|
||||||
|
"Card": {
|
||||||
|
"Lost": "lost",
|
||||||
|
"Destroyed": "destroyed",
|
||||||
|
"New": "new",
|
||||||
|
"None": "none",
|
||||||
|
},
|
||||||
|
"Date": {"Day": "day", "Month": "month", "Year": "year", "None": "none"},
|
||||||
|
}
|
||||||
|
self.value_dict = json.load(open("value_dict.json"))
|
||||||
|
|
||||||
|
def update(self, user_act=None):
|
||||||
|
for intent, domain, slot, value in user_act:
|
||||||
|
domain = domain.lower()
|
||||||
|
intent = intent.lower()
|
||||||
|
|
||||||
|
if domain in ["unk", "general", "booking"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if intent == "inform":
|
||||||
|
k = self.ref[domain.capitalize()].get(slot, slot)
|
||||||
|
|
||||||
|
if k is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
domain_dic = self.state["belief_state"][domain]
|
||||||
|
|
||||||
|
if k in domain_dic["semi"]:
|
||||||
|
nvalue = normalize_value(self.value_dict, domain, k, value)
|
||||||
|
self.state["belief_state"][domain]["semi"][k] = nvalue
|
||||||
|
elif k in domain_dic["book"]:
|
||||||
|
self.state["belief_state"][domain]["book"][k] = value
|
||||||
|
elif k.lower() in domain_dic["book"]:
|
||||||
|
self.state["belief_state"][domain]["book"][k.lower()] = value
|
||||||
|
elif intent == "request":
|
||||||
|
k = self.ref[domain.capitalize()].get(slot, slot)
|
||||||
|
if domain not in self.state["request_state"]:
|
||||||
|
self.state["request_state"][domain] = {}
|
||||||
|
if k not in self.state["request_state"][domain]:
|
||||||
|
self.state["request_state"][domain][k] = 0
|
||||||
|
|
||||||
|
return self.state
|
||||||
|
|
||||||
|
def init_session(self):
|
||||||
|
self.state = self_state
|
||||||
|
|
||||||
|
|
||||||
|
# Natural Language Generator
|
||||||
class NLG:
|
class NLG:
|
||||||
def __init__(self, acts, arguments):
|
def __init__(self, acts, arguments):
|
||||||
self.acts = acts
|
self.acts = acts
|
||||||
self.arguments = arguments
|
self.arguments = arguments
|
||||||
|
|
||||||
def vectorToText(self, actVector):
|
def vectorToText(self, actVector):
|
||||||
if(actVector == [0, 0]):
|
if actVector == [0, 0]:
|
||||||
return "Witaj, nazywam się Mateusz."
|
return "Witaj, nazywam się Mateusz."
|
||||||
else:
|
else:
|
||||||
return "Przykro mi, nie zrozumiałem Cię"
|
return "Przykro mi, nie zrozumiałem Cię"
|
||||||
@ -76,13 +260,11 @@ class NLG:
|
|||||||
|
|
||||||
class Run:
|
class Run:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.acts={
|
self.acts = {
|
||||||
0: "hello",
|
0: "hello",
|
||||||
1: "request",
|
1: "request",
|
||||||
}
|
}
|
||||||
self.arguments={
|
self.arguments = {0: "name"}
|
||||||
0: "name"
|
|
||||||
}
|
|
||||||
|
|
||||||
self.nlu = NLU()
|
self.nlu = NLU()
|
||||||
self.dp = DP(self.acts, self.arguments)
|
self.dp = DP(self.acts, self.arguments)
|
||||||
@ -98,15 +280,8 @@ class Run:
|
|||||||
|
|
||||||
return self.nlg.vectorToText(basic_act)
|
return self.nlg.vectorToText(basic_act)
|
||||||
|
|
||||||
|
|
||||||
# run = Run()
|
# run = Run()
|
||||||
# while(1):
|
# while(1):
|
||||||
# message = input("Napisz coś: ")
|
# message = input("Napisz coś: ")
|
||||||
# print(run.inputProcessing(message))
|
# print(run.inputProcessing(message))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
12
evaluate.py
12
evaluate.py
@ -4,7 +4,7 @@ import pandas as pd
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from Modules import NLU
|
from Modules import NLU
|
||||||
|
|
||||||
PATTERN = r'[^(]*'
|
PATTERN = r"[^(]*"
|
||||||
|
|
||||||
# Algorytm sprawdzający
|
# Algorytm sprawdzający
|
||||||
|
|
||||||
@ -13,17 +13,17 @@ hits = 0
|
|||||||
|
|
||||||
nlu = NLU()
|
nlu = NLU()
|
||||||
|
|
||||||
for file_name in os.listdir('data'):
|
for file_name in os.listdir("data"):
|
||||||
df = pd.read_csv(f'data/{file_name}', sep='\t', names=['user', 'sentence', 'acts'])
|
df = pd.read_csv(f"data/{file_name}", sep="\t", names=["user", "sentence", "acts"])
|
||||||
df = df[df.user == 'user']
|
df = df[df.user == "user"]
|
||||||
data = np.array(df)
|
data = np.array(df)
|
||||||
|
|
||||||
for row in data:
|
for row in data:
|
||||||
rows += 1
|
rows += 1
|
||||||
sentence = row[1]
|
sentence = row[1]
|
||||||
user_acts = row[2].split('&')
|
user_acts = row[2].split("&")
|
||||||
nlu_match = nlu.match(sentence)
|
nlu_match = nlu.match(sentence)
|
||||||
if nlu_match['act'] in user_acts:
|
if nlu_match["act"] in user_acts:
|
||||||
hits += 1
|
hits += 1
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user