Compare commits
2 Commits
master
...
feature/di
Author | SHA1 | Date | |
---|---|---|---|
8b55dbf73b | |||
ee1d7e45d4 |
@ -40,14 +40,7 @@ Agent powinien wykazywać elastyczność, adaptując się do potrzeb klienta, np
|
|||||||
|
|
||||||
- Python 3.10.12
|
- Python 3.10.12
|
||||||
- Instalacja dependencies `pip3 install -r requirements.txt`
|
- Instalacja dependencies `pip3 install -r requirements.txt`
|
||||||
- Centralna część systemu - uruchamiamy `python3 src/main.py` - wymagane są wyuczone modele (patrz niżej)
|
- Centralna część systemu - uruchamiamy `python3 src/main.py`
|
||||||
- NLU:
|
- NLU:
|
||||||
- uczenie modeli od zera `python3 nlu_train.py`
|
- uczenie modeli od zera `python3 nlu_train.py`
|
||||||
- Ewaluacja `python3 evaluate.py`
|
- Ewaluacja `python3 evaluate.py`
|
||||||
|
|
||||||
# Gotowe modele NLU
|
|
||||||
|
|
||||||
- [frame-model-prod](https://1drv.ms/f/s!Ar75ftQiNIxxhcgPS1EOLu0zC_WWzg?e=tJRqbB)
|
|
||||||
- [slot-model-prod](https://1drv.ms/f/s!Ar75ftQiNIxxhcgb2X6pFioRxXHVew?e=ZC6LFI)
|
|
||||||
|
|
||||||
Nazwa folderów jest istotna - muszą byc odpowiednio `frame-model-prod` i `slot-model-prod` oraz znajdować się w głównym katalogu repozytorium.
|
|
133
attributes.json
133
attributes.json
@ -1,92 +1,51 @@
|
|||||||
{
|
{
|
||||||
"size": ["M", "L", "XL"],
|
"dough": ["thick"],
|
||||||
"dough": [
|
"drink": ["pepsi", "cola", "water"],
|
||||||
"thick"
|
"food": ["pizza"],
|
||||||
],
|
"meat": ["chicken", "ham", "tuna"],
|
||||||
"drink": {
|
"sauce": ["garlic", "1000w"],
|
||||||
"woda": {
|
"ingredient": {
|
||||||
"price": 5
|
"chicken": {},
|
||||||
|
"tuna": {},
|
||||||
|
"pineapple": {},
|
||||||
|
"onion": {},
|
||||||
|
"cheese": {},
|
||||||
|
"tomato": {},
|
||||||
|
"ham": {},
|
||||||
|
"pepper": {}
|
||||||
},
|
},
|
||||||
"pepsi": {
|
"menu": ["capri", "margarita", "hawajska", "barcelona", "tuna"],
|
||||||
"price": 7
|
"pizza": {
|
||||||
|
"capri": {
|
||||||
|
"ingredient": ["tomato", "ham", "mushrooms", "cheese"],
|
||||||
|
"price": 25
|
||||||
|
},
|
||||||
|
"margarita": {
|
||||||
|
"ingredient": ["tomato", "cheese"],
|
||||||
|
"price": 20
|
||||||
|
},
|
||||||
|
"hawajska": {
|
||||||
|
"ingredient": ["tomato", "pineapple", "chicken", "cheese"],
|
||||||
|
"price": 30
|
||||||
|
},
|
||||||
|
"barcelona": {
|
||||||
|
"ingredient": ["tomato", "onion", "ham", "pepper", "cheese"],
|
||||||
|
"price": 40
|
||||||
|
},
|
||||||
|
"tuna": {
|
||||||
|
"ingredient": ["tomato", "tuna", "onion", "cheese"],
|
||||||
|
"price": 40
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"cola": {
|
"size": {
|
||||||
"price": 8
|
"m": {
|
||||||
|
"price_multiplier": 1
|
||||||
|
},
|
||||||
|
"l": {
|
||||||
|
"price_multiplier": 1.2
|
||||||
|
},
|
||||||
|
"xl": {
|
||||||
|
"price_multiplier": 1.4
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"food": [
|
|
||||||
"pizza"
|
|
||||||
],
|
|
||||||
"meat": [
|
|
||||||
"kurczak",
|
|
||||||
"szynka",
|
|
||||||
"tuna"
|
|
||||||
],
|
|
||||||
"sauce": [
|
|
||||||
"garlic",
|
|
||||||
"1000w"
|
|
||||||
],
|
|
||||||
"ingredients": {
|
|
||||||
"kurczak": {},
|
|
||||||
"tuna": {},
|
|
||||||
"ananas": {},
|
|
||||||
"cebula": {},
|
|
||||||
"ser": {},
|
|
||||||
"pomidor": {},
|
|
||||||
"szynka": {},
|
|
||||||
"papryka": {}
|
|
||||||
},
|
|
||||||
"menu": [
|
|
||||||
"capri",
|
|
||||||
"margarita",
|
|
||||||
"hawajska",
|
|
||||||
"barcelona",
|
|
||||||
"tuna"
|
|
||||||
],
|
|
||||||
"pizza": {
|
|
||||||
"capri": {
|
|
||||||
"ingredient": [
|
|
||||||
"tomato",
|
|
||||||
"ham",
|
|
||||||
"mushrooms",
|
|
||||||
"cheese"
|
|
||||||
],
|
|
||||||
"price": 25
|
|
||||||
},
|
|
||||||
"margarita": {
|
|
||||||
"ingredient": [
|
|
||||||
"tomato",
|
|
||||||
"cheese"
|
|
||||||
],
|
|
||||||
"price": 20
|
|
||||||
},
|
|
||||||
"hawajska": {
|
|
||||||
"ingredient": [
|
|
||||||
"tomato",
|
|
||||||
"pineapple",
|
|
||||||
"chicken",
|
|
||||||
"cheese"
|
|
||||||
],
|
|
||||||
"price": 30
|
|
||||||
},
|
|
||||||
"barcelona": {
|
|
||||||
"ingredient": [
|
|
||||||
"tomato",
|
|
||||||
"onion",
|
|
||||||
"ham",
|
|
||||||
"pepper",
|
|
||||||
"cheese"
|
|
||||||
],
|
|
||||||
"price": 40
|
|
||||||
},
|
|
||||||
"tuna": {
|
|
||||||
"ingredient": [
|
|
||||||
"tomato",
|
|
||||||
"tuna",
|
|
||||||
"onion",
|
|
||||||
"cheese"
|
|
||||||
],
|
|
||||||
"price": 40
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
@ -4,4 +4,3 @@ pandas==1.5.3
|
|||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
torch==1.13
|
torch==1.13
|
||||||
convlab==3.0.2a0
|
convlab==3.0.2a0
|
||||||
scipy==1.12
|
|
54
src/main.py
54
src/main.py
@ -5,52 +5,36 @@ from service.natural_language_generation import NaturalLanguageGeneration, parse
|
|||||||
from service.templates import templates
|
from service.templates import templates
|
||||||
|
|
||||||
# initialize classes
|
# initialize classes
|
||||||
nlu = NaturalLanguageUnderstanding(use_mocks=False) # NLU
|
nlu = NaturalLanguageUnderstanding() # NLU
|
||||||
monitor = DialogStateMonitor() # DSM
|
monitor = DialogStateMonitor() # DSM
|
||||||
dialog_policy = DialogPolicy() # DP
|
dialog_policy = DialogPolicy() # DP
|
||||||
language_generation = NaturalLanguageGeneration(templates) # NLG
|
language_generation = NaturalLanguageGeneration(templates) # NLG
|
||||||
|
|
||||||
def frame_to_dict(frame):
|
|
||||||
return {
|
|
||||||
"act": frame.act,
|
|
||||||
"slots": [{"name": slot.name, "value": slot.value} for slot in frame.slots],
|
|
||||||
"act_understood": frame.act_understood,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Main loop
|
# Main loop
|
||||||
dial_num = 0
|
user_input = input("Możesz zacząć pisać.\n")
|
||||||
print("CTRL+C aby zakończyć program.")
|
|
||||||
while True:
|
while True:
|
||||||
monitor.reset()
|
# NLU
|
||||||
|
frame = nlu.process_input(user_input)
|
||||||
|
# print(frame)
|
||||||
|
|
||||||
print(f"\n==== Rozpoczynasz rozmowę nr {dial_num} ====\n")
|
# DSM
|
||||||
user_input = input("Witamy w naszej pizza-przez-internet. W czym mogę pomóc?\n")
|
monitor.update(frame)
|
||||||
|
|
||||||
while True:
|
# DP
|
||||||
# NLU
|
# print(dialog_policy.next_dialogue_act(monitor.read()).act)
|
||||||
frame = nlu.predict(user_input)
|
|
||||||
# print("Frame: ", frame)
|
# NLG
|
||||||
|
act, slots = parse_frame(frame)
|
||||||
|
response = language_generation.generate(act, slots)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
if frame.act == "bye":
|
||||||
|
break
|
||||||
|
|
||||||
|
user_input = input(">\n")
|
||||||
|
|
||||||
# DSM
|
|
||||||
monitor.update(frame)
|
|
||||||
|
|
||||||
# DP
|
|
||||||
system_action = dialog_policy.predict(monitor)
|
|
||||||
system_action_dict = frame_to_dict(system_action) # Ensure system_action is a dictionary
|
|
||||||
# print("System action: ", system_action_dict)
|
|
||||||
|
|
||||||
# NLG
|
|
||||||
response = language_generation.generate(frame, system_action_dict)
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
if system_action.act == "bye_and_thanks" or system_action.act == "bye":
|
|
||||||
monitor.print_order()
|
|
||||||
break
|
|
||||||
|
|
||||||
if frame.act == "bye":
|
|
||||||
print(monitor.print_order())
|
|
||||||
break
|
|
||||||
|
|
||||||
user_input = input(">\n")
|
|
||||||
dial_num += 1
|
|
@ -1,11 +1,10 @@
|
|||||||
from .slot import Slot
|
from .slot import Slot
|
||||||
|
|
||||||
class Frame:
|
class Frame:
|
||||||
def __init__(self, source: str, act: str, slots: list[Slot] = [], act_understood = None):
|
def __init__(self, source: str, act: str, slots: list[Slot]):
|
||||||
self.source = source
|
self.source = source
|
||||||
self.slots = slots
|
self.slots = slots
|
||||||
self.act = act
|
self.act = act
|
||||||
self.act_understood = act_understood
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
msg = f"Act: {self.act}, Slots: ["
|
msg = f"Act: {self.act}, Slots: ["
|
||||||
|
4
src/service/data/restaurant/db/confirm_db.json
Normal file
4
src/service/data/restaurant/db/confirm_db.json
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[
|
||||||
|
"true",
|
||||||
|
"false"
|
||||||
|
]
|
5
src/service/data/restaurant/db/dough_db.json
Normal file
5
src/service/data/restaurant/db/dough_db.json
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
[
|
||||||
|
"pepsi",
|
||||||
|
"cola",
|
||||||
|
"water"
|
||||||
|
]
|
11
src/service/data/restaurant/db/drink_db.json
Normal file
11
src/service/data/restaurant/db/drink_db.json
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"name":"pepsi"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name":"cola"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name":"water"
|
||||||
|
}
|
||||||
|
]
|
3
src/service/data/restaurant/db/food_db.json
Normal file
3
src/service/data/restaurant/db/food_db.json
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[
|
||||||
|
"pizza"
|
||||||
|
]
|
5
src/service/data/restaurant/db/meat_db.json
Normal file
5
src/service/data/restaurant/db/meat_db.json
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
[
|
||||||
|
"chicken",
|
||||||
|
"ham",
|
||||||
|
"tuna"
|
||||||
|
]
|
7
src/service/data/restaurant/db/menu_db.json
Normal file
7
src/service/data/restaurant/db/menu_db.json
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
[
|
||||||
|
"capri",
|
||||||
|
"margarita",
|
||||||
|
"hawajska",
|
||||||
|
"barcelona",
|
||||||
|
"tuna"
|
||||||
|
]
|
51
src/service/data/restaurant/db/pizza_db.json
Normal file
51
src/service/data/restaurant/db/pizza_db.json
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "capri",
|
||||||
|
"ingredient": [
|
||||||
|
"tomato",
|
||||||
|
"ham",
|
||||||
|
"mushrooms",
|
||||||
|
"cheese"
|
||||||
|
],
|
||||||
|
"price": 25
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "margarita",
|
||||||
|
"ingredient": [
|
||||||
|
"tomato",
|
||||||
|
"cheese"
|
||||||
|
],
|
||||||
|
"price": 20
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "hawajska",
|
||||||
|
"ingredient": [
|
||||||
|
"tomato",
|
||||||
|
"pineapple",
|
||||||
|
"chicken",
|
||||||
|
"cheese"
|
||||||
|
],
|
||||||
|
"price": 30
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "barcelona",
|
||||||
|
"ingredient": [
|
||||||
|
"tomato",
|
||||||
|
"onion",
|
||||||
|
"ham",
|
||||||
|
"pepper",
|
||||||
|
"cheese"
|
||||||
|
],
|
||||||
|
"price": 40
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "tuna",
|
||||||
|
"ingredient": [
|
||||||
|
"tomato",
|
||||||
|
"tuna",
|
||||||
|
"onion",
|
||||||
|
"cheese"
|
||||||
|
],
|
||||||
|
"price": 40
|
||||||
|
}
|
||||||
|
]
|
4
src/service/data/restaurant/db/sauce_db.json
Normal file
4
src/service/data/restaurant/db/sauce_db.json
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[
|
||||||
|
"garlic",
|
||||||
|
"1000w"
|
||||||
|
]
|
14
src/service/data/restaurant/db/size_db.json
Normal file
14
src/service/data/restaurant/db/size_db.json
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"size": "m",
|
||||||
|
"price_multiplier": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"size": "l",
|
||||||
|
"price_multiplier": 1.2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"size": "xl",
|
||||||
|
"price_multiplier": 1.4
|
||||||
|
}
|
||||||
|
]
|
91
src/service/dbquery.py
Normal file
91
src/service/dbquery.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
"""
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from fuzzywuzzy import fuzz
|
||||||
|
from itertools import chain
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
class Database(object):
|
||||||
|
def __init__(self):
|
||||||
|
super(Database, self).__init__()
|
||||||
|
# loading databases
|
||||||
|
domains = ['restaurant', 'hotel', 'attraction', 'train', 'hospital', 'taxi', 'police']
|
||||||
|
self.dbs = {}
|
||||||
|
for domain in domains:
|
||||||
|
with open(os.path.join(os.path.dirname(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))),
|
||||||
|
'data/restaurant/db/{}_db.json'.format(domain))) as f:
|
||||||
|
self.dbs[domain] = json.load(f)
|
||||||
|
|
||||||
|
def query(self, domain, constraints, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60):
|
||||||
|
"""Returns the list of entities for a given domain
|
||||||
|
based on the annotation of the belief state"""
|
||||||
|
# query the db
|
||||||
|
if domain == 'taxi':
|
||||||
|
return [{'taxi_colors': random.choice(self.dbs[domain]['taxi_colors']),
|
||||||
|
'taxi_types': random.choice(self.dbs[domain]['taxi_types']),
|
||||||
|
'taxi_phone': ''.join([str(random.randint(1, 9)) for _ in range(11)])}]
|
||||||
|
if domain == 'police':
|
||||||
|
return deepcopy(self.dbs['police'])
|
||||||
|
if domain == 'hospital':
|
||||||
|
department = None
|
||||||
|
for key, val in constraints:
|
||||||
|
if key == 'department':
|
||||||
|
department = val
|
||||||
|
if not department:
|
||||||
|
return deepcopy(self.dbs['hospital'])
|
||||||
|
else:
|
||||||
|
return [deepcopy(x) for x in self.dbs['hospital'] if x['department'].lower() == department.strip().lower()]
|
||||||
|
constraints = list(map(lambda ele: ele if not(ele[0] == 'area' and ele[1] == 'center') else ('area', 'centre'), constraints))
|
||||||
|
|
||||||
|
found = []
|
||||||
|
for i, record in enumerate(self.dbs[domain]):
|
||||||
|
constraints_iterator = zip(constraints, [False] * len(constraints))
|
||||||
|
soft_contraints_iterator = zip(soft_contraints, [True] * len(soft_contraints))
|
||||||
|
for (key, val), fuzzy_match in chain(constraints_iterator, soft_contraints_iterator):
|
||||||
|
if val == "" or val == "dont care" or val == 'not mentioned' or val == "don't care" or val == "dontcare" or val == "do n't care":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
record_keys = [k.lower() for k in record]
|
||||||
|
if key.lower() not in record_keys:
|
||||||
|
continue
|
||||||
|
if key == 'leaveAt':
|
||||||
|
val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1])
|
||||||
|
val2 = int(record['leaveAt'].split(':')[0]) * 100 + int(record['leaveAt'].split(':')[1])
|
||||||
|
if val1 > val2:
|
||||||
|
break
|
||||||
|
elif key == 'arriveBy':
|
||||||
|
val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1])
|
||||||
|
val2 = int(record['arriveBy'].split(':')[0]) * 100 + int(record['arriveBy'].split(':')[1])
|
||||||
|
if val1 < val2:
|
||||||
|
break
|
||||||
|
# elif ignore_open and key in ['destination', 'departure', 'name']:
|
||||||
|
elif ignore_open and key in ['destination', 'departure']:
|
||||||
|
continue
|
||||||
|
elif record[key].strip() == '?':
|
||||||
|
# '?' matches any constraint
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if not fuzzy_match:
|
||||||
|
if val.strip().lower() != record[key].strip().lower():
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if fuzz.partial_ratio(val.strip().lower(), record[key].strip().lower()) < fuzzy_match_ratio:
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
res = deepcopy(record)
|
||||||
|
res['Ref'] = '{0:08d}'.format(i)
|
||||||
|
found.append(res)
|
||||||
|
|
||||||
|
return found
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
db = Database()
|
||||||
|
print(db.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arriveBy', '11:15']]))
|
@ -1,66 +1,59 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from model.frame import Frame
|
import copy
|
||||||
from model.slot import Slot
|
import json
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
class DialogPolicy:
|
from convlab.policy.policy import Policy
|
||||||
def _predict(self, dsm):
|
from dbquery import Database
|
||||||
last_frame = dsm.state['history'][-1]
|
|
||||||
act_processed = dsm.state['was_system_act_processed']
|
class SimpleRulePolicy(Policy):
|
||||||
if(dsm.state['was_previous_order_invalid'] == False):
|
def __init__(self):
|
||||||
if last_frame.act == "inform/order-complete":
|
Policy.__init__(self)
|
||||||
act = last_frame.act
|
self.db = Database()
|
||||||
elif ("inform" in last_frame.act):
|
|
||||||
act = last_frame.act.split('/')[0]
|
def predict(self, state):
|
||||||
|
self.results = []
|
||||||
|
system_action = defaultdict(list)
|
||||||
|
user_action = defaultdict(list)
|
||||||
|
|
||||||
|
for intent, domain, slot, value in state['user_action']:
|
||||||
|
user_action[(domain.lower(), intent.lower())].append((slot.lower(), value))
|
||||||
|
|
||||||
|
for user_act in user_action:
|
||||||
|
self.update_system_action(user_act, user_action, state, system_action)
|
||||||
|
|
||||||
|
# Reguła 3
|
||||||
|
if any(True for slots in user_action.values() for (slot, _) in slots if slot in ['book stay', 'book day', 'book people']):
|
||||||
|
if self.results:
|
||||||
|
system_action = {('Booking', 'Book'): [["Ref", self.results[0].get('Ref', 'N/A')]]}
|
||||||
|
|
||||||
|
system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for slot, value in slots]
|
||||||
|
state['system_action'] = system_acts
|
||||||
|
return system_acts
|
||||||
|
|
||||||
|
def update_system_action(self, user_act, user_action, state, system_action):
|
||||||
|
domain, intent = user_act
|
||||||
|
constraints = [(slot, value) for slot, value in state['belief_state'][domain.lower()].items() if value != '']
|
||||||
|
self.results = deepcopy(self.db.query(domain.lower(), constraints))
|
||||||
|
|
||||||
|
# Reguła 1
|
||||||
|
if intent == 'request':
|
||||||
|
if len(self.results) == 0:
|
||||||
|
system_action[(domain, 'NoOffer')] = []
|
||||||
else:
|
else:
|
||||||
act = last_frame.act
|
for slot in user_action[user_act]:
|
||||||
match(act):
|
if slot[0] in self.results[0]:
|
||||||
case "bye":
|
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
|
||||||
return Frame(source="system", act = "bye")
|
|
||||||
case "inform/order-complete":
|
|
||||||
return Frame(source="system", act = "bye_and_thanks")
|
|
||||||
case "inform" | "affirm" | "negate":
|
|
||||||
current_active_stage = dsm.get_current_active_stage()
|
|
||||||
if current_active_stage == None:
|
|
||||||
return Frame(source="system", act = "bye_and_thanks")
|
|
||||||
match(current_active_stage['name']):
|
|
||||||
case "collect_food":
|
|
||||||
return Frame(source="system", act = "request/food", slots = [Slot("menu", dsm.state['constants']['menu'])], act_understood=act_processed)
|
|
||||||
case "collect_drinks":
|
|
||||||
return Frame(source="system", act = "request/drinks", slots = [Slot("drink", dsm.state['constants']['drink'])], act_understood=act_processed)
|
|
||||||
case "more_food":
|
|
||||||
return Frame(source="system", act = "request/food-more", slots = [Slot("menu", dsm.state['constants']['menu'])], act_understood=act_processed)
|
|
||||||
case "more_drinks":
|
|
||||||
return Frame(source="system", act = "request/drinks-more", slots = [Slot("drink", dsm.state['constants']['drink'])], act_understood=act_processed)
|
|
||||||
case "collect_address":
|
|
||||||
return Frame(source="system", act = "request/address", act_understood=act_processed)
|
|
||||||
case "collect_payment_method":
|
|
||||||
return Frame(source="system", act = "request/payment-method", act_understood=act_processed)
|
|
||||||
case "collect_phone":
|
|
||||||
return Frame(source="system", act = "request/phone", act_understood=act_processed)
|
|
||||||
case "request/menu":
|
|
||||||
return Frame(source="system", act = "inform", slots = [Slot("menu", dsm.state['constants']['menu'])])
|
|
||||||
case "request/price":
|
|
||||||
return Frame(source="system", act = "inform", slots = [Slot("price", dsm.state['total_cost'])])
|
|
||||||
case "request/ingredients":
|
|
||||||
return Frame(source="system", act = "inform", slots = [Slot("ingredients", dsm.state['constants']['ingredients'])])
|
|
||||||
case "request/sauce":
|
|
||||||
return Frame(source="system", act = "inform", slots = [Slot("sauce", dsm.state['constants']['sauce'])])
|
|
||||||
case "request/time":
|
|
||||||
return Frame(source="system", act = "inform", slots = [Slot("time", dsm.state['belief_state']['time'])])
|
|
||||||
case "request/size":
|
|
||||||
return Frame(source="system", act = "inform", slots = [Slot("size", dsm.state['constants']['size'])])
|
|
||||||
case "request/delivery-price":
|
|
||||||
return Frame(source="system", act = "inform", slots = [Slot("delivery-price", "10")])
|
|
||||||
case "request/drinks":
|
|
||||||
return Frame(source="system", act = "inform", slots = [Slot("drink", dsm.state['constants']['drink'])])
|
|
||||||
case "welcomemsg":
|
|
||||||
return Frame(source="system", act = "inform", slots = [Slot("menu", dsm.state['constants']['menu'])])
|
|
||||||
case "repeat":
|
|
||||||
return Frame(source="system", act = "repeat")
|
|
||||||
|
|
||||||
return Frame(source="system", act = "repeat")
|
# Reguła 2
|
||||||
|
elif intent == 'inform':
|
||||||
|
if len(self.results) == 0:
|
||||||
|
system_action[(domain, 'NoOffer')] = []
|
||||||
|
else:
|
||||||
|
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
|
||||||
|
choice = self.results[0]
|
||||||
|
|
||||||
def predict(self, dsm):
|
if domain in ["hotel", "attraction", "police", "restaurant"]:
|
||||||
frame = self._predict(dsm)
|
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
|
||||||
dsm.state["system_history"].append(frame)
|
|
||||||
return frame
|
dialogPolicy = SimpleRulePolicy()
|
@ -1,183 +1,55 @@
|
|||||||
from model.frame import Frame
|
from src.model.frame import Frame
|
||||||
|
from convlab.dst.dst import DST
|
||||||
import copy
|
import copy
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(value):
|
def normalize(value):
|
||||||
value = value.lower()
|
value = value.lower()
|
||||||
if value[-1] in [",", "!", "?" ,"." ,")" ,"(",")"]:
|
|
||||||
value = value[:-1]
|
|
||||||
if value[-2:] == "ie":
|
|
||||||
value = value[:-2] + "a"
|
|
||||||
if value[-1] in ["ę", "ą", "y"]:
|
|
||||||
value = value[:-1] + "a"
|
|
||||||
if value in ["pizz", "pizzy", "pizze", "picce"]:
|
|
||||||
value = "pizza"
|
|
||||||
|
|
||||||
return ' '.join(value.split())
|
return ' '.join(value.split())
|
||||||
|
|
||||||
|
|
||||||
class DialogStateMonitor:
|
class DialogStateMonitor(DST):
|
||||||
def __init__(self, initial_state_file: str = 'attributes.json'):
|
domain = 'restaurant'
|
||||||
with open(initial_state_file) as file:
|
|
||||||
constants = json.load(file)
|
def __init__(self):
|
||||||
self.__initial_state = dict(
|
DST.__init__(self)
|
||||||
belief_state={
|
self.__initial_state = dict(user_action=[],
|
||||||
'order': [],
|
system_action=[],
|
||||||
'address': {},
|
belief_state={
|
||||||
'order-complete': False,
|
'order': [],
|
||||||
'phone': {},
|
'address': {},
|
||||||
'delivery': {},
|
'order-complete': False,
|
||||||
'payment': {},
|
'phone': {},
|
||||||
'time': {},
|
'delivery': {},
|
||||||
'name': {},
|
'payment': {},
|
||||||
},
|
'time': {},
|
||||||
total_cost=0,
|
'name': {},
|
||||||
stages=[
|
},
|
||||||
{'completed': False, 'name': 'collect_food'},
|
booked={},
|
||||||
{'completed': False, 'name': 'more_food'},
|
request_state={},
|
||||||
{'completed': False, 'name': 'collect_drinks'},
|
terminated=False,
|
||||||
{'completed': False, 'name': 'more_drinks'},
|
history=[])
|
||||||
{'completed': False, 'name': 'collect_address'},
|
|
||||||
{'completed': False, 'name': 'collect_phone'},
|
|
||||||
],
|
|
||||||
was_previous_order_invalid=False,
|
|
||||||
was_system_act_processed=False,
|
|
||||||
constants=constants,
|
|
||||||
history=[],
|
|
||||||
system_history=[],
|
|
||||||
)
|
|
||||||
self.state = copy.deepcopy(self.__initial_state)
|
self.state = copy.deepcopy(self.__initial_state)
|
||||||
|
|
||||||
def get_current_active_stage(self) -> str | None:
|
def update(self, frame: Frame):
|
||||||
for stage in self.state['stages']:
|
|
||||||
if stage['completed'] is False:
|
|
||||||
# print("Current stage: ", stage['name'])
|
|
||||||
return stage
|
|
||||||
self.state['belief_state']['order-complete'] = True
|
|
||||||
|
|
||||||
def mark_current_stage_completed(self) -> None:
|
|
||||||
for stage in self.state['stages']:
|
|
||||||
if stage['completed'] is False:
|
|
||||||
# print("Stage completed: ", stage['name'])
|
|
||||||
stage['completed'] = True
|
|
||||||
return
|
|
||||||
|
|
||||||
def complete_stage_if_valid(self, stage_name):
|
|
||||||
for stage in self.state['stages']:
|
|
||||||
if stage['name'] != stage_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if stage['name'] == "collect_food":
|
|
||||||
for order in self.state['belief_state']['order']:
|
|
||||||
if order.get("pizza"):
|
|
||||||
stage['completed'] = True
|
|
||||||
return
|
|
||||||
elif stage["name"] == "collect_drinks":
|
|
||||||
for order in self.state['belief_state']['order']:
|
|
||||||
if order.get("drink"):
|
|
||||||
stage['completed'] = True
|
|
||||||
return
|
|
||||||
elif stage["name"] == "collect_address":
|
|
||||||
if not len(self.state['belief_state']['address']):
|
|
||||||
return
|
|
||||||
stage['completed'] = True
|
|
||||||
return
|
|
||||||
elif stage["name"] == "collect_phone":
|
|
||||||
if self.state['belief_state']["phone"].get("phone"):
|
|
||||||
stage['completed'] = True
|
|
||||||
return
|
|
||||||
pass
|
|
||||||
|
|
||||||
def item_exists(self, type: str, name: str) -> bool:
|
|
||||||
return normalize(name) in self.state['constants'][type]
|
|
||||||
|
|
||||||
def drink_exists(self, name: str) -> bool:
|
|
||||||
return normalize(name) in self.state['constants']['pizza']
|
|
||||||
|
|
||||||
def get_total_cost(self) -> int:
|
|
||||||
return self.state['total_cost']
|
|
||||||
|
|
||||||
def slot_augmentation(self, slot, value):
|
|
||||||
drink_normalize = ["woda", "pepsi", "cola", "coca cola", "cole", "coca"]
|
|
||||||
if slot.name in ["food", "pizza", "ingredient"]:
|
|
||||||
if value in drink_normalize:
|
|
||||||
slot.name = 'drink'
|
|
||||||
return slot
|
|
||||||
|
|
||||||
def slot_valid(self, slot, act):
|
|
||||||
if act == "inform/order":
|
|
||||||
if slot.name in ["address", "payment-method", 'delivery', 'phone', 'time', 'name']:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def value_valid(self, slot_name, value):
|
|
||||||
if slot_name == "food":
|
|
||||||
if value != "pizza":
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def update(self, frame: Frame) -> None:
|
|
||||||
self.state['was_system_act_processed'] = False
|
|
||||||
belief_state_copy = copy.deepcopy(self.state["belief_state"])
|
|
||||||
self.state['history'].append(frame)
|
|
||||||
if frame.source != 'user':
|
if frame.source != 'user':
|
||||||
return
|
return
|
||||||
if frame.act == 'inform/order':
|
if frame.act == 'inform/order':
|
||||||
new_order = dict()
|
new_order = dict()
|
||||||
for slot in frame.slots:
|
for slot in frame.slots:
|
||||||
value = normalize(slot.value)
|
new_order[slot.name] = normalize(slot.value)
|
||||||
slot = self.slot_augmentation(slot, value)
|
self.state['belief_state']['order'].append(new_order)
|
||||||
if not self.slot_valid(slot, frame.act):
|
|
||||||
continue
|
|
||||||
if not self.value_valid(slot.name, value):
|
|
||||||
continue
|
|
||||||
stage_name = self.get_current_active_stage()['name']
|
|
||||||
is_collect_food = (slot.name == 'pizza' and stage_name == 'collect_food')
|
|
||||||
is_more_food = (slot.name == 'pizza' and stage_name == 'more_food')
|
|
||||||
is_collect_drinks = (slot.name == 'drink' and stage_name == 'collect_drinks')
|
|
||||||
is_more_drinks = (slot.name == 'drink' and stage_name == 'more_drinks')
|
|
||||||
if is_collect_food or is_collect_drinks or is_more_food or stage_name == is_more_drinks:
|
|
||||||
if self.item_exists(slot.name, value) is False:
|
|
||||||
self.state['was_previous_order_invalid'] = True
|
|
||||||
return
|
|
||||||
self.state['was_previous_order_invalid'] = False
|
|
||||||
self.state['total_cost'] += self.state['constants'][slot.name][value]['price']
|
|
||||||
if slot.name == "pizza":
|
|
||||||
if new_order.get("pizza"):
|
|
||||||
new_order[slot.name].append({"name": value, "ingredient": [], "ingredient/neg": []})
|
|
||||||
else:
|
|
||||||
new_order[slot.name] = [{"name": value, "ingredient": [], "ingredient/neg": []}]
|
|
||||||
elif slot.name == "drink":
|
|
||||||
if new_order.get("drink"):
|
|
||||||
new_order[slot.name].append({"drink": value})
|
|
||||||
else:
|
|
||||||
new_order[slot.name] = [{"drink": value}]
|
|
||||||
|
|
||||||
elif slot.name in ["ingredient", "ingredient/neg"]:
|
|
||||||
pizzas_list = new_order.get("pizza")
|
|
||||||
if pizzas_list:
|
|
||||||
pizzas_list[-1][slot.name].append(value)
|
|
||||||
|
|
||||||
if len(new_order) > 0:
|
|
||||||
self.state['belief_state']['order'].append(new_order)
|
|
||||||
self.complete_stage_if_valid('collect_food')
|
|
||||||
self.complete_stage_if_valid('collect_drinks')
|
|
||||||
elif frame.act == 'inform/address':
|
elif frame.act == 'inform/address':
|
||||||
for slot in frame.slots:
|
for slot in frame.slots:
|
||||||
self.state['belief_state']['address'][slot.name] = normalize(slot.value)
|
self.state['belief_state']['address'][slot.name] = normalize(slot.value)
|
||||||
self.complete_stage_if_valid('collect_address')
|
|
||||||
elif frame.act == 'inform/phone':
|
elif frame.act == 'inform/phone':
|
||||||
for slot in frame.slots:
|
for slot in frame.slots:
|
||||||
self.state['belief_state']['phone'][slot.name] = normalize(slot.value)
|
self.state['belief_state']['phone'][slot.name] = normalize(slot.value)
|
||||||
self.complete_stage_if_valid('collect_phone')
|
|
||||||
elif frame.act == 'inform/order-complete':
|
elif frame.act == 'inform/order-complete':
|
||||||
self.state['belief_state']['order-complete'] = True
|
self.state['belief_state']['order-complete'] = True
|
||||||
elif frame.act == 'inform/delivery':
|
elif frame.act == 'inform/delivery':
|
||||||
for slot in frame.slots:
|
for slot in frame.slots:
|
||||||
self.state['belief_state']['delivery'][slot.name] = normalize(slot.value)
|
self.state['belief_state']['delivery'][slot.name] = normalize(slot.value)
|
||||||
self.complete_stage_if_valid('collect_address')
|
|
||||||
elif frame.act == 'inform/payment':
|
elif frame.act == 'inform/payment':
|
||||||
for slot in frame.slots:
|
for slot in frame.slots:
|
||||||
self.state['belief_state']['payment'][slot.name] = normalize(slot.value)
|
self.state['belief_state']['payment'][slot.name] = normalize(slot.value)
|
||||||
@ -187,23 +59,6 @@ class DialogStateMonitor:
|
|||||||
elif frame.act == 'inform/name':
|
elif frame.act == 'inform/name':
|
||||||
for slot in frame.slots:
|
for slot in frame.slots:
|
||||||
self.state['belief_state']['name'][slot.name] = normalize(slot.value)
|
self.state['belief_state']['name'][slot.name] = normalize(slot.value)
|
||||||
elif frame.act == 'negate':
|
|
||||||
if "request" in self.state["system_history"][-1].act:
|
|
||||||
self.mark_current_stage_completed()
|
|
||||||
self.state['was_system_act_processed'] = True
|
|
||||||
|
|
||||||
if self.state["belief_state"] != belief_state_copy and frame.act not in ["repeat", 'affirm', 'negate']:
|
def reset(self):
|
||||||
self.state['was_system_act_processed'] = True
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
self.state = copy.deepcopy(self.__initial_state)
|
self.state = copy.deepcopy(self.__initial_state)
|
||||||
|
|
||||||
def print_order(self) -> dict:
|
|
||||||
print("\n=== Oto podsumowanie Twojego zamówienia ===")
|
|
||||||
for o in self.state['belief_state']['order']:
|
|
||||||
print(o)
|
|
||||||
print("Adres dostawy: ")
|
|
||||||
print(self.state['belief_state']['address'])
|
|
||||||
print("Numer telefonu: ")
|
|
||||||
print(self.state['belief_state']['phone'])
|
|
||||||
print(f"Czy zostało pomyślnie zrealizowane: {self.state['belief_state']['order-complete']}")
|
|
@ -1,6 +1,7 @@
|
|||||||
|
from flair.models import SequenceTagger
|
||||||
from utils.nlu_utils import predict_single, predict_and_annotate
|
from utils.nlu_utils import predict_single, predict_and_annotate
|
||||||
from model.frame import Frame, Slot
|
from model.frame import Frame, Slot
|
||||||
import random
|
|
||||||
"""
|
"""
|
||||||
ACTS:
|
ACTS:
|
||||||
inform/order
|
inform/order
|
||||||
@ -40,13 +41,8 @@ SLOTS:
|
|||||||
sauce
|
sauce
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class NaturalLanguageUnderstanding():
|
class NaturalLanguageUnderstanding:
|
||||||
def __init__(self, use_mocks=False):
|
def __init__(self):
|
||||||
self.use_mocks = use_mocks
|
|
||||||
if self.use_mocks:
|
|
||||||
return
|
|
||||||
|
|
||||||
from flair.models import SequenceTagger
|
|
||||||
print("\n========================================================")
|
print("\n========================================================")
|
||||||
print("Models are loading, it may take a moment, please wait...")
|
print("Models are loading, it may take a moment, please wait...")
|
||||||
print("========================================================\n")
|
print("========================================================\n")
|
||||||
@ -89,20 +85,8 @@ class NaturalLanguageUnderstanding():
|
|||||||
|
|
||||||
return slots
|
return slots
|
||||||
|
|
||||||
def predict(self, text: str):
|
def process_input(self, text: str):
|
||||||
if not self.use_mocks:
|
act = self.__predict_intention(text)
|
||||||
try:
|
slots = self.__predict_slot(text)
|
||||||
act = self.__predict_intention(text)
|
frame = Frame(source = 'user', act = act, slots = slots)
|
||||||
slots = self.__predict_slot(text)
|
return frame
|
||||||
frame = Frame(source = 'user', act = act, slots = slots)
|
|
||||||
return frame
|
|
||||||
except:
|
|
||||||
return Frame(source="user", act = "repeat", slots=[])
|
|
||||||
else:
|
|
||||||
frames = [
|
|
||||||
Frame(source="user", act = "inform/order", slots=[Slot(name="pizza", value="barcelona")]),
|
|
||||||
Frame(source="user", act = "welcomemsg", slots=[]),
|
|
||||||
Frame(source="user", act = "request/menu", slots=[]),
|
|
||||||
]
|
|
||||||
return random.choice(frames)
|
|
||||||
|
|
@ -1,37 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from service.template_selector import select_template
|
from service.template_selector import select_template
|
||||||
import random
|
import random
|
||||||
|
# from service.templates import templates
|
||||||
class NaturalLanguageGeneration:
|
|
||||||
def __init__(self, templates):
|
|
||||||
self.templates = templates
|
|
||||||
|
|
||||||
def generate(self, frame, system_action):
|
|
||||||
# Parsowanie frame
|
|
||||||
act, slots = parse_frame(frame)
|
|
||||||
|
|
||||||
# Wybierz szablon na podstawie system_action
|
|
||||||
template = select_template(system_action['act'], system_action['slots'])
|
|
||||||
if template is None:
|
|
||||||
print(f"Brak szablonu dla act: {system_action['act']} z slotami: {system_action['slots']}")
|
|
||||||
template = random.choice(self.templates["default/template"])
|
|
||||||
|
|
||||||
# Zamień sloty na wartości
|
|
||||||
slot_dict = {}
|
|
||||||
for slot in system_action['slots']:
|
|
||||||
if isinstance(slot['value'], list):
|
|
||||||
slot_dict[slot['name']] = ', '.join(slot['value'])
|
|
||||||
elif isinstance(slot['value'], dict):
|
|
||||||
slot_dict[slot['name']] = ', '.join([f"{k}: {v}" for k, v in slot['value'].items()])
|
|
||||||
else:
|
|
||||||
slot_dict[slot['name']] = slot['value']
|
|
||||||
|
|
||||||
response = template.format(**slot_dict)
|
|
||||||
if system_action['act_understood'] == True:
|
|
||||||
response = f"Zrozumiałem. {response}"
|
|
||||||
elif system_action['act_understood'] == False:
|
|
||||||
response = f"Nie zrozumiałem. {response}"
|
|
||||||
return response
|
|
||||||
|
|
||||||
def parse_frame(frame):
|
def parse_frame(frame):
|
||||||
if not hasattr(frame, 'act') or not hasattr(frame, 'slots'):
|
if not hasattr(frame, 'act') or not hasattr(frame, 'slots'):
|
||||||
@ -41,3 +11,14 @@ def parse_frame(frame):
|
|||||||
slots = [{"name": slot.name, "value": slot.value} for slot in frame.slots]
|
slots = [{"name": slot.name, "value": slot.value} for slot in frame.slots]
|
||||||
|
|
||||||
return act, slots
|
return act, slots
|
||||||
|
|
||||||
|
class NaturalLanguageGeneration:
|
||||||
|
def __init__(self, templates):
|
||||||
|
self.templates = templates
|
||||||
|
|
||||||
|
def generate(self, act, slots):
|
||||||
|
template = select_template(act, slots)
|
||||||
|
if template == "default/template":
|
||||||
|
template = random.choice(self.templates["default/template"])
|
||||||
|
slot_dict = {slot['name']: slot['value'] for slot in slots}
|
||||||
|
return template.format(**slot_dict)
|
@ -1,216 +1,35 @@
|
|||||||
import random
|
import random
|
||||||
from service.templates import templates
|
from service.templates import templates
|
||||||
|
|
||||||
def generate_pizza_info(slots):
|
|
||||||
pizza_name = None
|
|
||||||
query_type = None
|
|
||||||
|
|
||||||
for slot in slots:
|
|
||||||
if slot['name'] == 'pizza':
|
|
||||||
pizza_name = slot['value'].lower()
|
|
||||||
elif slot['name'] == 'ingredient':
|
|
||||||
query_type = 'ingredient'
|
|
||||||
elif slot['name'] == 'price':
|
|
||||||
query_type = 'price'
|
|
||||||
|
|
||||||
if pizza_name not in pizza_data:
|
|
||||||
return f"Nie mamy w ofercie pizzy o nazwie {pizza_name}."
|
|
||||||
|
|
||||||
if query_type == 'ingredient':
|
|
||||||
ingredients = pizza_data[pizza_name]['ingredient']
|
|
||||||
return f"Składniki pizzy {pizza_name} to: {', '.join(ingredients)}."
|
|
||||||
elif query_type == 'price':
|
|
||||||
price = pizza_data[pizza_name]['price']
|
|
||||||
return f"Cena pizzy {pizza_name} to {price} zł."
|
|
||||||
|
|
||||||
return f"Informacje o pizzy {pizza_name}: składniki to {', '.join(pizza_data[pizza_name]['ingredient'])}, cena to {pizza_data[pizza_name]['price']} zł."
|
|
||||||
|
|
||||||
|
|
||||||
def generate_ingredients_response(slots):
|
|
||||||
ingredients = [slot['value'] for slot in slots if slot['name'] == 'ingredients']
|
|
||||||
if ingredients:
|
|
||||||
ingredient_list = []
|
|
||||||
for ingredient in ingredients:
|
|
||||||
ingredient_list.extend(ingredient.keys())
|
|
||||||
response = f"Dostępne składniki to: {', '.join(ingredient_list)}."
|
|
||||||
return response
|
|
||||||
return "Nie podano składników."
|
|
||||||
|
|
||||||
|
|
||||||
def generate_drinks_response(slots, act):
|
|
||||||
more = 'drinks-more' in act
|
|
||||||
drinks = [slot['value'] for slot in slots if slot['name'] == 'drink']
|
|
||||||
if drinks:
|
|
||||||
drink_details = []
|
|
||||||
for drink in drinks:
|
|
||||||
for name, details in drink.items():
|
|
||||||
# price = details.get('price', 'unknown')
|
|
||||||
drink_details.append(f"{name}") # w cenie {price} zł")
|
|
||||||
if not more:
|
|
||||||
if len(drink_details) > 1:
|
|
||||||
response = f"Czy chcesz coś do picia? Dostępne napoje to: {', '.join(drink_details[:-1])} oraz {drink_details[-1]}."
|
|
||||||
else:
|
|
||||||
response = f"Czy chcesz coś do picia? Dostępne napoje to: {drink_details[0]}."
|
|
||||||
return response
|
|
||||||
else:
|
|
||||||
if len(drink_details) > 1:
|
|
||||||
response = f"Czy chcesz coś jeszcze picia?"
|
|
||||||
else:
|
|
||||||
response = f"Czy chcesz coś jeszcze picia?"
|
|
||||||
return response
|
|
||||||
return "Nie podano napojów."
|
|
||||||
|
|
||||||
def generate_size_response(slots):
|
|
||||||
sizes = [slot['value'] for slot in slots if slot['name'] == 'size'][0]
|
|
||||||
if len(sizes) != 0:
|
|
||||||
return F"Dostępne rozmiary to: {', '.join(sizes)}"
|
|
||||||
return "Nie podano rozmiarów."
|
|
||||||
|
|
||||||
def generate_sauce_response(slots):
|
|
||||||
sauces = [slot['value'] for slot in slots if slot['name'] == 'sauce']
|
|
||||||
if sauces:
|
|
||||||
sauce_list = []
|
|
||||||
for sauce in sauces:
|
|
||||||
if isinstance(sauce, list):
|
|
||||||
sauce_list.extend(sauce)
|
|
||||||
else:
|
|
||||||
sauce_list.append(sauce)
|
|
||||||
return f"Dostępne sosy to: {', '.join(sauce_list)}."
|
|
||||||
return "Nie podano sosów."
|
|
||||||
|
|
||||||
|
|
||||||
def select_template(act, slots):
|
def select_template(act, slots):
|
||||||
slot_names = {slot['name'] for slot in slots}
|
slot_names = {slot['name'] for slot in slots}
|
||||||
|
|
||||||
if act == "welcomemsg":
|
if act == "welcomemsg":
|
||||||
return random.choice(templates["welcomemsg"])
|
return random.choice(templates["welcomemsg"])
|
||||||
|
if act == "request/menu":
|
||||||
if "ingredients" in slot_names:
|
return random.choice(templates["request/menu"])
|
||||||
return generate_ingredients_response(slots)
|
if act == "inform/address":
|
||||||
elif "drink" in slot_names:
|
return random.choice(templates["inform/address"])
|
||||||
return generate_drinks_response(slots, act)
|
if act == "inform/delivery":
|
||||||
elif "sauce" in slot_names:
|
return random.choice(templates["inform/delivery"])
|
||||||
return generate_sauce_response(slots)
|
if act == "inform/payment":
|
||||||
elif "size" in slot_names:
|
return random.choice(templates["inform/payment"])
|
||||||
return generate_size_response(slots)
|
if act == "affirm":
|
||||||
elif "price" in slot_names:
|
return random.choice(templates["affirm"])
|
||||||
return random.choice(templates["inform/price"])
|
if act == "request/drinks":
|
||||||
elif "food" in slot_names:
|
return random.choice(templates["request/drinks"])
|
||||||
return random.choice(templates["inform/menu"])
|
if act == "bye":
|
||||||
|
return random.choice(templates["bye"])
|
||||||
if act == "inform":
|
if act == "inform/order":
|
||||||
if "menu" in slot_names:
|
if "quantity" in slot_names and "food" in slot_names and "pizza" in slot_names:
|
||||||
return random.choice(templates["inform/menu"])
|
|
||||||
elif "address" in slot_names:
|
|
||||||
return random.choice(templates["inform/address"])
|
|
||||||
elif "phone" in slot_names:
|
|
||||||
return random.choice(templates["inform/phone"])
|
|
||||||
elif "order-complete" in slot_names:
|
|
||||||
return random.choice(templates["inform/order-complete"])
|
|
||||||
elif "delivery" in slot_names:
|
|
||||||
return random.choice(templates["inform/delivery"])
|
|
||||||
elif "payment" in slot_names:
|
|
||||||
return random.choice(templates["inform/payment"])
|
|
||||||
elif "time" in slot_names:
|
|
||||||
return random.choice(templates["inform/time"])
|
|
||||||
elif "name" in slot_names:
|
|
||||||
return random.choice(templates["inform/name"])
|
|
||||||
elif "price" in slot_names and "pizza" in slot_names:
|
|
||||||
return generate_pizza_info(slots)
|
|
||||||
elif act == "inform/order":
|
|
||||||
if "quantity" in slot_names and "pizza" in slot_names and "size" in slot_names:
|
|
||||||
return templates["inform/order"][1]
|
return templates["inform/order"][1]
|
||||||
elif "quantity" in slot_names and "pizza" in slot_names:
|
elif "quantity" in slot_names and "pizza" in slot_names:
|
||||||
|
return templates["inform/order"][4]
|
||||||
|
elif "food" in slot_names and "pizza" in slot_names:
|
||||||
return templates["inform/order"][2]
|
return templates["inform/order"][2]
|
||||||
elif "quantity" in slot_names and "food" in slot_names:
|
|
||||||
return templates["inform/order"][0]
|
|
||||||
elif "food" in slot_names and "pizza" in slot_names and "price" in slot_names:
|
|
||||||
return templates["inform/order"][5]
|
|
||||||
elif "quantity" in slot_names:
|
elif "quantity" in slot_names:
|
||||||
return templates["inform/order"][3]
|
return templates["inform/order"][3]
|
||||||
else:
|
else:
|
||||||
return templates["inform/order"][4]
|
return templates["inform/order"][4]
|
||||||
elif act == "request/menu":
|
|
||||||
return random.choice(templates["request/menu"])
|
|
||||||
elif act == "inform/address":
|
|
||||||
return random.choice(templates["inform/address"])
|
|
||||||
elif act == "request/price":
|
|
||||||
return random.choice(templates["request/price"])
|
|
||||||
elif act == "inform/menu":
|
|
||||||
return random.choice(templates["inform/menu"])
|
|
||||||
elif act == "request/ingredients":
|
|
||||||
return random.choice(templates["request/ingredients"])
|
|
||||||
elif act == "request/sauce":
|
|
||||||
return random.choice(templates["request/sauce"])
|
|
||||||
elif act == "inform/phone":
|
|
||||||
return random.choice(templates["inform/phone"])
|
|
||||||
elif act == "inform/order-complete":
|
|
||||||
return random.choice(templates["inform/order-complete"])
|
|
||||||
elif act == "request/time":
|
|
||||||
return random.choice(templates["request/time"])
|
|
||||||
elif act == "request/size":
|
|
||||||
return random.choice(templates["request/size"])
|
|
||||||
elif act == "inform/delivery":
|
|
||||||
return random.choice(templates["inform/delivery"])
|
|
||||||
elif act == "inform/payment":
|
|
||||||
return random.choice(templates["inform/payment"])
|
|
||||||
elif act == "request/delivery-price":
|
|
||||||
return random.choice(templates["request/delivery-price"])
|
|
||||||
elif act == "inform/time":
|
|
||||||
return random.choice(templates["inform/time"])
|
|
||||||
elif act == "request/drinks":
|
|
||||||
return random.choice(templates["request/drinks"])
|
|
||||||
elif act == "request/food":
|
|
||||||
return random.choice(templates["inform/menu"])
|
|
||||||
elif act == "request/food-more":
|
|
||||||
return random.choice(templates["inform/menu-more"])
|
|
||||||
elif act == "inform/name":
|
|
||||||
return random.choice(templates["inform/name"])
|
|
||||||
elif act == "bye":
|
|
||||||
return random.choice(templates["bye"]) # TODO force end?
|
|
||||||
elif act == "bye_and_thanks":
|
|
||||||
return random.choice(templates["bye_and_thanks"])
|
|
||||||
elif act == "repeat":
|
|
||||||
return random.choice(templates["repeat"])
|
|
||||||
elif act == "request/address":
|
|
||||||
return random.choice(templates["request/address"])
|
|
||||||
elif act == "request/payment-method":
|
|
||||||
return random.choice(templates["request/payment-method"])
|
|
||||||
elif act == "request/phone":
|
|
||||||
return random.choice(templates["request/phone"])
|
|
||||||
|
|
||||||
return None
|
return "default/template"
|
||||||
|
|
||||||
|
|
||||||
# def select_template(act, slots):
|
|
||||||
# slot_names = {slot['name'] for slot in slots}
|
|
||||||
|
|
||||||
# if act == "welcomemsg":
|
|
||||||
# return random.choice(templates["welcomemsg"])
|
|
||||||
# if act == "request/menu":
|
|
||||||
# return random.choice(templates["request/menu"])
|
|
||||||
# if act == "inform/address":
|
|
||||||
# return random.choice(templates["inform/address"])
|
|
||||||
# if act == "inform/delivery":
|
|
||||||
# return random.choice(templates["inform/delivery"])
|
|
||||||
# if act == "inform/payment":
|
|
||||||
# return random.choice(templates["inform/payment"])
|
|
||||||
# if act == "affirm":
|
|
||||||
# return random.choice(templates["affirm"])
|
|
||||||
# if act == "request/drinks":
|
|
||||||
# return random.choice(templates["request/drinks"])
|
|
||||||
# if act == "bye":
|
|
||||||
# return random.choice(templates["bye"])
|
|
||||||
# if act == "inform/order":
|
|
||||||
# if "quantity" in slot_names and "food" in slot_names and "pizza" in slot_names:
|
|
||||||
# return templates["inform/order"][1]
|
|
||||||
# elif "quantity" in slot_names and "pizza" in slot_names:
|
|
||||||
# return templates["inform/order"][4]
|
|
||||||
# elif "food" in slot_names and "pizza" in slot_names:
|
|
||||||
# return templates["inform/order"][2]
|
|
||||||
# elif "quantity" in slot_names:
|
|
||||||
# return templates["inform/order"][3]
|
|
||||||
# else:
|
|
||||||
# return templates["inform/order"][4]
|
|
||||||
|
|
||||||
# return "default/template"
|
|
@ -6,24 +6,12 @@ templates = {
|
|||||||
"Dziękujemy za zamówienie {quantity} x {food}.",
|
"Dziękujemy za zamówienie {quantity} x {food}.",
|
||||||
"Na jaką pizzę masz ochotę?"
|
"Na jaką pizzę masz ochotę?"
|
||||||
],
|
],
|
||||||
"inform/menu": [
|
"request/menu": [
|
||||||
"Oferujemy następujące pizze: {menu}.",
|
"Oto nasze menu: {menu}.",
|
||||||
],
|
"Nasze menu obejmuje: {menu}.",
|
||||||
"inform/menu-more": [
|
"Proszę, oto lista dostępnych dań: {menu}.",
|
||||||
"Czy chcesz jeszcze jakąś pizzę?",
|
"Dostępne dania to: {menu}.",
|
||||||
],
|
"W naszym menu znajdziesz: {menu}."
|
||||||
"inform/name": [
|
|
||||||
"Twoje imię to {name}.",
|
|
||||||
"Podane imię: {name}.",
|
|
||||||
"Twoje imię: {name}.",
|
|
||||||
"Masz na imię {name}.",
|
|
||||||
"Imię: {name}."
|
|
||||||
],
|
|
||||||
"inform/price": [
|
|
||||||
"Cena wybranej pizzy to {price} zł.",
|
|
||||||
"Koszt pizzy to {price} zł.",
|
|
||||||
"Wybrana pizza kosztuje {price} zł.",
|
|
||||||
"Cena pizzy wynosi {price} zł."
|
|
||||||
],
|
],
|
||||||
"inform/address": [
|
"inform/address": [
|
||||||
"Twój adres to: {address}.",
|
"Twój adres to: {address}.",
|
||||||
@ -32,73 +20,6 @@ templates = {
|
|||||||
"Dostarczymy na adres: {address}.",
|
"Dostarczymy na adres: {address}.",
|
||||||
"Twój podany adres to: {address}."
|
"Twój podany adres to: {address}."
|
||||||
],
|
],
|
||||||
"inform/sauce": [
|
|
||||||
"Dostępne sosy to: {sauce}.",
|
|
||||||
"Możesz wybrać spośród następujących sosów: {sauce}.",
|
|
||||||
"Oferujemy następujące sosy: {sauce}."
|
|
||||||
],
|
|
||||||
"inform/phone": [
|
|
||||||
"Twój numer telefonu to: {phone}.",
|
|
||||||
"Podany numer telefonu: {phone}.",
|
|
||||||
"Numer telefonu, który podałeś, to: {phone}.",
|
|
||||||
"Twoje dane kontaktowe: {phone}.",
|
|
||||||
"Telefon kontaktowy: {phone}."
|
|
||||||
],
|
|
||||||
"inform/order-complete": [
|
|
||||||
"Twoje zamówienie zostało zrealizowane. Dziękujemy!",
|
|
||||||
"Zamówienie zakończone. Dziękujemy za zakupy!",
|
|
||||||
"Zamówienie zrealizowane. Czekaj na dostawę!",
|
|
||||||
"Twoje zamówienie jest gotowe. Dziękujemy!",
|
|
||||||
"Realizacja zamówienia zakończona. Dziękujemy!"
|
|
||||||
],
|
|
||||||
"inform/delivery": [
|
|
||||||
"Twoje zamówienie zostanie dostarczone na {address}.",
|
|
||||||
"Dostarczymy zamówienie na adres: {address}.",
|
|
||||||
"Dostawa na adres: {address}.",
|
|
||||||
"Twoje zamówienie jedzie na {address}.",
|
|
||||||
"Adres dostawy: {address}."
|
|
||||||
],
|
|
||||||
"inform/address": [
|
|
||||||
"Twoje zamówienie zostanie dostarczone na adres: {address}.",
|
|
||||||
"Dostawa będzie na adres: {address}.",
|
|
||||||
"Adres dostawy: {address}."
|
|
||||||
],
|
|
||||||
"inform/payment": [
|
|
||||||
"Metoda płatności to: {payment-method}.",
|
|
||||||
"Płatność realizujesz przez: {payment-method}.",
|
|
||||||
"Wybrałeś metodę płatności: {payment-method}.",
|
|
||||||
"Płatność: {payment-method}.",
|
|
||||||
"Możesz zapłacić kartą, gotówką lub blikiem"
|
|
||||||
],
|
|
||||||
"affirm": [
|
|
||||||
"Świetnie! Napisz co Ci chodzi po głowie.",
|
|
||||||
"Dobrze! Co dalej?",
|
|
||||||
"OK! Co chciałbyś zamówić?",
|
|
||||||
"Super! Co dalej?",
|
|
||||||
"Dobrze! Jakie dalsze zamówienia?"
|
|
||||||
],
|
|
||||||
"request/delivery-price": [
|
|
||||||
"Koszt dostawy wynosi {delivery-price} zł.",
|
|
||||||
"Cena dostawy to {delivery-price} zł.",
|
|
||||||
"Za dostawę zapłacisz {delivery-price} zł.",
|
|
||||||
"Dostawa kosztuje {delivery-price} zł.",
|
|
||||||
"Koszt dostawy: {delivery-price} zł."
|
|
||||||
],
|
|
||||||
"request/menu": [
|
|
||||||
"W naszym menu znajdziesz: {menu}."
|
|
||||||
],
|
|
||||||
"inform/time": [
|
|
||||||
"Aktualny czas to {time}.",
|
|
||||||
"Jest teraz {time}.",
|
|
||||||
"Czas: {time}.",
|
|
||||||
"Godzina: {time}.",
|
|
||||||
"Obecny czas: {time}."
|
|
||||||
],
|
|
||||||
"welcomemsg": [
|
|
||||||
"Witaj w naszej wspaniałej pizzerii. W czym mogę pomóc?",
|
|
||||||
"Halo, halo, tu najlepsza pizza w mieście. Masz głoda?",
|
|
||||||
"Dzieńdoberek, gdyby wszyscy jedli nasze pizze, na świecie nie byłoby wojen. Jaką pizzę sobie dziś gruchniesz?",
|
|
||||||
],
|
|
||||||
"request/price": [
|
"request/price": [
|
||||||
"Cena za {food} wynosi {price} zł.",
|
"Cena za {food} wynosi {price} zł.",
|
||||||
"Koszt {food} to {price} zł.",
|
"Koszt {food} to {price} zł.",
|
||||||
@ -120,10 +41,19 @@ templates = {
|
|||||||
"Które sosy mają być do {food}?",
|
"Które sosy mają być do {food}?",
|
||||||
"Wybierz sosy do {food}."
|
"Wybierz sosy do {food}."
|
||||||
],
|
],
|
||||||
"request/food": [
|
"inform/phone": [
|
||||||
"Co chciałbyś zamówić?",
|
"Twój numer telefonu to: {phone}.",
|
||||||
"Proszę podać na jaką pizzę masz ochotę",
|
"Podany numer telefonu: {phone}.",
|
||||||
"Którą pizzę wybrałeś?"
|
"Numer telefonu, który podałeś, to: {phone}.",
|
||||||
|
"Twoje dane kontaktowe: {phone}.",
|
||||||
|
"Telefon kontaktowy: {phone}."
|
||||||
|
],
|
||||||
|
"inform/order-complete": [
|
||||||
|
"Twoje zamówienie zostało zrealizowane. Dziękujemy!",
|
||||||
|
"Zamówienie zakończone. Dziękujemy za zakupy!",
|
||||||
|
"Zamówienie zrealizowane. Czekaj na dostawę!",
|
||||||
|
"Twoje zamówienie jest gotowe. Dziękujemy!",
|
||||||
|
"Realizacja zamówienia zakończona. Dziękujemy!"
|
||||||
],
|
],
|
||||||
"request/time": [
|
"request/time": [
|
||||||
"Oczekiwany czas dostawy to {time} minut.",
|
"Oczekiwany czas dostawy to {time} minut.",
|
||||||
@ -139,6 +69,73 @@ templates = {
|
|||||||
"Proszę wybrać rozmiar {pizza}: {sizes}.",
|
"Proszę wybrać rozmiar {pizza}: {sizes}.",
|
||||||
"Mamy następujące rozmiary {pizza}: {sizes}."
|
"Mamy następujące rozmiary {pizza}: {sizes}."
|
||||||
],
|
],
|
||||||
|
"affirm": [
|
||||||
|
"Świetnie! Napisz co Ci chodzi po głowie.",
|
||||||
|
"Dobrze! Co dalej?",
|
||||||
|
"OK! Co chciałbyś zamówić?",
|
||||||
|
"Super! Co dalej?",
|
||||||
|
"Dobrze! Jakie dalsze zamówienia?"
|
||||||
|
],
|
||||||
|
"inform/delivery": [
|
||||||
|
"Twoje zamówienie zostanie dostarczone na {address}.",
|
||||||
|
"Dostarczymy zamówienie na adres: {address}.",
|
||||||
|
"Dostawa na adres: {address}.",
|
||||||
|
"Twoje zamówienie jedzie na {address}.",
|
||||||
|
"Adres dostawy: {address}."
|
||||||
|
],
|
||||||
|
"inform/address": [
|
||||||
|
"Twoje zamówienie zostanie dostarczone na adres: {address}.",
|
||||||
|
"Dostawa będzie na adres: {address}.",
|
||||||
|
"Adres dostawy: {address}."
|
||||||
|
],
|
||||||
|
"inform/payment": [
|
||||||
|
"Metoda płatności to: {payment-method}.",
|
||||||
|
"Płatność realizujesz przez: {payment-method}.",
|
||||||
|
"Wybrałeś metodę płatności: {payment-method}.",
|
||||||
|
"Płatność: {payment-method}.",
|
||||||
|
"Możesz zapłacić kartą, gotówką lub blikiem"
|
||||||
|
],
|
||||||
|
"request/delivery-price": [
|
||||||
|
"Koszt dostawy wynosi {delivery-price} zł.",
|
||||||
|
"Cena dostawy to {delivery-price} zł.",
|
||||||
|
"Za dostawę zapłacisz {delivery-price} zł.",
|
||||||
|
"Dostawa kosztuje {delivery-price} zł.",
|
||||||
|
"Koszt dostawy: {delivery-price} zł."
|
||||||
|
],
|
||||||
|
"inform/time": [
|
||||||
|
"Aktualny czas to {time}.",
|
||||||
|
"Jest teraz {time}.",
|
||||||
|
"Czas: {time}.",
|
||||||
|
"Godzina: {time}.",
|
||||||
|
"Obecny czas: {time}."
|
||||||
|
],
|
||||||
|
"request/drinks": [
|
||||||
|
"Jakie napoje chciałbyś zamówić?",
|
||||||
|
"Proszę wybrać napoje do zamówienia.",
|
||||||
|
"Jakie napoje dołączamy do zamówienia?",
|
||||||
|
"Co chciałbyś pić?",
|
||||||
|
"Proszę podać napoje do zamówienia."
|
||||||
|
],
|
||||||
|
"inform/name": [
|
||||||
|
"Twoje imię to {name}.",
|
||||||
|
"Podane imię: {name}.",
|
||||||
|
"Twoje imię: {name}.",
|
||||||
|
"Masz na imię {name}.",
|
||||||
|
"Imię: {name}."
|
||||||
|
],
|
||||||
|
"welcomemsg": [
|
||||||
|
"Witaj w naszej wspaniałej pizzerii. W czym mogę pomóc?",
|
||||||
|
"Halo, halo, tu najlepsza pizza w mieście. Masz głoda?",
|
||||||
|
"Dzieńdoberek, gdyby wszyscy jedli nasze pizze, na świecie nie byłoby wojen. Jaką pizzę sobie dziś gruchniesz?",
|
||||||
|
],
|
||||||
|
"request/menu": [
|
||||||
|
"W naszym menu znajdują się pizze, spaghetti, gnocci oraz aranchini. Polecam potrawkę śląską po grecku.",
|
||||||
|
"Smażymy, gotujemy, prażymy, ale najlepiej nam wychodzi pizza. Na co masz ochotę?",
|
||||||
|
],
|
||||||
|
"request/drink": [
|
||||||
|
"Oferujemy napoje zimne, ciepłe i letnie. Cola, fanta, woda mineralna, kawa, herbata lub frappe.",
|
||||||
|
"Może z alkoholem? Mamy świeżo warzone piwo",
|
||||||
|
],
|
||||||
"bye": [
|
"bye": [
|
||||||
"Dziękujemy i do zobaczenia wkrótce.",
|
"Dziękujemy i do zobaczenia wkrótce.",
|
||||||
"Polecamy się na przyszłość. Do zobaczenia!",
|
"Polecamy się na przyszłość. Do zobaczenia!",
|
||||||
@ -152,18 +149,9 @@ templates = {
|
|||||||
"Niestety, nie mamy {ingredient/neg}."
|
"Niestety, nie mamy {ingredient/neg}."
|
||||||
],
|
],
|
||||||
"default/template": [
|
"default/template": [
|
||||||
"Nie zrozumiałem, spróbuj inaczej sformułować zdanie."
|
"Przepraszamy, ale nie rozumiemy Twojego zapytania.",
|
||||||
],
|
"Proszę spróbować ponownie później.",
|
||||||
"repeat": [
|
"Nie rozpoznajemy Twojej prośby, spróbuj ponownie.",
|
||||||
"Nie zrozumiałem, spróbuj inaczej sformułować zdanie.",
|
"Strasznie szybko to napisałeś, nie zrozumiałem...."
|
||||||
],
|
|
||||||
"request/address": [
|
|
||||||
"Jaki adres dostawy?"
|
|
||||||
],
|
|
||||||
"bye_and_thanks": [
|
|
||||||
"Dziękujęmy, przekazaliśmy zamówienie do realizacji. Do zobaczenia!",
|
|
||||||
],
|
|
||||||
"request/phone": [
|
|
||||||
"Podaj proszę numer telefonu do kontaktu dla kuriera."
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -2,49 +2,17 @@ from src.service.dialog_state_monitor import DialogStateMonitor
|
|||||||
from src.model.frame import Frame
|
from src.model.frame import Frame
|
||||||
from src.model.slot import Slot
|
from src.model.slot import Slot
|
||||||
|
|
||||||
dsm = DialogStateMonitor()
|
dst = DialogStateMonitor()
|
||||||
|
|
||||||
assert dsm.item_exists('pizza', 'capri') is True
|
dst.update(Frame('user', 'inform/order', [Slot('pizza', 'margaritta'), Slot('sauce', 'ketchup')]))
|
||||||
assert dsm.item_exists('pizza', 'buraczana') is False
|
dst.update(Frame('user', 'inform/order', [Slot('pizza', 'carbonara')]))
|
||||||
assert dsm.item_exists('drink', 'cola') is True
|
dst.update(Frame('user', 'inform/order-complete', []))
|
||||||
assert dsm.state['was_previous_order_invalid'] is False
|
|
||||||
|
|
||||||
assert dsm.get_current_active_stage() == 'collect_food'
|
assert dst.state['belief_state']['order'][0]['pizza'] == 'margaritta'
|
||||||
frame1 = Frame('user', 'inform/order', [Slot('pizza', 'margarita'), Slot('sauce', 'ketchup')])
|
assert dst.state['belief_state']['order'][0]['sauce'] == 'ketchup'
|
||||||
dsm.update(frame1)
|
assert dst.state['belief_state']['order-complete'] == True
|
||||||
assert dsm.get_current_active_stage() == 'collect_drinks'
|
|
||||||
assert dsm.get_total_cost() == 20
|
|
||||||
frame2 = Frame('user', 'inform/order', [Slot('pizza', 'tuna')])
|
|
||||||
dsm.update(frame2)
|
|
||||||
assert dsm.get_current_active_stage() == 'collect_drinks'
|
|
||||||
assert dsm.get_total_cost() == 20 # Pizza is not added, as previous stage is closed already
|
|
||||||
frame3 = Frame('user', 'inform/order-complete', [])
|
|
||||||
dsm.update(frame3)
|
|
||||||
frame4 = Frame('user', 'inform/order', [Slot('drink', 'cola')])
|
|
||||||
dsm.update(frame4)
|
|
||||||
assert dsm.get_current_active_stage() == 'collect_address'
|
|
||||||
assert dsm.get_total_cost() == 30
|
|
||||||
|
|
||||||
assert dsm.state['belief_state']['order'][0]['pizza'] == 'margarita'
|
dst.reset()
|
||||||
assert dsm.state['belief_state']['order'][0]['sauce'] == 'ketchup'
|
|
||||||
assert dsm.state['belief_state']['order-complete'] is True
|
|
||||||
assert dsm.state['history'][0] == frame1
|
|
||||||
assert dsm.state['history'][1] == frame2
|
|
||||||
assert dsm.state['history'][2] == frame3
|
|
||||||
assert dsm.state['history'][3] == frame4
|
|
||||||
|
|
||||||
dsm.reset()
|
assert dst.state['belief_state']['order'] == []
|
||||||
|
assert dst.state['belief_state']['order-complete'] == False
|
||||||
assert dsm.get_total_cost() == 0
|
|
||||||
assert dsm.get_current_active_stage() == 'collect_food'
|
|
||||||
assert dsm.state['belief_state']['order'] == []
|
|
||||||
assert dsm.state['belief_state']['order-complete'] is False
|
|
||||||
assert len(dsm.state['history']) == 0
|
|
||||||
|
|
||||||
dsm.reset()
|
|
||||||
|
|
||||||
frame1 = Frame('user', 'inform/order', [Slot('pizza', 'buraczana')])
|
|
||||||
dsm.update(frame1)
|
|
||||||
assert dsm.state['was_previous_order_invalid'] is True
|
|
||||||
assert dsm.state['belief_state']['order'] == []
|
|
||||||
assert dsm.get_total_cost() == 0
|
|
||||||
|
Loading…
Reference in New Issue
Block a user