diff --git a/src/service/convlab/policy/policy.py b/src/service/convlab/policy/policy.py deleted file mode 100644 index 5ac85a8..0000000 --- a/src/service/convlab/policy/policy.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Policy Interface""" -from convlab.util.module import Module - - -class Policy(Module): - """Policy module interface.""" - - def predict(self, state): - """Predict the next agent action given dialog state. - - Args: - state (dict or list of list): - when the policy takes dialogue state as input, the type is dict. - else when the policy takes dialogue act as input, the type is list of list. - Returns: - action (list of list or str): - when the policy outputs dialogue act, the type is list of list. - else when the policy outputs utterance directly, the type is str. - """ - return [] - - def update_memory(self, utterance_list, state_list, action_list, reward_list): - pass - diff --git a/src/service/convlab/util/module.py b/src/service/convlab/util/module.py deleted file mode 100644 index 9d280ce..0000000 --- a/src/service/convlab/util/module.py +++ /dev/null @@ -1,25 +0,0 @@ -"""module interface.""" -from abc import ABC - - -class Module(ABC): - - def train(self, *args, **kwargs): - """Model training entry point""" - pass - - def test(self, *args, **kwargs): - """Model testing entry point""" - pass - - def from_cache(self, *args, **kwargs): - """restore internal state for multi-turn dialog""" - return None - - def to_cache(self, *args, **kwargs): - """save internal state for multi-turn dialog""" - return None - - def init_session(self): - """Init the class variables for a new session.""" - pass diff --git a/src/service/convlab/util/restaurant/dbquery.py b/src/service/convlab/util/restaurant/dbquery.py deleted file mode 100644 index 735874d..0000000 --- a/src/service/convlab/util/restaurant/dbquery.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -""" -import json -import os -import random -from fuzzywuzzy import fuzz -from itertools import chain -from copy import deepcopy - - -class Database(object): - def __init__(self): - super(Database, self).__init__() - # loading databases - domains = ['menu', 'pizza', 'drink', 'size'] - self.dbs = {} - for domain in domains: - with open(os.path.join(os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), - 'data/restaurant/db/{}_db.json'.format(domain))) as f: - self.dbs[domain] = json.load(f) - - def query(self, domain): - """Returns the list of entities for a given domain - based on the annotation of the belief state""" - # query the db - if domain == 'pizza': - return [{'Name': random.choice(self.dbs[domain]['name'])}] - if domain == 'menu': - return deepcopy(self.dbs[domain]) - if domain == 'drink': - return [{'Name': random.choice(self.dbs[domain]['name'])}] - if domain == 'size': - return [{'Size': random.choice(self.dbs[domain]['size'])}] - - -if __name__ == '__main__': - db = Database() \ No newline at end of file diff --git a/src/service/dialog_policy.py b/src/service/dialog_policy.py index cafa615..12eaacc 100644 --- a/src/service/dialog_policy.py +++ b/src/service/dialog_policy.py @@ -56,4 +56,6 @@ class SimpleRulePolicy(Policy): if domain in ["pizza", "drink"]: system_action[(domain, 'Recommend')].append(['Name', choice['name']]) if domain in ["size"]: - system_action[(domain, 'Recommend')].append(['Size', choice['size']]) \ No newline at end of file + system_action[(domain, 'Recommend')].append(['Size', choice['size']]) + +dialogPolicy = SimpleRulePolicy() \ No newline at end of file