fix context_length_exceeded

This commit is contained in:
if 2023-03-28 02:57:29 +03:00
parent 852047ea38
commit 424fa164ba

View File

@ -1,5 +1,5 @@
import streamlit as st import streamlit as st
from openai.error import OpenAIError from openai.error import InvalidRequestError, OpenAIError
from streamlit_chat import message from streamlit_chat import message
from src.utils.ai_interaction import send_ai_request from src.utils.ai_interaction import send_ai_request
@ -35,28 +35,34 @@ def show_chat(ai_content: str, user_text: str) -> None:
# store the ai content # store the ai content
st.session_state.past.append(user_text) st.session_state.past.append(user_text)
st.session_state.generated.append(ai_content) st.session_state.generated.append(ai_content)
if st.session_state["generated"]: if st.session_state.generated:
for i in range(len(st.session_state["generated"])): for i in range(len(st.session_state.generated)):
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user", avatar_style="micah") message(st.session_state["past"][i], is_user=True, key=str(i) + "_user", avatar_style="micah")
message("", key=str(i)) message("", key=str(i))
st.markdown(st.session_state["generated"][i]) st.markdown(st.session_state.generated[i])
def show_conversation(user_content: str, model: str, role: str) -> None: def show_conversation(user_content: str, model: str, role: str) -> None:
if st.session_state.messages: if st.session_state.messages:
st.session_state.messages.append({"role": "user", "content": user_content}) st.session_state.messages.append({"role": "user", "content": user_content})
else: else:
st.session_state["messages"] = [ st.session_state.messages = [
{"role": "system", "content": f"{st.session_state.locale.ai_role_prefix} {role}."}, {"role": "system", "content": f"{st.session_state.locale.ai_role_prefix} {role}."},
{"role": "user", "content": user_content}, {"role": "user", "content": user_content},
] ]
try: try:
completion = send_ai_request(model, st.session_state.messages) completion = send_ai_request(model, st.session_state.messages)
ai_content = completion.get("choices")[0].get("message").get("content") ai_content = completion.get("choices")[0].get("message").get("content")
st.session_state["messages"].append({"role": "assistant", "content": ai_content}) st.session_state.messages.append({"role": "assistant", "content": ai_content})
if ai_content: if ai_content:
show_chat(ai_content, user_content) show_chat(ai_content, user_content)
st.markdown("---") st.markdown("---")
show_player(ai_content) show_player(ai_content)
except InvalidRequestError as e:
if e.code == "context_length_exceeded":
st.session_state.messages.pop(1)
if len(st.session_state.messages) == 1:
st.session_state.user_text = ""
show_conversation(st.session_state.user_text, st.session_state.model, st.session_state.role)
except (OpenAIError, UnboundLocalError) as err: except (OpenAIError, UnboundLocalError) as err:
st.error(err) st.error(err)