add generative component to dst/dp
This commit is contained in:
parent
31bf14063d
commit
bd96acf7ea
@ -4,12 +4,39 @@ from AMUseBotBackend.src.NLG.nlg import NLG
|
|||||||
from AMUseBotBackend.src.tools.search import search_recipe
|
from AMUseBotBackend.src.tools.search import search_recipe
|
||||||
|
|
||||||
import AMUseBotBackend.consts as c
|
import AMUseBotBackend.consts as c
|
||||||
|
import openai
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DP:
|
class DP:
|
||||||
|
|
||||||
def __init__(self, dst: DST):
|
def __init__(self, dst: DST, llm_rephrasing=False, character='default'): #TODO: a way to set llm_rephrasing status and a character
|
||||||
self.dst_module = dst
|
self.dst_module = dst
|
||||||
|
self.llm_rephrasing = llm_rephrasing
|
||||||
|
with open('ai_talks/AMUseBotBackend/utils/characters_dict.json') as f:
|
||||||
|
characters_dict = json.load(f)
|
||||||
|
self.character = characters_dict[character]
|
||||||
|
|
||||||
|
|
||||||
|
def llm_rephrase(self, character, response):
|
||||||
|
model = character['model']
|
||||||
|
openai.api_key = character['api_key']
|
||||||
|
prompt = character['prompt']
|
||||||
|
input = character['leftside_input'] + response + character['rightside_input']
|
||||||
|
|
||||||
|
message = [{'role': 'system', 'content': prompt}, {'role': 'user', 'content': input}]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = openai.ChatCompletion.create(
|
||||||
|
model=model, messages=message, temperature=1, max_tokens=128
|
||||||
|
)
|
||||||
|
rephrased_response = response.choices[0].message.content
|
||||||
|
except:
|
||||||
|
print('OpenAI API call failed during response paraphrasing! Returning input response')
|
||||||
|
rephrased_response = response
|
||||||
|
|
||||||
|
return rephrased_response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -31,8 +58,14 @@ class DP:
|
|||||||
if found_recipe:
|
if found_recipe:
|
||||||
recipe_name = self.dst_module.set_recipe(found_recipe)
|
recipe_name = self.dst_module.set_recipe(found_recipe)
|
||||||
self.dst_module.set_next_step()
|
self.dst_module.set_next_step()
|
||||||
|
if self.llm_rephrasing:
|
||||||
|
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
|
||||||
|
+ self.llm_rephrase(self.character, self.dst_module.generate_state(c.STEPS_KEY)[
|
||||||
|
self.dst_module.generate_state(c.CURR_STEP_KEY)])
|
||||||
|
else:
|
||||||
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
|
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
|
||||||
+ self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
|
+ self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
|
||||||
|
|
||||||
if not found_recipe:
|
if not found_recipe:
|
||||||
return NLG.MESSAGE_NOT_UNDERSTAND_SUGGEST_RECIPE(self.dst_module.get_random_recipes(3))
|
return NLG.MESSAGE_NOT_UNDERSTAND_SUGGEST_RECIPE(self.dst_module.get_random_recipes(3))
|
||||||
# not understand ask recipe
|
# not understand ask recipe
|
||||||
@ -51,6 +84,9 @@ class DP:
|
|||||||
or "req_instruction" in intents):
|
or "req_instruction" in intents):
|
||||||
next_step = self.dst_module.set_next_step()
|
next_step = self.dst_module.set_next_step()
|
||||||
if (next_step):
|
if (next_step):
|
||||||
|
if self.llm_rephrasing:
|
||||||
|
return self.llm_rephrase(self.character, self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)])
|
||||||
|
else:
|
||||||
return self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
|
return self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
|
||||||
if (not next_step):
|
if (not next_step):
|
||||||
self.dst_module.restart()
|
self.dst_module.restart()
|
||||||
|
@ -79,18 +79,20 @@ class DST:
|
|||||||
def __set_steps(self):
|
def __set_steps(self):
|
||||||
dialog_files = []
|
dialog_files = []
|
||||||
steps = {}
|
steps = {}
|
||||||
for (_, _, filenames) in walk(self.__dialog_path):
|
for (_, _, filenames) in walk(self.__recipe_path):
|
||||||
dialog_files.extend(filenames)
|
dialog_files.extend(filenames)
|
||||||
break
|
break
|
||||||
for dialog_title in dialog_files:
|
for dialog_title in dialog_files:
|
||||||
if dialog_title.startswith(f"{self.__recipe_id:03d}"):
|
if dialog_title.startswith(f"{self.__recipe_id:03d}"):
|
||||||
with open(self.__dialog_path + "/" + dialog_title) as f:
|
with open(self.__recipe_path + dialog_title) as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
for message in data["messages"]:
|
for row in data['content']:
|
||||||
if "inform_instruction" in message["annotations"]:
|
if row['type']=='instruction':
|
||||||
steps[len(steps)] = message["utterance"]
|
steps[len(steps)] = row['text'].split(maxsplit=1)[1]
|
||||||
|
|
||||||
self.__steps = steps
|
self.__steps = steps
|
||||||
|
|
||||||
|
|
||||||
def __set_ingredients(self):
|
def __set_ingredients(self):
|
||||||
dialog_files = []
|
dialog_files = []
|
||||||
ingredients = []
|
ingredients = []
|
||||||
|
Loading…
Reference in New Issue
Block a user