add chat mode
This commit is contained in:
parent
3f88e79b20
commit
5e18ab9104
49
chat.py
49
chat.py
@ -2,7 +2,8 @@ from openai.error import OpenAIError
|
||||
from pathlib import Path
|
||||
|
||||
from src.utils.ai import ai_settings, send_ai_request
|
||||
from src.utils.tts import lang_selector, speech_speed_radio, show_player
|
||||
from src.utils.tts import show_player
|
||||
from src.utils.conversation import get_user_input, clear_chat, show_conversation
|
||||
|
||||
import streamlit as st
|
||||
|
||||
@ -25,32 +26,46 @@ with open(css_file) as f:
|
||||
st.markdown(f"<h1 style='text-align: center;'>{PAGE_TITLE}</h1>", unsafe_allow_html=True)
|
||||
st.markdown("---")
|
||||
|
||||
# Storing The Context
|
||||
if "generated" not in st.session_state:
|
||||
st.session_state["generated"] = []
|
||||
if "past" not in st.session_state:
|
||||
st.session_state["past"] = []
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state["messages"] = []
|
||||
if "user_text" not in st.session_state:
|
||||
st.session_state["user_text"] = ""
|
||||
|
||||
|
||||
def main() -> None:
|
||||
user_text = st.text_area(label="Start your conversation with AI:")
|
||||
if st.button("Rerun"):
|
||||
st.cache_data.clear()
|
||||
user_content = get_user_input()
|
||||
b1, b2 = st.columns(2)
|
||||
with b1, b2:
|
||||
b1.button("Rerun", on_click=st.cache_data.clear)
|
||||
b2.button("Clear Conversation", on_click=clear_chat)
|
||||
|
||||
model, role = ai_settings()
|
||||
|
||||
if user_text:
|
||||
if user_content:
|
||||
if st.session_state["messages"]:
|
||||
st.session_state["messages"].append({"role": "user", "content": user_content})
|
||||
else:
|
||||
st.session_state["messages"] = [
|
||||
{"role": "system", "content": f"You are a {role}."},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
try:
|
||||
completion = send_ai_request(user_text, model, role)
|
||||
completion = send_ai_request(model, st.session_state["messages"])
|
||||
if st.checkbox(label="Show Full API Response", value=False):
|
||||
st.json(completion)
|
||||
ai_content = completion.get("choices")[0].get("message").get("content")
|
||||
if ai_content:
|
||||
st.markdown(ai_content)
|
||||
st.markdown("---")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
lang_code = lang_selector()
|
||||
with col2:
|
||||
is_speech_slow = speech_speed_radio()
|
||||
show_player(ai_content, lang_code, is_speech_slow)
|
||||
except OpenAIError as err:
|
||||
st.error(err)
|
||||
ai_content = completion.get("choices")[0].get("message").get("content")
|
||||
st.session_state["messages"].append({"role": "assistant", "content": ai_content})
|
||||
if ai_content:
|
||||
show_conversation(ai_content, user_content)
|
||||
st.markdown("---")
|
||||
show_player(ai_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,5 @@
|
||||
streamlit==1.19.0
|
||||
streamlit-chat==0.0.2.1
|
||||
openai==0.27.0
|
||||
gtts==2.3.1
|
||||
pip==23.0.1
|
||||
|
@ -1,10 +1,11 @@
|
||||
from typing import Dict, Tuple
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
import streamlit as st
|
||||
import openai
|
||||
|
||||
AI_MODEL_OPTIONS = [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4.0",
|
||||
]
|
||||
|
||||
AI_ROLE_OPTIONS = [
|
||||
@ -23,19 +24,16 @@ AI_ROLE_OPTIONS = [
|
||||
def ai_settings() -> Tuple[str, str]:
|
||||
c1, c2 = st.columns(2)
|
||||
with c1, c2:
|
||||
model = c1.selectbox(label="Select AI model", options=AI_MODEL_OPTIONS)
|
||||
role = c2.selectbox(label="Select AI role", options=AI_ROLE_OPTIONS)
|
||||
model = c1.selectbox(label="Select AI Model", options=AI_MODEL_OPTIONS)
|
||||
role = c2.selectbox(label="Select AI Role", options=AI_ROLE_OPTIONS)
|
||||
return model, role
|
||||
|
||||
|
||||
@st.cache_data()
|
||||
def send_ai_request(user_text: str, ai_model: str, ai_role: str) -> Dict:
|
||||
def send_ai_request(ai_model: str, messages: List[Dict]) -> Dict:
|
||||
openai.api_key = st.secrets.api_credentials.api_key
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=ai_model,
|
||||
messages=[
|
||||
{"role": "system", "content": f"You are a {ai_role}."},
|
||||
{"role": "user", "content": user_text},
|
||||
]
|
||||
messages=messages,
|
||||
)
|
||||
return completion
|
||||
|
26
src/utils/conversation.py
Normal file
26
src/utils/conversation.py
Normal file
@ -0,0 +1,26 @@
|
||||
import streamlit as st
|
||||
from streamlit_chat import message
|
||||
|
||||
|
||||
def clear_chat() -> None:
|
||||
st.session_state["generated"] = []
|
||||
st.session_state["past"] = []
|
||||
st.session_state["messages"] = []
|
||||
st.session_state["user_text"] = ""
|
||||
|
||||
|
||||
def get_user_input() -> str:
|
||||
user_text = st.text_area(label="Start Your Conversation With AI:", key="user_text")
|
||||
return user_text
|
||||
|
||||
|
||||
def show_conversation(ai_content: str, user_text: str) -> None:
|
||||
if ai_content not in st.session_state.generated:
|
||||
# store the ai content
|
||||
st.session_state.past.append(user_text)
|
||||
st.session_state.generated.append(ai_content)
|
||||
if st.session_state["generated"]:
|
||||
for i in range(len(st.session_state["generated"]) - 1, -1, -1):
|
||||
st.markdown(st.session_state["generated"][i])
|
||||
# message(st.session_state["generated"][i], key=str(i))
|
||||
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user", avatar_style="micah")
|
@ -18,7 +18,7 @@ def lang_selector() -> str:
|
||||
lang_options = list(lang.tts_langs().values())
|
||||
default_index = lang_options.index(DEFAULT_SPEECH_LANG)
|
||||
lang_name = st.selectbox(
|
||||
label="Select speech language",
|
||||
label="Select Speech Language",
|
||||
options=lang_options,
|
||||
index=default_index
|
||||
)
|
||||
@ -31,18 +31,23 @@ def speech_speed_radio() -> bool:
|
||||
"Slow": True
|
||||
}
|
||||
speed_speech = st.radio(
|
||||
label="Select speech speed",
|
||||
label="Select Speech Speed",
|
||||
options=speed_options.keys(),
|
||||
)
|
||||
return speed_options.get(speed_speech)
|
||||
|
||||
|
||||
def show_player(ai_content: str, lang_code: str, is_speech_slow: bool) -> None:
|
||||
def show_player(ai_content: str) -> None:
|
||||
sound_file = BytesIO()
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
lang_code = lang_selector()
|
||||
with col2:
|
||||
is_speech_slow = speech_speed_radio()
|
||||
try:
|
||||
tts = gTTS(text=ai_content, lang=lang_code, slow=is_speech_slow)
|
||||
tts.write_to_fp(sound_file)
|
||||
st.write("To hear the voice of AI, press the play button.")
|
||||
st.write("To Hear The Voice Of AI, Press Play.")
|
||||
st.audio(sound_file)
|
||||
except gTTSError as err:
|
||||
st.error(err)
|
||||
|
Loading…
Reference in New Issue
Block a user