retro-gap/full_pipeline.ipynb

1808 lines
59 KiB
Plaintext
Raw Permalink Normal View History

2021-01-11 22:59:48 +01:00
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "full_pipeline.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
2021-01-12 23:00:18 +01:00
"0a62008eba914666a971f04ecfc5b04e": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
2021-01-12 23:00:18 +01:00
"layout": "IPY_MODEL_d7cb15cad3d44d5d9c13c280353d0621",
2021-01-11 22:59:48 +01:00
"_model_module": "@jupyter-widgets/controls",
"children": [
2021-01-12 23:00:18 +01:00
"IPY_MODEL_5039d309c081494f98402f44fdab7f30",
"IPY_MODEL_fe1aa54d6836412f8e56195267888e78"
2021-01-11 22:59:48 +01:00
]
}
},
2021-01-12 23:00:18 +01:00
"d7cb15cad3d44d5d9c13c280353d0621": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
2021-01-12 23:00:18 +01:00
"5039d309c081494f98402f44fdab7f30": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
2021-01-12 23:00:18 +01:00
"style": "IPY_MODEL_6eabcb1b506742b0bfd9d4704a2db177",
2021-01-11 22:59:48 +01:00
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
2021-01-12 23:00:18 +01:00
"max": 10000,
2021-01-11 22:59:48 +01:00
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
2021-01-12 23:00:18 +01:00
"value": 10000,
2021-01-11 22:59:48 +01:00
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
2021-01-12 23:00:18 +01:00
"layout": "IPY_MODEL_796cd8889b1c49038132f58517885135"
2021-01-11 22:59:48 +01:00
}
},
2021-01-12 23:00:18 +01:00
"fe1aa54d6836412f8e56195267888e78": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
2021-01-12 23:00:18 +01:00
"style": "IPY_MODEL_c7ef36717ed94d4087bb3c7cc86d7a36",
2021-01-11 22:59:48 +01:00
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
2021-01-12 23:00:18 +01:00
"value": " 10000/10000 [20:20<00:00, 8.19it/s]",
2021-01-11 22:59:48 +01:00
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
2021-01-12 23:00:18 +01:00
"layout": "IPY_MODEL_0a098ddf0398429abbe7709d59774415"
2021-01-11 22:59:48 +01:00
}
},
2021-01-12 23:00:18 +01:00
"6eabcb1b506742b0bfd9d4704a2db177": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
2021-01-12 23:00:18 +01:00
"796cd8889b1c49038132f58517885135": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
2021-01-12 23:00:18 +01:00
"c7ef36717ed94d4087bb3c7cc86d7a36": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
2021-01-12 23:00:18 +01:00
"0a098ddf0398429abbe7709d59774415": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
2021-01-12 23:00:18 +01:00
"dc2978e2f59c49caa9adf0c25c14b4a7": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
2021-01-12 23:00:18 +01:00
"layout": "IPY_MODEL_ce71a076d140449a895d92d8abb48ee3",
2021-01-11 22:59:48 +01:00
"_model_module": "@jupyter-widgets/controls",
"children": [
2021-01-12 23:00:18 +01:00
"IPY_MODEL_af5a765b075b4b0c8bd44c7c0a52eb4d",
"IPY_MODEL_b7725fc568634e3db37a3b88515b92fd"
2021-01-11 22:59:48 +01:00
]
}
},
2021-01-12 23:00:18 +01:00
"ce71a076d140449a895d92d8abb48ee3": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
2021-01-12 23:00:18 +01:00
"af5a765b075b4b0c8bd44c7c0a52eb4d": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
2021-01-12 23:00:18 +01:00
"style": "IPY_MODEL_dc1cc51b8ddd4c7fb911c6f0ab512d72",
2021-01-11 22:59:48 +01:00
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 19986,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 19986,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
2021-01-12 23:00:18 +01:00
"layout": "IPY_MODEL_e5ecdc7c36914e88b978c85a645d1f79"
2021-01-11 22:59:48 +01:00
}
},
2021-01-12 23:00:18 +01:00
"b7725fc568634e3db37a3b88515b92fd": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
2021-01-12 23:00:18 +01:00
"style": "IPY_MODEL_70727d3490584014b6cf5501cb77a106",
2021-01-11 22:59:48 +01:00
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
2021-01-12 23:00:18 +01:00
"value": " 19986/19986 [17:03<00:00, 19.53it/s]",
2021-01-11 22:59:48 +01:00
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
2021-01-12 23:00:18 +01:00
"layout": "IPY_MODEL_cf04ed225c9e44e195bcd81940a8d9bc"
2021-01-11 22:59:48 +01:00
}
},
2021-01-12 23:00:18 +01:00
"dc1cc51b8ddd4c7fb911c6f0ab512d72": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
2021-01-12 23:00:18 +01:00
"e5ecdc7c36914e88b978c85a645d1f79": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
2021-01-12 23:00:18 +01:00
"70727d3490584014b6cf5501cb77a106": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
2021-01-12 23:00:18 +01:00
"cf04ed225c9e44e195bcd81940a8d9bc": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
2021-01-12 23:00:18 +01:00
"b504a463e4cf460cb8379e93ab480d54": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
2021-01-12 23:00:18 +01:00
"layout": "IPY_MODEL_e2a75fdd4ce9485a93d1306155f229b7",
2021-01-11 22:59:48 +01:00
"_model_module": "@jupyter-widgets/controls",
"children": [
2021-01-12 23:00:18 +01:00
"IPY_MODEL_2e7b6949536a4f89a29c3b9de6a50d53",
"IPY_MODEL_880cdda5579e404a93ce1f18f892df65"
2021-01-11 22:59:48 +01:00
]
}
},
2021-01-12 23:00:18 +01:00
"e2a75fdd4ce9485a93d1306155f229b7": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
2021-01-12 23:00:18 +01:00
"2e7b6949536a4f89a29c3b9de6a50d53": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
2021-01-12 23:00:18 +01:00
"style": "IPY_MODEL_a29e9ab8158d4f56a97c4dc71e164b57",
2021-01-11 22:59:48 +01:00
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 11628,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 11628,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
2021-01-12 23:00:18 +01:00
"layout": "IPY_MODEL_b115341af9f244c08ed933f26ec9bb0a"
2021-01-11 22:59:48 +01:00
}
},
2021-01-12 23:00:18 +01:00
"880cdda5579e404a93ce1f18f892df65": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
2021-01-12 23:00:18 +01:00
"style": "IPY_MODEL_e90a93719ca542e99825c35457212fd0",
2021-01-11 22:59:48 +01:00
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
2021-01-12 23:00:18 +01:00
"value": " 11628/11628 [09:59<00:00, 19.40it/s]",
2021-01-11 22:59:48 +01:00
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
2021-01-12 23:00:18 +01:00
"layout": "IPY_MODEL_c2087c68702e412ca6f63ee004114256"
2021-01-11 22:59:48 +01:00
}
},
2021-01-12 23:00:18 +01:00
"a29e9ab8158d4f56a97c4dc71e164b57": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
2021-01-12 23:00:18 +01:00
"b115341af9f244c08ed933f26ec9bb0a": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
2021-01-12 23:00:18 +01:00
"e90a93719ca542e99825c35457212fd0": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
2021-01-12 23:00:18 +01:00
"c2087c68702e412ca6f63ee004114256": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
2021-01-12 23:00:18 +01:00
"5e23acee323f44649e4a8dd2bb9bdb7f": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
2021-01-12 23:00:18 +01:00
"layout": "IPY_MODEL_851974922d5849a6b675676e07465116",
2021-01-11 22:59:48 +01:00
"_model_module": "@jupyter-widgets/controls",
"children": [
2021-01-12 23:00:18 +01:00
"IPY_MODEL_0a46fdf361e046bf9c6a1128abd693fa",
"IPY_MODEL_92389827e8fa48d1b42b9963328df9d2"
2021-01-11 22:59:48 +01:00
]
}
},
2021-01-12 23:00:18 +01:00
"851974922d5849a6b675676e07465116": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
2021-01-12 23:00:18 +01:00
"0a46fdf361e046bf9c6a1128abd693fa": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
2021-01-12 23:00:18 +01:00
"style": "IPY_MODEL_45705971d34c41018199482aaf8a5d2d",
2021-01-11 22:59:48 +01:00
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 14132,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 14132,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
2021-01-12 23:00:18 +01:00
"layout": "IPY_MODEL_bdb0c1ff2eb743f284fdecbee5c46009"
2021-01-11 22:59:48 +01:00
}
},
2021-01-12 23:00:18 +01:00
"92389827e8fa48d1b42b9963328df9d2": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
2021-01-12 23:00:18 +01:00
"style": "IPY_MODEL_3b65b80af6dc4944839e2775fcb96e1f",
2021-01-11 22:59:48 +01:00
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
2021-01-12 23:00:18 +01:00
"value": " 14132/14132 [12:07<00:00, 19.44it/s]",
2021-01-11 22:59:48 +01:00
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
2021-01-12 23:00:18 +01:00
"layout": "IPY_MODEL_04fd8b42561a4b55b5a19152a4065044"
2021-01-11 22:59:48 +01:00
}
},
2021-01-12 23:00:18 +01:00
"45705971d34c41018199482aaf8a5d2d": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
2021-01-12 23:00:18 +01:00
"bdb0c1ff2eb743f284fdecbee5c46009": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
2021-01-12 23:00:18 +01:00
"3b65b80af6dc4944839e2775fcb96e1f": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
2021-01-12 23:00:18 +01:00
"04fd8b42561a4b55b5a19152a4065044": {
2021-01-11 22:59:48 +01:00
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "J5Q-tp0U3pHl"
},
"source": [
"import re\n",
"import numpy as np\n",
"from collections import defaultdict\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9t4L-LbyOHNc"
},
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "yVPj34v-718x"
},
"source": [
"with open(\"stopwords.txt\", \"r+\") as f:\n",
" stop_words = f.read().split(\"\\n\")"
],
2021-01-12 23:00:18 +01:00
"execution_count": 4,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "7hDKry1k46ZJ"
},
"source": [
"def clean_text(text):\n",
" split = text.lower().split(\" \")\n",
"\n",
" # removing punctuation\n",
" clean = []\n",
" for token in split:\n",
" token = re.sub(r'[^\\w\\s]', '', token)\n",
" if token:\n",
" clean.append(token)\n",
" return clean"
],
2021-01-12 23:00:18 +01:00
"execution_count": 5,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "EZSzAKY8ALMK"
},
"source": [
"def prepare_corpus(texts, min_count=1, min_word_len=1):\n",
" corpus = {}\n",
" counters = defaultdict(lambda: 0)\n",
" idx_counter = 0\n",
" for text in texts:\n",
"\n",
" # add to corpus\n",
" for token in text:\n",
" if len(token) < min_word_len or token in stop_words:\n",
" continue\n",
" counters[token] += 1\n",
" if token not in corpus and counters[token] == min_count:\n",
" corpus[token] = idx_counter\n",
" idx_counter += 1\n",
" return corpus"
],
2021-01-12 23:00:18 +01:00
"execution_count": 6,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "7Mpm7weQANWX"
},
"source": [
"counters = defaultdict(lambda: 0)\n",
"\n",
"class WordCorpus:\n",
" def __init__(self, corpus=None, texts=None, min_count=1, min_word_len=1):\n",
" if corpus:\n",
" self.corpus = corpus\n",
" else:\n",
" self.corpus = prepare_corpus(texts, min_count, min_word_len)\n",
"\n",
" def get_word_idx(self, token):\n",
" token = token.lower()\n",
" token = re.sub(r'[^\\w\\s]', '', token)\n",
"\n",
" return self.corpus.get(token, None)\n",
"\n",
" def get_embedding(self, token, encode=False):\n",
" embedding = np.zeros(len(self.corpus), dtype=np.int32)\n",
" if encode:\n",
" token_idx = token\n",
" else:\n",
" token = token.lower()\n",
" token = re.sub(r'[^\\w\\s]', '', token)\n",
" if not token or token not in self.corpus:\n",
" return embedding\n",
"\n",
" token_idx = self.corpus[token]\n",
" embedding[token_idx] = 1\n",
" return embedding\n",
"\n",
" def get_bow(self, text, encode=False):\n",
" if encode:\n",
" embeddings = [\n",
" self.get_embedding(token, encode) for token in text\n",
" ]\n",
"\n",
" return np.sum(embeddings, axis=0)\n",
" else:\n",
" bow = np.zeros(len(self.corpus), dtype=np.int32)\n",
" for token in text:\n",
" bow[token] += 1\n",
" return bow"
],
2021-01-12 23:00:18 +01:00
"execution_count": 7,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "IjOoR5qyAQSS"
},
"source": [
"def load_train_data(train_path):\n",
" texts = []\n",
" with open(train_path, \"r+\") as file:\n",
" while True:\n",
" line = file.readline()\n",
" if not line:\n",
" break\n",
"\n",
" _, _, _, _, text, *_ = line.split(\"\\t\")\n",
" texts.append(clean_text(text))\n",
" print(f\"Loaded {len(texts)} texts from train_set.\")\n",
" return texts"
],
2021-01-12 23:00:18 +01:00
"execution_count": 8,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "oYgV8yPJGqq3"
},
"source": [
"class LanguageNeuralModel(nn.Module):\n",
" def __init__(self, corpus_size, hidden_size):\n",
" super().__init__()\n",
" self.input = nn.Linear(corpus_size, hidden_size)\n",
" self.hidden = nn.Linear(hidden_size, hidden_size)\n",
" self.output = nn.Linear(hidden_size, corpus_size)\n",
"\n",
" def forward(self, x):\n",
" x = self.input(x)\n",
" x = F.relu(x)\n",
" x = self.hidden(x)\n",
" x = F.relu(x)\n",
"\n",
" x = self.output(x)\n",
" return x"
],
2021-01-12 23:00:18 +01:00
"execution_count": 9,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "vAmveHcBGtrf"
},
"source": [
"def get_random_word_with_contexts(text, context_size):\n",
" allowed_indexes = np.arange(context_size, len(text) - context_size)\n",
" if not len(allowed_indexes):\n",
" return None, None\n",
" word_idx = np.random.choice(allowed_indexes)\n",
" word = text[word_idx]\n",
" context = text[(word_idx - context_size):word_idx] + text[(word_idx + 1):(word_idx + 1 + context_size)]\n",
" return word, context"
],
2021-01-12 23:00:18 +01:00
"execution_count": 10,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6N72wcXIIPFu",
2021-01-12 23:00:18 +01:00
"outputId": "15262de3-6bba-4d78-a0c3-1687163a78d2"
2021-01-11 22:59:48 +01:00
},
"source": [
"a = clean_text(\"Ala ma kota , kot pije mleko\")\n",
"get_random_word_with_contexts(a, 2)"
],
2021-01-12 23:00:18 +01:00
"execution_count": 11,
2021-01-11 22:59:48 +01:00
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
2021-01-12 23:00:18 +01:00
"('kot', ['ma', 'kota', 'pije', 'mleko'])"
2021-01-11 22:59:48 +01:00
]
},
"metadata": {
"tags": []
},
2021-01-12 23:00:18 +01:00
"execution_count": 11
2021-01-11 22:59:48 +01:00
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5cgusynnIQa7",
2021-01-12 23:00:18 +01:00
"outputId": "fc54e148-8032-4d47-dbc3-c06faab04300"
2021-01-11 22:59:48 +01:00
},
"source": [
"train_texts = load_train_data(\"drive/MyDrive/train.tsv\")"
],
2021-01-12 23:00:18 +01:00
"execution_count": 12,
2021-01-11 22:59:48 +01:00
"outputs": [
{
"output_type": "stream",
"text": [
"Loaded 107471 texts from train_set.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "qvOgmYr1KM10"
},
"source": [
"corpus = WordCorpus(texts=train_texts, min_count=20, min_word_len=5)"
],
2021-01-12 23:00:18 +01:00
"execution_count": 13,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "u56IlKZuju9d",
2021-01-12 23:00:18 +01:00
"outputId": "5a0fb193-1a4e-47d5-9652-b090218ba22d"
2021-01-11 22:59:48 +01:00
},
"source": [
"len(corpus.corpus)"
],
2021-01-12 23:00:18 +01:00
"execution_count": 14,
2021-01-11 22:59:48 +01:00
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"111418"
]
},
"metadata": {
"tags": []
},
2021-01-12 23:00:18 +01:00
"execution_count": 14
2021-01-11 22:59:48 +01:00
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "0GO307zuHYLm"
},
"source": [
"def remove_words_outside_corpus_and_encode(text, corpus):\n",
" return [corpus.get_word_idx(token) for token in text if token in corpus.corpus]"
],
2021-01-12 23:00:18 +01:00
"execution_count": 15,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "zFlfsgR3IQqX"
},
"source": [
"train_texts = [remove_words_outside_corpus_and_encode(text, corpus) for text in train_texts]"
],
2021-01-12 23:00:18 +01:00
"execution_count": 16,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "NTOCBZssKwci"
},
"source": [
"BATCH_SIZE = 96\n",
"CONTEXT_SIZE = 15"
],
2021-01-12 23:00:18 +01:00
"execution_count": 17,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "AtuP4xv7LF26"
},
"source": [
"import time\n",
"\n",
"def get_batch(texts):\n",
" X, y = [], []\n",
" size = len(texts)\n",
" for _ in range(BATCH_SIZE):\n",
" word_idx = None\n",
" while word_idx is None:\n",
" text_idx = np.random.randint(size)\n",
" text = texts[text_idx]\n",
" word_idx, context = get_random_word_with_contexts(text, CONTEXT_SIZE)\n",
" bow = corpus.get_bow(context, encode=False)\n",
" X.append(bow)\n",
" y.append(word_idx)\n",
" r = (np.array(X) / (CONTEXT_SIZE * 2)).astype(np.float32), np.array(y).astype(np.int64)\n",
" return r"
],
2021-01-12 23:00:18 +01:00
"execution_count": 18,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0LrDHSC-MF2g"
},
"source": [
"model = LanguageNeuralModel(len(corpus.corpus), 250)"
],
2021-01-12 23:00:18 +01:00
"execution_count": 19,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JhswE-B4MMBw"
},
"source": [
"model = model.to(device)"
],
2021-01-12 23:00:18 +01:00
"execution_count": 20,
2021-01-11 22:59:48 +01:00
"outputs": []
},
2021-01-12 23:00:18 +01:00
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "R9sS8bfVcjJ7",
"outputId": "cfd1403f-e070-4568-ff20-8e27689bab0f"
},
"source": [
"model.load_state_dict(torch.load(\"drive/MyDrive/model.pth\"))"
],
"execution_count": 21,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"metadata": {
"tags": []
},
"execution_count": 21
}
]
},
2021-01-11 22:59:48 +01:00
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6jsfrzQMOJHs",
2021-01-12 23:00:18 +01:00
"outputId": "f553b9bb-6bed-4425-f9bd-080feb6d7eb6"
2021-01-11 22:59:48 +01:00
},
"source": [
"model.train()"
],
2021-01-12 23:00:18 +01:00
"execution_count": 22,
2021-01-11 22:59:48 +01:00
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"LanguageNeuralModel(\n",
" (input): Linear(in_features=111418, out_features=250, bias=True)\n",
" (hidden): Linear(in_features=250, out_features=250, bias=True)\n",
" (output): Linear(in_features=250, out_features=111418, bias=True)\n",
")"
]
},
"metadata": {
"tags": []
},
2021-01-12 23:00:18 +01:00
"execution_count": 22
2021-01-11 22:59:48 +01:00
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "wgVAoHOjOZEP"
},
"source": [
"criterion = nn.CrossEntropyLoss().to(device)\n",
"optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)"
],
2021-01-12 23:00:18 +01:00
"execution_count": 23,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "32o0ZAtkOwzY"
},
"source": [
"import tqdm"
],
2021-01-12 23:00:18 +01:00
"execution_count": 24,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2021-01-12 23:00:18 +01:00
"height": 457,
2021-01-11 22:59:48 +01:00
"referenced_widgets": [
2021-01-12 23:00:18 +01:00
"0a62008eba914666a971f04ecfc5b04e",
"d7cb15cad3d44d5d9c13c280353d0621",
"5039d309c081494f98402f44fdab7f30",
"fe1aa54d6836412f8e56195267888e78",
"6eabcb1b506742b0bfd9d4704a2db177",
"796cd8889b1c49038132f58517885135",
"c7ef36717ed94d4087bb3c7cc86d7a36",
"0a098ddf0398429abbe7709d59774415"
2021-01-11 22:59:48 +01:00
]
},
"id": "0zYz4HDuO3mC",
2021-01-12 23:00:18 +01:00
"outputId": "5f6e7af6-eaf5-492e-9f61-ea8c79f91f57"
2021-01-11 22:59:48 +01:00
},
"source": [
"running_loss = 0.0\n",
"\n",
2021-01-12 23:00:18 +01:00
"for i in tqdm.tqdm_notebook(range(10000)):\n",
2021-01-11 22:59:48 +01:00
" X, y = get_batch(train_texts)\n",
" X, y = torch.from_numpy(X).to(device), torch.from_numpy(y).to(device)\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" outputs = model(X)\n",
" loss = criterion(outputs, y)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" running_loss += loss.item()\n",
" if i % 500 == 499:\n",
" torch.save(model.state_dict(), \"model.pth\")\n",
" print('[%d, %5d] loss: %.3f' %\n",
" (1, i + 1, running_loss / 500))\n",
" running_loss = 0.0"
],
2021-01-12 23:00:18 +01:00
"execution_count": 27,
2021-01-11 22:59:48 +01:00
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:3: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
"Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
" This is separate from the ipykernel package so we can avoid doing imports until\n"
],
"name": "stderr"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
2021-01-12 23:00:18 +01:00
"model_id": "0a62008eba914666a971f04ecfc5b04e",
2021-01-11 22:59:48 +01:00
"version_minor": 0,
"version_major": 2
},
"text/plain": [
2021-01-12 23:00:18 +01:00
"HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))"
2021-01-11 22:59:48 +01:00
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
2021-01-12 23:00:18 +01:00
"[1, 500] loss: 11.095\n",
"[1, 1000] loss: 11.138\n",
"[1, 1500] loss: 11.202\n",
"[1, 2000] loss: 11.237\n",
"[1, 2500] loss: 11.209\n",
"[1, 3000] loss: 11.261\n",
"[1, 3500] loss: 11.302\n",
"[1, 4000] loss: 11.303\n",
"[1, 4500] loss: 11.283\n",
"[1, 5000] loss: 11.305\n",
"[1, 5500] loss: 11.321\n",
"[1, 6000] loss: 11.348\n",
"[1, 6500] loss: 11.335\n",
"[1, 7000] loss: 11.272\n",
"[1, 7500] loss: 11.347\n",
"[1, 8000] loss: 11.320\n",
"[1, 8500] loss: 11.301\n",
"[1, 9000] loss: 11.307\n",
"[1, 9500] loss: 11.310\n",
"[1, 10000] loss: 11.274\n",
2021-01-11 22:59:48 +01:00
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Yoe-By2iQANV",
"colab": {
"base_uri": "https://localhost:8080/"
},
2021-01-12 23:00:18 +01:00
"outputId": "6ec17d91-fae4-4d71-b114-e04ae7871dc9"
2021-01-11 22:59:48 +01:00
},
"source": [
"model.eval()"
],
2021-01-12 23:00:18 +01:00
"execution_count": 25,
2021-01-11 22:59:48 +01:00
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"LanguageNeuralModel(\n",
" (input): Linear(in_features=111418, out_features=250, bias=True)\n",
" (hidden): Linear(in_features=250, out_features=250, bias=True)\n",
" (output): Linear(in_features=250, out_features=111418, bias=True)\n",
")"
]
},
"metadata": {
"tags": []
},
2021-01-12 23:00:18 +01:00
"execution_count": 25
2021-01-11 22:59:48 +01:00
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LX9xmKXwQdd7"
},
"source": [
"sets_to_eval = [\"drive/MyDrive/dev0/\", \"drive/MyDrive/dev1/\", \"drive/MyDrive/test/\"]"
],
2021-01-12 23:00:18 +01:00
"execution_count": 26,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "OmPYkEsHQ_QL"
},
"source": [
"def load_test_data(test_path, corpus):\n",
" texts = []\n",
" with open(test_path, \"r+\") as file:\n",
" while True:\n",
" line = file.readline()\n",
" if not line:\n",
" break\n",
"\n",
" _, _, left, right, *_ = line.split(\"\\t\")\n",
" texts.append(\n",
" (\n",
" remove_words_outside_corpus_and_encode(clean_text(left), corpus),\n",
" remove_words_outside_corpus_and_encode(clean_text(right), corpus)\n",
" )\n",
" )\n",
" print(f\"Loaded {len(texts)} texts from train_set.\")\n",
" return texts"
],
2021-01-12 23:00:18 +01:00
"execution_count": 27,
2021-01-11 22:59:48 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "6j2QUhPWSXyL",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 320,
"referenced_widgets": [
2021-01-12 23:00:18 +01:00
"dc2978e2f59c49caa9adf0c25c14b4a7",
"ce71a076d140449a895d92d8abb48ee3",
"af5a765b075b4b0c8bd44c7c0a52eb4d",
"b7725fc568634e3db37a3b88515b92fd",
"dc1cc51b8ddd4c7fb911c6f0ab512d72",
"e5ecdc7c36914e88b978c85a645d1f79",
"70727d3490584014b6cf5501cb77a106",
"cf04ed225c9e44e195bcd81940a8d9bc",
"b504a463e4cf460cb8379e93ab480d54",
"e2a75fdd4ce9485a93d1306155f229b7",
"2e7b6949536a4f89a29c3b9de6a50d53",
"880cdda5579e404a93ce1f18f892df65",
"a29e9ab8158d4f56a97c4dc71e164b57",
"b115341af9f244c08ed933f26ec9bb0a",
"e90a93719ca542e99825c35457212fd0",
"c2087c68702e412ca6f63ee004114256",
"5e23acee323f44649e4a8dd2bb9bdb7f",
"851974922d5849a6b675676e07465116",
"0a46fdf361e046bf9c6a1128abd693fa",
"92389827e8fa48d1b42b9963328df9d2",
"45705971d34c41018199482aaf8a5d2d",
"bdb0c1ff2eb743f284fdecbee5c46009",
"3b65b80af6dc4944839e2775fcb96e1f",
"04fd8b42561a4b55b5a19152a4065044"
2021-01-11 22:59:48 +01:00
]
},
2021-01-12 23:00:18 +01:00
"outputId": "3c257927-5809-4e5d-b16e-ed023d8362e5"
2021-01-11 22:59:48 +01:00
},
"source": [
"words = list(corpus.corpus)\n",
"\n",
"with torch.no_grad():\n",
" for path in sets_to_eval:\n",
" results = []\n",
2021-01-12 23:00:18 +01:00
" data = load_test_data(path + \"in.tsv\", corpus)\n",
2021-01-11 22:59:48 +01:00
" batch = []\n",
" for left, right in tqdm.tqdm_notebook(data):\n",
2021-01-12 23:00:18 +01:00
" context = left[-CONTEXT_SIZE:] + right[:CONTEXT_SIZE]\n",
" context = corpus.get_bow(context, encode=False)\n",
" batch.append(context)\n",
2021-01-11 22:59:48 +01:00
" if len(batch) < BATCH_SIZE:\n",
" continue\n",
2021-01-12 23:00:18 +01:00
" batch = (np.array(batch) / (2 * CONTEXT_SIZE)).astype(np.float32)\n",
2021-01-11 22:59:48 +01:00
" X = torch.from_numpy(batch).to(device)\n",
2021-01-12 23:00:18 +01:00
" out_all = F.softmax(model(X)).tolist()\n",
"\n",
" for pred_idx in range(BATCH_SIZE):\n",
" out = out_all[pred_idx]\n",
2021-01-11 22:59:48 +01:00
"\n",
2021-01-12 23:00:18 +01:00
" indexes = list(range(len(corpus.corpus)))\n",
" indexes = sorted(indexes, key=lambda x: out[x], reverse=True)\n",
2021-01-11 22:59:48 +01:00
"\n",
2021-01-12 23:00:18 +01:00
" with open(path + \"out.tsv\", \"a+\") as f:\n",
" res = \"\"\n",
" prob0 = 1.\n",
" for idx in indexes[:500]:\n",
" prob0 -= out[idx]\n",
" res += f\"{words[idx]}:{np.log(out[idx])} \"\n",
" res += f\":{np.log(prob0)}\\n\"\n",
" f.write(res)\n",
" batch = []"
2021-01-11 22:59:48 +01:00
],
2021-01-12 23:00:18 +01:00
"execution_count": 35,
2021-01-11 22:59:48 +01:00
"outputs": [
{
"output_type": "stream",
"text": [
"Loaded 19986 texts from train_set.\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:8: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
"Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
" \n"
],
"name": "stderr"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
2021-01-12 23:00:18 +01:00
"model_id": "dc2978e2f59c49caa9adf0c25c14b4a7",
2021-01-11 22:59:48 +01:00
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=19986.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:16: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" app.launch_new_instance()\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"Loaded 11628 texts from train_set.\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
2021-01-12 23:00:18 +01:00
"model_id": "b504a463e4cf460cb8379e93ab480d54",
2021-01-11 22:59:48 +01:00
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=11628.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n",
"Loaded 14132 texts from train_set.\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
2021-01-12 23:00:18 +01:00
"model_id": "5e23acee323f44649e4a8dd2bb9bdb7f",
2021-01-11 22:59:48 +01:00
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=14132.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZKLc8SZLt171"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}