{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import (\n", " AutoModelForSeq2SeqLM,\n", " AutoTokenizer,\n", " DataCollatorForSeq2Seq,\n", " Seq2SeqTrainer,\n", " Seq2SeqTrainingArguments,\n", " pipeline,\n", ")\n", "\n", "from datasets import load_dataset\n", "\n", "model_name = \"google/umt5-small\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset('csv', data_files='/kaggle/input/ngl-data/nlg_data.csv', split='train').train_test_split(test_size=0.1)\n", "dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", "\n", "def tokenize_samples(samples):\n", " inputs = [f\"generate text: {mr}\" for mr in samples[\"mr\"]]\n", "\n", " tokenized_inputs = tokenizer(\n", " inputs,\n", " max_length=128,\n", " padding=\"max_length\",\n", " truncation=True,\n", " )\n", "\n", " labels = tokenizer(\n", " text_target=samples[\"ref\"],\n", " max_length=128,\n", " padding=\"max_length\",\n", " truncation=True,\n", " )\n", "\n", " labels[\"input_ids\"] = [\n", " [\n", " (token_id if token_id != tokenizer.pad_token_id else -100)\n", " for token_id in label\n", " ]\n", " for label in labels[\"input_ids\"]\n", " ]\n", "\n", " tokenized_inputs[\"labels\"] = labels[\"input_ids\"]\n", " return tokenized_inputs\n", "\n", "\n", "tokenized_dataset = dataset.map(\n", " tokenize_samples,\n", " batched=True,\n", " remove_columns=[\"mr\", \"ref\"],\n", ")\n", "\n", "tokenized_dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n", "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100, pad_to_multiple_of=8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "training_args = Seq2SeqTrainingArguments(\n", " output_dir=\"/kaggle/working\",\n", " per_device_train_batch_size=8,\n", " per_device_eval_batch_size=16,\n", " predict_with_generate=True,\n", " learning_rate=5e-5,\n", " num_train_epochs=3,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " save_total_limit=1,\n", " load_best_model_at_end=True,\n", ")\n", "\n", "trainer = Seq2SeqTrainer(\n", " model=model,\n", " args=training_args,\n", " data_collator=data_collator,\n", " train_dataset=tokenized_dataset[\"train\"],\n", " eval_dataset=tokenized_dataset[\"test\"],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nlg = pipeline('summarization', model=model, tokenizer=tokenizer)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nlg(f'generate text: dish[tatar], price[50], ingredient[wolowina]')[0]['summary_text']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nlg(f'generate text: payment_methods[gotowka], price[150], addresses[ulica Dluga 5]')[0]['summary_text']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nlg(f'generate text: dish[tiramisu], ingredient[mleko], allergy[laktoza]')[0]['summary_text']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nlg(f'generate text: time[dziesiata]')[0]['summary_text']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nlg(f'generate text: dish[spaghetti], ingredient[ser]')[0]['summary_text']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nlg(f'generate text: dish[pierogi], ingredient[kozi ser]')[0]['summary_text']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nlg(f'generate text: time[23:00], adres[ul Krótka 256]')[0]['summary_text']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.save_pretrained(\"/kaggle/working\")" ] } ], "metadata": { "kernelspec": { "display_name": "jarvis", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 2 }