Compare commits
2 Commits
5d83b7b0e8
...
60a298fa49
Author | SHA1 | Date | |
---|---|---|---|
60a298fa49 | |||
88741c2d2b |
295
dialog_with_nlg.py
Normal file
295
dialog_with_nlg.py
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
import string
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import jsgf
|
||||||
|
from unidecode import unidecode
|
||||||
|
|
||||||
|
from convlab.dst import dst
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def default_state():
|
||||||
|
return dict(
|
||||||
|
user_action=[],
|
||||||
|
system_action=[],
|
||||||
|
belief_state={
|
||||||
|
'address': '',
|
||||||
|
'payment_method': '',
|
||||||
|
'dish': [],
|
||||||
|
'time': ''
|
||||||
|
},
|
||||||
|
booked={},
|
||||||
|
request_state=[],
|
||||||
|
terminated=False,
|
||||||
|
history=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Model:
|
||||||
|
def __init__(self):
|
||||||
|
self.state = default_state()
|
||||||
|
self.nlu = NLU()
|
||||||
|
self.dst = DST(self.state)
|
||||||
|
self.dp = DP(self.state)
|
||||||
|
self.nlg = NLG(self.state)
|
||||||
|
|
||||||
|
def __call__(self, prompt) -> Any:
|
||||||
|
print(prompt)
|
||||||
|
msg = prompt.lower()
|
||||||
|
|
||||||
|
r = self.nlu(msg)
|
||||||
|
slots = r['slots']
|
||||||
|
#print(r)
|
||||||
|
r = self.dst(r)
|
||||||
|
#print(r)
|
||||||
|
r = self.dp()
|
||||||
|
#print(r)
|
||||||
|
r = self.nlg(r, slots)
|
||||||
|
print(r)
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
class NLU():
|
||||||
|
def __init__(self):
|
||||||
|
self.book_grammar = jsgf.parse_grammar_file('book.jsgf')
|
||||||
|
|
||||||
|
def get_dialog_act(self, rule):
|
||||||
|
slots = []
|
||||||
|
self.get_slots(rule.expansion, slots)
|
||||||
|
return {'act': rule.name, 'slots': slots}
|
||||||
|
|
||||||
|
def get_slots(self, expansion, slots):
|
||||||
|
if expansion.tag != '':
|
||||||
|
slots.append((expansion.tag, expansion.current_match))
|
||||||
|
return
|
||||||
|
|
||||||
|
for child in expansion.children:
|
||||||
|
self.get_slots(child, slots)
|
||||||
|
|
||||||
|
if not expansion.children and isinstance(expansion, jsgf.NamedRuleRef):
|
||||||
|
self.get_slots(expansion.referenced_rule.expansion, slots)
|
||||||
|
|
||||||
|
def __call__(self, prompt) -> Any:
|
||||||
|
book_grammar = jsgf.parse_grammar_file('book.jsgf')
|
||||||
|
|
||||||
|
prompt = unidecode(prompt)
|
||||||
|
translator = str.maketrans('', '', string.punctuation)
|
||||||
|
prompt = prompt.translate(translator)
|
||||||
|
|
||||||
|
matched = book_grammar.find_matching_rules(prompt)
|
||||||
|
|
||||||
|
if matched:
|
||||||
|
return self.get_dialog_act(matched[0])
|
||||||
|
else:
|
||||||
|
return {'act': 'null', 'slots': []}
|
||||||
|
|
||||||
|
|
||||||
|
class DST(dst.DST):
|
||||||
|
|
||||||
|
def __init__(self, state):
|
||||||
|
dst.DST.__init__(self)
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
def __call__(self, user_act) -> Any:
|
||||||
|
if len(user_act['slots']) == 0:
|
||||||
|
user_act = [(user_act['act'], None, None)]
|
||||||
|
else:
|
||||||
|
user_act = [(user_act['act'], k, v) for k, v in user_act['slots'] if v is not None]
|
||||||
|
|
||||||
|
self.state['request_state'] = {}
|
||||||
|
for act, slot, value in user_act:
|
||||||
|
self.state['user_action'].append(act)
|
||||||
|
|
||||||
|
if act == "platnosc":
|
||||||
|
self.state['belief_state']['payment_method'] = value
|
||||||
|
self.state['request_state'] = ['payment_method']
|
||||||
|
|
||||||
|
elif act == "offer":
|
||||||
|
self.state['request_state'] = ['menu']
|
||||||
|
|
||||||
|
elif act == 'select':
|
||||||
|
if slot == 'dish':
|
||||||
|
self.state['belief_state']['dish'].append(value)
|
||||||
|
else:
|
||||||
|
self.state['belief_state'][slot] = value
|
||||||
|
self.state['request_state'] = [slot]
|
||||||
|
|
||||||
|
elif act == 'inform':
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif act == 'request':
|
||||||
|
pass
|
||||||
|
elif act == 'restart':
|
||||||
|
self.state["belief_state"] = default_state()["belief_state"]
|
||||||
|
self.state["booked"] = {}
|
||||||
|
self.state["request_state"] = []
|
||||||
|
self.state["terminated"] = False
|
||||||
|
self.state["history"] = []
|
||||||
|
return self.state
|
||||||
|
|
||||||
|
|
||||||
|
class DP():
|
||||||
|
def __init__(self, state):
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
def __call__(self) -> Any:
|
||||||
|
system_action = None
|
||||||
|
|
||||||
|
if self.state['user_action'][-1] == 'hello':
|
||||||
|
system_action = 'welcomemsg'
|
||||||
|
# przywitaj uzytkownika (i pokaz menu)
|
||||||
|
|
||||||
|
elif self.state['user_action'][-1] == 'select':
|
||||||
|
system_action = 'inform'
|
||||||
|
# poinformuj o wybranych slotach z "request_state"
|
||||||
|
|
||||||
|
elif (self.state['user_action'][-1] == 'help'
|
||||||
|
or self.state['user_action'][-1] == 'offer'
|
||||||
|
or self.state['user_action'][-1] == 'reqmore'
|
||||||
|
or (self.state['user_action'][-1] == 'request' and len(self.state['request_state']) == 0)
|
||||||
|
):
|
||||||
|
system_action = 'offer'
|
||||||
|
# zaoferuj cale menu
|
||||||
|
|
||||||
|
elif self.state['user_action'][-1] == 'ack':
|
||||||
|
address = self.state["belief_state"]["address"]
|
||||||
|
payment_method = self.state["belief_state"]["payment_method"]
|
||||||
|
dish = self.state["belief_state"]["dish"]
|
||||||
|
# W przypadku braku szczegolnej informacji o czasie zamówienia zamawiamy natychmiast
|
||||||
|
|
||||||
|
if address and payment_method and dish:
|
||||||
|
system_action = 'bye'
|
||||||
|
self.state['terminated'] = True
|
||||||
|
# potwierdz i zakoncz, podsumuj zamowienie
|
||||||
|
else:
|
||||||
|
system_action = 'canthelp.missing_slot_value'
|
||||||
|
elif self.state['user_action'][-1] == 'restart':
|
||||||
|
system_action = 'welcomemsg'
|
||||||
|
# zachowaj sie jak na poczatku rozmowy
|
||||||
|
else:
|
||||||
|
system_action = 'inform'
|
||||||
|
# poinformuj o wybranych slotach z "request_state"
|
||||||
|
# lub o wszystkich jezeli nic nie ma w request state
|
||||||
|
|
||||||
|
self.state['system_action'].append(system_action)
|
||||||
|
return system_action
|
||||||
|
|
||||||
|
|
||||||
|
class NLG():
|
||||||
|
def __init__(self, state):
|
||||||
|
self.model = AutoModelForSeq2SeqLM.from_pretrained("filnow/nlg-umt5-pol")
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
|
||||||
|
self.nlg_pipeline = pipeline('summarization', model=self.model, tokenizer=self.tokenizer)
|
||||||
|
|
||||||
|
def __call__(self, act, slots) -> Any:
|
||||||
|
if act == 'welcomemsg':
|
||||||
|
return "Witaj w naszej restauracji! Jak mogę Ci pomóc?"
|
||||||
|
|
||||||
|
elif act == "offer":
|
||||||
|
if slots == []:
|
||||||
|
return "Przepraszam nie rozumiem. Podaj więcej informacji."
|
||||||
|
|
||||||
|
elif act == "inform":
|
||||||
|
if slots == []:
|
||||||
|
return "Przepraszam nie rozumiem. Podaj więcej informacji."
|
||||||
|
else:
|
||||||
|
text = []
|
||||||
|
for i in slots:
|
||||||
|
if i[1] != None:
|
||||||
|
text.append(f"{i[0]}[{i[1]}]")
|
||||||
|
return self.nlg_pipeline(f'generate text: {", ".join(text)}')[0]['summary_text']
|
||||||
|
|
||||||
|
elif act == "canthelp.missing_slot_value":
|
||||||
|
return "Przepraszam, ale nie mogę zrealizować zamówienia. Brakuje mi niektórych informacji. Czy mogę pomóc w czymś innym?"
|
||||||
|
|
||||||
|
elif act == "bye":
|
||||||
|
return "Dziękujemy za zamówienie! Smacznego!"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model = Model()
|
||||||
|
|
||||||
|
# jezeli sie przywita to przywitaj uzytkownika (i pokaz menu)
|
||||||
|
# response = model("Cześć")
|
||||||
|
# response = model("Witam")
|
||||||
|
# response = model("Witam system")
|
||||||
|
# response = model("Hej, jakim botem jesteś?")
|
||||||
|
# response = model("Hej, czym się zajmujesz?")
|
||||||
|
# response = model("Hej, w czym mi możesz pomóc?")
|
||||||
|
response = model("Siema, w czym możesz mi pomóc?")
|
||||||
|
assert response == "welcomemsg"
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# jezeli prosi o pomoc lub po prostu o menu to zaoferuj cale menu
|
||||||
|
# response = model("Pokaz menu")
|
||||||
|
# response = model("A co do picia proponujesz?")
|
||||||
|
# response = model("Jakie inne desery oferujesz?")
|
||||||
|
response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.")
|
||||||
|
assert response == "offer"
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# jezeli wybierze danie to zapisz wybor i poinformuj o nim
|
||||||
|
# response = model("Wezmę rybe")
|
||||||
|
# response = model("Poproszę tatara")
|
||||||
|
response = model("Chciałbym zjesc tatara")
|
||||||
|
assert response == "inform"
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# jezeli poda adres to zapisze wybor i poinformuj o nim
|
||||||
|
# response = model('Poproszę na poznańską 2')
|
||||||
|
response = model("uniwersytetu poznanskiego 4 61-614 poznan")
|
||||||
|
assert response == "inform"
|
||||||
|
|
||||||
|
# jezeli sprobuje dokonac zamowienia bez podania potrzebnych informacji prosimy o nie
|
||||||
|
#response = model("Dobrze, nie mogę się już doczekać.")
|
||||||
|
response = model("Super, to zatem wszystko!")
|
||||||
|
assert response == "canthelp.missing_slot_value"
|
||||||
|
|
||||||
|
# jezeli wybierze rodzaj platnosci to zapisz wybor i poinformuj o nim
|
||||||
|
# response = model("karta")
|
||||||
|
# response = model("Poproszę blikiem z góry")
|
||||||
|
response = model("Zapłacę kartą przy odbiorze")
|
||||||
|
assert response == "inform"
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# jezeli potwiedzi zamowienie to zakoncz zamawianie sukcesem i wypisz calosc
|
||||||
|
# response = model("Potwierdzam!")
|
||||||
|
# response = model("Tak!")
|
||||||
|
# response = model("Tak to wszystko!")
|
||||||
|
# response = model("Super, to zatem wszystko!")
|
||||||
|
response = model("Dobrze, nie mogę się już doczekać.")
|
||||||
|
assert response == "bye"
|
||||||
|
|
||||||
|
print("----Konwersacja z restartem-------")
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
response = model("Siema, w czym możesz mi pomóc?")
|
||||||
|
assert response == "welcomemsg"
|
||||||
|
response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.")
|
||||||
|
assert response == "offer"
|
||||||
|
response = model("Chciałbym zjesc tatara")
|
||||||
|
assert response == "inform"
|
||||||
|
response = model("uniwersytetu poznanskiego 4 61-614 poznan")
|
||||||
|
assert response == "inform"
|
||||||
|
response = model("od nowa")
|
||||||
|
assert response == "welcomemsg"
|
||||||
|
response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.")
|
||||||
|
assert response == "offer"
|
||||||
|
response = model("Chciałbym zjesc tatara")
|
||||||
|
assert response == "inform"
|
||||||
|
response = model("uniwersytetu poznanskiego 4 61-614 poznan")
|
||||||
|
assert response == "inform"
|
||||||
|
response = model("Zapłacę kartą przy odbiorze")
|
||||||
|
assert response == "inform"
|
||||||
|
response = model("Dobrze, nie mogę się już doczekać.")
|
||||||
|
assert response == "bye"
|
41244
nlg_data.csv
41244
nlg_data.csv
File diff suppressed because it is too large
Load Diff
@ -2,36 +2,36 @@ import pandas as pd
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
addresses = ["ulica Zielona 15", "ulica Czerwona 20", "ulica Niebieska 30", "ulica Biala 5", "ulica Czarna 10", "ulica Fioletowa 25", "ulica Pomaranczowa 35", "ulica Zolta 40", "ulica Rozowa 45", "ulica Szara 50", "ulica Brzowa 55", "ulica Srebrna 60", "ulica Zlota 65", "ulica Platynowa 70", "ulica Miedziana 75", "ulica Niklowa 80", "ulica Aluminium 85", "ulica Stalowa 90", "ulica Zelazna 95", "ulica Miedziana 100"]
|
addresses = ["ulica Zielona 15", "ulica Czerwona 20", "ulica Niebieska 30", "ulica Biala 5", "ulica Czarna 10", "ulica Fioletowa 25", "ulica Pomaranczowa 35", "ulica Zolta 40", "ulica Różowa 45", "ulica Szara 50", "ulica Brązowa 55", "ulica Srebrna 60", "ulica Złota 65", "ulica Platynowa 70", "ulica Miedziana 75", "ulica Niklowa 80", "ulica Aluminium 85", "ulica Stalowa 90", "ulica Żelazna 95", "ulica Miedziana 100"]
|
||||||
payment_methods = ["karta kredytowa", "gotowka", "blik", "przelew", "google pay"]
|
payment_methods = ["karta kredytowa", "gotowka", "blik", "przelew", "google pay"]
|
||||||
dishes = ["spaghetti", "pierogi", "schabowy", "pizza", "burger", "tatar", "poledwica", "tiramisu", "zrazy", "pyzy", "placki", "makaron", "zupa", "ryba", "cole", "tiramisu", "zupa grzybowa", "stek", "soki", "napoj"]
|
dishes = ["spaghetti", "pierogi", "schabowy", "pizza", "burger", "tatar", "poledwica", "tiramisu", "zrazy", "pyzy", "placki", "makaron", "zupa", "ryba", "cole", "tiramisu", "zupa grzybowa", "stek", "soki", "napoj"]
|
||||||
times = ["8:00", "12:00", "18:00", "20:00", "10:00", "dziesiata", "dziewiata", "osma", "siodma", "szosta", "czwarta", "trzecia", "druga", "pierwsza", "poludnie", "polnoc", "wschod", "zachod", "poludniowy wschod", "poludniowy zachod", "polnocny wschod", "polnocny zachod", "rano", "wieczor", "noc", "popoludnie", "przedpoludnie", "po poludniu", "po polnocy", "przed polnoca", "przed poludniem"]
|
times = ["8:00", "12:00", "18:00", "20:00", "10:00", "dziesiata", "dziewiata", "osma", "siodma", "szosta", "czwarta", "trzecia", "druga", "pierwsza", "poludnie", "polnoc", "wschod", "zachod", "poludniowy wschod", "poludniowy zachod", "polnocny wschod", "polnocny zachod", "rano", "wieczor", "noc", "popoludnie", "przedpoludnie", "po poludniu", "po polnocy", "przed polnoca", "przed poludniem"]
|
||||||
portion_sizes = ["mala", "srednia", "duza", "gigantyczna", "mini"]
|
portion_sizes = ["mała", "średnia", "duża", "gigantyczna", "mini"]
|
||||||
price = ["10", "50", "100", "150", "tanio", "drogo"]
|
price = ["10", "50", "100", "150", "tanio", "drogo"]
|
||||||
ingredient = ["mieso", "mleko", "jajka", "maka", "cukier", "sol", "pieprz", "oliwa", "maslo", "ser", "warzywa", "owoce", "ryz", "makaron", "zupa", "ryba", "sos", "przyprawy", "soki", "napoje", "alkohol", "kawa", "herbata", "deser", "ciasto", "chleb", "pasta", "sos", "danie", "potrawa", "zupa", "salatka", "kanapka", "tost", "jajecznica", "omlet", "placki", "pierogi", "schabowy", "kotlet", "kotlet schabowy", "kotlet mielony", "kotlet z kurczaka", "kotlet z indyka", "kotlet z ryby", "kotlet z warzyw", "kotlet ziemniaczany", "kotlet z kaszy", "kotlet z makaronu", "kotlet z ziemniakow", "kotlet z ryzu"]
|
ingredient = ["mięso", "mleko", "jajka", "mąka", "cukier", "sól", "pieprz", "oliwa", "masło", "ser", "warzywa", "owoce", "ryż", "makaron", "zupa", "ryba", "sos", "przyprawy", "soki", "napoje", "alkohol", "kawa", "herbata", "deser", "ciasto", "chleb", "pasta", "sos", "danie", "potrawa", "zupa", "sałatka", "kanapka", "tost", "jajecznica", "omlet", "placki", "pierogi", "schabowy", "kotlet", "kotlet schabowy", "kotlet mielony", "kotlet z kurczaka", "kotlet z indyka", "kotlet z ryby", "kotlet z warzyw", "kotlet ziemniaczany", "kotlet z kaszy", "kotlet z makaronu", "kotlet z ziemniaków", "kotlet z ryżu"]
|
||||||
allergy = ["gluten", "laktoza", "jajka", "orzechy", "soja", "ryby", "skorupiaki", "mleko", "seler", "gorczyca", "sezam", "siarczyny", "lubin", "migdaly", "orzechy laskowe", "orzechy wloskie", "orzechy nerkowca", "orzechy ziemne", "orzechy brazylijskie", "orzechy makadamia", "orzechy pecan", "orzechy pistacjowe", "orzechy kasztanowe", "orzechy pinii", "orzechy arachidowe", "orzechy ziemne", "orzechy brazylijskie", "orzechy makadamia", "orzechy pecan", "orzechy pistacjowe", "orzechy kasztanowe", "orzechy pinii", "orzechy arachidowe", "orzechy ziemne", "orzechy brazylijskie", "orzechy makadamia", "orzechy pecan"]
|
allergy = ["gluten", "laktoza", "jajka", "orzechy", "soja", "ryby", "skorupiaki", "mleko", "seler", "gorczyca", "sezam", "siarczyny", "łubin", "migdały", "orzechy laskowe", "orzechy włoskie", "orzechy nerkowca", "orzechy ziemne", "orzechy brazylijskie", "orzechy makadamia", "orzechy pecan", "orzechy pistacjowe", "orzechy kasztanowe", "orzechy pinii", "orzechy arachidowe", "orzechy ziemne", "orzechy brazylijskie", "orzechy makadamia", "orzechy pecan", "orzechy pistacjowe", "orzechy kasztanowe", "orzechy pinii", "orzechy arachidowe", "orzechy ziemne", "orzechy brazylijskie", "orzechy makadamia", "orzechy pecan", "orzechy pistacjowe"]
|
||||||
|
|
||||||
def create_ref(slot, value):
|
def create_ref(slot, value):
|
||||||
ref_templates = {
|
ref_templates = {
|
||||||
"address": [
|
"address": [
|
||||||
f"Zamowienie zostanie dostarczone na {value}.",
|
f"Zamówienie zostanie dostarczone na {value}.",
|
||||||
f"Niestety nie dostarczamy na {value}.",
|
f"Niestety nie dostarczamy na {value}.",
|
||||||
f"Oczywiscie, dostarczymy na {value}.",
|
f"Oczywiście, dostarczymy na {value}.",
|
||||||
f"Dostawa mozliwa na {value}.",
|
f"Dostawa możliwa na {value}.",
|
||||||
f"Nie dostarczamy na {value}."
|
f"Nie dostarczamy na {value}."
|
||||||
],
|
],
|
||||||
"payment_method": [
|
"payment_method": [
|
||||||
f"Akceptujemy platnosc {value}.",
|
f"Akceptujemy płatność {value}.",
|
||||||
f"Nie akceptujemy platnosci {value}.",
|
f"Nie akceptujemy płatności {value}.",
|
||||||
f"Mozesz placic {value}.",
|
f"Możesz płacić {value}.",
|
||||||
f"Platnosc {value} jest mozliwa.",
|
f"Płatność {value} jest możliwa.",
|
||||||
f"Nie obslugujemy platnosci {value}."
|
f"Nie obsługujemy płatności {value}."
|
||||||
],
|
],
|
||||||
"dish": [
|
"dish": [
|
||||||
f"Specjalnoscia jest {value}.",
|
f"Specjalnością jest {value}.",
|
||||||
f"Nie mamy w ofercie {value}.",
|
f"Nie mamy w ofercie {value}.",
|
||||||
f"Zapraszamy na {value}.",
|
f"Zapraszamy na {value}.",
|
||||||
f"{value} jest dostepne.",
|
f"{value} jest dostępne.",
|
||||||
f"Nie mamy {value} w menu."
|
f"Nie mamy {value} w menu."
|
||||||
],
|
],
|
||||||
"time": [
|
"time": [
|
||||||
@ -39,19 +39,19 @@ def create_ref(slot, value):
|
|||||||
f"Nieczynne o {value}.",
|
f"Nieczynne o {value}.",
|
||||||
f"Zapraszamy o {value}.",
|
f"Zapraszamy o {value}.",
|
||||||
f"Otwarte od {value}.",
|
f"Otwarte od {value}.",
|
||||||
f"Zamkniete o {value}."
|
f"Zamknięte o {value}."
|
||||||
],
|
],
|
||||||
"portion_size": [
|
"portion_size": [
|
||||||
f"Dostepne porcje: {value}.",
|
f"Dostępne porcje: {value}.",
|
||||||
f"Brak porcji {value}.",
|
f"Brak porcji {value}.",
|
||||||
f"Dostepne porcje: {value}.",
|
f"Dostępne porcje: {value}.",
|
||||||
f"Porcja {value} jest dostepna.",
|
f"Porcja {value} jest dostępna.",
|
||||||
f"Nie mamy porcji {value}."
|
f"Nie mamy porcji {value}."
|
||||||
],
|
],
|
||||||
"price": [
|
"price": [
|
||||||
f"Cena to {value}.",
|
f"Cena to {value}.",
|
||||||
f"Nie mamy ceny {value}.",
|
f"Nie mamy ceny {value}.",
|
||||||
f"Mozesz kupic za {value}.",
|
f"Możesz kupić za {value}.",
|
||||||
f"Cena wynosi {value}.",
|
f"Cena wynosi {value}.",
|
||||||
],
|
],
|
||||||
"ingredient": [
|
"ingredient": [
|
||||||
@ -99,7 +99,7 @@ def generate_sample(num_slots):
|
|||||||
mr_list.append(f"{slot}[{value}]")
|
mr_list.append(f"{slot}[{value}]")
|
||||||
ref_list.append(random.choice(create_ref(slot, value)))
|
ref_list.append(random.choice(create_ref(slot, value)))
|
||||||
if len(mr_list) == 1:
|
if len(mr_list) == 1:
|
||||||
return {"mr": repr(mr_list[0]), "ref": ref_list[0]}
|
return {"mr": repr(mr_list[0]), "ref": " ".join(ref_list)}
|
||||||
else:
|
else:
|
||||||
return {"mr": ", ".join(mr_list), "ref": " ".join(ref_list)}
|
return {"mr": ", ".join(mr_list), "ref": " ".join(ref_list)}
|
||||||
|
|
||||||
@ -108,4 +108,4 @@ for num_slots in range(0, 6):
|
|||||||
data.append(generate_sample(num_slots))
|
data.append(generate_sample(num_slots))
|
||||||
|
|
||||||
df = pd.DataFrame(remove_duplicates(data))
|
df = pd.DataFrame(remove_duplicates(data))
|
||||||
df.to_csv('nlg_data.csv', index=False)
|
df.to_csv('abc.csv', index=False)
|
||||||
|
273
nlg_train.ipynb
273
nlg_train.ipynb
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user