diff --git a/GPT_2.ipynb b/GPT_2.ipynb index a8ad179..e3af2ce 100644 --- a/GPT_2.ipynb +++ b/GPT_2.ipynb @@ -16,7 +16,7 @@ "gpuClass": "standard", "widgets": { "application/vnd.jupyter.widget-state+json": { - "16f1b324020d48c3a8fd4487c42bbd6b": { + "f0fc084b95e0408a9d77d4051a540f2d": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", @@ -31,14 +31,14 @@ "_view_name": "HBoxView", "box_style": "", "children": [ - "IPY_MODEL_df036cb95c56454998cb7c788d341584", - "IPY_MODEL_4a0387dee622459498ddc9d7bf201187", - "IPY_MODEL_604840d710474c71a676c8368c9b3f2f" + "IPY_MODEL_f32509354c994a148ece1bf2f5d2fb66", + "IPY_MODEL_1cfa03eaaa7f4750af69da815f3f8360", + "IPY_MODEL_69e5f2a83b884fc7a640accaa27b5600" ], - "layout": "IPY_MODEL_7e3d18de5d554030bd6aa801ac7f3192" + "layout": "IPY_MODEL_b16b3d7a825a4435bab3dd8bdb26702d" } }, - "df036cb95c56454998cb7c788d341584": { + "f32509354c994a148ece1bf2f5d2fb66": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", @@ -53,13 +53,13 @@ "_view_name": "HTMLView", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_419e5a2aab8147d79490748f633675cd", + "layout": "IPY_MODEL_0f169cc9432649b9bc990ebed23faa47", "placeholder": "​", - "style": "IPY_MODEL_41452374c2a64afd82d30744b36dd801", + "style": "IPY_MODEL_e5e7f54b635748da9fb170c6819e6368", "value": "100%" } }, - "4a0387dee622459498ddc9d7bf201187": { + "1cfa03eaaa7f4750af69da815f3f8360": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", @@ -75,15 +75,15 @@ "bar_style": "success", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_ea94ef66672d4b69bf0d5eac6f7dada3", + "layout": "IPY_MODEL_c3361a78031047bca9494db148aa9c60", "max": 3, "min": 0, "orientation": "horizontal", - "style": "IPY_MODEL_a7e7dd3a259a4a878cfcd6a66ed35c7c", + "style": "IPY_MODEL_c0376b60cd6643a4b14c5f88f1feabfd", "value": 3 } }, - "604840d710474c71a676c8368c9b3f2f": { + "69e5f2a83b884fc7a640accaa27b5600": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", @@ -98,13 +98,13 @@ "_view_name": "HTMLView", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_275a9313a7e14c57b66cbe484499c8ec", + "layout": "IPY_MODEL_8582e82344404f68a3f89033e0f4987e", "placeholder": "​", - "style": "IPY_MODEL_c01651c9c010429bbf5507770ce6b6ce", - "value": " 3/3 [00:00<00:00, 136.35it/s]" + "style": "IPY_MODEL_ba03ab4c843c42909fbeb4ff411186d6", + "value": " 3/3 [00:00<00:00, 31.73it/s]" } }, - "7e3d18de5d554030bd6aa801ac7f3192": { + "b16b3d7a825a4435bab3dd8bdb26702d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", @@ -156,7 +156,7 @@ "width": null } }, - "419e5a2aab8147d79490748f633675cd": { + "0f169cc9432649b9bc990ebed23faa47": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", @@ -208,7 +208,7 @@ "width": null } }, - "41452374c2a64afd82d30744b36dd801": { + "e5e7f54b635748da9fb170c6819e6368": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", @@ -223,7 +223,7 @@ "description_width": "" } }, - "ea94ef66672d4b69bf0d5eac6f7dada3": { + "c3361a78031047bca9494db148aa9c60": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", @@ -275,7 +275,7 @@ "width": null } }, - "a7e7dd3a259a4a878cfcd6a66ed35c7c": { + "c0376b60cd6643a4b14c5f88f1feabfd": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", @@ -291,7 +291,7 @@ "description_width": "" } }, - "275a9313a7e14c57b66cbe484499c8ec": { + "8582e82344404f68a3f89033e0f4987e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", @@ -343,7 +343,7 @@ "width": null } }, - "c01651c9c010429bbf5507770ce6b6ce": { + "ba03ab4c843c42909fbeb4ff411186d6": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", @@ -363,274 +363,243 @@ }, "cells": [ { - "cell_type": "code", - "source": [], + "cell_type": "markdown", + "source": [ + "# Setup" + ], "metadata": { - "id": "JErLYXsaYy8-" - }, - "execution_count": null, - "outputs": [] + "id": "n2A5EThJNiAy" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Requirements" + ], + "metadata": { + "id": "tPp2_1rDOFYA" + } }, { "cell_type": "code", - "source": [ - "! pip install datasets transformers torch scikit-learn evaluate" - ], + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, - "id": "u29i-U30zRjY", - "outputId": "55534ca2-097f-4e7a-a517-463f974148cf" + "id": "OmsX3kG4bLTg", + "outputId": "cd31b31c-3840-490c-b57f-18edfe8d847a" }, - "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting datasets\n", - " Downloading datasets-2.9.0-py3-none-any.whl (462 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m462.8/462.8 KB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting transformers\n", - " Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.3/6.3 MB\u001b[0m \u001b[31m55.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.8/dist-packages (1.13.1+cu116)\n", - "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.8/dist-packages (1.0.2)\n", - "Collecting evaluate\n", - " Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m81.4/81.4 KB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from datasets) (1.21.6)\n", - "Collecting xxhash\n", - " Downloading xxhash-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m213.0/213.0 KB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (2023.1.0)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.8/dist-packages (1.13.1+cu116)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch) (4.4.0)\n", + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: datasets in /usr/local/lib/python3.8/dist-packages (2.9.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets) (3.8.3)\n", - "Collecting responses<0.19\n", - " Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n", - "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (6.0)\n", - "Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (9.0.0)\n", - "Collecting huggingface-hub<1.0.0,>=0.2.0\n", - " Downloading huggingface_hub-0.12.0-py3-none-any.whl (190 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.3/190.3 KB\u001b[0m \u001b[31m16.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: dill<0.3.7 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.3.6)\n", - "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (4.64.1)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (2.25.1)\n", + "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (4.64.1)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from datasets) (1.3.5)\n", + "Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.12.0)\n", + "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.18.0)\n", + "Requirement already satisfied: dill<0.3.7 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.3.6)\n", + "Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (9.0.0)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.8/dist-packages (from datasets) (3.2.0)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from datasets) (1.21.6)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.8/dist-packages (from datasets) (0.70.14)\n", + "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (2023.1.0)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from datasets) (23.0)\n", - "Collecting multiprocess\n", - " Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.0/132.0 KB\u001b[0m \u001b[31m8.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (2022.6.2)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers) (3.9.0)\n", - "Collecting tokenizers!=0.11.3,<0.14,>=0.11.1\n", - " Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m51.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch) (4.4.0)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (3.1.0)\n", - "Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.7.3)\n", - "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.2.0)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.8.2)\n", - "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (22.2.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.3)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (6.0)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (4.0.2)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (6.0.4)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.3)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (22.2.0)\n", "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (2.1.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.1)\n", - "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (4.0.0)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (1.24.3)\n", - "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2.10)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.8.2)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (6.0.4)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.4.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (3.9.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2022.12.7)\n", - "Collecting urllib3<1.27,>=1.21.1\n", - " Downloading urllib3-1.26.14-py2.py3-none-any.whl (140 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.6/140.6 KB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2022.7.1)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (1.26.14)\n", + "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (4.0.0)\n", + "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2.10)\n", "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2022.7.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n", - "Installing collected packages: tokenizers, xxhash, urllib3, multiprocess, responses, huggingface-hub, transformers, datasets, evaluate\n", - " Attempting uninstall: urllib3\n", - " Found existing installation: urllib3 1.24.3\n", - " Uninstalling urllib3-1.24.3:\n", - " Successfully uninstalled urllib3-1.24.3\n", - "Successfully installed datasets-2.9.0 evaluate-0.4.0 huggingface-hub-0.12.0 multiprocess-0.70.14 responses-0.18.0 tokenizers-0.13.2 transformers-4.26.1 urllib3-1.26.14 xxhash-3.2.0\n" + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.8/dist-packages (4.26.1)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (2022.6.2)\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.13.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (6.0)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (1.21.6)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers) (4.64.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers) (3.9.0)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.12.0)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers) (2.25.1)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (23.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.4.0)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (1.26.14)\n", + "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2.10)\n", + "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (4.0.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2022.12.7)\n", + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.8/dist-packages (1.0.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (3.1.0)\n", + "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.2.0)\n", + "Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.7.3)\n", + "Requirement already satisfied: numpy>=1.14.6 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.21.6)\n", + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: evaluate in /usr/local/lib/python3.8/dist-packages (0.4.0)\n", + "Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2023.1.0)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2.25.1)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from evaluate) (1.3.5)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.70.14)\n", + "Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.12.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from evaluate) (23.0)\n", + "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from evaluate) (4.64.1)\n", + "Requirement already satisfied: dill in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.3.6)\n", + "Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2.9.0)\n", + "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.18.0)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from evaluate) (1.21.6)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.8/dist-packages (from evaluate) (3.2.0)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (6.0)\n", + "Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (9.0.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (3.8.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub>=0.7.0->evaluate) (3.9.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.4.0)\n", + "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (2.10)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (2022.12.7)\n", + "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (4.0.0)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (1.26.14)\n", + "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->evaluate) (2022.7.1)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->evaluate) (2.8.2)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.0.4)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (22.2.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.1)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.3)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.8.2)\n", + "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (2.1.1)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (4.0.2)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->evaluate) (1.15.0)\n", + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: accelerate in /usr/local/lib/python3.8/dist-packages (0.16.0)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from accelerate) (23.0)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from accelerate) (5.4.8)\n", + "Requirement already satisfied: torch>=1.4.0 in /usr/local/lib/python3.8/dist-packages (from accelerate) (1.13.1+cu116)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from accelerate) (1.21.6)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.8/dist-packages (from accelerate) (6.0)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.4.0->accelerate) (4.4.0)\n", + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.8/dist-packages (0.1.97)\n", + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: protobuf in /usr/local/lib/python3.8/dist-packages (3.19.6)\n", + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: sacrebleu in /usr/local/lib/python3.8/dist-packages (2.3.1)\n", + "Requirement already satisfied: regex in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (2022.6.2)\n", + "Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (0.8.10)\n", + "Requirement already satisfied: colorama in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (0.4.6)\n", + "Requirement already satisfied: portalocker in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (2.7.0)\n", + "Requirement already satisfied: lxml in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (4.9.2)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (1.21.6)\n", + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: py7zr in /usr/local/lib/python3.8/dist-packages (0.20.4)\n", + "Requirement already satisfied: inflate64>=0.3.1 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.3.1)\n", + "Requirement already satisfied: pybcj>=0.6.0 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.1)\n", + "Requirement already satisfied: pyzstd>=0.14.4 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.15.3)\n", + "Requirement already satisfied: multivolumefile>=0.2.3 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.2.3)\n", + "Requirement already satisfied: pycryptodomex>=3.6.6 in /usr/local/lib/python3.8/dist-packages (from py7zr) (3.17)\n", + "Requirement already satisfied: brotli>=1.0.9 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.9)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from py7zr) (5.4.8)\n", + "Requirement already satisfied: texttable in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.6.7)\n", + "Requirement already satisfied: pyppmd<1.1.0,>=0.18.1 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.0)\n" ] } + ], + "source": [ + "!pip install torch\n", + "!pip install datasets\n", + "!pip install transformers\n", + "!pip install scikit-learn\n", + "!pip install evaluate\n", + "!pip install accelerate\n", + "!pip install sentencepiece\n", + "!pip install protobuf\n", + "!pip install sacrebleu\n", + "!pip install py7zr\n" ] }, { "cell_type": "markdown", - "source": [], + "source": [ + "## Imports" + ], "metadata": { - "id": "a_f-yno_zity" + "id": "o3Kj9IzuOKMi" } }, { "cell_type": "code", "source": [ - "!wget 'https://git.wmi.amu.edu.pl/s444465/projekt-glebokie/raw/branch/master/run_glue.py' -O 'run_glue.py'\n", - "!wget 'https://git.wmi.amu.edu.pl/s444465/projekt-glebokie/raw/branch/master/roberta.py' -O 'roberta.py'\n", - "!wget 'https://git.wmi.amu.edu.pl/s444465/projekt-glebokie/raw/branch/master/gpt2.py' -O 'gpt2.py'" + "import os\n", + "import json\n", + "import torch\n", + "from google.colab import drive\n", + "from pathlib import Path\n", + "from typing import Dict, List\n", + "from datasets import load_dataset\n", + "from transformers import T5Tokenizer" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "V_HmRNcmzhsw", - "outputId": "feafb930-4dbf-436c-8e37-de4e8b8a32cc" + "id": "r92S06noeSWE" }, - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--2023-02-12 21:57:57-- https://git.wmi.amu.edu.pl/s444465/projekt-glebokie/raw/branch/master/run_glue.py\n", - "Resolving git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)... 150.254.78.40\n", - "Connecting to git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)|150.254.78.40|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 30601 (30K) [text/plain]\n", - "Saving to: ‘run_glue.py’\n", - "\n", - "run_glue.py 100%[===================>] 29.88K --.-KB/s in 0.1s \n", - "\n", - "2023-02-12 21:57:58 (248 KB/s) - ‘run_glue.py’ saved [30601/30601]\n", - "\n", - "--2023-02-12 21:57:58-- https://git.wmi.amu.edu.pl/s444465/projekt-glebokie/raw/branch/master/roberta.py\n", - "Resolving git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)... 150.254.78.40\n", - "Connecting to git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)|150.254.78.40|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 12783 (12K) [text/plain]\n", - "Saving to: ‘roberta.py’\n", - "\n", - "roberta.py 100%[===================>] 12.48K --.-KB/s in 0s \n", - "\n", - "2023-02-12 21:57:58 (265 MB/s) - ‘roberta.py’ saved [12783/12783]\n", - "\n", - "--2023-02-12 21:57:58-- https://git.wmi.amu.edu.pl/s444465/projekt-glebokie/raw/branch/master/gpt2.py\n", - "Resolving git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)... 150.254.78.40\n", - "Connecting to git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)|150.254.78.40|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 7976 (7.8K) [text/plain]\n", - "Saving to: ‘gpt2.py’\n", - "\n", - "gpt2.py 100%[===================>] 7.79K --.-KB/s in 0s \n", - "\n", - "2023-02-12 21:57:59 (1.37 GB/s) - ‘gpt2.py’ saved [7976/7976]\n", - "\n" - ] - } - ] + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Loading data" + ], + "metadata": { + "id": "2UzLo91gNnsA" + } }, { "cell_type": "code", "source": [ - "import json\n", - "from pathlib import Path\n", - "from typing import Dict, List\n", - "from datasets import load_dataset\n", - "\n", "loaded_data = load_dataset('emotion')\n", "!mkdir -v -p data\n", - "\n", "train_path = Path('data/train.json')\n", "valid_path = Path('data/valid.json')\n", "test_path = Path('data/test.json')\n", - "data_train, data_valid, data_test = [], [], []\n", - "\n", - "for source_data, dataset, max_size in [\n", - " (loaded_data['train'], data_train, None),\n", - " (loaded_data['validation'], data_valid, None),\n", - " (loaded_data['test'], data_test, None),\n", - "]:\n", - " for i, data in enumerate(source_data):\n", - " if max_size is not None and i >= max_size:\n", - " break\n", - " data_line = {\n", - " 'label': int(data['label']),\n", - " 'text': data['text'],\n", - " }\n", - " dataset.append(data_line)\n", - "\n", - "print(f'Train: {len(data_train):6d}')\n", - "print(f'Valid: {len(data_valid):6d}')\n", - "\n", - "data_class_1, data_class_2 = [], []\n", - "\n", - "\"\"\"for data in data_valid:\n", - " label = data['label']\n", - " if label == 0:\n", - " data_class_1.append(data)\n", - " elif label == 1:\n", - " data_class_2.append(data)\n", - "\n", - "print(f'Label 1: {len(data_class_1):6d}')\n", - "print(f'Label 2: {len(data_class_2):6d}')\n", - "\n", - "size_half_class_1 = int(len(data_class_1) / 2)\n", - "size_half_class_2 = int(len(data_class_2) / 2)\n", - "\n", - "data_valid = data_class_1[:size_half_class_1] + data_class_2[:size_half_class_2]\n", - "data_test = data_class_1[size_half_class_1:] + data_class_2[size_half_class_2:]\n", - "\"\"\"\n", - "\n", - "print(f'Valid: {len(data_valid):6d}')\n", - "print(f'Test : {len(data_test):6d}')\n", - "\n", - "MAP_LABEL_TRANSLATION = {\n", - " 0: 'sadness',\n", - " 1: 'joy',\n", - " 2: 'love',\n", - " 3: 'anger',\n", - " 4: 'fear',\n", - " 5: 'surprise',\n", - "}\n", - "\n", - "def save_as_translations(original_save_path: Path, data_to_save: List[Dict]) -> None:\n", - " file_name = 's2s-' + original_save_path.name\n", - " file_path = original_save_path.parent / file_name\n", - "\n", - " print(f'Saving into: {file_path}')\n", - " with open(file_path, 'wt') as f_write:\n", - " for data_line in data_to_save:\n", - " label = data_line['label']\n", - " new_label = MAP_LABEL_TRANSLATION[label]\n", - " data_line['label'] = new_label\n", - " data_line_str = json.dumps(data_line)\n", - " f_write.write(f'{data_line_str}\\n')\n", - "\n", - "for file_path, data_to_save in [(train_path, data_train), (valid_path, data_valid), (test_path, data_test)]:\n", - " print(f'Saving into: {file_path}')\n", - " with open(file_path, 'wt') as f_write:\n", - " for data_line in data_to_save:\n", - " data_line_str = json.dumps(data_line)\n", - " f_write.write(f'{data_line_str}\\n')\n", - " \n", - " save_as_translations(file_path, data_to_save)\n", - "\n" + "data_train, data_valid, data_test = [], [], []" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", - "height": 295, + "height": 84, "referenced_widgets": [ - "16f1b324020d48c3a8fd4487c42bbd6b", - "df036cb95c56454998cb7c788d341584", - "4a0387dee622459498ddc9d7bf201187", - "604840d710474c71a676c8368c9b3f2f", - "7e3d18de5d554030bd6aa801ac7f3192", - "419e5a2aab8147d79490748f633675cd", - "41452374c2a64afd82d30744b36dd801", - "ea94ef66672d4b69bf0d5eac6f7dada3", - "a7e7dd3a259a4a878cfcd6a66ed35c7c", - "275a9313a7e14c57b66cbe484499c8ec", - "c01651c9c010429bbf5507770ce6b6ce" + "f0fc084b95e0408a9d77d4051a540f2d", + "f32509354c994a148ece1bf2f5d2fb66", + "1cfa03eaaa7f4750af69da815f3f8360", + "69e5f2a83b884fc7a640accaa27b5600", + "b16b3d7a825a4435bab3dd8bdb26702d", + "0f169cc9432649b9bc990ebed23faa47", + "e5e7f54b635748da9fb170c6819e6368", + "c3361a78031047bca9494db148aa9c60", + "c0376b60cd6643a4b14c5f88f1feabfd", + "8582e82344404f68a3f89033e0f4987e", + "ba03ab4c843c42909fbeb4ff411186d6" ] }, - "id": "bcR4tWQl0rqt", - "outputId": "6a2bad78-8eb7-4a90-c839-b7cc470438d7" + "id": "n_miey7eb2Xr", + "outputId": "273a8199-b14f-4a19-f9e1-a2961c2653bc" }, - "execution_count": 4, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -649,20 +618,119 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "16f1b324020d48c3a8fd4487c42bbd6b" + "model_id": "f0fc084b95e0408a9d77d4051a540f2d" } }, "metadata": {} + } + ] + }, + { + "cell_type": "code", + "source": [ + "for source_data, dataset, max_size in [\n", + " (loaded_data['train'], data_train, None),\n", + " (loaded_data['validation'], data_valid, None),\n", + " (loaded_data['test'], data_test, None),\n", + "]:\n", + " for i, data in enumerate(source_data):\n", + " if max_size is not None and i >= max_size:\n", + " break\n", + " data_line = {\n", + " 'label': int(data['label']),\n", + " 'text': data['text'],\n", + " }\n", + " dataset.append(data_line)\n", + "\n", + "print(f'Train: {len(data_train):6d}')\n", + "print(f'Valid: {len(data_valid):6d}')\n", + "print(f'Test: {len(data_test):6d}')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "BZ6afaRzGsxS", + "outputId": "139aaaf0-ea67-4ed2-bfa4-68caa7dd61e8" + }, + "execution_count": null, + "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "mkdir: created directory 'data'\n", "Train: 16000\n", "Valid: 2000\n", - "Valid: 2000\n", - "Test : 2000\n", + "Test: 2000\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "MAP_LABEL_TRANSLATION = {\n", + " 0: 'sadness',\n", + " 1: 'joy',\n", + " 2: 'love',\n", + " 3: 'anger',\n", + " 4: 'fear',\n", + " 5: 'surprise',\n", + "}" + ], + "metadata": { + "id": "w0KyM4TrGxQY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def save_as_translations(original_save_path: Path, data_to_save: List[Dict]) -> None:\n", + " file_name = 's2s-' + original_save_path.name\n", + " file_path = original_save_path.parent / file_name\n", + "\n", + " print(f'Saving into: {file_path}')\n", + " with open(file_path, 'wt') as f_write:\n", + " for data_line in data_to_save:\n", + " label = data_line['label']\n", + " new_label = MAP_LABEL_TRANSLATION[label]\n", + " data_line['label'] = new_label\n", + " data_line_str = json.dumps(data_line)\n", + " f_write.write(f'{data_line_str}\\n')" + ], + "metadata": { + "id": "-EFRYeAYHIKN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for file_path, data_to_save in [(train_path, data_train), (valid_path, data_valid), (test_path, data_test)]:\n", + " print(f'Saving into: {file_path}')\n", + " with open(file_path, 'wt') as f_write:\n", + " for data_line in data_to_save:\n", + " data_line_str = json.dumps(data_line)\n", + " f_write.write(f'{data_line_str}\\n')\n", + " \n", + " save_as_translations(file_path, data_to_save)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7RsrTNGCHIqc", + "outputId": "5cc59bc4-f71a-4b7b-ff27-f0f638f19fc9" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ "Saving into: data/train.json\n", "Saving into: data/s2s-train.json\n", "Saving into: data/valid.json\n", @@ -676,25 +744,31 @@ { "cell_type": "code", "source": [ - "\n", - "!head -n 4500 data/train.json > data/train-5k.json\n", - "!tail -n 2500 data/train.json >> data/train-5k.json\n", - "!wc -l data/train-5k.json" + "!head data/train.json" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, - "id": "pRmHIvyB0fZe", - "outputId": "a6f163f0-a393-431c-92e9-aaaf04601832" + "id": "Svu6YYSaHK4t", + "outputId": "3d90aaa5-7477-4d26-a1ce-8d830fe51178" }, - "execution_count": 5, + "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "7000 data/train-5k.json\n" + "{\"label\": 0, \"text\": \"i didnt feel humiliated\"}\n", + "{\"label\": 0, \"text\": \"i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake\"}\n", + "{\"label\": 3, \"text\": \"im grabbing a minute to post i feel greedy wrong\"}\n", + "{\"label\": 2, \"text\": \"i am ever feeling nostalgic about the fireplace i will know that it is still on the property\"}\n", + "{\"label\": 3, \"text\": \"i am feeling grouchy\"}\n", + "{\"label\": 0, \"text\": \"ive been feeling a little burdened lately wasnt sure why that was\"}\n", + "{\"label\": 5, \"text\": \"ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funny\"}\n", + "{\"label\": 4, \"text\": \"i feel as confused about life as a teenager or as jaded as a year old man\"}\n", + "{\"label\": 1, \"text\": \"i have been with petronas for years i feel that petronas has performed well and made a huge profit\"}\n", + "{\"label\": 2, \"text\": \"i feel romantic too\"}\n" ] } ] @@ -702,22 +776,53 @@ { "cell_type": "code", "source": [ - "from pathlib import Path\n", - "\n", - "for file_name in [\"train\", \"valid\", \"test\", \"s2s-train\", \"s2s-valid\", \"s2s-test\"]:\n", - " print(f\"=== {file_name} ===\")\n", - " all_text = Path(f\"data/{file_name}.json\").read_text().split('\\n')\n", - " text = all_text[:2500] + all_text[-2500:]\n", - " Path(f\"data/{file_name}-5k.json\").write_text(\"\\n\".join(text))" + "!head data/s2s-train.json" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, - "id": "rFa6ijdx2L28", - "outputId": "3303cd99-beba-4685-d8a0-80f819b1b50d" + "id": "5INZa4ZJHQbn", + "outputId": "12a2bbf0-fe51-4d63-de46-de63182657a9" }, - "execution_count": 6, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{\"label\": \"sadness\", \"text\": \"i didnt feel humiliated\"}\n", + "{\"label\": \"sadness\", \"text\": \"i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake\"}\n", + "{\"label\": \"anger\", \"text\": \"im grabbing a minute to post i feel greedy wrong\"}\n", + "{\"label\": \"love\", \"text\": \"i am ever feeling nostalgic about the fireplace i will know that it is still on the property\"}\n", + "{\"label\": \"anger\", \"text\": \"i am feeling grouchy\"}\n", + "{\"label\": \"sadness\", \"text\": \"ive been feeling a little burdened lately wasnt sure why that was\"}\n", + "{\"label\": \"surprise\", \"text\": \"ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funny\"}\n", + "{\"label\": \"fear\", \"text\": \"i feel as confused about life as a teenager or as jaded as a year old man\"}\n", + "{\"label\": \"joy\", \"text\": \"i have been with petronas for years i feel that petronas has performed well and made a huge profit\"}\n", + "{\"label\": \"love\", \"text\": \"i feel romantic too\"}\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# create tiny datasets for debugging purposes\n", + "for file_name in [\"train\", \"valid\", \"test\"]:\n", + " print(f\"=== {file_name} ===\")\n", + " all_text = Path(f\"data/{file_name}.json\").read_text().split('\\n')\n", + " text = all_text[:250] + all_text[-250:]\n", + " Path(f\"data/{file_name}-500.json\").write_text(\"\\n\".join(text))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OYeI-JvepSf7", + "outputId": "9f2a4bf8-a8c5-4ffb-b3f1-b1fe1157d997" + }, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -725,10 +830,7 @@ "text": [ "=== train ===\n", "=== valid ===\n", - "=== test ===\n", - "=== s2s-train ===\n", - "=== s2s-valid ===\n", - "=== s2s-test ===\n" + "=== test ===\n" ] } ] @@ -736,21 +838,212 @@ { "cell_type": "code", "source": [ - "import os\n", - "\n", - "os.environ['TOKENIZERS_PARALLELISM'] = 'true'" + "!wc -l data/*" ], "metadata": { - "id": "8opbDvBv3ZlK" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_WSOgm50LI0m", + "outputId": "2d4df642-b657-4e00-9b3b-c1408f7beb40" }, - "execution_count": 7, - "outputs": [] + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " 2000 data/s2s-test.json\n", + " 16000 data/s2s-train.json\n", + " 2000 data/s2s-valid.json\n", + " 499 data/test-500.json\n", + " 2000 data/test.json\n", + " 499 data/train-500.json\n", + " 16000 data/train.json\n", + " 499 data/valid-500.json\n", + " 2000 data/valid.json\n", + " 41497 total\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# GPU Info" + ], + "metadata": { + "id": "b78jArQhN2Jb" + } }, { "cell_type": "code", - "source": [], + "source": [ + "!nvidia-smi" + ], "metadata": { - "id": "pxuxjHt8P57X" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TZk2ZwJML4Wz", + "outputId": "4fd092bf-813e-4e83-9b13-cf3d62baf56f" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Sun Feb 12 23:30:18 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", + "| N/A 56C P0 26W / 70W | 0MiB / 15360MiB | 0% Default |\n", + "| | | N/A |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "os.environ['TOKENIZERS_PARALLELISM'] = 'true'" + ], + "metadata": { + "id": "e-ssYW1WL71Y" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Run" + ], + "metadata": { + "id": "gMK8qKF_dq5s" + } + }, + { + "cell_type": "code", + "source": [ + "!wget 'https://git.wmi.amu.edu.pl/s444465/projekt-glebokie/raw/branch/master/run_glue.py' -O 'run_glue.py'" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "f-NS2jDZdsMd", + "outputId": "445b069d-8628-4924-8d57-a0aaa0e8b964" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2023-02-12 23:30:18-- https://git.wmi.amu.edu.pl/s444465/projekt-glebokie/raw/branch/master/run_glue.py\n", + "Resolving git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)... 150.254.78.40\n", + "Connecting to git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)|150.254.78.40|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 30601 (30K) [text/plain]\n", + "Saving to: ‘run_glue.py’\n", + "\n", + "run_glue.py 100%[===================>] 29.88K --.-KB/s in 0.03s \n", + "\n", + "2023-02-12 23:30:18 (982 KB/s) - ‘run_glue.py’ saved [30601/30601]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "!wget 'https://git.wmi.amu.edu.pl/s444465/projekt-glebokie/raw/branch/master/roberta.py' -O 'roberta.py'" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rdUCXArmhliH", + "outputId": "f01832ae-2206-4f47-ae10-e50ed0d71c45" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2023-02-12 23:30:18-- https://git.wmi.amu.edu.pl/s444465/projekt-glebokie/raw/branch/master/roberta.py\n", + "Resolving git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)... 150.254.78.40\n", + "Connecting to git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)|150.254.78.40|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 12783 (12K) [text/plain]\n", + "Saving to: ‘roberta.py’\n", + "\n", + "roberta.py 100%[===================>] 12.48K --.-KB/s in 0s \n", + "\n", + "2023-02-12 23:30:18 (263 MB/s) - ‘roberta.py’ saved [12783/12783]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "!wget 'https://git.wmi.amu.edu.pl/s444465/projekt-glebokie/raw/branch/master/gpt2.py' -O 'gpt2.py'" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nw9Y56QukENR", + "outputId": "fe01c608-1dfe-4e4c-fd79-d6301fe0a2fe" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2023-02-12 23:30:18-- https://git.wmi.amu.edu.pl/s444465/projekt-glebokie/raw/branch/master/gpt2.py\n", + "Resolving git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)... 150.254.78.40\n", + "Connecting to git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)|150.254.78.40|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 8017 (7.8K) [text/plain]\n", + "Saving to: ‘gpt2.py’\n", + "\n", + "gpt2.py 100%[===================>] 7.83K --.-KB/s in 0s \n", + "\n", + "2023-02-12 23:30:19 (1.42 GB/s) - ‘gpt2.py’ saved [8017/8017]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "torch.cuda.empty_cache()" + ], + "metadata": { + "id": "2iIR3yh8dyPZ" }, "execution_count": null, "outputs": [] @@ -758,42 +1051,41 @@ { "cell_type": "code", "source": [ - "!python run_glue.py \\\n", - "--cache_dir .cache_training \\\n", - "--model_name_or_path gpt2 \\\n", - "--custom_model gpt2_hidden \\\n", - "--freeze_weights \\\n", - "--train_file data/s2s-train.json \\\n", - "--validation_file data/s2s-valid.json \\\n", - "--test_file data/s2s-test.json \\\n", - "--per_device_train_batch_size 24 \\\n", - "--per_device_eval_batch_size 24 \\\n", - "--do_train \\\n", - "--do_eval \\\n", - "--do_predict \\\n", - "--max_seq_length 128 \\\n", - "--learning_rate 2e-5 \\\n", - "--num_train_epochs 5 \\\n", - "--output_dir out/imdb-5k/gpt2" + "! python run_glue.py \\\n", + " --cache_dir .cache_training \\\n", + " --model_name_or_path gpt2 \\\n", + " --custom_model gpt2_hidden \\\n", + " --train_file data/train.json \\\n", + " --validation_file data/valid.json \\\n", + " --test_file data/test.json \\\n", + " --per_device_train_batch_size 8 \\\n", + " --per_device_eval_batch_size 8 \\\n", + " --do_train \\\n", + " --do_eval \\\n", + " --do_predict \\\n", + " --max_seq_length 128 \\\n", + " --num_train_epochs 1 \\\n", + " --metric_for_best_model accuracy \\\n", + " --greater_is_better True \\\n", + " --overwrite_output_dir \\\n", + " --output_dir out/emotion/gpt2" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, - "id": "XkkeRPG_z3Jc", - "outputId": "ffdacc37-5c06-401a-a588-a1d272dd72b0" + "id": "6KFVQFCqdyw6", + "outputId": "09b3934c-bc30-4349-e25a-af24544f86f3" }, - "execution_count": 8, + "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "2023-02-12 22:00:15.880386: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", - "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2023-02-12 22:00:16.771169: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia\n", - "2023-02-12 22:00:16.771276: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia\n", - "2023-02-12 22:00:16.771294: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", + "2023-02-12 23:30:29.286531: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia\n", + "2023-02-12 23:30:29.287316: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia\n", + "2023-02-12 23:30:29.287348: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", "WARNING:__main__:Process rank: -1, device: cuda:0, n_gpu: 1distributed training: False, 16-bits training: False\n", "INFO:__main__:Training/evaluation parameters TrainingArguments(\n", "_n_gpu=1,\n", @@ -831,7 +1123,7 @@ "full_determinism=False,\n", "gradient_accumulation_steps=1,\n", "gradient_checkpointing=False,\n", - "greater_is_better=None,\n", + "greater_is_better=True,\n", "group_by_length=False,\n", "half_precision_backend=auto,\n", "hub_model_id=None,\n", @@ -843,14 +1135,14 @@ "jit_mode_eval=False,\n", "label_names=None,\n", "label_smoothing_factor=0.0,\n", - "learning_rate=2e-05,\n", + "learning_rate=5e-05,\n", "length_column_name=length,\n", "load_best_model_at_end=False,\n", "local_rank=-1,\n", "log_level=passive,\n", "log_level_replica=passive,\n", "log_on_each_node=True,\n", - "logging_dir=out/imdb-5k/gpt2/runs/Feb12_22-00-19_506c7abe63fb,\n", + "logging_dir=out/emotion/gpt2/runs/Feb12_23-30-34_2740c0a1a5dc,\n", "logging_first_step=False,\n", "logging_nan_inf_filter=True,\n", "logging_steps=500,\n", @@ -858,17 +1150,17 @@ "lr_scheduler_type=linear,\n", "max_grad_norm=1.0,\n", "max_steps=-1,\n", - "metric_for_best_model=None,\n", + "metric_for_best_model=accuracy,\n", "mp_parameters=,\n", "no_cuda=False,\n", - "num_train_epochs=5.0,\n", + "num_train_epochs=1.0,\n", "optim=adamw_hf,\n", "optim_args=None,\n", - "output_dir=out/imdb-5k/gpt2,\n", - "overwrite_output_dir=False,\n", + "output_dir=out/emotion/gpt2,\n", + "overwrite_output_dir=True,\n", "past_index=-1,\n", - "per_device_eval_batch_size=24,\n", - "per_device_train_batch_size=24,\n", + "per_device_eval_batch_size=8,\n", + "per_device_train_batch_size=8,\n", "prediction_loss_only=False,\n", "push_to_hub=False,\n", "push_to_hub_model_id=None,\n", @@ -878,7 +1170,7 @@ "remove_unused_columns=True,\n", "report_to=['tensorboard'],\n", "resume_from_checkpoint=None,\n", - "run_name=out/imdb-5k/gpt2,\n", + "run_name=out/emotion/gpt2,\n", "save_on_each_node=False,\n", "save_steps=500,\n", "save_strategy=steps,\n", @@ -901,27 +1193,26 @@ "weight_decay=0.0,\n", "xpu_backend=None,\n", ")\n", - "INFO:__main__:load a local file for train: data/s2s-train.json\n", - "INFO:__main__:load a local file for validation: data/s2s-valid.json\n", - "INFO:__main__:load a local file for test: data/s2s-test.json\n", - "WARNING:datasets.builder:Using custom data configuration default-623c8a7b15a2e58a\n", + "INFO:__main__:load a local file for train: data/train.json\n", + "INFO:__main__:load a local file for validation: data/valid.json\n", + "INFO:__main__:load a local file for test: data/test.json\n", + "WARNING:datasets.builder:Using custom data configuration default-79a9e082059ced07\n", "INFO:datasets.info:Loading Dataset Infos from /usr/local/lib/python3.8/dist-packages/datasets/packaged_modules/json\n", - "INFO:datasets.builder:Generating dataset json (/content/.cache_training/json/default-623c8a7b15a2e58a/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n", - "Downloading and preparing dataset json/default to /content/.cache_training/json/default-623c8a7b15a2e58a/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51...\n", - "Downloading data files: 100% 3/3 [00:00<00:00, 14873.42it/s]\n", + "INFO:datasets.builder:Generating dataset json (/content/.cache_training/json/default-79a9e082059ced07/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n", + "Downloading and preparing dataset json/default to /content/.cache_training/json/default-79a9e082059ced07/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51...\n", + "Downloading data files: 100% 3/3 [00:00<00:00, 12384.76it/s]\n", "INFO:datasets.download.download_manager:Downloading took 0.0 min\n", "INFO:datasets.download.download_manager:Checksum Computation took 0.0 min\n", - "Extracting data files: 100% 3/3 [00:00<00:00, 1763.55it/s]\n", + "Extracting data files: 100% 3/3 [00:00<00:00, 1936.13it/s]\n", "INFO:datasets.utils.info_utils:Unable to verify checksums.\n", "INFO:datasets.builder:Generating train split\n", "INFO:datasets.builder:Generating validation split\n", "INFO:datasets.builder:Generating test split\n", "INFO:datasets.utils.info_utils:Unable to verify splits sizes.\n", - "Dataset json downloaded and prepared to /content/.cache_training/json/default-623c8a7b15a2e58a/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51. Subsequent calls will reuse this data.\n", - "100% 3/3 [00:00<00:00, 1028.86it/s]\n", - "Downloading (…)lve/main/config.json: 100% 665/665 [00:00<00:00, 126kB/s]\n", - "[INFO|configuration_utils.py:660] 2023-02-12 22:00:20,342 >> loading configuration file config.json from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/config.json\n", - "[INFO|configuration_utils.py:712] 2023-02-12 22:00:20,343 >> Model config GPT2Config {\n", + "Dataset json downloaded and prepared to /content/.cache_training/json/default-79a9e082059ced07/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51. Subsequent calls will reuse this data.\n", + "100% 3/3 [00:00<00:00, 989.92it/s]\n", + "[INFO|configuration_utils.py:660] 2023-02-12 23:30:36,613 >> loading configuration file config.json from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/config.json\n", + "[INFO|configuration_utils.py:712] 2023-02-12 23:30:36,614 >> Model config GPT2Config {\n", " \"_name_or_path\": \"gpt2\",\n", " \"activation_function\": \"gelu_new\",\n", " \"architectures\": [\n", @@ -976,9 +1267,9 @@ " \"vocab_size\": 50257\n", "}\n", "\n", - "[INFO|tokenization_auto.py:458] 2023-02-12 22:00:20,434 >> Could not locate the tokenizer configuration file, will try to use the model config instead.\n", - "[INFO|configuration_utils.py:660] 2023-02-12 22:00:20,525 >> loading configuration file config.json from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/config.json\n", - "[INFO|configuration_utils.py:712] 2023-02-12 22:00:20,525 >> Model config GPT2Config {\n", + "[INFO|tokenization_auto.py:458] 2023-02-12 23:30:36,976 >> Could not locate the tokenizer configuration file, will try to use the model config instead.\n", + "[INFO|configuration_utils.py:660] 2023-02-12 23:30:37,341 >> loading configuration file config.json from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/config.json\n", + "[INFO|configuration_utils.py:712] 2023-02-12 23:30:37,342 >> Model config GPT2Config {\n", " \"_name_or_path\": \"gpt2\",\n", " \"activation_function\": \"gelu_new\",\n", " \"architectures\": [\n", @@ -1017,17 +1308,14 @@ " \"vocab_size\": 50257\n", "}\n", "\n", - "Downloading (…)olve/main/vocab.json: 100% 1.04M/1.04M [00:00<00:00, 9.19MB/s]\n", - "Downloading (…)olve/main/merges.txt: 100% 456k/456k [00:00<00:00, 4.93MB/s]\n", - "Downloading (…)/main/tokenizer.json: 100% 1.36M/1.36M [00:00<00:00, 11.9MB/s]\n", - "[INFO|tokenization_utils_base.py:1802] 2023-02-12 22:00:21,743 >> loading file vocab.json from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/vocab.json\n", - "[INFO|tokenization_utils_base.py:1802] 2023-02-12 22:00:21,743 >> loading file merges.txt from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/merges.txt\n", - "[INFO|tokenization_utils_base.py:1802] 2023-02-12 22:00:21,743 >> loading file tokenizer.json from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/tokenizer.json\n", - "[INFO|tokenization_utils_base.py:1802] 2023-02-12 22:00:21,744 >> loading file added_tokens.json from cache at None\n", - "[INFO|tokenization_utils_base.py:1802] 2023-02-12 22:00:21,744 >> loading file special_tokens_map.json from cache at None\n", - "[INFO|tokenization_utils_base.py:1802] 2023-02-12 22:00:21,744 >> loading file tokenizer_config.json from cache at None\n", - "[INFO|configuration_utils.py:660] 2023-02-12 22:00:21,744 >> loading configuration file config.json from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/config.json\n", - "[INFO|configuration_utils.py:712] 2023-02-12 22:00:21,745 >> Model config GPT2Config {\n", + "[INFO|tokenization_utils_base.py:1802] 2023-02-12 23:30:38,088 >> loading file vocab.json from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/vocab.json\n", + "[INFO|tokenization_utils_base.py:1802] 2023-02-12 23:30:38,088 >> loading file merges.txt from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/merges.txt\n", + "[INFO|tokenization_utils_base.py:1802] 2023-02-12 23:30:38,088 >> loading file tokenizer.json from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/tokenizer.json\n", + "[INFO|tokenization_utils_base.py:1802] 2023-02-12 23:30:38,089 >> loading file added_tokens.json from cache at None\n", + "[INFO|tokenization_utils_base.py:1802] 2023-02-12 23:30:38,089 >> loading file special_tokens_map.json from cache at None\n", + "[INFO|tokenization_utils_base.py:1802] 2023-02-12 23:30:38,089 >> loading file tokenizer_config.json from cache at None\n", + "[INFO|configuration_utils.py:660] 2023-02-12 23:30:38,089 >> loading configuration file config.json from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/config.json\n", + "[INFO|configuration_utils.py:712] 2023-02-12 23:30:38,090 >> Model config GPT2Config {\n", " \"_name_or_path\": \"gpt2\",\n", " \"activation_function\": \"gelu_new\",\n", " \"architectures\": [\n", @@ -1068,269 +1356,132 @@ "\n", "INFO:__main__:Using hidden states in model: True\n", "INFO:__main__:Using implementation from class: GPT2ForSequenceClassificationCustom\n", - "Downloading (…)\"pytorch_model.bin\";: 100% 548M/548M [00:05<00:00, 103MB/s]\n", - "[INFO|modeling_utils.py:2275] 2023-02-12 22:00:27,304 >> loading weights file pytorch_model.bin from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/pytorch_model.bin\n", - "[INFO|modeling_utils.py:2857] 2023-02-12 22:00:30,150 >> All model checkpoint weights were used when initializing GPT2ForSequenceClassificationCustom.\n", + "[INFO|modeling_utils.py:2275] 2023-02-12 23:30:38,214 >> loading weights file pytorch_model.bin from cache at .cache_training/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/pytorch_model.bin\n", + "[INFO|modeling_utils.py:2857] 2023-02-12 23:30:43,108 >> All model checkpoint weights were used when initializing GPT2ForSequenceClassificationCustom.\n", "\n", - "[WARNING|modeling_utils.py:2859] 2023-02-12 22:00:30,150 >> Some weights of GPT2ForSequenceClassificationCustom were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.dense_1_hidden.weight', 'score.dense_2.weight', 'score.out_proj.weight', 'score.dense_2.bias', 'score.dense_1_hidden.bias', 'score.dense_1_input.bias', 'score.dense_1_input.weight']\n", + "[WARNING|modeling_utils.py:2859] 2023-02-12 23:30:43,108 >> Some weights of GPT2ForSequenceClassificationCustom were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.out_proj.weight', 'score.dense_1_hidden.bias', 'score.dense_3.bias', 'score.dense_1_input.weight', 'score.dense_1_hidden.weight', 'score.dense_3.weight', 'score.dense_1_input.bias', 'score.dense_2.weight', 'score.dense_2.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", - "INFO:__main__:Freezing encoder weights\n", - "INFO:__main__:Freezing layer 1\n", - "INFO:__main__:Freezing layer 2\n", - "INFO:__main__:Freezing layer 3\n", - "INFO:__main__:Freezing layer 4\n", - "INFO:__main__:Freezing layer 5\n", - "INFO:__main__:Freezing layer 6\n", - "INFO:__main__:Freezing layer 7\n", - "INFO:__main__:Freezing layer 8\n", - "INFO:__main__:Freezing layer 9\n", - "INFO:__main__:Freezing layer 10\n", - "INFO:__main__:Freezing layer 11\n", - "INFO:__main__:Freezing layer 12\n", - "INFO:__main__:Freezing layer 13\n", - "INFO:__main__:Freezing layer 14\n", - "INFO:__main__:Freezing layer 15\n", - "INFO:__main__:Freezing layer 16\n", - "INFO:__main__:Freezing layer 17\n", - "INFO:__main__:Freezing layer 18\n", - "INFO:__main__:Freezing layer 19\n", - "INFO:__main__:Freezing layer 20\n", - "INFO:__main__:Freezing layer 21\n", - "INFO:__main__:Freezing layer 22\n", - "INFO:__main__:Freezing layer 23\n", - "INFO:__main__:Freezing layer 24\n", - "INFO:__main__:Freezing layer 25\n", - "INFO:__main__:Freezing layer 26\n", - "INFO:__main__:Freezing layer 27\n", - "INFO:__main__:Freezing layer 28\n", - "INFO:__main__:Freezing layer 29\n", - "INFO:__main__:Freezing layer 30\n", - "INFO:__main__:Freezing layer 31\n", - "INFO:__main__:Freezing layer 32\n", - "INFO:__main__:Freezing layer 33\n", - "INFO:__main__:Freezing layer 34\n", - "INFO:__main__:Freezing layer 35\n", - "INFO:__main__:Freezing layer 36\n", - "INFO:__main__:Freezing layer 37\n", - "INFO:__main__:Freezing layer 38\n", - "INFO:__main__:Freezing layer 39\n", - "INFO:__main__:Freezing layer 40\n", - "INFO:__main__:Ignoring layer 41\n", - "INFO:__main__:Ignoring layer 42\n", - "INFO:__main__:Ignoring layer 43\n", - "INFO:__main__:Ignoring layer 44\n", - "INFO:__main__:Ignoring layer 45\n", - "INFO:__main__:Ignoring layer 46\n", - "INFO:__main__:Ignoring layer 47\n", - "INFO:__main__:Ignoring layer 48\n", - "INFO:__main__:Ignoring layer 49\n", - "INFO:__main__:Ignoring layer 50\n", - "INFO:__main__:Ignoring layer 51\n", - "INFO:__main__:Ignoring layer 52\n", - "INFO:__main__:Ignoring layer 53\n", - "INFO:__main__:Ignoring layer 54\n", - "INFO:__main__:Ignoring layer 55\n", - "INFO:__main__:Ignoring layer 56\n", - "INFO:__main__:Ignoring layer 57\n", - "INFO:__main__:Ignoring layer 58\n", - "INFO:__main__:Ignoring layer 59\n", - "INFO:__main__:Ignoring layer 60\n", - "INFO:__main__:Ignoring layer 61\n", - "INFO:__main__:Ignoring layer 62\n", - "INFO:__main__:Ignoring layer 63\n", - "INFO:__main__:Ignoring layer 64\n", - "INFO:__main__:Ignoring layer 65\n", - "INFO:__main__:Ignoring layer 66\n", - "INFO:__main__:Ignoring layer 67\n", - "INFO:__main__:Ignoring layer 68\n", - "INFO:__main__:Ignoring layer 69\n", - "INFO:__main__:Ignoring layer 70\n", - "INFO:__main__:Ignoring layer 71\n", - "INFO:__main__:Ignoring layer 72\n", - "INFO:__main__:Ignoring layer 73\n", - "INFO:__main__:Ignoring layer 74\n", - "INFO:__main__:Ignoring layer 75\n", - "INFO:__main__:Ignoring layer 76\n", - "INFO:__main__:Ignoring layer 77\n", - "INFO:__main__:Ignoring layer 78\n", - "INFO:__main__:Ignoring layer 79\n", - "INFO:__main__:Ignoring layer 80\n", - "INFO:__main__:Ignoring layer 81\n", - "INFO:__main__:Ignoring layer 82\n", - "INFO:__main__:Ignoring layer 83\n", - "INFO:__main__:Ignoring layer 84\n", - "INFO:__main__:Ignoring layer 85\n", - "INFO:__main__:Ignoring layer 86\n", - "INFO:__main__:Ignoring layer 87\n", - "INFO:__main__:Ignoring layer 88\n", - "INFO:__main__:Ignoring layer 89\n", - "INFO:__main__:Ignoring layer 90\n", - "INFO:__main__:Ignoring layer 91\n", - "INFO:__main__:Ignoring layer 92\n", - "INFO:__main__:Ignoring layer 93\n", - "INFO:__main__:Ignoring layer 94\n", - "INFO:__main__:Ignoring layer 95\n", - "INFO:__main__:Ignoring layer 96\n", - "INFO:__main__:Ignoring layer 97\n", - "INFO:__main__:Ignoring layer 98\n", - "INFO:__main__:Ignoring layer 99\n", - "INFO:__main__:Ignoring layer 100\n", - "INFO:__main__:Ignoring layer 101\n", - "INFO:__main__:Ignoring layer 102\n", - "INFO:__main__:Ignoring layer 103\n", - "INFO:__main__:Ignoring layer 104\n", - "INFO:__main__:Ignoring layer 105\n", - "INFO:__main__:Ignoring layer 106\n", - "INFO:__main__:Ignoring layer 107\n", - "INFO:__main__:Ignoring layer 108\n", - "INFO:__main__:Ignoring layer 109\n", - "INFO:__main__:Ignoring layer 110\n", - "INFO:__main__:Ignoring layer 111\n", - "INFO:__main__:Ignoring layer 112\n", - "INFO:__main__:Ignoring layer 113\n", - "INFO:__main__:Ignoring layer 114\n", - "INFO:__main__:Ignoring layer 115\n", - "INFO:__main__:Ignoring layer 116\n", - "INFO:__main__:Ignoring layer 117\n", - "INFO:__main__:Ignoring layer 118\n", - "INFO:__main__:Ignoring layer 119\n", - "INFO:__main__:Ignoring layer 120\n", - "INFO:__main__:Ignoring layer 121\n", - "INFO:__main__:Ignoring layer 122\n", - "INFO:__main__:Ignoring layer 123\n", - "INFO:__main__:Ignoring layer 124\n", - "INFO:__main__:Ignoring layer 125\n", - "INFO:__main__:Ignoring layer 126\n", - "INFO:__main__:Ignoring layer 127\n", - "INFO:__main__:Ignoring layer 128\n", - "INFO:__main__:Ignoring layer 129\n", - "INFO:__main__:Ignoring layer 130\n", - "INFO:__main__:Ignoring layer 131\n", - "INFO:__main__:Ignoring layer 132\n", - "INFO:__main__:Ignoring layer 133\n", - "INFO:__main__:Ignoring layer 134\n", - "INFO:__main__:Ignoring layer 135\n", - "INFO:__main__:Ignoring layer 136\n", - "INFO:__main__:Ignoring layer 137\n", - "INFO:__main__:Ignoring layer 138\n", - "INFO:__main__:Ignoring layer 139\n", - "INFO:__main__:Ignoring layer 140\n", - "INFO:__main__:Ignoring layer 141\n", - "INFO:__main__:Ignoring layer 142\n", - "INFO:__main__:Ignoring layer 143\n", - "INFO:__main__:Ignoring layer 144\n", - "INFO:__main__:Ignoring layer 145\n", - "INFO:__main__:Ignoring layer 146\n", - "INFO:__main__:Ignoring layer 147\n", - "INFO:__main__:Ignoring layer 148\n", - "INFO:__main__:Ignoring layer 149\n", - "INFO:__main__:Ignoring layer 150\n", - "INFO:__main__:Ignoring layer 151\n", - "INFO:__main__:Ignoring layer 152\n", - "INFO:__main__:Ignoring layer 153\n", - "INFO:__main__:Ignoring layer 154\n", - "INFO:__main__:Ignoring layer 155\n", - "[ERROR|tokenization_utils_base.py:1042] 2023-02-12 22:00:30,162 >> Using pad_token, but it is not set yet.\n", + "[ERROR|tokenization_utils_base.py:1042] 2023-02-12 23:30:43,118 >> Using pad_token, but it is not set yet.\n", "INFO:__main__:Set PAD token to EOS: <|endoftext|>\n", - "Running tokenizer on dataset: 0% 0/16 [00:00> The following columns in the training set don't have a corresponding argument in `GPT2ForSequenceClassificationCustom.forward` and have been ignored: text. If text are not expected by `GPT2ForSequenceClassificationCustom.forward`, you can safely ignore this message.\n", + "Running tokenizer on dataset: 0% 0/16 [00:00> The following columns in the training set don't have a corresponding argument in `GPT2ForSequenceClassificationCustom.forward` and have been ignored: text. If text are not expected by `GPT2ForSequenceClassificationCustom.forward`, you can safely ignore this message.\n", "/usr/local/lib/python3.8/dist-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", - "[INFO|trainer.py:1650] 2023-02-12 22:00:38,993 >> ***** Running training *****\n", - "[INFO|trainer.py:1651] 2023-02-12 22:00:38,993 >> Num examples = 16000\n", - "[INFO|trainer.py:1652] 2023-02-12 22:00:38,993 >> Num Epochs = 5\n", - "[INFO|trainer.py:1653] 2023-02-12 22:00:38,993 >> Instantaneous batch size per device = 24\n", - "[INFO|trainer.py:1654] 2023-02-12 22:00:38,993 >> Total train batch size (w. parallel, distributed & accumulation) = 24\n", - "[INFO|trainer.py:1655] 2023-02-12 22:00:38,993 >> Gradient Accumulation steps = 1\n", - "[INFO|trainer.py:1656] 2023-02-12 22:00:38,993 >> Total optimization steps = 3335\n", - "[INFO|trainer.py:1657] 2023-02-12 22:00:38,994 >> Number of trainable parameters = 68517888\n", - "{'loss': 1.0593, 'learning_rate': 1.7001499250374815e-05, 'epoch': 0.75}\n", - " 15% 500/3335 [04:16<24:04, 1.96it/s][INFO|trainer.py:2709] 2023-02-12 22:04:55,709 >> Saving model checkpoint to out/imdb-5k/gpt2/checkpoint-500\n", - "[INFO|configuration_utils.py:453] 2023-02-12 22:04:55,710 >> Configuration saved in out/imdb-5k/gpt2/checkpoint-500/config.json\n", - "[INFO|modeling_utils.py:1704] 2023-02-12 22:04:57,444 >> Model weights saved in out/imdb-5k/gpt2/checkpoint-500/pytorch_model.bin\n", - "[INFO|tokenization_utils_base.py:2160] 2023-02-12 22:04:57,444 >> tokenizer config file saved in out/imdb-5k/gpt2/checkpoint-500/tokenizer_config.json\n", - "[INFO|tokenization_utils_base.py:2167] 2023-02-12 22:04:57,444 >> Special tokens file saved in out/imdb-5k/gpt2/checkpoint-500/special_tokens_map.json\n", - "{'loss': 0.3829, 'learning_rate': 1.4002998500749626e-05, 'epoch': 1.5}\n", - " 30% 1000/3335 [08:36<19:51, 1.96it/s][INFO|trainer.py:2709] 2023-02-12 22:09:15,813 >> Saving model checkpoint to out/imdb-5k/gpt2/checkpoint-1000\n", - "[INFO|configuration_utils.py:453] 2023-02-12 22:09:15,814 >> Configuration saved in out/imdb-5k/gpt2/checkpoint-1000/config.json\n", - "[INFO|modeling_utils.py:1704] 2023-02-12 22:09:17,628 >> Model weights saved in out/imdb-5k/gpt2/checkpoint-1000/pytorch_model.bin\n", - "[INFO|tokenization_utils_base.py:2160] 2023-02-12 22:09:17,629 >> tokenizer config file saved in out/imdb-5k/gpt2/checkpoint-1000/tokenizer_config.json\n", - "[INFO|tokenization_utils_base.py:2167] 2023-02-12 22:09:17,630 >> Special tokens file saved in out/imdb-5k/gpt2/checkpoint-1000/special_tokens_map.json\n", - "{'loss': 0.256, 'learning_rate': 1.100449775112444e-05, 'epoch': 2.25}\n", - " 45% 1500/3335 [12:56<15:43, 1.95it/s][INFO|trainer.py:2709] 2023-02-12 22:13:36,008 >> Saving model checkpoint to out/imdb-5k/gpt2/checkpoint-1500\n", - "[INFO|configuration_utils.py:453] 2023-02-12 22:13:36,009 >> Configuration saved in out/imdb-5k/gpt2/checkpoint-1500/config.json\n", - "[INFO|modeling_utils.py:1704] 2023-02-12 22:13:37,703 >> Model weights saved in out/imdb-5k/gpt2/checkpoint-1500/pytorch_model.bin\n", - "[INFO|tokenization_utils_base.py:2160] 2023-02-12 22:13:37,704 >> tokenizer config file saved in out/imdb-5k/gpt2/checkpoint-1500/tokenizer_config.json\n", - "[INFO|tokenization_utils_base.py:2167] 2023-02-12 22:13:37,704 >> Special tokens file saved in out/imdb-5k/gpt2/checkpoint-1500/special_tokens_map.json\n", - "{'loss': 0.2101, 'learning_rate': 8.005997001499251e-06, 'epoch': 3.0}\n", - " 60% 2000/3335 [17:17<11:23, 1.95it/s][INFO|trainer.py:2709] 2023-02-12 22:17:56,308 >> Saving model checkpoint to out/imdb-5k/gpt2/checkpoint-2000\n", - "[INFO|configuration_utils.py:453] 2023-02-12 22:17:56,309 >> Configuration saved in out/imdb-5k/gpt2/checkpoint-2000/config.json\n", - "[INFO|modeling_utils.py:1704] 2023-02-12 22:17:58,005 >> Model weights saved in out/imdb-5k/gpt2/checkpoint-2000/pytorch_model.bin\n", - "[INFO|tokenization_utils_base.py:2160] 2023-02-12 22:17:58,006 >> tokenizer config file saved in out/imdb-5k/gpt2/checkpoint-2000/tokenizer_config.json\n", - "[INFO|tokenization_utils_base.py:2167] 2023-02-12 22:17:58,006 >> Special tokens file saved in out/imdb-5k/gpt2/checkpoint-2000/special_tokens_map.json\n", - "{'loss': 0.17, 'learning_rate': 5.0074962518740634e-06, 'epoch': 3.75}\n", - " 75% 2500/3335 [21:37<07:05, 1.96it/s][INFO|trainer.py:2709] 2023-02-12 22:22:16,686 >> Saving model checkpoint to out/imdb-5k/gpt2/checkpoint-2500\n", - "[INFO|configuration_utils.py:453] 2023-02-12 22:22:16,687 >> Configuration saved in out/imdb-5k/gpt2/checkpoint-2500/config.json\n", - "[INFO|modeling_utils.py:1704] 2023-02-12 22:22:18,356 >> Model weights saved in out/imdb-5k/gpt2/checkpoint-2500/pytorch_model.bin\n", - "[INFO|tokenization_utils_base.py:2160] 2023-02-12 22:22:18,357 >> tokenizer config file saved in out/imdb-5k/gpt2/checkpoint-2500/tokenizer_config.json\n", - "[INFO|tokenization_utils_base.py:2167] 2023-02-12 22:22:18,357 >> Special tokens file saved in out/imdb-5k/gpt2/checkpoint-2500/special_tokens_map.json\n", - "{'loss': 0.1569, 'learning_rate': 2.008995502248876e-06, 'epoch': 4.5}\n", - " 90% 3000/3335 [25:57<02:51, 1.95it/s][INFO|trainer.py:2709] 2023-02-12 22:26:36,938 >> Saving model checkpoint to out/imdb-5k/gpt2/checkpoint-3000\n", - "[INFO|configuration_utils.py:453] 2023-02-12 22:26:36,939 >> Configuration saved in out/imdb-5k/gpt2/checkpoint-3000/config.json\n", - "[INFO|modeling_utils.py:1704] 2023-02-12 22:26:38,608 >> Model weights saved in out/imdb-5k/gpt2/checkpoint-3000/pytorch_model.bin\n", - "[INFO|tokenization_utils_base.py:2160] 2023-02-12 22:26:38,608 >> tokenizer config file saved in out/imdb-5k/gpt2/checkpoint-3000/tokenizer_config.json\n", - "[INFO|tokenization_utils_base.py:2167] 2023-02-12 22:26:38,608 >> Special tokens file saved in out/imdb-5k/gpt2/checkpoint-3000/special_tokens_map.json\n", - "100% 3335/3335 [28:53<00:00, 2.15it/s][INFO|trainer.py:1901] 2023-02-12 22:29:32,259 >> \n", + "[INFO|trainer.py:1650] 2023-02-12 23:30:52,252 >> ***** Running training *****\n", + "[INFO|trainer.py:1651] 2023-02-12 23:30:52,253 >> Num examples = 16000\n", + "[INFO|trainer.py:1652] 2023-02-12 23:30:52,253 >> Num Epochs = 1\n", + "[INFO|trainer.py:1653] 2023-02-12 23:30:52,253 >> Instantaneous batch size per device = 8\n", + "[INFO|trainer.py:1654] 2023-02-12 23:30:52,253 >> Total train batch size (w. parallel, distributed & accumulation) = 8\n", + "[INFO|trainer.py:1655] 2023-02-12 23:30:52,253 >> Gradient Accumulation steps = 1\n", + "[INFO|trainer.py:1656] 2023-02-12 23:30:52,253 >> Total optimization steps = 2000\n", + "[INFO|trainer.py:1657] 2023-02-12 23:30:52,254 >> Number of trainable parameters = 137425920\n", + "{'loss': 0.9449, 'learning_rate': 3.7500000000000003e-05, 'epoch': 0.25}\n", + " 25% 500/2000 [02:07<06:07, 4.08it/s][INFO|trainer.py:2709] 2023-02-12 23:32:59,613 >> Saving model checkpoint to out/emotion/gpt2/checkpoint-500\n", + "[INFO|configuration_utils.py:453] 2023-02-12 23:32:59,615 >> Configuration saved in out/emotion/gpt2/checkpoint-500/config.json\n", + "[INFO|modeling_utils.py:1704] 2023-02-12 23:33:01,554 >> Model weights saved in out/emotion/gpt2/checkpoint-500/pytorch_model.bin\n", + "[INFO|tokenization_utils_base.py:2160] 2023-02-12 23:33:01,555 >> tokenizer config file saved in out/emotion/gpt2/checkpoint-500/tokenizer_config.json\n", + "[INFO|tokenization_utils_base.py:2167] 2023-02-12 23:33:01,555 >> Special tokens file saved in out/emotion/gpt2/checkpoint-500/special_tokens_map.json\n", + "{'loss': 0.3705, 'learning_rate': 2.5e-05, 'epoch': 0.5}\n", + " 50% 1000/2000 [04:17<04:09, 4.01it/s][INFO|trainer.py:2709] 2023-02-12 23:35:09,781 >> Saving model checkpoint to out/emotion/gpt2/checkpoint-1000\n", + "[INFO|configuration_utils.py:453] 2023-02-12 23:35:09,783 >> Configuration saved in out/emotion/gpt2/checkpoint-1000/config.json\n", + "[INFO|modeling_utils.py:1704] 2023-02-12 23:35:11,881 >> Model weights saved in out/emotion/gpt2/checkpoint-1000/pytorch_model.bin\n", + "[INFO|tokenization_utils_base.py:2160] 2023-02-12 23:35:11,882 >> tokenizer config file saved in out/emotion/gpt2/checkpoint-1000/tokenizer_config.json\n", + "[INFO|tokenization_utils_base.py:2167] 2023-02-12 23:35:11,882 >> Special tokens file saved in out/emotion/gpt2/checkpoint-1000/special_tokens_map.json\n", + "{'loss': 0.264, 'learning_rate': 1.25e-05, 'epoch': 0.75}\n", + " 75% 1500/2000 [06:27<02:03, 4.06it/s][INFO|trainer.py:2709] 2023-02-12 23:37:20,141 >> Saving model checkpoint to out/emotion/gpt2/checkpoint-1500\n", + "[INFO|configuration_utils.py:453] 2023-02-12 23:37:20,142 >> Configuration saved in out/emotion/gpt2/checkpoint-1500/config.json\n", + "[INFO|modeling_utils.py:1704] 2023-02-12 23:37:22,060 >> Model weights saved in out/emotion/gpt2/checkpoint-1500/pytorch_model.bin\n", + "[INFO|tokenization_utils_base.py:2160] 2023-02-12 23:37:22,061 >> tokenizer config file saved in out/emotion/gpt2/checkpoint-1500/tokenizer_config.json\n", + "[INFO|tokenization_utils_base.py:2167] 2023-02-12 23:37:22,061 >> Special tokens file saved in out/emotion/gpt2/checkpoint-1500/special_tokens_map.json\n", + "{'loss': 0.2223, 'learning_rate': 0.0, 'epoch': 1.0}\n", + "100% 2000/2000 [08:38<00:00, 4.06it/s][INFO|trainer.py:2709] 2023-02-12 23:39:30,550 >> Saving model checkpoint to out/emotion/gpt2/checkpoint-2000\n", + "[INFO|configuration_utils.py:453] 2023-02-12 23:39:30,551 >> Configuration saved in out/emotion/gpt2/checkpoint-2000/config.json\n", + "[INFO|modeling_utils.py:1704] 2023-02-12 23:39:32,522 >> Model weights saved in out/emotion/gpt2/checkpoint-2000/pytorch_model.bin\n", + "[INFO|tokenization_utils_base.py:2160] 2023-02-12 23:39:32,523 >> tokenizer config file saved in out/emotion/gpt2/checkpoint-2000/tokenizer_config.json\n", + "[INFO|tokenization_utils_base.py:2167] 2023-02-12 23:39:32,524 >> Special tokens file saved in out/emotion/gpt2/checkpoint-2000/special_tokens_map.json\n", + "[INFO|trainer.py:1901] 2023-02-12 23:39:36,929 >> \n", "\n", "Training completed. Do not forget to share your model on huggingface.co/models =)\n", "\n", "\n", - "{'train_runtime': 1733.281, 'train_samples_per_second': 46.155, 'train_steps_per_second': 1.924, 'train_loss': 0.35007504373118614, 'epoch': 5.0}\n", - "100% 3335/3335 [28:53<00:00, 1.92it/s]\n", - "[INFO|trainer.py:2709] 2023-02-12 22:29:32,277 >> Saving model checkpoint to out/imdb-5k/gpt2\n", - "[INFO|configuration_utils.py:453] 2023-02-12 22:29:32,278 >> Configuration saved in out/imdb-5k/gpt2/config.json\n", - "[INFO|modeling_utils.py:1704] 2023-02-12 22:29:33,934 >> Model weights saved in out/imdb-5k/gpt2/pytorch_model.bin\n", - "[INFO|tokenization_utils_base.py:2160] 2023-02-12 22:29:33,934 >> tokenizer config file saved in out/imdb-5k/gpt2/tokenizer_config.json\n", - "[INFO|tokenization_utils_base.py:2167] 2023-02-12 22:29:33,934 >> Special tokens file saved in out/imdb-5k/gpt2/special_tokens_map.json\n", + "{'train_runtime': 524.6759, 'train_samples_per_second': 30.495, 'train_steps_per_second': 3.812, 'train_loss': 0.4504347610473633, 'epoch': 1.0}\n", + "100% 2000/2000 [08:44<00:00, 3.81it/s]\n", + "[INFO|trainer.py:2709] 2023-02-12 23:39:36,932 >> Saving model checkpoint to out/emotion/gpt2\n", + "[INFO|configuration_utils.py:453] 2023-02-12 23:39:36,934 >> Configuration saved in out/emotion/gpt2/config.json\n", + "[INFO|modeling_utils.py:1704] 2023-02-12 23:39:39,121 >> Model weights saved in out/emotion/gpt2/pytorch_model.bin\n", + "[INFO|tokenization_utils_base.py:2160] 2023-02-12 23:39:39,122 >> tokenizer config file saved in out/emotion/gpt2/tokenizer_config.json\n", + "[INFO|tokenization_utils_base.py:2167] 2023-02-12 23:39:39,122 >> Special tokens file saved in out/emotion/gpt2/special_tokens_map.json\n", "***** train metrics *****\n", - " epoch = 5.0\n", - " train_loss = 0.3501\n", - " train_runtime = 0:28:53.28\n", + " epoch = 1.0\n", + " train_loss = 0.4504\n", + " train_runtime = 0:08:44.67\n", " train_samples = 16000\n", - " train_samples_per_second = 46.155\n", - " train_steps_per_second = 1.924\n", + " train_samples_per_second = 30.495\n", + " train_steps_per_second = 3.812\n", "INFO:__main__:*** Evaluate ***\n", - "[INFO|trainer.py:710] 2023-02-12 22:29:34,047 >> The following columns in the evaluation set don't have a corresponding argument in `GPT2ForSequenceClassificationCustom.forward` and have been ignored: text. If text are not expected by `GPT2ForSequenceClassificationCustom.forward`, you can safely ignore this message.\n", - "[INFO|trainer.py:2964] 2023-02-12 22:29:34,108 >> ***** Running Evaluation *****\n", - "[INFO|trainer.py:2966] 2023-02-12 22:29:34,108 >> Num examples = 2000\n", - "[INFO|trainer.py:2969] 2023-02-12 22:29:34,108 >> Batch size = 24\n", - "100% 84/84 [00:17<00:00, 4.83it/s]\n", + "[INFO|trainer.py:710] 2023-02-12 23:39:39,296 >> The following columns in the evaluation set don't have a corresponding argument in `GPT2ForSequenceClassificationCustom.forward` and have been ignored: text. If text are not expected by `GPT2ForSequenceClassificationCustom.forward`, you can safely ignore this message.\n", + "[INFO|trainer.py:2964] 2023-02-12 23:39:39,300 >> ***** Running Evaluation *****\n", + "[INFO|trainer.py:2966] 2023-02-12 23:39:39,301 >> Num examples = 2000\n", + "[INFO|trainer.py:2969] 2023-02-12 23:39:39,301 >> Batch size = 8\n", + "100% 250/250 [00:16<00:00, 14.71it/s]\n", "***** eval metrics *****\n", - " epoch = 5.0\n", - " eval_accuracy = 0.93\n", - " eval_loss = 0.1531\n", - " eval_runtime = 0:00:17.72\n", + " epoch = 1.0\n", + " eval_accuracy = 0.9355\n", + " eval_loss = 0.1925\n", + " eval_runtime = 0:00:17.11\n", " eval_samples = 2000\n", - " eval_samples_per_second = 112.855\n", - " eval_steps_per_second = 4.74\n", + " eval_samples_per_second = 116.846\n", + " eval_steps_per_second = 14.606\n", "INFO:__main__:*** Predict ***\n", - "[INFO|trainer.py:710] 2023-02-12 22:29:51,834 >> The following columns in the test set don't have a corresponding argument in `GPT2ForSequenceClassificationCustom.forward` and have been ignored: text. If text are not expected by `GPT2ForSequenceClassificationCustom.forward`, you can safely ignore this message.\n", - "[INFO|trainer.py:2964] 2023-02-12 22:29:51,836 >> ***** Running Prediction *****\n", - "[INFO|trainer.py:2966] 2023-02-12 22:29:51,836 >> Num examples = 2000\n", - "[INFO|trainer.py:2969] 2023-02-12 22:29:51,836 >> Batch size = 24\n", - "100% 84/84 [00:17<00:00, 4.81it/s]\n", + "[INFO|trainer.py:710] 2023-02-12 23:39:56,431 >> The following columns in the test set don't have a corresponding argument in `GPT2ForSequenceClassificationCustom.forward` and have been ignored: text. If text are not expected by `GPT2ForSequenceClassificationCustom.forward`, you can safely ignore this message.\n", + "[INFO|trainer.py:2964] 2023-02-12 23:39:56,432 >> ***** Running Prediction *****\n", + "[INFO|trainer.py:2966] 2023-02-12 23:39:56,433 >> Num examples = 2000\n", + "[INFO|trainer.py:2969] 2023-02-12 23:39:56,433 >> Batch size = 8\n", + "100% 250/250 [00:17<00:00, 14.46it/s]\n", "INFO:__main__:***** Predict results None *****\n", - "[INFO|modelcard.py:449] 2023-02-12 22:30:09,657 >> Dropping the following result as it does not have all the necessary fields:\n", - "{'task': {'name': 'Text Classification', 'type': 'text-classification'}, 'metrics': [{'name': 'Accuracy', 'type': 'accuracy', 'value': 0.9300000071525574}]}\n" + "[INFO|modelcard.py:449] 2023-02-12 23:40:14,252 >> Dropping the following result as it does not have all the necessary fields:\n", + "{'task': {'name': 'Text Classification', 'type': 'text-classification'}, 'metrics': [{'name': 'Accuracy', 'type': 'accuracy', 'value': 0.9355000257492065}]}\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Save model" + ], + "metadata": { + "id": "L55P7rx6nYE2" + } + }, + { + "cell_type": "code", + "source": [ + "drive.mount('/content/drive')\n", + "!cp -r /content/out/emotion /content/drive/MyDrive/models" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "QuuflS4qnZiw", + "outputId": "39ad5b6f-9019-49dc-a517-1e224d51a0bb" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" ] } ] diff --git a/README.md b/README.md index 9355880..5bede05 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ # Transformer Decoder - GPT-2 ## Modyfikacje -1. Zamrożenie pierwszych 40 warstw -2. Zmiana głowy klasyfikacyjnej poprzez dodanie po 2 warstwy dropout i relu() +1. Dodanie dodatkowej warstwy Linear do głowy +2. Wykorzystanie ukrytych stanów z t ostatnich warstw # Transformer Encoder-Decoder - T5 diff --git a/models/gpt2/README.md b/models/gpt2/README.md index 52c16b9..fdbd1d3 100644 --- a/models/gpt2/README.md +++ b/models/gpt2/README.md @@ -16,8 +16,8 @@ should probably proofread and complete it, then remove this comment. --> This model is a fine-tuned version of [gpt2](https://huggingface.co/gpt2) on an unknown dataset. It achieves the following results on the evaluation set: -- Loss: 0.2178 -- Accuracy: 0.9231 +- Loss: 0.1925 +- Accuracy: 0.9355 ## Model description @@ -36,13 +36,13 @@ More information needed ### Training hyperparameters The following hyperparameters were used during training: -- learning_rate: 2e-05 -- train_batch_size: 24 -- eval_batch_size: 24 +- learning_rate: 5e-05 +- train_batch_size: 8 +- eval_batch_size: 8 - seed: 42 - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08 - lr_scheduler_type: linear -- num_epochs: 5.0 +- num_epochs: 1.0 ### Training results diff --git a/models/gpt2/all_results.json b/models/gpt2/all_results.json index 4fde65d..71d6992 100644 --- a/models/gpt2/all_results.json +++ b/models/gpt2/all_results.json @@ -1,14 +1,14 @@ { - "epoch": 5.0, - "eval_accuracy": 0.9230769276618958, - "eval_loss": 0.2177695333957672, - "eval_runtime": 10.0539, - "eval_samples": 1274, - "eval_samples_per_second": 126.717, - "eval_steps_per_second": 5.371, - "train_loss": 0.689463275015069, - "train_runtime": 490.8844, - "train_samples": 4999, - "train_samples_per_second": 50.918, - "train_steps_per_second": 2.129 + "epoch": 1.0, + "eval_accuracy": 0.9355000257492065, + "eval_loss": 0.19254431128501892, + "eval_runtime": 17.1165, + "eval_samples": 2000, + "eval_samples_per_second": 116.846, + "eval_steps_per_second": 14.606, + "train_loss": 0.4504347610473633, + "train_runtime": 524.6759, + "train_samples": 16000, + "train_samples_per_second": 30.495, + "train_steps_per_second": 3.812 } \ No newline at end of file diff --git a/models/gpt2/eval_results.json b/models/gpt2/eval_results.json index 82a29d4..44770be 100644 --- a/models/gpt2/eval_results.json +++ b/models/gpt2/eval_results.json @@ -1,9 +1,9 @@ { - "epoch": 5.0, - "eval_accuracy": 0.9230769276618958, - "eval_loss": 0.2177695333957672, - "eval_runtime": 10.0539, - "eval_samples": 1274, - "eval_samples_per_second": 126.717, - "eval_steps_per_second": 5.371 + "epoch": 1.0, + "eval_accuracy": 0.9355000257492065, + "eval_loss": 0.19254431128501892, + "eval_runtime": 17.1165, + "eval_samples": 2000, + "eval_samples_per_second": 116.846, + "eval_steps_per_second": 14.606 } \ No newline at end of file diff --git a/models/gpt2/predict_results_None.txt b/models/gpt2/predict_results_None.txt index affc69c..1d58b2f 100644 --- a/models/gpt2/predict_results_None.txt +++ b/models/gpt2/predict_results_None.txt @@ -2,1278 +2,2000 @@ index prediction 0 0 1 0 2 0 -3 0 +3 1 4 0 -5 0 -6 0 -7 0 +5 4 +6 3 +7 1 8 1 -9 0 -10 4 +9 3 +10 3 11 0 -12 0 -13 3 -14 0 +12 4 +13 1 +14 2 15 0 -16 0 +16 1 17 0 -18 0 -19 0 +18 3 +19 1 20 0 -21 0 -22 0 +21 1 +22 1 23 0 24 0 -25 0 -26 0 +25 4 +26 3 27 0 -28 0 -29 0 -30 0 -31 0 +28 4 +29 3 +30 4 +31 3 32 0 -33 0 +33 3 34 0 -35 0 -36 0 +35 1 +36 1 37 0 -38 0 -39 0 -40 0 +38 1 +39 1 +40 3 41 0 -42 0 +42 1 43 0 44 1 -45 0 -46 0 +45 3 +46 1 47 1 -48 0 +48 4 49 4 50 0 -51 0 -52 0 +51 4 +52 1 53 0 -54 0 -55 4 +54 1 +55 0 56 0 -57 0 +57 1 58 0 -59 0 -60 1 +59 3 +60 0 61 0 -62 0 -63 0 +62 1 +63 1 64 0 -65 0 +65 5 66 0 67 0 -68 0 -69 0 -70 0 -71 0 -72 0 -73 0 -74 0 -75 0 -76 0 +68 4 +69 5 +70 1 +71 2 +72 4 +73 1 +74 2 +75 3 +76 1 77 0 -78 0 -79 0 -80 0 -81 0 +78 1 +79 2 +80 1 +81 3 82 0 -83 0 +83 1 84 0 85 0 -86 0 -87 0 -88 0 +86 2 +87 1 +88 1 89 0 -90 0 -91 3 -92 0 -93 0 -94 0 -95 0 -96 0 -97 4 +90 1 +91 4 +92 3 +93 4 +94 4 +95 3 +96 2 +97 0 98 3 99 0 100 0 -101 4 +101 0 102 0 -103 0 -104 0 -105 0 -106 4 -107 0 -108 0 +103 3 +104 3 +105 3 +106 1 +107 1 +108 4 109 0 -110 0 -111 0 -112 0 -113 0 +110 1 +111 2 +112 4 +113 1 114 0 -115 0 -116 0 -117 0 +115 1 +116 1 +117 4 118 0 119 0 -120 0 -121 0 +120 3 +121 1 122 0 -123 0 -124 1 -125 0 +123 3 +124 0 +125 2 126 0 -127 2 +127 4 128 0 129 0 -130 0 -131 0 +130 1 +131 2 132 0 -133 0 -134 0 -135 0 -136 0 -137 0 +133 3 +134 3 +135 1 +136 4 +137 4 138 0 -139 0 -140 0 +139 1 +140 1 141 0 -142 0 -143 0 -144 0 +142 4 +143 1 +144 1 145 0 -146 0 -147 0 +146 1 +147 4 148 4 -149 0 -150 0 -151 0 -152 0 +149 2 +150 3 +151 2 +152 4 153 0 -154 0 +154 1 155 0 -156 0 -157 0 -158 0 +156 1 +157 1 +158 3 159 0 -160 0 -161 0 -162 0 -163 0 -164 0 -165 0 -166 0 -167 0 -168 0 +160 3 +161 3 +162 1 +163 4 +164 4 +165 1 +166 2 +167 2 +168 2 169 0 -170 0 -171 0 -172 0 -173 0 +170 2 +171 3 +172 1 +173 1 174 0 -175 0 -176 0 -177 0 +175 3 +176 1 +177 1 178 0 -179 4 -180 0 -181 0 +179 0 +180 4 +181 1 182 0 -183 0 -184 0 -185 4 -186 0 -187 0 -188 0 -189 0 -190 0 -191 4 -192 0 +183 1 +184 4 +185 0 +186 1 +187 1 +188 4 +189 3 +190 1 +191 0 +192 1 193 0 -194 3 +194 0 195 0 -196 0 -197 0 -198 0 -199 3 -200 0 -201 0 -202 0 -203 0 -204 0 -205 0 -206 0 -207 0 -208 0 -209 0 -210 0 +196 4 +197 4 +198 1 +199 1 +200 1 +201 2 +202 2 +203 1 +204 1 +205 1 +206 2 +207 4 +208 4 +209 1 +210 3 211 0 -212 0 -213 0 -214 0 +212 1 +213 4 +214 1 215 0 -216 0 -217 1 -218 0 -219 0 -220 0 -221 0 -222 1 -223 0 -224 0 -225 0 -226 0 -227 0 -228 0 -229 3 +216 3 +217 0 +218 3 +219 3 +220 1 +221 4 +222 5 +223 1 +224 1 +225 1 +226 3 +227 1 +228 2 +229 4 230 0 231 0 -232 0 -233 0 +232 5 +233 1 234 0 -235 0 +235 1 236 0 -237 0 +237 1 238 0 -239 3 +239 1 240 0 -241 0 -242 0 -243 0 +241 2 +242 4 +243 1 244 0 245 0 246 0 -247 0 -248 0 -249 0 +247 3 +248 1 +249 5 250 0 251 0 252 0 -253 0 -254 0 +253 1 +254 1 255 0 -256 3 -257 0 +256 1 +257 1 258 0 -259 0 -260 0 -261 0 -262 0 -263 0 -264 0 -265 0 +259 3 +260 3 +261 3 +262 1 +263 3 +264 3 +265 1 266 1 -267 0 -268 0 +267 1 +268 3 269 0 -270 0 -271 0 -272 0 -273 3 -274 0 +270 3 +271 3 +272 4 +273 0 +274 3 275 0 -276 0 -277 0 -278 0 +276 1 +277 4 +278 3 279 0 -280 0 +280 1 281 0 282 0 283 1 -284 0 -285 0 -286 0 +284 1 +285 1 +286 2 287 0 -288 0 -289 0 +288 1 +289 1 290 0 -291 1 -292 1 -293 1 -294 1 -295 1 -296 1 -297 1 -298 1 -299 1 +291 4 +292 2 +293 0 +294 2 +295 4 +296 3 +297 0 +298 0 +299 3 300 1 -301 1 -302 1 -303 1 -304 1 -305 1 -306 1 -307 1 -308 1 -309 3 +301 3 +302 2 +303 2 +304 2 +305 5 +306 3 +307 2 +308 3 +309 4 310 1 -311 1 -312 1 +311 0 +312 4 313 1 -314 1 -315 1 -316 1 -317 1 -318 1 -319 1 -320 1 -321 1 -322 1 -323 2 -324 1 -325 1 +314 3 +315 3 +316 2 +317 5 +318 3 +319 5 +320 0 +321 0 +322 0 +323 1 +324 0 +325 0 326 1 -327 1 -328 2 -329 1 -330 1 -331 1 -332 1 -333 1 -334 1 -335 1 -336 1 +327 3 +328 0 +329 0 +330 4 +331 0 +332 0 +333 2 +334 0 +335 2 +336 0 337 1 338 1 -339 1 -340 1 -341 1 -342 1 +339 0 +340 3 +341 2 +342 5 343 1 344 1 -345 2 +345 1 346 1 -347 1 -348 1 +347 4 +348 4 349 1 -350 1 -351 1 -352 1 -353 1 +350 0 +351 0 +352 0 +353 2 354 1 -355 1 -356 1 -357 1 -358 1 +355 2 +356 2 +357 0 +358 0 359 1 -360 1 -361 1 -362 1 -363 1 +360 3 +361 0 +362 0 +363 3 364 1 365 1 -366 1 -367 1 -368 1 +366 2 +367 0 +368 3 369 2 -370 1 +370 0 371 1 372 1 -373 0 +373 3 374 1 -375 1 -376 1 -377 1 -378 1 -379 1 -380 1 -381 1 -382 1 -383 1 +375 0 +376 2 +377 0 +378 3 +379 0 +380 4 +381 3 +382 5 +383 4 384 1 -385 1 -386 1 -387 1 +385 3 +386 3 +387 4 388 1 -389 1 -390 1 -391 1 -392 1 -393 1 -394 1 -395 2 -396 1 +389 0 +390 0 +391 2 +392 0 +393 0 +394 4 +395 0 +396 3 397 1 -398 1 -399 1 -400 1 +398 0 +399 4 +400 4 401 1 -402 1 -403 1 -404 1 +402 5 +403 0 +404 2 405 1 406 1 -407 1 -408 1 -409 1 -410 1 -411 1 +407 0 +408 2 +409 0 +410 4 +411 3 412 1 413 1 -414 1 -415 1 -416 1 -417 1 -418 1 -419 1 +414 3 +415 5 +416 4 +417 0 +418 0 +419 0 420 1 -421 1 -422 1 +421 4 +422 4 423 1 -424 1 +424 0 425 1 -426 1 +426 4 427 1 -428 1 -429 1 -430 1 -431 1 -432 1 -433 1 -434 1 -435 1 -436 1 -437 1 -438 4 -439 1 +428 3 +429 4 +430 3 +431 0 +432 4 +433 4 +434 4 +435 0 +436 2 +437 3 +438 3 +439 0 440 1 -441 1 -442 1 -443 1 -444 1 +441 5 +442 3 +443 0 +444 4 445 1 -446 1 +446 0 447 0 448 1 -449 1 -450 1 -451 1 -452 1 -453 3 -454 1 -455 3 +449 3 +450 0 +451 3 +452 0 +453 0 +454 4 +455 0 456 1 -457 1 +457 4 458 1 -459 1 -460 1 +459 3 +460 5 461 1 -462 1 -463 1 -464 1 -465 4 -466 1 -467 1 +462 0 +463 0 +464 0 +465 1 +466 2 +467 0 468 1 -469 1 -470 1 -471 1 +469 5 +470 4 +471 0 472 1 473 1 -474 1 +474 3 475 1 -476 1 -477 2 -478 0 -479 1 -480 1 -481 1 -482 1 +476 4 +477 3 +478 1 +479 4 +480 0 +481 0 +482 3 483 1 484 1 -485 1 -486 1 -487 1 +485 0 +486 4 +487 3 488 1 -489 1 +489 2 490 1 -491 1 +491 0 492 1 -493 1 +493 4 494 1 495 1 -496 1 -497 1 +496 2 +497 0 498 1 -499 2 -500 1 -501 1 +499 4 +500 3 +501 3 502 0 503 1 504 1 505 1 -506 1 -507 1 -508 1 -509 1 +506 0 +507 3 +508 0 +509 0 510 1 511 1 512 1 513 1 -514 1 +514 0 515 1 516 1 517 1 -518 1 +518 3 519 1 -520 1 -521 1 -522 1 -523 1 +520 2 +521 3 +522 4 +523 3 524 1 -525 1 -526 1 -527 1 -528 1 +525 0 +526 4 +527 0 +528 3 529 1 530 1 531 1 532 1 533 1 534 1 -535 1 -536 2 -537 1 +535 2 +536 4 +537 3 538 1 -539 1 +539 0 540 1 -541 0 +541 1 542 1 -543 1 -544 2 -545 1 +543 5 +544 0 +545 3 546 1 547 1 548 1 -549 1 +549 0 550 1 -551 1 -552 1 -553 1 -554 2 -555 4 -556 1 -557 1 +551 0 +552 5 +553 0 +554 1 +555 1 +556 0 +557 3 558 1 559 1 560 1 -561 1 -562 1 +561 3 +562 0 563 1 564 0 -565 1 -566 1 -567 1 -568 1 +565 0 +566 0 +567 0 +568 0 569 1 570 1 -571 1 +571 0 572 1 -573 1 -574 1 -575 1 +573 0 +574 2 +575 2 576 1 -577 1 -578 1 +577 3 +578 4 579 1 580 1 581 1 -582 1 +582 0 583 1 -584 1 -585 1 -586 1 -587 1 -588 1 +584 0 +585 4 +586 0 +587 0 +588 0 589 1 -590 1 -591 1 -592 1 -593 1 -594 1 -595 1 -596 1 +590 4 +591 4 +592 0 +593 2 +594 0 +595 2 +596 5 597 1 -598 1 -599 1 -600 1 -601 1 -602 1 -603 2 -604 1 -605 1 -606 2 +598 0 +599 4 +600 0 +601 3 +602 0 +603 3 +604 3 +605 0 +606 1 607 1 -608 1 -609 0 -610 1 -611 1 +608 2 +609 1 +610 3 +611 0 612 1 -613 1 -614 1 -615 1 -616 1 +613 5 +614 2 +615 2 +616 0 617 1 618 1 -619 0 -620 1 -621 1 +619 1 +620 0 +621 2 622 1 -623 1 -624 1 -625 1 -626 1 +623 4 +624 3 +625 5 +626 4 627 2 628 1 -629 1 -630 1 +629 2 +630 2 631 1 -632 1 +632 4 633 1 -634 1 +634 0 635 1 636 1 637 1 638 1 -639 0 -640 0 +639 2 +640 1 641 0 -642 0 -643 0 -644 0 -645 0 -646 0 +642 1 +643 2 +644 1 +645 1 +646 1 647 1 -648 0 -649 4 -650 0 -651 0 -652 3 -653 0 -654 0 -655 0 +648 3 +649 0 +650 2 +651 1 +652 0 +653 1 +654 3 +655 1 656 0 -657 0 -658 0 -659 0 -660 0 +657 3 +658 4 +659 1 +660 3 661 0 -662 0 +662 3 663 0 -664 0 -665 0 -666 0 +664 3 +665 3 +666 4 667 0 -668 0 +668 4 669 0 -670 0 -671 0 -672 0 -673 0 -674 0 -675 0 -676 0 +670 2 +671 4 +672 1 +673 1 +674 1 +675 1 +676 3 677 0 -678 0 +678 1 679 0 680 0 -681 0 -682 0 -683 1 +681 5 +682 1 +683 0 684 0 -685 0 -686 1 -687 0 -688 4 +685 4 +686 0 +687 4 +688 3 689 0 690 0 691 0 -692 0 +692 4 693 0 -694 4 -695 0 +694 5 +695 1 696 0 -697 0 -698 0 +697 4 +698 3 699 1 -700 0 -701 0 +700 3 +701 1 702 0 703 0 704 0 -705 0 -706 0 -707 0 -708 0 +705 5 +706 1 +707 1 +708 1 709 0 710 0 -711 0 +711 3 712 0 -713 0 -714 0 -715 0 -716 0 -717 0 -718 0 -719 0 -720 0 -721 0 -722 0 -723 0 -724 0 +713 2 +714 4 +715 2 +716 1 +717 4 +718 3 +719 1 +720 1 +721 4 +722 1 +723 1 +724 3 725 0 -726 0 +726 4 727 0 -728 0 -729 0 -730 3 +728 1 +729 1 +730 1 731 0 732 0 -733 0 +733 1 734 0 -735 0 -736 4 -737 3 -738 0 -739 0 -740 4 -741 0 -742 0 +735 2 +736 0 +737 4 +738 3 +739 3 +740 1 +741 1 +742 1 743 0 -744 0 -745 4 -746 0 -747 0 +744 1 +745 1 +746 2 +747 1 748 0 749 0 750 0 -751 0 -752 0 +751 1 +752 1 753 0 -754 0 +754 3 755 0 -756 0 -757 0 -758 0 -759 0 -760 0 -761 0 -762 0 +756 3 +757 3 +758 1 +759 2 +760 1 +761 1 +762 3 763 1 764 0 -765 0 +765 1 766 2 -767 0 +767 1 768 0 -769 0 -770 0 -771 0 -772 0 -773 0 +769 3 +770 1 +771 1 +772 3 +773 1 774 0 -775 0 +775 1 776 0 777 0 -778 0 -779 0 -780 0 -781 0 -782 0 -783 0 +778 1 +779 1 +780 4 +781 1 +782 1 +783 3 784 0 -785 0 -786 0 -787 4 -788 0 -789 0 -790 0 +785 1 +786 1 +787 1 +788 1 +789 2 +790 1 791 0 -792 0 -793 0 -794 0 -795 0 +792 4 +793 1 +794 1 +795 4 796 0 -797 0 +797 1 798 0 -799 0 -800 0 -801 0 -802 0 +799 1 +800 1 +801 1 +802 2 803 0 804 0 -805 0 +805 1 806 0 807 0 -808 0 +808 1 809 0 -810 0 +810 3 811 0 -812 0 -813 0 -814 0 -815 0 -816 0 +812 1 +813 5 +814 1 +815 1 +816 1 817 0 -818 4 -819 0 -820 0 -821 0 -822 0 -823 0 +818 0 +819 1 +820 5 +821 2 +822 4 +823 3 824 4 -825 0 +825 1 826 0 827 0 -828 0 +828 4 829 0 830 4 -831 0 -832 0 -833 3 -834 0 +831 1 +832 4 +833 1 +834 4 835 0 -836 0 +836 3 837 0 -838 3 +838 0 839 0 -840 0 +840 1 841 0 -842 0 -843 0 -844 0 -845 0 +842 1 +843 1 +844 5 +845 2 846 0 847 0 -848 0 -849 0 +848 4 +849 1 850 0 851 0 -852 0 -853 0 +852 1 +853 1 854 0 855 0 -856 1 -857 0 +856 0 +857 1 858 0 859 0 860 0 -861 1 +861 4 862 0 863 0 -864 0 +864 1 865 0 866 0 -867 0 -868 3 -869 0 -870 0 -871 0 +867 1 +868 0 +869 1 +870 1 +871 4 872 0 873 0 -874 0 -875 0 +874 1 +875 5 876 0 -877 0 -878 3 -879 0 -880 0 -881 0 +877 1 +878 1 +879 5 +880 4 +881 3 882 0 -883 0 +883 4 884 0 -885 0 -886 0 +885 1 +886 2 887 0 -888 0 -889 0 -890 0 -891 0 -892 0 -893 0 -894 0 -895 3 -896 0 -897 0 +888 2 +889 1 +890 2 +891 1 +892 3 +893 1 +894 2 +895 2 +896 1 +897 5 898 0 -899 0 -900 0 -901 0 -902 0 -903 0 +899 1 +900 3 +901 1 +902 4 +903 1 904 0 -905 1 +905 3 906 0 -907 0 +907 3 908 0 -909 0 +909 2 910 0 -911 0 -912 3 -913 0 -914 0 +911 1 +912 0 +913 1 +914 1 915 0 -916 0 -917 0 +916 4 +917 2 918 0 919 0 -920 0 +920 1 921 0 -922 1 -923 0 -924 0 -925 0 +922 4 +923 2 +924 3 +925 1 926 0 -927 0 -928 0 -929 0 -930 1 +927 4 +928 1 +929 1 +930 3 931 1 -932 1 -933 1 +932 0 +933 3 934 1 -935 1 -936 1 -937 1 +935 0 +936 2 +937 3 938 1 939 1 940 1 -941 1 -942 1 -943 1 -944 1 -945 1 -946 1 -947 1 -948 3 -949 1 -950 1 +941 4 +942 4 +943 0 +944 0 +945 5 +946 0 +947 0 +948 1 +949 0 +950 0 951 1 952 1 -953 1 -954 1 +953 0 +954 4 955 1 956 1 -957 1 +957 2 958 1 -959 1 -960 1 -961 1 -962 2 +959 2 +960 2 +961 0 +962 0 963 1 964 1 -965 1 +965 0 966 1 -967 2 -968 1 -969 1 -970 1 +967 3 +968 3 +969 5 +970 3 971 1 972 1 -973 1 -974 1 +973 3 +974 0 975 1 -976 1 -977 1 -978 1 -979 1 +976 3 +977 4 +978 2 +979 3 980 1 -981 1 -982 1 -983 1 -984 2 -985 1 +981 0 +982 2 +983 0 +984 3 +985 2 986 1 987 1 988 1 -989 1 -990 1 -991 1 -992 1 +989 0 +990 0 +991 0 +992 0 993 1 -994 1 -995 1 -996 1 -997 1 +994 2 +995 4 +996 4 +997 2 998 1 -999 1 +999 0 1000 1 1001 1 -1002 1 -1003 1 -1004 1 +1002 3 +1003 5 +1004 3 1005 1 -1006 1 -1007 1 -1008 2 +1006 4 +1007 3 +1008 0 1009 1 -1010 1 +1010 3 1011 1 -1012 0 -1013 1 -1014 1 +1012 1 +1013 2 +1014 0 1015 1 -1016 1 -1017 1 -1018 1 +1016 0 +1017 2 +1018 4 1019 1 -1020 1 -1021 1 -1022 1 +1020 0 +1021 3 +1022 0 1023 1 -1024 1 -1025 1 -1026 1 -1027 1 +1024 4 +1025 3 +1026 4 +1027 4 1028 1 -1029 1 +1029 0 1030 1 -1031 1 -1032 1 -1033 1 -1034 2 +1031 0 +1032 0 +1033 3 +1034 0 1035 1 1036 1 -1037 1 -1038 1 +1037 4 +1038 4 1039 1 -1040 1 +1040 4 1041 1 1042 1 -1043 1 +1043 0 1044 1 -1045 1 -1046 1 +1045 0 +1046 3 1047 1 -1048 1 -1049 1 +1048 4 +1049 5 1050 1 -1051 1 +1051 5 1052 1 -1053 1 -1054 1 +1053 3 +1054 5 1055 1 -1056 1 +1056 0 1057 1 -1058 1 -1059 1 -1060 1 -1061 1 -1062 1 +1058 0 +1059 2 +1060 0 +1061 2 +1062 5 1063 1 -1064 1 -1065 1 +1064 2 +1065 0 1066 1 -1067 1 +1067 3 1068 1 1069 1 -1070 1 -1071 1 +1070 4 +1071 0 1072 1 1073 1 -1074 1 -1075 1 -1076 1 -1077 4 -1078 1 -1079 1 -1080 1 -1081 1 -1082 1 -1083 1 -1084 1 -1085 1 +1074 3 +1075 0 +1076 0 +1077 1 +1078 3 +1079 3 +1080 0 +1081 2 +1082 3 +1083 2 +1084 0 +1085 3 1086 0 -1087 1 -1088 1 +1087 2 +1088 0 1089 1 -1090 1 -1091 1 -1092 3 -1093 1 -1094 3 -1095 1 -1096 1 -1097 1 -1098 1 -1099 1 -1100 1 -1101 1 -1102 1 -1103 1 +1090 4 +1091 3 +1092 2 +1093 0 +1094 0 +1095 4 +1096 0 +1097 3 +1098 0 +1099 3 +1100 5 +1101 3 +1102 0 +1103 0 1104 4 1105 1 -1106 1 +1106 0 1107 1 -1108 1 -1109 1 -1110 1 -1111 1 -1112 1 -1113 1 +1108 4 +1109 3 +1110 0 +1111 2 +1112 4 +1113 3 1114 1 -1115 1 -1116 2 -1117 0 -1118 1 -1119 1 -1120 1 -1121 1 +1115 0 +1116 1 +1117 2 +1118 0 +1119 3 +1120 0 +1121 3 1122 1 -1123 1 -1124 1 -1125 1 +1123 2 +1124 0 +1125 0 1126 1 -1127 1 +1127 0 1128 1 1129 1 -1130 1 -1131 1 +1130 3 +1131 0 1132 1 -1133 1 -1134 1 -1135 1 +1133 4 +1134 0 +1135 0 1136 1 -1137 1 -1138 2 +1137 0 +1138 3 1139 1 1140 1 -1141 0 +1141 3 1142 1 -1143 1 -1144 1 +1143 0 +1144 3 1145 1 1146 1 1147 1 1148 1 -1149 1 -1150 1 -1151 1 +1149 3 +1150 3 +1151 0 1152 1 -1153 1 -1154 1 -1155 1 +1153 0 +1154 2 +1155 0 1156 1 1157 1 -1158 1 +1158 3 1159 1 1160 1 1161 1 -1162 1 +1162 4 1163 1 -1164 1 +1164 3 1165 1 -1166 1 -1167 1 -1168 1 +1166 0 +1167 0 +1168 4 1169 1 1170 1 1171 1 -1172 1 -1173 1 +1172 4 +1173 4 1174 1 -1175 2 +1175 0 1176 1 -1177 1 +1177 3 1178 1 1179 1 -1180 0 +1180 4 1181 1 1182 1 -1183 2 -1184 1 -1185 1 -1186 1 -1187 1 -1188 1 -1189 1 -1190 1 -1191 1 -1192 1 -1193 2 -1194 4 -1195 1 -1196 1 -1197 1 -1198 1 +1183 1 +1184 4 +1185 5 +1186 0 +1187 5 +1188 4 +1189 0 +1190 0 +1191 0 +1192 0 +1193 0 +1194 2 +1195 0 +1196 0 +1197 0 +1198 4 1199 1 -1200 1 -1201 1 -1202 1 +1200 3 +1201 0 +1202 5 1203 0 -1204 1 -1205 1 -1206 1 -1207 1 -1208 1 +1204 0 +1205 5 +1206 2 +1207 2 +1208 0 1209 1 1210 1 -1211 1 -1212 1 -1213 1 -1214 1 +1211 0 +1212 0 +1213 0 +1214 0 1215 1 -1216 1 +1216 2 1217 1 1218 1 1219 1 1220 1 1221 1 -1222 1 -1223 1 -1224 1 -1225 1 +1222 0 +1223 0 +1224 0 +1225 0 1226 1 -1227 1 +1227 0 1228 1 1229 1 1230 1 1231 1 -1232 1 +1232 3 1233 1 -1234 1 -1235 1 -1236 1 -1237 1 -1238 1 +1234 4 +1235 3 +1236 0 +1237 0 +1238 0 1239 1 -1240 1 -1241 1 -1242 2 -1243 1 +1240 5 +1241 0 +1242 0 +1243 0 1244 1 1245 2 -1246 1 +1246 0 1247 1 -1248 0 -1249 1 -1250 1 +1248 1 +1249 0 +1250 0 1251 1 -1252 1 -1253 1 +1252 3 +1253 0 1254 1 1255 1 -1256 1 -1257 1 -1258 0 -1259 1 +1256 4 +1257 3 +1258 4 +1259 0 1260 1 -1261 1 +1261 4 1262 1 -1263 1 +1263 0 1264 1 1265 1 -1266 2 +1266 4 1267 1 1268 1 1269 1 -1270 1 +1270 0 1271 1 1272 1 -1273 1 +1273 4 1274 1 -1275 1 +1275 4 1276 1 1277 1 +1278 4 +1279 4 +1280 1 +1281 3 +1282 4 +1283 1 +1284 0 +1285 0 +1286 1 +1287 4 +1288 0 +1289 3 +1290 0 +1291 1 +1292 0 +1293 0 +1294 4 +1295 0 +1296 3 +1297 0 +1298 0 +1299 3 +1300 3 +1301 2 +1302 1 +1303 1 +1304 0 +1305 4 +1306 5 +1307 1 +1308 5 +1309 1 +1310 1 +1311 1 +1312 4 +1313 1 +1314 4 +1315 1 +1316 1 +1317 1 +1318 3 +1319 3 +1320 3 +1321 4 +1322 5 +1323 3 +1324 1 +1325 4 +1326 0 +1327 3 +1328 2 +1329 0 +1330 2 +1331 1 +1332 1 +1333 2 +1334 0 +1335 0 +1336 1 +1337 0 +1338 0 +1339 0 +1340 2 +1341 0 +1342 1 +1343 1 +1344 0 +1345 1 +1346 2 +1347 2 +1348 1 +1349 4 +1350 1 +1351 1 +1352 3 +1353 1 +1354 3 +1355 4 +1356 1 +1357 1 +1358 2 +1359 1 +1360 2 +1361 1 +1362 1 +1363 1 +1364 4 +1365 0 +1366 0 +1367 4 +1368 0 +1369 1 +1370 2 +1371 3 +1372 0 +1373 1 +1374 0 +1375 0 +1376 3 +1377 2 +1378 0 +1379 1 +1380 1 +1381 2 +1382 0 +1383 3 +1384 4 +1385 1 +1386 2 +1387 2 +1388 1 +1389 1 +1390 1 +1391 1 +1392 1 +1393 0 +1394 5 +1395 2 +1396 0 +1397 3 +1398 4 +1399 3 +1400 1 +1401 3 +1402 1 +1403 4 +1404 0 +1405 3 +1406 1 +1407 1 +1408 1 +1409 4 +1410 1 +1411 1 +1412 4 +1413 3 +1414 0 +1415 0 +1416 1 +1417 1 +1418 1 +1419 3 +1420 4 +1421 4 +1422 0 +1423 0 +1424 3 +1425 4 +1426 3 +1427 1 +1428 1 +1429 1 +1430 0 +1431 1 +1432 1 +1433 0 +1434 1 +1435 0 +1436 1 +1437 1 +1438 0 +1439 0 +1440 1 +1441 4 +1442 4 +1443 0 +1444 1 +1445 5 +1446 0 +1447 2 +1448 1 +1449 1 +1450 1 +1451 5 +1452 3 +1453 0 +1454 1 +1455 0 +1456 0 +1457 1 +1458 2 +1459 1 +1460 1 +1461 2 +1462 0 +1463 2 +1464 4 +1465 5 +1466 1 +1467 3 +1468 1 +1469 4 +1470 3 +1471 0 +1472 1 +1473 1 +1474 0 +1475 0 +1476 2 +1477 4 +1478 2 +1479 1 +1480 1 +1481 1 +1482 0 +1483 0 +1484 3 +1485 5 +1486 1 +1487 1 +1488 2 +1489 0 +1490 1 +1491 3 +1492 4 +1493 1 +1494 0 +1495 1 +1496 4 +1497 2 +1498 1 +1499 3 +1500 1 +1501 1 +1502 1 +1503 0 +1504 0 +1505 0 +1506 1 +1507 3 +1508 4 +1509 3 +1510 3 +1511 0 +1512 1 +1513 2 +1514 1 +1515 1 +1516 1 +1517 4 +1518 1 +1519 0 +1520 1 +1521 0 +1522 0 +1523 3 +1524 2 +1525 1 +1526 1 +1527 0 +1528 4 +1529 1 +1530 3 +1531 4 +1532 1 +1533 0 +1534 4 +1535 3 +1536 1 +1537 2 +1538 4 +1539 1 +1540 0 +1541 2 +1542 3 +1543 3 +1544 1 +1545 2 +1546 0 +1547 0 +1548 0 +1549 1 +1550 1 +1551 0 +1552 4 +1553 1 +1554 0 +1555 0 +1556 3 +1557 0 +1558 2 +1559 1 +1560 4 +1561 4 +1562 1 +1563 4 +1564 0 +1565 3 +1566 1 +1567 0 +1568 1 +1569 0 +1570 3 +1571 2 +1572 4 +1573 1 +1574 0 +1575 4 +1576 4 +1577 4 +1578 0 +1579 1 +1580 0 +1581 0 +1582 0 +1583 2 +1584 3 +1585 2 +1586 3 +1587 0 +1588 2 +1589 1 +1590 1 +1591 3 +1592 2 +1593 0 +1594 1 +1595 0 +1596 0 +1597 1 +1598 3 +1599 5 +1600 0 +1601 1 +1602 3 +1603 0 +1604 1 +1605 0 +1606 4 +1607 2 +1608 1 +1609 4 +1610 4 +1611 1 +1612 3 +1613 3 +1614 1 +1615 0 +1616 0 +1617 1 +1618 4 +1619 1 +1620 1 +1621 0 +1622 0 +1623 1 +1624 4 +1625 3 +1626 3 +1627 1 +1628 3 +1629 3 +1630 1 +1631 0 +1632 4 +1633 4 +1634 5 +1635 1 +1636 1 +1637 5 +1638 1 +1639 1 +1640 1 +1641 3 +1642 2 +1643 4 +1644 0 +1645 0 +1646 1 +1647 1 +1648 1 +1649 3 +1650 2 +1651 0 +1652 1 +1653 4 +1654 3 +1655 4 +1656 1 +1657 4 +1658 0 +1659 0 +1660 0 +1661 0 +1662 1 +1663 0 +1664 0 +1665 0 +1666 0 +1667 3 +1668 0 +1669 0 +1670 1 +1671 1 +1672 1 +1673 4 +1674 1 +1675 3 +1676 0 +1677 4 +1678 3 +1679 3 +1680 1 +1681 3 +1682 0 +1683 1 +1684 2 +1685 2 +1686 1 +1687 1 +1688 1 +1689 1 +1690 0 +1691 3 +1692 1 +1693 1 +1694 1 +1695 3 +1696 0 +1697 0 +1698 0 +1699 3 +1700 1 +1701 2 +1702 0 +1703 1 +1704 0 +1705 1 +1706 0 +1707 1 +1708 2 +1709 2 +1710 3 +1711 1 +1712 3 +1713 0 +1714 2 +1715 0 +1716 4 +1717 1 +1718 4 +1719 1 +1720 0 +1721 5 +1722 1 +1723 3 +1724 1 +1725 3 +1726 1 +1727 1 +1728 0 +1729 2 +1730 3 +1731 2 +1732 2 +1733 1 +1734 5 +1735 4 +1736 5 +1737 3 +1738 3 +1739 2 +1740 4 +1741 3 +1742 0 +1743 4 +1744 0 +1745 2 +1746 1 +1747 4 +1748 0 +1749 0 +1750 1 +1751 0 +1752 0 +1753 1 +1754 1 +1755 0 +1756 2 +1757 1 +1758 1 +1759 0 +1760 4 +1761 1 +1762 4 +1763 1 +1764 2 +1765 2 +1766 0 +1767 1 +1768 0 +1769 1 +1770 1 +1771 0 +1772 1 +1773 1 +1774 5 +1775 5 +1776 1 +1777 1 +1778 0 +1779 0 +1780 0 +1781 0 +1782 1 +1783 3 +1784 4 +1785 4 +1786 0 +1787 0 +1788 4 +1789 0 +1790 0 +1791 5 +1792 1 +1793 1 +1794 3 +1795 0 +1796 3 +1797 5 +1798 0 +1799 1 +1800 1 +1801 1 +1802 0 +1803 0 +1804 0 +1805 0 +1806 0 +1807 1 +1808 4 +1809 3 +1810 4 +1811 4 +1812 0 +1813 0 +1814 0 +1815 1 +1816 3 +1817 1 +1818 2 +1819 3 +1820 0 +1821 0 +1822 1 +1823 3 +1824 1 +1825 1 +1826 1 +1827 0 +1828 0 +1829 4 +1830 1 +1831 1 +1832 3 +1833 2 +1834 1 +1835 3 +1836 2 +1837 3 +1838 0 +1839 1 +1840 1 +1841 0 +1842 5 +1843 1 +1844 2 +1845 1 +1846 0 +1847 1 +1848 1 +1849 2 +1850 1 +1851 2 +1852 1 +1853 4 +1854 1 +1855 0 +1856 4 +1857 1 +1858 3 +1859 1 +1860 0 +1861 1 +1862 3 +1863 1 +1864 0 +1865 0 +1866 1 +1867 0 +1868 3 +1869 1 +1870 1 +1871 4 +1872 1 +1873 1 +1874 4 +1875 1 +1876 0 +1877 4 +1878 0 +1879 0 +1880 1 +1881 0 +1882 1 +1883 1 +1884 0 +1885 1 +1886 0 +1887 2 +1888 1 +1889 2 +1890 3 +1891 1 +1892 3 +1893 0 +1894 1 +1895 0 +1896 0 +1897 0 +1898 1 +1899 2 +1900 0 +1901 0 +1902 2 +1903 0 +1904 1 +1905 0 +1906 0 +1907 3 +1908 0 +1909 3 +1910 1 +1911 4 +1912 4 +1913 3 +1914 1 +1915 1 +1916 1 +1917 3 +1918 0 +1919 0 +1920 1 +1921 1 +1922 0 +1923 0 +1924 1 +1925 4 +1926 0 +1927 0 +1928 0 +1929 3 +1930 0 +1931 1 +1932 0 +1933 1 +1934 1 +1935 4 +1936 1 +1937 0 +1938 3 +1939 1 +1940 0 +1941 1 +1942 4 +1943 3 +1944 1 +1945 2 +1946 3 +1947 0 +1948 0 +1949 1 +1950 0 +1951 1 +1952 4 +1953 1 +1954 0 +1955 0 +1956 5 +1957 5 +1958 1 +1959 1 +1960 0 +1961 2 +1962 1 +1963 1 +1964 0 +1965 0 +1966 1 +1967 1 +1968 0 +1969 2 +1970 2 +1971 1 +1972 1 +1973 1 +1974 3 +1975 4 +1976 1 +1977 0 +1978 0 +1979 3 +1980 0 +1981 3 +1982 3 +1983 0 +1984 3 +1985 1 +1986 4 +1987 0 +1988 0 +1989 0 +1990 2 +1991 2 +1992 3 +1993 3 +1994 0 +1995 3 +1996 3 +1997 1 +1998 1 +1999 4 diff --git a/models/gpt2/pytorch_model.bin b/models/gpt2/pytorch_model.bin new file mode 100644 index 0000000..33de568 Binary files /dev/null and b/models/gpt2/pytorch_model.bin differ diff --git a/models/gpt2/train_results.json b/models/gpt2/train_results.json index 0965139..8150ecb 100644 --- a/models/gpt2/train_results.json +++ b/models/gpt2/train_results.json @@ -1,8 +1,8 @@ { - "epoch": 5.0, - "train_loss": 0.689463275015069, - "train_runtime": 490.8844, - "train_samples": 4999, - "train_samples_per_second": 50.918, - "train_steps_per_second": 2.129 + "epoch": 1.0, + "train_loss": 0.4504347610473633, + "train_runtime": 524.6759, + "train_samples": 16000, + "train_samples_per_second": 30.495, + "train_steps_per_second": 3.812 } \ No newline at end of file diff --git a/models/gpt2/trainer_state.json b/models/gpt2/trainer_state.json index a260586..dafe3cc 100644 --- a/models/gpt2/trainer_state.json +++ b/models/gpt2/trainer_state.json @@ -1,37 +1,49 @@ { "best_metric": null, "best_model_checkpoint": null, - "epoch": 5.0, - "global_step": 1045, + "epoch": 1.0, + "global_step": 2000, "is_hyper_param_search": false, "is_local_process_zero": true, "is_world_process_zero": true, "log_history": [ { - "epoch": 2.39, - "learning_rate": 1.0430622009569378e-05, - "loss": 1.0247, + "epoch": 0.25, + "learning_rate": 3.7500000000000003e-05, + "loss": 0.9449, "step": 500 }, { - "epoch": 4.78, - "learning_rate": 8.612440191387561e-07, - "loss": 0.3843, + "epoch": 0.5, + "learning_rate": 2.5e-05, + "loss": 0.3705, "step": 1000 }, { - "epoch": 5.0, - "step": 1045, - "total_flos": 1723489601126400.0, - "train_loss": 0.689463275015069, - "train_runtime": 490.8844, - "train_samples_per_second": 50.918, - "train_steps_per_second": 2.129 + "epoch": 0.75, + "learning_rate": 1.25e-05, + "loss": 0.264, + "step": 1500 + }, + { + "epoch": 1.0, + "learning_rate": 0.0, + "loss": 0.2223, + "step": 2000 + }, + { + "epoch": 1.0, + "step": 2000, + "total_flos": 1204741472256000.0, + "train_loss": 0.4504347610473633, + "train_runtime": 524.6759, + "train_samples_per_second": 30.495, + "train_steps_per_second": 3.812 } ], - "max_steps": 1045, - "num_train_epochs": 5, - "total_flos": 1723489601126400.0, + "max_steps": 2000, + "num_train_epochs": 1, + "total_flos": 1204741472256000.0, "trial_name": null, "trial_params": null } diff --git a/models/gpt2/training_args.bin b/models/gpt2/training_args.bin index 8627cca..43e68f0 100644 Binary files a/models/gpt2/training_args.bin and b/models/gpt2/training_args.bin differ