2023-03-07 07:15:58 +01:00
|
|
|
from openai.error import OpenAIError
|
2023-03-02 15:32:39 +01:00
|
|
|
from pathlib import Path
|
|
|
|
|
2023-03-06 14:27:43 +01:00
|
|
|
from src.utils.ai import ai_settings, send_ai_request
|
2023-03-11 15:39:23 +01:00
|
|
|
from src.utils.tts import show_player
|
|
|
|
from src.utils.conversation import get_user_input, clear_chat, show_conversation
|
2023-03-02 18:39:03 +01:00
|
|
|
|
2023-03-02 15:32:39 +01:00
|
|
|
import streamlit as st
|
|
|
|
|
|
|
|
# --- PATH SETTINGS ---
|
|
|
|
current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd()
|
|
|
|
css_file = current_dir / "src/styles/.css"
|
|
|
|
assets_dir = current_dir / "src/assets"
|
|
|
|
icons_dir = assets_dir / "icons"
|
|
|
|
|
|
|
|
# --- GENERAL SETTINGS ---
|
|
|
|
PAGE_TITLE = "AI Talks"
|
2023-03-02 18:34:01 +01:00
|
|
|
PAGE_ICON = "🤖"
|
2023-03-02 15:32:39 +01:00
|
|
|
|
|
|
|
st.set_page_config(page_title=PAGE_TITLE, page_icon=PAGE_ICON)
|
|
|
|
|
|
|
|
# --- LOAD CSS ---
|
|
|
|
with open(css_file) as f:
|
|
|
|
st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
|
|
|
|
|
|
|
|
st.markdown(f"<h1 style='text-align: center;'>{PAGE_TITLE}</h1>", unsafe_allow_html=True)
|
|
|
|
st.markdown("---")
|
|
|
|
|
2023-03-11 15:39:23 +01:00
|
|
|
# Storing The Context
|
|
|
|
if "generated" not in st.session_state:
|
|
|
|
st.session_state["generated"] = []
|
|
|
|
if "past" not in st.session_state:
|
|
|
|
st.session_state["past"] = []
|
|
|
|
if "messages" not in st.session_state:
|
|
|
|
st.session_state["messages"] = []
|
|
|
|
if "user_text" not in st.session_state:
|
|
|
|
st.session_state["user_text"] = ""
|
|
|
|
|
2023-03-02 19:32:09 +01:00
|
|
|
|
|
|
|
def main() -> None:
|
2023-03-11 15:39:23 +01:00
|
|
|
user_content = get_user_input()
|
|
|
|
b1, b2 = st.columns(2)
|
|
|
|
with b1, b2:
|
|
|
|
b1.button("Rerun", on_click=st.cache_data.clear)
|
|
|
|
b2.button("Clear Conversation", on_click=clear_chat)
|
2023-03-02 19:32:09 +01:00
|
|
|
|
2023-03-06 14:27:43 +01:00
|
|
|
model, role = ai_settings()
|
|
|
|
|
2023-03-11 15:39:23 +01:00
|
|
|
if user_content:
|
|
|
|
if st.session_state["messages"]:
|
|
|
|
st.session_state["messages"].append({"role": "user", "content": user_content})
|
|
|
|
else:
|
|
|
|
st.session_state["messages"] = [
|
|
|
|
{"role": "system", "content": f"You are a {role}."},
|
|
|
|
{"role": "user", "content": user_content},
|
|
|
|
]
|
2023-03-02 19:32:09 +01:00
|
|
|
try:
|
2023-03-11 15:39:23 +01:00
|
|
|
completion = send_ai_request(model, st.session_state["messages"])
|
2023-03-13 15:56:55 +01:00
|
|
|
ai_content = completion.get("choices")[0].get("message").get("content")
|
|
|
|
st.session_state["messages"].append({"role": "assistant", "content": ai_content})
|
|
|
|
if ai_content:
|
|
|
|
show_conversation(ai_content, user_content)
|
|
|
|
st.markdown("---")
|
|
|
|
show_player(ai_content)
|
|
|
|
except (OpenAIError, UnboundLocalError) as err:
|
2023-03-02 19:32:09 +01:00
|
|
|
st.error(err)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|
2023-03-16 11:24:27 +01:00
|
|
|
st.image("assets/ai.jpg")
|