diff --git a/chatbot/main.py b/chatbot/main.py index 64b389a..b379079 100644 --- a/chatbot/main.py +++ b/chatbot/main.py @@ -1,5 +1,6 @@ from pathlib import Path -from modules.nlp import NaturalLanguageProcessor +from modules.nlu import NLU, Slot, UserAct +from modules.state_monitor import DialogStateMonitor from modules.generator import ResponseGenerator from modules.config import Config import colorama @@ -13,7 +14,8 @@ def main(): config_path = base_path / 'config' / 'config.json' config = Config.load_config(config_path) - nlp = NaturalLanguageProcessor(config) + nlu = NLU() + dst = DialogStateMonitor() generator = ResponseGenerator(config) print(Fore.CYAN + "Witaj w chatbocie! Rozpocznij rozmowÄ™.") @@ -25,10 +27,14 @@ def main(): print(Fore.RED + "Zamykanie chatbota...") break - intent = nlp.analyze(user_input) - response = generator.generate(intent) + user_act = nlu.analyze(user_input) + # user_act = UserAct(intent='inform', + # slots=[Slot(name='item', value='laptop'), Slot(name='item', value='kot'),Slot(name='address', value='123 Main St')]) + dst.update(user_act) + print(dst.state) + # response = generator.generate(intent) - print(Fore.CYAN + "Bot: " + response) + # print(Fore.CYAN + "Bot: " + response) if __name__ == "__main__": diff --git a/chatbot/modules/nlu.py b/chatbot/modules/nlu.py index 849a84b..42d4486 100644 --- a/chatbot/modules/nlu.py +++ b/chatbot/modules/nlu.py @@ -1,15 +1,39 @@ from flair.models import SequenceTagger import sys + sys.path.append("..") from models.nlu_train2 import predict_frame, predict_slot import logging logging.getLogger('flair').setLevel(logging.CRITICAL) + +class Slot: + def __init__(self, name, value=None): + self.name = name + self.value = value + + def __str__(self) -> str: + return f"Name: {self.name}, Value: {self.value}" + + +class UserAct: + def __init__(self, intent: str, slots: list[Slot] = []): + self.slots = slots + self.intent = intent + + def __str__(self): + msg = f"Act: {self.intent}, Slots: [" + for slot in self.slots: + msg += f"({slot}), " + msg += "]" + return msg + + class NLU: def __init__(self): - self.frame_model = SequenceTagger.load('../models/frame-model/final-model.pt') - self.slot_model = SequenceTagger.load('../models/slot-model/final-model.pt') + self.frame_model = SequenceTagger.load('models/frame-model/final-model.pt') + self.slot_model = SequenceTagger.load('models/slot-model/final-model.pt') def get_intent(self, text: str): return predict_frame(self.frame_model, text.split(), 'frame') @@ -24,14 +48,14 @@ class NLU: slot = frame["slot"] if slot.startswith("B-"): if current_slot: - slots.append({'name': current_slot, 'value': " ".join(current_slot_value)}) + slots.append(Slot(name=current_slot, value=current_slot_value)) current_slot = slot[2:] current_slot_value = [frame["form"]] elif slot.startswith("I-"): current_slot_value.append(frame["form"]) if current_slot: - slots.append({'name': current_slot, 'value': " ".join(current_slot_value)}) + slots.append(Slot(name=current_slot, value=current_slot_value)) return slots @@ -39,12 +63,5 @@ class NLU: intent = self.get_intent(text) slots = self.get_slot(text) print({'intent': intent, - 'slots': slots}) - return { - 'intent': intent, - 'slots': slots - } - -nlu = NLU() - -nlu.analyze("Chce kupic lakier do pazanokci") \ No newline at end of file + 'slots': slots}) + return UserAct(intent=intent, slots=slots) diff --git a/chatbot/modules/state_monitor.py b/chatbot/modules/state_monitor.py index 56ca56a..21d0994 100644 --- a/chatbot/modules/state_monitor.py +++ b/chatbot/modules/state_monitor.py @@ -1,6 +1,66 @@ -class DialogueStateMonitor: - def __init__(self) -> None: - self.state = {'last_intent': 'unknown'} +import copy +from modules.nlu import UserAct +import json - def update_state(self, intent: str) -> None: - self.state['last_intent'] = intent + +class DialogStateMonitor: + def __init__(self): + self.__initial_state = dict( + belief_state={ + 'item': {}, + 'address': {}, + 'card_nr': {}, + 'delivery_method': {}, + 'payment_method': {}, + 'email': {}, + 'order-complete': False, + }, + act='', + slot_names=[]) + + self.state = copy.deepcopy(self.__initial_state) + + def is_value_empty(self, d, key): + value = d.get(key, None) + if value in [None, '', [], {}]: + return True + return False + + def update_act(self, intent): + self.state['act'] = intent + + def update_slot_names(self, slots_names): + self.state['slot_names'] = slots_names + + def check_order_complete(self): + all_filled = all(bool(self.state['belief_state'][key]) for key in + ['item', 'address', 'card_nr', 'delivery_method', 'payment_method', 'email']) + self.state['belief_state']['order-complete'] = all_filled + + def update(self, act: UserAct) -> None: + print(act) + if act.intent == 'inform': + self.update_act(act.intent) + slots_mapping = { + 'item': [], + 'address': [], + 'card_nr': [], + 'delivery_method': [], + 'payment_method': [], + 'email': [] + } + for slot in act.slots: + if slot.name in slots_mapping and self.is_value_empty(self.state, slot.name): + slots_mapping[slot.name].append(slot.value) # To do: normalization + + for slot_name, values in slots_mapping.items(): + if values: + self.state['belief_state'][slot_name] = values + elif act.intent == 'request': + self.update_act(act.intent) + slots_names = [slot.name for slot in act.slots] + self.update_slot_names(slots_names) + elif act.intent == 'bye': + self.update_act(act.intent) + + self.check_order_complete()