change character format
This commit is contained in:
parent
04868e022d
commit
21710fccd2
@ -3,4 +3,5 @@ DIALOG_PATH=AMUseBotFront/ai_talks/AMUseBotBackend/dialog/
|
|||||||
INTENT_DICT_PATH=ai_talks/AMUseBotBackend/utils/intent_dict.json
|
INTENT_DICT_PATH=ai_talks/AMUseBotBackend/utils/intent_dict.json
|
||||||
MODEL_IDENTIFIER_PATH=ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt
|
MODEL_IDENTIFIER_PATH=ai_talks/AMUseBotBackend/models/NLU/roberta-base-cookdial.txt
|
||||||
INGREDIENTS_RECIPES_MERGED=
|
INGREDIENTS_RECIPES_MERGED=
|
||||||
|
CHARACTERS_DICT=
|
||||||
API_KEY=
|
API_KEY=
|
@ -11,33 +11,10 @@ import streamlit as st
|
|||||||
|
|
||||||
class DP:
|
class DP:
|
||||||
|
|
||||||
def __init__(self, dst: DST, llm_rephrasing=True, character='ramsay'): #TODO: a way to set llm_rephrasing status and a character
|
def __init__(self, dst: DST, llm_rephrasing=True, character='default'): #TODO: a way to set llm_rephrasing status and a character
|
||||||
self.dst_module = dst
|
self.dst_module = dst
|
||||||
self.llm_rephrasing = llm_rephrasing
|
self.llm_rephrasing = llm_rephrasing
|
||||||
with open('ai_talks/AMUseBotBackend/utils/characters_dict.json') as f:
|
self.character = character
|
||||||
characters_dict = json.load(f)
|
|
||||||
self.character = characters_dict[character]
|
|
||||||
|
|
||||||
|
|
||||||
def llm_rephrase(self, character, response):
|
|
||||||
model = character['model']
|
|
||||||
prompt = character['prompt']
|
|
||||||
input = character['leftside_input'] + response + character['rightside_input']
|
|
||||||
|
|
||||||
message = [{'role': 'system', 'content': prompt}, {'role': 'user', 'content': input}]
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = st.session_state.openai.ChatCompletion.create(
|
|
||||||
model=model, messages=message, temperature=1, max_tokens=128
|
|
||||||
)
|
|
||||||
rephrased_response = response.choices[0].message.content
|
|
||||||
except:
|
|
||||||
print('OpenAI API call failed during response paraphrasing! Returning input response')
|
|
||||||
rephrased_response = response
|
|
||||||
|
|
||||||
return rephrased_response
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_response(self, intents: List[str]) -> str:
|
def generate_response(self, intents: List[str]) -> str:
|
||||||
|
|
||||||
@ -59,8 +36,7 @@ class DP:
|
|||||||
self.dst_module.set_next_step()
|
self.dst_module.set_next_step()
|
||||||
if self.llm_rephrasing:
|
if self.llm_rephrasing:
|
||||||
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
|
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
|
||||||
+ self.llm_rephrase(self.character, self.dst_module.generate_state(c.STEPS_KEY)[
|
+ NLG.llm_rephrase_recipe(self.character, self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)])
|
||||||
self.dst_module.generate_state(c.CURR_STEP_KEY)])
|
|
||||||
else:
|
else:
|
||||||
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
|
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
|
||||||
+ self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
|
+ self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
|
||||||
@ -73,6 +49,8 @@ class DP:
|
|||||||
# Recipe choosen
|
# Recipe choosen
|
||||||
if (None != self.dst_module.generate_state(c.RECIPE_ID_KEY) and "" != self.dst_module.generate_state(
|
if (None != self.dst_module.generate_state(c.RECIPE_ID_KEY) and "" != self.dst_module.generate_state(
|
||||||
c.RECIPE_ID_KEY)):
|
c.RECIPE_ID_KEY)):
|
||||||
|
if ("req_substitute" in intents):
|
||||||
|
return NLG.llm_substitute_product(self.character, self.dst_module.generate_state(c.DIALOG_HISTORY_KEY)[-1][c.USER_MESSAGE_KEY])
|
||||||
if ("req_ingredient_list" in intents
|
if ("req_ingredient_list" in intents
|
||||||
or "req_ingredient" in intents):
|
or "req_ingredient" in intents):
|
||||||
return NLG.MESSAGE_INGREDIENTS(self.dst_module.generate_state(c.INGREDIENTS_KEY))
|
return NLG.MESSAGE_INGREDIENTS(self.dst_module.generate_state(c.INGREDIENTS_KEY))
|
||||||
@ -84,7 +62,7 @@ class DP:
|
|||||||
next_step = self.dst_module.set_next_step()
|
next_step = self.dst_module.set_next_step()
|
||||||
if (next_step):
|
if (next_step):
|
||||||
if self.llm_rephrasing:
|
if self.llm_rephrasing:
|
||||||
return self.llm_rephrase(self.character, self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)])
|
return NLG.llm_rephrase_recipe(self.character, self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)])
|
||||||
else:
|
else:
|
||||||
return self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
|
return self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
|
||||||
if (not next_step):
|
if (not next_step):
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import streamlit as st
|
||||||
|
|
||||||
class NLG:
|
class NLG:
|
||||||
MESSAGE_PROMPT = "Hello! I'm AMUseBot, a virtual cooking assistant. Please tell me the name of the dish that you'd like to prepare today."
|
MESSAGE_PROMPT = "Hello! I'm AMUseBot, a virtual cooking assistant. Please tell me the name of the dish that you'd like to prepare today."
|
||||||
MESSAGE_HI = "Hi! What do you want to make today?"
|
MESSAGE_HI = "Hi! What do you want to make today?"
|
||||||
@ -5,6 +7,7 @@ class NLG:
|
|||||||
BYE_ANSWER = "Bye, hope to see you soon!"
|
BYE_ANSWER = "Bye, hope to see you soon!"
|
||||||
RECIPE_OVER_ANSWER = "Congratulations! You finished preparing the dish, bon appetit!"
|
RECIPE_OVER_ANSWER = "Congratulations! You finished preparing the dish, bon appetit!"
|
||||||
NOT_UNDERSTAND_ANSWER = "I'm sorry, I don't understand. Could you rephrase?"
|
NOT_UNDERSTAND_ANSWER = "I'm sorry, I don't understand. Could you rephrase?"
|
||||||
|
CANNOT_HELP_ANSWER = "I'm sorry I can't help you with that."
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def MESSAGE_INGREDIENTS(ingr_list):
|
def MESSAGE_INGREDIENTS(ingr_list):
|
||||||
@ -24,3 +27,38 @@ class NLG:
|
|||||||
suggestions = ", ".join(recipes_list[0:-1]) + f" or {recipes_list[-1]}"
|
suggestions = ", ".join(recipes_list[0:-1]) + f" or {recipes_list[-1]}"
|
||||||
|
|
||||||
return f"I'm sorry, I don't know a recipe like that. Instead, I can suggest you {suggestions}."
|
return f"I'm sorry, I don't know a recipe like that. Instead, I can suggest you {suggestions}."
|
||||||
|
|
||||||
|
|
||||||
|
def llm_create_response(character, input):
|
||||||
|
model = st.session_state.characters_dict['model']
|
||||||
|
prompt = st.session_state.characters_dict['characters'][character]['prompt']
|
||||||
|
|
||||||
|
message = [{'role': 'system', 'content': prompt}, {'role': 'user', 'content': input}]
|
||||||
|
|
||||||
|
response = st.session_state.openai.ChatCompletion.create(
|
||||||
|
model=model, messages=message, temperature=1, max_tokens=128
|
||||||
|
)
|
||||||
|
rephrased_response = response.choices[0].message.content
|
||||||
|
|
||||||
|
return rephrased_response
|
||||||
|
|
||||||
|
def llm_rephrase_recipe(character, response):
|
||||||
|
|
||||||
|
input = st.session_state.characters_dict['task_paraphrase'] + f'"{response}".' + st.session_state.characters_dict['characters'][character]['task_specification']
|
||||||
|
try:
|
||||||
|
return NLG.llm_create_response(character, input)
|
||||||
|
except:
|
||||||
|
print('OpenAI API call failed during response paraphrasing! Returning input response')
|
||||||
|
return response
|
||||||
|
|
||||||
|
def llm_substitute_product(character, user_message):
|
||||||
|
|
||||||
|
input = st.session_state.characters_dict['task_substitute'] + f'"{user_message}".'
|
||||||
|
|
||||||
|
try:
|
||||||
|
return NLG.llm_create_response(character, input)
|
||||||
|
except:
|
||||||
|
print('OpenAI API call failed during response paraphrasing! Returning input response')
|
||||||
|
return NLG.CANNOT_HELP_ANSWER
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,16 @@
|
|||||||
{
|
{
|
||||||
"default": {"model": "gpt-3.5-turbo-0613", "prompt":
|
"task_paraphrase": "You're currently reading a step of a recipe, paraphrese it so that it matches your charater: ",
|
||||||
"You're a master chef known for treating everyone like your equal. You're currently reading steps of a recipe to your apprentice.",
|
"task_substitute": "A user has just asked for a substitute for a missing ingredient, answer him according to your character: ",
|
||||||
"leftside_input": "Rephrase this step of a recipe to make it sound more like a natural, full English sentence: '",
|
"model": "gpt-3.5-turbo-0613",
|
||||||
"rightside_input": "'."},
|
"characters": {
|
||||||
|
"default": {
|
||||||
|
"prompt": "You're a master chef known for treating everyone like your equal. ",
|
||||||
|
"task_specification": " Give your answer as a natural sounding, full English sentence."
|
||||||
|
|
||||||
"ramsay": {"model": "gpt-3.5-turbo-0613", "prompt":
|
},
|
||||||
"You're Gordon Ramsay, a famous British chef known for his short temper and routinely insulting people. You're currently reading steps of a recipe to your apprentice.",
|
"ramsay": {
|
||||||
"leftside_input": "Rephrase this step of a recipe to make it sound as if you said it, in your characteristic rude fashion: '",
|
"prompt": "You're Gordon Ramsay, a famous British chef known for his short temper and routinely insulting people. ",
|
||||||
"rightside_input": "'."}
|
"task_specification": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -10,6 +10,8 @@ from PIL import Image
|
|||||||
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
|
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
|
||||||
from src.utils.lang import en
|
from src.utils.lang import en
|
||||||
import openai
|
import openai
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@ -44,6 +46,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
DIALOG_PATH = os.getenv('DIALOG_PATH')
|
DIALOG_PATH = os.getenv('DIALOG_PATH')
|
||||||
RECIPE_PATH = os.getenv('RECIPE_PATH')
|
RECIPE_PATH = os.getenv('RECIPE_PATH')
|
||||||
|
CHARACTERS_DICT = os.getenv('CHARACTERS_DICT')
|
||||||
API_KEY = os.getenv('API_KEY')
|
API_KEY = os.getenv('API_KEY')
|
||||||
|
|
||||||
# Storing The Context
|
# Storing The Context
|
||||||
@ -72,24 +75,25 @@ if __name__ == '__main__':
|
|||||||
if "openai" not in st.session_state:
|
if "openai" not in st.session_state:
|
||||||
st.session_state.openai = openai
|
st.session_state.openai = openai
|
||||||
st.session_state.openai.api_key = API_KEY
|
st.session_state.openai.api_key = API_KEY
|
||||||
|
if "characters_dict" not in st.session_state:
|
||||||
|
with open(CHARACTERS_DICT) as f:
|
||||||
|
st.session_state.characters_dict = json.load(f)
|
||||||
|
|
||||||
def show_graph():
|
def show_graph():
|
||||||
# Create a graphlib graph object
|
# Create a graphlib graph object
|
||||||
if st.session_state.generated:
|
if st.session_state.generated:
|
||||||
user, chatbot = [], []
|
user, chatbot = [], []
|
||||||
graph = graphviz.Digraph()
|
graph = graphviz.Digraph()
|
||||||
for i in range(len(st.session_state.past)):
|
chatbot = copy.deepcopy(st.session_state.generated)
|
||||||
chatbot.append(st.session_state.generated[i])
|
user = copy.deepcopy(st.session_state.past)
|
||||||
user.append(st.session_state.past[i])
|
|
||||||
for x in range(len(user)):
|
for x in range(len(user)):
|
||||||
chatbot_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x].split(' '))]
|
chatbot_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(chatbot[x].split(' '))]
|
||||||
user_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.past[x].split(' '))]
|
user_text = [word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(user[x].split(' '))]
|
||||||
graph.edge(' '.join(chatbot_text), ' '.join(user_text))
|
graph.edge(' '.join(chatbot_text), ' '.join(user_text))
|
||||||
try:
|
try:
|
||||||
graph.edge(' '.join(user_text), ' '.join([word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(st.session_state.generated[x + 1].split(' '))]))
|
graph.edge(' '.join(user_text), ' '.join([word + '\n' if i % 5 == 0 and i > 0 else word for i, word in enumerate(chatbot[x + 1].split(' '))]))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
st.graphviz_chart(graph)
|
st.graphviz_chart(graph)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user