Changes for dialog_policy

This commit is contained in:
s495728 2024-06-10 17:48:55 +02:00
parent f4d9bff809
commit f7ee0ba059
3 changed files with 20 additions and 23 deletions

View File

@ -3,4 +3,5 @@ conllu==4.5.3
pandas==1.5.3
numpy==1.26.4
torch==1.13
convlab==3.0.2a0
convlab==3.0.2a0
scipy==1.12

View File

@ -1,16 +1,17 @@
from service.dialog_state_monitor import DialogStateMonitor
from service.dialog_policy import DialogPolicy
from service.natural_languag_understanding import NaturalLanguageUnderstanding
from service.natural_language_generation import NaturalLanguageGeneration
from service.natural_language_generation import NaturalLanguageGeneration, parse_frame
from service.templates import templates
from convlab.dialog_agent import PipelineAgent
# initialize classes
nlu = NaturalLanguageUnderstanding(use_mocks=False) # NLU
nlu = NaturalLanguageUnderstanding() # NLU
monitor = DialogStateMonitor() # DSM
dialog_policy = DialogPolicy() # DP
language_generation = NaturalLanguageGeneration(templates) # NLG
# Main loop
dial_num = 0
print("CTRL+C aby zakończyć program.")
@ -23,25 +24,20 @@ while True:
while True:
# NLU
frame = nlu.predict(user_input)
# print(frame)
print("Frame: ", frame)
# DSM
monitor.update(frame)
# DP
system_action = dialog_policy.predict(monitor)
print("System action: ", system_action)
# NLG
response = language_generation.generate(frame)
act, slots = parse_frame(frame)
response = language_generation.generate(act, slots)
print(response)
if frame.act == "bye":
break
user_input = input(">\n")
user_input = input(">\n")

View File

@ -7,27 +7,27 @@ class DialogPolicy:
def predict(self, dsm):
system_action = defaultdict(list)
last_frame = dsm.state['history'][-1]
if(dsm.was_previous_order_invalid==False):
if(dsm.state['was_previous_order_invalid']==False):
match(last_frame.act.strip('/')):
case "inform" | "affirm":
current_active_status = dsm.get_current_active_stage()
match(current_active_status):
case "collect_food":
system_action["inform"].append([Slot("menu", dsm.state['constants']['menu'])])
system_action["inform"].append(["menu", dsm.state['constants']['menu']])
case "collect_drinks":
system_action["inform"].append([Slot("drink"), dsm.state['constants']['drink']])
system_action["inform"].append(["drink", dsm.state['constants']['drink']])
case "collect_address":
system_action["request"].append([Slot("address")])
system_action["request"].append(["address"])
if(current_active_status == None):
system_action["inform"].append([Slot("order"), dsm.belief_state])
system_action["inform"].append(["order", dsm.belief_state])
case "request":
for slot in last_frame.slots:
system_action["inform"].append([Slot(slot), dsm.state['constants'][slot]])
system_action["inform"].append([slot, dsm.state['constants'][slot]])
case "welcomemsg":
system_action["inform"].append([Slot("menu"), dsm.state['constants']['menu']])
system_action["inform"].append(["menu", dsm.state['constants']['menu']])
case "bye":
system_action["bye"].append([], None)
system_action["bye"].append([[], None])
case "negate":
system_action["affirm"].append([], None)
system_action["affirm"].append([[], None])
return system_action