fix chat display

This commit is contained in:
s444417 2023-06-05 20:54:15 +02:00
parent 8a3669e2e6
commit 828578d6bf
3 changed files with 59 additions and 57 deletions

View File

@ -42,7 +42,7 @@ if __name__ == '__main__':
if "locale" not in st.session_state:
st.session_state.locale = en
if "generated" not in st.session_state:
st.session_state.generated = []
st.session_state.generated = ["Hello! I'm AMUseBot, a virtual cooking assistant. Please tell me the name of the dish that you'd like to prepare today."]
if "past" not in st.session_state:
st.session_state.past = []
if "messages" not in st.session_state:
@ -70,16 +70,18 @@ def show_graph():
if st.session_state.generated:
user, chatbot = [], []
graph = graphviz.Digraph()
for i in range(len(st.session_state.generated)):
user.append(st.session_state.past[i])
for i in range(len(st.session_state.past)):
chatbot.append(st.session_state.generated[i])
user.append(st.session_state.past[i])
for x in range(len(user)):
graph.edge(st.session_state.past[x], st.session_state.generated[x])
chatbot_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x].split(' '))]
user_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.past[x].split(' '))]
graph.edge(' '.join(chatbot_text), ' '.join(user_text))
try:
graph.edge(st.session_state.generated[x], st.session_state.past[x+1])
graph.edge(' '.join(user_text), ' '.join([word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x + 1].split(' '))]))
except:
pass
st.graphviz_chart(graph)
st.graphviz_chart(graph)
def main() -> None:
@ -101,12 +103,12 @@ def main() -> None:
elif role_kind == st.session_state.locale.radio_text2:
c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
if st.session_state.user_text:
show_graph()
show_conversation()
get_user_input()
show_chat_buttons()
show_conversation()
with st.sidebar:
show_graph()
if __name__ == "__main__":

View File

@ -1,19 +0,0 @@
import logging
from typing import List # NOQA: UP035
import openai
import streamlit as st
@st.cache_data()
def create_gpt_completion(ai_model: str, messages: List[dict]) -> dict:
#logging.info(f"{messages=}")
# completion = openai.ChatCompletion.create(
# model=ai_model,
# messages=messages,
# # stream=True,
# # temperature=0.7,
# )
#logging.info(f"{completion=}")
# return completion
pass

View File

@ -4,7 +4,6 @@ from random import randrange, choices
import streamlit as st
from openai.error import InvalidRequestError, OpenAIError
from streamlit_chat import message
from .agi.chat_gpt import create_gpt_completion
from .stt import show_voice_input
from .tts import show_audio_player
@ -29,11 +28,13 @@ def get_user_input():
# show_text_input()
st.session_state.user_text = st.text_input("You: ", "Hello, how are you?", key="primary")
def on_send():
st.session_state.past.append(st.session_state.user_text)
def show_chat_buttons() -> None:
b0, b1, b2 = st.columns(3)
with b0, b1, b2:
b0.button(label=st.session_state.locale.chat_run_btn)
b0.button(label=st.session_state.locale.chat_run_btn, on_click=on_send)
b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat)
b2.download_button(
label=st.session_state.locale.chat_save_btn,
@ -42,19 +43,34 @@ def show_chat_buttons() -> None:
mime="application/json",
)
def show_chat(ai_content: str, user_text: str) -> None:
if user_text not in st.session_state.past:
# # store the ai content
st.session_state.past.append(user_text)
st.session_state.generated.append(ai_content)
if st.session_state.generated:
for i in range(len(st.session_state.generated)):
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
# def show_chat(ai_content: str, user_text: str) -> None:
# first_message = True
#
# if user_text not in st.session_state.past:
# if len(st.session_state.past) == 0:
# first_message = False
# print('message 1')
# message(st.session_state.generated[0], key=str(0), seed=st.session_state.seed)
# else:
# # # store the ai content
# st.session_state.past.append(user_text)
# st.session_state.generated.append(ai_content)
#
# if st.session_state.generated:
# for i in range(len(st.session_state.past)):
# print('message 2')
# message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
# message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
# if first_message:
# print('message 3')
# message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
def show_chat() -> None:
for i in range(len(st.session_state.past)):
message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
def show_conversation() -> None:
if st.session_state.messages:
st.session_state.messages.append({"role": "user", "content": st.session_state.user_text})
@ -64,20 +80,23 @@ def show_conversation() -> None:
{"role": "system", "content": ai_role},
{"role": "user", "content": st.session_state.user_text},
]
if len(st.session_state.past):
user_message = st.session_state.past[-1]
# ai_content = "Dummy respone from AI"
intents = st.session_state.nlu.predict(user_message)
st.session_state.dst.update_dialog_history(
system_message='',
user_message=user_message,
intents=intents,
)
system_message = st.session_state.dp.generate_response(intents)
st.session_state.generated.append(system_message)
# ai_content = "Dummy respone from AI"
intents = st.session_state.nlu.predict(st.session_state.user_text)
st.session_state.dst.update_dialog_history(
system_message='',
user_message=st.session_state.user_text,
intents=intents,
)
system_message = st.session_state.dp.generate_response(intents)
# delete random before deploying with our model
#random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
ai_content = system_message
ai_content = st.session_state.generated[-1]
st.session_state.messages.append({"role": "assistant", "content": ai_content})
if ai_content:
show_chat(ai_content, st.session_state.user_text)
st.divider()
show_audio_player(ai_content)
show_chat()
st.divider()
show_audio_player(ai_content)