From 828578d6bf7a4923f35248ad11c913c3b3ba31b9 Mon Sep 17 00:00:00 2001 From: s444417 Date: Mon, 5 Jun 2023 20:54:15 +0200 Subject: [PATCH] fix chat display --- ai_talks/chat.py | 22 +++++---- ai_talks/src/utils/agi/chat_gpt.py | 19 -------- ai_talks/src/utils/conversation.py | 75 +++++++++++++++++++----------- 3 files changed, 59 insertions(+), 57 deletions(-) delete mode 100644 ai_talks/src/utils/agi/chat_gpt.py diff --git a/ai_talks/chat.py b/ai_talks/chat.py index cc4aea7..75b90f6 100644 --- a/ai_talks/chat.py +++ b/ai_talks/chat.py @@ -42,7 +42,7 @@ if __name__ == '__main__': if "locale" not in st.session_state: st.session_state.locale = en if "generated" not in st.session_state: - st.session_state.generated = [] + st.session_state.generated = ["Hello! I'm AMUseBot, a virtual cooking assistant. Please tell me the name of the dish that you'd like to prepare today."] if "past" not in st.session_state: st.session_state.past = [] if "messages" not in st.session_state: @@ -70,16 +70,18 @@ def show_graph(): if st.session_state.generated: user, chatbot = [], [] graph = graphviz.Digraph() - for i in range(len(st.session_state.generated)): - user.append(st.session_state.past[i]) + for i in range(len(st.session_state.past)): chatbot.append(st.session_state.generated[i]) + user.append(st.session_state.past[i]) for x in range(len(user)): - graph.edge(st.session_state.past[x], st.session_state.generated[x]) + chatbot_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x].split(' '))] + user_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.past[x].split(' '))] + graph.edge(' '.join(chatbot_text), ' '.join(user_text)) try: - graph.edge(st.session_state.generated[x], st.session_state.past[x+1]) + graph.edge(' '.join(user_text), ' '.join([word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x + 1].split(' '))])) except: pass - st.graphviz_chart(graph) + st.graphviz_chart(graph) def main() -> None: @@ -101,12 +103,12 @@ def main() -> None: elif role_kind == st.session_state.locale.radio_text2: c2.text_input(label=st.session_state.locale.select_placeholder3, key="role") - if st.session_state.user_text: - show_graph() - show_conversation() - get_user_input() show_chat_buttons() + + show_conversation() + with st.sidebar: + show_graph() if __name__ == "__main__": diff --git a/ai_talks/src/utils/agi/chat_gpt.py b/ai_talks/src/utils/agi/chat_gpt.py deleted file mode 100644 index 0b8e56e..0000000 --- a/ai_talks/src/utils/agi/chat_gpt.py +++ /dev/null @@ -1,19 +0,0 @@ -import logging -from typing import List # NOQA: UP035 - -import openai -import streamlit as st - - -@st.cache_data() -def create_gpt_completion(ai_model: str, messages: List[dict]) -> dict: - #logging.info(f"{messages=}") - # completion = openai.ChatCompletion.create( - # model=ai_model, - # messages=messages, - # # stream=True, - # # temperature=0.7, - # ) - #logging.info(f"{completion=}") - # return completion - pass diff --git a/ai_talks/src/utils/conversation.py b/ai_talks/src/utils/conversation.py index 35516ae..aa8b3ef 100644 --- a/ai_talks/src/utils/conversation.py +++ b/ai_talks/src/utils/conversation.py @@ -4,7 +4,6 @@ from random import randrange, choices import streamlit as st from openai.error import InvalidRequestError, OpenAIError from streamlit_chat import message -from .agi.chat_gpt import create_gpt_completion from .stt import show_voice_input from .tts import show_audio_player @@ -29,11 +28,13 @@ def get_user_input(): # show_text_input() st.session_state.user_text = st.text_input("You: ", "Hello, how are you?", key="primary") +def on_send(): + st.session_state.past.append(st.session_state.user_text) def show_chat_buttons() -> None: b0, b1, b2 = st.columns(3) with b0, b1, b2: - b0.button(label=st.session_state.locale.chat_run_btn) + b0.button(label=st.session_state.locale.chat_run_btn, on_click=on_send) b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat) b2.download_button( label=st.session_state.locale.chat_save_btn, @@ -42,19 +43,34 @@ def show_chat_buttons() -> None: mime="application/json", ) - -def show_chat(ai_content: str, user_text: str) -> None: - if user_text not in st.session_state.past: - # # store the ai content - st.session_state.past.append(user_text) - st.session_state.generated.append(ai_content) - - if st.session_state.generated: - for i in range(len(st.session_state.generated)): - message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed) - message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed) - - +# def show_chat(ai_content: str, user_text: str) -> None: +# first_message = True +# +# if user_text not in st.session_state.past: +# if len(st.session_state.past) == 0: +# first_message = False +# print('message 1') +# message(st.session_state.generated[0], key=str(0), seed=st.session_state.seed) +# else: +# # # store the ai content +# st.session_state.past.append(user_text) +# st.session_state.generated.append(ai_content) +# +# if st.session_state.generated: +# for i in range(len(st.session_state.past)): +# print('message 2') +# message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed) +# message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed) +# if first_message: +# print('message 3') +# message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed) + +def show_chat() -> None: + for i in range(len(st.session_state.past)): + message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed) + message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed) + message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed) + def show_conversation() -> None: if st.session_state.messages: st.session_state.messages.append({"role": "user", "content": st.session_state.user_text}) @@ -64,20 +80,23 @@ def show_conversation() -> None: {"role": "system", "content": ai_role}, {"role": "user", "content": st.session_state.user_text}, ] + + if len(st.session_state.past): + user_message = st.session_state.past[-1] + # ai_content = "Dummy respone from AI" + intents = st.session_state.nlu.predict(user_message) + st.session_state.dst.update_dialog_history( + system_message='', + user_message=user_message, + intents=intents, + ) + system_message = st.session_state.dp.generate_response(intents) + st.session_state.generated.append(system_message) - # ai_content = "Dummy respone from AI" - intents = st.session_state.nlu.predict(st.session_state.user_text) - st.session_state.dst.update_dialog_history( - system_message='', - user_message=st.session_state.user_text, - intents=intents, - ) - system_message = st.session_state.dp.generate_response(intents) # delete random before deploying with our model #random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5)) - ai_content = system_message + ai_content = st.session_state.generated[-1] st.session_state.messages.append({"role": "assistant", "content": ai_content}) - if ai_content: - show_chat(ai_content, st.session_state.user_text) - st.divider() - show_audio_player(ai_content) + show_chat() + st.divider() + show_audio_player(ai_content)