dp
This commit is contained in:
parent
4ef41c18d3
commit
18bebd05a1
@ -1,62 +1,54 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class DP:
|
||||
|
||||
def __init__(self):
|
||||
self.questionManager = UserQuestionModule()
|
||||
self.database = DBMock()
|
||||
|
||||
def getAction(self, lastUserAct, emptySlots, systemSlots, filledSlots):
|
||||
def getAction(self, lastUserAct, emptySlots, systemSlots, slotsWithValues):
|
||||
systemAct = None
|
||||
slotVal = None
|
||||
|
||||
if (lastUserAct == "request" | lastUserAct == "reqmore"):
|
||||
# TODO policy for user request
|
||||
return ["Cinema", "select", "", []]
|
||||
if ((lastUserAct == "request") | (lastUserAct == "reqmore")):
|
||||
title = None
|
||||
date = None
|
||||
time = None
|
||||
if "title" in slotsWithValues.keys(): title = slotsWithValues["title"]
|
||||
if "date" in slotsWithValues.keys(): date = slotsWithValues["date"]
|
||||
if "time" in slotsWithValues.keys(): time = slotsWithValues["time"]
|
||||
return ["Cinema", "select", "", self.database.getShows(title=title,date=date,time=time)]
|
||||
|
||||
elif ((lastUserAct == "hello") | (lastUserAct == "inform") | (lastUserAct == None)):
|
||||
# there are no empty slots
|
||||
systemAct = None
|
||||
slotName = None
|
||||
value = None
|
||||
if not emptySlots:
|
||||
# TODO generate reservation id
|
||||
|
||||
reservationId = max(self.database.reservations.keys()) + 1
|
||||
# TODO add reservation
|
||||
systemAct = "inform"
|
||||
slotVal = systemSlots
|
||||
slotName = systemSlots
|
||||
value = reservationId
|
||||
# there are empty slots
|
||||
else:
|
||||
for slot in emptySlots:
|
||||
systemAct = "request"
|
||||
slotVal = slot
|
||||
slotName = slot
|
||||
break
|
||||
return ["Cinema", systemAct, slotVal, ""]
|
||||
return ["Cinema", systemAct, slotName, value]
|
||||
else:
|
||||
systemAct = "repeat"
|
||||
return ["Cinema", systemAct, "", ""]
|
||||
|
||||
|
||||
class UserQuestionModule():
|
||||
|
||||
def __init__(self):
|
||||
self.questionSlots = ["title", "date", "time"]
|
||||
|
||||
def getQuestionTypes(self):
|
||||
return self.questionTypes
|
||||
|
||||
def getQuestionType(self, filledSlots):
|
||||
filledSlots = []
|
||||
# for slot in self.questionSlots
|
||||
|
||||
# if "title" in emptySlots:
|
||||
# questionType = self.questionTypes[0]
|
||||
# elif "date" in emptySlots:
|
||||
# questionType = self.questionTypes[1]
|
||||
# elif "time" in emptySlots:
|
||||
# questionType = self.getQuestionTypes[2]
|
||||
|
||||
# return questionType,
|
||||
|
||||
class DBMock():
|
||||
def __init__(self):
|
||||
self.shows = {
|
||||
1: {
|
||||
"title": "Batman",
|
||||
"date": "21.06",
|
||||
"time": 19,
|
||||
"date": "08.06",
|
||||
"time": "19:00",
|
||||
"seats": ["a1", "a2", "a3", "a4", "a5",
|
||||
"b1", "b2", "b3", "b4", "b5",
|
||||
"c1", "c2", "c3", "c4", "c5",
|
||||
@ -70,8 +62,8 @@ class DBMock():
|
||||
},
|
||||
2: {
|
||||
"title": "Batman",
|
||||
"date": "22.06",
|
||||
"time": 20,
|
||||
"date": "08.06",
|
||||
"time": "20:00",
|
||||
"seats": ["a1", "a2", "a3", "a4", "a5",
|
||||
"b1", "b2", "b3", "b4", "b5",
|
||||
"c1", "c2", "c3", "c4", "c5",
|
||||
@ -85,8 +77,8 @@ class DBMock():
|
||||
},
|
||||
3: {
|
||||
"title": "Zorro",
|
||||
"date": "23.06",
|
||||
"time": 21,
|
||||
"date": "09.06",
|
||||
"time": "21:00",
|
||||
"seats": ["a1", "a2", "a3", "a4", "a5",
|
||||
"b1", "b2", "b3", "b4", "b5",
|
||||
"c1", "c2", "c3", "c4", "c5",
|
||||
@ -132,43 +124,27 @@ class DBMock():
|
||||
}
|
||||
|
||||
def getShows(self, title = None, date = None, time = None,):
|
||||
result = []
|
||||
for key in self.shows.keys():
|
||||
# title is None
|
||||
if(title is None):
|
||||
titles = []
|
||||
for e in self.shows:
|
||||
if (date is not None & time is not None):
|
||||
if e["date"] == str(date):
|
||||
if e["time"] == str(time):
|
||||
titles.append([e["title"], e["date"]])
|
||||
elif (date is not None & time is None):
|
||||
if e["date"] == str(date):
|
||||
titles.append([e["title"], e["date"]])
|
||||
elif (date is None & time is not None):
|
||||
if e["time"] == str(time):
|
||||
titles.append([e["title"], e["date"]])
|
||||
return set(titles)
|
||||
if ((date is not None) & (time is not None)):
|
||||
if self.shows[key]["date"] == str(date):
|
||||
if self.shows[key]["time"] == str(time):
|
||||
result.append([self.shows[key]["title"], self.shows[key]["date"]])
|
||||
elif ((date is not None) & (time is None)):
|
||||
if self.shows[key]["date"] == str(date):
|
||||
result.append([self.shows[key]["title"], self.shows[key]["date"]])
|
||||
elif ((date is None) & (time is not None)):
|
||||
if self.shows[key]["time"] == str(time):
|
||||
result.append([self.shows[key]["title"], self.shows[key]["date"]])
|
||||
# title is not None
|
||||
elif(title is not None):
|
||||
if(date is None):
|
||||
dates = []
|
||||
for e in self.shows:
|
||||
if e["title"] == str(date):
|
||||
dates.append(e["date"])
|
||||
if self.shows[key]["title"] == str(title):
|
||||
result.append(self.shows[key]["date"])
|
||||
elif(date is not None):
|
||||
if(time is None):
|
||||
|
||||
|
||||
return set(titles)
|
||||
# time slot is not None
|
||||
elif(title is None & time is not None & date is None):
|
||||
titles = []
|
||||
for e in self.shows:
|
||||
if e["time"] == str(time):
|
||||
titles.append(e["title"], e["time"])
|
||||
return set(titles)
|
||||
elif(title is None & time is not None & date is None):
|
||||
titles = []
|
||||
for e in self.shows:
|
||||
if e["time"] == str(time):
|
||||
titles.append(e["title"], e["time"])
|
||||
return set(titles)
|
||||
if self.shows[key]["date"] == str(date):
|
||||
result.append(self.shows[key]["time"])
|
||||
return set(result)
|
@ -10,7 +10,7 @@ class DST:
|
||||
for intent, domain, slot, value in user_act:
|
||||
domain = domain.lower()
|
||||
intent = intent.lower()
|
||||
value = value.lower()
|
||||
value = value
|
||||
slot = slot.lower()
|
||||
|
||||
# all intents are same
|
||||
@ -48,6 +48,13 @@ class DST:
|
||||
result.append(key)
|
||||
return result
|
||||
|
||||
def getSlotsWithValues(self):
|
||||
result = {}
|
||||
for key in self.state['belief_state']["cinema"]["book"].keys():
|
||||
if self.state['belief_state']["cinema"]["book"][key] != "":
|
||||
result[key] = self.state['belief_state']["cinema"]["book"][key]
|
||||
return result
|
||||
|
||||
def getSystemSlots(self):
|
||||
result = []
|
||||
for key in self.state['belief_state']["cinema"]["book"].keys():
|
||||
|
@ -3,6 +3,8 @@ from flair.data import Sentence, Token
|
||||
from flair.datasets import SentenceDataset
|
||||
from flair.models import SequenceTagger, TextClassifier
|
||||
|
||||
from .chane import getDate, getTitle
|
||||
|
||||
class NLU:
|
||||
|
||||
def __init__(self):
|
||||
@ -83,6 +85,11 @@ class NLU:
|
||||
# slotValue = value
|
||||
|
||||
if slotValue is not None:
|
||||
# normalise input
|
||||
if slot == "title":
|
||||
slotValue = getTitle(slotValue)
|
||||
elif slot == "date":
|
||||
slotValue = getDate(slotValue)
|
||||
result.append([intent, 'Cinema', slot, slotValue])
|
||||
if len(result) == 0: result.append([intent, 'Cinema', "", ""])
|
||||
return result
|
||||
|
@ -1,11 +1,8 @@
|
||||
from difflib import SequenceMatcher
|
||||
from datetime import date
|
||||
import datetime
|
||||
from dateutil.parser import parse
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def getDate(user_date):
|
||||
#jeżeli w dacie są jakieś liczby, to uznajemy ją za poprawną datę
|
||||
if any(char.isdigit() for char in user_date):
|
||||
@ -22,17 +19,17 @@ def getDate(user_date):
|
||||
|
||||
#zwrócenie wyniku dzisiaj, dzisiaj+1 (jutro), dzisiaj+2 (pojutrze)
|
||||
if result_today > result_tommorow and result_today > result_day_after_tomorrow:
|
||||
return date.today().strftime("%m.%d")
|
||||
return date.today().strftime("%d.%m")
|
||||
elif result_tommorow > result_day_after_tomorrow:
|
||||
return (date.today() + datetime.timedelta(days=1)).strftime("%m.%d")
|
||||
return (date.today() + datetime.timedelta(days=1)).strftime("%d.%m")
|
||||
else:
|
||||
return (date.today() + datetime.timedelta(days=2)).strftime("%m.%d")
|
||||
return (date.today() + datetime.timedelta(days=2)).strftime("%d.%m")
|
||||
|
||||
|
||||
|
||||
def getTitle(user_title):
|
||||
titles=["Batman", "Na Noże", "Uncharted", "Ambulans", "Minionki", "Fantastyczne Zwierzęta", "To Nie Wypanda",
|
||||
"Inni Ludzie"]
|
||||
"Inni Ludzie", "Zorro"]
|
||||
number_list = list(map(lambda x: SequenceMatcher(a=user_title, b=x).ratio(), titles))
|
||||
max_value = max(number_list)
|
||||
max_index = number_list.index(max_value)
|
||||
|
@ -22,12 +22,14 @@ def chatbot():
|
||||
if userMessage == "/exit":
|
||||
print("Do usłyszenia")
|
||||
isActive = False
|
||||
elif userMessage == "/reset":
|
||||
chatbot()
|
||||
else:
|
||||
nluPred = nlu.predict(sentence=userMessage)
|
||||
print(nluPred)
|
||||
dst.update(nluPred)
|
||||
# print(dst.state)
|
||||
dpAct = dp.getAction(dst.getLastUserAct(), dst.getEmptySlots(), dst.getSystemSlots())
|
||||
dpAct = dp.getAction(dst.getLastUserAct(), dst.getEmptySlots(), dst.getSystemSlots(), dst.getSlotsWithValues())
|
||||
print(dpAct)
|
||||
# TODO update DST system act
|
||||
chatbot()
|
Loading…
Reference in New Issue
Block a user