diff --git a/chat.py b/chat.py index 8f9c1c9..58734d7 100644 --- a/chat.py +++ b/chat.py @@ -53,17 +53,18 @@ if "user_text" not in st.session_state: def main() -> None: - user_content = get_user_input() - show_chat_buttons() + if st.session_state.user_text: + show_conversation(st.session_state.user_text, st.session_state.model, st.session_state.role) + st.session_state.user_text = "" c1, c2 = st.columns(2) with c1, c2: - model = c1.selectbox(label=st.session_state.locale.select_placeholder1, options=AI_MODEL_OPTIONS) - role = c2.selectbox(label=st.session_state.locale.select_placeholder2, - options=st.session_state.locale.ai_role_options) + c1.selectbox(label=st.session_state.locale.select_placeholder1, key="model", options=AI_MODEL_OPTIONS) + c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role", + options=st.session_state.locale.ai_role_options) - if user_content: - show_conversation(user_content, model, role) + get_user_input() + show_chat_buttons() if __name__ == "__main__": diff --git a/src/utils/conversation.py b/src/utils/conversation.py index 28f4281..36e4e40 100644 --- a/src/utils/conversation.py +++ b/src/utils/conversation.py @@ -13,9 +13,8 @@ def clear_chat() -> None: st.session_state["user_text"] = "" -def get_user_input() -> str: - user_text = st.text_area(label=st.session_state.locale.chat_placeholder, key="user_text") - return user_text +def get_user_input(): + st.text_area(label=st.session_state.locale.chat_placeholder, value=st.session_state.user_text, key="user_text") def show_chat_buttons() -> None: @@ -32,10 +31,10 @@ def show_chat(ai_content: str, user_text: str) -> None: st.session_state.past.append(user_text) st.session_state.generated.append(ai_content) if st.session_state["generated"]: - for i in range(len(st.session_state["generated"]) - 1, -1, -1): + for i in range(len(st.session_state["generated"])): + message(st.session_state["past"][i], is_user=True, key=str(i) + "_user", avatar_style="micah") message("", key=str(i)) st.markdown(st.session_state["generated"][i]) - message(st.session_state["past"][i], is_user=True, key=str(i) + "_user", avatar_style="micah") def show_conversation(user_content: str, model: str, role: str) -> None: