projekt-glebokie/FLAN_T5.ipynb

1083 lines
48 KiB
Plaintext
Raw Permalink Normal View History

2023-02-12 23:22:40 +01:00
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
2023-02-14 00:25:11 +01:00
"accelerator": "GPU",
2023-02-12 23:22:40 +01:00
"gpuClass": "standard",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
2023-02-14 00:25:11 +01:00
"7e615ede17554aecbadc0b8ca5b2ff5a": {
2023-02-12 23:22:40 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
2023-02-14 00:25:11 +01:00
"IPY_MODEL_7ba3b7e45ae440668d42ab119c5b2cc0",
"IPY_MODEL_2702f15d084b43d3b8fb0e1c6a9f6b48",
"IPY_MODEL_179fc4e46e244e8fb3dbf861ee55db3d"
2023-02-12 23:22:40 +01:00
],
2023-02-14 00:25:11 +01:00
"layout": "IPY_MODEL_fda6f92b0b274978ab504247eda27f14"
2023-02-12 23:22:40 +01:00
}
},
2023-02-14 00:25:11 +01:00
"7ba3b7e45ae440668d42ab119c5b2cc0": {
2023-02-12 23:22:40 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
2023-02-14 00:25:11 +01:00
"layout": "IPY_MODEL_e56d169996cb42b9bc855d59aec82ba2",
2023-02-12 23:22:40 +01:00
"placeholder": "",
2023-02-14 00:25:11 +01:00
"style": "IPY_MODEL_f7567129f0c64039afe13288341bdeb7",
2023-02-12 23:22:40 +01:00
"value": "100%"
}
},
2023-02-14 00:25:11 +01:00
"2702f15d084b43d3b8fb0e1c6a9f6b48": {
2023-02-12 23:22:40 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
2023-02-14 00:25:11 +01:00
"layout": "IPY_MODEL_68cf2956c07f42cdb61ba6f38afc7009",
2023-02-12 23:22:40 +01:00
"max": 3,
"min": 0,
"orientation": "horizontal",
2023-02-14 00:25:11 +01:00
"style": "IPY_MODEL_aaff5b55a35b49d9808ccc1661358d5e",
2023-02-12 23:22:40 +01:00
"value": 3
}
},
2023-02-14 00:25:11 +01:00
"179fc4e46e244e8fb3dbf861ee55db3d": {
2023-02-12 23:22:40 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
2023-02-14 00:25:11 +01:00
"layout": "IPY_MODEL_20041d370edb4814855de96e66c24881",
2023-02-12 23:22:40 +01:00
"placeholder": "",
2023-02-14 00:25:11 +01:00
"style": "IPY_MODEL_b1a7113efcc94094be3b5e7dea0c1236",
"value": " 3/3 [00:00<00:00, 54.10it/s]"
2023-02-12 23:22:40 +01:00
}
},
2023-02-14 00:25:11 +01:00
"fda6f92b0b274978ab504247eda27f14": {
2023-02-12 23:22:40 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2023-02-14 00:25:11 +01:00
"e56d169996cb42b9bc855d59aec82ba2": {
2023-02-12 23:22:40 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2023-02-14 00:25:11 +01:00
"f7567129f0c64039afe13288341bdeb7": {
2023-02-12 23:22:40 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
2023-02-14 00:25:11 +01:00
"68cf2956c07f42cdb61ba6f38afc7009": {
2023-02-12 23:22:40 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2023-02-14 00:25:11 +01:00
"aaff5b55a35b49d9808ccc1661358d5e": {
2023-02-12 23:22:40 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
2023-02-14 00:25:11 +01:00
"20041d370edb4814855de96e66c24881": {
2023-02-12 23:22:40 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2023-02-14 00:25:11 +01:00
"b1a7113efcc94094be3b5e7dea0c1236": {
2023-02-12 23:22:40 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Setup"
],
"metadata": {
"id": "n2A5EThJNiAy"
}
},
{
"cell_type": "markdown",
"source": [
"## Requirements"
],
"metadata": {
"id": "tPp2_1rDOFYA"
}
},
{
"cell_type": "code",
2023-02-14 00:25:11 +01:00
"execution_count": 67,
2023-02-12 23:22:40 +01:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OmsX3kG4bLTg",
2023-02-14 00:25:11 +01:00
"outputId": "2ac1de01-0123-43c7-bd34-c2864b1bac57"
2023-02-12 23:22:40 +01:00
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.8/dist-packages (1.13.1+cu116)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch) (4.4.0)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: datasets in /usr/local/lib/python3.8/dist-packages (2.9.0)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (2023.1.0)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (4.64.1)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: xxhash in /usr/local/lib/python3.8/dist-packages (from datasets) (3.2.0)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from datasets) (1.21.6)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.8/dist-packages (from datasets) (0.70.14)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.12.0)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from datasets) (1.3.5)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from datasets) (23.0)\n",
"Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.18.0)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (2.25.1)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets) (3.8.3)\n",
"Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (9.0.0)\n",
"Requirement already satisfied: dill<0.3.7 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.3.6)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (6.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.1)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.8.2)\n",
"Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (6.0.4)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (4.0.2)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (22.2.0)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.3)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (3.9.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.4.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2022.12.7)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (1.26.14)\n",
"Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (4.0.0)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2.10)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2.8.2)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2022.7.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: transformers in /usr/local/lib/python3.8/dist-packages (4.26.1)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (6.0)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (1.21.6)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (2022.6.2)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.13.2)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.12.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers) (3.9.0)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers) (2.25.1)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (23.0)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers) (4.64.1)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.4.0)\n",
"Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (4.0.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2022.12.7)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (1.26.14)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2.10)\n",
2023-02-12 23:22:40 +01:00
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.8/dist-packages (1.0.2)\n",
"Requirement already satisfied: numpy>=1.14.6 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.21.6)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.7.3)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.2.0)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (3.1.0)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: evaluate in /usr/local/lib/python3.8/dist-packages (0.4.0)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.8/dist-packages (from evaluate) (3.2.0)\n",
"Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2.9.0)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2023.1.0)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from evaluate) (23.0)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from evaluate) (4.64.1)\n",
"Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.18.0)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2.25.1)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from evaluate) (1.21.6)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from evaluate) (1.3.5)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.12.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.70.14)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: dill in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.3.6)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (9.0.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (3.8.3)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (6.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.4.0)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub>=0.7.0->evaluate) (3.9.0)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (2022.12.7)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (2.10)\n",
"Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (4.0.0)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (1.26.14)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->evaluate) (2022.7.1)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->evaluate) (2.8.2)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (22.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.3)\n",
"Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (2.1.1)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (4.0.2)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.1)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.8.2)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.0.4)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->evaluate) (1.15.0)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: accelerate in /usr/local/lib/python3.8/dist-packages (0.16.0)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from accelerate) (5.4.8)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from accelerate) (1.21.6)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from accelerate) (23.0)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.8/dist-packages (from accelerate) (6.0)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: torch>=1.4.0 in /usr/local/lib/python3.8/dist-packages (from accelerate) (1.13.1+cu116)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.4.0->accelerate) (4.4.0)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.8/dist-packages (0.1.97)\n",
2023-02-12 23:22:40 +01:00
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: protobuf in /usr/local/lib/python3.8/dist-packages (3.19.6)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: sacrebleu in /usr/local/lib/python3.8/dist-packages (2.3.1)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: lxml in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (4.9.2)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: regex in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (2022.6.2)\n",
2023-02-12 23:22:40 +01:00
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (1.21.6)\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: colorama in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (0.4.6)\n",
"Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (0.8.10)\n",
"Requirement already satisfied: portalocker in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (2.7.0)\n",
2023-02-12 23:22:40 +01:00
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2023-02-14 00:25:11 +01:00
"Requirement already satisfied: py7zr in /usr/local/lib/python3.8/dist-packages (0.20.4)\n",
"Requirement already satisfied: texttable in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.6.7)\n",
"Requirement already satisfied: multivolumefile>=0.2.3 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.2.3)\n",
"Requirement already satisfied: brotli>=1.0.9 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.9)\n",
"Requirement already satisfied: pybcj>=0.6.0 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.1)\n",
"Requirement already satisfied: pyzstd>=0.14.4 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.15.3)\n",
"Requirement already satisfied: pyppmd<1.1.0,>=0.18.1 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.0)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from py7zr) (5.4.8)\n",
"Requirement already satisfied: pycryptodomex>=3.6.6 in /usr/local/lib/python3.8/dist-packages (from py7zr) (3.17)\n",
"Requirement already satisfied: inflate64>=0.3.1 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.3.1)\n"
2023-02-12 23:22:40 +01:00
]
}
],
"source": [
"!pip install torch\n",
"!pip install datasets\n",
"!pip install transformers\n",
"!pip install scikit-learn\n",
"!pip install evaluate\n",
"!pip install accelerate\n",
"!pip install sentencepiece\n",
"!pip install protobuf\n",
"!pip install sacrebleu\n",
"!pip install py7zr\n"
]
},
{
"cell_type": "markdown",
"source": [
"## Imports"
],
"metadata": {
"id": "o3Kj9IzuOKMi"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"import json\n",
"import torch\n",
"from google.colab import drive\n",
"from pathlib import Path\n",
"from typing import Dict, List\n",
"from datasets import load_dataset\n",
"from transformers import T5Tokenizer"
],
"metadata": {
"id": "r92S06noeSWE"
},
2023-02-14 00:25:11 +01:00
"execution_count": 68,
2023-02-12 23:22:40 +01:00
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Loading data"
],
"metadata": {
"id": "2UzLo91gNnsA"
}
},
{
"cell_type": "code",
"source": [
"loaded_data = load_dataset('emotion')\n",
"!mkdir -v -p data\n",
"train_path = Path('data/train.json')\n",
"valid_path = Path('data/valid.json')\n",
"test_path = Path('data/test.json')\n",
"data_train, data_valid, data_test = [], [], []"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 0,
"referenced_widgets": [
2023-02-14 00:25:11 +01:00
"7e615ede17554aecbadc0b8ca5b2ff5a",
"7ba3b7e45ae440668d42ab119c5b2cc0",
"2702f15d084b43d3b8fb0e1c6a9f6b48",
"179fc4e46e244e8fb3dbf861ee55db3d",
"fda6f92b0b274978ab504247eda27f14",
"e56d169996cb42b9bc855d59aec82ba2",
"f7567129f0c64039afe13288341bdeb7",
"68cf2956c07f42cdb61ba6f38afc7009",
"aaff5b55a35b49d9808ccc1661358d5e",
"20041d370edb4814855de96e66c24881",
"b1a7113efcc94094be3b5e7dea0c1236"
2023-02-12 23:22:40 +01:00
]
},
"id": "n_miey7eb2Xr",
2023-02-14 00:25:11 +01:00
"outputId": "7ec0c9cd-92b6-4c6f-eaa9-8418e1d904c9"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 69,
2023-02-12 23:22:40 +01:00
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
2023-02-14 00:25:11 +01:00
"WARNING:datasets.builder:No config specified, defaulting to: emotion/split\n",
"WARNING:datasets.builder:Found cached dataset emotion (/root/.cache/huggingface/datasets/emotion/split/1.0.0/cca5efe2dfeb58c1d098e0f9eeb200e9927d889b5a03c67097275dfb5fe463bd)\n"
2023-02-12 23:22:40 +01:00
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
2023-02-14 00:25:11 +01:00
"model_id": "7e615ede17554aecbadc0b8ca5b2ff5a"
2023-02-12 23:22:40 +01:00
}
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"for source_data, dataset, max_size in [\n",
" (loaded_data['train'], data_train, None),\n",
" (loaded_data['validation'], data_valid, None),\n",
" (loaded_data['test'], data_test, None),\n",
"]:\n",
" for i, data in enumerate(source_data):\n",
" if max_size is not None and i >= max_size:\n",
" break\n",
" data_line = {\n",
" 'label': int(data['label']),\n",
" 'text': data['text'],\n",
" }\n",
" dataset.append(data_line)\n",
"\n",
"print(f'Train: {len(data_train):6d}')\n",
"print(f'Valid: {len(data_valid):6d}')\n",
"print(f'Test: {len(data_test):6d}')"
],
"metadata": {
"colab": {
2023-02-14 00:25:11 +01:00
"base_uri": "https://localhost:8080/"
2023-02-12 23:22:40 +01:00
},
"id": "BZ6afaRzGsxS",
2023-02-14 00:25:11 +01:00
"outputId": "400213ed-60a7-4079-d7e9-b99d0a6b1a19"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 70,
2023-02-12 23:22:40 +01:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Train: 16000\n",
"Valid: 2000\n",
"Test: 2000\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"MAP_LABEL_TRANSLATION = {\n",
" 0: 'sadness',\n",
" 1: 'joy',\n",
" 2: 'love',\n",
" 3: 'anger',\n",
" 4: 'fear',\n",
" 5: 'surprise',\n",
"}"
],
"metadata": {
"id": "w0KyM4TrGxQY"
},
2023-02-14 00:25:11 +01:00
"execution_count": 71,
2023-02-12 23:22:40 +01:00
"outputs": []
},
{
"cell_type": "code",
"source": [
"def save_as_translations(original_save_path: Path, data_to_save: List[Dict]) -> None:\n",
" file_name = 's2s-' + original_save_path.name\n",
" file_path = original_save_path.parent / file_name\n",
"\n",
" print(f'Saving into: {file_path}')\n",
" with open(file_path, 'wt') as f_write:\n",
" for data_line in data_to_save:\n",
" label = data_line['label']\n",
" new_label = MAP_LABEL_TRANSLATION[label]\n",
" data_line['label'] = new_label\n",
" data_line_str = json.dumps(data_line)\n",
" f_write.write(f'{data_line_str}\\n')"
],
"metadata": {
"id": "-EFRYeAYHIKN"
},
2023-02-14 00:25:11 +01:00
"execution_count": 72,
2023-02-12 23:22:40 +01:00
"outputs": []
},
{
"cell_type": "code",
"source": [
"for file_path, data_to_save in [(train_path, data_train), (valid_path, data_valid), (test_path, data_test)]:\n",
" print(f'Saving into: {file_path}')\n",
" with open(file_path, 'wt') as f_write:\n",
" for data_line in data_to_save:\n",
" data_line_str = json.dumps(data_line)\n",
" f_write.write(f'{data_line_str}\\n')\n",
" \n",
" save_as_translations(file_path, data_to_save)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7RsrTNGCHIqc",
2023-02-14 00:25:11 +01:00
"outputId": "41f73dae-2ac1-4da9-f9c6-fd56d9f9e819"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 73,
2023-02-12 23:22:40 +01:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Saving into: data/train.json\n",
"Saving into: data/s2s-train.json\n",
"Saving into: data/valid.json\n",
"Saving into: data/s2s-valid.json\n",
"Saving into: data/test.json\n",
"Saving into: data/s2s-test.json\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!head data/train.json"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Svu6YYSaHK4t",
2023-02-14 00:25:11 +01:00
"outputId": "9d3623ba-baf8-4cbe-deed-3efa4cbe7d9f"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 74,
2023-02-12 23:22:40 +01:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{\"label\": 0, \"text\": \"i didnt feel humiliated\"}\n",
"{\"label\": 0, \"text\": \"i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake\"}\n",
"{\"label\": 3, \"text\": \"im grabbing a minute to post i feel greedy wrong\"}\n",
"{\"label\": 2, \"text\": \"i am ever feeling nostalgic about the fireplace i will know that it is still on the property\"}\n",
"{\"label\": 3, \"text\": \"i am feeling grouchy\"}\n",
"{\"label\": 0, \"text\": \"ive been feeling a little burdened lately wasnt sure why that was\"}\n",
"{\"label\": 5, \"text\": \"ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funny\"}\n",
"{\"label\": 4, \"text\": \"i feel as confused about life as a teenager or as jaded as a year old man\"}\n",
"{\"label\": 1, \"text\": \"i have been with petronas for years i feel that petronas has performed well and made a huge profit\"}\n",
"{\"label\": 2, \"text\": \"i feel romantic too\"}\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!head data/s2s-train.json"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5INZa4ZJHQbn",
2023-02-14 00:25:11 +01:00
"outputId": "1a3c9934-d738-4339-f8e7-419499ab3867"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 75,
2023-02-12 23:22:40 +01:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{\"label\": \"sadness\", \"text\": \"i didnt feel humiliated\"}\n",
"{\"label\": \"sadness\", \"text\": \"i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake\"}\n",
"{\"label\": \"anger\", \"text\": \"im grabbing a minute to post i feel greedy wrong\"}\n",
"{\"label\": \"love\", \"text\": \"i am ever feeling nostalgic about the fireplace i will know that it is still on the property\"}\n",
"{\"label\": \"anger\", \"text\": \"i am feeling grouchy\"}\n",
"{\"label\": \"sadness\", \"text\": \"ive been feeling a little burdened lately wasnt sure why that was\"}\n",
"{\"label\": \"surprise\", \"text\": \"ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funny\"}\n",
"{\"label\": \"fear\", \"text\": \"i feel as confused about life as a teenager or as jaded as a year old man\"}\n",
"{\"label\": \"joy\", \"text\": \"i have been with petronas for years i feel that petronas has performed well and made a huge profit\"}\n",
"{\"label\": \"love\", \"text\": \"i feel romantic too\"}\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# create tiny datasets for debugging purposes\n",
"for file_name in [\"s2s-train\", \"s2s-valid\", \"s2s-test\"]:\n",
" print(f\"=== {file_name} ===\")\n",
" all_text = Path(f\"data/{file_name}.json\").read_text().split('\\n')\n",
" text = all_text[:250] + all_text[-250:]\n",
" Path(f\"data/{file_name}-500.json\").write_text(\"\\n\".join(text))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OYeI-JvepSf7",
2023-02-14 00:25:11 +01:00
"outputId": "d3f86e7a-a691-498d-b6b2-61f698c2218e"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 76,
2023-02-12 23:22:40 +01:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"=== s2s-train ===\n",
"=== s2s-valid ===\n",
"=== s2s-test ===\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!wc -l data/*"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_WSOgm50LI0m",
2023-02-14 00:25:11 +01:00
"outputId": "0ad89231-0b18-4956-e10e-9975d2bb1f72"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 77,
2023-02-12 23:22:40 +01:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" 499 data/s2s-test-500.json\n",
" 2000 data/s2s-test.json\n",
" 499 data/s2s-train-500.json\n",
" 16000 data/s2s-train.json\n",
" 499 data/s2s-valid-500.json\n",
" 2000 data/s2s-valid.json\n",
" 2000 data/test.json\n",
" 16000 data/train.json\n",
" 2000 data/valid.json\n",
" 41497 total\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
2023-02-14 00:25:11 +01:00
"# Zero Shot"
2023-02-12 23:22:40 +01:00
],
"metadata": {
2023-02-14 00:25:11 +01:00
"id": "6_unwNzOsl8i"
2023-02-12 23:22:40 +01:00
}
},
2023-02-14 00:25:11 +01:00
{
"cell_type": "code",
"source": [
"from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM\n",
"import json\n",
"import time"
],
"metadata": {
"id": "pYe_v630tK8M"
},
"execution_count": 78,
"outputs": []
},
2023-02-12 23:22:40 +01:00
{
"cell_type": "code",
"source": [
"!nvidia-smi"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
2023-02-14 00:25:11 +01:00
"id": "xP_v-YiXAw5y",
"outputId": "8415f15f-cce6-4d1b-d148-2641ede4ff98"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 96,
2023-02-12 23:22:40 +01:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
2023-02-14 00:25:11 +01:00
"Mon Feb 13 23:18:24 2023 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 71C P0 31W / 70W | 7320MiB / 15360MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| 0 N/A N/A 5402 C 7317MiB |\n",
"+-----------------------------------------------------------------------------+\n"
2023-02-12 23:22:40 +01:00
]
}
]
},
{
"cell_type": "code",
"source": [
2023-02-14 00:25:11 +01:00
"if torch.cuda.is_available():\n",
" device = 0\n",
"else:\n",
" device = -1"
2023-02-12 23:22:40 +01:00
],
"metadata": {
2023-02-14 00:25:11 +01:00
"id": "tVvf2ZjwCsS2"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 89,
2023-02-12 23:22:40 +01:00
"outputs": []
},
{
"cell_type": "code",
"source": [
2023-02-14 00:25:11 +01:00
"def get_pipeline(pipeline_type: str, model_name: str, torch_dtype: torch.dtype=\"auto\"):\n",
" class_type = AutoModelForSeq2SeqLM\n",
" model = class_type.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float32)\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
" return pipeline(pipeline_type, model=model, tokenizer=tokenizer, device=device)"
2023-02-12 23:22:40 +01:00
],
"metadata": {
2023-02-14 00:25:11 +01:00
"id": "1V60Aax5tJt1"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 91,
"outputs": []
2023-02-12 23:22:40 +01:00
},
{
"cell_type": "code",
"source": [
2023-02-14 00:25:11 +01:00
"lm_pipeline = get_pipeline('text2text-generation', 'google/flan-t5-large')"
2023-02-12 23:22:40 +01:00
],
"metadata": {
2023-02-14 00:25:11 +01:00
"id": "BkhC-gr2soFF"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 92,
2023-02-12 23:22:40 +01:00
"outputs": []
},
{
"cell_type": "code",
"source": [
2023-02-14 00:25:11 +01:00
"def generate_prompt(text):\n",
" labels = \"possible labels: sadness, joy, love, anger, surprise, fear\"\n",
" prompt = labels + '\\n' + f'text: {text}' + '\\n' + 'label: '\n",
" return prompt"
2023-02-12 23:22:40 +01:00
],
"metadata": {
2023-02-14 00:25:11 +01:00
"id": "6F30kPAqvYwb"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 97,
"outputs": []
2023-02-12 23:22:40 +01:00
},
{
2023-02-14 00:25:11 +01:00
"cell_type": "code",
2023-02-12 23:22:40 +01:00
"source": [
2023-02-14 00:25:11 +01:00
"def predict(text):\n",
" return lm_pipeline(generate_prompt(text), do_sample=False)[0]['generated_text']"
2023-02-12 23:22:40 +01:00
],
"metadata": {
2023-02-14 00:25:11 +01:00
"id": "pdlOh5x3zvOT"
},
"execution_count": 98,
"outputs": []
2023-02-12 23:22:40 +01:00
},
{
"cell_type": "code",
"source": [
2023-02-14 00:25:11 +01:00
"with open('data/s2s-test.json') as f:\n",
" time_start = time.time()\n",
" total = 0\n",
" correct = 0\n",
" lines = f.readlines()\n",
" test_cases_amount = len(lines)\n",
" for line in lines:\n",
" item = json.loads(line)\n",
" text = item['text']\n",
" label = item['label']\n",
" total += 1\n",
" if total % 50 == 0:\n",
" print(f'{total}/{test_cases_amount}')\n",
" if predict(text) == label:\n",
" correct += 1\n",
" time_end = time.time()\n",
" print(f'Minutes elapsed: {(time_end - time_start) / 60}')\n",
" print(f'Accuracy: {correct/total}')"
2023-02-12 23:22:40 +01:00
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
2023-02-14 00:25:11 +01:00
"id": "yP2fKz87tqGr",
"outputId": "89b43c0d-9c1f-4623-80ce-210f3448adff"
2023-02-12 23:22:40 +01:00
},
2023-02-14 00:25:11 +01:00
"execution_count": 99,
2023-02-12 23:22:40 +01:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
2023-02-14 00:25:11 +01:00
"50/2000\n",
"100/2000\n",
"150/2000\n",
"200/2000\n",
"250/2000\n",
"300/2000\n",
"350/2000\n",
"400/2000\n",
"450/2000\n",
"500/2000\n",
"550/2000\n",
"600/2000\n",
"650/2000\n",
"700/2000\n",
"750/2000\n",
"800/2000\n",
"850/2000\n",
"900/2000\n",
"950/2000\n",
"1000/2000\n",
"1050/2000\n",
"1100/2000\n",
"1150/2000\n",
"1200/2000\n",
"1250/2000\n",
"1300/2000\n",
"1350/2000\n",
"1400/2000\n",
"1450/2000\n",
"1500/2000\n",
"1550/2000\n",
"1600/2000\n",
"1650/2000\n",
"1700/2000\n",
"1750/2000\n",
"1800/2000\n",
"1850/2000\n",
"1900/2000\n",
"1950/2000\n",
"2000/2000\n",
"Minutes elapsed: 3.088933833440145\n",
"Accuracy: 0.6505\n"
2023-02-12 23:22:40 +01:00
]
}
]
}
]
}