AMUseBot/ai_talks/chat.py
Adrian Charkiewicz 0d3b7d0a31 initial commit
2023-05-21 22:53:02 +02:00

106 lines
3.5 KiB
Python

from pathlib import Path
from random import randrange
import graphviz
import streamlit as st
from PIL import Image
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
from src.utils.lang import en
# --- PATH SETTINGS ---
current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
css_file: Path = current_dir / "src/styles/.css"
assets_dir: Path = current_dir / "assets"
icons_dir: Path = assets_dir / "icons"
img_dir: Path = assets_dir / "img"
tg_svg: Path = icons_dir / "tg.svg"
favicon: Path = icons_dir / "favicons/0.png"
# --- GENERAL SETTINGS ---
LANG_PL: str = "Pl"
AI_MODEL_OPTIONS: list[str] = [
"gpt-3.5-turbo",
"gpt-4",
"gpt-4-32k",
]
CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)}
st.set_page_config(**CONFIG)
# --- LOAD CSS ---
with open(css_file) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
# Storing The Context
if "locale" not in st.session_state:
st.session_state.locale = en
if "generated" not in st.session_state:
st.session_state.generated = []
if "past" not in st.session_state:
st.session_state.past = []
if "messages" not in st.session_state:
st.session_state.messages = []
if "user_text" not in st.session_state:
st.session_state.user_text = ""
if "input_kind" not in st.session_state:
st.session_state.input_kind = st.session_state.locale.input_kind_1
if "seed" not in st.session_state:
st.session_state.seed = randrange(10 ** 3) # noqa: S311
if "costs" not in st.session_state:
st.session_state.costs = []
if "total_tokens" not in st.session_state:
st.session_state.total_tokens = []
def show_graph():
# Create a graphlib graph object
if st.session_state.generated:
user, chatbot = [], []
graph = graphviz.Digraph()
for i in range(len(st.session_state.generated)):
user.append(st.session_state.past[i])
chatbot.append(st.session_state.generated[i])
for x in range(len(user)):
graph.edge(st.session_state.past[x], st.session_state.generated[x])
try:
graph.edge(st.session_state.generated[x], st.session_state.past[x+1])
except:
pass
st.graphviz_chart(graph)
def main() -> None:
c1, c2 = st.columns(2)
with c1, c2:
c1.selectbox(label=st.session_state.locale.select_placeholder1, key="model", options=AI_MODEL_OPTIONS)
st.session_state.input_kind = c2.radio(
label=st.session_state.locale.input_kind,
options=(st.session_state.locale.input_kind_1, st.session_state.locale.input_kind_2),
horizontal=True,
)
role_kind = c1.radio(
label=st.session_state.locale.radio_placeholder,
options=(st.session_state.locale.radio_text1, st.session_state.locale.radio_text2),
horizontal=True,
)
match role_kind:
case st.session_state.locale.radio_text1:
c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role",
options=st.session_state.locale.ai_role_options)
case st.session_state.locale.radio_text2:
c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
if st.session_state.user_text:
show_graph()
show_conversation()
get_user_input()
show_chat_buttons()
if __name__ == "__main__":
st.markdown(f"<h1 style='text-align: center;'>{st.session_state.locale.title}</h1>", unsafe_allow_html=True)
main()