2024-06-10 22:27:30 +02:00
|
|
|
from flair.models import SequenceTagger
|
|
|
|
import sys
|
2024-06-11 00:51:22 +02:00
|
|
|
|
2024-06-10 22:27:30 +02:00
|
|
|
sys.path.append("..")
|
|
|
|
from models.nlu_train2 import predict_frame, predict_slot
|
|
|
|
import logging
|
|
|
|
|
|
|
|
logging.getLogger('flair').setLevel(logging.CRITICAL)
|
2024-05-07 19:25:33 +02:00
|
|
|
|
2024-06-11 00:51:22 +02:00
|
|
|
|
|
|
|
class Slot:
|
|
|
|
def __init__(self, name, value=None):
|
|
|
|
self.name = name
|
|
|
|
self.value = value
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
return f"Name: {self.name}, Value: {self.value}"
|
|
|
|
|
|
|
|
|
2024-06-11 18:10:51 +02:00
|
|
|
class Act:
|
2024-06-11 00:51:22 +02:00
|
|
|
def __init__(self, intent: str, slots: list[Slot] = []):
|
|
|
|
self.slots = slots
|
|
|
|
self.intent = intent
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
msg = f"Act: {self.intent}, Slots: ["
|
|
|
|
for slot in self.slots:
|
|
|
|
msg += f"({slot}), "
|
|
|
|
msg += "]"
|
|
|
|
return msg
|
|
|
|
|
|
|
|
|
2024-05-07 19:25:33 +02:00
|
|
|
class NLU:
|
|
|
|
def __init__(self):
|
2024-06-11 00:51:22 +02:00
|
|
|
self.frame_model = SequenceTagger.load('models/frame-model/final-model.pt')
|
|
|
|
self.slot_model = SequenceTagger.load('models/slot-model/final-model.pt')
|
2024-06-10 22:27:30 +02:00
|
|
|
|
|
|
|
def get_intent(self, text: str):
|
|
|
|
return predict_frame(self.frame_model, text.split(), 'frame')
|
2024-05-07 19:25:33 +02:00
|
|
|
|
2024-06-10 22:27:30 +02:00
|
|
|
def get_slot(self, text: str):
|
|
|
|
pred = predict_slot(self.slot_model, text.split(), 'slot')
|
2024-05-07 19:25:33 +02:00
|
|
|
slots = []
|
2024-06-10 22:27:30 +02:00
|
|
|
current_slot = None
|
|
|
|
current_slot_value = []
|
|
|
|
|
|
|
|
for frame in pred:
|
|
|
|
slot = frame["slot"]
|
|
|
|
if slot.startswith("B-"):
|
|
|
|
if current_slot:
|
2024-06-11 00:51:22 +02:00
|
|
|
slots.append(Slot(name=current_slot, value=current_slot_value))
|
2024-06-10 22:27:30 +02:00
|
|
|
current_slot = slot[2:]
|
|
|
|
current_slot_value = [frame["form"]]
|
|
|
|
elif slot.startswith("I-"):
|
|
|
|
current_slot_value.append(frame["form"])
|
|
|
|
|
|
|
|
if current_slot:
|
2024-06-11 00:51:22 +02:00
|
|
|
slots.append(Slot(name=current_slot, value=current_slot_value))
|
2024-06-10 22:27:30 +02:00
|
|
|
|
|
|
|
return slots
|
|
|
|
|
|
|
|
def analyze(self, text: str):
|
|
|
|
intent = self.get_intent(text)
|
|
|
|
slots = self.get_slot(text)
|
2024-06-11 18:10:51 +02:00
|
|
|
return Act(intent=intent, slots=slots)
|