AMUseBot/ai_talks/chat.py

142 lines
5.2 KiB
Python
Raw Permalink Normal View History

2023-03-02 15:32:39 +01:00
from pathlib import Path
2023-04-23 21:33:19 +02:00
from random import randrange
2023-03-02 18:39:03 +01:00
2023-05-31 21:11:09 +02:00
from AMUseBotBackend.src.DP.dp import DP
from AMUseBotBackend.src.DST.dst import DST
2023-06-15 23:18:09 +02:00
import streamlit.components.v1 as components
2023-05-31 21:11:09 +02:00
2023-05-21 22:53:02 +02:00
import graphviz
2023-03-02 15:32:39 +01:00
import streamlit as st
2023-05-21 22:53:02 +02:00
from PIL import Image
2023-04-23 21:33:19 +02:00
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
2023-05-21 22:53:02 +02:00
from src.utils.lang import en
2023-03-02 15:32:39 +01:00
2023-06-15 23:18:09 +02:00
# from live_asr import LiveWav2Vec2
# english_model = "facebook/wav2vec2-large-960h-lv60-self"
# asr = LiveWav2Vec2(english_model,device_name="default")
# asr.start()
2023-06-12 14:17:16 +02:00
import os
from dotenv import load_dotenv
2023-03-11 15:39:23 +01:00
2023-05-31 21:11:09 +02:00
if __name__ == '__main__':
2023-06-11 22:25:49 +02:00
2023-05-31 21:11:09 +02:00
# --- PATH SETTINGS ---
current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
css_file: Path = current_dir / "src/styles/.css"
assets_dir: Path = current_dir / "assets"
icons_dir: Path = assets_dir / "icons"
img_dir: Path = assets_dir / "img"
tg_svg: Path = icons_dir / "tg.svg"
favicon: Path = icons_dir / "favicons/0.png"
# --- GENERAL SETTINGS ---
LANG_PL: str = "Pl"
AI_MODEL_OPTIONS: list[str] = [
"gpt-3.5-turbo",
"gpt-4",
"gpt-4-32k",
]
CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)}
st.set_page_config(**CONFIG)
# --- LOAD CSS ---
with open(css_file) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
2023-06-12 14:17:16 +02:00
load_dotenv()
DIALOG_PATH = os.getenv('DIALOG_PATH')
RECIPE_PATH = os.getenv('RECIPE_PATH')
2023-05-31 21:11:09 +02:00
# Storing The Context
if "locale" not in st.session_state:
st.session_state.locale = en
if "generated" not in st.session_state:
2023-06-05 20:54:15 +02:00
st.session_state.generated = ["Hello! I'm AMUseBot, a virtual cooking assistant. Please tell me the name of the dish that you'd like to prepare today."]
2023-05-31 21:11:09 +02:00
if "past" not in st.session_state:
st.session_state.past = []
if "messages" not in st.session_state:
st.session_state.messages = []
if "user_text" not in st.session_state:
st.session_state.user_text = ""
if "input_kind" not in st.session_state:
st.session_state.input_kind = st.session_state.locale.input_kind_1
if "seed" not in st.session_state:
st.session_state.seed = randrange(10 ** 3) # noqa: S311
if "costs" not in st.session_state:
st.session_state.costs = []
if "total_tokens" not in st.session_state:
st.session_state.total_tokens = []
if "dst" not in st.session_state:
2023-06-12 14:17:16 +02:00
st.session_state.dst = DST(recipe_path=RECIPE_PATH, dialog_path=DIALOG_PATH)
2023-05-31 21:11:09 +02:00
if "dp" not in st.session_state:
st.session_state.dp = DP(dst=st.session_state.dst)
2023-03-02 19:32:09 +01:00
2023-05-21 22:53:02 +02:00
def show_graph():
# Create a graphlib graph object
if st.session_state.generated:
user, chatbot = [], []
graph = graphviz.Digraph()
2023-06-05 20:54:15 +02:00
for i in range(len(st.session_state.past)):
2023-05-21 22:53:02 +02:00
chatbot.append(st.session_state.generated[i])
2023-06-05 20:54:15 +02:00
user.append(st.session_state.past[i])
2023-05-21 22:53:02 +02:00
for x in range(len(user)):
2023-06-05 20:54:15 +02:00
chatbot_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x].split(' '))]
user_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.past[x].split(' '))]
graph.edge(' '.join(chatbot_text), ' '.join(user_text))
2023-05-21 22:53:02 +02:00
try:
2023-06-05 20:54:15 +02:00
graph.edge(' '.join(user_text), ' '.join([word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x + 1].split(' '))]))
2023-05-21 22:53:02 +02:00
except:
pass
2023-06-05 20:54:15 +02:00
st.graphviz_chart(graph)
2023-05-21 22:53:02 +02:00
2023-06-15 23:18:09 +02:00
# def mermaid(code: str) -> None:
# components.html(
# f"""
# <pre class="mermaid">
# {code}
# </pre>
#
# <script type="module">
# import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
# mermaid.initialize({{ startOnLoad: true }});
# </script>
# """
# )
2023-05-21 22:53:02 +02:00
2023-03-02 19:32:09 +01:00
def main() -> None:
2023-04-07 22:10:58 +02:00
c1, c2 = st.columns(2)
2023-03-22 17:13:02 +01:00
with c1, c2:
2023-04-07 22:10:58 +02:00
st.session_state.input_kind = c2.radio(
label=st.session_state.locale.input_kind,
options=(st.session_state.locale.input_kind_1, st.session_state.locale.input_kind_2),
horizontal=True,
)
role_kind = c1.radio(
2023-04-01 23:37:20 +02:00
label=st.session_state.locale.radio_placeholder,
options=(st.session_state.locale.radio_text1, st.session_state.locale.radio_text2),
horizontal=True,
)
2023-05-31 21:11:09 +02:00
if role_kind == st.session_state.locale.radio_text1:
c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role",
2023-03-28 02:22:21 +02:00
options=st.session_state.locale.ai_role_options)
2023-05-31 21:11:09 +02:00
elif role_kind == st.session_state.locale.radio_text2:
c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
2023-05-21 22:53:02 +02:00
get_user_input()
show_chat_buttons()
2023-06-15 23:18:09 +02:00
# mermaid("""graph TD;
# A -->|Dupa|B;
# A --> C;
# B --> D;
# C --> D;""")
2023-06-05 20:54:15 +02:00
show_conversation()
with st.sidebar:
show_graph()
2023-04-16 02:01:02 +02:00
if __name__ == "__main__":
2023-05-21 22:53:02 +02:00
st.markdown(f"<h1 style='text-align: center;'>{st.session_state.locale.title}</h1>", unsafe_allow_html=True)
main()