add generative component to dst/dp
This commit is contained in:
parent
31bf14063d
commit
bd96acf7ea
@ -4,12 +4,39 @@ from AMUseBotBackend.src.NLG.nlg import NLG
|
||||
from AMUseBotBackend.src.tools.search import search_recipe
|
||||
|
||||
import AMUseBotBackend.consts as c
|
||||
import openai
|
||||
import json
|
||||
|
||||
|
||||
|
||||
class DP:
|
||||
|
||||
def __init__(self, dst: DST):
|
||||
def __init__(self, dst: DST, llm_rephrasing=False, character='default'): #TODO: a way to set llm_rephrasing status and a character
|
||||
self.dst_module = dst
|
||||
self.llm_rephrasing = llm_rephrasing
|
||||
with open('ai_talks/AMUseBotBackend/utils/characters_dict.json') as f:
|
||||
characters_dict = json.load(f)
|
||||
self.character = characters_dict[character]
|
||||
|
||||
|
||||
def llm_rephrase(self, character, response):
|
||||
model = character['model']
|
||||
openai.api_key = character['api_key']
|
||||
prompt = character['prompt']
|
||||
input = character['leftside_input'] + response + character['rightside_input']
|
||||
|
||||
message = [{'role': 'system', 'content': prompt}, {'role': 'user', 'content': input}]
|
||||
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=model, messages=message, temperature=1, max_tokens=128
|
||||
)
|
||||
rephrased_response = response.choices[0].message.content
|
||||
except:
|
||||
print('OpenAI API call failed during response paraphrasing! Returning input response')
|
||||
rephrased_response = response
|
||||
|
||||
return rephrased_response
|
||||
|
||||
|
||||
|
||||
@ -31,8 +58,14 @@ class DP:
|
||||
if found_recipe:
|
||||
recipe_name = self.dst_module.set_recipe(found_recipe)
|
||||
self.dst_module.set_next_step()
|
||||
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
|
||||
+ self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
|
||||
if self.llm_rephrasing:
|
||||
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
|
||||
+ self.llm_rephrase(self.character, self.dst_module.generate_state(c.STEPS_KEY)[
|
||||
self.dst_module.generate_state(c.CURR_STEP_KEY)])
|
||||
else:
|
||||
return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \
|
||||
+ self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
|
||||
|
||||
if not found_recipe:
|
||||
return NLG.MESSAGE_NOT_UNDERSTAND_SUGGEST_RECIPE(self.dst_module.get_random_recipes(3))
|
||||
# not understand ask recipe
|
||||
@ -51,7 +84,10 @@ class DP:
|
||||
or "req_instruction" in intents):
|
||||
next_step = self.dst_module.set_next_step()
|
||||
if (next_step):
|
||||
return self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
|
||||
if self.llm_rephrasing:
|
||||
return self.llm_rephrase(self.character, self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)])
|
||||
else:
|
||||
return self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]
|
||||
if (not next_step):
|
||||
self.dst_module.restart()
|
||||
return NLG.RECIPE_OVER_ANSWER
|
||||
|
@ -79,18 +79,20 @@ class DST:
|
||||
def __set_steps(self):
|
||||
dialog_files = []
|
||||
steps = {}
|
||||
for (_, _, filenames) in walk(self.__dialog_path):
|
||||
for (_, _, filenames) in walk(self.__recipe_path):
|
||||
dialog_files.extend(filenames)
|
||||
break
|
||||
for dialog_title in dialog_files:
|
||||
if dialog_title.startswith(f"{self.__recipe_id:03d}"):
|
||||
with open(self.__dialog_path + "/" + dialog_title) as f:
|
||||
with open(self.__recipe_path + dialog_title) as f:
|
||||
data = json.load(f)
|
||||
for message in data["messages"]:
|
||||
if "inform_instruction" in message["annotations"]:
|
||||
steps[len(steps)] = message["utterance"]
|
||||
for row in data['content']:
|
||||
if row['type']=='instruction':
|
||||
steps[len(steps)] = row['text'].split(maxsplit=1)[1]
|
||||
|
||||
self.__steps = steps
|
||||
|
||||
|
||||
def __set_ingredients(self):
|
||||
dialog_files = []
|
||||
ingredients = []
|
||||
|
Loading…
Reference in New Issue
Block a user