Dodanie stanów i ograniczeń do DSM

This commit is contained in:
Patryk Osiński 2024-06-10 00:19:13 +02:00
parent 4483747021
commit a90535a68b
3 changed files with 70 additions and 37 deletions

View File

@ -36,16 +36,5 @@
"ingredient": ["tomato", "tuna", "onion", "cheese"],
"price": 40
}
},
"size": {
"m": {
"price_multiplier": 1
},
"l": {
"price_multiplier": 1.2
},
"xl": {
"price_multiplier": 1.4
}
}
}

View File

@ -10,19 +10,52 @@ def normalize(value):
class DialogStateMonitor:
def __init__(self):
self.__initial_state = dict(belief_state={
'order': [],
'address': {},
'order-complete': False,
'phone': {},
'delivery': {},
'payment': {},
'time': {},
'name': {},
},
history=[])
'order': [],
'address': {},
'order-complete': False,
'phone': {},
'delivery': {},
'payment': {},
'time': {},
'name': {},
},
stages=[
{'completed': False, 'name': 'collect_food'},
{'completed': False, 'name': 'collect_drinks'},
{'completed': False, 'name': 'collect_address'},
],
constraints={
'order': [
'sauce',
'pizza',
],
},
history=[])
self.state = copy.deepcopy(self.__initial_state)
def update(self, frame: Frame):
def get_last_order_missing_fields(self) -> list[str]:
result = []
try:
last_order = self.state['belief_state']['order'][-1]
except IndexError:
raise RuntimeError('No orders are placed')
for constraint in self.state['constraints']['order']:
if constraint not in last_order:
result.append(constraint)
return result
def get_current_active_stage(self) -> str | None:
for stage in self.state['stages']:
if stage['completed'] is False:
return stage['name']
def mark_current_stage_completed(self) -> None:
for stage in self.state['stages']:
if stage['completed'] is False:
stage['completed'] = True
return
def update(self, frame: Frame) -> None:
self.state['history'].append(frame)
if frame.source != 'user':
return
@ -52,5 +85,5 @@ class DialogStateMonitor:
for slot in frame.slots:
self.state['belief_state']['name'][slot.name] = normalize(slot.value)
def reset(self):
def reset(self) -> None:
self.state = copy.deepcopy(self.__initial_state)

View File

@ -2,24 +2,35 @@ from src.service.dialog_state_monitor import DialogStateMonitor
from src.model.frame import Frame
from src.model.slot import Slot
dst = DialogStateMonitor()
dsm = DialogStateMonitor()
frame1 = Frame('user', 'inform/order', [Slot('pizza', 'margaritta'), Slot('sauce', 'ketchup')])
dsm.update(frame1)
assert dsm.get_last_order_missing_fields() == []
frame2 = Frame('user', 'inform/order', [Slot('pizza', 'carbonara')])
dsm.update(frame2)
assert dsm.get_last_order_missing_fields() == ['sauce']
frame3 = Frame('user', 'inform/order-complete', [])
dst.update(frame1)
dst.update(frame2)
dst.update(frame3)
dsm.update(frame3)
assert dst.state['belief_state']['order'][0]['pizza'] == 'margaritta'
assert dst.state['belief_state']['order'][0]['sauce'] == 'ketchup'
assert dst.state['belief_state']['order-complete'] is True
assert dst.state['history'][0] == frame1
assert dst.state['history'][1] == frame2
assert dst.state['history'][2] == frame3
assert dsm.state['belief_state']['order'][0]['pizza'] == 'margaritta'
assert dsm.state['belief_state']['order'][0]['sauce'] == 'ketchup'
assert dsm.state['belief_state']['order-complete'] is True
assert dsm.state['history'][0] == frame1
assert dsm.state['history'][1] == frame2
assert dsm.state['history'][2] == frame3
dst.reset()
assert dsm.get_current_active_stage() == 'collect_food'
dsm.mark_current_stage_completed()
assert dsm.get_current_active_stage() == 'collect_drinks'
dsm.mark_current_stage_completed()
assert dsm.get_current_active_stage() == 'collect_address'
dsm.mark_current_stage_completed()
assert dsm.get_current_active_stage() is None
assert dst.state['belief_state']['order'] == []
assert dst.state['belief_state']['order-complete'] is False
assert len(dst.state['history']) == 0
dsm.reset()
assert dsm.get_current_active_stage() == 'collect_food'
assert dsm.state['belief_state']['order'] == []
assert dsm.state['belief_state']['order-complete'] is False
assert len(dsm.state['history']) == 0