From f7ee0ba0590839eef11f408983e6edd21cc5e10a Mon Sep 17 00:00:00 2001 From: s495728 Date: Mon, 10 Jun 2024 17:48:55 +0200 Subject: [PATCH] Changes for dialog_policy --- requirements.txt | 3 ++- src/main.py | 22 +++++++++------------- src/service/dialog_policy.py | 18 +++++++++--------- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/requirements.txt b/requirements.txt index a4fc1ff..0450c2f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ conllu==4.5.3 pandas==1.5.3 numpy==1.26.4 torch==1.13 -convlab==3.0.2a0 \ No newline at end of file +convlab==3.0.2a0 +scipy==1.12 \ No newline at end of file diff --git a/src/main.py b/src/main.py index d8dd502..aafeefb 100644 --- a/src/main.py +++ b/src/main.py @@ -1,16 +1,17 @@ from service.dialog_state_monitor import DialogStateMonitor from service.dialog_policy import DialogPolicy from service.natural_languag_understanding import NaturalLanguageUnderstanding -from service.natural_language_generation import NaturalLanguageGeneration +from service.natural_language_generation import NaturalLanguageGeneration, parse_frame from service.templates import templates +from convlab.dialog_agent import PipelineAgent # initialize classes - -nlu = NaturalLanguageUnderstanding(use_mocks=False) # NLU +nlu = NaturalLanguageUnderstanding() # NLU monitor = DialogStateMonitor() # DSM dialog_policy = DialogPolicy() # DP language_generation = NaturalLanguageGeneration(templates) # NLG + # Main loop dial_num = 0 print("CTRL+C aby zakończyć program.") @@ -23,25 +24,20 @@ while True: while True: # NLU frame = nlu.predict(user_input) - # print(frame) + print("Frame: ", frame) # DSM monitor.update(frame) # DP system_action = dialog_policy.predict(monitor) - + print("System action: ", system_action) # NLG - response = language_generation.generate(frame) + act, slots = parse_frame(frame) + response = language_generation.generate(act, slots) print(response) if frame.act == "bye": break - user_input = input(">\n") - - - - - - + user_input = input(">\n") \ No newline at end of file diff --git a/src/service/dialog_policy.py b/src/service/dialog_policy.py index 18597b1..cb8a1de 100644 --- a/src/service/dialog_policy.py +++ b/src/service/dialog_policy.py @@ -7,27 +7,27 @@ class DialogPolicy: def predict(self, dsm): system_action = defaultdict(list) last_frame = dsm.state['history'][-1] - if(dsm.was_previous_order_invalid==False): + if(dsm.state['was_previous_order_invalid']==False): match(last_frame.act.strip('/')): case "inform" | "affirm": current_active_status = dsm.get_current_active_stage() match(current_active_status): case "collect_food": - system_action["inform"].append([Slot("menu", dsm.state['constants']['menu'])]) + system_action["inform"].append(["menu", dsm.state['constants']['menu']]) case "collect_drinks": - system_action["inform"].append([Slot("drink"), dsm.state['constants']['drink']]) + system_action["inform"].append(["drink", dsm.state['constants']['drink']]) case "collect_address": - system_action["request"].append([Slot("address")]) + system_action["request"].append(["address"]) if(current_active_status == None): - system_action["inform"].append([Slot("order"), dsm.belief_state]) + system_action["inform"].append(["order", dsm.belief_state]) case "request": for slot in last_frame.slots: - system_action["inform"].append([Slot(slot), dsm.state['constants'][slot]]) + system_action["inform"].append([slot, dsm.state['constants'][slot]]) case "welcomemsg": - system_action["inform"].append([Slot("menu"), dsm.state['constants']['menu']]) + system_action["inform"].append(["menu", dsm.state['constants']['menu']]) case "bye": - system_action["bye"].append([], None) + system_action["bye"].append([[], None]) case "negate": - system_action["affirm"].append([], None) + system_action["affirm"].append([[], None]) return system_action \ No newline at end of file