JARVIS/nlg_train.ipynb

231 lines
5.5 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import (\n",
" AutoModelForSeq2SeqLM,\n",
" AutoTokenizer,\n",
" DataCollatorForSeq2Seq,\n",
" Seq2SeqTrainer,\n",
" Seq2SeqTrainingArguments,\n",
" pipeline,\n",
")\n",
"\n",
"from datasets import load_dataset\n",
"\n",
"model_name = \"google/umt5-small\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = load_dataset('csv', data_files='/kaggle/input/ngl-data/nlg_data.csv', split='train').train_test_split(test_size=0.1)\n",
"dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
"\n",
"def tokenize_samples(samples):\n",
" inputs = [f\"generate text: {mr}\" for mr in samples[\"mr\"]]\n",
"\n",
" tokenized_inputs = tokenizer(\n",
" inputs,\n",
" max_length=128,\n",
" padding=\"max_length\",\n",
" truncation=True,\n",
" )\n",
"\n",
" labels = tokenizer(\n",
" text_target=samples[\"ref\"],\n",
" max_length=128,\n",
" padding=\"max_length\",\n",
" truncation=True,\n",
" )\n",
"\n",
" labels[\"input_ids\"] = [\n",
" [\n",
" (token_id if token_id != tokenizer.pad_token_id else -100)\n",
" for token_id in label\n",
" ]\n",
" for label in labels[\"input_ids\"]\n",
" ]\n",
"\n",
" tokenized_inputs[\"labels\"] = labels[\"input_ids\"]\n",
" return tokenized_inputs\n",
"\n",
"\n",
"tokenized_dataset = dataset.map(\n",
" tokenize_samples,\n",
" batched=True,\n",
" remove_columns=[\"mr\", \"ref\"],\n",
")\n",
"\n",
"tokenized_dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100, pad_to_multiple_of=8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_args = Seq2SeqTrainingArguments(\n",
" output_dir=\"/kaggle/working\",\n",
" per_device_train_batch_size=8,\n",
" per_device_eval_batch_size=16,\n",
" predict_with_generate=True,\n",
" learning_rate=5e-5,\n",
" num_train_epochs=3,\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" save_total_limit=1,\n",
" load_best_model_at_end=True,\n",
")\n",
"\n",
"trainer = Seq2SeqTrainer(\n",
" model=model,\n",
" args=training_args,\n",
" data_collator=data_collator,\n",
" train_dataset=tokenized_dataset[\"train\"],\n",
" eval_dataset=tokenized_dataset[\"test\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nlg = pipeline('summarization', model=model, tokenizer=tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nlg(f'generate text: dish[tatar], price[50], ingredient[wolowina]')[0]['summary_text']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nlg(f'generate text: payment_methods[gotowka], price[150], addresses[ulica Dluga 5]')[0]['summary_text']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nlg(f'generate text: dish[tiramisu], ingredient[mleko], allergy[laktoza]')[0]['summary_text']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nlg(f'generate text: time[dziesiata]')[0]['summary_text']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nlg(f'generate text: dish[spaghetti], ingredient[ser]')[0]['summary_text']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nlg(f'generate text: dish[pierogi], ingredient[kozi ser]')[0]['summary_text']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nlg(f'generate text: time[23:00], adres[ul Krótka 256]')[0]['summary_text']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.save_pretrained(\"/kaggle/working\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jarvis",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}