fix path issue

This commit is contained in:
s444417 2023-06-12 14:17:16 +02:00
parent 7c39778ea1
commit 5633019411
5 changed files with 31 additions and 5 deletions

4
.env_template Normal file
View File

@ -0,0 +1,4 @@
RECIPE_PATH=AMUseBotFront/ai_talks/AMUseBotBackend/recipe/
DIALOG_PATH=AMUseBotFront/ai_talks/AMUseBotBackend/dialog/
INTENT_DICT_PATH=ai_talks/AMUseBotBackend/utils/intent_dict.json
MODEL_IDENTIFIER_PATH=ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt

View File

@ -10,6 +10,8 @@ from PIL import Image
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
from src.utils.lang import en from src.utils.lang import en
import os
from dotenv import load_dotenv
if __name__ == '__main__': if __name__ == '__main__':
@ -37,6 +39,11 @@ if __name__ == '__main__':
with open(css_file) as f: with open(css_file) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
load_dotenv()
DIALOG_PATH = os.getenv('DIALOG_PATH')
RECIPE_PATH = os.getenv('RECIPE_PATH')
# Storing The Context # Storing The Context
if "locale" not in st.session_state: if "locale" not in st.session_state:
st.session_state.locale = en st.session_state.locale = en
@ -57,7 +64,7 @@ if __name__ == '__main__':
if "total_tokens" not in st.session_state: if "total_tokens" not in st.session_state:
st.session_state.total_tokens = [] st.session_state.total_tokens = []
if "dst" not in st.session_state: if "dst" not in st.session_state:
st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/") st.session_state.dst = DST(recipe_path=RECIPE_PATH, dialog_path=DIALOG_PATH)
if "dp" not in st.session_state: if "dp" not in st.session_state:
st.session_state.dp = DP(dst=st.session_state.dst) st.session_state.dp = DP(dst=st.session_state.dst)

View File

@ -7,12 +7,18 @@ from streamlit_chat import message
from .stt import show_voice_input from .stt import show_voice_input
from .tts import show_audio_player from .tts import show_audio_player
from AMUseBotBackend.src.DP.dp import DP
from AMUseBotBackend.src.DST.dst import DST
from AMUseBotBackend.src.NLU.nlu import NLU from AMUseBotBackend.src.NLU.nlu import NLU
import os
from dotenv import load_dotenv
load_dotenv()
INTENT_DICT_PATH = os.getenv('INTENT_DICT_PATH')
MODEL_IDENTIFIER_PATH = os.getenv('MODEL_IDENTIFIER_PATH')
@st.cache_resource @st.cache_resource
def get_nlu_model(intent_dict_path = 'AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json', model_identifier_path = 'AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt'): def get_nlu_model(intent_dict_path = INTENT_DICT_PATH, model_identifier_path = MODEL_IDENTIFIER_PATH):
return NLU(intent_dict_path=intent_dict_path, return NLU(intent_dict_path=intent_dict_path,
model_identifier_path=model_identifier_path) model_identifier_path=model_identifier_path)

View File

@ -10,3 +10,12 @@ watchdog>=3.0.0
setuptools~=65.5.0 setuptools~=65.5.0
graphviz~=0.20.1 graphviz~=0.20.1
Pillow~=9.5.0 Pillow~=9.5.0
pandas==1.5.2
scikit_learn==1.2.0
simpletransformers==0.63.9
torch==1.11.0
spacy==3.5.0
rank_bm25==0.2.2
tqdm==4.64.1
nlp==0.4.0
python-dotenv==1.0.0