diff --git a/requirements.txt b/requirements.txt index 75af938..a4fc1ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,5 @@ flair==0.13.1 conllu==4.5.3 pandas==1.5.3 numpy==1.26.4 -torch==2.3.0 +torch==1.13 convlab==3.0.2a0 \ No newline at end of file diff --git a/src/service/convlab/policy/policy.py b/src/service/convlab/policy/policy.py deleted file mode 100644 index 5ac85a8..0000000 --- a/src/service/convlab/policy/policy.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Policy Interface""" -from convlab.util.module import Module - - -class Policy(Module): - """Policy module interface.""" - - def predict(self, state): - """Predict the next agent action given dialog state. - - Args: - state (dict or list of list): - when the policy takes dialogue state as input, the type is dict. - else when the policy takes dialogue act as input, the type is list of list. - Returns: - action (list of list or str): - when the policy outputs dialogue act, the type is list of list. - else when the policy outputs utterance directly, the type is str. - """ - return [] - - def update_memory(self, utterance_list, state_list, action_list, reward_list): - pass - diff --git a/src/service/convlab/util/module.py b/src/service/convlab/util/module.py deleted file mode 100644 index 9d280ce..0000000 --- a/src/service/convlab/util/module.py +++ /dev/null @@ -1,25 +0,0 @@ -"""module interface.""" -from abc import ABC - - -class Module(ABC): - - def train(self, *args, **kwargs): - """Model training entry point""" - pass - - def test(self, *args, **kwargs): - """Model testing entry point""" - pass - - def from_cache(self, *args, **kwargs): - """restore internal state for multi-turn dialog""" - return None - - def to_cache(self, *args, **kwargs): - """save internal state for multi-turn dialog""" - return None - - def init_session(self): - """Init the class variables for a new session.""" - pass diff --git a/src/service/convlab/util/restaurant/dbquery.py b/src/service/convlab/util/restaurant/dbquery.py deleted file mode 100644 index 735874d..0000000 --- a/src/service/convlab/util/restaurant/dbquery.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -""" -import json -import os -import random -from fuzzywuzzy import fuzz -from itertools import chain -from copy import deepcopy - - -class Database(object): - def __init__(self): - super(Database, self).__init__() - # loading databases - domains = ['menu', 'pizza', 'drink', 'size'] - self.dbs = {} - for domain in domains: - with open(os.path.join(os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), - 'data/restaurant/db/{}_db.json'.format(domain))) as f: - self.dbs[domain] = json.load(f) - - def query(self, domain): - """Returns the list of entities for a given domain - based on the annotation of the belief state""" - # query the db - if domain == 'pizza': - return [{'Name': random.choice(self.dbs[domain]['name'])}] - if domain == 'menu': - return deepcopy(self.dbs[domain]) - if domain == 'drink': - return [{'Name': random.choice(self.dbs[domain]['name'])}] - if domain == 'size': - return [{'Size': random.choice(self.dbs[domain]['size'])}] - - -if __name__ == '__main__': - db = Database() \ No newline at end of file diff --git a/src/service/dbquery.py b/src/service/dbquery.py new file mode 100644 index 0000000..dd51afa --- /dev/null +++ b/src/service/dbquery.py @@ -0,0 +1,91 @@ +""" +""" +import json +import os +import random +from fuzzywuzzy import fuzz +from itertools import chain +from copy import deepcopy + + +class Database(object): + def __init__(self): + super(Database, self).__init__() + # loading databases + domains = ['restaurant', 'hotel', 'attraction', 'train', 'hospital', 'taxi', 'police'] + self.dbs = {} + for domain in domains: + with open(os.path.join(os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), + 'data/restaurant/db/{}_db.json'.format(domain))) as f: + self.dbs[domain] = json.load(f) + + def query(self, domain, constraints, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60): + """Returns the list of entities for a given domain + based on the annotation of the belief state""" + # query the db + if domain == 'taxi': + return [{'taxi_colors': random.choice(self.dbs[domain]['taxi_colors']), + 'taxi_types': random.choice(self.dbs[domain]['taxi_types']), + 'taxi_phone': ''.join([str(random.randint(1, 9)) for _ in range(11)])}] + if domain == 'police': + return deepcopy(self.dbs['police']) + if domain == 'hospital': + department = None + for key, val in constraints: + if key == 'department': + department = val + if not department: + return deepcopy(self.dbs['hospital']) + else: + return [deepcopy(x) for x in self.dbs['hospital'] if x['department'].lower() == department.strip().lower()] + constraints = list(map(lambda ele: ele if not(ele[0] == 'area' and ele[1] == 'center') else ('area', 'centre'), constraints)) + + found = [] + for i, record in enumerate(self.dbs[domain]): + constraints_iterator = zip(constraints, [False] * len(constraints)) + soft_contraints_iterator = zip(soft_contraints, [True] * len(soft_contraints)) + for (key, val), fuzzy_match in chain(constraints_iterator, soft_contraints_iterator): + if val == "" or val == "dont care" or val == 'not mentioned' or val == "don't care" or val == "dontcare" or val == "do n't care": + pass + else: + try: + record_keys = [k.lower() for k in record] + if key.lower() not in record_keys: + continue + if key == 'leaveAt': + val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1]) + val2 = int(record['leaveAt'].split(':')[0]) * 100 + int(record['leaveAt'].split(':')[1]) + if val1 > val2: + break + elif key == 'arriveBy': + val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1]) + val2 = int(record['arriveBy'].split(':')[0]) * 100 + int(record['arriveBy'].split(':')[1]) + if val1 < val2: + break + # elif ignore_open and key in ['destination', 'departure', 'name']: + elif ignore_open and key in ['destination', 'departure']: + continue + elif record[key].strip() == '?': + # '?' matches any constraint + continue + else: + if not fuzzy_match: + if val.strip().lower() != record[key].strip().lower(): + break + else: + if fuzz.partial_ratio(val.strip().lower(), record[key].strip().lower()) < fuzzy_match_ratio: + break + except: + continue + else: + res = deepcopy(record) + res['Ref'] = '{0:08d}'.format(i) + found.append(res) + + return found + + +if __name__ == '__main__': + db = Database() + print(db.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arriveBy', '11:15']])) diff --git a/src/service/dialog_policy.py b/src/service/dialog_policy.py index cafa615..79dbb4d 100644 --- a/src/service/dialog_policy.py +++ b/src/service/dialog_policy.py @@ -4,7 +4,7 @@ import json from copy import deepcopy from convlab.policy.policy import Policy -from convlab.util.restaurant.dbquery import Database +from dbquery import Database class SimpleRulePolicy(Policy): def __init__(self): @@ -23,9 +23,9 @@ class SimpleRulePolicy(Policy): self.update_system_action(user_act, user_action, state, system_action) # ReguĊ‚a 3 - if any(True for slots in user_action.values() for (slot, _) in slots if slot in ['pizza', 'size', 'drink']): + if any(True for slots in user_action.values() for (slot, _) in slots if slot in ['book stay', 'book day', 'book people']): if self.results: - system_action = {('Ordering', 'Order'): [["Ref", self.results[0].get('Ref', 'N/A')]]} + system_action = {('Booking', 'Book'): [["Ref", self.results[0].get('Ref', 'N/A')]]} system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for slot, value in slots] state['system_action'] = system_acts @@ -53,7 +53,7 @@ class SimpleRulePolicy(Policy): system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))]) choice = self.results[0] - if domain in ["pizza", "drink"]: + if domain in ["hotel", "attraction", "police", "restaurant"]: system_action[(domain, 'Recommend')].append(['Name', choice['name']]) - if domain in ["size"]: - system_action[(domain, 'Recommend')].append(['Size', choice['size']]) \ No newline at end of file + +dialogPolicy = SimpleRulePolicy() \ No newline at end of file