move code from jupyter to python classes
This commit is contained in:
parent
c71bbc071d
commit
513f09e06b
@ -1,8 +1,76 @@
|
||||
class DialoguePolicy:
|
||||
from collections import defaultdict
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from convlab.policy.policy import Policy
|
||||
|
||||
def policy(self, state):
|
||||
system_act = None
|
||||
name = "James"
|
||||
if state == "what name":
|
||||
system_act = f"inform(name={name})"
|
||||
return system_act
|
||||
db_path = './hotels_data.json'
|
||||
|
||||
|
||||
class DialoguePolicy(Policy):
|
||||
def __init__(self):
|
||||
Policy.__init__(self)
|
||||
self.db = self.load_database(db_path)
|
||||
|
||||
def load_database(self, db_path):
|
||||
with open(db_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def query(self, domain, constraints):
|
||||
if domain != 'hotel':
|
||||
return []
|
||||
|
||||
results = []
|
||||
for entry in self.db:
|
||||
match = all(entry.get(key) == value for key, value in constraints)
|
||||
if match:
|
||||
results.append(entry)
|
||||
return results
|
||||
|
||||
def predict(self, state):
|
||||
self.results = []
|
||||
system_action = defaultdict(list)
|
||||
user_action = defaultdict(list)
|
||||
|
||||
for intent, domain, slot, value in state['user_action']:
|
||||
user_action[(domain.lower(), intent.lower())].append((slot.lower(), value))
|
||||
|
||||
for user_act in user_action:
|
||||
self.update_system_action(user_act, user_action, state, system_action)
|
||||
|
||||
if any(True for slots in user_action.values() for (slot, _) in slots if
|
||||
slot in ['book stay', 'book day', 'book people']):
|
||||
if self.results:
|
||||
system_action = {('Booking', 'Book'): [["Ref", self.results[0].get('Ref', 'N/A')]]}
|
||||
|
||||
system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for
|
||||
slot, value in slots]
|
||||
state['system_action'] = system_acts
|
||||
return system_acts
|
||||
|
||||
def update_system_action(self, user_act, user_action, state, system_action):
|
||||
domain, intent = user_act
|
||||
constraints = [(slot, value) for slot, value in state['belief_state'][domain]['info'].items() if value != '']
|
||||
# print(f"Constraints: {constraints}")
|
||||
self.results = deepcopy(self.query(domain.lower(), constraints))
|
||||
# print(f"Query results: {self.results}")
|
||||
|
||||
if intent == 'request':
|
||||
if len(self.results) == 0:
|
||||
system_action[(domain, 'NoOffer')] = []
|
||||
else:
|
||||
for slot in user_action[user_act]:
|
||||
if slot[0] in self.results[0]:
|
||||
system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])
|
||||
|
||||
elif intent == 'inform':
|
||||
if len(self.results) == 0:
|
||||
system_action[(domain, 'NoOffer')] = []
|
||||
else:
|
||||
system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])
|
||||
choice = self.results[0]
|
||||
|
||||
if domain in ["hotel"]:
|
||||
system_action[(domain, 'Recommend')].append(['Name', choice['name']])
|
||||
for slot in state['belief_state'][domain]['info']:
|
||||
if choice.get(slot):
|
||||
state['belief_state'][domain]['info'][slot] = choice[slot]
|
@ -1,7 +1,77 @@
|
||||
class DialogueStateTracker:
|
||||
import json
|
||||
from convlab.dst.dst import DST
|
||||
|
||||
def dst(self, user_act):
|
||||
state = None
|
||||
if user_act == "request(firstname)":
|
||||
state = "what name"
|
||||
return state
|
||||
|
||||
def default_state():
|
||||
return {
|
||||
'belief_state': {
|
||||
'hotel': {
|
||||
'info': {
|
||||
'name': '',
|
||||
'area': '',
|
||||
'parking': '',
|
||||
'price range': '',
|
||||
'stars': '',
|
||||
'internet': '',
|
||||
'type': ''
|
||||
},
|
||||
'booking': {
|
||||
'book stay': '',
|
||||
'book day': '',
|
||||
'book people': ''
|
||||
}
|
||||
}
|
||||
},
|
||||
'request_state': {},
|
||||
'history': [],
|
||||
'user_action': [],
|
||||
'system_action': [],
|
||||
'terminated': False,
|
||||
'booked': []
|
||||
}
|
||||
|
||||
|
||||
class DialogueStateTracker(DST):
|
||||
def __init__(self):
|
||||
DST.__init__(self)
|
||||
self.state = default_state()
|
||||
with open('./hotels_data.json') as f:
|
||||
self.value_dict = json.load(f)
|
||||
|
||||
def update(self, user_act=None):
|
||||
for intent, domain, slot, value in user_act:
|
||||
domain = domain.lower()
|
||||
intent = intent.lower()
|
||||
slot = slot.lower()
|
||||
|
||||
if domain not in self.state['belief_state']:
|
||||
continue
|
||||
|
||||
if intent == 'inform':
|
||||
if slot == 'none' or slot == '' or value == 'dontcare':
|
||||
continue
|
||||
|
||||
domain_dic = self.state['belief_state'][domain]['info']
|
||||
|
||||
if slot in domain_dic:
|
||||
nvalue = self.normalize_value(self.value_dict, domain, slot, value)
|
||||
self.state['belief_state'][domain]['info'][slot] = nvalue
|
||||
|
||||
elif intent == 'request':
|
||||
if domain not in self.state['request_state']:
|
||||
self.state['request_state'][domain] = {}
|
||||
if slot not in self.state['request_state'][domain]:
|
||||
self.state['request_state'][domain][slot] = 0
|
||||
|
||||
return self.state
|
||||
|
||||
def normalize_value(self, value_dict, domain, slot, value):
|
||||
normalized_value = value.lower().strip()
|
||||
if domain in value_dict and slot in value_dict[domain]:
|
||||
possible_values = value_dict[domain][slot]
|
||||
if isinstance(possible_values, dict) and normalized_value in possible_values:
|
||||
return possible_values[normalized_value]
|
||||
return value
|
||||
|
||||
def init_session(self):
|
||||
self.state = default_state()
|
30
Main.py
30
Main.py
@ -1,24 +1,20 @@
|
||||
from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer
|
||||
from DialogueStateTracker import DialogueStateTracker
|
||||
from DialoguePolicy import DialoguePolicy
|
||||
from NaturalLanguageGeneration import NaturalLanguageGeneration
|
||||
from DialogueStateTracker import DialogueStateTracker
|
||||
from convlab.dialog_agent import PipelineAgent
|
||||
from convlab.nlg.template.multiwoz import TemplateNLG
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = "chciałbym zarezerwować pokój z balkonem 1 stycznia w Warszawie"
|
||||
nla = NaturalLanguageAnalyzer()
|
||||
user_act = nla.process(text)
|
||||
print(user_act)
|
||||
text = "chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum"
|
||||
nlu = NaturalLanguageAnalyzer()
|
||||
dst = DialogueStateTracker()
|
||||
policy = DialoguePolicy()
|
||||
nlg = TemplateNLG(is_user=False)
|
||||
|
||||
# dst = DialogueStateTracker()
|
||||
# state = dst.dst(user_act)
|
||||
# print(state)
|
||||
#
|
||||
# dp = DialoguePolicy()
|
||||
# system_act = dp.policy(state)
|
||||
# print(system_act)
|
||||
#
|
||||
# nlg = NaturalLanguageGeneration()
|
||||
# response = nlg.nlg(system_act)
|
||||
# print(response)
|
||||
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
|
||||
response = agent.response(text)
|
||||
print(response)
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@ def translate_text(text, target_language='en'):
|
||||
|
||||
|
||||
class NaturalLanguageAnalyzer:
|
||||
def process(self, text):
|
||||
def predict(self, text, context=None):
|
||||
# Inicjalizacja modelu NLU
|
||||
model_name = "ConvLab/t5-small-nlu-multiwoz21"
|
||||
nlu_model = T5NLU(speaker='user', context_window_size=0, model_name_or_path=model_name)
|
||||
@ -26,3 +26,7 @@ class NaturalLanguageAnalyzer:
|
||||
nlu_output = nlu_model.predict(translated_input)
|
||||
|
||||
return nlu_output
|
||||
|
||||
def init_session(self):
|
||||
# Inicjalizacja sesji (jeśli konieczne)
|
||||
pass
|
Loading…
Reference in New Issue
Block a user