diff --git a/requirements.txt b/requirements.txt index 75af938..a4fc1ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,5 @@ flair==0.13.1 conllu==4.5.3 pandas==1.5.3 numpy==1.26.4 -torch==2.3.0 +torch==1.13 convlab==3.0.2a0 \ No newline at end of file diff --git a/src/main.py b/src/main.py index aa0abb9..0d6d7cd 100644 --- a/src/main.py +++ b/src/main.py @@ -3,6 +3,7 @@ from service.dialog_policy import DialogPolicy from service.natural_languag_understanding import NaturalLanguageUnderstanding from service.natural_language_generation import NaturalLanguageGeneration, parse_frame from service.templates import templates +from convlab.dialog_agent import PipelineAgent # initialize classes nlu = NaturalLanguageUnderstanding() # NLU @@ -10,28 +11,38 @@ monitor = DialogStateMonitor() # DSM dialog_policy = DialogPolicy() # DP language_generation = NaturalLanguageGeneration(templates) # NLG +agent = PipelineAgent(nlu=nlu, dst=monitor, policy=None, nlg=language_generation, name='sys') +resp = agent.response("Dzień dobry") +print(resp) # Main loop -user_input = input("Możesz zacząć pisać.\n") +dial_num = 0 +print("CTRL+C aby zakończyć program.") while True: - # NLU - frame = nlu.process_input(user_input) - # print(frame) + monitor.reset() - # DSM - monitor.update(frame) + print(f"\n==== Rozpoczynasz rozmowę nr {dial_num} ====\n") + user_input = input("Możesz zacząć pisać.\n") - # DP - # print(dialog_policy.next_dialogue_act(monitor.read()).act) + while True: + # NLU + frame = nlu.predict(user_input) + # print(frame) - # NLG - act, slots = parse_frame(frame) - response = language_generation.generate(act, slots) - print(response) + # DSM + monitor.update(frame) - if frame.act == "bye": - break - - user_input = input(">\n") + # DP + # print(dialog_policy.next_dialogue_act(monitor.read()).act) + + # NLG + act, slots = parse_frame(frame) + response = language_generation.generate(act, slots) + print(response) + + if frame.act == "bye": + break + + user_input = input(">\n") diff --git a/src/service/dialog_state_monitor.py b/src/service/dialog_state_monitor.py index dfeb970..c4728d0 100644 --- a/src/service/dialog_state_monitor.py +++ b/src/service/dialog_state_monitor.py @@ -1,4 +1,4 @@ -from src.model.frame import Frame +from model.frame import Frame from convlab.dst.dst import DST import copy diff --git a/src/service/natural_languag_understanding.py b/src/service/natural_languag_understanding.py index 394ff55..a76f2b5 100644 --- a/src/service/natural_languag_understanding.py +++ b/src/service/natural_languag_understanding.py @@ -1,3 +1,4 @@ +from convlab.nlu.nlu import NLU from flair.models import SequenceTagger from utils.nlu_utils import predict_single, predict_and_annotate from model.frame import Frame, Slot @@ -41,7 +42,7 @@ SLOTS: sauce """ -class NaturalLanguageUnderstanding: +class NaturalLanguageUnderstanding(NLU): def __init__(self): print("\n========================================================") print("Models are loading, it may take a moment, please wait...") @@ -85,7 +86,7 @@ class NaturalLanguageUnderstanding: return slots - def process_input(self, text: str): + def predict(self, text: str, context: list): act = self.__predict_intention(text) slots = self.__predict_slot(text) frame = Frame(source = 'user', act = act, slots = slots) diff --git a/src/service/natural_language_generation.py b/src/service/natural_language_generation.py index 3beec00..14c37ba 100644 --- a/src/service/natural_language_generation.py +++ b/src/service/natural_language_generation.py @@ -1,6 +1,8 @@ import re from service.template_selector import select_template import random +from convlab.nlg.nlg import NLG + # from service.templates import templates def parse_frame(frame): @@ -12,7 +14,7 @@ def parse_frame(frame): return act, slots -class NaturalLanguageGeneration: +class NaturalLanguageGeneration(NLG): def __init__(self, templates): self.templates = templates