fix
This commit is contained in:
parent
0da7ca614d
commit
5dc7e17ae4
@ -2,10 +2,17 @@ from collections import defaultdict
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from convlab2.policy.policy import Policy
|
from convlab2.policy.policy import Policy
|
||||||
from convlab2.util.multiwoz.dbquery import Database
|
from convlab2.util.multiwoz.dbquery import Database
|
||||||
from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
|
# from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
|
||||||
from convlab2.dialog_agent import PipelineAgent
|
from convlab2.dialog_agent import PipelineAgent
|
||||||
from DST import DST
|
from DST import DST
|
||||||
|
|
||||||
|
REF_SYS_DA = {
|
||||||
|
'Cinema': {
|
||||||
|
'Type': 'type','Price': 'price','Stars': 'stars',
|
||||||
|
'Name': 'name','Day': 'day','People': 'people','Movie': 'movie',
|
||||||
|
'E-mail': 'e-mail', 'none': None
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
# Taktyka prowadzenia dialogu
|
# Taktyka prowadzenia dialogu
|
||||||
class DP(Policy):
|
class DP(Policy):
|
||||||
@ -57,7 +64,7 @@ class DP(Policy):
|
|||||||
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
|
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
|
||||||
choice = self.results[0]
|
choice = self.results[0]
|
||||||
|
|
||||||
if domain in ["Hotel", "Attraction", "Police", "Restaurant"]:
|
if domain in ["Cinema", "Hotel", "Attraction", "Police", "Restaurant"]:
|
||||||
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
|
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
|
||||||
|
|
||||||
|
|
||||||
@ -66,5 +73,5 @@ class DP(Policy):
|
|||||||
dst = DST()
|
dst = DST()
|
||||||
dp = DP()
|
dp = DP()
|
||||||
agent = PipelineAgent(nlu=None, dst=dst, policy=dp, nlg=None, name='sys')
|
agent = PipelineAgent(nlu=None, dst=dst, policy=dp, nlg=None, name='sys')
|
||||||
print(agent.response([['Inform', 'Hotel', 'Price', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']]))
|
print(agent.response([['Inform', 'Cinema', 'Price', '15 zł'], ['Inform', 'Cinema', 'Movie', 'Batman']]))
|
||||||
"""
|
"""
|
||||||
|
@ -2,8 +2,15 @@ from dialogue_state import default_state
|
|||||||
import json
|
import json
|
||||||
from convlab2.dst.dst import DST as CL2DST
|
from convlab2.dst.dst import DST as CL2DST
|
||||||
from convlab2.dst.rule.multiwoz.dst_util import normalize_value
|
from convlab2.dst.rule.multiwoz.dst_util import normalize_value
|
||||||
from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
|
# from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
|
||||||
|
|
||||||
|
REF_SYS_DA = {
|
||||||
|
'Cinema': {
|
||||||
|
'Type': 'type','Price': 'price','Stars': 'stars',
|
||||||
|
'Name': 'name','Day': 'day','People': 'people','Movie': 'movie',
|
||||||
|
'E-mail': 'e-mail', 'none': None
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
# Monitor stanu dialogu
|
# Monitor stanu dialogu
|
||||||
class DST(CL2DST):
|
class DST(CL2DST):
|
||||||
@ -54,6 +61,6 @@ class DST(CL2DST):
|
|||||||
dst = DST()
|
dst = DST()
|
||||||
print(dst.state)
|
print(dst.state)
|
||||||
|
|
||||||
dst.update([['Inform', 'Hotel', 'Price', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']])
|
dst.update([['Inform', 'Cinema', 'Price', '15 zł'], ['Inform', 'Cinema', 'Movie', 'Batman']])
|
||||||
print(dst.state['belief_state']['hotel'])
|
print(dst.state['belief_state']['cinema'])
|
||||||
"""
|
"""
|
||||||
|
0
DST_DP_lab_9-10/dialogue_state.py
Executable file → Normal file
0
DST_DP_lab_9-10/dialogue_state.py
Executable file → Normal file
0
DST_DP_lab_9-10/value_dict.json
Executable file → Normal file
0
DST_DP_lab_9-10/value_dict.json
Executable file → Normal file
Loading…
Reference in New Issue
Block a user