add chat mode
This commit is contained in:
parent
3f88e79b20
commit
5e18ab9104
49
chat.py
49
chat.py
@ -2,7 +2,8 @@ from openai.error import OpenAIError
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from src.utils.ai import ai_settings, send_ai_request
|
from src.utils.ai import ai_settings, send_ai_request
|
||||||
from src.utils.tts import lang_selector, speech_speed_radio, show_player
|
from src.utils.tts import show_player
|
||||||
|
from src.utils.conversation import get_user_input, clear_chat, show_conversation
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
@ -25,32 +26,46 @@ with open(css_file) as f:
|
|||||||
st.markdown(f"<h1 style='text-align: center;'>{PAGE_TITLE}</h1>", unsafe_allow_html=True)
|
st.markdown(f"<h1 style='text-align: center;'>{PAGE_TITLE}</h1>", unsafe_allow_html=True)
|
||||||
st.markdown("---")
|
st.markdown("---")
|
||||||
|
|
||||||
|
# Storing The Context
|
||||||
|
if "generated" not in st.session_state:
|
||||||
|
st.session_state["generated"] = []
|
||||||
|
if "past" not in st.session_state:
|
||||||
|
st.session_state["past"] = []
|
||||||
|
if "messages" not in st.session_state:
|
||||||
|
st.session_state["messages"] = []
|
||||||
|
if "user_text" not in st.session_state:
|
||||||
|
st.session_state["user_text"] = ""
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
user_text = st.text_area(label="Start your conversation with AI:")
|
user_content = get_user_input()
|
||||||
if st.button("Rerun"):
|
b1, b2 = st.columns(2)
|
||||||
st.cache_data.clear()
|
with b1, b2:
|
||||||
|
b1.button("Rerun", on_click=st.cache_data.clear)
|
||||||
|
b2.button("Clear Conversation", on_click=clear_chat)
|
||||||
|
|
||||||
model, role = ai_settings()
|
model, role = ai_settings()
|
||||||
|
|
||||||
if user_text:
|
if user_content:
|
||||||
|
if st.session_state["messages"]:
|
||||||
|
st.session_state["messages"].append({"role": "user", "content": user_content})
|
||||||
|
else:
|
||||||
|
st.session_state["messages"] = [
|
||||||
|
{"role": "system", "content": f"You are a {role}."},
|
||||||
|
{"role": "user", "content": user_content},
|
||||||
|
]
|
||||||
try:
|
try:
|
||||||
completion = send_ai_request(user_text, model, role)
|
completion = send_ai_request(model, st.session_state["messages"])
|
||||||
if st.checkbox(label="Show Full API Response", value=False):
|
if st.checkbox(label="Show Full API Response", value=False):
|
||||||
st.json(completion)
|
st.json(completion)
|
||||||
ai_content = completion.get("choices")[0].get("message").get("content")
|
|
||||||
if ai_content:
|
|
||||||
st.markdown(ai_content)
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
col1, col2 = st.columns(2)
|
|
||||||
with col1:
|
|
||||||
lang_code = lang_selector()
|
|
||||||
with col2:
|
|
||||||
is_speech_slow = speech_speed_radio()
|
|
||||||
show_player(ai_content, lang_code, is_speech_slow)
|
|
||||||
except OpenAIError as err:
|
except OpenAIError as err:
|
||||||
st.error(err)
|
st.error(err)
|
||||||
|
ai_content = completion.get("choices")[0].get("message").get("content")
|
||||||
|
st.session_state["messages"].append({"role": "assistant", "content": ai_content})
|
||||||
|
if ai_content:
|
||||||
|
show_conversation(ai_content, user_content)
|
||||||
|
st.markdown("---")
|
||||||
|
show_player(ai_content)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
streamlit==1.19.0
|
streamlit==1.19.0
|
||||||
|
streamlit-chat==0.0.2.1
|
||||||
openai==0.27.0
|
openai==0.27.0
|
||||||
gtts==2.3.1
|
gtts==2.3.1
|
||||||
pip==23.0.1
|
pip==23.0.1
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
from typing import Dict, Tuple
|
from typing import List, Dict, Tuple
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
AI_MODEL_OPTIONS = [
|
AI_MODEL_OPTIONS = [
|
||||||
"gpt-3.5-turbo",
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-4.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
AI_ROLE_OPTIONS = [
|
AI_ROLE_OPTIONS = [
|
||||||
@ -23,19 +24,16 @@ AI_ROLE_OPTIONS = [
|
|||||||
def ai_settings() -> Tuple[str, str]:
|
def ai_settings() -> Tuple[str, str]:
|
||||||
c1, c2 = st.columns(2)
|
c1, c2 = st.columns(2)
|
||||||
with c1, c2:
|
with c1, c2:
|
||||||
model = c1.selectbox(label="Select AI model", options=AI_MODEL_OPTIONS)
|
model = c1.selectbox(label="Select AI Model", options=AI_MODEL_OPTIONS)
|
||||||
role = c2.selectbox(label="Select AI role", options=AI_ROLE_OPTIONS)
|
role = c2.selectbox(label="Select AI Role", options=AI_ROLE_OPTIONS)
|
||||||
return model, role
|
return model, role
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data()
|
@st.cache_data()
|
||||||
def send_ai_request(user_text: str, ai_model: str, ai_role: str) -> Dict:
|
def send_ai_request(ai_model: str, messages: List[Dict]) -> Dict:
|
||||||
openai.api_key = st.secrets.api_credentials.api_key
|
openai.api_key = st.secrets.api_credentials.api_key
|
||||||
completion = openai.ChatCompletion.create(
|
completion = openai.ChatCompletion.create(
|
||||||
model=ai_model,
|
model=ai_model,
|
||||||
messages=[
|
messages=messages,
|
||||||
{"role": "system", "content": f"You are a {ai_role}."},
|
|
||||||
{"role": "user", "content": user_text},
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
return completion
|
return completion
|
||||||
|
26
src/utils/conversation.py
Normal file
26
src/utils/conversation.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import streamlit as st
|
||||||
|
from streamlit_chat import message
|
||||||
|
|
||||||
|
|
||||||
|
def clear_chat() -> None:
|
||||||
|
st.session_state["generated"] = []
|
||||||
|
st.session_state["past"] = []
|
||||||
|
st.session_state["messages"] = []
|
||||||
|
st.session_state["user_text"] = ""
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_input() -> str:
|
||||||
|
user_text = st.text_area(label="Start Your Conversation With AI:", key="user_text")
|
||||||
|
return user_text
|
||||||
|
|
||||||
|
|
||||||
|
def show_conversation(ai_content: str, user_text: str) -> None:
|
||||||
|
if ai_content not in st.session_state.generated:
|
||||||
|
# store the ai content
|
||||||
|
st.session_state.past.append(user_text)
|
||||||
|
st.session_state.generated.append(ai_content)
|
||||||
|
if st.session_state["generated"]:
|
||||||
|
for i in range(len(st.session_state["generated"]) - 1, -1, -1):
|
||||||
|
st.markdown(st.session_state["generated"][i])
|
||||||
|
# message(st.session_state["generated"][i], key=str(i))
|
||||||
|
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user", avatar_style="micah")
|
@ -18,7 +18,7 @@ def lang_selector() -> str:
|
|||||||
lang_options = list(lang.tts_langs().values())
|
lang_options = list(lang.tts_langs().values())
|
||||||
default_index = lang_options.index(DEFAULT_SPEECH_LANG)
|
default_index = lang_options.index(DEFAULT_SPEECH_LANG)
|
||||||
lang_name = st.selectbox(
|
lang_name = st.selectbox(
|
||||||
label="Select speech language",
|
label="Select Speech Language",
|
||||||
options=lang_options,
|
options=lang_options,
|
||||||
index=default_index
|
index=default_index
|
||||||
)
|
)
|
||||||
@ -31,18 +31,23 @@ def speech_speed_radio() -> bool:
|
|||||||
"Slow": True
|
"Slow": True
|
||||||
}
|
}
|
||||||
speed_speech = st.radio(
|
speed_speech = st.radio(
|
||||||
label="Select speech speed",
|
label="Select Speech Speed",
|
||||||
options=speed_options.keys(),
|
options=speed_options.keys(),
|
||||||
)
|
)
|
||||||
return speed_options.get(speed_speech)
|
return speed_options.get(speed_speech)
|
||||||
|
|
||||||
|
|
||||||
def show_player(ai_content: str, lang_code: str, is_speech_slow: bool) -> None:
|
def show_player(ai_content: str) -> None:
|
||||||
sound_file = BytesIO()
|
sound_file = BytesIO()
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
lang_code = lang_selector()
|
||||||
|
with col2:
|
||||||
|
is_speech_slow = speech_speed_radio()
|
||||||
try:
|
try:
|
||||||
tts = gTTS(text=ai_content, lang=lang_code, slow=is_speech_slow)
|
tts = gTTS(text=ai_content, lang=lang_code, slow=is_speech_slow)
|
||||||
tts.write_to_fp(sound_file)
|
tts.write_to_fp(sound_file)
|
||||||
st.write("To hear the voice of AI, press the play button.")
|
st.write("To Hear The Voice Of AI, Press Play.")
|
||||||
st.audio(sound_file)
|
st.audio(sound_file)
|
||||||
except gTTSError as err:
|
except gTTSError as err:
|
||||||
st.error(err)
|
st.error(err)
|
||||||
|
Loading…
Reference in New Issue
Block a user