cache model
This commit is contained in:
parent
b4ee3f02af
commit
7c39778ea1
@ -3,7 +3,6 @@ from random import randrange
|
|||||||
|
|
||||||
from AMUseBotBackend.src.DP.dp import DP
|
from AMUseBotBackend.src.DP.dp import DP
|
||||||
from AMUseBotBackend.src.DST.dst import DST
|
from AMUseBotBackend.src.DST.dst import DST
|
||||||
from AMUseBotBackend.src.NLU.nlu import NLU
|
|
||||||
|
|
||||||
import graphviz
|
import graphviz
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
@ -13,7 +12,7 @@ from src.utils.lang import en
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
# --- PATH SETTINGS ---
|
# --- PATH SETTINGS ---
|
||||||
current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
|
current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
|
||||||
css_file: Path = current_dir / "src/styles/.css"
|
css_file: Path = current_dir / "src/styles/.css"
|
||||||
@ -61,9 +60,6 @@ if __name__ == '__main__':
|
|||||||
st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/")
|
st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/")
|
||||||
if "dp" not in st.session_state:
|
if "dp" not in st.session_state:
|
||||||
st.session_state.dp = DP(dst=st.session_state.dst)
|
st.session_state.dp = DP(dst=st.session_state.dst)
|
||||||
if "nlu" not in st.session_state:
|
|
||||||
st.session_state.nlu = NLU(intent_dict_path='AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json',
|
|
||||||
model_identifier_path='AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt')
|
|
||||||
|
|
||||||
def show_graph():
|
def show_graph():
|
||||||
# Create a graphlib graph object
|
# Create a graphlib graph object
|
||||||
|
@ -7,6 +7,14 @@ from streamlit_chat import message
|
|||||||
from .stt import show_voice_input
|
from .stt import show_voice_input
|
||||||
from .tts import show_audio_player
|
from .tts import show_audio_player
|
||||||
|
|
||||||
|
from AMUseBotBackend.src.DP.dp import DP
|
||||||
|
from AMUseBotBackend.src.DST.dst import DST
|
||||||
|
from AMUseBotBackend.src.NLU.nlu import NLU
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def get_nlu_model(intent_dict_path = 'AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json', model_identifier_path = 'AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt'):
|
||||||
|
return NLU(intent_dict_path=intent_dict_path,
|
||||||
|
model_identifier_path=model_identifier_path)
|
||||||
|
|
||||||
def clear_chat() -> None:
|
def clear_chat() -> None:
|
||||||
st.session_state.generated = []
|
st.session_state.generated = []
|
||||||
@ -17,7 +25,6 @@ def clear_chat() -> None:
|
|||||||
st.session_state.costs = []
|
st.session_state.costs = []
|
||||||
st.session_state.total_tokens = []
|
st.session_state.total_tokens = []
|
||||||
|
|
||||||
|
|
||||||
def get_user_input():
|
def get_user_input():
|
||||||
# match st.session_state.input_kind:
|
# match st.session_state.input_kind:
|
||||||
# case st.session_state.locale.input_kind_1:
|
# case st.session_state.locale.input_kind_1:
|
||||||
@ -84,7 +91,7 @@ def show_conversation() -> None:
|
|||||||
if len(st.session_state.past):
|
if len(st.session_state.past):
|
||||||
user_message = st.session_state.past[-1]
|
user_message = st.session_state.past[-1]
|
||||||
# ai_content = "Dummy respone from AI"
|
# ai_content = "Dummy respone from AI"
|
||||||
intents = st.session_state.nlu.predict(user_message)
|
intents = get_nlu_model().predict(user_message)
|
||||||
st.session_state.dst.update_dialog_history(
|
st.session_state.dst.update_dialog_history(
|
||||||
system_message='',
|
system_message='',
|
||||||
user_message=user_message,
|
user_message=user_message,
|
||||||
|
Loading…
Reference in New Issue
Block a user