sklep-internetowy-systemy-d.../chatbot/modules/nlu.py
2024-06-11 18:10:51 +02:00

66 lines
1.8 KiB
Python

from flair.models import SequenceTagger
import sys
sys.path.append("..")
from models.nlu_train2 import predict_frame, predict_slot
import logging
logging.getLogger('flair').setLevel(logging.CRITICAL)
class Slot:
def __init__(self, name, value=None):
self.name = name
self.value = value
def __str__(self) -> str:
return f"Name: {self.name}, Value: {self.value}"
class Act:
def __init__(self, intent: str, slots: list[Slot] = []):
self.slots = slots
self.intent = intent
def __str__(self):
msg = f"Act: {self.intent}, Slots: ["
for slot in self.slots:
msg += f"({slot}), "
msg += "]"
return msg
class NLU:
def __init__(self):
self.frame_model = SequenceTagger.load('models/frame-model/final-model.pt')
self.slot_model = SequenceTagger.load('models/slot-model/final-model.pt')
def get_intent(self, text: str):
return predict_frame(self.frame_model, text.split(), 'frame')
def get_slot(self, text: str):
pred = predict_slot(self.slot_model, text.split(), 'slot')
slots = []
current_slot = None
current_slot_value = []
for frame in pred:
slot = frame["slot"]
if slot.startswith("B-"):
if current_slot:
slots.append(Slot(name=current_slot, value=current_slot_value))
current_slot = slot[2:]
current_slot_value = [frame["form"]]
elif slot.startswith("I-"):
current_slot_value.append(frame["form"])
if current_slot:
slots.append(Slot(name=current_slot, value=current_slot_value))
return slots
def analyze(self, text: str):
intent = self.get_intent(text)
slots = self.get_slot(text)
return Act(intent=intent, slots=slots)