diff --git a/chat.py b/chat.py
index 47fecee..aaaa0b1 100644
--- a/chat.py
+++ b/chat.py
@@ -2,7 +2,8 @@ from openai.error import OpenAIError
from pathlib import Path
from src.utils.ai import ai_settings, send_ai_request
-from src.utils.tts import lang_selector, speech_speed_radio, show_player
+from src.utils.tts import show_player
+from src.utils.conversation import get_user_input, clear_chat, show_conversation
import streamlit as st
@@ -25,32 +26,46 @@ with open(css_file) as f:
st.markdown(f"
{PAGE_TITLE}
", unsafe_allow_html=True)
st.markdown("---")
+# Storing The Context
+if "generated" not in st.session_state:
+ st.session_state["generated"] = []
+if "past" not in st.session_state:
+ st.session_state["past"] = []
+if "messages" not in st.session_state:
+ st.session_state["messages"] = []
+if "user_text" not in st.session_state:
+ st.session_state["user_text"] = ""
+
def main() -> None:
- user_text = st.text_area(label="Start your conversation with AI:")
- if st.button("Rerun"):
- st.cache_data.clear()
+ user_content = get_user_input()
+ b1, b2 = st.columns(2)
+ with b1, b2:
+ b1.button("Rerun", on_click=st.cache_data.clear)
+ b2.button("Clear Conversation", on_click=clear_chat)
model, role = ai_settings()
- if user_text:
+ if user_content:
+ if st.session_state["messages"]:
+ st.session_state["messages"].append({"role": "user", "content": user_content})
+ else:
+ st.session_state["messages"] = [
+ {"role": "system", "content": f"You are a {role}."},
+ {"role": "user", "content": user_content},
+ ]
try:
- completion = send_ai_request(user_text, model, role)
+ completion = send_ai_request(model, st.session_state["messages"])
if st.checkbox(label="Show Full API Response", value=False):
st.json(completion)
- ai_content = completion.get("choices")[0].get("message").get("content")
- if ai_content:
- st.markdown(ai_content)
- st.markdown("---")
-
- col1, col2 = st.columns(2)
- with col1:
- lang_code = lang_selector()
- with col2:
- is_speech_slow = speech_speed_radio()
- show_player(ai_content, lang_code, is_speech_slow)
except OpenAIError as err:
st.error(err)
+ ai_content = completion.get("choices")[0].get("message").get("content")
+ st.session_state["messages"].append({"role": "assistant", "content": ai_content})
+ if ai_content:
+ show_conversation(ai_content, user_content)
+ st.markdown("---")
+ show_player(ai_content)
if __name__ == "__main__":
diff --git a/requirements.txt b/requirements.txt
index 7a687c6..66e6bca 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,5 @@
streamlit==1.19.0
+streamlit-chat==0.0.2.1
openai==0.27.0
gtts==2.3.1
pip==23.0.1
diff --git a/src/utils/ai.py b/src/utils/ai.py
index efbe873..41da0a6 100644
--- a/src/utils/ai.py
+++ b/src/utils/ai.py
@@ -1,10 +1,11 @@
-from typing import Dict, Tuple
+from typing import List, Dict, Tuple
import streamlit as st
import openai
AI_MODEL_OPTIONS = [
"gpt-3.5-turbo",
+ "gpt-4.0",
]
AI_ROLE_OPTIONS = [
@@ -23,19 +24,16 @@ AI_ROLE_OPTIONS = [
def ai_settings() -> Tuple[str, str]:
c1, c2 = st.columns(2)
with c1, c2:
- model = c1.selectbox(label="Select AI model", options=AI_MODEL_OPTIONS)
- role = c2.selectbox(label="Select AI role", options=AI_ROLE_OPTIONS)
+ model = c1.selectbox(label="Select AI Model", options=AI_MODEL_OPTIONS)
+ role = c2.selectbox(label="Select AI Role", options=AI_ROLE_OPTIONS)
return model, role
@st.cache_data()
-def send_ai_request(user_text: str, ai_model: str, ai_role: str) -> Dict:
+def send_ai_request(ai_model: str, messages: List[Dict]) -> Dict:
openai.api_key = st.secrets.api_credentials.api_key
completion = openai.ChatCompletion.create(
model=ai_model,
- messages=[
- {"role": "system", "content": f"You are a {ai_role}."},
- {"role": "user", "content": user_text},
- ]
+ messages=messages,
)
return completion
diff --git a/src/utils/conversation.py b/src/utils/conversation.py
new file mode 100644
index 0000000..add64ad
--- /dev/null
+++ b/src/utils/conversation.py
@@ -0,0 +1,26 @@
+import streamlit as st
+from streamlit_chat import message
+
+
+def clear_chat() -> None:
+ st.session_state["generated"] = []
+ st.session_state["past"] = []
+ st.session_state["messages"] = []
+ st.session_state["user_text"] = ""
+
+
+def get_user_input() -> str:
+ user_text = st.text_area(label="Start Your Conversation With AI:", key="user_text")
+ return user_text
+
+
+def show_conversation(ai_content: str, user_text: str) -> None:
+ if ai_content not in st.session_state.generated:
+ # store the ai content
+ st.session_state.past.append(user_text)
+ st.session_state.generated.append(ai_content)
+ if st.session_state["generated"]:
+ for i in range(len(st.session_state["generated"]) - 1, -1, -1):
+ st.markdown(st.session_state["generated"][i])
+ # message(st.session_state["generated"][i], key=str(i))
+ message(st.session_state["past"][i], is_user=True, key=str(i) + "_user", avatar_style="micah")
diff --git a/src/utils/tts.py b/src/utils/tts.py
index 1d638ad..dc18bf0 100644
--- a/src/utils/tts.py
+++ b/src/utils/tts.py
@@ -18,7 +18,7 @@ def lang_selector() -> str:
lang_options = list(lang.tts_langs().values())
default_index = lang_options.index(DEFAULT_SPEECH_LANG)
lang_name = st.selectbox(
- label="Select speech language",
+ label="Select Speech Language",
options=lang_options,
index=default_index
)
@@ -31,18 +31,23 @@ def speech_speed_radio() -> bool:
"Slow": True
}
speed_speech = st.radio(
- label="Select speech speed",
+ label="Select Speech Speed",
options=speed_options.keys(),
)
return speed_options.get(speed_speech)
-def show_player(ai_content: str, lang_code: str, is_speech_slow: bool) -> None:
+def show_player(ai_content: str) -> None:
sound_file = BytesIO()
+ col1, col2 = st.columns(2)
+ with col1:
+ lang_code = lang_selector()
+ with col2:
+ is_speech_slow = speech_speed_radio()
try:
tts = gTTS(text=ai_content, lang=lang_code, slow=is_speech_slow)
tts.write_to_fp(sound_file)
- st.write("To hear the voice of AI, press the play button.")
+ st.write("To Hear The Voice Of AI, Press Play.")
st.audio(sound_file)
except gTTSError as err:
st.error(err)