SD-project-22/trailminator/nlu.py

59 lines
1.5 KiB
Python
Raw Normal View History

2022-04-13 13:23:47 +02:00
import re
2022-06-01 12:46:43 +02:00
import jsgf
2022-04-13 13:23:47 +02:00
class Nlu:
def __init__(self):
2022-06-01 12:46:43 +02:00
self.rules_grammar = jsgf.parse_grammar_file('rules.jsgf')
2022-04-13 13:23:47 +02:00
self.acts = {
"request": {
'triggers': ['jak', 'kiedy'],
'parameters': ['imie']
}
}
2022-06-01 12:46:43 +02:00
def get_slots(self, expansion, slots):
if expansion.tag != '':
slots.append((expansion.tag, expansion.current_match))
return
for child in expansion.children:
self.get_slots(child, slots)
if not expansion.children and isinstance(expansion, jsgf.NamedRuleRef):
self.get_slots(expansion.referenced_rule.expansion, slots)
def get_dialog_act(self, rule):
slots = []
self.get_slots(rule.expansion, slots)
return {'act': rule.grammar.name, 'slots': slots}
2022-04-13 13:23:47 +02:00
def tokenize(self, string):
clean_string = self.get_str_cleaned(string)
2022-06-01 12:46:43 +02:00
matched = self.rules_grammar.find_matching_rules(clean_string)
if matched:
return self.get_dialog_act(matched[0])
else:
return {'act': 'null', 'slots': []}
2022-04-13 13:23:47 +02:00
def get_str_cleaned(self, str_dirty):
punctuation = '!"#$%&\'()*+,-./:;<=>?@[\\\\]^_`{|}~'
new_str = str_dirty.lower()
new_str = re.sub(' +', ' ', new_str)
for char in punctuation:
new_str = new_str.replace(char,'')
return new_str
2022-05-18 13:11:49 +02:00
# TODO: Refactor
2022-04-13 13:23:47 +02:00
return (act, param)
2022-06-01 12:46:43 +02:00
nlu = Nlu()
2022-06-01 12:55:54 +02:00
print(nlu.tokenize('chciałbym kupić bilet do Krakow'))