343 lines
11 KiB
Python
343 lines
11 KiB
Python
import string
|
|
from typing import Any, List, Tuple
|
|
|
|
import jsgf
|
|
from unidecode import unidecode
|
|
import random
|
|
from convlab.dst import dst
|
|
|
|
from transformers import (
|
|
AutoModelForSeq2SeqLM,
|
|
AutoTokenizer,
|
|
pipeline,
|
|
)
|
|
|
|
from transformers.utils import logging
|
|
|
|
logging.set_verbosity_error()
|
|
|
|
|
|
def default_state():
|
|
return dict(
|
|
user_action=[],
|
|
system_action=[],
|
|
belief_state={
|
|
'address': '',
|
|
'payment_method': '',
|
|
'dish': [],
|
|
'time': ''
|
|
},
|
|
booked={},
|
|
request_state=[],
|
|
terminated=False,
|
|
history=[]
|
|
)
|
|
|
|
|
|
class Model:
|
|
def __init__(self):
|
|
self.state = default_state()
|
|
self.nlu = NLU()
|
|
self.dst = DST(self.state)
|
|
self.dp = DP(self.state)
|
|
self.nlg = NLG(self.state)
|
|
|
|
def __call__(self, prompt, debug=True) -> Any:
|
|
if debug:
|
|
print(prompt)
|
|
msg = prompt.lower()
|
|
|
|
r = self.nlu(msg)
|
|
slots = r['slots']
|
|
# print(r)
|
|
r = self.dst(r)
|
|
# print(r)
|
|
r = self.dp()
|
|
# print(r)
|
|
r = self.nlg(r, slots)
|
|
if debug:
|
|
print(r)
|
|
else:
|
|
print(f"JARVIS: {r}")
|
|
|
|
return r
|
|
|
|
|
|
class NLU():
|
|
def __init__(self):
|
|
self.book_grammar = jsgf.parse_grammar_file('book.jsgf')
|
|
|
|
def get_dialog_act(self, rule):
|
|
slots = []
|
|
self.get_slots(rule.expansion, slots)
|
|
return {'act': rule.name, 'slots': slots}
|
|
|
|
def get_slots(self, expansion, slots):
|
|
if expansion.tag != '':
|
|
slots.append((expansion.tag, expansion.current_match))
|
|
return
|
|
|
|
for child in expansion.children:
|
|
self.get_slots(child, slots)
|
|
|
|
if not expansion.children and isinstance(expansion, jsgf.NamedRuleRef):
|
|
self.get_slots(expansion.referenced_rule.expansion, slots)
|
|
|
|
def __call__(self, prompt) -> Any:
|
|
book_grammar = jsgf.parse_grammar_file('book.jsgf')
|
|
|
|
prompt = unidecode(prompt)
|
|
translator = str.maketrans('', '', string.punctuation)
|
|
prompt = prompt.translate(translator)
|
|
|
|
matched = book_grammar.find_matching_rules(prompt)
|
|
|
|
if matched:
|
|
return self.get_dialog_act(matched[0])
|
|
else:
|
|
return {'act': 'null', 'slots': []}
|
|
|
|
|
|
class DST(dst.DST):
|
|
|
|
def __init__(self, state):
|
|
dst.DST.__init__(self)
|
|
self.state = state
|
|
|
|
def __call__(self, user_act) -> Any:
|
|
if len(user_act['slots']) == 0:
|
|
user_act = [(user_act['act'], None, None)]
|
|
else:
|
|
user_act = [(user_act['act'], k, v) for k, v in user_act['slots'] if v is not None]
|
|
|
|
self.state['request_state'] = {}
|
|
for act, slot, value in user_act:
|
|
self.state['user_action'].append(act)
|
|
|
|
if act == "platnosc":
|
|
self.state['belief_state']['payment_method'] = value
|
|
self.state['request_state'] = ['payment_method']
|
|
|
|
elif act == "offer":
|
|
self.state['request_state'] = ['menu']
|
|
|
|
elif act == 'select':
|
|
if slot == 'dish':
|
|
self.state['belief_state']['dish'].append(value)
|
|
else:
|
|
self.state['belief_state'][slot] = value
|
|
self.state['request_state'] = [slot]
|
|
|
|
elif act == 'inform':
|
|
pass
|
|
|
|
elif act == 'request':
|
|
pass
|
|
elif act == 'restart':
|
|
self.state["belief_state"] = default_state()["belief_state"]
|
|
self.state["booked"] = {}
|
|
self.state["request_state"] = []
|
|
self.state["terminated"] = False
|
|
self.state["history"] = []
|
|
return self.state
|
|
|
|
|
|
class DP():
|
|
def __init__(self, state):
|
|
self.state = state
|
|
|
|
def __call__(self) -> Any:
|
|
system_action = None
|
|
|
|
if self.state['user_action'][-1] == 'hello':
|
|
system_action = 'welcomemsg'
|
|
# przywitaj uzytkownika (i pokaz menu)
|
|
|
|
elif self.state['user_action'][-1] == 'select':
|
|
system_action = 'inform'
|
|
# poinformuj o wybranych slotach z "request_state"
|
|
|
|
elif (self.state['user_action'][-1] == 'help'
|
|
or self.state['user_action'][-1] == 'offer'
|
|
or self.state['user_action'][-1] == 'reqmore'
|
|
or (self.state['user_action'][-1] == 'request' and len(self.state['request_state']) == 0)
|
|
):
|
|
system_action = 'offer'
|
|
# zaoferuj cale menu
|
|
|
|
elif self.state['user_action'][-1] == 'ack':
|
|
address = self.state["belief_state"]["address"]
|
|
payment_method = self.state["belief_state"]["payment_method"]
|
|
dish = self.state["belief_state"]["dish"]
|
|
# W przypadku braku szczegolnej informacji o czasie zamówienia zamawiamy natychmiast
|
|
|
|
if address and payment_method and dish:
|
|
system_action = 'bye'
|
|
self.state['terminated'] = True
|
|
# potwierdz i zakoncz, podsumuj zamowienie
|
|
else:
|
|
system_action = 'canthelp.missing_slot_value'
|
|
elif self.state['user_action'][-1] == 'restart':
|
|
system_action = 'welcomemsg'
|
|
# zachowaj sie jak na poczatku rozmowy
|
|
else:
|
|
system_action = 'inform'
|
|
# poinformuj o wybranych slotach z "request_state"
|
|
# lub o wszystkich jezeli nic nie ma w request state
|
|
|
|
self.state['system_action'].append(system_action)
|
|
return system_action
|
|
|
|
|
|
class NLG():
|
|
def __init__(self, state):
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained("filnow/nlg-umt5-pol")
|
|
self.tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
|
|
self.nlg_pipeline = pipeline('summarization', model=self.model, tokenizer=self.tokenizer)
|
|
|
|
self.messages = {
|
|
"welcomemsg": [
|
|
"Witaj w naszej restauracji! Jak mogę Ci pomóc?",
|
|
"Witaj! W czym mogę pomóc?",
|
|
"Hej! Co mogę dla Ciebie zrobić?"
|
|
],
|
|
"canthelp": [
|
|
"Przepraszam, nie mogę pomóc w tej chwili.",
|
|
"Nie jestem w stanie pomóc.",
|
|
"Przepraszam, nie rozumiem."
|
|
],
|
|
"bye": [
|
|
"Dziękujemy za zamówienie! Smacznego!",
|
|
"Smacznego! Do zobaczenia!",
|
|
"Dziękujemy za skorzystanie z naszych usług!"
|
|
],
|
|
"affirm": [
|
|
"Zamówienie zostało złożone!",
|
|
"Potwierdzam zamówienie!",
|
|
"Skladam zamówienie!"
|
|
],
|
|
"repeat": [
|
|
"Możesz powtórzyć?",
|
|
"Nie zrozumiałem, możesz powtórzyć?",
|
|
"Nie zrozumiałem, możesz powtórzyć jeszcze raz?"
|
|
],
|
|
"reqmore": [
|
|
"Potrzebujesz więcej informacji?",
|
|
"Czy mogę pomóc w czymś jeszcze?",
|
|
"Czy mogę zaoferować coś jeszcze?"
|
|
]
|
|
}
|
|
|
|
def __call__(self, act: str, slots: List[Tuple[str, str]]) -> str:
|
|
if act == 'welcomemsg':
|
|
return random.choice(self.messages["welcomemsg"])
|
|
|
|
elif act in ["inform", "request", "select"]:
|
|
if slots == []:
|
|
return "Przepraszam nie rozumiem. Podaj więcej informacji."
|
|
else:
|
|
text = [f"{slot[0]}[{slot[1]}]" for slot in slots if slot[1] is not None]
|
|
return self.nlg_pipeline(f'generate text: {", ".join(text)}')[0]['summary_text']
|
|
|
|
elif act == "canthelp.missing_slot_value" or act == "canthelp":
|
|
return random.choice(self.messages["canthelp"])
|
|
|
|
elif act == "bye":
|
|
return random.choice(self.messages["bye"])
|
|
|
|
elif act == 'affirm':
|
|
return random.choice(self.messages["affirm"])
|
|
|
|
elif act == "repeat":
|
|
return random.choice(self.messages["repeat"])
|
|
|
|
elif act == "reqmore":
|
|
return random.choice(self.messages["reqmore"])
|
|
|
|
elif act == "offer":
|
|
return "Proszę oto menu zeskanuj kod QR aby je zobaczyć."
|
|
|
|
|
|
def dialogue_test():
|
|
model = Model()
|
|
|
|
# jezeli sie przywita to przywitaj uzytkownika (i pokaz menu)
|
|
# response = model("Cześć")
|
|
# response = model("Witam")
|
|
# response = model("Witam system")
|
|
# response = model("Hej, jakim botem jesteś?")
|
|
# response = model("Hej, czym się zajmujesz?")
|
|
# response = model("Hej, w czym mi możesz pomóc?")
|
|
print()
|
|
response = model("Siema, w czym możesz mi pomóc?")
|
|
print()
|
|
# jezeli prosi o pomoc lub po prostu o menu to zaoferuj cale menu
|
|
# response = model("Pokaz menu")
|
|
# response = model("A co do picia proponujesz?")
|
|
# response = model("Jakie inne desery oferujesz?")
|
|
response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.")
|
|
print()
|
|
# jezeli wybierze danie to zapisz wybor i poinformuj o nim
|
|
# response = model("Poproszę tatara")
|
|
response = model("Chciałbym zjesc tatara")
|
|
print()
|
|
# jezeli poda adres to zapisze wybor i poinformuj o nim
|
|
# response = model('Poproszę na poznańską 2')
|
|
response = model("uniwersytetu poznanskiego 4 61-614 poznan")
|
|
print()
|
|
# jezeli sprobuje dokonac zamowienia bez podania potrzebnych informacji prosimy o nie
|
|
# response = model("Dobrze, nie mogę się już doczekać.")
|
|
response = model("Super, to zatem wszystko!")
|
|
print()
|
|
# jezeli wybierze rodzaj platnosci to zapisz wybor i poinformuj o nim
|
|
# response = model("karta")
|
|
# response = model("Poproszę blikiem z góry")
|
|
response = model("Zapłacę kartą przy odbiorze")
|
|
print()
|
|
# jezeli potwiedzi zamowienie to zakoncz zamawianie sukcesem i wypisz calosc
|
|
# response = model("Potwierdzam!")
|
|
# response = model("Tak!")
|
|
# response = model("Tak to wszystko!")
|
|
# response = model("Super, to zatem wszystko!")
|
|
response = model("Dobrze, nie mogę się już doczekać.")
|
|
print()
|
|
|
|
print("----Konwersacja z restartem-------")
|
|
|
|
model = Model()
|
|
response = model("Siema, w czym możesz mi pomóc?")
|
|
print()
|
|
response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.")
|
|
print()
|
|
response = model("Chciałbym zjesc tatara")
|
|
print()
|
|
response = model("uniwersytetu poznanskiego 4 61-614 poznan")
|
|
print()
|
|
response = model("od nowa")
|
|
print()
|
|
response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.")
|
|
print()
|
|
response = model("Chciałbym zjesc tatara")
|
|
print()
|
|
response = model("uniwersytetu poznanskiego 4 61-614 poznan")
|
|
print()
|
|
response = model("Zapłacę kartą przy odbiorze")
|
|
print()
|
|
response = model("Dobrze, nie mogę się już doczekać.")
|
|
print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = Model()
|
|
print("Chatbot Jarvis\n--------------")
|
|
while True:
|
|
print("\nUżytkownik: ")
|
|
user_input = input()
|
|
while user_input == "SYSTEM_FINISH":
|
|
print("\n\n\n\n\n\n\n\n\n\n\n\n")
|
|
model = Model()
|
|
print("Chatbot Jarvis\n--------------")
|
|
print("\nUżytkownik: ")
|
|
user_input = input()
|
|
response = model(user_input, debug=False)
|