dp
This commit is contained in:
parent
4ef41c18d3
commit
18bebd05a1
@ -1,62 +1,54 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
class DP:
|
class DP:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.questionManager = UserQuestionModule()
|
self.database = DBMock()
|
||||||
|
|
||||||
def getAction(self, lastUserAct, emptySlots, systemSlots, filledSlots):
|
def getAction(self, lastUserAct, emptySlots, systemSlots, slotsWithValues):
|
||||||
systemAct = None
|
systemAct = None
|
||||||
slotVal = None
|
slotVal = None
|
||||||
|
|
||||||
if (lastUserAct == "request" | lastUserAct == "reqmore"):
|
if ((lastUserAct == "request") | (lastUserAct == "reqmore")):
|
||||||
# TODO policy for user request
|
title = None
|
||||||
return ["Cinema", "select", "", []]
|
date = None
|
||||||
|
time = None
|
||||||
|
if "title" in slotsWithValues.keys(): title = slotsWithValues["title"]
|
||||||
|
if "date" in slotsWithValues.keys(): date = slotsWithValues["date"]
|
||||||
|
if "time" in slotsWithValues.keys(): time = slotsWithValues["time"]
|
||||||
|
return ["Cinema", "select", "", self.database.getShows(title=title,date=date,time=time)]
|
||||||
|
|
||||||
elif ((lastUserAct == "hello") | (lastUserAct == "inform") | (lastUserAct == None)):
|
elif ((lastUserAct == "hello") | (lastUserAct == "inform") | (lastUserAct == None)):
|
||||||
# there are no empty slots
|
# there are no empty slots
|
||||||
|
systemAct = None
|
||||||
|
slotName = None
|
||||||
|
value = None
|
||||||
if not emptySlots:
|
if not emptySlots:
|
||||||
# TODO generate reservation id
|
|
||||||
|
reservationId = max(self.database.reservations.keys()) + 1
|
||||||
|
# TODO add reservation
|
||||||
systemAct = "inform"
|
systemAct = "inform"
|
||||||
slotVal = systemSlots
|
slotName = systemSlots
|
||||||
|
value = reservationId
|
||||||
# there are empty slots
|
# there are empty slots
|
||||||
else:
|
else:
|
||||||
for slot in emptySlots:
|
for slot in emptySlots:
|
||||||
systemAct = "request"
|
systemAct = "request"
|
||||||
slotVal = slot
|
slotName = slot
|
||||||
break
|
break
|
||||||
return ["Cinema", systemAct, slotVal, ""]
|
return ["Cinema", systemAct, slotName, value]
|
||||||
else:
|
else:
|
||||||
systemAct = "repeat"
|
systemAct = "repeat"
|
||||||
return ["Cinema", systemAct, "", ""]
|
return ["Cinema", systemAct, "", ""]
|
||||||
|
|
||||||
|
|
||||||
class UserQuestionModule():
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.questionSlots = ["title", "date", "time"]
|
|
||||||
|
|
||||||
def getQuestionTypes(self):
|
|
||||||
return self.questionTypes
|
|
||||||
|
|
||||||
def getQuestionType(self, filledSlots):
|
|
||||||
filledSlots = []
|
|
||||||
# for slot in self.questionSlots
|
|
||||||
|
|
||||||
# if "title" in emptySlots:
|
|
||||||
# questionType = self.questionTypes[0]
|
|
||||||
# elif "date" in emptySlots:
|
|
||||||
# questionType = self.questionTypes[1]
|
|
||||||
# elif "time" in emptySlots:
|
|
||||||
# questionType = self.getQuestionTypes[2]
|
|
||||||
|
|
||||||
# return questionType,
|
|
||||||
|
|
||||||
class DBMock():
|
class DBMock():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.shows = {
|
self.shows = {
|
||||||
1: {
|
1: {
|
||||||
"title": "Batman",
|
"title": "Batman",
|
||||||
"date": "21.06",
|
"date": "08.06",
|
||||||
"time": 19,
|
"time": "19:00",
|
||||||
"seats": ["a1", "a2", "a3", "a4", "a5",
|
"seats": ["a1", "a2", "a3", "a4", "a5",
|
||||||
"b1", "b2", "b3", "b4", "b5",
|
"b1", "b2", "b3", "b4", "b5",
|
||||||
"c1", "c2", "c3", "c4", "c5",
|
"c1", "c2", "c3", "c4", "c5",
|
||||||
@ -70,8 +62,8 @@ class DBMock():
|
|||||||
},
|
},
|
||||||
2: {
|
2: {
|
||||||
"title": "Batman",
|
"title": "Batman",
|
||||||
"date": "22.06",
|
"date": "08.06",
|
||||||
"time": 20,
|
"time": "20:00",
|
||||||
"seats": ["a1", "a2", "a3", "a4", "a5",
|
"seats": ["a1", "a2", "a3", "a4", "a5",
|
||||||
"b1", "b2", "b3", "b4", "b5",
|
"b1", "b2", "b3", "b4", "b5",
|
||||||
"c1", "c2", "c3", "c4", "c5",
|
"c1", "c2", "c3", "c4", "c5",
|
||||||
@ -85,8 +77,8 @@ class DBMock():
|
|||||||
},
|
},
|
||||||
3: {
|
3: {
|
||||||
"title": "Zorro",
|
"title": "Zorro",
|
||||||
"date": "23.06",
|
"date": "09.06",
|
||||||
"time": 21,
|
"time": "21:00",
|
||||||
"seats": ["a1", "a2", "a3", "a4", "a5",
|
"seats": ["a1", "a2", "a3", "a4", "a5",
|
||||||
"b1", "b2", "b3", "b4", "b5",
|
"b1", "b2", "b3", "b4", "b5",
|
||||||
"c1", "c2", "c3", "c4", "c5",
|
"c1", "c2", "c3", "c4", "c5",
|
||||||
@ -132,43 +124,27 @@ class DBMock():
|
|||||||
}
|
}
|
||||||
|
|
||||||
def getShows(self, title = None, date = None, time = None,):
|
def getShows(self, title = None, date = None, time = None,):
|
||||||
|
result = []
|
||||||
|
for key in self.shows.keys():
|
||||||
# title is None
|
# title is None
|
||||||
if(title is None):
|
if(title is None):
|
||||||
titles = []
|
if ((date is not None) & (time is not None)):
|
||||||
for e in self.shows:
|
if self.shows[key]["date"] == str(date):
|
||||||
if (date is not None & time is not None):
|
if self.shows[key]["time"] == str(time):
|
||||||
if e["date"] == str(date):
|
result.append([self.shows[key]["title"], self.shows[key]["date"]])
|
||||||
if e["time"] == str(time):
|
elif ((date is not None) & (time is None)):
|
||||||
titles.append([e["title"], e["date"]])
|
if self.shows[key]["date"] == str(date):
|
||||||
elif (date is not None & time is None):
|
result.append([self.shows[key]["title"], self.shows[key]["date"]])
|
||||||
if e["date"] == str(date):
|
elif ((date is None) & (time is not None)):
|
||||||
titles.append([e["title"], e["date"]])
|
if self.shows[key]["time"] == str(time):
|
||||||
elif (date is None & time is not None):
|
result.append([self.shows[key]["title"], self.shows[key]["date"]])
|
||||||
if e["time"] == str(time):
|
|
||||||
titles.append([e["title"], e["date"]])
|
|
||||||
return set(titles)
|
|
||||||
# title is not None
|
# title is not None
|
||||||
elif(title is not None):
|
elif(title is not None):
|
||||||
if(date is None):
|
if(date is None):
|
||||||
dates = []
|
if self.shows[key]["title"] == str(title):
|
||||||
for e in self.shows:
|
result.append(self.shows[key]["date"])
|
||||||
if e["title"] == str(date):
|
|
||||||
dates.append(e["date"])
|
|
||||||
elif(date is not None):
|
elif(date is not None):
|
||||||
if(time is None):
|
if(time is None):
|
||||||
|
if self.shows[key]["date"] == str(date):
|
||||||
|
result.append(self.shows[key]["time"])
|
||||||
return set(titles)
|
return set(result)
|
||||||
# time slot is not None
|
|
||||||
elif(title is None & time is not None & date is None):
|
|
||||||
titles = []
|
|
||||||
for e in self.shows:
|
|
||||||
if e["time"] == str(time):
|
|
||||||
titles.append(e["title"], e["time"])
|
|
||||||
return set(titles)
|
|
||||||
elif(title is None & time is not None & date is None):
|
|
||||||
titles = []
|
|
||||||
for e in self.shows:
|
|
||||||
if e["time"] == str(time):
|
|
||||||
titles.append(e["title"], e["time"])
|
|
||||||
return set(titles)
|
|
@ -10,7 +10,7 @@ class DST:
|
|||||||
for intent, domain, slot, value in user_act:
|
for intent, domain, slot, value in user_act:
|
||||||
domain = domain.lower()
|
domain = domain.lower()
|
||||||
intent = intent.lower()
|
intent = intent.lower()
|
||||||
value = value.lower()
|
value = value
|
||||||
slot = slot.lower()
|
slot = slot.lower()
|
||||||
|
|
||||||
# all intents are same
|
# all intents are same
|
||||||
@ -48,6 +48,13 @@ class DST:
|
|||||||
result.append(key)
|
result.append(key)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def getSlotsWithValues(self):
|
||||||
|
result = {}
|
||||||
|
for key in self.state['belief_state']["cinema"]["book"].keys():
|
||||||
|
if self.state['belief_state']["cinema"]["book"][key] != "":
|
||||||
|
result[key] = self.state['belief_state']["cinema"]["book"][key]
|
||||||
|
return result
|
||||||
|
|
||||||
def getSystemSlots(self):
|
def getSystemSlots(self):
|
||||||
result = []
|
result = []
|
||||||
for key in self.state['belief_state']["cinema"]["book"].keys():
|
for key in self.state['belief_state']["cinema"]["book"].keys():
|
||||||
|
@ -3,6 +3,8 @@ from flair.data import Sentence, Token
|
|||||||
from flair.datasets import SentenceDataset
|
from flair.datasets import SentenceDataset
|
||||||
from flair.models import SequenceTagger, TextClassifier
|
from flair.models import SequenceTagger, TextClassifier
|
||||||
|
|
||||||
|
from .chane import getDate, getTitle
|
||||||
|
|
||||||
class NLU:
|
class NLU:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -83,6 +85,11 @@ class NLU:
|
|||||||
# slotValue = value
|
# slotValue = value
|
||||||
|
|
||||||
if slotValue is not None:
|
if slotValue is not None:
|
||||||
|
# normalise input
|
||||||
|
if slot == "title":
|
||||||
|
slotValue = getTitle(slotValue)
|
||||||
|
elif slot == "date":
|
||||||
|
slotValue = getDate(slotValue)
|
||||||
result.append([intent, 'Cinema', slot, slotValue])
|
result.append([intent, 'Cinema', slot, slotValue])
|
||||||
if len(result) == 0: result.append([intent, 'Cinema', "", ""])
|
if len(result) == 0: result.append([intent, 'Cinema', "", ""])
|
||||||
return result
|
return result
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
import datetime
|
||||||
from dateutil.parser import parse
|
from dateutil.parser import parse
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def getDate(user_date):
|
def getDate(user_date):
|
||||||
#jeżeli w dacie są jakieś liczby, to uznajemy ją za poprawną datę
|
#jeżeli w dacie są jakieś liczby, to uznajemy ją za poprawną datę
|
||||||
if any(char.isdigit() for char in user_date):
|
if any(char.isdigit() for char in user_date):
|
||||||
@ -22,17 +19,17 @@ def getDate(user_date):
|
|||||||
|
|
||||||
#zwrócenie wyniku dzisiaj, dzisiaj+1 (jutro), dzisiaj+2 (pojutrze)
|
#zwrócenie wyniku dzisiaj, dzisiaj+1 (jutro), dzisiaj+2 (pojutrze)
|
||||||
if result_today > result_tommorow and result_today > result_day_after_tomorrow:
|
if result_today > result_tommorow and result_today > result_day_after_tomorrow:
|
||||||
return date.today().strftime("%m.%d")
|
return date.today().strftime("%d.%m")
|
||||||
elif result_tommorow > result_day_after_tomorrow:
|
elif result_tommorow > result_day_after_tomorrow:
|
||||||
return (date.today() + datetime.timedelta(days=1)).strftime("%m.%d")
|
return (date.today() + datetime.timedelta(days=1)).strftime("%d.%m")
|
||||||
else:
|
else:
|
||||||
return (date.today() + datetime.timedelta(days=2)).strftime("%m.%d")
|
return (date.today() + datetime.timedelta(days=2)).strftime("%d.%m")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def getTitle(user_title):
|
def getTitle(user_title):
|
||||||
titles=["Batman", "Na Noże", "Uncharted", "Ambulans", "Minionki", "Fantastyczne Zwierzęta", "To Nie Wypanda",
|
titles=["Batman", "Na Noże", "Uncharted", "Ambulans", "Minionki", "Fantastyczne Zwierzęta", "To Nie Wypanda",
|
||||||
"Inni Ludzie"]
|
"Inni Ludzie", "Zorro"]
|
||||||
number_list = list(map(lambda x: SequenceMatcher(a=user_title, b=x).ratio(), titles))
|
number_list = list(map(lambda x: SequenceMatcher(a=user_title, b=x).ratio(), titles))
|
||||||
max_value = max(number_list)
|
max_value = max(number_list)
|
||||||
max_index = number_list.index(max_value)
|
max_index = number_list.index(max_value)
|
||||||
|
@ -22,12 +22,14 @@ def chatbot():
|
|||||||
if userMessage == "/exit":
|
if userMessage == "/exit":
|
||||||
print("Do usłyszenia")
|
print("Do usłyszenia")
|
||||||
isActive = False
|
isActive = False
|
||||||
|
elif userMessage == "/reset":
|
||||||
|
chatbot()
|
||||||
else:
|
else:
|
||||||
nluPred = nlu.predict(sentence=userMessage)
|
nluPred = nlu.predict(sentence=userMessage)
|
||||||
print(nluPred)
|
print(nluPred)
|
||||||
dst.update(nluPred)
|
dst.update(nluPred)
|
||||||
# print(dst.state)
|
# print(dst.state)
|
||||||
dpAct = dp.getAction(dst.getLastUserAct(), dst.getEmptySlots(), dst.getSystemSlots())
|
dpAct = dp.getAction(dst.getLastUserAct(), dst.getEmptySlots(), dst.getSystemSlots(), dst.getSlotsWithValues())
|
||||||
print(dpAct)
|
print(dpAct)
|
||||||
# TODO update DST system act
|
# TODO update DST system act
|
||||||
chatbot()
|
chatbot()
|
Loading…
Reference in New Issue
Block a user