Compare commits

..

4 Commits

Author SHA1 Message Date
PawelDopierala
893733025b Change to TemplateNLG 2024-06-04 13:07:41 +02:00
99306f0532 Merge remote-tracking branch 'origin/test_branch' into test_branch 2024-06-04 11:44:58 +02:00
7a39f24359 test_learning 2024-06-04 11:43:14 +02:00
Maciej Matusz
7306fdb83a test- commit 2024-06-04 10:56:40 +02:00
8 changed files with 45 additions and 77 deletions

1
ConvLab-3 Submodule

@ -0,0 +1 @@
Subproject commit 60f4e5641f93e99b8d61b49cf5fd6dc818a83c4c

View File

@ -1,20 +1,12 @@
from collections import defaultdict
import json
import random
import string
from copy import deepcopy
from convlab.policy.policy import Policy
db_path = './hotels_data.json'
def generate_reference_number(length=8):
letters_and_digits = string.ascii_uppercase + string.digits
reference_number = ''.join(random.choice(letters_and_digits) for _ in range(length))
return reference_number
class DialoguePolicy(Policy):
info_dict = None
def __init__(self):
Policy.__init__(self)
self.db = self.load_database(db_path)
@ -48,8 +40,7 @@ class DialoguePolicy(Policy):
if any(True for slots in user_action.values() for (slot, _) in slots if
slot in ['book stay', 'book day', 'book people']):
if self.results:
reference_number = generate_reference_number()
system_action = {('Booking', 'Book'): [["Ref", reference_number]]}
system_action = {('Booking', 'Book'): [["Ref", self.results[0].get('Ref', 'N/A')]]}
system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for
slot, value in slots]
@ -58,7 +49,6 @@ class DialoguePolicy(Policy):
def update_system_action(self, user_act, user_action, state, system_action):
domain, intent = user_act
if domain in state['belief_state']:
constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != '']
# print(f"Constraints: {constraints}")
self.results = deepcopy(self.query(domain.lower(), constraints))

View File

@ -27,12 +27,11 @@ def default_state():
'user_action': [],
'system_action': [],
'terminated': False,
'booked': {}
'booked': []
}
class DialogueStateTracker(DST):
info_dict = None
def __init__(self):
DST.__init__(self)
self.state = default_state()

View File

@ -33,15 +33,14 @@ class MachineLearningNLG:
return translated_response
def generate(self, action):
# Przyjmujemy, że 'action' jest formatowanym stringiem, który jest przekazywany do self.nlg
return self.nlg(action)
def init_session(self):
pass # Dodanie pustej metody init_session
pass
# Przykład użycia
if __name__ == "__main__":
nlg = MachineLearningNLG()
system_act = "inform(date.from=15.07, date.to=22.07)"
system_act = "inform(hotel='Four Seasons Hotel')"
print(nlg.nlg(system_act))

13
Main.py
View File

@ -1,22 +1,17 @@
from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer
from DialoguePolicy import DialoguePolicy
from DialogueStateTracker import DialogueStateTracker
from NaturalLanguageGeneration import NaturalLanguageGeneration
from convlab.dialog_agent import PipelineAgent
import warnings
warnings.filterwarnings("ignore")
from MachineLearningNLG import MachineLearningNLG # Importujemy nowy komponent NLG
from convlab.nlg.template.multiwoz import TemplateNLG
if __name__ == "__main__":
text = "chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum"
nlu = NaturalLanguageAnalyzer()
dst = DialogueStateTracker()
policy = DialoguePolicy()
nlg = NaturalLanguageGeneration()
nlg = TemplateNLG(is_user=False)
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
print("Witam, jestem systemem do rezerwowania pokoi hotelowych. W czym mogę Ci pomóc?")
while True:
text = input()
response = agent.response(text)
print(response)

View File

@ -14,7 +14,6 @@ def translate_text(text, target_language='en'):
class NaturalLanguageAnalyzer:
info_dict = None
def predict(self, text, context=None):
# Inicjalizacja modelu NLU
model_name = "ConvLab/t5-small-nlu-multiwoz21"

View File

@ -1,28 +1,12 @@
import requests
from convlab.nlg.template.multiwoz import TemplateNLG
def translate_text(text, target_language='pl'):
url = 'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={}&dt=t&q={}'.format(
target_language, text)
response = requests.get(url)
if response.status_code == 200:
translated_text = response.json()[0]
translated_text_joined = ''.join([sentence[0] for sentence in translated_text])
return translated_text_joined
else:
return None
import re
class NaturalLanguageGeneration:
info_dict = None
def generate(self, system_act):
if len(system_act) == 0:
return "Nie rozumiem."
tnlg = TemplateNLG(is_user=False)
response_en = tnlg.generate(system_act)
return translate_text(response_en)
def init_session(self):
pass # Dodanie pustej metody init_session
def nlg(self, system_act):
response = None
pattern = r'inform\(name=([^\)]+)\)'
matching = re.search(pattern, system_act)
if matching:
name = matching.group(1)
response = f"Witaj, nazywam się {name}"
return response

View File

@ -11,7 +11,8 @@ for file_name in os.listdir(translated_data_directory):
if file_name.endswith('.tsv'):
file_path = os.path.join(translated_data_directory, file_name)
df = pd.read_csv(file_path, sep='\t')
dfs.append(df)
df_user = df[df['role'] == 'system'].drop('role', axis=1)
dfs.append(df_user)
combined_df = pd.concat(dfs, ignore_index=True)
# Przygotowanie zbioru danych do trenowania
@ -49,7 +50,7 @@ training_args = Seq2SeqTrainingArguments(
per_device_eval_batch_size=16,
predict_with_generate=True,
learning_rate=5e-5,
num_train_epochs=3,
num_train_epochs=10,
evaluation_strategy="epoch",
save_strategy="epoch",
save_total_limit=None, # Wyłącz rotację punktów kontrolnych