modelowanie-jezykowe-aitech-cw/cw/12_Model_transformer_autoregresywny.ipynb
Jakub Pokrywka dd4cf0f01e 12
2022-06-05 22:35:38 +02:00

243 KiB
Raw Permalink Blame History

Logo 1

Modelowanie Języka

12. Model rekurencyjny z atencją [ćwiczenia]

Jakub Pokrywka (2022)

Logo 2

!pip install transformers
Requirement already satisfied: transformers in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (4.19.2)
Requirement already satisfied: tqdm>=4.27 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (4.64.0)
Requirement already satisfied: numpy>=1.17 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (1.22.3)
Requirement already satisfied: requests in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (2.27.1)
Requirement already satisfied: packaging>=20.0 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (21.3)
Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (0.12.1)
Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (0.6.0)
Requirement already satisfied: filelock in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (3.7.0)
Requirement already satisfied: regex!=2019.12.17 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (2022.4.24)
Requirement already satisfied: pyyaml>=5.1 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from transformers) (6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.1.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from packaging>=20.0->transformers) (3.0.8)
Requirement already satisfied: certifi>=2017.4.17 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from requests->transformers) (2020.6.20)
Requirement already satisfied: idna<4,>=2.5 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from requests->transformers) (3.3)
Requirement already satisfied: charset-normalizer~=2.0.0 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from requests->transformers) (2.0.4)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages (from requests->transformers) (1.26.9)
import torch
from transformers import pipeline, set_seed, AutoTokenizer, AutoModel, AutoModelForCausalLM

przykładowy tekst

TEXT = 'Today, on my way to the university,'

użycie modelu w bibliotece transormers

model_name = "gpt2"

w przypadku długiego czasu inferencji lub za małą ilością RAMu użyj mniejszego modelu:

# model_name = 'distilgpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoding = tokenizer(TEXT)
encoding
{'input_ids': [8888, 11, 319, 616, 835, 284, 262, 6403, 11], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}
for token in encoding['input_ids']:
    print(token, '\t', tokenizer.decode(token))
8888 	 Today
11 	 ,
319 	  on
616 	  my
835 	  way
284 	  to
262 	  the
6403 	  university
11 	 ,
pt_model = AutoModel.from_pretrained(model_name)
encoding
{'input_ids': [8888, 11, 319, 616, 835, 284, 262, 6403, 11], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}

poniżej pojawi się błąd, ponieważ na wejściu modelu muszą być tensory

pt_model(**encoding)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Input In [15], in <cell line: 1>()
----> 1 pt_model(**encoding)

File ~/anaconda3/envs/zajeciaei/lib/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/zajeciaei/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:769, in GPT2Model.forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)
    767     raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
    768 elif input_ids is not None:
--> 769     input_shape = input_ids.size()
    770     input_ids = input_ids.view(-1, input_shape[-1])
    771     batch_size = input_ids.shape[0]

