diff --git a/chat.py b/chat.py index 47fecee..aaaa0b1 100644 --- a/chat.py +++ b/chat.py @@ -2,7 +2,8 @@ from openai.error import OpenAIError from pathlib import Path from src.utils.ai import ai_settings, send_ai_request -from src.utils.tts import lang_selector, speech_speed_radio, show_player +from src.utils.tts import show_player +from src.utils.conversation import get_user_input, clear_chat, show_conversation import streamlit as st @@ -25,32 +26,46 @@ with open(css_file) as f: st.markdown(f"

{PAGE_TITLE}

", unsafe_allow_html=True) st.markdown("---") +# Storing The Context +if "generated" not in st.session_state: + st.session_state["generated"] = [] +if "past" not in st.session_state: + st.session_state["past"] = [] +if "messages" not in st.session_state: + st.session_state["messages"] = [] +if "user_text" not in st.session_state: + st.session_state["user_text"] = "" + def main() -> None: - user_text = st.text_area(label="Start your conversation with AI:") - if st.button("Rerun"): - st.cache_data.clear() + user_content = get_user_input() + b1, b2 = st.columns(2) + with b1, b2: + b1.button("Rerun", on_click=st.cache_data.clear) + b2.button("Clear Conversation", on_click=clear_chat) model, role = ai_settings() - if user_text: + if user_content: + if st.session_state["messages"]: + st.session_state["messages"].append({"role": "user", "content": user_content}) + else: + st.session_state["messages"] = [ + {"role": "system", "content": f"You are a {role}."}, + {"role": "user", "content": user_content}, + ] try: - completion = send_ai_request(user_text, model, role) + completion = send_ai_request(model, st.session_state["messages"]) if st.checkbox(label="Show Full API Response", value=False): st.json(completion) - ai_content = completion.get("choices")[0].get("message").get("content") - if ai_content: - st.markdown(ai_content) - st.markdown("---") - - col1, col2 = st.columns(2) - with col1: - lang_code = lang_selector() - with col2: - is_speech_slow = speech_speed_radio() - show_player(ai_content, lang_code, is_speech_slow) except OpenAIError as err: st.error(err) + ai_content = completion.get("choices")[0].get("message").get("content") + st.session_state["messages"].append({"role": "assistant", "content": ai_content}) + if ai_content: + show_conversation(ai_content, user_content) + st.markdown("---") + show_player(ai_content) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index 7a687c6..66e6bca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ streamlit==1.19.0 +streamlit-chat==0.0.2.1 openai==0.27.0 gtts==2.3.1 pip==23.0.1 diff --git a/src/utils/ai.py b/src/utils/ai.py index efbe873..41da0a6 100644 --- a/src/utils/ai.py +++ b/src/utils/ai.py @@ -1,10 +1,11 @@ -from typing import Dict, Tuple +from typing import List, Dict, Tuple import streamlit as st import openai AI_MODEL_OPTIONS = [ "gpt-3.5-turbo", + "gpt-4.0", ] AI_ROLE_OPTIONS = [ @@ -23,19 +24,16 @@ AI_ROLE_OPTIONS = [ def ai_settings() -> Tuple[str, str]: c1, c2 = st.columns(2) with c1, c2: - model = c1.selectbox(label="Select AI model", options=AI_MODEL_OPTIONS) - role = c2.selectbox(label="Select AI role", options=AI_ROLE_OPTIONS) + model = c1.selectbox(label="Select AI Model", options=AI_MODEL_OPTIONS) + role = c2.selectbox(label="Select AI Role", options=AI_ROLE_OPTIONS) return model, role @st.cache_data() -def send_ai_request(user_text: str, ai_model: str, ai_role: str) -> Dict: +def send_ai_request(ai_model: str, messages: List[Dict]) -> Dict: openai.api_key = st.secrets.api_credentials.api_key completion = openai.ChatCompletion.create( model=ai_model, - messages=[ - {"role": "system", "content": f"You are a {ai_role}."}, - {"role": "user", "content": user_text}, - ] + messages=messages, ) return completion diff --git a/src/utils/conversation.py b/src/utils/conversation.py new file mode 100644 index 0000000..add64ad --- /dev/null +++ b/src/utils/conversation.py @@ -0,0 +1,26 @@ +import streamlit as st +from streamlit_chat import message + + +def clear_chat() -> None: + st.session_state["generated"] = [] + st.session_state["past"] = [] + st.session_state["messages"] = [] + st.session_state["user_text"] = "" + + +def get_user_input() -> str: + user_text = st.text_area(label="Start Your Conversation With AI:", key="user_text") + return user_text + + +def show_conversation(ai_content: str, user_text: str) -> None: + if ai_content not in st.session_state.generated: + # store the ai content + st.session_state.past.append(user_text) + st.session_state.generated.append(ai_content) + if st.session_state["generated"]: + for i in range(len(st.session_state["generated"]) - 1, -1, -1): + st.markdown(st.session_state["generated"][i]) + # message(st.session_state["generated"][i], key=str(i)) + message(st.session_state["past"][i], is_user=True, key=str(i) + "_user", avatar_style="micah") diff --git a/src/utils/tts.py b/src/utils/tts.py index 1d638ad..dc18bf0 100644 --- a/src/utils/tts.py +++ b/src/utils/tts.py @@ -18,7 +18,7 @@ def lang_selector() -> str: lang_options = list(lang.tts_langs().values()) default_index = lang_options.index(DEFAULT_SPEECH_LANG) lang_name = st.selectbox( - label="Select speech language", + label="Select Speech Language", options=lang_options, index=default_index ) @@ -31,18 +31,23 @@ def speech_speed_radio() -> bool: "Slow": True } speed_speech = st.radio( - label="Select speech speed", + label="Select Speech Speed", options=speed_options.keys(), ) return speed_options.get(speed_speech) -def show_player(ai_content: str, lang_code: str, is_speech_slow: bool) -> None: +def show_player(ai_content: str) -> None: sound_file = BytesIO() + col1, col2 = st.columns(2) + with col1: + lang_code = lang_selector() + with col2: + is_speech_slow = speech_speed_radio() try: tts = gTTS(text=ai_content, lang=lang_code, slow=is_speech_slow) tts.write_to_fp(sound_file) - st.write("To hear the voice of AI, press the play button.") + st.write("To Hear The Voice Of AI, Press Play.") st.audio(sound_file) except gTTSError as err: st.error(err)