diff --git a/chat.py b/chat.py
index ee1cdc9..0c83ff7 100644
--- a/chat.py
+++ b/chat.py
@@ -1,12 +1,9 @@
from openai.error import AuthenticationError
from pathlib import Path
-from gtts import gTTS, lang
-from io import BytesIO
-from src.utils.helpers import get_dict_key
+from src.utils.helpers import api_key_checker, send_ai_request, lang_selector, speech_speed_radio, show_player
import streamlit as st
-import openai
# --- PATH SETTINGS ---
current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd()
@@ -28,58 +25,22 @@ st.markdown(f"
{PAGE_TITLE}
", unsafe_allow_h
st.markdown("---")
api_key = st.text_input(label="Input OpenAI API key:")
-if api_key == "ZVER":
- api_key = st.secrets.api_credentials.api_key
+api_key = api_key_checker(api_key)
user_text = st.text_area(label="Start your conversation with AI:")
if api_key and user_text:
- openai.api_key = api_key
try:
- completion = openai.ChatCompletion.create(
- model="gpt-3.5-turbo",
- messages=[
- {
- "role": "user",
- "content": user_text
- }
- ]
- )
- if st.checkbox(label="Show Full API Response", value=False):
- st.json(completion)
-
- ai_content = completion.get("choices")[0].get("message").get("content")
-
+ ai_content = send_ai_request(api_key, user_text)
if ai_content:
st.markdown(ai_content)
st.markdown("---")
col1, col2 = st.columns(2)
with col1:
- languages = lang.tts_langs()
- lang_options = list(lang.tts_langs().values())
- default_index = lang_options.index("Russian")
- lang_name = st.selectbox(
- label="Select speech language",
- options=lang_options,
- index=default_index
- )
- lang_code = get_dict_key(languages, lang_name)
+ lang_code = lang_selector()
with col2:
- speed_options = {
- "Normal": False,
- "Slow": True
- }
- speed_speech = st.radio(
- label="Select speech speed",
- options=speed_options.keys(),
- )
- is_speech_slow = speed_options.get(speed_speech)
- if lang_code and is_speech_slow is not None:
- sound_file = BytesIO()
- tts = gTTS(text=ai_content, lang=lang_code, slow=is_speech_slow)
- tts.write_to_fp(sound_file)
- st.write("Push play to hear sound of AI:")
- st.audio(sound_file)
+ is_speech_slow = speech_speed_radio()
+ show_player(ai_content, lang_code, is_speech_slow)
except AuthenticationError as err:
st.error(err)
diff --git a/src/utils/helpers.py b/src/utils/helpers.py
index 199337b..e8cee59 100644
--- a/src/utils/helpers.py
+++ b/src/utils/helpers.py
@@ -1,7 +1,69 @@
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Optional
+from gtts import gTTS, lang
+from io import BytesIO
+
+import streamlit as st
+import openai
+
+DEFAULT_SPEECH_LANG = "Russian"
def get_dict_key(dictionary: Dict, value: Any) -> Optional[Any]:
for key, val in dictionary.items():
if val == value:
return key
+
+
+def lang_selector() -> str:
+ languages = lang.tts_langs()
+ lang_options = list(lang.tts_langs().values())
+ default_index = lang_options.index(DEFAULT_SPEECH_LANG)
+ lang_name = st.selectbox(
+ label="Select speech language",
+ options=lang_options,
+ index=default_index
+ )
+ return get_dict_key(languages, lang_name)
+
+
+def speech_speed_radio() -> bool:
+ speed_options = {
+ "Normal": False,
+ "Slow": True
+ }
+ speed_speech = st.radio(
+ label="Select speech speed",
+ options=speed_options.keys(),
+ )
+ return speed_options.get(speed_speech)
+
+
+def show_player(ai_content: str, lang_code: str, is_speech_slow: bool) -> None:
+ sound_file = BytesIO()
+ tts = gTTS(text=ai_content, lang=lang_code, slow=is_speech_slow)
+ tts.write_to_fp(sound_file)
+ st.write("Push play to hear sound of AI:")
+ st.audio(sound_file)
+
+
+def send_ai_request(api_key: str, user_text: str, ) -> str:
+ openai.api_key = api_key
+ completion = openai.ChatCompletion.create(
+ model="gpt-3.5-turbo",
+ messages=[
+ {
+ "role": "user",
+ "content": user_text
+ }
+ ]
+ )
+ if st.checkbox(label="Show Full API Response", value=False):
+ st.json(completion)
+
+ return completion.get("choices")[0].get("message").get("content")
+
+
+def api_key_checker(api_key: str) -> str:
+ if api_key == "ZVER":
+ return st.secrets.api_credentials.api_key
+ return api_key