Compare commits

...

12 Commits

Author SHA1 Message Date
Kacper E. Dudzic
4ff281954f
Update README.md 2023-06-29 13:26:39 +02:00
Kacper
04eac5ac2f add elements from the presentation version 2023-06-29 13:16:17 +02:00
s444417
8a1677d02a fix requirements 2023-06-29 12:45:32 +02:00
s444417
cadd564387 fix instruction bug 2023-06-28 23:28:38 +02:00
s444417
6fa7ff4820 add instruction 2023-06-28 23:00:54 +02:00
s444417
0c8e63d488 move files to root 2023-06-28 22:55:54 +02:00
6a8f83f2b7 unnecessary tab in graph generation deleted 2023-06-18 09:30:16 +02:00
0cb506fe38 select role from frontend 2023-06-17 17:54:27 +02:00
s444417
21710fccd2 change character format 2023-06-17 16:55:23 +02:00
s444417
04868e022d move api to state 2023-06-17 10:44:39 +02:00
“Kacper
44da2b26e2 add char dict 2023-06-16 21:56:35 +02:00
“Kacper
bd96acf7ea add generative component to dst/dp 2023-06-16 21:49:27 +02:00
536 changed files with 202 additions and 96 deletions

View File

@ -1,5 +1,7 @@
RECIPE_PATH=AMUseBotFront/ai_talks/AMUseBotBackend/recipe/ RECIPE_PATH=recipe/
DIALOG_PATH=AMUseBotFront/ai_talks/AMUseBotBackend/dialog/ DIALOG_PATH=dialog/
INTENT_DICT_PATH=ai_talks/AMUseBotBackend/utils/intent_dict.json INTENT_DICT_PATH=intent_dict.json
MODEL_IDENTIFIER_PATH=ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt MODEL_IDENTIFIER_PATH=roberta-base-cookdial-v1_1.txt
INGREDIENTS_RECIPES_MERGED= INGREDIENTS_RECIPES_MERGED=ingredients_recipes_merged.csv
CHARACTERS_DICT=characters_dict.json
API_KEY=

36
README.md Normal file
View File

@ -0,0 +1,36 @@
# Cooking taskbot project
## Run system
#### With Conda
conda create -n "my_env" python=3.9.12 ipython
conda activate my_env
pip install -r requirements.txt
streamlit run ai_talks\chat.py
After running system, model saves in dir:
Linux
~/.cache/huggingface/transformers
Windows
C:\Users\username\.cache\huggingface\transformers
To use the purely experimental generative features, for now, an OpenAI API key is needed. Insert it into the following file:
AMUseBot/.env_template
## Requirements
Python 3.9.12
## Dataset
[YiweiJiang2015/CookDial](https://github.com/YiweiJiang2015/CookDial)
## NLU model HF repo
[kedudzic/roberta-base-cookdial](https://huggingface.co/AMUseBot/roberta-base-cookdial-v1_1)

View File

@ -4,14 +4,16 @@ from AMUseBotBackend.src.NLG.nlg import NLG
from AMUseBotBackend.src.tools.search import search_recipe from AMUseBotBackend.src.tools.search import search_recipe
import AMUseBotBackend.consts as c import AMUseBotBackend.consts as c
import json
import streamlit as st
class DP: class DP:
def __init__(self, dst: DST): def __init__(self, dst: DST, llm_rephrasing=True, character='default'): #TODO: a way to set llm_rephrasing status and a character
self.dst_module = dst self.dst_module = dst
self.llm_rephrasing = llm_rephrasing
self.character = character
def generate_response(self, intents: List[str]) -> str: def generate_response(self, intents: List[str]) -> str:
@ -31,8 +33,13 @@ class DP:
if found_recipe: if found_recipe:
recipe_name = self.dst_module.set_recipe(found_recipe) recipe_name = self.dst_module.set_recipe(found_recipe)
self.dst_module.set_next_step() self.dst_module.set_next_step()
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \ if self.llm_rephrasing:
+ self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)] return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
+ NLG.llm_rephrase_recipe(self.character, self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)])
else:
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
+ self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
if not found_recipe: if not found_recipe:
return NLG.MESSAGE_NOT_UNDERSTAND_SUGGEST_RECIPE(self.dst_module.get_random_recipes(3)) return NLG.MESSAGE_NOT_UNDERSTAND_SUGGEST_RECIPE(self.dst_module.get_random_recipes(3))
# not understand ask recipe # not understand ask recipe
@ -41,6 +48,8 @@ class DP:
# Recipe choosen # Recipe choosen
if (None != self.dst_module.generate_state(c.RECIPE_ID_KEY) and "" != self.dst_module.generate_state( if (None != self.dst_module.generate_state(c.RECIPE_ID_KEY) and "" != self.dst_module.generate_state(
c.RECIPE_ID_KEY)): c.RECIPE_ID_KEY)):
if ("req_substitute" in intents):
return NLG.llm_substitute_product(self.character, self.dst_module.generate_state(c.DIALOG_HISTORY_KEY)[-1][c.USER_MESSAGE_KEY])
if ("req_ingredient_list" in intents if ("req_ingredient_list" in intents
or "req_ingredient" in intents): or "req_ingredient" in intents):
return NLG.MESSAGE_INGREDIENTS(self.dst_module.generate_state(c.INGREDIENTS_KEY)) return NLG.MESSAGE_INGREDIENTS(self.dst_module.generate_state(c.INGREDIENTS_KEY))
@ -51,7 +60,10 @@ class DP:
or "req_instruction" in intents): or "req_instruction" in intents):
next_step = self.dst_module.set_next_step() next_step = self.dst_module.set_next_step()
if (next_step): if (next_step):
return self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)] if self.llm_rephrasing:
return NLG.llm_rephrase_recipe(self.character, self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)])
else:
return self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
if (not next_step): if (not next_step):
self.dst_module.restart() self.dst_module.restart()
return NLG.RECIPE_OVER_ANSWER return NLG.RECIPE_OVER_ANSWER

