add: nlg model
This commit is contained in:
parent
88741c2d2b
commit
60a298fa49
295
dialog_with_nlg.py
Normal file
295
dialog_with_nlg.py
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
import string
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import jsgf
|
||||||
|
from unidecode import unidecode
|
||||||
|
|
||||||
|
from convlab.dst import dst
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def default_state():
|
||||||
|
return dict(
|
||||||
|
user_action=[],
|
||||||
|
system_action=[],
|
||||||
|
belief_state={
|
||||||
|
'address': '',
|
||||||
|
'payment_method': '',
|
||||||
|
'dish': [],
|
||||||
|
'time': ''
|
||||||
|
},
|
||||||
|
booked={},
|
||||||
|
request_state=[],
|
||||||
|
terminated=False,
|
||||||
|
history=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Model:
|
||||||
|
def __init__(self):
|
||||||
|
self.state = default_state()
|
||||||
|
self.nlu = NLU()
|
||||||
|
self.dst = DST(self.state)
|
||||||
|
self.dp = DP(self.state)
|
||||||
|
self.nlg = NLG(self.state)
|
||||||
|
|
||||||
|
def __call__(self, prompt) -> Any:
|
||||||
|
print(prompt)
|
||||||
|
msg = prompt.lower()
|
||||||
|
|
||||||
|
r = self.nlu(msg)
|
||||||
|
slots = r['slots']
|
||||||
|
#print(r)
|
||||||
|
r = self.dst(r)
|
||||||
|
#print(r)
|
||||||
|
r = self.dp()
|
||||||
|
#print(r)
|
||||||
|
r = self.nlg(r, slots)
|
||||||
|
print(r)
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
class NLU():
|
||||||
|
def __init__(self):
|
||||||
|
self.book_grammar = jsgf.parse_grammar_file('book.jsgf')
|
||||||
|
|
||||||
|
def get_dialog_act(self, rule):
|
||||||
|
slots = []
|
||||||
|
self.get_slots(rule.expansion, slots)
|
||||||
|
return {'act': rule.name, 'slots': slots}
|
||||||
|
|
||||||
|
def get_slots(self, expansion, slots):
|
||||||
|
if expansion.tag != '':
|
||||||
|
slots.append((expansion.tag, expansion.current_match))
|
||||||
|
return
|
||||||
|
|
||||||
|
for child in expansion.children:
|
||||||
|
self.get_slots(child, slots)
|
||||||
|
|
||||||
|
if not expansion.children and isinstance(expansion, jsgf.NamedRuleRef):
|
||||||
|
self.get_slots(expansion.referenced_rule.expansion, slots)
|
||||||
|
|
||||||
|
def __call__(self, prompt) -> Any:
|
||||||
|
book_grammar = jsgf.parse_grammar_file('book.jsgf')
|
||||||
|
|
||||||
|
prompt = unidecode(prompt)
|
||||||
|
translator = str.maketrans('', '', string.punctuation)
|
||||||
|
prompt = prompt.translate(translator)
|
||||||
|
|
||||||
|
matched = book_grammar.find_matching_rules(prompt)
|
||||||
|
|
||||||
|
if matched:
|
||||||
|
return self.get_dialog_act(matched[0])
|
||||||
|
else:
|
||||||
|
return {'act': 'null', 'slots': []}
|
||||||
|
|
||||||
|
|
||||||
|
class DST(dst.DST):
|
||||||
|
|
||||||
|
def __init__(self, state):
|
||||||
|
dst.DST.__init__(self)
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
def __call__(self, user_act) -> Any:
|
||||||
|
if len(user_act['slots']) == 0:
|
||||||
|
user_act = [(user_act['act'], None, None)]
|
||||||
|
else:
|
||||||
|
user_act = [(user_act['act'], k, v) for k, v in user_act['slots'] if v is not None]
|
||||||
|
|
||||||
|
self.state['request_state'] = {}
|
||||||
|
for act, slot, value in user_act:
|
||||||
|
self.state['user_action'].append(act)
|
||||||
|
|
||||||
|
if act == "platnosc":
|
||||||
|
self.state['belief_state']['payment_method'] = value
|
||||||
|
self.state['request_state'] = ['payment_method']
|
||||||
|
|
||||||
|
elif act == "offer":
|
||||||
|
self.state['request_state'] = ['menu']
|
||||||
|
|
||||||
|
elif act == 'select':
|
||||||
|
if slot == 'dish':
|
||||||
|
self.state['belief_state']['dish'].append(value)
|
||||||
|
else:
|
||||||
|
self.state['belief_state'][slot] = value
|
||||||
|
self.state['request_state'] = [slot]
|
||||||
|
|
||||||
|
elif act == 'inform':
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif act == 'request':
|
||||||
|
pass
|
||||||
|
elif act == 'restart':
|
||||||
|
self.state["belief_state"] = default_state()["belief_state"]
|
||||||
|
self.state["booked"] = {}
|
||||||
|
self.state["request_state"] = []
|
||||||
|
self.state["terminated"] = False
|
||||||
|
self.state["history"] = []
|
||||||
|
return self.state
|
||||||
|
|
||||||
|
|
||||||
|
class DP():
|
||||||
|
def __init__(self, state):
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
def __call__(self) -> Any:
|
||||||
|
system_action = None
|
||||||
|
|
||||||
|
if self.state['user_action'][-1] == 'hello':
|
||||||
|
system_action = 'welcomemsg'
|
||||||
|
# przywitaj uzytkownika (i pokaz menu)
|
||||||
|
|
||||||
|
elif self.state['user_action'][-1] == 'select':
|
||||||
|
system_action = 'inform'
|
||||||
|
# poinformuj o wybranych slotach z "request_state"
|
||||||
|
|
||||||
|
elif (self.state['user_action'][-1] == 'help'
|
||||||
|
or self.state['user_action'][-1] == 'offer'
|
||||||
|
or self.state['user_action'][-1] == 'reqmore'
|
||||||
|
or (self.state['user_action'][-1] == 'request' and len(self.state['request_state']) == 0)
|
||||||
|
):
|
||||||
|
system_action = 'offer'
|
||||||
|
# zaoferuj cale menu
|
||||||
|
|
||||||
|
elif self.state['user_action'][-1] == 'ack':
|
||||||
|
address = self.state["belief_state"]["address"]
|
||||||
|
payment_method = self.state["belief_state"]["payment_method"]
|
||||||
|
dish = self.state["belief_state"]["dish"]
|
||||||
|
# W przypadku braku szczegolnej informacji o czasie zamówienia zamawiamy natychmiast
|
||||||
|
|
||||||
|
if address and payment_method and dish:
|
||||||
|
system_action = 'bye'
|
||||||
|
self.state['terminated'] = True
|
||||||
|
# potwierdz i zakoncz, podsumuj zamowienie
|
||||||
|
else:
|
||||||
|
system_action = 'canthelp.missing_slot_value'
|
||||||
|
elif self.state['user_action'][-1] == 'restart':
|
||||||
|
system_action = 'welcomemsg'
|
||||||
|
# zachowaj sie jak na poczatku rozmowy
|
||||||
|
else:
|
||||||
|
system_action = 'inform'
|
||||||
|
# poinformuj o wybranych slotach z "request_state"
|
||||||
|
# lub o wszystkich jezeli nic nie ma w request state
|
||||||
|
|
||||||
|
self.state['system_action'].append(system_action)
|
||||||
|
return system_action
|
||||||
|
|
||||||
|
|
||||||
|
class NLG():
|
||||||
|
def __init__(self, state):
|
||||||
|
self.model = AutoModelForSeq2SeqLM.from_pretrained("filnow/nlg-umt5-pol")
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
|
||||||
|
self.nlg_pipeline = pipeline('summarization', model=self.model, tokenizer=self.tokenizer)
|
||||||
|
|
||||||
|
def __call__(self, act, slots) -> Any:
|
||||||
|
if act == 'welcomemsg':
|
||||||
|
return "Witaj w naszej restauracji! Jak mogę Ci pomóc?"
|
||||||
|
|
||||||
|
elif act == "offer":
|
||||||
|
if slots == []:
|
||||||
|
return "Przepraszam nie rozumiem. Podaj więcej informacji."
|
||||||
|
|
||||||
|
elif act == "inform":
|
||||||
|
if slots == []:
|
||||||
|
return "Przepraszam nie rozumiem. Podaj więcej informacji."
|
||||||
|
else:
|
||||||
|
text = []
|
||||||
|
for i in slots:
|
||||||
|
if i[1] != None:
|
||||||
|
text.append(f"{i[0]}[{i[1]}]")
|
||||||
|
return self.nlg_pipeline(f'generate text: {", ".join(text)}')[0]['summary_text']
|
||||||
|
|
||||||
|
elif act == "canthelp.missing_slot_value":
|
||||||
|
return "Przepraszam, ale nie mogę zrealizować zamówienia. Brakuje mi niektórych informacji. Czy mogę pomóc w czymś innym?"
|
||||||
|
|
||||||
|
elif act == "bye":
|
||||||
|
return "Dziękujemy za zamówienie! Smacznego!"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model = Model()
|
||||||
|
|
||||||
|
# jezeli sie przywita to przywitaj uzytkownika (i pokaz menu)
|
||||||
|
# response = model("Cześć")
|
||||||
|
# response = model("Witam")
|
||||||
|
# response = model("Witam system")
|
||||||
|
# response = model("Hej, jakim botem jesteś?")
|
||||||
|
# response = model("Hej, czym się zajmujesz?")
|
||||||
|
# response = model("Hej, w czym mi możesz pomóc?")
|
||||||
|
response = model("Siema, w czym możesz mi pomóc?")
|
||||||
|
assert response == "welcomemsg"
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# jezeli prosi o pomoc lub po prostu o menu to zaoferuj cale menu
|
||||||
|
# response = model("Pokaz menu")
|
||||||
|
# response = model("A co do picia proponujesz?")
|
||||||
|
# response = model("Jakie inne desery oferujesz?")
|
||||||
|
response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.")
|
||||||
|
assert response == "offer"
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# jezeli wybierze danie to zapisz wybor i poinformuj o nim
|
||||||
|
# response = model("Wezmę rybe")
|
||||||
|
# response = model("Poproszę tatara")
|
||||||
|
response = model("Chciałbym zjesc tatara")
|
||||||
|
assert response == "inform"
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# jezeli poda adres to zapisze wybor i poinformuj o nim
|
||||||
|
# response = model('Poproszę na poznańską 2')
|
||||||
|
response = model("uniwersytetu poznanskiego 4 61-614 poznan")
|
||||||
|
assert response == "inform"
|
||||||
|
|
||||||
|
# jezeli sprobuje dokonac zamowienia bez podania potrzebnych informacji prosimy o nie
|
||||||
|
#response = model("Dobrze, nie mogę się już doczekać.")
|
||||||
|
response = model("Super, to zatem wszystko!")
|
||||||
|
assert response == "canthelp.missing_slot_value"
|
||||||
|
|
||||||
|
# jezeli wybierze rodzaj platnosci to zapisz wybor i poinformuj o nim
|
||||||
|
# response = model("karta")
|
||||||
|
# response = model("Poproszę blikiem z góry")
|
||||||
|
response = model("Zapłacę kartą przy odbiorze")
|
||||||
|
assert response == "inform"
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# jezeli potwiedzi zamowienie to zakoncz zamawianie sukcesem i wypisz calosc
|
||||||
|
# response = model("Potwierdzam!")
|
||||||
|
# response = model("Tak!")
|
||||||
|
# response = model("Tak to wszystko!")
|
||||||
|
# response = model("Super, to zatem wszystko!")
|
||||||
|
response = model("Dobrze, nie mogę się już doczekać.")
|
||||||
|
assert response == "bye"
|
||||||
|
|
||||||
|
print("----Konwersacja z restartem-------")
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
response = model("Siema, w czym możesz mi pomóc?")
|
||||||
|
assert response == "welcomemsg"
|
||||||
|
response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.")
|
||||||
|
assert response == "offer"
|
||||||
|
response = model("Chciałbym zjesc tatara")
|
||||||
|
assert response == "inform"
|
||||||
|
response = model("uniwersytetu poznanskiego 4 61-614 poznan")
|
||||||
|
assert response == "inform"
|
||||||
|
response = model("od nowa")
|
||||||
|
assert response == "welcomemsg"
|
||||||
|
response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.")
|
||||||
|
assert response == "offer"
|
||||||
|
response = model("Chciałbym zjesc tatara")
|
||||||
|
assert response == "inform"
|
||||||
|
response = model("uniwersytetu poznanskiego 4 61-614 poznan")
|
||||||
|
assert response == "inform"
|
||||||
|
response = model("Zapłacę kartą przy odbiorze")
|
||||||
|
assert response == "inform"
|
||||||
|
response = model("Dobrze, nie mogę się już doczekać.")
|
||||||
|
assert response == "bye"
|
231
nlg_train.ipynb
231
nlg_train.ipynb
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user