integrate chatbot

This commit is contained in:
s444417 2023-05-31 21:11:09 +02:00
parent 0d3b7d0a31
commit 8a3669e2e6
3 changed files with 69 additions and 52 deletions

@ -0,0 +1 @@
Subproject commit 01d547dc221b2bd81ccbe24da37b792e9b176b37

View File

@ -1,56 +1,69 @@
from pathlib import Path
from random import randrange
from AMUseBotBackend.src.DP.dp import DP
from AMUseBotBackend.src.DST.dst import DST
from AMUseBotBackend.src.NLU.nlu import NLU
import graphviz
import streamlit as st
from PIL import Image
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
from src.utils.lang import en
# --- PATH SETTINGS ---
current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
css_file: Path = current_dir / "src/styles/.css"
assets_dir: Path = current_dir / "assets"
icons_dir: Path = assets_dir / "icons"
img_dir: Path = assets_dir / "img"
tg_svg: Path = icons_dir / "tg.svg"
favicon: Path = icons_dir / "favicons/0.png"
# --- GENERAL SETTINGS ---
LANG_PL: str = "Pl"
AI_MODEL_OPTIONS: list[str] = [
if __name__ == '__main__':
# --- PATH SETTINGS ---
current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
css_file: Path = current_dir / "src/styles/.css"
assets_dir: Path = current_dir / "assets"
icons_dir: Path = assets_dir / "icons"
img_dir: Path = assets_dir / "img"
tg_svg: Path = icons_dir / "tg.svg"
favicon: Path = icons_dir / "favicons/0.png"
# --- GENERAL SETTINGS ---
LANG_PL: str = "Pl"
AI_MODEL_OPTIONS: list[str] = [
"gpt-3.5-turbo",
"gpt-4",
"gpt-4-32k",
]
]
CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)}
CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)}
st.set_page_config(**CONFIG)
st.set_page_config(**CONFIG)
# --- LOAD CSS ---
with open(css_file) as f:
# --- LOAD CSS ---
with open(css_file) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
# Storing The Context
if "locale" not in st.session_state:
# Storing The Context
if "locale" not in st.session_state:
st.session_state.locale = en
if "generated" not in st.session_state:
if "generated" not in st.session_state:
st.session_state.generated = []
if "past" not in st.session_state:
if "past" not in st.session_state:
st.session_state.past = []
if "messages" not in st.session_state:
if "messages" not in st.session_state:
st.session_state.messages = []
if "user_text" not in st.session_state:
if "user_text" not in st.session_state:
st.session_state.user_text = ""
if "input_kind" not in st.session_state:
if "input_kind" not in st.session_state:
st.session_state.input_kind = st.session_state.locale.input_kind_1
if "seed" not in st.session_state:
if "seed" not in st.session_state:
st.session_state.seed = randrange(10 ** 3) # noqa: S311
if "costs" not in st.session_state:
if "costs" not in st.session_state:
st.session_state.costs = []
if "total_tokens" not in st.session_state:
if "total_tokens" not in st.session_state:
st.session_state.total_tokens = []
if "dst" not in st.session_state:
st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/")
if "dp" not in st.session_state:
st.session_state.dp = DP(dst=st.session_state.dst)
if "nlu" not in st.session_state:
st.session_state.nlu = NLU(intent_dict_path='AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json',
model_identifier_path='AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt')
def show_graph():
# Create a graphlib graph object
@ -66,14 +79,12 @@ def show_graph():
graph.edge(st.session_state.generated[x], st.session_state.past[x+1])
except:
pass
st.graphviz_chart(graph)
def main() -> None:
c1, c2 = st.columns(2)
with c1, c2:
c1.selectbox(label=st.session_state.locale.select_placeholder1, key="model", options=AI_MODEL_OPTIONS)
st.session_state.input_kind = c2.radio(
label=st.session_state.locale.input_kind,
options=(st.session_state.locale.input_kind_1, st.session_state.locale.input_kind_2),
@ -84,11 +95,10 @@ def main() -> None:
options=(st.session_state.locale.radio_text1, st.session_state.locale.radio_text2),
horizontal=True,
)
match role_kind:
case st.session_state.locale.radio_text1:
if role_kind == st.session_state.locale.radio_text1:
c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role",
options=st.session_state.locale.ai_role_options)
case st.session_state.locale.radio_text2:
elif role_kind == st.session_state.locale.radio_text2:
c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
if st.session_state.user_text:
@ -96,7 +106,6 @@ def main() -> None:
show_conversation()
get_user_input()
show_chat_buttons()

View File

@ -65,10 +65,17 @@ def show_conversation() -> None:
{"role": "user", "content": st.session_state.user_text},
]
ai_content = "Dummy respone from AI"
# ai_content = "Dummy respone from AI"
intents = st.session_state.nlu.predict(st.session_state.user_text)
st.session_state.dst.update_dialog_history(
system_message='',
user_message=st.session_state.user_text,
intents=intents,
)
system_message = st.session_state.dp.generate_response(intents)
# delete random before deploying with our model
random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
ai_content += random_str
#random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
ai_content = system_message
st.session_state.messages.append({"role": "assistant", "content": ai_content})
if ai_content:
show_chat(ai_content, st.session_state.user_text)