From 18bebd05a1e9c208cc0f147f8b13c6d4cd1bb6eb Mon Sep 17 00:00:00 2001 From: s444417 Date: Tue, 7 Jun 2022 22:33:57 +0200 Subject: [PATCH] dp --- src/components/DP.py | 130 ++++++++++++++++------------------------ src/components/DST.py | 9 ++- src/components/NLU.py | 7 +++ src/components/chane.py | 13 ++-- src/dialogue_system.py | 4 +- 5 files changed, 76 insertions(+), 87 deletions(-) diff --git a/src/components/DP.py b/src/components/DP.py index 6cd36f6..e0b5a56 100644 --- a/src/components/DP.py +++ b/src/components/DP.py @@ -1,62 +1,54 @@ +from dataclasses import dataclass + + class DP: def __init__(self): - self.questionManager = UserQuestionModule() + self.database = DBMock() - def getAction(self, lastUserAct, emptySlots, systemSlots, filledSlots): + def getAction(self, lastUserAct, emptySlots, systemSlots, slotsWithValues): systemAct = None slotVal = None - if (lastUserAct == "request" | lastUserAct == "reqmore"): - # TODO policy for user request - return ["Cinema", "select", "", []] + if ((lastUserAct == "request") | (lastUserAct == "reqmore")): + title = None + date = None + time = None + if "title" in slotsWithValues.keys(): title = slotsWithValues["title"] + if "date" in slotsWithValues.keys(): date = slotsWithValues["date"] + if "time" in slotsWithValues.keys(): time = slotsWithValues["time"] + return ["Cinema", "select", "", self.database.getShows(title=title,date=date,time=time)] elif ((lastUserAct == "hello") | (lastUserAct == "inform") | (lastUserAct == None)): # there are no empty slots + systemAct = None + slotName = None + value = None if not emptySlots: - # TODO generate reservation id + + reservationId = max(self.database.reservations.keys()) + 1 + # TODO add reservation systemAct = "inform" - slotVal = systemSlots + slotName = systemSlots + value = reservationId # there are empty slots else: for slot in emptySlots: systemAct = "request" - slotVal = slot + slotName = slot break - return ["Cinema", systemAct, slotVal, ""] + return ["Cinema", systemAct, slotName, value] else: systemAct = "repeat" return ["Cinema", systemAct, "", ""] - -class UserQuestionModule(): - - def __init__(self): - self.questionSlots = ["title", "date", "time"] - - def getQuestionTypes(self): - return self.questionTypes - - def getQuestionType(self, filledSlots): - filledSlots = [] - # for slot in self.questionSlots - - # if "title" in emptySlots: - # questionType = self.questionTypes[0] - # elif "date" in emptySlots: - # questionType = self.questionTypes[1] - # elif "time" in emptySlots: - # questionType = self.getQuestionTypes[2] - - # return questionType, - class DBMock(): def __init__(self): self.shows = { 1: { "title": "Batman", - "date": "21.06", - "time": 19, + "date": "08.06", + "time": "19:00", "seats": ["a1", "a2", "a3", "a4", "a5", "b1", "b2", "b3", "b4", "b5", "c1", "c2", "c3", "c4", "c5", @@ -70,8 +62,8 @@ class DBMock(): }, 2: { "title": "Batman", - "date": "22.06", - "time": 20, + "date": "08.06", + "time": "20:00", "seats": ["a1", "a2", "a3", "a4", "a5", "b1", "b2", "b3", "b4", "b5", "c1", "c2", "c3", "c4", "c5", @@ -85,8 +77,8 @@ class DBMock(): }, 3: { "title": "Zorro", - "date": "23.06", - "time": 21, + "date": "09.06", + "time": "21:00", "seats": ["a1", "a2", "a3", "a4", "a5", "b1", "b2", "b3", "b4", "b5", "c1", "c2", "c3", "c4", "c5", @@ -132,43 +124,27 @@ class DBMock(): } def getShows(self, title = None, date = None, time = None,): - # title is None - if(title is None): - titles = [] - for e in self.shows: - if (date is not None & time is not None): - if e["date"] == str(date): - if e["time"] == str(time): - titles.append([e["title"], e["date"]]) - elif (date is not None & time is None): - if e["date"] == str(date): - titles.append([e["title"], e["date"]]) - elif (date is None & time is not None): - if e["time"] == str(time): - titles.append([e["title"], e["date"]]) - return set(titles) - # title is not None - elif(title is not None): - if(date is None): - dates = [] - for e in self.shows: - if e["title"] == str(date): - dates.append(e["date"]) - elif(date is not None): - if(time is None): - - - return set(titles) - # time slot is not None - elif(title is None & time is not None & date is None): - titles = [] - for e in self.shows: - if e["time"] == str(time): - titles.append(e["title"], e["time"]) - return set(titles) - elif(title is None & time is not None & date is None): - titles = [] - for e in self.shows: - if e["time"] == str(time): - titles.append(e["title"], e["time"]) - return set(titles) \ No newline at end of file + result = [] + for key in self.shows.keys(): + # title is None + if(title is None): + if ((date is not None) & (time is not None)): + if self.shows[key]["date"] == str(date): + if self.shows[key]["time"] == str(time): + result.append([self.shows[key]["title"], self.shows[key]["date"]]) + elif ((date is not None) & (time is None)): + if self.shows[key]["date"] == str(date): + result.append([self.shows[key]["title"], self.shows[key]["date"]]) + elif ((date is None) & (time is not None)): + if self.shows[key]["time"] == str(time): + result.append([self.shows[key]["title"], self.shows[key]["date"]]) + # title is not None + elif(title is not None): + if(date is None): + if self.shows[key]["title"] == str(title): + result.append(self.shows[key]["date"]) + elif(date is not None): + if(time is None): + if self.shows[key]["date"] == str(date): + result.append(self.shows[key]["time"]) + return set(result) \ No newline at end of file diff --git a/src/components/DST.py b/src/components/DST.py index cde7e0d..a0c46af 100644 --- a/src/components/DST.py +++ b/src/components/DST.py @@ -10,7 +10,7 @@ class DST: for intent, domain, slot, value in user_act: domain = domain.lower() intent = intent.lower() - value = value.lower() + value = value slot = slot.lower() # all intents are same @@ -48,6 +48,13 @@ class DST: result.append(key) return result + def getSlotsWithValues(self): + result = {} + for key in self.state['belief_state']["cinema"]["book"].keys(): + if self.state['belief_state']["cinema"]["book"][key] != "": + result[key] = self.state['belief_state']["cinema"]["book"][key] + return result + def getSystemSlots(self): result = [] for key in self.state['belief_state']["cinema"]["book"].keys(): diff --git a/src/components/NLU.py b/src/components/NLU.py index 103b716..18395a5 100644 --- a/src/components/NLU.py +++ b/src/components/NLU.py @@ -3,6 +3,8 @@ from flair.data import Sentence, Token from flair.datasets import SentenceDataset from flair.models import SequenceTagger, TextClassifier +from .chane import getDate, getTitle + class NLU: def __init__(self): @@ -83,6 +85,11 @@ class NLU: # slotValue = value if slotValue is not None: + # normalise input + if slot == "title": + slotValue = getTitle(slotValue) + elif slot == "date": + slotValue = getDate(slotValue) result.append([intent, 'Cinema', slot, slotValue]) if len(result) == 0: result.append([intent, 'Cinema', "", ""]) return result diff --git a/src/components/chane.py b/src/components/chane.py index 585d6ca..ae90782 100644 --- a/src/components/chane.py +++ b/src/components/chane.py @@ -1,11 +1,8 @@ from difflib import SequenceMatcher from datetime import date +import datetime from dateutil.parser import parse - - - - def getDate(user_date): #jeżeli w dacie są jakieś liczby, to uznajemy ją za poprawną datę if any(char.isdigit() for char in user_date): @@ -22,17 +19,17 @@ def getDate(user_date): #zwrócenie wyniku dzisiaj, dzisiaj+1 (jutro), dzisiaj+2 (pojutrze) if result_today > result_tommorow and result_today > result_day_after_tomorrow: - return date.today().strftime("%m.%d") + return date.today().strftime("%d.%m") elif result_tommorow > result_day_after_tomorrow: - return (date.today() + datetime.timedelta(days=1)).strftime("%m.%d") + return (date.today() + datetime.timedelta(days=1)).strftime("%d.%m") else: - return (date.today() + datetime.timedelta(days=2)).strftime("%m.%d") + return (date.today() + datetime.timedelta(days=2)).strftime("%d.%m") def getTitle(user_title): titles=["Batman", "Na Noże", "Uncharted", "Ambulans", "Minionki", "Fantastyczne Zwierzęta", "To Nie Wypanda", - "Inni Ludzie"] + "Inni Ludzie", "Zorro"] number_list = list(map(lambda x: SequenceMatcher(a=user_title, b=x).ratio(), titles)) max_value = max(number_list) max_index = number_list.index(max_value) diff --git a/src/dialogue_system.py b/src/dialogue_system.py index e0abf19..48aa975 100644 --- a/src/dialogue_system.py +++ b/src/dialogue_system.py @@ -22,12 +22,14 @@ def chatbot(): if userMessage == "/exit": print("Do usłyszenia") isActive = False + elif userMessage == "/reset": + chatbot() else: nluPred = nlu.predict(sentence=userMessage) print(nluPred) dst.update(nluPred) # print(dst.state) - dpAct = dp.getAction(dst.getLastUserAct(), dst.getEmptySlots(), dst.getSystemSlots()) + dpAct = dp.getAction(dst.getLastUserAct(), dst.getEmptySlots(), dst.getSystemSlots(), dst.getSlotsWithValues()) print(dpAct) # TODO update DST system act chatbot() \ No newline at end of file