Compare commits

..

4 Commits

Author SHA1 Message Date
PawelDopierala
893733025b Change to TemplateNLG 2024-06-04 13:07:41 +02:00
99306f0532 Merge remote-tracking branch 'origin/test_branch' into test_branch 2024-06-04 11:44:58 +02:00
7a39f24359 test_learning 2024-06-04 11:43:14 +02:00
Maciej Matusz
7306fdb83a test- commit 2024-06-04 10:56:40 +02:00
8 changed files with 45 additions and 77 deletions

1
ConvLab-3 Submodule

@ -0,0 +1 @@
Subproject commit 60f4e5641f93e99b8d61b49cf5fd6dc818a83c4c

View File

@ -1,20 +1,12 @@
from collections import defaultdict from collections import defaultdict
import json import json
import random
import string
from copy import deepcopy from copy import deepcopy
from convlab.policy.policy import Policy from convlab.policy.policy import Policy
db_path = './hotels_data.json' db_path = './hotels_data.json'
def generate_reference_number(length=8):
letters_and_digits = string.ascii_uppercase + string.digits
reference_number = ''.join(random.choice(letters_and_digits) for _ in range(length))
return reference_number
class DialoguePolicy(Policy): class DialoguePolicy(Policy):
info_dict = None
def __init__(self): def __init__(self):
Policy.__init__(self) Policy.__init__(self)
self.db = self.load_database(db_path) self.db = self.load_database(db_path)
@ -48,8 +40,7 @@ class DialoguePolicy(Policy):
if any(True for slots in user_action.values() for (slot, _) in slots if if any(True for slots in user_action.values() for (slot, _) in slots if
slot in ['book stay', 'book day', 'book people']): slot in ['book stay', 'book day', 'book people']):
if self.results: if self.results:
reference_number = generate_reference_number() system_action = {('Booking', 'Book'): [["Ref", self.results[0].get('Ref', 'N/A')]]}
system_action = {('Booking', 'Book'): [["Ref", reference_number]]}
system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for
slot, value in slots] slot, value in slots]
@ -58,29 +49,28 @@ class DialoguePolicy(Policy):
def update_system_action(self, user_act, user_action, state, system_action): def update_system_action(self, user_act, user_action, state, system_action):
domain, intent = user_act domain, intent = user_act
if domain in state['belief_state']: constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != '']
constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != ''] # print(f"Constraints: {constraints}")
# print(f"Constraints: {constraints}") self.results = deepcopy(self.query(domain.lower(), constraints))
self.results = deepcopy(self.query(domain.lower(), constraints)) # print(f"Query results: {self.results}")
# print(f"Query results: {self.results}")
if intent == 'request': if intent == 'request':
if len(self.results) == 0: if len(self.results) == 0:
system_action[(domain, 'NoOffer')] = [] system_action[(domain, 'NoOffer')] = []
else: else:
for slot in user_action[user_act]: for slot in user_action[user_act]:
if slot[0] in self.results[0]: if slot[0] in self.results[0]:
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')]) system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
elif intent == 'inform': elif intent == 'inform':
if len(self.results) == 0: if len(self.results) == 0:
system_action[(domain, 'NoOffer')] = [] system_action[(domain, 'NoOffer')] = []
else: else:
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))]) system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
choice = self.results[0] choice = self.results[0]
if domain in ["hotel"]: if domain in ["hotel"]:
system_action[(domain, 'Recommend')].append(['Name', choice['name']]) system_action[(domain, 'Recommend')].append(['Name', choice['name']])
for slot in state['belief_state'][domain]['info']: for slot in state['belief_state'][domain]['info']:
if choice.get(slot): if choice.get(slot):
state['belief_state'][domain]['info'][slot] = choice[slot] state['belief_state'][domain]['info'][slot] = choice[slot]

View File

@ -27,12 +27,11 @@ def default_state():
'user_action': [], 'user_action': [],
'system_action': [], 'system_action': [],
'terminated': False, 'terminated': False,
'booked': {} 'booked': []
} }
class DialogueStateTracker(DST): class DialogueStateTracker(DST):
info_dict = None
def __init__(self): def __init__(self):
DST.__init__(self) DST.__init__(self)
self.state = default_state() self.state = default_state()

View File

