cache model

This commit is contained in:
s444417 2023-06-11 22:25:49 +02:00
parent b4ee3f02af
commit 7c39778ea1
2 changed files with 10 additions and 7 deletions

View File

@ -3,7 +3,6 @@ from random import randrange
from AMUseBotBackend.src.DP.dp import DP
from AMUseBotBackend.src.DST.dst import DST
from AMUseBotBackend.src.NLU.nlu import NLU
import graphviz
import streamlit as st
@ -13,7 +12,7 @@ from src.utils.lang import en
if __name__ == '__main__':
# --- PATH SETTINGS ---
current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
css_file: Path = current_dir / "src/styles/.css"
@ -61,9 +60,6 @@ if __name__ == '__main__':
st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/")
if "dp" not in st.session_state:
st.session_state.dp = DP(dst=st.session_state.dst)
if "nlu" not in st.session_state:
st.session_state.nlu = NLU(intent_dict_path='AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json',
model_identifier_path='AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt')
def show_graph():
# Create a graphlib graph object

View File

@ -7,6 +7,14 @@ from streamlit_chat import message
from .stt import show_voice_input
from .tts import show_audio_player
from AMUseBotBackend.src.DP.dp import DP
from AMUseBotBackend.src.DST.dst import DST
from AMUseBotBackend.src.NLU.nlu import NLU
@st.cache_resource
def get_nlu_model(intent_dict_path = 'AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json', model_identifier_path = 'AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt'):
return NLU(intent_dict_path=intent_dict_path,
model_identifier_path=model_identifier_path)
def clear_chat() -> None:
st.session_state.generated = []
@ -17,7 +25,6 @@ def clear_chat() -> None:
st.session_state.costs = []
st.session_state.total_tokens = []
def get_user_input():
# match st.session_state.input_kind:
# case st.session_state.locale.input_kind_1:
@ -84,7 +91,7 @@ def show_conversation() -> None:
if len(st.session_state.past):
user_message = st.session_state.past[-1]
# ai_content = "Dummy respone from AI"
intents = st.session_state.nlu.predict(user_message)
intents = get_nlu_model().predict(user_message)
st.session_state.dst.update_dialog_history(
system_message='',
user_message=user_message,