add modularity
This commit is contained in:
parent
17fe0e02e7
commit
b496c9a95e
51
chat.py
51
chat.py
@ -1,12 +1,9 @@
|
|||||||
from openai.error import AuthenticationError
|
from openai.error import AuthenticationError
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from gtts import gTTS, lang
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
from src.utils.helpers import get_dict_key
|
from src.utils.helpers import api_key_checker, send_ai_request, lang_selector, speech_speed_radio, show_player
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import openai
|
|
||||||
|
|
||||||
# --- PATH SETTINGS ---
|
# --- PATH SETTINGS ---
|
||||||
current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd()
|
current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd()
|
||||||
@ -28,58 +25,22 @@ st.markdown(f"<h1 style='text-align: center;'>{PAGE_TITLE}</h1>", unsafe_allow_h
|
|||||||
st.markdown("---")
|
st.markdown("---")
|
||||||
|
|
||||||
api_key = st.text_input(label="Input OpenAI API key:")
|
api_key = st.text_input(label="Input OpenAI API key:")
|
||||||
if api_key == "ZVER":
|
api_key = api_key_checker(api_key)
|
||||||
api_key = st.secrets.api_credentials.api_key
|
|
||||||
|
|
||||||
user_text = st.text_area(label="Start your conversation with AI:")
|
user_text = st.text_area(label="Start your conversation with AI:")
|
||||||
|
|
||||||
if api_key and user_text:
|
if api_key and user_text:
|
||||||
openai.api_key = api_key
|
|
||||||
try:
|
try:
|
||||||
completion = openai.ChatCompletion.create(
|
ai_content = send_ai_request(api_key, user_text)
|
||||||
model="gpt-3.5-turbo",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": user_text
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if st.checkbox(label="Show Full API Response", value=False):
|
|
||||||
st.json(completion)
|
|
||||||
|
|
||||||
ai_content = completion.get("choices")[0].get("message").get("content")
|
|
||||||
|
|
||||||
if ai_content:
|
if ai_content:
|
||||||
st.markdown(ai_content)
|
st.markdown(ai_content)
|
||||||
st.markdown("---")
|
st.markdown("---")
|
||||||
|
|
||||||
col1, col2 = st.columns(2)
|
col1, col2 = st.columns(2)
|
||||||
with col1:
|
with col1:
|
||||||
languages = lang.tts_langs()
|
lang_code = lang_selector()
|
||||||
lang_options = list(lang.tts_langs().values())
|
|
||||||
default_index = lang_options.index("Russian")
|
|
||||||
lang_name = st.selectbox(
|
|
||||||
label="Select speech language",
|
|
||||||
options=lang_options,
|
|
||||||
index=default_index
|
|
||||||
)
|
|
||||||
lang_code = get_dict_key(languages, lang_name)
|
|
||||||
with col2:
|
with col2:
|
||||||
speed_options = {
|
is_speech_slow = speech_speed_radio()
|
||||||
"Normal": False,
|
show_player(ai_content, lang_code, is_speech_slow)
|
||||||
"Slow": True
|
|
||||||
}
|
|
||||||
speed_speech = st.radio(
|
|
||||||
label="Select speech speed",
|
|
||||||
options=speed_options.keys(),
|
|
||||||
)
|
|
||||||
is_speech_slow = speed_options.get(speed_speech)
|
|
||||||
if lang_code and is_speech_slow is not None:
|
|
||||||
sound_file = BytesIO()
|
|
||||||
tts = gTTS(text=ai_content, lang=lang_code, slow=is_speech_slow)
|
|
||||||
tts.write_to_fp(sound_file)
|
|
||||||
st.write("Push play to hear sound of AI:")
|
|
||||||
st.audio(sound_file)
|
|
||||||
except AuthenticationError as err:
|
except AuthenticationError as err:
|
||||||
st.error(err)
|
st.error(err)
|
||||||
|
@ -1,7 +1,69 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
from gtts import gTTS, lang
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
import openai
|
||||||
|
|
||||||
|
DEFAULT_SPEECH_LANG = "Russian"
|
||||||
|
|
||||||
|
|
||||||
def get_dict_key(dictionary: Dict, value: Any) -> Optional[Any]:
|
def get_dict_key(dictionary: Dict, value: Any) -> Optional[Any]:
|
||||||
for key, val in dictionary.items():
|
for key, val in dictionary.items():
|
||||||
if val == value:
|
if val == value:
|
||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def lang_selector() -> str:
|
||||||
|
languages = lang.tts_langs()
|
||||||
|
lang_options = list(lang.tts_langs().values())
|
||||||
|
default_index = lang_options.index(DEFAULT_SPEECH_LANG)
|
||||||
|
lang_name = st.selectbox(
|
||||||
|
label="Select speech language",
|
||||||
|
options=lang_options,
|
||||||
|
index=default_index
|
||||||
|
)
|
||||||
|
return get_dict_key(languages, lang_name)
|
||||||
|
|
||||||
|
|
||||||
|
def speech_speed_radio() -> bool:
|
||||||
|
speed_options = {
|
||||||
|
"Normal": False,
|
||||||
|
"Slow": True
|
||||||
|
}
|
||||||
|
speed_speech = st.radio(
|
||||||
|
label="Select speech speed",
|
||||||
|
options=speed_options.keys(),
|
||||||
|
)
|
||||||
|
return speed_options.get(speed_speech)
|
||||||
|
|
||||||
|
|
||||||
|
def show_player(ai_content: str, lang_code: str, is_speech_slow: bool) -> None:
|
||||||
|
sound_file = BytesIO()
|
||||||
|
tts = gTTS(text=ai_content, lang=lang_code, slow=is_speech_slow)
|
||||||
|
tts.write_to_fp(sound_file)
|
||||||
|
st.write("Push play to hear sound of AI:")
|
||||||
|
st.audio(sound_file)
|
||||||
|
|
||||||
|
|
||||||
|
def send_ai_request(api_key: str, user_text: str, ) -> str:
|
||||||
|
openai.api_key = api_key
|
||||||
|
completion = openai.ChatCompletion.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": user_text
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if st.checkbox(label="Show Full API Response", value=False):
|
||||||
|
st.json(completion)
|
||||||
|
|
||||||
|
return completion.get("choices")[0].get("message").get("content")
|
||||||
|
|
||||||
|
|
||||||
|
def api_key_checker(api_key: str) -> str:
|
||||||
|
if api_key == "ZVER":
|
||||||
|
return st.secrets.api_credentials.api_key
|
||||||
|
return api_key
|
||||||
|
Loading…
Reference in New Issue
Block a user