This commit is contained in:
s444417 2022-06-07 22:33:57 +02:00
parent 4ef41c18d3
commit 18bebd05a1
5 changed files with 76 additions and 87 deletions

View File

@ -1,62 +1,54 @@
from dataclasses import dataclass
class DP:
def __init__(self):
self.questionManager = UserQuestionModule()
self.database = DBMock()
def getAction(self, lastUserAct, emptySlots, systemSlots, filledSlots):
def getAction(self, lastUserAct, emptySlots, systemSlots, slotsWithValues):
systemAct = None
slotVal = None
if (lastUserAct == "request" | lastUserAct == "reqmore"):
# TODO policy for user request
return ["Cinema", "select", "", []]
if ((lastUserAct == "request") | (lastUserAct == "reqmore")):
title = None
date = None
time = None
if "title" in slotsWithValues.keys(): title = slotsWithValues["title"]
if "date" in slotsWithValues.keys(): date = slotsWithValues["date"]
if "time" in slotsWithValues.keys(): time = slotsWithValues["time"]
return ["Cinema", "select", "", self.database.getShows(title=title,date=date,time=time)]
elif ((lastUserAct == "hello") | (lastUserAct == "inform") | (lastUserAct == None)):
# there are no empty slots
systemAct = None
slotName = None
value = None
if not emptySlots:
# TODO generate reservation id
reservationId = max(self.database.reservations.keys()) + 1
# TODO add reservation
systemAct = "inform"
slotVal = systemSlots
slotName = systemSlots
value = reservationId
# there are empty slots
else:
for slot in emptySlots:
systemAct = "request"
slotVal = slot
slotName = slot
break
return ["Cinema", systemAct, slotVal, ""]
return ["Cinema", systemAct, slotName, value]
else:
systemAct = "repeat"
return ["Cinema", systemAct, "", ""]
class UserQuestionModule():
def __init__(self):
self.questionSlots = ["title", "date", "time"]
def getQuestionTypes(self):
return self.questionTypes
def getQuestionType(self, filledSlots):
filledSlots = []
# for slot in self.questionSlots
# if "title" in emptySlots:
# questionType = self.questionTypes[0]
# elif "date" in emptySlots:
# questionType = self.questionTypes[1]
# elif "time" in emptySlots:
# questionType = self.getQuestionTypes[2]
# return questionType,
class DBMock():
def __init__(self):
self.shows = {
1: {
"title": "Batman",
"date": "21.06",
"time": 19,
"date": "08.06",
"time": "19:00",
"seats": ["a1", "a2", "a3", "a4", "a5",
"b1", "b2", "b3", "b4", "b5",
"c1", "c2", "c3", "c4", "c5",
@ -70,8 +62,8 @@ class DBMock():
},
2: {
"title": "Batman",
"date": "22.06",
"time": 20,
"date": "08.06",
"time": "20:00",
"seats": ["a1", "a2", "a3", "a4", "a5",
"b1", "b2", "b3", "b4", "b5",
"c1", "c2", "c3", "c4", "c5",
@ -85,8 +77,8 @@ class DBMock():
},
3: {
"title": "Zorro",
"date": "23.06",
"time": 21,
"date": "09.06",
"time": "21:00",
"seats": ["a1", "a2", "a3", "a4", "a5",
"b1", "b2", "b3", "b4", "b5",
"c1", "c2", "c3", "c4", "c5",
@ -132,43 +124,27 @@ class DBMock():
}
def getShows(self, title = None, date = None, time = None,):
# title is None
if(title is None):
titles = []
for e in self.shows:
if (date is not None & time is not None):
if e["date"] == str(date):
if e["time"] == str(time):
titles.append([e["title"], e["date"]])
elif (date is not None & time is None):
if e["date"] == str(date):
titles.append([e["title"], e["date"]])
elif (date is None & time is not None):
if e["time"] == str(time):
titles.append([e["title"], e["date"]])
return set(titles)
# title is not None
elif(title is not None):
if(date is None):
dates = []
for e in self.shows:
if e["title"] == str(date):
dates.append(e["date"])
elif(date is not None):
if(time is None):
return set(titles)
# time slot is not None
elif(title is None & time is not None & date is None):
titles = []
for e in self.shows:
if e["time"] == str(time):
titles.append(e["title"], e["time"])
return set(titles)
elif(title is None & time is not None & date is None):
titles = []
for e in self.shows:
if e["time"] == str(time):
titles.append(e["title"], e["time"])
return set(titles)
result = []
for key in self.shows.keys():
# title is None
if(title is None):
if ((date is not None) & (time is not None)):
if self.shows[key]["date"] == str(date):
if self.shows[key]["time"] == str(time):
result.append([self.shows[key]["title"], self.shows[key]["date"]])
elif ((date is not None) & (time is None)):
if self.shows[key]["date"] == str(date):
result.append([self.shows[key]["title"], self.shows[key]["date"]])
elif ((date is None) & (time is not None)):
if self.shows[key]["time"] == str(time):
result.append([self.shows[key]["title"], self.shows[key]["date"]])
# title is not None
elif(title is not None):
if(date is None):
if self.shows[key]["title"] == str(title):
result.append(self.shows[key]["date"])
elif(date is not None):
if(time is None):
if self.shows[key]["date"] == str(date):
result.append(self.shows[key]["time"])
return set(result)

View File

@ -10,7 +10,7 @@ class DST:
for intent, domain, slot, value in user_act:
domain = domain.lower()
intent = intent.lower()
value = value.lower()
value = value
slot = slot.lower()
# all intents are same
@ -48,6 +48,13 @@ class DST:
result.append(key)
return result
def getSlotsWithValues(self):
result = {}
for key in self.state['belief_state']["cinema"]["book"].keys():
if self.state['belief_state']["cinema"]["book"][key] != "":
result[key] = self.state['belief_state']["cinema"]["book"][key]
return result
def getSystemSlots(self):
result = []
for key in self.state['belief_state']["cinema"]["book"].keys():

View File

@ -3,6 +3,8 @@ from flair.data import Sentence, Token
from flair.datasets import SentenceDataset
from flair.models import SequenceTagger, TextClassifier
from .chane import getDate, getTitle
class NLU:
def __init__(self):
@ -83,6 +85,11 @@ class NLU:
# slotValue = value
if slotValue is not None:
# normalise input
if slot == "title":
slotValue = getTitle(slotValue)
elif slot == "date":
slotValue = getDate(slotValue)
result.append([intent, 'Cinema', slot, slotValue])
if len(result) == 0: result.append([intent, 'Cinema', "", ""])
return result

View File

@ -1,11 +1,8 @@
from difflib import SequenceMatcher
from datetime import date
import datetime
from dateutil.parser import parse
def getDate(user_date):
#jeżeli w dacie są jakieś liczby, to uznajemy ją za poprawną datę
if any(char.isdigit() for char in user_date):
@ -22,17 +19,17 @@ def getDate(user_date):
#zwrócenie wyniku dzisiaj, dzisiaj+1 (jutro), dzisiaj+2 (pojutrze)
if result_today > result_tommorow and result_today > result_day_after_tomorrow:
return date.today().strftime("%m.%d")
return date.today().strftime("%d.%m")
elif result_tommorow > result_day_after_tomorrow:
return (date.today() + datetime.timedelta(days=1)).strftime("%m.%d")
return (date.today() + datetime.timedelta(days=1)).strftime("%d.%m")
else:
return (date.today() + datetime.timedelta(days=2)).strftime("%m.%d")
return (date.today() + datetime.timedelta(days=2)).strftime("%d.%m")
def getTitle(user_title):
titles=["Batman", "Na Noże", "Uncharted", "Ambulans", "Minionki", "Fantastyczne Zwierzęta", "To Nie Wypanda",
"Inni Ludzie"]
"Inni Ludzie", "Zorro"]
number_list = list(map(lambda x: SequenceMatcher(a=user_title, b=x).ratio(), titles))
max_value = max(number_list)
max_index = number_list.index(max_value)

View File

@ -22,12 +22,14 @@ def chatbot():
if userMessage == "/exit":
print("Do usłyszenia")
isActive = False
elif userMessage == "/reset":
chatbot()
else:
nluPred = nlu.predict(sentence=userMessage)
print(nluPred)
dst.update(nluPred)
# print(dst.state)
dpAct = dp.getAction(dst.getLastUserAct(), dst.getEmptySlots(), dst.getSystemSlots())
dpAct = dp.getAction(dst.getLastUserAct(), dst.getEmptySlots(), dst.getSystemSlots(), dst.getSlotsWithValues())
print(dpAct)
# TODO update DST system act
chatbot()