diff --git a/3_RNN.ipynb b/3_RNN.ipynb index ab2646c..b268e30 100644 --- a/3_RNN.ipynb +++ b/3_RNN.ipynb @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 42, "metadata": {}, "outputs": [ { @@ -27,33 +27,33 @@ "Defaulting to user installation because normal site-packages is not writeable\n", "Requirement already satisfied: torch in /home/pawel/.local/lib/python3.10/site-packages (2.3.0)\n", "Requirement already satisfied: torchtext in /home/pawel/.local/lib/python3.10/site-packages (0.18.0)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (11.4.5.107)\n", + "Requirement already satisfied: jinja2 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (3.1.3)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (4.10.0)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (2.20.5)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.105)\n", - "Requirement already satisfied: sympy in /home/pawel/.local/lib/python3.10/site-packages (from torch) (1.12)\n", - "Requirement already satisfied: jinja2 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (3.1.3)\n", - "Requirement already satisfied: filelock in /home/pawel/.local/lib/python3.10/site-packages (from torch) (3.13.1)\n", - "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (11.0.2.54)\n", - "Requirement already satisfied: triton==2.3.0 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (2.3.0)\n", - "Requirement already satisfied: fsspec in /home/pawel/.local/lib/python3.10/site-packages (from torch) (2024.2.0)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.105)\n", - "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.0.106)\n", - "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (11.4.5.107)\n", - "Requirement already satisfied: typing-extensions>=4.8.0 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (4.10.0)\n", - "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (8.9.2.26)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.105)\n", - "Requirement already satisfied: networkx in /home/pawel/.local/lib/python3.10/site-packages (from torch) (3.3)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: filelock in /home/pawel/.local/lib/python3.10/site-packages (from torch) (3.13.1)\n", + "Requirement already satisfied: sympy in /home/pawel/.local/lib/python3.10/site-packages (from torch) (1.12)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: triton==2.3.0 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (2.3.0)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (8.9.2.26)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (11.0.2.54)\n", + "Requirement already satisfied: fsspec in /home/pawel/.local/lib/python3.10/site-packages (from torch) (2024.2.0)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.0.106)\n", + "Requirement already satisfied: networkx in /home/pawel/.local/lib/python3.10/site-packages (from torch) (3.3)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/pawel/.local/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.4.127)\n", - "Requirement already satisfied: tqdm in /home/pawel/.local/lib/python3.10/site-packages (from torchtext) (4.66.2)\n", "Requirement already satisfied: requests in /home/pawel/.local/lib/python3.10/site-packages (from torchtext) (2.31.0)\n", "Requirement already satisfied: numpy in /home/pawel/.local/lib/python3.10/site-packages (from torchtext) (1.26.4)\n", + "Requirement already satisfied: tqdm in /home/pawel/.local/lib/python3.10/site-packages (from torchtext) (4.66.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /home/pawel/.local/lib/python3.10/site-packages (from jinja2->torch) (2.1.5)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /home/pawel/.local/lib/python3.10/site-packages (from requests->torchtext) (2024.2.2)\n", - "Requirement already satisfied: idna<4,>=2.5 in /home/pawel/.local/lib/python3.10/site-packages (from requests->torchtext) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/pawel/.local/lib/python3.10/site-packages (from requests->torchtext) (2.2.1)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /home/pawel/.local/lib/python3.10/site-packages (from requests->torchtext) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/pawel/.local/lib/python3.10/site-packages (from requests->torchtext) (3.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/pawel/.local/lib/python3.10/site-packages (from requests->torchtext) (2024.2.2)\n", "Requirement already satisfied: mpmath>=0.19 in /home/pawel/.local/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n" ] } @@ -64,7 +64,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 43, "metadata": {}, "outputs": [], "source": [ @@ -85,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 44, "metadata": { "scrolled": true }, @@ -109,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -122,7 +122,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ @@ -131,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -140,7 +140,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 48, "metadata": {}, "outputs": [ { @@ -157,7 +157,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 49, "metadata": {}, "outputs": [ { @@ -166,7 +166,7 @@ "23627" ] }, - "execution_count": 18, + "execution_count": 49, "metadata": {}, "output_type": "execute_result" } @@ -177,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 50, "metadata": {}, "outputs": [ { @@ -186,7 +186,7 @@ "5" ] }, - "execution_count": 19, + "execution_count": 50, "metadata": {}, "output_type": "execute_result" } @@ -197,7 +197,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 51, "metadata": {}, "outputs": [ { @@ -206,7 +206,7 @@ "0" ] }, - "execution_count": 20, + "execution_count": 51, "metadata": {}, "output_type": "execute_result" } @@ -224,7 +224,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ @@ -233,7 +233,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 53, "metadata": {}, "outputs": [], "source": [ @@ -250,7 +250,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 54, "metadata": {}, "outputs": [], "source": [ @@ -268,7 +268,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 55, "metadata": {}, "outputs": [], "source": [ @@ -277,7 +277,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 56, "metadata": {}, "outputs": [], "source": [ @@ -286,7 +286,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 57, "metadata": {}, "outputs": [], "source": [ @@ -295,7 +295,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 58, "metadata": { "scrolled": true }, @@ -306,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 59, "metadata": {}, "outputs": [], "source": [ @@ -315,7 +315,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 60, "metadata": {}, "outputs": [], "source": [ @@ -331,7 +331,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 61, "metadata": {}, "outputs": [ { @@ -340,7 +340,7 @@ "tensor([ 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 3])" ] }, - "execution_count": 30, + "execution_count": 61, "metadata": {}, "output_type": "execute_result" } @@ -351,7 +351,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 62, "metadata": {}, "outputs": [ { @@ -372,7 +372,7 @@ " 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}" ] }, - "execution_count": 31, + "execution_count": 62, "metadata": {}, "output_type": "execute_result" } @@ -383,7 +383,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 63, "metadata": { "scrolled": true }, @@ -394,7 +394,7 @@ "tensor([0, 3, 0, 7, 0, 0, 0, 7, 0, 0, 0])" ] }, - "execution_count": 32, + "execution_count": 63, "metadata": {}, "output_type": "execute_result" } @@ -412,7 +412,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 64, "metadata": {}, "outputs": [], "source": [ @@ -464,7 +464,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 65, "metadata": {}, "outputs": [ { @@ -489,7 +489,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 66, "metadata": {}, "outputs": [], "source": [ @@ -517,7 +517,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 67, "metadata": {}, "outputs": [], "source": [ @@ -533,7 +533,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ @@ -549,7 +549,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 69, "metadata": {}, "outputs": [], "source": [ @@ -565,7 +565,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 70, "metadata": {}, "outputs": [], "source": [ @@ -593,7 +593,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 71, "metadata": {}, "outputs": [], "source": [ @@ -602,7 +602,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 72, "metadata": { "scrolled": false }, @@ -610,7 +610,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "6b09e994ea71476e8e8600f708aca2ea", + "model_id": "2f51ce06b51e454987745fdce27fcdca", "version_major": 2, "version_minor": 0 }, @@ -622,19 +622,164 @@ "output_type": "display_data" }, { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[41], line 16\u001b[0m\n\u001b[1;32m 13\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(predicted_tags\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m0\u001b[39m),tags\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m 15\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m---> 16\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 18\u001b[0m lstm\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28mprint\u001b[39m(eval_model(validation_tokens_ids, validation_labels, lstm))\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/optim/optimizer.py:391\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 386\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 387\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 388\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 389\u001b[0m )\n\u001b[0;32m--> 391\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 392\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 394\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/optim/optimizer.py:76\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 74\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[1;32m 75\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 76\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 78\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/optim/adam.py:168\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 157\u001b[0m beta1, beta2 \u001b[38;5;241m=\u001b[39m group[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbetas\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 159\u001b[0m has_complex \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_group(\n\u001b[1;32m 160\u001b[0m group,\n\u001b[1;32m 161\u001b[0m params_with_grad,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 165\u001b[0m max_exp_avg_sqs,\n\u001b[1;32m 166\u001b[0m state_steps)\n\u001b[0;32m--> 168\u001b[0m \u001b[43madam\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[43m \u001b[49m\u001b[43mparams_with_grad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 170\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 171\u001b[0m \u001b[43m \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 172\u001b[0m \u001b[43m \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 173\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 174\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mamsgrad\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 176\u001b[0m \u001b[43m \u001b[49m\u001b[43mhas_complex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhas_complex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 177\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 178\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 179\u001b[0m \u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mlr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 180\u001b[0m \u001b[43m \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mweight_decay\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 181\u001b[0m \u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43meps\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 182\u001b[0m \u001b[43m \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmaximize\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 183\u001b[0m \u001b[43m \u001b[49m\u001b[43mforeach\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mforeach\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 184\u001b[0m \u001b[43m \u001b[49m\u001b[43mcapturable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcapturable\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 185\u001b[0m \u001b[43m \u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mdifferentiable\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 186\u001b[0m \u001b[43m \u001b[49m\u001b[43mfused\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mfused\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 187\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgrad_scale\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 188\u001b[0m \u001b[43m \u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfound_inf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 189\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 191\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/optim/adam.py:318\u001b[0m, in \u001b[0;36madam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach, capturable, differentiable, fused, grad_scale, found_inf, has_complex, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize)\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 316\u001b[0m func \u001b[38;5;241m=\u001b[39m _single_tensor_adam\n\u001b[0;32m--> 318\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 319\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 320\u001b[0m \u001b[43m \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[43m \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 322\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 323\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 324\u001b[0m \u001b[43m \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mamsgrad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 325\u001b[0m \u001b[43m \u001b[49m\u001b[43mhas_complex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhas_complex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 326\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 327\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 328\u001b[0m \u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 329\u001b[0m \u001b[43m \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweight_decay\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 330\u001b[0m \u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 331\u001b[0m \u001b[43m \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmaximize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 332\u001b[0m \u001b[43m \u001b[49m\u001b[43mcapturable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcapturable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 333\u001b[0m \u001b[43m \u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdifferentiable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 334\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgrad_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 335\u001b[0m \u001b[43m \u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfound_inf\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/optim/adam.py:393\u001b[0m, in \u001b[0;36m_single_tensor_adam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, grad_scale, found_inf, amsgrad, has_complex, beta1, beta2, lr, weight_decay, eps, maximize, capturable, differentiable)\u001b[0m\n\u001b[1;32m 390\u001b[0m param \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mview_as_real(param)\n\u001b[1;32m 392\u001b[0m \u001b[38;5;66;03m# Decay the first and second moment running average coefficient\u001b[39;00m\n\u001b[0;32m--> 393\u001b[0m \u001b[43mexp_avg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlerp_\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgrad\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 394\u001b[0m exp_avg_sq\u001b[38;5;241m.\u001b[39mmul_(beta2)\u001b[38;5;241m.\u001b[39maddcmul_(grad, grad\u001b[38;5;241m.\u001b[39mconj(), value\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m beta2)\n\u001b[1;32m 396\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m capturable \u001b[38;5;129;01mor\u001b[39;00m differentiable:\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c47750eda30c48a7849a31dadd912572", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3250 [00:00