From 9dffcd9369c50c428f4a49cbf7812b2fdaeecefb Mon Sep 17 00:00:00 2001 From: Cyganik Date: Tue, 23 Apr 2024 15:28:20 +0200 Subject: [PATCH] add nlu --- book.jsgf | 99 +++++++++++++++++++++++++++++++++++++++++++++++++ dialog_model.py | 37 +++++++++++++++--- evaluate.py | 0 3 files changed, 131 insertions(+), 5 deletions(-) create mode 100644 book.jsgf create mode 100644 evaluate.py diff --git a/book.jsgf b/book.jsgf new file mode 100644 index 0000000..b0561c5 --- /dev/null +++ b/book.jsgf @@ -0,0 +1,99 @@ +#JSGF V1.0 UTF-8 pl; + +grammar book; + +public = chcialbym zarezerwowac stolik ; + + = (na | w) {day}; + + = dzisiaj | jutro | poniedzialek | wtorek | srode | czwartek | piatek | sobote | niedziele; + + = (na | o) [godzine] {hour}; + + = []; + + = dziewiata | dziesiata | jedenasta | dwunasta; + + = pietnas | trzydziesci; + + = (na | dla) {size} osob; + + = dwie | dwoch | trzy | trzech | cztery | czterech | piec | pieciu | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10; + + = | | | ; + + = Tatar wolowy | Carpaccio z buraka | Salatka caprese; + + = Krem z dyni | Rosol z makaronem | Zupa grzybowa; + + = Stek z poledwicy wolowej | Pieczona kaczka z jablkami | Lasagne | Ryba z grilla | Risotto z kurczakiem i grzybami | Pierogi ruskie | Placki ziemniaczane; + + = Tiramisu | Szarlotka | Lody waniliowe; + + = * []; + + = | ; + + = cola | woda | lemoniada | piwo; + + = (potwierdzam | ok | dobrze | zgoda); + + = (czekaj | prosze czekac | odbior za minut); + + = z sosem smietanowym; + + = bez cebuli; + + = Tatar wolowy | Zupa grzybowa | Stek z poledwicy wolowej; + + = na ; + + = na ; + + = porcje; + + = Pamietajcie zeby bylo cieple; + + = minut; + + = Dziekuje | Dziekuje to wszystko | Okej to wszystko | To wszystko | Dziekuje za pomoc; + + = Pelen profesjonalizm | Swietnie; + + = Za drogo | Jednak nie zjem | To nie to; + + = z odbiorem osobistym | z dowozem | na wynos; + + = na pojutrze; + + = Twoje zamowienie bedzie gotowe ; + + = Wroc do nas pojutrze; + + = | | | | | | | | | | | | | | | ; + + = | | ; + + = ; + + = + ; + + = ; + + = ; + + = ; + + = ; + + = ; + + = + | + ; + + = ; + + = + ; + + = ; + + = + ; diff --git a/dialog_model.py b/dialog_model.py index 270b37e..72f979a 100644 --- a/dialog_model.py +++ b/dialog_model.py @@ -1,4 +1,6 @@ from typing import Any +import jsgf +import argparse class Model(): @@ -11,9 +13,10 @@ class Model(): dp = DP() nlg = NLG() - msg = prompt + msg = prompt.lower() rama_nlu = nlu(msg) + print(rama_nlu) rama_dst = dst(rama_nlu) rama_dp = dp(rama_dst) text = nlg(rama_dp) @@ -25,10 +28,31 @@ class NLU(): def __init__(self): pass + def get_dialog_act(self, rule): + slots = [] + self.get_slots(rule.expansion, slots) + return {'act': rule.grammar.name, 'slots': slots} + + def get_slots(self, expansion, slots): + if expansion.tag != '': + slots.append((expansion.tag, expansion.current_match)) + return + + for child in expansion.children: + self.get_slots(child, slots) + + if not expansion.children and isinstance(expansion, jsgf.NamedRuleRef): + self.get_slots(expansion.referenced_rule.expansion, slots) + def __call__(self, prompt) -> Any: - msg = prompt - if "imie" in msg: - return "jakie imie" + book_grammar = jsgf.parse_grammar_file('book.jsgf') + + matched = book_grammar.find_matching_rules(prompt) + + if matched: + return self.get_dialog_act(matched[0]) + else: + return {'act': 'null', 'slots': []} class DST(): @@ -64,6 +88,9 @@ class NLG(): if __name__ == "__main__": model = Model() + parser = argparse.ArgumentParser() + parser.add_argument("--msg") + args = parser.parse_args() - ans = model(prompt="Jak masz na imie") + ans = model(prompt=args.msg) print(ans) diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..e69de29