From 60a298fa4932fec48254f6fc75436fe87b9ae0b4 Mon Sep 17 00:00:00 2001 From: filnow Date: Mon, 3 Jun 2024 14:49:59 +0200 Subject: [PATCH] add: nlg model --- dialog_with_nlg.py | 295 +++++++++++++++++++++++++++++++++++++++++++++ nlg_train.ipynb | 231 +---------------------------------- 2 files changed, 296 insertions(+), 230 deletions(-) create mode 100644 dialog_with_nlg.py diff --git a/dialog_with_nlg.py b/dialog_with_nlg.py new file mode 100644 index 0000000..e9268fb --- /dev/null +++ b/dialog_with_nlg.py @@ -0,0 +1,295 @@ +import string +from typing import Any + +import jsgf +from unidecode import unidecode + +from convlab.dst import dst + +from transformers import ( + AutoModelForSeq2SeqLM, + AutoTokenizer, + pipeline, +) + + +def default_state(): + return dict( + user_action=[], + system_action=[], + belief_state={ + 'address': '', + 'payment_method': '', + 'dish': [], + 'time': '' + }, + booked={}, + request_state=[], + terminated=False, + history=[] + ) + + +class Model: + def __init__(self): + self.state = default_state() + self.nlu = NLU() + self.dst = DST(self.state) + self.dp = DP(self.state) + self.nlg = NLG(self.state) + + def __call__(self, prompt) -> Any: + print(prompt) + msg = prompt.lower() + + r = self.nlu(msg) + slots = r['slots'] + #print(r) + r = self.dst(r) + #print(r) + r = self.dp() + #print(r) + r = self.nlg(r, slots) + print(r) + + return r + + +class NLU(): + def __init__(self): + self.book_grammar = jsgf.parse_grammar_file('book.jsgf') + + def get_dialog_act(self, rule): + slots = [] + self.get_slots(rule.expansion, slots) + return {'act': rule.name, 'slots': slots} + + def get_slots(self, expansion, slots): + if expansion.tag != '': + slots.append((expansion.tag, expansion.current_match)) + return + + for child in expansion.children: + self.get_slots(child, slots) + + if not expansion.children and isinstance(expansion, jsgf.NamedRuleRef): + self.get_slots(expansion.referenced_rule.expansion, slots) + + def __call__(self, prompt) -> Any: + book_grammar = jsgf.parse_grammar_file('book.jsgf') + + prompt = unidecode(prompt) + translator = str.maketrans('', '', string.punctuation) + prompt = prompt.translate(translator) + + matched = book_grammar.find_matching_rules(prompt) + + if matched: + return self.get_dialog_act(matched[0]) + else: + return {'act': 'null', 'slots': []} + + +class DST(dst.DST): + + def __init__(self, state): + dst.DST.__init__(self) + self.state = state + + def __call__(self, user_act) -> Any: + if len(user_act['slots']) == 0: + user_act = [(user_act['act'], None, None)] + else: + user_act = [(user_act['act'], k, v) for k, v in user_act['slots'] if v is not None] + + self.state['request_state'] = {} + for act, slot, value in user_act: + self.state['user_action'].append(act) + + if act == "platnosc": + self.state['belief_state']['payment_method'] = value + self.state['request_state'] = ['payment_method'] + + elif act == "offer": + self.state['request_state'] = ['menu'] + + elif act == 'select': + if slot == 'dish': + self.state['belief_state']['dish'].append(value) + else: + self.state['belief_state'][slot] = value + self.state['request_state'] = [slot] + + elif act == 'inform': + pass + + elif act == 'request': + pass + elif act == 'restart': + self.state["belief_state"] = default_state()["belief_state"] + self.state["booked"] = {} + self.state["request_state"] = [] + self.state["terminated"] = False + self.state["history"] = [] + return self.state + + +class DP(): + def __init__(self, state): + self.state = state + + def __call__(self) -> Any: + system_action = None + + if self.state['user_action'][-1] == 'hello': + system_action = 'welcomemsg' + # przywitaj uzytkownika (i pokaz menu) + + elif self.state['user_action'][-1] == 'select': + system_action = 'inform' + # poinformuj o wybranych slotach z "request_state" + + elif (self.state['user_action'][-1] == 'help' + or self.state['user_action'][-1] == 'offer' + or self.state['user_action'][-1] == 'reqmore' + or (self.state['user_action'][-1] == 'request' and len(self.state['request_state']) == 0) + ): + system_action = 'offer' + # zaoferuj cale menu + + elif self.state['user_action'][-1] == 'ack': + address = self.state["belief_state"]["address"] + payment_method = self.state["belief_state"]["payment_method"] + dish = self.state["belief_state"]["dish"] + # W przypadku braku szczegolnej informacji o czasie zamówienia zamawiamy natychmiast + + if address and payment_method and dish: + system_action = 'bye' + self.state['terminated'] = True + # potwierdz i zakoncz, podsumuj zamowienie + else: + system_action = 'canthelp.missing_slot_value' + elif self.state['user_action'][-1] == 'restart': + system_action = 'welcomemsg' + # zachowaj sie jak na poczatku rozmowy + else: + system_action = 'inform' + # poinformuj o wybranych slotach z "request_state" + # lub o wszystkich jezeli nic nie ma w request state + + self.state['system_action'].append(system_action) + return system_action + + +class NLG(): + def __init__(self, state): + self.model = AutoModelForSeq2SeqLM.from_pretrained("filnow/nlg-umt5-pol") + self.tokenizer = AutoTokenizer.from_pretrained("google/umt5-small") + self.nlg_pipeline = pipeline('summarization', model=self.model, tokenizer=self.tokenizer) + + def __call__(self, act, slots) -> Any: + if act == 'welcomemsg': + return "Witaj w naszej restauracji! Jak mogę Ci pomóc?" + + elif act == "offer": + if slots == []: + return "Przepraszam nie rozumiem. Podaj więcej informacji." + + elif act == "inform": + if slots == []: + return "Przepraszam nie rozumiem. Podaj więcej informacji." + else: + text = [] + for i in slots: + if i[1] != None: + text.append(f"{i[0]}[{i[1]}]") + return self.nlg_pipeline(f'generate text: {", ".join(text)}')[0]['summary_text'] + + elif act == "canthelp.missing_slot_value": + return "Przepraszam, ale nie mogę zrealizować zamówienia. Brakuje mi niektórych informacji. Czy mogę pomóc w czymś innym?" + + elif act == "bye": + return "Dziękujemy za zamówienie! Smacznego!" + + +if __name__ == "__main__": + model = Model() + + # jezeli sie przywita to przywitaj uzytkownika (i pokaz menu) + # response = model("Cześć") + # response = model("Witam") + # response = model("Witam system") + # response = model("Hej, jakim botem jesteś?") + # response = model("Hej, czym się zajmujesz?") + # response = model("Hej, w czym mi możesz pomóc?") + response = model("Siema, w czym możesz mi pomóc?") + assert response == "welcomemsg" + + print() + + # jezeli prosi o pomoc lub po prostu o menu to zaoferuj cale menu + # response = model("Pokaz menu") + # response = model("A co do picia proponujesz?") + # response = model("Jakie inne desery oferujesz?") + response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.") + assert response == "offer" + + print() + + # jezeli wybierze danie to zapisz wybor i poinformuj o nim + # response = model("Wezmę rybe") + # response = model("Poproszę tatara") + response = model("Chciałbym zjesc tatara") + assert response == "inform" + + print() + + # jezeli poda adres to zapisze wybor i poinformuj o nim + # response = model('Poproszę na poznańską 2') + response = model("uniwersytetu poznanskiego 4 61-614 poznan") + assert response == "inform" + + # jezeli sprobuje dokonac zamowienia bez podania potrzebnych informacji prosimy o nie + #response = model("Dobrze, nie mogę się już doczekać.") + response = model("Super, to zatem wszystko!") + assert response == "canthelp.missing_slot_value" + + # jezeli wybierze rodzaj platnosci to zapisz wybor i poinformuj o nim + # response = model("karta") + # response = model("Poproszę blikiem z góry") + response = model("Zapłacę kartą przy odbiorze") + assert response == "inform" + + print() + + # jezeli potwiedzi zamowienie to zakoncz zamawianie sukcesem i wypisz calosc + # response = model("Potwierdzam!") + # response = model("Tak!") + # response = model("Tak to wszystko!") + # response = model("Super, to zatem wszystko!") + response = model("Dobrze, nie mogę się już doczekać.") + assert response == "bye" + + print("----Konwersacja z restartem-------") + + model = Model() + response = model("Siema, w czym możesz mi pomóc?") + assert response == "welcomemsg" + response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.") + assert response == "offer" + response = model("Chciałbym zjesc tatara") + assert response == "inform" + response = model("uniwersytetu poznanskiego 4 61-614 poznan") + assert response == "inform" + response = model("od nowa") + assert response == "welcomemsg" + response = model("Interesują mnie dania kuchni włoskiej oraz meksykanskiej.") + assert response == "offer" + response = model("Chciałbym zjesc tatara") + assert response == "inform" + response = model("uniwersytetu poznanskiego 4 61-614 poznan") + assert response == "inform" + response = model("Zapłacę kartą przy odbiorze") + assert response == "inform" + response = model("Dobrze, nie mogę się już doczekać.") + assert response == "bye" \ No newline at end of file diff --git a/nlg_train.ipynb b/nlg_train.ipynb index 64a375f..273df38 100644 --- a/nlg_train.ipynb +++ b/nlg_train.ipynb @@ -1,230 +1 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import (\n", - " AutoModelForSeq2SeqLM,\n", - " AutoTokenizer,\n", - " DataCollatorForSeq2Seq,\n", - " Seq2SeqTrainer,\n", - " Seq2SeqTrainingArguments,\n", - " pipeline,\n", - ")\n", - "\n", - "from datasets import load_dataset\n", - "\n", - "model_name = \"google/umt5-small\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = load_dataset('csv', data_files='/kaggle/input/ngl-data/nlg_data.csv', split='train').train_test_split(test_size=0.1)\n", - "dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", - "\n", - "\n", - "def tokenize_samples(samples):\n", - " inputs = [f\"generate text: {mr}\" for mr in samples[\"mr\"]]\n", - "\n", - " tokenized_inputs = tokenizer(\n", - " inputs,\n", - " max_length=128,\n", - " padding=\"max_length\",\n", - " truncation=True,\n", - " )\n", - "\n", - " labels = tokenizer(\n", - " text_target=samples[\"ref\"],\n", - " max_length=128,\n", - " padding=\"max_length\",\n", - " truncation=True,\n", - " )\n", - "\n", - " labels[\"input_ids\"] = [\n", - " [\n", - " (token_id if token_id != tokenizer.pad_token_id else -100)\n", - " for token_id in label\n", - " ]\n", - " for label in labels[\"input_ids\"]\n", - " ]\n", - "\n", - " tokenized_inputs[\"labels\"] = labels[\"input_ids\"]\n", - " return tokenized_inputs\n", - "\n", - "\n", - "tokenized_dataset = dataset.map(\n", - " tokenize_samples,\n", - " batched=True,\n", - " remove_columns=[\"mr\", \"ref\"],\n", - ")\n", - "\n", - "tokenized_dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n", - "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100, pad_to_multiple_of=8)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "training_args = Seq2SeqTrainingArguments(\n", - " output_dir=\"/kaggle/working\",\n", - " per_device_train_batch_size=8,\n", - " per_device_eval_batch_size=16,\n", - " predict_with_generate=True,\n", - " learning_rate=5e-5,\n", - " num_train_epochs=3,\n", - " evaluation_strategy=\"epoch\",\n", - " save_strategy=\"epoch\",\n", - " save_total_limit=1,\n", - " load_best_model_at_end=True,\n", - ")\n", - "\n", - "trainer = Seq2SeqTrainer(\n", - " model=model,\n", - " args=training_args,\n", - " data_collator=data_collator,\n", - " train_dataset=tokenized_dataset[\"train\"],\n", - " eval_dataset=tokenized_dataset[\"test\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.train()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "nlg = pipeline('summarization', model=model, tokenizer=tokenizer)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "nlg(f'generate text: dish[tatar], price[50], ingredient[wolowina]')[0]['summary_text']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "nlg(f'generate text: payment_methods[gotowka], price[150], addresses[ulica Dluga 5]')[0]['summary_text']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "nlg(f'generate text: dish[tiramisu], ingredient[mleko], allergy[laktoza]')[0]['summary_text']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "nlg(f'generate text: time[dziesiata]')[0]['summary_text']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "nlg(f'generate text: dish[spaghetti], ingredient[ser]')[0]['summary_text']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "nlg(f'generate text: dish[pierogi], ingredient[kozi ser]')[0]['summary_text']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "nlg(f'generate text: time[23:00], adres[ul Krótka 256]')[0]['summary_text']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model.save_pretrained(\"/kaggle/working\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "jarvis", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.19" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} +{"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3","language":"python"},"language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":8587424,"sourceType":"datasetVersion","datasetId":5135632}],"dockerImageVersionId":30716,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"from transformers import (\n AutoModelForSeq2SeqLM,\n AutoTokenizer,\n DataCollatorForSeq2Seq,\n Seq2SeqTrainer,\n Seq2SeqTrainingArguments,\n pipeline,\n)\n\nfrom datasets import load_dataset\n\nmodel_name = \"google/umt5-small\"","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:18:55.032642Z","iopub.execute_input":"2024-06-03T11:18:55.033345Z","iopub.status.idle":"2024-06-03T11:19:13.773777Z","shell.execute_reply.started":"2024-06-03T11:18:55.033313Z","shell.execute_reply":"2024-06-03T11:19:13.772989Z"},"trusted":true},"execution_count":1,"outputs":[{"name":"stderr","text":"2024-06-03 11:19:02.256736: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n2024-06-03 11:19:02.256864: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n2024-06-03 11:19:02.368948: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n","output_type":"stream"}]},{"cell_type":"code","source":"dataset = load_dataset('csv', data_files='/kaggle/input/ngl-data/nlg_data.csv', split='train').train_test_split(test_size=0.1)\ndataset","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:19:13.775364Z","iopub.execute_input":"2024-06-03T11:19:13.775904Z","iopub.status.idle":"2024-06-03T11:19:14.356839Z","shell.execute_reply.started":"2024-06-03T11:19:13.775878Z","shell.execute_reply":"2024-06-03T11:19:14.355976Z"},"trusted":true},"execution_count":2,"outputs":[{"output_type":"display_data","data":{"text/plain":"Generating train split: 0 examples [00:00, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"fdd37b65a44d42b2931bdc0db8229fa7"}},"metadata":{}},{"execution_count":2,"output_type":"execute_result","data":{"text/plain":"DatasetDict({\n train: Dataset({\n features: ['mr', 'ref'],\n num_rows: 18564\n })\n test: Dataset({\n features: ['mr', 'ref'],\n num_rows: 2063\n })\n})"},"metadata":{}}]},{"cell_type":"code","source":"tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n\ndef tokenize_samples(samples):\n inputs = [f\"generate text: {mr}\" for mr in samples[\"mr\"]]\n\n tokenized_inputs = tokenizer(\n inputs,\n max_length=128,\n padding=\"max_length\",\n truncation=True,\n )\n\n labels = tokenizer(\n text_target=samples[\"ref\"],\n max_length=128,\n padding=\"max_length\",\n truncation=True,\n )\n\n labels[\"input_ids\"] = [\n [\n (token_id if token_id != tokenizer.pad_token_id else -100)\n for token_id in label\n ]\n for label in labels[\"input_ids\"]\n ]\n\n tokenized_inputs[\"labels\"] = labels[\"input_ids\"]\n return tokenized_inputs\n\n\ntokenized_dataset = dataset.map(\n tokenize_samples,\n batched=True,\n remove_columns=[\"mr\", \"ref\"],\n)\n\ntokenized_dataset","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:19:14.357803Z","iopub.execute_input":"2024-06-03T11:19:14.358052Z","iopub.status.idle":"2024-06-03T11:19:24.614600Z","shell.execute_reply.started":"2024-06-03T11:19:14.358030Z","shell.execute_reply":"2024-06-03T11:19:24.613696Z"},"trusted":true},"execution_count":3,"outputs":[{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json: 0%| | 0.00/6.84k [00:00","text/html":"Tracking run with wandb version 0.17.0"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"","text/html":"Run data is saved locally in /kaggle/working/wandb/run-20240603_111947-zd4tutif"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"","text/html":"Syncing run /kaggle/working to Weights & Biases (docs)
"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"","text/html":" View project at https://wandb.ai/filnow42/huggingface"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"","text/html":" View run at https://wandb.ai/filnow42/huggingface/runs/zd4tutif"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"","text/html":"\n
\n \n \n [6963/6963 38:47, Epoch 3/3]\n
\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
EpochTraining LossValidation Loss
10.7329000.331611
20.3731000.246366
30.3269000.231167

