diff --git a/ai_talks/chat.py b/ai_talks/chat.py index 75b90f6..56fb34c 100644 --- a/ai_talks/chat.py +++ b/ai_talks/chat.py @@ -3,7 +3,6 @@ from random import randrange from AMUseBotBackend.src.DP.dp import DP from AMUseBotBackend.src.DST.dst import DST -from AMUseBotBackend.src.NLU.nlu import NLU import graphviz import streamlit as st @@ -13,7 +12,7 @@ from src.utils.lang import en if __name__ == '__main__': - + # --- PATH SETTINGS --- current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd() css_file: Path = current_dir / "src/styles/.css" @@ -61,9 +60,6 @@ if __name__ == '__main__': st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/") if "dp" not in st.session_state: st.session_state.dp = DP(dst=st.session_state.dst) - if "nlu" not in st.session_state: - st.session_state.nlu = NLU(intent_dict_path='AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json', - model_identifier_path='AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt') def show_graph(): # Create a graphlib graph object diff --git a/ai_talks/src/utils/conversation.py b/ai_talks/src/utils/conversation.py index aa8b3ef..28f2c46 100644 --- a/ai_talks/src/utils/conversation.py +++ b/ai_talks/src/utils/conversation.py @@ -7,6 +7,14 @@ from streamlit_chat import message from .stt import show_voice_input from .tts import show_audio_player +from AMUseBotBackend.src.DP.dp import DP +from AMUseBotBackend.src.DST.dst import DST +from AMUseBotBackend.src.NLU.nlu import NLU + +@st.cache_resource +def get_nlu_model(intent_dict_path = 'AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json', model_identifier_path = 'AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt'): + return NLU(intent_dict_path=intent_dict_path, + model_identifier_path=model_identifier_path) def clear_chat() -> None: st.session_state.generated = [] @@ -17,7 +25,6 @@ def clear_chat() -> None: st.session_state.costs = [] st.session_state.total_tokens = [] - def get_user_input(): # match st.session_state.input_kind: # case st.session_state.locale.input_kind_1: @@ -84,7 +91,7 @@ def show_conversation() -> None: if len(st.session_state.past): user_message = st.session_state.past[-1] # ai_content = "Dummy respone from AI" - intents = st.session_state.nlu.predict(user_message) + intents = get_nlu_model().predict(user_message) st.session_state.dst.update_dialog_history( system_message='', user_message=user_message,