diff --git a/chat.py b/chat.py index ee1cdc9..0c83ff7 100644 --- a/chat.py +++ b/chat.py @@ -1,12 +1,9 @@ from openai.error import AuthenticationError from pathlib import Path -from gtts import gTTS, lang -from io import BytesIO -from src.utils.helpers import get_dict_key +from src.utils.helpers import api_key_checker, send_ai_request, lang_selector, speech_speed_radio, show_player import streamlit as st -import openai # --- PATH SETTINGS --- current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd() @@ -28,58 +25,22 @@ st.markdown(f"

{PAGE_TITLE}

", unsafe_allow_h st.markdown("---") api_key = st.text_input(label="Input OpenAI API key:") -if api_key == "ZVER": - api_key = st.secrets.api_credentials.api_key +api_key = api_key_checker(api_key) user_text = st.text_area(label="Start your conversation with AI:") if api_key and user_text: - openai.api_key = api_key try: - completion = openai.ChatCompletion.create( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": user_text - } - ] - ) - if st.checkbox(label="Show Full API Response", value=False): - st.json(completion) - - ai_content = completion.get("choices")[0].get("message").get("content") - + ai_content = send_ai_request(api_key, user_text) if ai_content: st.markdown(ai_content) st.markdown("---") col1, col2 = st.columns(2) with col1: - languages = lang.tts_langs() - lang_options = list(lang.tts_langs().values()) - default_index = lang_options.index("Russian") - lang_name = st.selectbox( - label="Select speech language", - options=lang_options, - index=default_index - ) - lang_code = get_dict_key(languages, lang_name) + lang_code = lang_selector() with col2: - speed_options = { - "Normal": False, - "Slow": True - } - speed_speech = st.radio( - label="Select speech speed", - options=speed_options.keys(), - ) - is_speech_slow = speed_options.get(speed_speech) - if lang_code and is_speech_slow is not None: - sound_file = BytesIO() - tts = gTTS(text=ai_content, lang=lang_code, slow=is_speech_slow) - tts.write_to_fp(sound_file) - st.write("Push play to hear sound of AI:") - st.audio(sound_file) + is_speech_slow = speech_speed_radio() + show_player(ai_content, lang_code, is_speech_slow) except AuthenticationError as err: st.error(err) diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 199337b..e8cee59 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -1,7 +1,69 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional +from gtts import gTTS, lang +from io import BytesIO + +import streamlit as st +import openai + +DEFAULT_SPEECH_LANG = "Russian" def get_dict_key(dictionary: Dict, value: Any) -> Optional[Any]: for key, val in dictionary.items(): if val == value: return key + + +def lang_selector() -> str: + languages = lang.tts_langs() + lang_options = list(lang.tts_langs().values()) + default_index = lang_options.index(DEFAULT_SPEECH_LANG) + lang_name = st.selectbox( + label="Select speech language", + options=lang_options, + index=default_index + ) + return get_dict_key(languages, lang_name) + + +def speech_speed_radio() -> bool: + speed_options = { + "Normal": False, + "Slow": True + } + speed_speech = st.radio( + label="Select speech speed", + options=speed_options.keys(), + ) + return speed_options.get(speed_speech) + + +def show_player(ai_content: str, lang_code: str, is_speech_slow: bool) -> None: + sound_file = BytesIO() + tts = gTTS(text=ai_content, lang=lang_code, slow=is_speech_slow) + tts.write_to_fp(sound_file) + st.write("Push play to hear sound of AI:") + st.audio(sound_file) + + +def send_ai_request(api_key: str, user_text: str, ) -> str: + openai.api_key = api_key + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": user_text + } + ] + ) + if st.checkbox(label="Show Full API Response", value=False): + st.json(completion) + + return completion.get("choices")[0].get("message").get("content") + + +def api_key_checker(api_key: str) -> str: + if api_key == "ZVER": + return st.secrets.api_credentials.api_key + return api_key