add generative component to dst/dp

This commit is contained in:
“Kacper 2023-06-16 21:49:27 +02:00
parent 31bf14063d
commit bd96acf7ea
2 changed files with 48 additions and 10 deletions

View File

@ -4,12 +4,39 @@ from AMUseBotBackend.src.NLG.nlg import NLG
from AMUseBotBackend.src.tools.search import search_recipe
import AMUseBotBackend.consts as c
import openai
import json
class DP:
def __init__(self, dst: DST):
def __init__(self, dst: DST, llm_rephrasing=False, character='default'): #TODO: a way to set llm_rephrasing status and a character
self.dst_module = dst
self.llm_rephrasing = llm_rephrasing
with open('ai_talks/AMUseBotBackend/utils/characters_dict.json') as f:
characters_dict = json.load(f)
self.character = characters_dict[character]
def llm_rephrase(self, character, response):
model = character['model']
openai.api_key = character['api_key']
prompt = character['prompt']
input = character['leftside_input'] + response + character['rightside_input']
message = [{'role': 'system', 'content': prompt}, {'role': 'user', 'content': input}]
try:
response = openai.ChatCompletion.create(
model=model, messages=message, temperature=1, max_tokens=128
)
rephrased_response = response.choices[0].message.content
except:
print('OpenAI API call failed during response paraphrasing! Returning input response')
rephrased_response = response
return rephrased_response
@ -31,8 +58,14 @@ class DP:
if found_recipe:
recipe_name = self.dst_module.set_recipe(found_recipe)
self.dst_module.set_next_step()
if self.llm_rephrasing:
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
+ self.llm_rephrase(self.character, self.dst_module.generate_state(c.STEPS_KEY)[
self.dst_module.generate_state(c.CURR_STEP_KEY)])
else:
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
+ self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
if not found_recipe:
return NLG.MESSAGE_NOT_UNDERSTAND_SUGGEST_RECIPE(self.dst_module.get_random_recipes(3))
# not understand ask recipe
@ -51,6 +84,9 @@ class DP:
or "req_instruction" in intents):
next_step = self.dst_module.set_next_step()
if (next_step):
if self.llm_rephrasing:
return self.llm_rephrase(self.character, self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)])
else:
return self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
if (not next_step):
self.dst_module.restart()

View File

@ -79,18 +79,20 @@ class DST:
def __set_steps(self):
dialog_files = []
steps = {}
for (_, _, filenames) in walk(self.__dialog_path):
for (_, _, filenames) in walk(self.__recipe_path):
dialog_files.extend(filenames)
break
for dialog_title in dialog_files:
if dialog_title.startswith(f"{self.__recipe_id:03d}"):
with open(self.__dialog_path + "/" + dialog_title) as f:
with open(self.__recipe_path + dialog_title) as f:
data = json.load(f)
for message in data["messages"]:
if "inform_instruction" in message["annotations"]:
steps[len(steps)] = message["utterance"]
for row in data['content']:
if row['type']=='instruction':
steps[len(steps)] = row['text'].split(maxsplit=1)[1]
self.__steps = steps
def __set_ingredients(self):
dialog_files = []
ingredients = []