{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"A100","machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","source":["import locale\n","locale.getpreferredencoding = lambda: \"UTF-8\""],"metadata":{"id":"0cKJSrCDIC5c","executionInfo":{"status":"ok","timestamp":1687337064661,"user_tz":-120,"elapsed":5,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":1,"outputs":[]},{"cell_type":"code","source":["!pip install transformers torch accelerate"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"TVWZUBhyPfpa","outputId":"21e76e7e-6d88-41f1-8367-ec2f5862bfd0","executionInfo":{"status":"ok","timestamp":1687337068501,"user_tz":-120,"elapsed":3844,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.30.2)\n","Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.0.1+cu118)\n","Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.20.3)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.0)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.1)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)\n","Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n","Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.1)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.11.1)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.1)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n","Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.0.0)\n","Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (3.25.2)\n","Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (16.0.5)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.4.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.2)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2022.12.7)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n","Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n"]}]},{"cell_type":"code","source":["import pandas as pd"],"metadata":{"id":"2NPC0SFrzVQS","executionInfo":{"status":"ok","timestamp":1687337068860,"user_tz":-120,"elapsed":365,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":3,"outputs":[]},{"cell_type":"code","execution_count":4,"metadata":{"id":"LdRQU2xnOrst","executionInfo":{"status":"ok","timestamp":1687337077900,"user_tz":-120,"elapsed":9042,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"outputs":[],"source":["from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed\n","\n","model = AutoModelForCausalLM.from_pretrained('flax-community/papuGaPT2')\n","tokenizer = AutoTokenizer.from_pretrained('flax-community/papuGaPT2')\n","\n","# model = AutoModelForCausalLM.from_pretrained('sdadas/polish-gpt2-medium')\n","# tokenizer = AutoTokenizer.from_pretrained('sdadas/polish-gpt2-medium')\n","\n","tokenizer.pad_token = tokenizer.eos_token"]},{"cell_type":"markdown","source":["# Wczytanie danych do finetuningu\n","Dane stworzyliśmy ręcznie oraz za pomocą ChatGPT."],"metadata":{"id":"IY2e11OjS54T"}},{"cell_type":"code","source":["from google.colab import drive\n","\n","drive.mount('/content/gdrive/', force_remount=True)\n","working_dir = '/content/gdrive/My Drive/empatia/'"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"pSSQJy4zTDDr","outputId":"b8b98736-cc1a-4df5-9912-5ba2c5727749","executionInfo":{"status":"ok","timestamp":1687337080648,"user_tz":-120,"elapsed":2761,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":5,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/gdrive/\n"]}]},{"cell_type":"code","source":["dialogs_df = pd.read_csv(working_dir + 'data/dialogs.csv')\n","dialogs2_df = pd.read_csv(working_dir + 'data/dialogs2.csv')\n","\n","dialogs_df = pd.concat([dialogs_df, dialogs2_df])\n","\n","texts = 'question: ' + dialogs_df['question'] + \"\\nanswer: \" + dialogs_df['answer']\n","texts = texts.tolist()\n","\n","print(texts[10])"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"tD7U4Qa5UhEf","outputId":"5b15a372-fb35-47a7-c005-588924925204","executionInfo":{"status":"ok","timestamp":1687337080649,"user_tz":-120,"elapsed":18,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":6,"outputs":[{"output_type":"stream","name":"stdout","text":["question: powodzenia w szkole.\n","answer: Dziękuję bardzo.\n"]}]},{"cell_type":"code","source":["dialogs_df.sample(5)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":206},"id":"xRYVKkMw0EQd","executionInfo":{"status":"ok","timestamp":1687337080650,"user_tz":-120,"elapsed":14,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}},"outputId":"b53239ec-cf14-4da1-b9c2-ba910d3f5bc7"},"execution_count":7,"outputs":[{"output_type":"execute_result","data":{"text/plain":[" question \\\n","405 Szkoda, że ​​nie mogę pracować mniej. Czuję si... \n","548 Tak, to było o wiele prostsze. Cieszyliśmy się... \n","564 Dowiedziałem się więc czegoś, co bardzo mnie z... \n","142 moja wina, miałem obowiązki do zrobienia. \n","384 brzmi jakby to była bliska gra. \n","\n"," answer \n","405 Próbowałem tego, czego naprawdę potrzebuję, je... \n","548 życie było proste wtedy nie było! bardzo ładny. \n","564 Moje dziecko wyszło za moimi plecami i wymknęł... \n","142 w porządku. \n","384 dlatego była to tak świetna gra. "],"text/html":["\n","
\n","
\n","
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
questionanswer
405Szkoda, że ​​nie mogę pracować mniej. Czuję si...Próbowałem tego, czego naprawdę potrzebuję, je...
548Tak, to było o wiele prostsze. Cieszyliśmy się...życie było proste wtedy nie było! bardzo ładny.
564Dowiedziałem się więc czegoś, co bardzo mnie z...Moje dziecko wyszło za moimi plecami i wymknęł...
142moja wina, miałem obowiązki do zrobienia.w porządku.
384brzmi jakby to była bliska gra.dlatego była to tak świetna gra.
\n","
\n"," \n"," \n"," \n","\n"," \n","
\n","
\n"," "]},"metadata":{},"execution_count":7}]},{"cell_type":"markdown","source":["# Preprocessing"],"metadata":{"id":"CQw_oCFyUnY_"}},{"cell_type":"code","source":["from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler\n","import torch\n","\n","# Create custom dataset\n","class KolegaDataset(Dataset):\n"," def __init__(self, txt_list, tokenizer):\n"," self.tokenizer = tokenizer\n"," self.input_ids = []\n"," self.attn_masks = []\n","\n"," for txt in txt_list:\n"," encodings_dict = tokenizer(txt, padding=\"max_length\", truncation=True, max_length=512)\n"," self.input_ids.append(torch.tensor(encodings_dict['input_ids']))\n"," self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))\n","\n"," def __len__(self):\n"," return len(self.input_ids)\n","\n"," def __getitem__(self, idx):\n"," return self.input_ids[idx], self.attn_masks[idx]"],"metadata":{"id":"_AYrfmfGXMEV","executionInfo":{"status":"ok","timestamp":1687337080650,"user_tz":-120,"elapsed":11,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":8,"outputs":[]},{"cell_type":"code","source":["dataset = KolegaDataset(texts, tokenizer)\n","\n","train_size = int(0.9 * len(dataset))\n","val_size = len(dataset) - train_size\n","\n","train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n","\n","print('Train dataset size: ', train_size)\n","print('Validation dataset size: ', val_size)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"yQp1opRYXPAv","outputId":"65861636-2f84-4fc7-ba69-39972c0532fd","executionInfo":{"status":"ok","timestamp":1687337081299,"user_tz":-120,"elapsed":659,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":9,"outputs":[{"output_type":"stream","name":"stdout","text":["Train dataset size: 1349\n","Validation dataset size: 150\n"]}]},{"cell_type":"code","source":["batch_size = 8\n","\n","train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=batch_size)\n","validation_dataloader = DataLoader(val_dataset, sampler=SequentialSampler(val_dataset), batch_size=batch_size)"],"metadata":{"id":"4LDKgbSAcPo8","executionInfo":{"status":"ok","timestamp":1687337081299,"user_tz":-120,"elapsed":8,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":10,"outputs":[]},{"cell_type":"markdown","source":["# Fine-tuning"],"metadata":{"id":"a5NTJK7HVjYD"}},{"cell_type":"code","source":["# some parameters I cooked up that work reasonably well\n","\n","epochs = 20\n","learning_rate = 0.0005\n","warmup_steps = 1e2\n","epsilon = 1e-8"],"metadata":{"id":"TnPudHlZVmaA","executionInfo":{"status":"ok","timestamp":1687337081300,"user_tz":-120,"elapsed":9,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":11,"outputs":[]},{"cell_type":"code","source":["from transformers import AdamW, get_linear_schedule_with_warmup\n","\n","optimizer = AdamW(model.parameters(), lr = learning_rate, eps = epsilon)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ZPic7oqNdGcH","outputId":"0926637e-4be2-49e1-8e10-5c31d45f7b12","executionInfo":{"status":"ok","timestamp":1687337081300,"user_tz":-120,"elapsed":8,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":12,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n"," warnings.warn(\n"]}]},{"cell_type":"code","source":["total_steps = len(train_dataloader) * epochs\n","\n","scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warmup_steps, num_training_steps = total_steps)"],"metadata":{"id":"u-zq78GveBbk","executionInfo":{"status":"ok","timestamp":1687337081301,"user_tz":-120,"elapsed":6,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":13,"outputs":[]},{"cell_type":"code","source":["import datetime\n","import time\n","import random\n","\n","def format_time(elapsed):\n"," return str(datetime.timedelta(seconds=int(round((elapsed)))))\n","\n","device = torch.device(\"cuda\")\n","model.cuda()"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"52TDlVRadJCq","outputId":"5851ed70-4237-4c1b-faf6-3f45dc1bb9d4","executionInfo":{"status":"ok","timestamp":1687337081689,"user_tz":-120,"elapsed":394,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":14,"outputs":[{"output_type":"execute_result","data":{"text/plain":["GPT2LMHeadModel(\n"," (transformer): GPT2Model(\n"," (wte): Embedding(51200, 768)\n"," (wpe): Embedding(2048, 768)\n"," (drop): Dropout(p=0.1, inplace=False)\n"," (h): ModuleList(\n"," (0-11): 12 x GPT2Block(\n"," (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (attn): GPT2Attention(\n"," (c_attn): Conv1D()\n"," (c_proj): Conv1D()\n"," (attn_dropout): Dropout(p=0.1, inplace=False)\n"," (resid_dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (mlp): GPT2MLP(\n"," (c_fc): Conv1D()\n"," (c_proj): Conv1D()\n"," (act): FastGELUActivation()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," )\n"," )\n"," (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," )\n"," (lm_head): Linear(in_features=768, out_features=51200, bias=False)\n",")"]},"metadata":{},"execution_count":14}]},{"cell_type":"code","source":["total_t0 = time.time()\n","\n","training_stats = []\n","\n","model = model.to(device)\n","\n","for epoch_i in range(0, epochs):\n","\n"," # ========================================\n"," # Training\n"," # ========================================\n","\n"," print(\"\")\n"," print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))\n"," print('Training...')\n","\n"," t0 = time.time()\n","\n"," total_train_loss = 0\n","\n"," model.train()\n","\n"," for step, batch in enumerate(train_dataloader):\n","\n"," b_input_ids = batch[0].to(device)\n"," b_labels = batch[0].to(device)\n"," b_masks = batch[1].to(device)\n","\n"," model.zero_grad()\n","\n"," outputs = model( b_input_ids,\n"," labels=b_labels,\n"," attention_mask = b_masks,\n"," token_type_ids=None\n"," )\n","\n"," loss = outputs[0]\n","\n"," batch_loss = loss.item()\n"," total_train_loss += batch_loss\n","\n"," loss.backward()\n","\n"," optimizer.step()\n","\n"," scheduler.step()\n","\n"," # Calculate the average loss over all of the batches.\n"," avg_train_loss = total_train_loss / len(train_dataloader)\n","\n"," # Measure how long this epoch took.\n"," training_time = format_time(time.time() - t0)\n","\n"," print(\"\")\n"," print(\" Average training loss: {0:.2f}\".format(avg_train_loss))\n"," print(\" Training epoch took: {:}\".format(training_time))\n","\n"," # ========================================\n"," # Validation\n"," # ========================================\n","\n"," print(\"\")\n"," print(\"Running Validation...\")\n","\n"," t0 = time.time()\n","\n"," model.eval()\n","\n"," total_eval_loss = 0\n"," nb_eval_steps = 0\n","\n"," # Evaluate data for one epoch\n"," for batch in validation_dataloader:\n","\n"," b_input_ids = batch[0].to(device)\n"," b_labels = batch[0].to(device)\n"," b_masks = batch[1].to(device)\n","\n"," with torch.no_grad():\n","\n"," outputs = model(b_input_ids,\n","# token_type_ids=None,\n"," attention_mask = b_masks,\n"," labels=b_labels)\n","\n"," loss = outputs[0]\n","\n"," batch_loss = loss.item()\n"," total_eval_loss += batch_loss\n","\n"," avg_val_loss = total_eval_loss / len(validation_dataloader)\n","\n"," validation_time = format_time(time.time() - t0)\n","\n"," print(\" Validation Loss: {0:.2f}\".format(avg_val_loss))\n"," print(\" Validation took: {:}\".format(validation_time))\n","\n"," # Record all statistics from this epoch.\n"," training_stats.append(\n"," {\n"," 'epoch': epoch_i + 1,\n"," 'Training Loss': avg_train_loss,\n"," 'Valid. Loss': avg_val_loss,\n"," 'Training Time': training_time,\n"," 'Validation Time': validation_time\n"," }\n"," )\n","\n","print(\"\")\n","print(\"Training complete!\")\n","print(\"Total training took {:} (h:mm:ss)\".format(format_time(time.time()-total_t0)))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"pPNGSJoadS9V","outputId":"d8ec8711-16c1-436c-d37a-d761e374e5a2","executionInfo":{"status":"ok","timestamp":1687338041146,"user_tz":-120,"elapsed":959462,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":15,"outputs":[{"output_type":"stream","name":"stdout","text":["\n","======== Epoch 1 / 20 ========\n","Training...\n","\n"," Average training loss: 0.44\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.15\n"," Validation took: 0:00:02\n","\n","======== Epoch 2 / 20 ========\n","Training...\n","\n"," Average training loss: 0.12\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.11\n"," Validation took: 0:00:02\n","\n","======== Epoch 3 / 20 ========\n","Training...\n","\n"," Average training loss: 0.08\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.09\n"," Validation took: 0:00:02\n","\n","======== Epoch 4 / 20 ========\n","Training...\n","\n"," Average training loss: 0.05\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.08\n"," Validation took: 0:00:02\n","\n","======== Epoch 5 / 20 ========\n","Training...\n","\n"," Average training loss: 0.04\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.08\n"," Validation took: 0:00:02\n","\n","======== Epoch 6 / 20 ========\n","Training...\n","\n"," Average training loss: 0.03\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.08\n"," Validation took: 0:00:02\n","\n","======== Epoch 7 / 20 ========\n","Training...\n","\n"," Average training loss: 0.03\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.08\n"," Validation took: 0:00:02\n","\n","======== Epoch 8 / 20 ========\n","Training...\n","\n"," Average training loss: 0.02\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.08\n"," Validation took: 0:00:02\n","\n","======== Epoch 9 / 20 ========\n","Training...\n","\n"," Average training loss: 0.02\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.07\n"," Validation took: 0:00:02\n","\n","======== Epoch 10 / 20 ========\n","Training...\n","\n"," Average training loss: 0.02\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.07\n"," Validation took: 0:00:02\n","\n","======== Epoch 11 / 20 ========\n","Training...\n","\n"," Average training loss: 0.02\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.07\n"," Validation took: 0:00:02\n","\n","======== Epoch 12 / 20 ========\n","Training...\n","\n"," Average training loss: 0.02\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.07\n"," Validation took: 0:00:02\n","\n","======== Epoch 13 / 20 ========\n","Training...\n","\n"," Average training loss: 0.02\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.07\n"," Validation took: 0:00:02\n","\n","======== Epoch 14 / 20 ========\n","Training...\n","\n"," Average training loss: 0.02\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.07\n"," Validation took: 0:00:02\n","\n","======== Epoch 15 / 20 ========\n","Training...\n","\n"," Average training loss: 0.02\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.07\n"," Validation took: 0:00:02\n","\n","======== Epoch 16 / 20 ========\n","Training...\n","\n"," Average training loss: 0.02\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.07\n"," Validation took: 0:00:02\n","\n","======== Epoch 17 / 20 ========\n","Training...\n","\n"," Average training loss: 0.02\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.07\n"," Validation took: 0:00:02\n","\n","======== Epoch 18 / 20 ========\n","Training...\n","\n"," Average training loss: 0.02\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.07\n"," Validation took: 0:00:02\n","\n","======== Epoch 19 / 20 ========\n","Training...\n","\n"," Average training loss: 0.02\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.07\n"," Validation took: 0:00:02\n","\n","======== Epoch 20 / 20 ========\n","Training...\n","\n"," Average training loss: 0.01\n"," Training epoch took: 0:00:46\n","\n","Running Validation...\n"," Validation Loss: 0.08\n"," Validation took: 0:00:02\n","\n","Training complete!\n","Total training took 0:16:00 (h:mm:ss)\n"]}]},{"cell_type":"code","source":["model.eval()\n","\n","input_text = \"question: Cześć, byłem dziś w szkole i było źle\\nanswer:\"\n","input_ids = tokenizer.encode(input_text, return_tensors='pt')\n","input_ids = input_ids.to(device)\n","\n","output = model.generate(input_ids, max_length=100, early_stopping=True)\n","\n","generated_text = tokenizer.decode(output[0], skip_special_tokens=True)\n","print(generated_text)\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"YUAZReU3jPwm","outputId":"236980c3-db50-4b67-aca2-fcc104409a07","executionInfo":{"status":"ok","timestamp":1687338062190,"user_tz":-120,"elapsed":227,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"execution_count":20,"outputs":[{"output_type":"stream","name":"stderr","text":["The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n","Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"]},{"output_type":"stream","name":"stdout","text":["question: Cześć, byłem dziś w szkole i było źle\n","answer: Nie, byłem za młody, ale prawie płakałem\n"]}]},{"cell_type":"code","source":["model.save_pretrained('/content/gdrive/MyDrive/empatia/model')\n","tokenizer.save_pretrained('/content/gdrive/MyDrive/empatia/model')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_Bs1qTYC7ZbR","executionInfo":{"status":"ok","timestamp":1687338272091,"user_tz":-120,"elapsed":1701,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}},"outputId":"857f2dbc-d04f-4656-87ff-4972ba258cf6"},"execution_count":21,"outputs":[{"output_type":"execute_result","data":{"text/plain":["('/content/gdrive/MyDrive/empatia/model/tokenizer_config.json',\n"," '/content/gdrive/MyDrive/empatia/model/special_tokens_map.json',\n"," '/content/gdrive/MyDrive/empatia/model/vocab.json',\n"," '/content/gdrive/MyDrive/empatia/model/merges.txt',\n"," '/content/gdrive/MyDrive/empatia/model/added_tokens.json',\n"," '/content/gdrive/MyDrive/empatia/model/tokenizer.json')"]},"metadata":{},"execution_count":21}]}]}