diff --git a/ai_talks/AMUseBotBackend b/ai_talks/AMUseBotBackend
new file mode 160000
index 0000000..01d547d
--- /dev/null
+++ b/ai_talks/AMUseBotBackend
@@ -0,0 +1 @@
+Subproject commit 01d547dc221b2bd81ccbe24da37b792e9b176b37
diff --git a/ai_talks/chat.py b/ai_talks/chat.py
index 40c6bb8..cc4aea7 100644
--- a/ai_talks/chat.py
+++ b/ai_talks/chat.py
@@ -1,56 +1,69 @@
from pathlib import Path
from random import randrange
+from AMUseBotBackend.src.DP.dp import DP
+from AMUseBotBackend.src.DST.dst import DST
+from AMUseBotBackend.src.NLU.nlu import NLU
+
import graphviz
import streamlit as st
from PIL import Image
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
from src.utils.lang import en
-# --- PATH SETTINGS ---
-current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
-css_file: Path = current_dir / "src/styles/.css"
-assets_dir: Path = current_dir / "assets"
-icons_dir: Path = assets_dir / "icons"
-img_dir: Path = assets_dir / "img"
-tg_svg: Path = icons_dir / "tg.svg"
-favicon: Path = icons_dir / "favicons/0.png"
-# --- GENERAL SETTINGS ---
-LANG_PL: str = "Pl"
-AI_MODEL_OPTIONS: list[str] = [
- "gpt-3.5-turbo",
- "gpt-4",
- "gpt-4-32k",
-]
-CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)}
+if __name__ == '__main__':
-st.set_page_config(**CONFIG)
+ # --- PATH SETTINGS ---
+ current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
+ css_file: Path = current_dir / "src/styles/.css"
+ assets_dir: Path = current_dir / "assets"
+ icons_dir: Path = assets_dir / "icons"
+ img_dir: Path = assets_dir / "img"
+ tg_svg: Path = icons_dir / "tg.svg"
+ favicon: Path = icons_dir / "favicons/0.png"
+ # --- GENERAL SETTINGS ---
+ LANG_PL: str = "Pl"
+ AI_MODEL_OPTIONS: list[str] = [
+ "gpt-3.5-turbo",
+ "gpt-4",
+ "gpt-4-32k",
+ ]
-# --- LOAD CSS ---
-with open(css_file) as f:
- st.markdown(f"", unsafe_allow_html=True)
+ CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)}
-# Storing The Context
-if "locale" not in st.session_state:
- st.session_state.locale = en
-if "generated" not in st.session_state:
- st.session_state.generated = []
-if "past" not in st.session_state:
- st.session_state.past = []
-if "messages" not in st.session_state:
- st.session_state.messages = []
-if "user_text" not in st.session_state:
- st.session_state.user_text = ""
-if "input_kind" not in st.session_state:
- st.session_state.input_kind = st.session_state.locale.input_kind_1
-if "seed" not in st.session_state:
- st.session_state.seed = randrange(10 ** 3) # noqa: S311
-if "costs" not in st.session_state:
- st.session_state.costs = []
-if "total_tokens" not in st.session_state:
- st.session_state.total_tokens = []
+ st.set_page_config(**CONFIG)
+ # --- LOAD CSS ---
+ with open(css_file) as f:
+ st.markdown(f"", unsafe_allow_html=True)
+
+ # Storing The Context
+ if "locale" not in st.session_state:
+ st.session_state.locale = en
+ if "generated" not in st.session_state:
+ st.session_state.generated = []
+ if "past" not in st.session_state:
+ st.session_state.past = []
+ if "messages" not in st.session_state:
+ st.session_state.messages = []
+ if "user_text" not in st.session_state:
+ st.session_state.user_text = ""
+ if "input_kind" not in st.session_state:
+ st.session_state.input_kind = st.session_state.locale.input_kind_1
+ if "seed" not in st.session_state:
+ st.session_state.seed = randrange(10 ** 3) # noqa: S311
+ if "costs" not in st.session_state:
+ st.session_state.costs = []
+ if "total_tokens" not in st.session_state:
+ st.session_state.total_tokens = []
+ if "dst" not in st.session_state:
+ st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/")
+ if "dp" not in st.session_state:
+ st.session_state.dp = DP(dst=st.session_state.dst)
+ if "nlu" not in st.session_state:
+ st.session_state.nlu = NLU(intent_dict_path='AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json',
+ model_identifier_path='AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt')
def show_graph():
# Create a graphlib graph object
@@ -66,14 +79,12 @@ def show_graph():
graph.edge(st.session_state.generated[x], st.session_state.past[x+1])
except:
pass
-
- st.graphviz_chart(graph)
+ st.graphviz_chart(graph)
def main() -> None:
c1, c2 = st.columns(2)
with c1, c2:
- c1.selectbox(label=st.session_state.locale.select_placeholder1, key="model", options=AI_MODEL_OPTIONS)
st.session_state.input_kind = c2.radio(
label=st.session_state.locale.input_kind,
options=(st.session_state.locale.input_kind_1, st.session_state.locale.input_kind_2),
@@ -84,19 +95,17 @@ def main() -> None:
options=(st.session_state.locale.radio_text1, st.session_state.locale.radio_text2),
horizontal=True,
)
- match role_kind:
- case st.session_state.locale.radio_text1:
- c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role",
+ if role_kind == st.session_state.locale.radio_text1:
+ c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role",
options=st.session_state.locale.ai_role_options)
- case st.session_state.locale.radio_text2:
- c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
-
+ elif role_kind == st.session_state.locale.radio_text2:
+ c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
+
if st.session_state.user_text:
show_graph()
show_conversation()
get_user_input()
-
show_chat_buttons()
diff --git a/ai_talks/src/utils/conversation.py b/ai_talks/src/utils/conversation.py
index 349ffc4..35516ae 100644
--- a/ai_talks/src/utils/conversation.py
+++ b/ai_talks/src/utils/conversation.py
@@ -65,10 +65,17 @@ def show_conversation() -> None:
{"role": "user", "content": st.session_state.user_text},
]
- ai_content = "Dummy respone from AI"
+ # ai_content = "Dummy respone from AI"
+ intents = st.session_state.nlu.predict(st.session_state.user_text)
+ st.session_state.dst.update_dialog_history(
+ system_message='',
+ user_message=st.session_state.user_text,
+ intents=intents,
+ )
+ system_message = st.session_state.dp.generate_response(intents)
# delete random before deploying with our model
- random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
- ai_content += random_str
+ #random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
+ ai_content = system_message
st.session_state.messages.append({"role": "assistant", "content": ai_content})
if ai_content:
show_chat(ai_content, st.session_state.user_text)