from pathlib import Path from random import randrange from AMUseBotBackend.src.DP.dp import DP from AMUseBotBackend.src.DST.dst import DST import graphviz import streamlit as st from PIL import Image from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation from src.utils.lang import en import openai import copy import json import os from dotenv import load_dotenv if __name__ == '__main__': # --- PATH SETTINGS --- current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd() css_file: Path = current_dir / "src/styles/.css" assets_dir: Path = current_dir / "assets" icons_dir: Path = assets_dir / "icons" img_dir: Path = assets_dir / "img" tg_svg: Path = icons_dir / "tg.svg" favicon: Path = icons_dir / "favicons/0.png" # --- GENERAL SETTINGS --- LANG_PL: str = "Pl" AI_MODEL_OPTIONS: list[str] = [ "gpt-3.5-turbo", "gpt-4", "gpt-4-32k", ] CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)} st.set_page_config(**CONFIG) # --- LOAD CSS --- with open(css_file) as f: st.markdown(f"", unsafe_allow_html=True) load_dotenv() DIALOG_PATH = os.getenv('DIALOG_PATH') RECIPE_PATH = os.getenv('RECIPE_PATH') CHARACTERS_DICT = os.getenv('CHARACTERS_DICT') API_KEY = os.getenv('API_KEY') # Storing The Context if "locale" not in st.session_state: st.session_state.locale = en if "generated" not in st.session_state: st.session_state.generated = ["Hello! I'm AMUseBot, a virtual cooking assistant. Please tell me the name of the dish that you'd like to prepare today."] if "past" not in st.session_state: st.session_state.past = [] if "messages" not in st.session_state: st.session_state.messages = [] if "user_text" not in st.session_state: st.session_state.user_text = "" if "input_kind" not in st.session_state: st.session_state.input_kind = st.session_state.locale.input_kind_1 if "seed" not in st.session_state: st.session_state.seed = randrange(10 ** 3) # noqa: S311 if "costs" not in st.session_state: st.session_state.costs = [] if "total_tokens" not in st.session_state: st.session_state.total_tokens = [] if "dst" not in st.session_state: st.session_state.dst = DST(recipe_path=RECIPE_PATH, dialog_path=DIALOG_PATH) if "dp" not in st.session_state: st.session_state.dp = DP(dst=st.session_state.dst) if "openai" not in st.session_state: st.session_state.openai = openai st.session_state.openai.api_key = API_KEY if "characters_dict" not in st.session_state: with open(CHARACTERS_DICT) as f: st.session_state.characters_dict = json.load(f) def show_graph(): # Create a graphlib graph object if st.session_state.generated: user, chatbot = [], [] graph = graphviz.Digraph() chatbot = copy.deepcopy(st.session_state.generated) user = copy.deepcopy(st.session_state.past) for x in range(len(user)): chatbot_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(chatbot[x].split(' '))] user_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(user[x].split(' '))] graph.edge(' '.join(chatbot_text), ' '.join(user_text)) try: graph.edge(' '.join(user_text), ' '.join([word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(chatbot[x + 1].split(' '))])) except: pass st.graphviz_chart(graph) def main() -> None: c1, c2 = st.columns(2) with c1, c2: st.session_state.input_kind = c2.radio( label=st.session_state.locale.input_kind, options=(st.session_state.locale.input_kind_1, st.session_state.locale.input_kind_2), horizontal=True, ) role_kind = c1.radio( label=st.session_state.locale.radio_placeholder, options=(st.session_state.locale.radio_text1, st.session_state.locale.radio_text2), horizontal=True, ) if role_kind == st.session_state.locale.radio_text1: character_type = c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role", options=st.session_state.locale.ai_role_options) st.session_state.dp.character = character_type if character_type == 'default': st.session_state.dp.llm_rephrasing = False else: st.session_state.dp.llm_rephrasing = True elif role_kind == st.session_state.locale.radio_text2: c2.text_input(label=st.session_state.locale.select_placeholder3, key="role") get_user_input() show_chat_buttons() show_conversation() with st.sidebar: show_graph() if __name__ == "__main__": st.markdown(f"

{st.session_state.locale.title}

", unsafe_allow_html=True) main()