add chat mode

This commit is contained in:
kosarevskiydp 2023-03-11 17:39:23 +03:00
parent 3f88e79b20
commit 5e18ab9104
5 changed files with 74 additions and 29 deletions

49
chat.py
View File

@ -2,7 +2,8 @@ from openai.error import OpenAIError
from pathlib import Path
from src.utils.ai import ai_settings, send_ai_request
from src.utils.tts import lang_selector, speech_speed_radio, show_player
from src.utils.tts import show_player
from src.utils.conversation import get_user_input, clear_chat, show_conversation
import streamlit as st
@ -25,32 +26,46 @@ with open(css_file) as f:
st.markdown(f"<h1 style='text-align: center;'>{PAGE_TITLE}</h1>", unsafe_allow_html=True)
st.markdown("---")
# Storing The Context
if "generated" not in st.session_state:
st.session_state["generated"] = []
if "past" not in st.session_state:
st.session_state["past"] = []
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "user_text" not in st.session_state:
st.session_state["user_text"] = ""
def main() -> None:
user_text = st.text_area(label="Start your conversation with AI:")
if st.button("Rerun"):
st.cache_data.clear()
user_content = get_user_input()
b1, b2 = st.columns(2)
with b1, b2:
b1.button("Rerun", on_click=st.cache_data.clear)
b2.button("Clear Conversation", on_click=clear_chat)
model, role = ai_settings()
if user_text:
if user_content:
if st.session_state["messages"]:
st.session_state["messages"].append({"role": "user", "content": user_content})
else:
st.session_state["messages"] = [
{"role": "system", "content": f"You are a {role}."},
{"role": "user", "content": user_content},
]
try:
completion = send_ai_request(user_text, model, role)
completion = send_ai_request(model, st.session_state["messages"])
if st.checkbox(label="Show Full API Response", value=False):
st.json(completion)
ai_content = completion.get("choices")[0].get("message").get("content")
if ai_content:
st.markdown(ai_content)
st.markdown("---")
col1, col2 = st.columns(2)
with col1:
lang_code = lang_selector()
with col2:
is_speech_slow = speech_speed_radio()
show_player(ai_content, lang_code, is_speech_slow)
except OpenAIError as err:
st.error(err)
ai_content = completion.get("choices")[0].get("message").get("content")
st.session_state["messages"].append({"role": "assistant", "content": ai_content})
if ai_content:
show_conversation(ai_content, user_content)
st.markdown("---")
show_player(ai_content)
if __name__ == "__main__":

View File

@ -1,4 +1,5 @@
streamlit==1.19.0
streamlit-chat==0.0.2.1
openai==0.27.0
gtts==2.3.1
pip==23.0.1

View File

@ -1,10 +1,11 @@
from typing import Dict, Tuple
from typing import List, Dict, Tuple
import streamlit as st
import openai
AI_MODEL_OPTIONS = [
"gpt-3.5-turbo",
"gpt-4.0",
]
AI_ROLE_OPTIONS = [
@ -23,19 +24,16 @@ AI_ROLE_OPTIONS = [
def ai_settings() -> Tuple[str, str]:
c1, c2 = st.columns(2)
with c1, c2:
model = c1.selectbox(label="Select AI model", options=AI_MODEL_OPTIONS)
role = c2.selectbox(label="Select AI role", options=AI_ROLE_OPTIONS)
model = c1.selectbox(label="Select AI Model", options=AI_MODEL_OPTIONS)
role = c2.selectbox(label="Select AI Role", options=AI_ROLE_OPTIONS)
return model, role
@st.cache_data()
def send_ai_request(user_text: str, ai_model: str, ai_role: str) -> Dict:
def send_ai_request(ai_model: str, messages: List[Dict]) -> Dict:
openai.api_key = st.secrets.api_credentials.api_key
completion = openai.ChatCompletion.create(
model=ai_model,
messages=[
{"role": "system", "content": f"You are a {ai_role}."},
{"role": "user", "content": user_text},
]
messages=messages,
)
return completion

26
src/utils/conversation.py Normal file
View File

@ -0,0 +1,26 @@
import streamlit as st
from streamlit_chat import message
def clear_chat() -> None:
st.session_state["generated"] = []
st.session_state["past"] = []
st.session_state["messages"] = []
st.session_state["user_text"] = ""
def get_user_input() -> str:
user_text = st.text_area(label="Start Your Conversation With AI:", key="user_text")
return user_text
def show_conversation(ai_content: str, user_text: str) -> None:
if ai_content not in st.session_state.generated:
# store the ai content
st.session_state.past.append(user_text)
st.session_state.generated.append(ai_content)
if st.session_state["generated"]:
for i in range(len(st.session_state["generated"]) - 1, -1, -1):
st.markdown(st.session_state["generated"][i])
# message(st.session_state["generated"][i], key=str(i))
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user", avatar_style="micah")

View File

@ -18,7 +18,7 @@ def lang_selector() -> str:
lang_options = list(lang.tts_langs().values())
default_index = lang_options.index(DEFAULT_SPEECH_LANG)
lang_name = st.selectbox(
label="Select speech language",
label="Select Speech Language",
options=lang_options,
index=default_index
)
@ -31,18 +31,23 @@ def speech_speed_radio() -> bool:
"Slow": True
}
speed_speech = st.radio(
label="Select speech speed",
label="Select Speech Speed",
options=speed_options.keys(),
)
return speed_options.get(speed_speech)
def show_player(ai_content: str, lang_code: str, is_speech_slow: bool) -> None:
def show_player(ai_content: str) -> None:
sound_file = BytesIO()
col1, col2 = st.columns(2)
with col1:
lang_code = lang_selector()
with col2:
is_speech_slow = speech_speed_radio()
try:
tts = gTTS(text=ai_content, lang=lang_code, slow=is_speech_slow)
tts.write_to_fp(sound_file)
st.write("To hear the voice of AI, press the play button.")
st.write("To Hear The Voice Of AI, Press Play.")
st.audio(sound_file)
except gTTSError as err:
st.error(err)