Changes for dialog_policy

This commit is contained in:
s495728 2024-06-10 17:48:55 +02:00
parent f4d9bff809
commit f7ee0ba059
3 changed files with 20 additions and 23 deletions

View File

@ -4,3 +4,4 @@ pandas==1.5.3
numpy==1.26.4 numpy==1.26.4
torch==1.13 torch==1.13
convlab==3.0.2a0 convlab==3.0.2a0
scipy==1.12

View File

@ -1,16 +1,17 @@
from service.dialog_state_monitor import DialogStateMonitor from service.dialog_state_monitor import DialogStateMonitor
from service.dialog_policy import DialogPolicy from service.dialog_policy import DialogPolicy
from service.natural_languag_understanding import NaturalLanguageUnderstanding from service.natural_languag_understanding import NaturalLanguageUnderstanding
from service.natural_language_generation import NaturalLanguageGeneration from service.natural_language_generation import NaturalLanguageGeneration, parse_frame
from service.templates import templates from service.templates import templates
from convlab.dialog_agent import PipelineAgent
# initialize classes # initialize classes
nlu = NaturalLanguageUnderstanding() # NLU
nlu = NaturalLanguageUnderstanding(use_mocks=False) # NLU
monitor = DialogStateMonitor() # DSM monitor = DialogStateMonitor() # DSM
dialog_policy = DialogPolicy() # DP dialog_policy = DialogPolicy() # DP
language_generation = NaturalLanguageGeneration(templates) # NLG language_generation = NaturalLanguageGeneration(templates) # NLG
# Main loop # Main loop
dial_num = 0 dial_num = 0
print("CTRL+C aby zakończyć program.") print("CTRL+C aby zakończyć program.")
@ -23,25 +24,20 @@ while True:
while True: while True:
# NLU # NLU
frame = nlu.predict(user_input) frame = nlu.predict(user_input)
# print(frame) print("Frame: ", frame)
# DSM # DSM
monitor.update(frame) monitor.update(frame)
# DP # DP
system_action = dialog_policy.predict(monitor) system_action = dialog_policy.predict(monitor)
print("System action: ", system_action)
# NLG # NLG
response = language_generation.generate(frame) act, slots = parse_frame(frame)
response = language_generation.generate(act, slots)
print(response) print(response)
if frame.act == "bye": if frame.act == "bye":
break break
user_input = input(">\n") user_input = input(">\n")

View File

@ -7,27 +7,27 @@ class DialogPolicy:
def predict(self, dsm): def predict(self, dsm):
system_action = defaultdict(list) system_action = defaultdict(list)
last_frame = dsm.state['history'][-1] last_frame = dsm.state['history'][-1]
if(dsm.was_previous_order_invalid==False): if(dsm.state['was_previous_order_invalid']==False):
match(last_frame.act.strip('/')): match(last_frame.act.strip('/')):
case "inform" | "affirm": case "inform" | "affirm":
current_active_status = dsm.get_current_active_stage() current_active_status = dsm.get_current_active_stage()
match(current_active_status): match(current_active_status):
case "collect_food": case "collect_food":
system_action["inform"].append([Slot("menu", dsm.state['constants']['menu'])]) system_action["inform"].append(["menu", dsm.state['constants']['menu']])
case "collect_drinks": case "collect_drinks":
system_action["inform"].append([Slot("drink"), dsm.state['constants']['drink']]) system_action["inform"].append(["drink", dsm.state['constants']['drink']])
case "collect_address": case "collect_address":
system_action["request"].append([Slot("address")]) system_action["request"].append(["address"])
if(current_active_status == None): if(current_active_status == None):
system_action["inform"].append([Slot("order"), dsm.belief_state]) system_action["inform"].append(["order", dsm.belief_state])
case "request": case "request":
for slot in last_frame.slots: for slot in last_frame.slots:
system_action["inform"].append([Slot(slot), dsm.state['constants'][slot]]) system_action["inform"].append([slot, dsm.state['constants'][slot]])
case "welcomemsg": case "welcomemsg":
system_action["inform"].append([Slot("menu"), dsm.state['constants']['menu']]) system_action["inform"].append(["menu", dsm.state['constants']['menu']])
case "bye": case "bye":
system_action["bye"].append([], None) system_action["bye"].append([[], None])
case "negate": case "negate":
system_action["affirm"].append([], None) system_action["affirm"].append([[], None])
return system_action return system_action