{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# T5 model classification training" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Links:\n", "- Tensorboard training: https://tensorboard.dev/experiment/CgaWd9pATZeuquRT7TZp7w/#scalars\n", "- Huggingface dataset edited: https://huggingface.co/datasets/Zombely/sst2-project-dataset\n", "- Huggingface Trained model: https://huggingface.co/Zombely/t5-model\n", "- Huggingface Tokenizer: https://huggingface.co/Zombely/t5-tokenizer" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "moS_Sk1kPztc", "outputId": "de22adf0-80a7-47f8-9df4-b6ac3732ded6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting datasets\n", " Downloading datasets-2.9.0-py3-none-any.whl (462 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m462.8/462.8 KB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting transformers\n", " Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.3/6.3 MB\u001b[0m \u001b[31m66.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting sentencepiece\n", " Downloading sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m49.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from datasets) (1.21.6)\n", "Collecting huggingface-hub<1.0.0,>=0.2.0\n", " Downloading huggingface_hub-0.12.0-py3-none-any.whl (190 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.3/190.3 KB\u001b[0m \u001b[31m14.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (2.25.1)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from datasets) (1.3.5)\n", "Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (9.0.0)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from datasets) (23.0)\n", "Collecting responses<0.19\n", " Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n", "Collecting xxhash\n", " Downloading xxhash-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m213.0/213.0 KB\u001b[0m \u001b[31m10.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: dill<0.3.7 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.3.6)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets) (3.8.3)\n", "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (2023.1.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (4.64.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (6.0)\n", "Collecting multiprocess\n", " Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.0/132.0 KB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1\n", " Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m72.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers) (3.9.0)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (2022.6.2)\n", "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (2.1.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (22.2.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (6.0.4)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.3)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (4.0.2)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.8.2)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.4.0)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (1.24.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2022.12.7)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2.10)\n", "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (4.0.0)\n", "Collecting urllib3<1.27,>=1.21.1\n", " Downloading urllib3-1.26.14-py2.py3-none-any.whl (140 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.6/140.6 KB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2022.7.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n", "Installing collected packages: tokenizers, sentencepiece, xxhash, urllib3, multiprocess, responses, huggingface-hub, transformers, datasets\n", " Attempting uninstall: urllib3\n", " Found existing installation: urllib3 1.24.3\n", " Uninstalling urllib3-1.24.3:\n", " Successfully uninstalled urllib3-1.24.3\n", "Successfully installed datasets-2.9.0 huggingface-hub-0.12.0 multiprocess-0.70.14 responses-0.18.0 sentencepiece-0.1.97 tokenizers-0.13.2 transformers-4.26.1 urllib3-1.26.14 xxhash-3.2.0\n" ] } ], "source": [ "!pip install datasets transformers sentencepiece" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "Itn0ce_3P-Cv" }, "outputs": [], "source": [ "from datasets import load_dataset\n", "import torch\n", "from transformers import T5ForConditionalGeneration, T5Tokenizer, TrainingArguments, Trainer\n", "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n", "import tensorflow as tf\n", "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset\n", "import random\n", "import time\n", "import numpy as np\n", "import datetime\n", "import sklearn\n", "from tqdm.notebook import tqdm\n", "import os" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Load data and transform dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "GQmeuSdNQkB7" }, "outputs": [], "source": [ "def load_and_process_dataset():\n", " dataset = load_dataset(\"sst2\")\n", " dataset.remove_columns('idx')\n", " del dataset['test']\n", " dataset['test'] = dataset['validation']\n", " del dataset['validation']\n", " split_dataset = dataset['train'].train_test_split(test_size=1600)\n", " dataset['train'] = split_dataset['train']\n", " dataset['validation'] = split_dataset['test']\n", " return dataset" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 329, "referenced_widgets": [ "1f0d8bc4c93048baa73f3a1cf00eae63", "c026ef4af9c145b9a8315e60e01c8db3", "17ccc44770fc407f9e2619b4080eced1", "502b22b7634a4862b3f083fe5dc5efaf", "9e4d054b2f5742d287837daecdf3c1d3", "c8fd183dd6694c5788b50922a326e3fa", "6857e7133d024d5f9e0a54528a06f164", "e6a081c879df438da42c384e665e606d", "b36171be00c54e718cc546f2f8b28250", "8d8167938ae14287b214f177b0a3d5be", "dd9e3eb27fc643d9ad20c35491afa209" ] }, "id": "BWonEzhAQmnF", "outputId": "3bfe3731-974e-4a77-d7b1-5f3bfb806a4b" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.builder:Found cached dataset sst2 (/root/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1f0d8bc4c93048baa73f3a1cf00eae63", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00 and Privacy Policy\n", ", and TensorBoard.dev's Terms of Service\n", ".\n", "\n", "This notice will not be shown again while you are logged into the uploader.\n", "To log out, run `tensorboard dev auth revoke`.\n", "\n", "Continue? (yes/NO) yes\n", "\n", "To sign in with the TensorBoard uploader:\n", "\n", "1. On your computer or phone, visit:\n", "\n", " https://www.google.com/device\n", "\n", "2. Sign in with your Google account, then enter:\n", "\n", " ZBRH-SMMW\n", "\n", "\n", "Upload started and will continue reading any new data as it's added to the logdir.\n", "\n", "To stop uploading, press Ctrl-C.\n", "\n", "New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/CgaWd9pATZeuquRT7TZp7w/\n", "\n", "\u001b[1m[2023-02-13T11:50:01]\u001b[0m Started scanning logdir.\n", "\u001b[1m[2023-02-13T11:50:04]\u001b[0m Total uploaded: 12630 scalars, 0 tensors, 0 binary objects\n", "\n", "\n", "Interrupted. View your TensorBoard at https://tensorboard.dev/experiment/CgaWd9pATZeuquRT7TZp7w/\n", "Traceback (most recent call last):\n", " File \"/usr/local/bin/tensorboard\", line 8, in \n", " sys.exit(run_main())\n", " File \"/usr/local/lib/python3.8/dist-packages/tensorboard/main.py\", line 46, in run_main\n", " app.run(tensorboard.main, flags_parser=tensorboard.configure)\n", " File \"/usr/local/lib/python3.8/dist-packages/absl/app.py\", line 308, in run\n", " _run_main(main, args)\n", " File \"/usr/local/lib/python3.8/dist-packages/absl/app.py\", line 254, in _run_main\n", " sys.exit(main(argv))\n", " File \"/usr/local/lib/python3.8/dist-packages/tensorboard/program.py\", line 276, in main\n", " return runner(self.flags) or 0\n", " File \"/usr/local/lib/python3.8/dist-packages/tensorboard/uploader/uploader_subcommand.py\", line 691, in run\n", " return _run(flags, self._experiment_url_callback)\n", " File \"/usr/local/lib/python3.8/dist-packages/tensorboard/uploader/uploader_subcommand.py\", line 124, in _run\n", " intent.execute(server_info, channel)\n", " File \"/usr/local/lib/python3.8/dist-packages/grpc/_channel.py\", line 1564, in __exit__\n", " self._close()\n", " File \"/usr/local/lib/python3.8/dist-packages/grpc/_channel.py\", line 1550, in _close\n", " self._channel.close(cygrpc.StatusCode.cancelled, 'Channel closed!')\n", " File \"src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi\", line 513, in grpc._cython.cygrpc.Channel.close\n", " File \"src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi\", line 399, in grpc._cython.cygrpc._close\n", " File \"src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi\", line 429, in grpc._cython.cygrpc._close\n", " File \"/usr/lib/python3.8/threading.py\", line 364, in notify_all\n", " def notify_all(self):\n", "KeyboardInterrupt\n", "^C\n" ] } ], "source": [ "!tensorboard dev upload --logdir logs --name t5-sst2" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CHSu41ZvnaBB", "outputId": "2948e2b2-54a1-43aa-9d29-27a7b40538ed" }, "outputs": [ { "data": { "text/plain": [ "('./model/tokenizer_config.json',\n", " './model/special_tokens_map.json',\n", " './model/spiece.model',\n", " './model/added_tokens.json')" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t5model.model.save_pretrained(\"./model\")\n", "t5model.tokenizer.save_pretrained(\"./model\")\n" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Xja8cK6yoHcM", "outputId": "df987295-d175-4d17-928e-d9b03207169c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " adding: model/ (stored 0%)\n", " adding: model/tokenizer_config.json (deflated 82%)\n", " adding: model/config.json (deflated 62%)\n", " adding: model/generation_config.json (deflated 29%)\n", " adding: model/pytorch_model.bin (deflated 8%)\n", " adding: model/special_tokens_map.json (deflated 86%)\n", " adding: model/spiece.model (deflated 48%)\n" ] } ], "source": [ "!zip -r /content/model model" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "16jB6CpRoslW", "outputId": "15561efe-747b-4f21-aefc-84fb913c3037" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|\n", " _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n", " _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|\n", " _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n", " _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|\n", " \n", " To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .\n", "Token: \n", "Add token as git credential? (Y/n) y\n", "Token is valid.\n", "\u001b[1m\u001b[31mCannot authenticate through git-credential as no helper is defined on your machine.\n", "You might have to re-authenticate when pushing to the Hugging Face Hub.\n", "Run the following command in your terminal in case you want to set the 'store' credential helper as default.\n", "\n", "git config --global credential.helper store\n", "\n", "Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.\u001b[0m\n", "Token has not been saved to git credential helper.\n", "Your token has been saved to /root/.cache/huggingface/token\n", "Login successful\n" ] } ], "source": [ "!huggingface-cli login" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 133, "referenced_widgets": [ "7e8cfa36fed049e2a7b7d8666bd57cdf", "3c3d6d206d4844b199743baffb14d6c8", "7e4c949d6ed246359c2e683f7697886d", "52a2d3aa6abc4ac0a74ec9a2fa400f2c", "2b964660bb9f41f599c6282547e23bc2", "a0a534ea53ba41c2bab6ca1e99fee542", "5170299b7a664c34b5aad7ea5d0f85c0", "27cd138062d94582b0323d477c2df364", "9d1f7fa6d41e46ba8a3a4d330d73d05b", "992bd52b3e474ce297dfa09fb0007a01", "6a17c291ee0547f2a31e430db846397c", "5a2c81a908b94bb8b2a4422effbce25c", "c22999770c184330a75f5efa68e61cab", "2f9ccc2d649f4b45a30f3851beff2629", "b6f2f601496c4018ba14e5615e1e300f", "1e9fb6e057014cc78f64006ef96d9cb1", "0f244fbff6de4511883e953c494f0a22", "5266c0225cc246eab466d882a47e60dd", "fda5959a3987462b8e75ec01c0eb2fda", "256c1b04cfd94ec984e6a844c60e881c", "500de4e8a3c5444aa8395a6bb3fa90f3", "25b3a226518f44f983cd12cdbde7e434" ] }, "id": "DptjUZgypVN9", "outputId": "339f2bc3-0ccc-4e72-9bb9-f90d056a6044" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7e8cfa36fed049e2a7b7d8666bd57cdf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Upload 1 LFS files: 0%| | 0/1 [00:00