{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ZXsOR6oJOJbd" }, "source": [ "# Instalacja pakietów" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8l0hzptKNiZS", "outputId": "00b4e80b-9d2a-42f1-e087-1412429b63bd" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting transformers\n", " Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.3/6.3 MB\u001b[0m \u001b[31m39.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting datasets\n", " Downloading datasets-2.9.0-py3-none-any.whl (462 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m462.8/462.8 KB\u001b[0m \u001b[31m22.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.8/dist-packages (1.13.1+cu116)\n", "Collecting sentencepiece\n", " Downloading sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m28.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (23.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (1.21.6)\n", "Collecting tokenizers!=0.11.3,<0.14,>=0.11.1\n", " Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m23.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers) (4.64.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers) (3.9.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (6.0)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (2022.6.2)\n", "Collecting huggingface-hub<1.0,>=0.11.0\n", " Downloading huggingface_hub-0.12.1-py3-none-any.whl (190 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.3/190.3 KB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers) (2.25.1)\n", "Collecting responses<0.19\n", " Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n", "Collecting multiprocess\n", " Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.0/132.0 KB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (9.0.0)\n", "Collecting xxhash\n", " Downloading xxhash-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m213.0/213.0 KB\u001b[0m \u001b[31m14.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: dill<0.3.7 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.3.6)\n", "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (2023.1.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets) (3.8.4)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from datasets) (1.3.5)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch) (4.5.0)\n", "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (3.0.1)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.3)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.8.2)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (22.2.0)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (4.0.2)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (6.0.4)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (4.0.0)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (1.24.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2022.12.7)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2.10)\n", "Collecting urllib3<1.27,>=1.21.1\n", " Downloading urllib3-1.26.14-py2.py3-none-any.whl (140 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.6/140.6 KB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2022.7.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n", "Installing collected packages: tokenizers, sentencepiece, xxhash, urllib3, multiprocess, responses, huggingface-hub, transformers, datasets\n", " Attempting uninstall: urllib3\n", " Found existing installation: urllib3 1.24.3\n", " Uninstalling urllib3-1.24.3:\n", " Successfully uninstalled urllib3-1.24.3\n", "Successfully installed datasets-2.9.0 huggingface-hub-0.12.1 multiprocess-0.70.14 responses-0.18.0 sentencepiece-0.1.97 tokenizers-0.13.2 transformers-4.26.1 urllib3-1.26.14 xxhash-3.2.0\n" ] } ], "source": [ "!pip install transformers datasets torch sentencepiece" ] }, { "cell_type": "markdown", "metadata": { "id": "dhN0rmb5Oi3d" }, "source": [ "# Załadowanie datasetu" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "tnaDkwZ2Pbnn" }, "outputs": [], "source": [ "from datasets import load_dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 263, "referenced_widgets": [ "c396a3f65bb947ffa33130c424d9d93b", "fe2dd0bc42b84c5890d2c3dccaf66992", "245239c79fa74387bed598565bbc24a4", "926ce941393e4004bec99d38e82ea879", "65433acf8e5345d580c6bf8c949c0064", "293dc96088f942a3afc9af735d4c7117", "a7dc0ef4c814401f8b2ff982063b66cc", "790101eedd824ab893a7c7a1c0039163", "6889bbe799b54d9eade42250c2e5caa6", "5cc03bba39a74eaab7da908c6c24b1fb", "db67393d589c4d5bbbdab58adf51f970", "2c5a4622661a4465910b5f1f95bea742", "3d4560baaec44c40b5d5c27ed8eba68a", "dffa3d04bfe548e9aea5d4327ecca77a", "01ce10b9b22f48839824dda0a40ec5e8", "b839f01904ce41398b5286a801bbf4a7", "e36e4ad05a5040a4b67e0a133156358e", "892dbc4c003941409202f31523589835", "4a82770611cc4ba29aea5c462ce0c5be", "a3767cd0d3cd454595230f7933a0b2fe", "a77ad82dae14426b97ecc94166511b5c", "98e29231304e4d9ba32e0368ec5b3fd5", "d64108ba247a4ac5a93b3bdeede7fd9a", "aea219a8097b4d989664c09fcff9eb93", "b2721b8b11ed42e29c623a10a8c8e13f", "3d9d73ddd88446c5aeef5b2362bb878b", "b955aea22f9e4ce2882b5e722ae1dda8", "2a1ebe1ed1c64921bb7cdb0ed1e57b2b", "9980c1ddad4a475e97130c0efc2f3efe", "57feff796f4241f292cb4617ed85cfe0", "9297c35ec1f7494caab95714d65e34ab", "dd110977a96d4f19b37e923c4296fbdd", "953c5391bd5e49ad9b3752794929e08c", "b8c9fbbce7b84bf4a2f5b90c9d35ce0f", "51ab0b4ff03f4bdab8f5ba3fef868d4a", "fca8377c5d6e480180d22925109d431c", "a8854ee753ff4d179718690c3361b5a0", "72918e6a9c7e46f291f85c7adf237eb9", "fd3d449f08554666887c54c329baecd9", "8c83da8d3d54491a9f4c535bdb5611e1", "fb3f2df569ff42c89dea846f0df4b62f", "e298b598c6474c3fa2b7bada616110cc", "7838193eee55441d865d6f056454b841", "9ff537b0b2c745968c7d22e90cbe3894", "6b52b5d926bc43a3b70238c4fbfad7da", "733b55697ed8437c8d1d180756f7d3bb", "0d67db8eb2504316819ba0b4691aecb0", "6c402b861e4b49ad8b7d07c053e069bd", "0ea8f8e157634292b17babbdee9f46e0", "5016ba950fa447ae8d0f3a8191c8a34c", "8af86d8fcbbc41bd9e4bfee243bdb759", "16ddb3f4068040c9bbe5fa74c1b0177c", "f1424f2fde924cb88e05fc4a98d8e354", "afd518d5660e49b599d930c7f6040f87", "8a7c3caf8629499fbff1966e7544fc85", "bca3b484a8374617836c78a5a7247f19", "a6ad21ce2df942cc87a7009a76913bd7", "991524a5133c4ae9a1e6067be42f6ee4", "d4ed92b329c74cda9e38dc380fee1b71", "0e7dd3e98f64447dbe54df8d80fed381", "3034d7c759324b91bf84b704a023dd77", "4c3563f49ed0497ab2dc8739e4570d5b", "5788c4c9888e4def886297d645c670a3", "b713a8211c324197a5b66b05b4bfbb10", "9ab2665c20b8408490cc3fa43762b225", "b640800aca9f431883d8c309dd3b1daf" ] }, "id": "cCiAuRqrOkvV", "outputId": "f0e3ddd0-5cc7-47e2-9910-8b6b84cbd896" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Downloading builder script: 0%| | 0.00/3.21k [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "c396a3f65bb947ffa33130c424d9d93b" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Downloading metadata: 0%| | 0.00/1.69k [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "2c5a4622661a4465910b5f1f95bea742" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Downloading readme: 0%| | 0.00/4.87k [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "d64108ba247a4ac5a93b3bdeede7fd9a" } }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "Downloading and preparing dataset sms_spam/plain_text to /root/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c...\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "Downloading data: 0%| | 0.00/203k [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "b8c9fbbce7b84bf4a2f5b90c9d35ce0f" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Generating train split: 0%| | 0/5574 [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "6b52b5d926bc43a3b70238c4fbfad7da" } }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "Dataset sms_spam downloaded and prepared to /root/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c. Subsequent calls will reuse this data.\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ " 0%| | 0/1 [00:00, ?it/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "bca3b484a8374617836c78a5a7247f19" } }, "metadata": {} } ], "source": [ "dataset = load_dataset(\"sms_spam\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "JKFHPko3OnAV", "outputId": "6c5513f7-90f2-4977-a938-539c6f623aaa" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'sms': 'Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...\\n',\n", " 'label': 0}" ] }, "metadata": {}, "execution_count": 4 } ], "source": [ "dataset['train'][0]" ] }, { "cell_type": "markdown", "metadata": { "id": "l140vJrgYxPr" }, "source": [ "# Modyfikacja datasetu - klasyfikacja" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1boUF-YiY3_y", "outputId": "23bb86a0-9015-46b4-b36e-84007cad246e" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'sms': 'binary classification: Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...',\n", " 'label': '0'}" ] }, "metadata": {}, "execution_count": 5 } ], "source": [ "parsed_dataset = []\n", "\n", "for row in dataset['train']:\n", " text = \"binary classification: \" + row['sms'].replace(\"\\n\", \"\")\n", " new_row = {}\n", " new_row['sms'] = text\n", " if row['label'] == 0:\n", " new_row['label'] = \"0\"\n", " else:\n", " new_row['label'] = \"1\"\n", " parsed_dataset.append(new_row)\n", "\n", "parsed_dataset[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "O-J-jBDxPJcn" }, "source": [ "# Tokenizer T5" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "P23AYPX1PZ6g" }, "outputs": [], "source": [ "from transformers import T5Tokenizer" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 203, "referenced_widgets": [ "8263f80bfd30477389a7d24450a41aa9", "d46c7aa929e24e5da4a3508e1ec82795", "84a7f048c34541708b01a64726037e94", "c545335c355d43c2b158eb7ea1032b68", "45f93d1c8d574efc8573cd9c30be2fa4", "6bc7141208cf4f2787890daa6dd900b5", "5fd95d4231824740a4df2c7d7cb015a0", "154741e5c910415f9e87da6cb5e1c578", "7dda689c19df4ff9b854cede266b804b", "863b386cb073419e96e8ae0d01554a36", "663902a693b5405c85be17bbe46e0650", "e76d348dd73d4e88994fa53449b69a0c", "bd65017865934166878adf8aa6c352c9", "0dc9cdebd0bd480d86fd2b7151f8617f", "ac095333a156479bbba127d424b48943", "2b2a144c8b434a0eb6b91c532965a956", "9674fadc0e6448f1ad34f2e47d6dec14", "1ed148605a1c4d97a8ff1bbab36b0f8e", "3ed7755f2175486d8407dea96ecd8898", "9afc3775ccc440c4b58e606e5daa9e75", "9d94d64ee8c4465489f4b645080030e5", "4fdd5febab8d42e6b79438a69d39622e" ] }, "id": "q5Jz0E_oPMBr", "outputId": "1c5a4105-22c9-41d1-9d46-19120868ae9e" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Downloading (…)ve/main/spiece.model: 0%| | 0.00/792k [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "8263f80bfd30477389a7d24450a41aa9" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Downloading (…)lve/main/config.json: 0%| | 0.00/1.21k [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "e76d348dd73d4e88994fa53449b69a0c" } }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.8/dist-packages/transformers/models/t5/tokenization_t5.py:163: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n", "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n", " warnings.warn(\n" ] } ], "source": [ "tokenizer = T5Tokenizer.from_pretrained('t5-base')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dfxJQpoePsvI", "outputId": "a4b4cfa8-5334-4be6-c3ec-124840ecdcfa" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Original: binary classification: Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...\n", "Tokenized: ['▁binary', '▁classification', ':', '▁Go', '▁until', '▁jur', 'ong', '▁point', ',', '▁crazy', '.', '.', '▁Available', '▁only', '▁in', '▁bug', 'is', '▁', 'n', '▁great', '▁world', '▁la', '▁', 'e', '▁buffet', '...', '▁Cine', '▁there', '▁got', '▁', 'a', 'more', '▁wa', 't', '...']\n", "Token IDs: [14865, 13774, 10, 1263, 552, 10081, 2444, 500, 6, 6139, 5, 5, 8144, 163, 16, 8143, 159, 3, 29, 248, 296, 50, 3, 15, 15385, 233, 17270, 132, 530, 3, 9, 3706, 8036, 17, 233]\n" ] } ], "source": [ "sms = parsed_dataset[0]['sms']\n", "print('Original: ', sms)\n", "print('Tokenized: ', tokenizer.tokenize(sms))\n", "print('Token IDs: ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sms)))" ] }, { "cell_type": "markdown", "metadata": { "id": "UpluhM8cU5Ir" }, "source": [ "# Check maximum lenght of a sentence" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7uNUkixPU85O", "outputId": "b34a2f27-9478-4fc9-cdbb-23081472ec92" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Max sentence length: 341\n" ] } ], "source": [ "max_len = 0\n", "\n", "for sentence in parsed_dataset:\n", " input_ids = tokenizer.encode(sentence['sms'], add_special_tokens=True)\n", " max_len = max(max_len, len(input_ids))\n", "\n", "print('Max sentence length: ', max_len)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "lj0issBznZfK", "outputId": "d406d5a6-e278-47aa-b03e-8ee33c5871ac" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Max sentence length: 3\n" ] } ], "source": [ "max_label_len = 0\n", "\n", "for sentence in parsed_dataset:\n", " input_ids = tokenizer.encode(sentence['label'], add_special_tokens=True)\n", " max_label_len = max(max_label_len, len(input_ids))\n", "\n", "print('Max sentence length: ', max_label_len)" ] }, { "cell_type": "markdown", "metadata": { "id": "nfw62HdgSERb" }, "source": [ "# Pre train tokenization" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "KTXYalS1VLqH" }, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Z28QYfLnSGxR", "outputId": "aa3c2dce-488c-48a5-f47d-18026ac678d6" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Original: {'sms': 'binary classification: Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...', 'label': '0'}\n", "Token IDs: tensor([14865, 13774, 10, 1263, 552, 10081, 2444, 500, 6, 6139,\n", " 5, 5, 8144, 163, 16, 8143, 159, 3, 29, 248,\n", " 296, 50, 3, 15, 15385, 233, 17270, 132, 530, 3,\n", " 9, 3706, 8036, 17, 233, 1, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0])\n", "Label token IDs: tensor([ 3, 632, 1])\n" ] } ], "source": [ "input_ids = []\n", "target_ids = []\n", "attention_masks = []\n", "\n", "for sentence in parsed_dataset:\n", " encoded_dict = tokenizer.encode_plus(\n", " sentence['sms'],\n", " add_special_tokens = True,\n", " max_length = 341,\n", " padding = 'max_length',\n", " truncation=True,\n", " return_attention_mask = True,\n", " return_tensors = 'pt',\n", " )\n", " \n", " encoded_target_dict = tokenizer.encode_plus(\n", " sentence['label'],\n", " add_special_tokens = True,\n", " max_length = 3,\n", " padding = 'max_length',\n", " truncation=True,\n", " return_attention_mask = True,\n", " return_tensors = 'pt',\n", " )\n", " \n", " input_ids.append(encoded_dict['input_ids'])\n", " target_ids.append(encoded_target_dict['input_ids'])\n", " attention_masks.append(encoded_dict['attention_mask'])\n", "\n", "input_ids = torch.cat(input_ids, dim=0)\n", "target_ids = torch.cat(target_ids, dim=0)\n", "attention_masks = torch.cat(attention_masks, dim=0)\n", "\n", "print('Original: ', parsed_dataset[0])\n", "print('Token IDs:', input_ids[0])\n", "print('Label token IDs:', target_ids[0])" ] }, { "cell_type": "code", "source": [ "print('Label token IDs:', target_ids[123])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ld1xH-BD0G-M", "outputId": "67aca9e1-5dca-48d7-97b7-26f9781bf51f" }, "execution_count": 13, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Label token IDs: tensor([209, 1, 0])\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "qD_t0y0KVVSy" }, "source": [ "# Split dataset\n", "Class balance ratio should be similar to base dataset ratio." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "vN_SatRIVa4c" }, "outputs": [], "source": [ "from torch.utils.data import TensorDataset, random_split" ] }, { "cell_type": "code", "source": [ "def check_class_balance(dataset):\n", " spam_count = 0.0\n", " not_spam_count = 0.0\n", " for row in dataset:\n", " if row[2][1].item() == 1:\n", " spam_count += 1.0\n", " else:\n", " not_spam_count += 1.0\n", " return spam_count / not_spam_count " ], "metadata": { "id": "oo9C8ATt0dTq" }, "execution_count": 15, "outputs": [] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Mm6vc6lLVW3l", "outputId": "e7223b64-86a7-459d-b681-1ea1e0db02d8" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Spam to not spam messages ratio: 0.15475450590428838\n", "\n", "1,000 test samples\n", "Ratio: 0.15074798619102417\n", "\n", "4,116 training samples\n", "Ratio: 0.15455820476858345\n", "\n", " 458 validation samples\n", "Ratio: 0.16539440203562342\n", "\n" ] } ], "source": [ "dataset = TensorDataset(input_ids, attention_masks, target_ids)\n", "print(\"Spam to not spam messages ratio: {}\\n\".format(check_class_balance(dataset)))\n", "\n", "test_size = 1000\n", "dataset_len = len(dataset)\n", "train_size = int(0.9 * (dataset_len-test_size))\n", "val_size = (dataset_len-test_size) - train_size\n", "\n", "test_dataset, train_dataset, val_dataset = random_split(dataset, [test_size, train_size, val_size])\n", "\n", "print('{:>5,} test samples'.format(test_size))\n", "print(\"Ratio: {}\\n\".format(check_class_balance(test_dataset)))\n", "print('{:>5,} training samples'.format(train_size))\n", "print(\"Ratio: {}\\n\".format(check_class_balance(train_dataset)))\n", "print('{:>5,} validation samples'.format(val_size))\n", "print(\"Ratio: {}\\n\".format(check_class_balance(val_dataset)))" ] }, { "cell_type": "markdown", "metadata": { "id": "bmgQOP4EVfA1" }, "source": [ "# Create train and validation loaders" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CxnQ3cmIVlNh" }, "outputs": [], "source": [ "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0hcpO_onVjEC" }, "outputs": [], "source": [ "batch_size = 16\n", "\n", "train_dataloader = DataLoader(\n", " train_dataset,\n", " sampler = RandomSampler(train_dataset),\n", " batch_size = batch_size\n", " )\n", "\n", "validation_dataloader = DataLoader(\n", " val_dataset,\n", " sampler = SequentialSampler(val_dataset),\n", " batch_size = batch_size\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "efwhqLyyVu9z" }, "source": [ "# Device check" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ANBCfNGnVwVk", "outputId": "ff2ff959-f0e9-47f3-d504-9daa45f870c2" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "There are 1 GPU(s) available.\n", "We will use the GPU: Tesla T4\n" ] } ], "source": [ "if torch.cuda.is_available(): \n", " device = torch.device(\"cuda\")\n", "\n", " print('There are %d GPU(s) available.' % torch.cuda.device_count())\n", " print('We will use the GPU:', torch.cuda.get_device_name(0))\n", "\n", "else:\n", " print('No GPU available, using the CPU instead.')\n", " device = torch.device(\"cpu\")" ] }, { "cell_type": "markdown", "metadata": { "id": "okTx_ynMV0rH" }, "source": [ "# Load T5 model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Eu-7Eed8WgN0" }, "outputs": [], "source": [ "from transformers import T5ForConditionalGeneration" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "68418b4f08654a2c8a19bdefa31ef7e2", "f59f1fe74df84329baa0137729651d7e", "4e6666f32de94c14973b2f5895c4f4ec", "9a8b0e9cf614453789dceff586f47682", "a4e1407e1a42416087a3138812851afa", "1813bc00d8db4de5a7bb7cd276346312", "ab6b0613a4934f34aad4d28cd855362d", "7514dfc8c5c34f29ab9a246ba6b45dc2", "017b00a3a26743d3a761a5b05f72fe73", "1cfe23326f964bb0a2925456aea14ad5", "384aac4ea3274eebbb43ea847036793a", "17986d272156460f8e9bcee2559088d9", "f1c7c8e7770848dabf155be27b342c6f", "719b8ebc46884edd9b36829f49680c98", "f28050af08f947678a41e1ea5611067f", "2ff5d9e91bf64330a2747c9c518ba31c", "85bd410d586b4ac98b8df72f980c0194", "feb7905c359e4acd9c9f848fb63d5d55", "b1d4154a8b054c8380a9ac70c311755b", "2fee9e3e54ae41c8977beaae6802010f", "42dc0f0578ed4105abeee4362667a98a", "04bb3488deec4565a0864049b122437d" ] }, "id": "JKv9O8kfV2zZ", "outputId": "ad88a39b-bdc7-4325-b588-ed5feb453c3e" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Downloading (…)\"pytorch_model.bin\";: 0%| | 0.00/892M [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "68418b4f08654a2c8a19bdefa31ef7e2" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Downloading (…)neration_config.json: 0%| | 0.00/147 [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "17986d272156460f8e9bcee2559088d9" } }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "T5ForConditionalGeneration(\n", " (shared): Embedding(32128, 768)\n", " (encoder): T5Stack(\n", " (embed_tokens): Embedding(32128, 768)\n", " (block): ModuleList(\n", " (0): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " (relative_attention_bias): Embedding(32, 12)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (1): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (2): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (3): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (4): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (5): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (6): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (7): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (8): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (9): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (10): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (11): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (final_layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (decoder): T5Stack(\n", " (embed_tokens): Embedding(32128, 768)\n", " (block): ModuleList(\n", " (0): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " (relative_attention_bias): Embedding(32, 12)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (1): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (2): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (3): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (4): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (5): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (6): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (7): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (8): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (9): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (10): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (11): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseActDense(\n", " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): ReLU()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (final_layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (lm_head): Linear(in_features=768, out_features=32128, bias=False)\n", ")" ] }, "metadata": {}, "execution_count": 19 } ], "source": [ "model = T5ForConditionalGeneration.from_pretrained('t5-base')\n", "\n", "model.cuda()" ] }, { "cell_type": "markdown", "metadata": { "id": "F_SDAwxoawDy" }, "source": [ "# Helper functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s-q6_F38bLVA" }, "outputs": [], "source": [ "import datetime\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FzUi8908ax61" }, "outputs": [], "source": [ "def calculate_accuracy(preds, target):\n", " results_ok = 0.0\n", " results_false = 0.0\n", "\n", " for idx, pred in enumerate(preds):\n", " if pred == target[idx]:\n", " results_ok += 1.0\n", " else:\n", " results_false += 1.0\n", "\n", " return results_ok / (results_ok + results_false)\n", "\n", "def format_time(elapsed):\n", " '''\n", " Takes a time in seconds and returns a string hh:mm:ss\n", " '''\n", " elapsed_rounded = int(round((elapsed)))\n", " return str(datetime.timedelta(seconds=elapsed_rounded))" ] }, { "cell_type": "markdown", "metadata": { "id": "ucChBa-9bXJy" }, "source": [ "# Init training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A7XUF4PNbYy8" }, "outputs": [], "source": [ "optimizer = torch.optim.AdamW(model.parameters(),\n", " lr = 3e-4,\n", " eps = 1e-8\n", " )\n", "\n", "epochs = 4\n", "total_steps = len(train_dataloader) * epochs" ] }, { "cell_type": "markdown", "metadata": { "id": "DAzQWODja0A3" }, "source": [ "# Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Hoa7NlU0bI7G" }, "outputs": [], "source": [ "import random\n", "import time" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xsHxfslka1u5", "outputId": "e40d00a1-baf8-4554-e5ec-aeb87ee35f66" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "======== Epoch 1 / 4 ========\n", "Training...\n", " Batch 40 of 258. Elapsed: 0:01:12.\n", " Batch 80 of 258. Elapsed: 0:02:20.\n", " Batch 120 of 258. Elapsed: 0:03:28.\n", " Batch 160 of 258. Elapsed: 0:04:36.\n", " Batch 200 of 258. Elapsed: 0:05:45.\n", " Batch 240 of 258. Elapsed: 0:06:53.\n", "\n", " Average training loss: 0.09\n", " Average training acc: 0.81\n", " Training epcoh took: 0:07:23\n", "\n", "Running Validation...\n", " Accuracy: 0.83\n", " Validation took: 0:00:27\n", " Validation Loss: 0.00\n", "\n", "======== Epoch 2 / 4 ========\n", "Training...\n", " Batch 40 of 258. Elapsed: 0:01:09.\n", " Batch 80 of 258. Elapsed: 0:02:17.\n", " Batch 120 of 258. Elapsed: 0:03:25.\n", " Batch 160 of 258. Elapsed: 0:04:33.\n", " Batch 200 of 258. Elapsed: 0:05:42.\n", " Batch 240 of 258. Elapsed: 0:06:50.\n", "\n", " Average training loss: 0.00\n", " Average training acc: 0.86\n", " Training epcoh took: 0:07:19\n", "\n", "Running Validation...\n", " Accuracy: 0.83\n", " Validation took: 0:00:26\n", " Validation Loss: 0.00\n", "\n", "======== Epoch 3 / 4 ========\n", "Training...\n", " Batch 40 of 258. Elapsed: 0:01:08.\n", " Batch 80 of 258. Elapsed: 0:02:16.\n", " Batch 120 of 258. Elapsed: 0:03:24.\n", " Batch 160 of 258. Elapsed: 0:04:32.\n", " Batch 200 of 258. Elapsed: 0:05:41.\n", " Batch 240 of 258. Elapsed: 0:06:49.\n", "\n", " Average training loss: 0.00\n", " Average training acc: 0.85\n", " Training epcoh took: 0:07:18\n", "\n", "Running Validation...\n", " Accuracy: 0.83\n", " Validation took: 0:00:26\n", " Validation Loss: 0.00\n", "\n", "======== Epoch 4 / 4 ========\n", "Training...\n", " Batch 40 of 258. Elapsed: 0:01:08.\n", " Batch 80 of 258. Elapsed: 0:02:16.\n", " Batch 120 of 258. Elapsed: 0:03:24.\n", " Batch 160 of 258. Elapsed: 0:04:32.\n", " Batch 200 of 258. Elapsed: 0:05:41.\n", " Batch 240 of 258. Elapsed: 0:06:49.\n", "\n", " Average training loss: 0.00\n", " Average training acc: 0.86\n", " Training epcoh took: 0:07:18\n", "\n", "Running Validation...\n", " Accuracy: 0.83\n", " Validation took: 0:00:26\n", " Validation Loss: 0.00\n", "\n", "Training complete!\n", "Total training took 0:31:02 (h:mm:ss)\n" ] } ], "source": [ "# This training code is based on the `run_glue.py` script here:\n", "# https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128\n", "\n", "seed_val = 42\n", "\n", "random.seed(seed_val)\n", "np.random.seed(seed_val)\n", "torch.manual_seed(seed_val)\n", "torch.cuda.manual_seed_all(seed_val)\n", "\n", "training_stats = []\n", "total_t0 = time.time()\n", "\n", "for epoch_i in range(0, epochs):\n", " \n", " # ========================================\n", " # Training\n", " # ========================================\n", "\n", " print(\"\")\n", " print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))\n", " print('Training...')\n", "\n", " t0 = time.time()\n", " total_train_loss = 0\n", " total_train_acc = 0\n", "\n", " model.train()\n", "\n", " for step, batch in enumerate(train_dataloader):\n", " if step % 40 == 0 and not step == 0:\n", " elapsed = format_time(time.time() - t0)\n", " print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))\n", "\n", " b_input_ids = batch[0].to(device)\n", " b_input_mask = batch[1].to(device)\n", "\n", " y = batch[2].to(device)\n", " y_ids = y[:, :-1].contiguous()\n", " lm_labels = y[:, 1:].clone().detach()\n", " lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100 \n", "\n", " outputs = model(\n", " input_ids=b_input_ids,\n", " attention_mask=b_input_mask,\n", " decoder_input_ids=y_ids,\n", " labels=lm_labels\n", " )\n", "\n", " generated_ids = model.generate(\n", " input_ids = b_input_ids,\n", " attention_mask = b_input_mask, \n", " max_length=3, \n", " num_beams=2,\n", " repetition_penalty=2.5, \n", " length_penalty=1.0, \n", " early_stopping=True\n", " )\n", "\n", " preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]\n", " target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]\n", " total_train_acc += calculate_accuracy(preds, target) \n", "\n", " loss = outputs[0]\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " total_train_loss += loss.item()\n", "\n", " avg_train_loss = total_train_loss / len(train_dataloader) \n", " avg_train_acc = total_train_acc / len(train_dataloader) \n", " training_time = format_time(time.time() - t0)\n", "\n", " print(\"\")\n", " print(\" Average training loss: {0:.2f}\".format(avg_train_loss))\n", " print(\" Average training acc: {0:.2f}\".format(avg_train_acc))\n", " print(\" Training epcoh took: {:}\".format(training_time))\n", " \n", " # ========================================\n", " # Validation\n", " # ========================================\n", "\n", " print(\"\")\n", " print(\"Running Validation...\")\n", "\n", " t0 = time.time()\n", " model.eval()\n", "\n", " total_eval_loss = 0\n", " total_eval_accuracy = 0\n", "\n", " for batch in validation_dataloader:\n", " b_input_ids = batch[0].to(device)\n", " b_input_mask = batch[1].to(device)\n", "\n", " y = batch[2].to(device)\n", " y_ids = y[:, :-1].contiguous()\n", " lm_labels = y[:, 1:].clone().detach()\n", " lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100\n", " \n", " with torch.no_grad(): \n", "\n", " outputs = model(\n", " input_ids=b_input_ids,\n", " attention_mask=b_input_mask,\n", " decoder_input_ids=y_ids,\n", " labels=lm_labels\n", " )\n", "\n", " loss = outputs[0]\n", " total_eval_loss += loss.item()\n", "\n", " generated_ids = model.generate(\n", " input_ids = b_input_ids,\n", " attention_mask = b_input_mask, \n", " max_length=3, \n", " num_beams=2,\n", " repetition_penalty=2.5, \n", " length_penalty=1.0, \n", " early_stopping=True\n", " )\n", "\n", " preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]\n", " target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]\n", " total_eval_accuracy += calculate_accuracy(preds, target) \n", "\n", " avg_val_loss = total_eval_loss / len(validation_dataloader)\n", "\n", " avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)\n", " print(\" Accuracy: {0:.2f}\".format(avg_val_accuracy))\n", " \n", " validation_time = format_time(time.time() - t0)\n", " print(\" Validation took: {:}\".format(validation_time))\n", " print(\" Validation Loss: {0:.2f}\".format(avg_val_loss))\n", "\n", " training_stats.append(\n", " {\n", " 'epoch': epoch_i + 1,\n", " 'Training Loss': avg_train_loss,\n", " 'Training Accur.': avg_train_acc,\n", " 'Valid. Loss': avg_val_loss,\n", " 'Valid. Accur.': avg_val_accuracy,\n", " 'Training Time': training_time,\n", " 'Validation Time': validation_time\n", " }\n", " )\n", "\n", "print(\"\")\n", "print(\"Training complete!\")\n", "\n", "print(\"Total training took {:} (h:mm:ss)\".format(format_time(time.time()-total_t0)))" ] }, { "cell_type": "markdown", "metadata": { "id": "xIpFPoRb91Or" }, "source": [ "# Train summary" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GjYqBrrO93Oh", "colab": { "base_uri": "https://localhost:8080/", "height": 204 }, "outputId": "326edb05-56a5-4376-d793-424e5e122507" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Training Loss Training Accur. Valid. Loss Valid. Accur. \\\n", "epoch \n", "1 9.03e-02 0.81 9.89e-07 0.83 \n", "2 1.30e-05 0.86 2.26e-08 0.83 \n", "3 3.05e-06 0.85 0.00e+00 0.83 \n", "4 5.13e-06 0.86 0.00e+00 0.83 \n", "\n", " Training Time Validation Time \n", "epoch \n", "1 0:07:23 0:00:27 \n", "2 0:07:19 0:00:26 \n", "3 0:07:18 0:00:26 \n", "4 0:07:18 0:00:26 " ], "text/html": [ "\n", "
\n", " | Training Loss | \n", "Training Accur. | \n", "Valid. Loss | \n", "Valid. Accur. | \n", "Training Time | \n", "Validation Time | \n", "
---|---|---|---|---|---|---|
epoch | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
1 | \n", "9.03e-02 | \n", "0.81 | \n", "9.89e-07 | \n", "0.83 | \n", "0:07:23 | \n", "0:00:27 | \n", "
2 | \n", "1.30e-05 | \n", "0.86 | \n", "2.26e-08 | \n", "0.83 | \n", "0:07:19 | \n", "0:00:26 | \n", "
3 | \n", "3.05e-06 | \n", "0.85 | \n", "0.00e+00 | \n", "0.83 | \n", "0:07:18 | \n", "0:00:26 | \n", "
4 | \n", "5.13e-06 | \n", "0.86 | \n", "0.00e+00 | \n", "0.83 | \n", "0:07:18 | \n", "0:00:26 | \n", "