DST (start)

This commit is contained in:
Anna Nowak 2021-05-30 13:31:34 +02:00
parent 2227204dc4
commit 2eca0a043d
12 changed files with 192 additions and 1520 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
*frame-model*
*slot-model*
.venv*
env*

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@ import jsgf
from Modules.NLG_module import NLG
from Modules.DP_module import DP
from Modules.DST_module import DST
from Modules.DST_module import Rules_DST
from Modules.Book_NLU_module import Book_NLU
from Modules.ML_NLU_module import ML_NLU
@ -20,26 +20,15 @@ if torch.cuda.is_available():
class Janet:
def __init__(self):
self.acts={
0: "greetings",
1: "request",
}
self.arguments={
0: "name"
}
self.nlg = NLG(self.acts, self.arguments)
self.dp = DP(self.acts, self.arguments)
self.dst = DST(self.acts, self.arguments)
self.nlu = Book_NLU(self.acts, self.arguments, jsgf.parse_grammar_file('book.jsgf'))
self.nlu_v2 = ML_NLU(self.acts, self.arguments)
def test(self, command):
out = self.nlu_v2.test_nlu(command)
return out
self.nlg = NLG()
self.dp = DP()
self.dst = Rules_DST()
self.nlu = Book_NLU(jsgf.parse_grammar_file('book.jsgf'))
self.nlu_v2 = ML_NLU()
def process(self, command):
act = self.nlu.analyze(command)
self.dst.store(act)
act = self.nlu_v2.test_nlu(command)
self.dst.update(act)
dest_act = self.dp.choose_tactic(self.dst.transfer())
return self.nlg.change_to_text(dest_act)
@ -49,7 +38,7 @@ def main():
while(1):
print('\n')
text = input("Wpisz tekst: ")
print(janet.test(text))
print(janet.process(text))
if __name__ == "__main__":
main()

View File

