This commit is contained in:
Julian Zabłoński 2022-05-25 10:44:21 +02:00
parent 0da7ca614d
commit 5dc7e17ae4
4 changed files with 20 additions and 6 deletions

View File

@ -2,10 +2,17 @@ from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from convlab2.policy.policy import Policy from convlab2.policy.policy import Policy
from convlab2.util.multiwoz.dbquery import Database from convlab2.util.multiwoz.dbquery import Database
from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA # from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
from convlab2.dialog_agent import PipelineAgent from convlab2.dialog_agent import PipelineAgent
from DST import DST from DST import DST
REF_SYS_DA = {
'Cinema': {
'Type': 'type','Price': 'price','Stars': 'stars',
'Name': 'name','Day': 'day','People': 'people','Movie': 'movie',
'E-mail': 'e-mail', 'none': None
},
}
# Taktyka prowadzenia dialogu # Taktyka prowadzenia dialogu
class DP(Policy): class DP(Policy):
@ -57,7 +64,7 @@ class DP(Policy):
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))]) system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
choice = self.results[0] choice = self.results[0]
if domain in ["Hotel", "Attraction", "Police", "Restaurant"]: if domain in ["Cinema", "Hotel", "Attraction", "Police", "Restaurant"]:
system_action[(domain, 'Recommend')].append(['Name', choice['name']]) system_action[(domain, 'Recommend')].append(['Name', choice['name']])
@ -66,5 +73,5 @@ class DP(Policy):
dst = DST() dst = DST()
dp = DP() dp = DP()
agent = PipelineAgent(nlu=None, dst=dst, policy=dp, nlg=None, name='sys') agent = PipelineAgent(nlu=None, dst=dst, policy=dp, nlg=None, name='sys')
print(agent.response([['Inform', 'Hotel', 'Price', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']])) print(agent.response([['Inform', 'Cinema', 'Price', '15 zł'], ['Inform', 'Cinema', 'Movie', 'Batman']]))
""" """

View File

@ -2,8 +2,15 @@ from dialogue_state import default_state
import json import json
from convlab2.dst.dst import DST as CL2DST from convlab2.dst.dst import DST as CL2DST
from convlab2.dst.rule.multiwoz.dst_util import normalize_value from convlab2.dst.rule.multiwoz.dst_util import normalize_value
from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA # from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
REF_SYS_DA = {
'Cinema': {
'Type': 'type','Price': 'price','Stars': 'stars',
'Name': 'name','Day': 'day','People': 'people','Movie': 'movie',
'E-mail': 'e-mail', 'none': None
},
}
# Monitor stanu dialogu # Monitor stanu dialogu
class DST(CL2DST): class DST(CL2DST):
@ -54,6 +61,6 @@ class DST(CL2DST):
dst = DST() dst = DST()
print(dst.state) print(dst.state)
dst.update([['Inform', 'Hotel', 'Price', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']]) dst.update([['Inform', 'Cinema', 'Price', '15 zł'], ['Inform', 'Cinema', 'Movie', 'Batman']])
print(dst.state['belief_state']['hotel']) print(dst.state['belief_state']['cinema'])
""" """

0
DST_DP_lab_9-10/dialogue_state.py Executable file → Normal file
View File

0
DST_DP_lab_9-10/value_dict.json Executable file → Normal file
View File