add: nlg model
This commit is contained in:
parent
88741c2d2b
commit
60a298fa49
295
dialog_with_nlg.py
Normal file
295
dialog_with_nlg.py
Normal file
@ -0,0 +1,295 @@
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import jsgf
|
||||
from unidecode import unidecode
|
||||
|
||||
from convlab.dst import dst
|
||||
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
pipeline,
|
||||
)
|
||||
|
||||
|
||||
def default_state():
|
||||
return dict(
|
||||
user_action=[],
|
||||
system_action=[],
|
||||
belief_state={
|
||||
'address': '',
|
||||
'payment_method': '',
|
||||
'dish': [],
|
||||
'time': ''
|
||||
},
|
||||
booked={},
|
||||
request_state=[],
|
||||
terminated=False,
|
||||
history=[]
|
||||
)
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.state = default_state()
|
||||
self.nlu = NLU()
|
||||
self.dst = DST(self.state)
|
||||
self.dp = DP(self.state)
|
||||
self.nlg = NLG(self.state)
|
||||
|
||||
def __call__(self, prompt) -> Any:
|
||||
print(prompt)
|
||||
msg = prompt.lower()
|
||||
|
||||
r = self.nlu(msg)
|
||||
slots = r['slots']
|
||||
#print(r)
|
||||
r = self.dst(r)
|
||||
#print(r)
|
||||
r = self.dp()
|
||||
#print(r)
|
||||
r = self.nlg(r, slots)
|
||||
print(r)
|
||||
|
||||
return r
|
||||
|
||||
|
||||
class NLU():
|
||||
def __init__(self):
|
||||
self.book_grammar = jsgf.parse_grammar_file('book.jsgf')
|
||||
|
||||
def get_dialog_act(self, rule):
|
||||
slots = []
|
||||
self.get_slots(rule.expansion, slots)
|
||||
return {'act': rule.name, 'slots': slots}
|
||||
|
||||
def get_slots(self, expansion, slots):
|
||||
if expansion.tag != '':
|
||||
slots.append((expansion.tag, expansion.current_match))
|
||||
return
|
||||
|
||||
for child in expansion.children:
|
||||
self.get_slots(child, slots)
|
||||
|
||||
if not expansion.children and isinstance(expansion, jsgf.NamedRuleRef):
|
||||
self.get_slots(expansion.referenced_rule.expansion, slots)
|
||||
|
||||
def __call__(self, prompt) -> Any:
|
||||
book_grammar = jsgf.parse_grammar_file('book.jsgf')
|
||||
|
||||
prompt = unidecode(prompt)
|
||||
translator = str.maketrans('', '', string.punctuation)
|
||||
prompt = prompt.translate(translator)
|
||||
|
||||
matched = book_grammar.find_matching_rules(prompt)
|
||||
|
||||
if matched:
|
||||
return self.get_dialog_act(matched[0])
|
||||
else:
|
||||
return {'act': 'null', 'slots': []}
|
||||
|
||||
|
||||
class DST(dst.DST):
|
||||
|
||||
def __init__(self, state):
|
||||
dst.DST.__init__(self)
|
||||
self.state = state
|
||||
|
||||
def __call__(self, user_act) -> Any:
|
||||
if len(user_act['slots']) == 0:
|
||||
user_act = [(user_act['act'], None, None)]
|
||||
else:
|
||||
user_act = [(user_act['act'], k, v) for k, v in user_act['slots'] if v is not None]
|
||||
|
||||
self.state['request_state'] = {}
|
||||
for act, slot, value in user_act:
|
||||
self.state['user_action'].append(act)
|
||||
|
||||
if act == "platnosc":
|
||||
self.state['belief_state']['payment_method'] = value
|
||||
self.state['request_state'] = ['payment_method']
|
||||
|
||||
elif act == "offer":
|
||||
self.state['request_state'] = ['menu']
|
||||
|
||||
elif act == 'select':
|
||||
if slot == 'dish':
|
||||
self.state['belief_state']['dish'].append(value)
|
||||
else:
|
||||
self.state['belief_state'][slot] = value
|
||||
self.state['request_state'] = [slot]
|
||||
|
||||
elif act == 'inform':
|
||||
pass
|
||||
|
||||
elif act == 'request':
|
||||
pass
|
||||
elif act == 'restart':
|
||||
self.state["belief_state"] = default_state()["belief_state"]
|
||||
self.state["booked"] = {}
|
||||
self.state["request_state"] = []
|
||||
self.state["terminated"] = False
|
||||
self.state["history"] = []
|
||||
return self.state
|
||||
|
||||
|
||||
class DP():
|
||||
def __init__(self, state):
|
||||
self.state = state
|
||||
|
||||
def __call__(self) -> Any:
|
||||
system_action = None
|
||||
|
||||
if self.state['user_action'][-1] == 'hello':
|
||||
system_action = 'welcomemsg'
|
||||
# przywitaj uzytkownika (i pokaz menu)
|
||||
|
||||
elif self.state['user_action'][-1] == 'select':
|
||||
system_action = 'inform'
|
||||
# poinformuj o wybranych slotach z "request_state"
|
||||
|
||||
elif (self.state['user_action'][-1] == 'help'
|
||||
or self.state['user_action'][-1] == 'offer'
|
||||
or self.state['user_action'][-1] == 'reqmore'
|
||||
or (self.state['user_action'][-1] == 'request' and len(self.state['request_state']) == 0)
|
||||
):
|
||||
system_action = 'offer'
|
||||
# zaoferuj cale menu
|
||||
|
||||
elif self.state['user_action'][-1] == 'ack':
|
||||
address = self.state["belief_state"]["address"]
|
||||
payment_method = self.state["belief_state"]["payment_method"]
|
||||
dish = self.state["belief_state"]["dish"]
|
||||
# W przypadku braku szczegolnej informacji o czasie zamówienia zamawiamy natychmiast
|
||||
|
||||
if address and payment_method and dish:
|
||||
system_action = 'bye'
|
||||
self.state['terminated'] = True
|
||||
# potwierdz i zakoncz, podsumuj zamowienie
|
||||
else:
|
||||
system_action = 'canthelp.missing_slot_value'
|
||||
elif self.state['user_action'][-1] == 'restart':
|
||||
system_action = 'welcomemsg'
|
||||
# zachowaj sie jak na poczatku rozmowy
|
||||
else:
|
||||
system_action = 'inform'
|
||||
# poinformuj o wybranych slotach z "request_state"
|
||||
# lub o wszystkich jezeli nic nie ma w request state
|
||||
|
||||
self.state['system_action'].append(system_action)
|
||||
return system_action
|
||||
|
||||
|
||||
class NLG():
|
||||
def __init__(self, state):
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained("filnow/nlg-umt5-pol")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
|
||||
self.nlg_pipeline = pipeline('summarization', model=self.model, tokenizer=self.tokenizer)
|
||||
|
||||
def __call__(self, act, slots) -> Any:
|
||||
if act == 'welcomemsg':
|
||||
return "Witaj w naszej restauracji! Jak mogę Ci pomóc?"
|
||||
|
||||
elif act == "offer":
|
||||
if slots == []:
|
||||
return "Przepraszam nie rozumiem. Podaj więcej informacji."
|
||||
|
||||
elif act == "inform":
|
||||
if slots == []:
|
||||
return "Przepraszam nie rozumiem. Podaj więcej informacji."
|
||||
else:
|
||||
text = []
|
||||
for i in slots:
|
||||
if i[1] != None:
|
||||
text.append(f"{i[0]}[{i[1]}]")
|
||||
return self.nlg_pipeline(f'generate text: {", ".join(text)}')[0]['summary_text']
|
||||
|
||||
elif act == "canthelp.missing_slot_value":
|
||||
return "Przepraszam, ale nie mogę zrealizować zamówienia. Brakuje mi niektórych informacji. Czy mogę pomóc w czymś innym?"
|
||||
|
||||
elif act == "bye":
|
||||
return "Dziękujemy za zamówienie! Smacznego!"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = Model()
|
||||
|
||||
# jezeli sie przywita to przywitaj uzytkownika (i pokaz menu)
|
||||
# response = model("Cześć")
|
||||
# response = model("Witam")
|
||||
# response = model("Witam system")
|
||||
# response = model("Hej, jakim botem jesteś?")
|
||||
# response = model("Hej, czym się zajmujesz?")
|
||||
# response = model("Hej, w czym mi możesz pomóc?")
|
||||
response = model("Siema, w czym możesz mi pomóc?")
|
||||
assert response == "welcomemsg"
|
||||
|
||||
print()
|
||||
|
||||
# jezeli prosi o pomoc lub po prostu o menu to zaoferuj cale menu
|
||||
# response = model("Pokaz menu")
|
||||
# response = model("A co do picia proponujesz?")
|
||||
# response = model("Jakie inne desery oferujesz?")
|
||||
response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.")
|
||||
assert response == "offer"
|
||||
|
||||
print()
|
||||
|
||||
# jezeli wybierze danie to zapisz wybor i poinformuj o nim
|
||||
# response = model("Wezmę rybe")
|
||||
# response = model("Poproszę tatara")
|
||||
response = model("Chciałbym zjesc tatara")
|
||||
assert response == "inform"
|
||||
|
||||
print()
|
||||
|
||||
# jezeli poda adres to zapisze wybor i poinformuj o nim
|
||||
# response = model('Poproszę na poznańską 2')
|
||||
response = model("uniwersytetu poznanskiego 4 61-614 poznan")
|
||||
assert response == "inform"
|
||||
|
||||
# jezeli sprobuje dokonac zamowienia bez podania potrzebnych informacji prosimy o nie
|
||||
#response = model("Dobrze, nie mogę się już doczekać.")
|
||||
response = model("Super, to zatem wszystko!")
|
||||
assert response == "canthelp.missing_slot_value"
|
||||
|
||||
# jezeli wybierze rodzaj platnosci to zapisz wybor i poinformuj o nim
|
||||
# response = model("karta")
|
||||
# response = model("Poproszę blikiem z góry")
|
||||
response = model("Zapłacę kartą przy odbiorze")
|
||||
assert response == "inform"
|
||||
|
||||
print()
|
||||
|
||||
# jezeli potwiedzi zamowienie to zakoncz zamawianie sukcesem i wypisz calosc
|
||||
# response = model("Potwierdzam!")
|
||||
# response = model("Tak!")
|
||||
# response = model("Tak to wszystko!")
|
||||
# response = model("Super, to zatem wszystko!")
|
||||
response = model("Dobrze, nie mogę się już doczekać.")
|
||||
assert response == "bye"
|
||||
|
||||
print("----Konwersacja z restartem-------")
|
||||
|
||||
model = Model()
|
||||
response = model("Siema, w czym możesz mi pomóc?")
|
||||
assert response == "welcomemsg"
|
||||
response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.")
|
||||
assert response == "offer"
|
||||
response = model("Chciałbym zjesc tatara")
|
||||
assert response == "inform"
|
||||
response = model("uniwersytetu poznanskiego 4 61-614 poznan")
|
||||
assert response == "inform"
|
||||
response = model("od nowa")
|
||||
assert response == "welcomemsg"
|
||||
response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.")
|
||||
assert response == "offer"
|
||||
response = model("Chciałbym zjesc tatara")
|
||||
assert response == "inform"
|
||||
response = model("uniwersytetu poznanskiego 4 61-614 poznan")
|
||||
assert response == "inform"
|
||||
response = model("Zapłacę kartą przy odbiorze")
|
||||
assert response == "inform"
|
||||
response = model("Dobrze, nie mogę się już doczekać.")
|
||||
assert response == "bye"
|
231
nlg_train.ipynb
231
nlg_train.ipynb
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user