121 lines
3.5 KiB
Python
121 lines
3.5 KiB
Python
from typing import Any
|
|
import jsgf
|
|
from unidecode import unidecode
|
|
import string
|
|
import argparse
|
|
|
|
|
|
class Model():
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, prompt) -> Any:
|
|
nlu = NLU()
|
|
dst = DST()
|
|
dp = DP()
|
|
nlg = NLG()
|
|
|
|
msg = prompt.lower()
|
|
|
|
rama_nlu = nlu(msg)
|
|
print(rama_nlu)
|
|
rama_dst = dst(rama_nlu)
|
|
rama_dp = dp(rama_dst)
|
|
text = nlg(rama_dp)
|
|
|
|
return text
|
|
|
|
|
|
class NLU():
|
|
def __init__(self):
|
|
self.book_grammar = jsgf.parse_grammar_file('book.jsgf')
|
|
|
|
def get_dialog_act(self, rule):
|
|
slots = []
|
|
self.get_slots(rule.expansion, slots)
|
|
return {'act': rule.name, 'slots': slots}
|
|
|
|
def get_slots(self, expansion, slots):
|
|
if expansion.tag != '':
|
|
slots.append((expansion.tag, expansion.current_match))
|
|
return
|
|
|
|
for child in expansion.children:
|
|
self.get_slots(child, slots)
|
|
|
|
if not expansion.children and isinstance(expansion, jsgf.NamedRuleRef):
|
|
self.get_slots(expansion.referenced_rule.expansion, slots)
|
|
|
|
def __call__(self, prompt) -> Any:
|
|
book_grammar = jsgf.parse_grammar_file('book.jsgf')
|
|
|
|
prompt = unidecode(prompt)
|
|
translator = str.maketrans('', '', string.punctuation)
|
|
prompt = prompt.translate(translator)
|
|
|
|
matched = book_grammar.find_matching_rules(prompt)
|
|
|
|
if matched:
|
|
return self.get_dialog_act(matched[0])
|
|
else:
|
|
return {'act': 'null', 'slots': []}
|
|
|
|
|
|
class DST():
|
|
def __init__(self):
|
|
self.msgs = []
|
|
|
|
def __call__(self, msg) -> Any:
|
|
self.msgs.append(msg)
|
|
return msg
|
|
|
|
|
|
class DP():
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, msg) -> Any:
|
|
if "imie" in msg:
|
|
return "imieMSG"
|
|
else:
|
|
return None
|
|
|
|
|
|
class NLG():
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, msg) -> Any:
|
|
if msg == "imieMSG":
|
|
return "Mam na imie JARVIS"
|
|
else:
|
|
return "Nie rozumiem"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = Model()
|
|
# parser = argparse.ArgumentParser()
|
|
# parser.add_argument("--msg")
|
|
# args = parser.parse_args()
|
|
|
|
# print(model(prompt="chcialbym zarezerwowac stolik na jutro na dziesiata dla trzech osob"))
|
|
# print(model(prompt="Cześć"))
|
|
# print(model(prompt="Hej, jakim botem jesteś?"))
|
|
# print(model(prompt="Hej, w czym mi możesz pomóc?"))
|
|
# print(model(prompt="Siema, w czym możesz mi pomóc?"))
|
|
# print(model(prompt="Witam"))
|
|
# print(model(prompt="Witam system"))
|
|
# print(model(prompt="Hej, czym się zajmujesz?"))
|
|
# print(model(prompt="Czesc, jestem agentem dialogowym przyjmujacym zamowienia w restauracji. Moge doradzic ci w wyborze odpowiedniej pozycji z menu. W czym moge ci pomoc?"))
|
|
# print(model(prompt="oki"))
|
|
# print(model(prompt="Potwierdzam!"))
|
|
# print(model(prompt="Tak!"))
|
|
# print(model(prompt="Tak to wszystko!"))
|
|
print(model(prompt="interesuja mnie dania kuchni woskiej oraz meksykanskiej"))
|
|
print(model(prompt="poprosze ryba"))
|
|
print(model(prompt="poprosze tatara"))
|
|
print(model(prompt="poprosze 2 porcje"))
|
|
print(model(prompt="zaplace karta przy odbiorze"))
|
|
print(model(prompt="dobrze nie moge sie juz doczekac"))
|
|
print(model(prompt="uniwersytetu poznanskiego 4 61-614 poznan"))
|