sklep-internetowy-systemy-d.../chatbot/modules/nlu.py
2024-06-10 22:27:30 +02:00

50 lines
1.5 KiB
Python

from flair.models import SequenceTagger
import sys
sys.path.append("..")
from models.nlu_train2 import predict_frame, predict_slot
import logging
logging.getLogger('flair').setLevel(logging.CRITICAL)
class NLU:
def __init__(self):
self.frame_model = SequenceTagger.load('../models/frame-model/final-model.pt')
self.slot_model = SequenceTagger.load('../models/slot-model/final-model.pt')
def get_intent(self, text: str):
return predict_frame(self.frame_model, text.split(), 'frame')
def get_slot(self, text: str):
pred = predict_slot(self.slot_model, text.split(), 'slot')
slots = []
current_slot = None
current_slot_value = []
for frame in pred:
slot = frame["slot"]
if slot.startswith("B-"):
if current_slot:
slots.append({'name': current_slot, 'value': " ".join(current_slot_value)})
current_slot = slot[2:]
current_slot_value = [frame["form"]]
elif slot.startswith("I-"):
current_slot_value.append(frame["form"])
if current_slot:
slots.append({'name': current_slot, 'value': " ".join(current_slot_value)})
return slots
def analyze(self, text: str):
intent = self.get_intent(text)
slots = self.get_slot(text)
print({'intent': intent,
'slots': slots})
return {
'intent': intent,
'slots': slots
}
nlu = NLU()
nlu.analyze("Chce kupic lakier do pazanokci")