Changes for dialog_policy
This commit is contained in:
parent
f4d9bff809
commit
f7ee0ba059
@ -4,3 +4,4 @@ pandas==1.5.3
|
|||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
torch==1.13
|
torch==1.13
|
||||||
convlab==3.0.2a0
|
convlab==3.0.2a0
|
||||||
|
scipy==1.12
|
20
src/main.py
20
src/main.py
@ -1,16 +1,17 @@
|
|||||||
from service.dialog_state_monitor import DialogStateMonitor
|
from service.dialog_state_monitor import DialogStateMonitor
|
||||||
from service.dialog_policy import DialogPolicy
|
from service.dialog_policy import DialogPolicy
|
||||||
from service.natural_languag_understanding import NaturalLanguageUnderstanding
|
from service.natural_languag_understanding import NaturalLanguageUnderstanding
|
||||||
from service.natural_language_generation import NaturalLanguageGeneration
|
from service.natural_language_generation import NaturalLanguageGeneration, parse_frame
|
||||||
from service.templates import templates
|
from service.templates import templates
|
||||||
|
from convlab.dialog_agent import PipelineAgent
|
||||||
|
|
||||||
# initialize classes
|
# initialize classes
|
||||||
|
nlu = NaturalLanguageUnderstanding() # NLU
|
||||||
nlu = NaturalLanguageUnderstanding(use_mocks=False) # NLU
|
|
||||||
monitor = DialogStateMonitor() # DSM
|
monitor = DialogStateMonitor() # DSM
|
||||||
dialog_policy = DialogPolicy() # DP
|
dialog_policy = DialogPolicy() # DP
|
||||||
language_generation = NaturalLanguageGeneration(templates) # NLG
|
language_generation = NaturalLanguageGeneration(templates) # NLG
|
||||||
|
|
||||||
|
|
||||||
# Main loop
|
# Main loop
|
||||||
dial_num = 0
|
dial_num = 0
|
||||||
print("CTRL+C aby zakończyć program.")
|
print("CTRL+C aby zakończyć program.")
|
||||||
@ -23,25 +24,20 @@ while True:
|
|||||||
while True:
|
while True:
|
||||||
# NLU
|
# NLU
|
||||||
frame = nlu.predict(user_input)
|
frame = nlu.predict(user_input)
|
||||||
# print(frame)
|
print("Frame: ", frame)
|
||||||
|
|
||||||
# DSM
|
# DSM
|
||||||
monitor.update(frame)
|
monitor.update(frame)
|
||||||
|
|
||||||
# DP
|
# DP
|
||||||
system_action = dialog_policy.predict(monitor)
|
system_action = dialog_policy.predict(monitor)
|
||||||
|
print("System action: ", system_action)
|
||||||
# NLG
|
# NLG
|
||||||
response = language_generation.generate(frame)
|
act, slots = parse_frame(frame)
|
||||||
|
response = language_generation.generate(act, slots)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
if frame.act == "bye":
|
if frame.act == "bye":
|
||||||
break
|
break
|
||||||
|
|
||||||
user_input = input(">\n")
|
user_input = input(">\n")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,27 +7,27 @@ class DialogPolicy:
|
|||||||
def predict(self, dsm):
|
def predict(self, dsm):
|
||||||
system_action = defaultdict(list)
|
system_action = defaultdict(list)
|
||||||
last_frame = dsm.state['history'][-1]
|
last_frame = dsm.state['history'][-1]
|
||||||
if(dsm.was_previous_order_invalid==False):
|
if(dsm.state['was_previous_order_invalid']==False):
|
||||||
match(last_frame.act.strip('/')):
|
match(last_frame.act.strip('/')):
|
||||||
case "inform" | "affirm":
|
case "inform" | "affirm":
|
||||||
current_active_status = dsm.get_current_active_stage()
|
current_active_status = dsm.get_current_active_stage()
|
||||||
match(current_active_status):
|
match(current_active_status):
|
||||||
case "collect_food":
|
case "collect_food":
|
||||||
system_action["inform"].append([Slot("menu", dsm.state['constants']['menu'])])
|
system_action["inform"].append(["menu", dsm.state['constants']['menu']])
|
||||||
case "collect_drinks":
|
case "collect_drinks":
|
||||||
system_action["inform"].append([Slot("drink"), dsm.state['constants']['drink']])
|
system_action["inform"].append(["drink", dsm.state['constants']['drink']])
|
||||||
case "collect_address":
|
case "collect_address":
|
||||||
system_action["request"].append([Slot("address")])
|
system_action["request"].append(["address"])
|
||||||
if(current_active_status == None):
|
if(current_active_status == None):
|
||||||
system_action["inform"].append([Slot("order"), dsm.belief_state])
|
system_action["inform"].append(["order", dsm.belief_state])
|
||||||
case "request":
|
case "request":
|
||||||
for slot in last_frame.slots:
|
for slot in last_frame.slots:
|
||||||
system_action["inform"].append([Slot(slot), dsm.state['constants'][slot]])
|
system_action["inform"].append([slot, dsm.state['constants'][slot]])
|
||||||
case "welcomemsg":
|
case "welcomemsg":
|
||||||
system_action["inform"].append([Slot("menu"), dsm.state['constants']['menu']])
|
system_action["inform"].append(["menu", dsm.state['constants']['menu']])
|
||||||
case "bye":
|
case "bye":
|
||||||
system_action["bye"].append([], None)
|
system_action["bye"].append([[], None])
|
||||||
case "negate":
|
case "negate":
|
||||||
system_action["affirm"].append([], None)
|
system_action["affirm"].append([[], None])
|
||||||
|
|
||||||
return system_action
|
return system_action
|
Loading…
Reference in New Issue
Block a user