{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "gpuClass": "standard", "widgets": { "application/vnd.jupyter.widget-state+json": { "7e615ede17554aecbadc0b8ca5b2ff5a": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_7ba3b7e45ae440668d42ab119c5b2cc0", "IPY_MODEL_2702f15d084b43d3b8fb0e1c6a9f6b48", "IPY_MODEL_179fc4e46e244e8fb3dbf861ee55db3d" ], "layout": "IPY_MODEL_fda6f92b0b274978ab504247eda27f14" } }, "7ba3b7e45ae440668d42ab119c5b2cc0": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_e56d169996cb42b9bc855d59aec82ba2", "placeholder": "​", "style": "IPY_MODEL_f7567129f0c64039afe13288341bdeb7", "value": "100%" } }, "2702f15d084b43d3b8fb0e1c6a9f6b48": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_68cf2956c07f42cdb61ba6f38afc7009", "max": 3, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_aaff5b55a35b49d9808ccc1661358d5e", "value": 3 } }, "179fc4e46e244e8fb3dbf861ee55db3d": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_20041d370edb4814855de96e66c24881", "placeholder": "​", "style": "IPY_MODEL_b1a7113efcc94094be3b5e7dea0c1236", "value": " 3/3 [00:00<00:00, 54.10it/s]" } }, "fda6f92b0b274978ab504247eda27f14": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e56d169996cb42b9bc855d59aec82ba2": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f7567129f0c64039afe13288341bdeb7": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "68cf2956c07f42cdb61ba6f38afc7009": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "aaff5b55a35b49d9808ccc1661358d5e": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "20041d370edb4814855de96e66c24881": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b1a7113efcc94094be3b5e7dea0c1236": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "cells": [ { "cell_type": "markdown", "source": [ "# Setup" ], "metadata": { "id": "n2A5EThJNiAy" } }, { "cell_type": "markdown", "source": [ "## Requirements" ], "metadata": { "id": "tPp2_1rDOFYA" } }, { "cell_type": "code", "execution_count": 67, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "OmsX3kG4bLTg", "outputId": "2ac1de01-0123-43c7-bd34-c2864b1bac57" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: torch in /usr/local/lib/python3.8/dist-packages (1.13.1+cu116)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch) (4.4.0)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: datasets in /usr/local/lib/python3.8/dist-packages (2.9.0)\n", "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (2023.1.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (4.64.1)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.8/dist-packages (from datasets) (3.2.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from datasets) (1.21.6)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.8/dist-packages (from datasets) (0.70.14)\n", "Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.12.0)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from datasets) (1.3.5)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from datasets) (23.0)\n", "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.18.0)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (2.25.1)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets) (3.8.3)\n", "Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (9.0.0)\n", "Requirement already satisfied: dill<0.3.7 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.3.6)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (6.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.8.2)\n", "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (2.1.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (6.0.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (4.0.2)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (22.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.3)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (3.9.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.4.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2022.12.7)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (1.26.14)\n", "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (4.0.0)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2.10)\n", "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2022.7.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.8/dist-packages (4.26.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (6.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (1.21.6)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (2022.6.2)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.13.2)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.12.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers) (3.9.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers) (2.25.1)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (23.0)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers) (4.64.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.4.0)\n", "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (4.0.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2022.12.7)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (1.26.14)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2.10)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.8/dist-packages (1.0.2)\n", "Requirement already satisfied: numpy>=1.14.6 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.21.6)\n", "Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.7.3)\n", "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.2.0)\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (3.1.0)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: evaluate in /usr/local/lib/python3.8/dist-packages (0.4.0)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.8/dist-packages (from evaluate) (3.2.0)\n", "Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2.9.0)\n", "Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2023.1.0)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from evaluate) (23.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from evaluate) (4.64.1)\n", "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.18.0)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2.25.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from evaluate) (1.21.6)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from evaluate) (1.3.5)\n", "Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.12.0)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.70.14)\n", "Requirement already satisfied: dill in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.3.6)\n", "Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (9.0.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (3.8.3)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.4.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub>=0.7.0->evaluate) (3.9.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (2022.12.7)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (2.10)\n", "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (4.0.0)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (1.26.14)\n", "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->evaluate) (2022.7.1)\n", "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->evaluate) (2.8.2)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (22.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.3)\n", "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (2.1.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (4.0.2)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.1)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.8.2)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.0.4)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->evaluate) (1.15.0)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: accelerate in /usr/local/lib/python3.8/dist-packages (0.16.0)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from accelerate) (5.4.8)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from accelerate) (1.21.6)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from accelerate) (23.0)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.8/dist-packages (from accelerate) (6.0)\n", "Requirement already satisfied: torch>=1.4.0 in /usr/local/lib/python3.8/dist-packages (from accelerate) (1.13.1+cu116)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.4.0->accelerate) (4.4.0)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.8/dist-packages (0.1.97)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: protobuf in /usr/local/lib/python3.8/dist-packages (3.19.6)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: sacrebleu in /usr/local/lib/python3.8/dist-packages (2.3.1)\n", "Requirement already satisfied: lxml in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (4.9.2)\n", "Requirement already satisfied: regex in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (2022.6.2)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (1.21.6)\n", "Requirement already satisfied: colorama in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (0.4.6)\n", "Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (0.8.10)\n", "Requirement already satisfied: portalocker in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (2.7.0)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: py7zr in /usr/local/lib/python3.8/dist-packages (0.20.4)\n", "Requirement already satisfied: texttable in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.6.7)\n", "Requirement already satisfied: multivolumefile>=0.2.3 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.2.3)\n", "Requirement already satisfied: brotli>=1.0.9 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.9)\n", "Requirement already satisfied: pybcj>=0.6.0 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.1)\n", "Requirement already satisfied: pyzstd>=0.14.4 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.15.3)\n", "Requirement already satisfied: pyppmd<1.1.0,>=0.18.1 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.0)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from py7zr) (5.4.8)\n", "Requirement already satisfied: pycryptodomex>=3.6.6 in /usr/local/lib/python3.8/dist-packages (from py7zr) (3.17)\n", "Requirement already satisfied: inflate64>=0.3.1 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.3.1)\n" ] } ], "source": [ "!pip install torch\n", "!pip install datasets\n", "!pip install transformers\n", "!pip install scikit-learn\n", "!pip install evaluate\n", "!pip install accelerate\n", "!pip install sentencepiece\n", "!pip install protobuf\n", "!pip install sacrebleu\n", "!pip install py7zr\n" ] }, { "cell_type": "markdown", "source": [ "## Imports" ], "metadata": { "id": "o3Kj9IzuOKMi" } }, { "cell_type": "code", "source": [ "import os\n", "import json\n", "import torch\n", "from google.colab import drive\n", "from pathlib import Path\n", "from typing import Dict, List\n", "from datasets import load_dataset\n", "from transformers import T5Tokenizer" ], "metadata": { "id": "r92S06noeSWE" }, "execution_count": 68, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Loading data" ], "metadata": { "id": "2UzLo91gNnsA" } }, { "cell_type": "code", "source": [ "loaded_data = load_dataset('emotion')\n", "!mkdir -v -p data\n", "train_path = Path('data/train.json')\n", "valid_path = Path('data/valid.json')\n", "test_path = Path('data/test.json')\n", "data_train, data_valid, data_test = [], [], []" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 0, "referenced_widgets": [ "7e615ede17554aecbadc0b8ca5b2ff5a", "7ba3b7e45ae440668d42ab119c5b2cc0", "2702f15d084b43d3b8fb0e1c6a9f6b48", "179fc4e46e244e8fb3dbf861ee55db3d", "fda6f92b0b274978ab504247eda27f14", "e56d169996cb42b9bc855d59aec82ba2", "f7567129f0c64039afe13288341bdeb7", "68cf2956c07f42cdb61ba6f38afc7009", "aaff5b55a35b49d9808ccc1661358d5e", "20041d370edb4814855de96e66c24881", "b1a7113efcc94094be3b5e7dea0c1236" ] }, "id": "n_miey7eb2Xr", "outputId": "7ec0c9cd-92b6-4c6f-eaa9-8418e1d904c9" }, "execution_count": 69, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "WARNING:datasets.builder:No config specified, defaulting to: emotion/split\n", "WARNING:datasets.builder:Found cached dataset emotion (/root/.cache/huggingface/datasets/emotion/split/1.0.0/cca5efe2dfeb58c1d098e0f9eeb200e9927d889b5a03c67097275dfb5fe463bd)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ " 0%| | 0/3 [00:00= max_size:\n", " break\n", " data_line = {\n", " 'label': int(data['label']),\n", " 'text': data['text'],\n", " }\n", " dataset.append(data_line)\n", "\n", "print(f'Train: {len(data_train):6d}')\n", "print(f'Valid: {len(data_valid):6d}')\n", "print(f'Test: {len(data_test):6d}')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BZ6afaRzGsxS", "outputId": "400213ed-60a7-4079-d7e9-b99d0a6b1a19" }, "execution_count": 70, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Train: 16000\n", "Valid: 2000\n", "Test: 2000\n" ] } ] }, { "cell_type": "code", "source": [ "MAP_LABEL_TRANSLATION = {\n", " 0: 'sadness',\n", " 1: 'joy',\n", " 2: 'love',\n", " 3: 'anger',\n", " 4: 'fear',\n", " 5: 'surprise',\n", "}" ], "metadata": { "id": "w0KyM4TrGxQY" }, "execution_count": 71, "outputs": [] }, { "cell_type": "code", "source": [ "def save_as_translations(original_save_path: Path, data_to_save: List[Dict]) -> None:\n", " file_name = 's2s-' + original_save_path.name\n", " file_path = original_save_path.parent / file_name\n", "\n", " print(f'Saving into: {file_path}')\n", " with open(file_path, 'wt') as f_write:\n", " for data_line in data_to_save:\n", " label = data_line['label']\n", " new_label = MAP_LABEL_TRANSLATION[label]\n", " data_line['label'] = new_label\n", " data_line_str = json.dumps(data_line)\n", " f_write.write(f'{data_line_str}\\n')" ], "metadata": { "id": "-EFRYeAYHIKN" }, "execution_count": 72, "outputs": [] }, { "cell_type": "code", "source": [ "for file_path, data_to_save in [(train_path, data_train), (valid_path, data_valid), (test_path, data_test)]:\n", " print(f'Saving into: {file_path}')\n", " with open(file_path, 'wt') as f_write:\n", " for data_line in data_to_save:\n", " data_line_str = json.dumps(data_line)\n", " f_write.write(f'{data_line_str}\\n')\n", " \n", " save_as_translations(file_path, data_to_save)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7RsrTNGCHIqc", "outputId": "41f73dae-2ac1-4da9-f9c6-fd56d9f9e819" }, "execution_count": 73, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Saving into: data/train.json\n", "Saving into: data/s2s-train.json\n", "Saving into: data/valid.json\n", "Saving into: data/s2s-valid.json\n", "Saving into: data/test.json\n", "Saving into: data/s2s-test.json\n" ] } ] }, { "cell_type": "code", "source": [ "!head data/train.json" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Svu6YYSaHK4t", "outputId": "9d3623ba-baf8-4cbe-deed-3efa4cbe7d9f" }, "execution_count": 74, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{\"label\": 0, \"text\": \"i didnt feel humiliated\"}\n", "{\"label\": 0, \"text\": \"i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake\"}\n", "{\"label\": 3, \"text\": \"im grabbing a minute to post i feel greedy wrong\"}\n", "{\"label\": 2, \"text\": \"i am ever feeling nostalgic about the fireplace i will know that it is still on the property\"}\n", "{\"label\": 3, \"text\": \"i am feeling grouchy\"}\n", "{\"label\": 0, \"text\": \"ive been feeling a little burdened lately wasnt sure why that was\"}\n", "{\"label\": 5, \"text\": \"ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funny\"}\n", "{\"label\": 4, \"text\": \"i feel as confused about life as a teenager or as jaded as a year old man\"}\n", "{\"label\": 1, \"text\": \"i have been with petronas for years i feel that petronas has performed well and made a huge profit\"}\n", "{\"label\": 2, \"text\": \"i feel romantic too\"}\n" ] } ] }, { "cell_type": "code", "source": [ "!head data/s2s-train.json" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5INZa4ZJHQbn", "outputId": "1a3c9934-d738-4339-f8e7-419499ab3867" }, "execution_count": 75, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{\"label\": \"sadness\", \"text\": \"i didnt feel humiliated\"}\n", "{\"label\": \"sadness\", \"text\": \"i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake\"}\n", "{\"label\": \"anger\", \"text\": \"im grabbing a minute to post i feel greedy wrong\"}\n", "{\"label\": \"love\", \"text\": \"i am ever feeling nostalgic about the fireplace i will know that it is still on the property\"}\n", "{\"label\": \"anger\", \"text\": \"i am feeling grouchy\"}\n", "{\"label\": \"sadness\", \"text\": \"ive been feeling a little burdened lately wasnt sure why that was\"}\n", "{\"label\": \"surprise\", \"text\": \"ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funny\"}\n", "{\"label\": \"fear\", \"text\": \"i feel as confused about life as a teenager or as jaded as a year old man\"}\n", "{\"label\": \"joy\", \"text\": \"i have been with petronas for years i feel that petronas has performed well and made a huge profit\"}\n", "{\"label\": \"love\", \"text\": \"i feel romantic too\"}\n" ] } ] }, { "cell_type": "code", "source": [ "# create tiny datasets for debugging purposes\n", "for file_name in [\"s2s-train\", \"s2s-valid\", \"s2s-test\"]:\n", " print(f\"=== {file_name} ===\")\n", " all_text = Path(f\"data/{file_name}.json\").read_text().split('\\n')\n", " text = all_text[:250] + all_text[-250:]\n", " Path(f\"data/{file_name}-500.json\").write_text(\"\\n\".join(text))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "OYeI-JvepSf7", "outputId": "d3f86e7a-a691-498d-b6b2-61f698c2218e" }, "execution_count": 76, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "=== s2s-train ===\n", "=== s2s-valid ===\n", "=== s2s-test ===\n" ] } ] }, { "cell_type": "code", "source": [ "!wc -l data/*" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_WSOgm50LI0m", "outputId": "0ad89231-0b18-4956-e10e-9975d2bb1f72" }, "execution_count": 77, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " 499 data/s2s-test-500.json\n", " 2000 data/s2s-test.json\n", " 499 data/s2s-train-500.json\n", " 16000 data/s2s-train.json\n", " 499 data/s2s-valid-500.json\n", " 2000 data/s2s-valid.json\n", " 2000 data/test.json\n", " 16000 data/train.json\n", " 2000 data/valid.json\n", " 41497 total\n" ] } ] }, { "cell_type": "markdown", "source": [ "# Zero Shot" ], "metadata": { "id": "6_unwNzOsl8i" } }, { "cell_type": "code", "source": [ "from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM\n", "import json\n", "import time" ], "metadata": { "id": "pYe_v630tK8M" }, "execution_count": 78, "outputs": [] }, { "cell_type": "code", "source": [ "!nvidia-smi" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xP_v-YiXAw5y", "outputId": "8415f15f-cce6-4d1b-d148-2641ede4ff98" }, "execution_count": 96, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Mon Feb 13 23:18:24 2023 \n", "+-----------------------------------------------------------------------------+\n", "| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 |\n", "|-------------------------------+----------------------+----------------------+\n", "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|===============================+======================+======================|\n", "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", "| N/A 71C P0 31W / 70W | 7320MiB / 15360MiB | 0% Default |\n", "| | | N/A |\n", "+-------------------------------+----------------------+----------------------+\n", " \n", "+-----------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=============================================================================|\n", "| 0 N/A N/A 5402 C 7317MiB |\n", "+-----------------------------------------------------------------------------+\n" ] } ] }, { "cell_type": "code", "source": [ "if torch.cuda.is_available():\n", " device = 0\n", "else:\n", " device = -1" ], "metadata": { "id": "tVvf2ZjwCsS2" }, "execution_count": 89, "outputs": [] }, { "cell_type": "code", "source": [ "def get_pipeline(pipeline_type: str, model_name: str, torch_dtype: torch.dtype=\"auto\"):\n", " class_type = AutoModelForSeq2SeqLM\n", " model = class_type.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float32)\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", " return pipeline(pipeline_type, model=model, tokenizer=tokenizer, device=device)" ], "metadata": { "id": "1V60Aax5tJt1" }, "execution_count": 91, "outputs": [] }, { "cell_type": "code", "source": [ "lm_pipeline = get_pipeline('text2text-generation', 'google/flan-t5-large')" ], "metadata": { "id": "BkhC-gr2soFF" }, "execution_count": 92, "outputs": [] }, { "cell_type": "code", "source": [ "def generate_prompt(text):\n", " labels = \"possible labels: sadness, joy, love, anger, surprise, fear\"\n", " prompt = labels + '\\n' + f'text: {text}' + '\\n' + 'label: '\n", " return prompt" ], "metadata": { "id": "6F30kPAqvYwb" }, "execution_count": 97, "outputs": [] }, { "cell_type": "code", "source": [ "def predict(text):\n", " return lm_pipeline(generate_prompt(text), do_sample=False)[0]['generated_text']" ], "metadata": { "id": "pdlOh5x3zvOT" }, "execution_count": 98, "outputs": [] }, { "cell_type": "code", "source": [ "with open('data/s2s-test.json') as f:\n", " time_start = time.time()\n", " total = 0\n", " correct = 0\n", " lines = f.readlines()\n", " test_cases_amount = len(lines)\n", " for line in lines:\n", " item = json.loads(line)\n", " text = item['text']\n", " label = item['label']\n", " total += 1\n", " if total % 50 == 0:\n", " print(f'{total}/{test_cases_amount}')\n", " if predict(text) == label:\n", " correct += 1\n", " time_end = time.time()\n", " print(f'Minutes elapsed: {(time_end - time_start) / 60}')\n", " print(f'Accuracy: {correct/total}')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yP2fKz87tqGr", "outputId": "89b43c0d-9c1f-4623-80ce-210f3448adff" }, "execution_count": 99, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "50/2000\n", "100/2000\n", "150/2000\n", "200/2000\n", "250/2000\n", "300/2000\n", "350/2000\n", "400/2000\n", "450/2000\n", "500/2000\n", "550/2000\n", "600/2000\n", "650/2000\n", "700/2000\n", "750/2000\n", "800/2000\n", "850/2000\n", "900/2000\n", "950/2000\n", "1000/2000\n", "1050/2000\n", "1100/2000\n", "1150/2000\n", "1200/2000\n", "1250/2000\n", "1300/2000\n", "1350/2000\n", "1400/2000\n", "1450/2000\n", "1500/2000\n", "1550/2000\n", "1600/2000\n", "1650/2000\n", "1700/2000\n", "1750/2000\n", "1800/2000\n", "1850/2000\n", "1900/2000\n", "1950/2000\n", "2000/2000\n", "Minutes elapsed: 3.088933833440145\n", "Accuracy: 0.6505\n" ] } ] } ] }