add more acts to nlg

This commit is contained in:
filnow 2024-06-04 10:46:12 +02:00
parent 5922f0e074
commit 4b1c33bf98
2 changed files with 55 additions and 14 deletions

View File

@ -1,9 +1,9 @@
import string import string
from typing import Any from typing import Any, List, Tuple
import jsgf import jsgf
from unidecode import unidecode from unidecode import unidecode
import random
from convlab.dst import dst from convlab.dst import dst
from transformers import ( from transformers import (
@ -191,26 +191,67 @@ class NLG():
self.tokenizer = AutoTokenizer.from_pretrained("google/umt5-small") self.tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
self.nlg_pipeline = pipeline('summarization', model=self.model, tokenizer=self.tokenizer) self.nlg_pipeline = pipeline('summarization', model=self.model, tokenizer=self.tokenizer)
def __call__(self, act, slots) -> Any: self.messages = {
"welcomemsg": [
"Witaj w naszej restauracji! Jak mogę Ci pomóc?",
"Witaj! W czym mogę pomóc?",
"Hej! Co mogę dla Ciebie zrobić?"
],
"canthelp": [
"Przepraszam, nie mogę pomóc w tej chwili.",
"Nie jestem w stanie pomóc.",
"Przepraszam, nie rozumiem."
],
"bye": [
"Dziękujemy za zamówienie! Smacznego!",
"Smacznego! Do zobaczenia!",
"Dziękujemy za skorzystanie z naszych usług!"
],
"affirm": [
"Zamówienie zostało złożone!",
"Potwierdzam zamówienie!",
"Skladam zamówienie!"
],
"repeat": [
"Możesz powtórzyć?",
"Nie zrozumiałem, możesz powtórzyć?",
"Nie zrozumiałem, możesz powtórzyć jeszcze raz?"
],
"reqmore": [
"Potrzebujesz więcej informacji?",
"Czy mogę pomóc w czymś jeszcze?",
"Czy mogę zaoferować coś jeszcze?"
]
}
def __call__(self, act: str, slots: List[Tuple[str, str]]) -> str:
if act == 'welcomemsg': if act == 'welcomemsg':
return "Witaj w naszej restauracji! Jak mogę Ci pomóc?" return random.choice(self.messages["welcomemsg"])
elif act == "offer": elif act in ["inform", "request", "select"]:
if slots == []:
return "Przepraszam nie rozumiem. Podaj więcej informacji."
elif act == "inform":
if slots == []: if slots == []:
return "Przepraszam nie rozumiem. Podaj więcej informacji." return "Przepraszam nie rozumiem. Podaj więcej informacji."
else: else:
text = [f"{slot[0]}[{slot[1]}]" for slot in slots if slot[1] is not None] text = [f"{slot[0]}[{slot[1]}]" for slot in slots if slot[1] is not None]
return self.nlg_pipeline(f'generate text: {", ".join(text)}')[0]['summary_text'] return self.nlg_pipeline(f'generate text: {", ".join(text)}')[0]['summary_text']
elif act == "canthelp.missing_slot_value": elif act == "canthelp.missing_slot_value" or act == "canthelp":
return "Cieszę się, że mogłem pomóc. Czy mogę zrobić coś jeszcze?" return random.choice(self.messages["canthelp"])
elif act == "bye": elif act == "bye":
return "Dziękujemy za zamówienie! Smacznego!" return random.choice(self.messages["bye"])
elif act == 'affirm':
return random.choice(self.messages["affirm"])
elif act == "repeat":
return random.choice(self.messages["repeat"])
elif act == "reqmore":
return random.choice(self.messages["reqmore"])
elif act == "offer":
return "Proszę oto menu zeskanuj kod QR aby je zobaczyć."
if __name__ == "__main__": if __name__ == "__main__":

File diff suppressed because one or more lines are too long