fix chat display
This commit is contained in:
parent
8a3669e2e6
commit
828578d6bf
@ -42,7 +42,7 @@ if __name__ == '__main__':
|
||||
if "locale" not in st.session_state:
|
||||
st.session_state.locale = en
|
||||
if "generated" not in st.session_state:
|
||||
st.session_state.generated = []
|
||||
st.session_state.generated = ["Hello! I'm AMUseBot, a virtual cooking assistant. Please tell me the name of the dish that you'd like to prepare today."]
|
||||
if "past" not in st.session_state:
|
||||
st.session_state.past = []
|
||||
if "messages" not in st.session_state:
|
||||
@ -70,16 +70,18 @@ def show_graph():
|
||||
if st.session_state.generated:
|
||||
user, chatbot = [], []
|
||||
graph = graphviz.Digraph()
|
||||
for i in range(len(st.session_state.generated)):
|
||||
user.append(st.session_state.past[i])
|
||||
for i in range(len(st.session_state.past)):
|
||||
chatbot.append(st.session_state.generated[i])
|
||||
user.append(st.session_state.past[i])
|
||||
for x in range(len(user)):
|
||||
graph.edge(st.session_state.past[x], st.session_state.generated[x])
|
||||
chatbot_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x].split(' '))]
|
||||
user_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.past[x].split(' '))]
|
||||
graph.edge(' '.join(chatbot_text), ' '.join(user_text))
|
||||
try:
|
||||
graph.edge(st.session_state.generated[x], st.session_state.past[x+1])
|
||||
graph.edge(' '.join(user_text), ' '.join([word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x + 1].split(' '))]))
|
||||
except:
|
||||
pass
|
||||
st.graphviz_chart(graph)
|
||||
st.graphviz_chart(graph)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@ -101,12 +103,12 @@ def main() -> None:
|
||||
elif role_kind == st.session_state.locale.radio_text2:
|
||||
c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
|
||||
|
||||
if st.session_state.user_text:
|
||||
show_graph()
|
||||
show_conversation()
|
||||
|
||||
get_user_input()
|
||||
show_chat_buttons()
|
||||
|
||||
show_conversation()
|
||||
with st.sidebar:
|
||||
show_graph()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,19 +0,0 @@
|
||||
import logging
|
||||
from typing import List # NOQA: UP035
|
||||
|
||||
import openai
|
||||
import streamlit as st
|
||||
|
||||
|
||||
@st.cache_data()
|
||||
def create_gpt_completion(ai_model: str, messages: List[dict]) -> dict:
|
||||
#logging.info(f"{messages=}")
|
||||
# completion = openai.ChatCompletion.create(
|
||||
# model=ai_model,
|
||||
# messages=messages,
|
||||
# # stream=True,
|
||||
# # temperature=0.7,
|
||||
# )
|
||||
#logging.info(f"{completion=}")
|
||||
# return completion
|
||||
pass
|
@ -4,7 +4,6 @@ from random import randrange, choices
|
||||
import streamlit as st
|
||||
from openai.error import InvalidRequestError, OpenAIError
|
||||
from streamlit_chat import message
|
||||
from .agi.chat_gpt import create_gpt_completion
|
||||
from .stt import show_voice_input
|
||||
from .tts import show_audio_player
|
||||
|
||||
@ -29,11 +28,13 @@ def get_user_input():
|
||||
# show_text_input()
|
||||
st.session_state.user_text = st.text_input("You: ", "Hello, how are you?", key="primary")
|
||||
|
||||
def on_send():
|
||||
st.session_state.past.append(st.session_state.user_text)
|
||||
|
||||
def show_chat_buttons() -> None:
|
||||
b0, b1, b2 = st.columns(3)
|
||||
with b0, b1, b2:
|
||||
b0.button(label=st.session_state.locale.chat_run_btn)
|
||||
b0.button(label=st.session_state.locale.chat_run_btn, on_click=on_send)
|
||||
b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat)
|
||||
b2.download_button(
|
||||
label=st.session_state.locale.chat_save_btn,
|
||||
@ -42,19 +43,34 @@ def show_chat_buttons() -> None:
|
||||
mime="application/json",
|
||||
)
|
||||
|
||||
|
||||
def show_chat(ai_content: str, user_text: str) -> None:
|
||||
if user_text not in st.session_state.past:
|
||||
# # store the ai content
|
||||
st.session_state.past.append(user_text)
|
||||
st.session_state.generated.append(ai_content)
|
||||
|
||||
if st.session_state.generated:
|
||||
for i in range(len(st.session_state.generated)):
|
||||
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
|
||||
message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
|
||||
|
||||
|
||||
# def show_chat(ai_content: str, user_text: str) -> None:
|
||||
# first_message = True
|
||||
#
|
||||
# if user_text not in st.session_state.past:
|
||||
# if len(st.session_state.past) == 0:
|
||||
# first_message = False
|
||||
# print('message 1')
|
||||
# message(st.session_state.generated[0], key=str(0), seed=st.session_state.seed)
|
||||
# else:
|
||||
# # # store the ai content
|
||||
# st.session_state.past.append(user_text)
|
||||
# st.session_state.generated.append(ai_content)
|
||||
#
|
||||
# if st.session_state.generated:
|
||||
# for i in range(len(st.session_state.past)):
|
||||
# print('message 2')
|
||||
# message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
|
||||
# message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
|
||||
# if first_message:
|
||||
# print('message 3')
|
||||
# message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
|
||||
|
||||
def show_chat() -> None:
|
||||
for i in range(len(st.session_state.past)):
|
||||
message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
|
||||
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
|
||||
message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
|
||||
|
||||
def show_conversation() -> None:
|
||||
if st.session_state.messages:
|
||||
st.session_state.messages.append({"role": "user", "content": st.session_state.user_text})
|
||||
@ -64,20 +80,23 @@ def show_conversation() -> None:
|
||||
{"role": "system", "content": ai_role},
|
||||
{"role": "user", "content": st.session_state.user_text},
|
||||
]
|
||||
|
||||
if len(st.session_state.past):
|
||||
user_message = st.session_state.past[-1]
|
||||
# ai_content = "Dummy respone from AI"
|
||||
intents = st.session_state.nlu.predict(user_message)
|
||||
st.session_state.dst.update_dialog_history(
|
||||
system_message='',
|
||||
user_message=user_message,
|
||||
intents=intents,
|
||||
)
|
||||
system_message = st.session_state.dp.generate_response(intents)
|
||||
st.session_state.generated.append(system_message)
|
||||
|
||||
# ai_content = "Dummy respone from AI"
|
||||
intents = st.session_state.nlu.predict(st.session_state.user_text)
|
||||
st.session_state.dst.update_dialog_history(
|
||||
system_message='',
|
||||
user_message=st.session_state.user_text,
|
||||
intents=intents,
|
||||
)
|
||||
system_message = st.session_state.dp.generate_response(intents)
|
||||
# delete random before deploying with our model
|
||||
#random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
|
||||
ai_content = system_message
|
||||
ai_content = st.session_state.generated[-1]
|
||||
st.session_state.messages.append({"role": "assistant", "content": ai_content})
|
||||
if ai_content:
|
||||
show_chat(ai_content, st.session_state.user_text)
|
||||
st.divider()
|
||||
show_audio_player(ai_content)
|
||||
show_chat()
|
||||
st.divider()
|
||||
show_audio_player(ai_content)
|
||||
|
Loading…
Reference in New Issue
Block a user