add T5
This commit is contained in:
parent
b7e9ff1cbf
commit
b9f590cf3f
653
T5_encoder_decoder.ipynb
Normal file
653
T5_encoder_decoder.ipynb
Normal file
@ -0,0 +1,653 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"gpuType": "T4"
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"accelerator": "GPU",
|
||||
"widgets": {
|
||||
"application/vnd.jupyter.widget-state+json": {
|
||||
"7cd0b9f562c3497fa8521900361aa1b6": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_name": "HBoxModel",
|
||||
"model_module_version": "1.5.0",
|
||||
"state": {
|
||||
"_dom_classes": [],
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "HBoxModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/controls",
|
||||
"_view_module_version": "1.5.0",
|
||||
"_view_name": "HBoxView",
|
||||
"box_style": "",
|
||||
"children": [
|
||||
"IPY_MODEL_6a752c40c5ed4a848718fc5515ba3ff6",
|
||||
"IPY_MODEL_a3c49c91d53b4a0fbaa06cf9edc53902",
|
||||
"IPY_MODEL_8f42cebfbb2c44b5b65dc403950ab7b6"
|
||||
],
|
||||
"layout": "IPY_MODEL_66fdc95b84a7421f8394ff0ccabfd4ee"
|
||||
}
|
||||
},
|
||||
"6a752c40c5ed4a848718fc5515ba3ff6": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_name": "HTMLModel",
|
||||
"model_module_version": "1.5.0",
|
||||
"state": {
|
||||
"_dom_classes": [],
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "HTMLModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/controls",
|
||||
"_view_module_version": "1.5.0",
|
||||
"_view_name": "HTMLView",
|
||||
"description": "",
|
||||
"description_tooltip": null,
|
||||
"layout": "IPY_MODEL_bab672d9adef477a8d6db292e5217ae2",
|
||||
"placeholder": "",
|
||||
"style": "IPY_MODEL_bb31a82ad4df4a0bbdcdc7fee5c333f0",
|
||||
"value": "Map: 100%"
|
||||
}
|
||||
},
|
||||
"a3c49c91d53b4a0fbaa06cf9edc53902": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_name": "FloatProgressModel",
|
||||
"model_module_version": "1.5.0",
|
||||
"state": {
|
||||
"_dom_classes": [],
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "FloatProgressModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/controls",
|
||||
"_view_module_version": "1.5.0",
|
||||
"_view_name": "ProgressView",
|
||||
"bar_style": "success",
|
||||
"description": "",
|
||||
"description_tooltip": null,
|
||||
"layout": "IPY_MODEL_44645ba3efc546949f8de88a69109319",
|
||||
"max": 1024,
|
||||
"min": 0,
|
||||
"orientation": "horizontal",
|
||||
"style": "IPY_MODEL_e2d98d9e52f5402db83ac5290b6b5e66",
|
||||
"value": 1024
|
||||
}
|
||||
},
|
||||
"8f42cebfbb2c44b5b65dc403950ab7b6": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_name": "HTMLModel",
|
||||
"model_module_version": "1.5.0",
|
||||
"state": {
|
||||
"_dom_classes": [],
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "HTMLModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/controls",
|
||||
"_view_module_version": "1.5.0",
|
||||
"_view_name": "HTMLView",
|
||||
"description": "",
|
||||
"description_tooltip": null,
|
||||
"layout": "IPY_MODEL_24cca48948be45779949b96cd63a70c8",
|
||||
"placeholder": "",
|
||||
"style": "IPY_MODEL_8647bedfaa684f139d3534552a3fb493",
|
||||
"value": " 1024/1024 [00:00<00:00, 1828.40 examples/s]"
|
||||
}
|
||||
},
|
||||
"66fdc95b84a7421f8394ff0ccabfd4ee": {
|
||||
"model_module": "@jupyter-widgets/base",
|
||||
"model_name": "LayoutModel",
|
||||
"model_module_version": "1.2.0",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/base",
|
||||
"_model_module_version": "1.2.0",
|
||||
"_model_name": "LayoutModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "LayoutView",
|
||||
"align_content": null,
|
||||
"align_items": null,
|
||||
"align_self": null,
|
||||
"border": null,
|
||||
"bottom": null,
|
||||
"display": null,
|
||||
"flex": null,
|
||||
"flex_flow": null,
|
||||
"grid_area": null,
|
||||
"grid_auto_columns": null,
|
||||
"grid_auto_flow": null,
|
||||
"grid_auto_rows": null,
|
||||
"grid_column": null,
|
||||
"grid_gap": null,
|
||||
"grid_row": null,
|
||||
"grid_template_areas": null,
|
||||
"grid_template_columns": null,
|
||||
"grid_template_rows": null,
|
||||
"height": null,
|
||||
"justify_content": null,
|
||||
"justify_items": null,
|
||||
"left": null,
|
||||
"margin": null,
|
||||
"max_height": null,
|
||||
"max_width": null,
|
||||
"min_height": null,
|
||||
"min_width": null,
|
||||
"object_fit": null,
|
||||
"object_position": null,
|
||||
"order": null,
|
||||
"overflow": null,
|
||||
"overflow_x": null,
|
||||
"overflow_y": null,
|
||||
"padding": null,
|
||||
"right": null,
|
||||
"top": null,
|
||||
"visibility": null,
|
||||
"width": null
|
||||
}
|
||||
},
|
||||
"bab672d9adef477a8d6db292e5217ae2": {
|
||||
"model_module": "@jupyter-widgets/base",
|
||||
"model_name": "LayoutModel",
|
||||
"model_module_version": "1.2.0",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/base",
|
||||
"_model_module_version": "1.2.0",
|
||||
"_model_name": "LayoutModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "LayoutView",
|
||||
"align_content": null,
|
||||
"align_items": null,
|
||||
"align_self": null,
|
||||
"border": null,
|
||||
"bottom": null,
|
||||
"display": null,
|
||||
"flex": null,
|
||||
"flex_flow": null,
|
||||
"grid_area": null,
|
||||
"grid_auto_columns": null,
|
||||
"grid_auto_flow": null,
|
||||
"grid_auto_rows": null,
|
||||
"grid_column": null,
|
||||
"grid_gap": null,
|
||||
"grid_row": null,
|
||||
"grid_template_areas": null,
|
||||
"grid_template_columns": null,
|
||||
"grid_template_rows": null,
|
||||
"height": null,
|
||||
"justify_content": null,
|
||||
"justify_items": null,
|
||||
"left": null,
|
||||
"margin": null,
|
||||
"max_height": null,
|
||||
"max_width": null,
|
||||
"min_height": null,
|
||||
"min_width": null,
|
||||
"object_fit": null,
|
||||
"object_position": null,
|
||||
"order": null,
|
||||
"overflow": null,
|
||||
"overflow_x": null,
|
||||
"overflow_y": null,
|
||||
"padding": null,
|
||||
"right": null,
|
||||
"top": null,
|
||||
"visibility": null,
|
||||
"width": null
|
||||
}
|
||||
},
|
||||
"bb31a82ad4df4a0bbdcdc7fee5c333f0": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_name": "DescriptionStyleModel",
|
||||
"model_module_version": "1.5.0",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "DescriptionStyleModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "StyleView",
|
||||
"description_width": ""
|
||||
}
|
||||
},
|
||||
"44645ba3efc546949f8de88a69109319": {
|
||||
"model_module": "@jupyter-widgets/base",
|
||||
"model_name": "LayoutModel",
|
||||
"model_module_version": "1.2.0",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/base",
|
||||
"_model_module_version": "1.2.0",
|
||||
"_model_name": "LayoutModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "LayoutView",
|
||||
"align_content": null,
|
||||
"align_items": null,
|
||||
"align_self": null,
|
||||
"border": null,
|
||||
"bottom": null,
|
||||
"display": null,
|
||||
"flex": null,
|
||||
"flex_flow": null,
|
||||
"grid_area": null,
|
||||
"grid_auto_columns": null,
|
||||
"grid_auto_flow": null,
|
||||
"grid_auto_rows": null,
|
||||
"grid_column": null,
|
||||
"grid_gap": null,
|
||||
"grid_row": null,
|
||||
"grid_template_areas": null,
|
||||
"grid_template_columns": null,
|
||||
"grid_template_rows": null,
|
||||
"height": null,
|
||||
"justify_content": null,
|
||||
"justify_items": null,
|
||||
"left": null,
|
||||
"margin": null,
|
||||
"max_height": null,
|
||||
"max_width": null,
|
||||
"min_height": null,
|
||||
"min_width": null,
|
||||
"object_fit": null,
|
||||
"object_position": null,
|
||||
"order": null,
|
||||
"overflow": null,
|
||||
"overflow_x": null,
|
||||
"overflow_y": null,
|
||||
"padding": null,
|
||||
"right": null,
|
||||
"top": null,
|
||||
"visibility": null,
|
||||
"width": null
|
||||
}
|
||||
},
|
||||
"e2d98d9e52f5402db83ac5290b6b5e66": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_name": "ProgressStyleModel",
|
||||
"model_module_version": "1.5.0",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "ProgressStyleModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "StyleView",
|
||||
"bar_color": null,
|
||||
"description_width": ""
|
||||
}
|
||||
},
|
||||
"24cca48948be45779949b96cd63a70c8": {
|
||||
"model_module": "@jupyter-widgets/base",
|
||||
"model_name": "LayoutModel",
|
||||
"model_module_version": "1.2.0",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/base",
|
||||
"_model_module_version": "1.2.0",
|
||||
"_model_name": "LayoutModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "LayoutView",
|
||||
"align_content": null,
|
||||
"align_items": null,
|
||||
"align_self": null,
|
||||
"border": null,
|
||||
"bottom": null,
|
||||
"display": null,
|
||||
"flex": null,
|
||||
"flex_flow": null,
|
||||
"grid_area": null,
|
||||
"grid_auto_columns": null,
|
||||
"grid_auto_flow": null,
|
||||
"grid_auto_rows": null,
|
||||
"grid_column": null,
|
||||
"grid_gap": null,
|
||||
"grid_row": null,
|
||||
"grid_template_areas": null,
|
||||
"grid_template_columns": null,
|
||||
"grid_template_rows": null,
|
||||
"height": null,
|
||||
"justify_content": null,
|
||||
"justify_items": null,
|
||||
"left": null,
|
||||
"margin": null,
|
||||
"max_height": null,
|
||||
"max_width": null,
|
||||
"min_height": null,
|
||||
"min_width": null,
|
||||
"object_fit": null,
|
||||
"object_position": null,
|
||||
"order": null,
|
||||
"overflow": null,
|
||||
"overflow_x": null,
|
||||
"overflow_y": null,
|
||||
"padding": null,
|
||||
"right": null,
|
||||
"top": null,
|
||||
"visibility": null,
|
||||
"width": null
|
||||
}
|
||||
},
|
||||
"8647bedfaa684f139d3534552a3fb493": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_name": "DescriptionStyleModel",
|
||||
"model_module_version": "1.5.0",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "DescriptionStyleModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "StyleView",
|
||||
"description_width": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install transformers[torch] datasets evaluate"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "0ju_uJOjZLE3",
|
||||
"outputId": "b3836512-aed0-4e7f-81c8-895d595b589a"
|
||||
},
|
||||
"execution_count": 1,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Requirement already satisfied: transformers[torch] in /usr/local/lib/python3.10/dist-packages (4.35.2)\n",
|
||||
"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.16.1)\n",
|
||||
"Requirement already satisfied: evaluate in /usr/local/lib/python3.10/dist-packages (0.4.1)\n",
|
||||
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (3.13.1)\n",
|
||||
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.20.3)\n",
|
||||
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (1.23.5)\n",
|
||||
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (23.2)\n",
|
||||
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (6.0.1)\n",
|
||||
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2023.6.3)\n",
|
||||
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2.31.0)\n",
|
||||
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.15.1)\n",
|
||||
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.4.2)\n",
|
||||
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (4.66.1)\n",
|
||||
"Requirement already satisfied: torch!=1.12.0,>=1.10 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2.1.0+cu121)\n",
|
||||
"Requirement already satisfied: accelerate>=0.20.3 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.26.1)\n",
|
||||
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (10.0.1)\n",
|
||||
"Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n",
|
||||
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n",
|
||||
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n",
|
||||
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n",
|
||||
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n",
|
||||
"Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n",
|
||||
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.1)\n",
|
||||
"Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.18.0)\n",
|
||||
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.20.3->transformers[torch]) (5.9.5)\n",
|
||||
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0)\n",
|
||||
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n",
|
||||
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n",
|
||||
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n",
|
||||
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
|
||||
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
|
||||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers[torch]) (4.5.0)\n",
|
||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (3.3.2)\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (3.6)\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (2.0.7)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (2023.11.17)\n",
|
||||
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (1.12)\n",
|
||||
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (3.2.1)\n",
|
||||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (3.1.3)\n",
|
||||
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (2.1.0)\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
|
||||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n",
|
||||
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
|
||||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch!=1.12.0,>=1.10->transformers[torch]) (2.1.4)\n",
|
||||
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch!=1.12.0,>=1.10->transformers[torch]) (1.3.0)\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"id": "XLFWhjbUYt_I"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import datasets\n",
|
||||
"data = \"ag_news\"\n",
|
||||
"dataset = datasets.load_dataset(data)\n",
|
||||
"\n",
|
||||
"train_dataset=dataset['train']\n",
|
||||
"eval_dataset=dataset['test']\n",
|
||||
"shuffled_train_dataset = train_dataset.shuffle(seed=42).select(range(5052))\n",
|
||||
"shuffled_eval_dataset = eval_dataset.shuffle(seed=42).select(range(1024))\n",
|
||||
"\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "yECUnWmvZChI",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"outputId": "d34d1f53-77af-4e37-d501-f8144ad417e0"
|
||||
},
|
||||
"execution_count": 3,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n",
|
||||
"The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
|
||||
"To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
|
||||
"You will be able to reuse this secret in all of your notebooks.\n",
|
||||
"Please note that authentication is recommended but still optional to access public models or datasets.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"model_name = \"t5-base\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = 128)\n",
|
||||
"\n",
|
||||
"def tokenize_function(examples):\n",
|
||||
" return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tokenized_dataset_train = shuffled_train_dataset.map(tokenize_function, batched=True)\n",
|
||||
"tokenized_dataset_eval = shuffled_eval_dataset.map(tokenize_function, batched=True)\n"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 49,
|
||||
"referenced_widgets": [
|
||||
"7cd0b9f562c3497fa8521900361aa1b6",
|
||||
"6a752c40c5ed4a848718fc5515ba3ff6",
|
||||
"a3c49c91d53b4a0fbaa06cf9edc53902",
|
||||
"8f42cebfbb2c44b5b65dc403950ab7b6",
|
||||
"66fdc95b84a7421f8394ff0ccabfd4ee",
|
||||
"bab672d9adef477a8d6db292e5217ae2",
|
||||
"bb31a82ad4df4a0bbdcdc7fee5c333f0",
|
||||
"44645ba3efc546949f8de88a69109319",
|
||||
"e2d98d9e52f5402db83ac5290b6b5e66",
|
||||
"24cca48948be45779949b96cd63a70c8",
|
||||
"8647bedfaa684f139d3534552a3fb493"
|
||||
]
|
||||
},
|
||||
"id": "zFaiMEblaTFy",
|
||||
"outputId": "d397d1d4-6b4d-4913-de53-788d2d4c698c"
|
||||
},
|
||||
"execution_count": 4,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Map: 0%| | 0/1024 [00:00<?, ? examples/s]"
|
||||
],
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "7cd0b9f562c3497fa8521900361aa1b6"
|
||||
}
|
||||
},
|
||||
"metadata": {}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"num_classes = 4\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_classes)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ni4XQU0ia5Kd",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"outputId": "75930083-b0a0-4f4c-840d-35222e898719"
|
||||
},
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at t5-base and are newly initialized: ['classification_head.out_proj.bias', 'classification_head.out_proj.weight', 'classification_head.dense.weight', 'classification_head.dense.bias']\n",
|
||||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import evaluate\n",
|
||||
"\n",
|
||||
"accuracy = evaluate.load(\"accuracy\")\n",
|
||||
"\n",
|
||||
"def compute_metrics(p):\n",
|
||||
" predictions, labels = p.predictions, p.label_ids\n",
|
||||
" predictions = np.squeeze(predictions[0])\n",
|
||||
" predictions = predictions.reshape(-1, predictions.shape[-1])\n",
|
||||
" predictions = np.argmax(predictions, axis=-1)\n",
|
||||
" return accuracy.compute(predictions=predictions, references=labels)\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "hpRwZ0QhrV92"
|
||||
},
|
||||
"execution_count": 44,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"train_arguments = TrainingArguments(num_train_epochs=1, per_device_train_batch_size=32, output_dir=\"./Output\", evaluation_strategy=\"epoch\",)\n",
|
||||
"trainer=Trainer(model=model, args=train_arguments, train_dataset=tokenized_dataset_train, eval_dataset=tokenized_dataset_eval, compute_metrics=compute_metrics)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "P4hpYtvHbRex"
|
||||
},
|
||||
"execution_count": 45,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"trainer.train()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "VkPSt1vPb4eO",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 141
|
||||
},
|
||||
"outputId": "ed0f490c-f168-4382-8211-c05ee12be011"
|
||||
},
|
||||
"execution_count": 46,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
],
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" \n",
|
||||
" <progress value='158' max='158' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [158/158 04:34, Epoch 1/1]\n",
|
||||
" </div>\n",
|
||||
" <table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: left;\">\n",
|
||||
" <th>Epoch</th>\n",
|
||||
" <th>Training Loss</th>\n",
|
||||
" <th>Validation Loss</th>\n",
|
||||
" <th>Accuracy</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>No log</td>\n",
|
||||
" <td>1.616516</td>\n",
|
||||
" <td>0.888672</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table><p>"
|
||||
]
|
||||
},
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TrainOutput(global_step=158, training_loss=0.03489681135250043, metrics={'train_runtime': 276.2183, 'train_samples_per_second': 18.29, 'train_steps_per_second': 0.572, 'total_flos': 771417209622528.0, 'train_loss': 0.03489681135250043, 'epoch': 1.0})"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 46
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
Loading…
Reference in New Issue
Block a user