Dodanie stanów i ograniczeń do DSM

This commit is contained in:
Patryk Osiński 2024-06-10 00:19:13 +02:00
parent 4483747021
commit a90535a68b
3 changed files with 70 additions and 37 deletions

View File

@ -36,16 +36,5 @@
"ingredient": ["tomato", "tuna", "onion", "cheese"], "ingredient": ["tomato", "tuna", "onion", "cheese"],
"price": 40 "price": 40
} }
},
"size": {
"m": {
"price_multiplier": 1
},
"l": {
"price_multiplier": 1.2
},
"xl": {
"price_multiplier": 1.4
}
} }
} }

View File

@ -10,19 +10,52 @@ def normalize(value):
class DialogStateMonitor: class DialogStateMonitor:
def __init__(self): def __init__(self):
self.__initial_state = dict(belief_state={ self.__initial_state = dict(belief_state={
'order': [], 'order': [],
'address': {}, 'address': {},
'order-complete': False, 'order-complete': False,
'phone': {}, 'phone': {},
'delivery': {}, 'delivery': {},
'payment': {}, 'payment': {},
'time': {}, 'time': {},
'name': {}, 'name': {},
}, },
history=[]) stages=[
{'completed': False, 'name': 'collect_food'},
{'completed': False, 'name': 'collect_drinks'},
{'completed': False, 'name': 'collect_address'},
],
constraints={
'order': [
'sauce',
'pizza',
],
},
history=[])
self.state = copy.deepcopy(self.__initial_state) self.state = copy.deepcopy(self.__initial_state)
def update(self, frame: Frame): def get_last_order_missing_fields(self) -> list[str]:
result = []
try:
last_order = self.state['belief_state']['order'][-1]
except IndexError:
raise RuntimeError('No orders are placed')
for constraint in self.state['constraints']['order']:
if constraint not in last_order:
result.append(constraint)
return result
def get_current_active_stage(self) -> str | None:
for stage in self.state['stages']:
if stage['completed'] is False:
return stage['name']
def mark_current_stage_completed(self) -> None:
for stage in self.state['stages']:
if stage['completed'] is False:
stage['completed'] = True
return
def update(self, frame: Frame) -> None:
self.state['history'].append(frame) self.state['history'].append(frame)
if frame.source != 'user': if frame.source != 'user':
return return
@ -52,5 +85,5 @@ class DialogStateMonitor:
for slot in frame.slots: for slot in frame.slots:
self.state['belief_state']['name'][slot.name] = normalize(slot.value) self.state['belief_state']['name'][slot.name] = normalize(slot.value)
def reset(self): def reset(self) -> None:
self.state = copy.deepcopy(self.__initial_state) self.state = copy.deepcopy(self.__initial_state)

View File

@ -2,24 +2,35 @@ from src.service.dialog_state_monitor import DialogStateMonitor
from src.model.frame import Frame from src.model.frame import Frame
from src.model.slot import Slot from src.model.slot import Slot
dst = DialogStateMonitor() dsm = DialogStateMonitor()
frame1 = Frame('user', 'inform/order', [Slot('pizza', 'margaritta'), Slot('sauce', 'ketchup')]) frame1 = Frame('user', 'inform/order', [Slot('pizza', 'margaritta'), Slot('sauce', 'ketchup')])
dsm.update(frame1)
assert dsm.get_last_order_missing_fields() == []
frame2 = Frame('user', 'inform/order', [Slot('pizza', 'carbonara')]) frame2 = Frame('user', 'inform/order', [Slot('pizza', 'carbonara')])
dsm.update(frame2)
assert dsm.get_last_order_missing_fields() == ['sauce']
frame3 = Frame('user', 'inform/order-complete', []) frame3 = Frame('user', 'inform/order-complete', [])
dst.update(frame1) dsm.update(frame3)
dst.update(frame2)
dst.update(frame3)
assert dst.state['belief_state']['order'][0]['pizza'] == 'margaritta' assert dsm.state['belief_state']['order'][0]['pizza'] == 'margaritta'
assert dst.state['belief_state']['order'][0]['sauce'] == 'ketchup' assert dsm.state['belief_state']['order'][0]['sauce'] == 'ketchup'
assert dst.state['belief_state']['order-complete'] is True assert dsm.state['belief_state']['order-complete'] is True
assert dst.state['history'][0] == frame1 assert dsm.state['history'][0] == frame1
assert dst.state['history'][1] == frame2 assert dsm.state['history'][1] == frame2
assert dst.state['history'][2] == frame3 assert dsm.state['history'][2] == frame3
dst.reset() assert dsm.get_current_active_stage() == 'collect_food'
dsm.mark_current_stage_completed()
assert dsm.get_current_active_stage() == 'collect_drinks'
dsm.mark_current_stage_completed()
assert dsm.get_current_active_stage() == 'collect_address'
dsm.mark_current_stage_completed()
assert dsm.get_current_active_stage() is None
assert dst.state['belief_state']['order'] == [] dsm.reset()
assert dst.state['belief_state']['order-complete'] is False
assert len(dst.state['history']) == 0 assert dsm.get_current_active_stage() == 'collect_food'
assert dsm.state['belief_state']['order'] == []
assert dsm.state['belief_state']['order-complete'] is False
assert len(dsm.state['history']) == 0