integrate chatbot

This commit is contained in:
s444417 2023-05-31 21:11:09 +02:00
parent 0d3b7d0a31
commit 8a3669e2e6
3 changed files with 69 additions and 52 deletions

@ -0,0 +1 @@
Subproject commit 01d547dc221b2bd81ccbe24da37b792e9b176b37

View File

@ -1,56 +1,69 @@
from pathlib import Path from pathlib import Path
from random import randrange from random import randrange
from AMUseBotBackend.src.DP.dp import DP
from AMUseBotBackend.src.DST.dst import DST
from AMUseBotBackend.src.NLU.nlu import NLU
import graphviz import graphviz
import streamlit as st import streamlit as st
from PIL import Image from PIL import Image
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
from src.utils.lang import en from src.utils.lang import en
# --- PATH SETTINGS ---
current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
css_file: Path = current_dir / "src/styles/.css"
assets_dir: Path = current_dir / "assets"
icons_dir: Path = assets_dir / "icons"
img_dir: Path = assets_dir / "img"
tg_svg: Path = icons_dir / "tg.svg"
favicon: Path = icons_dir / "favicons/0.png"
# --- GENERAL SETTINGS ---
LANG_PL: str = "Pl"
AI_MODEL_OPTIONS: list[str] = [
"gpt-3.5-turbo",
"gpt-4",
"gpt-4-32k",
]
CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)} if __name__ == '__main__':
st.set_page_config(**CONFIG) # --- PATH SETTINGS ---
current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
css_file: Path = current_dir / "src/styles/.css"
assets_dir: Path = current_dir / "assets"
icons_dir: Path = assets_dir / "icons"
img_dir: Path = assets_dir / "img"
tg_svg: Path = icons_dir / "tg.svg"
favicon: Path = icons_dir / "favicons/0.png"
# --- GENERAL SETTINGS ---
LANG_PL: str = "Pl"
AI_MODEL_OPTIONS: list[str] = [
"gpt-3.5-turbo",
"gpt-4",
"gpt-4-32k",
]
# --- LOAD CSS --- CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)}
with open(css_file) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
# Storing The Context st.set_page_config(**CONFIG)
if "locale" not in st.session_state:
st.session_state.locale = en
if "generated" not in st.session_state:
st.session_state.generated = []
if "past" not in st.session_state:
st.session_state.past = []
if "messages" not in st.session_state:
st.session_state.messages = []
if "user_text" not in st.session_state:
st.session_state.user_text = ""
if "input_kind" not in st.session_state:
st.session_state.input_kind = st.session_state.locale.input_kind_1
if "seed" not in st.session_state:
st.session_state.seed = randrange(10 ** 3) # noqa: S311
if "costs" not in st.session_state:
st.session_state.costs = []
if "total_tokens" not in st.session_state:
st.session_state.total_tokens = []
# --- LOAD CSS ---
with open(css_file) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
# Storing The Context
if "locale" not in st.session_state:
st.session_state.locale = en
if "generated" not in st.session_state:
st.session_state.generated = []
if "past" not in st.session_state:
st.session_state.past = []
if "messages" not in st.session_state:
st.session_state.messages = []
if "user_text" not in st.session_state:
st.session_state.user_text = ""
if "input_kind" not in st.session_state:
st.session_state.input_kind = st.session_state.locale.input_kind_1
if "seed" not in st.session_state:
st.session_state.seed = randrange(10 ** 3) # noqa: S311
if "costs" not in st.session_state:
st.session_state.costs = []
if "total_tokens" not in st.session_state:
st.session_state.total_tokens = []
if "dst" not in st.session_state:
st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/")
if "dp" not in st.session_state:
st.session_state.dp = DP(dst=st.session_state.dst)
if "nlu" not in st.session_state:
st.session_state.nlu = NLU(intent_dict_path='AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json',
model_identifier_path='AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt')
def show_graph(): def show_graph():
# Create a graphlib graph object # Create a graphlib graph object
@ -66,14 +79,12 @@ def show_graph():
graph.edge(st.session_state.generated[x], st.session_state.past[x+1]) graph.edge(st.session_state.generated[x], st.session_state.past[x+1])
except: except:
pass pass
st.graphviz_chart(graph)
st.graphviz_chart(graph)
def main() -> None: def main() -> None:
c1, c2 = st.columns(2) c1, c2 = st.columns(2)
with c1, c2: with c1, c2:
c1.selectbox(label=st.session_state.locale.select_placeholder1, key="model", options=AI_MODEL_OPTIONS)
st.session_state.input_kind = c2.radio( st.session_state.input_kind = c2.radio(
label=st.session_state.locale.input_kind, label=st.session_state.locale.input_kind,
options=(st.session_state.locale.input_kind_1, st.session_state.locale.input_kind_2), options=(st.session_state.locale.input_kind_1, st.session_state.locale.input_kind_2),
@ -84,19 +95,17 @@ def main() -> None:
options=(st.session_state.locale.radio_text1, st.session_state.locale.radio_text2), options=(st.session_state.locale.radio_text1, st.session_state.locale.radio_text2),
horizontal=True, horizontal=True,
) )
match role_kind: if role_kind == st.session_state.locale.radio_text1:
case st.session_state.locale.radio_text1: c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role",
c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role",
options=st.session_state.locale.ai_role_options) options=st.session_state.locale.ai_role_options)
case st.session_state.locale.radio_text2: elif role_kind == st.session_state.locale.radio_text2:
c2.text_input(label=st.session_state.locale.select_placeholder3, key="role") c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
if st.session_state.user_text: if st.session_state.user_text:
show_graph() show_graph()
show_conversation() show_conversation()
get_user_input() get_user_input()
show_chat_buttons() show_chat_buttons()

View File

@ -65,10 +65,17 @@ def show_conversation() -> None:
{"role": "user", "content": st.session_state.user_text}, {"role": "user", "content": st.session_state.user_text},
] ]
ai_content = "Dummy respone from AI" # ai_content = "Dummy respone from AI"
intents = st.session_state.nlu.predict(st.session_state.user_text)
st.session_state.dst.update_dialog_history(
system_message='',
user_message=st.session_state.user_text,
intents=intents,
)
system_message = st.session_state.dp.generate_response(intents)
# delete random before deploying with our model # delete random before deploying with our model
random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5)) #random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
ai_content += random_str ai_content = system_message
st.session_state.messages.append({"role": "assistant", "content": ai_content}) st.session_state.messages.append({"role": "assistant", "content": ai_content})
if ai_content: if ai_content:
show_chat(ai_content, st.session_state.user_text) show_chat(ai_content, st.session_state.user_text)