View File

@ -79,18 +79,20 @@ class DST:
def __set_steps(self): def __set_steps(self):
dialog_files = [] dialog_files = []
steps = {} steps = {}
for (_, _, filenames) in walk(self.__dialog_path): for (_, _, filenames) in walk(self.__recipe_path):
dialog_files.extend(filenames) dialog_files.extend(filenames)
break break
for dialog_title in dialog_files: for dialog_title in dialog_files:
if dialog_title.startswith(f"{self.__recipe_id:03d}"): if dialog_title.startswith(f"{self.__recipe_id:03d}"):
with open(self.__dialog_path + "/" + dialog_title) as f: with open(self.__recipe_path + dialog_title) as f:
data = json.load(f) data = json.load(f)
for message in data["messages"]: for row in data['content']:
if "inform_instruction" in message["annotations"]: if row['type']=='instruction':
steps[len(steps)] = message["utterance"] steps[len(steps)] = row['text'].split(maxsplit=1)[1]
self.__steps = steps self.__steps = steps
def __set_ingredients(self): def __set_ingredients(self):
dialog_files = [] dialog_files = []
ingredients = [] ingredients = []

View File

@ -1,3 +1,4 @@
import streamlit as st
class NLG: class NLG:
MESSAGE_PROMPT = "Hello! I'm AMUseBot, a virtual cooking assistant. Please tell me the name of the dish that you'd like to prepare today." MESSAGE_PROMPT = "Hello! I'm AMUseBot, a virtual cooking assistant. Please tell me the name of the dish that you'd like to prepare today."
MESSAGE_HI = "Hi! What do you want to make today?" MESSAGE_HI = "Hi! What do you want to make today?"
@ -5,6 +6,7 @@ class NLG:
BYE_ANSWER = "Bye, hope to see you soon!" BYE_ANSWER = "Bye, hope to see you soon!"
RECIPE_OVER_ANSWER = "Congratulations! You finished preparing the dish, bon appetit!" RECIPE_OVER_ANSWER = "Congratulations! You finished preparing the dish, bon appetit!"
NOT_UNDERSTAND_ANSWER = "I'm sorry, I don't understand. Could you rephrase?" NOT_UNDERSTAND_ANSWER = "I'm sorry, I don't understand. Could you rephrase?"
CANNOT_HELP_ANSWER = "I'm sorry I can't help you with that."
@staticmethod @staticmethod
def MESSAGE_INGREDIENTS(ingr_list): def MESSAGE_INGREDIENTS(ingr_list):
@ -24,3 +26,38 @@ class NLG:
suggestions = ", ".join(recipes_list[0:-1]) + f" or {recipes_list[-1]}" suggestions = ", ".join(recipes_list[0:-1]) + f" or {recipes_list[-1]}"
return f"I'm sorry, I don't know a recipe like that. Instead, I can suggest you {suggestions}." return f"I'm sorry, I don't know a recipe like that. Instead, I can suggest you {suggestions}."
def llm_create_response(character, input):
model = st.session_state.characters_dict['model']
prompt = st.session_state.characters_dict['characters'][character]['prompt']
message = [{'role': 'system', 'content': prompt}, {'role': 'user', 'content': input}]
response = st.session_state.openai.ChatCompletion.create(
model=model, messages=message, temperature=1, max_tokens=128
)
rephrased_response = response.choices[0].message.content
return rephrased_response
def llm_rephrase_recipe(character, response):
input = st.session_state.characters_dict['task_paraphrase'] + f'"{response}".' + st.session_state.characters_dict['characters'][character]['task_specification']
try:
return NLG.llm_create_response(character, input)
except:
print('OpenAI API call failed during response paraphrasing! Returning input response')
return response
def llm_substitute_product(character, user_message):
input = st.session_state.characters_dict['task_substitute'] + f'"{user_message}".' + st.session_state.characters_dict['characters'][character]['task_specification']
try:
return NLG.llm_create_response(character, input)
except:
print('OpenAI API call failed during response paraphrasing! Returning input response')
return NLG.CANNOT_HELP_ANSWER

