move code from jupyter to python classes
This commit is contained in:
parent
c71bbc071d
commit
513f09e06b
@ -1,8 +1,76 @@
|
|||||||
class DialoguePolicy:
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
from copy import deepcopy
|
||||||
|
from convlab.policy.policy import Policy
|
||||||
|
|
||||||
def policy(self, state):
|
db_path = './hotels_data.json'
|
||||||
system_act = None
|
|
||||||
name = "James"
|
|
||||||
if state == "what name":
|
class DialoguePolicy(Policy):
|
||||||
system_act = f"inform(name={name})"
|
def __init__(self):
|
||||||
return system_act
|
Policy.__init__(self)
|
||||||
|
self.db = self.load_database(db_path)
|
||||||
|
|
||||||
|
def load_database(self, db_path):
|
||||||
|
with open(db_path, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def query(self, domain, constraints):
|
||||||
|
if domain != 'hotel':
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for entry in self.db:
|
||||||
|
match = all(entry.get(key) == value for key, value in constraints)
|
||||||
|
if match:
|
||||||
|
results.append(entry)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def predict(self, state):
|
||||||
|
self.results = []
|
||||||
|
system_action = defaultdict(list)
|
||||||
|
user_action = defaultdict(list)
|
||||||
|
|
||||||
|
for intent, domain, slot, value in state['user_action']:
|
||||||
|
user_action[(domain.lower(), intent.lower())].append((slot.lower(), value))
|
||||||
|
|
||||||
|
for user_act in user_action:
|
||||||
|
self.update_system_action(user_act, user_action, state, system_action)
|
||||||
|
|
||||||
|
if any(True for slots in user_action.values() for (slot, _) in slots if
|
||||||
|
slot in ['book stay', 'book day', 'book people']):
|
||||||
|
if self.results:
|
||||||
|
system_action = {('Booking', 'Book'): [["Ref", self.results[0].get('Ref', 'N/A')]]}
|
||||||
|
|
||||||
|
system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for
|
||||||
|
slot, value in slots]
|
||||||
|
state['system_action'] = system_acts
|
||||||
|
return system_acts
|
||||||
|
|
||||||
|
def update_system_action(self, user_act, user_action, state, system_action):
|
||||||
|
domain, intent = user_act
|
||||||
|
constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != '']
|
||||||
|
# print(f"Constraints: {constraints}")
|
||||||
|
self.results = deepcopy(self.query(domain.lower(), constraints))
|
||||||
|
# print(f"Query results: {self.results}")
|
||||||
|
|
||||||
|
if intent == 'request':
|
||||||
|
if len(self.results) == 0:
|
||||||
|
system_action[(domain, 'NoOffer')] = []
|
||||||
|
else:
|
||||||
|
for slot in user_action[user_act]:
|
||||||
|
if slot[0] in self.results[0]:
|
||||||
|
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
|
||||||
|
|
||||||
|
elif intent == 'inform':
|
||||||
|
if len(self.results) == 0:
|
||||||
|
system_action[(domain, 'NoOffer')] = []
|
||||||
|
else:
|
||||||
|
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
|
||||||
|
choice = self.results[0]
|
||||||
|
|
||||||
|
if domain in ["hotel"]:
|
||||||
|
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
|
||||||
|
for slot in state['belief_state'][domain]['info']:
|
||||||
|
if choice.get(slot):
|
||||||
|
state['belief_state'][domain]['info'][slot] = choice[slot]
|
@ -1,7 +1,77 @@
|
|||||||
class DialogueStateTracker:
|
import json
|
||||||
|
from convlab.dst.dst import DST
|
||||||
|
|
||||||
def dst(self, user_act):
|
|
||||||
state = None
|
def default_state():
|
||||||
if user_act == "request(firstname)":
|
return {
|
||||||
state = "what name"
|
'belief_state': {
|
||||||
return state
|
'hotel': {
|
||||||
|
'info': {
|
||||||
|
'name': '',
|
||||||
|
'area': '',
|
||||||
|
'parking': '',
|
||||||
|
'price range': '',
|
||||||
|
'stars': '',
|
||||||
|
'internet': '',
|
||||||
|
'type': ''
|
||||||
|
},
|
||||||
|
'booking': {
|
||||||
|
'book stay': '',
|
||||||
|
'book day': '',
|
||||||
|
'book people': ''
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'request_state': {},
|
||||||
|
'history': [],
|
||||||
|
'user_action': [],
|
||||||
|
'system_action': [],
|
||||||
|
'terminated': False,
|
||||||
|
'booked': []
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DialogueStateTracker(DST):
|
||||||
|
def __init__(self):
|
||||||
|
DST.__init__(self)
|
||||||
|
self.state = default_state()
|
||||||
|
with open('./hotels_data.json') as f:
|
||||||
|
self.value_dict = json.load(f)
|
||||||
|
|
||||||
|
def update(self, user_act=None):
|
||||||
|
for intent, domain, slot, value in user_act:
|
||||||
|
domain = domain.lower()
|
||||||
|
intent = intent.lower()
|
||||||
|
slot = slot.lower()
|
||||||
|
|
||||||
|
if domain not in self.state['belief_state']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if intent == 'inform':
|
||||||
|
if slot == 'none' or slot == '' or value == 'dontcare':
|
||||||
|
continue
|
||||||
|
|
||||||
|
domain_dic = self.state['belief_state'][domain]['info']
|
||||||
|
|
||||||
|
if slot in domain_dic:
|
||||||
|
nvalue = self.normalize_value(self.value_dict, domain, slot, value)
|
||||||
|
self.state['belief_state'][domain]['info'][slot] = nvalue
|
||||||
|
|
||||||
|
elif intent == 'request':
|
||||||
|
if domain not in self.state['request_state']:
|
||||||
|
self.state['request_state'][domain] = {}
|
||||||
|
if slot not in self.state['request_state'][domain]:
|
||||||
|
self.state['request_state'][domain][slot] = 0
|
||||||
|
|
||||||
|
return self.state
|
||||||
|
|
||||||
|
def normalize_value(self, value_dict, domain, slot, value):
|
||||||
|
normalized_value = value.lower().strip()
|
||||||
|
if domain in value_dict and slot in value_dict[domain]:
|
||||||
|
possible_values = value_dict[domain][slot]
|
||||||
|
if isinstance(possible_values, dict) and normalized_value in possible_values:
|
||||||
|
return possible_values[normalized_value]
|
||||||
|
return value
|
||||||
|
|
||||||
|
def init_session(self):
|
||||||
|
self.state = default_state()
|
30
Main.py
30
Main.py
@ -1,24 +1,20 @@
|
|||||||
from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer
|
from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer
|
||||||
from DialogueStateTracker import DialogueStateTracker
|
|
||||||
from DialoguePolicy import DialoguePolicy
|
from DialoguePolicy import DialoguePolicy
|
||||||
from NaturalLanguageGeneration import NaturalLanguageGeneration
|
from DialogueStateTracker import DialogueStateTracker
|
||||||
|
from convlab.dialog_agent import PipelineAgent
|
||||||
|
from convlab.nlg.template.multiwoz import TemplateNLG
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
text = "chciałbym zarezerwować pokój z balkonem 1 stycznia w Warszawie"
|
text = "chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum"
|
||||||
nla = NaturalLanguageAnalyzer()
|
nlu = NaturalLanguageAnalyzer()
|
||||||
user_act = nla.process(text)
|
dst = DialogueStateTracker()
|
||||||
print(user_act)
|
policy = DialoguePolicy()
|
||||||
|
nlg = TemplateNLG(is_user=False)
|
||||||
|
|
||||||
# dst = DialogueStateTracker()
|
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
|
||||||
# state = dst.dst(user_act)
|
response = agent.response(text)
|
||||||
# print(state)
|
print(response)
|
||||||
#
|
|
||||||
# dp = DialoguePolicy()
|
|
||||||
# system_act = dp.policy(state)
|
|
||||||
# print(system_act)
|
|
||||||
#
|
|
||||||
# nlg = NaturalLanguageGeneration()
|
|
||||||
# response = nlg.nlg(system_act)
|
|
||||||
# print(response)
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ def translate_text(text, target_language='en'):
|
|||||||
|
|
||||||
|
|
||||||
class NaturalLanguageAnalyzer:
|
class NaturalLanguageAnalyzer:
|
||||||
def process(self, text):
|
def predict(self, text, context=None):
|
||||||
# Inicjalizacja modelu NLU
|
# Inicjalizacja modelu NLU
|
||||||
model_name = "ConvLab/t5-small-nlu-multiwoz21"
|
model_name = "ConvLab/t5-small-nlu-multiwoz21"
|
||||||
nlu_model = T5NLU(speaker='user', context_window_size=0, model_name_or_path=model_name)
|
nlu_model = T5NLU(speaker='user', context_window_size=0, model_name_or_path=model_name)
|
||||||
@ -26,3 +26,7 @@ class NaturalLanguageAnalyzer:
|
|||||||
nlu_output = nlu_model.predict(translated_input)
|
nlu_output = nlu_model.predict(translated_input)
|
||||||
|
|
||||||
return nlu_output
|
return nlu_output
|
||||||
|
|
||||||
|
def init_session(self):
|
||||||
|
# Inicjalizacja sesji (jeśli konieczne)
|
||||||
|
pass
|
Loading…
Reference in New Issue
Block a user