Przygotowania pod convlaba

This commit is contained in:
s495727 2024-06-09 21:29:38 +02:00
parent 72e17d2106
commit caea45471c
5 changed files with 35 additions and 21 deletions

View File

@ -2,5 +2,5 @@ flair==0.13.1
conllu==4.5.3 conllu==4.5.3
pandas==1.5.3 pandas==1.5.3
numpy==1.26.4 numpy==1.26.4
torch==2.3.0 torch==1.13
convlab==3.0.2a0 convlab==3.0.2a0

View File

@ -3,6 +3,7 @@ from service.dialog_policy import DialogPolicy
from service.natural_languag_understanding import NaturalLanguageUnderstanding from service.natural_languag_understanding import NaturalLanguageUnderstanding
from service.natural_language_generation import NaturalLanguageGeneration, parse_frame from service.natural_language_generation import NaturalLanguageGeneration, parse_frame
from service.templates import templates from service.templates import templates
from convlab.dialog_agent import PipelineAgent
# initialize classes # initialize classes
nlu = NaturalLanguageUnderstanding() # NLU nlu = NaturalLanguageUnderstanding() # NLU
@ -10,11 +11,21 @@ monitor = DialogStateMonitor() # DSM
dialog_policy = DialogPolicy() # DP dialog_policy = DialogPolicy() # DP
language_generation = NaturalLanguageGeneration(templates) # NLG language_generation = NaturalLanguageGeneration(templates) # NLG
agent = PipelineAgent(nlu=nlu, dst=monitor, policy=None, nlg=language_generation, name='sys')
resp = agent.response("Dzień dobry")
print(resp)
# Main loop # Main loop
user_input = input("Możesz zacząć pisać.\n") dial_num = 0
print("CTRL+C aby zakończyć program.")
while True: while True:
monitor.reset()
print(f"\n==== Rozpoczynasz rozmowę nr {dial_num} ====\n")
user_input = input("Możesz zacząć pisać.\n")
while True:
# NLU # NLU
frame = nlu.process_input(user_input) frame = nlu.predict(user_input)
# print(frame) # print(frame)
# DSM # DSM

View File

@ -1,4 +1,4 @@
from src.model.frame import Frame from model.frame import Frame
from convlab.dst.dst import DST from convlab.dst.dst import DST
import copy import copy

View File

@ -1,3 +1,4 @@
from convlab.nlu.nlu import NLU
from flair.models import SequenceTagger from flair.models import SequenceTagger
from utils.nlu_utils import predict_single, predict_and_annotate from utils.nlu_utils import predict_single, predict_and_annotate
from model.frame import Frame, Slot from model.frame import Frame, Slot
@ -41,7 +42,7 @@ SLOTS:
sauce sauce
""" """
class NaturalLanguageUnderstanding: class NaturalLanguageUnderstanding(NLU):
def __init__(self): def __init__(self):
print("\n========================================================") print("\n========================================================")
print("Models are loading, it may take a moment, please wait...") print("Models are loading, it may take a moment, please wait...")
@ -85,7 +86,7 @@ class NaturalLanguageUnderstanding:
return slots return slots
def process_input(self, text: str): def predict(self, text: str, context: list):
act = self.__predict_intention(text) act = self.__predict_intention(text)
slots = self.__predict_slot(text) slots = self.__predict_slot(text)
frame = Frame(source = 'user', act = act, slots = slots) frame = Frame(source = 'user', act = act, slots = slots)

View File

@ -1,6 +1,8 @@
import re import re
from service.template_selector import select_template from service.template_selector import select_template
import random import random
from convlab.nlg.nlg import NLG
# from service.templates import templates # from service.templates import templates
def parse_frame(frame): def parse_frame(frame):
@ -12,7 +14,7 @@ def parse_frame(frame):
return act, slots return act, slots
class NaturalLanguageGeneration: class NaturalLanguageGeneration(NLG):
def __init__(self, templates): def __init__(self, templates):
self.templates = templates self.templates = templates