Compare commits
4 Commits
master
...
test_branc
Author | SHA1 | Date | |
---|---|---|---|
|
893733025b | ||
99306f0532 | |||
7a39f24359 | |||
|
7306fdb83a |
1
ConvLab-3
Submodule
1
ConvLab-3
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 60f4e5641f93e99b8d61b49cf5fd6dc818a83c4c
|
@ -1,20 +1,12 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import json
|
import json
|
||||||
import random
|
|
||||||
import string
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from convlab.policy.policy import Policy
|
from convlab.policy.policy import Policy
|
||||||
|
|
||||||
db_path = './hotels_data.json'
|
db_path = './hotels_data.json'
|
||||||
|
|
||||||
|
|
||||||
def generate_reference_number(length=8):
|
|
||||||
letters_and_digits = string.ascii_uppercase + string.digits
|
|
||||||
reference_number = ''.join(random.choice(letters_and_digits) for _ in range(length))
|
|
||||||
return reference_number
|
|
||||||
|
|
||||||
class DialoguePolicy(Policy):
|
class DialoguePolicy(Policy):
|
||||||
info_dict = None
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Policy.__init__(self)
|
Policy.__init__(self)
|
||||||
self.db = self.load_database(db_path)
|
self.db = self.load_database(db_path)
|
||||||
@ -48,8 +40,7 @@ class DialoguePolicy(Policy):
|
|||||||
if any(True for slots in user_action.values() for (slot, _) in slots if
|
if any(True for slots in user_action.values() for (slot, _) in slots if
|
||||||
slot in ['book stay', 'book day', 'book people']):
|
slot in ['book stay', 'book day', 'book people']):
|
||||||
if self.results:
|
if self.results:
|
||||||
reference_number = generate_reference_number()
|
system_action = {('Booking', 'Book'): [["Ref", self.results[0].get('Ref', 'N/A')]]}
|
||||||
system_action = {('Booking', 'Book'): [["Ref", reference_number]]}
|
|
||||||
|
|
||||||
system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for
|
system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for
|
||||||
slot, value in slots]
|
slot, value in slots]
|
||||||
@ -58,29 +49,28 @@ class DialoguePolicy(Policy):
|
|||||||
|
|
||||||
def update_system_action(self, user_act, user_action, state, system_action):
|
def update_system_action(self, user_act, user_action, state, system_action):
|
||||||
domain, intent = user_act
|
domain, intent = user_act
|
||||||
if domain in state['belief_state']:
|
constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != '']
|
||||||
constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != '']
|
# print(f"Constraints: {constraints}")
|
||||||
# print(f"Constraints: {constraints}")
|
self.results = deepcopy(self.query(domain.lower(), constraints))
|
||||||
self.results = deepcopy(self.query(domain.lower(), constraints))
|
# print(f"Query results: {self.results}")
|
||||||
# print(f"Query results: {self.results}")
|
|
||||||
|
|
||||||
if intent == 'request':
|
if intent == 'request':
|
||||||
if len(self.results) == 0:
|
if len(self.results) == 0:
|
||||||
system_action[(domain, 'NoOffer')] = []
|
system_action[(domain, 'NoOffer')] = []
|
||||||
else:
|
else:
|
||||||
for slot in user_action[user_act]:
|
for slot in user_action[user_act]:
|
||||||
if slot[0] in self.results[0]:
|
if slot[0] in self.results[0]:
|
||||||
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
|
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
|
||||||
|
|
||||||
elif intent == 'inform':
|
elif intent == 'inform':
|
||||||
if len(self.results) == 0:
|
if len(self.results) == 0:
|
||||||
system_action[(domain, 'NoOffer')] = []
|
system_action[(domain, 'NoOffer')] = []
|
||||||
else:
|
else:
|
||||||
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
|
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
|
||||||
choice = self.results[0]
|
choice = self.results[0]
|
||||||
|
|
||||||
if domain in ["hotel"]:
|
if domain in ["hotel"]:
|
||||||
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
|
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
|
||||||
for slot in state['belief_state'][domain]['info']:
|
for slot in state['belief_state'][domain]['info']:
|
||||||
if choice.get(slot):
|
if choice.get(slot):
|
||||||
state['belief_state'][domain]['info'][slot] = choice[slot]
|
state['belief_state'][domain]['info'][slot] = choice[slot]
|
@ -27,12 +27,11 @@ def default_state():
|
|||||||
'user_action': [],
|
'user_action': [],
|
||||||
'system_action': [],
|
'system_action': [],
|
||||||
'terminated': False,
|
'terminated': False,
|
||||||
'booked': {}
|
'booked': []
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class DialogueStateTracker(DST):
|
class DialogueStateTracker(DST):
|
||||||
info_dict = None
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DST.__init__(self)
|
DST.__init__(self)
|
||||||
self.state = default_state()
|
self.state = default_state()
|
||||||
|
@ -33,15 +33,14 @@ class MachineLearningNLG:
|
|||||||
return translated_response
|
return translated_response
|
||||||
|
|
||||||
def generate(self, action):
|
def generate(self, action):
|
||||||
# Przyjmujemy, że 'action' jest formatowanym stringiem, który jest przekazywany do self.nlg
|
|
||||||
return self.nlg(action)
|
return self.nlg(action)
|
||||||
|
|
||||||
def init_session(self):
|
def init_session(self):
|
||||||
pass # Dodanie pustej metody init_session
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Przykład użycia
|
# Przykład użycia
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
nlg = MachineLearningNLG()
|
nlg = MachineLearningNLG()
|
||||||
system_act = "inform(date.from=15.07, date.to=22.07)"
|
system_act = "inform(hotel='Four Seasons Hotel')"
|
||||||
print(nlg.nlg(system_act))
|
print(nlg.nlg(system_act))
|
||||||
|
17
Main.py
17
Main.py
@ -1,22 +1,17 @@
|
|||||||
from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer
|
from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer
|
||||||
from DialoguePolicy import DialoguePolicy
|
from DialoguePolicy import DialoguePolicy
|
||||||
from DialogueStateTracker import DialogueStateTracker
|
from DialogueStateTracker import DialogueStateTracker
|
||||||
from NaturalLanguageGeneration import NaturalLanguageGeneration
|
|
||||||
from convlab.dialog_agent import PipelineAgent
|
from convlab.dialog_agent import PipelineAgent
|
||||||
import warnings
|
from MachineLearningNLG import MachineLearningNLG # Importujemy nowy komponent NLG
|
||||||
warnings.filterwarnings("ignore")
|
from convlab.nlg.template.multiwoz import TemplateNLG
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
text = "chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum"
|
||||||
nlu = NaturalLanguageAnalyzer()
|
nlu = NaturalLanguageAnalyzer()
|
||||||
dst = DialogueStateTracker()
|
dst = DialogueStateTracker()
|
||||||
policy = DialoguePolicy()
|
policy = DialoguePolicy()
|
||||||
nlg = NaturalLanguageGeneration()
|
nlg = TemplateNLG(is_user=False)
|
||||||
|
|
||||||
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
|
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
|
||||||
|
response = agent.response(text)
|
||||||
print("Witam, jestem systemem do rezerwowania pokoi hotelowych. W czym mogę Ci pomóc?")
|
print(response)
|
||||||
while True:
|
|
||||||
text = input()
|
|
||||||
response = agent.response(text)
|
|
||||||
print(response)
|
|
||||||
|
@ -14,7 +14,6 @@ def translate_text(text, target_language='en'):
|
|||||||
|
|
||||||
|
|
||||||
class NaturalLanguageAnalyzer:
|
class NaturalLanguageAnalyzer:
|
||||||
info_dict = None
|
|
||||||
def predict(self, text, context=None):
|
def predict(self, text, context=None):
|
||||||
# Inicjalizacja modelu NLU
|
# Inicjalizacja modelu NLU
|
||||||
model_name = "ConvLab/t5-small-nlu-multiwoz21"
|
model_name = "ConvLab/t5-small-nlu-multiwoz21"
|
||||||
|
@ -1,28 +1,12 @@
|
|||||||
import requests
|
import re
|
||||||
|
|
||||||
from convlab.nlg.template.multiwoz import TemplateNLG
|
|
||||||
|
|
||||||
|
|
||||||
def translate_text(text, target_language='pl'):
|
|
||||||
url = 'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={}&dt=t&q={}'.format(
|
|
||||||
target_language, text)
|
|
||||||
response = requests.get(url)
|
|
||||||
if response.status_code == 200:
|
|
||||||
translated_text = response.json()[0]
|
|
||||||
translated_text_joined = ''.join([sentence[0] for sentence in translated_text])
|
|
||||||
return translated_text_joined
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
class NaturalLanguageGeneration:
|
class NaturalLanguageGeneration:
|
||||||
info_dict = None
|
|
||||||
def generate(self, system_act):
|
|
||||||
if len(system_act) == 0:
|
|
||||||
return "Nie rozumiem."
|
|
||||||
tnlg = TemplateNLG(is_user=False)
|
|
||||||
response_en = tnlg.generate(system_act)
|
|
||||||
return translate_text(response_en)
|
|
||||||
|
|
||||||
|
def nlg(self, system_act):
|
||||||
def init_session(self):
|
response = None
|
||||||
pass # Dodanie pustej metody init_session
|
pattern = r'inform\(name=([^\)]+)\)'
|
||||||
|
matching = re.search(pattern, system_act)
|
||||||
|
if matching:
|
||||||
|
name = matching.group(1)
|
||||||
|
response = f"Witaj, nazywam się {name}"
|
||||||
|
return response
|
||||||
|
@ -11,7 +11,8 @@ for file_name in os.listdir(translated_data_directory):
|
|||||||
if file_name.endswith('.tsv'):
|
if file_name.endswith('.tsv'):
|
||||||
file_path = os.path.join(translated_data_directory, file_name)
|
file_path = os.path.join(translated_data_directory, file_name)
|
||||||
df = pd.read_csv(file_path, sep='\t')
|
df = pd.read_csv(file_path, sep='\t')
|
||||||
dfs.append(df)
|
df_user = df[df['role'] == 'system'].drop('role', axis=1)
|
||||||
|
dfs.append(df_user)
|
||||||
combined_df = pd.concat(dfs, ignore_index=True)
|
combined_df = pd.concat(dfs, ignore_index=True)
|
||||||
|
|
||||||
# Przygotowanie zbioru danych do trenowania
|
# Przygotowanie zbioru danych do trenowania
|
||||||
@ -49,7 +50,7 @@ training_args = Seq2SeqTrainingArguments(
|
|||||||
per_device_eval_batch_size=16,
|
per_device_eval_batch_size=16,
|
||||||
predict_with_generate=True,
|
predict_with_generate=True,
|
||||||
learning_rate=5e-5,
|
learning_rate=5e-5,
|
||||||
num_train_epochs=3,
|
num_train_epochs=10,
|
||||||
evaluation_strategy="epoch",
|
evaluation_strategy="epoch",
|
||||||
save_strategy="epoch",
|
save_strategy="epoch",
|
||||||
save_total_limit=None, # Wyłącz rotację punktów kontrolnych
|
save_total_limit=None, # Wyłącz rotację punktów kontrolnych
|
||||||
|
Loading…
Reference in New Issue
Block a user