"},"metadata":{}},{"name":"stderr","text":"There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].\n","output_type":"stream"},{"execution_count":6,"output_type":"execute_result","data":{"text/plain":"TrainOutput(global_step=6963, training_loss=1.0388871652717377, metrics={'train_runtime': 2359.6292, 'train_samples_per_second': 23.602, 'train_steps_per_second': 2.951, 'total_flos': 7499132383002624.0, 'train_loss': 1.0388871652717377, 'epoch': 3.0})"},"metadata":{}}]},{"cell_type":"code","source":"nlg = pipeline('summarization', model=model, tokenizer=tokenizer)","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:58:53.891542Z","iopub.execute_input":"2024-06-03T11:58:53.891952Z","iopub.status.idle":"2024-06-03T11:58:53.897775Z","shell.execute_reply.started":"2024-06-03T11:58:53.891924Z","shell.execute_reply":"2024-06-03T11:58:53.896741Z"},"trusted":true},"execution_count":7,"outputs":[]},{"cell_type":"code","source":"nlg(f'generate text: dish[tatar], price[50], ingredient[wolowina]')[0]['summary_text']","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:58:53.898928Z","iopub.execute_input":"2024-06-03T11:58:53.899234Z","iopub.status.idle":"2024-06-03T11:59:05.979970Z","shell.execute_reply.started":"2024-06-03T11:58:53.899195Z","shell.execute_reply":"2024-06-03T11:59:05.978805Z"},"trusted":true},"execution_count":8,"outputs":[{"execution_count":8,"output_type":"execute_result","data":{"text/plain":"'Nie mamy tatar w menu. Cena wynosi 50. Składnik to owoce.'"},"metadata":{}}]},{"cell_type":"code","source":"nlg(f'generate text: payment_methods[gotowka], price[150], addresses[ulica Dluga 5]')[0]['summary_text']","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:59:05.981291Z","iopub.execute_input":"2024-06-03T11:59:05.981585Z","iopub.status.idle":"2024-06-03T11:59:06.533378Z","shell.execute_reply.started":"2024-06-03T11:59:05.981559Z","shell.execute_reply":"2024-06-03T11:59:06.532379Z"},"trusted":true},"execution_count":9,"outputs":[{"execution_count":9,"output_type":"execute_result","data":{"text/plain":"'Nie obsługujemy płatności gotowka. Cena wynosi 150. Oczywiście, dostarczymy na ulica Dluga 5.'"},"metadata":{}}]},{"cell_type":"code","source":"nlg(f'generate text: dish[tiramisu], ingredient[mleko], allergy[laktoza]')[0]['summary_text']","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:59:06.537427Z","iopub.execute_input":"2024-06-03T11:59:06.538123Z","iopub.status.idle":"2024-06-03T11:59:06.938435Z","shell.execute_reply.started":"2024-06-03T11:59:06.538081Z","shell.execute_reply":"2024-06-03T11:59:06.937299Z"},"trusted":true},"execution_count":10,"outputs":[{"execution_count":10,"output_type":"execute_result","data":{"text/plain":"'Nie mamy tiramisu w menu. Składnik mleko jest dostępny. Nie zawiera alergenu laktoza.'"},"metadata":{}}]},{"cell_type":"code","source":"nlg(f'generate text: time[dziesiata]')[0]['summary_text']","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:59:06.939929Z","iopub.execute_input":"2024-06-03T11:59:06.940331Z","iopub.status.idle":"2024-06-03T11:59:07.132913Z","shell.execute_reply.started":"2024-06-03T11:59:06.940292Z","shell.execute_reply":"2024-06-03T11:59:07.131901Z"},"trusted":true},"execution_count":11,"outputs":[{"name":"stderr","text":"Your max_length is set to 20, but your input_length is only 10. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=5)\n","output_type":"stream"},{"execution_count":11,"output_type":"execute_result","data":{"text/plain":"'Zamknięte o dziesiata.'"},"metadata":{}}]},{"cell_type":"code","source":"nlg(f'generate text: dish[spaghetti], ingredient[ser]')[0]['summary_text']","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:59:07.134067Z","iopub.execute_input":"2024-06-03T11:59:07.134671Z","iopub.status.idle":"2024-06-03T11:59:07.405347Z","shell.execute_reply.started":"2024-06-03T11:59:07.134642Z","shell.execute_reply":"2024-06-03T11:59:07.404117Z"},"trusted":true},"execution_count":12,"outputs":[{"name":"stderr","text":"Your max_length is set to 20, but your input_length is only 14. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=7)\n","output_type":"stream"},{"execution_count":12,"output_type":"execute_result","data":{"text/plain":"'Nie mamy spaghetti w menu. Składnik ser jest dostępny.'"},"metadata":{}}]},{"cell_type":"code","source":"nlg(f'generate text: dish[pierogi], ingredient[kozi ser]')[0]['summary_text']","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:59:07.407270Z","iopub.execute_input":"2024-06-03T11:59:07.410442Z","iopub.status.idle":"2024-06-03T11:59:07.697634Z","shell.execute_reply.started":"2024-06-03T11:59:07.410396Z","shell.execute_reply":"2024-06-03T11:59:07.695355Z"},"trusted":true},"execution_count":13,"outputs":[{"name":"stderr","text":"Your max_length is set to 20, but your input_length is only 16. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=8)\n","output_type":"stream"},{"execution_count":13,"output_type":"execute_result","data":{"text/plain":"'Nie mamy pierogi w menu. Składnik to koti ser.'"},"metadata":{}}]},{"cell_type":"code","source":"nlg(f'generate text: time[23:00], adres[ul Krótka 256]')[0]['summary_text']","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:59:07.698906Z","iopub.execute_input":"2024-06-03T11:59:07.699269Z","iopub.status.idle":"2024-06-03T11:59:08.138934Z","shell.execute_reply.started":"2024-06-03T11:59:07.699233Z","shell.execute_reply":"2024-06-03T11:59:08.137833Z"},"trusted":true},"execution_count":14,"outputs":[{"execution_count":14,"output_type":"execute_result","data":{"text/plain":"'Zamknięte o 23:00. Nie dostarczamy na ulica Krótka 256.'"},"metadata":{}}]},{"cell_type":"code","source":"model.save_pretrained(\"/kaggle/working\")","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:59:08.140399Z","iopub.execute_input":"2024-06-03T11:59:08.140718Z","iopub.status.idle":"2024-06-03T11:59:11.078579Z","shell.execute_reply.started":"2024-06-03T11:59:08.140689Z","shell.execute_reply":"2024-06-03T11:59:11.077378Z"},"trusted":true},"execution_count":15,"outputs":[]},{"cell_type":"code","source":"from kaggle_secrets import UserSecretsClient\nuser_secrets = UserSecretsClient()\nsecret_value_0 = user_secrets.get_secret(\"huggingface-write\")","metadata":{"execution":{"iopub.status.busy":"2024-06-03T12:03:34.283930Z","iopub.execute_input":"2024-06-03T12:03:34.284674Z","iopub.status.idle":"2024-06-03T12:03:34.468881Z","shell.execute_reply.started":"2024-06-03T12:03:34.284637Z","shell.execute_reply":"2024-06-03T12:03:34.467812Z"},"trusted":true},"execution_count":18,"outputs":[]},{"cell_type":"code","source":"from huggingface_hub import login\nlogin(secret_value_0)","metadata":{"execution":{"iopub.status.busy":"2024-06-03T12:03:38.979682Z","iopub.execute_input":"2024-06-03T12:03:38.980042Z","iopub.status.idle":"2024-06-03T12:03:39.119457Z","shell.execute_reply.started":"2024-06-03T12:03:38.980011Z","shell.execute_reply":"2024-06-03T12:03:39.118367Z"},"trusted":true},"execution_count":19,"outputs":[{"name":"stdout","text":"The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\nToken is valid (permission: write).\nYour token has been saved to /root/.cache/huggingface/token\nLogin successful\n","output_type":"stream"}]},{"cell_type":"code","source":"trainer.push_to_hub(\"filnow/nlg-umt5-pol\")","metadata":{"execution":{"iopub.status.busy":"2024-06-03T12:03:45.289755Z","iopub.execute_input":"2024-06-03T12:03:45.290131Z","iopub.status.idle":"2024-06-03T12:04:24.555639Z","shell.execute_reply.started":"2024-06-03T12:03:45.290099Z","shell.execute_reply":"2024-06-03T12:04:24.554427Z"},"trusted":true},"execution_count":20,"outputs":[{"output_type":"display_data","data":{"text/plain":"events.out.tfevents.1717413574.743112a2decd.34.0: 0%| | 0.00/9.10k [00:00