DST (start)
This commit is contained in:
parent
2227204dc4
commit
2eca0a043d
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
*frame-model*
|
||||||
|
*slot-model*
|
||||||
|
.venv*
|
||||||
|
env*
|
File diff suppressed because it is too large
Load Diff
@ -2,7 +2,7 @@ import jsgf
|
|||||||
|
|
||||||
from Modules.NLG_module import NLG
|
from Modules.NLG_module import NLG
|
||||||
from Modules.DP_module import DP
|
from Modules.DP_module import DP
|
||||||
from Modules.DST_module import DST
|
from Modules.DST_module import Rules_DST
|
||||||
from Modules.Book_NLU_module import Book_NLU
|
from Modules.Book_NLU_module import Book_NLU
|
||||||
from Modules.ML_NLU_module import ML_NLU
|
from Modules.ML_NLU_module import ML_NLU
|
||||||
|
|
||||||
@ -20,26 +20,15 @@ if torch.cuda.is_available():
|
|||||||
|
|
||||||
class Janet:
|
class Janet:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.acts={
|
self.nlg = NLG()
|
||||||
0: "greetings",
|
self.dp = DP()
|
||||||
1: "request",
|
self.dst = Rules_DST()
|
||||||
}
|
self.nlu = Book_NLU(jsgf.parse_grammar_file('book.jsgf'))
|
||||||
self.arguments={
|
self.nlu_v2 = ML_NLU()
|
||||||
0: "name"
|
|
||||||
}
|
|
||||||
self.nlg = NLG(self.acts, self.arguments)
|
|
||||||
self.dp = DP(self.acts, self.arguments)
|
|
||||||
self.dst = DST(self.acts, self.arguments)
|
|
||||||
self.nlu = Book_NLU(self.acts, self.arguments, jsgf.parse_grammar_file('book.jsgf'))
|
|
||||||
self.nlu_v2 = ML_NLU(self.acts, self.arguments)
|
|
||||||
|
|
||||||
def test(self, command):
|
|
||||||
out = self.nlu_v2.test_nlu(command)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def process(self, command):
|
def process(self, command):
|
||||||
act = self.nlu.analyze(command)
|
act = self.nlu_v2.test_nlu(command)
|
||||||
self.dst.store(act)
|
self.dst.update(act)
|
||||||
dest_act = self.dp.choose_tactic(self.dst.transfer())
|
dest_act = self.dp.choose_tactic(self.dst.transfer())
|
||||||
return self.nlg.change_to_text(dest_act)
|
return self.nlg.change_to_text(dest_act)
|
||||||
|
|
||||||
@ -49,7 +38,7 @@ def main():
|
|||||||
while(1):
|
while(1):
|
||||||
print('\n')
|
print('\n')
|
||||||
text = input("Wpisz tekst: ")
|
text = input("Wpisz tekst: ")
|
||||||
print(janet.test(text))
|
print(janet.process(text))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
@ -8,9 +8,7 @@ class Book_NLU: #Natural Language Understanding
|
|||||||
|
|
||||||
Wyjście: Akt użytkownika (rama)
|
Wyjście: Akt użytkownika (rama)
|
||||||
"""
|
"""
|
||||||
def __init__(self, acts, arguments, book_grammar):
|
def __init__(self, book_grammar):
|
||||||
self.acts = acts
|
|
||||||
self.arguments = arguments
|
|
||||||
self.book_grammar = book_grammar
|
self.book_grammar = book_grammar
|
||||||
|
|
||||||
def get_dialog_act(self, rule):
|
def get_dialog_act(self, rule):
|
||||||
|
@ -6,10 +6,8 @@ class DP:
|
|||||||
|
|
||||||
Wyjście: Akt systemu (rama)
|
Wyjście: Akt systemu (rama)
|
||||||
"""
|
"""
|
||||||
def __init__(self, acts, arguments):
|
def __init__(self):
|
||||||
self.acts = acts
|
pass
|
||||||
self.arguments = arguments
|
|
||||||
|
|
||||||
|
|
||||||
def choose_tactic(self, frame_list):
|
def choose_tactic(self, frame_list):
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,9 @@
|
|||||||
class DST: #Dialogue State Tracker
|
import json
|
||||||
|
from convlab2.dst.dst import DST
|
||||||
|
from convlab2.dst.rule.multiwoz.dst_util import normalize_value
|
||||||
|
|
||||||
|
|
||||||
|
class Rules_DST(DST): #Dialogue State Tracker
|
||||||
"""
|
"""
|
||||||
Moduł odpowiedzialny za śledzenie stanu dialogu. Przechowuje informacje o tym jakie dane zostały uzyskane od użytkownika w toku prowadzonej konwersacji.
|
Moduł odpowiedzialny za śledzenie stanu dialogu. Przechowuje informacje o tym jakie dane zostały uzyskane od użytkownika w toku prowadzonej konwersacji.
|
||||||
|
|
||||||
@ -6,22 +11,38 @@ class DST: #Dialogue State Tracker
|
|||||||
|
|
||||||
Wyjście: Reprezentacja stanu dialogu (rama)
|
Wyjście: Reprezentacja stanu dialogu (rama)
|
||||||
"""
|
"""
|
||||||
def __init__(self, acts, arguments):
|
def __init__(self):
|
||||||
self.acts = acts
|
DST.__init__(self)
|
||||||
self.arguments = arguments
|
self.state = json.load(open('default_state.json'))
|
||||||
self.frame_list= []
|
self.value_dict = json.load(open('value_dict.json'))
|
||||||
|
|
||||||
|
def update(self, user_act=None):
|
||||||
|
slots = user_act["slots"]
|
||||||
|
intent = user_act["act"]
|
||||||
|
domain = user_act["act"].split('/')[0]
|
||||||
|
|
||||||
def store(self, rama):
|
if domain in ['password', 'name', 'email', 'enter_email', 'enter_name']:
|
||||||
"""
|
return
|
||||||
Dodanie nowego aktu do listy
|
|
||||||
"""
|
|
||||||
print("\nDodanie do listy nowej ramy: ")
|
|
||||||
print(rama)
|
|
||||||
self.frame_list.append(rama)
|
|
||||||
|
|
||||||
|
|
||||||
def transfer(self):
|
if 'appointment' in intent:
|
||||||
print("Przekazanie dalej listy ram: ")
|
for full_slot in slots:
|
||||||
print(self.frame_list)
|
slot = full_slot[1]
|
||||||
return self.frame_list
|
value = full_slot[1]
|
||||||
|
k = self.value_dict[domain.lower()].get(slot, slot)
|
||||||
|
|
||||||
|
if k is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
domain_dic = self.state['belief_state'][domain]
|
||||||
|
|
||||||
|
if k in domain_dic['semi']:
|
||||||
|
nvalue = normalize_value(self.value_dict, domain, k, value)
|
||||||
|
self.state['belief_state'][domain]['semi'][k] = nvalue
|
||||||
|
elif k in domain_dic['book']:
|
||||||
|
self.state['belief_state'][domain]['book'][k] = value
|
||||||
|
elif k.lower() in domain_dic['book']:
|
||||||
|
self.state['belief_state'][domain]['book'][k.lower()] = value
|
||||||
|
elif intent == 'end_conversation':
|
||||||
|
self.state = {}
|
||||||
|
|
||||||
|
return self.state
|
@ -5,9 +5,7 @@ from flair.datasets import SentenceDataset
|
|||||||
from flair.models import SequenceTagger
|
from flair.models import SequenceTagger
|
||||||
|
|
||||||
class ML_NLU:
|
class ML_NLU:
|
||||||
def __init__(self, acts, arguments):
|
def __init__(self):
|
||||||
self.acts = acts
|
|
||||||
self.arguments = arguments
|
|
||||||
self.slot_model, self.frame_model = self.setup()
|
self.slot_model, self.frame_model = self.setup()
|
||||||
|
|
||||||
def nolabel2o(self, line, i):
|
def nolabel2o(self, line, i):
|
||||||
|
@ -6,10 +6,8 @@ class NLG:
|
|||||||
|
|
||||||
Wyjście: Tekst
|
Wyjście: Tekst
|
||||||
"""
|
"""
|
||||||
def __init__(self, acts, arguments):
|
def __init__(self):
|
||||||
self.acts = acts
|
pass
|
||||||
self.arguments = arguments
|
|
||||||
|
|
||||||
|
|
||||||
def change_to_text(self, act_vector):
|
def change_to_text(self, act_vector):
|
||||||
"""
|
"""
|
||||||
|
@ -46,6 +46,9 @@ frame_corpus = Corpus(train=conllu2flair(frame_trainset, 'frame'), test=conllu2f
|
|||||||
slot_tag_dictionary = slot_corpus.make_tag_dictionary(tag_type='slot')
|
slot_tag_dictionary = slot_corpus.make_tag_dictionary(tag_type='slot')
|
||||||
frame_tag_dictionary = frame_corpus.make_tag_dictionary(tag_type='frame')
|
frame_tag_dictionary = frame_corpus.make_tag_dictionary(tag_type='frame')
|
||||||
|
|
||||||
|
print(slot_tag_dictionary)
|
||||||
|
print(frame_tag_dictionary)
|
||||||
|
|
||||||
|
|
||||||
embedding_types = [
|
embedding_types = [
|
||||||
WordEmbeddings('pl'),
|
WordEmbeddings('pl'),
|
||||||
@ -62,7 +65,7 @@ frame_tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,
|
|||||||
tag_dictionary=frame_tag_dictionary,
|
tag_dictionary=frame_tag_dictionary,
|
||||||
tag_type='frame', use_crf=True)
|
tag_type='frame', use_crf=True)
|
||||||
|
|
||||||
# slot_trainer = ModelTrainer(slot_tagger, slot_corpus)
|
slot_trainer = ModelTrainer(slot_tagger, slot_corpus)
|
||||||
# slot_trainer.train('slot-model',
|
# slot_trainer.train('slot-model',
|
||||||
# learning_rate=0.1,
|
# learning_rate=0.1,
|
||||||
# mini_batch_size=32,
|
# mini_batch_size=32,
|
||||||
|
@ -828,8 +828,8 @@
|
|||||||
# intent: request_information/doctors
|
# intent: request_information/doctors
|
||||||
# slots:
|
# slots:
|
||||||
1 jacy request_information/doctors NoLabel
|
1 jacy request_information/doctors NoLabel
|
||||||
2 lekarze request_information/doctors B-appoinment/doctor
|
2 lekarze request_information/doctors B-appointment/doctor
|
||||||
3 specjaliści request_information/doctors B-appoinment/doctor
|
3 specjaliści request_information/doctors B-appointment/doctor
|
||||||
4 przyjmują request_information/doctors NoLabel
|
4 przyjmują request_information/doctors NoLabel
|
||||||
5 w request_information/doctors NoLabel
|
5 w request_information/doctors NoLabel
|
||||||
6 państwa request_information/doctors NoLabel
|
6 państwa request_information/doctors NoLabel
|
||||||
@ -840,7 +840,7 @@
|
|||||||
# slots:
|
# slots:
|
||||||
1 chciałbym appointment/create_appointment NoLabel
|
1 chciałbym appointment/create_appointment NoLabel
|
||||||
2 umówić appointment/create_appointment NoLabel
|
2 umówić appointment/create_appointment NoLabel
|
||||||
3 wizytę appointment/create_appointment B-appoinment
|
3 wizytę appointment/create_appointment B-appointment
|
||||||
4 do appointment/create_appointment NoLabel
|
4 do appointment/create_appointment NoLabel
|
||||||
5 doktora appointment/create_appointment B-appointment/doctor
|
5 doktora appointment/create_appointment B-appointment/doctor
|
||||||
6 kolano appointment/create_appointment I-appointment/doctor
|
6 kolano appointment/create_appointment I-appointment/doctor
|
||||||
@ -891,7 +891,7 @@
|
|||||||
5 okulisty appointment/create_appointment B-appointment/doctor
|
5 okulisty appointment/create_appointment B-appointment/doctor
|
||||||
6 ile request_information/cost NoLabel
|
6 ile request_information/cost NoLabel
|
||||||
7 kosztuje request_information/cost NoLabel
|
7 kosztuje request_information/cost NoLabel
|
||||||
8 wizyta request_information/cost B-appoinment
|
8 wizyta request_information/cost B-appointment
|
||||||
|
|
||||||
# text: Nie ten jest idealny
|
# text: Nie ten jest idealny
|
||||||
# intent: deny
|
# intent: deny
|
||||||
|
17
default_state.json
Normal file
17
default_state.json
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"user_action": [],
|
||||||
|
"system_action": [],
|
||||||
|
"belief_state": {
|
||||||
|
"appointment": {
|
||||||
|
"prescription": {
|
||||||
|
"url": {},
|
||||||
|
"type": {}
|
||||||
|
},
|
||||||
|
"doctor": {},
|
||||||
|
"where": {}
|
||||||
|
},
|
||||||
|
"request_state": {},
|
||||||
|
"terminated": false,
|
||||||
|
"history": []
|
||||||
|
}
|
||||||
|
}
|
110
value_dict.json
Normal file
110
value_dict.json
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
{
|
||||||
|
"appointment": {
|
||||||
|
"hour": [
|
||||||
|
"08:00",
|
||||||
|
"08:15",
|
||||||
|
"08:30",
|
||||||
|
"08:45",
|
||||||
|
"09:00",
|
||||||
|
"09:15",
|
||||||
|
"09:30",
|
||||||
|
"09:45",
|
||||||
|
"10:00",
|
||||||
|
"10:15",
|
||||||
|
"10:30",
|
||||||
|
"10:45",
|
||||||
|
"11:00",
|
||||||
|
"11:15",
|
||||||
|
"11:30",
|
||||||
|
"11:45",
|
||||||
|
"12:00",
|
||||||
|
"12:15",
|
||||||
|
"12:30",
|
||||||
|
"12:45",
|
||||||
|
"13:00",
|
||||||
|
"13:15",
|
||||||
|
"13:30",
|
||||||
|
"13:45",
|
||||||
|
"14:00",
|
||||||
|
"14:15",
|
||||||
|
"14:30",
|
||||||
|
"14:45",
|
||||||
|
"15:00",
|
||||||
|
"15:15",
|
||||||
|
"15:30",
|
||||||
|
"15:45",
|
||||||
|
"16:00",
|
||||||
|
"16:15",
|
||||||
|
"16:30",
|
||||||
|
"16:45",
|
||||||
|
"17:00",
|
||||||
|
"17:15",
|
||||||
|
"17:30",
|
||||||
|
"17:45",
|
||||||
|
"18:00",
|
||||||
|
"18:15",
|
||||||
|
"18:30",
|
||||||
|
"18:45",
|
||||||
|
"19:00",
|
||||||
|
"19:15",
|
||||||
|
"19:30",
|
||||||
|
"19:45"
|
||||||
|
],
|
||||||
|
"date": [
|
||||||
|
"poniedziałek",
|
||||||
|
"wtorek",
|
||||||
|
"środa",
|
||||||
|
"środę",
|
||||||
|
"czwartek",
|
||||||
|
"piątek",
|
||||||
|
"sobota",
|
||||||
|
"sobotę",
|
||||||
|
"niedziela",
|
||||||
|
"niedzielę",
|
||||||
|
"01.05.2021",
|
||||||
|
"02.05.2021",
|
||||||
|
"03.05.2021",
|
||||||
|
"04.05.2021",
|
||||||
|
"05.05.2021",
|
||||||
|
"06.05.2021",
|
||||||
|
"07.05.2021",
|
||||||
|
"08.05.2021",
|
||||||
|
"09.05.2021",
|
||||||
|
"10.05.2021",
|
||||||
|
"11.05.2021",
|
||||||
|
"12.05.2021",
|
||||||
|
"13.05.2021",
|
||||||
|
"14.05.2021",
|
||||||
|
"15.05.2021",
|
||||||
|
"16.05.2021",
|
||||||
|
"17.05.2021",
|
||||||
|
"18.05.2021",
|
||||||
|
"19.05.2021",
|
||||||
|
"20.05.2021",
|
||||||
|
"21.05.2021",
|
||||||
|
"22.05.2021",
|
||||||
|
"23.05.2021",
|
||||||
|
"24.05.2021",
|
||||||
|
"25.05.2021",
|
||||||
|
"26.05.2021",
|
||||||
|
"27.05.2021",
|
||||||
|
"28.05.2021",
|
||||||
|
"29.05.2021",
|
||||||
|
"20.05.2021"
|
||||||
|
],
|
||||||
|
"doctor": [
|
||||||
|
"internista",
|
||||||
|
"okulista",
|
||||||
|
"rodzinny",
|
||||||
|
"ginekolog",
|
||||||
|
"dr. Kolano",
|
||||||
|
"dr. Kowalska"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"prescritpion": {
|
||||||
|
"type": [
|
||||||
|
"recepta",
|
||||||
|
"e-recepta"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user