fix chat display

This commit is contained in:
s444417 2023-06-05 20:54:15 +02:00
parent 8a3669e2e6
commit 828578d6bf
3 changed files with 59 additions and 57 deletions

View File

@ -42,7 +42,7 @@ if __name__ == '__main__':
if "locale" not in st.session_state: if "locale" not in st.session_state:
st.session_state.locale = en st.session_state.locale = en
if "generated" not in st.session_state: if "generated" not in st.session_state:
st.session_state.generated = [] st.session_state.generated = ["Hello! I'm AMUseBot, a virtual cooking assistant. Please tell me the name of the dish that you'd like to prepare today."]
if "past" not in st.session_state: if "past" not in st.session_state:
st.session_state.past = [] st.session_state.past = []
if "messages" not in st.session_state: if "messages" not in st.session_state:
@ -70,13 +70,15 @@ def show_graph():
if st.session_state.generated: if st.session_state.generated:
user, chatbot = [], [] user, chatbot = [], []
graph = graphviz.Digraph() graph = graphviz.Digraph()
for i in range(len(st.session_state.generated)): for i in range(len(st.session_state.past)):
user.append(st.session_state.past[i])
chatbot.append(st.session_state.generated[i]) chatbot.append(st.session_state.generated[i])
user.append(st.session_state.past[i])
for x in range(len(user)): for x in range(len(user)):
graph.edge(st.session_state.past[x], st.session_state.generated[x]) chatbot_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x].split(' '))]
user_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.past[x].split(' '))]
graph.edge(' '.join(chatbot_text), ' '.join(user_text))
try: try:
graph.edge(st.session_state.generated[x], st.session_state.past[x+1]) graph.edge(' '.join(user_text), ' '.join([word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x + 1].split(' '))]))
except: except:
pass pass
st.graphviz_chart(graph) st.graphviz_chart(graph)
@ -101,13 +103,13 @@ def main() -> None:
elif role_kind == st.session_state.locale.radio_text2: elif role_kind == st.session_state.locale.radio_text2:
c2.text_input(label=st.session_state.locale.select_placeholder3, key="role") c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
if st.session_state.user_text:
show_graph()
show_conversation()
get_user_input() get_user_input()
show_chat_buttons() show_chat_buttons()
show_conversation()
with st.sidebar:
show_graph()
if __name__ == "__main__": if __name__ == "__main__":
st.markdown(f"<h1 style='text-align: center;'>{st.session_state.locale.title}</h1>", unsafe_allow_html=True) st.markdown(f"<h1 style='text-align: center;'>{st.session_state.locale.title}</h1>", unsafe_allow_html=True)

View File

@ -1,19 +0,0 @@
import logging
from typing import List # NOQA: UP035
import openai
import streamlit as st
@st.cache_data()
def create_gpt_completion(ai_model: str, messages: List[dict]) -> dict:
#logging.info(f"{messages=}")
# completion = openai.ChatCompletion.create(
# model=ai_model,
# messages=messages,
# # stream=True,
# # temperature=0.7,
# )
#logging.info(f"{completion=}")
# return completion
pass

View File

@ -4,7 +4,6 @@ from random import randrange, choices
import streamlit as st import streamlit as st
from openai.error import InvalidRequestError, OpenAIError from openai.error import InvalidRequestError, OpenAIError
from streamlit_chat import message from streamlit_chat import message
from .agi.chat_gpt import create_gpt_completion
from .stt import show_voice_input from .stt import show_voice_input
from .tts import show_audio_player from .tts import show_audio_player
@ -29,11 +28,13 @@ def get_user_input():
# show_text_input() # show_text_input()
st.session_state.user_text = st.text_input("You: ", "Hello, how are you?", key="primary") st.session_state.user_text = st.text_input("You: ", "Hello, how are you?", key="primary")
def on_send():
st.session_state.past.append(st.session_state.user_text)
def show_chat_buttons() -> None: def show_chat_buttons() -> None:
b0, b1, b2 = st.columns(3) b0, b1, b2 = st.columns(3)
with b0, b1, b2: with b0, b1, b2:
b0.button(label=st.session_state.locale.chat_run_btn) b0.button(label=st.session_state.locale.chat_run_btn, on_click=on_send)
b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat) b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat)
b2.download_button( b2.download_button(
label=st.session_state.locale.chat_save_btn, label=st.session_state.locale.chat_save_btn,
@ -42,18 +43,33 @@ def show_chat_buttons() -> None:
mime="application/json", mime="application/json",
) )
# def show_chat(ai_content: str, user_text: str) -> None:
# first_message = True
#
# if user_text not in st.session_state.past:
# if len(st.session_state.past) == 0:
# first_message = False
# print('message 1')
# message(st.session_state.generated[0], key=str(0), seed=st.session_state.seed)
# else:
# # # store the ai content
# st.session_state.past.append(user_text)
# st.session_state.generated.append(ai_content)
#
# if st.session_state.generated:
# for i in range(len(st.session_state.past)):
# print('message 2')
# message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
# message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
# if first_message:
# print('message 3')
# message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
def show_chat(ai_content: str, user_text: str) -> None: def show_chat() -> None:
if user_text not in st.session_state.past: for i in range(len(st.session_state.past)):
# # store the ai content
st.session_state.past.append(user_text)
st.session_state.generated.append(ai_content)
if st.session_state.generated:
for i in range(len(st.session_state.generated)):
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed) message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
def show_conversation() -> None: def show_conversation() -> None:
if st.session_state.messages: if st.session_state.messages:
@ -65,19 +81,22 @@ def show_conversation() -> None:
{"role": "user", "content": st.session_state.user_text}, {"role": "user", "content": st.session_state.user_text},
] ]
if len(st.session_state.past):
user_message = st.session_state.past[-1]
# ai_content = "Dummy respone from AI" # ai_content = "Dummy respone from AI"
intents = st.session_state.nlu.predict(st.session_state.user_text) intents = st.session_state.nlu.predict(user_message)
st.session_state.dst.update_dialog_history( st.session_state.dst.update_dialog_history(
system_message='', system_message='',
user_message=st.session_state.user_text, user_message=user_message,
intents=intents, intents=intents,
) )
system_message = st.session_state.dp.generate_response(intents) system_message = st.session_state.dp.generate_response(intents)
st.session_state.generated.append(system_message)
# delete random before deploying with our model # delete random before deploying with our model
#random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5)) #random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
ai_content = system_message ai_content = st.session_state.generated[-1]
st.session_state.messages.append({"role": "assistant", "content": ai_content}) st.session_state.messages.append({"role": "assistant", "content": ai_content})
if ai_content: show_chat()
show_chat(ai_content, st.session_state.user_text)
st.divider() st.divider()
show_audio_player(ai_content) show_audio_player(ai_content)