@ -33,15 +33,14 @@ class MachineLearningNLG:
return translated_response return translated_response
def generate(self, action): def generate(self, action):
# Przyjmujemy, że 'action' jest formatowanym stringiem, który jest przekazywany do self.nlg
return self.nlg(action) return self.nlg(action)
def init_session(self): def init_session(self):
pass # Dodanie pustej metody init_session pass
# Przykład użycia # Przykład użycia
if __name__ == "__main__": if __name__ == "__main__":
nlg = MachineLearningNLG() nlg = MachineLearningNLG()
system_act = "inform(date.from=15.07, date.to=22.07)" system_act = "inform(hotel='Four Seasons Hotel')"
print(nlg.nlg(system_act)) print(nlg.nlg(system_act))

17
Main.py
View File

@ -1,22 +1,17 @@
from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer
from DialoguePolicy import DialoguePolicy from DialoguePolicy import DialoguePolicy
from DialogueStateTracker import DialogueStateTracker from DialogueStateTracker import DialogueStateTracker
from NaturalLanguageGeneration import NaturalLanguageGeneration
from convlab.dialog_agent import PipelineAgent from convlab.dialog_agent import PipelineAgent
import warnings from MachineLearningNLG import MachineLearningNLG # Importujemy nowy komponent NLG
warnings.filterwarnings("ignore") from convlab.nlg.template.multiwoz import TemplateNLG
if __name__ == "__main__": if __name__ == "__main__":
text = "chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum"
nlu = NaturalLanguageAnalyzer() nlu = NaturalLanguageAnalyzer()
dst = DialogueStateTracker() dst = DialogueStateTracker()
policy = DialoguePolicy() policy = DialoguePolicy()
nlg = NaturalLanguageGeneration() nlg = TemplateNLG(is_user=False)
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys') agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
response = agent.response(text)
print("Witam, jestem systemem do rezerwowania pokoi hotelowych. W czym mogę Ci pomóc?") print(response)
while True:
text = input()
response = agent.response(text)
print(response)

View File

@ -14,7 +14,6 @@ def translate_text(text, target_language='en'):
class NaturalLanguageAnalyzer: class NaturalLanguageAnalyzer:
info_dict = None
def predict(self, text, context=None): def predict(self, text, context=None):
# Inicjalizacja modelu NLU # Inicjalizacja modelu NLU
model_name = "ConvLab/t5-small-nlu-multiwoz21" model_name = "ConvLab/t5-small-nlu-multiwoz21"

View File

@ -1,28 +1,12 @@
import requests import re
from convlab.nlg.template.multiwoz import TemplateNLG
def translate_text(text, target_language='pl'):
url = 'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={}&dt=t&q={}'.format(
target_language, text)
response = requests.get(url)
if response.status_code == 200:
translated_text = response.json()[0]
translated_text_joined = ''.join([sentence[0] for sentence in translated_text])
return translated_text_joined
else:
return None
class NaturalLanguageGeneration: class NaturalLanguageGeneration:
info_dict = None
def generate(self, system_act):
if len(system_act) == 0:
return "Nie rozumiem."
tnlg = TemplateNLG(is_user=False)
response_en = tnlg.generate(system_act)
return translate_text(response_en)
def nlg(self, system_act):
def init_session(self): response = None
pass # Dodanie pustej metody init_session pattern = r'inform\(name=([^\)]+)\)'
matching = re.search(pattern, system_act)
if matching:
name = matching.group(1)
response = f"Witaj, nazywam się {name}"
return response

View File

@ -11,7 +11,8 @@ for file_name in os.listdir(translated_data_directory):
if file_name.endswith('.tsv'): if file_name.endswith('.tsv'):
file_path = os.path.join(translated_data_directory, file_name) file_path = os.path.join(translated_data_directory, file_name)
df = pd.read_csv(file_path, sep='\t') df = pd.read_csv(file_path, sep='\t')
dfs.append(df) df_user = df[df['role'] == 'system'].drop('role', axis=1)
dfs.append(df_user)
combined_df = pd.concat(dfs, ignore_index=True) combined_df = pd.concat(dfs, ignore_index=True)
# Przygotowanie zbioru danych do trenowania # Przygotowanie zbioru danych do trenowania
@ -49,7 +50,7 @@ training_args = Seq2SeqTrainingArguments(
per_device_eval_batch_size=16, per_device_eval_batch_size=16,
predict_with_generate=True, predict_with_generate=True,
learning_rate=5e-5, learning_rate=5e-5,
num_train_epochs=3, num_train_epochs=10,
evaluation_strategy="epoch", evaluation_strategy="epoch",
save_strategy="epoch", save_strategy="epoch",
save_total_limit=None, # Wyłącz rotację punktów kontrolnych save_total_limit=None, # Wyłącz rotację punktów kontrolnych