cache model

This commit is contained in:
s444417 2023-06-11 22:25:49 +02:00
parent b4ee3f02af
commit 7c39778ea1
2 changed files with 10 additions and 7 deletions

View File

@ -3,7 +3,6 @@ from random import randrange
from AMUseBotBackend.src.DP.dp import DP from AMUseBotBackend.src.DP.dp import DP
from AMUseBotBackend.src.DST.dst import DST from AMUseBotBackend.src.DST.dst import DST
from AMUseBotBackend.src.NLU.nlu import NLU
import graphviz import graphviz
import streamlit as st import streamlit as st
@ -13,7 +12,7 @@ from src.utils.lang import en
if __name__ == '__main__': if __name__ == '__main__':
# --- PATH SETTINGS --- # --- PATH SETTINGS ---
current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd() current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
css_file: Path = current_dir / "src/styles/.css" css_file: Path = current_dir / "src/styles/.css"
@ -61,9 +60,6 @@ if __name__ == '__main__':
st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/") st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/")
if "dp" not in st.session_state: if "dp" not in st.session_state:
st.session_state.dp = DP(dst=st.session_state.dst) st.session_state.dp = DP(dst=st.session_state.dst)
if "nlu" not in st.session_state:
st.session_state.nlu = NLU(intent_dict_path='AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json',
model_identifier_path='AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt')
def show_graph(): def show_graph():
# Create a graphlib graph object # Create a graphlib graph object

View File

@ -7,6 +7,14 @@ from streamlit_chat import message
from .stt import show_voice_input from .stt import show_voice_input
from .tts import show_audio_player from .tts import show_audio_player
from AMUseBotBackend.src.DP.dp import DP
from AMUseBotBackend.src.DST.dst import DST
from AMUseBotBackend.src.NLU.nlu import NLU
@st.cache_resource
def get_nlu_model(intent_dict_path = 'AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json', model_identifier_path = 'AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt'):
return NLU(intent_dict_path=intent_dict_path,
model_identifier_path=model_identifier_path)
def clear_chat() -> None: def clear_chat() -> None:
st.session_state.generated = [] st.session_state.generated = []
@ -17,7 +25,6 @@ def clear_chat() -> None:
st.session_state.costs = [] st.session_state.costs = []
st.session_state.total_tokens = [] st.session_state.total_tokens = []
def get_user_input(): def get_user_input():
# match st.session_state.input_kind: # match st.session_state.input_kind:
# case st.session_state.locale.input_kind_1: # case st.session_state.locale.input_kind_1:
@ -84,7 +91,7 @@ def show_conversation() -> None:
if len(st.session_state.past): if len(st.session_state.past):
user_message = st.session_state.past[-1] user_message = st.session_state.past[-1]
# ai_content = "Dummy respone from AI" # ai_content = "Dummy respone from AI"
intents = st.session_state.nlu.predict(user_message) intents = get_nlu_model().predict(user_message)
st.session_state.dst.update_dialog_history( st.session_state.dst.update_dialog_history(
system_message='', system_message='',
user_message=user_message, user_message=user_message,