aitech-moj/wyk/12_Model_transormer_autoregresywny.ipynb
Jakub Pokrywka 01d9a34265 12
2022-06-05 22:32:44 +02:00

4665 lines
243 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
"<div class=\"alert alert-block alert-info\">\n",
"<h1> Modelowanie Języka</h1>\n",
"<h2> 12. <i>Model rekurencyjny z atencją</i> [ćwiczenia]</h2> \n",
"<h3> Jakub Pokrywka (2022)</h3>\n",
"</div>\n",
"\n",
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://huggingface.co/gpt2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: transformers in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (4.19.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (4.64.0)\n",
"Requirement already satisfied: numpy>=1.17 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (1.22.3)\n",
"Requirement already satisfied: requests in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (2.27.1)\n",
"Requirement already satisfied: packaging>=20.0 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (21.3)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (0.12.1)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (0.6.0)\n",
"Requirement already satisfied: filelock in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (3.7.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (2022.4.24)\n",
"Requirement already satisfied: pyyaml>=5.1 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (6.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.1.1)\n",
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from packaging>=20.0->transformers) (3.0.8)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from requests->transformers) (2020.6.20)\n",
"Requirement already satisfied: idna<4,>=2.5 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from requests->transformers) (3.3)\n",
"Requirement already satisfied: charset-normalizer~=2.0.0 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from requests->transformers) (2.0.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from requests->transformers) (1.26.9)\n"
]
}
],
"source": [
"!pip install transformers"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from transformers import pipeline, set_seed, AutoTokenizer, AutoModel, AutoModelForCausalLM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### przykładowy tekst"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"TEXT = 'Today, on my way to the university,'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## użycie modelu w bibliotece transormers"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"model_name = \"gpt2\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"w przypadku długiego czasu inferencji lub za małą ilością RAMu użyj mniejszego modelu:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# model_name = 'distilgpt2'"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"encoding = tokenizer(TEXT)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [8888, 11, 319, 616, 835, 284, 262, 6403, 11], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"encoding"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8888 \t Today\n",
"11 \t ,\n",
"319 \t on\n",
"616 \t my\n",
"835 \t way\n",
"284 \t to\n",
"262 \t the\n",
"6403 \t university\n",
"11 \t ,\n"
]
}
],
"source": [
"for token in encoding['input_ids']:\n",
" print(token, '\\t', tokenizer.decode(token))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"pt_model = AutoModel.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [8888, 11, 319, 616, 835, 284, 262, 6403, 11], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"encoding"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"poniżej pojawi się błąd, ponieważ na wejściu modelu muszą być tensory"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"scrolled": true
},
"outputs": [
{
"ename": "AttributeError",
"evalue": "'list' object has no attribute 'size'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [15]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mpt_model\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mencoding\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/anaconda3/envs/zajeciaei/lib/python3.10/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m~/anaconda3/envs/zajeciaei/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:769\u001b[0m, in \u001b[0;36mGPT2Model.forward\u001b[0;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 767\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot specify both input_ids and inputs_embeds at the same time\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 768\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 769\u001b[0m input_shape \u001b[38;5;241m=\u001b[39m \u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m()\n\u001b[1;32m 770\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, input_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 771\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n",
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'size'"
]
}
],
"source": [
"pt_model(**encoding)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Today, on my way to the university,'"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"TEXT"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"encoding = tokenizer(TEXT, return_tensors='pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"?pt_model.forward"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"output = pt_model(**encoding, output_hidden_states= True)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[ 0.0502, 0.0018, -0.1750, ..., -0.1020, -0.0257, -0.1292],\n",
" [ 0.1300, 0.1757, 0.2934, ..., 0.0794, 0.1164, -0.3280],\n",
" [ 0.0021, -0.2481, 0.2638, ..., 0.1507, 0.4056, 0.2376],\n",
" ...,\n",
" [ 0.1611, -0.4680, 0.7029, ..., 0.1209, 0.3803, 0.2864],\n",
" [ 0.1791, -0.3507, -1.2709, ..., -0.1535, -0.7109, -0.2459],\n",
" [ 0.2872, -0.0504, 0.0839, ..., 0.3417, -0.0518, -0.3151]]],\n",
" grad_fn=<ViewBackward0>), past_key_values=((tensor([[[[-7.0634e-01, 1.9011e+00, 7.7253e-01, ..., -1.3028e+00,\n",
" -5.0432e-01, 1.6823e+00],\n",
" [-1.6482e+00, 3.0222e+00, 1.2789e+00, ..., -9.0779e-01,\n",
" -1.7395e+00, 2.4237e+00],\n",
" [-2.3128e+00, 2.8957e+00, 1.8368e+00, ..., -7.0370e-01,\n",
" -1.6305e+00, 2.4407e+00],\n",
" ...,\n",
" [-2.4337e+00, 2.5271e+00, 2.1513e+00, ..., -5.8053e-01,\n",
" -1.6483e+00, 2.0594e+00],\n",
" [-3.8223e+00, 2.1391e+00, 1.7587e+00, ..., -1.0668e+00,\n",
" -1.6278e+00, 1.1729e+00],\n",
" [-1.9238e+00, 2.7944e+00, 1.6292e+00, ..., -8.9733e-01,\n",
" -2.2193e+00, 2.6272e+00]],\n",
"\n",
" [[-9.6153e-02, 8.9928e-01, -1.4324e+00, ..., -3.8667e-03,\n",
" 1.7698e+00, 6.0074e-01],\n",
" [ 2.7222e-01, -1.2016e+00, -1.9081e+00, ..., -1.3531e+00,\n",
" 1.2823e+00, -4.3198e-01],\n",
" [-1.1722e+00, -3.6670e-01, -1.6921e+00, ..., -1.2359e+00,\n",
" 2.5243e+00, 1.0228e+00],\n",
" ...,\n",
" [-1.6694e-01, -1.0159e+00, -2.5232e+00, ..., -9.7920e-01,\n",
" 4.8265e+00, -1.7799e+00],\n",
" [-1.1981e-01, -2.6784e+00, -2.9551e+00, ..., -1.9840e-01,\n",
" 3.3916e+00, -1.9762e-02],\n",
" [ 3.2722e-01, -1.2197e+00, -2.1079e+00, ..., -1.6297e+00,\n",
" 9.2404e-01, -7.6080e-01]],\n",
"\n",
" [[-1.4670e-01, 2.1407e-01, 1.1498e+00, ..., -1.3128e+00,\n",
" -2.1007e+00, 5.6910e-01],\n",
" [ 5.5608e-01, -4.6297e-01, 7.4483e-01, ..., -1.8272e+00,\n",
" 5.4572e-01, 1.0119e+00],\n",
" [ 9.2851e-01, 4.6049e-03, 4.1324e-01, ..., -2.4987e+00,\n",
" 5.2423e-01, 1.5260e+00],\n",
" ...,\n",
" [ 3.2328e-01, 3.5316e-01, 3.2756e-02, ..., -3.2780e+00,\n",
" 8.1692e-01, 1.4566e+00],\n",
" [-2.1528e-01, -2.2490e-01, -1.4536e+00, ..., -3.7075e+00,\n",
" 1.6835e+00, 1.6085e+00],\n",
" [ 7.6672e-01, -5.3757e-01, 4.2462e-01, ..., -2.2908e+00,\n",
" 1.7213e+00, 1.0240e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[ 5.4733e-01, 4.7672e-01, -2.2749e-01, ..., 2.9014e-01,\n",
" 7.7821e-01, 7.8295e-01],\n",
" [ 1.6820e-01, -9.1829e-02, -5.0034e-02, ..., 7.3646e-01,\n",
" 6.1343e-01, 5.4442e-01],\n",
" [ 2.9530e-02, -5.3167e-02, -6.1709e-02, ..., 1.0934e+00,\n",
" 3.7083e-01, 3.8425e-01],\n",
" ...,\n",
" [-1.3203e-02, -2.6465e-01, 4.4834e-02, ..., 1.2205e+00,\n",
" 5.4265e-01, 3.7732e-01],\n",
" [ 8.5854e-02, -2.3791e-01, -1.1271e-01, ..., 1.8211e+00,\n",
" -5.7249e-01, -7.4493e-01],\n",
" [-3.6544e-02, -1.4250e-01, 6.6582e-02, ..., 1.0489e+00,\n",
" 4.8485e-01, 4.6476e-01]],\n",
"\n",
" [[ 1.4700e+00, 1.3564e+00, -4.9892e-01, ..., -6.4925e-02,\n",
" 1.4507e+00, -1.2267e+00],\n",
" [ 1.0113e+00, 7.0108e-01, -5.7364e-01, ..., -7.1721e-01,\n",
" 1.0731e+00, -1.0718e+00],\n",
" [ 1.1010e+00, 4.8299e-01, -9.3231e-01, ..., -1.5044e+00,\n",
" 1.2941e+00, -3.3869e-01],\n",
" ...,\n",
" [ 1.1745e+00, 6.3323e-01, -6.1605e-01, ..., -8.1925e-01,\n",
" 5.2691e-01, -7.5443e-01],\n",
" [ 1.7895e+00, 5.7095e-01, -3.5775e-01, ..., -1.3193e+00,\n",
" 5.5676e-01, -1.6293e-01],\n",
" [ 9.6151e-01, 2.9245e-02, -5.3493e-01, ..., -7.8683e-01,\n",
" 3.7355e-01, -2.4032e-01]],\n",
"\n",
" [[ 7.1643e-01, -3.1278e-01, 1.4058e-01, ..., -2.0734e-01,\n",
" 2.5946e-01, 1.7684e+00],\n",
" [-5.6619e-01, 7.8687e-01, 2.5152e-02, ..., 6.2100e-01,\n",
" 4.7592e-01, 5.4321e-01],\n",
" [-6.2611e-01, 3.3320e-01, 1.1092e-01, ..., 6.4703e-01,\n",
" 6.4159e-01, 7.2777e-01],\n",
" ...,\n",
" [-1.7180e-01, 1.1778e+00, -2.3931e-01, ..., -6.3932e-01,\n",
" 1.1654e+00, 4.0462e-01],\n",
" [-4.8319e-01, 2.8237e-01, -4.4490e-01, ..., -1.2013e-01,\n",
" 4.8413e-01, -4.5133e-01],\n",
" [-1.1252e+00, 7.6533e-01, -6.0320e-02, ..., 1.8912e-01,\n",
" 7.8018e-01, -5.4733e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[ 0.1900, 0.0015, -0.0517, ..., 0.0536, 0.0312, -0.0694],\n",
" [-0.0800, 0.0181, -0.0534, ..., -0.0419, -0.0365, 0.0151],\n",
" [ 0.0448, 0.1912, -0.1849, ..., -0.0062, -0.1420, 0.1609],\n",
" ...,\n",
" [-0.1635, 0.0196, 0.1185, ..., 0.0794, 0.0980, -0.1084],\n",
" [-0.2303, 0.1991, -0.1576, ..., 0.2774, -0.1813, -0.2463],\n",
" [-0.1009, 0.0410, -0.0970, ..., -0.0684, -0.0763, 0.0260]],\n",
"\n",
" [[ 0.4406, 0.1176, -0.2136, ..., -0.6839, -0.2371, 0.2999],\n",
" [ 0.5926, 0.0197, 0.1107, ..., 0.1253, 0.5675, -0.2665],\n",
" [ 0.6762, 0.0459, -0.3685, ..., 0.0744, 0.5420, -0.1240],\n",
" ...,\n",
" [ 0.8509, -0.0962, 0.0762, ..., -0.1705, 0.1339, 0.1068],\n",
" [ 0.2928, -0.2582, 0.1735, ..., 0.0800, 0.2879, -0.0139],\n",
" [ 0.5969, 0.0592, 0.0263, ..., -0.0100, 0.5129, -0.1905]],\n",
"\n",
" [[ 0.0810, -0.1910, 0.1092, ..., -0.0283, 0.0408, 0.0961],\n",
" [-0.3257, 0.0398, -0.1531, ..., 0.0411, -0.0413, 0.0745],\n",
" [ 0.5201, 0.0126, 0.3504, ..., 0.1020, 0.0543, -0.2188],\n",
" ...,\n",
" [-0.5288, -0.0025, -0.5926, ..., -0.1874, -0.0674, 0.3113],\n",
" [ 0.1521, 0.0271, -0.2514, ..., -0.0465, -0.0565, -0.3401],\n",
" [-0.2885, 0.0590, -0.1736, ..., 0.0685, -0.1112, 0.0604]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.0111, -0.0168, 0.0263, ..., -0.2135, 0.2054, 0.0729],\n",
" [-0.3022, -0.0878, 0.1001, ..., 0.0262, -0.1647, 0.1682],\n",
" [-0.1587, -0.0666, 0.0826, ..., -0.0416, 0.0812, 0.2067],\n",
" ...,\n",
" [-0.0925, -0.4836, 0.0332, ..., 0.0641, -0.1597, 0.2375],\n",
" [-0.0742, 0.8589, 0.0336, ..., -0.3268, -0.2455, 0.3080],\n",
" [-0.0869, -0.4287, 0.1231, ..., -0.0474, -0.1705, 0.0347]],\n",
"\n",
" [[ 0.2081, -0.2399, -0.1318, ..., 0.1471, 0.1123, -0.0316],\n",
" [-0.2119, 0.0589, 0.0997, ..., 0.0038, 0.1331, 0.0930],\n",
" [-0.1213, 0.1404, 0.1775, ..., 0.1688, -0.0020, 0.0829],\n",
" ...,\n",
" [-0.2325, 0.1252, -0.0345, ..., 0.2837, 0.0686, -0.0089],\n",
" [ 0.1896, 0.0282, -0.0740, ..., 0.1655, -0.3020, 0.2837],\n",
" [ 0.0298, 0.0086, -0.1626, ..., 0.1976, 0.0970, -0.0014]],\n",
"\n",
" [[-0.0689, -0.3955, 0.2328, ..., 0.1539, -0.1823, -0.0845],\n",
" [ 0.0538, -0.2648, -0.0146, ..., 0.2331, 0.0516, 0.0924],\n",
" [-0.0647, 0.0062, 0.1329, ..., 0.1026, 0.1185, 0.0463],\n",
" ...,\n",
" [ 0.0186, 0.1904, -0.0966, ..., 0.0714, -0.0321, -0.0059],\n",
" [ 0.0219, 0.4180, -0.1580, ..., -0.0072, -0.2708, 0.1529],\n",
" [ 0.1236, -0.3671, -0.0392, ..., 0.1061, -0.0278, -0.0074]]]],\n",
" grad_fn=<PermuteBackward0>)), (tensor([[[[-3.5429e-01, 2.2092e+00, -1.5580e+00, ..., 1.4397e+00,\n",
" -1.1504e+00, 1.4646e+00],\n",
" [ 7.3885e-01, 1.8177e+00, -1.4766e+00, ..., -4.6761e-01,\n",
" -1.6869e+00, 5.0785e-01],\n",
" [ 1.6962e+00, 1.1427e+00, -1.1112e+00, ..., 1.2764e-01,\n",
" -2.5909e+00, 7.2933e-01],\n",
" ...,\n",
" [-1.9130e-03, 1.6441e+00, -3.0120e-01, ..., 3.8508e-01,\n",
" -1.0645e+00, -4.5135e-01],\n",
" [-3.9438e-01, 1.6005e+00, 9.6257e-01, ..., 5.8858e-01,\n",
" -1.8425e+00, -9.6318e-01],\n",
" [-4.9488e-01, 1.1094e+00, 5.2522e-02, ..., 5.6471e-01,\n",
" -1.3969e+00, -3.0882e-01]],\n",
"\n",
" [[-1.0087e+00, -4.5958e-01, -7.4797e-01, ..., -3.7310e-01,\n",
" 7.9809e-01, -2.3881e-01],\n",
" [-6.6438e-02, 4.8658e-01, -8.2457e-01, ..., -9.4308e-01,\n",
" 1.8907e-01, -1.5256e-02],\n",
" [-1.7392e-01, 1.1992e+00, -1.5513e+00, ..., -3.2774e-01,\n",
" 7.3627e-01, -3.6968e-01],\n",
" ...,\n",
" [-1.1986e-01, 6.0111e-01, -1.4226e+00, ..., -6.1346e-01,\n",
" 1.3460e-01, -6.1240e-01],\n",
" [ 1.8174e-01, 3.1973e-01, -2.2986e+00, ..., -4.1319e-01,\n",
" -1.0757e+00, -4.7605e-01],\n",
" [-2.4593e-01, 1.1035e+00, -1.4215e+00, ..., -6.2691e-01,\n",
" -1.1097e+00, -6.3956e-01]],\n",
"\n",
" [[ 3.2591e-01, -1.6143e-02, -2.0098e-01, ..., -1.3362e+00,\n",
" 3.3876e-01, -1.6542e-01],\n",
" [-1.0002e-02, 3.9666e-01, -9.3499e-02, ..., -1.0921e+00,\n",
" 5.6914e-02, 4.1318e-01],\n",
" [-1.1656e-02, 2.1262e-01, -2.3546e-01, ..., -9.7254e-01,\n",
" 1.4688e-01, 2.7869e-01],\n",
" ...,\n",
" [-8.3349e-02, 3.9433e-02, -9.7432e-03, ..., -7.0562e-01,\n",
" 4.2687e-01, 2.3274e-01],\n",
" [ 1.0450e-01, -2.0783e-01, -2.8860e-01, ..., -1.0073e+00,\n",
" -1.2179e-01, 3.5471e-01],\n",
" [-1.4484e-01, -5.0447e-02, -3.9541e-03, ..., -1.0255e+00,\n",
" 1.9039e-01, 3.3890e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[ 2.1528e-01, -4.6627e-01, -5.9642e-01, ..., -4.2178e-01,\n",
" 4.3739e-01, -8.5899e-01],\n",
" [-5.0305e-02, 1.2479e+00, 1.8768e+00, ..., 6.8503e-01,\n",
" -7.3186e-01, -3.4076e-01],\n",
" [-4.0512e-01, 1.6082e+00, 1.8570e+00, ..., 1.2636e+00,\n",
" -1.1781e+00, -8.1034e-01],\n",
" ...,\n",
" [-5.1299e-01, 2.6865e-01, 7.6903e-01, ..., -1.3940e+00,\n",
" 8.1194e-01, -1.8763e-01],\n",
" [ 2.3526e-01, -5.7615e-01, 1.3541e+00, ..., 1.4708e+00,\n",
" -2.9934e-01, -3.9407e-01],\n",
" [ 5.0755e-02, 7.0489e-01, 1.9166e+00, ..., 6.6883e-01,\n",
" -9.1450e-01, -2.5584e-01]],\n",
"\n",
" [[-1.1473e+00, -2.7966e+00, 1.4438e-01, ..., 1.7208e+00,\n",
" 1.5965e+00, -1.4860e+00],\n",
" [ 3.5231e-01, 7.5960e-01, -4.7429e-01, ..., -8.1442e-01,\n",
" 4.5442e-01, -2.9752e-01],\n",
" [ 2.1113e-01, 7.5264e-01, -4.5093e-01, ..., -9.6233e-01,\n",
" 5.8766e-01, 9.0545e-02],\n",
" ...,\n",
" [ 1.6897e-01, 2.5023e-01, -7.4581e-01, ..., -1.2799e-01,\n",
" 7.1349e-01, -8.5998e-02],\n",
" [-2.3828e-01, 5.9684e-01, -7.5936e-01, ..., -6.6564e-01,\n",
" 7.3313e-01, 1.8287e-01],\n",
" [-1.6440e-01, 2.5931e-01, -8.1777e-01, ..., -3.5322e-01,\n",
" 8.3564e-01, -5.9446e-02]],\n",
"\n",
" [[ 1.3976e+00, 1.6241e+00, 5.4245e-01, ..., -7.8420e-01,\n",
" 1.1678e-01, 3.7706e-01],\n",
" [ 8.8908e-01, 2.1345e+00, 1.0939e+00, ..., 1.1961e-01,\n",
" -7.5297e-01, -1.4081e-01],\n",
" [ 6.7893e-01, 1.8408e+00, 1.5060e+00, ..., 5.9498e-01,\n",
" -2.2553e+00, -1.8270e+00],\n",
" ...,\n",
" [-5.1015e-02, 2.4946e+00, -1.6883e-01, ..., 5.4761e-01,\n",
" -2.8891e-01, -6.7954e-01],\n",
" [-1.6942e-01, 4.9026e-01, 1.1144e+00, ..., 9.3912e-03,\n",
" -8.0171e-01, -1.4243e-01],\n",
" [ 8.4424e-01, 1.7401e+00, 9.2639e-01, ..., -1.4967e-01,\n",
" -3.8360e-01, -1.5520e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[ 3.3872e-01, 1.3968e-01, -1.7938e-01, ..., 1.5467e-01,\n",
" -1.2589e-01, 7.0887e-02],\n",
" [ 3.7346e-01, 2.8615e-01, 7.3073e-02, ..., -1.7334e-01,\n",
" -1.7929e-01, 8.0809e-02],\n",
" [ 1.3121e-01, 1.3779e-01, 9.8802e-02, ..., 1.7611e-01,\n",
" -6.5489e-01, -3.7171e-01],\n",
" ...,\n",
" [ 4.5774e-01, 6.2110e-02, 4.7204e-02, ..., 2.1876e-01,\n",
" -1.9506e-01, 1.5526e-01],\n",
" [-1.6503e-01, 7.2050e-02, -4.4076e-01, ..., 9.3966e-02,\n",
" -8.1660e-02, -2.9702e-01],\n",
" [ 3.7986e-01, 3.8336e-01, 1.0341e-01, ..., -1.9899e-01,\n",
" -2.3373e-01, -1.3201e-01]],\n",
"\n",
" [[-7.9321e-02, -6.6966e-02, -2.2227e-01, ..., -1.4152e-02,\n",
" -4.5964e-01, 2.7340e-01],\n",
" [-2.0632e-01, -2.7675e-01, 9.3918e-02, ..., -9.7495e-02,\n",
" 2.0266e-01, 3.4913e-02],\n",
" [-3.6562e-01, -2.8439e-01, 2.9782e-01, ..., -1.0605e+00,\n",
" 2.7564e-01, 3.3809e-01],\n",
" ...,\n",
" [ 5.1779e-01, 2.3170e-01, -3.0248e-01, ..., 4.6880e-01,\n",
" 4.3330e-01, -6.2105e-01],\n",
" [-1.9805e-01, 6.8445e-02, -5.7586e-02, ..., 1.3844e-01,\n",
" -6.2666e-02, 1.8667e-01],\n",
" [ 6.9782e-02, -1.5278e-01, 6.9243e-02, ..., -1.0944e-01,\n",
" 1.1224e-01, 1.1524e-01]],\n",
"\n",
" [[ 7.9376e-02, -1.4863e-02, -4.4028e-02, ..., -6.2825e-01,\n",
" 6.7840e-02, 1.0440e-02],\n",
" [ 4.2720e-01, 2.4379e-01, 2.3040e-01, ..., -5.0812e-01,\n",
" 3.7279e-02, -1.3192e-01],\n",
" [ 6.2018e-01, 1.7793e-01, 2.9474e-01, ..., -7.6162e-01,\n",
" -2.8552e-01, -1.4080e-01],\n",
" ...,\n",
" [ 5.8184e-01, 5.9326e-02, 2.5048e-03, ..., -6.1473e-01,\n",
" -3.0034e-02, 4.4224e-02],\n",
" [ 6.7462e-01, 1.3863e-01, -5.1645e-02, ..., -5.6261e-01,\n",
" -2.2474e-01, -1.2376e-01],\n",
" [ 6.0415e-01, 9.6460e-02, 1.1331e-01, ..., -2.8026e-01,\n",
" 2.4650e-02, -2.4321e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[ 1.0567e-01, 6.7946e-01, -1.7619e-01, ..., 1.2480e-02,\n",
" -9.7338e-01, -2.5708e-01],\n",
" [-5.0101e-04, -7.4670e-01, 1.4215e-01, ..., 2.6520e-02,\n",
" -9.1824e-01, -4.4347e-01],\n",
" [ 5.7162e-02, -6.6084e-01, -1.7225e-01, ..., -6.7773e-02,\n",
" -6.9370e-01, 2.2682e-01],\n",
" ...,\n",
" [ 3.6897e-01, 4.0040e-01, 1.3203e-01, ..., 5.9832e-02,\n",
" -4.3946e-01, 3.3851e-02],\n",
" [-1.9931e-01, 4.7522e-01, 6.5326e-01, ..., 8.5060e-01,\n",
" -1.5948e-01, 2.6952e-01],\n",
" [ 4.5483e-02, -7.9412e-01, 2.0943e-01, ..., 6.4299e-02,\n",
" -6.5777e-01, -2.0458e-01]],\n",
"\n",
" [[ 4.7333e-02, -1.1130e-02, -1.4608e-01, ..., 3.8364e-01,\n",
" -3.4244e+00, 6.6758e-02],\n",
" [ 5.0051e-01, 8.4673e-03, 1.9747e-01, ..., 2.1474e-01,\n",
" -7.4449e-03, -2.8373e-01],\n",
" [-2.0428e-01, 2.4512e-01, -2.7017e-01, ..., 4.5577e-02,\n",
" 2.1612e-02, -1.3106e-01],\n",
" ...,\n",
" [ 7.3244e-02, -1.5794e-01, 1.7578e-01, ..., -2.2690e-01,\n",
" -6.3669e-02, -1.8729e-02],\n",
" [ 1.3369e-01, 4.0795e-01, -6.9403e-02, ..., -2.8477e-02,\n",
" 8.1580e-02, -3.7645e-01],\n",
" [ 3.2948e-01, 2.4525e-01, 3.1002e-02, ..., 1.4547e-03,\n",
" -2.0459e-01, -1.3566e-02]],\n",
"\n",
" [[ 2.4439e-02, -2.3092e-01, 1.1163e-02, ..., -3.4285e-01,\n",
" 2.7007e-01, -3.4211e-02],\n",
" [ 2.0095e-01, -4.9356e-01, 5.3058e-01, ..., -2.7157e-01,\n",
" 4.2807e-01, 3.2917e-01],\n",
" [-1.0993e-01, -4.1360e-01, 1.9816e-02, ..., -1.7917e-01,\n",
" 3.6033e-01, 2.2954e-01],\n",
" ...,\n",
" [ 4.2263e-02, 1.5875e-02, -3.0871e-01, ..., -3.1441e-01,\n",
" 2.9030e-01, 2.2213e-01],\n",
" [-4.9536e-02, 8.3578e-02, 7.2786e-02, ..., -2.5493e-01,\n",
" 4.7891e-02, 3.4251e-01],\n",
" [ 5.0301e-02, -1.8544e-01, 5.7551e-01, ..., -3.4349e-01,\n",
" 1.5927e-01, 4.2942e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-1.5217e-01, -1.1477e+00, 2.3295e-01, ..., -6.4279e-01,\n",
" -1.1349e-01, 4.0799e-02],\n",
" [ 4.5919e-01, -2.0374e+00, -7.9378e-01, ..., 4.4668e-02,\n",
" -8.8579e-01, -9.0097e-01],\n",
" [ 3.8866e-01, -1.6082e+00, -3.9608e-01, ..., 3.1908e-01,\n",
" -4.2160e-01, -1.1912e-01],\n",
" ...,\n",
" [-8.1627e-02, -1.0257e+00, -6.6449e-01, ..., 6.6261e-01,\n",
" -1.8242e-01, -5.9660e-02],\n",
" [ 9.9366e-01, -2.8990e+00, -4.2770e-01, ..., 1.5473e+00,\n",
" -2.7730e-01, 1.0212e+00],\n",
" [ 3.7402e-01, -1.2451e+00, -8.3321e-01, ..., 1.5307e+00,\n",
" -6.0831e-01, -1.0434e+00]],\n",
"\n",
" [[-5.0563e-01, 3.4884e-01, -4.0126e-01, ..., 1.2945e+00,\n",
" -5.5872e-01, -4.4031e-01],\n",
" [-1.0783e+00, -1.0583e+00, -8.7019e-01, ..., 9.3939e-01,\n",
" 6.1988e-01, -3.6133e-01],\n",
" [-1.4605e+00, 7.9834e-04, -1.6445e+00, ..., 8.5405e-01,\n",
" 1.1266e+00, 2.1244e-01],\n",
" ...,\n",
" [-1.7653e+00, -4.5490e-01, 5.8049e-01, ..., 1.3604e-01,\n",
" -2.6502e-01, 1.4497e+00],\n",
" [-2.7539e+00, -1.9189e+00, -6.1803e-01, ..., 2.3083e+00,\n",
" -6.2625e-01, -5.0954e-01],\n",
" [-8.4786e-01, -9.9176e-01, -1.4226e+00, ..., 1.0424e+00,\n",
" 1.2138e+00, -6.2367e-01]],\n",
"\n",
" [[ 1.3477e+00, 3.0343e+00, 3.7258e+00, ..., 6.1286e-01,\n",
" 1.7142e+00, -7.4960e-01],\n",
" [-3.4424e+00, 2.1578e+00, -3.4773e+00, ..., -1.7704e+00,\n",
" 3.4858e+00, 9.8086e-01],\n",
" [-3.3403e+00, 7.3066e-01, -4.6132e+00, ..., -3.2065e+00,\n",
" 5.3039e+00, 7.1677e-01],\n",
" ...,\n",
" [-4.8998e+00, -5.9784e-01, -2.9574e+00, ..., -4.1010e+00,\n",
" 2.4786e+00, 2.7664e-02],\n",
" [-3.3274e+00, -1.2454e+00, -5.1031e+00, ..., -3.2964e+00,\n",
" 3.3057e+00, 1.4853e+00],\n",
" [-4.2024e+00, -1.7287e+00, -5.1702e+00, ..., -2.7123e+00,\n",
" 2.8922e+00, 1.8391e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[ 1.3818e+00, -2.7867e+00, -2.6519e+00, ..., 9.1555e-01,\n",
" 4.4077e-01, 2.7028e+00],\n",
" [-2.4026e+00, 1.6620e+00, -4.5219e-01, ..., 1.2064e-01,\n",
" -1.6484e+00, 5.6717e-01],\n",
" [-1.7379e+00, 2.8888e+00, 2.1535e-01, ..., -7.8397e-02,\n",
" -2.7045e+00, -3.0823e-03],\n",
" ...,\n",
" [-2.9426e+00, 3.5565e+00, 1.0280e+00, ..., -3.5420e-01,\n",
" -3.7917e+00, -7.8773e-01],\n",
" [-2.8640e+00, 2.8314e+00, 2.3865e+00, ..., -2.2468e+00,\n",
" -4.0705e+00, -1.2861e+00],\n",
" [-3.9137e+00, 4.3675e+00, 1.5171e+00, ..., -6.0161e-01,\n",
" -2.7414e+00, -1.2265e+00]],\n",
"\n",
" [[ 1.7415e+00, 4.5990e-01, 9.3163e-01, ..., 1.2650e-03,\n",
" -9.8961e-01, -2.9552e-01],\n",
" [ 2.2626e+00, 1.0377e+00, 1.1163e+00, ..., 3.4995e-01,\n",
" -2.5767e+00, -1.2164e+00],\n",
" [ 2.0896e+00, 6.8649e-01, 1.2068e+00, ..., 4.1762e-01,\n",
" -2.1005e+00, -1.2765e+00],\n",
" ...,\n",
" [ 1.8625e+00, 5.6272e-01, 1.1284e+00, ..., 3.5132e-01,\n",
" -2.0787e+00, -1.0202e+00],\n",
" [ 2.2705e+00, 3.2166e-01, 1.1907e+00, ..., 2.6156e-01,\n",
" -1.2966e+00, -9.9152e-01],\n",
" [ 2.3024e+00, 4.0813e-01, 9.6441e-01, ..., 4.9377e-01,\n",
" -2.5960e+00, -6.9144e-01]],\n",
"\n",
" [[-2.2407e-01, 1.4293e-01, -5.5406e-01, ..., 3.1676e-01,\n",
" 2.7494e-01, 1.6436e-01],\n",
" [-5.7508e-01, 6.1265e-01, -2.6713e-01, ..., 8.0278e-01,\n",
" 8.5041e-01, 1.8214e-01],\n",
" [ 6.2629e-01, 3.5029e-02, 8.6408e-02, ..., 4.6667e-01,\n",
" 1.6070e-01, 1.2988e-01],\n",
" ...,\n",
" [ 1.5542e-01, -2.5139e-01, -8.1318e-01, ..., 2.1838e-01,\n",
" 2.0266e-01, 6.9734e-01],\n",
" [-2.4867e-01, 4.2143e-01, -4.6590e-01, ..., 3.0348e-01,\n",
" 5.7653e-01, -5.7979e-01],\n",
" [-4.1779e-01, -4.9530e-01, -6.0749e-01, ..., 5.8660e-01,\n",
" 9.1405e-01, -3.4966e-02]]]], grad_fn=<PermuteBackward0>), tensor([[[[-1.5059e-02, -2.1934e-02, -1.3257e-01, ..., -3.3233e-03,\n",
" 5.6872e-03, -5.5921e-01],\n",
" [-4.4076e-01, 4.7031e-01, -2.1116e-01, ..., 5.7315e-01,\n",
" -3.8024e-01, 2.5338e-01],\n",
" [ 2.7640e-01, 1.0290e-01, -1.5030e-01, ..., 8.0443e-02,\n",
" -1.0340e-02, 6.5651e-01],\n",
" ...,\n",
" [ 7.7904e-01, 1.2082e+00, 3.0358e-01, ..., 4.4578e-01,\n",
" -4.0582e-02, 8.5044e-01],\n",
" [-2.0731e-01, -5.8119e-01, 4.1100e-01, ..., -1.7157e-01,\n",
" 2.8487e-01, 6.4911e-01],\n",
" [-8.6411e-01, 5.4967e-01, -4.1298e-01, ..., 9.2813e-01,\n",
" -4.2606e-01, -3.4161e-01]],\n",
"\n",
" [[ 3.8557e-02, 3.3662e-03, 5.4482e-02, ..., -5.7578e-02,\n",
" -7.4123e-02, 2.2392e-02],\n",
" [ 1.9386e-01, 1.8534e-01, 3.0680e-01, ..., -1.2764e-03,\n",
" -2.5348e-01, 8.6118e-02],\n",
" [-1.4242e-01, 3.2992e-01, 7.6395e-02, ..., 9.8633e-02,\n",
" -5.6915e-02, 4.4799e-02],\n",
" ...,\n",
" [-7.1944e-02, 3.8884e-02, 1.0161e-01, ..., -2.7253e-01,\n",
" 1.3398e-01, 1.1796e-01],\n",
" [-1.0896e+00, 2.1403e+00, -1.3890e-01, ..., 1.0035e+00,\n",
" 6.1333e-01, -1.1536e+00],\n",
" [ 6.1611e-02, 7.1527e-02, 2.0043e-01, ..., -3.5723e-01,\n",
" -1.4230e-01, 8.4502e-02]],\n",
"\n",
" [[ 1.1201e-02, -7.6654e-01, -1.1583e-02, ..., 4.3143e-02,\n",
" 1.5736e-02, -5.8100e-02],\n",
" [ 2.8462e-01, -1.0610e+00, 1.2486e-01, ..., 3.1588e-02,\n",
" -1.1913e-01, -4.8153e-02],\n",
" [ 2.6008e-01, -6.3008e-01, -8.1709e-01, ..., 1.8586e-01,\n",
" 3.4370e-01, 9.2477e-01],\n",
" ...,\n",
" [-1.9891e-01, -1.9001e+00, -4.4621e-02, ..., 7.8242e-02,\n",
" 2.2361e-02, 1.3589e-02],\n",
" [-2.8968e-01, -1.5899e+00, 9.2801e-02, ..., -2.7827e-01,\n",
" 1.6159e-01, -4.6007e-01],\n",
" [ 1.6971e-01, -1.5136e+00, 1.2845e-01, ..., -6.2768e-02,\n",
" -2.5769e-01, -1.5622e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[ 1.6522e-02, -7.7326e-02, 1.3163e+00, ..., -5.6423e-02,\n",
" 1.7141e-01, 2.1386e-02],\n",
" [-4.3988e-01, -2.9255e-01, 2.4116e+00, ..., -1.8846e-01,\n",
" 1.0912e-01, 1.4147e-01],\n",
" [ 2.3190e-01, -1.5369e-01, 2.5701e+00, ..., 6.3039e-01,\n",
" -1.0088e-01, 5.1586e-01],\n",
" ...,\n",
" [ 1.7250e-02, 6.7580e-01, 2.5971e+00, ..., -5.2273e-01,\n",
" 4.5050e-01, -6.9956e-01],\n",
" [-4.9545e-02, 6.1819e-01, 3.8825e-01, ..., -1.4691e-01,\n",
" 4.5526e-01, 7.1271e-01],\n",
" [-1.9639e-01, -1.2515e-01, 2.5813e+00, ..., -1.8536e-01,\n",
" -1.3485e-01, -8.7375e-02]],\n",
"\n",
" [[ 7.4395e-02, -8.7165e-02, -1.8260e-01, ..., 1.3185e-01,\n",
" 1.2575e-01, 1.7169e-01],\n",
" [ 6.5960e-01, 1.0117e+00, 7.1659e-01, ..., 8.3512e-02,\n",
" -6.5585e-01, -3.3111e-01],\n",
" [ 3.2666e-01, -1.2571e-01, 8.1719e-01, ..., 9.9527e-01,\n",
" -1.0291e+00, -5.0537e-01],\n",
" ...,\n",
" [-7.2666e-01, 1.0662e-01, -7.2195e-02, ..., -2.7005e-01,\n",
" 5.2628e-01, 2.3005e-01],\n",
" [-2.0959e-01, -2.3959e-01, -3.0772e-01, ..., 4.6964e-01,\n",
" -1.8979e-01, -2.7418e-01],\n",
" [ 1.7468e-01, 1.0415e+00, 7.5772e-01, ..., -4.9262e-01,\n",
" -8.0868e-01, 4.5074e-01]],\n",
"\n",
" [[ 1.1606e-02, 2.1828e-02, 2.7971e-02, ..., -3.3218e-02,\n",
" 2.2172e-01, -2.3344e-03],\n",
" [ 1.1778e-01, -3.0263e-01, 3.5408e-01, ..., -3.3052e-01,\n",
" -1.9086e+00, 4.3385e-01],\n",
" [-7.0245e-01, 4.2293e-02, -1.3216e-01, ..., 3.4737e-01,\n",
" -1.4905e+00, 3.5105e-01],\n",
" ...,\n",
" [ 2.1967e-01, -6.0979e-01, -6.8996e-01, ..., 4.4944e-01,\n",
" -1.9601e+00, -1.7819e-01],\n",
" [ 3.8903e-01, 1.9728e-01, -9.0256e-01, ..., 1.3781e-01,\n",
" -2.0059e+00, 3.0071e-01],\n",
" [ 5.9661e-01, -3.1890e-01, -2.2125e-01, ..., 2.8531e-01,\n",
" -1.8048e+00, 2.1086e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[ 0.0532, -0.2197, 0.1445, ..., -0.8884, 0.7361, -1.2044],\n",
" [-0.6462, -0.7026, -1.4285, ..., 0.2179, -0.3014, 0.1623],\n",
" [-0.8909, -1.9166, -1.3314, ..., -1.8027, -2.7636, 2.9528],\n",
" ...,\n",
" [-1.9169, -0.2602, -0.2397, ..., -0.4901, -0.8816, 0.7061],\n",
" [-2.0792, 0.1064, 0.6011, ..., 0.5948, -0.5403, 1.4379],\n",
" [-0.4271, -0.4968, -0.0297, ..., 1.0395, -0.3829, 0.3067]],\n",
"\n",
" [[ 0.7842, 0.1905, 0.0089, ..., -0.1612, -1.0898, -0.1939],\n",
" [-1.3909, -1.5235, -0.5037, ..., 0.9582, 4.2044, 1.1825],\n",
" [ 0.1689, -1.8025, 0.8404, ..., 1.5177, 5.7815, 2.1470],\n",
" ...,\n",
" [ 1.2462, -1.4013, -1.2263, ..., 0.5912, 6.0711, 1.7328],\n",
" [ 1.4548, -2.0760, -2.0483, ..., -1.5971, 5.6172, 2.5548],\n",
" [-1.1053, -0.8554, -2.0471, ..., 0.8743, 6.2095, 1.1606]],\n",
"\n",
" [[ 0.3413, -0.3572, -0.3331, ..., 0.3294, 1.4604, 0.2755],\n",
" [ 0.0960, -6.2139, -0.6779, ..., -2.8446, -1.4388, -4.4836],\n",
" [-0.8714, -7.8835, -1.6969, ..., -2.1200, -2.1704, -7.2160],\n",
" ...,\n",
" [-3.2255, -7.0802, -1.8176, ..., -2.8620, -2.7388, -5.1880],\n",
" [-2.2788, -5.5723, -1.6649, ..., -3.3594, -2.4676, -5.1028],\n",
" [-2.9788, -7.2411, -1.0434, ..., -3.2540, -2.9263, -5.0116]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.2148, 1.7719, 0.5129, ..., 0.2612, 0.4477, -1.6895],\n",
" [-0.2874, -5.8026, 1.1293, ..., -2.2826, -1.7007, 5.5452],\n",
" [-2.4104, -6.5778, 1.1952, ..., -2.4193, -0.3969, 3.8159],\n",
" ...,\n",
" [-1.4026, -7.7514, 1.2659, ..., -3.4256, -2.3786, 6.9488],\n",
" [-1.0623, -5.7453, 0.1012, ..., -0.5622, -2.4292, 6.8565],\n",
" [-0.3079, -7.9204, 1.8029, ..., -3.2453, -2.3462, 7.0537]],\n",
"\n",
" [[ 0.0559, -0.0269, 0.1386, ..., -0.1165, -0.0882, -0.1612],\n",
" [ 0.1342, -0.5329, -0.2255, ..., -1.0159, 0.1003, -0.4600],\n",
" [-0.7412, -0.2755, 0.1787, ..., -0.8159, -0.9071, -0.1041],\n",
" ...,\n",
" [-0.0215, -0.5192, -0.2004, ..., 0.3272, -0.3216, 0.5758],\n",
" [ 0.2406, -0.3252, 0.3839, ..., -0.2115, 0.3593, -0.6457],\n",
" [-0.6898, -1.1861, -0.0238, ..., 0.5217, 0.0940, 0.9089]],\n",
"\n",
" [[ 0.3939, -0.0741, 1.9091, ..., -0.2314, -0.2112, -0.9825],\n",
" [ 2.5678, 1.8706, -2.0184, ..., 0.0582, 0.5182, 2.5282],\n",
" [ 3.1803, 2.0001, -2.9358, ..., 2.6552, 1.0590, 4.2195],\n",
" ...,\n",
" [ 2.6593, 1.2215, -2.5623, ..., 1.4338, 0.6112, 3.2894],\n",
" [ 1.1448, 0.9766, -2.1789, ..., 1.8788, 0.3242, 3.7226],\n",
" [ 2.8828, 1.7918, -3.5229, ..., -0.0936, 0.5881, 4.5368]]]],\n",
" grad_fn=<PermuteBackward0>), tensor([[[[ 4.5583e-02, 6.3733e-02, -3.4908e-03, ..., 3.8117e-03,\n",
" 1.0385e-01, 2.3468e-02],\n",
" [ 2.7193e-01, -2.7436e-01, 8.2051e-01, ..., -5.7602e-01,\n",
" -6.8246e-02, -1.7190e-02],\n",
" [-3.3605e-01, -9.2270e-01, -2.9339e-01, ..., 3.2747e-02,\n",
" 5.3266e-01, -1.1793e+00],\n",
" ...,\n",
" [ 4.6659e-01, 1.6959e-01, -6.8990e-02, ..., 3.2092e-01,\n",
" 2.9894e-02, 4.5212e-03],\n",
" [-3.8838e-01, -2.8303e-01, 6.4867e-01, ..., 5.4443e-01,\n",
" -3.8750e-03, -7.7317e-01],\n",
" [-1.4669e-01, -2.2234e-01, 5.0309e-01, ..., -2.0195e-01,\n",
" -3.4870e-02, 1.0260e+00]],\n",
"\n",
" [[-3.8964e-02, -8.6139e-03, 9.1636e-02, ..., -4.5061e-02,\n",
" -1.8257e-02, -4.4496e-02],\n",
" [-1.2398e-01, -4.6354e-01, 6.3162e-02, ..., 4.1472e-01,\n",
" -8.8383e-02, -6.1835e-02],\n",
" [ 2.3124e-01, -4.1944e-01, -5.5628e-02, ..., -6.5586e-01,\n",
" -2.9434e-01, 1.1322e-01],\n",
" ...,\n",
" [ 9.0615e-02, -2.5366e-01, -1.7453e-01, ..., 3.6981e-02,\n",
" 9.6252e-02, 2.8861e-01],\n",
" [ 2.6449e-01, -1.1997e+00, -2.9121e-01, ..., 1.8929e-01,\n",
" 8.9705e-01, 5.2265e-02],\n",
" [ 1.8653e-01, -4.1886e-01, -2.5386e-01, ..., 5.6907e-01,\n",
" -5.6461e-01, -2.9499e-01]],\n",
"\n",
" [[ 4.5761e-02, -1.1113e-01, -6.0327e-02, ..., -1.7311e-02,\n",
" 8.8352e-02, -1.4918e-01],\n",
" [ 3.5832e-01, 1.0048e-01, -3.5981e-01, ..., 4.7004e-01,\n",
" -1.0480e-01, -9.6169e-01],\n",
" [-1.2025e+00, -4.9562e-01, -5.6530e-01, ..., -7.7073e-02,\n",
" -1.8603e-01, 4.5677e-02],\n",
" ...,\n",
" [-1.1527e-01, -1.2046e-02, 7.9755e-01, ..., 2.0678e-01,\n",
" -1.6562e-01, -9.4135e-02],\n",
" [ 3.0203e-01, -5.3025e-02, 1.0025e-01, ..., -1.3117e-01,\n",
" -3.9940e-01, 2.0309e-01],\n",
" [ 5.4948e-01, -3.1714e-03, -9.9666e-01, ..., 3.6800e-01,\n",
" 2.6345e-01, -6.6638e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-2.2513e-02, 1.1954e-01, -1.7875e-02, ..., -1.4198e-02,\n",
" 6.4433e-02, -5.2401e-02],\n",
" [-6.7643e-03, 1.3038e-01, -3.1770e-02, ..., -2.8075e-02,\n",
" -7.0123e-02, 2.9359e-01],\n",
" [ 7.8513e-01, -7.9053e-01, -1.5511e-01, ..., -3.0193e-01,\n",
" -5.3295e-02, 5.1889e-01],\n",
" ...,\n",
" [-1.9707e-01, 5.0177e-02, -1.1185e-01, ..., -3.0111e-01,\n",
" 2.1017e-01, -2.7775e-01],\n",
" [-3.1374e-01, -3.5912e-02, -2.5133e-01, ..., -1.2073e-01,\n",
" 1.3938e-01, -1.4568e-01],\n",
" [-1.2432e-01, 3.0442e-01, 1.0542e-01, ..., 2.1967e-02,\n",
" 3.2316e-02, 1.2676e-01]],\n",
"\n",
" [[-1.7366e-01, -1.3407e-01, -6.7815e-02, ..., -2.3521e-01,\n",
" -1.8675e-02, -5.1927e-02],\n",
" [ 4.8318e-01, -4.9988e-01, 7.3483e-01, ..., 1.7037e-01,\n",
" 6.2192e-01, 2.3596e-01],\n",
" [ 1.1730e-01, 3.0694e-02, 7.3273e-01, ..., 5.0575e-01,\n",
" 3.1356e-02, -5.0081e-01],\n",
" ...,\n",
" [ 6.1899e-01, -9.2282e-01, 1.6701e-01, ..., -2.4323e-02,\n",
" 1.7694e-01, -3.4102e-01],\n",
" [ 8.4867e-01, 1.2311e-01, 3.3463e-01, ..., 3.2204e-01,\n",
" 8.6678e-01, 5.9980e-01],\n",
" [ 4.1040e-01, 1.7545e-01, 2.0518e-01, ..., -9.3810e-01,\n",
" 4.8850e-01, -5.4087e-01]],\n",
"\n",
" [[ 1.1481e-01, -7.4767e-02, -2.5446e-02, ..., -1.8679e-02,\n",
" -9.1254e-02, -9.6947e-02],\n",
" [ 5.5079e-01, 1.9193e-01, 1.8251e-04, ..., -1.0992e-02,\n",
" -2.6968e-01, -3.8421e-02],\n",
" [-1.8607e-01, -8.5692e-02, 3.1742e-01, ..., -3.9823e-01,\n",
" 4.3919e-01, -8.0165e-02],\n",
" ...,\n",
" [-1.6626e-01, -1.0646e+00, -1.0149e-02, ..., -9.7871e-02,\n",
" 1.4443e-01, -1.5419e-01],\n",
" [-4.4313e-01, -1.3310e-01, 4.2125e-01, ..., 4.0301e-02,\n",
" -1.7659e-01, 3.1838e-01],\n",
" [ 8.1519e-01, 2.4844e-01, 1.2036e-01, ..., -9.9506e-02,\n",
" -2.9214e-01, 5.8580e-02]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-8.8678e-01, -1.3593e-01, 3.3093e-01, ..., -9.5576e-01,\n",
" 2.5192e-02, -2.9464e+00],\n",
" [ 1.5584e+00, -5.3821e-01, -2.4421e+00, ..., -1.7774e+00,\n",
" -1.4069e+00, 7.2554e+00],\n",
" [-6.1613e-02, -9.6410e-01, -3.3367e+00, ..., -1.7228e+00,\n",
" -7.6467e+00, 6.5063e+00],\n",
" ...,\n",
" [ 1.2191e+00, -3.6478e-01, -1.8077e+00, ..., -1.4126e+00,\n",
" -3.4429e+00, 1.1099e+01],\n",
" [ 1.2810e+00, -4.1117e-01, -4.4152e+00, ..., -1.0298e+00,\n",
" -2.3506e+00, 1.1191e+01],\n",
" [ 1.5495e+00, -1.9605e+00, -3.1807e+00, ..., -9.8794e-01,\n",
" -2.1888e+00, 9.4760e+00]],\n",
"\n",
" [[ 3.7499e-01, -6.6046e-02, 4.5773e-01, ..., -1.2836e-01,\n",
" -7.7381e-02, -2.2161e+00],\n",
" [-1.9084e+00, -5.1770e-01, 3.3306e+00, ..., -1.0169e-01,\n",
" -2.0618e+00, 7.5854e+00],\n",
" [-3.1865e+00, -5.3798e-01, 3.4467e+00, ..., 8.8427e-02,\n",
" -4.1777e+00, 7.7792e+00],\n",
" ...,\n",
" [-2.9382e+00, -8.8965e-01, 3.4723e+00, ..., -1.4002e+00,\n",
" -5.7932e-01, 6.9011e+00],\n",
" [-3.7302e+00, -1.4835e+00, 7.7318e-01, ..., -1.4177e+00,\n",
" -1.5522e+00, 7.3279e+00],\n",
" [-2.4526e+00, -1.8321e+00, 3.6389e+00, ..., -4.4448e-01,\n",
" -1.6136e+00, 6.6650e+00]],\n",
"\n",
" [[ 1.2211e-01, -6.5015e-01, -2.2831e-01, ..., 1.4110e-01,\n",
" 2.7893e-01, -1.7424e-01],\n",
" [ 1.7771e-01, 1.7629e+00, 6.3257e-01, ..., -2.6582e-01,\n",
" 6.2577e-01, 5.0930e-02],\n",
" [ 2.2530e-01, 3.0012e+00, 5.3516e-01, ..., -3.2276e-01,\n",
" 5.9087e-01, -3.6453e-02],\n",
" ...,\n",
" [-6.4210e-01, 3.1597e+00, 2.3032e-01, ..., 6.4203e-01,\n",
" 1.9326e-01, 5.4560e-01],\n",
" [-4.8734e-01, 2.4240e+00, 1.1159e-01, ..., 9.6528e-01,\n",
" 1.2245e+00, -1.7901e+00],\n",
" [ 2.7319e-01, 2.8160e+00, 6.3444e-01, ..., -5.1675e-01,\n",
" -1.5301e-01, -8.1118e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-4.0181e-01, 1.2737e-02, -1.1140e-02, ..., 1.2548e+00,\n",
" 4.3199e-02, 1.8033e+00],\n",
" [-4.6139e-01, -1.3921e+00, -1.4511e+00, ..., -2.5093e+00,\n",
" -1.6920e+00, -2.7131e-01],\n",
" [-1.3954e-01, 3.9872e-01, -5.5181e-01, ..., -4.0252e+00,\n",
" -1.2034e+00, -8.0604e-01],\n",
" ...,\n",
" [ 3.8913e-01, -9.2129e-01, 6.7512e-01, ..., -3.2734e+00,\n",
" -3.7855e-01, -1.2775e+00],\n",
" [ 3.6478e-01, 1.1098e+00, 1.9589e+00, ..., -1.2581e+00,\n",
" -9.2984e-01, -1.5476e+00],\n",
" [-2.0390e-01, -6.6112e-01, -9.6914e-01, ..., -3.2531e+00,\n",
" -3.5533e-01, -3.5020e-01]],\n",
"\n",
" [[-3.3790e-01, -1.2825e-01, 2.2242e-01, ..., 2.6358e-01,\n",
" -2.9314e-02, 3.1528e-02],\n",
" [-6.0304e-01, -1.1295e+00, 1.4573e+00, ..., 7.0224e-01,\n",
" -8.5480e-01, 1.8017e-01],\n",
" [ 9.3104e-01, -2.1456e+00, 3.8324e-01, ..., 9.3967e-01,\n",
" -8.2110e-01, 1.3123e-01],\n",
" ...,\n",
" [-7.5492e-01, -1.8400e-01, 3.3456e-01, ..., 1.7404e+00,\n",
" 7.1590e-01, 1.3268e+00],\n",
" [-1.5429e-01, 5.3506e-01, 2.4561e+00, ..., 1.2834e+00,\n",
" 5.7729e-01, 1.3149e+00],\n",
" [-7.7036e-01, -6.9287e-01, 1.1238e+00, ..., 1.0106e+00,\n",
" -5.3742e-01, 1.3852e+00]],\n",
"\n",
" [[ 3.4402e+00, 2.1226e+00, -2.1050e+00, ..., -2.8555e+00,\n",
" -3.9038e+00, -1.2060e+00],\n",
" [-3.0643e+00, -1.6132e+00, 4.7811e+00, ..., -2.6905e+00,\n",
" 9.4376e+00, -3.7636e+00],\n",
" [-2.8029e+00, 9.2815e-01, 2.2908e+00, ..., -3.5372e+00,\n",
" 9.2503e+00, 2.0644e+00],\n",
" ...,\n",
" [-4.9774e+00, -1.8169e+00, 4.4703e+00, ..., -4.3005e+00,\n",
" 1.5492e+01, 3.7749e+00],\n",
" [-2.4577e+00, -1.8796e+00, 6.0842e+00, ..., -4.6722e+00,\n",
" 9.1210e+00, 1.8122e+00],\n",
" [-4.4732e+00, -2.0733e+00, 7.2062e+00, ..., -3.7151e+00,\n",
" 1.2814e+01, -2.2193e+00]]]], grad_fn=<PermuteBackward0>), tensor([[[[-0.0028, -0.0602, 0.0219, ..., 0.0593, 0.0264, 0.0681],\n",
" [ 0.6692, -0.1774, 0.2994, ..., 0.1940, -0.3524, -0.1093],\n",
" [ 0.0441, -0.6776, -0.4458, ..., 0.2746, 0.9155, -0.5374],\n",
" ...,\n",
" [-0.3262, -0.0103, -0.0866, ..., 0.0454, -0.1561, 0.2205],\n",
" [-0.0552, -0.6212, -0.4492, ..., -0.2533, 0.0952, -0.2438],\n",
" [ 0.1740, 0.0146, -0.0917, ..., 0.1930, -0.1700, 0.1307]],\n",
"\n",
" [[-0.0538, -0.0195, -0.1417, ..., -0.0445, 0.0476, -0.0319],\n",
" [ 0.3175, -0.1990, -0.2276, ..., 0.1004, -0.0740, -0.1226],\n",
" [ 0.3296, -0.6555, -0.2850, ..., -0.8669, 0.2712, 0.0552],\n",
" ...,\n",
" [-0.0141, 0.1838, 0.2267, ..., 0.0249, -0.0362, 0.3883],\n",
" [-0.2939, -0.5590, 0.3243, ..., -0.0678, 0.0157, -0.5514],\n",
" [-0.0048, -0.0914, -0.2181, ..., -0.2868, 0.0018, -0.0651]],\n",
"\n",
" [[ 0.0639, 0.0961, 0.0831, ..., 0.0160, -0.0859, -0.0050],\n",
" [-0.8685, -0.1267, -0.8107, ..., 0.0526, -0.7176, -0.0689],\n",
" [ 0.1621, 0.2253, 0.0752, ..., 0.1041, -0.4005, 0.1818],\n",
" ...,\n",
" [-0.4981, 0.5339, -0.4980, ..., -0.2581, -0.8093, -0.3876],\n",
" [-0.6054, 1.6497, 1.0752, ..., -1.0363, 0.7149, -0.6451],\n",
" [-0.4706, 0.3250, 0.3061, ..., 0.4489, -0.6589, 0.0312]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.0115, 0.0657, -0.0777, ..., 0.0440, 0.0456, -0.1384],\n",
" [ 0.0136, -0.3035, 0.8164, ..., -0.2084, -0.8236, 0.4428],\n",
" [ 0.2521, -0.4054, 0.2197, ..., 0.1480, 0.2216, 0.5164],\n",
" ...,\n",
" [-0.0778, -0.1247, -0.3227, ..., 0.1474, 0.1483, 0.3701],\n",
" [-0.3559, -0.8621, -0.0799, ..., -0.9994, 0.4109, 0.2198],\n",
" [-0.1967, 0.0573, 0.6049, ..., 0.1913, 0.0767, -0.0245]],\n",
"\n",
" [[-0.1315, -0.0534, 0.0947, ..., -0.0666, 0.0539, -0.0204],\n",
" [ 0.0918, -0.3386, -0.7173, ..., -0.2867, -0.0289, -0.1466],\n",
" [ 0.2971, 0.6579, -0.9279, ..., -0.0267, -1.3269, 0.6167],\n",
" ...,\n",
" [-0.1993, 0.8396, 0.5954, ..., -0.2100, 0.3891, 0.5287],\n",
" [ 1.5998, 0.6881, -0.2637, ..., 1.1610, 0.1208, -0.6552],\n",
" [ 0.5209, -0.3917, 0.1674, ..., -0.2824, 0.0700, -0.3138]],\n",
"\n",
" [[-0.0193, -0.0120, -0.0240, ..., -0.0300, 0.0080, -0.0136],\n",
" [-0.0432, -0.3667, -0.3346, ..., -0.1011, 0.0167, 0.1537],\n",
" [-0.3303, -0.6508, -0.2167, ..., -0.6360, -0.1999, -0.1340],\n",
" ...,\n",
" [-0.0058, -0.1530, -0.3235, ..., -0.3699, 0.0510, 0.1209],\n",
" [ 0.1009, 0.4467, -0.0791, ..., -0.2715, -0.2259, 0.5418],\n",
" [ 0.0141, 0.2831, -0.4868, ..., -0.1903, -0.1869, 0.7274]]]],\n",
" grad_fn=<PermuteBackward0>)), (tensor([[[[ 3.4512e-02, -2.8466e-01, 2.2210e-01, ..., 1.6982e+00,\n",
" -2.2029e-01, -8.0207e-02],\n",
" [ 1.1887e+00, 1.2521e+00, -2.0973e-02, ..., -2.7505e+00,\n",
" 1.1517e-01, -9.4738e-01],\n",
" [-3.7887e-01, 2.3147e-01, 7.8851e-01, ..., -3.8859e+00,\n",
" -1.2610e+00, -1.5381e+00],\n",
" ...,\n",
" [-3.8848e-01, -7.8692e-01, 6.0321e-01, ..., -1.5790e+00,\n",
" -4.4260e-01, -1.7360e+00],\n",
" [-1.3567e+00, -1.2212e-02, 4.0693e-01, ..., -2.6267e+00,\n",
" 3.1883e-01, -1.1768e+00],\n",
" [-1.3149e+00, 5.3910e-01, 8.4051e-01, ..., -2.6472e+00,\n",
" -8.0766e-02, -1.3063e+00]],\n",
"\n",
" [[ 1.5566e-01, 9.6884e-01, -1.4234e+00, ..., -1.1945e-01,\n",
" 2.6095e-01, 9.2861e-01],\n",
" [-1.1655e+00, -5.3317e+00, 7.2065e-01, ..., -1.4863e+00,\n",
" -2.2354e+00, -2.4988e+00],\n",
" [ 1.2192e+00, -4.3649e+00, 9.3857e-01, ..., 3.6005e-01,\n",
" -1.0827e+00, -2.1299e+00],\n",
" ...,\n",
" [-6.5003e-01, -3.6931e+00, 4.9255e-01, ..., -2.0790e+00,\n",
" -3.1514e-01, -2.7136e+00],\n",
" [ 5.9668e-01, -3.1527e+00, 7.6608e-01, ..., -4.4680e-01,\n",
" -1.1040e-01, -1.9393e+00],\n",
" [ 2.0418e+00, -5.3709e+00, 5.4901e+00, ..., -2.3439e-02,\n",
" 4.6572e-01, -3.8706e+00]],\n",
"\n",
" [[-6.7068e-01, 2.4994e-01, -5.6570e-02, ..., 1.7880e-01,\n",
" 5.6148e-02, -2.9901e-01],\n",
" [ 1.9676e+00, 2.9566e-02, -8.5660e-01, ..., -1.8619e+00,\n",
" -3.3802e-01, 1.6140e-01],\n",
" [ 2.1615e+00, -7.5559e-01, 3.4024e-01, ..., -1.4898e+00,\n",
" 4.2649e-01, 1.5977e+00],\n",
" ...,\n",
" [ 1.1094e+00, -8.7126e-01, 4.4787e-02, ..., -4.0946e-01,\n",
" -6.8646e-01, -5.1147e-01],\n",
" [ 1.3666e+00, -6.3472e-01, -6.9747e-01, ..., 6.0671e-01,\n",
" 2.1492e+00, -3.3250e-01],\n",
" [ 2.1474e+00, -4.8501e-02, -8.7131e-01, ..., -1.4417e+00,\n",
" 1.5616e+00, 1.8827e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-4.0635e-02, 1.1188e-01, 1.4037e-01, ..., -9.7647e-02,\n",
" 1.2961e-02, 1.5000e-01],\n",
" [ 7.8730e-01, -5.6138e-01, -1.2585e+00, ..., 1.1703e+00,\n",
" -1.7229e-01, 1.2928e+00],\n",
" [-1.0394e-01, 1.4770e-01, 3.8454e-01, ..., 6.5685e-01,\n",
" -2.6355e-01, 1.3102e+00],\n",
" ...,\n",
" [ 1.1297e-01, 1.4229e+00, 2.8362e-01, ..., 9.3448e-01,\n",
" 2.5909e-01, 2.9945e-01],\n",
" [ 6.5882e-01, 7.3874e-01, -5.1318e-01, ..., 9.5171e-01,\n",
" 1.6892e-01, -1.7952e-01],\n",
" [ 4.4172e-01, 9.7651e-02, -1.4498e+00, ..., 1.2877e+00,\n",
" 7.8737e-01, 5.7300e-02]],\n",
"\n",
" [[-3.0020e+00, 4.0418e-01, -2.7798e-02, ..., -4.8566e-01,\n",
" -3.4500e-01, 1.2311e+00],\n",
" [ 4.8764e+00, 1.4500e+00, -1.1937e+00, ..., -1.6858e+00,\n",
" 3.0943e-01, -9.1063e-01],\n",
" [ 4.6146e+00, 9.6566e-01, -5.1178e-01, ..., -2.1980e-01,\n",
" 1.1130e+00, -1.2746e+00],\n",
" ...,\n",
" [ 4.9677e+00, 2.5583e-02, -1.3527e+00, ..., -1.8770e+00,\n",
" -6.6969e-01, -4.0065e-01],\n",
" [ 4.3137e+00, 1.0467e+00, -1.5161e+00, ..., -2.2238e+00,\n",
" -1.7302e-01, -1.6034e-01],\n",
" [ 4.6436e+00, 9.9926e-01, -5.2100e-01, ..., -1.5177e+00,\n",
" 1.9258e-01, -2.4487e-01]],\n",
"\n",
" [[-7.4442e-03, -2.5452e-01, -1.9922e-04, ..., -1.8494e-01,\n",
" 3.4208e-01, 9.0523e-02],\n",
" [ 8.8014e-01, -3.2005e+00, -2.3284e-01, ..., -5.6783e-01,\n",
" 5.3092e-01, 4.5332e-02],\n",
" [-3.2605e-01, -1.7599e+00, -5.3681e-01, ..., -5.2140e-01,\n",
" 1.7060e+00, -8.0691e-01],\n",
" ...,\n",
" [-1.1833e+00, -8.9443e-01, 5.9676e-01, ..., 3.0636e-01,\n",
" 5.0886e-01, -1.5048e+00],\n",
" [-1.2903e+00, -9.5492e-01, 2.1957e-01, ..., 2.2938e+00,\n",
" -5.0270e-01, -7.8764e-02],\n",
" [ 4.4758e-01, -1.5906e+00, 1.4957e-01, ..., 2.3779e+00,\n",
" -2.2358e-01, 4.7562e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[-0.0224, -0.0239, 0.0031, ..., -0.0027, -0.0209, 0.3535],\n",
" [ 1.1117, -0.0250, 1.0920, ..., -0.1345, -0.2322, -0.7385],\n",
" [ 1.2585, -0.5406, -0.8740, ..., 0.6211, 0.4854, -0.4785],\n",
" ...,\n",
" [ 0.8751, -0.3160, -0.5735, ..., -0.2102, -0.0831, -0.5934],\n",
" [ 0.0085, -0.3084, 0.1655, ..., 0.4398, 0.5114, -0.4383],\n",
" [-0.0725, -0.3939, 0.5899, ..., 0.7469, -0.3640, 0.0679]],\n",
"\n",
" [[ 0.0048, -0.0161, 0.0186, ..., -0.0150, 0.0150, 0.0090],\n",
" [ 0.4621, -0.6415, -0.2005, ..., 0.2446, 1.2697, -0.7838],\n",
" [ 0.6805, -1.2565, 0.0765, ..., -0.0242, 1.4869, 0.1836],\n",
" ...,\n",
" [-0.2538, -0.0022, 0.1847, ..., 0.4838, 1.5106, 0.7886],\n",
" [ 1.2671, -0.9662, -0.3248, ..., 0.5432, -0.0319, -0.1366],\n",
" [-0.1197, -1.6058, -0.3833, ..., 0.3964, 1.0133, -0.1477]],\n",
"\n",
" [[-0.0603, 0.0030, -0.0383, ..., -0.0468, 0.0119, -0.0780],\n",
" [ 0.5506, -0.3951, 0.6694, ..., -0.6748, 0.3026, 0.0286],\n",
" [ 0.4687, 0.1415, 0.0033, ..., 0.4084, 0.2910, 0.4103],\n",
" ...,\n",
" [ 0.4985, 0.4334, 0.3964, ..., -0.2184, -0.0373, -0.0717],\n",
" [ 0.0850, -0.4120, -0.2606, ..., -0.2593, 0.7614, -0.8139],\n",
" [ 0.0813, -0.2308, 0.9975, ..., -0.3412, -1.0508, -0.9304]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.3063, -0.1904, -0.0540, ..., -0.4848, 0.2131, 0.1049],\n",
" [ 1.7674, -1.5249, 1.8613, ..., 1.4412, -0.3079, 0.1500],\n",
" [ 0.3092, -1.5787, 0.6095, ..., 0.5455, 0.1634, 1.3060],\n",
" ...,\n",
" [-1.9410, -1.8215, -0.4399, ..., 0.3221, -0.1979, -1.4136],\n",
" [ 1.2677, -1.9424, 0.0700, ..., -0.9788, -0.6381, -0.4399],\n",
" [ 0.8285, -1.8581, -0.3010, ..., -1.3209, 0.2318, -0.1750]],\n",
"\n",
" [[-0.0861, -0.1412, -0.0534, ..., -0.1797, -0.1466, 0.1142],\n",
" [-0.7113, -0.5252, -0.7349, ..., -0.0491, 0.5213, -0.7352],\n",
" [ 0.4967, -1.1247, -0.6529, ..., -0.4258, -0.1081, -0.2017],\n",
" ...,\n",
" [-0.4174, -1.3939, 0.0162, ..., -0.2306, -0.4274, 0.3158],\n",
" [-1.1609, -0.1209, -0.1991, ..., 1.2310, 0.5859, 0.6733],\n",
" [-1.1036, -0.5834, 0.1167, ..., 0.8276, -0.1767, 0.3441]],\n",
"\n",
" [[-0.0294, -0.0414, 0.1069, ..., 0.0614, -0.0412, 0.0239],\n",
" [-0.6127, -0.0583, -0.7644, ..., -1.4024, -0.9271, 0.9733],\n",
" [ 0.5288, 0.2919, 0.0434, ..., -0.4878, -0.6339, 0.4392],\n",
" ...,\n",
" [ 0.0769, -0.0123, -1.2272, ..., 0.3366, -0.2014, 0.2725],\n",
" [ 0.0642, 1.9300, -0.3253, ..., -1.0578, 0.4355, -1.4476],\n",
" [-0.5956, 0.2606, -0.5507, ..., -0.5284, -0.1602, -0.7526]]]],\n",
" grad_fn=<PermuteBackward0>)), (tensor([[[[-3.3938e-01, 8.6866e-01, -1.6351e-01, ..., 1.1267e+00,\n",
" -1.6784e-01, 1.2959e-01],\n",
" [-3.9849e-01, -4.8487e+00, -2.9290e-01, ..., -4.2862e+00,\n",
" 6.9337e-01, 6.4498e-01],\n",
" [ 7.8937e-01, -3.8062e+00, -7.5906e-01, ..., -4.4411e+00,\n",
" 9.3744e-01, 2.4774e+00],\n",
" ...,\n",
" [-2.9636e-01, -5.8016e+00, 1.4007e+00, ..., -3.1025e+00,\n",
" -2.7375e-01, 1.2819e+00],\n",
" [ 1.3025e+00, -3.8148e+00, 1.8926e+00, ..., -3.3508e+00,\n",
" 6.2647e-01, 5.1378e-01],\n",
" [-3.7934e-01, -4.5341e+00, 6.0715e-01, ..., -4.3161e+00,\n",
" 6.4808e-01, 7.9708e-01]],\n",
"\n",
" [[ 5.8217e-02, 8.7121e-01, -6.2251e-01, ..., -2.4310e-02,\n",
" 2.9330e-01, 1.3199e-02],\n",
" [ 1.3357e+00, -1.1482e+00, 1.2032e-01, ..., 1.5088e+00,\n",
" -1.0720e+00, -1.1527e+00],\n",
" [ 2.6168e+00, 9.3244e-03, -7.2926e-01, ..., 9.4531e-01,\n",
" -7.9178e-01, -1.6888e+00],\n",
" ...,\n",
" [ 8.0452e-01, -3.9176e-01, -3.0347e-01, ..., 1.3463e+00,\n",
" -3.1319e-01, -1.3556e-01],\n",
" [-7.1086e-01, 9.6997e-02, 1.2591e+00, ..., 2.0719e-01,\n",
" 4.2983e-01, -6.3391e-01],\n",
" [ 1.1039e-01, -1.3052e+00, 1.1124e-01, ..., 1.3074e+00,\n",
" 1.4712e+00, -2.7487e-01]],\n",
"\n",
" [[-3.1165e-01, 1.2165e-01, -9.8370e-01, ..., -3.5095e-01,\n",
" -6.3912e-02, -1.3616e-01],\n",
" [ 5.0049e-01, -6.6728e-01, 2.9285e+00, ..., -3.9263e-01,\n",
" 4.3198e-01, -2.3447e-01],\n",
" [ 1.2306e-01, -2.9766e-01, 3.6896e+00, ..., -1.0091e-01,\n",
" -2.5103e-01, -2.0315e-01],\n",
" ...,\n",
" [-2.1391e-01, -2.1547e+00, 2.8612e+00, ..., 5.8855e-01,\n",
" -1.9214e-01, 1.8883e+00],\n",
" [-2.1992e-01, -1.4360e+00, 3.3444e+00, ..., 9.8178e-01,\n",
" -1.9441e+00, 5.7364e-01],\n",
" [ 6.3090e-02, -1.4908e+00, 2.0854e+00, ..., 1.4157e-01,\n",
" -1.3972e-01, -6.9580e-02]],\n",
"\n",
" ...,\n",
"\n",
" [[ 3.7597e-01, 8.1398e-02, -6.4505e-02, ..., -4.8594e-02,\n",
" 2.2536e-01, 4.1931e-03],\n",
" [-1.2319e+00, 8.1079e-01, -5.4320e-01, ..., 1.2257e-01,\n",
" -7.8676e-02, -2.6823e-01],\n",
" [-1.9185e-02, 5.4915e-01, 9.4312e-01, ..., -2.6608e+00,\n",
" 3.8096e-01, -1.3816e+00],\n",
" ...,\n",
" [-1.5423e+00, -2.7545e-01, 2.9765e+00, ..., 5.4036e-01,\n",
" 1.6682e+00, -7.5562e-01],\n",
" [-1.2052e+00, -2.4065e-01, 4.7900e-02, ..., -1.5625e+00,\n",
" 2.8238e-01, -3.3910e-01],\n",
" [-1.7759e+00, 3.9760e-01, -1.0807e+00, ..., -1.9584e+00,\n",
" -1.1637e+00, 1.5918e+00]],\n",
"\n",
" [[ 2.0009e-01, 5.4941e-02, 3.2748e-01, ..., 4.1661e-01,\n",
" -3.4165e-03, 2.3171e-01],\n",
" [ 1.6163e+00, 1.2442e+00, 2.8373e-01, ..., -3.9689e-01,\n",
" 7.1320e-03, -1.1601e-01],\n",
" [ 1.3228e+00, 1.4674e-01, 6.3871e-01, ..., -5.9913e-02,\n",
" 1.6461e-01, 3.3509e-01],\n",
" ...,\n",
" [ 7.7162e-01, 7.9756e-01, 8.2908e-01, ..., -1.0911e+00,\n",
" 8.8888e-01, -1.1994e+00],\n",
" [ 1.6909e+00, 8.3524e-01, 6.7132e-01, ..., -1.1008e+00,\n",
" -7.2901e-01, 6.1303e-01],\n",
" [ 2.8334e+00, 4.6555e-01, 1.2473e+00, ..., -7.3844e-01,\n",
" -7.0963e-01, 1.0278e-01]],\n",
"\n",
" [[-3.0156e+00, 5.3756e-01, 5.6815e-01, ..., -9.3899e-01,\n",
" 3.2683e-01, 1.8463e-01],\n",
" [ 7.7879e+00, -9.7524e-01, -2.1850e+00, ..., 2.2429e+00,\n",
" -1.0887e+00, 5.6749e-01],\n",
" [ 6.9868e+00, -3.9651e-01, -9.7286e-01, ..., 1.0613e+00,\n",
" -7.0396e-01, 1.3823e+00],\n",
" ...,\n",
" [ 9.5361e+00, -8.7937e-01, -2.5252e+00, ..., 1.3820e+00,\n",
" -2.2409e+00, 2.4565e-01],\n",
" [ 8.8630e+00, -1.1387e+00, -1.7681e+00, ..., 1.0129e+00,\n",
" 2.0493e-01, -2.1170e-01],\n",
" [ 9.5194e+00, -2.0795e-01, -1.6476e+00, ..., 2.4340e+00,\n",
" -1.9197e+00, -2.9640e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[ 4.6474e-02, -5.0378e-02, 1.0945e-02, ..., -6.9955e-02,\n",
" 2.9789e-03, -1.0073e-01],\n",
" [ 5.1051e-01, -5.5772e-01, -3.8570e-01, ..., -3.2328e-01,\n",
" 2.3945e-01, -2.9826e-01],\n",
" [-2.8010e-01, 7.4962e-01, -5.4584e-01, ..., -3.6442e-01,\n",
" 4.2576e-01, -1.4805e+00],\n",
" ...,\n",
" [-9.0783e-01, -4.8128e-01, -1.8888e-01, ..., -2.2824e-01,\n",
" -7.4845e-02, -1.0972e+00],\n",
" [-5.0702e-01, 1.0603e-01, -1.0484e+00, ..., 5.5779e-01,\n",
" -4.9793e-01, -9.2837e-01],\n",
" [ 3.8714e-02, 4.2493e-01, -4.1890e-01, ..., 5.6050e-01,\n",
" -2.7279e-01, -1.3355e+00]],\n",
"\n",
" [[ 6.7674e-02, 3.0544e-02, -2.3115e-02, ..., -4.3823e-02,\n",
" 5.2575e-03, -1.6795e-03],\n",
" [-5.3669e-01, 1.7762e+00, -5.2043e-01, ..., 7.5157e-01,\n",
" -6.1868e-01, -7.3336e-01],\n",
" [-2.5054e-01, 4.9751e-03, -8.3214e-02, ..., -7.4598e-01,\n",
" -6.1617e-01, 3.3602e-01],\n",
" ...,\n",
" [-1.4067e-01, 5.9621e-01, 1.0898e+00, ..., 9.4066e-01,\n",
" -1.3745e+00, 1.1213e+00],\n",
" [-1.0630e+00, -5.0378e-01, 6.9651e-01, ..., -4.6445e-01,\n",
" -6.6259e-01, 1.7251e-01],\n",
" [-1.5972e+00, 3.2659e-01, 3.4644e-01, ..., 2.8986e-01,\n",
" -5.7299e-01, -2.2912e-01]],\n",
"\n",
" [[ 7.0952e-02, 8.2320e-03, -1.6572e-03, ..., 2.1678e-02,\n",
" -6.7437e-02, -5.0287e-02],\n",
" [ 7.4200e-01, -3.2418e-01, 4.1442e-01, ..., -1.4945e-02,\n",
" 2.5678e-01, 1.5392e-01],\n",
" [ 2.9304e-01, 5.7399e-01, -2.7184e-01, ..., -1.4044e-01,\n",
" 6.1588e-02, -1.5561e-01],\n",
" ...,\n",
" [ 7.1019e-01, -8.5043e-01, -3.1989e-01, ..., 2.5753e-01,\n",
" 2.2188e-01, 7.3108e-01],\n",
" [ 7.1561e-01, -8.6057e-01, 9.2320e-01, ..., 3.9957e-01,\n",
" 2.4226e+00, 1.6563e+00],\n",
" [-7.6132e-02, 2.4041e-01, 9.3365e-01, ..., -2.2613e-01,\n",
" 3.9552e-01, 1.0165e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-8.4018e-04, 4.2945e-02, 2.0029e-02, ..., -6.6209e-02,\n",
" -1.8070e-02, 2.2869e-02],\n",
" [-1.4168e+00, 2.7825e-01, 3.5415e-02, ..., 2.2794e-01,\n",
" -1.8244e-01, 2.6631e-01],\n",
" [-1.5832e+00, 6.7589e-01, -1.3738e-01, ..., 7.5377e-01,\n",
" -8.9247e-01, 8.4118e-01],\n",
" ...,\n",
" [-1.0343e+00, 2.2096e-01, 1.8098e-01, ..., 1.5064e+00,\n",
" -9.4570e-01, -9.6457e-01],\n",
" [-5.5192e-01, 6.5732e-01, -7.3323e-01, ..., 8.2586e-01,\n",
" 1.0773e+00, -5.0690e-01],\n",
" [-6.9760e-01, -2.0758e-01, 2.9526e-01, ..., -1.6063e-02,\n",
" 1.6516e-02, 4.3263e-01]],\n",
"\n",
" [[ 5.6418e-02, -6.3642e-03, 2.3703e-02, ..., 1.7139e-02,\n",
" -1.5312e-02, 6.8112e-03],\n",
" [ 1.8381e+00, -1.3941e+00, -1.0189e+00, ..., -9.4177e-01,\n",
" 4.2883e-01, 8.2570e-01],\n",
" [ 8.8893e-01, -1.6692e+00, -4.3398e-01, ..., -1.2906e+00,\n",
" 1.0952e-01, 3.7169e-01],\n",
" ...,\n",
" [ 7.4024e-01, -1.4955e-01, -8.9148e-01, ..., -1.0267e+00,\n",
" -6.1569e-01, 5.8172e-01],\n",
" [-7.3008e-01, -4.7314e-01, 3.7697e-01, ..., 5.2418e-01,\n",
" -1.6633e-01, 3.0198e-01],\n",
" [ 6.6411e-02, -4.8074e-01, -4.0598e-01, ..., 1.1196e-01,\n",
" 1.0054e+00, -4.4949e-01]],\n",
"\n",
" [[ 6.7269e-02, -2.0375e-01, -7.5082e-02, ..., -4.0162e-02,\n",
" 1.9610e-01, -5.1942e-02],\n",
" [ 3.7243e-01, -9.5645e-01, -3.3796e-01, ..., -9.8523e-01,\n",
" -4.3307e-01, -2.3109e-01],\n",
" [-5.5909e-01, -9.8741e-01, -8.3997e-01, ..., -4.0350e-02,\n",
" 2.2590e-03, -1.1709e+00],\n",
" ...,\n",
" [ 2.6116e-01, -1.7003e+00, 9.9667e-03, ..., 2.5269e-01,\n",
" -6.5086e-01, -5.0987e-01],\n",
" [-2.2483e-01, -3.8567e-01, -1.6472e-01, ..., -7.8707e-01,\n",
" 3.2198e-01, -4.2609e-01],\n",
" [-1.5893e-01, -7.3543e-01, -4.9369e-01, ..., -1.5504e+00,\n",
" -3.8277e-01, -4.1377e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[ 1.0315, -0.2591, -0.1501, ..., 0.6246, 0.7232, -0.3059],\n",
" [-3.2631, -1.9435, -1.2089, ..., 0.2806, -5.5510, 0.3489],\n",
" [-4.3399, -1.4013, -0.2186, ..., 0.7109, -5.7860, -0.6970],\n",
" ...,\n",
" [-5.3760, -3.7959, -0.3718, ..., 0.7804, -4.6301, -1.2402],\n",
" [-6.1584, -2.8255, 0.0772, ..., 0.3908, -4.4567, -0.1920],\n",
" [-4.6862, -2.2484, -0.4802, ..., 1.1911, -4.6985, -1.0555]],\n",
"\n",
" [[-0.1525, -0.0745, 0.1651, ..., -0.0464, -0.8761, -0.1921],\n",
" [-0.6493, 1.0436, -0.2845, ..., -0.2628, 0.0537, 0.8063],\n",
" [-1.7708, 1.3885, -0.9440, ..., 0.3637, 0.7435, 1.4247],\n",
" ...,\n",
" [-2.0241, 0.3328, -0.2828, ..., 0.8545, 0.5231, 2.4687],\n",
" [-2.8308, -0.2631, -0.4617, ..., -0.3337, 1.8320, 2.9475],\n",
" [-2.0453, 1.1846, -2.5580, ..., 0.5495, 1.1092, 1.8249]],\n",
"\n",
" [[ 0.1958, 0.3039, 1.1389, ..., -0.4691, 0.4513, -0.4878],\n",
" [-1.4815, -0.5524, -1.6846, ..., -0.5676, -1.8434, 2.4752],\n",
" [-3.5171, -1.7341, -1.0781, ..., -0.0126, -0.8584, 2.8363],\n",
" ...,\n",
" [-1.2945, -1.0943, -0.7373, ..., 0.2280, -2.9008, 2.5152],\n",
" [-2.2796, -0.5816, -0.3174, ..., 0.7422, -1.4116, 2.2355],\n",
" [-0.7958, 0.1943, -2.7152, ..., 1.7208, -1.5123, 0.9313]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.1518, 0.0675, -0.2341, ..., 0.0125, 0.1685, 0.0227],\n",
" [-1.5855, 0.1968, -0.4700, ..., 0.9270, -1.3281, -0.1941],\n",
" [ 0.7663, -0.7921, -0.5326, ..., 0.9606, -0.0650, -0.2843],\n",
" ...,\n",
" [ 0.1914, -0.1551, -0.9815, ..., 1.8034, 0.1310, 0.7172],\n",
" [-1.2788, -1.7422, -0.4975, ..., 1.3406, -0.4531, -0.5256],\n",
" [-1.8526, 0.1496, -0.0816, ..., 0.8122, -1.0543, 0.1050]],\n",
"\n",
" [[-0.3515, -2.1836, 0.1103, ..., -0.0873, -0.0481, 0.9174],\n",
" [-0.3931, 1.7304, -1.0893, ..., -1.0898, -1.7984, 1.0287],\n",
" [ 0.3552, 3.3603, -1.5929, ..., -0.7109, -1.5203, 0.7090],\n",
" ...,\n",
" [ 0.4526, 3.6483, -3.1344, ..., 1.3756, -1.8511, 2.2068],\n",
" [ 1.4022, 2.2589, -2.0330, ..., 0.3515, -0.4796, 0.9019],\n",
" [ 0.7568, 2.8114, -2.1562, ..., 1.3476, -0.3658, 0.7552]],\n",
"\n",
" [[ 0.3682, 0.0657, -0.1320, ..., 0.6454, 0.1343, 0.2644],\n",
" [-0.8896, 0.3677, 0.1631, ..., -0.3916, -0.4439, -0.9719],\n",
" [ 0.4470, 0.5271, -0.4635, ..., -0.6886, -1.2558, 0.0390],\n",
" ...,\n",
" [-1.7867, -2.2049, 2.1719, ..., -0.8210, -0.2084, 1.6132],\n",
" [-1.4884, -1.5097, 0.1562, ..., 0.5166, 0.2819, -0.1415],\n",
" [-2.1183, 1.1049, 1.0999, ..., -0.3114, 0.2994, 0.8749]]]],\n",
" grad_fn=<PermuteBackward0>), tensor([[[[-2.8336e-02, 4.5010e-02, -5.7978e-02, ..., -1.9222e-02,\n",
" 1.2577e-02, 3.6269e-02],\n",
" [-4.6357e-01, 2.2473e-01, -5.6911e-01, ..., -1.1179e-01,\n",
" 5.4963e-01, 1.4621e-01],\n",
" [ 8.8791e-01, 1.3920e-01, -1.1074e+00, ..., -3.0970e-01,\n",
" 8.6369e-01, 2.8616e-01],\n",
" ...,\n",
" [ 9.8967e-01, 7.6810e-02, 3.6725e-01, ..., 1.0289e-01,\n",
" -7.9780e-01, -1.0472e-01],\n",
" [ 4.1472e-01, -3.0706e-01, -1.0118e-01, ..., -2.9164e-02,\n",
" 9.2894e-02, 2.6503e-01],\n",
" [-4.1391e-01, -4.3953e-01, 9.5461e-02, ..., -1.8622e-02,\n",
" 1.2946e-01, -4.0387e-01]],\n",
"\n",
" [[ 2.1536e-02, -2.8120e-02, 3.8532e-02, ..., 2.1765e-02,\n",
" -4.7212e-02, 5.3255e-03],\n",
" [-1.6248e-01, -4.5659e-01, -4.4525e-02, ..., 5.6903e-01,\n",
" -3.0144e-01, -1.2120e+00],\n",
" [-1.6019e-01, -3.1593e-01, 1.0682e+00, ..., -1.1746e-01,\n",
" -4.8418e-01, 4.2423e-01],\n",
" ...,\n",
" [-7.0670e-01, 1.4226e-01, -2.0767e-01, ..., -5.3785e-01,\n",
" -3.7916e-01, 2.9476e-01],\n",
" [ 3.5204e-01, 1.6746e-01, -1.8197e+00, ..., 1.8833e-01,\n",
" 2.5200e-01, 1.3326e+00],\n",
" [ 1.0614e-01, -5.6477e-01, -1.3717e+00, ..., 2.8329e-01,\n",
" -2.3432e-01, 5.8129e-01]],\n",
"\n",
" [[ 3.9084e-02, -2.6990e-02, 5.6189e-02, ..., 2.6549e-02,\n",
" -7.1806e-03, 1.9065e-02],\n",
" [ 8.1593e-01, 3.5473e-01, -1.9476e-01, ..., 7.1779e-01,\n",
" 1.7158e-01, 1.7037e-01],\n",
" [-3.0468e-01, 6.4740e-01, -1.1535e+00, ..., 2.5107e+00,\n",
" -1.3214e+00, 6.0931e-01],\n",
" ...,\n",
" [-3.8012e-01, -1.0693e+00, -4.3163e-01, ..., -1.2006e-01,\n",
" -4.7626e-01, -5.9241e-01],\n",
" [-6.6220e-01, 1.0321e+00, 6.1114e-01, ..., -1.0294e+00,\n",
" -5.9746e-02, -1.4874e+00],\n",
" [ 1.5239e+00, 1.7266e-01, -2.6497e-01, ..., -6.9278e-01,\n",
" 2.7154e-01, 1.1508e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-1.8650e-01, 8.9365e-02, 5.7435e-02, ..., 4.6573e-02,\n",
" 3.7369e-02, -1.2676e-01],\n",
" [-6.2920e-01, -4.5253e-02, 1.5379e-01, ..., -8.5838e-01,\n",
" 2.2210e-01, -4.9222e-01],\n",
" [-2.1227e-01, 6.7216e-01, 5.8456e-01, ..., -4.8421e-02,\n",
" -4.2428e-01, -4.8305e-01],\n",
" ...,\n",
" [-9.4783e-01, -4.8206e-02, -1.2836e-01, ..., 1.8181e-01,\n",
" -4.6491e-01, -8.4671e-01],\n",
" [-7.2088e-01, 4.8839e-01, -1.6034e+00, ..., -3.5454e-01,\n",
" 8.5080e-02, -1.4271e+00],\n",
" [-1.0528e+00, 8.3454e-01, -9.8252e-01, ..., 1.1729e-01,\n",
" -1.4640e-01, -1.9143e+00]],\n",
"\n",
" [[-5.8489e-01, -4.5877e-03, 4.4912e-02, ..., -2.0796e-02,\n",
" 6.2989e-03, -6.4938e-03],\n",
" [-1.6445e+00, 4.2511e-02, -3.1403e-01, ..., -3.7935e-01,\n",
" 2.3561e-01, 5.9496e-02],\n",
" [-2.5505e+00, -2.0482e-01, -3.6240e-01, ..., -3.0201e-01,\n",
" -4.2028e-01, -1.8376e-02],\n",
" ...,\n",
" [-1.6757e+00, 4.2658e-01, -9.1740e-01, ..., 2.0202e-01,\n",
" 5.2352e-01, 3.1575e-01],\n",
" [-1.7608e+00, 5.6837e-01, 3.5225e-01, ..., 5.5874e-01,\n",
" -6.9264e-01, -1.8256e-01],\n",
" [-2.3731e+00, -2.8098e-01, 3.9676e-01, ..., -2.5406e-01,\n",
" 4.8834e-01, -6.1031e-01]],\n",
"\n",
" [[ 1.5471e-03, 8.2456e-02, -4.7513e-02, ..., 5.5853e-02,\n",
" 3.0368e-02, -4.6994e-02],\n",
" [-5.5504e-01, 7.3400e-01, -2.0816e-01, ..., -1.2824e-01,\n",
" 3.8586e-01, 8.0331e-01],\n",
" [ 6.3713e-01, 1.6547e+00, 2.6059e-01, ..., -1.1861e+00,\n",
" 6.3198e-01, -1.3541e-01],\n",
" ...,\n",
" [ 4.7463e-01, 1.1477e+00, 6.0258e-02, ..., -4.6058e-01,\n",
" -3.5489e-01, 7.9365e-02],\n",
" [ 8.1016e-02, -1.3944e-01, 4.1258e-01, ..., 1.1060e-01,\n",
" -2.8541e+00, 4.1492e-01],\n",
" [-1.2963e+00, 2.2384e-01, -2.4338e-01, ..., 2.2294e-01,\n",
" 1.0918e-01, 2.1425e+00]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-0.0268, -2.3398, 0.1634, ..., -0.2365, -0.1944, 0.0645],\n",
" [-0.2799, 4.0417, -0.1959, ..., -0.1126, -0.3356, -0.5690],\n",
" [-0.0094, 4.9294, 0.5373, ..., -0.1404, -0.6815, 0.3025],\n",
" ...,\n",
" [ 0.0091, 3.8780, 0.7991, ..., -0.5989, 0.7071, 0.5137],\n",
" [-0.6218, 4.7662, -0.4088, ..., -0.8925, -0.0737, 0.7395],\n",
" [-0.8638, 5.1069, -0.1012, ..., -0.0097, 0.0632, -0.7295]],\n",
"\n",
" [[-0.8140, 0.2218, 0.4656, ..., -0.5189, 1.0732, 1.1234],\n",
" [ 0.2364, 0.2685, 1.0541, ..., 0.5500, 1.3914, 0.4962],\n",
" [ 1.5259, 1.0305, -0.6830, ..., -0.3595, 0.8213, -0.1596],\n",
" ...,\n",
" [-0.8893, 0.6401, 1.5340, ..., -0.3154, 0.9969, 0.1131],\n",
" [-1.4240, -0.5673, -0.9037, ..., -0.0334, 2.1567, -0.3555],\n",
" [-2.3116, 1.4069, 0.2116, ..., 0.7944, 2.6708, 0.1778]],\n",
"\n",
" [[-0.8504, 0.4700, 0.0232, ..., 0.4955, -0.2356, 1.1518],\n",
" [ 0.6655, -0.1374, 1.1604, ..., 0.2494, 1.0734, -0.9082],\n",
" [ 2.0262, 0.3311, 0.5329, ..., 0.2746, 0.6484, -1.2565],\n",
" ...,\n",
" [ 0.8666, 0.2080, 0.7423, ..., -0.0590, 0.7947, 0.2077],\n",
" [ 1.3274, -0.5878, 1.5562, ..., 1.2727, 0.8958, -0.8393],\n",
" [ 0.6793, -0.9115, 2.1432, ..., 1.5571, 1.7428, -0.3943]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.3102, -0.1292, 0.1523, ..., 0.1793, 1.7438, -2.8696],\n",
" [ 1.2379, -0.5238, 0.3674, ..., -0.3042, -4.6049, 4.6856],\n",
" [ 0.6856, 0.3973, 0.9211, ..., -0.6994, -5.2863, 4.9465],\n",
" ...,\n",
" [-0.5532, -0.4212, 1.0728, ..., 0.4562, -5.7176, 5.1979],\n",
" [-0.9992, -1.4073, -0.8534, ..., 0.8452, -5.9484, 4.2105],\n",
" [ 0.1935, -1.2555, 1.2355, ..., -0.0070, -6.0872, 5.8807]],\n",
"\n",
" [[ 0.1957, 0.3617, 0.2155, ..., -0.2170, 0.0182, -0.1540],\n",
" [-0.6359, -0.7831, -0.5938, ..., 1.0413, -0.4280, 0.6407],\n",
" [-0.6033, -1.0964, -0.2818, ..., 0.2840, -0.2947, 0.6149],\n",
" ...,\n",
" [-0.2907, 0.0759, 0.5673, ..., 1.1031, -0.7398, 0.1992],\n",
" [-0.3487, -0.1916, 1.1144, ..., 0.6085, 0.1949, 1.1279],\n",
" [-1.1693, -0.8894, 0.6257, ..., 1.4145, -1.2843, 0.4372]],\n",
"\n",
" [[ 0.3722, 0.0987, 0.6134, ..., 0.5249, 0.5746, -0.3289],\n",
" [ 0.7276, -0.7879, -1.5108, ..., -1.7654, -3.2146, 0.1771],\n",
" [ 0.6286, -1.0423, -1.3390, ..., -2.0023, -2.7540, -0.0532],\n",
" ...,\n",
" [ 1.2008, -1.0047, -2.2047, ..., -2.5210, -4.7543, 1.0585],\n",
" [ 0.1571, -1.0960, -1.7899, ..., -3.0896, -4.1969, -0.4143],\n",
" [ 1.1982, -1.3326, -1.5329, ..., -1.6822, -4.4774, -0.5948]]]],\n",
" grad_fn=<PermuteBackward0>), tensor([[[[ 0.0512, -0.0071, -0.0159, ..., 0.1308, -0.0604, -0.0383],\n",
" [ 0.0743, 0.0098, 0.8985, ..., 0.3322, 1.0163, -0.2279],\n",
" [ 0.2299, -1.0595, -0.2036, ..., 0.4071, 1.0309, 0.7073],\n",
" ...,\n",
" [ 0.9942, 0.0985, -0.3045, ..., 0.3595, 1.2762, 0.1312],\n",
" [-1.1851, 0.1872, 2.5162, ..., -0.4091, -0.5504, -0.3313],\n",
" [ 0.4573, -0.2495, 1.1492, ..., 0.3916, 0.3092, -0.2549]],\n",
"\n",
" [[ 0.0178, 0.0383, 0.0396, ..., 0.0060, -0.0180, 0.0108],\n",
" [ 0.3077, 0.2800, -1.2484, ..., 0.1144, -0.0260, -0.6417],\n",
" [ 0.8365, 0.1942, -2.6429, ..., 1.4839, -2.4390, -1.1518],\n",
" ...,\n",
" [-1.0152, -1.3838, 0.4507, ..., 0.2284, 0.2643, 0.3901],\n",
" [-1.8002, -1.5104, -0.6286, ..., 1.0451, 0.2438, -0.3518],\n",
" [-0.4032, -0.3529, -1.6265, ..., 0.5828, 0.5720, -1.2572]],\n",
"\n",
" [[ 0.0495, -0.0389, 0.0613, ..., 0.0561, -0.0711, -0.0673],\n",
" [-0.6686, 1.1461, -0.4798, ..., 0.1773, 0.4573, 0.4967],\n",
" [-0.3811, 0.8968, -0.6061, ..., 0.0926, 0.3056, 0.9180],\n",
" ...,\n",
" [-0.3757, -0.0510, 0.0062, ..., 0.6064, 0.7972, 0.7227],\n",
" [-0.2685, -0.7850, 0.7441, ..., -0.8875, -0.0677, 1.0534],\n",
" [-0.7876, 1.0096, -0.0108, ..., -0.9138, -0.1195, -0.2942]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.1028, -0.0452, 0.0346, ..., -0.0871, 0.0427, 0.0092],\n",
" [ 0.5169, 0.0966, 0.2483, ..., -0.4591, 0.3724, 0.6674],\n",
" [ 0.9085, 0.9305, -0.0286, ..., -0.8769, -0.3911, 0.3594],\n",
" ...,\n",
" [-0.0673, -0.2202, -0.2051, ..., 0.2041, -0.4487, 1.0220],\n",
" [-0.2218, -0.4037, 1.4038, ..., 1.5332, -1.2336, 0.4163],\n",
" [ 0.8637, -1.0940, 0.2482, ..., 0.3983, -1.4612, 0.6188]],\n",
"\n",
" [[ 0.1576, -0.0522, 0.1510, ..., 0.0776, 0.0389, -0.1486],\n",
" [-0.0612, 1.4222, 1.2901, ..., 1.0537, 1.9877, -1.2965],\n",
" [ 0.0701, 1.0599, 1.3164, ..., 1.8434, 1.7597, -0.8641],\n",
" ...,\n",
" [-0.0791, 0.1802, -0.2036, ..., 0.6063, 1.2652, 0.1763],\n",
" [ 0.4001, 1.6460, 1.1749, ..., -0.6267, 2.3732, -0.3538],\n",
" [ 0.2739, 1.4950, 0.8300, ..., 1.1957, 1.5808, -1.0777]],\n",
"\n",
" [[ 0.2067, -0.0439, -0.0680, ..., 0.0390, 0.0473, 0.0275],\n",
" [-0.6717, 0.2561, 0.7676, ..., -0.2872, -0.5916, -0.1957],\n",
" [-0.9239, 0.0464, 0.4365, ..., 0.6006, -0.4989, 0.7633],\n",
" ...,\n",
" [ 0.5723, 0.0787, 0.7033, ..., 0.3464, -0.7811, 1.3074],\n",
" [-0.8109, -0.4612, -1.6027, ..., -1.6367, -0.0065, -0.7756],\n",
" [-1.3609, 0.5702, 0.7531, ..., -0.1462, 0.1355, 0.2370]]]],\n",
" grad_fn=<PermuteBackward0>)), (tensor([[[[ 0.0436, -0.2509, -0.4550, ..., 0.3116, 0.3358, 0.3770],\n",
" [ 0.3000, -0.0478, -1.2507, ..., 0.3519, 0.7682, 0.7296],\n",
" [ 0.5719, -1.2229, -1.5268, ..., 0.2386, 0.1496, 0.8318],\n",
" ...,\n",
" [-0.7869, -1.6832, -0.6862, ..., 0.5248, -1.1760, -0.6061],\n",
" [ 1.1739, 0.7271, -1.4276, ..., 1.1409, -1.3880, -0.6762],\n",
" [ 1.7515, -0.1609, -0.0345, ..., 0.9718, -0.5132, 1.4921]],\n",
"\n",
" [[-0.2801, 0.1559, 0.1167, ..., 0.0214, -1.1384, -0.1501],\n",
" [-0.7174, -0.1200, -0.7961, ..., -0.4121, 0.7157, 0.5868],\n",
" [ 1.5377, 0.1651, -0.9257, ..., 0.3588, 1.3888, 0.1633],\n",
" ...,\n",
" [ 0.2032, 0.5659, -0.9297, ..., -1.1580, -1.0870, 1.0748],\n",
" [-0.0984, 1.5501, -1.2118, ..., -1.0350, 0.6500, 0.8747],\n",
" [-1.1498, 0.8479, -0.9318, ..., -1.2515, 0.5937, 0.4393]],\n",
"\n",
" [[-1.2411, -0.0878, 0.5490, ..., -0.6611, 0.4539, -0.2888],\n",
" [ 0.6556, 1.0735, -0.5900, ..., 0.0895, -0.3484, -0.2450],\n",
" [ 0.3530, 0.0116, 0.0702, ..., 0.7262, -1.4991, -0.5028],\n",
" ...,\n",
" [ 0.6693, 0.8831, -0.7045, ..., 1.2413, 0.0528, 0.1498],\n",
" [ 1.5144, 1.9988, -1.8167, ..., 1.0272, -0.5508, -0.2781],\n",
" [-0.2976, 1.1260, -1.6873, ..., 1.3365, -0.2020, -0.3461]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.7973, -0.8987, -0.3939, ..., -1.0369, -0.4123, 0.4803],\n",
" [ 1.4616, 0.0408, -1.0295, ..., 0.7219, 0.3444, -0.0145],\n",
" [ 0.2550, -1.1764, -0.3335, ..., 0.8036, 1.7228, -2.3128],\n",
" ...,\n",
" [ 0.6038, -0.3213, -0.9128, ..., 1.7723, 0.7332, -1.3456],\n",
" [ 1.5292, 0.8308, -1.5665, ..., 1.7068, 0.6255, -1.4453],\n",
" [ 2.1459, -0.1321, -0.5784, ..., 1.8690, 1.6415, 0.8508]],\n",
"\n",
" [[-0.9151, 2.5785, 0.3082, ..., 0.3579, 1.9421, -0.5408],\n",
" [-0.0171, -2.5663, 0.7328, ..., 0.3923, -4.1463, 1.9012],\n",
" [ 1.0185, -2.5828, -1.5448, ..., 1.0508, -4.9451, 1.7123],\n",
" ...,\n",
" [ 1.1092, -2.5339, 0.2730, ..., -0.9127, -3.6883, -0.9762],\n",
" [ 0.7417, -1.7092, 0.4430, ..., 0.6517, -4.0859, -0.6250],\n",
" [ 0.6957, -4.4839, -0.4944, ..., 1.2733, -5.0460, 2.7409]],\n",
"\n",
" [[-2.0221, -0.3681, -1.1042, ..., -0.3983, 0.0527, 0.2442],\n",
" [ 1.4425, 0.4368, 0.8613, ..., 1.2344, -0.1098, 0.1759],\n",
" [ 2.6873, -0.5718, 0.7670, ..., 1.4859, -0.9973, 1.5824],\n",
" ...,\n",
" [ 2.3481, -0.2267, 0.4736, ..., 1.0791, 0.1695, -0.6822],\n",
" [ 1.8121, 0.8181, 1.5002, ..., 1.3897, -1.1112, -0.6512],\n",
" [ 2.0455, 0.8276, 1.0394, ..., 1.7555, -0.0730, -0.0210]]]],\n",
" grad_fn=<PermuteBackward0>), tensor([[[[-4.8632e-02, -9.1396e-02, 3.1682e-02, ..., 1.2261e-01,\n",
" -3.6255e-02, 1.2526e-02],\n",
" [-3.1696e-01, 2.4710e-01, -5.5300e-02, ..., 2.4286e-02,\n",
" -3.1162e-01, 2.2300e-01],\n",
" [ 1.9266e-01, 5.7364e-01, 5.6620e-01, ..., 6.3398e-01,\n",
" -1.6994e-01, -3.2943e-01],\n",
" ...,\n",
" [ 8.0219e-01, -6.2467e-02, 7.5092e-01, ..., -2.6152e-01,\n",
" 6.4908e-01, 9.1121e-01],\n",
" [ 2.9819e-01, -1.1154e+00, 5.7111e-01, ..., -1.1155e+00,\n",
" 5.0150e-01, 3.6634e-01],\n",
" [ 7.2844e-01, 4.1041e-01, 6.7296e-01, ..., 2.8859e-01,\n",
" -9.5357e-01, 4.9752e-01]],\n",
"\n",
" [[ 1.1928e-02, 1.3112e-02, -2.6053e-02, ..., 4.6390e-02,\n",
" 2.8720e-02, 5.6897e-02],\n",
" [ 5.1804e-01, -8.6756e-03, 3.4240e-01, ..., -9.3518e-01,\n",
" -2.8230e-02, -1.6108e-01],\n",
" [-6.5553e-01, -1.4296e-01, 6.3211e-01, ..., -2.3726e+00,\n",
" -1.0325e+00, 1.1180e+00],\n",
" ...,\n",
" [-2.7697e-01, 4.7694e-01, 9.3078e-01, ..., -1.4985e-02,\n",
" -9.5630e-01, -1.0057e+00],\n",
" [-3.5304e-01, 7.6668e-01, -7.3687e-01, ..., 8.2464e-01,\n",
" 6.1313e-01, 1.4616e-01],\n",
" [ 6.2543e-02, 9.5850e-01, 9.9546e-02, ..., -4.1675e-01,\n",
" -3.1019e-01, 2.1785e-02]],\n",
"\n",
" [[ 3.9431e-02, 3.2304e-02, -6.9643e-02, ..., 3.1842e-03,\n",
" 1.5391e-02, 8.6383e-03],\n",
" [-4.5218e-02, 3.8015e-01, -7.4175e-03, ..., -8.6065e-02,\n",
" 1.9510e-01, 2.4301e-02],\n",
" [ 1.0227e+00, 7.7004e-02, 7.1903e-02, ..., 1.1994e+00,\n",
" 1.6976e-01, -4.0066e-01],\n",
" ...,\n",
" [ 1.1771e+00, 2.4422e-01, 7.0662e-01, ..., 1.1337e+00,\n",
" -8.5384e-01, -9.9605e-01],\n",
" [ 4.0196e-01, 3.7700e-01, 1.0244e+00, ..., -2.4000e-01,\n",
" -2.2166e-03, -8.7664e-01],\n",
" [ 6.3016e-01, 1.0653e-01, 6.7085e-01, ..., 1.8561e-01,\n",
" -1.0484e+00, -2.8506e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-2.8633e-02, 2.3521e-02, -1.3071e-02, ..., -3.3836e-02,\n",
" -4.1805e-02, 1.6132e-02],\n",
" [ 2.9617e-01, 2.5753e-01, 6.4459e-01, ..., 6.7883e-01,\n",
" -1.1170e-01, 3.4354e-01],\n",
" [ 1.4840e-01, -2.1638e-01, 1.5988e-01, ..., 3.0029e-01,\n",
" -1.7462e+00, 2.2010e+00],\n",
" ...,\n",
" [-1.0735e-01, -2.7973e-01, 1.7696e-01, ..., 1.2454e-01,\n",
" 1.6533e+00, 4.6311e-02],\n",
" [-2.5303e-01, -5.3346e-01, -7.0970e-01, ..., 3.3254e-01,\n",
" -1.0337e-01, -1.5011e+00],\n",
" [ 9.4744e-01, 4.1239e-01, -1.0214e-01, ..., 1.0832e+00,\n",
" 1.1939e+00, 2.1364e-01]],\n",
"\n",
" [[-7.3958e-02, -4.4124e-02, 1.7760e-02, ..., 3.1321e-03,\n",
" -4.5881e-02, -1.0916e-01],\n",
" [-4.6492e-01, -6.5992e-02, -4.8427e-02, ..., 2.7765e-01,\n",
" 1.7094e-01, -2.1020e-01],\n",
" [-9.3265e-01, -1.7024e+00, 1.1011e-01, ..., -6.0777e-01,\n",
" 2.7326e-01, -1.2374e+00],\n",
" ...,\n",
" [-1.0394e-01, 3.4447e-02, -1.4004e+00, ..., 1.9303e-01,\n",
" -1.2038e+00, 5.6969e-01],\n",
" [-9.6140e-01, 5.8390e-01, -5.3376e-01, ..., 3.5307e-01,\n",
" 3.7874e-01, -4.8008e-02],\n",
" [-6.0081e-02, -5.1836e-01, -9.0043e-02, ..., 2.2977e-01,\n",
" -1.1964e-01, -4.6107e-01]],\n",
"\n",
" [[-7.7909e-03, 4.0206e-02, -5.6468e-02, ..., -3.0341e-02,\n",
" 2.4338e-02, 5.3261e-03],\n",
" [-2.4073e-01, -1.4607e-01, 6.8568e-01, ..., -9.4289e-01,\n",
" -1.0285e+00, -7.2268e-01],\n",
" [-8.9161e-01, 3.2033e-01, 2.2241e-01, ..., 7.4783e-01,\n",
" -1.8553e-01, -1.4143e+00],\n",
" ...,\n",
" [ 5.0678e-01, -8.7200e-01, 1.3745e+00, ..., 4.7279e-01,\n",
" 1.9468e-01, -3.0692e-01],\n",
" [ 9.3960e-02, -1.1271e+00, -3.2356e-01, ..., 4.6166e-01,\n",
" 1.1812e+00, -7.4736e-02],\n",
" [-1.3350e-01, -7.4492e-01, 7.1189e-01, ..., -1.8032e-01,\n",
" -1.5200e+00, -9.0480e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-0.5266, 0.4829, -0.8581, ..., -1.0129, -1.3275, 0.2040],\n",
" [ 0.6896, 0.9769, 0.7884, ..., 2.3917, -0.4978, -1.2931],\n",
" [ 0.4849, 0.2801, 2.0546, ..., 2.4861, 0.4516, -2.1291],\n",
" ...,\n",
" [ 1.1704, -0.3403, 1.9815, ..., 2.4192, 1.2849, -0.9075],\n",
" [ 0.5567, 0.0731, 0.2333, ..., 1.9754, -0.6718, 0.4945],\n",
" [ 1.1157, 0.0910, 1.9513, ..., 2.0806, -0.6777, -1.8277]],\n",
"\n",
" [[ 0.8645, -2.0846, 0.1532, ..., 0.2459, -2.4906, -0.4514],\n",
" [ 1.1494, 1.2483, -0.0495, ..., 0.1813, 0.8199, -0.5313],\n",
" [ 1.4952, 1.4661, -1.3266, ..., 0.6351, 0.5419, 0.3732],\n",
" ...,\n",
" [ 1.6627, 3.2121, -1.1410, ..., -0.1081, 2.2876, -1.0492],\n",
" [ 0.8725, 3.7180, -0.8677, ..., 0.5521, 0.0537, -2.0911],\n",
" [ 1.8427, 2.8496, -0.9180, ..., -0.0876, 2.9764, -1.0137]],\n",
"\n",
" [[ 1.0175, 0.3871, -0.1741, ..., -0.8094, -1.4149, -0.3730],\n",
" [-0.2747, 0.4294, -0.8148, ..., 0.7997, -1.0098, -0.2083],\n",
" [-0.1443, 0.1837, -0.6903, ..., 2.3234, -0.5142, -1.1581],\n",
" ...,\n",
" [-1.7553, -0.7940, -1.4744, ..., 1.9563, -0.3079, 0.2517],\n",
" [-1.1555, -0.9816, -1.4792, ..., 2.4893, -0.8572, 0.6439],\n",
" [ 0.2061, 0.6956, -1.2343, ..., 1.2946, -0.7649, -1.0596]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.2137, -0.5814, 0.4917, ..., -0.6758, 1.0594, 0.2809],\n",
" [-0.0787, 1.1178, -0.9665, ..., -2.9838, -0.0755, 0.3358],\n",
" [-0.3467, 0.6547, -1.9701, ..., -2.6404, -1.7759, 0.1484],\n",
" ...,\n",
" [-0.6088, 0.2404, -1.0831, ..., -2.5044, -0.5236, 0.2501],\n",
" [-1.2627, -0.4007, 0.0159, ..., -2.2715, -1.9617, 0.1351],\n",
" [-0.4629, 0.4004, -1.0877, ..., -3.2533, -0.1876, -0.2612]],\n",
"\n",
" [[ 0.2431, 0.5528, 0.5439, ..., 0.7452, 0.0856, 0.8468],\n",
" [ 0.3639, 2.4237, 0.9672, ..., 0.7770, -0.7330, 0.4097],\n",
" [-0.4982, 1.9386, -0.1103, ..., 1.4543, -0.3265, 0.4745],\n",
" ...,\n",
" [ 0.2191, 1.5633, -0.4826, ..., -0.9138, -0.7183, 0.2929],\n",
" [-2.4011, -0.7274, -0.1691, ..., 0.5614, -0.1154, 2.1418],\n",
" [ 1.8710, 2.7152, 0.3026, ..., 0.4339, -1.6067, 0.4278]],\n",
"\n",
" [[-0.7092, 0.3125, -1.6205, ..., -0.4008, 0.2350, -1.3048],\n",
" [ 0.0382, 0.8210, -1.6851, ..., 1.5476, 1.1133, -1.3639],\n",
" [-1.2253, -0.0602, -3.1185, ..., -0.4857, 1.8382, 0.9552],\n",
" ...,\n",
" [-3.0265, -0.1628, -0.6678, ..., 1.2046, 1.1136, -0.6637],\n",
" [-2.2680, -0.1403, -1.6040, ..., 0.0642, 0.6752, -0.0818],\n",
" [-0.6567, -0.4737, -1.8665, ..., 1.7928, 1.7230, -1.4443]]]],\n",
" grad_fn=<PermuteBackward0>), tensor([[[[ 3.4694e-03, 5.2180e-02, -7.1138e-02, ..., 5.9633e-02,\n",
" -5.0955e-02, -7.4279e-02],\n",
" [ 9.5441e-02, -8.2082e-04, 3.7786e-01, ..., -7.9814e-01,\n",
" -3.4941e-01, 5.3955e-01],\n",
" [-1.7982e-01, 1.0793e+00, 8.4480e-01, ..., -5.6335e-01,\n",
" 4.7423e-01, -1.3511e-01],\n",
" ...,\n",
" [-8.4265e-01, -2.8338e-02, -8.4992e-01, ..., -8.6247e-01,\n",
" 6.9610e-01, -1.6560e-01],\n",
" [-5.8879e-02, 7.0560e-01, 8.3837e-01, ..., -4.8124e-01,\n",
" -1.7102e+00, 4.3793e-01],\n",
" [ 4.7365e-02, 2.9308e-01, 2.9819e-01, ..., -3.3441e-01,\n",
" -6.1577e-01, 5.0968e-01]],\n",
"\n",
" [[ 3.7719e-02, 1.2977e-04, 4.9038e-02, ..., -3.9138e-02,\n",
" -2.6473e-02, -1.4142e-02],\n",
" [-1.9622e-01, 4.4304e-01, 9.3501e-02, ..., 8.6879e-01,\n",
" 5.8439e-01, 6.5467e-01],\n",
" [-8.1397e-01, -1.4557e+00, -3.4408e-01, ..., 1.0143e+00,\n",
" 1.6014e-01, -7.6486e-01],\n",
" ...,\n",
" [ 4.8213e-01, 1.1956e+00, -6.7466e-01, ..., 4.4558e-03,\n",
" -6.0745e-01, 1.5004e-01],\n",
" [-9.2434e-01, -9.9667e-02, -1.7371e-01, ..., 3.3668e-01,\n",
" 3.7452e-01, 9.1399e-01],\n",
" [-8.0525e-01, 2.7367e-01, 2.7182e-01, ..., 1.5725e+00,\n",
" 1.8934e-01, 9.1494e-01]],\n",
"\n",
" [[-6.9176e-03, 1.8243e-02, -3.3975e-02, ..., 8.5669e-03,\n",
" 2.7227e-02, 5.8461e-02],\n",
" [-2.8638e-01, 4.4393e-02, -2.4720e-01, ..., 5.8055e-01,\n",
" -1.1038e+00, -3.1214e-01],\n",
" [-4.1151e-02, 4.7980e-01, -8.1177e-01, ..., 2.5263e+00,\n",
" -6.2052e-01, -4.0801e-01],\n",
" ...,\n",
" [-6.4285e-01, 2.1790e-01, 7.1201e-01, ..., 7.6857e-01,\n",
" 1.9746e-02, -1.2292e-02],\n",
" [ 4.3683e-01, -2.0561e-01, 5.6170e-01, ..., -1.3195e+00,\n",
" -6.0955e-01, 8.5465e-01],\n",
" [-5.0826e-02, 2.0641e-01, 2.1014e-01, ..., -6.1202e-01,\n",
" -3.7409e-01, 5.8607e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-7.8410e-02, 2.6667e-02, 1.1429e-02, ..., -3.5996e-02,\n",
" -7.8381e-03, 1.2273e-03],\n",
" [-2.4955e-01, 3.0179e-01, 2.2439e-01, ..., -7.0245e-01,\n",
" -4.7259e-01, -1.2154e-01],\n",
" [-1.1360e+00, 4.8186e-01, 6.9660e-01, ..., 5.2388e-02,\n",
" 4.9656e-01, -7.1202e-01],\n",
" ...,\n",
" [-2.2132e-01, 2.5862e-01, 5.4504e-01, ..., 6.4937e-01,\n",
" -1.4201e-01, -6.9701e-02],\n",
" [-9.8322e-01, -1.1579e-01, 1.4461e+00, ..., 4.0303e-01,\n",
" -8.9281e-01, 9.6826e-01],\n",
" [-1.2535e+00, 7.9669e-01, 2.3864e+00, ..., -1.1996e+00,\n",
" -1.2942e-02, 1.5757e+00]],\n",
"\n",
" [[ 7.4373e-02, -1.0839e-03, 4.7472e-02, ..., 2.5576e-02,\n",
" 5.5578e-02, 3.0725e-02],\n",
" [ 1.6561e-01, 1.1326e+00, 1.1021e+00, ..., -6.7084e-02,\n",
" 1.0625e+00, -7.9841e-01],\n",
" [-1.1934e+00, 1.3455e+00, 7.5402e-01, ..., 3.0290e+00,\n",
" 1.9807e+00, -1.6143e-01],\n",
" ...,\n",
" [ 8.7119e-01, 1.6007e+00, 9.8724e-01, ..., 6.2297e-01,\n",
" 9.5836e-01, -6.7591e-02],\n",
" [ 5.6550e-01, 7.5545e-01, -9.4622e-01, ..., 3.9639e-01,\n",
" -1.3479e-01, 1.4511e-01],\n",
" [-2.5438e-01, 1.3767e+00, 1.5838e+00, ..., 5.7618e-02,\n",
" 1.7279e+00, -7.0514e-01]],\n",
"\n",
" [[-1.1214e-01, 2.7461e-02, -6.8169e-02, ..., -8.8035e-02,\n",
" 7.2290e-02, -2.1984e-02],\n",
" [-1.5667e-01, 3.3572e-01, -2.9793e-01, ..., -3.2849e-01,\n",
" -6.0364e-02, 8.4579e-02],\n",
" [-3.0011e-01, 3.6599e-01, 4.1995e-01, ..., -5.6659e-01,\n",
" 1.6448e-01, 2.1300e-01],\n",
" ...,\n",
" [-5.3965e-01, 8.5568e-01, -1.0334e+00, ..., 1.6571e+00,\n",
" 1.2634e+00, 1.2663e-02],\n",
" [-1.1969e+00, 1.1998e-01, 7.4285e-01, ..., 2.5529e+00,\n",
" 2.4390e+00, -3.4413e-01],\n",
" [-7.4888e-01, 4.5366e-01, -1.2199e+00, ..., -4.5325e-01,\n",
" 4.5486e-01, 9.2945e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-1.7115, -0.3095, -0.3052, ..., 0.1569, 0.3295, -0.5102],\n",
" [-0.2954, -0.5369, -1.0378, ..., 1.1073, -0.8305, 0.1360],\n",
" [ 0.3707, -0.1892, -0.0631, ..., -0.2839, -0.1220, 0.7498],\n",
" ...,\n",
" [ 0.1536, 0.6239, -0.8671, ..., 0.7387, 0.5640, 0.1241],\n",
" [ 2.3072, -0.4214, -0.9284, ..., 1.0017, 0.4802, 0.5843],\n",
" [ 0.1397, 0.0057, -1.1350, ..., 1.1576, -1.0530, 0.4489]],\n",
"\n",
" [[ 0.1108, -0.0705, 2.3025, ..., 0.2415, 0.0896, -0.1951],\n",
" [ 0.7120, -1.0674, -0.9315, ..., 0.1600, 0.2404, -0.4726],\n",
" [ 1.4893, -0.9425, -1.5101, ..., -0.3591, 0.0335, 0.4421],\n",
" ...,\n",
" [ 0.1341, -0.0506, -1.3000, ..., -0.1105, -0.2529, 0.7670],\n",
" [-0.5547, -0.6913, -1.2921, ..., 0.2898, -0.2538, 0.9526],\n",
" [ 0.2336, -0.8297, -0.8416, ..., -0.0980, -0.2919, 0.7454]],\n",
"\n",
" [[-0.2041, 1.0503, 0.4759, ..., -0.5452, 0.3040, -0.1147],\n",
" [-0.4822, 0.5302, -0.5850, ..., 1.5575, 0.2531, -0.3998],\n",
" [-0.5382, 0.5110, 0.0939, ..., 0.6880, -0.2115, -0.7376],\n",
" ...,\n",
" [-0.3389, 0.6598, -0.4937, ..., 0.4478, -0.8530, -0.5765],\n",
" [-0.7257, 1.3182, 0.9792, ..., 1.7382, 0.5813, -1.0117],\n",
" [-0.9058, 1.0543, -0.6513, ..., 2.0381, -0.3831, 0.2467]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.5577, 0.9790, -0.8716, ..., -0.7184, 0.7245, 0.8611],\n",
" [ 0.3537, 0.1875, -0.5812, ..., -0.5651, -0.0493, 0.0897],\n",
" [ 0.4359, -0.7858, -0.9179, ..., -1.4713, 0.2309, -0.2758],\n",
" ...,\n",
" [ 0.9503, -0.6978, 0.7306, ..., -0.7847, -0.9335, -0.8081],\n",
" [ 0.1232, 1.2112, 1.0973, ..., 0.2584, 1.1175, -0.0057],\n",
" [-0.0062, -0.2483, 0.2463, ..., -0.3165, 0.3718, -0.2848]],\n",
"\n",
" [[-0.4014, 0.3733, 0.3393, ..., 0.7212, 0.0451, -0.0838],\n",
" [-0.5197, 1.3345, -1.5982, ..., 0.5380, -0.2475, -0.9776],\n",
" [ 0.0646, 0.0452, -0.4746, ..., 0.9874, -0.8139, 0.1726],\n",
" ...,\n",
" [-0.6432, -0.4941, 0.4357, ..., 0.5838, -1.3339, -0.0826],\n",
" [-0.9413, -1.2357, -0.4911, ..., 1.3679, -1.0148, -1.4263],\n",
" [-0.9545, 0.2418, -1.5970, ..., 0.3238, -0.9107, -0.7229]],\n",
"\n",
" [[-0.7459, -0.0075, 0.4400, ..., -0.1109, 0.0299, -0.0598],\n",
" [-0.2137, 0.3865, 1.1712, ..., 0.4425, -0.3584, 1.2832],\n",
" [ 0.5957, 0.1015, -0.1897, ..., 0.4039, -1.3808, 1.2112],\n",
" ...,\n",
" [ 0.7437, -1.3902, 0.2656, ..., 0.9423, -1.2780, 1.6726],\n",
" [-0.7614, 0.3624, 1.4484, ..., 0.2220, -1.0658, 1.0444],\n",
" [-0.4612, 0.8413, 1.7939, ..., 0.1289, -0.8518, 1.1819]]]],\n",
" grad_fn=<PermuteBackward0>), tensor([[[[ 7.7570e-02, -1.1777e-01, -1.6829e-01, ..., -3.0139e-01,\n",
" 2.8640e-01, -1.7741e-01],\n",
" [ 3.7056e-01, 6.3164e-01, 7.7662e-01, ..., 2.7495e+00,\n",
" -1.5492e+00, 1.1155e+00],\n",
" [ 1.6399e+00, 1.3236e+00, 5.0145e-01, ..., 2.7380e+00,\n",
" -2.5362e+00, 2.0660e+00],\n",
" ...,\n",
" [-2.2210e-03, 1.2618e-01, 1.9018e-01, ..., 2.5907e+00,\n",
" -1.5682e+00, 7.7443e-01],\n",
" [ 1.0661e+00, 1.8362e-01, 9.9011e-01, ..., 1.7970e+00,\n",
" -1.8210e-01, -7.9636e-01],\n",
" [ 1.1176e+00, 9.5490e-01, 4.2716e-01, ..., 2.4762e+00,\n",
" -1.8121e+00, 1.6125e+00]],\n",
"\n",
" [[ 1.0853e-01, -1.0814e-02, 5.5897e-02, ..., -9.3695e-03,\n",
" -8.4395e-02, 1.6578e-01],\n",
" [ 9.0288e-02, 4.3214e-01, 7.7907e-02, ..., 3.6511e-01,\n",
" 4.1462e-01, -3.7498e-01],\n",
" [ 4.8901e-02, 1.1972e+00, -1.0267e-01, ..., -2.4577e-01,\n",
" 3.2252e-01, 9.5713e-02],\n",
" ...,\n",
" [ 1.4289e+00, -4.0081e-01, 8.8847e-01, ..., -1.2688e-01,\n",
" -2.1349e-01, -1.5179e+00],\n",
" [-1.8024e-01, -5.9997e-01, 1.6811e+00, ..., 8.8114e-01,\n",
" -1.2796e+00, 8.0612e-01],\n",
" [ 3.5363e-01, 1.5338e-01, 1.0489e-01, ..., 7.1419e-01,\n",
" -2.5939e-01, 1.1640e-01]],\n",
"\n",
" [[-1.3536e-02, 2.5633e-02, -3.8610e-02, ..., 4.7447e-02,\n",
" 4.5465e-04, 7.3786e-02],\n",
" [ 3.7973e-01, -2.6919e-01, -4.5875e-01, ..., -1.4160e-01,\n",
" 3.0695e-01, -4.8341e-01],\n",
" [ 1.1969e+00, 1.2378e+00, -6.2153e-01, ..., -9.3299e-01,\n",
" 5.5717e-02, -2.5939e-02],\n",
" ...,\n",
" [ 1.0509e+00, -6.8117e-01, -5.0678e-01, ..., -5.8349e-01,\n",
" 1.6390e-01, -4.4167e-01],\n",
" [-5.3312e-01, 6.3160e-01, 2.2554e-01, ..., -1.1507e+00,\n",
" 6.4968e-01, 3.7368e-01],\n",
" [ 2.3626e-01, -1.7837e-01, 2.7653e-01, ..., -8.8951e-02,\n",
" -3.4488e-02, -6.5983e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-3.0262e-02, -1.2759e-02, 8.2024e-02, ..., 4.1477e-02,\n",
" -3.4039e-02, 1.6534e-02],\n",
" [ 7.0146e-02, 3.9249e-01, 3.6694e-02, ..., 1.1981e-01,\n",
" -3.4416e-01, -1.2740e-01],\n",
" [-1.4357e+00, -8.1313e-01, 3.6240e-01, ..., 6.4624e-01,\n",
" -7.6324e-01, 1.4873e+00],\n",
" ...,\n",
" [ 1.0556e-01, -3.8366e-01, 1.2748e+00, ..., -3.6558e-01,\n",
" 4.0858e-01, 2.4199e-01],\n",
" [-2.5444e-01, 1.1958e+00, -1.7147e-01, ..., 6.1984e-01,\n",
" -2.2845e-01, -1.8110e+00],\n",
" [ 3.2427e-01, 8.9915e-01, 1.1141e+00, ..., 6.8071e-01,\n",
" -3.4533e-01, -1.7910e-01]],\n",
"\n",
" [[-1.8907e-01, -6.5480e-02, 7.6243e-02, ..., -5.9887e-02,\n",
" 5.6530e-02, -7.3080e-02],\n",
" [-8.2506e-01, -3.6656e-02, 4.9222e-01, ..., 2.5220e-01,\n",
" 3.1897e-01, 1.9113e-01],\n",
" [-4.6517e-01, -2.1911e-01, -6.4030e-01, ..., 7.2280e-01,\n",
" 7.5668e-01, 5.6131e-01],\n",
" ...,\n",
" [-1.0660e+00, -4.2479e-01, -5.0573e-01, ..., -5.8658e-02,\n",
" -6.6094e-02, -4.4752e-01],\n",
" [-8.6907e-02, 1.2486e-04, -5.2314e-01, ..., 1.1544e-01,\n",
" 4.3831e-01, -1.0179e-02],\n",
" [-1.0669e+00, -7.1475e-01, 8.0158e-01, ..., -1.1919e-01,\n",
" -2.0185e-01, 3.2946e-01]],\n",
"\n",
" [[ 1.2763e-01, -1.2701e-01, 1.6529e-01, ..., -1.4527e-01,\n",
" -8.5370e-03, -1.7278e-01],\n",
" [-6.5069e-02, 3.5000e-01, 5.6586e-01, ..., -3.5917e-01,\n",
" -4.1324e-01, 2.9987e-01],\n",
" [ 2.5123e-01, 5.5106e-01, 4.2795e-01, ..., -1.0718e+00,\n",
" -6.8236e-01, -4.2256e-01],\n",
" ...,\n",
" [-6.0648e-01, -5.4619e-01, 1.4942e-02, ..., -7.6836e-01,\n",
" -5.9767e-01, -1.3891e-02],\n",
" [-3.4398e-01, -8.0992e-01, 7.4776e-01, ..., -1.8947e+00,\n",
" -2.7473e-01, 4.0089e-01],\n",
" [ 8.6354e-02, -1.2515e-02, -2.7977e-01, ..., -4.1148e-01,\n",
" -5.5178e-01, 7.0079e-02]]]], grad_fn=<PermuteBackward0>))), hidden_states=(tensor([[[ 0.0231, -0.2904, 0.1120, ..., 0.2610, 0.0677, 0.0696],\n",
" [ 0.0355, -0.0567, -0.0626, ..., 0.0619, -0.0195, -0.0601],\n",
" [-0.0454, -0.0329, 0.0899, ..., -0.1099, -0.1276, 0.0272],\n",
" ...,\n",
" [-0.0366, -0.0155, 0.1617, ..., -0.0452, 0.0707, -0.0497],\n",
" [-0.0264, -0.0053, 0.2777, ..., 0.1060, 0.1664, 0.0183],\n",
" [ 0.0103, -0.0047, 0.1434, ..., 0.0254, -0.0255, -0.0704]]],\n",
" grad_fn=<AddBackward0>), tensor([[[ 1.5737e+00, -4.1554e-01, 4.5012e-01, ..., 4.3850e-02,\n",
" 7.4813e-01, -8.7114e-01],\n",
" [ 5.1361e-01, -6.6155e-01, 1.0332e-01, ..., 4.2718e-01,\n",
" 1.7186e-01, 3.6244e-01],\n",
" [ 1.2385e+00, 5.1269e-04, -1.1555e-01, ..., 3.3694e-01,\n",
" -2.2656e-01, 7.6178e-02],\n",
" ...,\n",
" [-1.5542e+00, 5.6012e-01, 3.0304e-01, ..., 2.0757e-01,\n",
" 3.6331e-01, -5.2796e-01],\n",
" [-8.0574e-01, 5.1341e-01, -1.3832e+00, ..., 8.7573e-01,\n",
" -3.1620e-01, -2.6355e+00],\n",
" [ 8.7906e-01, -4.0571e-01, 6.8713e-01, ..., 1.3655e+00,\n",
" -1.1660e-01, 2.1324e-01]]], grad_fn=<AddBackward0>), tensor([[[ 1.7136, -0.5216, 1.2041, ..., -0.4961, 0.3665, -0.9365],\n",
" [ 0.4630, -1.2140, 0.2936, ..., 0.0555, -0.1479, 0.3223],\n",
" [ 1.2810, -0.0626, -0.0681, ..., 0.6627, -0.5515, 0.0529],\n",
" ...,\n",
" [-1.3538, 0.9463, 0.3435, ..., -0.0469, 0.4996, -0.5079],\n",
" [-1.2018, 1.1568, -1.9729, ..., 0.3070, -0.0780, -2.1962],\n",
" [ 0.4538, -0.4325, 0.9298, ..., 1.6704, 0.1176, 0.5136]]],\n",
" grad_fn=<AddBackward0>), tensor([[[ 1.8131, -0.8204, 1.0690, ..., -0.6062, 0.4388, -0.8892],\n",
" [ 0.3553, -1.4214, 0.3465, ..., -0.1229, 0.1026, 0.6289],\n",
" [ 1.6588, -0.5855, -0.1310, ..., 1.0190, -0.4376, -0.4088],\n",
" ...,\n",
" [-1.1001, 1.4018, -0.0845, ..., -0.4871, 0.3749, -1.0466],\n",
" [-1.2557, 1.2836, -2.5036, ..., -0.1603, 0.0254, -2.3484],\n",
" [ 0.4166, -0.5125, 0.6953, ..., 1.8050, 0.6178, 0.6728]]],\n",
" grad_fn=<AddBackward0>), tensor([[[ 1.8357, -0.8387, 1.1291, ..., -0.5870, 0.4266, -1.0183],\n",
" [-0.3023, -1.8606, 1.0695, ..., 0.3596, -0.5872, 0.5146],\n",
" [ 1.5486, -1.3812, -0.1454, ..., 1.4216, -0.7276, -0.3115],\n",
" ...,\n",
" [-0.8990, 1.3792, -0.6556, ..., -0.6427, -0.1838, -1.0314],\n",
" [-0.6506, 1.4321, -3.7864, ..., 0.2906, -0.3390, -2.7433],\n",
" [ 0.5480, -0.9662, 0.9323, ..., 2.0826, -0.5486, 1.2011]]],\n",
" grad_fn=<AddBackward0>), tensor([[[ 1.6162, -0.8975, 1.0517, ..., -0.7185, 0.2539, -1.0555],\n",
" [-0.0308, -2.1858, 1.7953, ..., 0.5839, -1.0037, 0.0798],\n",
" [ 1.9824, -0.7727, -0.1712, ..., 1.7961, -1.0021, -0.3786],\n",
" ...,\n",
" [-1.1462, 1.0538, -1.0321, ..., -0.0505, -0.3385, -1.3392],\n",
" [-0.6031, 1.9507, -4.7104, ..., -0.0331, -1.0798, -2.4425],\n",
" [ 0.5712, -0.7698, 0.1273, ..., 2.8240, -0.8675, 2.1530]]],\n",
" grad_fn=<AddBackward0>), tensor([[[ 1.5710, -0.9778, 1.0983, ..., -0.8036, 0.1757, -1.0363],\n",
" [-0.5121, -2.1376, 1.7901, ..., -0.0355, -0.4783, 0.1833],\n",
" [ 2.8356, -1.5824, -0.2001, ..., 1.8292, -0.4691, -0.2781],\n",
" ...,\n",
" [-1.6092, 0.1276, -1.6480, ..., 0.7556, -2.2751, -1.2271],\n",
" [-0.3862, 2.8926, -5.3254, ..., 0.5635, -1.5554, -2.6868],\n",
" [ 0.6955, -0.6462, -0.3514, ..., 3.4493, -1.9874, 1.3638]]],\n",
" grad_fn=<AddBackward0>), tensor([[[ 1.4392, -0.9376, 1.1554, ..., -0.8639, 0.1171, -1.0310],\n",
" [-0.2120, -2.0884, 2.2357, ..., -0.8004, -0.2832, -0.2491],\n",
" [ 2.7662, -1.6102, -0.1855, ..., 2.3809, 0.2519, -0.4420],\n",
" ...,\n",
" [-1.4429, -0.1494, -0.8831, ..., 1.2360, -1.6377, -0.8880],\n",
" [-0.9246, 2.8136, -5.2786, ..., 0.1955, -1.6184, -2.6251],\n",
" [ 0.9074, -0.3075, 0.1530, ..., 3.1575, -1.6791, 2.0776]]],\n",
" grad_fn=<AddBackward0>), tensor([[[ 1.4151, -0.7950, 1.0212, ..., -0.8095, 0.0292, -1.1826],\n",
" [-0.0706, -2.0130, 1.8284, ..., -1.0185, -0.5239, -0.3039],\n",
" [ 2.8450, -2.3009, -0.5953, ..., 2.0502, 1.1716, -0.2201],\n",
" ...,\n",
" [-1.0831, -0.3495, -0.4953, ..., 0.7348, -1.0733, -0.3256],\n",
" [-0.6313, 2.8501, -5.5530, ..., -0.0141, -2.2424, -3.8297],\n",
" [ 2.0435, -0.2091, 0.7285, ..., 2.5350, -2.2868, 1.3605]]],\n",
" grad_fn=<AddBackward0>), tensor([[[ 1.4463, -0.7566, 0.9623, ..., -0.7003, 0.0289, -1.2995],\n",
" [ 0.4330, -1.4939, 2.7411, ..., -0.2542, 0.3714, -1.6697],\n",
" [ 2.4653, -2.0962, -0.6611, ..., 2.4599, 1.8867, -0.6674],\n",
" ...,\n",
" [ 1.2202, -2.0474, 1.7625, ..., -0.5113, 0.7804, 1.4529],\n",
" [ 1.0899, 0.4627, -6.6348, ..., -2.2547, -2.7966, -4.2566],\n",
" [ 2.2639, -0.6145, 0.7215, ..., 1.7289, -0.9348, -0.0800]]],\n",
" grad_fn=<AddBackward0>), tensor([[[ 1.3416, -0.6064, 0.6988, ..., -0.6046, 0.0922, -1.5941],\n",
" [-0.3715, -1.3355, 2.9444, ..., -0.1253, 1.5043, -2.8058],\n",
" [ 0.9600, -2.2277, -0.0108, ..., 2.9812, 3.4562, -1.3117],\n",
" ...,\n",
" [ 2.7550, -2.8540, 3.9844, ..., -0.4379, 2.8047, 0.9528],\n",
" [ 1.7625, -2.2070, -7.9801, ..., -2.1712, -3.5339, -4.6076],\n",
" [ 2.3666, -1.7680, 0.7266, ..., 4.0575, -0.2326, -2.1535]]],\n",
" grad_fn=<AddBackward0>), tensor([[[ 1.2195, -0.3806, 0.3530, ..., -0.5992, 0.3146, -1.7930],\n",
" [-0.1447, 0.0618, 2.7296, ..., 0.8753, 1.8019, -4.6930],\n",
" [-0.9555, -3.1084, 1.1448, ..., 3.5270, 4.3085, 1.1351],\n",
" ...,\n",
" [ 1.1198, -5.1489, 5.3349, ..., 1.5175, 3.6925, 1.5494],\n",
" [ 2.8521, -1.7178, -7.8211, ..., -2.2027, -6.7088, -5.0671],\n",
" [ 2.9345, -1.3891, 0.9643, ..., 3.5691, -0.1766, -3.9141]]],\n",
" grad_fn=<AddBackward0>), tensor([[[ 0.0502, 0.0018, -0.1750, ..., -0.1020, -0.0257, -0.1292],\n",
" [ 0.1300, 0.1757, 0.2934, ..., 0.0794, 0.1164, -0.3280],\n",
" [ 0.0021, -0.2481, 0.2638, ..., 0.1507, 0.4056, 0.2376],\n",
" ...,\n",
" [ 0.1611, -0.4680, 0.7029, ..., 0.1209, 0.3803, 0.2864],\n",
" [ 0.1791, -0.3507, -1.2709, ..., -0.1535, -0.7109, -0.2459],\n",
" [ 0.2872, -0.0504, 0.0839, ..., 0.3417, -0.0518, -0.3151]]],\n",
" grad_fn=<ViewBackward0>)), attentions=None, cross_attentions=None)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 9, 768])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output.hidden_states[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 9, 768])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output.hidden_states[1].shape"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 9, 768])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output.hidden_states[2].shape"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"13"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(output.hidden_states)\n"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 9, 768])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output.last_hidden_state.shape"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"pt_model = AutoModelForCausalLM.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"output = pt_model(**encoding)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[ -36.3292, -36.3402, -40.4228, ..., -46.0234, -44.5284,\n",
" -37.1276],\n",
" [-114.9346, -116.5035, -117.9236, ..., -117.8857, -119.3379,\n",
" -112.9298],\n",
" [-123.5036, -123.0548, -127.3876, ..., -130.5238, -130.5279,\n",
" -123.2711],\n",
" ...,\n",
" [-101.3852, -101.2506, -103.6583, ..., -103.3747, -107.7192,\n",
" -99.4521],\n",
" [ -83.0701, -84.3884, -91.9513, ..., -91.7482, -93.3971,\n",
" -85.1204],\n",
" [ -91.2749, -93.1332, -93.6408, ..., -94.3482, -93.4517,\n",
" -90.1472]]], grad_fn=<UnsafeViewBackward0>), past_key_values=((tensor([[[[-7.0634e-01, 1.9011e+00, 7.7253e-01, ..., -1.3028e+00,\n",
" -5.0432e-01, 1.6823e+00],\n",
" [-1.6482e+00, 3.0222e+00, 1.2789e+00, ..., -9.0779e-01,\n",
" -1.7395e+00, 2.4237e+00],\n",
" [-2.3128e+00, 2.8957e+00, 1.8368e+00, ..., -7.0370e-01,\n",
" -1.6305e+00, 2.4407e+00],\n",
" ...,\n",
" [-2.4337e+00, 2.5271e+00, 2.1513e+00, ..., -5.8053e-01,\n",
" -1.6483e+00, 2.0594e+00],\n",
" [-3.8223e+00, 2.1391e+00, 1.7587e+00, ..., -1.0668e+00,\n",
" -1.6278e+00, 1.1729e+00],\n",
" [-1.9238e+00, 2.7944e+00, 1.6292e+00, ..., -8.9733e-01,\n",
" -2.2193e+00, 2.6272e+00]],\n",
"\n",
" [[-9.6153e-02, 8.9928e-01, -1.4324e+00, ..., -3.8667e-03,\n",
" 1.7698e+00, 6.0074e-01],\n",
" [ 2.7222e-01, -1.2016e+00, -1.9081e+00, ..., -1.3531e+00,\n",
" 1.2823e+00, -4.3198e-01],\n",
" [-1.1722e+00, -3.6670e-01, -1.6921e+00, ..., -1.2359e+00,\n",
" 2.5243e+00, 1.0228e+00],\n",
" ...,\n",
" [-1.6694e-01, -1.0159e+00, -2.5232e+00, ..., -9.7920e-01,\n",
" 4.8265e+00, -1.7799e+00],\n",
" [-1.1981e-01, -2.6784e+00, -2.9551e+00, ..., -1.9840e-01,\n",
" 3.3916e+00, -1.9762e-02],\n",
" [ 3.2722e-01, -1.2197e+00, -2.1079e+00, ..., -1.6297e+00,\n",
" 9.2404e-01, -7.6080e-01]],\n",
"\n",
" [[-1.4670e-01, 2.1407e-01, 1.1498e+00, ..., -1.3128e+00,\n",
" -2.1007e+00, 5.6910e-01],\n",
" [ 5.5608e-01, -4.6297e-01, 7.4483e-01, ..., -1.8272e+00,\n",
" 5.4572e-01, 1.0119e+00],\n",
" [ 9.2851e-01, 4.6049e-03, 4.1324e-01, ..., -2.4987e+00,\n",
" 5.2423e-01, 1.5260e+00],\n",
" ...,\n",
" [ 3.2328e-01, 3.5316e-01, 3.2756e-02, ..., -3.2780e+00,\n",
" 8.1692e-01, 1.4566e+00],\n",
" [-2.1528e-01, -2.2490e-01, -1.4536e+00, ..., -3.7075e+00,\n",
" 1.6835e+00, 1.6085e+00],\n",
" [ 7.6672e-01, -5.3757e-01, 4.2462e-01, ..., -2.2908e+00,\n",
" 1.7213e+00, 1.0240e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[ 5.4733e-01, 4.7672e-01, -2.2749e-01, ..., 2.9014e-01,\n",
" 7.7821e-01, 7.8295e-01],\n",
" [ 1.6820e-01, -9.1829e-02, -5.0034e-02, ..., 7.3646e-01,\n",
" 6.1343e-01, 5.4442e-01],\n",
" [ 2.9530e-02, -5.3167e-02, -6.1709e-02, ..., 1.0934e+00,\n",
" 3.7083e-01, 3.8425e-01],\n",
" ...,\n",
" [-1.3203e-02, -2.6465e-01, 4.4834e-02, ..., 1.2205e+00,\n",
" 5.4265e-01, 3.7732e-01],\n",
" [ 8.5854e-02, -2.3791e-01, -1.1271e-01, ..., 1.8211e+00,\n",
" -5.7249e-01, -7.4493e-01],\n",
" [-3.6544e-02, -1.4250e-01, 6.6582e-02, ..., 1.0489e+00,\n",
" 4.8485e-01, 4.6476e-01]],\n",
"\n",
" [[ 1.4700e+00, 1.3564e+00, -4.9892e-01, ..., -6.4925e-02,\n",
" 1.4507e+00, -1.2267e+00],\n",
" [ 1.0113e+00, 7.0108e-01, -5.7364e-01, ..., -7.1721e-01,\n",
" 1.0731e+00, -1.0718e+00],\n",
" [ 1.1010e+00, 4.8299e-01, -9.3231e-01, ..., -1.5044e+00,\n",
" 1.2941e+00, -3.3869e-01],\n",
" ...,\n",
" [ 1.1745e+00, 6.3323e-01, -6.1605e-01, ..., -8.1925e-01,\n",
" 5.2691e-01, -7.5443e-01],\n",
" [ 1.7895e+00, 5.7095e-01, -3.5775e-01, ..., -1.3193e+00,\n",
" 5.5676e-01, -1.6293e-01],\n",
" [ 9.6151e-01, 2.9245e-02, -5.3493e-01, ..., -7.8683e-01,\n",
" 3.7355e-01, -2.4032e-01]],\n",
"\n",
" [[ 7.1643e-01, -3.1278e-01, 1.4058e-01, ..., -2.0734e-01,\n",
" 2.5946e-01, 1.7684e+00],\n",
" [-5.6619e-01, 7.8687e-01, 2.5152e-02, ..., 6.2100e-01,\n",
" 4.7592e-01, 5.4321e-01],\n",
" [-6.2611e-01, 3.3320e-01, 1.1092e-01, ..., 6.4703e-01,\n",
" 6.4159e-01, 7.2777e-01],\n",
" ...,\n",
" [-1.7180e-01, 1.1778e+00, -2.3931e-01, ..., -6.3932e-01,\n",
" 1.1654e+00, 4.0462e-01],\n",
" [-4.8319e-01, 2.8237e-01, -4.4490e-01, ..., -1.2013e-01,\n",
" 4.8413e-01, -4.5133e-01],\n",
" [-1.1252e+00, 7.6533e-01, -6.0320e-02, ..., 1.8912e-01,\n",
" 7.8018e-01, -5.4733e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[ 0.1900, 0.0015, -0.0517, ..., 0.0536, 0.0312, -0.0694],\n",
" [-0.0800, 0.0181, -0.0534, ..., -0.0419, -0.0365, 0.0151],\n",
" [ 0.0448, 0.1912, -0.1849, ..., -0.0062, -0.1420, 0.1609],\n",
" ...,\n",
" [-0.1635, 0.0196, 0.1185, ..., 0.0794, 0.0980, -0.1084],\n",
" [-0.2303, 0.1991, -0.1576, ..., 0.2774, -0.1813, -0.2463],\n",
" [-0.1009, 0.0410, -0.0970, ..., -0.0684, -0.0763, 0.0260]],\n",
"\n",
" [[ 0.4406, 0.1176, -0.2136, ..., -0.6839, -0.2371, 0.2999],\n",
" [ 0.5926, 0.0197, 0.1107, ..., 0.1253, 0.5675, -0.2665],\n",
" [ 0.6762, 0.0459, -0.3685, ..., 0.0744, 0.5420, -0.1240],\n",
" ...,\n",
" [ 0.8509, -0.0962, 0.0762, ..., -0.1705, 0.1339, 0.1068],\n",
" [ 0.2928, -0.2582, 0.1735, ..., 0.0800, 0.2879, -0.0139],\n",
" [ 0.5969, 0.0592, 0.0263, ..., -0.0100, 0.5129, -0.1905]],\n",
"\n",
" [[ 0.0810, -0.1910, 0.1092, ..., -0.0283, 0.0408, 0.0961],\n",
" [-0.3257, 0.0398, -0.1531, ..., 0.0411, -0.0413, 0.0745],\n",
" [ 0.5201, 0.0126, 0.3504, ..., 0.1020, 0.0543, -0.2188],\n",
" ...,\n",
" [-0.5288, -0.0025, -0.5926, ..., -0.1874, -0.0674, 0.3113],\n",
" [ 0.1521, 0.0271, -0.2514, ..., -0.0465, -0.0565, -0.3401],\n",
" [-0.2885, 0.0590, -0.1736, ..., 0.0685, -0.1112, 0.0604]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.0111, -0.0168, 0.0263, ..., -0.2135, 0.2054, 0.0729],\n",
" [-0.3022, -0.0878, 0.1001, ..., 0.0262, -0.1647, 0.1682],\n",
" [-0.1587, -0.0666, 0.0826, ..., -0.0416, 0.0812, 0.2067],\n",
" ...,\n",
" [-0.0925, -0.4836, 0.0332, ..., 0.0641, -0.1597, 0.2375],\n",
" [-0.0742, 0.8589, 0.0336, ..., -0.3268, -0.2455, 0.3080],\n",
" [-0.0869, -0.4287, 0.1231, ..., -0.0474, -0.1705, 0.0347]],\n",
"\n",
" [[ 0.2081, -0.2399, -0.1318, ..., 0.1471, 0.1123, -0.0316],\n",
" [-0.2119, 0.0589, 0.0997, ..., 0.0038, 0.1331, 0.0930],\n",
" [-0.1213, 0.1404, 0.1775, ..., 0.1688, -0.0020, 0.0829],\n",
" ...,\n",
" [-0.2325, 0.1252, -0.0345, ..., 0.2837, 0.0686, -0.0089],\n",
" [ 0.1896, 0.0282, -0.0740, ..., 0.1655, -0.3020, 0.2837],\n",
" [ 0.0298, 0.0086, -0.1626, ..., 0.1976, 0.0970, -0.0014]],\n",
"\n",
" [[-0.0689, -0.3955, 0.2328, ..., 0.1539, -0.1823, -0.0845],\n",
" [ 0.0538, -0.2648, -0.0146, ..., 0.2331, 0.0516, 0.0924],\n",
" [-0.0647, 0.0062, 0.1329, ..., 0.1026, 0.1185, 0.0463],\n",
" ...,\n",
" [ 0.0186, 0.1904, -0.0966, ..., 0.0714, -0.0321, -0.0059],\n",
" [ 0.0219, 0.4180, -0.1580, ..., -0.0072, -0.2708, 0.1529],\n",
" [ 0.1236, -0.3671, -0.0392, ..., 0.1061, -0.0278, -0.0074]]]],\n",
" grad_fn=<PermuteBackward0>)), (tensor([[[[-3.5429e-01, 2.2092e+00, -1.5580e+00, ..., 1.4397e+00,\n",
" -1.1504e+00, 1.4646e+00],\n",
" [ 7.3885e-01, 1.8177e+00, -1.4766e+00, ..., -4.6761e-01,\n",
" -1.6869e+00, 5.0785e-01],\n",
" [ 1.6962e+00, 1.1427e+00, -1.1112e+00, ..., 1.2764e-01,\n",
" -2.5909e+00, 7.2933e-01],\n",
" ...,\n",
" [-1.9130e-03, 1.6441e+00, -3.0120e-01, ..., 3.8508e-01,\n",
" -1.0645e+00, -4.5135e-01],\n",
" [-3.9438e-01, 1.6005e+00, 9.6257e-01, ..., 5.8858e-01,\n",
" -1.8425e+00, -9.6318e-01],\n",
" [-4.9488e-01, 1.1094e+00, 5.2522e-02, ..., 5.6471e-01,\n",
" -1.3969e+00, -3.0882e-01]],\n",
"\n",
" [[-1.0087e+00, -4.5958e-01, -7.4797e-01, ..., -3.7310e-01,\n",
" 7.9809e-01, -2.3881e-01],\n",
" [-6.6438e-02, 4.8658e-01, -8.2457e-01, ..., -9.4308e-01,\n",
" 1.8907e-01, -1.5256e-02],\n",
" [-1.7392e-01, 1.1992e+00, -1.5513e+00, ..., -3.2774e-01,\n",
" 7.3627e-01, -3.6968e-01],\n",
" ...,\n",
" [-1.1986e-01, 6.0111e-01, -1.4226e+00, ..., -6.1346e-01,\n",
" 1.3460e-01, -6.1240e-01],\n",
" [ 1.8174e-01, 3.1973e-01, -2.2986e+00, ..., -4.1319e-01,\n",
" -1.0757e+00, -4.7605e-01],\n",
" [-2.4593e-01, 1.1035e+00, -1.4215e+00, ..., -6.2691e-01,\n",
" -1.1097e+00, -6.3956e-01]],\n",
"\n",
" [[ 3.2591e-01, -1.6143e-02, -2.0098e-01, ..., -1.3362e+00,\n",
" 3.3876e-01, -1.6542e-01],\n",
" [-1.0002e-02, 3.9666e-01, -9.3499e-02, ..., -1.0921e+00,\n",
" 5.6914e-02, 4.1318e-01],\n",
" [-1.1656e-02, 2.1262e-01, -2.3546e-01, ..., -9.7254e-01,\n",
" 1.4688e-01, 2.7869e-01],\n",
" ...,\n",
" [-8.3349e-02, 3.9433e-02, -9.7432e-03, ..., -7.0562e-01,\n",
" 4.2687e-01, 2.3274e-01],\n",
" [ 1.0450e-01, -2.0783e-01, -2.8860e-01, ..., -1.0073e+00,\n",
" -1.2179e-01, 3.5471e-01],\n",
" [-1.4484e-01, -5.0447e-02, -3.9541e-03, ..., -1.0255e+00,\n",
" 1.9039e-01, 3.3890e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[ 2.1528e-01, -4.6627e-01, -5.9642e-01, ..., -4.2178e-01,\n",
" 4.3739e-01, -8.5899e-01],\n",
" [-5.0305e-02, 1.2479e+00, 1.8768e+00, ..., 6.8503e-01,\n",
" -7.3186e-01, -3.4076e-01],\n",
" [-4.0512e-01, 1.6082e+00, 1.8570e+00, ..., 1.2636e+00,\n",
" -1.1781e+00, -8.1034e-01],\n",
" ...,\n",
" [-5.1299e-01, 2.6865e-01, 7.6903e-01, ..., -1.3940e+00,\n",
" 8.1194e-01, -1.8763e-01],\n",
" [ 2.3526e-01, -5.7615e-01, 1.3541e+00, ..., 1.4708e+00,\n",
" -2.9934e-01, -3.9407e-01],\n",
" [ 5.0755e-02, 7.0489e-01, 1.9166e+00, ..., 6.6883e-01,\n",
" -9.1450e-01, -2.5584e-01]],\n",
"\n",
" [[-1.1473e+00, -2.7966e+00, 1.4438e-01, ..., 1.7208e+00,\n",
" 1.5965e+00, -1.4860e+00],\n",
" [ 3.5231e-01, 7.5960e-01, -4.7429e-01, ..., -8.1442e-01,\n",
" 4.5442e-01, -2.9752e-01],\n",
" [ 2.1113e-01, 7.5264e-01, -4.5093e-01, ..., -9.6233e-01,\n",
" 5.8766e-01, 9.0545e-02],\n",
" ...,\n",
" [ 1.6897e-01, 2.5023e-01, -7.4581e-01, ..., -1.2799e-01,\n",
" 7.1349e-01, -8.5998e-02],\n",
" [-2.3828e-01, 5.9684e-01, -7.5936e-01, ..., -6.6564e-01,\n",
" 7.3313e-01, 1.8287e-01],\n",
" [-1.6440e-01, 2.5931e-01, -8.1777e-01, ..., -3.5322e-01,\n",
" 8.3564e-01, -5.9446e-02]],\n",
"\n",
" [[ 1.3976e+00, 1.6241e+00, 5.4245e-01, ..., -7.8420e-01,\n",
" 1.1678e-01, 3.7706e-01],\n",
" [ 8.8908e-01, 2.1345e+00, 1.0939e+00, ..., 1.1961e-01,\n",
" -7.5297e-01, -1.4081e-01],\n",
" [ 6.7893e-01, 1.8408e+00, 1.5060e+00, ..., 5.9498e-01,\n",
" -2.2553e+00, -1.8270e+00],\n",
" ...,\n",
" [-5.1015e-02, 2.4946e+00, -1.6883e-01, ..., 5.4761e-01,\n",
" -2.8891e-01, -6.7954e-01],\n",
" [-1.6942e-01, 4.9026e-01, 1.1144e+00, ..., 9.3912e-03,\n",
" -8.0171e-01, -1.4243e-01],\n",
" [ 8.4424e-01, 1.7401e+00, 9.2639e-01, ..., -1.4967e-01,\n",
" -3.8360e-01, -1.5520e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[ 3.3872e-01, 1.3968e-01, -1.7938e-01, ..., 1.5467e-01,\n",
" -1.2589e-01, 7.0887e-02],\n",
" [ 3.7346e-01, 2.8615e-01, 7.3073e-02, ..., -1.7334e-01,\n",
" -1.7929e-01, 8.0809e-02],\n",
" [ 1.3121e-01, 1.3779e-01, 9.8802e-02, ..., 1.7611e-01,\n",
" -6.5489e-01, -3.7171e-01],\n",
" ...,\n",
" [ 4.5774e-01, 6.2110e-02, 4.7204e-02, ..., 2.1876e-01,\n",
" -1.9506e-01, 1.5526e-01],\n",
" [-1.6503e-01, 7.2050e-02, -4.4076e-01, ..., 9.3966e-02,\n",
" -8.1660e-02, -2.9702e-01],\n",
" [ 3.7986e-01, 3.8336e-01, 1.0341e-01, ..., -1.9899e-01,\n",
" -2.3373e-01, -1.3201e-01]],\n",
"\n",
" [[-7.9321e-02, -6.6966e-02, -2.2227e-01, ..., -1.4152e-02,\n",
" -4.5964e-01, 2.7340e-01],\n",
" [-2.0632e-01, -2.7675e-01, 9.3918e-02, ..., -9.7495e-02,\n",
" 2.0266e-01, 3.4913e-02],\n",
" [-3.6562e-01, -2.8439e-01, 2.9782e-01, ..., -1.0605e+00,\n",
" 2.7564e-01, 3.3809e-01],\n",
" ...,\n",
" [ 5.1779e-01, 2.3170e-01, -3.0248e-01, ..., 4.6880e-01,\n",
" 4.3330e-01, -6.2105e-01],\n",
" [-1.9805e-01, 6.8445e-02, -5.7586e-02, ..., 1.3844e-01,\n",
" -6.2666e-02, 1.8667e-01],\n",
" [ 6.9782e-02, -1.5278e-01, 6.9243e-02, ..., -1.0944e-01,\n",
" 1.1224e-01, 1.1524e-01]],\n",
"\n",
" [[ 7.9376e-02, -1.4863e-02, -4.4028e-02, ..., -6.2825e-01,\n",
" 6.7840e-02, 1.0440e-02],\n",
" [ 4.2720e-01, 2.4379e-01, 2.3040e-01, ..., -5.0812e-01,\n",
" 3.7279e-02, -1.3192e-01],\n",
" [ 6.2018e-01, 1.7793e-01, 2.9474e-01, ..., -7.6162e-01,\n",
" -2.8552e-01, -1.4080e-01],\n",
" ...,\n",
" [ 5.8184e-01, 5.9326e-02, 2.5048e-03, ..., -6.1473e-01,\n",
" -3.0034e-02, 4.4224e-02],\n",
" [ 6.7462e-01, 1.3863e-01, -5.1645e-02, ..., -5.6261e-01,\n",
" -2.2474e-01, -1.2376e-01],\n",
" [ 6.0415e-01, 9.6460e-02, 1.1331e-01, ..., -2.8026e-01,\n",
" 2.4650e-02, -2.4321e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[ 1.0567e-01, 6.7946e-01, -1.7619e-01, ..., 1.2480e-02,\n",
" -9.7338e-01, -2.5708e-01],\n",
" [-5.0101e-04, -7.4670e-01, 1.4215e-01, ..., 2.6520e-02,\n",
" -9.1824e-01, -4.4347e-01],\n",
" [ 5.7162e-02, -6.6084e-01, -1.7225e-01, ..., -6.7773e-02,\n",
" -6.9370e-01, 2.2682e-01],\n",
" ...,\n",
" [ 3.6897e-01, 4.0040e-01, 1.3203e-01, ..., 5.9832e-02,\n",
" -4.3946e-01, 3.3851e-02],\n",
" [-1.9931e-01, 4.7522e-01, 6.5326e-01, ..., 8.5060e-01,\n",
" -1.5948e-01, 2.6952e-01],\n",
" [ 4.5483e-02, -7.9412e-01, 2.0943e-01, ..., 6.4299e-02,\n",
" -6.5777e-01, -2.0458e-01]],\n",
"\n",
" [[ 4.7333e-02, -1.1130e-02, -1.4608e-01, ..., 3.8364e-01,\n",
" -3.4244e+00, 6.6758e-02],\n",
" [ 5.0051e-01, 8.4673e-03, 1.9747e-01, ..., 2.1474e-01,\n",
" -7.4449e-03, -2.8373e-01],\n",
" [-2.0428e-01, 2.4512e-01, -2.7017e-01, ..., 4.5577e-02,\n",
" 2.1612e-02, -1.3106e-01],\n",
" ...,\n",
" [ 7.3244e-02, -1.5794e-01, 1.7578e-01, ..., -2.2690e-01,\n",
" -6.3669e-02, -1.8729e-02],\n",
" [ 1.3369e-01, 4.0795e-01, -6.9403e-02, ..., -2.8477e-02,\n",
" 8.1580e-02, -3.7645e-01],\n",
" [ 3.2948e-01, 2.4525e-01, 3.1002e-02, ..., 1.4547e-03,\n",
" -2.0459e-01, -1.3566e-02]],\n",
"\n",
" [[ 2.4439e-02, -2.3092e-01, 1.1163e-02, ..., -3.4285e-01,\n",
" 2.7007e-01, -3.4211e-02],\n",
" [ 2.0095e-01, -4.9356e-01, 5.3058e-01, ..., -2.7157e-01,\n",
" 4.2807e-01, 3.2917e-01],\n",
" [-1.0993e-01, -4.1360e-01, 1.9816e-02, ..., -1.7917e-01,\n",
" 3.6033e-01, 2.2954e-01],\n",
" ...,\n",
" [ 4.2263e-02, 1.5875e-02, -3.0871e-01, ..., -3.1441e-01,\n",
" 2.9030e-01, 2.2213e-01],\n",
" [-4.9536e-02, 8.3578e-02, 7.2786e-02, ..., -2.5493e-01,\n",
" 4.7891e-02, 3.4251e-01],\n",
" [ 5.0301e-02, -1.8544e-01, 5.7551e-01, ..., -3.4349e-01,\n",
" 1.5927e-01, 4.2942e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-1.5217e-01, -1.1477e+00, 2.3295e-01, ..., -6.4279e-01,\n",
" -1.1349e-01, 4.0799e-02],\n",
" [ 4.5919e-01, -2.0374e+00, -7.9378e-01, ..., 4.4668e-02,\n",
" -8.8579e-01, -9.0097e-01],\n",
" [ 3.8866e-01, -1.6082e+00, -3.9608e-01, ..., 3.1908e-01,\n",
" -4.2160e-01, -1.1912e-01],\n",
" ...,\n",
" [-8.1627e-02, -1.0257e+00, -6.6449e-01, ..., 6.6261e-01,\n",
" -1.8242e-01, -5.9660e-02],\n",
" [ 9.9366e-01, -2.8990e+00, -4.2770e-01, ..., 1.5473e+00,\n",
" -2.7730e-01, 1.0212e+00],\n",
" [ 3.7402e-01, -1.2451e+00, -8.3321e-01, ..., 1.5307e+00,\n",
" -6.0831e-01, -1.0434e+00]],\n",
"\n",
" [[-5.0563e-01, 3.4884e-01, -4.0126e-01, ..., 1.2945e+00,\n",
" -5.5872e-01, -4.4031e-01],\n",
" [-1.0783e+00, -1.0583e+00, -8.7019e-01, ..., 9.3939e-01,\n",
" 6.1988e-01, -3.6133e-01],\n",
" [-1.4605e+00, 7.9834e-04, -1.6445e+00, ..., 8.5405e-01,\n",
" 1.1266e+00, 2.1244e-01],\n",
" ...,\n",
" [-1.7653e+00, -4.5490e-01, 5.8049e-01, ..., 1.3604e-01,\n",
" -2.6502e-01, 1.4497e+00],\n",
" [-2.7539e+00, -1.9189e+00, -6.1803e-01, ..., 2.3083e+00,\n",
" -6.2625e-01, -5.0954e-01],\n",
" [-8.4786e-01, -9.9176e-01, -1.4226e+00, ..., 1.0424e+00,\n",
" 1.2138e+00, -6.2367e-01]],\n",
"\n",
" [[ 1.3477e+00, 3.0343e+00, 3.7258e+00, ..., 6.1286e-01,\n",
" 1.7142e+00, -7.4960e-01],\n",
" [-3.4424e+00, 2.1578e+00, -3.4773e+00, ..., -1.7704e+00,\n",
" 3.4858e+00, 9.8086e-01],\n",
" [-3.3403e+00, 7.3066e-01, -4.6132e+00, ..., -3.2065e+00,\n",
" 5.3039e+00, 7.1677e-01],\n",
" ...,\n",
" [-4.8998e+00, -5.9784e-01, -2.9574e+00, ..., -4.1010e+00,\n",
" 2.4786e+00, 2.7664e-02],\n",
" [-3.3274e+00, -1.2454e+00, -5.1031e+00, ..., -3.2964e+00,\n",
" 3.3057e+00, 1.4853e+00],\n",
" [-4.2024e+00, -1.7287e+00, -5.1702e+00, ..., -2.7123e+00,\n",
" 2.8922e+00, 1.8391e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[ 1.3818e+00, -2.7867e+00, -2.6519e+00, ..., 9.1555e-01,\n",
" 4.4077e-01, 2.7028e+00],\n",
" [-2.4026e+00, 1.6620e+00, -4.5219e-01, ..., 1.2064e-01,\n",
" -1.6484e+00, 5.6717e-01],\n",
" [-1.7379e+00, 2.8888e+00, 2.1535e-01, ..., -7.8397e-02,\n",
" -2.7045e+00, -3.0823e-03],\n",
" ...,\n",
" [-2.9426e+00, 3.5565e+00, 1.0280e+00, ..., -3.5420e-01,\n",
" -3.7917e+00, -7.8773e-01],\n",
" [-2.8640e+00, 2.8314e+00, 2.3865e+00, ..., -2.2468e+00,\n",
" -4.0705e+00, -1.2861e+00],\n",
" [-3.9137e+00, 4.3675e+00, 1.5171e+00, ..., -6.0161e-01,\n",
" -2.7414e+00, -1.2265e+00]],\n",
"\n",
" [[ 1.7415e+00, 4.5990e-01, 9.3163e-01, ..., 1.2650e-03,\n",
" -9.8961e-01, -2.9552e-01],\n",
" [ 2.2626e+00, 1.0377e+00, 1.1163e+00, ..., 3.4995e-01,\n",
" -2.5767e+00, -1.2164e+00],\n",
" [ 2.0896e+00, 6.8649e-01, 1.2068e+00, ..., 4.1762e-01,\n",
" -2.1005e+00, -1.2765e+00],\n",
" ...,\n",
" [ 1.8625e+00, 5.6272e-01, 1.1284e+00, ..., 3.5132e-01,\n",
" -2.0787e+00, -1.0202e+00],\n",
" [ 2.2705e+00, 3.2166e-01, 1.1907e+00, ..., 2.6156e-01,\n",
" -1.2966e+00, -9.9152e-01],\n",
" [ 2.3024e+00, 4.0813e-01, 9.6441e-01, ..., 4.9377e-01,\n",
" -2.5960e+00, -6.9144e-01]],\n",
"\n",
" [[-2.2407e-01, 1.4293e-01, -5.5406e-01, ..., 3.1676e-01,\n",
" 2.7494e-01, 1.6436e-01],\n",
" [-5.7508e-01, 6.1265e-01, -2.6713e-01, ..., 8.0278e-01,\n",
" 8.5041e-01, 1.8214e-01],\n",
" [ 6.2629e-01, 3.5029e-02, 8.6408e-02, ..., 4.6667e-01,\n",
" 1.6070e-01, 1.2988e-01],\n",
" ...,\n",
" [ 1.5542e-01, -2.5139e-01, -8.1318e-01, ..., 2.1838e-01,\n",
" 2.0266e-01, 6.9734e-01],\n",
" [-2.4867e-01, 4.2143e-01, -4.6590e-01, ..., 3.0348e-01,\n",
" 5.7653e-01, -5.7979e-01],\n",
" [-4.1779e-01, -4.9530e-01, -6.0749e-01, ..., 5.8660e-01,\n",
" 9.1405e-01, -3.4966e-02]]]], grad_fn=<PermuteBackward0>), tensor([[[[-1.5059e-02, -2.1934e-02, -1.3257e-01, ..., -3.3233e-03,\n",
" 5.6872e-03, -5.5921e-01],\n",
" [-4.4076e-01, 4.7031e-01, -2.1116e-01, ..., 5.7315e-01,\n",
" -3.8024e-01, 2.5338e-01],\n",
" [ 2.7640e-01, 1.0290e-01, -1.5030e-01, ..., 8.0443e-02,\n",
" -1.0340e-02, 6.5651e-01],\n",
" ...,\n",
" [ 7.7904e-01, 1.2082e+00, 3.0358e-01, ..., 4.4578e-01,\n",
" -4.0582e-02, 8.5044e-01],\n",
" [-2.0731e-01, -5.8119e-01, 4.1100e-01, ..., -1.7157e-01,\n",
" 2.8487e-01, 6.4911e-01],\n",
" [-8.6411e-01, 5.4967e-01, -4.1298e-01, ..., 9.2813e-01,\n",
" -4.2606e-01, -3.4161e-01]],\n",
"\n",
" [[ 3.8557e-02, 3.3662e-03, 5.4482e-02, ..., -5.7578e-02,\n",
" -7.4123e-02, 2.2392e-02],\n",
" [ 1.9386e-01, 1.8534e-01, 3.0680e-01, ..., -1.2764e-03,\n",
" -2.5348e-01, 8.6118e-02],\n",
" [-1.4242e-01, 3.2992e-01, 7.6395e-02, ..., 9.8633e-02,\n",
" -5.6915e-02, 4.4799e-02],\n",
" ...,\n",
" [-7.1944e-02, 3.8884e-02, 1.0161e-01, ..., -2.7253e-01,\n",
" 1.3398e-01, 1.1796e-01],\n",
" [-1.0896e+00, 2.1403e+00, -1.3890e-01, ..., 1.0035e+00,\n",
" 6.1333e-01, -1.1536e+00],\n",
" [ 6.1611e-02, 7.1527e-02, 2.0043e-01, ..., -3.5723e-01,\n",
" -1.4230e-01, 8.4502e-02]],\n",
"\n",
" [[ 1.1201e-02, -7.6654e-01, -1.1583e-02, ..., 4.3143e-02,\n",
" 1.5736e-02, -5.8100e-02],\n",
" [ 2.8462e-01, -1.0610e+00, 1.2486e-01, ..., 3.1588e-02,\n",
" -1.1913e-01, -4.8153e-02],\n",
" [ 2.6008e-01, -6.3008e-01, -8.1709e-01, ..., 1.8586e-01,\n",
" 3.4370e-01, 9.2477e-01],\n",
" ...,\n",
" [-1.9891e-01, -1.9001e+00, -4.4621e-02, ..., 7.8242e-02,\n",
" 2.2361e-02, 1.3589e-02],\n",
" [-2.8968e-01, -1.5899e+00, 9.2801e-02, ..., -2.7827e-01,\n",
" 1.6159e-01, -4.6007e-01],\n",
" [ 1.6971e-01, -1.5136e+00, 1.2845e-01, ..., -6.2768e-02,\n",
" -2.5769e-01, -1.5622e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[ 1.6522e-02, -7.7326e-02, 1.3163e+00, ..., -5.6423e-02,\n",
" 1.7141e-01, 2.1386e-02],\n",
" [-4.3988e-01, -2.9255e-01, 2.4116e+00, ..., -1.8846e-01,\n",
" 1.0912e-01, 1.4147e-01],\n",
" [ 2.3190e-01, -1.5369e-01, 2.5701e+00, ..., 6.3039e-01,\n",
" -1.0088e-01, 5.1586e-01],\n",
" ...,\n",
" [ 1.7250e-02, 6.7580e-01, 2.5971e+00, ..., -5.2273e-01,\n",
" 4.5050e-01, -6.9956e-01],\n",
" [-4.9545e-02, 6.1819e-01, 3.8825e-01, ..., -1.4691e-01,\n",
" 4.5526e-01, 7.1271e-01],\n",
" [-1.9639e-01, -1.2515e-01, 2.5813e+00, ..., -1.8536e-01,\n",
" -1.3485e-01, -8.7375e-02]],\n",
"\n",
" [[ 7.4395e-02, -8.7165e-02, -1.8260e-01, ..., 1.3185e-01,\n",
" 1.2575e-01, 1.7169e-01],\n",
" [ 6.5960e-01, 1.0117e+00, 7.1659e-01, ..., 8.3512e-02,\n",
" -6.5585e-01, -3.3111e-01],\n",
" [ 3.2666e-01, -1.2571e-01, 8.1719e-01, ..., 9.9527e-01,\n",
" -1.0291e+00, -5.0537e-01],\n",
" ...,\n",
" [-7.2666e-01, 1.0662e-01, -7.2195e-02, ..., -2.7005e-01,\n",
" 5.2628e-01, 2.3005e-01],\n",
" [-2.0959e-01, -2.3959e-01, -3.0772e-01, ..., 4.6964e-01,\n",
" -1.8979e-01, -2.7418e-01],\n",
" [ 1.7468e-01, 1.0415e+00, 7.5772e-01, ..., -4.9262e-01,\n",
" -8.0868e-01, 4.5074e-01]],\n",
"\n",
" [[ 1.1606e-02, 2.1828e-02, 2.7971e-02, ..., -3.3218e-02,\n",
" 2.2172e-01, -2.3344e-03],\n",
" [ 1.1778e-01, -3.0263e-01, 3.5408e-01, ..., -3.3052e-01,\n",
" -1.9086e+00, 4.3385e-01],\n",
" [-7.0245e-01, 4.2293e-02, -1.3216e-01, ..., 3.4737e-01,\n",
" -1.4905e+00, 3.5105e-01],\n",
" ...,\n",
" [ 2.1967e-01, -6.0979e-01, -6.8996e-01, ..., 4.4944e-01,\n",
" -1.9601e+00, -1.7819e-01],\n",
" [ 3.8903e-01, 1.9728e-01, -9.0256e-01, ..., 1.3781e-01,\n",
" -2.0059e+00, 3.0071e-01],\n",
" [ 5.9661e-01, -3.1890e-01, -2.2125e-01, ..., 2.8531e-01,\n",
" -1.8048e+00, 2.1086e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[ 0.0532, -0.2197, 0.1445, ..., -0.8884, 0.7361, -1.2044],\n",
" [-0.6462, -0.7026, -1.4285, ..., 0.2179, -0.3014, 0.1623],\n",
" [-0.8909, -1.9166, -1.3314, ..., -1.8027, -2.7636, 2.9528],\n",
" ...,\n",
" [-1.9169, -0.2602, -0.2397, ..., -0.4901, -0.8816, 0.7061],\n",
" [-2.0792, 0.1064, 0.6011, ..., 0.5948, -0.5403, 1.4379],\n",
" [-0.4271, -0.4968, -0.0297, ..., 1.0395, -0.3829, 0.3067]],\n",
"\n",
" [[ 0.7842, 0.1905, 0.0089, ..., -0.1612, -1.0898, -0.1939],\n",
" [-1.3909, -1.5235, -0.5037, ..., 0.9582, 4.2044, 1.1825],\n",
" [ 0.1689, -1.8025, 0.8404, ..., 1.5177, 5.7815, 2.1470],\n",
" ...,\n",
" [ 1.2462, -1.4013, -1.2263, ..., 0.5912, 6.0711, 1.7328],\n",
" [ 1.4548, -2.0760, -2.0483, ..., -1.5971, 5.6172, 2.5548],\n",
" [-1.1053, -0.8554, -2.0471, ..., 0.8743, 6.2095, 1.1606]],\n",
"\n",
" [[ 0.3413, -0.3572, -0.3331, ..., 0.3294, 1.4604, 0.2755],\n",
" [ 0.0960, -6.2139, -0.6779, ..., -2.8446, -1.4388, -4.4836],\n",
" [-0.8714, -7.8835, -1.6969, ..., -2.1200, -2.1704, -7.2160],\n",
" ...,\n",
" [-3.2255, -7.0802, -1.8176, ..., -2.8620, -2.7388, -5.1880],\n",
" [-2.2788, -5.5723, -1.6649, ..., -3.3594, -2.4676, -5.1028],\n",
" [-2.9788, -7.2411, -1.0434, ..., -3.2540, -2.9263, -5.0116]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.2148, 1.7719, 0.5129, ..., 0.2612, 0.4477, -1.6895],\n",
" [-0.2874, -5.8026, 1.1293, ..., -2.2826, -1.7007, 5.5452],\n",
" [-2.4104, -6.5778, 1.1952, ..., -2.4193, -0.3969, 3.8159],\n",
" ...,\n",
" [-1.4026, -7.7514, 1.2659, ..., -3.4256, -2.3786, 6.9488],\n",
" [-1.0623, -5.7453, 0.1012, ..., -0.5622, -2.4292, 6.8565],\n",
" [-0.3079, -7.9204, 1.8029, ..., -3.2453, -2.3462, 7.0537]],\n",
"\n",
" [[ 0.0559, -0.0269, 0.1386, ..., -0.1165, -0.0882, -0.1612],\n",
" [ 0.1342, -0.5329, -0.2255, ..., -1.0159, 0.1003, -0.4600],\n",
" [-0.7412, -0.2755, 0.1787, ..., -0.8159, -0.9071, -0.1041],\n",
" ...,\n",
" [-0.0215, -0.5192, -0.2004, ..., 0.3272, -0.3216, 0.5758],\n",
" [ 0.2406, -0.3252, 0.3839, ..., -0.2115, 0.3593, -0.6457],\n",
" [-0.6898, -1.1861, -0.0238, ..., 0.5217, 0.0940, 0.9089]],\n",
"\n",
" [[ 0.3939, -0.0741, 1.9091, ..., -0.2314, -0.2112, -0.9825],\n",
" [ 2.5678, 1.8706, -2.0184, ..., 0.0582, 0.5182, 2.5282],\n",
" [ 3.1803, 2.0001, -2.9358, ..., 2.6552, 1.0590, 4.2195],\n",
" ...,\n",
" [ 2.6593, 1.2215, -2.5623, ..., 1.4338, 0.6112, 3.2894],\n",
" [ 1.1448, 0.9766, -2.1789, ..., 1.8788, 0.3242, 3.7226],\n",
" [ 2.8828, 1.7918, -3.5229, ..., -0.0936, 0.5881, 4.5368]]]],\n",
" grad_fn=<PermuteBackward0>), tensor([[[[ 4.5583e-02, 6.3733e-02, -3.4908e-03, ..., 3.8117e-03,\n",
" 1.0385e-01, 2.3468e-02],\n",
" [ 2.7193e-01, -2.7436e-01, 8.2051e-01, ..., -5.7602e-01,\n",
" -6.8246e-02, -1.7190e-02],\n",
" [-3.3605e-01, -9.2270e-01, -2.9339e-01, ..., 3.2747e-02,\n",
" 5.3266e-01, -1.1793e+00],\n",
" ...,\n",
" [ 4.6659e-01, 1.6959e-01, -6.8990e-02, ..., 3.2092e-01,\n",
" 2.9894e-02, 4.5212e-03],\n",
" [-3.8838e-01, -2.8303e-01, 6.4867e-01, ..., 5.4443e-01,\n",
" -3.8750e-03, -7.7317e-01],\n",
" [-1.4669e-01, -2.2234e-01, 5.0309e-01, ..., -2.0195e-01,\n",
" -3.4870e-02, 1.0260e+00]],\n",
"\n",
" [[-3.8964e-02, -8.6139e-03, 9.1636e-02, ..., -4.5061e-02,\n",
" -1.8257e-02, -4.4496e-02],\n",
" [-1.2398e-01, -4.6354e-01, 6.3162e-02, ..., 4.1472e-01,\n",
" -8.8383e-02, -6.1835e-02],\n",
" [ 2.3124e-01, -4.1944e-01, -5.5628e-02, ..., -6.5586e-01,\n",
" -2.9434e-01, 1.1322e-01],\n",
" ...,\n",
" [ 9.0615e-02, -2.5366e-01, -1.7453e-01, ..., 3.6981e-02,\n",
" 9.6252e-02, 2.8861e-01],\n",
" [ 2.6449e-01, -1.1997e+00, -2.9121e-01, ..., 1.8929e-01,\n",
" 8.9705e-01, 5.2265e-02],\n",
" [ 1.8653e-01, -4.1886e-01, -2.5386e-01, ..., 5.6907e-01,\n",
" -5.6461e-01, -2.9499e-01]],\n",
"\n",
" [[ 4.5761e-02, -1.1113e-01, -6.0327e-02, ..., -1.7311e-02,\n",
" 8.8352e-02, -1.4918e-01],\n",
" [ 3.5832e-01, 1.0048e-01, -3.5981e-01, ..., 4.7004e-01,\n",
" -1.0480e-01, -9.6169e-01],\n",
" [-1.2025e+00, -4.9562e-01, -5.6530e-01, ..., -7.7073e-02,\n",
" -1.8603e-01, 4.5677e-02],\n",
" ...,\n",
" [-1.1527e-01, -1.2046e-02, 7.9755e-01, ..., 2.0678e-01,\n",
" -1.6562e-01, -9.4135e-02],\n",
" [ 3.0203e-01, -5.3025e-02, 1.0025e-01, ..., -1.3117e-01,\n",
" -3.9940e-01, 2.0309e-01],\n",
" [ 5.4948e-01, -3.1714e-03, -9.9666e-01, ..., 3.6800e-01,\n",
" 2.6345e-01, -6.6638e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-2.2513e-02, 1.1954e-01, -1.7875e-02, ..., -1.4198e-02,\n",
" 6.4433e-02, -5.2401e-02],\n",
" [-6.7643e-03, 1.3038e-01, -3.1770e-02, ..., -2.8075e-02,\n",
" -7.0123e-02, 2.9359e-01],\n",
" [ 7.8513e-01, -7.9053e-01, -1.5511e-01, ..., -3.0193e-01,\n",
" -5.3295e-02, 5.1889e-01],\n",
" ...,\n",
" [-1.9707e-01, 5.0177e-02, -1.1185e-01, ..., -3.0111e-01,\n",
" 2.1017e-01, -2.7775e-01],\n",
" [-3.1374e-01, -3.5912e-02, -2.5133e-01, ..., -1.2073e-01,\n",
" 1.3938e-01, -1.4568e-01],\n",
" [-1.2432e-01, 3.0442e-01, 1.0542e-01, ..., 2.1967e-02,\n",
" 3.2316e-02, 1.2676e-01]],\n",
"\n",
" [[-1.7366e-01, -1.3407e-01, -6.7815e-02, ..., -2.3521e-01,\n",
" -1.8675e-02, -5.1927e-02],\n",
" [ 4.8318e-01, -4.9988e-01, 7.3483e-01, ..., 1.7037e-01,\n",
" 6.2192e-01, 2.3596e-01],\n",
" [ 1.1730e-01, 3.0694e-02, 7.3273e-01, ..., 5.0575e-01,\n",
" 3.1356e-02, -5.0081e-01],\n",
" ...,\n",
" [ 6.1899e-01, -9.2282e-01, 1.6701e-01, ..., -2.4323e-02,\n",
" 1.7694e-01, -3.4102e-01],\n",
" [ 8.4867e-01, 1.2311e-01, 3.3463e-01, ..., 3.2204e-01,\n",
" 8.6678e-01, 5.9980e-01],\n",
" [ 4.1040e-01, 1.7545e-01, 2.0518e-01, ..., -9.3810e-01,\n",
" 4.8850e-01, -5.4087e-01]],\n",
"\n",
" [[ 1.1481e-01, -7.4767e-02, -2.5446e-02, ..., -1.8679e-02,\n",
" -9.1254e-02, -9.6947e-02],\n",
" [ 5.5079e-01, 1.9193e-01, 1.8251e-04, ..., -1.0992e-02,\n",
" -2.6968e-01, -3.8421e-02],\n",
" [-1.8607e-01, -8.5692e-02, 3.1742e-01, ..., -3.9823e-01,\n",
" 4.3919e-01, -8.0165e-02],\n",
" ...,\n",
" [-1.6626e-01, -1.0646e+00, -1.0149e-02, ..., -9.7871e-02,\n",
" 1.4443e-01, -1.5419e-01],\n",
" [-4.4313e-01, -1.3310e-01, 4.2125e-01, ..., 4.0301e-02,\n",
" -1.7659e-01, 3.1838e-01],\n",
" [ 8.1519e-01, 2.4844e-01, 1.2036e-01, ..., -9.9506e-02,\n",
" -2.9214e-01, 5.8580e-02]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-8.8678e-01, -1.3593e-01, 3.3093e-01, ..., -9.5576e-01,\n",
" 2.5192e-02, -2.9464e+00],\n",
" [ 1.5584e+00, -5.3821e-01, -2.4421e+00, ..., -1.7774e+00,\n",
" -1.4069e+00, 7.2554e+00],\n",
" [-6.1613e-02, -9.6410e-01, -3.3367e+00, ..., -1.7228e+00,\n",
" -7.6467e+00, 6.5063e+00],\n",
" ...,\n",
" [ 1.2191e+00, -3.6478e-01, -1.8077e+00, ..., -1.4126e+00,\n",
" -3.4429e+00, 1.1099e+01],\n",
" [ 1.2810e+00, -4.1117e-01, -4.4152e+00, ..., -1.0298e+00,\n",
" -2.3506e+00, 1.1191e+01],\n",
" [ 1.5495e+00, -1.9605e+00, -3.1807e+00, ..., -9.8794e-01,\n",
" -2.1888e+00, 9.4760e+00]],\n",
"\n",
" [[ 3.7499e-01, -6.6046e-02, 4.5773e-01, ..., -1.2836e-01,\n",
" -7.7381e-02, -2.2161e+00],\n",
" [-1.9084e+00, -5.1770e-01, 3.3306e+00, ..., -1.0169e-01,\n",
" -2.0618e+00, 7.5854e+00],\n",
" [-3.1865e+00, -5.3798e-01, 3.4467e+00, ..., 8.8427e-02,\n",
" -4.1777e+00, 7.7792e+00],\n",
" ...,\n",
" [-2.9382e+00, -8.8965e-01, 3.4723e+00, ..., -1.4002e+00,\n",
" -5.7932e-01, 6.9011e+00],\n",
" [-3.7302e+00, -1.4835e+00, 7.7318e-01, ..., -1.4177e+00,\n",
" -1.5522e+00, 7.3279e+00],\n",
" [-2.4526e+00, -1.8321e+00, 3.6389e+00, ..., -4.4448e-01,\n",
" -1.6136e+00, 6.6650e+00]],\n",
"\n",
" [[ 1.2211e-01, -6.5015e-01, -2.2831e-01, ..., 1.4110e-01,\n",
" 2.7893e-01, -1.7424e-01],\n",
" [ 1.7771e-01, 1.7629e+00, 6.3257e-01, ..., -2.6582e-01,\n",
" 6.2577e-01, 5.0930e-02],\n",
" [ 2.2530e-01, 3.0012e+00, 5.3516e-01, ..., -3.2276e-01,\n",
" 5.9087e-01, -3.6453e-02],\n",
" ...,\n",
" [-6.4210e-01, 3.1597e+00, 2.3032e-01, ..., 6.4203e-01,\n",
" 1.9326e-01, 5.4560e-01],\n",
" [-4.8734e-01, 2.4240e+00, 1.1159e-01, ..., 9.6528e-01,\n",
" 1.2245e+00, -1.7901e+00],\n",
" [ 2.7319e-01, 2.8160e+00, 6.3444e-01, ..., -5.1675e-01,\n",
" -1.5301e-01, -8.1118e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-4.0181e-01, 1.2737e-02, -1.1140e-02, ..., 1.2548e+00,\n",
" 4.3199e-02, 1.8033e+00],\n",
" [-4.6139e-01, -1.3921e+00, -1.4511e+00, ..., -2.5093e+00,\n",
" -1.6920e+00, -2.7131e-01],\n",
" [-1.3954e-01, 3.9872e-01, -5.5181e-01, ..., -4.0252e+00,\n",
" -1.2034e+00, -8.0604e-01],\n",
" ...,\n",
" [ 3.8913e-01, -9.2129e-01, 6.7512e-01, ..., -3.2734e+00,\n",
" -3.7855e-01, -1.2775e+00],\n",
" [ 3.6478e-01, 1.1098e+00, 1.9589e+00, ..., -1.2581e+00,\n",
" -9.2984e-01, -1.5476e+00],\n",
" [-2.0390e-01, -6.6112e-01, -9.6914e-01, ..., -3.2531e+00,\n",
" -3.5533e-01, -3.5020e-01]],\n",
"\n",
" [[-3.3790e-01, -1.2825e-01, 2.2242e-01, ..., 2.6358e-01,\n",
" -2.9314e-02, 3.1528e-02],\n",
" [-6.0304e-01, -1.1295e+00, 1.4573e+00, ..., 7.0224e-01,\n",
" -8.5480e-01, 1.8017e-01],\n",
" [ 9.3104e-01, -2.1456e+00, 3.8324e-01, ..., 9.3967e-01,\n",
" -8.2110e-01, 1.3123e-01],\n",
" ...,\n",
" [-7.5492e-01, -1.8400e-01, 3.3456e-01, ..., 1.7404e+00,\n",
" 7.1590e-01, 1.3268e+00],\n",
" [-1.5429e-01, 5.3506e-01, 2.4561e+00, ..., 1.2834e+00,\n",
" 5.7729e-01, 1.3149e+00],\n",
" [-7.7036e-01, -6.9287e-01, 1.1238e+00, ..., 1.0106e+00,\n",
" -5.3742e-01, 1.3852e+00]],\n",
"\n",
" [[ 3.4402e+00, 2.1226e+00, -2.1050e+00, ..., -2.8555e+00,\n",
" -3.9038e+00, -1.2060e+00],\n",
" [-3.0643e+00, -1.6132e+00, 4.7811e+00, ..., -2.6905e+00,\n",
" 9.4376e+00, -3.7636e+00],\n",
" [-2.8029e+00, 9.2815e-01, 2.2908e+00, ..., -3.5372e+00,\n",
" 9.2503e+00, 2.0644e+00],\n",
" ...,\n",
" [-4.9774e+00, -1.8169e+00, 4.4703e+00, ..., -4.3005e+00,\n",
" 1.5492e+01, 3.7749e+00],\n",
" [-2.4577e+00, -1.8796e+00, 6.0842e+00, ..., -4.6722e+00,\n",
" 9.1210e+00, 1.8122e+00],\n",
" [-4.4732e+00, -2.0733e+00, 7.2062e+00, ..., -3.7151e+00,\n",
" 1.2814e+01, -2.2193e+00]]]], grad_fn=<PermuteBackward0>), tensor([[[[-0.0028, -0.0602, 0.0219, ..., 0.0593, 0.0264, 0.0681],\n",
" [ 0.6692, -0.1774, 0.2994, ..., 0.1940, -0.3524, -0.1093],\n",
" [ 0.0441, -0.6776, -0.4458, ..., 0.2746, 0.9155, -0.5374],\n",
" ...,\n",
" [-0.3262, -0.0103, -0.0866, ..., 0.0454, -0.1561, 0.2205],\n",
" [-0.0552, -0.6212, -0.4492, ..., -0.2533, 0.0952, -0.2438],\n",
" [ 0.1740, 0.0146, -0.0917, ..., 0.1930, -0.1700, 0.1307]],\n",
"\n",
" [[-0.0538, -0.0195, -0.1417, ..., -0.0445, 0.0476, -0.0319],\n",
" [ 0.3175, -0.1990, -0.2276, ..., 0.1004, -0.0740, -0.1226],\n",
" [ 0.3296, -0.6555, -0.2850, ..., -0.8669, 0.2712, 0.0552],\n",
" ...,\n",
" [-0.0141, 0.1838, 0.2267, ..., 0.0249, -0.0362, 0.3883],\n",
" [-0.2939, -0.5590, 0.3243, ..., -0.0678, 0.0157, -0.5514],\n",
" [-0.0048, -0.0914, -0.2181, ..., -0.2868, 0.0018, -0.0651]],\n",
"\n",
" [[ 0.0639, 0.0961, 0.0831, ..., 0.0160, -0.0859, -0.0050],\n",
" [-0.8685, -0.1267, -0.8107, ..., 0.0526, -0.7176, -0.0689],\n",
" [ 0.1621, 0.2253, 0.0752, ..., 0.1041, -0.4005, 0.1818],\n",
" ...,\n",
" [-0.4981, 0.5339, -0.4980, ..., -0.2581, -0.8093, -0.3876],\n",
" [-0.6054, 1.6497, 1.0752, ..., -1.0363, 0.7149, -0.6451],\n",
" [-0.4706, 0.3250, 0.3061, ..., 0.4489, -0.6589, 0.0312]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.0115, 0.0657, -0.0777, ..., 0.0440, 0.0456, -0.1384],\n",
" [ 0.0136, -0.3035, 0.8164, ..., -0.2084, -0.8236, 0.4428],\n",
" [ 0.2521, -0.4054, 0.2197, ..., 0.1480, 0.2216, 0.5164],\n",
" ...,\n",
" [-0.0778, -0.1247, -0.3227, ..., 0.1474, 0.1483, 0.3701],\n",
" [-0.3559, -0.8621, -0.0799, ..., -0.9994, 0.4109, 0.2198],\n",
" [-0.1967, 0.0573, 0.6049, ..., 0.1913, 0.0767, -0.0245]],\n",
"\n",
" [[-0.1315, -0.0534, 0.0947, ..., -0.0666, 0.0539, -0.0204],\n",
" [ 0.0918, -0.3386, -0.7173, ..., -0.2867, -0.0289, -0.1466],\n",
" [ 0.2971, 0.6579, -0.9279, ..., -0.0267, -1.3269, 0.6167],\n",
" ...,\n",
" [-0.1993, 0.8396, 0.5954, ..., -0.2100, 0.3891, 0.5287],\n",
" [ 1.5998, 0.6881, -0.2637, ..., 1.1610, 0.1208, -0.6552],\n",
" [ 0.5209, -0.3917, 0.1674, ..., -0.2824, 0.0700, -0.3138]],\n",
"\n",
" [[-0.0193, -0.0120, -0.0240, ..., -0.0300, 0.0080, -0.0136],\n",
" [-0.0432, -0.3667, -0.3346, ..., -0.1011, 0.0167, 0.1537],\n",
" [-0.3303, -0.6508, -0.2167, ..., -0.6360, -0.1999, -0.1340],\n",
" ...,\n",
" [-0.0058, -0.1530, -0.3235, ..., -0.3699, 0.0510, 0.1209],\n",
" [ 0.1009, 0.4467, -0.0791, ..., -0.2715, -0.2259, 0.5418],\n",
" [ 0.0141, 0.2831, -0.4868, ..., -0.1903, -0.1869, 0.7274]]]],\n",
" grad_fn=<PermuteBackward0>)), (tensor([[[[ 3.4512e-02, -2.8466e-01, 2.2210e-01, ..., 1.6982e+00,\n",
" -2.2029e-01, -8.0207e-02],\n",
" [ 1.1887e+00, 1.2521e+00, -2.0973e-02, ..., -2.7505e+00,\n",
" 1.1517e-01, -9.4738e-01],\n",
" [-3.7887e-01, 2.3147e-01, 7.8851e-01, ..., -3.8859e+00,\n",
" -1.2610e+00, -1.5381e+00],\n",
" ...,\n",
" [-3.8848e-01, -7.8692e-01, 6.0321e-01, ..., -1.5790e+00,\n",
" -4.4260e-01, -1.7360e+00],\n",
" [-1.3567e+00, -1.2212e-02, 4.0693e-01, ..., -2.6267e+00,\n",
" 3.1883e-01, -1.1768e+00],\n",
" [-1.3149e+00, 5.3910e-01, 8.4051e-01, ..., -2.6472e+00,\n",
" -8.0766e-02, -1.3063e+00]],\n",
"\n",
" [[ 1.5566e-01, 9.6884e-01, -1.4234e+00, ..., -1.1945e-01,\n",
" 2.6095e-01, 9.2861e-01],\n",
" [-1.1655e+00, -5.3317e+00, 7.2065e-01, ..., -1.4863e+00,\n",
" -2.2354e+00, -2.4988e+00],\n",
" [ 1.2192e+00, -4.3649e+00, 9.3857e-01, ..., 3.6005e-01,\n",
" -1.0827e+00, -2.1299e+00],\n",
" ...,\n",
" [-6.5003e-01, -3.6931e+00, 4.9255e-01, ..., -2.0790e+00,\n",
" -3.1514e-01, -2.7136e+00],\n",
" [ 5.9668e-01, -3.1527e+00, 7.6608e-01, ..., -4.4680e-01,\n",
" -1.1040e-01, -1.9393e+00],\n",
" [ 2.0418e+00, -5.3709e+00, 5.4901e+00, ..., -2.3439e-02,\n",
" 4.6572e-01, -3.8706e+00]],\n",
"\n",
" [[-6.7068e-01, 2.4994e-01, -5.6570e-02, ..., 1.7880e-01,\n",
" 5.6148e-02, -2.9901e-01],\n",
" [ 1.9676e+00, 2.9566e-02, -8.5660e-01, ..., -1.8619e+00,\n",
" -3.3802e-01, 1.6140e-01],\n",
" [ 2.1615e+00, -7.5559e-01, 3.4024e-01, ..., -1.4898e+00,\n",
" 4.2649e-01, 1.5977e+00],\n",
" ...,\n",
" [ 1.1094e+00, -8.7126e-01, 4.4787e-02, ..., -4.0946e-01,\n",
" -6.8646e-01, -5.1147e-01],\n",
" [ 1.3666e+00, -6.3472e-01, -6.9747e-01, ..., 6.0671e-01,\n",
" 2.1492e+00, -3.3250e-01],\n",
" [ 2.1474e+00, -4.8501e-02, -8.7131e-01, ..., -1.4417e+00,\n",
" 1.5616e+00, 1.8827e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-4.0635e-02, 1.1188e-01, 1.4037e-01, ..., -9.7647e-02,\n",
" 1.2961e-02, 1.5000e-01],\n",
" [ 7.8730e-01, -5.6138e-01, -1.2585e+00, ..., 1.1703e+00,\n",
" -1.7229e-01, 1.2928e+00],\n",
" [-1.0394e-01, 1.4770e-01, 3.8454e-01, ..., 6.5685e-01,\n",
" -2.6355e-01, 1.3102e+00],\n",
" ...,\n",
" [ 1.1297e-01, 1.4229e+00, 2.8362e-01, ..., 9.3448e-01,\n",
" 2.5909e-01, 2.9945e-01],\n",
" [ 6.5882e-01, 7.3874e-01, -5.1318e-01, ..., 9.5171e-01,\n",
" 1.6892e-01, -1.7952e-01],\n",
" [ 4.4172e-01, 9.7651e-02, -1.4498e+00, ..., 1.2877e+00,\n",
" 7.8737e-01, 5.7300e-02]],\n",
"\n",
" [[-3.0020e+00, 4.0418e-01, -2.7798e-02, ..., -4.8566e-01,\n",
" -3.4500e-01, 1.2311e+00],\n",
" [ 4.8764e+00, 1.4500e+00, -1.1937e+00, ..., -1.6858e+00,\n",
" 3.0943e-01, -9.1063e-01],\n",
" [ 4.6146e+00, 9.6566e-01, -5.1178e-01, ..., -2.1980e-01,\n",
" 1.1130e+00, -1.2746e+00],\n",
" ...,\n",
" [ 4.9677e+00, 2.5583e-02, -1.3527e+00, ..., -1.8770e+00,\n",
" -6.6969e-01, -4.0065e-01],\n",
" [ 4.3137e+00, 1.0467e+00, -1.5161e+00, ..., -2.2238e+00,\n",
" -1.7302e-01, -1.6034e-01],\n",
" [ 4.6436e+00, 9.9926e-01, -5.2100e-01, ..., -1.5177e+00,\n",
" 1.9258e-01, -2.4487e-01]],\n",
"\n",
" [[-7.4442e-03, -2.5452e-01, -1.9922e-04, ..., -1.8494e-01,\n",
" 3.4208e-01, 9.0523e-02],\n",
" [ 8.8014e-01, -3.2005e+00, -2.3284e-01, ..., -5.6783e-01,\n",
" 5.3092e-01, 4.5332e-02],\n",
" [-3.2605e-01, -1.7599e+00, -5.3681e-01, ..., -5.2140e-01,\n",
" 1.7060e+00, -8.0691e-01],\n",
" ...,\n",
" [-1.1833e+00, -8.9443e-01, 5.9676e-01, ..., 3.0636e-01,\n",
" 5.0886e-01, -1.5048e+00],\n",
" [-1.2903e+00, -9.5492e-01, 2.1957e-01, ..., 2.2938e+00,\n",
" -5.0270e-01, -7.8764e-02],\n",
" [ 4.4758e-01, -1.5906e+00, 1.4957e-01, ..., 2.3779e+00,\n",
" -2.2358e-01, 4.7562e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[-0.0224, -0.0239, 0.0031, ..., -0.0027, -0.0209, 0.3535],\n",
" [ 1.1117, -0.0250, 1.0920, ..., -0.1345, -0.2322, -0.7385],\n",
" [ 1.2585, -0.5406, -0.8740, ..., 0.6211, 0.4854, -0.4785],\n",
" ...,\n",
" [ 0.8751, -0.3160, -0.5735, ..., -0.2102, -0.0831, -0.5934],\n",
" [ 0.0085, -0.3084, 0.1655, ..., 0.4398, 0.5114, -0.4383],\n",
" [-0.0725, -0.3939, 0.5899, ..., 0.7469, -0.3640, 0.0679]],\n",
"\n",
" [[ 0.0048, -0.0161, 0.0186, ..., -0.0150, 0.0150, 0.0090],\n",
" [ 0.4621, -0.6415, -0.2005, ..., 0.2446, 1.2697, -0.7838],\n",
" [ 0.6805, -1.2565, 0.0765, ..., -0.0242, 1.4869, 0.1836],\n",
" ...,\n",
" [-0.2538, -0.0022, 0.1847, ..., 0.4838, 1.5106, 0.7886],\n",
" [ 1.2671, -0.9662, -0.3248, ..., 0.5432, -0.0319, -0.1366],\n",
" [-0.1197, -1.6058, -0.3833, ..., 0.3964, 1.0133, -0.1477]],\n",
"\n",
" [[-0.0603, 0.0030, -0.0383, ..., -0.0468, 0.0119, -0.0780],\n",
" [ 0.5506, -0.3951, 0.6694, ..., -0.6748, 0.3026, 0.0286],\n",
" [ 0.4687, 0.1415, 0.0033, ..., 0.4084, 0.2910, 0.4103],\n",
" ...,\n",
" [ 0.4985, 0.4334, 0.3964, ..., -0.2184, -0.0373, -0.0717],\n",
" [ 0.0850, -0.4120, -0.2606, ..., -0.2593, 0.7614, -0.8139],\n",
" [ 0.0813, -0.2308, 0.9975, ..., -0.3412, -1.0508, -0.9304]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.3063, -0.1904, -0.0540, ..., -0.4848, 0.2131, 0.1049],\n",
" [ 1.7674, -1.5249, 1.8613, ..., 1.4412, -0.3079, 0.1500],\n",
" [ 0.3092, -1.5787, 0.6095, ..., 0.5455, 0.1634, 1.3060],\n",
" ...,\n",
" [-1.9410, -1.8215, -0.4399, ..., 0.3221, -0.1979, -1.4136],\n",
" [ 1.2677, -1.9424, 0.0700, ..., -0.9788, -0.6381, -0.4399],\n",
" [ 0.8285, -1.8581, -0.3010, ..., -1.3209, 0.2318, -0.1750]],\n",
"\n",
" [[-0.0861, -0.1412, -0.0534, ..., -0.1797, -0.1466, 0.1142],\n",
" [-0.7113, -0.5252, -0.7349, ..., -0.0491, 0.5213, -0.7352],\n",
" [ 0.4967, -1.1247, -0.6529, ..., -0.4258, -0.1081, -0.2017],\n",
" ...,\n",
" [-0.4174, -1.3939, 0.0162, ..., -0.2306, -0.4274, 0.3158],\n",
" [-1.1609, -0.1209, -0.1991, ..., 1.2310, 0.5859, 0.6733],\n",
" [-1.1036, -0.5834, 0.1167, ..., 0.8276, -0.1767, 0.3441]],\n",
"\n",
" [[-0.0294, -0.0414, 0.1069, ..., 0.0614, -0.0412, 0.0239],\n",
" [-0.6127, -0.0583, -0.7644, ..., -1.4024, -0.9271, 0.9733],\n",
" [ 0.5288, 0.2919, 0.0434, ..., -0.4878, -0.6339, 0.4392],\n",
" ...,\n",
" [ 0.0769, -0.0123, -1.2272, ..., 0.3366, -0.2014, 0.2725],\n",
" [ 0.0642, 1.9300, -0.3253, ..., -1.0578, 0.4355, -1.4476],\n",
" [-0.5956, 0.2606, -0.5507, ..., -0.5284, -0.1602, -0.7526]]]],\n",
" grad_fn=<PermuteBackward0>)), (tensor([[[[-3.3938e-01, 8.6866e-01, -1.6351e-01, ..., 1.1267e+00,\n",
" -1.6784e-01, 1.2959e-01],\n",
" [-3.9849e-01, -4.8487e+00, -2.9290e-01, ..., -4.2862e+00,\n",
" 6.9337e-01, 6.4498e-01],\n",
" [ 7.8937e-01, -3.8062e+00, -7.5906e-01, ..., -4.4411e+00,\n",
" 9.3744e-01, 2.4774e+00],\n",
" ...,\n",
" [-2.9636e-01, -5.8016e+00, 1.4007e+00, ..., -3.1025e+00,\n",
" -2.7375e-01, 1.2819e+00],\n",
" [ 1.3025e+00, -3.8148e+00, 1.8926e+00, ..., -3.3508e+00,\n",
" 6.2647e-01, 5.1378e-01],\n",
" [-3.7934e-01, -4.5341e+00, 6.0715e-01, ..., -4.3161e+00,\n",
" 6.4808e-01, 7.9708e-01]],\n",
"\n",
" [[ 5.8217e-02, 8.7121e-01, -6.2251e-01, ..., -2.4310e-02,\n",
" 2.9330e-01, 1.3199e-02],\n",
" [ 1.3357e+00, -1.1482e+00, 1.2032e-01, ..., 1.5088e+00,\n",
" -1.0720e+00, -1.1527e+00],\n",
" [ 2.6168e+00, 9.3244e-03, -7.2926e-01, ..., 9.4531e-01,\n",
" -7.9178e-01, -1.6888e+00],\n",
" ...,\n",
" [ 8.0452e-01, -3.9176e-01, -3.0347e-01, ..., 1.3463e+00,\n",
" -3.1319e-01, -1.3556e-01],\n",
" [-7.1086e-01, 9.6997e-02, 1.2591e+00, ..., 2.0719e-01,\n",
" 4.2983e-01, -6.3391e-01],\n",
" [ 1.1039e-01, -1.3052e+00, 1.1124e-01, ..., 1.3074e+00,\n",
" 1.4712e+00, -2.7487e-01]],\n",
"\n",
" [[-3.1165e-01, 1.2165e-01, -9.8370e-01, ..., -3.5095e-01,\n",
" -6.3912e-02, -1.3616e-01],\n",
" [ 5.0049e-01, -6.6728e-01, 2.9285e+00, ..., -3.9263e-01,\n",
" 4.3198e-01, -2.3447e-01],\n",
" [ 1.2306e-01, -2.9766e-01, 3.6896e+00, ..., -1.0091e-01,\n",
" -2.5103e-01, -2.0315e-01],\n",
" ...,\n",
" [-2.1391e-01, -2.1547e+00, 2.8612e+00, ..., 5.8855e-01,\n",
" -1.9214e-01, 1.8883e+00],\n",
" [-2.1992e-01, -1.4360e+00, 3.3444e+00, ..., 9.8178e-01,\n",
" -1.9441e+00, 5.7364e-01],\n",
" [ 6.3090e-02, -1.4908e+00, 2.0854e+00, ..., 1.4157e-01,\n",
" -1.3972e-01, -6.9580e-02]],\n",
"\n",
" ...,\n",
"\n",
" [[ 3.7597e-01, 8.1398e-02, -6.4505e-02, ..., -4.8594e-02,\n",
" 2.2536e-01, 4.1931e-03],\n",
" [-1.2319e+00, 8.1079e-01, -5.4320e-01, ..., 1.2257e-01,\n",
" -7.8676e-02, -2.6823e-01],\n",
" [-1.9185e-02, 5.4915e-01, 9.4312e-01, ..., -2.6608e+00,\n",
" 3.8096e-01, -1.3816e+00],\n",
" ...,\n",
" [-1.5423e+00, -2.7545e-01, 2.9765e+00, ..., 5.4036e-01,\n",
" 1.6682e+00, -7.5562e-01],\n",
" [-1.2052e+00, -2.4065e-01, 4.7900e-02, ..., -1.5625e+00,\n",
" 2.8238e-01, -3.3910e-01],\n",
" [-1.7759e+00, 3.9760e-01, -1.0807e+00, ..., -1.9584e+00,\n",
" -1.1637e+00, 1.5918e+00]],\n",
"\n",
" [[ 2.0009e-01, 5.4941e-02, 3.2748e-01, ..., 4.1661e-01,\n",
" -3.4165e-03, 2.3171e-01],\n",
" [ 1.6163e+00, 1.2442e+00, 2.8373e-01, ..., -3.9689e-01,\n",
" 7.1320e-03, -1.1601e-01],\n",
" [ 1.3228e+00, 1.4674e-01, 6.3871e-01, ..., -5.9913e-02,\n",
" 1.6461e-01, 3.3509e-01],\n",
" ...,\n",
" [ 7.7162e-01, 7.9756e-01, 8.2908e-01, ..., -1.0911e+00,\n",
" 8.8888e-01, -1.1994e+00],\n",
" [ 1.6909e+00, 8.3524e-01, 6.7132e-01, ..., -1.1008e+00,\n",
" -7.2901e-01, 6.1303e-01],\n",
" [ 2.8334e+00, 4.6555e-01, 1.2473e+00, ..., -7.3844e-01,\n",
" -7.0963e-01, 1.0278e-01]],\n",
"\n",
" [[-3.0156e+00, 5.3756e-01, 5.6815e-01, ..., -9.3899e-01,\n",
" 3.2683e-01, 1.8463e-01],\n",
" [ 7.7879e+00, -9.7524e-01, -2.1850e+00, ..., 2.2429e+00,\n",
" -1.0887e+00, 5.6749e-01],\n",
" [ 6.9868e+00, -3.9651e-01, -9.7286e-01, ..., 1.0613e+00,\n",
" -7.0396e-01, 1.3823e+00],\n",
" ...,\n",
" [ 9.5361e+00, -8.7937e-01, -2.5252e+00, ..., 1.3820e+00,\n",
" -2.2409e+00, 2.4565e-01],\n",
" [ 8.8630e+00, -1.1387e+00, -1.7681e+00, ..., 1.0129e+00,\n",
" 2.0493e-01, -2.1170e-01],\n",
" [ 9.5194e+00, -2.0795e-01, -1.6476e+00, ..., 2.4340e+00,\n",
" -1.9197e+00, -2.9640e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[ 4.6474e-02, -5.0378e-02, 1.0945e-02, ..., -6.9955e-02,\n",
" 2.9789e-03, -1.0073e-01],\n",
" [ 5.1051e-01, -5.5772e-01, -3.8570e-01, ..., -3.2328e-01,\n",
" 2.3945e-01, -2.9826e-01],\n",
" [-2.8010e-01, 7.4962e-01, -5.4584e-01, ..., -3.6442e-01,\n",
" 4.2576e-01, -1.4805e+00],\n",
" ...,\n",
" [-9.0783e-01, -4.8128e-01, -1.8888e-01, ..., -2.2824e-01,\n",
" -7.4845e-02, -1.0972e+00],\n",
" [-5.0702e-01, 1.0603e-01, -1.0484e+00, ..., 5.5779e-01,\n",
" -4.9793e-01, -9.2837e-01],\n",
" [ 3.8714e-02, 4.2493e-01, -4.1890e-01, ..., 5.6050e-01,\n",
" -2.7279e-01, -1.3355e+00]],\n",
"\n",
" [[ 6.7674e-02, 3.0544e-02, -2.3115e-02, ..., -4.3823e-02,\n",
" 5.2575e-03, -1.6795e-03],\n",
" [-5.3669e-01, 1.7762e+00, -5.2043e-01, ..., 7.5157e-01,\n",
" -6.1868e-01, -7.3336e-01],\n",
" [-2.5054e-01, 4.9751e-03, -8.3214e-02, ..., -7.4598e-01,\n",
" -6.1617e-01, 3.3602e-01],\n",
" ...,\n",
" [-1.4067e-01, 5.9621e-01, 1.0898e+00, ..., 9.4066e-01,\n",
" -1.3745e+00, 1.1213e+00],\n",
" [-1.0630e+00, -5.0378e-01, 6.9651e-01, ..., -4.6445e-01,\n",
" -6.6259e-01, 1.7251e-01],\n",
" [-1.5972e+00, 3.2659e-01, 3.4644e-01, ..., 2.8986e-01,\n",
" -5.7299e-01, -2.2912e-01]],\n",
"\n",
" [[ 7.0952e-02, 8.2320e-03, -1.6572e-03, ..., 2.1678e-02,\n",
" -6.7437e-02, -5.0287e-02],\n",
" [ 7.4200e-01, -3.2418e-01, 4.1442e-01, ..., -1.4945e-02,\n",
" 2.5678e-01, 1.5392e-01],\n",
" [ 2.9304e-01, 5.7399e-01, -2.7184e-01, ..., -1.4044e-01,\n",
" 6.1588e-02, -1.5561e-01],\n",
" ...,\n",
" [ 7.1019e-01, -8.5043e-01, -3.1989e-01, ..., 2.5753e-01,\n",
" 2.2188e-01, 7.3108e-01],\n",
" [ 7.1561e-01, -8.6057e-01, 9.2320e-01, ..., 3.9957e-01,\n",
" 2.4226e+00, 1.6563e+00],\n",
" [-7.6132e-02, 2.4041e-01, 9.3365e-01, ..., -2.2613e-01,\n",
" 3.9552e-01, 1.0165e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-8.4018e-04, 4.2945e-02, 2.0029e-02, ..., -6.6209e-02,\n",
" -1.8070e-02, 2.2869e-02],\n",
" [-1.4168e+00, 2.7825e-01, 3.5415e-02, ..., 2.2794e-01,\n",
" -1.8244e-01, 2.6631e-01],\n",
" [-1.5832e+00, 6.7589e-01, -1.3738e-01, ..., 7.5377e-01,\n",
" -8.9247e-01, 8.4118e-01],\n",
" ...,\n",
" [-1.0343e+00, 2.2096e-01, 1.8098e-01, ..., 1.5064e+00,\n",
" -9.4570e-01, -9.6457e-01],\n",
" [-5.5192e-01, 6.5732e-01, -7.3323e-01, ..., 8.2586e-01,\n",
" 1.0773e+00, -5.0690e-01],\n",
" [-6.9760e-01, -2.0758e-01, 2.9526e-01, ..., -1.6063e-02,\n",
" 1.6516e-02, 4.3263e-01]],\n",
"\n",
" [[ 5.6418e-02, -6.3642e-03, 2.3703e-02, ..., 1.7139e-02,\n",
" -1.5312e-02, 6.8112e-03],\n",
" [ 1.8381e+00, -1.3941e+00, -1.0189e+00, ..., -9.4177e-01,\n",
" 4.2883e-01, 8.2570e-01],\n",
" [ 8.8893e-01, -1.6692e+00, -4.3398e-01, ..., -1.2906e+00,\n",
" 1.0952e-01, 3.7169e-01],\n",
" ...,\n",
" [ 7.4024e-01, -1.4955e-01, -8.9148e-01, ..., -1.0267e+00,\n",
" -6.1569e-01, 5.8172e-01],\n",
" [-7.3008e-01, -4.7314e-01, 3.7697e-01, ..., 5.2418e-01,\n",
" -1.6633e-01, 3.0198e-01],\n",
" [ 6.6411e-02, -4.8074e-01, -4.0598e-01, ..., 1.1196e-01,\n",
" 1.0054e+00, -4.4949e-01]],\n",
"\n",
" [[ 6.7269e-02, -2.0375e-01, -7.5082e-02, ..., -4.0162e-02,\n",
" 1.9610e-01, -5.1942e-02],\n",
" [ 3.7243e-01, -9.5645e-01, -3.3796e-01, ..., -9.8523e-01,\n",
" -4.3307e-01, -2.3109e-01],\n",
" [-5.5909e-01, -9.8741e-01, -8.3997e-01, ..., -4.0350e-02,\n",
" 2.2590e-03, -1.1709e+00],\n",
" ...,\n",
" [ 2.6116e-01, -1.7003e+00, 9.9667e-03, ..., 2.5269e-01,\n",
" -6.5086e-01, -5.0987e-01],\n",
" [-2.2483e-01, -3.8567e-01, -1.6472e-01, ..., -7.8707e-01,\n",
" 3.2198e-01, -4.2609e-01],\n",
" [-1.5893e-01, -7.3543e-01, -4.9369e-01, ..., -1.5504e+00,\n",
" -3.8277e-01, -4.1377e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[ 1.0315, -0.2591, -0.1501, ..., 0.6246, 0.7232, -0.3059],\n",
" [-3.2631, -1.9435, -1.2089, ..., 0.2806, -5.5510, 0.3489],\n",
" [-4.3399, -1.4013, -0.2186, ..., 0.7109, -5.7860, -0.6970],\n",
" ...,\n",
" [-5.3760, -3.7959, -0.3718, ..., 0.7804, -4.6301, -1.2402],\n",
" [-6.1584, -2.8255, 0.0772, ..., 0.3908, -4.4567, -0.1920],\n",
" [-4.6862, -2.2484, -0.4802, ..., 1.1911, -4.6985, -1.0555]],\n",
"\n",
" [[-0.1525, -0.0745, 0.1651, ..., -0.0464, -0.8761, -0.1921],\n",
" [-0.6493, 1.0436, -0.2845, ..., -0.2628, 0.0537, 0.8063],\n",
" [-1.7708, 1.3885, -0.9440, ..., 0.3637, 0.7435, 1.4247],\n",
" ...,\n",
" [-2.0241, 0.3328, -0.2828, ..., 0.8545, 0.5231, 2.4687],\n",
" [-2.8308, -0.2631, -0.4617, ..., -0.3337, 1.8320, 2.9475],\n",
" [-2.0453, 1.1846, -2.5580, ..., 0.5495, 1.1092, 1.8249]],\n",
"\n",
" [[ 0.1958, 0.3039, 1.1389, ..., -0.4691, 0.4513, -0.4878],\n",
" [-1.4815, -0.5524, -1.6846, ..., -0.5676, -1.8434, 2.4752],\n",
" [-3.5171, -1.7341, -1.0781, ..., -0.0126, -0.8584, 2.8363],\n",
" ...,\n",
" [-1.2945, -1.0943, -0.7373, ..., 0.2280, -2.9008, 2.5152],\n",
" [-2.2796, -0.5816, -0.3174, ..., 0.7422, -1.4116, 2.2355],\n",
" [-0.7958, 0.1943, -2.7152, ..., 1.7208, -1.5123, 0.9313]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.1518, 0.0675, -0.2341, ..., 0.0125, 0.1685, 0.0227],\n",
" [-1.5855, 0.1968, -0.4700, ..., 0.9270, -1.3281, -0.1941],\n",
" [ 0.7663, -0.7921, -0.5326, ..., 0.9606, -0.0650, -0.2843],\n",
" ...,\n",
" [ 0.1914, -0.1551, -0.9815, ..., 1.8034, 0.1310, 0.7172],\n",
" [-1.2788, -1.7422, -0.4975, ..., 1.3406, -0.4531, -0.5256],\n",
" [-1.8526, 0.1496, -0.0816, ..., 0.8122, -1.0543, 0.1050]],\n",
"\n",
" [[-0.3515, -2.1836, 0.1103, ..., -0.0873, -0.0481, 0.9174],\n",
" [-0.3931, 1.7304, -1.0893, ..., -1.0898, -1.7984, 1.0287],\n",
" [ 0.3552, 3.3603, -1.5929, ..., -0.7109, -1.5203, 0.7090],\n",
" ...,\n",
" [ 0.4526, 3.6483, -3.1344, ..., 1.3756, -1.8511, 2.2068],\n",
" [ 1.4022, 2.2589, -2.0330, ..., 0.3515, -0.4796, 0.9019],\n",
" [ 0.7568, 2.8114, -2.1562, ..., 1.3476, -0.3658, 0.7552]],\n",
"\n",
" [[ 0.3682, 0.0657, -0.1320, ..., 0.6454, 0.1343, 0.2644],\n",
" [-0.8896, 0.3677, 0.1631, ..., -0.3916, -0.4439, -0.9719],\n",
" [ 0.4470, 0.5271, -0.4635, ..., -0.6886, -1.2558, 0.0390],\n",
" ...,\n",
" [-1.7867, -2.2049, 2.1719, ..., -0.8210, -0.2084, 1.6132],\n",
" [-1.4884, -1.5097, 0.1562, ..., 0.5166, 0.2819, -0.1415],\n",
" [-2.1183, 1.1049, 1.0999, ..., -0.3114, 0.2994, 0.8749]]]],\n",
" grad_fn=<PermuteBackward0>), tensor([[[[-2.8336e-02, 4.5010e-02, -5.7978e-02, ..., -1.9222e-02,\n",
" 1.2577e-02, 3.6269e-02],\n",
" [-4.6357e-01, 2.2473e-01, -5.6911e-01, ..., -1.1179e-01,\n",
" 5.4963e-01, 1.4621e-01],\n",
" [ 8.8791e-01, 1.3920e-01, -1.1074e+00, ..., -3.0970e-01,\n",
" 8.6369e-01, 2.8616e-01],\n",
" ...,\n",
" [ 9.8967e-01, 7.6810e-02, 3.6725e-01, ..., 1.0289e-01,\n",
" -7.9780e-01, -1.0472e-01],\n",
" [ 4.1472e-01, -3.0706e-01, -1.0118e-01, ..., -2.9164e-02,\n",
" 9.2894e-02, 2.6503e-01],\n",
" [-4.1391e-01, -4.3953e-01, 9.5461e-02, ..., -1.8622e-02,\n",
" 1.2946e-01, -4.0387e-01]],\n",
"\n",
" [[ 2.1536e-02, -2.8120e-02, 3.8532e-02, ..., 2.1765e-02,\n",
" -4.7212e-02, 5.3255e-03],\n",
" [-1.6248e-01, -4.5659e-01, -4.4525e-02, ..., 5.6903e-01,\n",
" -3.0144e-01, -1.2120e+00],\n",
" [-1.6019e-01, -3.1593e-01, 1.0682e+00, ..., -1.1746e-01,\n",
" -4.8418e-01, 4.2423e-01],\n",
" ...,\n",
" [-7.0670e-01, 1.4226e-01, -2.0767e-01, ..., -5.3785e-01,\n",
" -3.7916e-01, 2.9476e-01],\n",
" [ 3.5204e-01, 1.6746e-01, -1.8197e+00, ..., 1.8833e-01,\n",
" 2.5200e-01, 1.3326e+00],\n",
" [ 1.0614e-01, -5.6477e-01, -1.3717e+00, ..., 2.8329e-01,\n",
" -2.3432e-01, 5.8129e-01]],\n",
"\n",
" [[ 3.9084e-02, -2.6990e-02, 5.6189e-02, ..., 2.6549e-02,\n",
" -7.1806e-03, 1.9065e-02],\n",
" [ 8.1593e-01, 3.5473e-01, -1.9476e-01, ..., 7.1779e-01,\n",
" 1.7158e-01, 1.7037e-01],\n",
" [-3.0468e-01, 6.4740e-01, -1.1535e+00, ..., 2.5107e+00,\n",
" -1.3214e+00, 6.0931e-01],\n",
" ...,\n",
" [-3.8012e-01, -1.0693e+00, -4.3163e-01, ..., -1.2006e-01,\n",
" -4.7626e-01, -5.9241e-01],\n",
" [-6.6220e-01, 1.0321e+00, 6.1114e-01, ..., -1.0294e+00,\n",
" -5.9746e-02, -1.4874e+00],\n",
" [ 1.5239e+00, 1.7266e-01, -2.6497e-01, ..., -6.9278e-01,\n",
" 2.7154e-01, 1.1508e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-1.8650e-01, 8.9365e-02, 5.7435e-02, ..., 4.6573e-02,\n",
" 3.7369e-02, -1.2676e-01],\n",
" [-6.2920e-01, -4.5253e-02, 1.5379e-01, ..., -8.5838e-01,\n",
" 2.2210e-01, -4.9222e-01],\n",
" [-2.1227e-01, 6.7216e-01, 5.8456e-01, ..., -4.8421e-02,\n",
" -4.2428e-01, -4.8305e-01],\n",
" ...,\n",
" [-9.4783e-01, -4.8206e-02, -1.2836e-01, ..., 1.8181e-01,\n",
" -4.6491e-01, -8.4671e-01],\n",
" [-7.2088e-01, 4.8839e-01, -1.6034e+00, ..., -3.5454e-01,\n",
" 8.5080e-02, -1.4271e+00],\n",
" [-1.0528e+00, 8.3454e-01, -9.8252e-01, ..., 1.1729e-01,\n",
" -1.4640e-01, -1.9143e+00]],\n",
"\n",
" [[-5.8489e-01, -4.5877e-03, 4.4912e-02, ..., -2.0796e-02,\n",
" 6.2989e-03, -6.4938e-03],\n",
" [-1.6445e+00, 4.2511e-02, -3.1403e-01, ..., -3.7935e-01,\n",
" 2.3561e-01, 5.9496e-02],\n",
" [-2.5505e+00, -2.0482e-01, -3.6240e-01, ..., -3.0201e-01,\n",
" -4.2028e-01, -1.8376e-02],\n",
" ...,\n",
" [-1.6757e+00, 4.2658e-01, -9.1740e-01, ..., 2.0202e-01,\n",
" 5.2352e-01, 3.1575e-01],\n",
" [-1.7608e+00, 5.6837e-01, 3.5225e-01, ..., 5.5874e-01,\n",
" -6.9264e-01, -1.8256e-01],\n",
" [-2.3731e+00, -2.8098e-01, 3.9676e-01, ..., -2.5406e-01,\n",
" 4.8834e-01, -6.1031e-01]],\n",
"\n",
" [[ 1.5471e-03, 8.2456e-02, -4.7513e-02, ..., 5.5853e-02,\n",
" 3.0368e-02, -4.6994e-02],\n",
" [-5.5504e-01, 7.3400e-01, -2.0816e-01, ..., -1.2824e-01,\n",
" 3.8586e-01, 8.0331e-01],\n",
" [ 6.3713e-01, 1.6547e+00, 2.6059e-01, ..., -1.1861e+00,\n",
" 6.3198e-01, -1.3541e-01],\n",
" ...,\n",
" [ 4.7463e-01, 1.1477e+00, 6.0258e-02, ..., -4.6058e-01,\n",
" -3.5489e-01, 7.9365e-02],\n",
" [ 8.1016e-02, -1.3944e-01, 4.1258e-01, ..., 1.1060e-01,\n",
" -2.8541e+00, 4.1492e-01],\n",
" [-1.2963e+00, 2.2384e-01, -2.4338e-01, ..., 2.2294e-01,\n",
" 1.0918e-01, 2.1425e+00]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-0.0268, -2.3398, 0.1634, ..., -0.2365, -0.1944, 0.0645],\n",
" [-0.2799, 4.0417, -0.1959, ..., -0.1126, -0.3356, -0.5690],\n",
" [-0.0094, 4.9294, 0.5373, ..., -0.1404, -0.6815, 0.3025],\n",
" ...,\n",
" [ 0.0091, 3.8780, 0.7991, ..., -0.5989, 0.7071, 0.5137],\n",
" [-0.6218, 4.7662, -0.4088, ..., -0.8925, -0.0737, 0.7395],\n",
" [-0.8638, 5.1069, -0.1012, ..., -0.0097, 0.0632, -0.7295]],\n",
"\n",
" [[-0.8140, 0.2218, 0.4656, ..., -0.5189, 1.0732, 1.1234],\n",
" [ 0.2364, 0.2685, 1.0541, ..., 0.5500, 1.3914, 0.4962],\n",
" [ 1.5259, 1.0305, -0.6830, ..., -0.3595, 0.8213, -0.1596],\n",
" ...,\n",
" [-0.8893, 0.6401, 1.5340, ..., -0.3154, 0.9969, 0.1131],\n",
" [-1.4240, -0.5673, -0.9037, ..., -0.0334, 2.1567, -0.3555],\n",
" [-2.3116, 1.4069, 0.2116, ..., 0.7944, 2.6708, 0.1778]],\n",
"\n",
" [[-0.8504, 0.4700, 0.0232, ..., 0.4955, -0.2356, 1.1518],\n",
" [ 0.6655, -0.1374, 1.1604, ..., 0.2494, 1.0734, -0.9082],\n",
" [ 2.0262, 0.3311, 0.5329, ..., 0.2746, 0.6484, -1.2565],\n",
" ...,\n",
" [ 0.8666, 0.2080, 0.7423, ..., -0.0590, 0.7947, 0.2077],\n",
" [ 1.3274, -0.5878, 1.5562, ..., 1.2727, 0.8958, -0.8393],\n",
" [ 0.6793, -0.9115, 2.1432, ..., 1.5571, 1.7428, -0.3943]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.3102, -0.1292, 0.1523, ..., 0.1793, 1.7438, -2.8696],\n",
" [ 1.2379, -0.5238, 0.3674, ..., -0.3042, -4.6049, 4.6856],\n",
" [ 0.6856, 0.3973, 0.9211, ..., -0.6994, -5.2863, 4.9465],\n",
" ...,\n",
" [-0.5532, -0.4212, 1.0728, ..., 0.4562, -5.7176, 5.1979],\n",
" [-0.9992, -1.4073, -0.8534, ..., 0.8452, -5.9484, 4.2105],\n",
" [ 0.1935, -1.2555, 1.2355, ..., -0.0070, -6.0872, 5.8807]],\n",
"\n",
" [[ 0.1957, 0.3617, 0.2155, ..., -0.2170, 0.0182, -0.1540],\n",
" [-0.6359, -0.7831, -0.5938, ..., 1.0413, -0.4280, 0.6407],\n",
" [-0.6033, -1.0964, -0.2818, ..., 0.2840, -0.2947, 0.6149],\n",
" ...,\n",
" [-0.2907, 0.0759, 0.5673, ..., 1.1031, -0.7398, 0.1992],\n",
" [-0.3487, -0.1916, 1.1144, ..., 0.6085, 0.1949, 1.1279],\n",
" [-1.1693, -0.8894, 0.6257, ..., 1.4145, -1.2843, 0.4372]],\n",
"\n",
" [[ 0.3722, 0.0987, 0.6134, ..., 0.5249, 0.5746, -0.3289],\n",
" [ 0.7276, -0.7879, -1.5108, ..., -1.7654, -3.2146, 0.1771],\n",
" [ 0.6286, -1.0423, -1.3390, ..., -2.0023, -2.7540, -0.0532],\n",
" ...,\n",
" [ 1.2008, -1.0047, -2.2047, ..., -2.5210, -4.7543, 1.0585],\n",
" [ 0.1571, -1.0960, -1.7899, ..., -3.0896, -4.1969, -0.4143],\n",
" [ 1.1982, -1.3326, -1.5329, ..., -1.6822, -4.4774, -0.5948]]]],\n",
" grad_fn=<PermuteBackward0>), tensor([[[[ 0.0512, -0.0071, -0.0159, ..., 0.1308, -0.0604, -0.0383],\n",
" [ 0.0743, 0.0098, 0.8985, ..., 0.3322, 1.0163, -0.2279],\n",
" [ 0.2299, -1.0595, -0.2036, ..., 0.4071, 1.0309, 0.7073],\n",
" ...,\n",
" [ 0.9942, 0.0985, -0.3045, ..., 0.3595, 1.2762, 0.1312],\n",
" [-1.1851, 0.1872, 2.5162, ..., -0.4091, -0.5504, -0.3313],\n",
" [ 0.4573, -0.2495, 1.1492, ..., 0.3916, 0.3092, -0.2549]],\n",
"\n",
" [[ 0.0178, 0.0383, 0.0396, ..., 0.0060, -0.0180, 0.0108],\n",
" [ 0.3077, 0.2800, -1.2484, ..., 0.1144, -0.0260, -0.6417],\n",
" [ 0.8365, 0.1942, -2.6429, ..., 1.4839, -2.4390, -1.1518],\n",
" ...,\n",
" [-1.0152, -1.3838, 0.4507, ..., 0.2284, 0.2643, 0.3901],\n",
" [-1.8002, -1.5104, -0.6286, ..., 1.0451, 0.2438, -0.3518],\n",
" [-0.4032, -0.3529, -1.6265, ..., 0.5828, 0.5720, -1.2572]],\n",
"\n",
" [[ 0.0495, -0.0389, 0.0613, ..., 0.0561, -0.0711, -0.0673],\n",
" [-0.6686, 1.1461, -0.4798, ..., 0.1773, 0.4573, 0.4967],\n",
" [-0.3811, 0.8968, -0.6061, ..., 0.0926, 0.3056, 0.9180],\n",
" ...,\n",
" [-0.3757, -0.0510, 0.0062, ..., 0.6064, 0.7972, 0.7227],\n",
" [-0.2685, -0.7850, 0.7441, ..., -0.8875, -0.0677, 1.0534],\n",
" [-0.7876, 1.0096, -0.0108, ..., -0.9138, -0.1195, -0.2942]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.1028, -0.0452, 0.0346, ..., -0.0871, 0.0427, 0.0092],\n",
" [ 0.5169, 0.0966, 0.2483, ..., -0.4591, 0.3724, 0.6674],\n",
" [ 0.9085, 0.9305, -0.0286, ..., -0.8769, -0.3911, 0.3594],\n",
" ...,\n",
" [-0.0673, -0.2202, -0.2051, ..., 0.2041, -0.4487, 1.0220],\n",
" [-0.2218, -0.4037, 1.4038, ..., 1.5332, -1.2336, 0.4163],\n",
" [ 0.8637, -1.0940, 0.2482, ..., 0.3983, -1.4612, 0.6188]],\n",
"\n",
" [[ 0.1576, -0.0522, 0.1510, ..., 0.0776, 0.0389, -0.1486],\n",
" [-0.0612, 1.4222, 1.2901, ..., 1.0537, 1.9877, -1.2965],\n",
" [ 0.0701, 1.0599, 1.3164, ..., 1.8434, 1.7597, -0.8641],\n",
" ...,\n",
" [-0.0791, 0.1802, -0.2036, ..., 0.6063, 1.2652, 0.1763],\n",
" [ 0.4001, 1.6460, 1.1749, ..., -0.6267, 2.3732, -0.3538],\n",
" [ 0.2739, 1.4950, 0.8300, ..., 1.1957, 1.5808, -1.0777]],\n",
"\n",
" [[ 0.2067, -0.0439, -0.0680, ..., 0.0390, 0.0473, 0.0275],\n",
" [-0.6717, 0.2561, 0.7676, ..., -0.2872, -0.5916, -0.1957],\n",
" [-0.9239, 0.0464, 0.4365, ..., 0.6006, -0.4989, 0.7633],\n",
" ...,\n",
" [ 0.5723, 0.0787, 0.7033, ..., 0.3464, -0.7811, 1.3074],\n",
" [-0.8109, -0.4612, -1.6027, ..., -1.6367, -0.0065, -0.7756],\n",
" [-1.3609, 0.5702, 0.7531, ..., -0.1462, 0.1355, 0.2370]]]],\n",
" grad_fn=<PermuteBackward0>)), (tensor([[[[ 0.0436, -0.2509, -0.4550, ..., 0.3116, 0.3358, 0.3770],\n",
" [ 0.3000, -0.0478, -1.2507, ..., 0.3519, 0.7682, 0.7296],\n",
" [ 0.5719, -1.2229, -1.5268, ..., 0.2386, 0.1496, 0.8318],\n",
" ...,\n",
" [-0.7869, -1.6832, -0.6862, ..., 0.5248, -1.1760, -0.6061],\n",
" [ 1.1739, 0.7271, -1.4276, ..., 1.1409, -1.3880, -0.6762],\n",
" [ 1.7515, -0.1609, -0.0345, ..., 0.9718, -0.5132, 1.4921]],\n",
"\n",
" [[-0.2801, 0.1559, 0.1167, ..., 0.0214, -1.1384, -0.1501],\n",
" [-0.7174, -0.1200, -0.7961, ..., -0.4121, 0.7157, 0.5868],\n",
" [ 1.5377, 0.1651, -0.9257, ..., 0.3588, 1.3888, 0.1633],\n",
" ...,\n",
" [ 0.2032, 0.5659, -0.9297, ..., -1.1580, -1.0870, 1.0748],\n",
" [-0.0984, 1.5501, -1.2118, ..., -1.0350, 0.6500, 0.8747],\n",
" [-1.1498, 0.8479, -0.9318, ..., -1.2515, 0.5937, 0.4393]],\n",
"\n",
" [[-1.2411, -0.0878, 0.5490, ..., -0.6611, 0.4539, -0.2888],\n",
" [ 0.6556, 1.0735, -0.5900, ..., 0.0895, -0.3484, -0.2450],\n",
" [ 0.3530, 0.0116, 0.0702, ..., 0.7262, -1.4991, -0.5028],\n",
" ...,\n",
" [ 0.6693, 0.8831, -0.7045, ..., 1.2413, 0.0528, 0.1498],\n",
" [ 1.5144, 1.9988, -1.8167, ..., 1.0272, -0.5508, -0.2781],\n",
" [-0.2976, 1.1260, -1.6873, ..., 1.3365, -0.2020, -0.3461]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.7973, -0.8987, -0.3939, ..., -1.0369, -0.4123, 0.4803],\n",
" [ 1.4616, 0.0408, -1.0295, ..., 0.7219, 0.3444, -0.0145],\n",
" [ 0.2550, -1.1764, -0.3335, ..., 0.8036, 1.7228, -2.3128],\n",
" ...,\n",
" [ 0.6038, -0.3213, -0.9128, ..., 1.7723, 0.7332, -1.3456],\n",
" [ 1.5292, 0.8308, -1.5665, ..., 1.7068, 0.6255, -1.4453],\n",
" [ 2.1459, -0.1321, -0.5784, ..., 1.8690, 1.6415, 0.8508]],\n",
"\n",
" [[-0.9151, 2.5785, 0.3082, ..., 0.3579, 1.9421, -0.5408],\n",
" [-0.0171, -2.5663, 0.7328, ..., 0.3923, -4.1463, 1.9012],\n",
" [ 1.0185, -2.5828, -1.5448, ..., 1.0508, -4.9451, 1.7123],\n",
" ...,\n",
" [ 1.1092, -2.5339, 0.2730, ..., -0.9127, -3.6883, -0.9762],\n",
" [ 0.7417, -1.7092, 0.4430, ..., 0.6517, -4.0859, -0.6250],\n",
" [ 0.6957, -4.4839, -0.4944, ..., 1.2733, -5.0460, 2.7409]],\n",
"\n",
" [[-2.0221, -0.3681, -1.1042, ..., -0.3983, 0.0527, 0.2442],\n",
" [ 1.4425, 0.4368, 0.8613, ..., 1.2344, -0.1098, 0.1759],\n",
" [ 2.6873, -0.5718, 0.7670, ..., 1.4859, -0.9973, 1.5824],\n",
" ...,\n",
" [ 2.3481, -0.2267, 0.4736, ..., 1.0791, 0.1695, -0.6822],\n",
" [ 1.8121, 0.8181, 1.5002, ..., 1.3897, -1.1112, -0.6512],\n",
" [ 2.0455, 0.8276, 1.0394, ..., 1.7555, -0.0730, -0.0210]]]],\n",
" grad_fn=<PermuteBackward0>), tensor([[[[-4.8632e-02, -9.1396e-02, 3.1682e-02, ..., 1.2261e-01,\n",
" -3.6255e-02, 1.2526e-02],\n",
" [-3.1696e-01, 2.4710e-01, -5.5300e-02, ..., 2.4286e-02,\n",
" -3.1162e-01, 2.2300e-01],\n",
" [ 1.9266e-01, 5.7364e-01, 5.6620e-01, ..., 6.3398e-01,\n",
" -1.6994e-01, -3.2943e-01],\n",
" ...,\n",
" [ 8.0219e-01, -6.2467e-02, 7.5092e-01, ..., -2.6152e-01,\n",
" 6.4908e-01, 9.1121e-01],\n",
" [ 2.9819e-01, -1.1154e+00, 5.7111e-01, ..., -1.1155e+00,\n",
" 5.0150e-01, 3.6634e-01],\n",
" [ 7.2844e-01, 4.1041e-01, 6.7296e-01, ..., 2.8859e-01,\n",
" -9.5357e-01, 4.9752e-01]],\n",
"\n",
" [[ 1.1928e-02, 1.3112e-02, -2.6053e-02, ..., 4.6390e-02,\n",
" 2.8720e-02, 5.6897e-02],\n",
" [ 5.1804e-01, -8.6756e-03, 3.4240e-01, ..., -9.3518e-01,\n",
" -2.8230e-02, -1.6108e-01],\n",
" [-6.5553e-01, -1.4296e-01, 6.3211e-01, ..., -2.3726e+00,\n",
" -1.0325e+00, 1.1180e+00],\n",
" ...,\n",
" [-2.7697e-01, 4.7694e-01, 9.3078e-01, ..., -1.4985e-02,\n",
" -9.5630e-01, -1.0057e+00],\n",
" [-3.5304e-01, 7.6668e-01, -7.3687e-01, ..., 8.2464e-01,\n",
" 6.1313e-01, 1.4616e-01],\n",
" [ 6.2543e-02, 9.5850e-01, 9.9546e-02, ..., -4.1675e-01,\n",
" -3.1019e-01, 2.1785e-02]],\n",
"\n",
" [[ 3.9431e-02, 3.2304e-02, -6.9643e-02, ..., 3.1842e-03,\n",
" 1.5391e-02, 8.6383e-03],\n",
" [-4.5218e-02, 3.8015e-01, -7.4175e-03, ..., -8.6065e-02,\n",
" 1.9510e-01, 2.4301e-02],\n",
" [ 1.0227e+00, 7.7004e-02, 7.1903e-02, ..., 1.1994e+00,\n",
" 1.6976e-01, -4.0066e-01],\n",
" ...,\n",
" [ 1.1771e+00, 2.4422e-01, 7.0662e-01, ..., 1.1337e+00,\n",
" -8.5384e-01, -9.9605e-01],\n",
" [ 4.0196e-01, 3.7700e-01, 1.0244e+00, ..., -2.4000e-01,\n",
" -2.2166e-03, -8.7664e-01],\n",
" [ 6.3016e-01, 1.0653e-01, 6.7085e-01, ..., 1.8561e-01,\n",
" -1.0484e+00, -2.8506e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-2.8633e-02, 2.3521e-02, -1.3071e-02, ..., -3.3836e-02,\n",
" -4.1805e-02, 1.6132e-02],\n",
" [ 2.9617e-01, 2.5753e-01, 6.4459e-01, ..., 6.7883e-01,\n",
" -1.1170e-01, 3.4354e-01],\n",
" [ 1.4840e-01, -2.1638e-01, 1.5988e-01, ..., 3.0029e-01,\n",
" -1.7462e+00, 2.2010e+00],\n",
" ...,\n",
" [-1.0735e-01, -2.7973e-01, 1.7696e-01, ..., 1.2454e-01,\n",
" 1.6533e+00, 4.6311e-02],\n",
" [-2.5303e-01, -5.3346e-01, -7.0970e-01, ..., 3.3254e-01,\n",
" -1.0337e-01, -1.5011e+00],\n",
" [ 9.4744e-01, 4.1239e-01, -1.0214e-01, ..., 1.0832e+00,\n",
" 1.1939e+00, 2.1364e-01]],\n",
"\n",
" [[-7.3958e-02, -4.4124e-02, 1.7760e-02, ..., 3.1321e-03,\n",
" -4.5881e-02, -1.0916e-01],\n",
" [-4.6492e-01, -6.5992e-02, -4.8427e-02, ..., 2.7765e-01,\n",
" 1.7094e-01, -2.1020e-01],\n",
" [-9.3265e-01, -1.7024e+00, 1.1011e-01, ..., -6.0777e-01,\n",
" 2.7326e-01, -1.2374e+00],\n",
" ...,\n",
" [-1.0394e-01, 3.4447e-02, -1.4004e+00, ..., 1.9303e-01,\n",
" -1.2038e+00, 5.6969e-01],\n",
" [-9.6140e-01, 5.8390e-01, -5.3376e-01, ..., 3.5307e-01,\n",
" 3.7874e-01, -4.8008e-02],\n",
" [-6.0081e-02, -5.1836e-01, -9.0043e-02, ..., 2.2977e-01,\n",
" -1.1964e-01, -4.6107e-01]],\n",
"\n",
" [[-7.7909e-03, 4.0206e-02, -5.6468e-02, ..., -3.0341e-02,\n",
" 2.4338e-02, 5.3261e-03],\n",
" [-2.4073e-01, -1.4607e-01, 6.8568e-01, ..., -9.4289e-01,\n",
" -1.0285e+00, -7.2268e-01],\n",
" [-8.9161e-01, 3.2033e-01, 2.2241e-01, ..., 7.4783e-01,\n",
" -1.8553e-01, -1.4143e+00],\n",
" ...,\n",
" [ 5.0678e-01, -8.7200e-01, 1.3745e+00, ..., 4.7279e-01,\n",
" 1.9468e-01, -3.0692e-01],\n",
" [ 9.3960e-02, -1.1271e+00, -3.2356e-01, ..., 4.6166e-01,\n",
" 1.1812e+00, -7.4736e-02],\n",
" [-1.3350e-01, -7.4492e-01, 7.1189e-01, ..., -1.8032e-01,\n",
" -1.5200e+00, -9.0480e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-0.5266, 0.4829, -0.8581, ..., -1.0129, -1.3275, 0.2040],\n",
" [ 0.6896, 0.9769, 0.7884, ..., 2.3917, -0.4978, -1.2931],\n",
" [ 0.4849, 0.2801, 2.0546, ..., 2.4861, 0.4516, -2.1291],\n",
" ...,\n",
" [ 1.1704, -0.3403, 1.9815, ..., 2.4192, 1.2849, -0.9075],\n",
" [ 0.5567, 0.0731, 0.2333, ..., 1.9754, -0.6718, 0.4945],\n",
" [ 1.1157, 0.0910, 1.9513, ..., 2.0806, -0.6777, -1.8277]],\n",
"\n",
" [[ 0.8645, -2.0846, 0.1532, ..., 0.2459, -2.4906, -0.4514],\n",
" [ 1.1494, 1.2483, -0.0495, ..., 0.1813, 0.8199, -0.5313],\n",
" [ 1.4952, 1.4661, -1.3266, ..., 0.6351, 0.5419, 0.3732],\n",
" ...,\n",
" [ 1.6627, 3.2121, -1.1410, ..., -0.1081, 2.2876, -1.0492],\n",
" [ 0.8725, 3.7180, -0.8677, ..., 0.5521, 0.0537, -2.0911],\n",
" [ 1.8427, 2.8496, -0.9180, ..., -0.0876, 2.9764, -1.0137]],\n",
"\n",
" [[ 1.0175, 0.3871, -0.1741, ..., -0.8094, -1.4149, -0.3730],\n",
" [-0.2747, 0.4294, -0.8148, ..., 0.7997, -1.0098, -0.2083],\n",
" [-0.1443, 0.1837, -0.6903, ..., 2.3234, -0.5142, -1.1581],\n",
" ...,\n",
" [-1.7553, -0.7940, -1.4744, ..., 1.9563, -0.3079, 0.2517],\n",
" [-1.1555, -0.9816, -1.4792, ..., 2.4893, -0.8572, 0.6439],\n",
" [ 0.2061, 0.6956, -1.2343, ..., 1.2946, -0.7649, -1.0596]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.2137, -0.5814, 0.4917, ..., -0.6758, 1.0594, 0.2809],\n",
" [-0.0787, 1.1178, -0.9665, ..., -2.9838, -0.0755, 0.3358],\n",
" [-0.3467, 0.6547, -1.9701, ..., -2.6404, -1.7759, 0.1484],\n",
" ...,\n",
" [-0.6088, 0.2404, -1.0831, ..., -2.5044, -0.5236, 0.2501],\n",
" [-1.2627, -0.4007, 0.0159, ..., -2.2715, -1.9617, 0.1351],\n",
" [-0.4629, 0.4004, -1.0877, ..., -3.2533, -0.1876, -0.2612]],\n",
"\n",
" [[ 0.2431, 0.5528, 0.5439, ..., 0.7452, 0.0856, 0.8468],\n",
" [ 0.3639, 2.4237, 0.9672, ..., 0.7770, -0.7330, 0.4097],\n",
" [-0.4982, 1.9386, -0.1103, ..., 1.4543, -0.3265, 0.4745],\n",
" ...,\n",
" [ 0.2191, 1.5633, -0.4826, ..., -0.9138, -0.7183, 0.2929],\n",
" [-2.4011, -0.7274, -0.1691, ..., 0.5614, -0.1154, 2.1418],\n",
" [ 1.8710, 2.7152, 0.3026, ..., 0.4339, -1.6067, 0.4278]],\n",
"\n",
" [[-0.7092, 0.3125, -1.6205, ..., -0.4008, 0.2350, -1.3048],\n",
" [ 0.0382, 0.8210, -1.6851, ..., 1.5476, 1.1133, -1.3639],\n",
" [-1.2253, -0.0602, -3.1185, ..., -0.4857, 1.8382, 0.9552],\n",
" ...,\n",
" [-3.0265, -0.1628, -0.6678, ..., 1.2046, 1.1136, -0.6637],\n",
" [-2.2680, -0.1403, -1.6040, ..., 0.0642, 0.6752, -0.0818],\n",
" [-0.6567, -0.4737, -1.8665, ..., 1.7928, 1.7230, -1.4443]]]],\n",
" grad_fn=<PermuteBackward0>), tensor([[[[ 3.4694e-03, 5.2180e-02, -7.1138e-02, ..., 5.9633e-02,\n",
" -5.0955e-02, -7.4279e-02],\n",
" [ 9.5441e-02, -8.2082e-04, 3.7786e-01, ..., -7.9814e-01,\n",
" -3.4941e-01, 5.3955e-01],\n",
" [-1.7982e-01, 1.0793e+00, 8.4480e-01, ..., -5.6335e-01,\n",
" 4.7423e-01, -1.3511e-01],\n",
" ...,\n",
" [-8.4265e-01, -2.8338e-02, -8.4992e-01, ..., -8.6247e-01,\n",
" 6.9610e-01, -1.6560e-01],\n",
" [-5.8879e-02, 7.0560e-01, 8.3837e-01, ..., -4.8124e-01,\n",
" -1.7102e+00, 4.3793e-01],\n",
" [ 4.7365e-02, 2.9308e-01, 2.9819e-01, ..., -3.3441e-01,\n",
" -6.1577e-01, 5.0968e-01]],\n",
"\n",
" [[ 3.7719e-02, 1.2977e-04, 4.9038e-02, ..., -3.9138e-02,\n",
" -2.6473e-02, -1.4142e-02],\n",
" [-1.9622e-01, 4.4304e-01, 9.3501e-02, ..., 8.6879e-01,\n",
" 5.8439e-01, 6.5467e-01],\n",
" [-8.1397e-01, -1.4557e+00, -3.4408e-01, ..., 1.0143e+00,\n",
" 1.6014e-01, -7.6486e-01],\n",
" ...,\n",
" [ 4.8213e-01, 1.1956e+00, -6.7466e-01, ..., 4.4558e-03,\n",
" -6.0745e-01, 1.5004e-01],\n",
" [-9.2434e-01, -9.9667e-02, -1.7371e-01, ..., 3.3668e-01,\n",
" 3.7452e-01, 9.1399e-01],\n",
" [-8.0525e-01, 2.7367e-01, 2.7182e-01, ..., 1.5725e+00,\n",
" 1.8934e-01, 9.1494e-01]],\n",
"\n",
" [[-6.9176e-03, 1.8243e-02, -3.3975e-02, ..., 8.5669e-03,\n",
" 2.7227e-02, 5.8461e-02],\n",
" [-2.8638e-01, 4.4393e-02, -2.4720e-01, ..., 5.8055e-01,\n",
" -1.1038e+00, -3.1214e-01],\n",
" [-4.1151e-02, 4.7980e-01, -8.1177e-01, ..., 2.5263e+00,\n",
" -6.2052e-01, -4.0801e-01],\n",
" ...,\n",
" [-6.4285e-01, 2.1790e-01, 7.1201e-01, ..., 7.6857e-01,\n",
" 1.9746e-02, -1.2292e-02],\n",
" [ 4.3683e-01, -2.0561e-01, 5.6170e-01, ..., -1.3195e+00,\n",
" -6.0955e-01, 8.5465e-01],\n",
" [-5.0826e-02, 2.0641e-01, 2.1014e-01, ..., -6.1202e-01,\n",
" -3.7409e-01, 5.8607e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-7.8410e-02, 2.6667e-02, 1.1429e-02, ..., -3.5996e-02,\n",
" -7.8381e-03, 1.2273e-03],\n",
" [-2.4955e-01, 3.0179e-01, 2.2439e-01, ..., -7.0245e-01,\n",
" -4.7259e-01, -1.2154e-01],\n",
" [-1.1360e+00, 4.8186e-01, 6.9660e-01, ..., 5.2388e-02,\n",
" 4.9656e-01, -7.1202e-01],\n",
" ...,\n",
" [-2.2132e-01, 2.5862e-01, 5.4504e-01, ..., 6.4937e-01,\n",
" -1.4201e-01, -6.9701e-02],\n",
" [-9.8322e-01, -1.1579e-01, 1.4461e+00, ..., 4.0303e-01,\n",
" -8.9281e-01, 9.6826e-01],\n",
" [-1.2535e+00, 7.9669e-01, 2.3864e+00, ..., -1.1996e+00,\n",
" -1.2942e-02, 1.5757e+00]],\n",
"\n",
" [[ 7.4373e-02, -1.0839e-03, 4.7472e-02, ..., 2.5576e-02,\n",
" 5.5578e-02, 3.0725e-02],\n",
" [ 1.6561e-01, 1.1326e+00, 1.1021e+00, ..., -6.7084e-02,\n",
" 1.0625e+00, -7.9841e-01],\n",
" [-1.1934e+00, 1.3455e+00, 7.5402e-01, ..., 3.0290e+00,\n",
" 1.9807e+00, -1.6143e-01],\n",
" ...,\n",
" [ 8.7119e-01, 1.6007e+00, 9.8724e-01, ..., 6.2297e-01,\n",
" 9.5836e-01, -6.7591e-02],\n",
" [ 5.6550e-01, 7.5545e-01, -9.4622e-01, ..., 3.9639e-01,\n",
" -1.3479e-01, 1.4511e-01],\n",
" [-2.5438e-01, 1.3767e+00, 1.5838e+00, ..., 5.7618e-02,\n",
" 1.7279e+00, -7.0514e-01]],\n",
"\n",
" [[-1.1214e-01, 2.7461e-02, -6.8169e-02, ..., -8.8035e-02,\n",
" 7.2290e-02, -2.1984e-02],\n",
" [-1.5667e-01, 3.3572e-01, -2.9793e-01, ..., -3.2849e-01,\n",
" -6.0364e-02, 8.4579e-02],\n",
" [-3.0011e-01, 3.6599e-01, 4.1995e-01, ..., -5.6659e-01,\n",
" 1.6448e-01, 2.1300e-01],\n",
" ...,\n",
" [-5.3965e-01, 8.5568e-01, -1.0334e+00, ..., 1.6571e+00,\n",
" 1.2634e+00, 1.2663e-02],\n",
" [-1.1969e+00, 1.1998e-01, 7.4285e-01, ..., 2.5529e+00,\n",
" 2.4390e+00, -3.4413e-01],\n",
" [-7.4888e-01, 4.5366e-01, -1.2199e+00, ..., -4.5325e-01,\n",
" 4.5486e-01, 9.2945e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-1.7115, -0.3095, -0.3052, ..., 0.1569, 0.3295, -0.5102],\n",
" [-0.2954, -0.5369, -1.0378, ..., 1.1073, -0.8305, 0.1360],\n",
" [ 0.3707, -0.1892, -0.0631, ..., -0.2839, -0.1220, 0.7498],\n",
" ...,\n",
" [ 0.1536, 0.6239, -0.8671, ..., 0.7387, 0.5640, 0.1241],\n",
" [ 2.3072, -0.4214, -0.9284, ..., 1.0017, 0.4802, 0.5843],\n",
" [ 0.1397, 0.0057, -1.1350, ..., 1.1576, -1.0530, 0.4489]],\n",
"\n",
" [[ 0.1108, -0.0705, 2.3025, ..., 0.2415, 0.0896, -0.1951],\n",
" [ 0.7120, -1.0674, -0.9315, ..., 0.1600, 0.2404, -0.4726],\n",
" [ 1.4893, -0.9425, -1.5101, ..., -0.3591, 0.0335, 0.4421],\n",
" ...,\n",
" [ 0.1341, -0.0506, -1.3000, ..., -0.1105, -0.2529, 0.7670],\n",
" [-0.5547, -0.6913, -1.2921, ..., 0.2898, -0.2538, 0.9526],\n",
" [ 0.2336, -0.8297, -0.8416, ..., -0.0980, -0.2919, 0.7454]],\n",
"\n",
" [[-0.2041, 1.0503, 0.4759, ..., -0.5452, 0.3040, -0.1147],\n",
" [-0.4822, 0.5302, -0.5850, ..., 1.5575, 0.2531, -0.3998],\n",
" [-0.5382, 0.5110, 0.0939, ..., 0.6880, -0.2115, -0.7376],\n",
" ...,\n",
" [-0.3389, 0.6598, -0.4937, ..., 0.4478, -0.8530, -0.5765],\n",
" [-0.7257, 1.3182, 0.9792, ..., 1.7382, 0.5813, -1.0117],\n",
" [-0.9058, 1.0543, -0.6513, ..., 2.0381, -0.3831, 0.2467]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.5577, 0.9790, -0.8716, ..., -0.7184, 0.7245, 0.8611],\n",
" [ 0.3537, 0.1875, -0.5812, ..., -0.5651, -0.0493, 0.0897],\n",
" [ 0.4359, -0.7858, -0.9179, ..., -1.4713, 0.2309, -0.2758],\n",
" ...,\n",
" [ 0.9503, -0.6978, 0.7306, ..., -0.7847, -0.9335, -0.8081],\n",
" [ 0.1232, 1.2112, 1.0973, ..., 0.2584, 1.1175, -0.0057],\n",
" [-0.0062, -0.2483, 0.2463, ..., -0.3165, 0.3718, -0.2848]],\n",
"\n",
" [[-0.4014, 0.3733, 0.3393, ..., 0.7212, 0.0451, -0.0838],\n",
" [-0.5197, 1.3345, -1.5982, ..., 0.5380, -0.2475, -0.9776],\n",
" [ 0.0646, 0.0452, -0.4746, ..., 0.9874, -0.8139, 0.1726],\n",
" ...,\n",
" [-0.6432, -0.4941, 0.4357, ..., 0.5838, -1.3339, -0.0826],\n",
" [-0.9413, -1.2357, -0.4911, ..., 1.3679, -1.0148, -1.4263],\n",
" [-0.9545, 0.2418, -1.5970, ..., 0.3238, -0.9107, -0.7229]],\n",
"\n",
" [[-0.7459, -0.0075, 0.4400, ..., -0.1109, 0.0299, -0.0598],\n",
" [-0.2137, 0.3865, 1.1712, ..., 0.4425, -0.3584, 1.2832],\n",
" [ 0.5957, 0.1015, -0.1897, ..., 0.4039, -1.3808, 1.2112],\n",
" ...,\n",
" [ 0.7437, -1.3902, 0.2656, ..., 0.9423, -1.2780, 1.6726],\n",
" [-0.7614, 0.3624, 1.4484, ..., 0.2220, -1.0658, 1.0444],\n",
" [-0.4612, 0.8413, 1.7939, ..., 0.1289, -0.8518, 1.1819]]]],\n",
" grad_fn=<PermuteBackward0>), tensor([[[[ 7.7570e-02, -1.1777e-01, -1.6829e-01, ..., -3.0139e-01,\n",
" 2.8640e-01, -1.7741e-01],\n",
" [ 3.7056e-01, 6.3164e-01, 7.7662e-01, ..., 2.7495e+00,\n",
" -1.5492e+00, 1.1155e+00],\n",
" [ 1.6399e+00, 1.3236e+00, 5.0145e-01, ..., 2.7380e+00,\n",
" -2.5362e+00, 2.0660e+00],\n",
" ...,\n",
" [-2.2210e-03, 1.2618e-01, 1.9018e-01, ..., 2.5907e+00,\n",
" -1.5682e+00, 7.7443e-01],\n",
" [ 1.0661e+00, 1.8362e-01, 9.9011e-01, ..., 1.7970e+00,\n",
" -1.8210e-01, -7.9636e-01],\n",
" [ 1.1176e+00, 9.5490e-01, 4.2716e-01, ..., 2.4762e+00,\n",
" -1.8121e+00, 1.6125e+00]],\n",
"\n",
" [[ 1.0853e-01, -1.0814e-02, 5.5897e-02, ..., -9.3695e-03,\n",
" -8.4395e-02, 1.6578e-01],\n",
" [ 9.0288e-02, 4.3214e-01, 7.7907e-02, ..., 3.6511e-01,\n",
" 4.1462e-01, -3.7498e-01],\n",
" [ 4.8901e-02, 1.1972e+00, -1.0267e-01, ..., -2.4577e-01,\n",
" 3.2252e-01, 9.5713e-02],\n",
" ...,\n",
" [ 1.4289e+00, -4.0081e-01, 8.8847e-01, ..., -1.2688e-01,\n",
" -2.1349e-01, -1.5179e+00],\n",
" [-1.8024e-01, -5.9997e-01, 1.6811e+00, ..., 8.8114e-01,\n",
" -1.2796e+00, 8.0612e-01],\n",
" [ 3.5363e-01, 1.5338e-01, 1.0489e-01, ..., 7.1419e-01,\n",
" -2.5939e-01, 1.1640e-01]],\n",
"\n",
" [[-1.3536e-02, 2.5633e-02, -3.8610e-02, ..., 4.7447e-02,\n",
" 4.5465e-04, 7.3786e-02],\n",
" [ 3.7973e-01, -2.6919e-01, -4.5875e-01, ..., -1.4160e-01,\n",
" 3.0695e-01, -4.8341e-01],\n",
" [ 1.1969e+00, 1.2378e+00, -6.2153e-01, ..., -9.3299e-01,\n",
" 5.5717e-02, -2.5939e-02],\n",
" ...,\n",
" [ 1.0509e+00, -6.8117e-01, -5.0678e-01, ..., -5.8349e-01,\n",
" 1.6390e-01, -4.4167e-01],\n",
" [-5.3312e-01, 6.3160e-01, 2.2554e-01, ..., -1.1507e+00,\n",
" 6.4968e-01, 3.7368e-01],\n",
" [ 2.3626e-01, -1.7837e-01, 2.7653e-01, ..., -8.8951e-02,\n",
" -3.4488e-02, -6.5983e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[-3.0262e-02, -1.2759e-02, 8.2024e-02, ..., 4.1477e-02,\n",
" -3.4039e-02, 1.6534e-02],\n",
" [ 7.0146e-02, 3.9249e-01, 3.6694e-02, ..., 1.1981e-01,\n",
" -3.4416e-01, -1.2740e-01],\n",
" [-1.4357e+00, -8.1313e-01, 3.6240e-01, ..., 6.4624e-01,\n",
" -7.6324e-01, 1.4873e+00],\n",
" ...,\n",
" [ 1.0556e-01, -3.8366e-01, 1.2748e+00, ..., -3.6558e-01,\n",
" 4.0858e-01, 2.4199e-01],\n",
" [-2.5444e-01, 1.1958e+00, -1.7147e-01, ..., 6.1984e-01,\n",
" -2.2845e-01, -1.8110e+00],\n",
" [ 3.2427e-01, 8.9915e-01, 1.1141e+00, ..., 6.8071e-01,\n",
" -3.4533e-01, -1.7910e-01]],\n",
"\n",
" [[-1.8907e-01, -6.5480e-02, 7.6243e-02, ..., -5.9887e-02,\n",
" 5.6530e-02, -7.3080e-02],\n",
" [-8.2506e-01, -3.6656e-02, 4.9222e-01, ..., 2.5220e-01,\n",
" 3.1897e-01, 1.9113e-01],\n",
" [-4.6517e-01, -2.1911e-01, -6.4030e-01, ..., 7.2280e-01,\n",
" 7.5668e-01, 5.6131e-01],\n",
" ...,\n",
" [-1.0660e+00, -4.2479e-01, -5.0573e-01, ..., -5.8658e-02,\n",
" -6.6094e-02, -4.4752e-01],\n",
" [-8.6907e-02, 1.2486e-04, -5.2314e-01, ..., 1.1544e-01,\n",
" 4.3831e-01, -1.0179e-02],\n",
" [-1.0669e+00, -7.1475e-01, 8.0158e-01, ..., -1.1919e-01,\n",
" -2.0185e-01, 3.2946e-01]],\n",
"\n",
" [[ 1.2763e-01, -1.2701e-01, 1.6529e-01, ..., -1.4527e-01,\n",
" -8.5370e-03, -1.7278e-01],\n",
" [-6.5069e-02, 3.5000e-01, 5.6586e-01, ..., -3.5917e-01,\n",
" -4.1324e-01, 2.9987e-01],\n",
" [ 2.5123e-01, 5.5106e-01, 4.2795e-01, ..., -1.0718e+00,\n",
" -6.8236e-01, -4.2256e-01],\n",
" ...,\n",
" [-6.0648e-01, -5.4619e-01, 1.4942e-02, ..., -7.6836e-01,\n",
" -5.9767e-01, -1.3891e-02],\n",
" [-3.4398e-01, -8.0992e-01, 7.4776e-01, ..., -1.8947e+00,\n",
" -2.7473e-01, 4.0089e-01],\n",
" [ 8.6354e-02, -1.2515e-02, -2.7977e-01, ..., -4.1148e-01,\n",
" -5.5178e-01, 7.0079e-02]]]], grad_fn=<PermuteBackward0>))), hidden_states=None, attentions=None, cross_attentions=None)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ -36.3292, -36.3402, -40.4228, ..., -46.0234, -44.5284,\n",
" -37.1276],\n",
" [-114.9346, -116.5035, -117.9236, ..., -117.8857, -119.3379,\n",
" -112.9298],\n",
" [-123.5036, -123.0548, -127.3876, ..., -130.5238, -130.5279,\n",
" -123.2711],\n",
" ...,\n",
" [-101.3852, -101.2506, -103.6583, ..., -103.3747, -107.7192,\n",
" -99.4521],\n",
" [ -83.0701, -84.3884, -91.9513, ..., -91.7482, -93.3971,\n",
" -85.1204],\n",
" [ -91.2749, -93.1332, -93.6408, ..., -94.3482, -93.4517,\n",
" -90.1472]]], grad_fn=<UnsafeViewBackward0>)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output[0]"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 9, 50257])"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.return_types.topk(\n",
"values=tensor([[ -32.8755, -33.1021, -33.9975, -34.4861, -34.5463],\n",
" [-105.5972, -106.3818, -106.3978, -106.9693, -107.0778],\n",
" [-113.2521, -114.7346, -114.8781, -114.9605, -115.0834],\n",
" [-118.2435, -119.2980, -119.5907, -119.6229, -119.7969],\n",
" [ -83.6241, -84.6822, -84.8526, -85.4978, -86.6938],\n",
" [ -79.9051, -80.3284, -81.6157, -81.8538, -82.9018],\n",
" [ -90.4443, -90.7053, -91.9059, -92.0003, -92.1531],\n",
" [ -75.2650, -76.9698, -77.5753, -77.6700, -77.8095],\n",
" [ -78.7985, -81.5545, -81.6846, -81.8984, -82.5938]],\n",
" grad_fn=<TopkBackward0>),\n",
"indices=tensor([[ 11, 13, 198, 290, 286],\n",
" [ 262, 356, 314, 340, 257],\n",
" [ 262, 257, 1737, 2901, 2805],\n",
" [ 835, 717, 938, 10955, 1218],\n",
" [ 284, 736, 1363, 503, 422],\n",
" [ 670, 262, 616, 257, 1524],\n",
" [ 9003, 2607, 11550, 4436, 4495],\n",
" [ 11, 314, 338, 284, 287],\n",
" [ 314, 616, 257, 262, 612]]))"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.topk(output[0][0],5)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([8888, 11, 319, 616, 835, 284, 262, 6403, 11])"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"encoding.input_ids[0]"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Today, \t→ the\n",
"Today, on \t→ the\n",
"Today, on my \t→ way\n",
"Today, on my way \t→ to\n",
"Today, on my way to \t→ work\n",
"Today, on my way to the \t→ airport\n",
"Today, on my way to the university \t→ ,\n",
"Today, on my way to the university, \t→ I\n"
]
}
],
"source": [
"for i in range(1,len(encoding.input_ids[0])):\n",
" print(tokenizer.decode(encoding.input_ids[0][:i+1]), '\\t→', tokenizer.decode(torch.topk(output[0][0],1).indices[i]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### generowanie tekstu"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': tensor([[8888, 11, 319, 616, 835, 284, 262, 6403, 11]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"encoding"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"text = TEXT"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Today, on my way to the university,'"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': tensor([[8888, 11, 319, 616, 835, 284, 262, 6403, 11]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"encoding"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"encoding = tokenizer(text, return_tensors='pt')"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"for i in range(10):\n",
" output =pt_model(**encoding)\n",
" text += tokenizer.decode(torch.topk(output[0][0][-1],1).indices)\n",
" encoding = tokenizer(text, return_tensors='pt')"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Today, on my way to the university, I was approached by a man who was a student'"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://towardsdatascience.com/decoding-strategies-that-you-need-to-know-for-response-generation-ba95ee0faadc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Co można zrobić, żeby poprawić wynik? Strategie dekodowania:\n",
"\n",
"- greedy search\n",
"- random sampling\n",
"- random sampling with temperature\n",
"- top-k sampling lub top-k sampling with temperature\n",
"- top-p sampling (inna nazwa: nucleus sampling) lub top-p sampling with temperature\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### https://huggingface.co/tasks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## pipeline"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"generator = pipeline('text-generation', model=model_name)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Today, on my way to the university,'"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"TEXT"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"data": {
"text/plain": [
"[{'generated_text': 'Today, on my way to the university, some of them would have been very pleased, and I'},\n",
" {'generated_text': 'Today, on my way to the university, and he made me dinner, and he called me back'},\n",
" {'generated_text': 'Today, on my way to the university, I saw three white girls who seemed a bit different—'},\n",
" {'generated_text': 'Today, on my way to the university, I drove through the town, past trees and bushes,'},\n",
" {'generated_text': 'Today, on my way to the university, I saw an elderly lady come up behind me.\"\\n'}]"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generator(TEXT, max_length=20, num_return_sequences=5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://huggingface.co/docs/transformers/main_classes/text_generation"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"data": {
"text/plain": [
"[{'generated_text': 'Today, on my way to the university, I was approached by a man who was a student at'}]"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generator(TEXT, max_length=20, num_beams=1, do_sample=False)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"data": {
"text/plain": [
"[{'generated_text': 'Today, on my way to the university, I was approached by a man who was very nice and'}]"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generator(TEXT, max_length=20, num_beams=10, top_p = 0.2)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"data": {
"text/plain": [
"[{'generated_text': 'Today, on my way to the university, I was approached by a group of students who asked me'}]"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generator(TEXT, max_length=20, num_beams=10, temperature = 1.0 )"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"data": {
"text/plain": [
"[{'generated_text': 'Today, on my way to the university, I noticed some young boys who was very active on campus'}]"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generator(TEXT, max_length=20, num_beams=10, temperature = 10.0 )"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"data": {
"text/plain": [
"[{'generated_text': 'Today, on my way to the university, the trainees have noticed how a car could become an'}]"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generator(TEXT, max_length=20, num_beams=10, temperature = 100.0 )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"inne możliwość:\n",
"\n",
"\n",
"- repetition_penalty\n",
"- length_penalty\n",
"- no_repeat_ngram_size\n",
"- bad_words_ids\n",
"- force_words_ids"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## huggingface API"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://huggingface.co/gpt2?text=Today%2C+on+my+way+to+the+university"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"from transformers import CTRLTokenizer, CTRLModel"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = CTRLTokenizer.from_pretrained(\"ctrl\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## CTRL"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"inputs = tokenizer(\"Opinion My dog is cute\", return_tensors=\"pt\")"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': tensor([[43213, 586, 3153, 8, 83781]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'Pregnancy': 168629,\n",
" 'Christianity': 7675,\n",
" 'Explain': 106423,\n",
" 'Fitness': 63440,\n",
" 'Saving': 63163,\n",
" 'Ask': 27171,\n",
" 'Ass': 95985,\n",
" 'Joke': 163509,\n",
" 'Questions': 45622,\n",
" 'Thoughts': 49605,\n",
" 'Retail': 52342,\n",
" 'Feminism': 164338,\n",
" 'Writing': 11992,\n",
" 'Atheism': 192263,\n",
" 'Netflix': 48616,\n",
" 'Computing': 39639,\n",
" 'Opinion': 43213,\n",
" 'Alone': 44967,\n",
" 'Funny': 58917,\n",
" 'Gaming': 40358,\n",
" 'Human': 4088,\n",
" 'India': 1331,\n",
" 'Joker': 77138,\n",
" 'Diet': 36206,\n",
" 'Legal': 11859,\n",
" 'Norman': 4939,\n",
" 'Tip': 72689,\n",
" 'Weight': 52343,\n",
" 'Movies': 46273,\n",
" 'Running': 23425,\n",
" 'Science': 2090,\n",
" 'Horror': 37793,\n",
" 'Confession': 60572,\n",
" 'Finance': 12250,\n",
" 'Politics': 16360,\n",
" 'Scary': 191985,\n",
" 'Support': 12654,\n",
" 'Technologies': 32516,\n",
" 'Teenage': 66160,\n",
" 'Event': 32769,\n",
" 'Learned': 67460,\n",
" 'Notion': 182770,\n",
" 'Wikipedia': 37583,\n",
" 'Books': 6665,\n",
" 'Extract': 76050,\n",
" 'Confessions': 102701,\n",
" 'Conspiracy': 75932,\n",
" 'Links': 63674,\n",
" 'Narcissus': 150425,\n",
" 'Relationship': 54766,\n",
" 'Relationships': 134796,\n",
" 'Reviews': 41671,\n",
" 'News': 4256,\n",
" 'Translation': 26820,\n",
" 'multilingual': 128406}"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.control_codes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages/transformers/models/ctrl/modeling_ctrl.py:43: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
" angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / d_model_size)\n"
]
}
],
"source": [
"generator = pipeline('text-generation', model=\"ctrl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"TEXT = \"Today\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"generator(\"Opinion \" + TEXT, max_length = 50)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[{'generated_text': 'Opinion Today I learned that the US government has been spying on the citizens of other countries for years. \\n Score: 6 \\n \\n Title: CMV: I think that the US should not be involved in the Middle East \\n Text: I think that the US'}]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"generator(\"Technologies \" + TEXT, max_length = 50)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[{'generated_text': 'Technologies Today \\n Score: 6 \\n \\n Title: The Internet is a great tool for the average person to get information and to share it with others. But it is also a great tool for the government to spy on us. \\n Score: 6 \\n \\n Title: The'}]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"generator(\"Gaming \" + TEXT, max_length = 50)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[{'generated_text': 'Gaming Today \\n Score: 6 \\n \\n Title: I just got a new gaming pc and I have a question \\n Text: I just got a new gaming pc and I have a question \\n \\n I have a monitor that I bought a while back'}]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Zadanie \n",
"\n",
"Za pomocą GPT2 lub distillGPT wygenerować odpowiedzi dla wyzwania challanging america. Nie trzeba douczać modelu."
]
}
],
"metadata": {
"author": "Jakub Pokrywka",
"email": "kubapok@wmi.amu.edu.pl",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]",
"title": "Ekstrakcja informacji",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}