projekt-glebokie/FLAN_T5.ipynb
2023-02-14 00:25:11 +01:00

1083 lines
48 KiB
Plaintext
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"7e615ede17554aecbadc0b8ca5b2ff5a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_7ba3b7e45ae440668d42ab119c5b2cc0",
"IPY_MODEL_2702f15d084b43d3b8fb0e1c6a9f6b48",
"IPY_MODEL_179fc4e46e244e8fb3dbf861ee55db3d"
],
"layout": "IPY_MODEL_fda6f92b0b274978ab504247eda27f14"
}
},
"7ba3b7e45ae440668d42ab119c5b2cc0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_e56d169996cb42b9bc855d59aec82ba2",
"placeholder": "",
"style": "IPY_MODEL_f7567129f0c64039afe13288341bdeb7",
"value": "100%"
}
},
"2702f15d084b43d3b8fb0e1c6a9f6b48": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_68cf2956c07f42cdb61ba6f38afc7009",
"max": 3,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_aaff5b55a35b49d9808ccc1661358d5e",
"value": 3
}
},
"179fc4e46e244e8fb3dbf861ee55db3d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_20041d370edb4814855de96e66c24881",
"placeholder": "",
"style": "IPY_MODEL_b1a7113efcc94094be3b5e7dea0c1236",
"value": " 3/3 [00:00<00:00, 54.10it/s]"
}
},
"fda6f92b0b274978ab504247eda27f14": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e56d169996cb42b9bc855d59aec82ba2": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f7567129f0c64039afe13288341bdeb7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"68cf2956c07f42cdb61ba6f38afc7009": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"aaff5b55a35b49d9808ccc1661358d5e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"20041d370edb4814855de96e66c24881": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b1a7113efcc94094be3b5e7dea0c1236": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Setup"
],
"metadata": {
"id": "n2A5EThJNiAy"
}
},
{
"cell_type": "markdown",
"source": [
"## Requirements"
],
"metadata": {
"id": "tPp2_1rDOFYA"
}
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OmsX3kG4bLTg",
"outputId": "2ac1de01-0123-43c7-bd34-c2864b1bac57"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.8/dist-packages (1.13.1+cu116)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch) (4.4.0)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: datasets in /usr/local/lib/python3.8/dist-packages (2.9.0)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (2023.1.0)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (4.64.1)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.8/dist-packages (from datasets) (3.2.0)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from datasets) (1.21.6)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.8/dist-packages (from datasets) (0.70.14)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.12.0)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from datasets) (1.3.5)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from datasets) (23.0)\n",
"Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.18.0)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (2.25.1)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets) (3.8.3)\n",
"Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (9.0.0)\n",
"Requirement already satisfied: dill<0.3.7 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.3.6)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (6.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.8.2)\n",
"Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (4.0.2)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (22.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.3)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (3.9.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.4.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2022.12.7)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (1.26.14)\n",
"Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (4.0.0)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2.10)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2.8.2)\n",
"Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2022.7.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: transformers in /usr/local/lib/python3.8/dist-packages (4.26.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (6.0)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (1.21.6)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (2022.6.2)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.13.2)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.12.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers) (3.9.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers) (2.25.1)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (23.0)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers) (4.64.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.4.0)\n",
"Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (4.0.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2022.12.7)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (1.26.14)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2.10)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.8/dist-packages (1.0.2)\n",
"Requirement already satisfied: numpy>=1.14.6 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.21.6)\n",
"Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.7.3)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.2.0)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (3.1.0)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: evaluate in /usr/local/lib/python3.8/dist-packages (0.4.0)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.8/dist-packages (from evaluate) (3.2.0)\n",
"Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2.9.0)\n",
"Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2023.1.0)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from evaluate) (23.0)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from evaluate) (4.64.1)\n",
"Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.18.0)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2.25.1)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from evaluate) (1.21.6)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from evaluate) (1.3.5)\n",
"Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.12.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.70.14)\n",
"Requirement already satisfied: dill in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.3.6)\n",
"Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (9.0.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (3.8.3)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (6.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.4.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub>=0.7.0->evaluate) (3.9.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (2022.12.7)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (2.10)\n",
"Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (4.0.0)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (1.26.14)\n",
"Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->evaluate) (2022.7.1)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->evaluate) (2.8.2)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (22.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.3)\n",
"Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (2.1.1)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (4.0.2)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.1)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.8.2)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.0.4)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->evaluate) (1.15.0)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: accelerate in /usr/local/lib/python3.8/dist-packages (0.16.0)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from accelerate) (5.4.8)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from accelerate) (1.21.6)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from accelerate) (23.0)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.8/dist-packages (from accelerate) (6.0)\n",
"Requirement already satisfied: torch>=1.4.0 in /usr/local/lib/python3.8/dist-packages (from accelerate) (1.13.1+cu116)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.4.0->accelerate) (4.4.0)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.8/dist-packages (0.1.97)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: protobuf in /usr/local/lib/python3.8/dist-packages (3.19.6)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: sacrebleu in /usr/local/lib/python3.8/dist-packages (2.3.1)\n",
"Requirement already satisfied: lxml in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (4.9.2)\n",
"Requirement already satisfied: regex in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (2022.6.2)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (1.21.6)\n",
"Requirement already satisfied: colorama in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (0.4.6)\n",
"Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (0.8.10)\n",
"Requirement already satisfied: portalocker in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (2.7.0)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: py7zr in /usr/local/lib/python3.8/dist-packages (0.20.4)\n",
"Requirement already satisfied: texttable in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.6.7)\n",
"Requirement already satisfied: multivolumefile>=0.2.3 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.2.3)\n",
"Requirement already satisfied: brotli>=1.0.9 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.9)\n",
"Requirement already satisfied: pybcj>=0.6.0 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.1)\n",
"Requirement already satisfied: pyzstd>=0.14.4 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.15.3)\n",
"Requirement already satisfied: pyppmd<1.1.0,>=0.18.1 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.0)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from py7zr) (5.4.8)\n",
"Requirement already satisfied: pycryptodomex>=3.6.6 in /usr/local/lib/python3.8/dist-packages (from py7zr) (3.17)\n",
"Requirement already satisfied: inflate64>=0.3.1 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.3.1)\n"
]
}
],
"source": [
"!pip install torch\n",
"!pip install datasets\n",
"!pip install transformers\n",
"!pip install scikit-learn\n",
"!pip install evaluate\n",
"!pip install accelerate\n",
"!pip install sentencepiece\n",
"!pip install protobuf\n",
"!pip install sacrebleu\n",
"!pip install py7zr\n"
]
},
{
"cell_type": "markdown",
"source": [
"## Imports"
],
"metadata": {
"id": "o3Kj9IzuOKMi"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"import json\n",
"import torch\n",
"from google.colab import drive\n",
"from pathlib import Path\n",
"from typing import Dict, List\n",
"from datasets import load_dataset\n",
"from transformers import T5Tokenizer"
],
"metadata": {
"id": "r92S06noeSWE"
},
"execution_count": 68,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Loading data"
],
"metadata": {
"id": "2UzLo91gNnsA"
}
},
{
"cell_type": "code",
"source": [
"loaded_data = load_dataset('emotion')\n",
"!mkdir -v -p data\n",
"train_path = Path('data/train.json')\n",
"valid_path = Path('data/valid.json')\n",
"test_path = Path('data/test.json')\n",
"data_train, data_valid, data_test = [], [], []"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 0,
"referenced_widgets": [
"7e615ede17554aecbadc0b8ca5b2ff5a",
"7ba3b7e45ae440668d42ab119c5b2cc0",
"2702f15d084b43d3b8fb0e1c6a9f6b48",
"179fc4e46e244e8fb3dbf861ee55db3d",
"fda6f92b0b274978ab504247eda27f14",
"e56d169996cb42b9bc855d59aec82ba2",
"f7567129f0c64039afe13288341bdeb7",
"68cf2956c07f42cdb61ba6f38afc7009",
"aaff5b55a35b49d9808ccc1661358d5e",
"20041d370edb4814855de96e66c24881",
"b1a7113efcc94094be3b5e7dea0c1236"
]
},
"id": "n_miey7eb2Xr",
"outputId": "7ec0c9cd-92b6-4c6f-eaa9-8418e1d904c9"
},
"execution_count": 69,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:datasets.builder:No config specified, defaulting to: emotion/split\n",
"WARNING:datasets.builder:Found cached dataset emotion (/root/.cache/huggingface/datasets/emotion/split/1.0.0/cca5efe2dfeb58c1d098e0f9eeb200e9927d889b5a03c67097275dfb5fe463bd)\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "7e615ede17554aecbadc0b8ca5b2ff5a"
}
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"for source_data, dataset, max_size in [\n",
" (loaded_data['train'], data_train, None),\n",
" (loaded_data['validation'], data_valid, None),\n",
" (loaded_data['test'], data_test, None),\n",
"]:\n",
" for i, data in enumerate(source_data):\n",
" if max_size is not None and i >= max_size:\n",
" break\n",
" data_line = {\n",
" 'label': int(data['label']),\n",
" 'text': data['text'],\n",
" }\n",
" dataset.append(data_line)\n",
"\n",
"print(f'Train: {len(data_train):6d}')\n",
"print(f'Valid: {len(data_valid):6d}')\n",
"print(f'Test: {len(data_test):6d}')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BZ6afaRzGsxS",
"outputId": "400213ed-60a7-4079-d7e9-b99d0a6b1a19"
},
"execution_count": 70,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Train: 16000\n",
"Valid: 2000\n",
"Test: 2000\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"MAP_LABEL_TRANSLATION = {\n",
" 0: 'sadness',\n",
" 1: 'joy',\n",
" 2: 'love',\n",
" 3: 'anger',\n",
" 4: 'fear',\n",
" 5: 'surprise',\n",
"}"
],
"metadata": {
"id": "w0KyM4TrGxQY"
},
"execution_count": 71,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def save_as_translations(original_save_path: Path, data_to_save: List[Dict]) -> None:\n",
" file_name = 's2s-' + original_save_path.name\n",
" file_path = original_save_path.parent / file_name\n",
"\n",
" print(f'Saving into: {file_path}')\n",
" with open(file_path, 'wt') as f_write:\n",
" for data_line in data_to_save:\n",
" label = data_line['label']\n",
" new_label = MAP_LABEL_TRANSLATION[label]\n",
" data_line['label'] = new_label\n",
" data_line_str = json.dumps(data_line)\n",
" f_write.write(f'{data_line_str}\\n')"
],
"metadata": {
"id": "-EFRYeAYHIKN"
},
"execution_count": 72,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for file_path, data_to_save in [(train_path, data_train), (valid_path, data_valid), (test_path, data_test)]:\n",
" print(f'Saving into: {file_path}')\n",
" with open(file_path, 'wt') as f_write:\n",
" for data_line in data_to_save:\n",
" data_line_str = json.dumps(data_line)\n",
" f_write.write(f'{data_line_str}\\n')\n",
" \n",
" save_as_translations(file_path, data_to_save)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7RsrTNGCHIqc",
"outputId": "41f73dae-2ac1-4da9-f9c6-fd56d9f9e819"
},
"execution_count": 73,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Saving into: data/train.json\n",
"Saving into: data/s2s-train.json\n",
"Saving into: data/valid.json\n",
"Saving into: data/s2s-valid.json\n",
"Saving into: data/test.json\n",
"Saving into: data/s2s-test.json\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!head data/train.json"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Svu6YYSaHK4t",
"outputId": "9d3623ba-baf8-4cbe-deed-3efa4cbe7d9f"
},
"execution_count": 74,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{\"label\": 0, \"text\": \"i didnt feel humiliated\"}\n",
"{\"label\": 0, \"text\": \"i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake\"}\n",
"{\"label\": 3, \"text\": \"im grabbing a minute to post i feel greedy wrong\"}\n",
"{\"label\": 2, \"text\": \"i am ever feeling nostalgic about the fireplace i will know that it is still on the property\"}\n",
"{\"label\": 3, \"text\": \"i am feeling grouchy\"}\n",
"{\"label\": 0, \"text\": \"ive been feeling a little burdened lately wasnt sure why that was\"}\n",
"{\"label\": 5, \"text\": \"ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funny\"}\n",
"{\"label\": 4, \"text\": \"i feel as confused about life as a teenager or as jaded as a year old man\"}\n",
"{\"label\": 1, \"text\": \"i have been with petronas for years i feel that petronas has performed well and made a huge profit\"}\n",
"{\"label\": 2, \"text\": \"i feel romantic too\"}\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!head data/s2s-train.json"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5INZa4ZJHQbn",
"outputId": "1a3c9934-d738-4339-f8e7-419499ab3867"
},
"execution_count": 75,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{\"label\": \"sadness\", \"text\": \"i didnt feel humiliated\"}\n",
"{\"label\": \"sadness\", \"text\": \"i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake\"}\n",
"{\"label\": \"anger\", \"text\": \"im grabbing a minute to post i feel greedy wrong\"}\n",
"{\"label\": \"love\", \"text\": \"i am ever feeling nostalgic about the fireplace i will know that it is still on the property\"}\n",
"{\"label\": \"anger\", \"text\": \"i am feeling grouchy\"}\n",
"{\"label\": \"sadness\", \"text\": \"ive been feeling a little burdened lately wasnt sure why that was\"}\n",
"{\"label\": \"surprise\", \"text\": \"ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funny\"}\n",
"{\"label\": \"fear\", \"text\": \"i feel as confused about life as a teenager or as jaded as a year old man\"}\n",
"{\"label\": \"joy\", \"text\": \"i have been with petronas for years i feel that petronas has performed well and made a huge profit\"}\n",
"{\"label\": \"love\", \"text\": \"i feel romantic too\"}\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# create tiny datasets for debugging purposes\n",
"for file_name in [\"s2s-train\", \"s2s-valid\", \"s2s-test\"]:\n",
" print(f\"=== {file_name} ===\")\n",
" all_text = Path(f\"data/{file_name}.json\").read_text().split('\\n')\n",
" text = all_text[:250] + all_text[-250:]\n",
" Path(f\"data/{file_name}-500.json\").write_text(\"\\n\".join(text))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OYeI-JvepSf7",
"outputId": "d3f86e7a-a691-498d-b6b2-61f698c2218e"
},
"execution_count": 76,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"=== s2s-train ===\n",
"=== s2s-valid ===\n",
"=== s2s-test ===\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!wc -l data/*"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_WSOgm50LI0m",
"outputId": "0ad89231-0b18-4956-e10e-9975d2bb1f72"
},
"execution_count": 77,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" 499 data/s2s-test-500.json\n",
" 2000 data/s2s-test.json\n",
" 499 data/s2s-train-500.json\n",
" 16000 data/s2s-train.json\n",
" 499 data/s2s-valid-500.json\n",
" 2000 data/s2s-valid.json\n",
" 2000 data/test.json\n",
" 16000 data/train.json\n",
" 2000 data/valid.json\n",
" 41497 total\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Zero Shot"
],
"metadata": {
"id": "6_unwNzOsl8i"
}
},
{
"cell_type": "code",
"source": [
"from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM\n",
"import json\n",
"import time"
],
"metadata": {
"id": "pYe_v630tK8M"
},
"execution_count": 78,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!nvidia-smi"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xP_v-YiXAw5y",
"outputId": "8415f15f-cce6-4d1b-d148-2641ede4ff98"
},
"execution_count": 96,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mon Feb 13 23:18:24 2023 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 71C P0 31W / 70W | 7320MiB / 15360MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| 0 N/A N/A 5402 C 7317MiB |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"if torch.cuda.is_available():\n",
" device = 0\n",
"else:\n",
" device = -1"
],
"metadata": {
"id": "tVvf2ZjwCsS2"
},
"execution_count": 89,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_pipeline(pipeline_type: str, model_name: str, torch_dtype: torch.dtype=\"auto\"):\n",
" class_type = AutoModelForSeq2SeqLM\n",
" model = class_type.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float32)\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
" return pipeline(pipeline_type, model=model, tokenizer=tokenizer, device=device)"
],
"metadata": {
"id": "1V60Aax5tJt1"
},
"execution_count": 91,
"outputs": []
},
{
"cell_type": "code",
"source": [
"lm_pipeline = get_pipeline('text2text-generation', 'google/flan-t5-large')"
],
"metadata": {
"id": "BkhC-gr2soFF"
},
"execution_count": 92,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def generate_prompt(text):\n",
" labels = \"possible labels: sadness, joy, love, anger, surprise, fear\"\n",
" prompt = labels + '\\n' + f'text: {text}' + '\\n' + 'label: '\n",
" return prompt"
],
"metadata": {
"id": "6F30kPAqvYwb"
},
"execution_count": 97,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def predict(text):\n",
" return lm_pipeline(generate_prompt(text), do_sample=False)[0]['generated_text']"
],
"metadata": {
"id": "pdlOh5x3zvOT"
},
"execution_count": 98,
"outputs": []
},
{
"cell_type": "code",
"source": [
"with open('data/s2s-test.json') as f:\n",
" time_start = time.time()\n",
" total = 0\n",
" correct = 0\n",
" lines = f.readlines()\n",
" test_cases_amount = len(lines)\n",
" for line in lines:\n",
" item = json.loads(line)\n",
" text = item['text']\n",
" label = item['label']\n",
" total += 1\n",
" if total % 50 == 0:\n",
" print(f'{total}/{test_cases_amount}')\n",
" if predict(text) == label:\n",
" correct += 1\n",
" time_end = time.time()\n",
" print(f'Minutes elapsed: {(time_end - time_start) / 60}')\n",
" print(f'Accuracy: {correct/total}')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "yP2fKz87tqGr",
"outputId": "89b43c0d-9c1f-4623-80ce-210f3448adff"
},
"execution_count": 99,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"50/2000\n",
"100/2000\n",
"150/2000\n",
"200/2000\n",
"250/2000\n",
"300/2000\n",
"350/2000\n",
"400/2000\n",
"450/2000\n",
"500/2000\n",
"550/2000\n",
"600/2000\n",
"650/2000\n",
"700/2000\n",
"750/2000\n",
"800/2000\n",
"850/2000\n",
"900/2000\n",
"950/2000\n",
"1000/2000\n",
"1050/2000\n",
"1100/2000\n",
"1150/2000\n",
"1200/2000\n",
"1250/2000\n",
"1300/2000\n",
"1350/2000\n",
"1400/2000\n",
"1450/2000\n",
"1500/2000\n",
"1550/2000\n",
"1600/2000\n",
"1650/2000\n",
"1700/2000\n",
"1750/2000\n",
"1800/2000\n",
"1850/2000\n",
"1900/2000\n",
"1950/2000\n",
"2000/2000\n",
"Minutes elapsed: 3.088933833440145\n",
"Accuracy: 0.6505\n"
]
}
]
}
]
}