50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
from flair.models import SequenceTagger
|
|
import sys
|
|
sys.path.append("..")
|
|
from models.nlu_train2 import predict_frame, predict_slot
|
|
import logging
|
|
|
|
logging.getLogger('flair').setLevel(logging.CRITICAL)
|
|
|
|
class NLU:
|
|
def __init__(self):
|
|
self.frame_model = SequenceTagger.load('../models/frame-model/final-model.pt')
|
|
self.slot_model = SequenceTagger.load('../models/slot-model/final-model.pt')
|
|
|
|
def get_intent(self, text: str):
|
|
return predict_frame(self.frame_model, text.split(), 'frame')
|
|
|
|
def get_slot(self, text: str):
|
|
pred = predict_slot(self.slot_model, text.split(), 'slot')
|
|
slots = []
|
|
current_slot = None
|
|
current_slot_value = []
|
|
|
|
for frame in pred:
|
|
slot = frame["slot"]
|
|
if slot.startswith("B-"):
|
|
if current_slot:
|
|
slots.append({'name': current_slot, 'value': " ".join(current_slot_value)})
|
|
current_slot = slot[2:]
|
|
current_slot_value = [frame["form"]]
|
|
elif slot.startswith("I-"):
|
|
current_slot_value.append(frame["form"])
|
|
|
|
if current_slot:
|
|
slots.append({'name': current_slot, 'value': " ".join(current_slot_value)})
|
|
|
|
return slots
|
|
|
|
def analyze(self, text: str):
|
|
intent = self.get_intent(text)
|
|
slots = self.get_slot(text)
|
|
print({'intent': intent,
|
|
'slots': slots})
|
|
return {
|
|
'intent': intent,
|
|
'slots': slots
|
|
}
|
|
|
|
nlu = NLU()
|
|
|
|
nlu.analyze("Chce kupic lakier do pazanokci") |