{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Transformer.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "gpuClass": "standard", "accelerator": "GPU", "widgets": { "application/vnd.jupyter.widget-state+json": { "dd6f1c77ea87429597b5d9a34b9b3ec6": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_aaa78178103e4a40bea1047a584fcadc", "IPY_MODEL_91b8a0c785fe4426ab63ac8f4f473618", "IPY_MODEL_257ad1eb241c4231beea30fd2cb9b99d" ], "layout": "IPY_MODEL_a56e21a336d7489396022f73cf7e0743" } }, "aaa78178103e4a40bea1047a584fcadc": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_00b475bdd1184f2c809ff221bba760bf", "placeholder": "​", "style": "IPY_MODEL_7922cc0562ce4e499c02fb9e095069ee", "value": "Downloading: 100%" } }, "91b8a0c785fe4426ab63ac8f4f473618": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_67c896a3ece04e0d863ac5ba14ec336d", "max": 570, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_7abe840d4f1b4b54b6887ddc0d4ab8c6", "value": 570 } }, "257ad1eb241c4231beea30fd2cb9b99d": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_40a494d432014f928515afafe81712cb", "placeholder": "​", "style": "IPY_MODEL_b818456816674a5aa138483901437f1d", "value": " 570/570 [00:00<00:00, 16.1kB/s]" } }, "a56e21a336d7489396022f73cf7e0743": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "00b475bdd1184f2c809ff221bba760bf": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7922cc0562ce4e499c02fb9e095069ee": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "67c896a3ece04e0d863ac5ba14ec336d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7abe840d4f1b4b54b6887ddc0d4ab8c6": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "40a494d432014f928515afafe81712cb": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b818456816674a5aa138483901437f1d": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "e3af588bc95741c58241073f1bbb7329": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_dbfdc6e40f0c44678c335efc727d73a6", "IPY_MODEL_ce89511005494833b216838d1a536af9", "IPY_MODEL_0184159543a74048942795d91bf0d98d" ], "layout": "IPY_MODEL_0ebacfd58f6545929cff0a0a78188a66" } }, "dbfdc6e40f0c44678c335efc727d73a6": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_18136d44a88245148e96d498556db23b", "placeholder": "​", "style": "IPY_MODEL_88f32857f7bf4fb595d412042635e421", "value": "Downloading: 100%" } }, "ce89511005494833b216838d1a536af9": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_9cd7c765093c4bdd853c0cf66ca445d4", "max": 440473133, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_d30aa86296c34269b735b45dbc21b6a3", "value": 440473133 } }, "0184159543a74048942795d91bf0d98d": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_40ef73551c894b8aaca8af92e5b494d0", "placeholder": "​", "style": "IPY_MODEL_051e77b213ee4e12b1ca08f59c3e6b7e", "value": " 420M/420M [00:20<00:00, 20.8MB/s]" } }, "0ebacfd58f6545929cff0a0a78188a66": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "18136d44a88245148e96d498556db23b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "88f32857f7bf4fb595d412042635e421": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "9cd7c765093c4bdd853c0cf66ca445d4": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d30aa86296c34269b735b45dbc21b6a3": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "40ef73551c894b8aaca8af92e5b494d0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "051e77b213ee4e12b1ca08f59c3e6b7e": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ba0cc9d7efb84f04ab8c3c9e80a8bce8": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_32c47a52f0c14fa7a509a20e2bb2a63c", "IPY_MODEL_61d66b55dc8746239b3622b193591efb", "IPY_MODEL_e8eeff96cc8c47e8b71315ccc58204f6" ], "layout": "IPY_MODEL_8231c078232b4a45a66623890feb3ed8" } }, "32c47a52f0c14fa7a509a20e2bb2a63c": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c317a5d8ae894ebe95a1f1201ebd3328", "placeholder": "​", "style": "IPY_MODEL_46915aadfdd44407af77edc45b3a8955", "value": "Downloading: 100%" } }, "61d66b55dc8746239b3622b193591efb": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c79d655ba15a4635b01307a9c4a11530", "max": 28, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_8693cd493adc4459b8b91b057a785479", "value": 28 } }, "e8eeff96cc8c47e8b71315ccc58204f6": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_99dcab149e6849c490b0ae1708736d5f", "placeholder": "​", "style": "IPY_MODEL_a9b3df23850341439d05ad573ae52d29", "value": " 28.0/28.0 [00:00<00:00, 845B/s]" } }, "8231c078232b4a45a66623890feb3ed8": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c317a5d8ae894ebe95a1f1201ebd3328": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "46915aadfdd44407af77edc45b3a8955": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "c79d655ba15a4635b01307a9c4a11530": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8693cd493adc4459b8b91b057a785479": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "99dcab149e6849c490b0ae1708736d5f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a9b3df23850341439d05ad573ae52d29": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "3efd33efe633402583686ebc692864b5": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_ed660ca670a0408fb7196e679e9d39a1", "IPY_MODEL_a50c7d82d87b49d1b1ae9cca3f35f4cd", "IPY_MODEL_6eb3b02838044377a9aca91a8412b3e4" ], "layout": "IPY_MODEL_0f32391f1b60417ea41e3e20fb66b10e" } }, "ed660ca670a0408fb7196e679e9d39a1": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_0a120b6d26c24799b1ac0dce323185e7", "placeholder": "​", "style": "IPY_MODEL_f71cfeab2791409abc2bfcfc7c6d7b56", "value": "Downloading: 100%" } }, "a50c7d82d87b49d1b1ae9cca3f35f4cd": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_baa1e3a083524ff7885c9be9112cc149", "max": 231508, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_111f035ca3744086bd99a2339e82c58b", "value": 231508 } }, "6eb3b02838044377a9aca91a8412b3e4": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_795450795bba45658bd6841b072578fa", "placeholder": "​", "style": "IPY_MODEL_0e16985134ff46bb9d9201efc7fab653", "value": " 226k/226k [00:00<00:00, 569kB/s]" } }, "0f32391f1b60417ea41e3e20fb66b10e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0a120b6d26c24799b1ac0dce323185e7": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f71cfeab2791409abc2bfcfc7c6d7b56": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "baa1e3a083524ff7885c9be9112cc149": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "111f035ca3744086bd99a2339e82c58b": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "795450795bba45658bd6841b072578fa": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0e16985134ff46bb9d9201efc7fab653": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "099918165ca1455b88ea3fc4f8fae020": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_8cd150725d404a16be61deb89992a8fd", "IPY_MODEL_a10fb1acbd4d40b89813533ca51f8297", "IPY_MODEL_56685835f4d7459ab4801e068af57799" ], "layout": "IPY_MODEL_922eb6dedf3f47988d0cf5cc8b3b4aae" } }, "8cd150725d404a16be61deb89992a8fd": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_a854871fb2734de1ae15003eddfa351e", "placeholder": "​", "style": "IPY_MODEL_ef113d3fa1b942c896abf1673e977db9", "value": "Downloading: 100%" } }, "a10fb1acbd4d40b89813533ca51f8297": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_2d18af4b4b284dfb86b504fceb94096d", "max": 466062, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_893537d3d1d340d88458f4bff8741327", "value": 466062 } }, "56685835f4d7459ab4801e068af57799": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ce033673385b4894b10cc8074bc71a18", "placeholder": "​", "style": "IPY_MODEL_5ba9f0fc73644d07835f284f3dd8f0a4", "value": " 455k/455k [00:00<00:00, 835kB/s]" } }, "922eb6dedf3f47988d0cf5cc8b3b4aae": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a854871fb2734de1ae15003eddfa351e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ef113d3fa1b942c896abf1673e977db9": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "2d18af4b4b284dfb86b504fceb94096d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "893537d3d1d340d88458f4bff8741327": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "ce033673385b4894b10cc8074bc71a18": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5ba9f0fc73644d07835f284f3dd8f0a4": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fzHUPGMyWNxK", "outputId": "0158026b-4198-4a4f-e41e-347f9c57f7cd" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting transformers\n", " Downloading transformers-4.20.1-py3-none-any.whl (4.4 MB)\n", "\u001b[K |████████████████████████████████| 4.4 MB 7.9 MB/s \n", "\u001b[?25hCollecting tokenizers!=0.11.3,<0.13,>=0.11.1\n", " Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)\n", "\u001b[K |████████████████████████████████| 6.6 MB 47.5 MB/s \n", "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)\n", "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.4)\n", "Collecting huggingface-hub<1.0,>=0.1.0\n", " Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)\n", "\u001b[K |████████████████████████████████| 101 kB 12.9 MB/s \n", "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2022.6.2)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.1)\n", "Collecting pyyaml>=5.1\n", " Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n", "\u001b[K |████████████████████████████████| 596 kB 67.7 MB/s \n", "\u001b[?25hRequirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.1.1)\n", "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)\n", "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.6.15)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n", "Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers\n", " Attempting uninstall: pyyaml\n", " Found existing installation: PyYAML 3.13\n", " Uninstalling PyYAML-3.13:\n", " Successfully uninstalled PyYAML-3.13\n", "Successfully installed huggingface-hub-0.8.1 pyyaml-6.0 tokenizers-0.12.1 transformers-4.20.1\n" ] } ], "source": [ "!pip install transformers\n", "import re\n", "import torch\n", "import torch.nn as nn\n", "import pandas as pd\n", "import numpy as np\n", "from transformers import pipeline, set_seed\n", "from transformers import RobertaTokenizer, RobertaModel\n", "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n", "from transformers import AutoModel, BertTokenizerFast" ] }, { "cell_type": "code", "source": [ "def load_data(path):\n", " #return pd.read_csv(path, sep='\\t', header=None)\n", " with open(path, 'r', encoding='utf8') as f:\n", " return f.readlines()" ], "metadata": { "id": "fXQPC07mWhqf" }, "execution_count": 2, "outputs": [] }, { "cell_type": "code", "source": [ "def write_res(data, path):\n", " with open(path, 'w') as f:\n", " for line in data:\n", " f.write(f'{line}\\n')\n", " print(f\"Data written {path}/out.tsv\")" ], "metadata": { "id": "crkUIjgLWiiO" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "device" ], "metadata": { "id": "mR2UfBRYdlMW", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "ae8eeb0d-bf68-4575-af0c-da82a71808ee" }, "execution_count": 4, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "device(type='cuda')" ] }, "metadata": {}, "execution_count": 4 } ] }, { "cell_type": "code", "source": [ "import pandas as pd\n", "train_input = pd.read_csv(\"/content/drive/MyDrive/paranormal-or-skeptic/train/in.tsv\", sep = '\\t', names = ['text', 'label'], header=None, nrows=10000)\n", "# train['text'] = train['text'].apply(lambda x: tokenizer(x, return_tensors='pt'))\n", "train_input['label'] = pd.read_csv(\"/content/drive/MyDrive/paranormal-or-skeptic/train/expected.tsv\", header=None, nrows=10000)\n", "train_input" ], "metadata": { "id": "pTjGxxAu6b-v", "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "outputId": "61f5911c-4dd6-4cde-dbeb-43b58fde3f7e" }, "execution_count": 5, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " text label\n", "0 have you had an medical issues recently? 1\n", "1 It's supposedly aluminum, barium, and strontiu... 0\n", "2 Nobel prizes don't make you rich. 0\n", "3 I came for the article, I stayed for the doctor. 0\n", "4 you resorted to insults AND got owned directly... 0\n", "... ... ...\n", "9995 >a very very very very very liberal college... 0\n", "9996 To be fair, most of Newton's writings were on ... 0\n", "9997 your elementary idea is brilliant 0\n", "9998 I know! I was like ...Simon Pegg?? 1\n", "9999 You seem to have missed the purpose of my post... 0\n", "\n", "[10000 rows x 2 columns]" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabel
0have you had an medical issues recently?1
1It's supposedly aluminum, barium, and strontiu...0
2Nobel prizes don't make you rich.0
3I came for the article, I stayed for the doctor.0
4you resorted to insults AND got owned directly...0
.........
9995&gt;a very very very very very liberal college...0
9996To be fair, most of Newton's writings were on ...0
9997your elementary idea is brilliant0
9998I know! I was like ...Simon Pegg??1
9999You seem to have missed the purpose of my post...0
\n", "

10000 rows × 2 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 5 } ] }, { "cell_type": "code", "source": [ "dev = pd.read_csv(\"/content/drive/MyDrive/paranormal-or-skeptic/dev-0/in.tsv\", sep = '\\t', names = ['text', 'label'], header=None)\n", "# test['text'] = test['text'].apply(lambda x: tokenizer(x, return_tensors='pt'))\n", "dev['label'] = pd.read_csv(\"/content/drive/MyDrive/paranormal-or-skeptic/dev-0/expected.tsv\", header=None)\n", "dev" ], "metadata": { "id": "11xq68SHnw6Q", "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "outputId": "86c69ec6-4c95-4e03-b1e1-8cbff3f3bfab" }, "execution_count": 6, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " text label\n", "0 In which case, tell them I'm in work, or dead,... 0\n", "1 Put me down as another for Mysterious Universe... 1\n", "2 The military of any country would never admit ... 1\n", "3 An example would have been more productive tha... 0\n", "4 sorry, but the authors of this article admit t... 0\n", "... ... ...\n", "5267 Your fault for going at all. That's how we get... 0\n", "5268 EVP....that's a shot in the GH drinking game. 1\n", "5269 i think a good hard massage is good for you. t... 0\n", "5270 Interesting theory. Makes my imagination run w... 1\n", "5271 Tampering of candy? More like cooking somethin... 0\n", "\n", "[5272 rows x 2 columns]" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabel
0In which case, tell them I'm in work, or dead,...0
1Put me down as another for Mysterious Universe...1
2The military of any country would never admit ...1
3An example would have been more productive tha...0
4sorry, but the authors of this article admit t...0
.........
5267Your fault for going at all. That's how we get...0
5268EVP....that's a shot in the GH drinking game.1
5269i think a good hard massage is good for you. t...0
5270Interesting theory. Makes my imagination run w...1
5271Tampering of candy? More like cooking somethin...0
\n", "

5272 rows × 2 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 6 } ] }, { "cell_type": "code", "source": [ "test = pd.read_csv(\"/content/drive/MyDrive/paranormal-or-skeptic/test-A/in.tsv\", sep = '\\t', names = ['text', 'label'], header=None)\n", "# test['text'] = test['text'].apply(lambda x: tokenizer(x, return_tensors='pt'))\n", "# test['label'] = pd.read_csv(\"/content/drive/MyDrive/paranormal-or-skeptic/test-A/expected.tsv\", header=None)\n", "test = test.drop(['label'], axis=1)\n", "test" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "id": "5S172kriE_za", "outputId": "b0e772a6-b007-400a-a010-93e370619c9b" }, "execution_count": 7, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " text\n", "0 Gentleman, I believe we can agree that this is...\n", "1 The problem is that it will just turn it r/nos...\n", "2 Well, according to some Christian apologists, ...\n", "3 Don't know if this is what you are looking for...\n", "4 I respect what you're saying completely. I jus...\n", "... ...\n", "5147 GAMBIT\n", "5148 >Joe Rogan is no snake oil salesman.\\n\\nHe ...\n", "5149 Reading further, Sagan does seem to agree with...\n", "5150 Notice that they never invoke god, or any othe...\n", "5151 They might co-ordinate an anniversary attack o...\n", "\n", "[5152 rows x 1 columns]" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
text
0Gentleman, I believe we can agree that this is...
1The problem is that it will just turn it r/nos...
2Well, according to some Christian apologists, ...
3Don't know if this is what you are looking for...
4I respect what you're saying completely. I jus...
......
5147GAMBIT
5148&gt;Joe Rogan is no snake oil salesman.\\n\\nHe ...
5149Reading further, Sagan does seem to agree with...
5150Notice that they never invoke god, or any othe...
5151They might co-ordinate an anniversary attack o...
\n", "

5152 rows × 1 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 7 } ] }, { "cell_type": "code", "source": [ "# import BERT-base pretrained model\n", "bert = AutoModel.from_pretrained('bert-base-uncased')\n", "\n", "# Load the BERT tokenizer\n", "tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 249, "referenced_widgets": [ "dd6f1c77ea87429597b5d9a34b9b3ec6", "aaa78178103e4a40bea1047a584fcadc", "91b8a0c785fe4426ab63ac8f4f473618", "257ad1eb241c4231beea30fd2cb9b99d", "a56e21a336d7489396022f73cf7e0743", "00b475bdd1184f2c809ff221bba760bf", "7922cc0562ce4e499c02fb9e095069ee", "67c896a3ece04e0d863ac5ba14ec336d", "7abe840d4f1b4b54b6887ddc0d4ab8c6", "40a494d432014f928515afafe81712cb", "b818456816674a5aa138483901437f1d", "e3af588bc95741c58241073f1bbb7329", "dbfdc6e40f0c44678c335efc727d73a6", "ce89511005494833b216838d1a536af9", "0184159543a74048942795d91bf0d98d", "0ebacfd58f6545929cff0a0a78188a66", "18136d44a88245148e96d498556db23b", "88f32857f7bf4fb595d412042635e421", "9cd7c765093c4bdd853c0cf66ca445d4", "d30aa86296c34269b735b45dbc21b6a3", "40ef73551c894b8aaca8af92e5b494d0", "051e77b213ee4e12b1ca08f59c3e6b7e", "ba0cc9d7efb84f04ab8c3c9e80a8bce8", "32c47a52f0c14fa7a509a20e2bb2a63c", "61d66b55dc8746239b3622b193591efb", "e8eeff96cc8c47e8b71315ccc58204f6", "8231c078232b4a45a66623890feb3ed8", "c317a5d8ae894ebe95a1f1201ebd3328", "46915aadfdd44407af77edc45b3a8955", "c79d655ba15a4635b01307a9c4a11530", "8693cd493adc4459b8b91b057a785479", "99dcab149e6849c490b0ae1708736d5f", "a9b3df23850341439d05ad573ae52d29", "3efd33efe633402583686ebc692864b5", "ed660ca670a0408fb7196e679e9d39a1", "a50c7d82d87b49d1b1ae9cca3f35f4cd", "6eb3b02838044377a9aca91a8412b3e4", "0f32391f1b60417ea41e3e20fb66b10e", "0a120b6d26c24799b1ac0dce323185e7", "f71cfeab2791409abc2bfcfc7c6d7b56", "baa1e3a083524ff7885c9be9112cc149", "111f035ca3744086bd99a2339e82c58b", "795450795bba45658bd6841b072578fa", "0e16985134ff46bb9d9201efc7fab653", "099918165ca1455b88ea3fc4f8fae020", "8cd150725d404a16be61deb89992a8fd", "a10fb1acbd4d40b89813533ca51f8297", "56685835f4d7459ab4801e068af57799", "922eb6dedf3f47988d0cf5cc8b3b4aae", "a854871fb2734de1ae15003eddfa351e", "ef113d3fa1b942c896abf1673e977db9", "2d18af4b4b284dfb86b504fceb94096d", "893537d3d1d340d88458f4bff8741327", "ce033673385b4894b10cc8074bc71a18", "5ba9f0fc73644d07835f284f3dd8f0a4" ] }, "id": "IffbKW5BEAlb", "outputId": "de303710-5400-429c-b596-47636b278250" }, "execution_count": 8, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Downloading: 0%| | 0.00/570 [00:00" ] }, "metadata": {}, "execution_count": 9 }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAYSUlEQVR4nO3df5Ac5X3n8fcnyKCE9WmlEO8pkuokn2WniCljaQOinEvtolgS2GWRKofCpTrWnK6UynGOnfgSRFxECeA6kUB8pi7B3op0ETb2WqeYoJJxqM2aSUp/8MOyscwPK1qQsKXCKGaFnBHECeSbP/pZGG92tT2tmdkJz+dVNTXdTz/T/e2W5tMzz/TMKiIwM7M8/MRcF2BmZp3j0Dczy4hD38wsIw59M7OMOPTNzDIyb64LOJMLLrggli9fXumxp0+f5vzzz29tQS3SrbW5ruZ0a13QvbW5ruZUrevAgQM/iIifmXZhRHTtbfXq1VHVgw8+WPmx7dattbmu5nRrXRHdW5vrak7VuoCvxwy56uEdM7OMOPTNzDLi0Dczy4hD38wsIw59M7OMOPTNzDLi0Dczy4hD38wsIw59M7OMdPXPMJyt5Vu/Uqrf0e3va3MlZmbdwa/0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCOzhr6kd0h6rOH2Q0kfk7RI0qikw+l+YeovSXdKGpd0UNKqhnUNpf6HJQ21c8fMzOzfmjX0I+JQRFwcERcDq4GXgHuBrcBYRKwExtI8wBXAynTbAtwFIGkRsA24FLgE2DZ5ojAzs85odnhnLfB0RDwLbAR2pfZdwFVpeiNwd/pTjQ8BvZIWA+uB0YiYiIiTwCiw4az3wMzMSlPxN3RLdpZ2At+IiP8r6cWI6E3tAk5GRK+kfcD2iNiflo0BNwADwPyIuDW13wS8HBG3T9nGFop3CPT19a0eGRmptGP1ep0jp14t1feiJQsqbaOqer1OT09PR7dZhutqTrfWBd1bm+tqTtW6BgcHD0RE/3TLSv/2jqRzgQ8AN05dFhEhqfzZ4wwiYhgYBujv74+BgYFK66nVatyx/3Spvkc3VdtGVbVajar71U6uqzndWhd0b22uqzntqKuZ4Z0rKF7lP5/mn0/DNqT7E6n9OLCs4XFLU9tM7WZm1iHNhP6HgC82zO8FJq/AGQLua2i/Nl3FswY4FRHPAQ8A6yQtTB/grkttZmbWIaWGdySdD7wX+LWG5u3AbkmbgWeBq1P7/cCVwDjFlT7XAUTEhKRbgEdTv5sjYuKs98DMzEorFfoRcRr46SltL1BczTO1bwDXz7CencDO5ss0M7NW8Ddyzcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCOlQl9Sr6Q9kr4j6SlJl0laJGlU0uF0vzD1laQ7JY1LOihpVcN6hlL/w5KG2rVTZmY2vbKv9D8N/FVE/BzwLuApYCswFhErgbE0D3AFsDLdtgB3AUhaBGwDLgUuAbZNnijMzKwzZg19SQuAXwJ2AETEP0XEi8BGYFfqtgu4Kk1vBO6OwkNAr6TFwHpgNCImIuIkMApsaOnemJnZGSkiztxBuhgYBp6keJV/APgocDwielMfAScjolfSPmB7ROxPy8aAG4ABYH5E3JrabwJejojbp2xvC8U7BPr6+laPjIxU2rF6vc6RU6+W6nvRkgWVtlFVvV6np6eno9ssw3U1p1vrgu6tzXU1p2pdg4ODByKif7pl80o8fh6wCvhIRDws6dO8PpQDQESEpDOfPUqKiGGKkwz9/f0xMDBQaT21Wo079p8u1ffopmrbqKpWq1F1v9rJdTWnW+uC7q3NdTWnHXWVGdM/BhyLiIfT/B6Kk8DzadiGdH8iLT8OLGt4/NLUNlO7mZl1yKyhHxHfB74n6R2paS3FUM9eYPIKnCHgvjS9F7g2XcWzBjgVEc8BDwDrJC1MH+CuS21mZtYhZYZ3AD4C3CPpXOAZ4DqKE8ZuSZuBZ4GrU9/7gSuBceCl1JeImJB0C/Bo6ndzREy0ZC/MzKyUUqEfEY8B030osHaavgFcP8N6dgI7mynQzMxax9/INTPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4yUCn1JRyV9W9Jjkr6e2hZJGpV0ON0vTO2SdKekcUkHJa1qWM9Q6n9Y0lB7dsnMzGbSzCv9wYi4OCIm/0D6VmAsIlYCY2ke4ApgZbptAe6C4iQBbAMuBS4Btk2eKMzMrDPOZnhnI7ArTe8CrmpovzsKDwG9khYD64HRiJiIiJPAKLDhLLZvZmZNUkTM3kk6ApwEAvhsRAxLejEietNyAScjolfSPmB7ROxPy8aAG4ABYH5E3JrabwJejojbp2xrC8U7BPr6+laPjIxU2rF6vc6RU6+W6nvRkgWVtlFVvV6np6eno9ssw3U1p1vrgu6tzXU1p2pdg4ODBxpGZX7MvJLr+MWIOC7pLcCopO80LoyIkDT72aOEiBgGhgH6+/tjYGCg0npqtRp37D9dqu/RTdW2UVWtVqPqfrWT62pOt9YF3Vub62pOO+oqNbwTEcfT/QngXoox+efTsA3p/kTqfhxY1vDwpaltpnYzM+uQWUNf0vmS3jw5DawDHgf2ApNX4AwB96XpvcC16SqeNcCpiHgOeABYJ2lh+gB3XWozM7MOKTO80wfcWwzbMw/4QkT8laRHgd2SNgPPAlen/vcDVwLjwEvAdQARMSHpFuDR1O/miJho2Z6YmdmsZg39iHgGeNc07S8Aa6dpD+D6Gda1E9jZfJlmZtYK/kaumVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZaR06Es6R9I3Je1L8yskPSxpXNKXJJ2b2s9L8+Np+fKGddyY2g9JWt/qnTEzszNr5pX+R4GnGuZvAz4VEW8DTgKbU/tm4GRq/1Tqh6QLgWuAnwc2AH8q6ZyzK9/MzJpRKvQlLQXeB/xZmhdwObAnddkFXJWmN6Z50vK1qf9GYCQifhQRR4Bx4JJW7ISZmZWjiJi9k7QH+N/Am4H/BXwYeCi9mkfSMuCrEfFOSY8DGyLiWFr2NHAp8PvpMZ9P7TvSY/ZM2dYWYAtAX1/f6pGRkUo7Vq/XOXLq1VJ9L1qyoNI2qqrX6/T09HR0m2W4ruZ0a13QvbW5ruZUrWtwcPBARPRPt2zebA+W9H7gREQckDTQ9NabFBHDwDBAf39/DAxU22StVuOO/adL9T26qdo2qqrValTdr3ZyXc3p1rqge2tzXc1pR12zhj7wHuADkq4E5gP/Afg00CtpXkS8AiwFjqf+x4FlwDFJ84AFwAsN7ZMaH2NmZh0w65h+RNwYEUsjYjnFB7Ffi4hNwIPAB1O3IeC+NL03zZOWfy2KMaS9wDXp6p4VwErgkZbtiZmZzarMK/2Z3ACMSLoV+CawI7XvAD4naRyYoDhREBFPSNoNPAm8AlwfEeUG3c3MrCWaCv2IqAG1NP0M01x9ExH/CPzqDI//JPDJZos0M7PW8Ddyzcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCOzhr6k+ZIekfQtSU9I+oPUvkLSw5LGJX1J0rmp/bw0P56WL29Y142p/ZCk9e3aKTMzm16ZV/o/Ai6PiHcBFwMbJK0BbgM+FRFvA04Cm1P/zcDJ1P6p1A9JFwLXAD8PbAD+VNI5rdwZMzM7s1lDPwr1NPumdAvgcmBPat8FXJWmN6Z50vK1kpTaRyLiRxFxBBgHLmnJXpiZWSmKiNk7Fa/IDwBvA/4E+CPgofRqHknLgK9GxDslPQ5siIhjadnTwKXA76fHfD6170iP2TNlW1uALQB9fX2rR0ZGKu1YvV7nyKlXS/W9aMmCStuoql6v09PT09FtluG6mtOtdUH31ua6mlO1rsHBwQMR0T/dsnllVhARrwIXS+oF7gV+rukqSoqIYWAYoL+/PwYGBiqtp1arccf+06X6Ht1UbRtV1Wo1qu5XO7mu5nRrXdC9tbmu5rSjrqau3omIF4EHgcuAXkmTJ42lwPE0fRxYBpCWLwBeaGyf5jFmZtYBZa7e+Zn0Ch9JPwm8F3iKIvw/mLoNAfel6b1pnrT8a1GMIe0FrklX96wAVgKPtGpHzMxsdmWGdxYDu9K4/k8AuyNin6QngRFJtwLfBHak/juAz0kaByYortghIp6QtBt4EngFuD4NG5mZWYfMGvoRcRB49zTtzzDN1TcR8Y/Ar86wrk8Cn2y+TDMzawV/I9fMLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwyMmvoS1om6UFJT0p6QtJHU/siSaOSDqf7haldku6UNC7poKRVDesaSv0PSxpq326Zmdl0yrzSfwX4eERcCKwBrpd0IbAVGIuIlcBYmge4AliZbluAu6A4SQDbgEsp/qD6tskThZmZdcasoR8Rz0XEN9L0PwBPAUuAjcCu1G0XcFWa3gjcHYWHgF5Ji4H1wGhETETESWAU2NDSvTEzszNSRJTvLC0H/hZ4J/DdiOhN7QJORkSvpH3A9ojYn5aNATcAA8D8iLg1td8EvBwRt0/ZxhaKdwj09fWtHhkZqbRj9XqdI6deLdX3oiULKm2jqnq9Tk9PT0e3WYbrak631gXdW5vrak7VugYHBw9ERP90y+aVXYmkHuAvgI9FxA+LnC9EREgqf/Y4g4gYBoYB+vv7Y2BgoNJ6arUad+w/Xarv0U3VtlFVrVaj6n61k+tqTrfWBd1bm+tqTjvqKnX1jqQ3UQT+PRHx5dT8fBq2Id2fSO3HgWUND1+a2mZqNzOzDilz9Y6AHcBTEfHHDYv2ApNX4AwB9zW0X5uu4lkDnIqI54AHgHWSFqYPcNelNjMz65AywzvvAf4r8G1Jj6W23wW2A7slbQaeBa5Oy+4HrgTGgZeA6wAiYkLSLcCjqd/NETHRkr0wM7NSZg399IGsZli8dpr+AVw/w7p2AjubKdDMzFrH38g1M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8tI6V/ZfCNbvvUrpfod3f6+NldiZtZefqVvZpYRh76ZWUYc+mZmGXHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llxKFvZpaRWUNf0k5JJyQ93tC2SNKopMPpfmFql6Q7JY1LOihpVcNjhlL/w5KG2rM7ZmZ2JmVe6f85sGFK21ZgLCJWAmNpHuAKYGW6bQHuguIkAWwDLgUuAbZNnijMzKxzZg39iPhbYGJK80ZgV5reBVzV0H53FB4CeiUtBtYDoxExEREngVH+7YnEzMzaTBExeydpObAvIt6Z5l+MiN40LeBkRPRK2gdsj4j9adkYcAMwAMyPiFtT+03AyxFx+zTb2kLxLoG+vr7VIyMjlXasXq9z5NSrlR47k4uWLGjJeur1Oj09PS1ZVyu5ruZ0a13QvbW5ruZUrWtwcPBARPRPt+ysf2UzIkLS7GeO8usbBoYB+vv7Y2BgoNJ6arUad+w/3aqyADi6qVotU9VqNaruVzu5ruZ0a13QvbW5rua0o66qV+88n4ZtSPcnUvtxYFlDv6WpbaZ2MzProKqhvxeYvAJnCLivof3adBXPGuBURDwHPACsk7QwfYC7LrWZmVkHzTq8I+mLFGPyF0g6RnEVznZgt6TNwLPA1an7/cCVwDjwEnAdQERMSLoFeDT1uzkipn44bGZmbTZr6EfEh2ZYtHaavgFcP8N6dgI7m6rOzMxayt/INTPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4yc9W/v5GT51q+U6nd0+/vaXImZWTV+pW9mlhGHvplZRhz6ZmYZceibmWXEoW9mlhGHvplZRhz6ZmYZ8XX6bTDb9fwfv+gVPrz1K76e38w6zq/0zcwy4tA3M8uIh3fmkH/Wwcw6reOv9CVtkHRI0rikrZ3evplZzjr6Sl/SOcCfAO8FjgGPStobEU92so5/b/yOwMxapdPDO5cA4xHxDICkEWAj4NBvgbInh5lMXlXUDXwCM2uPTof+EuB7DfPHgEsbO0jaAmxJs3VJhypu6wLgBxUf21a/0aW1dVNduu3HZrumrim6tS7o3tpcV3Oq1vWfZlrQdR/kRsQwMHy265H09Yjob0FJLdettbmu5nRrXdC9tbmu5rSjrk5/kHscWNYwvzS1mZlZB3Q69B8FVkpaIelc4Bpgb4drMDPLVkeHdyLiFUn/E3gAOAfYGRFPtGlzZz1E1EbdWpvrak631gXdW5vrak7L61JEtHqdZmbWpfwzDGZmGXHom5ll5A0Z+nP5Uw+Slkl6UNKTkp6Q9NHUvkjSqKTD6X5hapekO1OtByWtanN950j6pqR9aX6FpIfT9r+UPmBH0nlpfjwtX97Gmnol7ZH0HUlPSbqsi47Xb6Z/x8clfVHS/Lk4ZpJ2Sjoh6fGGtqaPkaSh1P+wpKE21fVH6d/yoKR7JfU2LLsx1XVI0vqG9pY/Z6errWHZxyWFpAvS/Jwes9T+kXTcnpD0hw3trT1mEfGGulF8QPw08FbgXOBbwIUd3P5iYFWafjPwd8CFwB8CW1P7VuC2NH0l8FVAwBrg4TbX91vAF4B9aX43cE2a/gzw62n6fwCfSdPXAF9qY027gP+eps8FervheFF8mfAI8JMNx+rDc3HMgF8CVgGPN7Q1dYyARcAz6X5hml7YhrrWAfPS9G0NdV2Yno/nASvS8/Scdj1np6sttS+juJjkWeCCLjlmg8BfA+el+be065i15ckylzfgMuCBhvkbgRvnsJ77KH5r6BCwOLUtBg6l6c8CH2ro/1q/NtSyFBgDLgf2pf/gP2h4gr527NKT4rI0PS/1UxtqWkARrJrS3g3Ha/Ib5IvSMdgHrJ+rYwYsnxIUTR0j4EPAZxvaf6xfq+qasuxXgHvS9I89FyePVzufs9PVBuwB3gUc5fXQn9NjRvFC4pen6dfyY/ZGHN6Z7qcelsxFIent/buBh4G+iHguLfo+0JemO1nv/wF+B/iXNP/TwIsR8co0236trrT8VOrfaiuAvwf+Xxp2+jNJ59MFxysijgO3A98FnqM4BgeY+2M2qdljNBfPjf9G8Qq6K+qStBE4HhHfmrJormt7O/Bf0rDg30j6hXbV9UYM/a4gqQf4C+BjEfHDxmVRnJo7eq2spPcDJyLiQCe3W8I8ire6d0XEu4HTFEMVr5mL4wWQxsg3UpyYfhY4H9jQ6TrKmKtjdCaSPgG8Atwz17UASPop4HeB35vrWqYxj+Id5Rrgt4HdktSODb0RQ3/Of+pB0psoAv+eiPhyan5e0uK0fDFwIrV3qt73AB+QdBQYoRji+TTQK2nyS3qN236trrR8AfBCG+o6BhyLiIfT/B6Kk8BcHy+AXwaORMTfR8Q/A1+mOI5zfcwmNXuMOnbsJH0YeD+wKZ2QuqGu/0xxAv9Weh4sBb4h6T92QW3HgC9H4RGKd+MXtKOuN2Loz+lPPaSz8w7gqYj444ZFe4HJT/6HKMb6J9uvTVcPrAFONbxlb5mIuDEilkbEcopj8rWI2AQ8CHxwhrom6/1g6t/yV5IR8X3ge5LekZrWUvzU9pwer+S7wBpJP5X+XSdrm9Nj1qDZY/QAsE7SwvQuZl1qaylJGyiGET8QES9NqfcaFVc5rQBWAo/QoedsRHw7It4SEcvT8+AYxUUX32eOjxnwlxQf5iLp7RQfzv6AdhyzVnxY0m03ik/i/47i0+1PdHjbv0jxNvsg8Fi6XUkxtjsGHKb4lH5R6i+KPyzzNPBtoL8DNQ7w+tU7b03/icaB/8/rVw/MT/Pjaflb21jPxcDX0zH7S4qrJLrieAF/AHwHeBz4HMVVFB0/ZsAXKT5X+GeKsNpc5RhRjLGPp9t1baprnGK8efL//2ca+n8i1XUIuKKhveXP2elqm7L8KK9/kDvXx+xc4PPp/9k3gMvbdcz8MwxmZhl5Iw7vmJnZDBz6ZmYZceibmWXEoW9mlhGHvplZRhz6ZmYZceibmWXkXwGtZZtE+IxaZwAAAABJRU5ErkJggg==\n" }, "metadata": { "needs_background": "light" } } ] }, { "cell_type": "code", "source": [ "max_seq_len = 100" ], "metadata": { "id": "JB3VM8WuEQ4O" }, "execution_count": 10, "outputs": [] }, { "cell_type": "code", "source": [ "# tokenize and encode sequences in the sets\n", "tokens_train = tokenizer.batch_encode_plus(\n", " train_input['text'].tolist(), \n", " max_length = max_seq_len,\n", " pad_to_max_length=True,\n", " truncation=True,\n", " return_token_type_ids=False\n", ")\n", "\n", "tokens_dev = tokenizer.batch_encode_plus(\n", " dev['text'].tolist(), \n", " max_length = max_seq_len,\n", " pad_to_max_length=True,\n", " truncation=True,\n", " return_token_type_ids=False\n", ")\n", "\n", "tokens_test = tokenizer.batch_encode_plus(\n", " test['text'].tolist(), \n", " max_length = max_seq_len,\n", " pad_to_max_length=True,\n", " truncation=True,\n", " return_token_type_ids=False\n", ")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "921Msq7VEUH4", "outputId": "0a92a3da-eedc-44e8-b6b9-9d4f28958e72" }, "execution_count": 11, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils_base.py:2307: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n", " FutureWarning,\n" ] } ] }, { "cell_type": "code", "source": [ "# for train set\n", "train_seq = torch.tensor(tokens_train['input_ids'])\n", "train_mask = torch.tensor(tokens_train['attention_mask'])\n", "train_y = torch.tensor(train_input['label'].tolist())\n", "\n", "# for validation set\n", "val_seq = torch.tensor(tokens_dev['input_ids'])\n", "val_mask = torch.tensor(tokens_dev['attention_mask'])\n", "val_y = torch.tensor(dev['label'].tolist())\n", "\n", "# for test set\n", "test_seq = torch.tensor(tokens_test['input_ids'])\n", "test_mask = torch.tensor(tokens_test['attention_mask'])\n", "test_y = torch.tensor([])" ], "metadata": { "id": "_6hI_UZdFtEn" }, "execution_count": 12, "outputs": [] }, { "cell_type": "code", "source": [ "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n", "\n", "#define a batch size\n", "batch_size = 16\n", "\n", "# wrap tensors\n", "train_data = TensorDataset(train_seq, train_mask, train_y)\n", "\n", "# sampler for sampling the data during training\n", "train_sampler = RandomSampler(train_data)\n", "\n", "# dataLoader for train set\n", "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n", "\n", "# wrap tensors\n", "val_data = TensorDataset(val_seq, val_mask, val_y)\n", "\n", "# sampler for sampling the data during training\n", "val_sampler = SequentialSampler(val_data)\n", "\n", "# dataLoader for validation set\n", "val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)\n" ], "metadata": { "id": "KSNdUlPLHiw0" }, "execution_count": 13, "outputs": [] }, { "cell_type": "code", "source": [ "# freeze all the parameters\n", "for param in bert.parameters():\n", " param.requires_grad = False" ], "metadata": { "id": "4rU2lm6MHqin" }, "execution_count": 14, "outputs": [] }, { "cell_type": "code", "source": [ "class BERT_Arch(nn.Module):\n", "\n", " def __init__(self, bert):\n", " \n", " super(BERT_Arch, self).__init__()\n", "\n", " self.bert = bert \n", " \n", " # dropout layer\n", " self.dropout = nn.Dropout(0.1)\n", " \n", " # relu activation function\n", " self.relu = nn.ReLU()\n", "\n", " # dense layer 1\n", " self.fc1 = nn.Linear(768,512)\n", " \n", " # dense layer 2 (Output layer)\n", " self.fc2 = nn.Linear(512,2)\n", "\n", " #softmax activation function\n", " self.softmax = nn.LogSoftmax(dim=1)\n", "\n", " #define the forward pass\n", " def forward(self, sent_id, mask):\n", "\n", " #pass the inputs to the model \n", " _, cls_hs = self.bert(sent_id, attention_mask=mask, return_dict=False)\n", " \n", " x = self.fc1(cls_hs)\n", "\n", " x = self.relu(x)\n", "\n", " x = self.dropout(x)\n", "\n", " # output layer\n", " x = self.fc2(x)\n", " \n", " # apply softmax activation\n", " x = self.softmax(x)\n", "\n", " return x" ], "metadata": { "id": "YFohRUDdHryu" }, "execution_count": 15, "outputs": [] }, { "cell_type": "code", "source": [ "# pass the pre-trained BERT to our define architecture\n", "model = BERT_Arch(bert)\n", "\n", "# push the model to GPU\n", "model = model.to(device)" ], "metadata": { "id": "rDPdkzf0HtTH" }, "execution_count": 16, "outputs": [] }, { "cell_type": "code", "source": [ "# optimizer from hugging face transformers\n", "from transformers import AdamW\n", "\n", "# define the optimizer\n", "optimizer = AdamW(model.parameters(), lr = 1e-3)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "k-Rj8zh-HyNv", "outputId": "f15878b3-3a36-4966-a2d0-58ab8268b3d1" }, "execution_count": 17, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:310: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " FutureWarning,\n" ] } ] }, { "cell_type": "code", "source": [ "from sklearn.utils.class_weight import compute_class_weight\n", "\n", "#compute the class weights\n", "class_wts = compute_class_weight('balanced', classes = np.unique(train_input['label']), y = train_input['label'])\n", "\n", "print(class_wts)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mNAZd44jH3rf", "outputId": "bdb3aed8-31f4-4653-c6c5-43864f546ae1" }, "execution_count": 18, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[0.7732756 1.41482739]\n" ] } ] }, { "cell_type": "code", "source": [ "# convert class weights to tensor\n", "weights= torch.tensor(class_wts,dtype=torch.float)\n", "weights = weights.to(device)\n", "\n", "# loss function\n", "cross_entropy = nn.NLLLoss(weight=weights) \n", "\n", "# number of training epochs\n", "epochs = 3" ], "metadata": { "id": "JAjsPk2QH46H" }, "execution_count": 19, "outputs": [] }, { "cell_type": "code", "source": [ "# function to train the model\n", "def train():\n", " \n", " model.train()\n", "\n", " total_loss, total_accuracy = 0, 0\n", " \n", " # empty list to save model predictions\n", " total_preds=[]\n", " \n", " # iterate over batches\n", " for step,batch in enumerate(train_dataloader):\n", " \n", " # progress update after every 50 batches.\n", " if step % 50 == 0 and not step == 0:\n", " print(' Batch {:>5,} of {:>5,}.'.format(step, len(train_dataloader)))\n", "\n", " # push the batch to gpu\n", " batch = [r.to(device) for r in batch]\n", " \n", " sent_id, mask, labels = batch\n", "\n", " # clear previously calculated gradients \n", " model.zero_grad() \n", "\n", " # get model predictions for the current batch\n", " preds = model(sent_id, mask)\n", "\n", " # compute the loss between actual and predicted values\n", " loss = cross_entropy(preds, labels)\n", "\n", " # add on to the total loss\n", " total_loss = total_loss + loss.item()\n", "\n", " # backward pass to calculate the gradients\n", " loss.backward()\n", "\n", " # clip the the gradients to 1.0. It helps in preventing the exploding gradient problem\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", "\n", " # update parameters\n", " optimizer.step()\n", "\n", " # model predictions are stored on GPU. So, push it to CPU\n", " preds=preds.detach().cpu().numpy()\n", "\n", " # append the model predictions\n", " total_preds.append(preds)\n", "\n", " # compute the training loss of the epoch\n", " avg_loss = total_loss / len(train_dataloader)\n", " \n", " # predictions are in the form of (no. of batches, size of batch, no. of classes).\n", " # reshape the predictions in form of (number of samples, no. of classes)\n", " total_preds = np.concatenate(total_preds, axis=0)\n", "\n", " #returns the loss and predictions\n", " return avg_loss, total_preds" ], "metadata": { "id": "VFzCbFmmIhhd" }, "execution_count": 20, "outputs": [] }, { "cell_type": "code", "source": [ "import time\n", "# function for evaluating the model\n", "def evaluate():\n", " \n", " print(\"\\nEvaluating...\")\n", " \n", " # deactivate dropout layers\n", " model.eval()\n", "\n", " total_loss, total_accuracy = 0, 0\n", " \n", " # empty list to save the model predictions\n", " total_preds = []\n", "\n", " # iterate over batches\n", " for step,batch in enumerate(val_dataloader):\n", " \n", " # Progress update every 50 batches.\n", " if step % 50 == 0 and not step == 0:\n", " \n", " # Calculate elapsed time in minutes.\n", " #elapsed = format_time(time.time() - t0)\n", " \n", " # Report progress.\n", " print(' Batch {:>5,} of {:>5,}.'.format(step, len(val_dataloader)))\n", "\n", " # push the batch to gpu\n", " batch = [t.to(device) for t in batch]\n", "\n", " sent_id, mask, labels = batch\n", "\n", " # deactivate autograd\n", " with torch.no_grad():\n", " \n", " # model predictions\n", " preds = model(sent_id, mask)\n", "\n", " # compute the validation loss between actual and predicted values\n", " loss = cross_entropy(preds,labels)\n", "\n", " total_loss = total_loss + loss.item()\n", "\n", " preds = preds.detach().cpu().numpy()\n", "\n", " total_preds.append(preds)\n", "\n", " # compute the validation loss of the epoch\n", " avg_loss = total_loss / len(val_dataloader) \n", "\n", " # reshape the predictions in form of (number of samples, no. of classes)\n", " total_preds = np.concatenate(total_preds, axis=0)\n", "\n", " return avg_loss, total_preds" ], "metadata": { "id": "lnVTBlprIjE_" }, "execution_count": 21, "outputs": [] }, { "cell_type": "code", "source": [ "# set initial loss to infinite\n", "best_valid_loss = float('inf')\n", "\n", "# empty lists to store training and validation loss of each epoch\n", "train_losses=[]\n", "valid_losses=[]\n", "\n", "#for each epoch\n", "for epoch in range(epochs):\n", " \n", " print('\\n Epoch {:} / {:}'.format(epoch + 1, epochs))\n", " \n", " #train model\n", " train_loss, _ = train()\n", " \n", " #evaluate model\n", " valid_loss, _ = evaluate()\n", " \n", " #save the best model\n", " if valid_loss < best_valid_loss:\n", " best_valid_loss = valid_loss\n", " torch.save(model.state_dict(), 'saved_weights_10k.pt')\n", " \n", " # append training and validation loss\n", " train_losses.append(train_loss)\n", " valid_losses.append(valid_loss)\n", " \n", " print(f'\\nTraining Loss: {train_loss:.3f}')\n", " print(f'Validation Loss: {valid_loss:.3f}')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "S7R_IWk1Ilk_", "outputId": "c4ee3b7c-ecae-4baf-91b9-6cc6f77dc481" }, "execution_count": 22, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 1 / 3\n", " Batch 50 of 625.\n", " Batch 100 of 625.\n", " Batch 150 of 625.\n", " Batch 200 of 625.\n", " Batch 250 of 625.\n", " Batch 300 of 625.\n", " Batch 350 of 625.\n", " Batch 400 of 625.\n", " Batch 450 of 625.\n", " Batch 500 of 625.\n", " Batch 550 of 625.\n", " Batch 600 of 625.\n", "\n", "Evaluating...\n", " Batch 50 of 330.\n", " Batch 100 of 330.\n", " Batch 150 of 330.\n", " Batch 200 of 330.\n", " Batch 250 of 330.\n", " Batch 300 of 330.\n", "\n", "Training Loss: 0.663\n", "Validation Loss: 0.617\n", "\n", " Epoch 2 / 3\n", " Batch 50 of 625.\n", " Batch 100 of 625.\n", " Batch 150 of 625.\n", " Batch 200 of 625.\n", " Batch 250 of 625.\n", " Batch 300 of 625.\n", " Batch 350 of 625.\n", " Batch 400 of 625.\n", " Batch 450 of 625.\n", " Batch 500 of 625.\n", " Batch 550 of 625.\n", " Batch 600 of 625.\n", "\n", "Evaluating...\n", " Batch 50 of 330.\n", " Batch 100 of 330.\n", " Batch 150 of 330.\n", " Batch 200 of 330.\n", " Batch 250 of 330.\n", " Batch 300 of 330.\n", "\n", "Training Loss: 0.608\n", "Validation Loss: 0.586\n", "\n", " Epoch 3 / 3\n", " Batch 50 of 625.\n", " Batch 100 of 625.\n", " Batch 150 of 625.\n", " Batch 200 of 625.\n", " Batch 250 of 625.\n", " Batch 300 of 625.\n", " Batch 350 of 625.\n", " Batch 400 of 625.\n", " Batch 450 of 625.\n", " Batch 500 of 625.\n", " Batch 550 of 625.\n", " Batch 600 of 625.\n", "\n", "Evaluating...\n", " Batch 50 of 330.\n", " Batch 100 of 330.\n", " Batch 150 of 330.\n", " Batch 200 of 330.\n", " Batch 250 of 330.\n", " Batch 300 of 330.\n", "\n", "Training Loss: 0.585\n", "Validation Loss: 0.598\n" ] } ] }, { "cell_type": "code", "source": [ "#load weights of best model\n", "path = './saved_weights_10k.pt'\n", "model.load_state_dict(torch.load(path))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "OMNZqvqzIozP", "outputId": "533d1d3e-0004-4df3-95b2-b1b57c65f55c" }, "execution_count": 48, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 48 } ] }, { "cell_type": "code", "source": [ "# get predictions for test data\n", "with torch.no_grad():\n", " preds = model(test_seq[:1000].to(device), test_mask[:1000].to(device))\n", " preds = preds.detach().cpu().numpy()" ], "metadata": { "id": "8GuauaHOs53n" }, "execution_count": 34, "outputs": [] }, { "cell_type": "code", "source": [ "def predict(model, seq, mask):\n", " result = []\n", " \n", " with torch.no_grad():\n", " for i in range(0, len(seq), 16):\n", " s = seq[i:i+16]\n", " m = mask[i:i+16]\n", " preds = model(s[i].to(device), m[i].to(device))\n", " preds = preds.detach().cpu().numpy()\n", " preds = np.argmax(preds, axis = 1)\n", " result.extend(preds)\n", "\n", " return result" ], "metadata": { "id": "zm0g6fChWCvq" }, "execution_count": 43, "outputs": [] }, { "cell_type": "code", "source": [ "result = []\n", "for i in range(int(len(test_seq)/5)):\n", " x0 = i*int(len(test_seq)/5)\n", " x1 = (i+1)*int(len(test_seq)/5)\n", " preds = model(test_seq[x0:x1].to(device), test_mask[x0:x1].to(device))\n", " preds = preds.detach().cpu().numpy()\n", " preds = np.argmax(preds, axis = 1)\n", " result.extend(preds)\n", "result" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "slZZ0gkTXxLm", "outputId": "014e4da6-5d53-4307-f3ac-d10870b6ef58" }, "execution_count": 48, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " ...]" ] }, "metadata": {}, "execution_count": 48 } ] }, { "cell_type": "code", "source": [ "write_res(result, './test_out.tsv')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "lqDJ2A29ZwZ3", "outputId": "d22ec4f9-fbde-45d8-ea1d-970f6ec57fba" }, "execution_count": 53, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Data written ./test_out.tsv/out.tsv\n" ] } ] }, { "cell_type": "code", "source": [ "result_dev = []\n", "for i in range(int(len(val_seq)/5)):\n", " x0 = i*int(len(val_seq)/5)\n", " x1 = (i+1)*int(len(val_seq)/5)\n", " preds = model(val_seq[x0:x1].to(device), val_mask[x0:x1].to(device))\n", " preds = preds.detach().cpu().numpy()\n", " preds = np.argmax(preds, axis = 1)\n", " result_dev.extend(preds)\n", "write_res(result_dev, './dev_out.tsv')\n", "len(result_dev)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ktm90qCvapDH", "outputId": "077fa023-890a-4a29-e881-e6a95366e7ad" }, "execution_count": 55, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Data written ./dev_out.tsv/out.tsv\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "5272" ] }, "metadata": {}, "execution_count": 55 } ] } ] }