View File

@ -7,7 +7,7 @@ from rank_bm25 import BM25Okapi
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv('.env_template')
INGREDIENTS_RECIPES_MERGED = os.getenv('INGREDIENTS_RECIPES_MERGED') INGREDIENTS_RECIPES_MERGED = os.getenv('INGREDIENTS_RECIPES_MERGED')

View File

@ -9,7 +9,12 @@ import streamlit as st
from PIL import Image from PIL import Image
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
from src.utils.lang import en from src.utils.lang import en
import openai
import copy
import json
import string
import streamlit.components.v1 as components
import re
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
@ -25,11 +30,6 @@ if __name__ == '__main__':
favicon: Path = icons_dir / "favicons/0.png" favicon: Path = icons_dir / "favicons/0.png"
# --- GENERAL SETTINGS --- # --- GENERAL SETTINGS ---
LANG_PL: str = "Pl" LANG_PL: str = "Pl"
AI_MODEL_OPTIONS: list[str] = [
"gpt-3.5-turbo",
"gpt-4",
"gpt-4-32k",
]
CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)} CONFIG = {"page_title": "AMUsebot", "page_icon": Image.open(favicon)}
@ -39,10 +39,12 @@ if __name__ == '__main__':
with open(css_file) as f: with open(css_file) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
load_dotenv() load_dotenv('.env_template')
DIALOG_PATH = os.getenv('DIALOG_PATH') DIALOG_PATH = os.getenv('DIALOG_PATH')
RECIPE_PATH = os.getenv('RECIPE_PATH') RECIPE_PATH = os.getenv('RECIPE_PATH')
CHARACTERS_DICT = os.getenv('CHARACTERS_DICT')
API_KEY = os.getenv('API_KEY')
# Storing The Context # Storing The Context
if "locale" not in st.session_state: if "locale" not in st.session_state:
@ -55,8 +57,6 @@ if __name__ == '__main__':
st.session_state.messages = [] st.session_state.messages = []
if "user_text" not in st.session_state: if "user_text" not in st.session_state:
st.session_state.user_text = "" st.session_state.user_text = ""
if "input_kind" not in st.session_state:
st.session_state.input_kind = st.session_state.locale.input_kind_1
if "seed" not in st.session_state: if "seed" not in st.session_state:
st.session_state.seed = randrange(10 ** 3) # noqa: S311 st.session_state.seed = randrange(10 ** 3) # noqa: S311
if "costs" not in st.session_state: if "costs" not in st.session_state:
@ -67,51 +67,69 @@ if __name__ == '__main__':
st.session_state.dst = DST(recipe_path=RECIPE_PATH, dialog_path=DIALOG_PATH) st.session_state.dst = DST(recipe_path=RECIPE_PATH, dialog_path=DIALOG_PATH)
if "dp" not in st.session_state: if "dp" not in st.session_state:
st.session_state.dp = DP(dst=st.session_state.dst) st.session_state.dp = DP(dst=st.session_state.dst)
if "openai" not in st.session_state:
st.session_state.openai = openai
st.session_state.openai.api_key = API_KEY
if "characters_dict" not in st.session_state:
with open(CHARACTERS_DICT) as f:
st.session_state.characters_dict = json.load(f)
def show_graph(): def mermaid(code: str) -> None:
components.html(
f"""
<pre class="mermaid">
%%{{init: {{'themeVariables': {{ 'edgeLabelBackground': 'transparent'}}}}}}%%
flowchart TD;
{code}
linkStyle default fill:white,color:white,stroke-width:2px,background-color:lime;
</pre>
<script type="module">
import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
mermaid.initialize({{ startOnLoad: true }});
</script>
""", height=1000
)
def graph():
# Create a graphlib graph object # Create a graphlib graph object
if st.session_state.generated: if st.session_state.generated:
user, chatbot = [], [] system = [utterance for utterance in st.session_state.generated][-3:]
graph = graphviz.Digraph() user = [utterance for utterance in st.session_state.past][-2:]
for i in range(len(st.session_state.past)): graph = ""
chatbot.append(st.session_state.generated[i]) for i, utterance in enumerate(system):
user.append(st.session_state.past[i]) utterance = utterance.strip('\n')
for x in range(len(user)): utterance = " ".join([word + '<br>' if i % 5 == 0 and i > 0 else word for i, word in enumerate(utterance.split(" "))])
chatbot_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x].split(' '))] utterance = utterance.replace('\"', '')
user_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.past[x].split(' '))] if i < len(user):
graph.edge(' '.join(chatbot_text), ' '.join(user_text)) user[i] = user[i].strip('\n')
try: user[i] = user[i].replace('\"', '')
graph.edge(' '.join(user_text), ' '.join([word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x + 1].split(' '))])) user[i] = " ".join([word + '<br>' if i % 5 == 0 and i > 0 else word for i, word in enumerate(user[i].split(' '))])
except: graph += f"{string.ascii_uppercase[i]}(\"{utterance}\") --> |{user[i]}| {string.ascii_uppercase[i+1]};"
pass else:
st.graphviz_chart(graph) graph += f"{string.ascii_uppercase[i]}(\"{utterance}\") --> {string.ascii_uppercase[i+1]}(...);style {string.ascii_uppercase[i+1]} fill:none,color:white;"
graph = graph.replace('\n', ' ')#replace(')','').replace('(','')
#print(graph)
return graph
def main() -> None: def main() -> None:
c1, c2 = st.columns(2) c1, c2 = st.columns(2)
with c1, c2: with c1, c2:
st.session_state.input_kind = c2.radio( character_type = c1.selectbox(label=st.session_state.locale.select_placeholder2, key="role",
label=st.session_state.locale.input_kind, options=st.session_state.locale.ai_role_options)
options=(st.session_state.locale.input_kind_1, st.session_state.locale.input_kind_2), st.session_state.dp.character = character_type
horizontal=True, if character_type == 'default':
) st.session_state.dp.llm_rephrasing = False
role_kind = c1.radio( else:
label=st.session_state.locale.radio_placeholder, st.session_state.dp.llm_rephrasing = True
options=(st.session_state.locale.radio_text1, st.session_state.locale.radio_text2),
horizontal=True,
)
if role_kind == st.session_state.locale.radio_text1:
c2.selectbox(label=st.session_state.locale.select_placeholder2, key="role",
options=st.session_state.locale.ai_role_options)
elif role_kind == st.session_state.locale.radio_text2:
c2.text_input(label=st.session_state.locale.select_placeholder3, key="role")
get_user_input() get_user_input()
show_chat_buttons() show_chat_buttons()
show_conversation() show_conversation()
with st.sidebar: with st.sidebar:
show_graph() mermaid(graph())
#show_graph()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,16 +1,6 @@
AI_ROLE_OPTIONS_EN: list[str] = [ AI_ROLE_OPTIONS_EN: list[str] = [
"helpful assistant", "default",
"code assistant", "helpful_chef",
"code reviewer", "ramsay",
"text improver",
"cinema expert",
"sport expert",
"online games expert",
"food recipes expert",
"English grammar expert",
"friendly and helpful teaching assistant",
"laconic assistant",
"helpful, pattern-following assistant",
"translate corporate jargon into plain English",
] ]

