DST (start)
This commit is contained in:
parent
2227204dc4
commit
2eca0a043d
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
*frame-model*
|
||||
*slot-model*
|
||||
.venv*
|
||||
env*
|
File diff suppressed because it is too large
Load Diff
@ -2,7 +2,7 @@ import jsgf
|
||||
|
||||
from Modules.NLG_module import NLG
|
||||
from Modules.DP_module import DP
|
||||
from Modules.DST_module import DST
|
||||
from Modules.DST_module import Rules_DST
|
||||
from Modules.Book_NLU_module import Book_NLU
|
||||
from Modules.ML_NLU_module import ML_NLU
|
||||
|
||||
@ -20,26 +20,15 @@ if torch.cuda.is_available():
|
||||
|
||||
class Janet:
|
||||
def __init__(self):
|
||||
self.acts={
|
||||
0: "greetings",
|
||||
1: "request",
|
||||
}
|
||||
self.arguments={
|
||||
0: "name"
|
||||
}
|
||||
self.nlg = NLG(self.acts, self.arguments)
|
||||
self.dp = DP(self.acts, self.arguments)
|
||||
self.dst = DST(self.acts, self.arguments)
|
||||
self.nlu = Book_NLU(self.acts, self.arguments, jsgf.parse_grammar_file('book.jsgf'))
|
||||
self.nlu_v2 = ML_NLU(self.acts, self.arguments)
|
||||
|
||||
def test(self, command):
|
||||
out = self.nlu_v2.test_nlu(command)
|
||||
return out
|
||||
self.nlg = NLG()
|
||||
self.dp = DP()
|
||||
self.dst = Rules_DST()
|
||||
self.nlu = Book_NLU(jsgf.parse_grammar_file('book.jsgf'))
|
||||
self.nlu_v2 = ML_NLU()
|
||||
|
||||
def process(self, command):
|
||||
act = self.nlu.analyze(command)
|
||||
self.dst.store(act)
|
||||
act = self.nlu_v2.test_nlu(command)
|
||||
self.dst.update(act)
|
||||
dest_act = self.dp.choose_tactic(self.dst.transfer())
|
||||
return self.nlg.change_to_text(dest_act)
|
||||
|
||||
@ -49,7 +38,7 @@ def main():
|
||||
while(1):
|
||||
print('\n')
|
||||
text = input("Wpisz tekst: ")
|
||||
print(janet.test(text))
|
||||
print(janet.process(text))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -8,9 +8,7 @@ class Book_NLU: #Natural Language Understanding
|
||||
|
||||
Wyjście: Akt użytkownika (rama)
|
||||
"""
|
||||
def __init__(self, acts, arguments, book_grammar):
|
||||
self.acts = acts
|
||||
self.arguments = arguments
|
||||
def __init__(self, book_grammar):
|
||||
self.book_grammar = book_grammar
|
||||
|
||||
def get_dialog_act(self, rule):
|
||||
|
@ -6,10 +6,8 @@ class DP:
|
||||
|
||||
Wyjście: Akt systemu (rama)
|
||||
"""
|
||||
def __init__(self, acts, arguments):
|
||||
self.acts = acts
|
||||
self.arguments = arguments
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def choose_tactic(self, frame_list):
|
||||
"""
|
||||
|
@ -1,4 +1,9 @@
|
||||
class DST: #Dialogue State Tracker
|
||||
import json
|
||||
from convlab2.dst.dst import DST
|
||||
from convlab2.dst.rule.multiwoz.dst_util import normalize_value
|
||||
|
||||
|
||||
class Rules_DST(DST): #Dialogue State Tracker
|
||||
"""
|
||||
Moduł odpowiedzialny za śledzenie stanu dialogu. Przechowuje informacje o tym jakie dane zostały uzyskane od użytkownika w toku prowadzonej konwersacji.
|
||||
|
||||
@ -6,22 +11,38 @@ class DST: #Dialogue State Tracker
|
||||
|
||||
Wyjście: Reprezentacja stanu dialogu (rama)
|
||||
"""
|
||||
def __init__(self, acts, arguments):
|
||||
self.acts = acts
|
||||
self.arguments = arguments
|
||||
self.frame_list= []
|
||||
def __init__(self):
|
||||
DST.__init__(self)
|
||||
self.state = json.load(open('default_state.json'))
|
||||
self.value_dict = json.load(open('value_dict.json'))
|
||||
|
||||
def update(self, user_act=None):
|
||||
slots = user_act["slots"]
|
||||
intent = user_act["act"]
|
||||
domain = user_act["act"].split('/')[0]
|
||||
|
||||
def store(self, rama):
|
||||
"""
|
||||
Dodanie nowego aktu do listy
|
||||
"""
|
||||
print("\nDodanie do listy nowej ramy: ")
|
||||
print(rama)
|
||||
self.frame_list.append(rama)
|
||||
if domain in ['password', 'name', 'email', 'enter_email', 'enter_name']:
|
||||
return
|
||||
|
||||
if 'appointment' in intent:
|
||||
for full_slot in slots:
|
||||
slot = full_slot[1]
|
||||
value = full_slot[1]
|
||||
k = self.value_dict[domain.lower()].get(slot, slot)
|
||||
|
||||
def transfer(self):
|
||||
print("Przekazanie dalej listy ram: ")
|
||||
print(self.frame_list)
|
||||
return self.frame_list
|
||||
if k is None:
|
||||
return
|
||||
|
||||
domain_dic = self.state['belief_state'][domain]
|
||||
|
||||
if k in domain_dic['semi']:
|
||||
nvalue = normalize_value(self.value_dict, domain, k, value)
|
||||
self.state['belief_state'][domain]['semi'][k] = nvalue
|
||||
elif k in domain_dic['book']:
|
||||
self.state['belief_state'][domain]['book'][k] = value
|
||||
elif k.lower() in domain_dic['book']:
|
||||
self.state['belief_state'][domain]['book'][k.lower()] = value
|
||||
elif intent == 'end_conversation':
|
||||
self.state = {}
|
||||
|
||||
return self.state
|
@ -5,9 +5,7 @@ from flair.datasets import SentenceDataset
|
||||
from flair.models import SequenceTagger
|
||||
|
||||
class ML_NLU:
|
||||
def __init__(self, acts, arguments):
|
||||
self.acts = acts
|
||||
self.arguments = arguments
|
||||
def __init__(self):
|
||||
self.slot_model, self.frame_model = self.setup()
|
||||
|
||||
def nolabel2o(self, line, i):
|
||||
|
@ -6,10 +6,8 @@ class NLG:
|
||||
|
||||
Wyjście: Tekst
|
||||
"""
|
||||
def __init__(self, acts, arguments):
|
||||
self.acts = acts
|
||||
self.arguments = arguments
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def change_to_text(self, act_vector):
|
||||
"""
|
||||
|
@ -46,6 +46,9 @@ frame_corpus = Corpus(train=conllu2flair(frame_trainset, 'frame'), test=conllu2f
|
||||
slot_tag_dictionary = slot_corpus.make_tag_dictionary(tag_type='slot')
|
||||
frame_tag_dictionary = frame_corpus.make_tag_dictionary(tag_type='frame')
|
||||
|
||||
print(slot_tag_dictionary)
|
||||
print(frame_tag_dictionary)
|
||||
|
||||
|
||||
embedding_types = [
|
||||
WordEmbeddings('pl'),
|
||||
@ -62,7 +65,7 @@ frame_tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,
|
||||
tag_dictionary=frame_tag_dictionary,
|
||||
tag_type='frame', use_crf=True)
|
||||
|
||||
# slot_trainer = ModelTrainer(slot_tagger, slot_corpus)
|
||||
slot_trainer = ModelTrainer(slot_tagger, slot_corpus)
|
||||
# slot_trainer.train('slot-model',
|
||||
# learning_rate=0.1,
|
||||
# mini_batch_size=32,
|
||||
|
@ -828,8 +828,8 @@
|
||||
# intent: request_information/doctors
|
||||
# slots:
|
||||
1 jacy request_information/doctors NoLabel
|
||||
2 lekarze request_information/doctors B-appoinment/doctor
|
||||
3 specjaliści request_information/doctors B-appoinment/doctor
|
||||
2 lekarze request_information/doctors B-appointment/doctor
|
||||
3 specjaliści request_information/doctors B-appointment/doctor
|
||||
4 przyjmują request_information/doctors NoLabel
|
||||
5 w request_information/doctors NoLabel
|
||||
6 państwa request_information/doctors NoLabel
|
||||
@ -840,7 +840,7 @@
|
||||
# slots:
|
||||
1 chciałbym appointment/create_appointment NoLabel
|
||||
2 umówić appointment/create_appointment NoLabel
|
||||
3 wizytę appointment/create_appointment B-appoinment
|
||||
3 wizytę appointment/create_appointment B-appointment
|
||||
4 do appointment/create_appointment NoLabel
|
||||
5 doktora appointment/create_appointment B-appointment/doctor
|
||||
6 kolano appointment/create_appointment I-appointment/doctor
|
||||
@ -891,7 +891,7 @@
|
||||
5 okulisty appointment/create_appointment B-appointment/doctor
|
||||
6 ile request_information/cost NoLabel
|
||||
7 kosztuje request_information/cost NoLabel
|
||||
8 wizyta request_information/cost B-appoinment
|
||||
8 wizyta request_information/cost B-appointment
|
||||
|
||||
# text: Nie ten jest idealny
|
||||
# intent: deny
|
||||
|
17
default_state.json
Normal file
17
default_state.json
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"user_action": [],
|
||||
"system_action": [],
|
||||
"belief_state": {
|
||||
"appointment": {
|
||||
"prescription": {
|
||||
"url": {},
|
||||
"type": {}
|
||||
},
|
||||
"doctor": {},
|
||||
"where": {}
|
||||
},
|
||||
"request_state": {},
|
||||
"terminated": false,
|
||||
"history": []
|
||||
}
|
||||
}
|
110
value_dict.json
Normal file
110
value_dict.json
Normal file
@ -0,0 +1,110 @@
|
||||
{
|
||||
"appointment": {
|
||||
"hour": [
|
||||
"08:00",
|
||||
"08:15",
|
||||
"08:30",
|
||||
"08:45",
|
||||
"09:00",
|
||||
"09:15",
|
||||
"09:30",
|
||||
"09:45",
|
||||
"10:00",
|
||||
"10:15",
|
||||
"10:30",
|
||||
"10:45",
|
||||
"11:00",
|
||||
"11:15",
|
||||
"11:30",
|
||||
"11:45",
|
||||
"12:00",
|
||||
"12:15",
|
||||
"12:30",
|
||||
"12:45",
|
||||
"13:00",
|
||||
"13:15",
|
||||
"13:30",
|
||||
"13:45",
|
||||
"14:00",
|
||||
"14:15",
|
||||
"14:30",
|
||||
"14:45",
|
||||
"15:00",
|
||||
"15:15",
|
||||
"15:30",
|
||||
"15:45",
|
||||
"16:00",
|
||||
"16:15",
|
||||
"16:30",
|
||||
"16:45",
|
||||
"17:00",
|
||||
"17:15",
|
||||
"17:30",
|
||||
"17:45",
|
||||
"18:00",
|
||||
"18:15",
|
||||
"18:30",
|
||||
"18:45",
|
||||
"19:00",
|
||||
"19:15",
|
||||
"19:30",
|
||||
"19:45"
|
||||
],
|
||||
"date": [
|
||||
"poniedziałek",
|
||||
"wtorek",
|
||||
"środa",
|
||||
"środę",
|
||||
"czwartek",
|
||||
"piątek",
|
||||
"sobota",
|
||||
"sobotę",
|
||||
"niedziela",
|
||||
"niedzielę",
|
||||
"01.05.2021",
|
||||
"02.05.2021",
|
||||
"03.05.2021",
|
||||
"04.05.2021",
|
||||
"05.05.2021",
|
||||
"06.05.2021",
|
||||
"07.05.2021",
|
||||
"08.05.2021",
|
||||
"09.05.2021",
|
||||
"10.05.2021",
|
||||
"11.05.2021",
|
||||
"12.05.2021",
|
||||
"13.05.2021",
|
||||
"14.05.2021",
|
||||
"15.05.2021",
|
||||
"16.05.2021",
|
||||
"17.05.2021",
|
||||
"18.05.2021",
|
||||
"19.05.2021",
|
||||
"20.05.2021",
|
||||
"21.05.2021",
|
||||
"22.05.2021",
|
||||
"23.05.2021",
|
||||
"24.05.2021",
|
||||
"25.05.2021",
|
||||
"26.05.2021",
|
||||
"27.05.2021",
|
||||
"28.05.2021",
|
||||
"29.05.2021",
|
||||
"20.05.2021"
|
||||
],
|
||||
"doctor": [
|
||||
"internista",
|
||||
"okulista",
|
||||
"rodzinny",
|
||||
"ginekolog",
|
||||
"dr. Kolano",
|
||||
"dr. Kowalska"
|
||||
]
|
||||
},
|
||||
"prescritpion": {
|
||||
"type": [
|
||||
"recepta",
|
||||
"e-recepta"
|
||||
]
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user