JSGF evaluation + small NLU optimization
This commit is contained in:
parent
9dffcd9369
commit
5c18bc5b9a
@ -26,7 +26,7 @@ class Model():
|
|||||||
|
|
||||||
class NLU():
|
class NLU():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
self.book_grammar = jsgf.parse_grammar_file('book.jsgf')
|
||||||
|
|
||||||
def get_dialog_act(self, rule):
|
def get_dialog_act(self, rule):
|
||||||
slots = []
|
slots = []
|
||||||
@ -45,10 +45,7 @@ class NLU():
|
|||||||
self.get_slots(expansion.referenced_rule.expansion, slots)
|
self.get_slots(expansion.referenced_rule.expansion, slots)
|
||||||
|
|
||||||
def __call__(self, prompt) -> Any:
|
def __call__(self, prompt) -> Any:
|
||||||
book_grammar = jsgf.parse_grammar_file('book.jsgf')
|
matched = self.book_grammar.find_matching_rules(prompt)
|
||||||
|
|
||||||
matched = book_grammar.find_matching_rules(prompt)
|
|
||||||
|
|
||||||
if matched:
|
if matched:
|
||||||
return self.get_dialog_act(matched[0])
|
return self.get_dialog_act(matched[0])
|
||||||
else:
|
else:
|
||||||
|
34
evaluate.py
34
evaluate.py
@ -0,0 +1,34 @@
|
|||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import jsgf
|
||||||
|
|
||||||
|
grammar = jsgf.parse_grammar_file('book.jsgf')
|
||||||
|
data_files = []
|
||||||
|
|
||||||
|
for filename in os.listdir("data"):
|
||||||
|
f = os.path.join("data", filename)
|
||||||
|
if os.path.isfile(f):
|
||||||
|
data_files.append(pd.read_csv(f, sep='\t', header=None))
|
||||||
|
|
||||||
|
recognized = 0
|
||||||
|
unrecognized = 0
|
||||||
|
|
||||||
|
for df in data_files:
|
||||||
|
if len(df.columns)==3:
|
||||||
|
df.columns = ["agent", "message", "act"]
|
||||||
|
elif len(df.columns)==2:
|
||||||
|
df.columns = ["agent", "message"]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
user_speech_rows = df[df['agent'] == "user"]
|
||||||
|
user_speeches = user_speech_rows["message"]
|
||||||
|
entries_count = len(user_speeches)
|
||||||
|
parsed = user_speeches.apply(lambda x: bool(grammar.find_matching_rules(x)))
|
||||||
|
true_count = parsed.sum()
|
||||||
|
false_count = len(parsed) - true_count
|
||||||
|
recognized += true_count
|
||||||
|
unrecognized += false_count
|
||||||
|
|
||||||
|
print(f"Recognized user utterances: {recognized}")
|
||||||
|
print(f"Unrecognized user utterances: {unrecognized}")
|
||||||
|
print(f"Accuracy: {recognized/(recognized+unrecognized)}")
|
Loading…
Reference in New Issue
Block a user