View File

@ -12,7 +12,7 @@ from AMUseBotBackend.src.NLU.nlu import NLU
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv('.env_template')
INTENT_DICT_PATH = os.getenv('INTENT_DICT_PATH') INTENT_DICT_PATH = os.getenv('INTENT_DICT_PATH')
MODEL_IDENTIFIER_PATH = os.getenv('MODEL_IDENTIFIER_PATH') MODEL_IDENTIFIER_PATH = os.getenv('MODEL_IDENTIFIER_PATH')
@ -23,7 +23,8 @@ def get_nlu_model(intent_dict_path = INTENT_DICT_PATH, model_identifier_path = M
model_identifier_path=model_identifier_path) model_identifier_path=model_identifier_path)
def clear_chat() -> None: def clear_chat() -> None:
st.session_state.generated = [] st.session_state.generated = ["Hello! I'm AMUseBot, a virtual cooking assistant. Please tell me the name of the dish that you'd like to prepare today."]
st.session_state.dst.restart()
st.session_state.past = [] st.session_state.past = []
st.session_state.messages = [] st.session_state.messages = []
st.session_state.user_text = "" st.session_state.user_text = ""

View File

@ -20,15 +20,8 @@ class Locale:
chat_clear_btn: str chat_clear_btn: str
chat_save_btn: str chat_save_btn: str
speak_btn: str speak_btn: str
input_kind: str
input_kind_1: str
input_kind_2: str
select_placeholder1: str select_placeholder1: str
select_placeholder2: str select_placeholder2: str
select_placeholder3: str
radio_placeholder: str
radio_text1: str
radio_text2: str
stt_placeholder: str stt_placeholder: str
footer_title: str footer_title: str
footer_option0: str footer_option0: str
@ -55,15 +48,8 @@ en = Locale(
chat_clear_btn="Clear", chat_clear_btn="Clear",
chat_save_btn="Save", chat_save_btn="Save",
speak_btn="Push to Speak", speak_btn="Push to Speak",
input_kind="Input Kind",
input_kind_1="Text",
input_kind_2="Voice [test mode]",
select_placeholder1="Select Model", select_placeholder1="Select Model",
select_placeholder2="Select Role", select_placeholder2="Select Role",
select_placeholder3="Create Role",
radio_placeholder="Role Interaction",
radio_text1="Select",
radio_text2="Create",
stt_placeholder="To Hear The Voice Of AI Press Play", stt_placeholder="To Hear The Voice Of AI Press Play",
footer_title="Support & Feedback", footer_title="Support & Feedback",
footer_option0="Chat", footer_option0="Chat",

21
characters_dict.json Normal file
View File

@ -0,0 +1,21 @@
{
"task_paraphrase": "You're currently reading a step of a recipe, paraphrase it so that it matches your character: ",
"task_substitute": "A user has just asked for a substitute for a missing ingredient, answer him according to your character in one short sentence with at most 3 alternatives: ",
"model": "gpt-3.5-turbo-0613",
"characters": {
"default": {
"prompt": "",
"task_specification": ""
},
"helpful_chef": {
"prompt": "You're a master chef known for treating everyone like your equal. ",
"task_specification": " Give your answer as a natural sounding, full English sentence. Keep the sentence length similar and do not make the language flowery."
},
"ramsay": {
"prompt": "You're Gordon Ramsay, a famous British chef known for his short temper and routinely insulting people. ",
"task_specification": ""
}
}
}

Some files were not shown because too many files have changed in this diff Show More