From bd96acf7ea02a71a01122f9cfa9036aea33ac57c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CKacper?= Date: Fri, 16 Jun 2023 21:49:27 +0200 Subject: [PATCH] add generative component to dst/dp --- ai_talks/AMUseBotBackend/src/DP/dp.py | 44 ++++++++++++++++++++++--- ai_talks/AMUseBotBackend/src/DST/dst.py | 14 ++++---- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/ai_talks/AMUseBotBackend/src/DP/dp.py b/ai_talks/AMUseBotBackend/src/DP/dp.py index c362169..5d5543f 100644 --- a/ai_talks/AMUseBotBackend/src/DP/dp.py +++ b/ai_talks/AMUseBotBackend/src/DP/dp.py @@ -4,12 +4,39 @@ from AMUseBotBackend.src.NLG.nlg import NLG from AMUseBotBackend.src.tools.search import search_recipe import AMUseBotBackend.consts as c +import openai +import json + class DP: - def __init__(self, dst: DST): + def __init__(self, dst: DST, llm_rephrasing=False, character='default'): #TODO: a way to set llm_rephrasing status and a character self.dst_module = dst + self.llm_rephrasing = llm_rephrasing + with open('ai_talks/AMUseBotBackend/utils/characters_dict.json') as f: + characters_dict = json.load(f) + self.character = characters_dict[character] + + + def llm_rephrase(self, character, response): + model = character['model'] + openai.api_key = character['api_key'] + prompt = character['prompt'] + input = character['leftside_input'] + response + character['rightside_input'] + + message = [{'role': 'system', 'content': prompt}, {'role': 'user', 'content': input}] + + try: + response = openai.ChatCompletion.create( + model=model, messages=message, temperature=1, max_tokens=128 + ) + rephrased_response = response.choices[0].message.content + except: + print('OpenAI API call failed during response paraphrasing! Returning input response') + rephrased_response = response + + return rephrased_response @@ -31,8 +58,14 @@ class DP: if found_recipe: recipe_name = self.dst_module.set_recipe(found_recipe) self.dst_module.set_next_step() - return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \ - + self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)] + if self.llm_rephrasing: + return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \ + + self.llm_rephrase(self.character, self.dst_module.generate_state(c.STEPS_KEY)[ + self.dst_module.generate_state(c.CURR_STEP_KEY)]) + else: + return NLG.MESSAGE_CHOOSEN_RECIPE(recipe_name=recipe_name) + "\n" \ + + self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)] + if not found_recipe: return NLG.MESSAGE_NOT_UNDERSTAND_SUGGEST_RECIPE(self.dst_module.get_random_recipes(3)) # not understand ask recipe @@ -51,7 +84,10 @@ class DP: or "req_instruction" in intents): next_step = self.dst_module.set_next_step() if (next_step): - return self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)] + if self.llm_rephrasing: + return self.llm_rephrase(self.character, self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)]) + else: + return self.dst_module.generate_state(c.STEPS_KEY)[self.dst_module.generate_state(c.CURR_STEP_KEY)] if (not next_step): self.dst_module.restart() return NLG.RECIPE_OVER_ANSWER diff --git a/ai_talks/AMUseBotBackend/src/DST/dst.py b/ai_talks/AMUseBotBackend/src/DST/dst.py index 3eac32d..64eb9fb 100644 --- a/ai_talks/AMUseBotBackend/src/DST/dst.py +++ b/ai_talks/AMUseBotBackend/src/DST/dst.py @@ -77,19 +77,21 @@ class DST: return [self.recipes[id] for id in recipes_id] def __set_steps(self): - dialog_files = [] + dialog_files = [] steps = {} - for (_, _, filenames) in walk(self.__dialog_path): + for (_, _, filenames) in walk(self.__recipe_path): dialog_files.extend(filenames) break for dialog_title in dialog_files: if dialog_title.startswith(f"{self.__recipe_id:03d}"): - with open(self.__dialog_path + "/" + dialog_title) as f: + with open(self.__recipe_path + dialog_title) as f: data = json.load(f) - for message in data["messages"]: - if "inform_instruction" in message["annotations"]: - steps[len(steps)] = message["utterance"] + for row in data['content']: + if row['type']=='instruction': + steps[len(steps)] = row['text'].split(maxsplit=1)[1] + self.__steps = steps + def __set_ingredients(self): dialog_files = []