diff --git a/DBManager.py b/DBManager.py index 79e42bf..3efa59f 100644 --- a/DBManager.py +++ b/DBManager.py @@ -43,6 +43,10 @@ class calender_db: with open(self.db_file_path, 'r+') as f: self.db = json.load(f) + def clear_db(self): + self.db = {} + self.save_db() + def save_db(self): with open(self.db_file_path, 'w+') as f: json.dump(self.db, f) @@ -72,19 +76,10 @@ class calender_db: self.db[format_date(date_time)] = meetings self.save_db() - def update_meeting(self, meeting_dict): - date_time = get_date( - meeting_dict['date'].lower(), meeting_dict['time'].lower()) - if format_date(date_time) in self.db.keys(): - meetings = self.db[format_date(date_time)] - for key, meeting in enumerate(meetings): - if format_time(meeting['time']) == format_time(meeting_dict['time'].lower()): - meetings.remove(meeting) - - meetings.append(meeting_dict) - self.db[format_date(date_time)] = meetings - self.save_db() + def update_meeting(self, old_meeting_date, old_meeting_time, new_meeting_dict): + self.delete_meeting(old_meeting_date, old_meeting_time) + self.create_meeting(new_meeting_dict) def find_meeting(self, date, time): if date in self.db.keys(): @@ -180,10 +175,15 @@ class calender_db: # Tests # db = calender_db() -# db.create_meeting({"date": "16.06.2021", "time": "15:00", -# "description": "chuj"}) +# db.clear_db() +# db.create_meeting({"date": "16.06.2021", "time": "15:00", "description": "ciastko"}) # db.create_meeting({"date": "14.06.2021", "time": "13:00-18:00"}) -# db.delete_meeting("16.06.2021", "15:00") -# print(db.find_meeting("16.06.2021", "13:00-14:00")) -# print(db.get_meetings(["16.06.2021", "14.06.2021"])) -# print(db.is_collision("16.06.2021", "12:30-13")) +#db.create_meeting({"date": "16.06.2021", "time": "12:00-13:00", "description": "costam"}) +#print(db.get_meetings(["16.06.2021", "14.06.2021"])) +#db.update_meeting("16.06.2021", "12:00-13:00", {"date": "14.06.2021", "time": "11:00-12:00"}) +#print(db.get_meetings(["16.06.2021", "14.06.2021"])) +#db.delete_meeting("16.06.2021", "15:00") +#print(db.find_meeting("16.06.2021", "13:00-14:00")) +#print(db.get_meetings(["16.06.2021", "14.06.2021"])) +#print(db.is_collision("16.06.2021", "12:30-13")) +#db.clear_db() \ No newline at end of file diff --git a/DialoguePolicy.py b/DialoguePolicy.py index fe46bcb..fb7153a 100644 --- a/DialoguePolicy.py +++ b/DialoguePolicy.py @@ -11,8 +11,9 @@ class DP: Wyjście: Akt systemu (rama) """ - def __init__(self, dst): + def __init__(self, dst, db): self.DST = dst + self.DB = db self.meeting_to_update = False def chooseTactic(self) -> SystemAct: @@ -71,6 +72,7 @@ class DP: self.DST.system_update(system_act) return system_act else: + # TODO sprawdzanie czy spotkanie nie koliduje system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots) self.DST.system_update(system_act) return system_act @@ -78,6 +80,7 @@ class DP: if last_user_act == UserActType.CONFIRM: system_act = SystemAct(SystemActType.AFFIRM, ['create_meeting']) # implementacja wpisywanie spotkania do bazy + self.DB.create_meeting(slots) self.DST.clear() return system_act elif last_user_act == UserActType.NEGATE: @@ -201,7 +204,10 @@ class DP: return system_act else: # implementacja wyszukiwania odpowiedniego spotkania w bazie - system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_cancel']) + slots_to_delete = self.DB.find_meeting(slots['date'], slots['time']) + self.DST.update_slots(slots_to_delete) + #system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_cancel']) + system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots_to_delete) self.DST.system_update(system_act) return system_act elif last_system_act.getActType() == SystemActType.REQUEST: @@ -215,13 +221,17 @@ class DP: return system_act else: # implementacja wyszukiwania odpowiedniego spotkania w bazie - system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_cancel']) + slots_to_delete = self.DB.find_meeting(slots['date'], slots['time']) + self.DST.update_slots(slots_to_delete) + # system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, ['meeting_to_cancel']) + system_act = SystemAct(SystemActType.CONFIRM_DOMAIN, slots_to_delete) self.DST.system_update(system_act) return system_act elif last_system_act.getActType() == SystemActType.CONFIRM_DOMAIN: if last_user_act == UserActType.CONFIRM: system_act = SystemAct(SystemActType.AFFIRM, ['cancel_meeting']) # implementacja usuwania spotkania z bazy + self.DB.delete_meeting(slots['date'], ['time']) self.DST.clear() return system_act elif last_user_act == UserActType.NEGATE: @@ -231,18 +241,42 @@ class DP: return SystemAct(SystemActType.NOT_UNDERSTOOD, []) # stan prośby o listę spotkań elif dialogue_state == UserActType.MEETING_LIST: - if last_user_act == UserActType.NEGATE: - self.DST.clear() - return SystemAct(SystemActType.REQMORE, ['meeting_list']) - else: - if 'date' in slots: - system_act = SystemAct(SystemActType.MEETING_LIST, slots) - self.DST.clear() - return system_act - else: + if not last_system_act: + if 'date' not in slots: system_act = SystemAct(SystemActType.REQUEST, ['date']) self.DST.system_update(system_act) return system_act + else: + # implementacja wyszukiwania spotkań w bazie + meetings_slots = self.DB.get_meetings([slots['date']]) + system_act = SystemAct(SystemActType.MEETING_LIST, meetings_slots) + self.DST.system_update(system_act) + return system_act + elif last_system_act.getActType() == SystemActType.REQUEST: + if 'date' not in slots: + system_act = SystemAct(SystemActType.REQUEST, ['date']) + self.DST.system_update(system_act) + return system_act + else: + # implementacja wyszukiwania spotkań w bazie + meetings_slots = self.DB.get_meetings([slots['date']]) + system_act = SystemAct(SystemActType.MEETING_LIST, meetings_slots) + self.DST.system_update(system_act) + return system_act + else: + return SystemAct(SystemActType.NOT_UNDERSTOOD, []) + # if last_user_act == UserActType.NEGATE: + # self.DST.clear() + # return SystemAct(SystemActType.REQMORE, ['meeting_list']) + # else: + # if 'date' in slots: + # system_act = SystemAct(SystemActType.MEETING_LIST, slots) + # self.DST.clear() + # return system_act + # else: + # system_act = SystemAct(SystemActType.REQUEST, ['date']) + # self.DST.system_update(system_act) + # return system_act # stan prośby o czas wolny elif dialogue_state == UserActType.FREE_TIME: if last_user_act == UserActType.NEGATE: diff --git a/DialogueStateTracker.py b/DialogueStateTracker.py index 0eddd49..dd0c36d 100644 --- a/DialogueStateTracker.py +++ b/DialogueStateTracker.py @@ -45,6 +45,9 @@ class DST: def clear_slots(self): self.slots = {} + def update_slots(self, slots): + self.slots = slots + def get_dialogue_state(self): return self.state, self.last_user_act, self.last_system_act diff --git a/NaturalLanguageGeneration.py b/NaturalLanguageGeneration.py index cbe8f1c..189d03b 100644 --- a/NaturalLanguageGeneration.py +++ b/NaturalLanguageGeneration.py @@ -69,26 +69,24 @@ class NLG: return "W jakim dniu miało się odbyć to spotkanie?" if "time" in systemAct.getActParams(): return "W jakim czasie miało się odbyć to spotkanie?" - # TODO dopracować po dodaniu DB if systemAct.getActType() == SystemActType.CONFIRM_DOMAIN: date = slots['date'] time = slots['time'] - # place = slots['place'] - # part_list = slots['participants'] - # part = "" - # for p in part_list: - # part += p - # part += ", " - # part = part[:-2] - # desc = slots['description'] - return f'Spotkanie:\n' \ - f'Dzień: {date}\nCzas: {time}' + place = slots['place'] + part_list = slots['participants'] + part = "" + for p in part_list: + part += p + part += ", " + part = part[:-2] + desc = slots['description'] + return f'Odwołać te spotkanie?:\n' \ + f'Dzień: {date}\nCzas: {time}\nMiejsce: {place}\nUczestnicy: {part}\nOpis: {desc}' elif dialogue_state == UserActType.MEETING_LIST: if systemAct.getActType() == SystemActType.REQUEST: if "date" in systemAct.getActParams(): return "Z jakiego okresu chcesz przejrzeć spotkania?" - # TODO: dopracować po dodaniu DB if systemAct.getActType() == SystemActType.MEETING_LIST: response = "" for s in slots: @@ -104,6 +102,7 @@ class NLG: desc = s['description'] response += f'Spotkanie:\nDzień: {date}\nCzas: {time}\nMiejsce: {place}\nUczestnicy: {part}\nOpis: {desc}\n' response += "--------------------" + self.DST.clear_slots() return response elif dialogue_state == UserActType.FREE_TIME: diff --git a/main.py b/main.py index 7673aa1..0052992 100644 --- a/main.py +++ b/main.py @@ -3,36 +3,36 @@ from NaturalLanguageUnderstanding import NLU from NaturalLanguageGeneration import NLG from DialogueStateTracker import DST from DialoguePolicy import DP +from DBManager import calender_db if __name__ == "__main__": - + db = calender_db() nlu = NLU() dst = DST() - dp = DP(dst) + dp = DP(dst, db) nlg = NLG(dst) while(1): - user_input = input("Wpisz tekst: ") + user_input = input("\nWpisz tekst: ") user_frame = nlu.parse_user_input(user_input) - print('------ rozpoznany user frame ------') + print('\n------ rozpoznany user frame ------') print(user_frame) + dst.user_update(user_frame) state, last_user_act, last_system_act = dst.get_dialogue_state() slots = dst.get_dialogue_slots() system_act = dp.chooseTactic() - print('------ stan ------') + print('\n------ stan ------') print(state, last_user_act, last_system_act) - print('------ przechowywane sloty ------') + print('\n------ przechowywane sloty ------') print(slots) - print('------ wybrana akcja systemu ------') + print('\n------ wybrana akcja systemu ------') print(system_act) system_response = nlg.generateResponse(system_act) - print('------ wygenerowana odpowiedź systemu ------') + print('\n------ wygenerowana odpowiedź systemu ------') print(system_response) - print('-----------------------------------') - print('-----------------------------------') if system_act.getActType() == SystemActType.BYE: break