integrate chatbot
This commit is contained in:
parent
0d3b7d0a31
commit
8a3669e2e6
1
ai_talks/AMUseBotBackend
Submodule
1
ai_talks/AMUseBotBackend
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 01d547dc221b2bd81ccbe24da37b792e9b176b37
|
107
ai_talks/chat.py
107
ai_talks/chat.py
@ -1,56 +1,69 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from random import randrange
|
from random import randrange
|
||||||
|
|
||||||
|
from AMUseBotBackend.src.DP.dp import DP
|
||||||
|
from AMUseBotBackend.src.DST.dst import DST
|
||||||
|
from AMUseBotBackend.src.NLU.nlu import NLU
|
||||||
|
|
||||||
import graphviz
|
import graphviz
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
|
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
|
||||||
from src.utils.lang import en
|
from src.utils.lang import en
|
||||||
|
|
||||||
# --- PATH SETTINGS ---
|
|
||||||
current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
|
|
||||||
css_file: Path = current_dir / "src/styles/.css"
|
|
||||||
assets_dir: Path = current_dir / "assets"
|
|
||||||
icons_dir: Path = assets_dir / "icons"
|
|
||||||
img_dir: Path = assets_dir / "img"
|
|
||||||
tg_svg: Path = icons_dir / "tg.svg"
|
|
||||||
favicon: Path = icons_dir / "favicons/0.png"
|
|
||||||
# --- GENERAL SETTINGS ---
|
|
||||||
LANG_PL: str = "Pl"
|
|
||||||
AI_MODEL_OPTIONS: list[str] = [
|
|
||||||
"gpt-3.5-turbo",
|
|
||||||
"gpt-4",
|
|
||||||
"gpt-4-32k",
|
|
||||||
]
|
|
||||||
|
|
||||||
CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)}
|
if __name__ == '__main__':
|
||||||
|
|
||||||
st.set_page_config(**CONFIG)
|
# --- PATH SETTINGS ---
|
||||||
|
current_dir: Path = Path(__file__).parent if "__file__" in locals() else Path.cwd()
|
||||||
|
css_file: Path = current_dir / "src/styles/.css"
|
||||||
|
assets_dir: Path = current_dir / "assets"
|
||||||
|
icons_dir: Path = assets_dir / "icons"
|
||||||
|
img_dir: Path = assets_dir / "img"
|
||||||
|
tg_svg: Path = icons_dir / "tg.svg"
|
||||||
|
favicon: Path = icons_dir / "favicons/0.png"
|
||||||
|
# --- GENERAL SETTINGS ---
|
||||||
|
LANG_PL: str = "Pl"
|
||||||
|
AI_MODEL_OPTIONS: list[str] = [
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-4-32k",
|
||||||
|
]
|
||||||
|
|
||||||
# --- LOAD CSS ---
|
CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)}
|
||||||
with open(css_file) as f:
|
|
||||||
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
# Storing The Context
|
st.set_page_config(**CONFIG)
|
||||||
if "locale" not in st.session_state:
|
|
||||||
st.session_state.locale = en
|
|
||||||
if "generated" not in st.session_state:
|
|
||||||
st.session_state.generated = []
|
|
||||||
if "past" not in st.session_state:
|
|
||||||
st.session_state.past = []
|
|
||||||
if "messages" not in st.session_state:
|
|
||||||
st.session_state.messages = []
|
|
||||||
if "user_text" not in st.session_state:
|
|
||||||
st.session_state.user_text = ""
|
|
||||||
if "input_kind" not in st.session_state:
|
|
||||||
st.session_state.input_kind = st.session_state.locale.input_kind_1
|
|
||||||
if "seed" not in st.session_state:
|
|
||||||
st.session_state.seed = randrange(10 ** 3) # noqa: S311
|
|
||||||
if "costs" not in st.session_state:
|
|
||||||
st.session_state.costs = []
|
|
||||||
if "total_tokens" not in st.session_state:
|
|
||||||
st.session_state.total_tokens = []
|
|
||||||
|
|
||||||
|
# --- LOAD CSS ---
|
||||||
|
with open(css_file) as f:
|
||||||
|
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# Storing The Context
|
||||||
|
if "locale" not in st.session_state:
|
||||||
|
st.session_state.locale = en
|
||||||
|
if "generated" not in st.session_state:
|
||||||
|
st.session_state.generated = []
|
||||||
|
if "past" not in st.session_state:
|
||||||
|
st.session_state.past = []
|
||||||
|
if "messages" not in st.session_state:
|
||||||
|
st.session_state.messages = []
|
||||||
|
if "user_text" not in st.session_state:
|
||||||
|
st.session_state.user_text = ""
|
||||||
|
if "input_kind" not in st.session_state:
|
||||||
|
st.session_state.input_kind = st.session_state.locale.input_kind_1
|
||||||
|
if "seed" not in st.session_state:
|
||||||
|
st.session_state.seed = randrange(10 ** 3) # noqa: S311
|
||||||
|
if "costs" not in st.session_state:
|
||||||
|
st.session_state.costs = []
|
||||||
|
if "total_tokens" not in st.session_state:
|
||||||
|
st.session_state.total_tokens = []
|
||||||
|
if "dst" not in st.session_state:
|
||||||
|
st.session_state.dst = DST(recipe_path="AMUseBotFront/ai_talks/AMUseBotBackend/recipe/", dialog_path="AMUseBotFront/ai_talks/AMUseBotBackend/dialog/")
|
||||||
|
if "dp" not in st.session_state:
|
||||||
|
st.session_state.dp = DP(dst=st.session_state.dst)
|
||||||
|
if "nlu" not in st.session_state:
|
||||||
|
st.session_state.nlu = NLU(intent_dict_path='AMUseBotFront/ai_talks/AMUseBotBackend/utils/intent_dict.json',
|
||||||
|
model_identifier_path='AMUseBotFront/ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt')
|
||||||
|
|
||||||
def show_graph():
|
def show_graph():
|
||||||
# Create a graphlib graph object
|
# Create a graphlib graph object
|
||||||
@ -66,14 +79,12 @@ def show_graph():
|
|||||||
graph.edge(st.session_state.generated[x], st.session_state.past[x+1])
|
graph.edge(st.session_state.generated[x], st.session_state.past[x+1])
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
st.graphviz_chart(graph)
|
||||||
st.graphviz_chart(graph)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
c1, c2 = st.columns(2)
|
c1, c2 = st.columns(2)
|
||||||
with c1, c2:
|
with c1, c2:
|
||||||
c1.selectbox(label=st.session_state.locale.select_placeholder1, key="model", options=AI_MODEL_OPTIONS)
|
|
||||||
st.session_state.input_kind = c2.radio(
|
st.session_state.input_kind = c2.radio(
|
||||||
label=st.session_state.locale.input_kind,
|
label=st.session_state.locale.input_kind,
|
||||||
options=(st.session_state.locale.input_kind_1, st.session_state.locale.input_kind_2),
|
options=(st.session_state.locale.input_kind_1, st.session_state.locale.input_kind_2),
|
||||||
@ -84,19 +95,17 @@ def main() -> None:
|
|||||||
options=(st.session_state.locale.radio_text1, st.session_state.locale.radio_text2),
|
options=(st.session_state.locale.radio_text1, st.session_state.locale.radio_text2),
|
||||||
horizontal=True,
|
horizontal=True,
|
||||||
)
|
)
|
||||||
match role_kind:
|
if role_kind == st.session_state.locale.radio_text1:
|
||||||
case st.session_state.locale.radio_text1:
|
c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role",
|
||||||
c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role",
|
|
||||||
options=st.session_state.locale.ai_role_options)
|
options=st.session_state.locale.ai_role_options)
|
||||||
case st.session_state.locale.radio_text2:
|
elif role_kind == st.session_state.locale.radio_text2:
|
||||||
c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
|
c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
|
||||||
|
|
||||||
if st.session_state.user_text:
|
if st.session_state.user_text:
|
||||||
show_graph()
|
show_graph()
|
||||||
show_conversation()
|
show_conversation()
|
||||||
|
|
||||||
get_user_input()
|
get_user_input()
|
||||||
|
|
||||||
show_chat_buttons()
|
show_chat_buttons()
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,10 +65,17 @@ def show_conversation() -> None:
|
|||||||
{"role": "user", "content": st.session_state.user_text},
|
{"role": "user", "content": st.session_state.user_text},
|
||||||
]
|
]
|
||||||
|
|
||||||
ai_content = "Dummy respone from AI"
|
# ai_content = "Dummy respone from AI"
|
||||||
|
intents = st.session_state.nlu.predict(st.session_state.user_text)
|
||||||
|
st.session_state.dst.update_dialog_history(
|
||||||
|
system_message='',
|
||||||
|
user_message=st.session_state.user_text,
|
||||||
|
intents=intents,
|
||||||
|
)
|
||||||
|
system_message = st.session_state.dp.generate_response(intents)
|
||||||
# delete random before deploying with our model
|
# delete random before deploying with our model
|
||||||
random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
|
#random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
|
||||||
ai_content += random_str
|
ai_content = system_message
|
||||||
st.session_state.messages.append({"role": "assistant", "content": ai_content})
|
st.session_state.messages.append({"role": "assistant", "content": ai_content})
|
||||||
if ai_content:
|
if ai_content:
|
||||||
show_chat(ai_content, st.session_state.user_text)
|
show_chat(ai_content, st.session_state.user_text)
|
||||||
|
Loading…
Reference in New Issue
Block a user