diff --git a/ai_talks/AMUseBotBackend b/ai_talks/AMUseBotBackend new file mode 160000 index 0000000..01d547d --- /dev/null +++ b/ai_talks/AMUseBotBackend @@ -0,0 +1 @@ +Subproject commit 01d547dc221b2bd81ccbe24da37b792e9b176b37 diff --git a/ai_talks/chat.py b/ai_talks/chat.py index 40c6bb8..cc4aea7 100644 --- a/ai_talks/chat.py +++ b/ai_talks/chat.py @@ -1,56 +1,69 @@ from pathlib import Path from random import randrange +from AMUseBotBackend.src.DP.dp import DP +from AMUseBotBackend.src.DST.dst import DST +from AMUseBotBackend.src.NLU.nlu import NLU + import graphviz import streamlit as st from PIL import Image from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation from src.utils.lang import en -# --- PATH SETTINGS --- -current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd() -css_file: Path = current_dir / "src/styles/.css" -assets_dir: Path = current_dir / "assets" -icons_dir: Path = assets_dir / "icons" -img_dir: Path = assets_dir / "img" -tg_svg: Path = icons_dir / "tg.svg" -favicon: Path = icons_dir / "favicons/0.png" -# --- GENERAL SETTINGS --- -LANG_PL: str = "Pl" -AI_MODEL_OPTIONS: list[str] = [ - "gpt-3.5-turbo", - "gpt-4", - "gpt-4-32k", -] -CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)} +if __name__ == '__main__': -st.set_page_config(**CONFIG) + # --- PATH SETTINGS --- + current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd() + css_file: Path = current_dir / "src/styles/.css" + assets_dir: Path = current_dir / "assets" + icons_dir: Path = assets_dir / "icons" + img_dir: Path = assets_dir / "img" + tg_svg: Path = icons_dir / "tg.svg" + favicon: Path = icons_dir / "favicons/0.png" + # --- GENERAL SETTINGS --- + LANG_PL: str = "Pl" + AI_MODEL_OPTIONS: list[str] = [ + "gpt-3.5-turbo", + "gpt-4", + "gpt-4-32k", + ] -# --- LOAD CSS --- -with open(css_file) as f: - st.markdown(f"", unsafe_allow_html=True) + CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)} -# Storing The Context -if "locale" not in st.session_state: - st.session_state.locale = en -if "generated" not in st.session_state: - st.session_state.generated = [] -if "past" not in st.session_state: - st.session_state.past = [] -if "messages" not in st.session_state: - st.session_state.messages = [] -if "user_text" not in st.session_state: - st.session_state.user_text = "" -if "input_kind" not in st.session_state: - st.session_state.input_kind = st.session_state.locale.input_kind_1 -if "seed" not in st.session_state: - st.session_state.seed = randrange(10 ** 3) # noqa: S311 -if "costs" not in st.session_state: - st.session_state.costs = [] -if "total_tokens" not in st.session_state: - st.session_state.total_tokens = [] + st.set_page_config(**CONFIG) + # --- LOAD CSS --- + with open(css_file) as f: + st.markdown(f"", unsafe_allow_html=True) + + # Storing The Context + if "locale" not in st.session_state: + st.session_state.locale = en + if "generated" not in st.session_state: + st.session_state.generated = [] + if "past" not in st.session_state: + st.session_state.past = [] + if "messages" not in st.session_state: + st.session_state.messages = [] + if "user_text" not in st.session_state: + st.session_state.user_text = "" + if "input_kind" not in st.session_state: + st.session_state.input_kind = st.session_state.locale.input_kind_1 + if "seed" not in st.session_state: + st.session_state.seed = randrange(10 ** 3) # noqa: S311 + if "costs" not in st.session_state: + st.session_state.costs = [] + if "total_tokens" not in st.session_state: + st.session_state.total_tokens = [] + if "dst" not in st.session_state: + st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/") + if "dp" not in st.session_state: + st.session_state.dp = DP(dst=st.session_state.dst) + if "nlu" not in st.session_state: + st.session_state.nlu = NLU(intent_dict_path='AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json', + model_identifier_path='AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt') def show_graph(): # Create a graphlib graph object @@ -66,14 +79,12 @@ def show_graph(): graph.edge(st.session_state.generated[x], st.session_state.past[x+1]) except: pass - - st.graphviz_chart(graph) + st.graphviz_chart(graph) def main() -> None: c1, c2 = st.columns(2) with c1, c2: - c1.selectbox(label=st.session_state.locale.select_placeholder1, key="model", options=AI_MODEL_OPTIONS) st.session_state.input_kind = c2.radio( label=st.session_state.locale.input_kind, options=(st.session_state.locale.input_kind_1, st.session_state.locale.input_kind_2), @@ -84,19 +95,17 @@ def main() -> None: options=(st.session_state.locale.radio_text1, st.session_state.locale.radio_text2), horizontal=True, ) - match role_kind: - case st.session_state.locale.radio_text1: - c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role", + if role_kind == st.session_state.locale.radio_text1: + c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role", options=st.session_state.locale.ai_role_options) - case st.session_state.locale.radio_text2: - c2.text_input(label=st.session_state.locale.select_placeholder3, key="role") - + elif role_kind == st.session_state.locale.radio_text2: + c2.text_input(label=st.session_state.locale.select_placeholder3, key="role") + if st.session_state.user_text: show_graph() show_conversation() get_user_input() - show_chat_buttons() diff --git a/ai_talks/src/utils/conversation.py b/ai_talks/src/utils/conversation.py index 349ffc4..35516ae 100644 --- a/ai_talks/src/utils/conversation.py +++ b/ai_talks/src/utils/conversation.py @@ -65,10 +65,17 @@ def show_conversation() -> None: {"role": "user", "content": st.session_state.user_text}, ] - ai_content = "Dummy respone from AI" + # ai_content = "Dummy respone from AI" + intents = st.session_state.nlu.predict(st.session_state.user_text) + st.session_state.dst.update_dialog_history( + system_message='', + user_message=st.session_state.user_text, + intents=intents, + ) + system_message = st.session_state.dp.generate_response(intents) # delete random before deploying with our model - random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5)) - ai_content += random_str + #random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5)) + ai_content = system_message st.session_state.messages.append({"role": "assistant", "content": ai_content}) if ai_content: show_chat(ai_content, st.session_state.user_text)