AttributeError: 'list' object has no attribute 'size'
TEXT
'Today, on my way to the university,'
encoding = tokenizer(TEXT, return_tensors='pt')
?pt_model.forward
output = pt_model(**encoding, output_hidden_states= True)
output
BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[ 0.0502,  0.0018, -0.1750,  ..., -0.1020, -0.0257, -0.1292],
         [ 0.1300,  0.1757,  0.2934,  ...,  0.0794,  0.1164, -0.3280],
         [ 0.0021, -0.2481,  0.2638,  ...,  0.1507,  0.4056,  0.2376],
         ...,
         [ 0.1611, -0.4680,  0.7029,  ...,  0.1209,  0.3803,  0.2864],
         [ 0.1791, -0.3507, -1.2709,  ..., -0.1535, -0.7109, -0.2459],
         [ 0.2872, -0.0504,  0.0839,  ...,  0.3417, -0.0518, -0.3151]]],
       grad_fn=<ViewBackward0>), past_key_values=((tensor([[[[-7.0634e-01,  1.9011e+00,  7.7253e-01,  ..., -1.3028e+00,
           -5.0432e-01,  1.6823e+00],
          [-1.6482e+00,  3.0222e+00,  1.2789e+00,  ..., -9.0779e-01,
           -1.7395e+00,  2.4237e+00],
          [-2.3128e+00,  2.8957e+00,  1.8368e+00,  ..., -7.0370e-01,
           -1.6305e+00,  2.4407e+00],
          ...,
          [-2.4337e+00,  2.5271e+00,  2.1513e+00,  ..., -5.8053e-01,
           -1.6483e+00,  2.0594e+00],
          [-3.8223e+00,  2.1391e+00,  1.7587e+00,  ..., -1.0668e+00,
           -1.6278e+00,  1.1729e+00],
          [-1.9238e+00,  2.7944e+00,  1.6292e+00,  ..., -8.9733e-01,
           -2.2193e+00,  2.6272e+00]],

         [[-9.6153e-02,  8.9928e-01, -1.4324e+00,  ..., -3.8667e-03,
            1.7698e+00,  6.0074e-01],
          [ 2.7222e-01, -1.2016e+00, -1.9081e+00,  ..., -1.3531e+00,
            1.2823e+00, -4.3198e-01],
          [-1.1722e+00, -3.6670e-01, -1.6921e+00,  ..., -1.2359e+00,
            2.5243e+00,  1.0228e+00],
          ...,
          [-1.6694e-01, -1.0159e+00, -2.5232e+00,  ..., -9.7920e-01,
            4.8265e+00, -1.7799e+00],
          [-1.1981e-01, -2.6784e+00, -2.9551e+00,  ..., -1.9840e-01,
            3.3916e+00, -1.9762e-02],
          [ 3.2722e-01, -1.2197e+00, -2.1079e+00,  ..., -1.6297e+00,
            9.2404e-01, -7.6080e-01]],

         [[-1.4670e-01,  2.1407e-01,  1.1498e+00,  ..., -1.3128e+00,
           -2.1007e+00,  5.6910e-01],
          [ 5.5608e-01, -4.6297e-01,  7.4483e-01,  ..., -1.8272e+00,
            5.4572e-01,  1.0119e+00],
          [ 9.2851e-01,  4.6049e-03,  4.1324e-01,  ..., -2.4987e+00,
            5.2423e-01,  1.5260e+00],
          ...,
          [ 3.2328e-01,  3.5316e-01,  3.2756e-02,  ..., -3.2780e+00,
            8.1692e-01,  1.4566e+00],
          [-2.1528e-01, -2.2490e-01, -1.4536e+00,  ..., -3.7075e+00,
            1.6835e+00,  1.6085e+00],
          [ 7.6672e-01, -5.3757e-01,  4.2462e-01,  ..., -2.2908e+00,
            1.7213e+00,  1.0240e+00]],

         ...,

         [[ 5.4733e-01,  4.7672e-01, -2.2749e-01,  ...,  2.9014e-01,
            7.7821e-01,  7.8295e-01],
          [ 1.6820e-01, -9.1829e-02, -5.0034e-02,  ...,  7.3646e-01,
            6.1343e-01,  5.4442e-01],
          [ 2.9530e-02, -5.3167e-02, -6.1709e-02,  ...,  1.0934e+00,
            3.7083e-01,  3.8425e-01],
          ...,
          [-1.3203e-02, -2.6465e-01,  4.4834e-02,  ...,  1.2205e+00,
            5.4265e-01,  3.7732e-01],
          [ 8.5854e-02, -2.3791e-01, -1.1271e-01,  ...,  1.8211e+00,
           -5.7249e-01, -7.4493e-01],
          [-3.6544e-02, -1.4250e-01,  6.6582e-02,  ...,  1.0489e+00,
            4.8485e-01,  4.6476e-01]],

         [[ 1.4700e+00,  1.3564e+00, -4.9892e-01,  ..., -6.4925e-02,
            1.4507e+00, -1.2267e+00],
          [ 1.0113e+00,  7.0108e-01, -5.7364e-01,  ..., -7.1721e-01,
            1.0731e+00, -1.0718e+00],
          [ 1.1010e+00,  4.8299e-01, -9.3231e-01,  ..., -1.5044e+00,
            1.2941e+00, -3.3869e-01],
          ...,
          [ 1.1745e+00,  6.3323e-01, -6.1605e-01,  ..., -8.1925e-01,
            5.2691e-01, -7.5443e-01],
          [ 1.7895e+00,  5.7095e-01, -3.5775e-01,  ..., -1.3193e+00,
            5.5676e-01, -1.6293e-01],
          [ 9.6151e-01,  2.9245e-02, -5.3493e-01,  ..., -7.8683e-01,
            3.7355e-01, -2.4032e-01]],

         [[ 7.1643e-01, -3.1278e-01,  1.4058e-01,  ..., -2.0734e-01,
            2.5946e-01,  1.7684e+00],
          [-5.6619e-01,  7.8687e-01,  2.5152e-02,  ...,  6.2100e-01,
            4.7592e-01,  5.4321e-01],
          [-6.2611e-01,  3.3320e-01,  1.1092e-01,  ...,  6.4703e-01,
            6.4159e-01,  7.2777e-01],
          ...,
          [-1.7180e-01,  1.1778e+00, -2.3931e-01,  ..., -6.3932e-01,
            1.1654e+00,  4.0462e-01],
          [-4.8319e-01,  2.8237e-01, -4.4490e-01,  ..., -1.2013e-01,
            4.8413e-01, -4.5133e-01],
          [-1.1252e+00,  7.6533e-01, -6.0320e-02,  ...,  1.8912e-01,
            7.8018e-01, -5.4733e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[ 0.1900,  0.0015, -0.0517,  ...,  0.0536,  0.0312, -0.0694],
          [-0.0800,  0.0181, -0.0534,  ..., -0.0419, -0.0365,  0.0151],
          [ 0.0448,  0.1912, -0.1849,  ..., -0.0062, -0.1420,  0.1609],
          ...,
          [-0.1635,  0.0196,  0.1185,  ...,  0.0794,  0.0980, -0.1084],
          [-0.2303,  0.1991, -0.1576,  ...,  0.2774, -0.1813, -0.2463],
          [-0.1009,  0.0410, -0.0970,  ..., -0.0684, -0.0763,  0.0260]],

         [[ 0.4406,  0.1176, -0.2136,  ..., -0.6839, -0.2371,  0.2999],
          [ 0.5926,  0.0197,  0.1107,  ...,  0.1253,  0.5675, -0.2665],
          [ 0.6762,  0.0459, -0.3685,  ...,  0.0744,  0.5420, -0.1240],
          ...,
          [ 0.8509, -0.0962,  0.0762,  ..., -0.1705,  0.1339,  0.1068],
          [ 0.2928, -0.2582,  0.1735,  ...,  0.0800,  0.2879, -0.0139],
          [ 0.5969,  0.0592,  0.0263,  ..., -0.0100,  0.5129, -0.1905]],

         [[ 0.0810, -0.1910,  0.1092,  ..., -0.0283,  0.0408,  0.0961],
          [-0.3257,  0.0398, -0.1531,  ...,  0.0411, -0.0413,  0.0745],
          [ 0.5201,  0.0126,  0.3504,  ...,  0.1020,  0.0543, -0.2188],
          ...,
          [-0.5288, -0.0025, -0.5926,  ..., -0.1874, -0.0674,  0.3113],
          [ 0.1521,  0.0271, -0.2514,  ..., -0.0465, -0.0565, -0.3401],
          [-0.2885,  0.0590, -0.1736,  ...,  0.0685, -0.1112,  0.0604]],

         ...,

         [[ 0.0111, -0.0168,  0.0263,  ..., -0.2135,  0.2054,  0.0729],
          [-0.3022, -0.0878,  0.1001,  ...,  0.0262, -0.1647,  0.1682],
          [-0.1587, -0.0666,  0.0826,  ..., -0.0416,  0.0812,  0.2067],
          ...,
          [-0.0925, -0.4836,  0.0332,  ...,  0.0641, -0.1597,  0.2375],
          [-0.0742,  0.8589,  0.0336,  ..., -0.3268, -0.2455,  0.3080],
          [-0.0869, -0.4287,  0.1231,  ..., -0.0474, -0.1705,  0.0347]],

         [[ 0.2081, -0.2399, -0.1318,  ...,  0.1471,  0.1123, -0.0316],
          [-0.2119,  0.0589,  0.0997,  ...,  0.0038,  0.1331,  0.0930],
          [-0.1213,  0.1404,  0.1775,  ...,  0.1688, -0.0020,  0.0829],
          ...,
          [-0.2325,  0.1252, -0.0345,  ...,  0.2837,  0.0686, -0.0089],
          [ 0.1896,  0.0282, -0.0740,  ...,  0.1655, -0.3020,  0.2837],
          [ 0.0298,  0.0086, -0.1626,  ...,  0.1976,  0.0970, -0.0014]],

         [[-0.0689, -0.3955,  0.2328,  ...,  0.1539, -0.1823, -0.0845],
          [ 0.0538, -0.2648, -0.0146,  ...,  0.2331,  0.0516,  0.0924],
          [-0.0647,  0.0062,  0.1329,  ...,  0.1026,  0.1185,  0.0463],
          ...,
          [ 0.0186,  0.1904, -0.0966,  ...,  0.0714, -0.0321, -0.0059],
          [ 0.0219,  0.4180, -0.1580,  ..., -0.0072, -0.2708,  0.1529],
          [ 0.1236, -0.3671, -0.0392,  ...,  0.1061, -0.0278, -0.0074]]]],
       grad_fn=<PermuteBackward0>)), (tensor([[[[-3.5429e-01,  2.2092e+00, -1.5580e+00,  ...,  1.4397e+00,
           -1.1504e+00,  1.4646e+00],
          [ 7.3885e-01,  1.8177e+00, -1.4766e+00,  ..., -4.6761e-01,
           -1.6869e+00,  5.0785e-01],
          [ 1.6962e+00,  1.1427e+00, -1.1112e+00,  ...,  1.2764e-01,
           -2.5909e+00,  7.2933e-01],
          ...,
          [-1.9130e-03,  1.6441e+00, -3.0120e-01,  ...,  3.8508e-01,
           -1.0645e+00, -4.5135e-01],
          [-3.9438e-01,  1.6005e+00,  9.6257e-01,  ...,  5.8858e-01,
           -1.8425e+00, -9.6318e-01],
          [-4.9488e-01,  1.1094e+00,  5.2522e-02,  ...,  5.6471e-01,
           -1.3969e+00, -3.0882e-01]],

         [[-1.0087e+00, -4.5958e-01, -7.4797e-01,  ..., -3.7310e-01,
            7.9809e-01, -2.3881e-01],
          [-6.6438e-02,  4.8658e-01, -8.2457e-01,  ..., -9.4308e-01,
            1.8907e-01, -1.5256e-02],
          [-1.7392e-01,  1.1992e+00, -1.5513e+00,  ..., -3.2774e-01,
            7.3627e-01, -3.6968e-01],
          ...,
          [-1.1986e-01,  6.0111e-01, -1.4226e+00,  ..., -6.1346e-01,
            1.3460e-01, -6.1240e-01],
          [ 1.8174e-01,  3.1973e-01, -2.2986e+00,  ..., -4.1319e-01,
           -1.0757e+00, -4.7605e-01],
          [-2.4593e-01,  1.1035e+00, -1.4215e+00,  ..., -6.2691e-01,
           -1.1097e+00, -6.3956e-01]],

         [[ 3.2591e-01, -1.6143e-02, -2.0098e-01,  ..., -1.3362e+00,
            3.3876e-01, -1.6542e-01],
          [-1.0002e-02,  3.9666e-01, -9.3499e-02,  ..., -1.0921e+00,
            5.6914e-02,  4.1318e-01],
          [-1.1656e-02,  2.1262e-01, -2.3546e-01,  ..., -9.7254e-01,
            1.4688e-01,  2.7869e-01],
          ...,
          [-8.3349e-02,  3.9433e-02, -9.7432e-03,  ..., -7.0562e-01,
            4.2687e-01,  2.3274e-01],
          [ 1.0450e-01, -2.0783e-01, -2.8860e-01,  ..., -1.0073e+00,
           -1.2179e-01,  3.5471e-01],
          [-1.4484e-01, -5.0447e-02, -3.9541e-03,  ..., -1.0255e+00,
            1.9039e-01,  3.3890e-01]],

         ...,

         [[ 2.1528e-01, -4.6627e-01, -5.9642e-01,  ..., -4.2178e-01,
            4.3739e-01, -8.5899e-01],
          [-5.0305e-02,  1.2479e+00,  1.8768e+00,  ...,  6.8503e-01,
           -7.3186e-01, -3.4076e-01],
          [-4.0512e-01,  1.6082e+00,  1.8570e+00,  ...,  1.2636e+00,
           -1.1781e+00, -8.1034e-01],
          ...,
          [-5.1299e-01,  2.6865e-01,  7.6903e-01,  ..., -1.3940e+00,
            8.1194e-01, -1.8763e-01],
          [ 2.3526e-01, -5.7615e-01,  1.3541e+00,  ...,  1.4708e+00,
           -2.9934e-01, -3.9407e-01],
          [ 5.0755e-02,  7.0489e-01,  1.9166e+00,  ...,  6.6883e-01,
           -9.1450e-01, -2.5584e-01]],

         [[-1.1473e+00, -2.7966e+00,  1.4438e-01,  ...,  1.7208e+00,
            1.5965e+00, -1.4860e+00],
          [ 3.5231e-01,  7.5960e-01, -4.7429e-01,  ..., -8.1442e-01,
            4.5442e-01, -2.9752e-01],
          [ 2.1113e-01,  7.5264e-01, -4.5093e-01,  ..., -9.6233e-01,
            5.8766e-01,  9.0545e-02],
          ...,
          [ 1.6897e-01,  2.5023e-01, -7.4581e-01,  ..., -1.2799e-01,
            7.1349e-01, -8.5998e-02],
          [-2.3828e-01,  5.9684e-01, -7.5936e-01,  ..., -6.6564e-01,
            7.3313e-01,  1.8287e-01],
          [-1.6440e-01,  2.5931e-01, -8.1777e-01,  ..., -3.5322e-01,
            8.3564e-01, -5.9446e-02]],

         [[ 1.3976e+00,  1.6241e+00,  5.4245e-01,  ..., -7.8420e-01,
            1.1678e-01,  3.7706e-01],
          [ 8.8908e-01,  2.1345e+00,  1.0939e+00,  ...,  1.1961e-01,
           -7.5297e-01, -1.4081e-01],
          [ 6.7893e-01,  1.8408e+00,  1.5060e+00,  ...,  5.9498e-01,
           -2.2553e+00, -1.8270e+00],
          ...,
          [-5.1015e-02,  2.4946e+00, -1.6883e-01,  ...,  5.4761e-01,
           -2.8891e-01, -6.7954e-01],
          [-1.6942e-01,  4.9026e-01,  1.1144e+00,  ...,  9.3912e-03,
           -8.0171e-01, -1.4243e-01],
          [ 8.4424e-01,  1.7401e+00,  9.2639e-01,  ..., -1.4967e-01,
           -3.8360e-01, -1.5520e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[ 3.3872e-01,  1.3968e-01, -1.7938e-01,  ...,  1.5467e-01,
           -1.2589e-01,  7.0887e-02],
          [ 3.7346e-01,  2.8615e-01,  7.3073e-02,  ..., -1.7334e-01,
           -1.7929e-01,  8.0809e-02],
          [ 1.3121e-01,  1.3779e-01,  9.8802e-02,  ...,  1.7611e-01,
           -6.5489e-01, -3.7171e-01],
          ...,
          [ 4.5774e-01,  6.2110e-02,  4.7204e-02,  ...,  2.1876e-01,
           -1.9506e-01,  1.5526e-01],
          [-1.6503e-01,  7.2050e-02, -4.4076e-01,  ...,  9.3966e-02,
           -8.1660e-02, -2.9702e-01],
          [ 3.7986e-01,  3.8336e-01,  1.0341e-01,  ..., -1.9899e-01,
           -2.3373e-01, -1.3201e-01]],

         [[-7.9321e-02, -6.6966e-02, -2.2227e-01,  ..., -1.4152e-02,
           -4.5964e-01,  2.7340e-01],
          [-2.0632e-01, -2.7675e-01,  9.3918e-02,  ..., -9.7495e-02,
            2.0266e-01,  3.4913e-02],
          [-3.6562e-01, -2.8439e-01,  2.9782e-01,  ..., -1.0605e+00,
            2.7564e-01,  3.3809e-01],
          ...,
          [ 5.1779e-01,  2.3170e-01, -3.0248e-01,  ...,  4.6880e-01,
            4.3330e-01, -6.2105e-01],
          [-1.9805e-01,  6.8445e-02, -5.7586e-02,  ...,  1.3844e-01,
           -6.2666e-02,  1.8667e-01],
          [ 6.9782e-02, -1.5278e-01,  6.9243e-02,  ..., -1.0944e-01,
            1.1224e-01,  1.1524e-01]],

         [[ 7.9376e-02, -1.4863e-02, -4.4028e-02,  ..., -6.2825e-01,
            6.7840e-02,  1.0440e-02],
          [ 4.2720e-01,  2.4379e-01,  2.3040e-01,  ..., -5.0812e-01,
            3.7279e-02, -1.3192e-01],
          [ 6.2018e-01,  1.7793e-01,  2.9474e-01,  ..., -7.6162e-01,
           -2.8552e-01, -1.4080e-01],
          ...,
          [ 5.8184e-01,  5.9326e-02,  2.5048e-03,  ..., -6.1473e-01,
           -3.0034e-02,  4.4224e-02],
          [ 6.7462e-01,  1.3863e-01, -5.1645e-02,  ..., -5.6261e-01,
           -2.2474e-01, -1.2376e-01],
          [ 6.0415e-01,  9.6460e-02,  1.1331e-01,  ..., -2.8026e-01,
            2.4650e-02, -2.4321e-01]],

         ...,

         [[ 1.0567e-01,  6.7946e-01, -1.7619e-01,  ...,  1.2480e-02,
           -9.7338e-01, -2.5708e-01],
          [-5.0101e-04, -7.4670e-01,  1.4215e-01,  ...,  2.6520e-02,
           -9.1824e-01, -4.4347e-01],
          [ 5.7162e-02, -6.6084e-01, -1.7225e-01,  ..., -6.7773e-02,
           -6.9370e-01,  2.2682e-01],
          ...,
          [ 3.6897e-01,  4.0040e-01,  1.3203e-01,  ...,  5.9832e-02,
           -4.3946e-01,  3.3851e-02],
          [-1.9931e-01,  4.7522e-01,  6.5326e-01,  ...,  8.5060e-01,
           -1.5948e-01,  2.6952e-01],
          [ 4.5483e-02, -7.9412e-01,  2.0943e-01,  ...,  6.4299e-02,
           -6.5777e-01, -2.0458e-01]],

         [[ 4.7333e-02, -1.1130e-02, -1.4608e-01,  ...,  3.8364e-01,
           -3.4244e+00,  6.6758e-02],
          [ 5.0051e-01,  8.4673e-03,  1.9747e-01,  ...,  2.1474e-01,
           -7.4449e-03, -2.8373e-01],
          [-2.0428e-01,  2.4512e-01, -2.7017e-01,  ...,  4.5577e-02,
            2.1612e-02, -1.3106e-01],
          ...,
          [ 7.3244e-02, -1.5794e-01,  1.7578e-01,  ..., -2.2690e-01,
           -6.3669e-02, -1.8729e-02],
          [ 1.3369e-01,  4.0795e-01, -6.9403e-02,  ..., -2.8477e-02,
            8.1580e-02, -3.7645e-01],
          [ 3.2948e-01,  2.4525e-01,  3.1002e-02,  ...,  1.4547e-03,
           -2.0459e-01, -1.3566e-02]],

         [[ 2.4439e-02, -2.3092e-01,  1.1163e-02,  ..., -3.4285e-01,
            2.7007e-01, -3.4211e-02],
          [ 2.0095e-01, -4.9356e-01,  5.3058e-01,  ..., -2.7157e-01,
            4.2807e-01,  3.2917e-01],
          [-1.0993e-01, -4.1360e-01,  1.9816e-02,  ..., -1.7917e-01,
            3.6033e-01,  2.2954e-01],
          ...,
          [ 4.2263e-02,  1.5875e-02, -3.0871e-01,  ..., -3.1441e-01,
            2.9030e-01,  2.2213e-01],
          [-4.9536e-02,  8.3578e-02,  7.2786e-02,  ..., -2.5493e-01,
            4.7891e-02,  3.4251e-01],
          [ 5.0301e-02, -1.8544e-01,  5.7551e-01,  ..., -3.4349e-01,
            1.5927e-01,  4.2942e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-1.5217e-01, -1.1477e+00,  2.3295e-01,  ..., -6.4279e-01,
           -1.1349e-01,  4.0799e-02],
          [ 4.5919e-01, -2.0374e+00, -7.9378e-01,  ...,  4.4668e-02,
           -8.8579e-01, -9.0097e-01],
          [ 3.8866e-01, -1.6082e+00, -3.9608e-01,  ...,  3.1908e-01,
           -4.2160e-01, -1.1912e-01],
          ...,
          [-8.1627e-02, -1.0257e+00, -6.6449e-01,  ...,  6.6261e-01,
           -1.8242e-01, -5.9660e-02],
          [ 9.9366e-01, -2.8990e+00, -4.2770e-01,  ...,  1.5473e+00,
           -2.7730e-01,  1.0212e+00],
          [ 3.7402e-01, -1.2451e+00, -8.3321e-01,  ...,  1.5307e+00,
           -6.0831e-01, -1.0434e+00]],

         [[-5.0563e-01,  3.4884e-01, -4.0126e-01,  ...,  1.2945e+00,
           -5.5872e-01, -4.4031e-01],
          [-1.0783e+00, -1.0583e+00, -8.7019e-01,  ...,  9.3939e-01,
            6.1988e-01, -3.6133e-01],
          [-1.4605e+00,  7.9834e-04, -1.6445e+00,  ...,  8.5405e-01,
            1.1266e+00,  2.1244e-01],
          ...,
          [-1.7653e+00, -4.5490e-01,  5.8049e-01,  ...,  1.3604e-01,
           -2.6502e-01,  1.4497e+00],
          [-2.7539e+00, -1.9189e+00, -6.1803e-01,  ...,  2.3083e+00,
           -6.2625e-01, -5.0954e-01],
          [-8.4786e-01, -9.9176e-01, -1.4226e+00,  ...,  1.0424e+00,
            1.2138e+00, -6.2367e-01]],

         [[ 1.3477e+00,  3.0343e+00,  3.7258e+00,  ...,  6.1286e-01,
            1.7142e+00, -7.4960e-01],
          [-3.4424e+00,  2.1578e+00, -3.4773e+00,  ..., -1.7704e+00,
            3.4858e+00,  9.8086e-01],
          [-3.3403e+00,  7.3066e-01, -4.6132e+00,  ..., -3.2065e+00,
            5.3039e+00,  7.1677e-01],
          ...,
          [-4.8998e+00, -5.9784e-01, -2.9574e+00,  ..., -4.1010e+00,
            2.4786e+00,  2.7664e-02],
          [-3.3274e+00, -1.2454e+00, -5.1031e+00,  ..., -3.2964e+00,
            3.3057e+00,  1.4853e+00],
          [-4.2024e+00, -1.7287e+00, -5.1702e+00,  ..., -2.7123e+00,
            2.8922e+00,  1.8391e+00]],

         ...,

         [[ 1.3818e+00, -2.7867e+00, -2.6519e+00,  ...,  9.1555e-01,
            4.4077e-01,  2.7028e+00],
          [-2.4026e+00,  1.6620e+00, -4.5219e-01,  ...,  1.2064e-01,
           -1.6484e+00,  5.6717e-01],
          [-1.7379e+00,  2.8888e+00,  2.1535e-01,  ..., -7.8397e-02,
           -2.7045e+00, -3.0823e-03],
          ...,
          [-2.9426e+00,  3.5565e+00,  1.0280e+00,  ..., -3.5420e-01,
           -3.7917e+00, -7.8773e-01],
          [-2.8640e+00,  2.8314e+00,  2.3865e+00,  ..., -2.2468e+00,
           -4.0705e+00, -1.2861e+00],
          [-3.9137e+00,  4.3675e+00,  1.5171e+00,  ..., -6.0161e-01,
           -2.7414e+00, -1.2265e+00]],

         [[ 1.7415e+00,  4.5990e-01,  9.3163e-01,  ...,  1.2650e-03,
           -9.8961e-01, -2.9552e-01],
          [ 2.2626e+00,  1.0377e+00,  1.1163e+00,  ...,  3.4995e-01,
           -2.5767e+00, -1.2164e+00],
          [ 2.0896e+00,  6.8649e-01,  1.2068e+00,  ...,  4.1762e-01,
           -2.1005e+00, -1.2765e+00],
          ...,
          [ 1.8625e+00,  5.6272e-01,  1.1284e+00,  ...,  3.5132e-01,
           -2.0787e+00, -1.0202e+00],
          [ 2.2705e+00,  3.2166e-01,  1.1907e+00,  ...,  2.6156e-01,
           -1.2966e+00, -9.9152e-01],
          [ 2.3024e+00,  4.0813e-01,  9.6441e-01,  ...,  4.9377e-01,
           -2.5960e+00, -6.9144e-01]],

         [[-2.2407e-01,  1.4293e-01, -5.5406e-01,  ...,  3.1676e-01,
            2.7494e-01,  1.6436e-01],
          [-5.7508e-01,  6.1265e-01, -2.6713e-01,  ...,  8.0278e-01,
            8.5041e-01,  1.8214e-01],
          [ 6.2629e-01,  3.5029e-02,  8.6408e-02,  ...,  4.6667e-01,
            1.6070e-01,  1.2988e-01],
          ...,
          [ 1.5542e-01, -2.5139e-01, -8.1318e-01,  ...,  2.1838e-01,
            2.0266e-01,  6.9734e-01],
          [-2.4867e-01,  4.2143e-01, -4.6590e-01,  ...,  3.0348e-01,
            5.7653e-01, -5.7979e-01],
          [-4.1779e-01, -4.9530e-01, -6.0749e-01,  ...,  5.8660e-01,
            9.1405e-01, -3.4966e-02]]]], grad_fn=<PermuteBackward0>), tensor([[[[-1.5059e-02, -2.1934e-02, -1.3257e-01,  ..., -3.3233e-03,
            5.6872e-03, -5.5921e-01],
          [-4.4076e-01,  4.7031e-01, -2.1116e-01,  ...,  5.7315e-01,
           -3.8024e-01,  2.5338e-01],
          [ 2.7640e-01,  1.0290e-01, -1.5030e-01,  ...,  8.0443e-02,
           -1.0340e-02,  6.5651e-01],
          ...,
          [ 7.7904e-01,  1.2082e+00,  3.0358e-01,  ...,  4.4578e-01,
           -4.0582e-02,  8.5044e-01],
          [-2.0731e-01, -5.8119e-01,  4.1100e-01,  ..., -1.7157e-01,
            2.8487e-01,  6.4911e-01],
          [-8.6411e-01,  5.4967e-01, -4.1298e-01,  ...,  9.2813e-01,
           -4.2606e-01, -3.4161e-01]],

         [[ 3.8557e-02,  3.3662e-03,  5.4482e-02,  ..., -5.7578e-02,
           -7.4123e-02,  2.2392e-02],
          [ 1.9386e-01,  1.8534e-01,  3.0680e-01,  ..., -1.2764e-03,
           -2.5348e-01,  8.6118e-02],
          [-1.4242e-01,  3.2992e-01,  7.6395e-02,  ...,  9.8633e-02,
           -5.6915e-02,  4.4799e-02],
          ...,
          [-7.1944e-02,  3.8884e-02,  1.0161e-01,  ..., -2.7253e-01,
            1.3398e-01,  1.1796e-01],
          [-1.0896e+00,  2.1403e+00, -1.3890e-01,  ...,  1.0035e+00,
            6.1333e-01, -1.1536e+00],
          [ 6.1611e-02,  7.1527e-02,  2.0043e-01,  ..., -3.5723e-01,
           -1.4230e-01,  8.4502e-02]],

         [[ 1.1201e-02, -7.6654e-01, -1.1583e-02,  ...,  4.3143e-02,
            1.5736e-02, -5.8100e-02],
          [ 2.8462e-01, -1.0610e+00,  1.2486e-01,  ...,  3.1588e-02,
           -1.1913e-01, -4.8153e-02],
          [ 2.6008e-01, -6.3008e-01, -8.1709e-01,  ...,  1.8586e-01,
            3.4370e-01,  9.2477e-01],
          ...,
          [-1.9891e-01, -1.9001e+00, -4.4621e-02,  ...,  7.8242e-02,
            2.2361e-02,  1.3589e-02],
          [-2.8968e-01, -1.5899e+00,  9.2801e-02,  ..., -2.7827e-01,
            1.6159e-01, -4.6007e-01],
          [ 1.6971e-01, -1.5136e+00,  1.2845e-01,  ..., -6.2768e-02,
           -2.5769e-01, -1.5622e-01]],

         ...,

         [[ 1.6522e-02, -7.7326e-02,  1.3163e+00,  ..., -5.6423e-02,
            1.7141e-01,  2.1386e-02],
          [-4.3988e-01, -2.9255e-01,  2.4116e+00,  ..., -1.8846e-01,
            1.0912e-01,  1.4147e-01],
          [ 2.3190e-01, -1.5369e-01,  2.5701e+00,  ...,  6.3039e-01,
           -1.0088e-01,  5.1586e-01],
          ...,
          [ 1.7250e-02,  6.7580e-01,  2.5971e+00,  ..., -5.2273e-01,
            4.5050e-01, -6.9956e-01],
          [-4.9545e-02,  6.1819e-01,  3.8825e-01,  ..., -1.4691e-01,
            4.5526e-01,  7.1271e-01],
          [-1.9639e-01, -1.2515e-01,  2.5813e+00,  ..., -1.8536e-01,
           -1.3485e-01, -8.7375e-02]],

         [[ 7.4395e-02, -8.7165e-02, -1.8260e-01,  ...,  1.3185e-01,
            1.2575e-01,  1.7169e-01],
          [ 6.5960e-01,  1.0117e+00,  7.1659e-01,  ...,  8.3512e-02,
           -6.5585e-01, -3.3111e-01],
          [ 3.2666e-01, -1.2571e-01,  8.1719e-01,  ...,  9.9527e-01,
           -1.0291e+00, -5.0537e-01],
          ...,
          [-7.2666e-01,  1.0662e-01, -7.2195e-02,  ..., -2.7005e-01,
            5.2628e-01,  2.3005e-01],
          [-2.0959e-01, -2.3959e-01, -3.0772e-01,  ...,  4.6964e-01,
           -1.8979e-01, -2.7418e-01],
          [ 1.7468e-01,  1.0415e+00,  7.5772e-01,  ..., -4.9262e-01,
           -8.0868e-01,  4.5074e-01]],

         [[ 1.1606e-02,  2.1828e-02,  2.7971e-02,  ..., -3.3218e-02,
            2.2172e-01, -2.3344e-03],
          [ 1.1778e-01, -3.0263e-01,  3.5408e-01,  ..., -3.3052e-01,
           -1.9086e+00,  4.3385e-01],
          [-7.0245e-01,  4.2293e-02, -1.3216e-01,  ...,  3.4737e-01,
           -1.4905e+00,  3.5105e-01],
          ...,
          [ 2.1967e-01, -6.0979e-01, -6.8996e-01,  ...,  4.4944e-01,
           -1.9601e+00, -1.7819e-01],
          [ 3.8903e-01,  1.9728e-01, -9.0256e-01,  ...,  1.3781e-01,
           -2.0059e+00,  3.0071e-01],
          [ 5.9661e-01, -3.1890e-01, -2.2125e-01,  ...,  2.8531e-01,
           -1.8048e+00,  2.1086e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[ 0.0532, -0.2197,  0.1445,  ..., -0.8884,  0.7361, -1.2044],
          [-0.6462, -0.7026, -1.4285,  ...,  0.2179, -0.3014,  0.1623],
          [-0.8909, -1.9166, -1.3314,  ..., -1.8027, -2.7636,  2.9528],
          ...,
          [-1.9169, -0.2602, -0.2397,  ..., -0.4901, -0.8816,  0.7061],
          [-2.0792,  0.1064,  0.6011,  ...,  0.5948, -0.5403,  1.4379],
          [-0.4271, -0.4968, -0.0297,  ...,  1.0395, -0.3829,  0.3067]],

         [[ 0.7842,  0.1905,  0.0089,  ..., -0.1612, -1.0898, -0.1939],
          [-1.3909, -1.5235, -0.5037,  ...,  0.9582,  4.2044,  1.1825],
          [ 0.1689, -1.8025,  0.8404,  ...,  1.5177,  5.7815,  2.1470],
          ...,
          [ 1.2462, -1.4013, -1.2263,  ...,  0.5912,  6.0711,  1.7328],
          [ 1.4548, -2.0760, -2.0483,  ..., -1.5971,  5.6172,  2.5548],
          [-1.1053, -0.8554, -2.0471,  ...,  0.8743,  6.2095,  1.1606]],

         [[ 0.3413, -0.3572, -0.3331,  ...,  0.3294,  1.4604,  0.2755],
          [ 0.0960, -6.2139, -0.6779,  ..., -2.8446, -1.4388, -4.4836],
          [-0.8714, -7.8835, -1.6969,  ..., -2.1200, -2.1704, -7.2160],
          ...,
          [-3.2255, -7.0802, -1.8176,  ..., -2.8620, -2.7388, -5.1880],
          [-2.2788, -5.5723, -1.6649,  ..., -3.3594, -2.4676, -5.1028],
          [-2.9788, -7.2411, -1.0434,  ..., -3.2540, -2.9263, -5.0116]],

         ...,

         [[ 0.2148,  1.7719,  0.5129,  ...,  0.2612,  0.4477, -1.6895],
          [-0.2874, -5.8026,  1.1293,  ..., -2.2826, -1.7007,  5.5452],
          [-2.4104, -6.5778,  1.1952,  ..., -2.4193, -0.3969,  3.8159],
          ...,
          [-1.4026, -7.7514,  1.2659,  ..., -3.4256, -2.3786,  6.9488],
          [-1.0623, -5.7453,  0.1012,  ..., -0.5622, -2.4292,  6.8565],
          [-0.3079, -7.9204,  1.8029,  ..., -3.2453, -2.3462,  7.0537]],

         [[ 0.0559, -0.0269,  0.1386,  ..., -0.1165, -0.0882, -0.1612],
          [ 0.1342, -0.5329, -0.2255,  ..., -1.0159,  0.1003, -0.4600],
          [-0.7412, -0.2755,  0.1787,  ..., -0.8159, -0.9071, -0.1041],
          ...,
          [-0.0215, -0.5192, -0.2004,  ...,  0.3272, -0.3216,  0.5758],
          [ 0.2406, -0.3252,  0.3839,  ..., -0.2115,  0.3593, -0.6457],
          [-0.6898, -1.1861, -0.0238,  ...,  0.5217,  0.0940,  0.9089]],

         [[ 0.3939, -0.0741,  1.9091,  ..., -0.2314, -0.2112, -0.9825],
          [ 2.5678,  1.8706, -2.0184,  ...,  0.0582,  0.5182,  2.5282],
          [ 3.1803,  2.0001, -2.9358,  ...,  2.6552,  1.0590,  4.2195],
          ...,
          [ 2.6593,  1.2215, -2.5623,  ...,  1.4338,  0.6112,  3.2894],
          [ 1.1448,  0.9766, -2.1789,  ...,  1.8788,  0.3242,  3.7226],
          [ 2.8828,  1.7918, -3.5229,  ..., -0.0936,  0.5881,  4.5368]]]],
       grad_fn=<PermuteBackward0>), tensor([[[[ 4.5583e-02,  6.3733e-02, -3.4908e-03,  ...,  3.8117e-03,
            1.0385e-01,  2.3468e-02],
          [ 2.7193e-01, -2.7436e-01,  8.2051e-01,  ..., -5.7602e-01,
           -6.8246e-02, -1.7190e-02],
          [-3.3605e-01, -9.2270e-01, -2.9339e-01,  ...,  3.2747e-02,
            5.3266e-01, -1.1793e+00],
          ...,
          [ 4.6659e-01,  1.6959e-01, -6.8990e-02,  ...,  3.2092e-01,
            2.9894e-02,  4.5212e-03],
          [-3.8838e-01, -2.8303e-01,  6.4867e-01,  ...,  5.4443e-01,
           -3.8750e-03, -7.7317e-01],
          [-1.4669e-01, -2.2234e-01,  5.0309e-01,  ..., -2.0195e-01,
           -3.4870e-02,  1.0260e+00]],

         [[-3.8964e-02, -8.6139e-03,  9.1636e-02,  ..., -4.5061e-02,
           -1.8257e-02, -4.4496e-02],
          [-1.2398e-01, -4.6354e-01,  6.3162e-02,  ...,  4.1472e-01,
           -8.8383e-02, -6.1835e-02],
          [ 2.3124e-01, -4.1944e-01, -5.5628e-02,  ..., -6.5586e-01,
           -2.9434e-01,  1.1322e-01],
          ...,
          [ 9.0615e-02, -2.5366e-01, -1.7453e-01,  ...,  3.6981e-02,
            9.6252e-02,  2.8861e-01],
          [ 2.6449e-01, -1.1997e+00, -2.9121e-01,  ...,  1.8929e-01,
            8.9705e-01,  5.2265e-02],
          [ 1.8653e-01, -4.1886e-01, -2.5386e-01,  ...,  5.6907e-01,
           -5.6461e-01, -2.9499e-01]],

         [[ 4.5761e-02, -1.1113e-01, -6.0327e-02,  ..., -1.7311e-02,
            8.8352e-02, -1.4918e-01],
          [ 3.5832e-01,  1.0048e-01, -3.5981e-01,  ...,  4.7004e-01,
           -1.0480e-01, -9.6169e-01],
          [-1.2025e+00, -4.9562e-01, -5.6530e-01,  ..., -7.7073e-02,
           -1.8603e-01,  4.5677e-02],
          ...,
          [-1.1527e-01, -1.2046e-02,  7.9755e-01,  ...,  2.0678e-01,
           -1.6562e-01, -9.4135e-02],
          [ 3.0203e-01, -5.3025e-02,  1.0025e-01,  ..., -1.3117e-01,
           -3.9940e-01,  2.0309e-01],
          [ 5.4948e-01, -3.1714e-03, -9.9666e-01,  ...,  3.6800e-01,
            2.6345e-01, -6.6638e-01]],

         ...,

         [[-2.2513e-02,  1.1954e-01, -1.7875e-02,  ..., -1.4198e-02,
            6.4433e-02, -5.2401e-02],
          [-6.7643e-03,  1.3038e-01, -3.1770e-02,  ..., -2.8075e-02,
           -7.0123e-02,  2.9359e-01],
          [ 7.8513e-01, -7.9053e-01, -1.5511e-01,  ..., -3.0193e-01,
           -5.3295e-02,  5.1889e-01],
          ...,
          [-1.9707e-01,  5.0177e-02, -1.1185e-01,  ..., -3.0111e-01,
            2.1017e-01, -2.7775e-01],
          [-3.1374e-01, -3.5912e-02, -2.5133e-01,  ..., -1.2073e-01,
            1.3938e-01, -1.4568e-01],
          [-1.2432e-01,  3.0442e-01,  1.0542e-01,  ...,  2.1967e-02,
            3.2316e-02,  1.2676e-01]],

         [[-1.7366e-01, -1.3407e-01, -6.7815e-02,  ..., -2.3521e-01,
           -1.8675e-02, -5.1927e-02],
          [ 4.8318e-01, -4.9988e-01,  7.3483e-01,  ...,  1.7037e-01,
            6.2192e-01,  2.3596e-01],
          [ 1.1730e-01,  3.0694e-02,  7.3273e-01,  ...,  5.0575e-01,
            3.1356e-02, -5.0081e-01],
          ...,
          [ 6.1899e-01, -9.2282e-01,  1.6701e-01,  ..., -2.4323e-02,
            1.7694e-01, -3.4102e-01],
          [ 8.4867e-01,  1.2311e-01,  3.3463e-01,  ...,  3.2204e-01,
            8.6678e-01,  5.9980e-01],
          [ 4.1040e-01,  1.7545e-01,  2.0518e-01,  ..., -9.3810e-01,
            4.8850e-01, -5.4087e-01]],

         [[ 1.1481e-01, -7.4767e-02, -2.5446e-02,  ..., -1.8679e-02,
           -9.1254e-02, -9.6947e-02],
          [ 5.5079e-01,  1.9193e-01,  1.8251e-04,  ..., -1.0992e-02,
           -2.6968e-01, -3.8421e-02],
          [-1.8607e-01, -8.5692e-02,  3.1742e-01,  ..., -3.9823e-01,
            4.3919e-01, -8.0165e-02],
          ...,
          [-1.6626e-01, -1.0646e+00, -1.0149e-02,  ..., -9.7871e-02,
            1.4443e-01, -1.5419e-01],
          [-4.4313e-01, -1.3310e-01,  4.2125e-01,  ...,  4.0301e-02,
           -1.7659e-01,  3.1838e-01],
          [ 8.1519e-01,  2.4844e-01,  1.2036e-01,  ..., -9.9506e-02,
           -2.9214e-01,  5.8580e-02]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-8.8678e-01, -1.3593e-01,  3.3093e-01,  ..., -9.5576e-01,
            2.5192e-02, -2.9464e+00],
          [ 1.5584e+00, -5.3821e-01, -2.4421e+00,  ..., -1.7774e+00,
           -1.4069e+00,  7.2554e+00],
          [-6.1613e-02, -9.6410e-01, -3.3367e+00,  ..., -1.7228e+00,
           -7.6467e+00,  6.5063e+00],
          ...,
          [ 1.2191e+00, -3.6478e-01, -1.8077e+00,  ..., -1.4126e+00,
           -3.4429e+00,  1.1099e+01],
          [ 1.2810e+00, -4.1117e-01, -4.4152e+00,  ..., -1.0298e+00,
           -2.3506e+00,  1.1191e+01],
          [ 1.5495e+00, -1.9605e+00, -3.1807e+00,  ..., -9.8794e-01,
           -2.1888e+00,  9.4760e+00]],

         [[ 3.7499e-01, -6.6046e-02,  4.5773e-01,  ..., -1.2836e-01,
           -7.7381e-02, -2.2161e+00],
          [-1.9084e+00, -5.1770e-01,  3.3306e+00,  ..., -1.0169e-01,
           -2.0618e+00,  7.5854e+00],
          [-3.1865e+00, -5.3798e-01,  3.4467e+00,  ...,  8.8427e-02,
           -4.1777e+00,  7.7792e+00],
          ...,
          [-2.9382e+00, -8.8965e-01,  3.4723e+00,  ..., -1.4002e+00,
           -5.7932e-01,  6.9011e+00],
          [-3.7302e+00, -1.4835e+00,  7.7318e-01,  ..., -1.4177e+00,
           -1.5522e+00,  7.3279e+00],
          [-2.4526e+00, -1.8321e+00,  3.6389e+00,  ..., -4.4448e-01,
           -1.6136e+00,  6.6650e+00]],

         [[ 1.2211e-01, -6.5015e-01, -2.2831e-01,  ...,  1.4110e-01,
            2.7893e-01, -1.7424e-01],
          [ 1.7771e-01,  1.7629e+00,  6.3257e-01,  ..., -2.6582e-01,
            6.2577e-01,  5.0930e-02],
          [ 2.2530e-01,  3.0012e+00,  5.3516e-01,  ..., -3.2276e-01,
            5.9087e-01, -3.6453e-02],
          ...,
          [-6.4210e-01,  3.1597e+00,  2.3032e-01,  ...,  6.4203e-01,
            1.9326e-01,  5.4560e-01],
          [-4.8734e-01,  2.4240e+00,  1.1159e-01,  ...,  9.6528e-01,
            1.2245e+00, -1.7901e+00],
          [ 2.7319e-01,  2.8160e+00,  6.3444e-01,  ..., -5.1675e-01,
           -1.5301e-01, -8.1118e-01]],

         ...,

         [[-4.0181e-01,  1.2737e-02, -1.1140e-02,  ...,  1.2548e+00,
            4.3199e-02,  1.8033e+00],
          [-4.6139e-01, -1.3921e+00, -1.4511e+00,  ..., -2.5093e+00,
           -1.6920e+00, -2.7131e-01],
          [-1.3954e-01,  3.9872e-01, -5.5181e-01,  ..., -4.0252e+00,
           -1.2034e+00, -8.0604e-01],
          ...,
          [ 3.8913e-01, -9.2129e-01,  6.7512e-01,  ..., -3.2734e+00,
           -3.7855e-01, -1.2775e+00],
          [ 3.6478e-01,  1.1098e+00,  1.9589e+00,  ..., -1.2581e+00,
           -9.2984e-01, -1.5476e+00],
          [-2.0390e-01, -6.6112e-01, -9.6914e-01,  ..., -3.2531e+00,
           -3.5533e-01, -3.5020e-01]],

         [[-3.3790e-01, -1.2825e-01,  2.2242e-01,  ...,  2.6358e-01,
           -2.9314e-02,  3.1528e-02],
          [-6.0304e-01, -1.1295e+00,  1.4573e+00,  ...,  7.0224e-01,
           -8.5480e-01,  1.8017e-01],
          [ 9.3104e-01, -2.1456e+00,  3.8324e-01,  ...,  9.3967e-01,
           -8.2110e-01,  1.3123e-01],
          ...,
          [-7.5492e-01, -1.8400e-01,  3.3456e-01,  ...,  1.7404e+00,
            7.1590e-01,  1.3268e+00],
          [-1.5429e-01,  5.3506e-01,  2.4561e+00,  ...,  1.2834e+00,
            5.7729e-01,  1.3149e+00],
          [-7.7036e-01, -6.9287e-01,  1.1238e+00,  ...,  1.0106e+00,
           -5.3742e-01,  1.3852e+00]],

         [[ 3.4402e+00,  2.1226e+00, -2.1050e+00,  ..., -2.8555e+00,
           -3.9038e+00, -1.2060e+00],
          [-3.0643e+00, -1.6132e+00,  4.7811e+00,  ..., -2.6905e+00,
            9.4376e+00, -3.7636e+00],
          [-2.8029e+00,  9.2815e-01,  2.2908e+00,  ..., -3.5372e+00,
            9.2503e+00,  2.0644e+00],
          ...,
          [-4.9774e+00, -1.8169e+00,  4.4703e+00,  ..., -4.3005e+00,
            1.5492e+01,  3.7749e+00],
          [-2.4577e+00, -1.8796e+00,  6.0842e+00,  ..., -4.6722e+00,
            9.1210e+00,  1.8122e+00],
          [-4.4732e+00, -2.0733e+00,  7.2062e+00,  ..., -3.7151e+00,
            1.2814e+01, -2.2193e+00]]]], grad_fn=<PermuteBackward0>), tensor([[[[-0.0028, -0.0602,  0.0219,  ...,  0.0593,  0.0264,  0.0681],
          [ 0.6692, -0.1774,  0.2994,  ...,  0.1940, -0.3524, -0.1093],
          [ 0.0441, -0.6776, -0.4458,  ...,  0.2746,  0.9155, -0.5374],
          ...,
          [-0.3262, -0.0103, -0.0866,  ...,  0.0454, -0.1561,  0.2205],
          [-0.0552, -0.6212, -0.4492,  ..., -0.2533,  0.0952, -0.2438],
          [ 0.1740,  0.0146, -0.0917,  ...,  0.1930, -0.1700,  0.1307]],

         [[-0.0538, -0.0195, -0.1417,  ..., -0.0445,  0.0476, -0.0319],
          [ 0.3175, -0.1990, -0.2276,  ...,  0.1004, -0.0740, -0.1226],
          [ 0.3296, -0.6555, -0.2850,  ..., -0.8669,  0.2712,  0.0552],
          ...,
          [-0.0141,  0.1838,  0.2267,  ...,  0.0249, -0.0362,  0.3883],
          [-0.2939, -0.5590,  0.3243,  ..., -0.0678,  0.0157, -0.5514],
          [-0.0048, -0.0914, -0.2181,  ..., -0.2868,  0.0018, -0.0651]],

         [[ 0.0639,  0.0961,  0.0831,  ...,  0.0160, -0.0859, -0.0050],
          [-0.8685, -0.1267, -0.8107,  ...,  0.0526, -0.7176, -0.0689],
          [ 0.1621,  0.2253,  0.0752,  ...,  0.1041, -0.4005,  0.1818],
          ...,
          [-0.4981,  0.5339, -0.4980,  ..., -0.2581, -0.8093, -0.3876],
          [-0.6054,  1.6497,  1.0752,  ..., -1.0363,  0.7149, -0.6451],
          [-0.4706,  0.3250,  0.3061,  ...,  0.4489, -0.6589,  0.0312]],

         ...,

         [[-0.0115,  0.0657, -0.0777,  ...,  0.0440,  0.0456, -0.1384],
          [ 0.0136, -0.3035,  0.8164,  ..., -0.2084, -0.8236,  0.4428],
          [ 0.2521, -0.4054,  0.2197,  ...,  0.1480,  0.2216,  0.5164],
          ...,
          [-0.0778, -0.1247, -0.3227,  ...,  0.1474,  0.1483,  0.3701],
          [-0.3559, -0.8621, -0.0799,  ..., -0.9994,  0.4109,  0.2198],
          [-0.1967,  0.0573,  0.6049,  ...,  0.1913,  0.0767, -0.0245]],

         [[-0.1315, -0.0534,  0.0947,  ..., -0.0666,  0.0539, -0.0204],
          [ 0.0918, -0.3386, -0.7173,  ..., -0.2867, -0.0289, -0.1466],
          [ 0.2971,  0.6579, -0.9279,  ..., -0.0267, -1.3269,  0.6167],
          ...,
          [-0.1993,  0.8396,  0.5954,  ..., -0.2100,  0.3891,  0.5287],
          [ 1.5998,  0.6881, -0.2637,  ...,  1.1610,  0.1208, -0.6552],
          [ 0.5209, -0.3917,  0.1674,  ..., -0.2824,  0.0700, -0.3138]],

         [[-0.0193, -0.0120, -0.0240,  ..., -0.0300,  0.0080, -0.0136],
          [-0.0432, -0.3667, -0.3346,  ..., -0.1011,  0.0167,  0.1537],
          [-0.3303, -0.6508, -0.2167,  ..., -0.6360, -0.1999, -0.1340],
          ...,
          [-0.0058, -0.1530, -0.3235,  ..., -0.3699,  0.0510,  0.1209],
          [ 0.1009,  0.4467, -0.0791,  ..., -0.2715, -0.2259,  0.5418],
          [ 0.0141,  0.2831, -0.4868,  ..., -0.1903, -0.1869,  0.7274]]]],
       grad_fn=<PermuteBackward0>)), (tensor([[[[ 3.4512e-02, -2.8466e-01,  2.2210e-01,  ...,  1.6982e+00,
           -2.2029e-01, -8.0207e-02],
          [ 1.1887e+00,  1.2521e+00, -2.0973e-02,  ..., -2.7505e+00,
            1.1517e-01, -9.4738e-01],
          [-3.7887e-01,  2.3147e-01,  7.8851e-01,  ..., -3.8859e+00,
           -1.2610e+00, -1.5381e+00],
          ...,
          [-3.8848e-01, -7.8692e-01,  6.0321e-01,  ..., -1.5790e+00,
           -4.4260e-01, -1.7360e+00],
          [-1.3567e+00, -1.2212e-02,  4.0693e-01,  ..., -2.6267e+00,
            3.1883e-01, -1.1768e+00],
          [-1.3149e+00,  5.3910e-01,  8.4051e-01,  ..., -2.6472e+00,
           -8.0766e-02, -1.3063e+00]],

         [[ 1.5566e-01,  9.6884e-01, -1.4234e+00,  ..., -1.1945e-01,
            2.6095e-01,  9.2861e-01],
          [-1.1655e+00, -5.3317e+00,  7.2065e-01,  ..., -1.4863e+00,
           -2.2354e+00, -2.4988e+00],
          [ 1.2192e+00, -4.3649e+00,  9.3857e-01,  ...,  3.6005e-01,
           -1.0827e+00, -2.1299e+00],
          ...,
          [-6.5003e-01, -3.6931e+00,  4.9255e-01,  ..., -2.0790e+00,
           -3.1514e-01, -2.7136e+00],
          [ 5.9668e-01, -3.1527e+00,  7.6608e-01,  ..., -4.4680e-01,
           -1.1040e-01, -1.9393e+00],
          [ 2.0418e+00, -5.3709e+00,  5.4901e+00,  ..., -2.3439e-02,
            4.6572e-01, -3.8706e+00]],

         [[-6.7068e-01,  2.4994e-01, -5.6570e-02,  ...,  1.7880e-01,
            5.6148e-02, -2.9901e-01],
          [ 1.9676e+00,  2.9566e-02, -8.5660e-01,  ..., -1.8619e+00,
           -3.3802e-01,  1.6140e-01],
          [ 2.1615e+00, -7.5559e-01,  3.4024e-01,  ..., -1.4898e+00,
            4.2649e-01,  1.5977e+00],
          ...,
          [ 1.1094e+00, -8.7126e-01,  4.4787e-02,  ..., -4.0946e-01,
           -6.8646e-01, -5.1147e-01],
          [ 1.3666e+00, -6.3472e-01, -6.9747e-01,  ...,  6.0671e-01,
            2.1492e+00, -3.3250e-01],
          [ 2.1474e+00, -4.8501e-02, -8.7131e-01,  ..., -1.4417e+00,
            1.5616e+00,  1.8827e-01]],

         ...,

         [[-4.0635e-02,  1.1188e-01,  1.4037e-01,  ..., -9.7647e-02,
            1.2961e-02,  1.5000e-01],
          [ 7.8730e-01, -5.6138e-01, -1.2585e+00,  ...,  1.1703e+00,
           -1.7229e-01,  1.2928e+00],
          [-1.0394e-01,  1.4770e-01,  3.8454e-01,  ...,  6.5685e-01,
           -2.6355e-01,  1.3102e+00],
          ...,
          [ 1.1297e-01,  1.4229e+00,  2.8362e-01,  ...,  9.3448e-01,
            2.5909e-01,  2.9945e-01],
          [ 6.5882e-01,  7.3874e-01, -5.1318e-01,  ...,  9.5171e-01,
            1.6892e-01, -1.7952e-01],
          [ 4.4172e-01,  9.7651e-02, -1.4498e+00,  ...,  1.2877e+00,
            7.8737e-01,  5.7300e-02]],

         [[-3.0020e+00,  4.0418e-01, -2.7798e-02,  ..., -4.8566e-01,
           -3.4500e-01,  1.2311e+00],
          [ 4.8764e+00,  1.4500e+00, -1.1937e+00,  ..., -1.6858e+00,
            3.0943e-01, -9.1063e-01],
          [ 4.6146e+00,  9.6566e-01, -5.1178e-01,  ..., -2.1980e-01,
            1.1130e+00, -1.2746e+00],
          ...,
          [ 4.9677e+00,  2.5583e-02, -1.3527e+00,  ..., -1.8770e+00,
           -6.6969e-01, -4.0065e-01],
          [ 4.3137e+00,  1.0467e+00, -1.5161e+00,  ..., -2.2238e+00,
           -1.7302e-01, -1.6034e-01],
          [ 4.6436e+00,  9.9926e-01, -5.2100e-01,  ..., -1.5177e+00,
            1.9258e-01, -2.4487e-01]],

         [[-7.4442e-03, -2.5452e-01, -1.9922e-04,  ..., -1.8494e-01,
            3.4208e-01,  9.0523e-02],
          [ 8.8014e-01, -3.2005e+00, -2.3284e-01,  ..., -5.6783e-01,
            5.3092e-01,  4.5332e-02],
          [-3.2605e-01, -1.7599e+00, -5.3681e-01,  ..., -5.2140e-01,
            1.7060e+00, -8.0691e-01],
          ...,
          [-1.1833e+00, -8.9443e-01,  5.9676e-01,  ...,  3.0636e-01,
            5.0886e-01, -1.5048e+00],
          [-1.2903e+00, -9.5492e-01,  2.1957e-01,  ...,  2.2938e+00,
           -5.0270e-01, -7.8764e-02],
          [ 4.4758e-01, -1.5906e+00,  1.4957e-01,  ...,  2.3779e+00,
           -2.2358e-01,  4.7562e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[-0.0224, -0.0239,  0.0031,  ..., -0.0027, -0.0209,  0.3535],
          [ 1.1117, -0.0250,  1.0920,  ..., -0.1345, -0.2322, -0.7385],
          [ 1.2585, -0.5406, -0.8740,  ...,  0.6211,  0.4854, -0.4785],
          ...,
          [ 0.8751, -0.3160, -0.5735,  ..., -0.2102, -0.0831, -0.5934],
          [ 0.0085, -0.3084,  0.1655,  ...,  0.4398,  0.5114, -0.4383],
          [-0.0725, -0.3939,  0.5899,  ...,  0.7469, -0.3640,  0.0679]],

         [[ 0.0048, -0.0161,  0.0186,  ..., -0.0150,  0.0150,  0.0090],
          [ 0.4621, -0.6415, -0.2005,  ...,  0.2446,  1.2697, -0.7838],
          [ 0.6805, -1.2565,  0.0765,  ..., -0.0242,  1.4869,  0.1836],
          ...,
          [-0.2538, -0.0022,  0.1847,  ...,  0.4838,  1.5106,  0.7886],
          [ 1.2671, -0.9662, -0.3248,  ...,  0.5432, -0.0319, -0.1366],
          [-0.1197, -1.6058, -0.3833,  ...,  0.3964,  1.0133, -0.1477]],

         [[-0.0603,  0.0030, -0.0383,  ..., -0.0468,  0.0119, -0.0780],
          [ 0.5506, -0.3951,  0.6694,  ..., -0.6748,  0.3026,  0.0286],
          [ 0.4687,  0.1415,  0.0033,  ...,  0.4084,  0.2910,  0.4103],
          ...,
          [ 0.4985,  0.4334,  0.3964,  ..., -0.2184, -0.0373, -0.0717],
          [ 0.0850, -0.4120, -0.2606,  ..., -0.2593,  0.7614, -0.8139],
          [ 0.0813, -0.2308,  0.9975,  ..., -0.3412, -1.0508, -0.9304]],

         ...,

         [[-0.3063, -0.1904, -0.0540,  ..., -0.4848,  0.2131,  0.1049],
          [ 1.7674, -1.5249,  1.8613,  ...,  1.4412, -0.3079,  0.1500],
          [ 0.3092, -1.5787,  0.6095,  ...,  0.5455,  0.1634,  1.3060],
          ...,
          [-1.9410, -1.8215, -0.4399,  ...,  0.3221, -0.1979, -1.4136],
          [ 1.2677, -1.9424,  0.0700,  ..., -0.9788, -0.6381, -0.4399],
          [ 0.8285, -1.8581, -0.3010,  ..., -1.3209,  0.2318, -0.1750]],

         [[-0.0861, -0.1412, -0.0534,  ..., -0.1797, -0.1466,  0.1142],
          [-0.7113, -0.5252, -0.7349,  ..., -0.0491,  0.5213, -0.7352],
          [ 0.4967, -1.1247, -0.6529,  ..., -0.4258, -0.1081, -0.2017],
          ...,
          [-0.4174, -1.3939,  0.0162,  ..., -0.2306, -0.4274,  0.3158],
          [-1.1609, -0.1209, -0.1991,  ...,  1.2310,  0.5859,  0.6733],
          [-1.1036, -0.5834,  0.1167,  ...,  0.8276, -0.1767,  0.3441]],

         [[-0.0294, -0.0414,  0.1069,  ...,  0.0614, -0.0412,  0.0239],
          [-0.6127, -0.0583, -0.7644,  ..., -1.4024, -0.9271,  0.9733],
          [ 0.5288,  0.2919,  0.0434,  ..., -0.4878, -0.6339,  0.4392],
          ...,
          [ 0.0769, -0.0123, -1.2272,  ...,  0.3366, -0.2014,  0.2725],
          [ 0.0642,  1.9300, -0.3253,  ..., -1.0578,  0.4355, -1.4476],
          [-0.5956,  0.2606, -0.5507,  ..., -0.5284, -0.1602, -0.7526]]]],
       grad_fn=<PermuteBackward0>)), (tensor([[[[-3.3938e-01,  8.6866e-01, -1.6351e-01,  ...,  1.1267e+00,
           -1.6784e-01,  1.2959e-01],
          [-3.9849e-01, -4.8487e+00, -2.9290e-01,  ..., -4.2862e+00,
            6.9337e-01,  6.4498e-01],
          [ 7.8937e-01, -3.8062e+00, -7.5906e-01,  ..., -4.4411e+00,
            9.3744e-01,  2.4774e+00],
          ...,
          [-2.9636e-01, -5.8016e+00,  1.4007e+00,  ..., -3.1025e+00,
           -2.7375e-01,  1.2819e+00],
          [ 1.3025e+00, -3.8148e+00,  1.8926e+00,  ..., -3.3508e+00,
            6.2647e-01,  5.1378e-01],
          [-3.7934e-01, -4.5341e+00,  6.0715e-01,  ..., -4.3161e+00,
            6.4808e-01,  7.9708e-01]],

         [[ 5.8217e-02,  8.7121e-01, -6.2251e-01,  ..., -2.4310e-02,
            2.9330e-01,  1.3199e-02],
          [ 1.3357e+00, -1.1482e+00,  1.2032e-01,  ...,  1.5088e+00,
           -1.0720e+00, -1.1527e+00],
          [ 2.6168e+00,  9.3244e-03, -7.2926e-01,  ...,  9.4531e-01,
           -7.9178e-01, -1.6888e+00],
          ...,
          [ 8.0452e-01, -3.9176e-01, -3.0347e-01,  ...,  1.3463e+00,
           -3.1319e-01, -1.3556e-01],
          [-7.1086e-01,  9.6997e-02,  1.2591e+00,  ...,  2.0719e-01,
            4.2983e-01, -6.3391e-01],
          [ 1.1039e-01, -1.3052e+00,  1.1124e-01,  ...,  1.3074e+00,
            1.4712e+00, -2.7487e-01]],

         [[-3.1165e-01,  1.2165e-01, -9.8370e-01,  ..., -3.5095e-01,
           -6.3912e-02, -1.3616e-01],
          [ 5.0049e-01, -6.6728e-01,  2.9285e+00,  ..., -3.9263e-01,
            4.3198e-01, -2.3447e-01],
          [ 1.2306e-01, -2.9766e-01,  3.6896e+00,  ..., -1.0091e-01,
           -2.5103e-01, -2.0315e-01],
          ...,
          [-2.1391e-01, -2.1547e+00,  2.8612e+00,  ...,  5.8855e-01,
           -1.9214e-01,  1.8883e+00],
          [-2.1992e-01, -1.4360e+00,  3.3444e+00,  ...,  9.8178e-01,
           -1.9441e+00,  5.7364e-01],
          [ 6.3090e-02, -1.4908e+00,  2.0854e+00,  ...,  1.4157e-01,
           -1.3972e-01, -6.9580e-02]],

         ...,

         [[ 3.7597e-01,  8.1398e-02, -6.4505e-02,  ..., -4.8594e-02,
            2.2536e-01,  4.1931e-03],
          [-1.2319e+00,  8.1079e-01, -5.4320e-01,  ...,  1.2257e-01,
           -7.8676e-02, -2.6823e-01],
          [-1.9185e-02,  5.4915e-01,  9.4312e-01,  ..., -2.6608e+00,
            3.8096e-01, -1.3816e+00],
          ...,
          [-1.5423e+00, -2.7545e-01,  2.9765e+00,  ...,  5.4036e-01,
            1.6682e+00, -7.5562e-01],
          [-1.2052e+00, -2.4065e-01,  4.7900e-02,  ..., -1.5625e+00,
            2.8238e-01, -3.3910e-01],
          [-1.7759e+00,  3.9760e-01, -1.0807e+00,  ..., -1.9584e+00,
           -1.1637e+00,  1.5918e+00]],

         [[ 2.0009e-01,  5.4941e-02,  3.2748e-01,  ...,  4.1661e-01,
           -3.4165e-03,  2.3171e-01],
          [ 1.6163e+00,  1.2442e+00,  2.8373e-01,  ..., -3.9689e-01,
            7.1320e-03, -1.1601e-01],
          [ 1.3228e+00,  1.4674e-01,  6.3871e-01,  ..., -5.9913e-02,
            1.6461e-01,  3.3509e-01],
          ...,
          [ 7.7162e-01,  7.9756e-01,  8.2908e-01,  ..., -1.0911e+00,
            8.8888e-01, -1.1994e+00],
          [ 1.6909e+00,  8.3524e-01,  6.7132e-01,  ..., -1.1008e+00,
           -7.2901e-01,  6.1303e-01],
          [ 2.8334e+00,  4.6555e-01,  1.2473e+00,  ..., -7.3844e-01,
           -7.0963e-01,  1.0278e-01]],

         [[-3.0156e+00,  5.3756e-01,  5.6815e-01,  ..., -9.3899e-01,
            3.2683e-01,  1.8463e-01],
          [ 7.7879e+00, -9.7524e-01, -2.1850e+00,  ...,  2.2429e+00,
           -1.0887e+00,  5.6749e-01],
          [ 6.9868e+00, -3.9651e-01, -9.7286e-01,  ...,  1.0613e+00,
           -7.0396e-01,  1.3823e+00],
          ...,
          [ 9.5361e+00, -8.7937e-01, -2.5252e+00,  ...,  1.3820e+00,
           -2.2409e+00,  2.4565e-01],
          [ 8.8630e+00, -1.1387e+00, -1.7681e+00,  ...,  1.0129e+00,
            2.0493e-01, -2.1170e-01],
          [ 9.5194e+00, -2.0795e-01, -1.6476e+00,  ...,  2.4340e+00,
           -1.9197e+00, -2.9640e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[ 4.6474e-02, -5.0378e-02,  1.0945e-02,  ..., -6.9955e-02,
            2.9789e-03, -1.0073e-01],
          [ 5.1051e-01, -5.5772e-01, -3.8570e-01,  ..., -3.2328e-01,
            2.3945e-01, -2.9826e-01],
          [-2.8010e-01,  7.4962e-01, -5.4584e-01,  ..., -3.6442e-01,
            4.2576e-01, -1.4805e+00],
          ...,
          [-9.0783e-01, -4.8128e-01, -1.8888e-01,  ..., -2.2824e-01,
           -7.4845e-02, -1.0972e+00],
          [-5.0702e-01,  1.0603e-01, -1.0484e+00,  ...,  5.5779e-01,
           -4.9793e-01, -9.2837e-01],
          [ 3.8714e-02,  4.2493e-01, -4.1890e-01,  ...,  5.6050e-01,
           -2.7279e-01, -1.3355e+00]],

         [[ 6.7674e-02,  3.0544e-02, -2.3115e-02,  ..., -4.3823e-02,
            5.2575e-03, -1.6795e-03],
          [-5.3669e-01,  1.7762e+00, -5.2043e-01,  ...,  7.5157e-01,
           -6.1868e-01, -7.3336e-01],
          [-2.5054e-01,  4.9751e-03, -8.3214e-02,  ..., -7.4598e-01,
           -6.1617e-01,  3.3602e-01],
          ...,
          [-1.4067e-01,  5.9621e-01,  1.0898e+00,  ...,  9.4066e-01,
           -1.3745e+00,  1.1213e+00],
          [-1.0630e+00, -5.0378e-01,  6.9651e-01,  ..., -4.6445e-01,
           -6.6259e-01,  1.7251e-01],
          [-1.5972e+00,  3.2659e-01,  3.4644e-01,  ...,  2.8986e-01,
           -5.7299e-01, -2.2912e-01]],

         [[ 7.0952e-02,  8.2320e-03, -1.6572e-03,  ...,  2.1678e-02,
           -6.7437e-02, -5.0287e-02],
          [ 7.4200e-01, -3.2418e-01,  4.1442e-01,  ..., -1.4945e-02,
            2.5678e-01,  1.5392e-01],
          [ 2.9304e-01,  5.7399e-01, -2.7184e-01,  ..., -1.4044e-01,
            6.1588e-02, -1.5561e-01],
          ...,
          [ 7.1019e-01, -8.5043e-01, -3.1989e-01,  ...,  2.5753e-01,
            2.2188e-01,  7.3108e-01],
          [ 7.1561e-01, -8.6057e-01,  9.2320e-01,  ...,  3.9957e-01,
            2.4226e+00,  1.6563e+00],
          [-7.6132e-02,  2.4041e-01,  9.3365e-01,  ..., -2.2613e-01,
            3.9552e-01,  1.0165e-01]],

         ...,

         [[-8.4018e-04,  4.2945e-02,  2.0029e-02,  ..., -6.6209e-02,
           -1.8070e-02,  2.2869e-02],
          [-1.4168e+00,  2.7825e-01,  3.5415e-02,  ...,  2.2794e-01,
           -1.8244e-01,  2.6631e-01],
          [-1.5832e+00,  6.7589e-01, -1.3738e-01,  ...,  7.5377e-01,
           -8.9247e-01,  8.4118e-01],
          ...,
          [-1.0343e+00,  2.2096e-01,  1.8098e-01,  ...,  1.5064e+00,
           -9.4570e-01, -9.6457e-01],
          [-5.5192e-01,  6.5732e-01, -7.3323e-01,  ...,  8.2586e-01,
            1.0773e+00, -5.0690e-01],
          [-6.9760e-01, -2.0758e-01,  2.9526e-01,  ..., -1.6063e-02,
            1.6516e-02,  4.3263e-01]],

         [[ 5.6418e-02, -6.3642e-03,  2.3703e-02,  ...,  1.7139e-02,
           -1.5312e-02,  6.8112e-03],
          [ 1.8381e+00, -1.3941e+00, -1.0189e+00,  ..., -9.4177e-01,
            4.2883e-01,  8.2570e-01],
          [ 8.8893e-01, -1.6692e+00, -4.3398e-01,  ..., -1.2906e+00,
            1.0952e-01,  3.7169e-01],
          ...,
          [ 7.4024e-01, -1.4955e-01, -8.9148e-01,  ..., -1.0267e+00,
           -6.1569e-01,  5.8172e-01],
          [-7.3008e-01, -4.7314e-01,  3.7697e-01,  ...,  5.2418e-01,
           -1.6633e-01,  3.0198e-01],
          [ 6.6411e-02, -4.8074e-01, -4.0598e-01,  ...,  1.1196e-01,
            1.0054e+00, -4.4949e-01]],

         [[ 6.7269e-02, -2.0375e-01, -7.5082e-02,  ..., -4.0162e-02,
            1.9610e-01, -5.1942e-02],
          [ 3.7243e-01, -9.5645e-01, -3.3796e-01,  ..., -9.8523e-01,
           -4.3307e-01, -2.3109e-01],
          [-5.5909e-01, -9.8741e-01, -8.3997e-01,  ..., -4.0350e-02,
            2.2590e-03, -1.1709e+00],
          ...,
          [ 2.6116e-01, -1.7003e+00,  9.9667e-03,  ...,  2.5269e-01,
           -6.5086e-01, -5.0987e-01],
          [-2.2483e-01, -3.8567e-01, -1.6472e-01,  ..., -7.8707e-01,
            3.2198e-01, -4.2609e-01],
          [-1.5893e-01, -7.3543e-01, -4.9369e-01,  ..., -1.5504e+00,
           -3.8277e-01, -4.1377e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[ 1.0315, -0.2591, -0.1501,  ...,  0.6246,  0.7232, -0.3059],
          [-3.2631, -1.9435, -1.2089,  ...,  0.2806, -5.5510,  0.3489],
          [-4.3399, -1.4013, -0.2186,  ...,  0.7109, -5.7860, -0.6970],
          ...,
          [-5.3760, -3.7959, -0.3718,  ...,  0.7804, -4.6301, -1.2402],
          [-6.1584, -2.8255,  0.0772,  ...,  0.3908, -4.4567, -0.1920],
          [-4.6862, -2.2484, -0.4802,  ...,  1.1911, -4.6985, -1.0555]],

         [[-0.1525, -0.0745,  0.1651,  ..., -0.0464, -0.8761, -0.1921],
          [-0.6493,  1.0436, -0.2845,  ..., -0.2628,  0.0537,  0.8063],
          [-1.7708,  1.3885, -0.9440,  ...,  0.3637,  0.7435,  1.4247],
          ...,
          [-2.0241,  0.3328, -0.2828,  ...,  0.8545,  0.5231,  2.4687],
          [-2.8308, -0.2631, -0.4617,  ..., -0.3337,  1.8320,  2.9475],
          [-2.0453,  1.1846, -2.5580,  ...,  0.5495,  1.1092,  1.8249]],

         [[ 0.1958,  0.3039,  1.1389,  ..., -0.4691,  0.4513, -0.4878],
          [-1.4815, -0.5524, -1.6846,  ..., -0.5676, -1.8434,  2.4752],
          [-3.5171, -1.7341, -1.0781,  ..., -0.0126, -0.8584,  2.8363],
          ...,
          [-1.2945, -1.0943, -0.7373,  ...,  0.2280, -2.9008,  2.5152],
          [-2.2796, -0.5816, -0.3174,  ...,  0.7422, -1.4116,  2.2355],
          [-0.7958,  0.1943, -2.7152,  ...,  1.7208, -1.5123,  0.9313]],

         ...,

         [[-0.1518,  0.0675, -0.2341,  ...,  0.0125,  0.1685,  0.0227],
          [-1.5855,  0.1968, -0.4700,  ...,  0.9270, -1.3281, -0.1941],
          [ 0.7663, -0.7921, -0.5326,  ...,  0.9606, -0.0650, -0.2843],
          ...,
          [ 0.1914, -0.1551, -0.9815,  ...,  1.8034,  0.1310,  0.7172],
          [-1.2788, -1.7422, -0.4975,  ...,  1.3406, -0.4531, -0.5256],
          [-1.8526,  0.1496, -0.0816,  ...,  0.8122, -1.0543,  0.1050]],

         [[-0.3515, -2.1836,  0.1103,  ..., -0.0873, -0.0481,  0.9174],
          [-0.3931,  1.7304, -1.0893,  ..., -1.0898, -1.7984,  1.0287],
          [ 0.3552,  3.3603, -1.5929,  ..., -0.7109, -1.5203,  0.7090],
          ...,
          [ 0.4526,  3.6483, -3.1344,  ...,  1.3756, -1.8511,  2.2068],
          [ 1.4022,  2.2589, -2.0330,  ...,  0.3515, -0.4796,  0.9019],
          [ 0.7568,  2.8114, -2.1562,  ...,  1.3476, -0.3658,  0.7552]],

         [[ 0.3682,  0.0657, -0.1320,  ...,  0.6454,  0.1343,  0.2644],
          [-0.8896,  0.3677,  0.1631,  ..., -0.3916, -0.4439, -0.9719],
          [ 0.4470,  0.5271, -0.4635,  ..., -0.6886, -1.2558,  0.0390],
          ...,
          [-1.7867, -2.2049,  2.1719,  ..., -0.8210, -0.2084,  1.6132],
          [-1.4884, -1.5097,  0.1562,  ...,  0.5166,  0.2819, -0.1415],
          [-2.1183,  1.1049,  1.0999,  ..., -0.3114,  0.2994,  0.8749]]]],
       grad_fn=<PermuteBackward0>), tensor([[[[-2.8336e-02,  4.5010e-02, -5.7978e-02,  ..., -1.9222e-02,
            1.2577e-02,  3.6269e-02],
          [-4.6357e-01,  2.2473e-01, -5.6911e-01,  ..., -1.1179e-01,
            5.4963e-01,  1.4621e-01],
          [ 8.8791e-01,  1.3920e-01, -1.1074e+00,  ..., -3.0970e-01,
            8.6369e-01,  2.8616e-01],
          ...,
          [ 9.8967e-01,  7.6810e-02,  3.6725e-01,  ...,  1.0289e-01,
           -7.9780e-01, -1.0472e-01],
          [ 4.1472e-01, -3.0706e-01, -1.0118e-01,  ..., -2.9164e-02,
            9.2894e-02,  2.6503e-01],
          [-4.1391e-01, -4.3953e-01,  9.5461e-02,  ..., -1.8622e-02,
            1.2946e-01, -4.0387e-01]],

         [[ 2.1536e-02, -2.8120e-02,  3.8532e-02,  ...,  2.1765e-02,
           -4.7212e-02,  5.3255e-03],
          [-1.6248e-01, -4.5659e-01, -4.4525e-02,  ...,  5.6903e-01,
           -3.0144e-01, -1.2120e+00],
          [-1.6019e-01, -3.1593e-01,  1.0682e+00,  ..., -1.1746e-01,
           -4.8418e-01,  4.2423e-01],
          ...,
          [-7.0670e-01,  1.4226e-01, -2.0767e-01,  ..., -5.3785e-01,
           -3.7916e-01,  2.9476e-01],
          [ 3.5204e-01,  1.6746e-01, -1.8197e+00,  ...,  1.8833e-01,
            2.5200e-01,  1.3326e+00],
          [ 1.0614e-01, -5.6477e-01, -1.3717e+00,  ...,  2.8329e-01,
           -2.3432e-01,  5.8129e-01]],

         [[ 3.9084e-02, -2.6990e-02,  5.6189e-02,  ...,  2.6549e-02,
           -7.1806e-03,  1.9065e-02],
          [ 8.1593e-01,  3.5473e-01, -1.9476e-01,  ...,  7.1779e-01,
            1.7158e-01,  1.7037e-01],
          [-3.0468e-01,  6.4740e-01, -1.1535e+00,  ...,  2.5107e+00,
           -1.3214e+00,  6.0931e-01],
          ...,
          [-3.8012e-01, -1.0693e+00, -4.3163e-01,  ..., -1.2006e-01,
           -4.7626e-01, -5.9241e-01],
          [-6.6220e-01,  1.0321e+00,  6.1114e-01,  ..., -1.0294e+00,
           -5.9746e-02, -1.4874e+00],
          [ 1.5239e+00,  1.7266e-01, -2.6497e-01,  ..., -6.9278e-01,
            2.7154e-01,  1.1508e-01]],

         ...,

         [[-1.8650e-01,  8.9365e-02,  5.7435e-02,  ...,  4.6573e-02,
            3.7369e-02, -1.2676e-01],
          [-6.2920e-01, -4.5253e-02,  1.5379e-01,  ..., -8.5838e-01,
            2.2210e-01, -4.9222e-01],
          [-2.1227e-01,  6.7216e-01,  5.8456e-01,  ..., -4.8421e-02,
           -4.2428e-01, -4.8305e-01],
          ...,
          [-9.4783e-01, -4.8206e-02, -1.2836e-01,  ...,  1.8181e-01,
           -4.6491e-01, -8.4671e-01],
          [-7.2088e-01,  4.8839e-01, -1.6034e+00,  ..., -3.5454e-01,
            8.5080e-02, -1.4271e+00],
          [-1.0528e+00,  8.3454e-01, -9.8252e-01,  ...,  1.1729e-01,
           -1.4640e-01, -1.9143e+00]],

         [[-5.8489e-01, -4.5877e-03,  4.4912e-02,  ..., -2.0796e-02,
            6.2989e-03, -6.4938e-03],
          [-1.6445e+00,  4.2511e-02, -3.1403e-01,  ..., -3.7935e-01,
            2.3561e-01,  5.9496e-02],
          [-2.5505e+00, -2.0482e-01, -3.6240e-01,  ..., -3.0201e-01,
           -4.2028e-01, -1.8376e-02],
          ...,
          [-1.6757e+00,  4.2658e-01, -9.1740e-01,  ...,  2.0202e-01,
            5.2352e-01,  3.1575e-01],
          [-1.7608e+00,  5.6837e-01,  3.5225e-01,  ...,  5.5874e-01,
           -6.9264e-01, -1.8256e-01],
          [-2.3731e+00, -2.8098e-01,  3.9676e-01,  ..., -2.5406e-01,
            4.8834e-01, -6.1031e-01]],

         [[ 1.5471e-03,  8.2456e-02, -4.7513e-02,  ...,  5.5853e-02,
            3.0368e-02, -4.6994e-02],
          [-5.5504e-01,  7.3400e-01, -2.0816e-01,  ..., -1.2824e-01,
            3.8586e-01,  8.0331e-01],
          [ 6.3713e-01,  1.6547e+00,  2.6059e-01,  ..., -1.1861e+00,
            6.3198e-01, -1.3541e-01],
          ...,
          [ 4.7463e-01,  1.1477e+00,  6.0258e-02,  ..., -4.6058e-01,
           -3.5489e-01,  7.9365e-02],
          [ 8.1016e-02, -1.3944e-01,  4.1258e-01,  ...,  1.1060e-01,
           -2.8541e+00,  4.1492e-01],
          [-1.2963e+00,  2.2384e-01, -2.4338e-01,  ...,  2.2294e-01,
            1.0918e-01,  2.1425e+00]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-0.0268, -2.3398,  0.1634,  ..., -0.2365, -0.1944,  0.0645],
          [-0.2799,  4.0417, -0.1959,  ..., -0.1126, -0.3356, -0.5690],
          [-0.0094,  4.9294,  0.5373,  ..., -0.1404, -0.6815,  0.3025],
          ...,
          [ 0.0091,  3.8780,  0.7991,  ..., -0.5989,  0.7071,  0.5137],
          [-0.6218,  4.7662, -0.4088,  ..., -0.8925, -0.0737,  0.7395],
          [-0.8638,  5.1069, -0.1012,  ..., -0.0097,  0.0632, -0.7295]],

         [[-0.8140,  0.2218,  0.4656,  ..., -0.5189,  1.0732,  1.1234],
          [ 0.2364,  0.2685,  1.0541,  ...,  0.5500,  1.3914,  0.4962],
          [ 1.5259,  1.0305, -0.6830,  ..., -0.3595,  0.8213, -0.1596],
          ...,
          [-0.8893,  0.6401,  1.5340,  ..., -0.3154,  0.9969,  0.1131],
          [-1.4240, -0.5673, -0.9037,  ..., -0.0334,  2.1567, -0.3555],
          [-2.3116,  1.4069,  0.2116,  ...,  0.7944,  2.6708,  0.1778]],

         [[-0.8504,  0.4700,  0.0232,  ...,  0.4955, -0.2356,  1.1518],
          [ 0.6655, -0.1374,  1.1604,  ...,  0.2494,  1.0734, -0.9082],
          [ 2.0262,  0.3311,  0.5329,  ...,  0.2746,  0.6484, -1.2565],
          ...,
          [ 0.8666,  0.2080,  0.7423,  ..., -0.0590,  0.7947,  0.2077],
          [ 1.3274, -0.5878,  1.5562,  ...,  1.2727,  0.8958, -0.8393],
          [ 0.6793, -0.9115,  2.1432,  ...,  1.5571,  1.7428, -0.3943]],

         ...,

         [[-0.3102, -0.1292,  0.1523,  ...,  0.1793,  1.7438, -2.8696],
          [ 1.2379, -0.5238,  0.3674,  ..., -0.3042, -4.6049,  4.6856],
          [ 0.6856,  0.3973,  0.9211,  ..., -0.6994, -5.2863,  4.9465],
          ...,
          [-0.5532, -0.4212,  1.0728,  ...,  0.4562, -5.7176,  5.1979],
          [-0.9992, -1.4073, -0.8534,  ...,  0.8452, -5.9484,  4.2105],
          [ 0.1935, -1.2555,  1.2355,  ..., -0.0070, -6.0872,  5.8807]],

         [[ 0.1957,  0.3617,  0.2155,  ..., -0.2170,  0.0182, -0.1540],
          [-0.6359, -0.7831, -0.5938,  ...,  1.0413, -0.4280,  0.6407],
          [-0.6033, -1.0964, -0.2818,  ...,  0.2840, -0.2947,  0.6149],
          ...,
          [-0.2907,  0.0759,  0.5673,  ...,  1.1031, -0.7398,  0.1992],
          [-0.3487, -0.1916,  1.1144,  ...,  0.6085,  0.1949,  1.1279],
          [-1.1693, -0.8894,  0.6257,  ...,  1.4145, -1.2843,  0.4372]],

         [[ 0.3722,  0.0987,  0.6134,  ...,  0.5249,  0.5746, -0.3289],
          [ 0.7276, -0.7879, -1.5108,  ..., -1.7654, -3.2146,  0.1771],
          [ 0.6286, -1.0423, -1.3390,  ..., -2.0023, -2.7540, -0.0532],
          ...,
          [ 1.2008, -1.0047, -2.2047,  ..., -2.5210, -4.7543,  1.0585],
          [ 0.1571, -1.0960, -1.7899,  ..., -3.0896, -4.1969, -0.4143],
          [ 1.1982, -1.3326, -1.5329,  ..., -1.6822, -4.4774, -0.5948]]]],
       grad_fn=<PermuteBackward0>), tensor([[[[ 0.0512, -0.0071, -0.0159,  ...,  0.1308, -0.0604, -0.0383],
          [ 0.0743,  0.0098,  0.8985,  ...,  0.3322,  1.0163, -0.2279],
          [ 0.2299, -1.0595, -0.2036,  ...,  0.4071,  1.0309,  0.7073],
          ...,
          [ 0.9942,  0.0985, -0.3045,  ...,  0.3595,  1.2762,  0.1312],
          [-1.1851,  0.1872,  2.5162,  ..., -0.4091, -0.5504, -0.3313],
          [ 0.4573, -0.2495,  1.1492,  ...,  0.3916,  0.3092, -0.2549]],

         [[ 0.0178,  0.0383,  0.0396,  ...,  0.0060, -0.0180,  0.0108],
          [ 0.3077,  0.2800, -1.2484,  ...,  0.1144, -0.0260, -0.6417],
          [ 0.8365,  0.1942, -2.6429,  ...,  1.4839, -2.4390, -1.1518],
          ...,
          [-1.0152, -1.3838,  0.4507,  ...,  0.2284,  0.2643,  0.3901],
          [-1.8002, -1.5104, -0.6286,  ...,  1.0451,  0.2438, -0.3518],
          [-0.4032, -0.3529, -1.6265,  ...,  0.5828,  0.5720, -1.2572]],

         [[ 0.0495, -0.0389,  0.0613,  ...,  0.0561, -0.0711, -0.0673],
          [-0.6686,  1.1461, -0.4798,  ...,  0.1773,  0.4573,  0.4967],
          [-0.3811,  0.8968, -0.6061,  ...,  0.0926,  0.3056,  0.9180],
          ...,
          [-0.3757, -0.0510,  0.0062,  ...,  0.6064,  0.7972,  0.7227],
          [-0.2685, -0.7850,  0.7441,  ..., -0.8875, -0.0677,  1.0534],
          [-0.7876,  1.0096, -0.0108,  ..., -0.9138, -0.1195, -0.2942]],

         ...,

         [[-0.1028, -0.0452,  0.0346,  ..., -0.0871,  0.0427,  0.0092],
          [ 0.5169,  0.0966,  0.2483,  ..., -0.4591,  0.3724,  0.6674],
          [ 0.9085,  0.9305, -0.0286,  ..., -0.8769, -0.3911,  0.3594],
          ...,
          [-0.0673, -0.2202, -0.2051,  ...,  0.2041, -0.4487,  1.0220],
          [-0.2218, -0.4037,  1.4038,  ...,  1.5332, -1.2336,  0.4163],
          [ 0.8637, -1.0940,  0.2482,  ...,  0.3983, -1.4612,  0.6188]],

         [[ 0.1576, -0.0522,  0.1510,  ...,  0.0776,  0.0389, -0.1486],
          [-0.0612,  1.4222,  1.2901,  ...,  1.0537,  1.9877, -1.2965],
          [ 0.0701,  1.0599,  1.3164,  ...,  1.8434,  1.7597, -0.8641],
          ...,
          [-0.0791,  0.1802, -0.2036,  ...,  0.6063,  1.2652,  0.1763],
          [ 0.4001,  1.6460,  1.1749,  ..., -0.6267,  2.3732, -0.3538],
          [ 0.2739,  1.4950,  0.8300,  ...,  1.1957,  1.5808, -1.0777]],

         [[ 0.2067, -0.0439, -0.0680,  ...,  0.0390,  0.0473,  0.0275],
          [-0.6717,  0.2561,  0.7676,  ..., -0.2872, -0.5916, -0.1957],
          [-0.9239,  0.0464,  0.4365,  ...,  0.6006, -0.4989,  0.7633],
          ...,
          [ 0.5723,  0.0787,  0.7033,  ...,  0.3464, -0.7811,  1.3074],
          [-0.8109, -0.4612, -1.6027,  ..., -1.6367, -0.0065, -0.7756],
          [-1.3609,  0.5702,  0.7531,  ..., -0.1462,  0.1355,  0.2370]]]],
       grad_fn=<PermuteBackward0>)), (tensor([[[[ 0.0436, -0.2509, -0.4550,  ...,  0.3116,  0.3358,  0.3770],
          [ 0.3000, -0.0478, -1.2507,  ...,  0.3519,  0.7682,  0.7296],
          [ 0.5719, -1.2229, -1.5268,  ...,  0.2386,  0.1496,  0.8318],
          ...,
          [-0.7869, -1.6832, -0.6862,  ...,  0.5248, -1.1760, -0.6061],
          [ 1.1739,  0.7271, -1.4276,  ...,  1.1409, -1.3880, -0.6762],
          [ 1.7515, -0.1609, -0.0345,  ...,  0.9718, -0.5132,  1.4921]],

         [[-0.2801,  0.1559,  0.1167,  ...,  0.0214, -1.1384, -0.1501],
          [-0.7174, -0.1200, -0.7961,  ..., -0.4121,  0.7157,  0.5868],
          [ 1.5377,  0.1651, -0.9257,  ...,  0.3588,  1.3888,  0.1633],
          ...,
          [ 0.2032,  0.5659, -0.9297,  ..., -1.1580, -1.0870,  1.0748],
          [-0.0984,  1.5501, -1.2118,  ..., -1.0350,  0.6500,  0.8747],
          [-1.1498,  0.8479, -0.9318,  ..., -1.2515,  0.5937,  0.4393]],

         [[-1.2411, -0.0878,  0.5490,  ..., -0.6611,  0.4539, -0.2888],
          [ 0.6556,  1.0735, -0.5900,  ...,  0.0895, -0.3484, -0.2450],
          [ 0.3530,  0.0116,  0.0702,  ...,  0.7262, -1.4991, -0.5028],
          ...,
          [ 0.6693,  0.8831, -0.7045,  ...,  1.2413,  0.0528,  0.1498],
          [ 1.5144,  1.9988, -1.8167,  ...,  1.0272, -0.5508, -0.2781],
          [-0.2976,  1.1260, -1.6873,  ...,  1.3365, -0.2020, -0.3461]],

         ...,

         [[ 0.7973, -0.8987, -0.3939,  ..., -1.0369, -0.4123,  0.4803],
          [ 1.4616,  0.0408, -1.0295,  ...,  0.7219,  0.3444, -0.0145],
          [ 0.2550, -1.1764, -0.3335,  ...,  0.8036,  1.7228, -2.3128],
          ...,
          [ 0.6038, -0.3213, -0.9128,  ...,  1.7723,  0.7332, -1.3456],
          [ 1.5292,  0.8308, -1.5665,  ...,  1.7068,  0.6255, -1.4453],
          [ 2.1459, -0.1321, -0.5784,  ...,  1.8690,  1.6415,  0.8508]],

         [[-0.9151,  2.5785,  0.3082,  ...,  0.3579,  1.9421, -0.5408],
          [-0.0171, -2.5663,  0.7328,  ...,  0.3923, -4.1463,  1.9012],
          [ 1.0185, -2.5828, -1.5448,  ...,  1.0508, -4.9451,  1.7123],
          ...,
          [ 1.1092, -2.5339,  0.2730,  ..., -0.9127, -3.6883, -0.9762],
          [ 0.7417, -1.7092,  0.4430,  ...,  0.6517, -4.0859, -0.6250],
          [ 0.6957, -4.4839, -0.4944,  ...,  1.2733, -5.0460,  2.7409]],

         [[-2.0221, -0.3681, -1.1042,  ..., -0.3983,  0.0527,  0.2442],
          [ 1.4425,  0.4368,  0.8613,  ...,  1.2344, -0.1098,  0.1759],
          [ 2.6873, -0.5718,  0.7670,  ...,  1.4859, -0.9973,  1.5824],
          ...,
          [ 2.3481, -0.2267,  0.4736,  ...,  1.0791,  0.1695, -0.6822],
          [ 1.8121,  0.8181,  1.5002,  ...,  1.3897, -1.1112, -0.6512],
          [ 2.0455,  0.8276,  1.0394,  ...,  1.7555, -0.0730, -0.0210]]]],
       grad_fn=<PermuteBackward0>), tensor([[[[-4.8632e-02, -9.1396e-02,  3.1682e-02,  ...,  1.2261e-01,
           -3.6255e-02,  1.2526e-02],
          [-3.1696e-01,  2.4710e-01, -5.5300e-02,  ...,  2.4286e-02,
           -3.1162e-01,  2.2300e-01],
          [ 1.9266e-01,  5.7364e-01,  5.6620e-01,  ...,  6.3398e-01,
           -1.6994e-01, -3.2943e-01],
          ...,
          [ 8.0219e-01, -6.2467e-02,  7.5092e-01,  ..., -2.6152e-01,
            6.4908e-01,  9.1121e-01],
          [ 2.9819e-01, -1.1154e+00,  5.7111e-01,  ..., -1.1155e+00,
            5.0150e-01,  3.6634e-01],
          [ 7.2844e-01,  4.1041e-01,  6.7296e-01,  ...,  2.8859e-01,
           -9.5357e-01,  4.9752e-01]],

         [[ 1.1928e-02,  1.3112e-02, -2.6053e-02,  ...,  4.6390e-02,
            2.8720e-02,  5.6897e-02],
          [ 5.1804e-01, -8.6756e-03,  3.4240e-01,  ..., -9.3518e-01,
           -2.8230e-02, -1.6108e-01],
          [-6.5553e-01, -1.4296e-01,  6.3211e-01,  ..., -2.3726e+00,
           -1.0325e+00,  1.1180e+00],
          ...,
          [-2.7697e-01,  4.7694e-01,  9.3078e-01,  ..., -1.4985e-02,
           -9.5630e-01, -1.0057e+00],
          [-3.5304e-01,  7.6668e-01, -7.3687e-01,  ...,  8.2464e-01,
            6.1313e-01,  1.4616e-01],
          [ 6.2543e-02,  9.5850e-01,  9.9546e-02,  ..., -4.1675e-01,
           -3.1019e-01,  2.1785e-02]],

         [[ 3.9431e-02,  3.2304e-02, -6.9643e-02,  ...,  3.1842e-03,
            1.5391e-02,  8.6383e-03],
          [-4.5218e-02,  3.8015e-01, -7.4175e-03,  ..., -8.6065e-02,
            1.9510e-01,  2.4301e-02],
          [ 1.0227e+00,  7.7004e-02,  7.1903e-02,  ...,  1.1994e+00,
            1.6976e-01, -4.0066e-01],
          ...,
          [ 1.1771e+00,  2.4422e-01,  7.0662e-01,  ...,  1.1337e+00,
           -8.5384e-01, -9.9605e-01],
          [ 4.0196e-01,  3.7700e-01,  1.0244e+00,  ..., -2.4000e-01,
           -2.2166e-03, -8.7664e-01],
          [ 6.3016e-01,  1.0653e-01,  6.7085e-01,  ...,  1.8561e-01,
           -1.0484e+00, -2.8506e-01]],

         ...,

         [[-2.8633e-02,  2.3521e-02, -1.3071e-02,  ..., -3.3836e-02,
           -4.1805e-02,  1.6132e-02],
          [ 2.9617e-01,  2.5753e-01,  6.4459e-01,  ...,  6.7883e-01,
           -1.1170e-01,  3.4354e-01],
          [ 1.4840e-01, -2.1638e-01,  1.5988e-01,  ...,  3.0029e-01,
           -1.7462e+00,  2.2010e+00],
          ...,
          [-1.0735e-01, -2.7973e-01,  1.7696e-01,  ...,  1.2454e-01,
            1.6533e+00,  4.6311e-02],
          [-2.5303e-01, -5.3346e-01, -7.0970e-01,  ...,  3.3254e-01,
           -1.0337e-01, -1.5011e+00],
          [ 9.4744e-01,  4.1239e-01, -1.0214e-01,  ...,  1.0832e+00,
            1.1939e+00,  2.1364e-01]],

         [[-7.3958e-02, -4.4124e-02,  1.7760e-02,  ...,  3.1321e-03,
           -4.5881e-02, -1.0916e-01],
          [-4.6492e-01, -6.5992e-02, -4.8427e-02,  ...,  2.7765e-01,
            1.7094e-01, -2.1020e-01],
          [-9.3265e-01, -1.7024e+00,  1.1011e-01,  ..., -6.0777e-01,
            2.7326e-01, -1.2374e+00],
          ...,
          [-1.0394e-01,  3.4447e-02, -1.4004e+00,  ...,  1.9303e-01,
           -1.2038e+00,  5.6969e-01],
          [-9.6140e-01,  5.8390e-01, -5.3376e-01,  ...,  3.5307e-01,
            3.7874e-01, -4.8008e-02],
          [-6.0081e-02, -5.1836e-01, -9.0043e-02,  ...,  2.2977e-01,
           -1.1964e-01, -4.6107e-01]],

         [[-7.7909e-03,  4.0206e-02, -5.6468e-02,  ..., -3.0341e-02,
            2.4338e-02,  5.3261e-03],
          [-2.4073e-01, -1.4607e-01,  6.8568e-01,  ..., -9.4289e-01,
           -1.0285e+00, -7.2268e-01],
          [-8.9161e-01,  3.2033e-01,  2.2241e-01,  ...,  7.4783e-01,
           -1.8553e-01, -1.4143e+00],
          ...,
          [ 5.0678e-01, -8.7200e-01,  1.3745e+00,  ...,  4.7279e-01,
            1.9468e-01, -3.0692e-01],
          [ 9.3960e-02, -1.1271e+00, -3.2356e-01,  ...,  4.6166e-01,
            1.1812e+00, -7.4736e-02],
          [-1.3350e-01, -7.4492e-01,  7.1189e-01,  ..., -1.8032e-01,
           -1.5200e+00, -9.0480e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-0.5266,  0.4829, -0.8581,  ..., -1.0129, -1.3275,  0.2040],
          [ 0.6896,  0.9769,  0.7884,  ...,  2.3917, -0.4978, -1.2931],
          [ 0.4849,  0.2801,  2.0546,  ...,  2.4861,  0.4516, -2.1291],
          ...,
          [ 1.1704, -0.3403,  1.9815,  ...,  2.4192,  1.2849, -0.9075],
          [ 0.5567,  0.0731,  0.2333,  ...,  1.9754, -0.6718,  0.4945],
          [ 1.1157,  0.0910,  1.9513,  ...,  2.0806, -0.6777, -1.8277]],

         [[ 0.8645, -2.0846,  0.1532,  ...,  0.2459, -2.4906, -0.4514],
          [ 1.1494,  1.2483, -0.0495,  ...,  0.1813,  0.8199, -0.5313],
          [ 1.4952,  1.4661, -1.3266,  ...,  0.6351,  0.5419,  0.3732],
          ...,
          [ 1.6627,  3.2121, -1.1410,  ..., -0.1081,  2.2876, -1.0492],
          [ 0.8725,  3.7180, -0.8677,  ...,  0.5521,  0.0537, -2.0911],
          [ 1.8427,  2.8496, -0.9180,  ..., -0.0876,  2.9764, -1.0137]],

         [[ 1.0175,  0.3871, -0.1741,  ..., -0.8094, -1.4149, -0.3730],
          [-0.2747,  0.4294, -0.8148,  ...,  0.7997, -1.0098, -0.2083],
          [-0.1443,  0.1837, -0.6903,  ...,  2.3234, -0.5142, -1.1581],
          ...,
          [-1.7553, -0.7940, -1.4744,  ...,  1.9563, -0.3079,  0.2517],
          [-1.1555, -0.9816, -1.4792,  ...,  2.4893, -0.8572,  0.6439],
          [ 0.2061,  0.6956, -1.2343,  ...,  1.2946, -0.7649, -1.0596]],

         ...,

         [[ 0.2137, -0.5814,  0.4917,  ..., -0.6758,  1.0594,  0.2809],
          [-0.0787,  1.1178, -0.9665,  ..., -2.9838, -0.0755,  0.3358],
          [-0.3467,  0.6547, -1.9701,  ..., -2.6404, -1.7759,  0.1484],
          ...,
          [-0.6088,  0.2404, -1.0831,  ..., -2.5044, -0.5236,  0.2501],
          [-1.2627, -0.4007,  0.0159,  ..., -2.2715, -1.9617,  0.1351],
          [-0.4629,  0.4004, -1.0877,  ..., -3.2533, -0.1876, -0.2612]],

         [[ 0.2431,  0.5528,  0.5439,  ...,  0.7452,  0.0856,  0.8468],
          [ 0.3639,  2.4237,  0.9672,  ...,  0.7770, -0.7330,  0.4097],
          [-0.4982,  1.9386, -0.1103,  ...,  1.4543, -0.3265,  0.4745],
          ...,
          [ 0.2191,  1.5633, -0.4826,  ..., -0.9138, -0.7183,  0.2929],
          [-2.4011, -0.7274, -0.1691,  ...,  0.5614, -0.1154,  2.1418],
          [ 1.8710,  2.7152,  0.3026,  ...,  0.4339, -1.6067,  0.4278]],

         [[-0.7092,  0.3125, -1.6205,  ..., -0.4008,  0.2350, -1.3048],
          [ 0.0382,  0.8210, -1.6851,  ...,  1.5476,  1.1133, -1.3639],
          [-1.2253, -0.0602, -3.1185,  ..., -0.4857,  1.8382,  0.9552],
          ...,
          [-3.0265, -0.1628, -0.6678,  ...,  1.2046,  1.1136, -0.6637],
          [-2.2680, -0.1403, -1.6040,  ...,  0.0642,  0.6752, -0.0818],
          [-0.6567, -0.4737, -1.8665,  ...,  1.7928,  1.7230, -1.4443]]]],
       grad_fn=<PermuteBackward0>), tensor([[[[ 3.4694e-03,  5.2180e-02, -7.1138e-02,  ...,  5.9633e-02,
           -5.0955e-02, -7.4279e-02],
          [ 9.5441e-02, -8.2082e-04,  3.7786e-01,  ..., -7.9814e-01,
           -3.4941e-01,  5.3955e-01],
          [-1.7982e-01,  1.0793e+00,  8.4480e-01,  ..., -5.6335e-01,
            4.7423e-01, -1.3511e-01],
          ...,
          [-8.4265e-01, -2.8338e-02, -8.4992e-01,  ..., -8.6247e-01,
            6.9610e-01, -1.6560e-01],
          [-5.8879e-02,  7.0560e-01,  8.3837e-01,  ..., -4.8124e-01,
           -1.7102e+00,  4.3793e-01],
          [ 4.7365e-02,  2.9308e-01,  2.9819e-01,  ..., -3.3441e-01,
           -6.1577e-01,  5.0968e-01]],

         [[ 3.7719e-02,  1.2977e-04,  4.9038e-02,  ..., -3.9138e-02,
           -2.6473e-02, -1.4142e-02],
          [-1.9622e-01,  4.4304e-01,  9.3501e-02,  ...,  8.6879e-01,
            5.8439e-01,  6.5467e-01],
          [-8.1397e-01, -1.4557e+00, -3.4408e-01,  ...,  1.0143e+00,
            1.6014e-01, -7.6486e-01],
          ...,
          [ 4.8213e-01,  1.1956e+00, -6.7466e-01,  ...,  4.4558e-03,
           -6.0745e-01,  1.5004e-01],
          [-9.2434e-01, -9.9667e-02, -1.7371e-01,  ...,  3.3668e-01,
            3.7452e-01,  9.1399e-01],
          [-8.0525e-01,  2.7367e-01,  2.7182e-01,  ...,  1.5725e+00,
            1.8934e-01,  9.1494e-01]],

         [[-6.9176e-03,  1.8243e-02, -3.3975e-02,  ...,  8.5669e-03,
            2.7227e-02,  5.8461e-02],
          [-2.8638e-01,  4.4393e-02, -2.4720e-01,  ...,  5.8055e-01,
           -1.1038e+00, -3.1214e-01],
          [-4.1151e-02,  4.7980e-01, -8.1177e-01,  ...,  2.5263e+00,
           -6.2052e-01, -4.0801e-01],
          ...,
          [-6.4285e-01,  2.1790e-01,  7.1201e-01,  ...,  7.6857e-01,
            1.9746e-02, -1.2292e-02],
          [ 4.3683e-01, -2.0561e-01,  5.6170e-01,  ..., -1.3195e+00,
           -6.0955e-01,  8.5465e-01],
          [-5.0826e-02,  2.0641e-01,  2.1014e-01,  ..., -6.1202e-01,
           -3.7409e-01,  5.8607e-01]],

         ...,

         [[-7.8410e-02,  2.6667e-02,  1.1429e-02,  ..., -3.5996e-02,
           -7.8381e-03,  1.2273e-03],
          [-2.4955e-01,  3.0179e-01,  2.2439e-01,  ..., -7.0245e-01,
           -4.7259e-01, -1.2154e-01],
          [-1.1360e+00,  4.8186e-01,  6.9660e-01,  ...,  5.2388e-02,
            4.9656e-01, -7.1202e-01],
          ...,
          [-2.2132e-01,  2.5862e-01,  5.4504e-01,  ...,  6.4937e-01,
           -1.4201e-01, -6.9701e-02],
          [-9.8322e-01, -1.1579e-01,  1.4461e+00,  ...,  4.0303e-01,
           -8.9281e-01,  9.6826e-01],
          [-1.2535e+00,  7.9669e-01,  2.3864e+00,  ..., -1.1996e+00,
           -1.2942e-02,  1.5757e+00]],

         [[ 7.4373e-02, -1.0839e-03,  4.7472e-02,  ...,  2.5576e-02,
            5.5578e-02,  3.0725e-02],
          [ 1.6561e-01,  1.1326e+00,  1.1021e+00,  ..., -6.7084e-02,
            1.0625e+00, -7.9841e-01],
          [-1.1934e+00,  1.3455e+00,  7.5402e-01,  ...,  3.0290e+00,
            1.9807e+00, -1.6143e-01],
          ...,
          [ 8.7119e-01,  1.6007e+00,  9.8724e-01,  ...,  6.2297e-01,
            9.5836e-01, -6.7591e-02],
          [ 5.6550e-01,  7.5545e-01, -9.4622e-01,  ...,  3.9639e-01,
           -1.3479e-01,  1.4511e-01],
          [-2.5438e-01,  1.3767e+00,  1.5838e+00,  ...,  5.7618e-02,
            1.7279e+00, -7.0514e-01]],

         [[-1.1214e-01,  2.7461e-02, -6.8169e-02,  ..., -8.8035e-02,
            7.2290e-02, -2.1984e-02],
          [-1.5667e-01,  3.3572e-01, -2.9793e-01,  ..., -3.2849e-01,
           -6.0364e-02,  8.4579e-02],
          [-3.0011e-01,  3.6599e-01,  4.1995e-01,  ..., -5.6659e-01,
            1.6448e-01,  2.1300e-01],
          ...,
          [-5.3965e-01,  8.5568e-01, -1.0334e+00,  ...,  1.6571e+00,
            1.2634e+00,  1.2663e-02],
          [-1.1969e+00,  1.1998e-01,  7.4285e-01,  ...,  2.5529e+00,
            2.4390e+00, -3.4413e-01],
          [-7.4888e-01,  4.5366e-01, -1.2199e+00,  ..., -4.5325e-01,
            4.5486e-01,  9.2945e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-1.7115, -0.3095, -0.3052,  ...,  0.1569,  0.3295, -0.5102],
          [-0.2954, -0.5369, -1.0378,  ...,  1.1073, -0.8305,  0.1360],
          [ 0.3707, -0.1892, -0.0631,  ..., -0.2839, -0.1220,  0.7498],
          ...,
          [ 0.1536,  0.6239, -0.8671,  ...,  0.7387,  0.5640,  0.1241],
          [ 2.3072, -0.4214, -0.9284,  ...,  1.0017,  0.4802,  0.5843],
          [ 0.1397,  0.0057, -1.1350,  ...,  1.1576, -1.0530,  0.4489]],

         [[ 0.1108, -0.0705,  2.3025,  ...,  0.2415,  0.0896, -0.1951],
          [ 0.7120, -1.0674, -0.9315,  ...,  0.1600,  0.2404, -0.4726],
          [ 1.4893, -0.9425, -1.5101,  ..., -0.3591,  0.0335,  0.4421],
          ...,
          [ 0.1341, -0.0506, -1.3000,  ..., -0.1105, -0.2529,  0.7670],
          [-0.5547, -0.6913, -1.2921,  ...,  0.2898, -0.2538,  0.9526],
          [ 0.2336, -0.8297, -0.8416,  ..., -0.0980, -0.2919,  0.7454]],

         [[-0.2041,  1.0503,  0.4759,  ..., -0.5452,  0.3040, -0.1147],
          [-0.4822,  0.5302, -0.5850,  ...,  1.5575,  0.2531, -0.3998],
          [-0.5382,  0.5110,  0.0939,  ...,  0.6880, -0.2115, -0.7376],
          ...,
          [-0.3389,  0.6598, -0.4937,  ...,  0.4478, -0.8530, -0.5765],
          [-0.7257,  1.3182,  0.9792,  ...,  1.7382,  0.5813, -1.0117],
          [-0.9058,  1.0543, -0.6513,  ...,  2.0381, -0.3831,  0.2467]],

         ...,

         [[ 0.5577,  0.9790, -0.8716,  ..., -0.7184,  0.7245,  0.8611],
          [ 0.3537,  0.1875, -0.5812,  ..., -0.5651, -0.0493,  0.0897],
          [ 0.4359, -0.7858, -0.9179,  ..., -1.4713,  0.2309, -0.2758],
          ...,
          [ 0.9503, -0.6978,  0.7306,  ..., -0.7847, -0.9335, -0.8081],
          [ 0.1232,  1.2112,  1.0973,  ...,  0.2584,  1.1175, -0.0057],
          [-0.0062, -0.2483,  0.2463,  ..., -0.3165,  0.3718, -0.2848]],

         [[-0.4014,  0.3733,  0.3393,  ...,  0.7212,  0.0451, -0.0838],
          [-0.5197,  1.3345, -1.5982,  ...,  0.5380, -0.2475, -0.9776],
          [ 0.0646,  0.0452, -0.4746,  ...,  0.9874, -0.8139,  0.1726],
          ...,
          [-0.6432, -0.4941,  0.4357,  ...,  0.5838, -1.3339, -0.0826],
          [-0.9413, -1.2357, -0.4911,  ...,  1.3679, -1.0148, -1.4263],
          [-0.9545,  0.2418, -1.5970,  ...,  0.3238, -0.9107, -0.7229]],

         [[-0.7459, -0.0075,  0.4400,  ..., -0.1109,  0.0299, -0.0598],
          [-0.2137,  0.3865,  1.1712,  ...,  0.4425, -0.3584,  1.2832],
          [ 0.5957,  0.1015, -0.1897,  ...,  0.4039, -1.3808,  1.2112],
          ...,
          [ 0.7437, -1.3902,  0.2656,  ...,  0.9423, -1.2780,  1.6726],
          [-0.7614,  0.3624,  1.4484,  ...,  0.2220, -1.0658,  1.0444],
          [-0.4612,  0.8413,  1.7939,  ...,  0.1289, -0.8518,  1.1819]]]],
       grad_fn=<PermuteBackward0>), tensor([[[[ 7.7570e-02, -1.1777e-01, -1.6829e-01,  ..., -3.0139e-01,
            2.8640e-01, -1.7741e-01],
          [ 3.7056e-01,  6.3164e-01,  7.7662e-01,  ...,  2.7495e+00,
           -1.5492e+00,  1.1155e+00],
          [ 1.6399e+00,  1.3236e+00,  5.0145e-01,  ...,  2.7380e+00,
           -2.5362e+00,  2.0660e+00],
          ...,
          [-2.2210e-03,  1.2618e-01,  1.9018e-01,  ...,  2.5907e+00,
           -1.5682e+00,  7.7443e-01],
          [ 1.0661e+00,  1.8362e-01,  9.9011e-01,  ...,  1.7970e+00,
           -1.8210e-01, -7.9636e-01],
          [ 1.1176e+00,  9.5490e-01,  4.2716e-01,  ...,  2.4762e+00,
           -1.8121e+00,  1.6125e+00]],

         [[ 1.0853e-01, -1.0814e-02,  5.5897e-02,  ..., -9.3695e-03,
           -8.4395e-02,  1.6578e-01],
          [ 9.0288e-02,  4.3214e-01,  7.7907e-02,  ...,  3.6511e-01,
            4.1462e-01, -3.7498e-01],
          [ 4.8901e-02,  1.1972e+00, -1.0267e-01,  ..., -2.4577e-01,
            3.2252e-01,  9.5713e-02],
          ...,
          [ 1.4289e+00, -4.0081e-01,  8.8847e-01,  ..., -1.2688e-01,
           -2.1349e-01, -1.5179e+00],
          [-1.8024e-01, -5.9997e-01,  1.6811e+00,  ...,  8.8114e-01,
           -1.2796e+00,  8.0612e-01],
          [ 3.5363e-01,  1.5338e-01,  1.0489e-01,  ...,  7.1419e-01,
           -2.5939e-01,  1.1640e-01]],

         [[-1.3536e-02,  2.5633e-02, -3.8610e-02,  ...,  4.7447e-02,
            4.5465e-04,  7.3786e-02],
          [ 3.7973e-01, -2.6919e-01, -4.5875e-01,  ..., -1.4160e-01,
            3.0695e-01, -4.8341e-01],
          [ 1.1969e+00,  1.2378e+00, -6.2153e-01,  ..., -9.3299e-01,
            5.5717e-02, -2.5939e-02],
          ...,
          [ 1.0509e+00, -6.8117e-01, -5.0678e-01,  ..., -5.8349e-01,
            1.6390e-01, -4.4167e-01],
          [-5.3312e-01,  6.3160e-01,  2.2554e-01,  ..., -1.1507e+00,
            6.4968e-01,  3.7368e-01],
          [ 2.3626e-01, -1.7837e-01,  2.7653e-01,  ..., -8.8951e-02,
           -3.4488e-02, -6.5983e-01]],

         ...,

         [[-3.0262e-02, -1.2759e-02,  8.2024e-02,  ...,  4.1477e-02,
           -3.4039e-02,  1.6534e-02],
          [ 7.0146e-02,  3.9249e-01,  3.6694e-02,  ...,  1.1981e-01,
           -3.4416e-01, -1.2740e-01],
          [-1.4357e+00, -8.1313e-01,  3.6240e-01,  ...,  6.4624e-01,
           -7.6324e-01,  1.4873e+00],
          ...,
          [ 1.0556e-01, -3.8366e-01,  1.2748e+00,  ..., -3.6558e-01,
            4.0858e-01,  2.4199e-01],
          [-2.5444e-01,  1.1958e+00, -1.7147e-01,  ...,  6.1984e-01,
           -2.2845e-01, -1.8110e+00],
          [ 3.2427e-01,  8.9915e-01,  1.1141e+00,  ...,  6.8071e-01,
           -3.4533e-01, -1.7910e-01]],

         [[-1.8907e-01, -6.5480e-02,  7.6243e-02,  ..., -5.9887e-02,
            5.6530e-02, -7.3080e-02],
          [-8.2506e-01, -3.6656e-02,  4.9222e-01,  ...,  2.5220e-01,
            3.1897e-01,  1.9113e-01],
          [-4.6517e-01, -2.1911e-01, -6.4030e-01,  ...,  7.2280e-01,
            7.5668e-01,  5.6131e-01],
          ...,
          [-1.0660e+00, -4.2479e-01, -5.0573e-01,  ..., -5.8658e-02,
           -6.6094e-02, -4.4752e-01],
          [-8.6907e-02,  1.2486e-04, -5.2314e-01,  ...,  1.1544e-01,
            4.3831e-01, -1.0179e-02],
          [-1.0669e+00, -7.1475e-01,  8.0158e-01,  ..., -1.1919e-01,
           -2.0185e-01,  3.2946e-01]],

         [[ 1.2763e-01, -1.2701e-01,  1.6529e-01,  ..., -1.4527e-01,
           -8.5370e-03, -1.7278e-01],
          [-6.5069e-02,  3.5000e-01,  5.6586e-01,  ..., -3.5917e-01,
           -4.1324e-01,  2.9987e-01],
          [ 2.5123e-01,  5.5106e-01,  4.2795e-01,  ..., -1.0718e+00,
           -6.8236e-01, -4.2256e-01],
          ...,
          [-6.0648e-01, -5.4619e-01,  1.4942e-02,  ..., -7.6836e-01,
           -5.9767e-01, -1.3891e-02],
          [-3.4398e-01, -8.0992e-01,  7.4776e-01,  ..., -1.8947e+00,
           -2.7473e-01,  4.0089e-01],
          [ 8.6354e-02, -1.2515e-02, -2.7977e-01,  ..., -4.1148e-01,
           -5.5178e-01,  7.0079e-02]]]], grad_fn=<PermuteBackward0>))), hidden_states=(tensor([[[ 0.0231, -0.2904,  0.1120,  ...,  0.2610,  0.0677,  0.0696],
         [ 0.0355, -0.0567, -0.0626,  ...,  0.0619, -0.0195, -0.0601],
         [-0.0454, -0.0329,  0.0899,  ..., -0.1099, -0.1276,  0.0272],
         ...,
         [-0.0366, -0.0155,  0.1617,  ..., -0.0452,  0.0707, -0.0497],
         [-0.0264, -0.0053,  0.2777,  ...,  0.1060,  0.1664,  0.0183],
         [ 0.0103, -0.0047,  0.1434,  ...,  0.0254, -0.0255, -0.0704]]],
       grad_fn=<AddBackward0>), tensor([[[ 1.5737e+00, -4.1554e-01,  4.5012e-01,  ...,  4.3850e-02,
           7.4813e-01, -8.7114e-01],
         [ 5.1361e-01, -6.6155e-01,  1.0332e-01,  ...,  4.2718e-01,
           1.7186e-01,  3.6244e-01],
         [ 1.2385e+00,  5.1269e-04, -1.1555e-01,  ...,  3.3694e-01,
          -2.2656e-01,  7.6178e-02],
         ...,
         [-1.5542e+00,  5.6012e-01,  3.0304e-01,  ...,  2.0757e-01,
           3.6331e-01, -5.2796e-01],
         [-8.0574e-01,  5.1341e-01, -1.3832e+00,  ...,  8.7573e-01,
          -3.1620e-01, -2.6355e+00],
         [ 8.7906e-01, -4.0571e-01,  6.8713e-01,  ...,  1.3655e+00,
          -1.1660e-01,  2.1324e-01]]], grad_fn=<AddBackward0>), tensor([[[ 1.7136, -0.5216,  1.2041,  ..., -0.4961,  0.3665, -0.9365],
         [ 0.4630, -1.2140,  0.2936,  ...,  0.0555, -0.1479,  0.3223],
         [ 1.2810, -0.0626, -0.0681,  ...,  0.6627, -0.5515,  0.0529],
         ...,
         [-1.3538,  0.9463,  0.3435,  ..., -0.0469,  0.4996, -0.5079],
         [-1.2018,  1.1568, -1.9729,  ...,  0.3070, -0.0780, -2.1962],
         [ 0.4538, -0.4325,  0.9298,  ...,  1.6704,  0.1176,  0.5136]]],
       grad_fn=<AddBackward0>), tensor([[[ 1.8131, -0.8204,  1.0690,  ..., -0.6062,  0.4388, -0.8892],
         [ 0.3553, -1.4214,  0.3465,  ..., -0.1229,  0.1026,  0.6289],
         [ 1.6588, -0.5855, -0.1310,  ...,  1.0190, -0.4376, -0.4088],
         ...,
         [-1.1001,  1.4018, -0.0845,  ..., -0.4871,  0.3749, -1.0466],
         [-1.2557,  1.2836, -2.5036,  ..., -0.1603,  0.0254, -2.3484],
         [ 0.4166, -0.5125,  0.6953,  ...,  1.8050,  0.6178,  0.6728]]],
       grad_fn=<AddBackward0>), tensor([[[ 1.8357, -0.8387,  1.1291,  ..., -0.5870,  0.4266, -1.0183],
         [-0.3023, -1.8606,  1.0695,  ...,  0.3596, -0.5872,  0.5146],
         [ 1.5486, -1.3812, -0.1454,  ...,  1.4216, -0.7276, -0.3115],
         ...,
         [-0.8990,  1.3792, -0.6556,  ..., -0.6427, -0.1838, -1.0314],
         [-0.6506,  1.4321, -3.7864,  ...,  0.2906, -0.3390, -2.7433],
         [ 0.5480, -0.9662,  0.9323,  ...,  2.0826, -0.5486,  1.2011]]],
       grad_fn=<AddBackward0>), tensor([[[ 1.6162, -0.8975,  1.0517,  ..., -0.7185,  0.2539, -1.0555],
         [-0.0308, -2.1858,  1.7953,  ...,  0.5839, -1.0037,  0.0798],
         [ 1.9824, -0.7727, -0.1712,  ...,  1.7961, -1.0021, -0.3786],
         ...,
         [-1.1462,  1.0538, -1.0321,  ..., -0.0505, -0.3385, -1.3392],
         [-0.6031,  1.9507, -4.7104,  ..., -0.0331, -1.0798, -2.4425],
         [ 0.5712, -0.7698,  0.1273,  ...,  2.8240, -0.8675,  2.1530]]],
       grad_fn=<AddBackward0>), tensor([[[ 1.5710, -0.9778,  1.0983,  ..., -0.8036,  0.1757, -1.0363],
         [-0.5121, -2.1376,  1.7901,  ..., -0.0355, -0.4783,  0.1833],
         [ 2.8356, -1.5824, -0.2001,  ...,  1.8292, -0.4691, -0.2781],
         ...,
         [-1.6092,  0.1276, -1.6480,  ...,  0.7556, -2.2751, -1.2271],
         [-0.3862,  2.8926, -5.3254,  ...,  0.5635, -1.5554, -2.6868],
         [ 0.6955, -0.6462, -0.3514,  ...,  3.4493, -1.9874,  1.3638]]],
       grad_fn=<AddBackward0>), tensor([[[ 1.4392, -0.9376,  1.1554,  ..., -0.8639,  0.1171, -1.0310],
         [-0.2120, -2.0884,  2.2357,  ..., -0.8004, -0.2832, -0.2491],
         [ 2.7662, -1.6102, -0.1855,  ...,  2.3809,  0.2519, -0.4420],
         ...,
         [-1.4429, -0.1494, -0.8831,  ...,  1.2360, -1.6377, -0.8880],
         [-0.9246,  2.8136, -5.2786,  ...,  0.1955, -1.6184, -2.6251],
         [ 0.9074, -0.3075,  0.1530,  ...,  3.1575, -1.6791,  2.0776]]],
       grad_fn=<AddBackward0>), tensor([[[ 1.4151, -0.7950,  1.0212,  ..., -0.8095,  0.0292, -1.1826],
         [-0.0706, -2.0130,  1.8284,  ..., -1.0185, -0.5239, -0.3039],
         [ 2.8450, -2.3009, -0.5953,  ...,  2.0502,  1.1716, -0.2201],
         ...,
         [-1.0831, -0.3495, -0.4953,  ...,  0.7348, -1.0733, -0.3256],
         [-0.6313,  2.8501, -5.5530,  ..., -0.0141, -2.2424, -3.8297],
         [ 2.0435, -0.2091,  0.7285,  ...,  2.5350, -2.2868,  1.3605]]],
       grad_fn=<AddBackward0>), tensor([[[ 1.4463, -0.7566,  0.9623,  ..., -0.7003,  0.0289, -1.2995],
         [ 0.4330, -1.4939,  2.7411,  ..., -0.2542,  0.3714, -1.6697],
         [ 2.4653, -2.0962, -0.6611,  ...,  2.4599,  1.8867, -0.6674],
         ...,
         [ 1.2202, -2.0474,  1.7625,  ..., -0.5113,  0.7804,  1.4529],
         [ 1.0899,  0.4627, -6.6348,  ..., -2.2547, -2.7966, -4.2566],
         [ 2.2639, -0.6145,  0.7215,  ...,  1.7289, -0.9348, -0.0800]]],
       grad_fn=<AddBackward0>), tensor([[[ 1.3416, -0.6064,  0.6988,  ..., -0.6046,  0.0922, -1.5941],
         [-0.3715, -1.3355,  2.9444,  ..., -0.1253,  1.5043, -2.8058],
         [ 0.9600, -2.2277, -0.0108,  ...,  2.9812,  3.4562, -1.3117],
         ...,
         [ 2.7550, -2.8540,  3.9844,  ..., -0.4379,  2.8047,  0.9528],
         [ 1.7625, -2.2070, -7.9801,  ..., -2.1712, -3.5339, -4.6076],
         [ 2.3666, -1.7680,  0.7266,  ...,  4.0575, -0.2326, -2.1535]]],
       grad_fn=<AddBackward0>), tensor([[[ 1.2195, -0.3806,  0.3530,  ..., -0.5992,  0.3146, -1.7930],
         [-0.1447,  0.0618,  2.7296,  ...,  0.8753,  1.8019, -4.6930],
         [-0.9555, -3.1084,  1.1448,  ...,  3.5270,  4.3085,  1.1351],
         ...,
         [ 1.1198, -5.1489,  5.3349,  ...,  1.5175,  3.6925,  1.5494],
         [ 2.8521, -1.7178, -7.8211,  ..., -2.2027, -6.7088, -5.0671],
         [ 2.9345, -1.3891,  0.9643,  ...,  3.5691, -0.1766, -3.9141]]],
       grad_fn=<AddBackward0>), tensor([[[ 0.0502,  0.0018, -0.1750,  ..., -0.1020, -0.0257, -0.1292],
         [ 0.1300,  0.1757,  0.2934,  ...,  0.0794,  0.1164, -0.3280],
         [ 0.0021, -0.2481,  0.2638,  ...,  0.1507,  0.4056,  0.2376],
         ...,
         [ 0.1611, -0.4680,  0.7029,  ...,  0.1209,  0.3803,  0.2864],
         [ 0.1791, -0.3507, -1.2709,  ..., -0.1535, -0.7109, -0.2459],
         [ 0.2872, -0.0504,  0.0839,  ...,  0.3417, -0.0518, -0.3151]]],
       grad_fn=<ViewBackward0>)), attentions=None, cross_attentions=None)
output.hidden_states[0].shape
torch.Size([1, 9, 768])
output.hidden_states[1].shape
torch.Size([1, 9, 768])
output.hidden_states[2].shape
torch.Size([1, 9, 768])
len(output.hidden_states)
13
output.last_hidden_state.shape
torch.Size([1, 9, 768])
pt_model = AutoModelForCausalLM.from_pretrained(model_name)
output = pt_model(**encoding)
output
CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[ -36.3292,  -36.3402,  -40.4228,  ...,  -46.0234,  -44.5284,
           -37.1276],
         [-114.9346, -116.5035, -117.9236,  ..., -117.8857, -119.3379,
          -112.9298],
         [-123.5036, -123.0548, -127.3876,  ..., -130.5238, -130.5279,
          -123.2711],
         ...,
         [-101.3852, -101.2506, -103.6583,  ..., -103.3747, -107.7192,
           -99.4521],
         [ -83.0701,  -84.3884,  -91.9513,  ...,  -91.7482,  -93.3971,
           -85.1204],
         [ -91.2749,  -93.1332,  -93.6408,  ...,  -94.3482,  -93.4517,
           -90.1472]]], grad_fn=<UnsafeViewBackward0>), past_key_values=((tensor([[[[-7.0634e-01,  1.9011e+00,  7.7253e-01,  ..., -1.3028e+00,
           -5.0432e-01,  1.6823e+00],
          [-1.6482e+00,  3.0222e+00,  1.2789e+00,  ..., -9.0779e-01,
           -1.7395e+00,  2.4237e+00],
          [-2.3128e+00,  2.8957e+00,  1.8368e+00,  ..., -7.0370e-01,
           -1.6305e+00,  2.4407e+00],
          ...,
          [-2.4337e+00,  2.5271e+00,  2.1513e+00,  ..., -5.8053e-01,
           -1.6483e+00,  2.0594e+00],
          [-3.8223e+00,  2.1391e+00,  1.7587e+00,  ..., -1.0668e+00,
           -1.6278e+00,  1.1729e+00],
          [-1.9238e+00,  2.7944e+00,  1.6292e+00,  ..., -8.9733e-01,
           -2.2193e+00,  2.6272e+00]],

         [[-9.6153e-02,  8.9928e-01, -1.4324e+00,  ..., -3.8667e-03,
            1.7698e+00,  6.0074e-01],
          [ 2.7222e-01, -1.2016e+00, -1.9081e+00,  ..., -1.3531e+00,
            1.2823e+00, -4.3198e-01],
          [-1.1722e+00, -3.6670e-01, -1.6921e+00,  ..., -1.2359e+00,
            2.5243e+00,  1.0228e+00],
          ...,
          [-1.6694e-01, -1.0159e+00, -2.5232e+00,  ..., -9.7920e-01,
            4.8265e+00, -1.7799e+00],
          [-1.1981e-01, -2.6784e+00, -2.9551e+00,  ..., -1.9840e-01,
            3.3916e+00, -1.9762e-02],
          [ 3.2722e-01, -1.2197e+00, -2.1079e+00,  ..., -1.6297e+00,
            9.2404e-01, -7.6080e-01]],

         [[-1.4670e-01,  2.1407e-01,  1.1498e+00,  ..., -1.3128e+00,
           -2.1007e+00,  5.6910e-01],
          [ 5.5608e-01, -4.6297e-01,  7.4483e-01,  ..., -1.8272e+00,
            5.4572e-01,  1.0119e+00],
          [ 9.2851e-01,  4.6049e-03,  4.1324e-01,  ..., -2.4987e+00,
            5.2423e-01,  1.5260e+00],
          ...,
          [ 3.2328e-01,  3.5316e-01,  3.2756e-02,  ..., -3.2780e+00,
            8.1692e-01,  1.4566e+00],
          [-2.1528e-01, -2.2490e-01, -1.4536e+00,  ..., -3.7075e+00,
            1.6835e+00,  1.6085e+00],
          [ 7.6672e-01, -5.3757e-01,  4.2462e-01,  ..., -2.2908e+00,
            1.7213e+00,  1.0240e+00]],

         ...,

         [[ 5.4733e-01,  4.7672e-01, -2.2749e-01,  ...,  2.9014e-01,
            7.7821e-01,  7.8295e-01],
          [ 1.6820e-01, -9.1829e-02, -5.0034e-02,  ...,  7.3646e-01,
            6.1343e-01,  5.4442e-01],
          [ 2.9530e-02, -5.3167e-02, -6.1709e-02,  ...,  1.0934e+00,
            3.7083e-01,  3.8425e-01],
          ...,
          [-1.3203e-02, -2.6465e-01,  4.4834e-02,  ...,  1.2205e+00,
            5.4265e-01,  3.7732e-01],
          [ 8.5854e-02, -2.3791e-01, -1.1271e-01,  ...,  1.8211e+00,
           -5.7249e-01, -7.4493e-01],
          [-3.6544e-02, -1.4250e-01,  6.6582e-02,  ...,  1.0489e+00,
            4.8485e-01,  4.6476e-01]],

         [[ 1.4700e+00,  1.3564e+00, -4.9892e-01,  ..., -6.4925e-02,
            1.4507e+00, -1.2267e+00],
          [ 1.0113e+00,  7.0108e-01, -5.7364e-01,  ..., -7.1721e-01,
            1.0731e+00, -1.0718e+00],
          [ 1.1010e+00,  4.8299e-01, -9.3231e-01,  ..., -1.5044e+00,
            1.2941e+00, -3.3869e-01],
          ...,
          [ 1.1745e+00,  6.3323e-01, -6.1605e-01,  ..., -8.1925e-01,
            5.2691e-01, -7.5443e-01],
          [ 1.7895e+00,  5.7095e-01, -3.5775e-01,  ..., -1.3193e+00,
            5.5676e-01, -1.6293e-01],
          [ 9.6151e-01,  2.9245e-02, -5.3493e-01,  ..., -7.8683e-01,
            3.7355e-01, -2.4032e-01]],

         [[ 7.1643e-01, -3.1278e-01,  1.4058e-01,  ..., -2.0734e-01,
            2.5946e-01,  1.7684e+00],
          [-5.6619e-01,  7.8687e-01,  2.5152e-02,  ...,  6.2100e-01,
            4.7592e-01,  5.4321e-01],
          [-6.2611e-01,  3.3320e-01,  1.1092e-01,  ...,  6.4703e-01,
            6.4159e-01,  7.2777e-01],
          ...,
          [-1.7180e-01,  1.1778e+00, -2.3931e-01,  ..., -6.3932e-01,
            1.1654e+00,  4.0462e-01],
          [-4.8319e-01,  2.8237e-01, -4.4490e-01,  ..., -1.2013e-01,
            4.8413e-01, -4.5133e-01],
          [-1.1252e+00,  7.6533e-01, -6.0320e-02,  ...,  1.8912e-01,
            7.8018e-01, -5.4733e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[ 0.1900,  0.0015, -0.0517,  ...,  0.0536,  0.0312, -0.0694],
          [-0.0800,  0.0181, -0.0534,  ..., -0.0419, -0.0365,  0.0151],
          [ 0.0448,  0.1912, -0.1849,  ..., -0.0062, -0.1420,  0.1609],
          ...,
          [-0.1635,  0.0196,  0.1185,  ...,  0.0794,  0.0980, -0.1084],
          [-0.2303,  0.1991, -0.1576,  ...,  0.2774, -0.1813, -0.2463],
          [-0.1009,  0.0410, -0.0970,  ..., -0.0684, -0.0763,  0.0260]],

         [[ 0.4406,  0.1176, -0.2136,  ..., -0.6839, -0.2371,  0.2999],
          [ 0.5926,  0.0197,  0.1107,  ...,  0.1253,  0.5675, -0.2665],
          [ 0.6762,  0.0459, -0.3685,  ...,  0.0744,  0.5420, -0.1240],
          ...,
          [ 0.8509, -0.0962,  0.0762,  ..., -0.1705,  0.1339,  0.1068],
          [ 0.2928, -0.2582,  0.1735,  ...,  0.0800,  0.2879, -0.0139],
          [ 0.5969,  0.0592,  0.0263,  ..., -0.0100,  0.5129, -0.1905]],

         [[ 0.0810, -0.1910,  0.1092,  ..., -0.0283,  0.0408,  0.0961],
          [-0.3257,  0.0398, -0.1531,  ...,  0.0411, -0.0413,  0.0745],
          [ 0.5201,  0.0126,  0.3504,  ...,  0.1020,  0.0543, -0.2188],
          ...,
          [-0.5288, -0.0025, -0.5926,  ..., -0.1874, -0.0674,  0.3113],
          [ 0.1521,  0.0271, -0.2514,  ..., -0.0465, -0.0565, -0.3401],
          [-0.2885,  0.0590, -0.1736,  ...,  0.0685, -0.1112,  0.0604]],

         ...,

         [[ 0.0111, -0.0168,  0.0263,  ..., -0.2135,  0.2054,  0.0729],
          [-0.3022, -0.0878,  0.1001,  ...,  0.0262, -0.1647,  0.1682],
          [-0.1587, -0.0666,  0.0826,  ..., -0.0416,  0.0812,  0.2067],
          ...,
          [-0.0925, -0.4836,  0.0332,  ...,  0.0641, -0.1597,  0.2375],
          [-0.0742,  0.8589,  0.0336,  ..., -0.3268, -0.2455,  0.3080],
          [-0.0869, -0.4287,  0.1231,  ..., -0.0474, -0.1705,  0.0347]],

         [[ 0.2081, -0.2399, -0.1318,  ...,  0.1471,  0.1123, -0.0316],
          [-0.2119,  0.0589,  0.0997,  ...,  0.0038,  0.1331,  0.0930],
          [-0.1213,  0.1404,  0.1775,  ...,  0.1688, -0.0020,  0.0829],
          ...,
          [-0.2325,  0.1252, -0.0345,  ...,  0.2837,  0.0686, -0.0089],
          [ 0.1896,  0.0282, -0.0740,  ...,  0.1655, -0.3020,  0.2837],
          [ 0.0298,  0.0086, -0.1626,  ...,  0.1976,  0.0970, -0.0014]],

         [[-0.0689, -0.3955,  0.2328,  ...,  0.1539, -0.1823, -0.0845],
          [ 0.0538, -0.2648, -0.0146,  ...,  0.2331,  0.0516,  0.0924],
          [-0.0647,  0.0062,  0.1329,  ...,  0.1026,  0.1185,  0.0463],
          ...,
          [ 0.0186,  0.1904, -0.0966,  ...,  0.0714, -0.0321, -0.0059],
          [ 0.0219,  0.4180, -0.1580,  ..., -0.0072, -0.2708,  0.1529],
          [ 0.1236, -0.3671, -0.0392,  ...,  0.1061, -0.0278, -0.0074]]]],
       grad_fn=<PermuteBackward0>)), (tensor([[[[-3.5429e-01,  2.2092e+00, -1.5580e+00,  ...,  1.4397e+00,
           -1.1504e+00,  1.4646e+00],
          [ 7.3885e-01,  1.8177e+00, -1.4766e+00,  ..., -4.6761e-01,
           -1.6869e+00,  5.0785e-01],
          [ 1.6962e+00,  1.1427e+00, -1.1112e+00,  ...,  1.2764e-01,
           -2.5909e+00,  7.2933e-01],
          ...,
          [-1.9130e-03,  1.6441e+00, -3.0120e-01,  ...,  3.8508e-01,
           -1.0645e+00, -4.5135e-01],
          [-3.9438e-01,  1.6005e+00,  9.6257e-01,  ...,  5.8858e-01,
           -1.8425e+00, -9.6318e-01],
          [-4.9488e-01,  1.1094e+00,  5.2522e-02,  ...,  5.6471e-01,
           -1.3969e+00, -3.0882e-01]],

         [[-1.0087e+00, -4.5958e-01, -7.4797e-01,  ..., -3.7310e-01,
            7.9809e-01, -2.3881e-01],
          [-6.6438e-02,  4.8658e-01, -8.2457e-01,  ..., -9.4308e-01,
            1.8907e-01, -1.5256e-02],
          [-1.7392e-01,  1.1992e+00, -1.5513e+00,  ..., -3.2774e-01,
            7.3627e-01, -3.6968e-01],
          ...,
          [-1.1986e-01,  6.0111e-01, -1.4226e+00,  ..., -6.1346e-01,
            1.3460e-01, -6.1240e-01],
          [ 1.8174e-01,  3.1973e-01, -2.2986e+00,  ..., -4.1319e-01,
           -1.0757e+00, -4.7605e-01],
          [-2.4593e-01,  1.1035e+00, -1.4215e+00,  ..., -6.2691e-01,
           -1.1097e+00, -6.3956e-01]],

         [[ 3.2591e-01, -1.6143e-02, -2.0098e-01,  ..., -1.3362e+00,
            3.3876e-01, -1.6542e-01],
          [-1.0002e-02,  3.9666e-01, -9.3499e-02,  ..., -1.0921e+00,
            5.6914e-02,  4.1318e-01],
          [-1.1656e-02,  2.1262e-01, -2.3546e-01,  ..., -9.7254e-01,
            1.4688e-01,  2.7869e-01],
          ...,
          [-8.3349e-02,  3.9433e-02, -9.7432e-03,  ..., -7.0562e-01,
            4.2687e-01,  2.3274e-01],
          [ 1.0450e-01, -2.0783e-01, -2.8860e-01,  ..., -1.0073e+00,
           -1.2179e-01,  3.5471e-01],
          [-1.4484e-01, -5.0447e-02, -3.9541e-03,  ..., -1.0255e+00,
            1.9039e-01,  3.3890e-01]],

         ...,

         [[ 2.1528e-01, -4.6627e-01, -5.9642e-01,  ..., -4.2178e-01,
            4.3739e-01, -8.5899e-01],
          [-5.0305e-02,  1.2479e+00,  1.8768e+00,  ...,  6.8503e-01,
           -7.3186e-01, -3.4076e-01],
          [-4.0512e-01,  1.6082e+00,  1.8570e+00,  ...,  1.2636e+00,
           -1.1781e+00, -8.1034e-01],
          ...,
          [-5.1299e-01,  2.6865e-01,  7.6903e-01,  ..., -1.3940e+00,
            8.1194e-01, -1.8763e-01],
          [ 2.3526e-01, -5.7615e-01,  1.3541e+00,  ...,  1.4708e+00,
           -2.9934e-01, -3.9407e-01],
          [ 5.0755e-02,  7.0489e-01,  1.9166e+00,  ...,  6.6883e-01,
           -9.1450e-01, -2.5584e-01]],

         [[-1.1473e+00, -2.7966e+00,  1.4438e-01,  ...,  1.7208e+00,
            1.5965e+00, -1.4860e+00],
          [ 3.5231e-01,  7.5960e-01, -4.7429e-01,  ..., -8.1442e-01,
            4.5442e-01, -2.9752e-01],
          [ 2.1113e-01,  7.5264e-01, -4.5093e-01,  ..., -9.6233e-01,
            5.8766e-01,  9.0545e-02],
          ...,
          [ 1.6897e-01,  2.5023e-01, -7.4581e-01,  ..., -1.2799e-01,
            7.1349e-01, -8.5998e-02],
          [-2.3828e-01,  5.9684e-01, -7.5936e-01,  ..., -6.6564e-01,
            7.3313e-01,  1.8287e-01],
          [-1.6440e-01,  2.5931e-01, -8.1777e-01,  ..., -3.5322e-01,
            8.3564e-01, -5.9446e-02]],

         [[ 1.3976e+00,  1.6241e+00,  5.4245e-01,  ..., -7.8420e-01,
            1.1678e-01,  3.7706e-01],
          [ 8.8908e-01,  2.1345e+00,  1.0939e+00,  ...,  1.1961e-01,
           -7.5297e-01, -1.4081e-01],
          [ 6.7893e-01,  1.8408e+00,  1.5060e+00,  ...,  5.9498e-01,
           -2.2553e+00, -1.8270e+00],
          ...,
          [-5.1015e-02,  2.4946e+00, -1.6883e-01,  ...,  5.4761e-01,
           -2.8891e-01, -6.7954e-01],
          [-1.6942e-01,  4.9026e-01,  1.1144e+00,  ...,  9.3912e-03,
           -8.0171e-01, -1.4243e-01],
          [ 8.4424e-01,  1.7401e+00,  9.2639e-01,  ..., -1.4967e-01,
           -3.8360e-01, -1.5520e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[ 3.3872e-01,  1.3968e-01, -1.7938e-01,  ...,  1.5467e-01,
           -1.2589e-01,  7.0887e-02],
          [ 3.7346e-01,  2.8615e-01,  7.3073e-02,  ..., -1.7334e-01,
           -1.7929e-01,  8.0809e-02],
          [ 1.3121e-01,  1.3779e-01,  9.8802e-02,  ...,  1.7611e-01,
           -6.5489e-01, -3.7171e-01],
          ...,
          [ 4.5774e-01,  6.2110e-02,  4.7204e-02,  ...,  2.1876e-01,
           -1.9506e-01,  1.5526e-01],
          [-1.6503e-01,  7.2050e-02, -4.4076e-01,  ...,  9.3966e-02,
           -8.1660e-02, -2.9702e-01],
          [ 3.7986e-01,  3.8336e-01,  1.0341e-01,  ..., -1.9899e-01,
           -2.3373e-01, -1.3201e-01]],

         [[-7.9321e-02, -6.6966e-02, -2.2227e-01,  ..., -1.4152e-02,
           -4.5964e-01,  2.7340e-01],
          [-2.0632e-01, -2.7675e-01,  9.3918e-02,  ..., -9.7495e-02,
            2.0266e-01,  3.4913e-02],
          [-3.6562e-01, -2.8439e-01,  2.9782e-01,  ..., -1.0605e+00,
            2.7564e-01,  3.3809e-01],
          ...,
          [ 5.1779e-01,  2.3170e-01, -3.0248e-01,  ...,  4.6880e-01,
            4.3330e-01, -6.2105e-01],
          [-1.9805e-01,  6.8445e-02, -5.7586e-02,  ...,  1.3844e-01,
           -6.2666e-02,  1.8667e-01],
          [ 6.9782e-02, -1.5278e-01,  6.9243e-02,  ..., -1.0944e-01,
            1.1224e-01,  1.1524e-01]],

         [[ 7.9376e-02, -1.4863e-02, -4.4028e-02,  ..., -6.2825e-01,
            6.7840e-02,  1.0440e-02],
          [ 4.2720e-01,  2.4379e-01,  2.3040e-01,  ..., -5.0812e-01,
            3.7279e-02, -1.3192e-01],
          [ 6.2018e-01,  1.7793e-01,  2.9474e-01,  ..., -7.6162e-01,
           -2.8552e-01, -1.4080e-01],
          ...,
          [ 5.8184e-01,  5.9326e-02,  2.5048e-03,  ..., -6.1473e-01,
           -3.0034e-02,  4.4224e-02],
          [ 6.7462e-01,  1.3863e-01, -5.1645e-02,  ..., -5.6261e-01,
           -2.2474e-01, -1.2376e-01],
          [ 6.0415e-01,  9.6460e-02,  1.1331e-01,  ..., -2.8026e-01,
            2.4650e-02, -2.4321e-01]],

         ...,

         [[ 1.0567e-01,  6.7946e-01, -1.7619e-01,  ...,  1.2480e-02,
           -9.7338e-01, -2.5708e-01],
          [-5.0101e-04, -7.4670e-01,  1.4215e-01,  ...,  2.6520e-02,
           -9.1824e-01, -4.4347e-01],
          [ 5.7162e-02, -6.6084e-01, -1.7225e-01,  ..., -6.7773e-02,
           -6.9370e-01,  2.2682e-01],
          ...,
          [ 3.6897e-01,  4.0040e-01,  1.3203e-01,  ...,  5.9832e-02,
           -4.3946e-01,  3.3851e-02],
          [-1.9931e-01,  4.7522e-01,  6.5326e-01,  ...,  8.5060e-01,
           -1.5948e-01,  2.6952e-01],
          [ 4.5483e-02, -7.9412e-01,  2.0943e-01,  ...,  6.4299e-02,
           -6.5777e-01, -2.0458e-01]],

         [[ 4.7333e-02, -1.1130e-02, -1.4608e-01,  ...,  3.8364e-01,
           -3.4244e+00,  6.6758e-02],
          [ 5.0051e-01,  8.4673e-03,  1.9747e-01,  ...,  2.1474e-01,
           -7.4449e-03, -2.8373e-01],
          [-2.0428e-01,  2.4512e-01, -2.7017e-01,  ...,  4.5577e-02,
            2.1612e-02, -1.3106e-01],
          ...,
          [ 7.3244e-02, -1.5794e-01,  1.7578e-01,  ..., -2.2690e-01,
           -6.3669e-02, -1.8729e-02],
          [ 1.3369e-01,  4.0795e-01, -6.9403e-02,  ..., -2.8477e-02,
            8.1580e-02, -3.7645e-01],
          [ 3.2948e-01,  2.4525e-01,  3.1002e-02,  ...,  1.4547e-03,
           -2.0459e-01, -1.3566e-02]],

         [[ 2.4439e-02, -2.3092e-01,  1.1163e-02,  ..., -3.4285e-01,
            2.7007e-01, -3.4211e-02],
          [ 2.0095e-01, -4.9356e-01,  5.3058e-01,  ..., -2.7157e-01,
            4.2807e-01,  3.2917e-01],
          [-1.0993e-01, -4.1360e-01,  1.9816e-02,  ..., -1.7917e-01,
            3.6033e-01,  2.2954e-01],
          ...,
          [ 4.2263e-02,  1.5875e-02, -3.0871e-01,  ..., -3.1441e-01,
            2.9030e-01,  2.2213e-01],
          [-4.9536e-02,  8.3578e-02,  7.2786e-02,  ..., -2.5493e-01,
            4.7891e-02,  3.4251e-01],
          [ 5.0301e-02, -1.8544e-01,  5.7551e-01,  ..., -3.4349e-01,
            1.5927e-01,  4.2942e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-1.5217e-01, -1.1477e+00,  2.3295e-01,  ..., -6.4279e-01,
           -1.1349e-01,  4.0799e-02],
          [ 4.5919e-01, -2.0374e+00, -7.9378e-01,  ...,  4.4668e-02,
           -8.8579e-01, -9.0097e-01],
          [ 3.8866e-01, -1.6082e+00, -3.9608e-01,  ...,  3.1908e-01,
           -4.2160e-01, -1.1912e-01],
          ...,
          [-8.1627e-02, -1.0257e+00, -6.6449e-01,  ...,  6.6261e-01,
           -1.8242e-01, -5.9660e-02],
          [ 9.9366e-01, -2.8990e+00, -4.2770e-01,  ...,  1.5473e+00,
           -2.7730e-01,  1.0212e+00],
          [ 3.7402e-01, -1.2451e+00, -8.3321e-01,  ...,  1.5307e+00,
           -6.0831e-01, -1.0434e+00]],

         [[-5.0563e-01,  3.4884e-01, -4.0126e-01,  ...,  1.2945e+00,
           -5.5872e-01, -4.4031e-01],
          [-1.0783e+00, -1.0583e+00, -8.7019e-01,  ...,  9.3939e-01,
            6.1988e-01, -3.6133e-01],
          [-1.4605e+00,  7.9834e-04, -1.6445e+00,  ...,  8.5405e-01,
            1.1266e+00,  2.1244e-01],
          ...,
          [-1.7653e+00, -4.5490e-01,  5.8049e-01,  ...,  1.3604e-01,
           -2.6502e-01,  1.4497e+00],
          [-2.7539e+00, -1.9189e+00, -6.1803e-01,  ...,  2.3083e+00,
           -6.2625e-01, -5.0954e-01],
          [-8.4786e-01, -9.9176e-01, -1.4226e+00,  ...,  1.0424e+00,
            1.2138e+00, -6.2367e-01]],

         [[ 1.3477e+00,  3.0343e+00,  3.7258e+00,  ...,  6.1286e-01,
            1.7142e+00, -7.4960e-01],
          [-3.4424e+00,  2.1578e+00, -3.4773e+00,  ..., -1.7704e+00,
            3.4858e+00,  9.8086e-01],
          [-3.3403e+00,  7.3066e-01, -4.6132e+00,  ..., -3.2065e+00,
            5.3039e+00,  7.1677e-01],
          ...,
          [-4.8998e+00, -5.9784e-01, -2.9574e+00,  ..., -4.1010e+00,
            2.4786e+00,  2.7664e-02],
          [-3.3274e+00, -1.2454e+00, -5.1031e+00,  ..., -3.2964e+00,
            3.3057e+00,  1.4853e+00],
          [-4.2024e+00, -1.7287e+00, -5.1702e+00,  ..., -2.7123e+00,
            2.8922e+00,  1.8391e+00]],

         ...,

         [[ 1.3818e+00, -2.7867e+00, -2.6519e+00,  ...,  9.1555e-01,
            4.4077e-01,  2.7028e+00],
          [-2.4026e+00,  1.6620e+00, -4.5219e-01,  ...,  1.2064e-01,
           -1.6484e+00,  5.6717e-01],
          [-1.7379e+00,  2.8888e+00,  2.1535e-01,  ..., -7.8397e-02,
           -2.7045e+00, -3.0823e-03],
          ...,
          [-2.9426e+00,  3.5565e+00,  1.0280e+00,  ..., -3.5420e-01,
           -3.7917e+00, -7.8773e-01],
          [-2.8640e+00,  2.8314e+00,  2.3865e+00,  ..., -2.2468e+00,
           -4.0705e+00, -1.2861e+00],
          [-3.9137e+00,  4.3675e+00,  1.5171e+00,  ..., -6.0161e-01,
           -2.7414e+00, -1.2265e+00]],

         [[ 1.7415e+00,  4.5990e-01,  9.3163e-01,  ...,  1.2650e-03,
           -9.8961e-01, -2.9552e-01],
          [ 2.2626e+00,  1.0377e+00,  1.1163e+00,  ...,  3.4995e-01,
           -2.5767e+00, -1.2164e+00],
          [ 2.0896e+00,  6.8649e-01,  1.2068e+00,  ...,  4.1762e-01,
           -2.1005e+00, -1.2765e+00],
          ...,
          [ 1.8625e+00,  5.6272e-01,  1.1284e+00,  ...,  3.5132e-01,
           -2.0787e+00, -1.0202e+00],
          [ 2.2705e+00,  3.2166e-01,  1.1907e+00,  ...,  2.6156e-01,
           -1.2966e+00, -9.9152e-01],
          [ 2.3024e+00,  4.0813e-01,  9.6441e-01,  ...,  4.9377e-01,
           -2.5960e+00, -6.9144e-01]],

         [[-2.2407e-01,  1.4293e-01, -5.5406e-01,  ...,  3.1676e-01,
            2.7494e-01,  1.6436e-01],
          [-5.7508e-01,  6.1265e-01, -2.6713e-01,  ...,  8.0278e-01,
            8.5041e-01,  1.8214e-01],
          [ 6.2629e-01,  3.5029e-02,  8.6408e-02,  ...,  4.6667e-01,
            1.6070e-01,  1.2988e-01],
          ...,
          [ 1.5542e-01, -2.5139e-01, -8.1318e-01,  ...,  2.1838e-01,
            2.0266e-01,  6.9734e-01],
          [-2.4867e-01,  4.2143e-01, -4.6590e-01,  ...,  3.0348e-01,
            5.7653e-01, -5.7979e-01],
          [-4.1779e-01, -4.9530e-01, -6.0749e-01,  ...,  5.8660e-01,
            9.1405e-01, -3.4966e-02]]]], grad_fn=<PermuteBackward0>), tensor([[[[-1.5059e-02, -2.1934e-02, -1.3257e-01,  ..., -3.3233e-03,
            5.6872e-03, -5.5921e-01],
          [-4.4076e-01,  4.7031e-01, -2.1116e-01,  ...,  5.7315e-01,
           -3.8024e-01,  2.5338e-01],
          [ 2.7640e-01,  1.0290e-01, -1.5030e-01,  ...,  8.0443e-02,
           -1.0340e-02,  6.5651e-01],
          ...,
          [ 7.7904e-01,  1.2082e+00,  3.0358e-01,  ...,  4.4578e-01,
           -4.0582e-02,  8.5044e-01],
          [-2.0731e-01, -5.8119e-01,  4.1100e-01,  ..., -1.7157e-01,
            2.8487e-01,  6.4911e-01],
          [-8.6411e-01,  5.4967e-01, -4.1298e-01,  ...,  9.2813e-01,
           -4.2606e-01, -3.4161e-01]],

         [[ 3.8557e-02,  3.3662e-03,  5.4482e-02,  ..., -5.7578e-02,
           -7.4123e-02,  2.2392e-02],
          [ 1.9386e-01,  1.8534e-01,  3.0680e-01,  ..., -1.2764e-03,
           -2.5348e-01,  8.6118e-02],
          [-1.4242e-01,  3.2992e-01,  7.6395e-02,  ...,  9.8633e-02,
           -5.6915e-02,  4.4799e-02],
          ...,
          [-7.1944e-02,  3.8884e-02,  1.0161e-01,  ..., -2.7253e-01,
            1.3398e-01,  1.1796e-01],
          [-1.0896e+00,  2.1403e+00, -1.3890e-01,  ...,  1.0035e+00,
            6.1333e-01, -1.1536e+00],
          [ 6.1611e-02,  7.1527e-02,  2.0043e-01,  ..., -3.5723e-01,
           -1.4230e-01,  8.4502e-02]],

         [[ 1.1201e-02, -7.6654e-01, -1.1583e-02,  ...,  4.3143e-02,
            1.5736e-02, -5.8100e-02],
          [ 2.8462e-01, -1.0610e+00,  1.2486e-01,  ...,  3.1588e-02,
           -1.1913e-01, -4.8153e-02],
          [ 2.6008e-01, -6.3008e-01, -8.1709e-01,  ...,  1.8586e-01,
            3.4370e-01,  9.2477e-01],
          ...,
          [-1.9891e-01, -1.9001e+00, -4.4621e-02,  ...,  7.8242e-02,
            2.2361e-02,  1.3589e-02],
          [-2.8968e-01, -1.5899e+00,  9.2801e-02,  ..., -2.7827e-01,
            1.6159e-01, -4.6007e-01],
          [ 1.6971e-01, -1.5136e+00,  1.2845e-01,  ..., -6.2768e-02,
           -2.5769e-01, -1.5622e-01]],

         ...,

         [[ 1.6522e-02, -7.7326e-02,  1.3163e+00,  ..., -5.6423e-02,
            1.7141e-01,  2.1386e-02],
          [-4.3988e-01, -2.9255e-01,  2.4116e+00,  ..., -1.8846e-01,
            1.0912e-01,  1.4147e-01],
          [ 2.3190e-01, -1.5369e-01,  2.5701e+00,  ...,  6.3039e-01,
           -1.0088e-01,  5.1586e-01],
          ...,
          [ 1.7250e-02,  6.7580e-01,  2.5971e+00,  ..., -5.2273e-01,
            4.5050e-01, -6.9956e-01],
          [-4.9545e-02,  6.1819e-01,  3.8825e-01,  ..., -1.4691e-01,
            4.5526e-01,  7.1271e-01],
          [-1.9639e-01, -1.2515e-01,  2.5813e+00,  ..., -1.8536e-01,
           -1.3485e-01, -8.7375e-02]],

         [[ 7.4395e-02, -8.7165e-02, -1.8260e-01,  ...,  1.3185e-01,
            1.2575e-01,  1.7169e-01],
          [ 6.5960e-01,  1.0117e+00,  7.1659e-01,  ...,  8.3512e-02,
           -6.5585e-01, -3.3111e-01],
          [ 3.2666e-01, -1.2571e-01,  8.1719e-01,  ...,  9.9527e-01,
           -1.0291e+00, -5.0537e-01],
          ...,
          [-7.2666e-01,  1.0662e-01, -7.2195e-02,  ..., -2.7005e-01,
            5.2628e-01,  2.3005e-01],
          [-2.0959e-01, -2.3959e-01, -3.0772e-01,  ...,  4.6964e-01,
           -1.8979e-01, -2.7418e-01],
          [ 1.7468e-01,  1.0415e+00,  7.5772e-01,  ..., -4.9262e-01,
           -8.0868e-01,  4.5074e-01]],

         [[ 1.1606e-02,  2.1828e-02,  2.7971e-02,  ..., -3.3218e-02,
            2.2172e-01, -2.3344e-03],
          [ 1.1778e-01, -3.0263e-01,  3.5408e-01,  ..., -3.3052e-01,
           -1.9086e+00,  4.3385e-01],
          [-7.0245e-01,  4.2293e-02, -1.3216e-01,  ...,  3.4737e-01,
           -1.4905e+00,  3.5105e-01],
          ...,
          [ 2.1967e-01, -6.0979e-01, -6.8996e-01,  ...,  4.4944e-01,
           -1.9601e+00, -1.7819e-01],
          [ 3.8903e-01,  1.9728e-01, -9.0256e-01,  ...,  1.3781e-01,
           -2.0059e+00,  3.0071e-01],
          [ 5.9661e-01, -3.1890e-01, -2.2125e-01,  ...,  2.8531e-01,
           -1.8048e+00,  2.1086e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[ 0.0532, -0.2197,  0.1445,  ..., -0.8884,  0.7361, -1.2044],
          [-0.6462, -0.7026, -1.4285,  ...,  0.2179, -0.3014,  0.1623],
          [-0.8909, -1.9166, -1.3314,  ..., -1.8027, -2.7636,  2.9528],
          ...,
          [-1.9169, -0.2602, -0.2397,  ..., -0.4901, -0.8816,  0.7061],
          [-2.0792,  0.1064,  0.6011,  ...,  0.5948, -0.5403,  1.4379],
          [-0.4271, -0.4968, -0.0297,  ...,  1.0395, -0.3829,  0.3067]],

         [[ 0.7842,  0.1905,  0.0089,  ..., -0.1612, -1.0898, -0.1939],
          [-1.3909, -1.5235, -0.5037,  ...,  0.9582,  4.2044,  1.1825],
          [ 0.1689, -1.8025,  0.8404,  ...,  1.5177,  5.7815,  2.1470],
          ...,
          [ 1.2462, -1.4013, -1.2263,  ...,  0.5912,  6.0711,  1.7328],
          [ 1.4548, -2.0760, -2.0483,  ..., -1.5971,  5.6172,  2.5548],
          [-1.1053, -0.8554, -2.0471,  ...,  0.8743,  6.2095,  1.1606]],

         [[ 0.3413, -0.3572, -0.3331,  ...,  0.3294,  1.4604,  0.2755],
          [ 0.0960, -6.2139, -0.6779,  ..., -2.8446, -1.4388, -4.4836],
          [-0.8714, -7.8835, -1.6969,  ..., -2.1200, -2.1704, -7.2160],
          ...,
          [-3.2255, -7.0802, -1.8176,  ..., -2.8620, -2.7388, -5.1880],
          [-2.2788, -5.5723, -1.6649,  ..., -3.3594, -2.4676, -5.1028],
          [-2.9788, -7.2411, -1.0434,  ..., -3.2540, -2.9263, -5.0116]],

         ...,

         [[ 0.2148,  1.7719,  0.5129,  ...,  0.2612,  0.4477, -1.6895],
          [-0.2874, -5.8026,  1.1293,  ..., -2.2826, -1.7007,  5.5452],
          [-2.4104, -6.5778,  1.1952,  ..., -2.4193, -0.3969,  3.8159],
          ...,
          [-1.4026, -7.7514,  1.2659,  ..., -3.4256, -2.3786,  6.9488],
          [-1.0623, -5.7453,  0.1012,  ..., -0.5622, -2.4292,  6.8565],
          [-0.3079, -7.9204,  1.8029,  ..., -3.2453, -2.3462,  7.0537]],

         [[ 0.0559, -0.0269,  0.1386,  ..., -0.1165, -0.0882, -0.1612],
          [ 0.1342, -0.5329, -0.2255,  ..., -1.0159,  0.1003, -0.4600],
          [-0.7412, -0.2755,  0.1787,  ..., -0.8159, -0.9071, -0.1041],
          ...,
          [-0.0215, -0.5192, -0.2004,  ...,  0.3272, -0.3216,  0.5758],
          [ 0.2406, -0.3252,  0.3839,  ..., -0.2115,  0.3593, -0.6457],
          [-0.6898, -1.1861, -0.0238,  ...,  0.5217,  0.0940,  0.9089]],

         [[ 0.3939, -0.0741,  1.9091,  ..., -0.2314, -0.2112, -0.9825],
          [ 2.5678,  1.8706, -2.0184,  ...,  0.0582,  0.5182,  2.5282],
          [ 3.1803,  2.0001, -2.9358,  ...,  2.6552,  1.0590,  4.2195],
          ...,
          [ 2.6593,  1.2215, -2.5623,  ...,  1.4338,  0.6112,  3.2894],
          [ 1.1448,  0.9766, -2.1789,  ...,  1.8788,  0.3242,  3.7226],
          [ 2.8828,  1.7918, -3.5229,  ..., -0.0936,  0.5881,  4.5368]]]],
       grad_fn=<PermuteBackward0>), tensor([[[[ 4.5583e-02,  6.3733e-02, -3.4908e-03,  ...,  3.8117e-03,
            1.0385e-01,  2.3468e-02],
          [ 2.7193e-01, -2.7436e-01,  8.2051e-01,  ..., -5.7602e-01,
           -6.8246e-02, -1.7190e-02],
          [-3.3605e-01, -9.2270e-01, -2.9339e-01,  ...,  3.2747e-02,
            5.3266e-01, -1.1793e+00],
          ...,
          [ 4.6659e-01,  1.6959e-01, -6.8990e-02,  ...,  3.2092e-01,
            2.9894e-02,  4.5212e-03],
          [-3.8838e-01, -2.8303e-01,  6.4867e-01,  ...,  5.4443e-01,
           -3.8750e-03, -7.7317e-01],
          [-1.4669e-01, -2.2234e-01,  5.0309e-01,  ..., -2.0195e-01,
           -3.4870e-02,  1.0260e+00]],

         [[-3.8964e-02, -8.6139e-03,  9.1636e-02,  ..., -4.5061e-02,
           -1.8257e-02, -4.4496e-02],
          [-1.2398e-01, -4.6354e-01,  6.3162e-02,  ...,  4.1472e-01,
           -8.8383e-02, -6.1835e-02],
          [ 2.3124e-01, -4.1944e-01, -5.5628e-02,  ..., -6.5586e-01,
           -2.9434e-01,  1.1322e-01],
          ...,
          [ 9.0615e-02, -2.5366e-01, -1.7453e-01,  ...,  3.6981e-02,
            9.6252e-02,  2.8861e-01],
          [ 2.6449e-01, -1.1997e+00, -2.9121e-01,  ...,  1.8929e-01,
            8.9705e-01,  5.2265e-02],
          [ 1.8653e-01, -4.1886e-01, -2.5386e-01,  ...,  5.6907e-01,
           -5.6461e-01, -2.9499e-01]],

         [[ 4.5761e-02, -1.1113e-01, -6.0327e-02,  ..., -1.7311e-02,
            8.8352e-02, -1.4918e-01],
          [ 3.5832e-01,  1.0048e-01, -3.5981e-01,  ...,  4.7004e-01,
           -1.0480e-01, -9.6169e-01],
          [-1.2025e+00, -4.9562e-01, -5.6530e-01,  ..., -7.7073e-02,
           -1.8603e-01,  4.5677e-02],
          ...,
          [-1.1527e-01, -1.2046e-02,  7.9755e-01,  ...,  2.0678e-01,
           -1.6562e-01, -9.4135e-02],
          [ 3.0203e-01, -5.3025e-02,  1.0025e-01,  ..., -1.3117e-01,
           -3.9940e-01,  2.0309e-01],
          [ 5.4948e-01, -3.1714e-03, -9.9666e-01,  ...,  3.6800e-01,
            2.6345e-01, -6.6638e-01]],

         ...,

         [[-2.2513e-02,  1.1954e-01, -1.7875e-02,  ..., -1.4198e-02,
            6.4433e-02, -5.2401e-02],
          [-6.7643e-03,  1.3038e-01, -3.1770e-02,  ..., -2.8075e-02,
           -7.0123e-02,  2.9359e-01],
          [ 7.8513e-01, -7.9053e-01, -1.5511e-01,  ..., -3.0193e-01,
           -5.3295e-02,  5.1889e-01],
          ...,
          [-1.9707e-01,  5.0177e-02, -1.1185e-01,  ..., -3.0111e-01,
            2.1017e-01, -2.7775e-01],
          [-3.1374e-01, -3.5912e-02, -2.5133e-01,  ..., -1.2073e-01,
            1.3938e-01, -1.4568e-01],
          [-1.2432e-01,  3.0442e-01,  1.0542e-01,  ...,  2.1967e-02,
            3.2316e-02,  1.2676e-01]],

         [[-1.7366e-01, -1.3407e-01, -6.7815e-02,  ..., -2.3521e-01,
           -1.8675e-02, -5.1927e-02],
          [ 4.8318e-01, -4.9988e-01,  7.3483e-01,  ...,  1.7037e-01,
            6.2192e-01,  2.3596e-01],
          [ 1.1730e-01,  3.0694e-02,  7.3273e-01,  ...,  5.0575e-01,
            3.1356e-02, -5.0081e-01],
          ...,
          [ 6.1899e-01, -9.2282e-01,  1.6701e-01,  ..., -2.4323e-02,
            1.7694e-01, -3.4102e-01],
          [ 8.4867e-01,  1.2311e-01,  3.3463e-01,  ...,  3.2204e-01,
            8.6678e-01,  5.9980e-01],
          [ 4.1040e-01,  1.7545e-01,  2.0518e-01,  ..., -9.3810e-01,
            4.8850e-01, -5.4087e-01]],

         [[ 1.1481e-01, -7.4767e-02, -2.5446e-02,  ..., -1.8679e-02,
           -9.1254e-02, -9.6947e-02],
          [ 5.5079e-01,  1.9193e-01,  1.8251e-04,  ..., -1.0992e-02,
           -2.6968e-01, -3.8421e-02],
          [-1.8607e-01, -8.5692e-02,  3.1742e-01,  ..., -3.9823e-01,
            4.3919e-01, -8.0165e-02],
          ...,
          [-1.6626e-01, -1.0646e+00, -1.0149e-02,  ..., -9.7871e-02,
            1.4443e-01, -1.5419e-01],
          [-4.4313e-01, -1.3310e-01,  4.2125e-01,  ...,  4.0301e-02,
           -1.7659e-01,  3.1838e-01],
          [ 8.1519e-01,  2.4844e-01,  1.2036e-01,  ..., -9.9506e-02,
           -2.9214e-01,  5.8580e-02]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-8.8678e-01, -1.3593e-01,  3.3093e-01,  ..., -9.5576e-01,
            2.5192e-02, -2.9464e+00],
          [ 1.5584e+00, -5.3821e-01, -2.4421e+00,  ..., -1.7774e+00,
           -1.4069e+00,  7.2554e+00],
          [-6.1613e-02, -9.6410e-01, -3.3367e+00,  ..., -1.7228e+00,
           -7.6467e+00,  6.5063e+00],
          ...,
          [ 1.2191e+00, -3.6478e-01, -1.8077e+00,  ..., -1.4126e+00,
           -3.4429e+00,  1.1099e+01],
          [ 1.2810e+00, -4.1117e-01, -4.4152e+00,  ..., -1.0298e+00,
           -2.3506e+00,  1.1191e+01],
          [ 1.5495e+00, -1.9605e+00, -3.1807e+00,  ..., -9.8794e-01,
           -2.1888e+00,  9.4760e+00]],

         [[ 3.7499e-01, -6.6046e-02,  4.5773e-01,  ..., -1.2836e-01,
           -7.7381e-02, -2.2161e+00],
          [-1.9084e+00, -5.1770e-01,  3.3306e+00,  ..., -1.0169e-01,
           -2.0618e+00,  7.5854e+00],
          [-3.1865e+00, -5.3798e-01,  3.4467e+00,  ...,  8.8427e-02,
           -4.1777e+00,  7.7792e+00],
          ...,
          [-2.9382e+00, -8.8965e-01,  3.4723e+00,  ..., -1.4002e+00,
           -5.7932e-01,  6.9011e+00],
          [-3.7302e+00, -1.4835e+00,  7.7318e-01,  ..., -1.4177e+00,
           -1.5522e+00,  7.3279e+00],
          [-2.4526e+00, -1.8321e+00,  3.6389e+00,  ..., -4.4448e-01,
           -1.6136e+00,  6.6650e+00]],

         [[ 1.2211e-01, -6.5015e-01, -2.2831e-01,  ...,  1.4110e-01,
            2.7893e-01, -1.7424e-01],
          [ 1.7771e-01,  1.7629e+00,  6.3257e-01,  ..., -2.6582e-01,
            6.2577e-01,  5.0930e-02],
          [ 2.2530e-01,  3.0012e+00,  5.3516e-01,  ..., -3.2276e-01,
            5.9087e-01, -3.6453e-02],
          ...,
          [-6.4210e-01,  3.1597e+00,  2.3032e-01,  ...,  6.4203e-01,
            1.9326e-01,  5.4560e-01],
          [-4.8734e-01,  2.4240e+00,  1.1159e-01,  ...,  9.6528e-01,
            1.2245e+00, -1.7901e+00],
          [ 2.7319e-01,  2.8160e+00,  6.3444e-01,  ..., -5.1675e-01,
           -1.5301e-01, -8.1118e-01]],

         ...,

         [[-4.0181e-01,  1.2737e-02, -1.1140e-02,  ...,  1.2548e+00,
            4.3199e-02,  1.8033e+00],
          [-4.6139e-01, -1.3921e+00, -1.4511e+00,  ..., -2.5093e+00,
           -1.6920e+00, -2.7131e-01],
          [-1.3954e-01,  3.9872e-01, -5.5181e-01,  ..., -4.0252e+00,
           -1.2034e+00, -8.0604e-01],
          ...,
          [ 3.8913e-01, -9.2129e-01,  6.7512e-01,  ..., -3.2734e+00,
           -3.7855e-01, -1.2775e+00],
          [ 3.6478e-01,  1.1098e+00,  1.9589e+00,  ..., -1.2581e+00,
           -9.2984e-01, -1.5476e+00],
          [-2.0390e-01, -6.6112e-01, -9.6914e-01,  ..., -3.2531e+00,
           -3.5533e-01, -3.5020e-01]],

         [[-3.3790e-01, -1.2825e-01,  2.2242e-01,  ...,  2.6358e-01,
           -2.9314e-02,  3.1528e-02],
          [-6.0304e-01, -1.1295e+00,  1.4573e+00,  ...,  7.0224e-01,
           -8.5480e-01,  1.8017e-01],
          [ 9.3104e-01, -2.1456e+00,  3.8324e-01,  ...,  9.3967e-01,
           -8.2110e-01,  1.3123e-01],
          ...,
          [-7.5492e-01, -1.8400e-01,  3.3456e-01,  ...,  1.7404e+00,
            7.1590e-01,  1.3268e+00],
          [-1.5429e-01,  5.3506e-01,  2.4561e+00,  ...,  1.2834e+00,
            5.7729e-01,  1.3149e+00],
          [-7.7036e-01, -6.9287e-01,  1.1238e+00,  ...,  1.0106e+00,
           -5.3742e-01,  1.3852e+00]],

         [[ 3.4402e+00,  2.1226e+00, -2.1050e+00,  ..., -2.8555e+00,
           -3.9038e+00, -1.2060e+00],
          [-3.0643e+00, -1.6132e+00,  4.7811e+00,  ..., -2.6905e+00,
            9.4376e+00, -3.7636e+00],
          [-2.8029e+00,  9.2815e-01,  2.2908e+00,  ..., -3.5372e+00,
            9.2503e+00,  2.0644e+00],
          ...,
          [-4.9774e+00, -1.8169e+00,  4.4703e+00,  ..., -4.3005e+00,
            1.5492e+01,  3.7749e+00],
          [-2.4577e+00, -1.8796e+00,  6.0842e+00,  ..., -4.6722e+00,
            9.1210e+00,  1.8122e+00],
          [-4.4732e+00, -2.0733e+00,  7.2062e+00,  ..., -3.7151e+00,
            1.2814e+01, -2.2193e+00]]]], grad_fn=<PermuteBackward0>), tensor([[[[-0.0028, -0.0602,  0.0219,  ...,  0.0593,  0.0264,  0.0681],
          [ 0.6692, -0.1774,  0.2994,  ...,  0.1940, -0.3524, -0.1093],
          [ 0.0441, -0.6776, -0.4458,  ...,  0.2746,  0.9155, -0.5374],
          ...,
          [-0.3262, -0.0103, -0.0866,  ...,  0.0454, -0.1561,  0.2205],
          [-0.0552, -0.6212, -0.4492,  ..., -0.2533,  0.0952, -0.2438],
          [ 0.1740,  0.0146, -0.0917,  ...,  0.1930, -0.1700,  0.1307]],

         [[-0.0538, -0.0195, -0.1417,  ..., -0.0445,  0.0476, -0.0319],
          [ 0.3175, -0.1990, -0.2276,  ...,  0.1004, -0.0740, -0.1226],
          [ 0.3296, -0.6555, -0.2850,  ..., -0.8669,  0.2712,  0.0552],
          ...,
          [-0.0141,  0.1838,  0.2267,  ...,  0.0249, -0.0362,  0.3883],
          [-0.2939, -0.5590,  0.3243,  ..., -0.0678,  0.0157, -0.5514],
          [-0.0048, -0.0914, -0.2181,  ..., -0.2868,  0.0018, -0.0651]],

         [[ 0.0639,  0.0961,  0.0831,  ...,  0.0160, -0.0859, -0.0050],
          [-0.8685, -0.1267, -0.8107,  ...,  0.0526, -0.7176, -0.0689],
          [ 0.1621,  0.2253,  0.0752,  ...,  0.1041, -0.4005,  0.1818],
          ...,
          [-0.4981,  0.5339, -0.4980,  ..., -0.2581, -0.8093, -0.3876],
          [-0.6054,  1.6497,  1.0752,  ..., -1.0363,  0.7149, -0.6451],
          [-0.4706,  0.3250,  0.3061,  ...,  0.4489, -0.6589,  0.0312]],

         ...,

         [[-0.0115,  0.0657, -0.0777,  ...,  0.0440,  0.0456, -0.1384],
          [ 0.0136, -0.3035,  0.8164,  ..., -0.2084, -0.8236,  0.4428],
          [ 0.2521, -0.4054,  0.2197,  ...,  0.1480,  0.2216,  0.5164],
          ...,
          [-0.0778, -0.1247, -0.3227,  ...,  0.1474,  0.1483,  0.3701],
          [-0.3559, -0.8621, -0.0799,  ..., -0.9994,  0.4109,  0.2198],
          [-0.1967,  0.0573,  0.6049,  ...,  0.1913,  0.0767, -0.0245]],

         [[-0.1315, -0.0534,  0.0947,  ..., -0.0666,  0.0539, -0.0204],
          [ 0.0918, -0.3386, -0.7173,  ..., -0.2867, -0.0289, -0.1466],
          [ 0.2971,  0.6579, -0.9279,  ..., -0.0267, -1.3269,  0.6167],
          ...,
          [-0.1993,  0.8396,  0.5954,  ..., -0.2100,  0.3891,  0.5287],
          [ 1.5998,  0.6881, -0.2637,  ...,  1.1610,  0.1208, -0.6552],
          [ 0.5209, -0.3917,  0.1674,  ..., -0.2824,  0.0700, -0.3138]],

         [[-0.0193, -0.0120, -0.0240,  ..., -0.0300,  0.0080, -0.0136],
          [-0.0432, -0.3667, -0.3346,  ..., -0.1011,  0.0167,  0.1537],
          [-0.3303, -0.6508, -0.2167,  ..., -0.6360, -0.1999, -0.1340],
          ...,
          [-0.0058, -0.1530, -0.3235,  ..., -0.3699,  0.0510,  0.1209],
          [ 0.1009,  0.4467, -0.0791,  ..., -0.2715, -0.2259,  0.5418],
          [ 0.0141,  0.2831, -0.4868,  ..., -0.1903, -0.1869,  0.7274]]]],
       grad_fn=<PermuteBackward0>)), (tensor([[[[ 3.4512e-02, -2.8466e-01,  2.2210e-01,  ...,  1.6982e+00,
           -2.2029e-01, -8.0207e-02],
          [ 1.1887e+00,  1.2521e+00, -2.0973e-02,  ..., -2.7505e+00,
            1.1517e-01, -9.4738e-01],
          [-3.7887e-01,  2.3147e-01,  7.8851e-01,  ..., -3.8859e+00,
           -1.2610e+00, -1.5381e+00],
          ...,
          [-3.8848e-01, -7.8692e-01,  6.0321e-01,  ..., -1.5790e+00,
           -4.4260e-01, -1.7360e+00],
          [-1.3567e+00, -1.2212e-02,  4.0693e-01,  ..., -2.6267e+00,
            3.1883e-01, -1.1768e+00],
          [-1.3149e+00,  5.3910e-01,  8.4051e-01,  ..., -2.6472e+00,
           -8.0766e-02, -1.3063e+00]],

         [[ 1.5566e-01,  9.6884e-01, -1.4234e+00,  ..., -1.1945e-01,
            2.6095e-01,  9.2861e-01],
          [-1.1655e+00, -5.3317e+00,  7.2065e-01,  ..., -1.4863e+00,
           -2.2354e+00, -2.4988e+00],
          [ 1.2192e+00, -4.3649e+00,  9.3857e-01,  ...,  3.6005e-01,
           -1.0827e+00, -2.1299e+00],
          ...,
          [-6.5003e-01, -3.6931e+00,  4.9255e-01,  ..., -2.0790e+00,
           -3.1514e-01, -2.7136e+00],
          [ 5.9668e-01, -3.1527e+00,  7.6608e-01,  ..., -4.4680e-01,
           -1.1040e-01, -1.9393e+00],
          [ 2.0418e+00, -5.3709e+00,  5.4901e+00,  ..., -2.3439e-02,
            4.6572e-01, -3.8706e+00]],

         [[-6.7068e-01,  2.4994e-01, -5.6570e-02,  ...,  1.7880e-01,
            5.6148e-02, -2.9901e-01],
          [ 1.9676e+00,  2.9566e-02, -8.5660e-01,  ..., -1.8619e+00,
           -3.3802e-01,  1.6140e-01],
          [ 2.1615e+00, -7.5559e-01,  3.4024e-01,  ..., -1.4898e+00,
            4.2649e-01,  1.5977e+00],
          ...,
          [ 1.1094e+00, -8.7126e-01,  4.4787e-02,  ..., -4.0946e-01,
           -6.8646e-01, -5.1147e-01],
          [ 1.3666e+00, -6.3472e-01, -6.9747e-01,  ...,  6.0671e-01,
            2.1492e+00, -3.3250e-01],
          [ 2.1474e+00, -4.8501e-02, -8.7131e-01,  ..., -1.4417e+00,
            1.5616e+00,  1.8827e-01]],

         ...,

         [[-4.0635e-02,  1.1188e-01,  1.4037e-01,  ..., -9.7647e-02,
            1.2961e-02,  1.5000e-01],
          [ 7.8730e-01, -5.6138e-01, -1.2585e+00,  ...,  1.1703e+00,
           -1.7229e-01,  1.2928e+00],
          [-1.0394e-01,  1.4770e-01,  3.8454e-01,  ...,  6.5685e-01,
           -2.6355e-01,  1.3102e+00],
          ...,
          [ 1.1297e-01,  1.4229e+00,  2.8362e-01,  ...,  9.3448e-01,
            2.5909e-01,  2.9945e-01],
          [ 6.5882e-01,  7.3874e-01, -5.1318e-01,  ...,  9.5171e-01,
            1.6892e-01, -1.7952e-01],
          [ 4.4172e-01,  9.7651e-02, -1.4498e+00,  ...,  1.2877e+00,
            7.8737e-01,  5.7300e-02]],

         [[-3.0020e+00,  4.0418e-01, -2.7798e-02,  ..., -4.8566e-01,
           -3.4500e-01,  1.2311e+00],
          [ 4.8764e+00,  1.4500e+00, -1.1937e+00,  ..., -1.6858e+00,
            3.0943e-01, -9.1063e-01],
          [ 4.6146e+00,  9.6566e-01, -5.1178e-01,  ..., -2.1980e-01,
            1.1130e+00, -1.2746e+00],
          ...,
          [ 4.9677e+00,  2.5583e-02, -1.3527e+00,  ..., -1.8770e+00,
           -6.6969e-01, -4.0065e-01],
          [ 4.3137e+00,  1.0467e+00, -1.5161e+00,  ..., -2.2238e+00,
           -1.7302e-01, -1.6034e-01],
          [ 4.6436e+00,  9.9926e-01, -5.2100e-01,  ..., -1.5177e+00,
            1.9258e-01, -2.4487e-01]],

         [[-7.4442e-03, -2.5452e-01, -1.9922e-04,  ..., -1.8494e-01,
            3.4208e-01,  9.0523e-02],
          [ 8.8014e-01, -3.2005e+00, -2.3284e-01,  ..., -5.6783e-01,
            5.3092e-01,  4.5332e-02],
          [-3.2605e-01, -1.7599e+00, -5.3681e-01,  ..., -5.2140e-01,
            1.7060e+00, -8.0691e-01],
          ...,
          [-1.1833e+00, -8.9443e-01,  5.9676e-01,  ...,  3.0636e-01,
            5.0886e-01, -1.5048e+00],
          [-1.2903e+00, -9.5492e-01,  2.1957e-01,  ...,  2.2938e+00,
           -5.0270e-01, -7.8764e-02],
          [ 4.4758e-01, -1.5906e+00,  1.4957e-01,  ...,  2.3779e+00,
           -2.2358e-01,  4.7562e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[-0.0224, -0.0239,  0.0031,  ..., -0.0027, -0.0209,  0.3535],
          [ 1.1117, -0.0250,  1.0920,  ..., -0.1345, -0.2322, -0.7385],
          [ 1.2585, -0.5406, -0.8740,  ...,  0.6211,  0.4854, -0.4785],
          ...,
          [ 0.8751, -0.3160, -0.5735,  ..., -0.2102, -0.0831, -0.5934],
          [ 0.0085, -0.3084,  0.1655,  ...,  0.4398,  0.5114, -0.4383],
          [-0.0725, -0.3939,  0.5899,  ...,  0.7469, -0.3640,  0.0679]],

         [[ 0.0048, -0.0161,  0.0186,  ..., -0.0150,  0.0150,  0.0090],
          [ 0.4621, -0.6415, -0.2005,  ...,  0.2446,  1.2697, -0.7838],
          [ 0.6805, -1.2565,  0.0765,  ..., -0.0242,  1.4869,  0.1836],
          ...,
          [-0.2538, -0.0022,  0.1847,  ...,  0.4838,  1.5106,  0.7886],
          [ 1.2671, -0.9662, -0.3248,  ...,  0.5432, -0.0319, -0.1366],
          [-0.1197, -1.6058, -0.3833,  ...,  0.3964,  1.0133, -0.1477]],

         [[-0.0603,  0.0030, -0.0383,  ..., -0.0468,  0.0119, -0.0780],
          [ 0.5506, -0.3951,  0.6694,  ..., -0.6748,  0.3026,  0.0286],
          [ 0.4687,  0.1415,  0.0033,  ...,  0.4084,  0.2910,  0.4103],
          ...,
          [ 0.4985,  0.4334,  0.3964,  ..., -0.2184, -0.0373, -0.0717],
          [ 0.0850, -0.4120, -0.2606,  ..., -0.2593,  0.7614, -0.8139],
          [ 0.0813, -0.2308,  0.9975,  ..., -0.3412, -1.0508, -0.9304]],

         ...,

         [[-0.3063, -0.1904, -0.0540,  ..., -0.4848,  0.2131,  0.1049],
          [ 1.7674, -1.5249,  1.8613,  ...,  1.4412, -0.3079,  0.1500],
          [ 0.3092, -1.5787,  0.6095,  ...,  0.5455,  0.1634,  1.3060],
          ...,
          [-1.9410, -1.8215, -0.4399,  ...,  0.3221, -0.1979, -1.4136],
          [ 1.2677, -1.9424,  0.0700,  ..., -0.9788, -0.6381, -0.4399],
          [ 0.8285, -1.8581, -0.3010,  ..., -1.3209,  0.2318, -0.1750]],

         [[-0.0861, -0.1412, -0.0534,  ..., -0.1797, -0.1466,  0.1142],
          [-0.7113, -0.5252, -0.7349,  ..., -0.0491,  0.5213, -0.7352],
          [ 0.4967, -1.1247, -0.6529,  ..., -0.4258, -0.1081, -0.2017],
          ...,
          [-0.4174, -1.3939,  0.0162,  ..., -0.2306, -0.4274,  0.3158],
          [-1.1609, -0.1209, -0.1991,  ...,  1.2310,  0.5859,  0.6733],
          [-1.1036, -0.5834,  0.1167,  ...,  0.8276, -0.1767,  0.3441]],

         [[-0.0294, -0.0414,  0.1069,  ...,  0.0614, -0.0412,  0.0239],
          [-0.6127, -0.0583, -0.7644,  ..., -1.4024, -0.9271,  0.9733],
          [ 0.5288,  0.2919,  0.0434,  ..., -0.4878, -0.6339,  0.4392],
          ...,
          [ 0.0769, -0.0123, -1.2272,  ...,  0.3366, -0.2014,  0.2725],
          [ 0.0642,  1.9300, -0.3253,  ..., -1.0578,  0.4355, -1.4476],
          [-0.5956,  0.2606, -0.5507,  ..., -0.5284, -0.1602, -0.7526]]]],
       grad_fn=<PermuteBackward0>)), (tensor([[[[-3.3938e-01,  8.6866e-01, -1.6351e-01,  ...,  1.1267e+00,
           -1.6784e-01,  1.2959e-01],
          [-3.9849e-01, -4.8487e+00, -2.9290e-01,  ..., -4.2862e+00,
            6.9337e-01,  6.4498e-01],
          [ 7.8937e-01, -3.8062e+00, -7.5906e-01,  ..., -4.4411e+00,
            9.3744e-01,  2.4774e+00],
          ...,
          [-2.9636e-01, -5.8016e+00,  1.4007e+00,  ..., -3.1025e+00,
           -2.7375e-01,  1.2819e+00],
          [ 1.3025e+00, -3.8148e+00,  1.8926e+00,  ..., -3.3508e+00,
            6.2647e-01,  5.1378e-01],
          [-3.7934e-01, -4.5341e+00,  6.0715e-01,  ..., -4.3161e+00,
            6.4808e-01,  7.9708e-01]],

         [[ 5.8217e-02,  8.7121e-01, -6.2251e-01,  ..., -2.4310e-02,
            2.9330e-01,  1.3199e-02],
          [ 1.3357e+00, -1.1482e+00,  1.2032e-01,  ...,  1.5088e+00,
           -1.0720e+00, -1.1527e+00],
          [ 2.6168e+00,  9.3244e-03, -7.2926e-01,  ...,  9.4531e-01,
           -7.9178e-01, -1.6888e+00],
          ...,
          [ 8.0452e-01, -3.9176e-01, -3.0347e-01,  ...,  1.3463e+00,
           -3.1319e-01, -1.3556e-01],
          [-7.1086e-01,  9.6997e-02,  1.2591e+00,  ...,  2.0719e-01,
            4.2983e-01, -6.3391e-01],
          [ 1.1039e-01, -1.3052e+00,  1.1124e-01,  ...,  1.3074e+00,
            1.4712e+00, -2.7487e-01]],

         [[-3.1165e-01,  1.2165e-01, -9.8370e-01,  ..., -3.5095e-01,
           -6.3912e-02, -1.3616e-01],
          [ 5.0049e-01, -6.6728e-01,  2.9285e+00,  ..., -3.9263e-01,
            4.3198e-01, -2.3447e-01],
          [ 1.2306e-01, -2.9766e-01,  3.6896e+00,  ..., -1.0091e-01,
           -2.5103e-01, -2.0315e-01],
          ...,
          [-2.1391e-01, -2.1547e+00,  2.8612e+00,  ...,  5.8855e-01,
           -1.9214e-01,  1.8883e+00],
          [-2.1992e-01, -1.4360e+00,  3.3444e+00,  ...,  9.8178e-01,
           -1.9441e+00,  5.7364e-01],
          [ 6.3090e-02, -1.4908e+00,  2.0854e+00,  ...,  1.4157e-01,
           -1.3972e-01, -6.9580e-02]],

         ...,

         [[ 3.7597e-01,  8.1398e-02, -6.4505e-02,  ..., -4.8594e-02,
            2.2536e-01,  4.1931e-03],
          [-1.2319e+00,  8.1079e-01, -5.4320e-01,  ...,  1.2257e-01,
           -7.8676e-02, -2.6823e-01],
          [-1.9185e-02,  5.4915e-01,  9.4312e-01,  ..., -2.6608e+00,
            3.8096e-01, -1.3816e+00],
          ...,
          [-1.5423e+00, -2.7545e-01,  2.9765e+00,  ...,  5.4036e-01,
            1.6682e+00, -7.5562e-01],
          [-1.2052e+00, -2.4065e-01,  4.7900e-02,  ..., -1.5625e+00,
            2.8238e-01, -3.3910e-01],
          [-1.7759e+00,  3.9760e-01, -1.0807e+00,  ..., -1.9584e+00,
           -1.1637e+00,  1.5918e+00]],

         [[ 2.0009e-01,  5.4941e-02,  3.2748e-01,  ...,  4.1661e-01,
           -3.4165e-03,  2.3171e-01],
          [ 1.6163e+00,  1.2442e+00,  2.8373e-01,  ..., -3.9689e-01,
            7.1320e-03, -1.1601e-01],
          [ 1.3228e+00,  1.4674e-01,  6.3871e-01,  ..., -5.9913e-02,
            1.6461e-01,  3.3509e-01],
          ...,
          [ 7.7162e-01,  7.9756e-01,  8.2908e-01,  ..., -1.0911e+00,
            8.8888e-01, -1.1994e+00],
          [ 1.6909e+00,  8.3524e-01,  6.7132e-01,  ..., -1.1008e+00,
           -7.2901e-01,  6.1303e-01],
          [ 2.8334e+00,  4.6555e-01,  1.2473e+00,  ..., -7.3844e-01,
           -7.0963e-01,  1.0278e-01]],

         [[-3.0156e+00,  5.3756e-01,  5.6815e-01,  ..., -9.3899e-01,
            3.2683e-01,  1.8463e-01],
          [ 7.7879e+00, -9.7524e-01, -2.1850e+00,  ...,  2.2429e+00,
           -1.0887e+00,  5.6749e-01],
          [ 6.9868e+00, -3.9651e-01, -9.7286e-01,  ...,  1.0613e+00,
           -7.0396e-01,  1.3823e+00],
          ...,
          [ 9.5361e+00, -8.7937e-01, -2.5252e+00,  ...,  1.3820e+00,
           -2.2409e+00,  2.4565e-01],
          [ 8.8630e+00, -1.1387e+00, -1.7681e+00,  ...,  1.0129e+00,
            2.0493e-01, -2.1170e-01],
          [ 9.5194e+00, -2.0795e-01, -1.6476e+00,  ...,  2.4340e+00,
           -1.9197e+00, -2.9640e-01]]]], grad_fn=<PermuteBackward0>), tensor([[[[ 4.6474e-02, -5.0378e-02,  1.0945e-02,  ..., -6.9955e-02,
            2.9789e-03, -1.0073e-01],
          [ 5.1051e-01, -5.5772e-01, -3.8570e-01,  ..., -3.2328e-01,
            2.3945e-01, -2.9826e-01],
          [-2.8010e-01,  7.4962e-01, -5.4584e-01,  ..., -3.6442e-01,
            4.2576e-01, -1.4805e+00],
          ...,
          [-9.0783e-01, -4.8128e-01, -1.8888e-01,  ..., -2.2824e-01,
           -7.4845e-02, -1.0972e+00],
          [-5.0702e-01,  1.0603e-01, -1.0484e+00,  ...,  5.5779e-01,
           -4.9793e-01, -9.2837e-01],
          [ 3.8714e-02,  4.2493e-01, -4.1890e-01,  ...,  5.6050e-01,
           -2.7279e-01, -1.3355e+00]],

         [[ 6.7674e-02,  3.0544e-02, -2.3115e-02,  ..., -4.3823e-02,
            5.2575e-03, -1.6795e-03],
          [-5.3669e-01,  1.7762e+00, -5.2043e-01,  ...,  7.5157e-01,
           -6.1868e-01, -7.3336e-01],
          [-2.5054e-01,  4.9751e-03, -8.3214e-02,  ..., -7.4598e-01,
           -6.1617e-01,  3.3602e-01],
          ...,
          [-1.4067e-01,  5.9621e-01,  1.0898e+00,  ...,  9.4066e-01,
           -1.3745e+00,  1.1213e+00],
          [-1.0630e+00, -5.0378e-01,  6.9651e-01,  ..., -4.6445e-01,
           -6.6259e-01,  1.7251e-01],
          [-1.5972e+00,  3.2659e-01,  3.4644e-01,  ...,  2.8986e-01,
           -5.7299e-01, -2.2912e-01]],

         [[ 7.0952e-02,  8.2320e-03, -1.6572e-03,  ...,  2.1678e-02,
           -6.7437e-02, -5.0287e-02],
          [ 7.4200e-01, -3.2418e-01,  4.1442e-01,  ..., -1.4945e-02,
            2.5678e-01,  1.5392e-01],
          [ 2.9304e-01,  5.7399e-01, -2.7184e-01,  ..., -1.4044e-01,
            6.1588e-02, -1.5561e-01],
          ...,
          [ 7.1019e-01, -8.5043e-01, -3.1989e-01,  ...,  2.5753e-01,
            2.2188e-01,  7.3108e-01],
          [ 7.1561e-01, -8.6057e-01,  9.2320e-01,  ...,  3.9957e-01,
            2.4226e+00,  1.6563e+00],
          [-7.6132e-02,  2.4041e-01,  9.3365e-01,  ..., -2.2613e-01,
            3.9552e-01,  1.0165e-01]],

         ...,

         [[-8.4018e-04,  4.2945e-02,  2.0029e-02,  ..., -6.6209e-02,
           -1.8070e-02,  2.2869e-02],
          [-1.4168e+00,  2.7825e-01,  3.5415e-02,  ...,  2.2794e-01,
           -1.8244e-01,  2.6631e-01],
          [-1.5832e+00,  6.7589e-01, -1.3738e-01,  ...,  7.5377e-01,
           -8.9247e-01,  8.4118e-01],
          ...,
          [-1.0343e+00,  2.2096e-01,  1.8098e-01,  ...,  1.5064e+00,
           -9.4570e-01, -9.6457e-01],
          [-5.5192e-01,  6.5732e-01, -7.3323e-01,  ...,  8.2586e-01,
            1.0773e+00, -5.0690e-01],
          [-6.9760e-01, -2.0758e-01,  2.9526e-01,  ..., -1.6063e-02,
            1.6516e-02,  4.3263e-01]],

         [[ 5.6418e-02, -6.3642e-03,  2.3703e-02,  ...,  1.7139e-02,
           -1.5312e-02,  6.8112e-03],
          [ 1.8381e+00, -1.3941e+00, -1.0189e+00,  ..., -9.4177e-01,
            4.2883e-01,  8.2570e-01],
          [ 8.8893e-01, -1.6692e+00, -4.3398e-01,  ..., -1.2906e+00,
            1.0952e-01,  3.7169e-01],
          ...,
          [ 7.4024e-01, -1.4955e-01, -8.9148e-01,  ..., -1.0267e+00,
           -6.1569e-01,  5.8172e-01],
          [-7.3008e-01, -4.7314e-01,  3.7697e-01,  ...,  5.2418e-01,
           -1.6633e-01,  3.0198e-01],
          [ 6.6411e-02, -4.8074e-01, -4.0598e-01,  ...,  1.1196e-01,
            1.0054e+00, -4.4949e-01]],

         [[ 6.7269e-02, -2.0375e-01, -7.5082e-02,  ..., -4.0162e-02,
            1.9610e-01, -5.1942e-02],
          [ 3.7243e-01, -9.5645e-01, -3.3796e-01,  ..., -9.8523e-01,
           -4.3307e-01, -2.3109e-01],
          [-5.5909e-01, -9.8741e-01, -8.3997e-01,  ..., -4.0350e-02,
            2.2590e-03, -1.1709e+00],
          ...,
          [ 2.6116e-01, -1.7003e+00,  9.9667e-03,  ...,  2.5269e-01,
           -6.5086e-01, -5.0987e-01],
          [-2.2483e-01, -3.8567e-01, -1.6472e-01,  ..., -7.8707e-01,
            3.2198e-01, -4.2609e-01],
          [-1.5893e-01, -7.3543e-01, -4.9369e-01,  ..., -1.5504e+00,
           -3.8277e-01, -4.1377e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[ 1.0315, -0.2591, -0.1501,  ...,  0.6246,  0.7232, -0.3059],
          [-3.2631, -1.9435, -1.2089,  ...,  0.2806, -5.5510,  0.3489],
          [-4.3399, -1.4013, -0.2186,  ...,  0.7109, -5.7860, -0.6970],
          ...,
          [-5.3760, -3.7959, -0.3718,  ...,  0.7804, -4.6301, -1.2402],
          [-6.1584, -2.8255,  0.0772,  ...,  0.3908, -4.4567, -0.1920],
          [-4.6862, -2.2484, -0.4802,  ...,  1.1911, -4.6985, -1.0555]],

         [[-0.1525, -0.0745,  0.1651,  ..., -0.0464, -0.8761, -0.1921],
          [-0.6493,  1.0436, -0.2845,  ..., -0.2628,  0.0537,  0.8063],
          [-1.7708,  1.3885, -0.9440,  ...,  0.3637,  0.7435,  1.4247],
          ...,
          [-2.0241,  0.3328, -0.2828,  ...,  0.8545,  0.5231,  2.4687],
          [-2.8308, -0.2631, -0.4617,  ..., -0.3337,  1.8320,  2.9475],
          [-2.0453,  1.1846, -2.5580,  ...,  0.5495,  1.1092,  1.8249]],

         [[ 0.1958,  0.3039,  1.1389,  ..., -0.4691,  0.4513, -0.4878],
          [-1.4815, -0.5524, -1.6846,  ..., -0.5676, -1.8434,  2.4752],
          [-3.5171, -1.7341, -1.0781,  ..., -0.0126, -0.8584,  2.8363],
          ...,
          [-1.2945, -1.0943, -0.7373,  ...,  0.2280, -2.9008,  2.5152],
          [-2.2796, -0.5816, -0.3174,  ...,  0.7422, -1.4116,  2.2355],
          [-0.7958,  0.1943, -2.7152,  ...,  1.7208, -1.5123,  0.9313]],

         ...,

         [[-0.1518,  0.0675, -0.2341,  ...,  0.0125,  0.1685,  0.0227],
          [-1.5855,  0.1968, -0.4700,  ...,  0.9270, -1.3281, -0.1941],
          [ 0.7663, -0.7921, -0.5326,  ...,  0.9606, -0.0650, -0.2843],
          ...,
          [ 0.1914, -0.1551, -0.9815,  ...,  1.8034,  0.1310,  0.7172],
          [-1.2788, -1.7422, -0.4975,  ...,  1.3406, -0.4531, -0.5256],
          [-1.8526,  0.1496, -0.0816,  ...,  0.8122, -1.0543,  0.1050]],

         [[-0.3515, -2.1836,  0.1103,  ..., -0.0873, -0.0481,  0.9174],
          [-0.3931,  1.7304, -1.0893,  ..., -1.0898, -1.7984,  1.0287],
          [ 0.3552,  3.3603, -1.5929,  ..., -0.7109, -1.5203,  0.7090],
          ...,
          [ 0.4526,  3.6483, -3.1344,  ...,  1.3756, -1.8511,  2.2068],
          [ 1.4022,  2.2589, -2.0330,  ...,  0.3515, -0.4796,  0.9019],
          [ 0.7568,  2.8114, -2.1562,  ...,  1.3476, -0.3658,  0.7552]],

         [[ 0.3682,  0.0657, -0.1320,  ...,  0.6454,  0.1343,  0.2644],
          [-0.8896,  0.3677,  0.1631,  ..., -0.3916, -0.4439, -0.9719],
          [ 0.4470,  0.5271, -0.4635,  ..., -0.6886, -1.2558,  0.0390],
          ...,
          [-1.7867, -2.2049,  2.1719,  ..., -0.8210, -0.2084,  1.6132],
          [-1.4884, -1.5097,  0.1562,  ...,  0.5166,  0.2819, -0.1415],
          [-2.1183,  1.1049,  1.0999,  ..., -0.3114,  0.2994,  0.8749]]]],
       grad_fn=<PermuteBackward0>), tensor([[[[-2.8336e-02,  4.5010e-02, -5.7978e-02,  ..., -1.9222e-02,
            1.2577e-02,  3.6269e-02],
          [-4.6357e-01,  2.2473e-01, -5.6911e-01,  ..., -1.1179e-01,
            5.4963e-01,  1.4621e-01],
          [ 8.8791e-01,  1.3920e-01, -1.1074e+00,  ..., -3.0970e-01,
            8.6369e-01,  2.8616e-01],
          ...,
          [ 9.8967e-01,  7.6810e-02,  3.6725e-01,  ...,  1.0289e-01,
           -7.9780e-01, -1.0472e-01],
          [ 4.1472e-01, -3.0706e-01, -1.0118e-01,  ..., -2.9164e-02,
            9.2894e-02,  2.6503e-01],
          [-4.1391e-01, -4.3953e-01,  9.5461e-02,  ..., -1.8622e-02,
            1.2946e-01, -4.0387e-01]],

         [[ 2.1536e-02, -2.8120e-02,  3.8532e-02,  ...,  2.1765e-02,
           -4.7212e-02,  5.3255e-03],
          [-1.6248e-01, -4.5659e-01, -4.4525e-02,  ...,  5.6903e-01,
           -3.0144e-01, -1.2120e+00],
          [-1.6019e-01, -3.1593e-01,  1.0682e+00,  ..., -1.1746e-01,
           -4.8418e-01,  4.2423e-01],
          ...,
          [-7.0670e-01,  1.4226e-01, -2.0767e-01,  ..., -5.3785e-01,
           -3.7916e-01,  2.9476e-01],
          [ 3.5204e-01,  1.6746e-01, -1.8197e+00,  ...,  1.8833e-01,
            2.5200e-01,  1.3326e+00],
          [ 1.0614e-01, -5.6477e-01, -1.3717e+00,  ...,  2.8329e-01,
           -2.3432e-01,  5.8129e-01]],

         [[ 3.9084e-02, -2.6990e-02,  5.6189e-02,  ...,  2.6549e-02,
           -7.1806e-03,  1.9065e-02],
          [ 8.1593e-01,  3.5473e-01, -1.9476e-01,  ...,  7.1779e-01,
            1.7158e-01,  1.7037e-01],
          [-3.0468e-01,  6.4740e-01, -1.1535e+00,  ...,  2.5107e+00,
           -1.3214e+00,  6.0931e-01],
          ...,
          [-3.8012e-01, -1.0693e+00, -4.3163e-01,  ..., -1.2006e-01,
           -4.7626e-01, -5.9241e-01],
          [-6.6220e-01,  1.0321e+00,  6.1114e-01,  ..., -1.0294e+00,
           -5.9746e-02, -1.4874e+00],
          [ 1.5239e+00,  1.7266e-01, -2.6497e-01,  ..., -6.9278e-01,
            2.7154e-01,  1.1508e-01]],

         ...,

         [[-1.8650e-01,  8.9365e-02,  5.7435e-02,  ...,  4.6573e-02,
            3.7369e-02, -1.2676e-01],
          [-6.2920e-01, -4.5253e-02,  1.5379e-01,  ..., -8.5838e-01,
            2.2210e-01, -4.9222e-01],
          [-2.1227e-01,  6.7216e-01,  5.8456e-01,  ..., -4.8421e-02,
           -4.2428e-01, -4.8305e-01],
          ...,
          [-9.4783e-01, -4.8206e-02, -1.2836e-01,  ...,  1.8181e-01,
           -4.6491e-01, -8.4671e-01],
          [-7.2088e-01,  4.8839e-01, -1.6034e+00,  ..., -3.5454e-01,
            8.5080e-02, -1.4271e+00],
          [-1.0528e+00,  8.3454e-01, -9.8252e-01,  ...,  1.1729e-01,
           -1.4640e-01, -1.9143e+00]],

         [[-5.8489e-01, -4.5877e-03,  4.4912e-02,  ..., -2.0796e-02,
            6.2989e-03, -6.4938e-03],
          [-1.6445e+00,  4.2511e-02, -3.1403e-01,  ..., -3.7935e-01,
            2.3561e-01,  5.9496e-02],
          [-2.5505e+00, -2.0482e-01, -3.6240e-01,  ..., -3.0201e-01,
           -4.2028e-01, -1.8376e-02],
          ...,
          [-1.6757e+00,  4.2658e-01, -9.1740e-01,  ...,  2.0202e-01,
            5.2352e-01,  3.1575e-01],
          [-1.7608e+00,  5.6837e-01,  3.5225e-01,  ...,  5.5874e-01,
           -6.9264e-01, -1.8256e-01],
          [-2.3731e+00, -2.8098e-01,  3.9676e-01,  ..., -2.5406e-01,
            4.8834e-01, -6.1031e-01]],

         [[ 1.5471e-03,  8.2456e-02, -4.7513e-02,  ...,  5.5853e-02,
            3.0368e-02, -4.6994e-02],
          [-5.5504e-01,  7.3400e-01, -2.0816e-01,  ..., -1.2824e-01,
            3.8586e-01,  8.0331e-01],
          [ 6.3713e-01,  1.6547e+00,  2.6059e-01,  ..., -1.1861e+00,
            6.3198e-01, -1.3541e-01],
          ...,
          [ 4.7463e-01,  1.1477e+00,  6.0258e-02,  ..., -4.6058e-01,
           -3.5489e-01,  7.9365e-02],
          [ 8.1016e-02, -1.3944e-01,  4.1258e-01,  ...,  1.1060e-01,
           -2.8541e+00,  4.1492e-01],
          [-1.2963e+00,  2.2384e-01, -2.4338e-01,  ...,  2.2294e-01,
            1.0918e-01,  2.1425e+00]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-0.0268, -2.3398,  0.1634,  ..., -0.2365, -0.1944,  0.0645],
          [-0.2799,  4.0417, -0.1959,  ..., -0.1126, -0.3356, -0.5690],
          [-0.0094,  4.9294,  0.5373,  ..., -0.1404, -0.6815,  0.3025],
          ...,
          [ 0.0091,  3.8780,  0.7991,  ..., -0.5989,  0.7071,  0.5137],
          [-0.6218,  4.7662, -0.4088,  ..., -0.8925, -0.0737,  0.7395],
          [-0.8638,  5.1069, -0.1012,  ..., -0.0097,  0.0632, -0.7295]],

         [[-0.8140,  0.2218,  0.4656,  ..., -0.5189,  1.0732,  1.1234],
          [ 0.2364,  0.2685,  1.0541,  ...,  0.5500,  1.3914,  0.4962],
          [ 1.5259,  1.0305, -0.6830,  ..., -0.3595,  0.8213, -0.1596],
          ...,
          [-0.8893,  0.6401,  1.5340,  ..., -0.3154,  0.9969,  0.1131],
          [-1.4240, -0.5673, -0.9037,  ..., -0.0334,  2.1567, -0.3555],
          [-2.3116,  1.4069,  0.2116,  ...,  0.7944,  2.6708,  0.1778]],

         [[-0.8504,  0.4700,  0.0232,  ...,  0.4955, -0.2356,  1.1518],
          [ 0.6655, -0.1374,  1.1604,  ...,  0.2494,  1.0734, -0.9082],
          [ 2.0262,  0.3311,  0.5329,  ...,  0.2746,  0.6484, -1.2565],
          ...,
          [ 0.8666,  0.2080,  0.7423,  ..., -0.0590,  0.7947,  0.2077],
          [ 1.3274, -0.5878,  1.5562,  ...,  1.2727,  0.8958, -0.8393],
          [ 0.6793, -0.9115,  2.1432,  ...,  1.5571,  1.7428, -0.3943]],

         ...,

         [[-0.3102, -0.1292,  0.1523,  ...,  0.1793,  1.7438, -2.8696],
          [ 1.2379, -0.5238,  0.3674,  ..., -0.3042, -4.6049,  4.6856],
          [ 0.6856,  0.3973,  0.9211,  ..., -0.6994, -5.2863,  4.9465],
          ...,
          [-0.5532, -0.4212,  1.0728,  ...,  0.4562, -5.7176,  5.1979],
          [-0.9992, -1.4073, -0.8534,  ...,  0.8452, -5.9484,  4.2105],
          [ 0.1935, -1.2555,  1.2355,  ..., -0.0070, -6.0872,  5.8807]],

         [[ 0.1957,  0.3617,  0.2155,  ..., -0.2170,  0.0182, -0.1540],
          [-0.6359, -0.7831, -0.5938,  ...,  1.0413, -0.4280,  0.6407],
          [-0.6033, -1.0964, -0.2818,  ...,  0.2840, -0.2947,  0.6149],
          ...,
          [-0.2907,  0.0759,  0.5673,  ...,  1.1031, -0.7398,  0.1992],
          [-0.3487, -0.1916,  1.1144,  ...,  0.6085,  0.1949,  1.1279],
          [-1.1693, -0.8894,  0.6257,  ...,  1.4145, -1.2843,  0.4372]],

         [[ 0.3722,  0.0987,  0.6134,  ...,  0.5249,  0.5746, -0.3289],
          [ 0.7276, -0.7879, -1.5108,  ..., -1.7654, -3.2146,  0.1771],
          [ 0.6286, -1.0423, -1.3390,  ..., -2.0023, -2.7540, -0.0532],
          ...,
          [ 1.2008, -1.0047, -2.2047,  ..., -2.5210, -4.7543,  1.0585],
          [ 0.1571, -1.0960, -1.7899,  ..., -3.0896, -4.1969, -0.4143],
          [ 1.1982, -1.3326, -1.5329,  ..., -1.6822, -4.4774, -0.5948]]]],
       grad_fn=<PermuteBackward0>), tensor([[[[ 0.0512, -0.0071, -0.0159,  ...,  0.1308, -0.0604, -0.0383],
          [ 0.0743,  0.0098,  0.8985,  ...,  0.3322,  1.0163, -0.2279],
          [ 0.2299, -1.0595, -0.2036,  ...,  0.4071,  1.0309,  0.7073],
          ...,
          [ 0.9942,  0.0985, -0.3045,  ...,  0.3595,  1.2762,  0.1312],
          [-1.1851,  0.1872,  2.5162,  ..., -0.4091, -0.5504, -0.3313],
          [ 0.4573, -0.2495,  1.1492,  ...,  0.3916,  0.3092, -0.2549]],

         [[ 0.0178,  0.0383,  0.0396,  ...,  0.0060, -0.0180,  0.0108],
          [ 0.3077,  0.2800, -1.2484,  ...,  0.1144, -0.0260, -0.6417],
          [ 0.8365,  0.1942, -2.6429,  ...,  1.4839, -2.4390, -1.1518],
          ...,
          [-1.0152, -1.3838,  0.4507,  ...,  0.2284,  0.2643,  0.3901],
          [-1.8002, -1.5104, -0.6286,  ...,  1.0451,  0.2438, -0.3518],
          [-0.4032, -0.3529, -1.6265,  ...,  0.5828,  0.5720, -1.2572]],

         [[ 0.0495, -0.0389,  0.0613,  ...,  0.0561, -0.0711, -0.0673],
          [-0.6686,  1.1461, -0.4798,  ...,  0.1773,  0.4573,  0.4967],
          [-0.3811,  0.8968, -0.6061,  ...,  0.0926,  0.3056,  0.9180],
          ...,
          [-0.3757, -0.0510,  0.0062,  ...,  0.6064,  0.7972,  0.7227],
          [-0.2685, -0.7850,  0.7441,  ..., -0.8875, -0.0677,  1.0534],
          [-0.7876,  1.0096, -0.0108,  ..., -0.9138, -0.1195, -0.2942]],

         ...,

         [[-0.1028, -0.0452,  0.0346,  ..., -0.0871,  0.0427,  0.0092],
          [ 0.5169,  0.0966,  0.2483,  ..., -0.4591,  0.3724,  0.6674],
          [ 0.9085,  0.9305, -0.0286,  ..., -0.8769, -0.3911,  0.3594],
          ...,
          [-0.0673, -0.2202, -0.2051,  ...,  0.2041, -0.4487,  1.0220],
          [-0.2218, -0.4037,  1.4038,  ...,  1.5332, -1.2336,  0.4163],
          [ 0.8637, -1.0940,  0.2482,  ...,  0.3983, -1.4612,  0.6188]],

         [[ 0.1576, -0.0522,  0.1510,  ...,  0.0776,  0.0389, -0.1486],
          [-0.0612,  1.4222,  1.2901,  ...,  1.0537,  1.9877, -1.2965],
          [ 0.0701,  1.0599,  1.3164,  ...,  1.8434,  1.7597, -0.8641],
          ...,
          [-0.0791,  0.1802, -0.2036,  ...,  0.6063,  1.2652,  0.1763],
          [ 0.4001,  1.6460,  1.1749,  ..., -0.6267,  2.3732, -0.3538],
          [ 0.2739,  1.4950,  0.8300,  ...,  1.1957,  1.5808, -1.0777]],

         [[ 0.2067, -0.0439, -0.0680,  ...,  0.0390,  0.0473,  0.0275],
          [-0.6717,  0.2561,  0.7676,  ..., -0.2872, -0.5916, -0.1957],
          [-0.9239,  0.0464,  0.4365,  ...,  0.6006, -0.4989,  0.7633],
          ...,
          [ 0.5723,  0.0787,  0.7033,  ...,  0.3464, -0.7811,  1.3074],
          [-0.8109, -0.4612, -1.6027,  ..., -1.6367, -0.0065, -0.7756],
          [-1.3609,  0.5702,  0.7531,  ..., -0.1462,  0.1355,  0.2370]]]],
       grad_fn=<PermuteBackward0>)), (tensor([[[[ 0.0436, -0.2509, -0.4550,  ...,  0.3116,  0.3358,  0.3770],
          [ 0.3000, -0.0478, -1.2507,  ...,  0.3519,  0.7682,  0.7296],
          [ 0.5719, -1.2229, -1.5268,  ...,  0.2386,  0.1496,  0.8318],
          ...,
          [-0.7869, -1.6832, -0.6862,  ...,  0.5248, -1.1760, -0.6061],
          [ 1.1739,  0.7271, -1.4276,  ...,  1.1409, -1.3880, -0.6762],
          [ 1.7515, -0.1609, -0.0345,  ...,  0.9718, -0.5132,  1.4921]],

         [[-0.2801,  0.1559,  0.1167,  ...,  0.0214, -1.1384, -0.1501],
          [-0.7174, -0.1200, -0.7961,  ..., -0.4121,  0.7157,  0.5868],
          [ 1.5377,  0.1651, -0.9257,  ...,  0.3588,  1.3888,  0.1633],
          ...,
          [ 0.2032,  0.5659, -0.9297,  ..., -1.1580, -1.0870,  1.0748],
          [-0.0984,  1.5501, -1.2118,  ..., -1.0350,  0.6500,  0.8747],
          [-1.1498,  0.8479, -0.9318,  ..., -1.2515,  0.5937,  0.4393]],

         [[-1.2411, -0.0878,  0.5490,  ..., -0.6611,  0.4539, -0.2888],
          [ 0.6556,  1.0735, -0.5900,  ...,  0.0895, -0.3484, -0.2450],
          [ 0.3530,  0.0116,  0.0702,  ...,  0.7262, -1.4991, -0.5028],
          ...,
          [ 0.6693,  0.8831, -0.7045,  ...,  1.2413,  0.0528,  0.1498],
          [ 1.5144,  1.9988, -1.8167,  ...,  1.0272, -0.5508, -0.2781],
          [-0.2976,  1.1260, -1.6873,  ...,  1.3365, -0.2020, -0.3461]],

         ...,

         [[ 0.7973, -0.8987, -0.3939,  ..., -1.0369, -0.4123,  0.4803],
          [ 1.4616,  0.0408, -1.0295,  ...,  0.7219,  0.3444, -0.0145],
          [ 0.2550, -1.1764, -0.3335,  ...,  0.8036,  1.7228, -2.3128],
          ...,
          [ 0.6038, -0.3213, -0.9128,  ...,  1.7723,  0.7332, -1.3456],
          [ 1.5292,  0.8308, -1.5665,  ...,  1.7068,  0.6255, -1.4453],
          [ 2.1459, -0.1321, -0.5784,  ...,  1.8690,  1.6415,  0.8508]],

         [[-0.9151,  2.5785,  0.3082,  ...,  0.3579,  1.9421, -0.5408],
          [-0.0171, -2.5663,  0.7328,  ...,  0.3923, -4.1463,  1.9012],
          [ 1.0185, -2.5828, -1.5448,  ...,  1.0508, -4.9451,  1.7123],
          ...,
          [ 1.1092, -2.5339,  0.2730,  ..., -0.9127, -3.6883, -0.9762],
          [ 0.7417, -1.7092,  0.4430,  ...,  0.6517, -4.0859, -0.6250],
          [ 0.6957, -4.4839, -0.4944,  ...,  1.2733, -5.0460,  2.7409]],

         [[-2.0221, -0.3681, -1.1042,  ..., -0.3983,  0.0527,  0.2442],
          [ 1.4425,  0.4368,  0.8613,  ...,  1.2344, -0.1098,  0.1759],
          [ 2.6873, -0.5718,  0.7670,  ...,  1.4859, -0.9973,  1.5824],
          ...,
          [ 2.3481, -0.2267,  0.4736,  ...,  1.0791,  0.1695, -0.6822],
          [ 1.8121,  0.8181,  1.5002,  ...,  1.3897, -1.1112, -0.6512],
          [ 2.0455,  0.8276,  1.0394,  ...,  1.7555, -0.0730, -0.0210]]]],
       grad_fn=<PermuteBackward0>), tensor([[[[-4.8632e-02, -9.1396e-02,  3.1682e-02,  ...,  1.2261e-01,
           -3.6255e-02,  1.2526e-02],
          [-3.1696e-01,  2.4710e-01, -5.5300e-02,  ...,  2.4286e-02,
           -3.1162e-01,  2.2300e-01],
          [ 1.9266e-01,  5.7364e-01,  5.6620e-01,  ...,  6.3398e-01,
           -1.6994e-01, -3.2943e-01],
          ...,
          [ 8.0219e-01, -6.2467e-02,  7.5092e-01,  ..., -2.6152e-01,
            6.4908e-01,  9.1121e-01],
          [ 2.9819e-01, -1.1154e+00,  5.7111e-01,  ..., -1.1155e+00,
            5.0150e-01,  3.6634e-01],
          [ 7.2844e-01,  4.1041e-01,  6.7296e-01,  ...,  2.8859e-01,
           -9.5357e-01,  4.9752e-01]],

         [[ 1.1928e-02,  1.3112e-02, -2.6053e-02,  ...,  4.6390e-02,
            2.8720e-02,  5.6897e-02],
          [ 5.1804e-01, -8.6756e-03,  3.4240e-01,  ..., -9.3518e-01,
           -2.8230e-02, -1.6108e-01],
          [-6.5553e-01, -1.4296e-01,  6.3211e-01,  ..., -2.3726e+00,
           -1.0325e+00,  1.1180e+00],
          ...,
          [-2.7697e-01,  4.7694e-01,  9.3078e-01,  ..., -1.4985e-02,
           -9.5630e-01, -1.0057e+00],
          [-3.5304e-01,  7.6668e-01, -7.3687e-01,  ...,  8.2464e-01,
            6.1313e-01,  1.4616e-01],
          [ 6.2543e-02,  9.5850e-01,  9.9546e-02,  ..., -4.1675e-01,
           -3.1019e-01,  2.1785e-02]],

         [[ 3.9431e-02,  3.2304e-02, -6.9643e-02,  ...,  3.1842e-03,
            1.5391e-02,  8.6383e-03],
          [-4.5218e-02,  3.8015e-01, -7.4175e-03,  ..., -8.6065e-02,
            1.9510e-01,  2.4301e-02],
          [ 1.0227e+00,  7.7004e-02,  7.1903e-02,  ...,  1.1994e+00,
            1.6976e-01, -4.0066e-01],
          ...,
          [ 1.1771e+00,  2.4422e-01,  7.0662e-01,  ...,  1.1337e+00,
           -8.5384e-01, -9.9605e-01],
          [ 4.0196e-01,  3.7700e-01,  1.0244e+00,  ..., -2.4000e-01,
           -2.2166e-03, -8.7664e-01],
          [ 6.3016e-01,  1.0653e-01,  6.7085e-01,  ...,  1.8561e-01,
           -1.0484e+00, -2.8506e-01]],

         ...,

         [[-2.8633e-02,  2.3521e-02, -1.3071e-02,  ..., -3.3836e-02,
           -4.1805e-02,  1.6132e-02],
          [ 2.9617e-01,  2.5753e-01,  6.4459e-01,  ...,  6.7883e-01,
           -1.1170e-01,  3.4354e-01],
          [ 1.4840e-01, -2.1638e-01,  1.5988e-01,  ...,  3.0029e-01,
           -1.7462e+00,  2.2010e+00],
          ...,
          [-1.0735e-01, -2.7973e-01,  1.7696e-01,  ...,  1.2454e-01,
            1.6533e+00,  4.6311e-02],
          [-2.5303e-01, -5.3346e-01, -7.0970e-01,  ...,  3.3254e-01,
           -1.0337e-01, -1.5011e+00],
          [ 9.4744e-01,  4.1239e-01, -1.0214e-01,  ...,  1.0832e+00,
            1.1939e+00,  2.1364e-01]],

         [[-7.3958e-02, -4.4124e-02,  1.7760e-02,  ...,  3.1321e-03,
           -4.5881e-02, -1.0916e-01],
          [-4.6492e-01, -6.5992e-02, -4.8427e-02,  ...,  2.7765e-01,
            1.7094e-01, -2.1020e-01],
          [-9.3265e-01, -1.7024e+00,  1.1011e-01,  ..., -6.0777e-01,
            2.7326e-01, -1.2374e+00],
          ...,
          [-1.0394e-01,  3.4447e-02, -1.4004e+00,  ...,  1.9303e-01,
           -1.2038e+00,  5.6969e-01],
          [-9.6140e-01,  5.8390e-01, -5.3376e-01,  ...,  3.5307e-01,
            3.7874e-01, -4.8008e-02],
          [-6.0081e-02, -5.1836e-01, -9.0043e-02,  ...,  2.2977e-01,
           -1.1964e-01, -4.6107e-01]],

         [[-7.7909e-03,  4.0206e-02, -5.6468e-02,  ..., -3.0341e-02,
            2.4338e-02,  5.3261e-03],
          [-2.4073e-01, -1.4607e-01,  6.8568e-01,  ..., -9.4289e-01,
           -1.0285e+00, -7.2268e-01],
          [-8.9161e-01,  3.2033e-01,  2.2241e-01,  ...,  7.4783e-01,
           -1.8553e-01, -1.4143e+00],
          ...,
          [ 5.0678e-01, -8.7200e-01,  1.3745e+00,  ...,  4.7279e-01,
            1.9468e-01, -3.0692e-01],
          [ 9.3960e-02, -1.1271e+00, -3.2356e-01,  ...,  4.6166e-01,
            1.1812e+00, -7.4736e-02],
          [-1.3350e-01, -7.4492e-01,  7.1189e-01,  ..., -1.8032e-01,
           -1.5200e+00, -9.0480e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-0.5266,  0.4829, -0.8581,  ..., -1.0129, -1.3275,  0.2040],
          [ 0.6896,  0.9769,  0.7884,  ...,  2.3917, -0.4978, -1.2931],
          [ 0.4849,  0.2801,  2.0546,  ...,  2.4861,  0.4516, -2.1291],
          ...,
          [ 1.1704, -0.3403,  1.9815,  ...,  2.4192,  1.2849, -0.9075],
          [ 0.5567,  0.0731,  0.2333,  ...,  1.9754, -0.6718,  0.4945],
          [ 1.1157,  0.0910,  1.9513,  ...,  2.0806, -0.6777, -1.8277]],

         [[ 0.8645, -2.0846,  0.1532,  ...,  0.2459, -2.4906, -0.4514],
          [ 1.1494,  1.2483, -0.0495,  ...,  0.1813,  0.8199, -0.5313],
          [ 1.4952,  1.4661, -1.3266,  ...,  0.6351,  0.5419,  0.3732],
          ...,
          [ 1.6627,  3.2121, -1.1410,  ..., -0.1081,  2.2876, -1.0492],
          [ 0.8725,  3.7180, -0.8677,  ...,  0.5521,  0.0537, -2.0911],
          [ 1.8427,  2.8496, -0.9180,  ..., -0.0876,  2.9764, -1.0137]],

         [[ 1.0175,  0.3871, -0.1741,  ..., -0.8094, -1.4149, -0.3730],
          [-0.2747,  0.4294, -0.8148,  ...,  0.7997, -1.0098, -0.2083],
          [-0.1443,  0.1837, -0.6903,  ...,  2.3234, -0.5142, -1.1581],
          ...,
          [-1.7553, -0.7940, -1.4744,  ...,  1.9563, -0.3079,  0.2517],
          [-1.1555, -0.9816, -1.4792,  ...,  2.4893, -0.8572,  0.6439],
          [ 0.2061,  0.6956, -1.2343,  ...,  1.2946, -0.7649, -1.0596]],

         ...,

         [[ 0.2137, -0.5814,  0.4917,  ..., -0.6758,  1.0594,  0.2809],
          [-0.0787,  1.1178, -0.9665,  ..., -2.9838, -0.0755,  0.3358],
          [-0.3467,  0.6547, -1.9701,  ..., -2.6404, -1.7759,  0.1484],
          ...,
          [-0.6088,  0.2404, -1.0831,  ..., -2.5044, -0.5236,  0.2501],
          [-1.2627, -0.4007,  0.0159,  ..., -2.2715, -1.9617,  0.1351],
          [-0.4629,  0.4004, -1.0877,  ..., -3.2533, -0.1876, -0.2612]],

         [[ 0.2431,  0.5528,  0.5439,  ...,  0.7452,  0.0856,  0.8468],
          [ 0.3639,  2.4237,  0.9672,  ...,  0.7770, -0.7330,  0.4097],
          [-0.4982,  1.9386, -0.1103,  ...,  1.4543, -0.3265,  0.4745],
          ...,
          [ 0.2191,  1.5633, -0.4826,  ..., -0.9138, -0.7183,  0.2929],
          [-2.4011, -0.7274, -0.1691,  ...,  0.5614, -0.1154,  2.1418],
          [ 1.8710,  2.7152,  0.3026,  ...,  0.4339, -1.6067,  0.4278]],

         [[-0.7092,  0.3125, -1.6205,  ..., -0.4008,  0.2350, -1.3048],
          [ 0.0382,  0.8210, -1.6851,  ...,  1.5476,  1.1133, -1.3639],
          [-1.2253, -0.0602, -3.1185,  ..., -0.4857,  1.8382,  0.9552],
          ...,
          [-3.0265, -0.1628, -0.6678,  ...,  1.2046,  1.1136, -0.6637],
          [-2.2680, -0.1403, -1.6040,  ...,  0.0642,  0.6752, -0.0818],
          [-0.6567, -0.4737, -1.8665,  ...,  1.7928,  1.7230, -1.4443]]]],
       grad_fn=<PermuteBackward0>), tensor([[[[ 3.4694e-03,  5.2180e-02, -7.1138e-02,  ...,  5.9633e-02,
           -5.0955e-02, -7.4279e-02],
          [ 9.5441e-02, -8.2082e-04,  3.7786e-01,  ..., -7.9814e-01,
           -3.4941e-01,  5.3955e-01],
          [-1.7982e-01,  1.0793e+00,  8.4480e-01,  ..., -5.6335e-01,
            4.7423e-01, -1.3511e-01],
          ...,
          [-8.4265e-01, -2.8338e-02, -8.4992e-01,  ..., -8.6247e-01,
            6.9610e-01, -1.6560e-01],
          [-5.8879e-02,  7.0560e-01,  8.3837e-01,  ..., -4.8124e-01,
           -1.7102e+00,  4.3793e-01],
          [ 4.7365e-02,  2.9308e-01,  2.9819e-01,  ..., -3.3441e-01,
           -6.1577e-01,  5.0968e-01]],

         [[ 3.7719e-02,  1.2977e-04,  4.9038e-02,  ..., -3.9138e-02,
           -2.6473e-02, -1.4142e-02],
          [-1.9622e-01,  4.4304e-01,  9.3501e-02,  ...,  8.6879e-01,
            5.8439e-01,  6.5467e-01],
          [-8.1397e-01, -1.4557e+00, -3.4408e-01,  ...,  1.0143e+00,
            1.6014e-01, -7.6486e-01],
          ...,
          [ 4.8213e-01,  1.1956e+00, -6.7466e-01,  ...,  4.4558e-03,
           -6.0745e-01,  1.5004e-01],
          [-9.2434e-01, -9.9667e-02, -1.7371e-01,  ...,  3.3668e-01,
            3.7452e-01,  9.1399e-01],
          [-8.0525e-01,  2.7367e-01,  2.7182e-01,  ...,  1.5725e+00,
            1.8934e-01,  9.1494e-01]],

         [[-6.9176e-03,  1.8243e-02, -3.3975e-02,  ...,  8.5669e-03,
            2.7227e-02,  5.8461e-02],
          [-2.8638e-01,  4.4393e-02, -2.4720e-01,  ...,  5.8055e-01,
           -1.1038e+00, -3.1214e-01],
          [-4.1151e-02,  4.7980e-01, -8.1177e-01,  ...,  2.5263e+00,
           -6.2052e-01, -4.0801e-01],
          ...,
          [-6.4285e-01,  2.1790e-01,  7.1201e-01,  ...,  7.6857e-01,
            1.9746e-02, -1.2292e-02],
          [ 4.3683e-01, -2.0561e-01,  5.6170e-01,  ..., -1.3195e+00,
           -6.0955e-01,  8.5465e-01],
          [-5.0826e-02,  2.0641e-01,  2.1014e-01,  ..., -6.1202e-01,
           -3.7409e-01,  5.8607e-01]],

         ...,

         [[-7.8410e-02,  2.6667e-02,  1.1429e-02,  ..., -3.5996e-02,
           -7.8381e-03,  1.2273e-03],
          [-2.4955e-01,  3.0179e-01,  2.2439e-01,  ..., -7.0245e-01,
           -4.7259e-01, -1.2154e-01],
          [-1.1360e+00,  4.8186e-01,  6.9660e-01,  ...,  5.2388e-02,
            4.9656e-01, -7.1202e-01],
          ...,
          [-2.2132e-01,  2.5862e-01,  5.4504e-01,  ...,  6.4937e-01,
           -1.4201e-01, -6.9701e-02],
          [-9.8322e-01, -1.1579e-01,  1.4461e+00,  ...,  4.0303e-01,
           -8.9281e-01,  9.6826e-01],
          [-1.2535e+00,  7.9669e-01,  2.3864e+00,  ..., -1.1996e+00,
           -1.2942e-02,  1.5757e+00]],

         [[ 7.4373e-02, -1.0839e-03,  4.7472e-02,  ...,  2.5576e-02,
            5.5578e-02,  3.0725e-02],
          [ 1.6561e-01,  1.1326e+00,  1.1021e+00,  ..., -6.7084e-02,
            1.0625e+00, -7.9841e-01],
          [-1.1934e+00,  1.3455e+00,  7.5402e-01,  ...,  3.0290e+00,
            1.9807e+00, -1.6143e-01],
          ...,
          [ 8.7119e-01,  1.6007e+00,  9.8724e-01,  ...,  6.2297e-01,
            9.5836e-01, -6.7591e-02],
          [ 5.6550e-01,  7.5545e-01, -9.4622e-01,  ...,  3.9639e-01,
           -1.3479e-01,  1.4511e-01],
          [-2.5438e-01,  1.3767e+00,  1.5838e+00,  ...,  5.7618e-02,
            1.7279e+00, -7.0514e-01]],

         [[-1.1214e-01,  2.7461e-02, -6.8169e-02,  ..., -8.8035e-02,
            7.2290e-02, -2.1984e-02],
          [-1.5667e-01,  3.3572e-01, -2.9793e-01,  ..., -3.2849e-01,
           -6.0364e-02,  8.4579e-02],
          [-3.0011e-01,  3.6599e-01,  4.1995e-01,  ..., -5.6659e-01,
            1.6448e-01,  2.1300e-01],
          ...,
          [-5.3965e-01,  8.5568e-01, -1.0334e+00,  ...,  1.6571e+00,
            1.2634e+00,  1.2663e-02],
          [-1.1969e+00,  1.1998e-01,  7.4285e-01,  ...,  2.5529e+00,
            2.4390e+00, -3.4413e-01],
          [-7.4888e-01,  4.5366e-01, -1.2199e+00,  ..., -4.5325e-01,
            4.5486e-01,  9.2945e-01]]]], grad_fn=<PermuteBackward0>)), (tensor([[[[-1.7115, -0.3095, -0.3052,  ...,  0.1569,  0.3295, -0.5102],
          [-0.2954, -0.5369, -1.0378,  ...,  1.1073, -0.8305,  0.1360],
          [ 0.3707, -0.1892, -0.0631,  ..., -0.2839, -0.1220,  0.7498],
          ...,
          [ 0.1536,  0.6239, -0.8671,  ...,  0.7387,  0.5640,  0.1241],
          [ 2.3072, -0.4214, -0.9284,  ...,  1.0017,  0.4802,  0.5843],
          [ 0.1397,  0.0057, -1.1350,  ...,  1.1576, -1.0530,  0.4489]],

         [[ 0.1108, -0.0705,  2.3025,  ...,  0.2415,  0.0896, -0.1951],
          [ 0.7120, -1.0674, -0.9315,  ...,  0.1600,  0.2404, -0.4726],
          [ 1.4893, -0.9425, -1.5101,  ..., -0.3591,  0.0335,  0.4421],
          ...,
          [ 0.1341, -0.0506, -1.3000,  ..., -0.1105, -0.2529,  0.7670],
          [-0.5547, -0.6913, -1.2921,  ...,  0.2898, -0.2538,  0.9526],
          [ 0.2336, -0.8297, -0.8416,  ..., -0.0980, -0.2919,  0.7454]],

         [[-0.2041,  1.0503,  0.4759,  ..., -0.5452,  0.3040, -0.1147],
          [-0.4822,  0.5302, -0.5850,  ...,  1.5575,  0.2531, -0.3998],
          [-0.5382,  0.5110,  0.0939,  ...,  0.6880, -0.2115, -0.7376],
          ...,
          [-0.3389,  0.6598, -0.4937,  ...,  0.4478, -0.8530, -0.5765],
          [-0.7257,  1.3182,  0.9792,  ...,  1.7382,  0.5813, -1.0117],
          [-0.9058,  1.0543, -0.6513,  ...,  2.0381, -0.3831,  0.2467]],

         ...,

         [[ 0.5577,  0.9790, -0.8716,  ..., -0.7184,  0.7245,  0.8611],
          [ 0.3537,  0.1875, -0.5812,  ..., -0.5651, -0.0493,  0.0897],
          [ 0.4359, -0.7858, -0.9179,  ..., -1.4713,  0.2309, -0.2758],
          ...,
          [ 0.9503, -0.6978,  0.7306,  ..., -0.7847, -0.9335, -0.8081],
          [ 0.1232,  1.2112,  1.0973,  ...,  0.2584,  1.1175, -0.0057],
          [-0.0062, -0.2483,  0.2463,  ..., -0.3165,  0.3718, -0.2848]],

         [[-0.4014,  0.3733,  0.3393,  ...,  0.7212,  0.0451, -0.0838],
          [-0.5197,  1.3345, -1.5982,  ...,  0.5380, -0.2475, -0.9776],
          [ 0.0646,  0.0452, -0.4746,  ...,  0.9874, -0.8139,  0.1726],
          ...,
          [-0.6432, -0.4941,  0.4357,  ...,  0.5838, -1.3339, -0.0826],
          [-0.9413, -1.2357, -0.4911,  ...,  1.3679, -1.0148, -1.4263],
          [-0.9545,  0.2418, -1.5970,  ...,  0.3238, -0.9107, -0.7229]],

         [[-0.7459, -0.0075,  0.4400,  ..., -0.1109,  0.0299, -0.0598],
          [-0.2137,  0.3865,  1.1712,  ...,  0.4425, -0.3584,  1.2832],
          [ 0.5957,  0.1015, -0.1897,  ...,  0.4039, -1.3808,  1.2112],
          ...,
          [ 0.7437, -1.3902,  0.2656,  ...,  0.9423, -1.2780,  1.6726],
          [-0.7614,  0.3624,  1.4484,  ...,  0.2220, -1.0658,  1.0444],
          [-0.4612,  0.8413,  1.7939,  ...,  0.1289, -0.8518,  1.1819]]]],
       grad_fn=<PermuteBackward0>), tensor([[[[ 7.7570e-02, -1.1777e-01, -1.6829e-01,  ..., -3.0139e-01,
            2.8640e-01, -1.7741e-01],
          [ 3.7056e-01,  6.3164e-01,  7.7662e-01,  ...,  2.7495e+00,
           -1.5492e+00,  1.1155e+00],
          [ 1.6399e+00,  1.3236e+00,  5.0145e-01,  ...,  2.7380e+00,
           -2.5362e+00,  2.0660e+00],
          ...,
          [-2.2210e-03,  1.2618e-01,  1.9018e-01,  ...,  2.5907e+00,
           -1.5682e+00,  7.7443e-01],
          [ 1.0661e+00,  1.8362e-01,  9.9011e-01,  ...,  1.7970e+00,
           -1.8210e-01, -7.9636e-01],
          [ 1.1176e+00,  9.5490e-01,  4.2716e-01,  ...,  2.4762e+00,
           -1.8121e+00,  1.6125e+00]],

         [[ 1.0853e-01, -1.0814e-02,  5.5897e-02,  ..., -9.3695e-03,
           -8.4395e-02,  1.6578e-01],
          [ 9.0288e-02,  4.3214e-01,  7.7907e-02,  ...,  3.6511e-01,
            4.1462e-01, -3.7498e-01],
          [ 4.8901e-02,  1.1972e+00, -1.0267e-01,  ..., -2.4577e-01,
            3.2252e-01,  9.5713e-02],
          ...,
          [ 1.4289e+00, -4.0081e-01,  8.8847e-01,  ..., -1.2688e-01,
           -2.1349e-01, -1.5179e+00],
          [-1.8024e-01, -5.9997e-01,  1.6811e+00,  ...,  8.8114e-01,
           -1.2796e+00,  8.0612e-01],
          [ 3.5363e-01,  1.5338e-01,  1.0489e-01,  ...,  7.1419e-01,
           -2.5939e-01,  1.1640e-01]],

         [[-1.3536e-02,  2.5633e-02, -3.8610e-02,  ...,  4.7447e-02,
            4.5465e-04,  7.3786e-02],
          [ 3.7973e-01, -2.6919e-01, -4.5875e-01,  ..., -1.4160e-01,
            3.0695e-01, -4.8341e-01],
          [ 1.1969e+00,  1.2378e+00, -6.2153e-01,  ..., -9.3299e-01,
            5.5717e-02, -2.5939e-02],
          ...,
          [ 1.0509e+00, -6.8117e-01, -5.0678e-01,  ..., -5.8349e-01,
            1.6390e-01, -4.4167e-01],
          [-5.3312e-01,  6.3160e-01,  2.2554e-01,  ..., -1.1507e+00,
            6.4968e-01,  3.7368e-01],
          [ 2.3626e-01, -1.7837e-01,  2.7653e-01,  ..., -8.8951e-02,
           -3.4488e-02, -6.5983e-01]],

         ...,

         [[-3.0262e-02, -1.2759e-02,  8.2024e-02,  ...,  4.1477e-02,
           -3.4039e-02,  1.6534e-02],
          [ 7.0146e-02,  3.9249e-01,  3.6694e-02,  ...,  1.1981e-01,
           -3.4416e-01, -1.2740e-01],
          [-1.4357e+00, -8.1313e-01,  3.6240e-01,  ...,  6.4624e-01,
           -7.6324e-01,  1.4873e+00],
          ...,
          [ 1.0556e-01, -3.8366e-01,  1.2748e+00,  ..., -3.6558e-01,
            4.0858e-01,  2.4199e-01],
          [-2.5444e-01,  1.1958e+00, -1.7147e-01,  ...,  6.1984e-01,
           -2.2845e-01, -1.8110e+00],
          [ 3.2427e-01,  8.9915e-01,  1.1141e+00,  ...,  6.8071e-01,
           -3.4533e-01, -1.7910e-01]],

         [[-1.8907e-01, -6.5480e-02,  7.6243e-02,  ..., -5.9887e-02,
            5.6530e-02, -7.3080e-02],
          [-8.2506e-01, -3.6656e-02,  4.9222e-01,  ...,  2.5220e-01,
            3.1897e-01,  1.9113e-01],
          [-4.6517e-01, -2.1911e-01, -6.4030e-01,  ...,  7.2280e-01,
            7.5668e-01,  5.6131e-01],
          ...,
          [-1.0660e+00, -4.2479e-01, -5.0573e-01,  ..., -5.8658e-02,
           -6.6094e-02, -4.4752e-01],
          [-8.6907e-02,  1.2486e-04, -5.2314e-01,  ...,  1.1544e-01,
            4.3831e-01, -1.0179e-02],
          [-1.0669e+00, -7.1475e-01,  8.0158e-01,  ..., -1.1919e-01,
           -2.0185e-01,  3.2946e-01]],

         [[ 1.2763e-01, -1.2701e-01,  1.6529e-01,  ..., -1.4527e-01,
           -8.5370e-03, -1.7278e-01],
          [-6.5069e-02,  3.5000e-01,  5.6586e-01,  ..., -3.5917e-01,
           -4.1324e-01,  2.9987e-01],
          [ 2.5123e-01,  5.5106e-01,  4.2795e-01,  ..., -1.0718e+00,
           -6.8236e-01, -4.2256e-01],
          ...,
          [-6.0648e-01, -5.4619e-01,  1.4942e-02,  ..., -7.6836e-01,
           -5.9767e-01, -1.3891e-02],
          [-3.4398e-01, -8.0992e-01,  7.4776e-01,  ..., -1.8947e+00,
           -2.7473e-01,  4.0089e-01],
          [ 8.6354e-02, -1.2515e-02, -2.7977e-01,  ..., -4.1148e-01,
           -5.5178e-01,  7.0079e-02]]]], grad_fn=<PermuteBackward0>))), hidden_states=None, attentions=None, cross_attentions=None)
output[0]
tensor([[[ -36.3292,  -36.3402,  -40.4228,  ...,  -46.0234,  -44.5284,
           -37.1276],
         [-114.9346, -116.5035, -117.9236,  ..., -117.8857, -119.3379,
          -112.9298],
         [-123.5036, -123.0548, -127.3876,  ..., -130.5238, -130.5279,
          -123.2711],
         ...,
         [-101.3852, -101.2506, -103.6583,  ..., -103.3747, -107.7192,
           -99.4521],
         [ -83.0701,  -84.3884,  -91.9513,  ...,  -91.7482,  -93.3971,
           -85.1204],
         [ -91.2749,  -93.1332,  -93.6408,  ...,  -94.3482,  -93.4517,
           -90.1472]]], grad_fn=<UnsafeViewBackward0>)
output[0].shape
torch.Size([1, 9, 50257])
torch.topk(output[0][0],5)
torch.return_types.topk(
values=tensor([[ -32.8755,  -33.1021,  -33.9975,  -34.4861,  -34.5463],
        [-105.5972, -106.3818, -106.3978, -106.9693, -107.0778],
        [-113.2521, -114.7346, -114.8781, -114.9605, -115.0834],
        [-118.2435, -119.2980, -119.5907, -119.6229, -119.7969],
        [ -83.6241,  -84.6822,  -84.8526,  -85.4978,  -86.6938],
        [ -79.9051,  -80.3284,  -81.6157,  -81.8538,  -82.9018],
        [ -90.4443,  -90.7053,  -91.9059,  -92.0003,  -92.1531],
        [ -75.2650,  -76.9698,  -77.5753,  -77.6700,  -77.8095],
        [ -78.7985,  -81.5545,  -81.6846,  -81.8984,  -82.5938]],
       grad_fn=<TopkBackward0>),
indices=tensor([[   11,    13,   198,   290,   286],
        [  262,   356,   314,   340,   257],
        [  262,   257,  1737,  2901,  2805],
        [  835,   717,   938, 10955,  1218],
        [  284,   736,  1363,   503,   422],
        [  670,   262,   616,   257,  1524],
        [ 9003,  2607, 11550,  4436,  4495],
        [   11,   314,   338,   284,   287],
        [  314,   616,   257,   262,   612]]))
encoding.input_ids[0]
tensor([8888,   11,  319,  616,  835,  284,  262, 6403,   11])
for i in range(1,len(encoding.input_ids[0])):
    print(tokenizer.decode(encoding.input_ids[0][:i+1]), '\t→', tokenizer.decode(torch.topk(output[0][0],1).indices[i]))
Today, 	→  the
Today, on 	→  the
Today, on my 	→  way
Today, on my way 	→  to
Today, on my way to 	→  work
Today, on my way to the 	→  airport
Today, on my way to the university 	→ ,
Today, on my way to the university, 	→  I

generowanie tekstu

encoding
{'input_ids': tensor([[8888,   11,  319,  616,  835,  284,  262, 6403,   11]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}
text = TEXT
text
'Today, on my way to the university,'
encoding
{'input_ids': tensor([[8888,   11,  319,  616,  835,  284,  262, 6403,   11]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}
encoding = tokenizer(text, return_tensors='pt')
for i in range(10):
    output =pt_model(**encoding)
    text += tokenizer.decode(torch.topk(output[0][0][-1],1).indices)
    encoding = tokenizer(text, return_tensors='pt')
text
'Today, on my way to the university, I was approached by a man who was a student'

Co można zrobić, żeby poprawić wynik? Strategie dekodowania:

  • greedy search
  • random sampling
  • random sampling with temperature
  • top-k sampling lub top-k sampling with temperature
  • top-p sampling (inna nazwa: nucleus sampling) lub top-p sampling with temperature

pipeline

generator = pipeline('text-generation', model=model_name)
TEXT
'Today, on my way to the university,'
generator(TEXT, max_length=20, num_return_sequences=5)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[{'generated_text': 'Today, on my way to the university, some of them would have been very pleased, and I'},
 {'generated_text': 'Today, on my way to the university, and he made me dinner, and he called me back'},
 {'generated_text': 'Today, on my way to the university, I saw three white girls who seemed a bit different—'},
 {'generated_text': 'Today, on my way to the university, I drove through the town, past trees and bushes,'},
 {'generated_text': 'Today, on my way to the university, I saw an elderly lady come up behind me."\n'}]
generator(TEXT, max_length=20, num_beams=1, do_sample=False)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[{'generated_text': 'Today, on my way to the university, I was approached by a man who was a student at'}]
generator(TEXT, max_length=20, num_beams=10, top_p = 0.2)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[{'generated_text': 'Today, on my way to the university, I was approached by a man who was very nice and'}]
generator(TEXT, max_length=20, num_beams=10, temperature = 1.0 )
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[{'generated_text': 'Today, on my way to the university, I was approached by a group of students who asked me'}]
generator(TEXT, max_length=20, num_beams=10, temperature = 10.0 )
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[{'generated_text': 'Today, on my way to the university, I noticed some young boys who was very active on campus'}]
generator(TEXT, max_length=20, num_beams=10,  temperature = 100.0 )
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[{'generated_text': 'Today, on my way to the university, the trainees have noticed how a car could become an'}]

inne możliwość:

  • repetition_penalty
  • length_penalty
  • no_repeat_ngram_size
  • bad_words_ids
  • force_words_ids

huggingface API

from transformers import CTRLTokenizer, CTRLModel
tokenizer = CTRLTokenizer.from_pretrained("ctrl")

CTRL

inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
inputs
{'input_ids': tensor([[43213,   586,  3153,     8, 83781]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}
tokenizer.control_codes
{'Pregnancy': 168629,
 'Christianity': 7675,
 'Explain': 106423,
 'Fitness': 63440,
 'Saving': 63163,
 'Ask': 27171,
 'Ass': 95985,
 'Joke': 163509,
 'Questions': 45622,
 'Thoughts': 49605,
 'Retail': 52342,
 'Feminism': 164338,
 'Writing': 11992,
 'Atheism': 192263,
 'Netflix': 48616,
 'Computing': 39639,
 'Opinion': 43213,
 'Alone': 44967,
 'Funny': 58917,
 'Gaming': 40358,
 'Human': 4088,
 'India': 1331,
 'Joker': 77138,
 'Diet': 36206,
 'Legal': 11859,
 'Norman': 4939,
 'Tip': 72689,
 'Weight': 52343,
 'Movies': 46273,
 'Running': 23425,
 'Science': 2090,
 'Horror': 37793,
 'Confession': 60572,
 'Finance': 12250,
 'Politics': 16360,
 'Scary': 191985,
 'Support': 12654,
 'Technologies': 32516,
 'Teenage': 66160,
 'Event': 32769,
 'Learned': 67460,
 'Notion': 182770,
 'Wikipedia': 37583,
 'Books': 6665,
 'Extract': 76050,
 'Confessions': 102701,
 'Conspiracy': 75932,
 'Links': 63674,
 'Narcissus': 150425,
 'Relationship': 54766,
 'Relationships': 134796,
 'Reviews': 41671,
 'News': 4256,
 'Translation': 26820,
 'multilingual': 128406}
generator = pipeline('text-generation', model="ctrl")
/home/kuba/anaconda3/envs/zajeciaei/lib/python3.10/site-packages/transformers/models/ctrl/modeling_ctrl.py:43: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / d_model_size)
TEXT = "Today"
generator("Opinion " + TEXT, max_length = 50)

[{'generated_text': 'Opinion Today I learned that the US government has been spying on the citizens of other countries for years. \n Score: 6 \n \n Title: CMV: I think that the US should not be involved in the Middle East \n Text: I think that the US'}]

generator("Technologies " + TEXT, max_length = 50)

[{'generated_text': 'Technologies Today \n Score: 6 \n \n Title: The Internet is a great tool for the average person to get information and to share it with others. But it is also a great tool for the government to spy on us. \n Score: 6 \n \n Title: The'}]

generator("Gaming " + TEXT, max_length = 50)

[{'generated_text': 'Gaming Today \n Score: 6 \n \n Title: I just got a new gaming pc and I have a question \n Text: I just got a new gaming pc and I have a question \n \n I have a monitor that I bought a while back'}]

Zadanie

Za pomocą GPT2 lub distillGPT wygenerować odpowiedzi dla wyzwania challanging america. Nie trzeba douczać modelu.