add dst
This commit is contained in:
parent
166707dd02
commit
1926681255
@ -1,5 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from modules.nlp import NaturalLanguageProcessor
|
from modules.nlu import NLU, Slot, UserAct
|
||||||
|
from modules.state_monitor import DialogStateMonitor
|
||||||
from modules.generator import ResponseGenerator
|
from modules.generator import ResponseGenerator
|
||||||
from modules.config import Config
|
from modules.config import Config
|
||||||
import colorama
|
import colorama
|
||||||
@ -13,7 +14,8 @@ def main():
|
|||||||
config_path = base_path / 'config' / 'config.json'
|
config_path = base_path / 'config' / 'config.json'
|
||||||
config = Config.load_config(config_path)
|
config = Config.load_config(config_path)
|
||||||
|
|
||||||
nlp = NaturalLanguageProcessor(config)
|
nlu = NLU()
|
||||||
|
dst = DialogStateMonitor()
|
||||||
generator = ResponseGenerator(config)
|
generator = ResponseGenerator(config)
|
||||||
|
|
||||||
print(Fore.CYAN + "Witaj w chatbocie! Rozpocznij rozmowę.")
|
print(Fore.CYAN + "Witaj w chatbocie! Rozpocznij rozmowę.")
|
||||||
@ -25,10 +27,14 @@ def main():
|
|||||||
print(Fore.RED + "Zamykanie chatbota...")
|
print(Fore.RED + "Zamykanie chatbota...")
|
||||||
break
|
break
|
||||||
|
|
||||||
intent = nlp.analyze(user_input)
|
user_act = nlu.analyze(user_input)
|
||||||
response = generator.generate(intent)
|
# user_act = UserAct(intent='inform',
|
||||||
|
# slots=[Slot(name='item', value='laptop'), Slot(name='item', value='kot'),Slot(name='address', value='123 Main St')])
|
||||||
|
dst.update(user_act)
|
||||||
|
print(dst.state)
|
||||||
|
# response = generator.generate(intent)
|
||||||
|
|
||||||
print(Fore.CYAN + "Bot: " + response)
|
# print(Fore.CYAN + "Bot: " + response)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,15 +1,39 @@
|
|||||||
from flair.models import SequenceTagger
|
from flair.models import SequenceTagger
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append("..")
|
sys.path.append("..")
|
||||||
from models.nlu_train2 import predict_frame, predict_slot
|
from models.nlu_train2 import predict_frame, predict_slot
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.getLogger('flair').setLevel(logging.CRITICAL)
|
logging.getLogger('flair').setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
|
||||||
|
class Slot:
|
||||||
|
def __init__(self, name, value=None):
|
||||||
|
self.name = name
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"Name: {self.name}, Value: {self.value}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserAct:
|
||||||
|
def __init__(self, intent: str, slots: list[Slot] = []):
|
||||||
|
self.slots = slots
|
||||||
|
self.intent = intent
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
msg = f"Act: {self.intent}, Slots: ["
|
||||||
|
for slot in self.slots:
|
||||||
|
msg += f"({slot}), "
|
||||||
|
msg += "]"
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
class NLU:
|
class NLU:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.frame_model = SequenceTagger.load('../models/frame-model/final-model.pt')
|
self.frame_model = SequenceTagger.load('models/frame-model/final-model.pt')
|
||||||
self.slot_model = SequenceTagger.load('../models/slot-model/final-model.pt')
|
self.slot_model = SequenceTagger.load('models/slot-model/final-model.pt')
|
||||||
|
|
||||||
def get_intent(self, text: str):
|
def get_intent(self, text: str):
|
||||||
return predict_frame(self.frame_model, text.split(), 'frame')
|
return predict_frame(self.frame_model, text.split(), 'frame')
|
||||||
@ -24,14 +48,14 @@ class NLU:
|
|||||||
slot = frame["slot"]
|
slot = frame["slot"]
|
||||||
if slot.startswith("B-"):
|
if slot.startswith("B-"):
|
||||||
if current_slot:
|
if current_slot:
|
||||||
slots.append({'name': current_slot, 'value': " ".join(current_slot_value)})
|
slots.append(Slot(name=current_slot, value=current_slot_value))
|
||||||
current_slot = slot[2:]
|
current_slot = slot[2:]
|
||||||
current_slot_value = [frame["form"]]
|
current_slot_value = [frame["form"]]
|
||||||
elif slot.startswith("I-"):
|
elif slot.startswith("I-"):
|
||||||
current_slot_value.append(frame["form"])
|
current_slot_value.append(frame["form"])
|
||||||
|
|
||||||
if current_slot:
|
if current_slot:
|
||||||
slots.append({'name': current_slot, 'value': " ".join(current_slot_value)})
|
slots.append(Slot(name=current_slot, value=current_slot_value))
|
||||||
|
|
||||||
return slots
|
return slots
|
||||||
|
|
||||||
@ -40,11 +64,4 @@ class NLU:
|
|||||||
slots = self.get_slot(text)
|
slots = self.get_slot(text)
|
||||||
print({'intent': intent,
|
print({'intent': intent,
|
||||||
'slots': slots})
|
'slots': slots})
|
||||||
return {
|
return UserAct(intent=intent, slots=slots)
|
||||||
'intent': intent,
|
|
||||||
'slots': slots
|
|
||||||
}
|
|
||||||
|
|
||||||
nlu = NLU()
|
|
||||||
|
|
||||||
nlu.analyze("Chce kupic lakier do pazanokci")
|
|
||||||
|
@ -1,6 +1,66 @@
|
|||||||
class DialogueStateMonitor:
|
import copy
|
||||||
def __init__(self) -> None:
|
from modules.nlu import UserAct
|
||||||
self.state = {'last_intent': 'unknown'}
|
import json
|
||||||
|
|
||||||
def update_state(self, intent: str) -> None:
|
|
||||||
self.state['last_intent'] = intent
|
class DialogStateMonitor:
|
||||||
|
def __init__(self):
|
||||||
|
self.__initial_state = dict(
|
||||||
|
belief_state={
|
||||||
|
'item': {},
|
||||||
|
'address': {},
|
||||||
|
'card_nr': {},
|
||||||
|
'delivery_method': {},
|
||||||
|
'payment_method': {},
|
||||||
|
'email': {},
|
||||||
|
'order-complete': False,
|
||||||
|
},
|
||||||
|
act='',
|
||||||
|
slot_names=[])
|
||||||
|
|
||||||
|
self.state = copy.deepcopy(self.__initial_state)
|
||||||
|
|
||||||
|
def is_value_empty(self, d, key):
|
||||||
|
value = d.get(key, None)
|
||||||
|
if value in [None, '', [], {}]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def update_act(self, intent):
|
||||||
|
self.state['act'] = intent
|
||||||
|
|
||||||
|
def update_slot_names(self, slots_names):
|
||||||
|
self.state['slot_names'] = slots_names
|
||||||
|
|
||||||
|
def check_order_complete(self):
|
||||||
|
all_filled = all(bool(self.state['belief_state'][key]) for key in
|
||||||
|
['item', 'address', 'card_nr', 'delivery_method', 'payment_method', 'email'])
|
||||||
|
self.state['belief_state']['order-complete'] = all_filled
|
||||||
|
|
||||||
|
def update(self, act: UserAct) -> None:
|
||||||
|
print(act)
|
||||||
|
if act.intent == 'inform':
|
||||||
|
self.update_act(act.intent)
|
||||||
|
slots_mapping = {
|
||||||
|
'item': [],
|
||||||
|
'address': [],
|
||||||
|
'card_nr': [],
|
||||||
|
'delivery_method': [],
|
||||||
|
'payment_method': [],
|
||||||
|
'email': []
|
||||||
|
}
|
||||||
|
for slot in act.slots:
|
||||||
|
if slot.name in slots_mapping and self.is_value_empty(self.state, slot.name):
|
||||||
|
slots_mapping[slot.name].append(slot.value) # To do: normalization
|
||||||
|
|
||||||
|
for slot_name, values in slots_mapping.items():
|
||||||
|
if values:
|
||||||
|
self.state['belief_state'][slot_name] = values
|
||||||
|
elif act.intent == 'request':
|
||||||
|
self.update_act(act.intent)
|
||||||
|
slots_names = [slot.name for slot in act.slots]
|
||||||
|
self.update_slot_names(slots_names)
|
||||||
|
elif act.intent == 'bye':
|
||||||
|
self.update_act(act.intent)
|
||||||
|
|
||||||
|
self.check_order_complete()
|
||||||
|
Loading…
Reference in New Issue
Block a user