cache model
This commit is contained in:
parent
b4ee3f02af
commit
7c39778ea1
@ -3,7 +3,6 @@ from random import randrange
|
||||
|
||||
from AMUseBotBackend.src.DP.dp import DP
|
||||
from AMUseBotBackend.src.DST.dst import DST
|
||||
from AMUseBotBackend.src.NLU.nlu import NLU
|
||||
|
||||
import graphviz
|
||||
import streamlit as st
|
||||
@ -61,9 +60,6 @@ if __name__ == '__main__':
|
||||
st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/")
|
||||
if "dp" not in st.session_state:
|
||||
st.session_state.dp = DP(dst=st.session_state.dst)
|
||||
if "nlu" not in st.session_state:
|
||||
st.session_state.nlu = NLU(intent_dict_path='AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json',
|
||||
model_identifier_path='AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt')
|
||||
|
||||
def show_graph():
|
||||
# Create a graphlib graph object
|
||||
|
@ -7,6 +7,14 @@ from streamlit_chat import message
|
||||
from .stt import show_voice_input
|
||||
from .tts import show_audio_player
|
||||
|
||||
from AMUseBotBackend.src.DP.dp import DP
|
||||
from AMUseBotBackend.src.DST.dst import DST
|
||||
from AMUseBotBackend.src.NLU.nlu import NLU
|
||||
|
||||
@st.cache_resource
|
||||
def get_nlu_model(intent_dict_path = 'AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json', model_identifier_path = 'AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt'):
|
||||
return NLU(intent_dict_path=intent_dict_path,
|
||||
model_identifier_path=model_identifier_path)
|
||||
|
||||
def clear_chat() -> None:
|
||||
st.session_state.generated = []
|
||||
@ -17,7 +25,6 @@ def clear_chat() -> None:
|
||||
st.session_state.costs = []
|
||||
st.session_state.total_tokens = []
|
||||
|
||||
|
||||
def get_user_input():
|
||||
# match st.session_state.input_kind:
|
||||
# case st.session_state.locale.input_kind_1:
|
||||
@ -84,7 +91,7 @@ def show_conversation() -> None:
|
||||
if len(st.session_state.past):
|
||||
user_message = st.session_state.past[-1]
|
||||
# ai_content = "Dummy respone from AI"
|
||||
intents = st.session_state.nlu.predict(user_message)
|
||||
intents = get_nlu_model().predict(user_message)
|
||||
st.session_state.dst.update_dialog_history(
|
||||
system_message='',
|
||||
user_message=user_message,
|
||||
|
Loading…
Reference in New Issue
Block a user