Compare commits
6 Commits
test_branc
...
master
Author | SHA1 | Date | |
---|---|---|---|
|
3464a04349 | ||
1dcb6c7343 | |||
29700baaf9 | |||
|
e24d1894a3 | ||
7bb36b5bac | |||
e9ffbd646f |
@ -1,12 +1,20 @@
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
from copy import deepcopy
|
||||
from convlab.policy.policy import Policy
|
||||
|
||||
db_path = './hotels_data.json'
|
||||
|
||||
|
||||
def generate_reference_number(length=8):
|
||||
letters_and_digits = string.ascii_uppercase + string.digits
|
||||
reference_number = ''.join(random.choice(letters_and_digits) for _ in range(length))
|
||||
return reference_number
|
||||
|
||||
class DialoguePolicy(Policy):
|
||||
info_dict = None
|
||||
def __init__(self):
|
||||
Policy.__init__(self)
|
||||
self.db = self.load_database(db_path)
|
||||
@ -40,7 +48,8 @@ class DialoguePolicy(Policy):
|
||||
if any(True for slots in user_action.values() for (slot, _) in slots if
|
||||
slot in ['book stay', 'book day', 'book people']):
|
||||
if self.results:
|
||||
system_action = {('Booking', 'Book'): [["Ref", self.results[0].get('Ref', 'N/A')]]}
|
||||
reference_number = generate_reference_number()
|
||||
system_action = {('Booking', 'Book'): [["Ref", reference_number]]}
|
||||
|
||||
system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for
|
||||
slot, value in slots]
|
||||
@ -49,28 +58,29 @@ class DialoguePolicy(Policy):
|
||||
|
||||
def update_system_action(self, user_act, user_action, state, system_action):
|
||||
domain, intent = user_act
|
||||
constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != '']
|
||||
# print(f"Constraints: {constraints}")
|
||||
self.results = deepcopy(self.query(domain.lower(), constraints))
|
||||
# print(f"Query results: {self.results}")
|
||||
if domain in state['belief_state']:
|
||||
constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != '']
|
||||
# print(f"Constraints: {constraints}")
|
||||
self.results = deepcopy(self.query(domain.lower(), constraints))
|
||||
# print(f"Query results: {self.results}")
|
||||
|
||||
if intent == 'request':
|
||||
if len(self.results) == 0:
|
||||
system_action[(domain, 'NoOffer')] = []
|
||||
else:
|
||||
for slot in user_action[user_act]:
|
||||
if slot[0] in self.results[0]:
|
||||
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
|
||||
if intent == 'request':
|
||||
if len(self.results) == 0:
|
||||
system_action[(domain, 'NoOffer')] = []
|
||||
else:
|
||||
for slot in user_action[user_act]:
|
||||
if slot[0] in self.results[0]:
|
||||
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
|
||||
|
||||
elif intent == 'inform':
|
||||
if len(self.results) == 0:
|
||||
system_action[(domain, 'NoOffer')] = []
|
||||
else:
|
||||
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
|
||||
choice = self.results[0]
|
||||
elif intent == 'inform':
|
||||
if len(self.results) == 0:
|
||||
system_action[(domain, 'NoOffer')] = []
|
||||
else:
|
||||
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
|
||||
choice = self.results[0]
|
||||
|
||||
if domain in ["hotel"]:
|
||||
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
|
||||
for slot in state['belief_state'][domain]['info']:
|
||||
if choice.get(slot):
|
||||
state['belief_state'][domain]['info'][slot] = choice[slot]
|
||||
if domain in ["hotel"]:
|
||||
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
|
||||
for slot in state['belief_state'][domain]['info']:
|
||||
if choice.get(slot):
|
||||
state['belief_state'][domain]['info'][slot] = choice[slot]
|
||||
|
@ -27,11 +27,12 @@ def default_state():
|
||||
'user_action': [],
|
||||
'system_action': [],
|
||||
'terminated': False,
|
||||
'booked': []
|
||||
'booked': {}
|
||||
}
|
||||
|
||||
|
||||
class DialogueStateTracker(DST):
|
||||
info_dict = None
|
||||
def __init__(self):
|
||||
DST.__init__(self)
|
||||
self.state = default_state()
|
||||
|
16
Main.py
16
Main.py
@ -1,16 +1,22 @@
|
||||
from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer
|
||||
from DialoguePolicy import DialoguePolicy
|
||||
from DialogueStateTracker import DialogueStateTracker
|
||||
from NaturalLanguageGeneration import NaturalLanguageGeneration
|
||||
from convlab.dialog_agent import PipelineAgent
|
||||
from MachineLearningNLG import MachineLearningNLG # Importujemy nowy komponent NLG
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = "chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum"
|
||||
nlu = NaturalLanguageAnalyzer()
|
||||
dst = DialogueStateTracker()
|
||||
policy = DialoguePolicy()
|
||||
nlg = MachineLearningNLG()
|
||||
nlg = NaturalLanguageGeneration()
|
||||
|
||||
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
|
||||
response = agent.response(text)
|
||||
print(response)
|
||||
|
||||
print("Witam, jestem systemem do rezerwowania pokoi hotelowych. W czym mogę Ci pomóc?")
|
||||
while True:
|
||||
text = input()
|
||||
response = agent.response(text)
|
||||
print(response)
|
||||
|
@ -14,6 +14,7 @@ def translate_text(text, target_language='en'):
|
||||
|
||||
|
||||
class NaturalLanguageAnalyzer:
|
||||
info_dict = None
|
||||
def predict(self, text, context=None):
|
||||
# Inicjalizacja modelu NLU
|
||||
model_name = "ConvLab/t5-small-nlu-multiwoz21"
|
||||
|
@ -1,12 +1,28 @@
|
||||
import re
|
||||
import requests
|
||||
|
||||
from convlab.nlg.template.multiwoz import TemplateNLG
|
||||
|
||||
|
||||
def translate_text(text, target_language='pl'):
|
||||
url = 'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={}&dt=t&q={}'.format(
|
||||
target_language, text)
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
translated_text = response.json()[0]
|
||||
translated_text_joined = ''.join([sentence[0] for sentence in translated_text])
|
||||
return translated_text_joined
|
||||
else:
|
||||
return None
|
||||
|
||||
class NaturalLanguageGeneration:
|
||||
info_dict = None
|
||||
def generate(self, system_act):
|
||||
if len(system_act) == 0:
|
||||
return "Nie rozumiem."
|
||||
tnlg = TemplateNLG(is_user=False)
|
||||
response_en = tnlg.generate(system_act)
|
||||
return translate_text(response_en)
|
||||
|
||||
def nlg(self, system_act):
|
||||
response = None
|
||||
pattern = r'inform\(name=([^\)]+)\)'
|
||||
matching = re.search(pattern, system_act)
|
||||
if matching:
|
||||
name = matching.group(1)
|
||||
response = f"Witaj, nazywam się {name}"
|
||||
return response
|
||||
|
||||
def init_session(self):
|
||||
pass # Dodanie pustej metody init_session
|
||||
|
Loading…
Reference in New Issue
Block a user