This commit is contained in:
s487194 2024-02-01 02:41:55 +01:00
parent b7e9ff1cbf
commit b9f590cf3f

653
T5_encoder_decoder.ipynb Normal file
View File

@ -0,0 +1,653 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"7cd0b9f562c3497fa8521900361aa1b6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_6a752c40c5ed4a848718fc5515ba3ff6",
"IPY_MODEL_a3c49c91d53b4a0fbaa06cf9edc53902",
"IPY_MODEL_8f42cebfbb2c44b5b65dc403950ab7b6"
],
"layout": "IPY_MODEL_66fdc95b84a7421f8394ff0ccabfd4ee"
}
},
"6a752c40c5ed4a848718fc5515ba3ff6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_bab672d9adef477a8d6db292e5217ae2",
"placeholder": "",
"style": "IPY_MODEL_bb31a82ad4df4a0bbdcdc7fee5c333f0",
"value": "Map: 100%"
}
},
"a3c49c91d53b4a0fbaa06cf9edc53902": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_44645ba3efc546949f8de88a69109319",
"max": 1024,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_e2d98d9e52f5402db83ac5290b6b5e66",
"value": 1024
}
},
"8f42cebfbb2c44b5b65dc403950ab7b6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_24cca48948be45779949b96cd63a70c8",
"placeholder": "",
"style": "IPY_MODEL_8647bedfaa684f139d3534552a3fb493",
"value": " 1024/1024 [00:00<00:00, 1828.40 examples/s]"
}
},
"66fdc95b84a7421f8394ff0ccabfd4ee": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"bab672d9adef477a8d6db292e5217ae2": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"bb31a82ad4df4a0bbdcdc7fee5c333f0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"44645ba3efc546949f8de88a69109319": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e2d98d9e52f5402db83ac5290b6b5e66": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"24cca48948be45779949b96cd63a70c8": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"8647bedfaa684f139d3534552a3fb493": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"cells": [
{
"cell_type": "code",
"source": [
"!pip install transformers[torch] datasets evaluate"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0ju_uJOjZLE3",
"outputId": "b3836512-aed0-4e7f-81c8-895d595b589a"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: transformers[torch] in /usr/local/lib/python3.10/dist-packages (4.35.2)\n",
"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.16.1)\n",
"Requirement already satisfied: evaluate in /usr/local/lib/python3.10/dist-packages (0.4.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (3.13.1)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.20.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (1.23.5)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (23.2)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2023.6.3)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2.31.0)\n",
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.15.1)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.4.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (4.66.1)\n",
"Requirement already satisfied: torch!=1.12.0,>=1.10 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2.1.0+cu121)\n",
"Requirement already satisfied: accelerate>=0.20.3 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.26.1)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (10.0.1)\n",
"Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.1)\n",
"Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.18.0)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.20.3->transformers[torch]) (5.9.5)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers[torch]) (4.5.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (2023.11.17)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (3.1.3)\n",
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (2.1.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch!=1.12.0,>=1.10->transformers[torch]) (2.1.4)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch!=1.12.0,>=1.10->transformers[torch]) (1.3.0)\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "XLFWhjbUYt_I"
},
"outputs": [],
"source": [
"from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification\n"
]
},
{
"cell_type": "code",
"source": [
"import datasets\n",
"data = \"ag_news\"\n",
"dataset = datasets.load_dataset(data)\n",
"\n",
"train_dataset=dataset['train']\n",
"eval_dataset=dataset['test']\n",
"shuffled_train_dataset = train_dataset.shuffle(seed=42).select(range(5052))\n",
"shuffled_eval_dataset = eval_dataset.shuffle(seed=42).select(range(1024))\n",
"\n",
"\n"
],
"metadata": {
"id": "yECUnWmvZChI",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d34d1f53-77af-4e37-d501-f8144ad417e0"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n",
"The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
"To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
"You will be able to reuse this secret in all of your notebooks.\n",
"Please note that authentication is recommended but still optional to access public models or datasets.\n",
" warnings.warn(\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"model_name = \"t5-base\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = 128)\n",
"\n",
"def tokenize_function(examples):\n",
" return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
"\n",
"\n",
"tokenized_dataset_train = shuffled_train_dataset.map(tokenize_function, batched=True)\n",
"tokenized_dataset_eval = shuffled_eval_dataset.map(tokenize_function, batched=True)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 49,
"referenced_widgets": [
"7cd0b9f562c3497fa8521900361aa1b6",
"6a752c40c5ed4a848718fc5515ba3ff6",
"a3c49c91d53b4a0fbaa06cf9edc53902",
"8f42cebfbb2c44b5b65dc403950ab7b6",
"66fdc95b84a7421f8394ff0ccabfd4ee",
"bab672d9adef477a8d6db292e5217ae2",
"bb31a82ad4df4a0bbdcdc7fee5c333f0",
"44645ba3efc546949f8de88a69109319",
"e2d98d9e52f5402db83ac5290b6b5e66",
"24cca48948be45779949b96cd63a70c8",
"8647bedfaa684f139d3534552a3fb493"
]
},
"id": "zFaiMEblaTFy",
"outputId": "d397d1d4-6b4d-4913-de53-788d2d4c698c"
},
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Map: 0%| | 0/1024 [00:00<?, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "7cd0b9f562c3497fa8521900361aa1b6"
}
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"num_classes = 4\n",
"model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_classes)"
],
"metadata": {
"id": "ni4XQU0ia5Kd",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "75930083-b0a0-4f4c-840d-35222e898719"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at t5-base and are newly initialized: ['classification_head.out_proj.bias', 'classification_head.out_proj.weight', 'classification_head.dense.weight', 'classification_head.dense.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import evaluate\n",
"\n",
"accuracy = evaluate.load(\"accuracy\")\n",
"\n",
"def compute_metrics(p):\n",
" predictions, labels = p.predictions, p.label_ids\n",
" predictions = np.squeeze(predictions[0])\n",
" predictions = predictions.reshape(-1, predictions.shape[-1])\n",
" predictions = np.argmax(predictions, axis=-1)\n",
" return accuracy.compute(predictions=predictions, references=labels)\n"
],
"metadata": {
"id": "hpRwZ0QhrV92"
},
"execution_count": 44,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_arguments = TrainingArguments(num_train_epochs=1, per_device_train_batch_size=32, output_dir=\"./Output\", evaluation_strategy=\"epoch\",)\n",
"trainer=Trainer(model=model, args=train_arguments, train_dataset=tokenized_dataset_train, eval_dataset=tokenized_dataset_eval, compute_metrics=compute_metrics)"
],
"metadata": {
"id": "P4hpYtvHbRex"
},
"execution_count": 45,
"outputs": []
},
{
"cell_type": "code",
"source": [
"trainer.train()"
],
"metadata": {
"id": "VkPSt1vPb4eO",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 141
},
"outputId": "ed0f490c-f168-4382-8211-c05ee12be011"
},
"execution_count": 46,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='158' max='158' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [158/158 04:34, Epoch 1/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Epoch</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" <th>Accuracy</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>No log</td>\n",
" <td>1.616516</td>\n",
" <td>0.888672</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
]
},
"metadata": {}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TrainOutput(global_step=158, training_loss=0.03489681135250043, metrics={'train_runtime': 276.2183, 'train_samples_per_second': 18.29, 'train_steps_per_second': 0.572, 'total_flos': 771417209622528.0, 'train_loss': 0.03489681135250043, 'epoch': 1.0})"
]
},
"metadata": {},
"execution_count": 46
}
]
}
]
}