fix chat display
This commit is contained in:
parent
8a3669e2e6
commit
828578d6bf
@ -42,7 +42,7 @@ if __name__ == '__main__':
|
|||||||
if "locale" not in st.session_state:
|
if "locale" not in st.session_state:
|
||||||
st.session_state.locale = en
|
st.session_state.locale = en
|
||||||
if "generated" not in st.session_state:
|
if "generated" not in st.session_state:
|
||||||
st.session_state.generated = []
|
st.session_state.generated = ["Hello! I'm AMUseBot, a virtual cooking assistant. Please tell me the name of the dish that you'd like to prepare today."]
|
||||||
if "past" not in st.session_state:
|
if "past" not in st.session_state:
|
||||||
st.session_state.past = []
|
st.session_state.past = []
|
||||||
if "messages" not in st.session_state:
|
if "messages" not in st.session_state:
|
||||||
@ -70,16 +70,18 @@ def show_graph():
|
|||||||
if st.session_state.generated:
|
if st.session_state.generated:
|
||||||
user, chatbot = [], []
|
user, chatbot = [], []
|
||||||
graph = graphviz.Digraph()
|
graph = graphviz.Digraph()
|
||||||
for i in range(len(st.session_state.generated)):
|
for i in range(len(st.session_state.past)):
|
||||||
user.append(st.session_state.past[i])
|
|
||||||
chatbot.append(st.session_state.generated[i])
|
chatbot.append(st.session_state.generated[i])
|
||||||
|
user.append(st.session_state.past[i])
|
||||||
for x in range(len(user)):
|
for x in range(len(user)):
|
||||||
graph.edge(st.session_state.past[x], st.session_state.generated[x])
|
chatbot_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x].split(' '))]
|
||||||
|
user_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.past[x].split(' '))]
|
||||||
|
graph.edge(' '.join(chatbot_text), ' '.join(user_text))
|
||||||
try:
|
try:
|
||||||
graph.edge(st.session_state.generated[x], st.session_state.past[x+1])
|
graph.edge(' '.join(user_text), ' '.join([word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x + 1].split(' '))]))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
st.graphviz_chart(graph)
|
st.graphviz_chart(graph)
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
@ -101,13 +103,13 @@ def main() -> None:
|
|||||||
elif role_kind == st.session_state.locale.radio_text2:
|
elif role_kind == st.session_state.locale.radio_text2:
|
||||||
c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
|
c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
|
||||||
|
|
||||||
if st.session_state.user_text:
|
|
||||||
show_graph()
|
|
||||||
show_conversation()
|
|
||||||
|
|
||||||
get_user_input()
|
get_user_input()
|
||||||
show_chat_buttons()
|
show_chat_buttons()
|
||||||
|
|
||||||
|
show_conversation()
|
||||||
|
with st.sidebar:
|
||||||
|
show_graph()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
st.markdown(f"<h1 style='text-align: center;'>{st.session_state.locale.title}</h1>", unsafe_allow_html=True)
|
st.markdown(f"<h1 style='text-align: center;'>{st.session_state.locale.title}</h1>", unsafe_allow_html=True)
|
||||||
|
@ -1,19 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import List # NOQA: UP035
|
|
||||||
|
|
||||||
import openai
|
|
||||||
import streamlit as st
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data()
|
|
||||||
def create_gpt_completion(ai_model: str, messages: List[dict]) -> dict:
|
|
||||||
#logging.info(f"{messages=}")
|
|
||||||
# completion = openai.ChatCompletion.create(
|
|
||||||
# model=ai_model,
|
|
||||||
# messages=messages,
|
|
||||||
# # stream=True,
|
|
||||||
# # temperature=0.7,
|
|
||||||
# )
|
|
||||||
#logging.info(f"{completion=}")
|
|
||||||
# return completion
|
|
||||||
pass
|
|
@ -4,7 +4,6 @@ from random import randrange, choices
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
from openai.error import InvalidRequestError, OpenAIError
|
from openai.error import InvalidRequestError, OpenAIError
|
||||||
from streamlit_chat import message
|
from streamlit_chat import message
|
||||||
from .agi.chat_gpt import create_gpt_completion
|
|
||||||
from .stt import show_voice_input
|
from .stt import show_voice_input
|
||||||
from .tts import show_audio_player
|
from .tts import show_audio_player
|
||||||
|
|
||||||
@ -29,11 +28,13 @@ def get_user_input():
|
|||||||
# show_text_input()
|
# show_text_input()
|
||||||
st.session_state.user_text = st.text_input("You: ", "Hello, how are you?", key="primary")
|
st.session_state.user_text = st.text_input("You: ", "Hello, how are you?", key="primary")
|
||||||
|
|
||||||
|
def on_send():
|
||||||
|
st.session_state.past.append(st.session_state.user_text)
|
||||||
|
|
||||||
def show_chat_buttons() -> None:
|
def show_chat_buttons() -> None:
|
||||||
b0, b1, b2 = st.columns(3)
|
b0, b1, b2 = st.columns(3)
|
||||||
with b0, b1, b2:
|
with b0, b1, b2:
|
||||||
b0.button(label=st.session_state.locale.chat_run_btn)
|
b0.button(label=st.session_state.locale.chat_run_btn, on_click=on_send)
|
||||||
b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat)
|
b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat)
|
||||||
b2.download_button(
|
b2.download_button(
|
||||||
label=st.session_state.locale.chat_save_btn,
|
label=st.session_state.locale.chat_save_btn,
|
||||||
@ -42,18 +43,33 @@ def show_chat_buttons() -> None:
|
|||||||
mime="application/json",
|
mime="application/json",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# def show_chat(ai_content: str, user_text: str) -> None:
|
||||||
|
# first_message = True
|
||||||
|
#
|
||||||
|
# if user_text not in st.session_state.past:
|
||||||
|
# if len(st.session_state.past) == 0:
|
||||||
|
# first_message = False
|
||||||
|
# print('message 1')
|
||||||
|
# message(st.session_state.generated[0], key=str(0), seed=st.session_state.seed)
|
||||||
|
# else:
|
||||||
|
# # # store the ai content
|
||||||
|
# st.session_state.past.append(user_text)
|
||||||
|
# st.session_state.generated.append(ai_content)
|
||||||
|
#
|
||||||
|
# if st.session_state.generated:
|
||||||
|
# for i in range(len(st.session_state.past)):
|
||||||
|
# print('message 2')
|
||||||
|
# message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
|
||||||
|
# message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
|
||||||
|
# if first_message:
|
||||||
|
# print('message 3')
|
||||||
|
# message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
|
||||||
|
|
||||||
def show_chat(ai_content: str, user_text: str) -> None:
|
def show_chat() -> None:
|
||||||
if user_text not in st.session_state.past:
|
for i in range(len(st.session_state.past)):
|
||||||
# # store the ai content
|
message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
|
||||||
st.session_state.past.append(user_text)
|
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
|
||||||
st.session_state.generated.append(ai_content)
|
message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
|
||||||
|
|
||||||
if st.session_state.generated:
|
|
||||||
for i in range(len(st.session_state.generated)):
|
|
||||||
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
|
|
||||||
message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
|
|
||||||
|
|
||||||
|
|
||||||
def show_conversation() -> None:
|
def show_conversation() -> None:
|
||||||
if st.session_state.messages:
|
if st.session_state.messages:
|
||||||
@ -65,19 +81,22 @@ def show_conversation() -> None:
|
|||||||
{"role": "user", "content": st.session_state.user_text},
|
{"role": "user", "content": st.session_state.user_text},
|
||||||
]
|
]
|
||||||
|
|
||||||
# ai_content = "Dummy respone from AI"
|
if len(st.session_state.past):
|
||||||
intents = st.session_state.nlu.predict(st.session_state.user_text)
|
user_message = st.session_state.past[-1]
|
||||||
st.session_state.dst.update_dialog_history(
|
# ai_content = "Dummy respone from AI"
|
||||||
system_message='',
|
intents = st.session_state.nlu.predict(user_message)
|
||||||
user_message=st.session_state.user_text,
|
st.session_state.dst.update_dialog_history(
|
||||||
intents=intents,
|
system_message='',
|
||||||
)
|
user_message=user_message,
|
||||||
system_message = st.session_state.dp.generate_response(intents)
|
intents=intents,
|
||||||
|
)
|
||||||
|
system_message = st.session_state.dp.generate_response(intents)
|
||||||
|
st.session_state.generated.append(system_message)
|
||||||
|
|
||||||
# delete random before deploying with our model
|
# delete random before deploying with our model
|
||||||
#random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
|
#random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
|
||||||
ai_content = system_message
|
ai_content = st.session_state.generated[-1]
|
||||||
st.session_state.messages.append({"role": "assistant", "content": ai_content})
|
st.session_state.messages.append({"role": "assistant", "content": ai_content})
|
||||||
if ai_content:
|
show_chat()
|
||||||
show_chat(ai_content, st.session_state.user_text)
|
st.divider()
|
||||||
st.divider()
|
show_audio_player(ai_content)
|
||||||
show_audio_player(ai_content)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user