@ -8,9 +8,7 @@ class Book_NLU: #Natural Language Understanding
Wyjście: Akt użytkownika (rama)
"""
def __init__(self, acts, arguments, book_grammar):
self.acts = acts
self.arguments = arguments
def __init__(self, book_grammar):
self.book_grammar = book_grammar
def get_dialog_act(self, rule):

View File

@ -6,10 +6,8 @@ class DP:
Wyjście: Akt systemu (rama)
"""
def __init__(self, acts, arguments):
self.acts = acts
self.arguments = arguments
def __init__(self):
pass
def choose_tactic(self, frame_list):
"""

View File

@ -1,4 +1,9 @@
class DST: #Dialogue State Tracker
import json
from convlab2.dst.dst import DST
from convlab2.dst.rule.multiwoz.dst_util import normalize_value
class Rules_DST(DST): #Dialogue State Tracker
"""
Moduł odpowiedzialny za śledzenie stanu dialogu. Przechowuje informacje o tym jakie dane zostały uzyskane od użytkownika w toku prowadzonej konwersacji.
@ -6,22 +11,38 @@ class DST: #Dialogue State Tracker
Wyjście: Reprezentacja stanu dialogu (rama)
"""
def __init__(self, acts, arguments):
self.acts = acts
self.arguments = arguments
self.frame_list= []
def __init__(self):
DST.__init__(self)
self.state = json.load(open('default_state.json'))
self.value_dict = json.load(open('value_dict.json'))
def update(self, user_act=None):
slots = user_act["slots"]
intent = user_act["act"]
domain = user_act["act"].split('/')[0]
def store(self, rama):
"""
Dodanie nowego aktu do listy
"""
print("\nDodanie do listy nowej ramy: ")
print(rama)
self.frame_list.append(rama)
if domain in ['password', 'name', 'email', 'enter_email', 'enter_name']:
return
if 'appointment' in intent:
for full_slot in slots:
slot = full_slot[1]
value = full_slot[1]
k = self.value_dict[domain.lower()].get(slot, slot)
def transfer(self):
print("Przekazanie dalej listy ram: ")
print(self.frame_list)
return self.frame_list
if k is None:
return
domain_dic = self.state['belief_state'][domain]
if k in domain_dic['semi']:
nvalue = normalize_value(self.value_dict, domain, k, value)
self.state['belief_state'][domain]['semi'][k] = nvalue
elif k in domain_dic['book']:
self.state['belief_state'][domain]['book'][k] = value
elif k.lower() in domain_dic['book']:
self.state['belief_state'][domain]['book'][k.lower()] = value
elif intent == 'end_conversation':
self.state = {}
return self.state

View File

@ -5,9 +5,7 @@ from flair.datasets import SentenceDataset
from flair.models import SequenceTagger
class ML_NLU:
def __init__(self, acts, arguments):
self.acts = acts
self.arguments = arguments
def __init__(self):
self.slot_model, self.frame_model = self.setup()
def nolabel2o(self, line, i):

View File

@ -6,10 +6,8 @@ class NLG:
Wyjście: Tekst
"""
def __init__(self, acts, arguments):
self.acts = acts
self.arguments = arguments
def __init__(self):
pass
def change_to_text(self, act_vector):
"""

View File

@ -46,6 +46,9 @@ frame_corpus = Corpus(train=conllu2flair(frame_trainset, 'frame'), test=conllu2f
slot_tag_dictionary = slot_corpus.make_tag_dictionary(tag_type='slot')
frame_tag_dictionary = frame_corpus.make_tag_dictionary(tag_type='frame')
print(slot_tag_dictionary)
print(frame_tag_dictionary)
embedding_types = [
WordEmbeddings('pl'),
@ -62,7 +65,7 @@ frame_tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,
tag_dictionary=frame_tag_dictionary,
tag_type='frame', use_crf=True)
# slot_trainer = ModelTrainer(slot_tagger, slot_corpus)
slot_trainer = ModelTrainer(slot_tagger, slot_corpus)
# slot_trainer.train('slot-model',
# learning_rate=0.1,
# mini_batch_size=32,

View File

@ -828,8 +828,8 @@
# intent: request_information/doctors
# slots:
1 jacy request_information/doctors NoLabel
2 lekarze request_information/doctors B-appoinment/doctor
3 specjaliści request_information/doctors B-appoinment/doctor
2 lekarze request_information/doctors B-appointment/doctor
3 specjaliści request_information/doctors B-appointment/doctor
4 przyjmują request_information/doctors NoLabel
5 w request_information/doctors NoLabel
6 państwa request_information/doctors NoLabel
@ -840,7 +840,7 @@
# slots:
1 chciałbym appointment/create_appointment NoLabel
2 umówić appointment/create_appointment NoLabel
3 wizytę appointment/create_appointment B-appoinment
3 wizytę appointment/create_appointment B-appointment
4 do appointment/create_appointment NoLabel
5 doktora appointment/create_appointment B-appointment/doctor
6 kolano appointment/create_appointment I-appointment/doctor
@ -891,7 +891,7 @@
5 okulisty appointment/create_appointment B-appointment/doctor
6 ile request_information/cost NoLabel
7 kosztuje request_information/cost NoLabel
8 wizyta request_information/cost B-appoinment
8 wizyta request_information/cost B-appointment
# text: Nie ten jest idealny
# intent: deny

17
default_state.json Normal file
View File

@ -0,0 +1,17 @@
{
"user_action": [],
"system_action": [],
"belief_state": {
"appointment": {
"prescription": {
"url": {},
"type": {}
},
"doctor": {},
"where": {}
},
"request_state": {},
"terminated": false,
"history": []
}
}

110
value_dict.json Normal file
View File

@ -0,0 +1,110 @@
{
"appointment": {
"hour": [
"08:00",
"08:15",
"08:30",
"08:45",
"09:00",
"09:15",
"09:30",
"09:45",
"10:00",
"10:15",
"10:30",
"10:45",
"11:00",
"11:15",
"11:30",
"11:45",
"12:00",
"12:15",
"12:30",
"12:45",
"13:00",
"13:15",
"13:30",
"13:45",
"14:00",
"14:15",
"14:30",
"14:45",
"15:00",
"15:15",
"15:30",
"15:45",
"16:00",
"16:15",
"16:30",
"16:45",
"17:00",
"17:15",
"17:30",
"17:45",
"18:00",
"18:15",
"18:30",
"18:45",
"19:00",
"19:15",
"19:30",
"19:45"
],
"date": [
"poniedziałek",
"wtorek",
"środa",
"środę",
"czwartek",
"piątek",
"sobota",
"sobotę",
"niedziela",
"niedzielę",
"01.05.2021",
"02.05.2021",
"03.05.2021",
"04.05.2021",
"05.05.2021",
"06.05.2021",
"07.05.2021",
"08.05.2021",
"09.05.2021",
"10.05.2021",
"11.05.2021",
"12.05.2021",
"13.05.2021",
"14.05.2021",
"15.05.2021",
"16.05.2021",
"17.05.2021",
"18.05.2021",
"19.05.2021",
"20.05.2021",
"21.05.2021",
"22.05.2021",
"23.05.2021",
"24.05.2021",
"25.05.2021",
"26.05.2021",
"27.05.2021",
"28.05.2021",
"29.05.2021",
"20.05.2021"
],
"doctor": [
"internista",
"okulista",
"rodzinny",
"ginekolog",
"dr. Kolano",
"dr. Kowalska"
]
},
"prescritpion": {
"type": [
"recepta",
"e-recepta"
]
}
}