From 424fa164ba7aebcd9d16aa52abd65020f30aa70c Mon Sep 17 00:00:00 2001 From: if Date: Tue, 28 Mar 2023 02:57:29 +0300 Subject: [PATCH] fix context_length_exceeded --- src/utils/conversation.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/utils/conversation.py b/src/utils/conversation.py index c2aa447..08b075e 100644 --- a/src/utils/conversation.py +++ b/src/utils/conversation.py @@ -1,5 +1,5 @@ import streamlit as st -from openai.error import OpenAIError +from openai.error import InvalidRequestError, OpenAIError from streamlit_chat import message from src.utils.ai_interaction import send_ai_request @@ -35,28 +35,34 @@ def show_chat(ai_content: str, user_text: str) -> None: # store the ai content st.session_state.past.append(user_text) st.session_state.generated.append(ai_content) - if st.session_state["generated"]: - for i in range(len(st.session_state["generated"])): + if st.session_state.generated: + for i in range(len(st.session_state.generated)): message(st.session_state["past"][i], is_user=True, key=str(i) + "_user", avatar_style="micah") message("", key=str(i)) - st.markdown(st.session_state["generated"][i]) + st.markdown(st.session_state.generated[i]) def show_conversation(user_content: str, model: str, role: str) -> None: if st.session_state.messages: st.session_state.messages.append({"role": "user", "content": user_content}) else: - st.session_state["messages"] = [ + st.session_state.messages = [ {"role": "system", "content": f"{st.session_state.locale.ai_role_prefix} {role}."}, {"role": "user", "content": user_content}, ] try: completion = send_ai_request(model, st.session_state.messages) ai_content = completion.get("choices")[0].get("message").get("content") - st.session_state["messages"].append({"role": "assistant", "content": ai_content}) + st.session_state.messages.append({"role": "assistant", "content": ai_content}) if ai_content: show_chat(ai_content, user_content) st.markdown("---") show_player(ai_content) + except InvalidRequestError as e: + if e.code == "context_length_exceeded": + st.session_state.messages.pop(1) + if len(st.session_state.messages) == 1: + st.session_state.user_text = "" + show_conversation(st.session_state.user_text, st.session_state.model, st.session_state.role) except (OpenAIError, UnboundLocalError) as err: st.error(err)