zugb-materials-extra/meetup.ipynb

2.2 MiB

from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
model = "tiiuae/falcon-40b"
tokenizer = AutoTokenizer.from_pretrained(model)
vocab = tokenizer.get_vocab()
len(vocab)
65024
tokenizer.tokenize('Dzisiaj rano w Poznaniu na Łęgach wylądował latający talerz. Premier')
['D',
 'z',
 'isia',
 'j',
 'Ġr',
 'ano',
 'Ġw',
 'ĠPoz',
 'n',
 'aniu',
 'Ġna',
 'ĠÅģ',
 'ÄĻ',
 'g',
 'ach',
 'Ġw',
 'yl',
 'Äħd',
 'owaÅĤ',
 'Ġl',
 'ata',
 'jÄħ',
 'cy',
 'Ġtal',
 'er',
 'z',
 '.',
 'ĠPremier']
tokenizer.tokenize('Today a flying saucer has landed in Poznan. The prime')
['Today',
 'Ġa',
 'Ġflying',
 'Ġsau',
 'cer',
 'Ġhas',
 'Ġlanded',
 'Ġin',
 'ĠPoz',
 'nan',
 '.',
 'ĠThe',
 'Ġprime']
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
)
A new version of the following files was downloaded from https://huggingface.co/tiiuae/falcon-40b:
- configuration_RW.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/tiiuae/falcon-40b:
- modelling_RW.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Downloading shards:   0%|          | 0/9 [00:00<?, ?it/s]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]
Xformers is not installed correctly. If you want to use memorry_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.
The model 'RWForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MvpForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].
pipeline.model.eval()

sequences = pipeline(
   "Dzisiaj rano w Piotrkowie Trybunalskim wybuchła epidemia tajemniczego nowego wirusa. Premier",
    max_length=200,
    do_sample=True,
    top_k=100,
    temperature=0.7,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
)

sequences[0]['generated_text']
/home/filipg/miniconda3/envs/torch2b/lib/python3.11/site-packages/transformers/generation/utils.py:1255: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)
  warnings.warn(
Setting `pad_token_id` to `eos_token_id`:11 for open-end generation.
'Dzisiaj rano w Piotrkowie Trybunalskim wybuchła epidemia tajemniczego nowego wirusa. Premier w osobie Krzysztofa Rutkowskiego ogłosił stan wyjątkowy i areszt domowy. Tymczasem bary i restauracje otwierają się, a ludzie wchodzą do nich.\nBary i restauracje otwierają się, a ludzie wchodzą do nich.\nKiedy wchodzę do baru, to widzę, że po całym mieście otwierają się bary i restauracje. Ludzie wchodzą do tych miejsc, nie zdając sobie sprawy, że wirus może przenosić się w powietrzu. Mimo stanu wyjątkowego, ludzie jakby nic nie widzieli i nie słyszeli.\nJak można'
pipeline.model
RWForCausalLM(
  (transformer): RWModel(
    (word_embeddings): Embedding(65024, 8192)
    (h): ModuleList(
      (0-59): 60 x DecoderLayer(
        (ln_attn): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
        (ln_mlp): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
        (self_attention): Attention(
          (maybe_rotary): RotaryEmbedding()
          (query_key_value): Linear(in_features=8192, out_features=9216, bias=False)
          (dense): Linear(in_features=8192, out_features=8192, bias=False)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): MLP(
          (dense_h_to_4h): Linear(in_features=8192, out_features=32768, bias=False)
          (act): GELU(approximate='none')
          (dense_4h_to_h): Linear(in_features=32768, out_features=8192, bias=False)
        )
      )
    )
    (ln_f): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=8192, out_features=65024, bias=False)
)
tokens = tokenizer.tokenize("Dzisiaj rano w Madrycie wylądował latający talerz. Premier")
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids
[47,
 101,
 33623,
 85,
 392,
 2808,
 251,
 5509,
 547,
 9228,
 251,
 1985,
 27255,
 38281,
 282,
 785,
 14585,
 2586,
 3438,
 246,
 101,
 25,
 15222]
output = pipeline.model(torch.tensor([token_ids]))
output
CausalLMOutputWithCrossAttentions(loss={'logits': tensor([[[ -8.7500, -10.6250, -11.9375,  ..., -10.8125, -12.0000,  -9.5000],
         [ -9.1250, -10.5000, -12.1250,  ...,  -8.1250,  -9.4375,  -7.1562],
         [-16.1250, -22.2500, -24.0000,  ..., -19.1250, -19.5000, -18.6250],
         ...,
         [-14.1250, -16.0000, -20.2500,  ..., -12.3750, -17.7500, -11.3125],
         [-13.1875, -14.3125, -18.8750,  ..., -15.3125, -17.6250, -14.6250],
         [-13.5000, -14.4375, -18.0000,  ..., -12.5000, -16.5000,  -8.8750]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), 'past_key_values': ((tensor([[[ 0.5547, -1.4062, -0.3340,  ...,  0.1514, -0.9297, -1.0312],
         [-0.4785, -1.4062, -0.4316,  ...,  0.0583, -1.4062,  0.1992],
         [-2.2500, -2.3281,  0.4707,  ..., -1.1562, -0.1973,  0.9883],
         ...,
         [-0.5508,  1.2656,  0.2754,  ...,  0.0527, -1.4062,  0.1992],
         [-0.6953,  1.1406, -0.3242,  ..., -0.9414, -0.6211, -1.1953],
         [-1.4141,  1.2656,  0.2314,  ..., -0.1182, -0.2041, -0.2812]],

        [[ 0.5547, -1.4062, -0.3340,  ...,  0.1514, -0.9297, -1.0312],
         [-0.4785, -1.4062, -0.4316,  ...,  0.0583, -1.4062,  0.1992],
         [-2.2500, -2.3281,  0.4707,  ..., -1.1562, -0.1973,  0.9883],
         ...,
         [-0.5508,  1.2656,  0.2754,  ...,  0.0527, -1.4062,  0.1992],
         [-0.6953,  1.1406, -0.3242,  ..., -0.9414, -0.6211, -1.1953],
         [-1.4141,  1.2656,  0.2314,  ..., -0.1182, -0.2041, -0.2812]],

        [[ 0.5547, -1.4062, -0.3340,  ...,  0.1514, -0.9297, -1.0312],
         [-0.4785, -1.4062, -0.4316,  ...,  0.0583, -1.4062,  0.1992],
         [-2.2500, -2.3281,  0.4707,  ..., -1.1562, -0.1973,  0.9883],
         ...,
         [-0.5508,  1.2656,  0.2754,  ...,  0.0527, -1.4062,  0.1992],
         [-0.6953,  1.1406, -0.3242,  ..., -0.9414, -0.6211, -1.1953],
         [-1.4141,  1.2656,  0.2314,  ..., -0.1182, -0.2041, -0.2812]],

        ...,

        [[ 0.4551, -0.1377, -0.2383,  ...,  0.9648, -1.7422,  0.1562],
         [ 0.2422, -1.1875, -0.4629,  ..., -0.1816, -0.8867, -0.7070],
         [-1.3672, -1.3750, -1.0781,  ...,  0.2119,  1.6172,  0.6758],
         ...,
         [ 0.0078, -0.5391,  0.4355,  ..., -0.1787, -0.8945, -0.7070],
         [-0.1738,  0.0957,  0.4004,  ...,  0.7578, -1.8203,  1.6328],
         [-0.4863,  1.2500, -2.1094,  ...,  0.9961,  1.2109,  1.5938]],

        [[ 0.4551, -0.1377, -0.2383,  ...,  0.9648, -1.7422,  0.1562],
         [ 0.2422, -1.1875, -0.4629,  ..., -0.1816, -0.8867, -0.7070],
         [-1.3672, -1.3750, -1.0781,  ...,  0.2119,  1.6172,  0.6758],
         ...,
         [ 0.0078, -0.5391,  0.4355,  ..., -0.1787, -0.8945, -0.7070],
         [-0.1738,  0.0957,  0.4004,  ...,  0.7578, -1.8203,  1.6328],
         [-0.4863,  1.2500, -2.1094,  ...,  0.9961,  1.2109,  1.5938]],

        [[ 0.4551, -0.1377, -0.2383,  ...,  0.9648, -1.7422,  0.1562],
         [ 0.2422, -1.1875, -0.4629,  ..., -0.1816, -0.8867, -0.7070],
         [-1.3672, -1.3750, -1.0781,  ...,  0.2119,  1.6172,  0.6758],
         ...,
         [ 0.0078, -0.5391,  0.4355,  ..., -0.1787, -0.8945, -0.7070],
         [-0.1738,  0.0957,  0.4004,  ...,  0.7578, -1.8203,  1.6328],
         [-0.4863,  1.2500, -2.1094,  ...,  0.9961,  1.2109,  1.5938]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[-0.0036,  0.0417,  0.0364,  ..., -0.0087, -0.0391, -0.0474],
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [-0.0591, -0.0118, -0.1279,  ..., -0.1592,  0.1338,  0.0255],
         ...,
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [ 0.0087,  0.0205,  0.0557,  ...,  0.0085,  0.0039, -0.0454],
         [-0.0659,  0.0427,  0.1006,  ...,  0.1055, -0.0527, -0.0339]],

        [[-0.0036,  0.0417,  0.0364,  ..., -0.0087, -0.0391, -0.0474],
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [-0.0591, -0.0118, -0.1279,  ..., -0.1592,  0.1338,  0.0255],
         ...,
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [ 0.0087,  0.0205,  0.0557,  ...,  0.0085,  0.0039, -0.0454],
         [-0.0659,  0.0427,  0.1006,  ...,  0.1055, -0.0527, -0.0339]],

        [[-0.0036,  0.0417,  0.0364,  ..., -0.0087, -0.0391, -0.0474],
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [-0.0591, -0.0118, -0.1279,  ..., -0.1592,  0.1338,  0.0255],
         ...,
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [ 0.0087,  0.0205,  0.0557,  ...,  0.0085,  0.0039, -0.0454],
         [-0.0659,  0.0427,  0.1006,  ...,  0.1055, -0.0527, -0.0339]],

        ...,

        [[-0.0593, -0.0640, -0.0276,  ..., -0.0116, -0.0459, -0.0016],
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0708, -0.0918,  0.2285,  ..., -0.0635,  0.1396,  0.0603],
         ...,
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0923,  0.0505,  0.0068,  ...,  0.0239, -0.0119,  0.0031],
         [-0.2773, -0.3125, -0.3086,  ...,  0.0464,  0.1826, -0.2871]],

        [[-0.0593, -0.0640, -0.0276,  ..., -0.0116, -0.0459, -0.0016],
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0708, -0.0918,  0.2285,  ..., -0.0635,  0.1396,  0.0603],
         ...,
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0923,  0.0505,  0.0068,  ...,  0.0239, -0.0119,  0.0031],
         [-0.2773, -0.3125, -0.3086,  ...,  0.0464,  0.1826, -0.2871]],

        [[-0.0593, -0.0640, -0.0276,  ..., -0.0116, -0.0459, -0.0016],
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0708, -0.0918,  0.2285,  ..., -0.0635,  0.1396,  0.0603],
         ...,
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0923,  0.0505,  0.0068,  ...,  0.0239, -0.0119,  0.0031],
         [-0.2773, -0.3125, -0.3086,  ...,  0.0464,  0.1826, -0.2871]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.4531,  1.6719, -4.0938,  ..., -1.4219,  0.8477,  1.2422],
         [ 7.1250,  2.4375, -2.5938,  ..., -1.1406,  0.6445,  1.5938],
         [ 4.6875,  4.4375, -2.6875,  ..., -0.1709,  0.5898,  3.0156],
         ...,
         [ 7.1875,  3.5156,  0.3594,  ..., -2.2969,  0.7539,  3.3750],
         [ 2.5312,  0.1777, -1.4219,  ...,  0.1572,  2.0312, -0.9766],
         [-3.5000, -2.9531, -4.5312,  ..., -1.9609,  1.1016,  1.3672]],

        [[ 2.4531,  1.6719, -4.0938,  ..., -1.4219,  0.8477,  1.2422],
         [ 7.1250,  2.4375, -2.5938,  ..., -1.1406,  0.6445,  1.5938],
         [ 4.6875,  4.4375, -2.6875,  ..., -0.1709,  0.5898,  3.0156],
         ...,
         [ 7.1875,  3.5156,  0.3594,  ..., -2.2969,  0.7539,  3.3750],
         [ 2.5312,  0.1777, -1.4219,  ...,  0.1572,  2.0312, -0.9766],
         [-3.5000, -2.9531, -4.5312,  ..., -1.9609,  1.1016,  1.3672]],

        [[ 2.4531,  1.6719, -4.0938,  ..., -1.4219,  0.8477,  1.2422],
         [ 7.1250,  2.4375, -2.5938,  ..., -1.1406,  0.6445,  1.5938],
         [ 4.6875,  4.4375, -2.6875,  ..., -0.1709,  0.5898,  3.0156],
         ...,
         [ 7.1875,  3.5156,  0.3594,  ..., -2.2969,  0.7539,  3.3750],
         [ 2.5312,  0.1777, -1.4219,  ...,  0.1572,  2.0312, -0.9766],
         [-3.5000, -2.9531, -4.5312,  ..., -1.9609,  1.1016,  1.3672]],

        ...,

        [[ 0.1846,  0.7070, -2.5938,  ..., -1.6562,  3.1719, -1.0938],
         [ 3.0312,  1.2500, -1.8672,  ..., -2.9219,  2.3594, -0.6680],
         [ 5.0625,  0.7070, -0.4258,  ..., -4.0938,  3.5781, -3.2344],
         ...,
         [ 3.1094, -0.4531, -1.0938,  ..., -2.7969,  4.0625, -3.1562],
         [ 0.1914, -0.6250, -0.1875,  ..., -4.5000,  2.4844, -2.0000],
         [-2.2344,  3.1250, -4.0625,  ..., -2.6406,  3.4844, -0.3887]],

        [[ 0.1846,  0.7070, -2.5938,  ..., -1.6562,  3.1719, -1.0938],
         [ 3.0312,  1.2500, -1.8672,  ..., -2.9219,  2.3594, -0.6680],
         [ 5.0625,  0.7070, -0.4258,  ..., -4.0938,  3.5781, -3.2344],
         ...,
         [ 3.1094, -0.4531, -1.0938,  ..., -2.7969,  4.0625, -3.1562],
         [ 0.1914, -0.6250, -0.1875,  ..., -4.5000,  2.4844, -2.0000],
         [-2.2344,  3.1250, -4.0625,  ..., -2.6406,  3.4844, -0.3887]],

        [[ 0.1846,  0.7070, -2.5938,  ..., -1.6562,  3.1719, -1.0938],
         [ 3.0312,  1.2500, -1.8672,  ..., -2.9219,  2.3594, -0.6680],
         [ 5.0625,  0.7070, -0.4258,  ..., -4.0938,  3.5781, -3.2344],
         ...,
         [ 3.1094, -0.4531, -1.0938,  ..., -2.7969,  4.0625, -3.1562],
         [ 0.1914, -0.6250, -0.1875,  ..., -4.5000,  2.4844, -2.0000],
         [-2.2344,  3.1250, -4.0625,  ..., -2.6406,  3.4844, -0.3887]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[ 0.0366,  0.0337,  0.0234,  ...,  0.1299,  0.2031, -0.1338],
         [-0.0269, -0.0304,  0.0752,  ..., -0.1118, -0.0603,  0.0413],
         [-0.0747,  0.0601,  0.0187,  ..., -0.0109, -0.2598, -0.1670],
         ...,
         [ 0.2061, -0.0820,  0.0376,  ..., -0.0986,  0.1738,  0.1660],
         [-0.0284, -0.0187,  0.0200,  ...,  0.0508, -0.0062, -0.0474],
         [-0.1807,  0.1826,  0.0069,  ...,  0.1045, -0.3145, -0.1138]],

        [[ 0.0366,  0.0337,  0.0234,  ...,  0.1299,  0.2031, -0.1338],
         [-0.0269, -0.0304,  0.0752,  ..., -0.1118, -0.0603,  0.0413],
         [-0.0747,  0.0601,  0.0187,  ..., -0.0109, -0.2598, -0.1670],
         ...,
         [ 0.2061, -0.0820,  0.0376,  ..., -0.0986,  0.1738,  0.1660],
         [-0.0284, -0.0187,  0.0200,  ...,  0.0508, -0.0062, -0.0474],
         [-0.1807,  0.1826,  0.0069,  ...,  0.1045, -0.3145, -0.1138]],

        [[ 0.0366,  0.0337,  0.0234,  ...,  0.1299,  0.2031, -0.1338],
         [-0.0269, -0.0304,  0.0752,  ..., -0.1118, -0.0603,  0.0413],
         [-0.0747,  0.0601,  0.0187,  ..., -0.0109, -0.2598, -0.1670],
         ...,
         [ 0.2061, -0.0820,  0.0376,  ..., -0.0986,  0.1738,  0.1660],
         [-0.0284, -0.0187,  0.0200,  ...,  0.0508, -0.0062, -0.0474],
         [-0.1807,  0.1826,  0.0069,  ...,  0.1045, -0.3145, -0.1138]],

        ...,

        [[ 0.0688,  0.0198,  0.0096,  ..., -0.0469, -0.0825, -0.0283],
         [-0.0437,  0.0231, -0.0981,  ...,  0.0354, -0.0835, -0.0356],
         [-0.1641,  0.0330, -0.0334,  ...,  0.0674, -0.1543,  0.1328],
         ...,
         [ 0.0854, -0.0124, -0.1245,  ...,  0.0864, -0.0591, -0.0588],
         [-0.0104, -0.0232,  0.0012,  ...,  0.0289,  0.0244,  0.0532],
         [ 0.0466,  0.1074,  0.2637,  ..., -0.0938,  0.0044,  0.0801]],

        [[ 0.0688,  0.0198,  0.0096,  ..., -0.0469, -0.0825, -0.0283],
         [-0.0437,  0.0231, -0.0981,  ...,  0.0354, -0.0835, -0.0356],
         [-0.1641,  0.0330, -0.0334,  ...,  0.0674, -0.1543,  0.1328],
         ...,
         [ 0.0854, -0.0124, -0.1245,  ...,  0.0864, -0.0591, -0.0588],
         [-0.0104, -0.0232,  0.0012,  ...,  0.0289,  0.0244,  0.0532],
         [ 0.0466,  0.1074,  0.2637,  ..., -0.0938,  0.0044,  0.0801]],

        [[ 0.0688,  0.0198,  0.0096,  ..., -0.0469, -0.0825, -0.0283],
         [-0.0437,  0.0231, -0.0981,  ...,  0.0354, -0.0835, -0.0356],
         [-0.1641,  0.0330, -0.0334,  ...,  0.0674, -0.1543,  0.1328],
         ...,
         [ 0.0854, -0.0124, -0.1245,  ...,  0.0864, -0.0591, -0.0588],
         [-0.0104, -0.0232,  0.0012,  ...,  0.0289,  0.0244,  0.0532],
         [ 0.0466,  0.1074,  0.2637,  ..., -0.0938,  0.0044,  0.0801]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[-2.1820e-03, -5.0964e-03,  8.5831e-04,  ..., -9.5312e-01,
           6.8970e-03,  4.1406e-01],
         [ 1.4219e+00,  1.4141e+00,  2.2188e+00,  ...,  6.9062e+00,
           1.8164e-01, -1.7188e+00],
         [ 1.3672e+00,  7.8613e-02,  1.4531e+00,  ...,  6.3438e+00,
          -1.0547e+00, -2.1562e+00],
         ...,
         [ 2.2031e+00, -3.1250e-01, -2.8516e-01,  ...,  7.3438e+00,
          -2.1406e+00, -2.9688e+00],
         [ 1.5234e-01, -2.6367e-01,  1.6699e-01,  ...,  6.3125e+00,
          -8.0469e-01, -1.4844e+00],
         [ 1.2031e+00, -1.0469e+00,  1.1016e+00,  ...,  6.2812e+00,
           1.1250e+00, -1.1328e+00]],

        [[-2.1820e-03, -5.0964e-03,  8.5831e-04,  ..., -9.5312e-01,
           6.8970e-03,  4.1406e-01],
         [ 1.4219e+00,  1.4141e+00,  2.2188e+00,  ...,  6.9062e+00,
           1.8164e-01, -1.7188e+00],
         [ 1.3672e+00,  7.8613e-02,  1.4531e+00,  ...,  6.3438e+00,
          -1.0547e+00, -2.1562e+00],
         ...,
         [ 2.2031e+00, -3.1250e-01, -2.8516e-01,  ...,  7.3438e+00,
          -2.1406e+00, -2.9688e+00],
         [ 1.5234e-01, -2.6367e-01,  1.6699e-01,  ...,  6.3125e+00,
          -8.0469e-01, -1.4844e+00],
         [ 1.2031e+00, -1.0469e+00,  1.1016e+00,  ...,  6.2812e+00,
           1.1250e+00, -1.1328e+00]],

        [[-2.1820e-03, -5.0964e-03,  8.5831e-04,  ..., -9.5312e-01,
           6.8970e-03,  4.1406e-01],
         [ 1.4219e+00,  1.4141e+00,  2.2188e+00,  ...,  6.9062e+00,
           1.8164e-01, -1.7188e+00],
         [ 1.3672e+00,  7.8613e-02,  1.4531e+00,  ...,  6.3438e+00,
          -1.0547e+00, -2.1562e+00],
         ...,
         [ 2.2031e+00, -3.1250e-01, -2.8516e-01,  ...,  7.3438e+00,
          -2.1406e+00, -2.9688e+00],
         [ 1.5234e-01, -2.6367e-01,  1.6699e-01,  ...,  6.3125e+00,
          -8.0469e-01, -1.4844e+00],
         [ 1.2031e+00, -1.0469e+00,  1.1016e+00,  ...,  6.2812e+00,
           1.1250e+00, -1.1328e+00]],

        ...,

        [[ 1.0315e-02,  7.4005e-04, -5.4626e-03,  ..., -4.5654e-02,
           9.8145e-02,  1.5015e-02],
         [ 4.0000e+00,  1.5156e+00,  1.7188e+00,  ...,  8.3008e-02,
          -2.4844e+00,  1.7188e+00],
         [ 2.1875e+00, -4.4434e-02,  1.3672e+00,  ...,  1.4609e+00,
          -1.6250e+00,  1.2578e+00],
         ...,
         [ 3.4844e+00,  9.6484e-01,  5.4297e-01,  ...,  1.0625e+00,
          -1.5938e+00,  7.6172e-01],
         [ 8.1250e-01,  7.3242e-04,  7.5000e-01,  ...,  1.5156e+00,
           4.2773e-01,  2.3594e+00],
         [-5.3516e-01,  1.0469e+00,  2.1719e+00,  ...,  6.0547e-01,
          -1.0391e+00,  2.4805e-01]],

        [[ 1.0315e-02,  7.4005e-04, -5.4626e-03,  ..., -4.5654e-02,
           9.8145e-02,  1.5015e-02],
         [ 4.0000e+00,  1.5156e+00,  1.7188e+00,  ...,  8.3008e-02,
          -2.4844e+00,  1.7188e+00],
         [ 2.1875e+00, -4.4434e-02,  1.3672e+00,  ...,  1.4609e+00,
          -1.6250e+00,  1.2578e+00],
         ...,
         [ 3.4844e+00,  9.6484e-01,  5.4297e-01,  ...,  1.0625e+00,
          -1.5938e+00,  7.6172e-01],
         [ 8.1250e-01,  7.3242e-04,  7.5000e-01,  ...,  1.5156e+00,
           4.2773e-01,  2.3594e+00],
         [-5.3516e-01,  1.0469e+00,  2.1719e+00,  ...,  6.0547e-01,
          -1.0391e+00,  2.4805e-01]],

        [[ 1.0315e-02,  7.4005e-04, -5.4626e-03,  ..., -4.5654e-02,
           9.8145e-02,  1.5015e-02],
         [ 4.0000e+00,  1.5156e+00,  1.7188e+00,  ...,  8.3008e-02,
          -2.4844e+00,  1.7188e+00],
         [ 2.1875e+00, -4.4434e-02,  1.3672e+00,  ...,  1.4609e+00,
          -1.6250e+00,  1.2578e+00],
         ...,
         [ 3.4844e+00,  9.6484e-01,  5.4297e-01,  ...,  1.0625e+00,
          -1.5938e+00,  7.6172e-01],
         [ 8.1250e-01,  7.3242e-04,  7.5000e-01,  ...,  1.5156e+00,
           4.2773e-01,  2.3594e+00],
         [-5.3516e-01,  1.0469e+00,  2.1719e+00,  ...,  6.0547e-01,
          -1.0391e+00,  2.4805e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.8997e-03, -1.0864e-02,  2.7313e-03,  ..., -5.3711e-03,
           7.7820e-04,  3.0975e-03],
         [ 1.2793e-01,  3.1445e-01,  5.7812e-01,  ...,  5.1953e-01,
          -1.4771e-02, -3.0078e-01],
         [-2.1973e-01,  5.4443e-02, -1.6699e-01,  ..., -5.3516e-01,
           2.7930e-01,  1.0205e-01],
         ...,
         [-1.7480e-01,  7.2266e-02,  2.2949e-01,  ...,  2.8809e-02,
           2.4512e-01, -1.9775e-02],
         [ 3.5742e-01, -2.8198e-02,  5.3955e-02,  ...,  3.4766e-01,
          -1.0400e-01, -1.5820e-01],
         [ 4.0234e-01,  1.1719e+00, -1.2812e+00,  ...,  1.0156e+00,
          -5.1172e-01,  5.5469e-01]],

        [[-1.8997e-03, -1.0864e-02,  2.7313e-03,  ..., -5.3711e-03,
           7.7820e-04,  3.0975e-03],
         [ 1.2793e-01,  3.1445e-01,  5.7812e-01,  ...,  5.1953e-01,
          -1.4771e-02, -3.0078e-01],
         [-2.1973e-01,  5.4443e-02, -1.6699e-01,  ..., -5.3516e-01,
           2.7930e-01,  1.0205e-01],
         ...,
         [-1.7480e-01,  7.2266e-02,  2.2949e-01,  ...,  2.8809e-02,
           2.4512e-01, -1.9775e-02],
         [ 3.5742e-01, -2.8198e-02,  5.3955e-02,  ...,  3.4766e-01,
          -1.0400e-01, -1.5820e-01],
         [ 4.0234e-01,  1.1719e+00, -1.2812e+00,  ...,  1.0156e+00,
          -5.1172e-01,  5.5469e-01]],

        [[-1.8997e-03, -1.0864e-02,  2.7313e-03,  ..., -5.3711e-03,
           7.7820e-04,  3.0975e-03],
         [ 1.2793e-01,  3.1445e-01,  5.7812e-01,  ...,  5.1953e-01,
          -1.4771e-02, -3.0078e-01],
         [-2.1973e-01,  5.4443e-02, -1.6699e-01,  ..., -5.3516e-01,
           2.7930e-01,  1.0205e-01],
         ...,
         [-1.7480e-01,  7.2266e-02,  2.2949e-01,  ...,  2.8809e-02,
           2.4512e-01, -1.9775e-02],
         [ 3.5742e-01, -2.8198e-02,  5.3955e-02,  ...,  3.4766e-01,
          -1.0400e-01, -1.5820e-01],
         [ 4.0234e-01,  1.1719e+00, -1.2812e+00,  ...,  1.0156e+00,
          -5.1172e-01,  5.5469e-01]],

        ...,

        [[ 1.1658e-02, -7.6599e-03, -4.5967e-04,  ..., -3.1128e-03,
           3.2349e-03,  9.1934e-04],
         [ 2.6953e-01,  2.1582e-01, -2.3633e-01,  ...,  2.2095e-02,
           2.5195e-01,  5.7373e-02],
         [-3.2031e-01,  1.0315e-02,  5.2979e-02,  ...,  4.0820e-01,
           1.6895e-01, -2.3047e-01],
         ...,
         [-2.4048e-02,  1.3184e-01, -2.1289e-01,  ...,  2.3340e-01,
           2.2949e-01,  3.2617e-01],
         [-3.1250e-01,  1.7090e-01, -1.5918e-01,  ...,  2.1289e-01,
          -2.3926e-01, -2.6245e-02],
         [-1.4062e-01, -1.6504e-01,  1.4282e-02,  ..., -2.7344e-01,
          -2.6758e-01, -4.4727e-01]],

        [[ 1.1658e-02, -7.6599e-03, -4.5967e-04,  ..., -3.1128e-03,
           3.2349e-03,  9.1934e-04],
         [ 2.6953e-01,  2.1582e-01, -2.3633e-01,  ...,  2.2095e-02,
           2.5195e-01,  5.7373e-02],
         [-3.2031e-01,  1.0315e-02,  5.2979e-02,  ...,  4.0820e-01,
           1.6895e-01, -2.3047e-01],
         ...,
         [-2.4048e-02,  1.3184e-01, -2.1289e-01,  ...,  2.3340e-01,
           2.2949e-01,  3.2617e-01],
         [-3.1250e-01,  1.7090e-01, -1.5918e-01,  ...,  2.1289e-01,
          -2.3926e-01, -2.6245e-02],
         [-1.4062e-01, -1.6504e-01,  1.4282e-02,  ..., -2.7344e-01,
          -2.6758e-01, -4.4727e-01]],

        [[ 1.1658e-02, -7.6599e-03, -4.5967e-04,  ..., -3.1128e-03,
           3.2349e-03,  9.1934e-04],
         [ 2.6953e-01,  2.1582e-01, -2.3633e-01,  ...,  2.2095e-02,
           2.5195e-01,  5.7373e-02],
         [-3.2031e-01,  1.0315e-02,  5.2979e-02,  ...,  4.0820e-01,
           1.6895e-01, -2.3047e-01],
         ...,
         [-2.4048e-02,  1.3184e-01, -2.1289e-01,  ...,  2.3340e-01,
           2.2949e-01,  3.2617e-01],
         [-3.1250e-01,  1.7090e-01, -1.5918e-01,  ...,  2.1289e-01,
          -2.3926e-01, -2.6245e-02],
         [-1.4062e-01, -1.6504e-01,  1.4282e-02,  ..., -2.7344e-01,
          -2.6758e-01, -4.4727e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 7.3242e-03,  4.6387e-03,  8.1787e-03,  ...,  6.4844e-01,
           5.0000e-01, -3.8477e-01],
         [-1.9688e+00,  1.5000e+00, -7.1484e-01,  ..., -3.2031e+00,
          -2.9375e+00,  2.9883e-01],
         [-5.7031e-01,  5.5078e-01, -1.0781e+00,  ..., -3.3281e+00,
          -3.3906e+00,  1.0469e+00],
         ...,
         [-6.6406e-01,  3.0078e-01,  8.3984e-01,  ..., -4.9375e+00,
          -2.9844e+00,  8.1543e-02],
         [-3.3203e-01, -7.6172e-02,  1.9531e-01,  ..., -4.9062e+00,
          -3.5625e+00,  6.4453e-01],
         [ 4.2188e-01, -8.7891e-01,  1.4355e-01,  ..., -5.2188e+00,
          -2.3281e+00,  1.0625e+00]],

        [[ 7.3242e-03,  4.6387e-03,  8.1787e-03,  ...,  6.4844e-01,
           5.0000e-01, -3.8477e-01],
         [-1.9688e+00,  1.5000e+00, -7.1484e-01,  ..., -3.2031e+00,
          -2.9375e+00,  2.9883e-01],
         [-5.7031e-01,  5.5078e-01, -1.0781e+00,  ..., -3.3281e+00,
          -3.3906e+00,  1.0469e+00],
         ...,
         [-6.6406e-01,  3.0078e-01,  8.3984e-01,  ..., -4.9375e+00,
          -2.9844e+00,  8.1543e-02],
         [-3.3203e-01, -7.6172e-02,  1.9531e-01,  ..., -4.9062e+00,
          -3.5625e+00,  6.4453e-01],
         [ 4.2188e-01, -8.7891e-01,  1.4355e-01,  ..., -5.2188e+00,
          -2.3281e+00,  1.0625e+00]],

        [[ 7.3242e-03,  4.6387e-03,  8.1787e-03,  ...,  6.4844e-01,
           5.0000e-01, -3.8477e-01],
         [-1.9688e+00,  1.5000e+00, -7.1484e-01,  ..., -3.2031e+00,
          -2.9375e+00,  2.9883e-01],
         [-5.7031e-01,  5.5078e-01, -1.0781e+00,  ..., -3.3281e+00,
          -3.3906e+00,  1.0469e+00],
         ...,
         [-6.6406e-01,  3.0078e-01,  8.3984e-01,  ..., -4.9375e+00,
          -2.9844e+00,  8.1543e-02],
         [-3.3203e-01, -7.6172e-02,  1.9531e-01,  ..., -4.9062e+00,
          -3.5625e+00,  6.4453e-01],
         [ 4.2188e-01, -8.7891e-01,  1.4355e-01,  ..., -5.2188e+00,
          -2.3281e+00,  1.0625e+00]],

        ...,

        [[ 2.6398e-03, -5.1575e-03,  3.1586e-03,  ...,  8.0078e-02,
          -5.4321e-03, -2.4219e-01],
         [-5.1875e+00, -8.9062e-01,  7.8516e-01,  ..., -6.6406e-01,
          -1.3281e+00, -6.8750e-01],
         [-3.2188e+00, -5.5664e-02,  1.1719e-02,  ...,  5.1172e-01,
           3.5352e-01, -9.7266e-01],
         ...,
         [-4.7188e+00,  2.7344e+00,  1.2500e+00,  ..., -6.9531e-01,
           4.5508e-01, -7.4219e-01],
         [-2.1406e+00,  2.5312e+00,  1.1875e+00,  ...,  6.6406e-01,
           7.3828e-01, -7.6953e-01],
         [ 3.0469e+00,  1.2812e+00,  2.6562e+00,  ..., -1.1094e+00,
           2.1875e-01, -2.4292e-02]],

        [[ 2.6398e-03, -5.1575e-03,  3.1586e-03,  ...,  8.0078e-02,
          -5.4321e-03, -2.4219e-01],
         [-5.1875e+00, -8.9062e-01,  7.8516e-01,  ..., -6.6406e-01,
          -1.3281e+00, -6.8750e-01],
         [-3.2188e+00, -5.5664e-02,  1.1719e-02,  ...,  5.1172e-01,
           3.5352e-01, -9.7266e-01],
         ...,
         [-4.7188e+00,  2.7344e+00,  1.2500e+00,  ..., -6.9531e-01,
           4.5508e-01, -7.4219e-01],
         [-2.1406e+00,  2.5312e+00,  1.1875e+00,  ...,  6.6406e-01,
           7.3828e-01, -7.6953e-01],
         [ 3.0469e+00,  1.2812e+00,  2.6562e+00,  ..., -1.1094e+00,
           2.1875e-01, -2.4292e-02]],

        [[ 2.6398e-03, -5.1575e-03,  3.1586e-03,  ...,  8.0078e-02,
          -5.4321e-03, -2.4219e-01],
         [-5.1875e+00, -8.9062e-01,  7.8516e-01,  ..., -6.6406e-01,
          -1.3281e+00, -6.8750e-01],
         [-3.2188e+00, -5.5664e-02,  1.1719e-02,  ...,  5.1172e-01,
           3.5352e-01, -9.7266e-01],
         ...,
         [-4.7188e+00,  2.7344e+00,  1.2500e+00,  ..., -6.9531e-01,
           4.5508e-01, -7.4219e-01],
         [-2.1406e+00,  2.5312e+00,  1.1875e+00,  ...,  6.6406e-01,
           7.3828e-01, -7.6953e-01],
         [ 3.0469e+00,  1.2812e+00,  2.6562e+00,  ..., -1.1094e+00,
           2.1875e-01, -2.4292e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 5.3406e-04,  6.9580e-03, -4.0283e-03,  ...,  1.3428e-03,
          -3.2349e-03, -1.6174e-03],
         [ 1.1963e-01, -6.2109e-01,  2.1680e-01,  ..., -2.0874e-02,
           1.0193e-02,  2.7930e-01],
         [-3.1055e-01,  4.1016e-01,  8.4375e-01,  ...,  1.8457e-01,
          -7.3730e-02, -3.7305e-01],
         ...,
         [ 1.3672e-01,  4.7656e-01,  3.8477e-01,  ...,  1.4258e-01,
           3.2422e-01,  1.6479e-02],
         [ 7.4609e-01, -6.8359e-01,  1.6211e-01,  ..., -2.5781e-01,
           4.6875e-01,  9.9121e-02],
         [-1.3867e-01,  1.6699e-01,  1.1279e-01,  ...,  4.2969e-01,
          -8.2397e-03, -2.7539e-01]],

        [[ 5.3406e-04,  6.9580e-03, -4.0283e-03,  ...,  1.3428e-03,
          -3.2349e-03, -1.6174e-03],
         [ 1.1963e-01, -6.2109e-01,  2.1680e-01,  ..., -2.0874e-02,
           1.0193e-02,  2.7930e-01],
         [-3.1055e-01,  4.1016e-01,  8.4375e-01,  ...,  1.8457e-01,
          -7.3730e-02, -3.7305e-01],
         ...,
         [ 1.3672e-01,  4.7656e-01,  3.8477e-01,  ...,  1.4258e-01,
           3.2422e-01,  1.6479e-02],
         [ 7.4609e-01, -6.8359e-01,  1.6211e-01,  ..., -2.5781e-01,
           4.6875e-01,  9.9121e-02],
         [-1.3867e-01,  1.6699e-01,  1.1279e-01,  ...,  4.2969e-01,
          -8.2397e-03, -2.7539e-01]],

        [[ 5.3406e-04,  6.9580e-03, -4.0283e-03,  ...,  1.3428e-03,
          -3.2349e-03, -1.6174e-03],
         [ 1.1963e-01, -6.2109e-01,  2.1680e-01,  ..., -2.0874e-02,
           1.0193e-02,  2.7930e-01],
         [-3.1055e-01,  4.1016e-01,  8.4375e-01,  ...,  1.8457e-01,
          -7.3730e-02, -3.7305e-01],
         ...,
         [ 1.3672e-01,  4.7656e-01,  3.8477e-01,  ...,  1.4258e-01,
           3.2422e-01,  1.6479e-02],
         [ 7.4609e-01, -6.8359e-01,  1.6211e-01,  ..., -2.5781e-01,
           4.6875e-01,  9.9121e-02],
         [-1.3867e-01,  1.6699e-01,  1.1279e-01,  ...,  4.2969e-01,
          -8.2397e-03, -2.7539e-01]],

        ...,

        [[-3.9368e-03, -2.0752e-03,  3.1433e-03,  ...,  3.3569e-04,
           3.0212e-03, -7.6675e-04],
         [-3.5547e-01,  2.2339e-02, -1.4648e-01,  ..., -1.3770e-01,
          -2.1973e-01,  1.2024e-02],
         [ 8.7891e-02, -1.6699e-01,  6.2500e-02,  ...,  4.0039e-01,
          -2.6953e-01,  1.7383e-01],
         ...,
         [ 7.7148e-02,  2.5586e-01, -2.6172e-01,  ...,  1.7578e-01,
           9.2773e-02, -1.6968e-02],
         [ 1.5039e-01,  4.4678e-02,  1.2061e-01,  ...,  1.4648e-02,
           2.7539e-01, -1.4453e-01],
         [-1.5234e-01, -3.2617e-01, -4.0625e-01,  ...,  7.1716e-03,
          -9.5215e-02,  6.0791e-02]],

        [[-3.9368e-03, -2.0752e-03,  3.1433e-03,  ...,  3.3569e-04,
           3.0212e-03, -7.6675e-04],
         [-3.5547e-01,  2.2339e-02, -1.4648e-01,  ..., -1.3770e-01,
          -2.1973e-01,  1.2024e-02],
         [ 8.7891e-02, -1.6699e-01,  6.2500e-02,  ...,  4.0039e-01,
          -2.6953e-01,  1.7383e-01],
         ...,
         [ 7.7148e-02,  2.5586e-01, -2.6172e-01,  ...,  1.7578e-01,
           9.2773e-02, -1.6968e-02],
         [ 1.5039e-01,  4.4678e-02,  1.2061e-01,  ...,  1.4648e-02,
           2.7539e-01, -1.4453e-01],
         [-1.5234e-01, -3.2617e-01, -4.0625e-01,  ...,  7.1716e-03,
          -9.5215e-02,  6.0791e-02]],

        [[-3.9368e-03, -2.0752e-03,  3.1433e-03,  ...,  3.3569e-04,
           3.0212e-03, -7.6675e-04],
         [-3.5547e-01,  2.2339e-02, -1.4648e-01,  ..., -1.3770e-01,
          -2.1973e-01,  1.2024e-02],
         [ 8.7891e-02, -1.6699e-01,  6.2500e-02,  ...,  4.0039e-01,
          -2.6953e-01,  1.7383e-01],
         ...,
         [ 7.7148e-02,  2.5586e-01, -2.6172e-01,  ...,  1.7578e-01,
           9.2773e-02, -1.6968e-02],
         [ 1.5039e-01,  4.4678e-02,  1.2061e-01,  ...,  1.4648e-02,
           2.7539e-01, -1.4453e-01],
         [-1.5234e-01, -3.2617e-01, -4.0625e-01,  ...,  7.1716e-03,
          -9.5215e-02,  6.0791e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 0.0146,  0.0317, -0.0143,  ..., -0.2393, -0.6523, -0.1162],
         [ 1.6094, -1.1406, -0.0879,  ...,  0.4395,  2.6562, -0.4863],
         [ 2.5781, -0.8555, -0.3262,  ...,  0.2451,  3.0469, -0.1484],
         ...,
         [ 1.9609,  0.9023, -1.1562,  ...,  2.4219,  3.5000,  1.3750],
         [-0.1348, -0.6406, -0.3516,  ...,  1.4609,  0.8047,  0.5312],
         [ 0.5703,  0.3828, -0.3555,  ...,  1.1250,  2.1094, -2.1875]],

        [[ 0.0146,  0.0317, -0.0143,  ..., -0.2393, -0.6523, -0.1162],
         [ 1.6094, -1.1406, -0.0879,  ...,  0.4395,  2.6562, -0.4863],
         [ 2.5781, -0.8555, -0.3262,  ...,  0.2451,  3.0469, -0.1484],
         ...,
         [ 1.9609,  0.9023, -1.1562,  ...,  2.4219,  3.5000,  1.3750],
         [-0.1348, -0.6406, -0.3516,  ...,  1.4609,  0.8047,  0.5312],
         [ 0.5703,  0.3828, -0.3555,  ...,  1.1250,  2.1094, -2.1875]],

        [[ 0.0146,  0.0317, -0.0143,  ..., -0.2393, -0.6523, -0.1162],
         [ 1.6094, -1.1406, -0.0879,  ...,  0.4395,  2.6562, -0.4863],
         [ 2.5781, -0.8555, -0.3262,  ...,  0.2451,  3.0469, -0.1484],
         ...,
         [ 1.9609,  0.9023, -1.1562,  ...,  2.4219,  3.5000,  1.3750],
         [-0.1348, -0.6406, -0.3516,  ...,  1.4609,  0.8047,  0.5312],
         [ 0.5703,  0.3828, -0.3555,  ...,  1.1250,  2.1094, -2.1875]],

        ...,

        [[-0.0132, -0.0184,  0.0106,  ...,  0.0322, -0.2773,  0.1357],
         [ 2.0625, -1.0625, -1.7812,  ..., -0.1167, -0.0369, -0.5156],
         [ 0.3672, -0.5586, -0.3867,  ...,  0.9570, -0.5742, -0.1230],
         ...,
         [ 1.5938, -0.4102,  0.0332,  ...,  1.7266,  1.3203, -1.1406],
         [ 0.0674,  0.0227,  0.0391,  ...,  0.0649, -1.4062,  1.2812],
         [-0.1367,  0.7305, -0.2539,  ...,  1.7578,  0.3047, -4.2812]],

        [[-0.0132, -0.0184,  0.0106,  ...,  0.0322, -0.2773,  0.1357],
         [ 2.0625, -1.0625, -1.7812,  ..., -0.1167, -0.0369, -0.5156],
         [ 0.3672, -0.5586, -0.3867,  ...,  0.9570, -0.5742, -0.1230],
         ...,
         [ 1.5938, -0.4102,  0.0332,  ...,  1.7266,  1.3203, -1.1406],
         [ 0.0674,  0.0227,  0.0391,  ...,  0.0649, -1.4062,  1.2812],
         [-0.1367,  0.7305, -0.2539,  ...,  1.7578,  0.3047, -4.2812]],

        [[-0.0132, -0.0184,  0.0106,  ...,  0.0322, -0.2773,  0.1357],
         [ 2.0625, -1.0625, -1.7812,  ..., -0.1167, -0.0369, -0.5156],
         [ 0.3672, -0.5586, -0.3867,  ...,  0.9570, -0.5742, -0.1230],
         ...,
         [ 1.5938, -0.4102,  0.0332,  ...,  1.7266,  1.3203, -1.1406],
         [ 0.0674,  0.0227,  0.0391,  ...,  0.0649, -1.4062,  1.2812],
         [-0.1367,  0.7305, -0.2539,  ...,  1.7578,  0.3047, -4.2812]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[-9.8267e-03,  3.4027e-03, -1.1963e-02,  ...,  4.8523e-03,
          -4.0894e-03, -9.4604e-03],
         [-1.2988e-01,  6.7578e-01, -2.5586e-01,  ..., -2.1191e-01,
          -2.3828e-01,  6.7188e-01],
         [ 3.7109e-01, -2.2363e-01, -2.2559e-01,  ...,  5.8594e-01,
           6.9531e-01, -2.0117e-01],
         ...,
         [ 2.3193e-02,  1.3867e-01,  7.3828e-01,  ...,  5.7031e-01,
           1.9434e-01,  1.4648e-01],
         [ 8.2422e-01, -5.4688e-01,  6.0938e-01,  ..., -4.4141e-01,
          -1.9434e-01,  5.6641e-01],
         [-1.5332e-01,  3.4912e-02, -2.3535e-01,  ...,  3.2715e-02,
           2.0508e-01, -1.7285e-01]],

        [[-9.8267e-03,  3.4027e-03, -1.1963e-02,  ...,  4.8523e-03,
          -4.0894e-03, -9.4604e-03],
         [-1.2988e-01,  6.7578e-01, -2.5586e-01,  ..., -2.1191e-01,
          -2.3828e-01,  6.7188e-01],
         [ 3.7109e-01, -2.2363e-01, -2.2559e-01,  ...,  5.8594e-01,
           6.9531e-01, -2.0117e-01],
         ...,
         [ 2.3193e-02,  1.3867e-01,  7.3828e-01,  ...,  5.7031e-01,
           1.9434e-01,  1.4648e-01],
         [ 8.2422e-01, -5.4688e-01,  6.0938e-01,  ..., -4.4141e-01,
          -1.9434e-01,  5.6641e-01],
         [-1.5332e-01,  3.4912e-02, -2.3535e-01,  ...,  3.2715e-02,
           2.0508e-01, -1.7285e-01]],

        [[-9.8267e-03,  3.4027e-03, -1.1963e-02,  ...,  4.8523e-03,
          -4.0894e-03, -9.4604e-03],
         [-1.2988e-01,  6.7578e-01, -2.5586e-01,  ..., -2.1191e-01,
          -2.3828e-01,  6.7188e-01],
         [ 3.7109e-01, -2.2363e-01, -2.2559e-01,  ...,  5.8594e-01,
           6.9531e-01, -2.0117e-01],
         ...,
         [ 2.3193e-02,  1.3867e-01,  7.3828e-01,  ...,  5.7031e-01,
           1.9434e-01,  1.4648e-01],
         [ 8.2422e-01, -5.4688e-01,  6.0938e-01,  ..., -4.4141e-01,
          -1.9434e-01,  5.6641e-01],
         [-1.5332e-01,  3.4912e-02, -2.3535e-01,  ...,  3.2715e-02,
           2.0508e-01, -1.7285e-01]],

        ...,

        [[-4.5967e-04, -4.8523e-03, -3.3447e-02,  ...,  6.7139e-03,
           7.4768e-03,  5.1880e-03],
         [-1.9629e-01, -2.9541e-02,  9.2163e-03,  ..., -6.4844e-01,
          -5.3516e-01,  3.8574e-02],
         [ 4.1602e-01, -6.7578e-01,  1.9531e-01,  ..., -4.4531e-01,
           3.5156e-02,  4.7070e-01],
         ...,
         [ 4.4141e-01,  5.0391e-01, -5.2002e-02,  ..., -3.4375e-01,
           6.8848e-02,  2.1973e-01],
         [ 2.0996e-01,  4.8828e-01,  4.5508e-01,  ..., -2.5195e-01,
          -1.0547e-01,  3.1836e-01],
         [-1.2500e-01, -6.4844e-01, -1.1816e-01,  ..., -1.4648e-01,
           8.2016e-04, -3.0859e-01]],

        [[-4.5967e-04, -4.8523e-03, -3.3447e-02,  ...,  6.7139e-03,
           7.4768e-03,  5.1880e-03],
         [-1.9629e-01, -2.9541e-02,  9.2163e-03,  ..., -6.4844e-01,
          -5.3516e-01,  3.8574e-02],
         [ 4.1602e-01, -6.7578e-01,  1.9531e-01,  ..., -4.4531e-01,
           3.5156e-02,  4.7070e-01],
         ...,
         [ 4.4141e-01,  5.0391e-01, -5.2002e-02,  ..., -3.4375e-01,
           6.8848e-02,  2.1973e-01],
         [ 2.0996e-01,  4.8828e-01,  4.5508e-01,  ..., -2.5195e-01,
          -1.0547e-01,  3.1836e-01],
         [-1.2500e-01, -6.4844e-01, -1.1816e-01,  ..., -1.4648e-01,
           8.2016e-04, -3.0859e-01]],

        [[-4.5967e-04, -4.8523e-03, -3.3447e-02,  ...,  6.7139e-03,
           7.4768e-03,  5.1880e-03],
         [-1.9629e-01, -2.9541e-02,  9.2163e-03,  ..., -6.4844e-01,
          -5.3516e-01,  3.8574e-02],
         [ 4.1602e-01, -6.7578e-01,  1.9531e-01,  ..., -4.4531e-01,
           3.5156e-02,  4.7070e-01],
         ...,
         [ 4.4141e-01,  5.0391e-01, -5.2002e-02,  ..., -3.4375e-01,
           6.8848e-02,  2.1973e-01],
         [ 2.0996e-01,  4.8828e-01,  4.5508e-01,  ..., -2.5195e-01,
          -1.0547e-01,  3.1836e-01],
         [-1.2500e-01, -6.4844e-01, -1.1816e-01,  ..., -1.4648e-01,
           8.2016e-04, -3.0859e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-5.8289e-03, -5.2185e-03, -3.2043e-03,  ...,  1.9141e-01,
          -9.6191e-02, -6.4062e-01],
         [-3.3750e+00, -1.9297e+00,  2.3828e-01,  ..., -1.8359e-01,
           7.5391e-01,  2.9375e+00],
         [ 5.9375e-01, -2.5312e+00, -1.2266e+00,  ..., -3.8574e-02,
           8.2812e-01,  1.7812e+00],
         ...,
         [-1.8828e+00, -1.8125e+00,  1.3438e+00,  ..., -1.3125e+00,
          -4.3359e-01,  2.8594e+00],
         [ 1.2188e+00, -1.3867e-01,  1.0703e+00,  ..., -2.2812e+00,
          -1.9824e-01,  3.6406e+00],
         [ 3.5156e+00,  1.7188e+00,  1.7188e+00,  ...,  6.7188e-01,
          -2.3340e-01,  2.0469e+00]],

        [[-5.8289e-03, -5.2185e-03, -3.2043e-03,  ...,  1.9141e-01,
          -9.6191e-02, -6.4062e-01],
         [-3.3750e+00, -1.9297e+00,  2.3828e-01,  ..., -1.8359e-01,
           7.5391e-01,  2.9375e+00],
         [ 5.9375e-01, -2.5312e+00, -1.2266e+00,  ..., -3.8574e-02,
           8.2812e-01,  1.7812e+00],
         ...,
         [-1.8828e+00, -1.8125e+00,  1.3438e+00,  ..., -1.3125e+00,
          -4.3359e-01,  2.8594e+00],
         [ 1.2188e+00, -1.3867e-01,  1.0703e+00,  ..., -2.2812e+00,
          -1.9824e-01,  3.6406e+00],
         [ 3.5156e+00,  1.7188e+00,  1.7188e+00,  ...,  6.7188e-01,
          -2.3340e-01,  2.0469e+00]],

        [[-5.8289e-03, -5.2185e-03, -3.2043e-03,  ...,  1.9141e-01,
          -9.6191e-02, -6.4062e-01],
         [-3.3750e+00, -1.9297e+00,  2.3828e-01,  ..., -1.8359e-01,
           7.5391e-01,  2.9375e+00],
         [ 5.9375e-01, -2.5312e+00, -1.2266e+00,  ..., -3.8574e-02,
           8.2812e-01,  1.7812e+00],
         ...,
         [-1.8828e+00, -1.8125e+00,  1.3438e+00,  ..., -1.3125e+00,
          -4.3359e-01,  2.8594e+00],
         [ 1.2188e+00, -1.3867e-01,  1.0703e+00,  ..., -2.2812e+00,
          -1.9824e-01,  3.6406e+00],
         [ 3.5156e+00,  1.7188e+00,  1.7188e+00,  ...,  6.7188e-01,
          -2.3340e-01,  2.0469e+00]],

        ...,

        [[-2.0599e-03, -3.8338e-04,  5.0049e-03,  ...,  1.0437e-02,
           1.6235e-02,  2.6758e-01],
         [-6.3672e-01,  5.5469e-01, -4.7656e-01,  ..., -2.0625e+00,
           9.9609e-01, -1.2969e+00],
         [ 1.2266e+00,  1.9531e+00,  3.3984e-01,  ..., -7.0703e-01,
          -2.5781e+00,  2.0605e-01],
         ...,
         [ 3.9062e-01,  1.8906e+00,  3.1250e-02,  ...,  6.8359e-01,
           2.2852e-01, -3.0664e-01],
         [ 2.9688e-01, -9.6094e-01,  9.8828e-01,  ...,  1.2578e+00,
          -2.0000e+00, -1.3516e+00],
         [ 1.7734e+00, -2.6953e-01, -1.8906e+00,  ...,  1.6328e+00,
          -9.9609e-01, -6.7188e-01]],

        [[-2.0599e-03, -3.8338e-04,  5.0049e-03,  ...,  1.0437e-02,
           1.6235e-02,  2.6758e-01],
         [-6.3672e-01,  5.5469e-01, -4.7656e-01,  ..., -2.0625e+00,
           9.9609e-01, -1.2969e+00],
         [ 1.2266e+00,  1.9531e+00,  3.3984e-01,  ..., -7.0703e-01,
          -2.5781e+00,  2.0605e-01],
         ...,
         [ 3.9062e-01,  1.8906e+00,  3.1250e-02,  ...,  6.8359e-01,
           2.2852e-01, -3.0664e-01],
         [ 2.9688e-01, -9.6094e-01,  9.8828e-01,  ...,  1.2578e+00,
          -2.0000e+00, -1.3516e+00],
         [ 1.7734e+00, -2.6953e-01, -1.8906e+00,  ...,  1.6328e+00,
          -9.9609e-01, -6.7188e-01]],

        [[-2.0599e-03, -3.8338e-04,  5.0049e-03,  ...,  1.0437e-02,
           1.6235e-02,  2.6758e-01],
         [-6.3672e-01,  5.5469e-01, -4.7656e-01,  ..., -2.0625e+00,
           9.9609e-01, -1.2969e+00],
         [ 1.2266e+00,  1.9531e+00,  3.3984e-01,  ..., -7.0703e-01,
          -2.5781e+00,  2.0605e-01],
         ...,
         [ 3.9062e-01,  1.8906e+00,  3.1250e-02,  ...,  6.8359e-01,
           2.2852e-01, -3.0664e-01],
         [ 2.9688e-01, -9.6094e-01,  9.8828e-01,  ...,  1.2578e+00,
          -2.0000e+00, -1.3516e+00],
         [ 1.7734e+00, -2.6953e-01, -1.8906e+00,  ...,  1.6328e+00,
          -9.9609e-01, -6.7188e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-6.6833e-03,  9.1553e-03, -9.5215e-03,  ...,  9.5825e-03,
          -1.5747e-02,  8.4229e-03],
         [ 1.8164e-01,  5.7861e-02,  8.6719e-01,  ..., -2.4414e-01,
           3.2812e-01,  5.7422e-01],
         [ 1.6113e-01, -7.6953e-01,  4.9414e-01,  ..., -5.0391e-01,
          -2.1191e-01, -4.2578e-01],
         ...,
         [ 1.7188e-01, -1.1572e-01, -2.1484e-02,  ..., -2.3828e-01,
           6.3281e-01,  5.4297e-01],
         [ 8.6914e-02, -3.3398e-01, -7.8125e-02,  ..., -3.1445e-01,
           9.1309e-02, -2.5781e-01],
         [-3.2812e-01,  2.1191e-01,  2.6172e-01,  ..., -1.0205e-01,
           1.3281e+00, -2.5586e-01]],

        [[-6.6833e-03,  9.1553e-03, -9.5215e-03,  ...,  9.5825e-03,
          -1.5747e-02,  8.4229e-03],
         [ 1.8164e-01,  5.7861e-02,  8.6719e-01,  ..., -2.4414e-01,
           3.2812e-01,  5.7422e-01],
         [ 1.6113e-01, -7.6953e-01,  4.9414e-01,  ..., -5.0391e-01,
          -2.1191e-01, -4.2578e-01],
         ...,
         [ 1.7188e-01, -1.1572e-01, -2.1484e-02,  ..., -2.3828e-01,
           6.3281e-01,  5.4297e-01],
         [ 8.6914e-02, -3.3398e-01, -7.8125e-02,  ..., -3.1445e-01,
           9.1309e-02, -2.5781e-01],
         [-3.2812e-01,  2.1191e-01,  2.6172e-01,  ..., -1.0205e-01,
           1.3281e+00, -2.5586e-01]],

        [[-6.6833e-03,  9.1553e-03, -9.5215e-03,  ...,  9.5825e-03,
          -1.5747e-02,  8.4229e-03],
         [ 1.8164e-01,  5.7861e-02,  8.6719e-01,  ..., -2.4414e-01,
           3.2812e-01,  5.7422e-01],
         [ 1.6113e-01, -7.6953e-01,  4.9414e-01,  ..., -5.0391e-01,
          -2.1191e-01, -4.2578e-01],
         ...,
         [ 1.7188e-01, -1.1572e-01, -2.1484e-02,  ..., -2.3828e-01,
           6.3281e-01,  5.4297e-01],
         [ 8.6914e-02, -3.3398e-01, -7.8125e-02,  ..., -3.1445e-01,
           9.1309e-02, -2.5781e-01],
         [-3.2812e-01,  2.1191e-01,  2.6172e-01,  ..., -1.0205e-01,
           1.3281e+00, -2.5586e-01]],

        ...,

        [[ 7.4463e-03,  4.1504e-03, -7.1716e-03,  ...,  4.1199e-03,
           1.5625e-02, -1.1841e-02],
         [ 2.1289e-01,  2.9102e-01, -2.2363e-01,  ...,  1.0010e-01,
           2.1973e-01,  3.2812e-01],
         [-1.3672e-01,  5.5859e-01,  4.2534e-04,  ..., -3.8672e-01,
          -6.1328e-01, -5.4688e-01],
         ...,
         [ 1.4551e-01,  4.1406e-01, -5.8594e-01,  ..., -7.6953e-01,
          -8.2520e-02,  3.2422e-01],
         [-1.6699e-01, -3.3203e-02, -3.7842e-02,  ..., -4.4141e-01,
          -2.6855e-02,  2.8906e-01],
         [ 4.4141e-01,  1.0234e+00,  8.4961e-02,  ...,  3.7305e-01,
          -2.3828e-01, -9.7656e-02]],

        [[ 7.4463e-03,  4.1504e-03, -7.1716e-03,  ...,  4.1199e-03,
           1.5625e-02, -1.1841e-02],
         [ 2.1289e-01,  2.9102e-01, -2.2363e-01,  ...,  1.0010e-01,
           2.1973e-01,  3.2812e-01],
         [-1.3672e-01,  5.5859e-01,  4.2534e-04,  ..., -3.8672e-01,
          -6.1328e-01, -5.4688e-01],
         ...,
         [ 1.4551e-01,  4.1406e-01, -5.8594e-01,  ..., -7.6953e-01,
          -8.2520e-02,  3.2422e-01],
         [-1.6699e-01, -3.3203e-02, -3.7842e-02,  ..., -4.4141e-01,
          -2.6855e-02,  2.8906e-01],
         [ 4.4141e-01,  1.0234e+00,  8.4961e-02,  ...,  3.7305e-01,
          -2.3828e-01, -9.7656e-02]],

        [[ 7.4463e-03,  4.1504e-03, -7.1716e-03,  ...,  4.1199e-03,
           1.5625e-02, -1.1841e-02],
         [ 2.1289e-01,  2.9102e-01, -2.2363e-01,  ...,  1.0010e-01,
           2.1973e-01,  3.2812e-01],
         [-1.3672e-01,  5.5859e-01,  4.2534e-04,  ..., -3.8672e-01,
          -6.1328e-01, -5.4688e-01],
         ...,
         [ 1.4551e-01,  4.1406e-01, -5.8594e-01,  ..., -7.6953e-01,
          -8.2520e-02,  3.2422e-01],
         [-1.6699e-01, -3.3203e-02, -3.7842e-02,  ..., -4.4141e-01,
          -2.6855e-02,  2.8906e-01],
         [ 4.4141e-01,  1.0234e+00,  8.4961e-02,  ...,  3.7305e-01,
          -2.3828e-01, -9.7656e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 0.0128,  0.0156, -0.0175,  ...,  0.0996, -0.0767, -0.1533],
         [-1.1406, -0.2266, -0.6562,  ...,  1.5859,  0.4961,  0.7188],
         [-1.5938,  0.2578,  0.3047,  ...,  0.1641,  1.9609, -1.6484],
         ...,
         [-2.0469, -1.1094, -1.8203,  ...,  0.1680,  1.1094, -2.9062],
         [ 0.0063, -0.4082,  0.0391,  ..., -2.1719,  2.5312,  1.8047],
         [-1.1875, -0.7344, -0.9141,  ..., -0.6055, -0.3867, -0.0410]],

        [[ 0.0128,  0.0156, -0.0175,  ...,  0.0996, -0.0767, -0.1533],
         [-1.1406, -0.2266, -0.6562,  ...,  1.5859,  0.4961,  0.7188],
         [-1.5938,  0.2578,  0.3047,  ...,  0.1641,  1.9609, -1.6484],
         ...,
         [-2.0469, -1.1094, -1.8203,  ...,  0.1680,  1.1094, -2.9062],
         [ 0.0063, -0.4082,  0.0391,  ..., -2.1719,  2.5312,  1.8047],
         [-1.1875, -0.7344, -0.9141,  ..., -0.6055, -0.3867, -0.0410]],

        [[ 0.0128,  0.0156, -0.0175,  ...,  0.0996, -0.0767, -0.1533],
         [-1.1406, -0.2266, -0.6562,  ...,  1.5859,  0.4961,  0.7188],
         [-1.5938,  0.2578,  0.3047,  ...,  0.1641,  1.9609, -1.6484],
         ...,
         [-2.0469, -1.1094, -1.8203,  ...,  0.1680,  1.1094, -2.9062],
         [ 0.0063, -0.4082,  0.0391,  ..., -2.1719,  2.5312,  1.8047],
         [-1.1875, -0.7344, -0.9141,  ..., -0.6055, -0.3867, -0.0410]],

        ...,

        [[-0.0070, -0.0035, -0.0055,  ..., -0.0588, -0.0309,  0.0552],
         [ 1.9453,  1.6719,  0.0781,  ...,  1.7734, -0.5156,  1.0938],
         [ 2.8125,  1.0000,  0.5859,  ...,  1.2578,  0.5547,  1.7344],
         ...,
         [ 2.8750,  0.9023, -0.5430,  ...,  0.7539,  0.1641, -0.8438],
         [ 0.4609, -0.8828,  0.0352,  ...,  0.4473,  1.8203, -1.0000],
         [ 1.2812, -0.7305, -1.2734,  ..., -0.7500, -0.3809, -2.3281]],

        [[-0.0070, -0.0035, -0.0055,  ..., -0.0588, -0.0309,  0.0552],
         [ 1.9453,  1.6719,  0.0781,  ...,  1.7734, -0.5156,  1.0938],
         [ 2.8125,  1.0000,  0.5859,  ...,  1.2578,  0.5547,  1.7344],
         ...,
         [ 2.8750,  0.9023, -0.5430,  ...,  0.7539,  0.1641, -0.8438],
         [ 0.4609, -0.8828,  0.0352,  ...,  0.4473,  1.8203, -1.0000],
         [ 1.2812, -0.7305, -1.2734,  ..., -0.7500, -0.3809, -2.3281]],

        [[-0.0070, -0.0035, -0.0055,  ..., -0.0588, -0.0309,  0.0552],
         [ 1.9453,  1.6719,  0.0781,  ...,  1.7734, -0.5156,  1.0938],
         [ 2.8125,  1.0000,  0.5859,  ...,  1.2578,  0.5547,  1.7344],
         ...,
         [ 2.8750,  0.9023, -0.5430,  ...,  0.7539,  0.1641, -0.8438],
         [ 0.4609, -0.8828,  0.0352,  ...,  0.4473,  1.8203, -1.0000],
         [ 1.2812, -0.7305, -1.2734,  ..., -0.7500, -0.3809, -2.3281]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[-0.0125,  0.0048, -0.0069,  ...,  0.0050, -0.0082, -0.0070],
         [ 0.1729, -0.1914,  0.0452,  ...,  0.5156,  0.0698,  0.3691],
         [-1.0156,  0.7773,  0.3066,  ..., -0.1240,  0.8320,  0.8203],
         ...,
         [-0.3867,  0.8281, -0.7305,  ...,  0.2080, -0.0381, -0.7500],
         [-0.6133, -0.3047,  0.1768,  ...,  0.2129, -0.0342, -0.4980],
         [-0.0039,  0.4922, -0.2188,  ...,  0.1338, -0.0232, -0.2285]],

        [[-0.0125,  0.0048, -0.0069,  ...,  0.0050, -0.0082, -0.0070],
         [ 0.1729, -0.1914,  0.0452,  ...,  0.5156,  0.0698,  0.3691],
         [-1.0156,  0.7773,  0.3066,  ..., -0.1240,  0.8320,  0.8203],
         ...,
         [-0.3867,  0.8281, -0.7305,  ...,  0.2080, -0.0381, -0.7500],
         [-0.6133, -0.3047,  0.1768,  ...,  0.2129, -0.0342, -0.4980],
         [-0.0039,  0.4922, -0.2188,  ...,  0.1338, -0.0232, -0.2285]],

        [[-0.0125,  0.0048, -0.0069,  ...,  0.0050, -0.0082, -0.0070],
         [ 0.1729, -0.1914,  0.0452,  ...,  0.5156,  0.0698,  0.3691],
         [-1.0156,  0.7773,  0.3066,  ..., -0.1240,  0.8320,  0.8203],
         ...,
         [-0.3867,  0.8281, -0.7305,  ...,  0.2080, -0.0381, -0.7500],
         [-0.6133, -0.3047,  0.1768,  ...,  0.2129, -0.0342, -0.4980],
         [-0.0039,  0.4922, -0.2188,  ...,  0.1338, -0.0232, -0.2285]],

        ...,

        [[-0.0022, -0.0028, -0.0028,  ..., -0.0060,  0.0087, -0.0074],
         [ 0.0129,  0.1001,  0.2002,  ..., -0.2734,  0.4023,  0.2168],
         [ 0.3848,  0.5156,  0.6914,  ..., -0.5273,  0.2832,  0.9570],
         ...,
         [ 0.3203,  0.0518, -0.5703,  ...,  0.3145, -0.5781, -0.1191],
         [-0.0708,  0.5273, -0.2002,  ..., -0.3906, -0.0850,  0.4043],
         [ 0.3457,  0.4980,  0.2041,  ...,  0.0576,  0.2148,  0.0630]],

        [[-0.0022, -0.0028, -0.0028,  ..., -0.0060,  0.0087, -0.0074],
         [ 0.0129,  0.1001,  0.2002,  ..., -0.2734,  0.4023,  0.2168],
         [ 0.3848,  0.5156,  0.6914,  ..., -0.5273,  0.2832,  0.9570],
         ...,
         [ 0.3203,  0.0518, -0.5703,  ...,  0.3145, -0.5781, -0.1191],
         [-0.0708,  0.5273, -0.2002,  ..., -0.3906, -0.0850,  0.4043],
         [ 0.3457,  0.4980,  0.2041,  ...,  0.0576,  0.2148,  0.0630]],

        [[-0.0022, -0.0028, -0.0028,  ..., -0.0060,  0.0087, -0.0074],
         [ 0.0129,  0.1001,  0.2002,  ..., -0.2734,  0.4023,  0.2168],
         [ 0.3848,  0.5156,  0.6914,  ..., -0.5273,  0.2832,  0.9570],
         ...,
         [ 0.3203,  0.0518, -0.5703,  ...,  0.3145, -0.5781, -0.1191],
         [-0.0708,  0.5273, -0.2002,  ..., -0.3906, -0.0850,  0.4043],
         [ 0.3457,  0.4980,  0.2041,  ...,  0.0576,  0.2148,  0.0630]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[-1.5015e-02, -4.2725e-03, -8.5831e-04,  ..., -5.9375e-01,
          -1.5723e-01,  2.8076e-02],
         [-1.6641e+00,  2.8125e-01, -1.2188e+00,  ...,  2.2031e+00,
           6.6797e-01,  4.2188e-01],
         [-1.6406e-01,  1.0391e+00, -1.2695e-01,  ..., -1.5442e-02,
           2.4531e+00,  1.2109e+00],
         ...,
         [-7.9297e-01,  5.0000e-01, -4.8828e-01,  ...,  1.3516e+00,
           2.1719e+00, -9.2188e-01],
         [ 4.7266e-01, -1.6357e-02, -4.2578e-01,  ...,  8.0078e-01,
           1.0938e+00,  2.0898e-01],
         [ 6.9922e-01,  4.7266e-01, -1.2578e+00,  ...,  1.9531e+00,
           3.2617e-01,  7.8516e-01]],

        [[-1.5015e-02, -4.2725e-03, -8.5831e-04,  ..., -5.9375e-01,
          -1.5723e-01,  2.8076e-02],
         [-1.6641e+00,  2.8125e-01, -1.2188e+00,  ...,  2.2031e+00,
           6.6797e-01,  4.2188e-01],
         [-1.6406e-01,  1.0391e+00, -1.2695e-01,  ..., -1.5442e-02,
           2.4531e+00,  1.2109e+00],
         ...,
         [-7.9297e-01,  5.0000e-01, -4.8828e-01,  ...,  1.3516e+00,
           2.1719e+00, -9.2188e-01],
         [ 4.7266e-01, -1.6357e-02, -4.2578e-01,  ...,  8.0078e-01,
           1.0938e+00,  2.0898e-01],
         [ 6.9922e-01,  4.7266e-01, -1.2578e+00,  ...,  1.9531e+00,
           3.2617e-01,  7.8516e-01]],

        [[-1.5015e-02, -4.2725e-03, -8.5831e-04,  ..., -5.9375e-01,
          -1.5723e-01,  2.8076e-02],
         [-1.6641e+00,  2.8125e-01, -1.2188e+00,  ...,  2.2031e+00,
           6.6797e-01,  4.2188e-01],
         [-1.6406e-01,  1.0391e+00, -1.2695e-01,  ..., -1.5442e-02,
           2.4531e+00,  1.2109e+00],
         ...,
         [-7.9297e-01,  5.0000e-01, -4.8828e-01,  ...,  1.3516e+00,
           2.1719e+00, -9.2188e-01],
         [ 4.7266e-01, -1.6357e-02, -4.2578e-01,  ...,  8.0078e-01,
           1.0938e+00,  2.0898e-01],
         [ 6.9922e-01,  4.7266e-01, -1.2578e+00,  ...,  1.9531e+00,
           3.2617e-01,  7.8516e-01]],

        ...,

        [[ 3.1433e-03,  8.4839e-03,  4.5967e-04,  ...,  1.1914e-01,
           5.0781e-02, -9.8267e-03],
         [ 1.4688e+00, -8.5547e-01, -1.2734e+00,  ..., -4.9805e-01,
          -4.1504e-02,  1.1016e+00],
         [-1.2422e+00, -9.2969e-01, -9.9219e-01,  ...,  5.4297e-01,
          -1.6328e+00,  5.5078e-01],
         ...,
         [-3.9062e-02, -6.8750e-01, -1.0391e+00,  ...,  3.6523e-01,
           3.0469e-01,  1.6172e+00],
         [-9.2188e-01,  9.9609e-02, -2.0898e-01,  ...,  7.7344e-01,
          -9.6875e-01,  2.2188e+00],
         [-9.2188e-01, -6.5625e-01, -1.4688e+00,  ..., -6.6016e-01,
          -2.5391e-01,  9.0234e-01]],

        [[ 3.1433e-03,  8.4839e-03,  4.5967e-04,  ...,  1.1914e-01,
           5.0781e-02, -9.8267e-03],
         [ 1.4688e+00, -8.5547e-01, -1.2734e+00,  ..., -4.9805e-01,
          -4.1504e-02,  1.1016e+00],
         [-1.2422e+00, -9.2969e-01, -9.9219e-01,  ...,  5.4297e-01,
          -1.6328e+00,  5.5078e-01],
         ...,
         [-3.9062e-02, -6.8750e-01, -1.0391e+00,  ...,  3.6523e-01,
           3.0469e-01,  1.6172e+00],
         [-9.2188e-01,  9.9609e-02, -2.0898e-01,  ...,  7.7344e-01,
          -9.6875e-01,  2.2188e+00],
         [-9.2188e-01, -6.5625e-01, -1.4688e+00,  ..., -6.6016e-01,
          -2.5391e-01,  9.0234e-01]],

        [[ 3.1433e-03,  8.4839e-03,  4.5967e-04,  ...,  1.1914e-01,
           5.0781e-02, -9.8267e-03],
         [ 1.4688e+00, -8.5547e-01, -1.2734e+00,  ..., -4.9805e-01,
          -4.1504e-02,  1.1016e+00],
         [-1.2422e+00, -9.2969e-01, -9.9219e-01,  ...,  5.4297e-01,
          -1.6328e+00,  5.5078e-01],
         ...,
         [-3.9062e-02, -6.8750e-01, -1.0391e+00,  ...,  3.6523e-01,
           3.0469e-01,  1.6172e+00],
         [-9.2188e-01,  9.9609e-02, -2.0898e-01,  ...,  7.7344e-01,
          -9.6875e-01,  2.2188e+00],
         [-9.2188e-01, -6.5625e-01, -1.4688e+00,  ..., -6.6016e-01,
          -2.5391e-01,  9.0234e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-2.3651e-03, -5.9204e-03,  4.3030e-03,  ...,  5.5313e-04,
           9.9182e-04, -5.1880e-03],
         [ 4.3750e-01,  5.5859e-01,  6.6016e-01,  ...,  1.4160e-01,
          -8.3594e-01, -5.1514e-02],
         [ 5.1953e-01, -4.2725e-02,  1.4062e-01,  ..., -4.8828e-01,
           2.5586e-01,  2.6367e-01],
         ...,
         [-2.7344e-01,  3.1445e-01, -3.1982e-02,  ...,  1.6797e-01,
          -4.8047e-01,  9.5312e-01],
         [ 3.2617e-01,  4.5703e-01, -5.0391e-01,  ..., -3.3398e-01,
           1.3379e-01,  5.8203e-01],
         [-5.2734e-02,  7.9688e-01, -9.8145e-02,  ..., -4.6875e-02,
          -8.3203e-01,  1.7676e-01]],

        [[-2.3651e-03, -5.9204e-03,  4.3030e-03,  ...,  5.5313e-04,
           9.9182e-04, -5.1880e-03],
         [ 4.3750e-01,  5.5859e-01,  6.6016e-01,  ...,  1.4160e-01,
          -8.3594e-01, -5.1514e-02],
         [ 5.1953e-01, -4.2725e-02,  1.4062e-01,  ..., -4.8828e-01,
           2.5586e-01,  2.6367e-01],
         ...,
         [-2.7344e-01,  3.1445e-01, -3.1982e-02,  ...,  1.6797e-01,
          -4.8047e-01,  9.5312e-01],
         [ 3.2617e-01,  4.5703e-01, -5.0391e-01,  ..., -3.3398e-01,
           1.3379e-01,  5.8203e-01],
         [-5.2734e-02,  7.9688e-01, -9.8145e-02,  ..., -4.6875e-02,
          -8.3203e-01,  1.7676e-01]],

        [[-2.3651e-03, -5.9204e-03,  4.3030e-03,  ...,  5.5313e-04,
           9.9182e-04, -5.1880e-03],
         [ 4.3750e-01,  5.5859e-01,  6.6016e-01,  ...,  1.4160e-01,
          -8.3594e-01, -5.1514e-02],
         [ 5.1953e-01, -4.2725e-02,  1.4062e-01,  ..., -4.8828e-01,
           2.5586e-01,  2.6367e-01],
         ...,
         [-2.7344e-01,  3.1445e-01, -3.1982e-02,  ...,  1.6797e-01,
          -4.8047e-01,  9.5312e-01],
         [ 3.2617e-01,  4.5703e-01, -5.0391e-01,  ..., -3.3398e-01,
           1.3379e-01,  5.8203e-01],
         [-5.2734e-02,  7.9688e-01, -9.8145e-02,  ..., -4.6875e-02,
          -8.3203e-01,  1.7676e-01]],

        ...,

        [[-4.8065e-04,  9.7275e-04,  5.5176e-02,  ..., -3.0518e-02,
          -1.1108e-02,  2.5635e-02],
         [-2.1680e-01, -1.3750e+00, -3.3984e-01,  ...,  4.8438e-01,
          -1.5039e-01,  7.4219e-01],
         [-1.3965e-01, -8.9453e-01, -3.2617e-01,  ...,  3.6914e-01,
          -7.4609e-01,  1.6699e-01],
         ...,
         [-8.2031e-01, -1.6602e-01,  1.9043e-01,  ...,  1.2656e+00,
          -1.7188e-01,  1.3906e+00],
         [-8.9844e-01,  3.6328e-01,  1.3379e-01,  ...,  3.0469e-01,
          -3.5156e-01,  5.1172e-01],
         [ 1.1094e+00, -8.5156e-01, -4.6680e-01,  ..., -1.9629e-01,
           1.2598e-01, -5.4688e-01]],

        [[-4.8065e-04,  9.7275e-04,  5.5176e-02,  ..., -3.0518e-02,
          -1.1108e-02,  2.5635e-02],
         [-2.1680e-01, -1.3750e+00, -3.3984e-01,  ...,  4.8438e-01,
          -1.5039e-01,  7.4219e-01],
         [-1.3965e-01, -8.9453e-01, -3.2617e-01,  ...,  3.6914e-01,
          -7.4609e-01,  1.6699e-01],
         ...,
         [-8.2031e-01, -1.6602e-01,  1.9043e-01,  ...,  1.2656e+00,
          -1.7188e-01,  1.3906e+00],
         [-8.9844e-01,  3.6328e-01,  1.3379e-01,  ...,  3.0469e-01,
          -3.5156e-01,  5.1172e-01],
         [ 1.1094e+00, -8.5156e-01, -4.6680e-01,  ..., -1.9629e-01,
           1.2598e-01, -5.4688e-01]],

        [[-4.8065e-04,  9.7275e-04,  5.5176e-02,  ..., -3.0518e-02,
          -1.1108e-02,  2.5635e-02],
         [-2.1680e-01, -1.3750e+00, -3.3984e-01,  ...,  4.8438e-01,
          -1.5039e-01,  7.4219e-01],
         [-1.3965e-01, -8.9453e-01, -3.2617e-01,  ...,  3.6914e-01,
          -7.4609e-01,  1.6699e-01],
         ...,
         [-8.2031e-01, -1.6602e-01,  1.9043e-01,  ...,  1.2656e+00,
          -1.7188e-01,  1.3906e+00],
         [-8.9844e-01,  3.6328e-01,  1.3379e-01,  ...,  3.0469e-01,
          -3.5156e-01,  5.1172e-01],
         [ 1.1094e+00, -8.5156e-01, -4.6680e-01,  ..., -1.9629e-01,
           1.2598e-01, -5.4688e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.1780e-02, -3.5858e-03, -3.2196e-03,  ..., -3.6328e-01,
          -1.3733e-02, -4.6875e-02],
         [-1.5234e-01,  1.4062e-01, -9.8438e-01,  ..., -8.7500e-01,
           1.2969e+00,  1.2969e+00],
         [ 8.9062e-01, -1.4062e+00, -6.4062e-01,  ...,  1.8828e+00,
           2.2656e+00,  1.3438e+00],
         ...,
         [ 1.4062e+00, -1.9531e-02,  2.8125e-01,  ..., -1.6895e-01,
          -1.0156e+00, -1.1182e-01],
         [ 2.1484e-02,  2.8711e-01, -2.1875e-01,  ...,  1.9609e+00,
          -1.7031e+00, -1.3750e+00],
         [ 3.4766e-01,  8.7500e-01, -1.2266e+00,  ...,  1.3203e+00,
           2.4688e+00,  1.8516e+00]],

        [[ 1.1780e-02, -3.5858e-03, -3.2196e-03,  ..., -3.6328e-01,
          -1.3733e-02, -4.6875e-02],
         [-1.5234e-01,  1.4062e-01, -9.8438e-01,  ..., -8.7500e-01,
           1.2969e+00,  1.2969e+00],
         [ 8.9062e-01, -1.4062e+00, -6.4062e-01,  ...,  1.8828e+00,
           2.2656e+00,  1.3438e+00],
         ...,
         [ 1.4062e+00, -1.9531e-02,  2.8125e-01,  ..., -1.6895e-01,
          -1.0156e+00, -1.1182e-01],
         [ 2.1484e-02,  2.8711e-01, -2.1875e-01,  ...,  1.9609e+00,
          -1.7031e+00, -1.3750e+00],
         [ 3.4766e-01,  8.7500e-01, -1.2266e+00,  ...,  1.3203e+00,
           2.4688e+00,  1.8516e+00]],

        [[ 1.1780e-02, -3.5858e-03, -3.2196e-03,  ..., -3.6328e-01,
          -1.3733e-02, -4.6875e-02],
         [-1.5234e-01,  1.4062e-01, -9.8438e-01,  ..., -8.7500e-01,
           1.2969e+00,  1.2969e+00],
         [ 8.9062e-01, -1.4062e+00, -6.4062e-01,  ...,  1.8828e+00,
           2.2656e+00,  1.3438e+00],
         ...,
         [ 1.4062e+00, -1.9531e-02,  2.8125e-01,  ..., -1.6895e-01,
          -1.0156e+00, -1.1182e-01],
         [ 2.1484e-02,  2.8711e-01, -2.1875e-01,  ...,  1.9609e+00,
          -1.7031e+00, -1.3750e+00],
         [ 3.4766e-01,  8.7500e-01, -1.2266e+00,  ...,  1.3203e+00,
           2.4688e+00,  1.8516e+00]],

        ...,

        [[-6.5613e-04,  1.4099e-02, -2.3956e-03,  ..., -2.5977e-01,
          -1.6406e-01, -4.7607e-03],
         [-1.0156e+00,  3.3203e-01,  2.2656e+00,  ...,  1.3516e+00,
          -9.7266e-01, -3.9844e-01],
         [-2.8750e+00,  5.5469e-01,  2.0625e+00,  ...,  1.0391e+00,
          -9.7168e-02, -1.7109e+00],
         ...,
         [-2.4375e+00,  1.5469e+00,  2.6172e-01,  ..., -9.8438e-01,
           8.1641e-01, -4.6680e-01],
         [-5.3906e-01,  5.1172e-01,  2.4219e-01,  ..., -1.3672e+00,
          -6.3672e-01,  6.7188e-01],
         [-8.7109e-01,  7.7344e-01,  8.2812e-01,  ..., -1.8652e-01,
           4.8633e-01, -2.4023e-01]],

        [[-6.5613e-04,  1.4099e-02, -2.3956e-03,  ..., -2.5977e-01,
          -1.6406e-01, -4.7607e-03],
         [-1.0156e+00,  3.3203e-01,  2.2656e+00,  ...,  1.3516e+00,
          -9.7266e-01, -3.9844e-01],
         [-2.8750e+00,  5.5469e-01,  2.0625e+00,  ...,  1.0391e+00,
          -9.7168e-02, -1.7109e+00],
         ...,
         [-2.4375e+00,  1.5469e+00,  2.6172e-01,  ..., -9.8438e-01,
           8.1641e-01, -4.6680e-01],
         [-5.3906e-01,  5.1172e-01,  2.4219e-01,  ..., -1.3672e+00,
          -6.3672e-01,  6.7188e-01],
         [-8.7109e-01,  7.7344e-01,  8.2812e-01,  ..., -1.8652e-01,
           4.8633e-01, -2.4023e-01]],

        [[-6.5613e-04,  1.4099e-02, -2.3956e-03,  ..., -2.5977e-01,
          -1.6406e-01, -4.7607e-03],
         [-1.0156e+00,  3.3203e-01,  2.2656e+00,  ...,  1.3516e+00,
          -9.7266e-01, -3.9844e-01],
         [-2.8750e+00,  5.5469e-01,  2.0625e+00,  ...,  1.0391e+00,
          -9.7168e-02, -1.7109e+00],
         ...,
         [-2.4375e+00,  1.5469e+00,  2.6172e-01,  ..., -9.8438e-01,
           8.1641e-01, -4.6680e-01],
         [-5.3906e-01,  5.1172e-01,  2.4219e-01,  ..., -1.3672e+00,
          -6.3672e-01,  6.7188e-01],
         [-8.7109e-01,  7.7344e-01,  8.2812e-01,  ..., -1.8652e-01,
           4.8633e-01, -2.4023e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 3.7842e-02, -7.5073e-03, -7.4768e-04,  ...,  8.5449e-02,
           3.0762e-02,  1.2695e-02],
         [-6.9141e-01,  4.3359e-01, -2.8320e-02,  ...,  3.4570e-01,
          -2.2949e-01, -1.6699e-01],
         [-4.7070e-01,  7.6172e-02, -4.6631e-02,  ...,  1.8750e-01,
          -2.4219e-01, -5.5859e-01],
         ...,
         [-2.9053e-02,  5.7422e-01,  4.2578e-01,  ..., -4.7461e-01,
          -7.3828e-01, -3.1641e-01],
         [-7.1875e-01,  4.4141e-01,  6.3672e-01,  ..., -1.6797e-01,
          -6.6406e-01,  3.8867e-01],
         [ 4.3945e-01,  5.2734e-01, -9.7656e-02,  ...,  3.7305e-01,
          -2.3340e-01, -9.1406e-01]],

        [[ 3.7842e-02, -7.5073e-03, -7.4768e-04,  ...,  8.5449e-02,
           3.0762e-02,  1.2695e-02],
         [-6.9141e-01,  4.3359e-01, -2.8320e-02,  ...,  3.4570e-01,
          -2.2949e-01, -1.6699e-01],
         [-4.7070e-01,  7.6172e-02, -4.6631e-02,  ...,  1.8750e-01,
          -2.4219e-01, -5.5859e-01],
         ...,
         [-2.9053e-02,  5.7422e-01,  4.2578e-01,  ..., -4.7461e-01,
          -7.3828e-01, -3.1641e-01],
         [-7.1875e-01,  4.4141e-01,  6.3672e-01,  ..., -1.6797e-01,
          -6.6406e-01,  3.8867e-01],
         [ 4.3945e-01,  5.2734e-01, -9.7656e-02,  ...,  3.7305e-01,
          -2.3340e-01, -9.1406e-01]],

        [[ 3.7842e-02, -7.5073e-03, -7.4768e-04,  ...,  8.5449e-02,
           3.0762e-02,  1.2695e-02],
         [-6.9141e-01,  4.3359e-01, -2.8320e-02,  ...,  3.4570e-01,
          -2.2949e-01, -1.6699e-01],
         [-4.7070e-01,  7.6172e-02, -4.6631e-02,  ...,  1.8750e-01,
          -2.4219e-01, -5.5859e-01],
         ...,
         [-2.9053e-02,  5.7422e-01,  4.2578e-01,  ..., -4.7461e-01,
          -7.3828e-01, -3.1641e-01],
         [-7.1875e-01,  4.4141e-01,  6.3672e-01,  ..., -1.6797e-01,
          -6.6406e-01,  3.8867e-01],
         [ 4.3945e-01,  5.2734e-01, -9.7656e-02,  ...,  3.7305e-01,
          -2.3340e-01, -9.1406e-01]],

        ...,

        [[ 1.3580e-03, -1.8082e-03, -3.6812e-04,  ...,  1.6861e-03,
          -4.3945e-03,  9.2163e-03],
         [ 6.5625e-01, -1.4954e-03, -7.1484e-01,  ...,  2.5000e-01,
           8.4375e-01,  1.0781e+00],
         [-1.8555e-01, -9.1309e-02, -5.0000e-01,  ..., -4.0430e-01,
           1.7188e-01,  5.8203e-01],
         ...,
         [ 3.4961e-01,  1.2158e-01,  7.4609e-01,  ..., -4.6875e-02,
           2.7930e-01, -2.1289e-01],
         [-7.7637e-02, -1.1963e-01,  9.2578e-01,  ...,  5.3906e-01,
          -6.4844e-01, -1.7285e-01],
         [-9.8828e-01, -4.9414e-01,  4.3945e-01,  ...,  2.1210e-03,
           6.1719e-01, -1.0205e-01]],

        [[ 1.3580e-03, -1.8082e-03, -3.6812e-04,  ...,  1.6861e-03,
          -4.3945e-03,  9.2163e-03],
         [ 6.5625e-01, -1.4954e-03, -7.1484e-01,  ...,  2.5000e-01,
           8.4375e-01,  1.0781e+00],
         [-1.8555e-01, -9.1309e-02, -5.0000e-01,  ..., -4.0430e-01,
           1.7188e-01,  5.8203e-01],
         ...,
         [ 3.4961e-01,  1.2158e-01,  7.4609e-01,  ..., -4.6875e-02,
           2.7930e-01, -2.1289e-01],
         [-7.7637e-02, -1.1963e-01,  9.2578e-01,  ...,  5.3906e-01,
          -6.4844e-01, -1.7285e-01],
         [-9.8828e-01, -4.9414e-01,  4.3945e-01,  ...,  2.1210e-03,
           6.1719e-01, -1.0205e-01]],

        [[ 1.3580e-03, -1.8082e-03, -3.6812e-04,  ...,  1.6861e-03,
          -4.3945e-03,  9.2163e-03],
         [ 6.5625e-01, -1.4954e-03, -7.1484e-01,  ...,  2.5000e-01,
           8.4375e-01,  1.0781e+00],
         [-1.8555e-01, -9.1309e-02, -5.0000e-01,  ..., -4.0430e-01,
           1.7188e-01,  5.8203e-01],
         ...,
         [ 3.4961e-01,  1.2158e-01,  7.4609e-01,  ..., -4.6875e-02,
           2.7930e-01, -2.1289e-01],
         [-7.7637e-02, -1.1963e-01,  9.2578e-01,  ...,  5.3906e-01,
          -6.4844e-01, -1.7285e-01],
         [-9.8828e-01, -4.9414e-01,  4.3945e-01,  ...,  2.1210e-03,
           6.1719e-01, -1.0205e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.1719e-02,  1.8799e-02,  3.8605e-03,  ..., -1.9629e-01,
          -1.7090e-01, -1.7578e-01],
         [ 3.2344e+00, -3.0781e+00,  3.3203e-01,  ...,  5.3516e-01,
          -2.3125e+00,  9.1016e-01],
         [ 3.6914e-01,  1.4160e-01,  1.7031e+00,  ...,  9.1016e-01,
           3.7500e-01, -1.6602e-01],
         ...,
         [ 1.1328e-01, -4.3555e-01,  1.9922e+00,  ..., -1.1016e+00,
          -9.2969e-01,  2.3281e+00],
         [-1.2812e+00, -5.1953e-01,  6.0938e-01,  ...,  9.6191e-02,
          -8.7891e-01,  1.8594e+00],
         [-9.6094e-01,  1.2969e+00,  1.3203e+00,  ...,  1.0078e+00,
           2.4375e+00, -1.6719e+00]],

        [[-1.1719e-02,  1.8799e-02,  3.8605e-03,  ..., -1.9629e-01,
          -1.7090e-01, -1.7578e-01],
         [ 3.2344e+00, -3.0781e+00,  3.3203e-01,  ...,  5.3516e-01,
          -2.3125e+00,  9.1016e-01],
         [ 3.6914e-01,  1.4160e-01,  1.7031e+00,  ...,  9.1016e-01,
           3.7500e-01, -1.6602e-01],
         ...,
         [ 1.1328e-01, -4.3555e-01,  1.9922e+00,  ..., -1.1016e+00,
          -9.2969e-01,  2.3281e+00],
         [-1.2812e+00, -5.1953e-01,  6.0938e-01,  ...,  9.6191e-02,
          -8.7891e-01,  1.8594e+00],
         [-9.6094e-01,  1.2969e+00,  1.3203e+00,  ...,  1.0078e+00,
           2.4375e+00, -1.6719e+00]],

        [[-1.1719e-02,  1.8799e-02,  3.8605e-03,  ..., -1.9629e-01,
          -1.7090e-01, -1.7578e-01],
         [ 3.2344e+00, -3.0781e+00,  3.3203e-01,  ...,  5.3516e-01,
          -2.3125e+00,  9.1016e-01],
         [ 3.6914e-01,  1.4160e-01,  1.7031e+00,  ...,  9.1016e-01,
           3.7500e-01, -1.6602e-01],
         ...,
         [ 1.1328e-01, -4.3555e-01,  1.9922e+00,  ..., -1.1016e+00,
          -9.2969e-01,  2.3281e+00],
         [-1.2812e+00, -5.1953e-01,  6.0938e-01,  ...,  9.6191e-02,
          -8.7891e-01,  1.8594e+00],
         [-9.6094e-01,  1.2969e+00,  1.3203e+00,  ...,  1.0078e+00,
           2.4375e+00, -1.6719e+00]],

        ...,

        [[ 6.1798e-04,  1.2024e-02, -1.5076e-02,  ...,  1.9141e-01,
           4.7656e-01,  4.5703e-01],
         [-3.6406e+00, -1.2109e-01,  1.8203e+00,  ...,  1.6797e+00,
          -2.1719e+00,  1.7109e+00],
         [-1.5938e+00,  1.7383e-01,  1.0625e+00,  ...,  1.1250e+00,
          -4.1875e+00,  7.7637e-02],
         ...,
         [-2.0469e+00,  1.5859e+00, -5.2344e-01,  ..., -7.0703e-01,
          -3.7812e+00,  5.1514e-02],
         [ 3.9062e-01,  1.1719e+00,  5.7812e-01,  ..., -4.3750e-01,
          -3.6250e+00, -1.8047e+00],
         [-4.4727e-01,  1.0156e+00,  1.9141e-01,  ...,  2.3281e+00,
          -1.2969e+00,  8.4375e-01]],

        [[ 6.1798e-04,  1.2024e-02, -1.5076e-02,  ...,  1.9141e-01,
           4.7656e-01,  4.5703e-01],
         [-3.6406e+00, -1.2109e-01,  1.8203e+00,  ...,  1.6797e+00,
          -2.1719e+00,  1.7109e+00],
         [-1.5938e+00,  1.7383e-01,  1.0625e+00,  ...,  1.1250e+00,
          -4.1875e+00,  7.7637e-02],
         ...,
         [-2.0469e+00,  1.5859e+00, -5.2344e-01,  ..., -7.0703e-01,
          -3.7812e+00,  5.1514e-02],
         [ 3.9062e-01,  1.1719e+00,  5.7812e-01,  ..., -4.3750e-01,
          -3.6250e+00, -1.8047e+00],
         [-4.4727e-01,  1.0156e+00,  1.9141e-01,  ...,  2.3281e+00,
          -1.2969e+00,  8.4375e-01]],

        [[ 6.1798e-04,  1.2024e-02, -1.5076e-02,  ...,  1.9141e-01,
           4.7656e-01,  4.5703e-01],
         [-3.6406e+00, -1.2109e-01,  1.8203e+00,  ...,  1.6797e+00,
          -2.1719e+00,  1.7109e+00],
         [-1.5938e+00,  1.7383e-01,  1.0625e+00,  ...,  1.1250e+00,
          -4.1875e+00,  7.7637e-02],
         ...,
         [-2.0469e+00,  1.5859e+00, -5.2344e-01,  ..., -7.0703e-01,
          -3.7812e+00,  5.1514e-02],
         [ 3.9062e-01,  1.1719e+00,  5.7812e-01,  ..., -4.3750e-01,
          -3.6250e+00, -1.8047e+00],
         [-4.4727e-01,  1.0156e+00,  1.9141e-01,  ...,  2.3281e+00,
          -1.2969e+00,  8.4375e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-3.8300e-03,  1.1749e-03, -5.0964e-03,  ...,  1.0193e-02,
           1.3489e-02, -1.4221e-02],
         [ 2.2559e-01, -1.7212e-02, -2.6562e-01,  ...,  2.5586e-01,
           1.8457e-01,  7.3828e-01],
         [-4.7656e-01,  3.0664e-01,  2.2559e-01,  ..., -1.1084e-01,
          -3.3594e-01, -2.4023e-01],
         ...,
         [ 7.6562e-01,  3.4766e-01, -1.2031e+00,  ..., -3.7891e-01,
          -2.8564e-02, -1.4453e-01],
         [ 2.3535e-01,  3.3789e-01, -4.4922e-01,  ..., -8.3008e-02,
          -3.1641e-01, -2.4707e-01],
         [ 1.1816e-01, -5.2344e-01,  7.5391e-01,  ..., -2.3047e-01,
           4.6680e-01,  3.7305e-01]],

        [[-3.8300e-03,  1.1749e-03, -5.0964e-03,  ...,  1.0193e-02,
           1.3489e-02, -1.4221e-02],
         [ 2.2559e-01, -1.7212e-02, -2.6562e-01,  ...,  2.5586e-01,
           1.8457e-01,  7.3828e-01],
         [-4.7656e-01,  3.0664e-01,  2.2559e-01,  ..., -1.1084e-01,
          -3.3594e-01, -2.4023e-01],
         ...,
         [ 7.6562e-01,  3.4766e-01, -1.2031e+00,  ..., -3.7891e-01,
          -2.8564e-02, -1.4453e-01],
         [ 2.3535e-01,  3.3789e-01, -4.4922e-01,  ..., -8.3008e-02,
          -3.1641e-01, -2.4707e-01],
         [ 1.1816e-01, -5.2344e-01,  7.5391e-01,  ..., -2.3047e-01,
           4.6680e-01,  3.7305e-01]],

        [[-3.8300e-03,  1.1749e-03, -5.0964e-03,  ...,  1.0193e-02,
           1.3489e-02, -1.4221e-02],
         [ 2.2559e-01, -1.7212e-02, -2.6562e-01,  ...,  2.5586e-01,
           1.8457e-01,  7.3828e-01],
         [-4.7656e-01,  3.0664e-01,  2.2559e-01,  ..., -1.1084e-01,
          -3.3594e-01, -2.4023e-01],
         ...,
         [ 7.6562e-01,  3.4766e-01, -1.2031e+00,  ..., -3.7891e-01,
          -2.8564e-02, -1.4453e-01],
         [ 2.3535e-01,  3.3789e-01, -4.4922e-01,  ..., -8.3008e-02,
          -3.1641e-01, -2.4707e-01],
         [ 1.1816e-01, -5.2344e-01,  7.5391e-01,  ..., -2.3047e-01,
           4.6680e-01,  3.7305e-01]],

        ...,

        [[-8.8501e-03, -3.8452e-03, -2.5940e-03,  ...,  2.5635e-03,
           3.6469e-03, -3.3875e-03],
         [ 1.6504e-01, -4.8242e-01, -6.4453e-02,  ..., -2.0508e-01,
          -2.4902e-01, -2.7344e-01],
         [-7.1484e-01,  8.3203e-01,  2.2656e-01,  ..., -7.3242e-02,
          -9.1797e-01,  1.1016e+00],
         ...,
         [-3.3984e-01,  3.8867e-01, -4.2188e-01,  ..., -2.5000e-01,
          -3.8867e-01, -3.2422e-01],
         [-1.3965e-01,  4.3750e-01, -3.8281e-01,  ...,  2.6758e-01,
           6.6895e-02, -2.2168e-01],
         [ 1.0156e+00, -1.1328e-01, -8.1250e-01,  ..., -5.2979e-02,
           1.0205e-01,  6.6406e-01]],

        [[-8.8501e-03, -3.8452e-03, -2.5940e-03,  ...,  2.5635e-03,
           3.6469e-03, -3.3875e-03],
         [ 1.6504e-01, -4.8242e-01, -6.4453e-02,  ..., -2.0508e-01,
          -2.4902e-01, -2.7344e-01],
         [-7.1484e-01,  8.3203e-01,  2.2656e-01,  ..., -7.3242e-02,
          -9.1797e-01,  1.1016e+00],
         ...,
         [-3.3984e-01,  3.8867e-01, -4.2188e-01,  ..., -2.5000e-01,
          -3.8867e-01, -3.2422e-01],
         [-1.3965e-01,  4.3750e-01, -3.8281e-01,  ...,  2.6758e-01,
           6.6895e-02, -2.2168e-01],
         [ 1.0156e+00, -1.1328e-01, -8.1250e-01,  ..., -5.2979e-02,
           1.0205e-01,  6.6406e-01]],

        [[-8.8501e-03, -3.8452e-03, -2.5940e-03,  ...,  2.5635e-03,
           3.6469e-03, -3.3875e-03],
         [ 1.6504e-01, -4.8242e-01, -6.4453e-02,  ..., -2.0508e-01,
          -2.4902e-01, -2.7344e-01],
         [-7.1484e-01,  8.3203e-01,  2.2656e-01,  ..., -7.3242e-02,
          -9.1797e-01,  1.1016e+00],
         ...,
         [-3.3984e-01,  3.8867e-01, -4.2188e-01,  ..., -2.5000e-01,
          -3.8867e-01, -3.2422e-01],
         [-1.3965e-01,  4.3750e-01, -3.8281e-01,  ...,  2.6758e-01,
           6.6895e-02, -2.2168e-01],
         [ 1.0156e+00, -1.1328e-01, -8.1250e-01,  ..., -5.2979e-02,
           1.0205e-01,  6.6406e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 3.0060e-03,  8.0872e-04, -1.0437e-02,  ..., -2.7344e-02,
          -3.0518e-03, -2.2461e-01],
         [-1.8594e+00,  2.4688e+00,  2.5000e+00,  ...,  1.6328e+00,
          -5.8350e-02,  1.6328e+00],
         [ 1.2500e+00, -1.9336e-01,  1.2656e+00,  ...,  1.1562e+00,
          -5.8203e-01,  1.1562e+00],
         ...,
         [-4.9609e-01, -1.7188e+00, -6.1328e-01,  ...,  2.2656e+00,
          -6.4844e-01, -3.7500e-01],
         [-3.4570e-01, -4.0820e-01, -6.8750e-01,  ...,  2.1094e+00,
           9.8145e-02, -2.5000e+00],
         [ 1.5625e+00, -4.1797e-01,  9.7266e-01,  ...,  3.0469e-01,
          -7.7344e-01, -5.6641e-01]],

        [[ 3.0060e-03,  8.0872e-04, -1.0437e-02,  ..., -2.7344e-02,
          -3.0518e-03, -2.2461e-01],
         [-1.8594e+00,  2.4688e+00,  2.5000e+00,  ...,  1.6328e+00,
          -5.8350e-02,  1.6328e+00],
         [ 1.2500e+00, -1.9336e-01,  1.2656e+00,  ...,  1.1562e+00,
          -5.8203e-01,  1.1562e+00],
         ...,
         [-4.9609e-01, -1.7188e+00, -6.1328e-01,  ...,  2.2656e+00,
          -6.4844e-01, -3.7500e-01],
         [-3.4570e-01, -4.0820e-01, -6.8750e-01,  ...,  2.1094e+00,
           9.8145e-02, -2.5000e+00],
         [ 1.5625e+00, -4.1797e-01,  9.7266e-01,  ...,  3.0469e-01,
          -7.7344e-01, -5.6641e-01]],

        [[ 3.0060e-03,  8.0872e-04, -1.0437e-02,  ..., -2.7344e-02,
          -3.0518e-03, -2.2461e-01],
         [-1.8594e+00,  2.4688e+00,  2.5000e+00,  ...,  1.6328e+00,
          -5.8350e-02,  1.6328e+00],
         [ 1.2500e+00, -1.9336e-01,  1.2656e+00,  ...,  1.1562e+00,
          -5.8203e-01,  1.1562e+00],
         ...,
         [-4.9609e-01, -1.7188e+00, -6.1328e-01,  ...,  2.2656e+00,
          -6.4844e-01, -3.7500e-01],
         [-3.4570e-01, -4.0820e-01, -6.8750e-01,  ...,  2.1094e+00,
           9.8145e-02, -2.5000e+00],
         [ 1.5625e+00, -4.1797e-01,  9.7266e-01,  ...,  3.0469e-01,
          -7.7344e-01, -5.6641e-01]],

        ...,

        [[-8.2779e-04,  5.0964e-03, -1.3123e-03,  ...,  2.4609e-01,
           7.2656e-01,  5.5078e-01],
         [-1.8203e+00,  7.3047e-01, -2.6875e+00,  ...,  4.5312e-01,
          -3.4219e+00,  3.5156e-01],
         [ 1.9824e-01,  3.9453e-01, -7.2266e-01,  ...,  5.5859e-01,
          -6.7500e+00,  4.4141e-01],
         ...,
         [-1.0469e+00, -7.5391e-01, -6.4062e-01,  ...,  2.3438e+00,
          -4.1875e+00,  2.3438e+00],
         [ 9.7656e-02,  8.2812e-01, -1.1572e-01,  ..., -1.5332e-01,
          -6.5312e+00,  3.3594e+00],
         [ 8.3984e-01,  3.0469e-01, -1.4531e+00,  ..., -1.3359e+00,
          -2.6719e+00,  3.0000e+00]],

        [[-8.2779e-04,  5.0964e-03, -1.3123e-03,  ...,  2.4609e-01,
           7.2656e-01,  5.5078e-01],
         [-1.8203e+00,  7.3047e-01, -2.6875e+00,  ...,  4.5312e-01,
          -3.4219e+00,  3.5156e-01],
         [ 1.9824e-01,  3.9453e-01, -7.2266e-01,  ...,  5.5859e-01,
          -6.7500e+00,  4.4141e-01],
         ...,
         [-1.0469e+00, -7.5391e-01, -6.4062e-01,  ...,  2.3438e+00,
          -4.1875e+00,  2.3438e+00],
         [ 9.7656e-02,  8.2812e-01, -1.1572e-01,  ..., -1.5332e-01,
          -6.5312e+00,  3.3594e+00],
         [ 8.3984e-01,  3.0469e-01, -1.4531e+00,  ..., -1.3359e+00,
          -2.6719e+00,  3.0000e+00]],

        [[-8.2779e-04,  5.0964e-03, -1.3123e-03,  ...,  2.4609e-01,
           7.2656e-01,  5.5078e-01],
         [-1.8203e+00,  7.3047e-01, -2.6875e+00,  ...,  4.5312e-01,
          -3.4219e+00,  3.5156e-01],
         [ 1.9824e-01,  3.9453e-01, -7.2266e-01,  ...,  5.5859e-01,
          -6.7500e+00,  4.4141e-01],
         ...,
         [-1.0469e+00, -7.5391e-01, -6.4062e-01,  ...,  2.3438e+00,
          -4.1875e+00,  2.3438e+00],
         [ 9.7656e-02,  8.2812e-01, -1.1572e-01,  ..., -1.5332e-01,
          -6.5312e+00,  3.3594e+00],
         [ 8.3984e-01,  3.0469e-01, -1.4531e+00,  ..., -1.3359e+00,
          -2.6719e+00,  3.0000e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.3046e-03,  9.9487e-03, -1.2756e-02,  ..., -1.3000e-02,
           3.1281e-03, -3.1891e-03],
         [ 3.8672e-01, -2.0215e-01, -6.1035e-02,  ...,  8.5449e-02,
           4.7656e-01,  5.9570e-02],
         [ 1.0391e+00, -6.9336e-02, -2.5391e-01,  ..., -3.4375e-01,
          -4.6387e-02, -2.8906e-01],
         ...,
         [-6.8359e-01,  2.4902e-01,  1.6895e-01,  ..., -1.6016e-01,
          -2.2559e-01, -3.7305e-01],
         [-1.4844e-01,  6.0547e-01, -4.9805e-01,  ...,  9.2969e-01,
           1.5918e-01,  2.5977e-01],
         [ 1.1426e-01, -6.3965e-02, -1.1963e-01,  ..., -1.8359e-01,
          -4.8633e-01, -4.6680e-01]],

        [[-1.3046e-03,  9.9487e-03, -1.2756e-02,  ..., -1.3000e-02,
           3.1281e-03, -3.1891e-03],
         [ 3.8672e-01, -2.0215e-01, -6.1035e-02,  ...,  8.5449e-02,
           4.7656e-01,  5.9570e-02],
         [ 1.0391e+00, -6.9336e-02, -2.5391e-01,  ..., -3.4375e-01,
          -4.6387e-02, -2.8906e-01],
         ...,
         [-6.8359e-01,  2.4902e-01,  1.6895e-01,  ..., -1.6016e-01,
          -2.2559e-01, -3.7305e-01],
         [-1.4844e-01,  6.0547e-01, -4.9805e-01,  ...,  9.2969e-01,
           1.5918e-01,  2.5977e-01],
         [ 1.1426e-01, -6.3965e-02, -1.1963e-01,  ..., -1.8359e-01,
          -4.8633e-01, -4.6680e-01]],

        [[-1.3046e-03,  9.9487e-03, -1.2756e-02,  ..., -1.3000e-02,
           3.1281e-03, -3.1891e-03],
         [ 3.8672e-01, -2.0215e-01, -6.1035e-02,  ...,  8.5449e-02,
           4.7656e-01,  5.9570e-02],
         [ 1.0391e+00, -6.9336e-02, -2.5391e-01,  ..., -3.4375e-01,
          -4.6387e-02, -2.8906e-01],
         ...,
         [-6.8359e-01,  2.4902e-01,  1.6895e-01,  ..., -1.6016e-01,
          -2.2559e-01, -3.7305e-01],
         [-1.4844e-01,  6.0547e-01, -4.9805e-01,  ...,  9.2969e-01,
           1.5918e-01,  2.5977e-01],
         [ 1.1426e-01, -6.3965e-02, -1.1963e-01,  ..., -1.8359e-01,
          -4.8633e-01, -4.6680e-01]],

        ...,

        [[ 2.9785e-02, -4.5166e-02,  2.2949e-02,  ...,  2.6001e-02,
           3.2715e-02,  8.3618e-03],
         [-6.6016e-01, -8.2520e-02, -2.0410e-01,  ...,  9.5825e-03,
          -3.3008e-01, -7.3730e-02],
         [-8.3203e-01,  3.6719e-01, -1.0391e+00,  ..., -5.0781e-01,
           1.6211e-01, -2.3242e-01],
         ...,
         [-6.3672e-01,  3.8867e-01, -8.5547e-01,  ...,  1.1328e+00,
          -6.0547e-01,  4.8438e-01],
         [-2.8906e-01,  3.8818e-02, -1.3828e+00,  ...,  3.4766e-01,
          -3.5938e-01, -7.3242e-02],
         [-6.9141e-01, -2.5781e-01, -4.4727e-01,  ..., -5.3125e-01,
          -1.4746e-01,  2.6245e-02]],

        [[ 2.9785e-02, -4.5166e-02,  2.2949e-02,  ...,  2.6001e-02,
           3.2715e-02,  8.3618e-03],
         [-6.6016e-01, -8.2520e-02, -2.0410e-01,  ...,  9.5825e-03,
          -3.3008e-01, -7.3730e-02],
         [-8.3203e-01,  3.6719e-01, -1.0391e+00,  ..., -5.0781e-01,
           1.6211e-01, -2.3242e-01],
         ...,
         [-6.3672e-01,  3.8867e-01, -8.5547e-01,  ...,  1.1328e+00,
          -6.0547e-01,  4.8438e-01],
         [-2.8906e-01,  3.8818e-02, -1.3828e+00,  ...,  3.4766e-01,
          -3.5938e-01, -7.3242e-02],
         [-6.9141e-01, -2.5781e-01, -4.4727e-01,  ..., -5.3125e-01,
          -1.4746e-01,  2.6245e-02]],

        [[ 2.9785e-02, -4.5166e-02,  2.2949e-02,  ...,  2.6001e-02,
           3.2715e-02,  8.3618e-03],
         [-6.6016e-01, -8.2520e-02, -2.0410e-01,  ...,  9.5825e-03,
          -3.3008e-01, -7.3730e-02],
         [-8.3203e-01,  3.6719e-01, -1.0391e+00,  ..., -5.0781e-01,
           1.6211e-01, -2.3242e-01],
         ...,
         [-6.3672e-01,  3.8867e-01, -8.5547e-01,  ...,  1.1328e+00,
          -6.0547e-01,  4.8438e-01],
         [-2.8906e-01,  3.8818e-02, -1.3828e+00,  ...,  3.4766e-01,
          -3.5938e-01, -7.3242e-02],
         [-6.9141e-01, -2.5781e-01, -4.4727e-01,  ..., -5.3125e-01,
          -1.4746e-01,  2.6245e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.6602e-02,  7.8735e-03, -1.5869e-02,  ..., -4.2969e-01,
           3.4570e-01,  9.5215e-02],
         [-2.8281e+00, -1.3750e+00,  1.6641e+00,  ...,  1.5625e+00,
           3.0273e-02,  7.7637e-02],
         [ 5.7812e-01, -7.7344e-01,  1.5781e+00,  ...,  2.2969e+00,
           1.6641e+00, -5.8984e-01],
         ...,
         [ 1.9922e-01, -2.4844e+00,  3.3594e-01,  ...,  1.4453e+00,
          -1.4688e+00,  6.1719e-01],
         [ 2.2461e-02, -4.3359e-01, -5.1172e-01,  ...,  1.2734e+00,
          -8.7891e-01,  7.1094e-01],
         [ 1.3516e+00, -3.9844e-01,  9.4531e-01,  ...,  8.2031e-01,
           2.1289e-01, -2.2188e+00]],

        [[-1.6602e-02,  7.8735e-03, -1.5869e-02,  ..., -4.2969e-01,
           3.4570e-01,  9.5215e-02],
         [-2.8281e+00, -1.3750e+00,  1.6641e+00,  ...,  1.5625e+00,
           3.0273e-02,  7.7637e-02],
         [ 5.7812e-01, -7.7344e-01,  1.5781e+00,  ...,  2.2969e+00,
           1.6641e+00, -5.8984e-01],
         ...,
         [ 1.9922e-01, -2.4844e+00,  3.3594e-01,  ...,  1.4453e+00,
          -1.4688e+00,  6.1719e-01],
         [ 2.2461e-02, -4.3359e-01, -5.1172e-01,  ...,  1.2734e+00,
          -8.7891e-01,  7.1094e-01],
         [ 1.3516e+00, -3.9844e-01,  9.4531e-01,  ...,  8.2031e-01,
           2.1289e-01, -2.2188e+00]],

        [[-1.6602e-02,  7.8735e-03, -1.5869e-02,  ..., -4.2969e-01,
           3.4570e-01,  9.5215e-02],
         [-2.8281e+00, -1.3750e+00,  1.6641e+00,  ...,  1.5625e+00,
           3.0273e-02,  7.7637e-02],
         [ 5.7812e-01, -7.7344e-01,  1.5781e+00,  ...,  2.2969e+00,
           1.6641e+00, -5.8984e-01],
         ...,
         [ 1.9922e-01, -2.4844e+00,  3.3594e-01,  ...,  1.4453e+00,
          -1.4688e+00,  6.1719e-01],
         [ 2.2461e-02, -4.3359e-01, -5.1172e-01,  ...,  1.2734e+00,
          -8.7891e-01,  7.1094e-01],
         [ 1.3516e+00, -3.9844e-01,  9.4531e-01,  ...,  8.2031e-01,
           2.1289e-01, -2.2188e+00]],

        ...,

        [[-2.3315e-02, -1.9043e-02, -2.8687e-03,  ...,  7.2754e-02,
          -2.9219e+00, -4.1602e-01],
         [ 3.3203e-01, -5.1953e-01, -1.9922e-01,  ...,  5.7422e-01,
           8.7500e+00,  5.1875e+00],
         [-2.8516e-01, -3.2422e-01, -1.7578e-01,  ...,  8.8281e-01,
           8.1875e+00,  5.2500e+00],
         ...,
         [-1.0645e-01, -2.2461e-01,  6.9141e-01,  ...,  6.5625e-01,
           8.3125e+00, -1.0547e+00],
         [ 8.3008e-02,  2.9688e-01,  1.5430e-01,  ...,  9.1797e-01,
           5.9375e+00, -2.1094e+00],
         [-7.9688e-01,  5.8594e-01, -2.3340e-01,  ..., -1.5938e+00,
           7.2188e+00,  2.4688e+00]],

        [[-2.3315e-02, -1.9043e-02, -2.8687e-03,  ...,  7.2754e-02,
          -2.9219e+00, -4.1602e-01],
         [ 3.3203e-01, -5.1953e-01, -1.9922e-01,  ...,  5.7422e-01,
           8.7500e+00,  5.1875e+00],
         [-2.8516e-01, -3.2422e-01, -1.7578e-01,  ...,  8.8281e-01,
           8.1875e+00,  5.2500e+00],
         ...,
         [-1.0645e-01, -2.2461e-01,  6.9141e-01,  ...,  6.5625e-01,
           8.3125e+00, -1.0547e+00],
         [ 8.3008e-02,  2.9688e-01,  1.5430e-01,  ...,  9.1797e-01,
           5.9375e+00, -2.1094e+00],
         [-7.9688e-01,  5.8594e-01, -2.3340e-01,  ..., -1.5938e+00,
           7.2188e+00,  2.4688e+00]],

        [[-2.3315e-02, -1.9043e-02, -2.8687e-03,  ...,  7.2754e-02,
          -2.9219e+00, -4.1602e-01],
         [ 3.3203e-01, -5.1953e-01, -1.9922e-01,  ...,  5.7422e-01,
           8.7500e+00,  5.1875e+00],
         [-2.8516e-01, -3.2422e-01, -1.7578e-01,  ...,  8.8281e-01,
           8.1875e+00,  5.2500e+00],
         ...,
         [-1.0645e-01, -2.2461e-01,  6.9141e-01,  ...,  6.5625e-01,
           8.3125e+00, -1.0547e+00],
         [ 8.3008e-02,  2.9688e-01,  1.5430e-01,  ...,  9.1797e-01,
           5.9375e+00, -2.1094e+00],
         [-7.9688e-01,  5.8594e-01, -2.3340e-01,  ..., -1.5938e+00,
           7.2188e+00,  2.4688e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-2.7832e-02, -1.5869e-02, -4.3701e-02,  ...,  5.1270e-03,
           9.5215e-03,  1.0620e-02],
         [ 5.9375e-01, -4.9414e-01,  4.2969e-01,  ...,  1.3086e-01,
          -1.9434e-01,  3.2812e-01],
         [ 2.7734e-01, -1.6504e-01,  7.1484e-01,  ...,  1.1797e+00,
           2.8906e-01,  2.6245e-02],
         ...,
         [ 2.3633e-01, -1.5312e+00, -3.9258e-01,  ...,  2.0508e-01,
           4.6875e-01,  9.9121e-02],
         [ 1.9629e-01, -1.7500e+00, -7.6562e-01,  ...,  1.1875e+00,
           3.7305e-01,  8.2031e-01],
         [ 5.9766e-01, -5.2979e-02,  1.0791e-01,  ...,  8.1543e-02,
           7.0703e-01, -9.5215e-02]],

        [[-2.7832e-02, -1.5869e-02, -4.3701e-02,  ...,  5.1270e-03,
           9.5215e-03,  1.0620e-02],
         [ 5.9375e-01, -4.9414e-01,  4.2969e-01,  ...,  1.3086e-01,
          -1.9434e-01,  3.2812e-01],
         [ 2.7734e-01, -1.6504e-01,  7.1484e-01,  ...,  1.1797e+00,
           2.8906e-01,  2.6245e-02],
         ...,
         [ 2.3633e-01, -1.5312e+00, -3.9258e-01,  ...,  2.0508e-01,
           4.6875e-01,  9.9121e-02],
         [ 1.9629e-01, -1.7500e+00, -7.6562e-01,  ...,  1.1875e+00,
           3.7305e-01,  8.2031e-01],
         [ 5.9766e-01, -5.2979e-02,  1.0791e-01,  ...,  8.1543e-02,
           7.0703e-01, -9.5215e-02]],

        [[-2.7832e-02, -1.5869e-02, -4.3701e-02,  ...,  5.1270e-03,
           9.5215e-03,  1.0620e-02],
         [ 5.9375e-01, -4.9414e-01,  4.2969e-01,  ...,  1.3086e-01,
          -1.9434e-01,  3.2812e-01],
         [ 2.7734e-01, -1.6504e-01,  7.1484e-01,  ...,  1.1797e+00,
           2.8906e-01,  2.6245e-02],
         ...,
         [ 2.3633e-01, -1.5312e+00, -3.9258e-01,  ...,  2.0508e-01,
           4.6875e-01,  9.9121e-02],
         [ 1.9629e-01, -1.7500e+00, -7.6562e-01,  ...,  1.1875e+00,
           3.7305e-01,  8.2031e-01],
         [ 5.9766e-01, -5.2979e-02,  1.0791e-01,  ...,  8.1543e-02,
           7.0703e-01, -9.5215e-02]],

        ...,

        [[-2.3438e-02,  2.6978e-02,  1.3062e-02,  ...,  7.7438e-04,
           1.4954e-02,  1.2283e-03],
         [ 6.6406e-01, -8.2031e-01, -3.6523e-01,  ..., -5.2795e-03,
          -1.8359e-01,  2.6367e-01],
         [-1.0498e-01, -6.3281e-01, -3.3789e-01,  ..., -1.1250e+00,
          -7.9688e-01,  8.9844e-02],
         ...,
         [ 1.9688e+00, -5.5469e-01,  3.4375e-01,  ..., -3.0469e-01,
          -1.8359e+00, -1.1250e+00],
         [ 3.9648e-01,  2.4707e-01,  3.3203e-01,  ...,  4.6387e-02,
          -1.0703e+00, -6.7578e-01],
         [-2.1387e-01, -6.0938e-01, -9.7266e-01,  ..., -2.6758e-01,
          -9.7656e-02,  5.2734e-01]],

        [[-2.3438e-02,  2.6978e-02,  1.3062e-02,  ...,  7.7438e-04,
           1.4954e-02,  1.2283e-03],
         [ 6.6406e-01, -8.2031e-01, -3.6523e-01,  ..., -5.2795e-03,
          -1.8359e-01,  2.6367e-01],
         [-1.0498e-01, -6.3281e-01, -3.3789e-01,  ..., -1.1250e+00,
          -7.9688e-01,  8.9844e-02],
         ...,
         [ 1.9688e+00, -5.5469e-01,  3.4375e-01,  ..., -3.0469e-01,
          -1.8359e+00, -1.1250e+00],
         [ 3.9648e-01,  2.4707e-01,  3.3203e-01,  ...,  4.6387e-02,
          -1.0703e+00, -6.7578e-01],
         [-2.1387e-01, -6.0938e-01, -9.7266e-01,  ..., -2.6758e-01,
          -9.7656e-02,  5.2734e-01]],

        [[-2.3438e-02,  2.6978e-02,  1.3062e-02,  ...,  7.7438e-04,
           1.4954e-02,  1.2283e-03],
         [ 6.6406e-01, -8.2031e-01, -3.6523e-01,  ..., -5.2795e-03,
          -1.8359e-01,  2.6367e-01],
         [-1.0498e-01, -6.3281e-01, -3.3789e-01,  ..., -1.1250e+00,
          -7.9688e-01,  8.9844e-02],
         ...,
         [ 1.9688e+00, -5.5469e-01,  3.4375e-01,  ..., -3.0469e-01,
          -1.8359e+00, -1.1250e+00],
         [ 3.9648e-01,  2.4707e-01,  3.3203e-01,  ...,  4.6387e-02,
          -1.0703e+00, -6.7578e-01],
         [-2.1387e-01, -6.0938e-01, -9.7266e-01,  ..., -2.6758e-01,
          -9.7656e-02,  5.2734e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 7.9727e-04,  1.0925e-02, -3.7842e-03,  ..., -1.3184e-01,
          -2.9297e-01, -9.7168e-02],
         [ 2.0000e+00,  7.6172e-01, -2.5938e+00,  ...,  9.4238e-02,
           6.7188e-01, -9.1406e-01],
         [ 6.0156e-01, -1.2344e+00, -1.1250e+00,  ...,  8.0078e-02,
           1.8672e+00, -4.7461e-01],
         ...,
         [ 0.0000e+00, -3.4375e-01,  2.4023e-01,  ...,  1.5703e+00,
          -5.7031e-01, -5.5469e-01],
         [ 7.8125e-03,  3.3398e-01, -1.8555e-01,  ...,  4.0430e-01,
          -5.6250e-01, -9.7656e-01],
         [-2.0781e+00,  5.2734e-01, -1.5938e+00,  ..., -2.3594e+00,
          -2.4062e+00, -1.7734e+00]],

        [[ 7.9727e-04,  1.0925e-02, -3.7842e-03,  ..., -1.3184e-01,
          -2.9297e-01, -9.7168e-02],
         [ 2.0000e+00,  7.6172e-01, -2.5938e+00,  ...,  9.4238e-02,
           6.7188e-01, -9.1406e-01],
         [ 6.0156e-01, -1.2344e+00, -1.1250e+00,  ...,  8.0078e-02,
           1.8672e+00, -4.7461e-01],
         ...,
         [ 0.0000e+00, -3.4375e-01,  2.4023e-01,  ...,  1.5703e+00,
          -5.7031e-01, -5.5469e-01],
         [ 7.8125e-03,  3.3398e-01, -1.8555e-01,  ...,  4.0430e-01,
          -5.6250e-01, -9.7656e-01],
         [-2.0781e+00,  5.2734e-01, -1.5938e+00,  ..., -2.3594e+00,
          -2.4062e+00, -1.7734e+00]],

        [[ 7.9727e-04,  1.0925e-02, -3.7842e-03,  ..., -1.3184e-01,
          -2.9297e-01, -9.7168e-02],
         [ 2.0000e+00,  7.6172e-01, -2.5938e+00,  ...,  9.4238e-02,
           6.7188e-01, -9.1406e-01],
         [ 6.0156e-01, -1.2344e+00, -1.1250e+00,  ...,  8.0078e-02,
           1.8672e+00, -4.7461e-01],
         ...,
         [ 0.0000e+00, -3.4375e-01,  2.4023e-01,  ...,  1.5703e+00,
          -5.7031e-01, -5.5469e-01],
         [ 7.8125e-03,  3.3398e-01, -1.8555e-01,  ...,  4.0430e-01,
          -5.6250e-01, -9.7656e-01],
         [-2.0781e+00,  5.2734e-01, -1.5938e+00,  ..., -2.3594e+00,
          -2.4062e+00, -1.7734e+00]],

        ...,

        [[ 3.6469e-03,  1.0742e-02,  3.5706e-03,  ..., -2.5156e+00,
          -1.3672e-02,  1.0498e-01],
         [ 1.4219e+00,  1.2500e-01,  5.7812e-01,  ...,  8.2500e+00,
          -4.3945e-01, -1.3047e+00],
         [-5.0781e-01,  6.7969e-01, -1.4766e+00,  ...,  8.4375e+00,
           1.2422e+00, -1.2578e+00],
         ...,
         [-1.0703e+00,  1.0312e+00,  1.5430e-01,  ...,  7.1875e+00,
          -1.1094e+00,  7.3828e-01],
         [-1.3086e-01, -1.7383e-01,  3.3789e-01,  ...,  5.3750e+00,
          -2.3047e-01,  7.4609e-01],
         [-4.8242e-01, -2.8906e-01,  1.1250e+00,  ...,  6.9375e+00,
          -2.0625e+00, -1.1328e+00]],

        [[ 3.6469e-03,  1.0742e-02,  3.5706e-03,  ..., -2.5156e+00,
          -1.3672e-02,  1.0498e-01],
         [ 1.4219e+00,  1.2500e-01,  5.7812e-01,  ...,  8.2500e+00,
          -4.3945e-01, -1.3047e+00],
         [-5.0781e-01,  6.7969e-01, -1.4766e+00,  ...,  8.4375e+00,
           1.2422e+00, -1.2578e+00],
         ...,
         [-1.0703e+00,  1.0312e+00,  1.5430e-01,  ...,  7.1875e+00,
          -1.1094e+00,  7.3828e-01],
         [-1.3086e-01, -1.7383e-01,  3.3789e-01,  ...,  5.3750e+00,
          -2.3047e-01,  7.4609e-01],
         [-4.8242e-01, -2.8906e-01,  1.1250e+00,  ...,  6.9375e+00,
          -2.0625e+00, -1.1328e+00]],

        [[ 3.6469e-03,  1.0742e-02,  3.5706e-03,  ..., -2.5156e+00,
          -1.3672e-02,  1.0498e-01],
         [ 1.4219e+00,  1.2500e-01,  5.7812e-01,  ...,  8.2500e+00,
          -4.3945e-01, -1.3047e+00],
         [-5.0781e-01,  6.7969e-01, -1.4766e+00,  ...,  8.4375e+00,
           1.2422e+00, -1.2578e+00],
         ...,
         [-1.0703e+00,  1.0312e+00,  1.5430e-01,  ...,  7.1875e+00,
          -1.1094e+00,  7.3828e-01],
         [-1.3086e-01, -1.7383e-01,  3.3789e-01,  ...,  5.3750e+00,
          -2.3047e-01,  7.4609e-01],
         [-4.8242e-01, -2.8906e-01,  1.1250e+00,  ...,  6.9375e+00,
          -2.0625e+00, -1.1328e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.9287e-02,  2.2461e-02, -3.5645e-02,  ..., -2.5391e-02,
          -1.1841e-02, -1.5564e-02],
         [ 4.3359e-01,  2.4805e-01, -4.0234e-01,  ...,  2.8125e-01,
          -6.6016e-01, -6.8750e-01],
         [-2.2949e-02, -7.6660e-02, -1.0547e-01,  ..., -5.8594e-01,
          -3.2617e-01,  8.6426e-02],
         ...,
         [-3.8867e-01, -6.7969e-01,  6.2500e-01,  ..., -7.9102e-02,
          -1.2031e+00, -1.2207e-01],
         [-1.3770e-01, -7.2656e-01, -1.7456e-02,  ...,  1.7676e-01,
          -6.8359e-01,  4.4336e-01],
         [ 4.5117e-01,  1.2266e+00,  1.1328e-01,  ...,  8.0078e-01,
          -2.0410e-01, -1.1475e-01]],

        [[-1.9287e-02,  2.2461e-02, -3.5645e-02,  ..., -2.5391e-02,
          -1.1841e-02, -1.5564e-02],
         [ 4.3359e-01,  2.4805e-01, -4.0234e-01,  ...,  2.8125e-01,
          -6.6016e-01, -6.8750e-01],
         [-2.2949e-02, -7.6660e-02, -1.0547e-01,  ..., -5.8594e-01,
          -3.2617e-01,  8.6426e-02],
         ...,
         [-3.8867e-01, -6.7969e-01,  6.2500e-01,  ..., -7.9102e-02,
          -1.2031e+00, -1.2207e-01],
         [-1.3770e-01, -7.2656e-01, -1.7456e-02,  ...,  1.7676e-01,
          -6.8359e-01,  4.4336e-01],
         [ 4.5117e-01,  1.2266e+00,  1.1328e-01,  ...,  8.0078e-01,
          -2.0410e-01, -1.1475e-01]],

        [[-1.9287e-02,  2.2461e-02, -3.5645e-02,  ..., -2.5391e-02,
          -1.1841e-02, -1.5564e-02],
         [ 4.3359e-01,  2.4805e-01, -4.0234e-01,  ...,  2.8125e-01,
          -6.6016e-01, -6.8750e-01],
         [-2.2949e-02, -7.6660e-02, -1.0547e-01,  ..., -5.8594e-01,
          -3.2617e-01,  8.6426e-02],
         ...,
         [-3.8867e-01, -6.7969e-01,  6.2500e-01,  ..., -7.9102e-02,
          -1.2031e+00, -1.2207e-01],
         [-1.3770e-01, -7.2656e-01, -1.7456e-02,  ...,  1.7676e-01,
          -6.8359e-01,  4.4336e-01],
         [ 4.5117e-01,  1.2266e+00,  1.1328e-01,  ...,  8.0078e-01,
          -2.0410e-01, -1.1475e-01]],

        ...,

        [[ 1.9653e-02,  5.0049e-03, -5.9509e-03,  ...,  3.7354e-02,
          -1.7090e-02,  1.6022e-03],
         [-8.8379e-02, -1.9238e-01,  2.8320e-01,  ...,  3.4570e-01,
          -5.9082e-02,  5.5908e-02],
         [ 3.1250e-01, -8.3594e-01,  3.4375e-01,  ...,  5.9326e-02,
           3.1250e-01, -3.3008e-01],
         ...,
         [-7.9688e-01,  5.7422e-01,  6.7871e-02,  ..., -2.0117e-01,
          -5.0391e-01,  8.3008e-02],
         [-5.7617e-02, -2.6733e-02, -6.6376e-04,  ..., -1.5332e-01,
          -2.6367e-01,  3.0859e-01],
         [-2.1875e-01,  3.9844e-01, -2.5000e-01,  ..., -8.3984e-01,
          -1.0156e+00,  1.5723e-01]],

        [[ 1.9653e-02,  5.0049e-03, -5.9509e-03,  ...,  3.7354e-02,
          -1.7090e-02,  1.6022e-03],
         [-8.8379e-02, -1.9238e-01,  2.8320e-01,  ...,  3.4570e-01,
          -5.9082e-02,  5.5908e-02],
         [ 3.1250e-01, -8.3594e-01,  3.4375e-01,  ...,  5.9326e-02,
           3.1250e-01, -3.3008e-01],
         ...,
         [-7.9688e-01,  5.7422e-01,  6.7871e-02,  ..., -2.0117e-01,
          -5.0391e-01,  8.3008e-02],
         [-5.7617e-02, -2.6733e-02, -6.6376e-04,  ..., -1.5332e-01,
          -2.6367e-01,  3.0859e-01],
         [-2.1875e-01,  3.9844e-01, -2.5000e-01,  ..., -8.3984e-01,
          -1.0156e+00,  1.5723e-01]],

        [[ 1.9653e-02,  5.0049e-03, -5.9509e-03,  ...,  3.7354e-02,
          -1.7090e-02,  1.6022e-03],
         [-8.8379e-02, -1.9238e-01,  2.8320e-01,  ...,  3.4570e-01,
          -5.9082e-02,  5.5908e-02],
         [ 3.1250e-01, -8.3594e-01,  3.4375e-01,  ...,  5.9326e-02,
           3.1250e-01, -3.3008e-01],
         ...,
         [-7.9688e-01,  5.7422e-01,  6.7871e-02,  ..., -2.0117e-01,
          -5.0391e-01,  8.3008e-02],
         [-5.7617e-02, -2.6733e-02, -6.6376e-04,  ..., -1.5332e-01,
          -2.6367e-01,  3.0859e-01],
         [-2.1875e-01,  3.9844e-01, -2.5000e-01,  ..., -8.3984e-01,
          -1.0156e+00,  1.5723e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.7700e-02, -1.2329e-02,  1.7456e-02,  ...,  7.7637e-02,
           5.0537e-02, -1.0107e-01],
         [ 2.0000e+00,  1.4551e-01,  2.0469e+00,  ..., -4.6289e-01,
          -6.8359e-01, -7.5391e-01],
         [ 1.6562e+00,  1.2734e+00, -2.5977e-01,  ...,  2.4512e-01,
          -2.8594e+00,  1.6562e+00],
         ...,
         [ 3.1445e-01,  6.0156e-01, -3.8086e-02,  ..., -1.6953e+00,
           5.4375e+00, -1.1172e+00],
         [-8.0469e-01, -4.6680e-01,  3.3398e-01,  ..., -3.8672e-01,
           5.1172e-01, -2.6250e+00],
         [-5.6641e-01, -4.4531e-01,  1.4219e+00,  ..., -6.8359e-01,
           9.9121e-02,  1.1562e+00]],

        [[-1.7700e-02, -1.2329e-02,  1.7456e-02,  ...,  7.7637e-02,
           5.0537e-02, -1.0107e-01],
         [ 2.0000e+00,  1.4551e-01,  2.0469e+00,  ..., -4.6289e-01,
          -6.8359e-01, -7.5391e-01],
         [ 1.6562e+00,  1.2734e+00, -2.5977e-01,  ...,  2.4512e-01,
          -2.8594e+00,  1.6562e+00],
         ...,
         [ 3.1445e-01,  6.0156e-01, -3.8086e-02,  ..., -1.6953e+00,
           5.4375e+00, -1.1172e+00],
         [-8.0469e-01, -4.6680e-01,  3.3398e-01,  ..., -3.8672e-01,
           5.1172e-01, -2.6250e+00],
         [-5.6641e-01, -4.4531e-01,  1.4219e+00,  ..., -6.8359e-01,
           9.9121e-02,  1.1562e+00]],

        [[-1.7700e-02, -1.2329e-02,  1.7456e-02,  ...,  7.7637e-02,
           5.0537e-02, -1.0107e-01],
         [ 2.0000e+00,  1.4551e-01,  2.0469e+00,  ..., -4.6289e-01,
          -6.8359e-01, -7.5391e-01],
         [ 1.6562e+00,  1.2734e+00, -2.5977e-01,  ...,  2.4512e-01,
          -2.8594e+00,  1.6562e+00],
         ...,
         [ 3.1445e-01,  6.0156e-01, -3.8086e-02,  ..., -1.6953e+00,
           5.4375e+00, -1.1172e+00],
         [-8.0469e-01, -4.6680e-01,  3.3398e-01,  ..., -3.8672e-01,
           5.1172e-01, -2.6250e+00],
         [-5.6641e-01, -4.4531e-01,  1.4219e+00,  ..., -6.8359e-01,
           9.9121e-02,  1.1562e+00]],

        ...,

        [[-3.5706e-03,  6.9275e-03,  1.4221e-02,  ..., -1.7285e-01,
           1.7773e-01, -1.6968e-02],
         [ 2.4062e+00,  1.9062e+00,  2.1777e-01,  ..., -1.5156e+00,
          -1.3750e+00, -9.2578e-01],
         [ 9.4922e-01,  1.1484e+00,  3.0273e-01,  ..., -3.0938e+00,
          -2.1250e+00, -2.0938e+00],
         ...,
         [ 1.9531e+00, -1.5820e-01, -4.6875e-01,  ...,  7.3438e-01,
           7.0703e-01,  1.8984e+00],
         [-2.0508e-01,  1.2500e-01, -6.8359e-01,  ..., -7.4707e-02,
           5.0391e-01,  2.5469e+00],
         [-1.3125e+00, -2.1406e+00, -1.0391e+00,  ..., -6.2500e-01,
          -1.4453e+00,  1.4453e+00]],

        [[-3.5706e-03,  6.9275e-03,  1.4221e-02,  ..., -1.7285e-01,
           1.7773e-01, -1.6968e-02],
         [ 2.4062e+00,  1.9062e+00,  2.1777e-01,  ..., -1.5156e+00,
          -1.3750e+00, -9.2578e-01],
         [ 9.4922e-01,  1.1484e+00,  3.0273e-01,  ..., -3.0938e+00,
          -2.1250e+00, -2.0938e+00],
         ...,
         [ 1.9531e+00, -1.5820e-01, -4.6875e-01,  ...,  7.3438e-01,
           7.0703e-01,  1.8984e+00],
         [-2.0508e-01,  1.2500e-01, -6.8359e-01,  ..., -7.4707e-02,
           5.0391e-01,  2.5469e+00],
         [-1.3125e+00, -2.1406e+00, -1.0391e+00,  ..., -6.2500e-01,
          -1.4453e+00,  1.4453e+00]],

        [[-3.5706e-03,  6.9275e-03,  1.4221e-02,  ..., -1.7285e-01,
           1.7773e-01, -1.6968e-02],
         [ 2.4062e+00,  1.9062e+00,  2.1777e-01,  ..., -1.5156e+00,
          -1.3750e+00, -9.2578e-01],
         [ 9.4922e-01,  1.1484e+00,  3.0273e-01,  ..., -3.0938e+00,
          -2.1250e+00, -2.0938e+00],
         ...,
         [ 1.9531e+00, -1.5820e-01, -4.6875e-01,  ...,  7.3438e-01,
           7.0703e-01,  1.8984e+00],
         [-2.0508e-01,  1.2500e-01, -6.8359e-01,  ..., -7.4707e-02,
           5.0391e-01,  2.5469e+00],
         [-1.3125e+00, -2.1406e+00, -1.0391e+00,  ..., -6.2500e-01,
          -1.4453e+00,  1.4453e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 0.0537, -0.0067, -0.0131,  ...,  0.0227, -0.0175,  0.0188],
         [-0.4023, -0.3379, -0.4727,  ..., -0.3730, -0.2930,  0.3906],
         [-0.2852, -0.7734,  0.0374,  ..., -0.4922, -0.1074,  0.3262],
         ...,
         [ 0.9141,  0.0201,  0.0767,  ...,  0.3379,  0.7227, -0.1226],
         [ 0.4629, -0.3047, -1.0156,  ..., -0.3867,  0.9883, -0.2891],
         [-0.3574, -0.0312,  0.0036,  ...,  0.0957,  0.3242,  0.3711]],

        [[ 0.0537, -0.0067, -0.0131,  ...,  0.0227, -0.0175,  0.0188],
         [-0.4023, -0.3379, -0.4727,  ..., -0.3730, -0.2930,  0.3906],
         [-0.2852, -0.7734,  0.0374,  ..., -0.4922, -0.1074,  0.3262],
         ...,
         [ 0.9141,  0.0201,  0.0767,  ...,  0.3379,  0.7227, -0.1226],
         [ 0.4629, -0.3047, -1.0156,  ..., -0.3867,  0.9883, -0.2891],
         [-0.3574, -0.0312,  0.0036,  ...,  0.0957,  0.3242,  0.3711]],

        [[ 0.0537, -0.0067, -0.0131,  ...,  0.0227, -0.0175,  0.0188],
         [-0.4023, -0.3379, -0.4727,  ..., -0.3730, -0.2930,  0.3906],
         [-0.2852, -0.7734,  0.0374,  ..., -0.4922, -0.1074,  0.3262],
         ...,
         [ 0.9141,  0.0201,  0.0767,  ...,  0.3379,  0.7227, -0.1226],
         [ 0.4629, -0.3047, -1.0156,  ..., -0.3867,  0.9883, -0.2891],
         [-0.3574, -0.0312,  0.0036,  ...,  0.0957,  0.3242,  0.3711]],

        ...,

        [[-0.0300, -0.0415, -0.0488,  ..., -0.0352,  0.0204,  0.0442],
         [-0.2051,  0.1680,  0.2500,  ..., -0.3320, -0.2949,  0.1670],
         [ 0.0203, -0.3398, -0.0371,  ..., -0.2354, -0.0206,  0.0171],
         ...,
         [ 0.3613, -0.0474,  0.5508,  ...,  0.7266,  0.5312,  0.2363],
         [ 0.0061, -0.2520,  0.2520,  ...,  0.5820, -0.9219, -0.2422],
         [ 0.3750,  0.4922,  0.8398,  ...,  0.4512,  0.2139,  0.4316]],

        [[-0.0300, -0.0415, -0.0488,  ..., -0.0352,  0.0204,  0.0442],
         [-0.2051,  0.1680,  0.2500,  ..., -0.3320, -0.2949,  0.1670],
         [ 0.0203, -0.3398, -0.0371,  ..., -0.2354, -0.0206,  0.0171],
         ...,
         [ 0.3613, -0.0474,  0.5508,  ...,  0.7266,  0.5312,  0.2363],
         [ 0.0061, -0.2520,  0.2520,  ...,  0.5820, -0.9219, -0.2422],
         [ 0.3750,  0.4922,  0.8398,  ...,  0.4512,  0.2139,  0.4316]],

        [[-0.0300, -0.0415, -0.0488,  ..., -0.0352,  0.0204,  0.0442],
         [-0.2051,  0.1680,  0.2500,  ..., -0.3320, -0.2949,  0.1670],
         [ 0.0203, -0.3398, -0.0371,  ..., -0.2354, -0.0206,  0.0171],
         ...,
         [ 0.3613, -0.0474,  0.5508,  ...,  0.7266,  0.5312,  0.2363],
         [ 0.0061, -0.2520,  0.2520,  ...,  0.5820, -0.9219, -0.2422],
         [ 0.3750,  0.4922,  0.8398,  ...,  0.4512,  0.2139,  0.4316]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[-7.9956e-03, -3.5553e-03, -3.8147e-03,  ..., -1.2500e+00,
           1.5747e-02,  6.2012e-02],
         [ 7.0312e-01,  3.6719e-01,  1.5820e-01,  ...,  2.6172e-01,
          -5.8594e-01, -1.5625e-01],
         [-6.3281e-01,  7.7734e-01, -8.4229e-03,  ...,  1.1719e+00,
          -1.7109e+00, -2.7500e+00],
         ...,
         [-1.0312e+00,  1.3086e-01,  2.3828e-01,  ..., -1.7109e+00,
          -2.5156e+00,  7.3438e-01],
         [-3.6523e-01, -8.7891e-02, -4.9414e-01,  ...,  6.7871e-02,
          -1.5156e+00, -7.3438e-01],
         [-5.7422e-01, -1.4062e+00, -1.4766e+00,  ...,  6.8750e-01,
           1.4609e+00, -1.2188e+00]],

        [[-7.9956e-03, -3.5553e-03, -3.8147e-03,  ..., -1.2500e+00,
           1.5747e-02,  6.2012e-02],
         [ 7.0312e-01,  3.6719e-01,  1.5820e-01,  ...,  2.6172e-01,
          -5.8594e-01, -1.5625e-01],
         [-6.3281e-01,  7.7734e-01, -8.4229e-03,  ...,  1.1719e+00,
          -1.7109e+00, -2.7500e+00],
         ...,
         [-1.0312e+00,  1.3086e-01,  2.3828e-01,  ..., -1.7109e+00,
          -2.5156e+00,  7.3438e-01],
         [-3.6523e-01, -8.7891e-02, -4.9414e-01,  ...,  6.7871e-02,
          -1.5156e+00, -7.3438e-01],
         [-5.7422e-01, -1.4062e+00, -1.4766e+00,  ...,  6.8750e-01,
           1.4609e+00, -1.2188e+00]],

        [[-7.9956e-03, -3.5553e-03, -3.8147e-03,  ..., -1.2500e+00,
           1.5747e-02,  6.2012e-02],
         [ 7.0312e-01,  3.6719e-01,  1.5820e-01,  ...,  2.6172e-01,
          -5.8594e-01, -1.5625e-01],
         [-6.3281e-01,  7.7734e-01, -8.4229e-03,  ...,  1.1719e+00,
          -1.7109e+00, -2.7500e+00],
         ...,
         [-1.0312e+00,  1.3086e-01,  2.3828e-01,  ..., -1.7109e+00,
          -2.5156e+00,  7.3438e-01],
         [-3.6523e-01, -8.7891e-02, -4.9414e-01,  ...,  6.7871e-02,
          -1.5156e+00, -7.3438e-01],
         [-5.7422e-01, -1.4062e+00, -1.4766e+00,  ...,  6.8750e-01,
           1.4609e+00, -1.2188e+00]],

        ...,

        [[-1.8692e-03,  1.0300e-03, -7.0496e-03,  ...,  1.4375e+00,
          -2.7969e+00, -3.5156e-01],
         [ 1.7188e-01, -2.3281e+00, -9.9219e-01,  ..., -1.9219e+00,
           6.9375e+00,  1.6641e+00],
         [ 1.5527e-01, -3.7109e-01, -6.2500e-01,  ..., -2.3906e+00,
           7.8438e+00,  4.6562e+00],
         ...,
         [ 9.3750e-01,  9.6484e-01,  1.3770e-01,  ..., -1.2812e+00,
           6.0625e+00, -4.2188e+00],
         [-4.9609e-01,  5.3125e-01,  1.7285e-01,  ..., -8.5938e-01,
           4.5625e+00, -1.1719e+00],
         [ 1.4258e-01,  1.9375e+00,  3.6523e-01,  ..., -4.2188e-01,
           5.3438e+00,  2.2969e+00]],

        [[-1.8692e-03,  1.0300e-03, -7.0496e-03,  ...,  1.4375e+00,
          -2.7969e+00, -3.5156e-01],
         [ 1.7188e-01, -2.3281e+00, -9.9219e-01,  ..., -1.9219e+00,
           6.9375e+00,  1.6641e+00],
         [ 1.5527e-01, -3.7109e-01, -6.2500e-01,  ..., -2.3906e+00,
           7.8438e+00,  4.6562e+00],
         ...,
         [ 9.3750e-01,  9.6484e-01,  1.3770e-01,  ..., -1.2812e+00,
           6.0625e+00, -4.2188e+00],
         [-4.9609e-01,  5.3125e-01,  1.7285e-01,  ..., -8.5938e-01,
           4.5625e+00, -1.1719e+00],
         [ 1.4258e-01,  1.9375e+00,  3.6523e-01,  ..., -4.2188e-01,
           5.3438e+00,  2.2969e+00]],

        [[-1.8692e-03,  1.0300e-03, -7.0496e-03,  ...,  1.4375e+00,
          -2.7969e+00, -3.5156e-01],
         [ 1.7188e-01, -2.3281e+00, -9.9219e-01,  ..., -1.9219e+00,
           6.9375e+00,  1.6641e+00],
         [ 1.5527e-01, -3.7109e-01, -6.2500e-01,  ..., -2.3906e+00,
           7.8438e+00,  4.6562e+00],
         ...,
         [ 9.3750e-01,  9.6484e-01,  1.3770e-01,  ..., -1.2812e+00,
           6.0625e+00, -4.2188e+00],
         [-4.9609e-01,  5.3125e-01,  1.7285e-01,  ..., -8.5938e-01,
           4.5625e+00, -1.1719e+00],
         [ 1.4258e-01,  1.9375e+00,  3.6523e-01,  ..., -4.2188e-01,
           5.3438e+00,  2.2969e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 0.0623,  0.0913, -0.0081,  ...,  0.0413, -0.0123,  0.0043],
         [-0.5742,  0.0535, -0.3262,  ...,  0.3320,  0.4824,  0.2266],
         [-0.7031, -0.4531,  0.2227,  ..., -0.0579,  0.4180,  0.1084],
         ...,
         [ 0.5781, -0.4688,  0.4727,  ...,  0.2832, -0.0732,  0.8008],
         [ 0.0183, -0.4453,  0.8086,  ...,  0.1299,  0.0757,  0.7695],
         [ 0.2207, -0.5938, -1.3984,  ...,  0.3223,  0.6367, -0.5078]],

        [[ 0.0623,  0.0913, -0.0081,  ...,  0.0413, -0.0123,  0.0043],
         [-0.5742,  0.0535, -0.3262,  ...,  0.3320,  0.4824,  0.2266],
         [-0.7031, -0.4531,  0.2227,  ..., -0.0579,  0.4180,  0.1084],
         ...,
         [ 0.5781, -0.4688,  0.4727,  ...,  0.2832, -0.0732,  0.8008],
         [ 0.0183, -0.4453,  0.8086,  ...,  0.1299,  0.0757,  0.7695],
         [ 0.2207, -0.5938, -1.3984,  ...,  0.3223,  0.6367, -0.5078]],

        [[ 0.0623,  0.0913, -0.0081,  ...,  0.0413, -0.0123,  0.0043],
         [-0.5742,  0.0535, -0.3262,  ...,  0.3320,  0.4824,  0.2266],
         [-0.7031, -0.4531,  0.2227,  ..., -0.0579,  0.4180,  0.1084],
         ...,
         [ 0.5781, -0.4688,  0.4727,  ...,  0.2832, -0.0732,  0.8008],
         [ 0.0183, -0.4453,  0.8086,  ...,  0.1299,  0.0757,  0.7695],
         [ 0.2207, -0.5938, -1.3984,  ...,  0.3223,  0.6367, -0.5078]],

        ...,

        [[-0.0090,  0.0481,  0.0194,  ...,  0.0170,  0.0413,  0.0540],
         [-0.3809, -0.5859,  0.0077,  ...,  0.2715,  0.6562, -0.1631],
         [-0.8516, -0.4805, -0.6523,  ..., -0.4141,  0.1099, -0.4941],
         ...,
         [ 0.9414, -0.0767, -0.1104,  ..., -0.5312,  0.1113, -0.4844],
         [ 1.3125,  0.2100, -0.3555,  ..., -0.5938,  0.0104, -0.3984],
         [-0.3887, -0.5547, -0.7422,  ...,  0.2266, -0.4219, -0.0659]],

        [[-0.0090,  0.0481,  0.0194,  ...,  0.0170,  0.0413,  0.0540],
         [-0.3809, -0.5859,  0.0077,  ...,  0.2715,  0.6562, -0.1631],
         [-0.8516, -0.4805, -0.6523,  ..., -0.4141,  0.1099, -0.4941],
         ...,
         [ 0.9414, -0.0767, -0.1104,  ..., -0.5312,  0.1113, -0.4844],
         [ 1.3125,  0.2100, -0.3555,  ..., -0.5938,  0.0104, -0.3984],
         [-0.3887, -0.5547, -0.7422,  ...,  0.2266, -0.4219, -0.0659]],

        [[-0.0090,  0.0481,  0.0194,  ...,  0.0170,  0.0413,  0.0540],
         [-0.3809, -0.5859,  0.0077,  ...,  0.2715,  0.6562, -0.1631],
         [-0.8516, -0.4805, -0.6523,  ..., -0.4141,  0.1099, -0.4941],
         ...,
         [ 0.9414, -0.0767, -0.1104,  ..., -0.5312,  0.1113, -0.4844],
         [ 1.3125,  0.2100, -0.3555,  ..., -0.5938,  0.0104, -0.3984],
         [-0.3887, -0.5547, -0.7422,  ...,  0.2266, -0.4219, -0.0659]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[-7.8125e-03,  3.9673e-03, -4.2419e-03,  ..., -6.7188e-01,
          -1.7090e-01,  9.5215e-02],
         [ 1.6406e+00,  9.9609e-01,  1.7422e+00,  ..., -1.0781e+00,
           1.0469e+00,  1.4141e+00],
         [ 9.2578e-01,  9.8047e-01,  1.8828e+00,  ..., -2.4062e+00,
          -1.1328e+00,  2.3750e+00],
         ...,
         [ 1.0781e+00,  7.4219e-01,  1.1328e+00,  ..., -3.3447e-02,
           4.1797e-01,  8.0078e-01],
         [-4.8047e-01,  2.4512e-01,  8.0469e-01,  ..., -1.1084e-01,
          -1.2891e+00,  1.4375e+00],
         [ 2.4414e-01, -1.0156e+00,  2.1406e+00,  ..., -1.4746e-01,
           1.3281e+00,  3.9551e-02]],

        [[-7.8125e-03,  3.9673e-03, -4.2419e-03,  ..., -6.7188e-01,
          -1.7090e-01,  9.5215e-02],
         [ 1.6406e+00,  9.9609e-01,  1.7422e+00,  ..., -1.0781e+00,
           1.0469e+00,  1.4141e+00],
         [ 9.2578e-01,  9.8047e-01,  1.8828e+00,  ..., -2.4062e+00,
          -1.1328e+00,  2.3750e+00],
         ...,
         [ 1.0781e+00,  7.4219e-01,  1.1328e+00,  ..., -3.3447e-02,
           4.1797e-01,  8.0078e-01],
         [-4.8047e-01,  2.4512e-01,  8.0469e-01,  ..., -1.1084e-01,
          -1.2891e+00,  1.4375e+00],
         [ 2.4414e-01, -1.0156e+00,  2.1406e+00,  ..., -1.4746e-01,
           1.3281e+00,  3.9551e-02]],

        [[-7.8125e-03,  3.9673e-03, -4.2419e-03,  ..., -6.7188e-01,
          -1.7090e-01,  9.5215e-02],
         [ 1.6406e+00,  9.9609e-01,  1.7422e+00,  ..., -1.0781e+00,
           1.0469e+00,  1.4141e+00],
         [ 9.2578e-01,  9.8047e-01,  1.8828e+00,  ..., -2.4062e+00,
          -1.1328e+00,  2.3750e+00],
         ...,
         [ 1.0781e+00,  7.4219e-01,  1.1328e+00,  ..., -3.3447e-02,
           4.1797e-01,  8.0078e-01],
         [-4.8047e-01,  2.4512e-01,  8.0469e-01,  ..., -1.1084e-01,
          -1.2891e+00,  1.4375e+00],
         [ 2.4414e-01, -1.0156e+00,  2.1406e+00,  ..., -1.4746e-01,
           1.3281e+00,  3.9551e-02]],

        ...,

        [[ 1.5198e-02,  1.0864e-02, -6.3477e-03,  ...,  1.2146e-02,
           1.2451e-02,  2.3281e+00],
         [ 8.3984e-01,  2.2344e+00,  1.0391e+00,  ...,  3.3398e-01,
          -3.8438e+00, -7.8438e+00],
         [-2.1484e-01,  1.6562e+00,  1.2695e-01,  ...,  5.5859e-01,
          -1.9219e+00, -1.0062e+01],
         ...,
         [-8.4375e-01, -1.0156e+00, -1.2969e+00,  ...,  8.9062e-01,
           3.1055e-01, -4.8438e+00],
         [-4.6094e-01, -5.3125e-01, -7.4609e-01,  ...,  1.6797e+00,
          -6.4062e-01, -5.3438e+00],
         [-1.4609e+00, -1.6719e+00,  1.9531e-03,  ...,  4.9609e-01,
          -2.7188e+00, -5.4688e+00]],

        [[ 1.5198e-02,  1.0864e-02, -6.3477e-03,  ...,  1.2146e-02,
           1.2451e-02,  2.3281e+00],
         [ 8.3984e-01,  2.2344e+00,  1.0391e+00,  ...,  3.3398e-01,
          -3.8438e+00, -7.8438e+00],
         [-2.1484e-01,  1.6562e+00,  1.2695e-01,  ...,  5.5859e-01,
          -1.9219e+00, -1.0062e+01],
         ...,
         [-8.4375e-01, -1.0156e+00, -1.2969e+00,  ...,  8.9062e-01,
           3.1055e-01, -4.8438e+00],
         [-4.6094e-01, -5.3125e-01, -7.4609e-01,  ...,  1.6797e+00,
          -6.4062e-01, -5.3438e+00],
         [-1.4609e+00, -1.6719e+00,  1.9531e-03,  ...,  4.9609e-01,
          -2.7188e+00, -5.4688e+00]],

        [[ 1.5198e-02,  1.0864e-02, -6.3477e-03,  ...,  1.2146e-02,
           1.2451e-02,  2.3281e+00],
         [ 8.3984e-01,  2.2344e+00,  1.0391e+00,  ...,  3.3398e-01,
          -3.8438e+00, -7.8438e+00],
         [-2.1484e-01,  1.6562e+00,  1.2695e-01,  ...,  5.5859e-01,
          -1.9219e+00, -1.0062e+01],
         ...,
         [-8.4375e-01, -1.0156e+00, -1.2969e+00,  ...,  8.9062e-01,
           3.1055e-01, -4.8438e+00],
         [-4.6094e-01, -5.3125e-01, -7.4609e-01,  ...,  1.6797e+00,
          -6.4062e-01, -5.3438e+00],
         [-1.4609e+00, -1.6719e+00,  1.9531e-03,  ...,  4.9609e-01,
          -2.7188e+00, -5.4688e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.0437e-02,  4.6730e-04,  1.5869e-02,  ...,  4.7607e-02,
          -4.8584e-02, -2.1118e-02],
         [-4.8438e-01,  6.2256e-02,  3.7500e-01,  ..., -1.2329e-02,
           6.8054e-03,  4.7461e-01],
         [ 2.6367e-01, -3.1641e-01,  3.7109e-01,  ..., -4.0039e-02,
           4.7461e-01,  4.8242e-01],
         ...,
         [ 1.5332e-01,  4.4531e-01,  6.9531e-01,  ..., -6.4844e-01,
          -5.8984e-01, -6.6406e-02],
         [ 2.7930e-01,  2.0752e-02, -3.1836e-01,  ..., -1.6602e-01,
           6.7578e-01, -1.7773e-01],
         [-2.0215e-01,  4.3555e-01,  8.1543e-02,  ...,  1.6699e-01,
          -3.8477e-01,  2.6367e-01]],

        [[-1.0437e-02,  4.6730e-04,  1.5869e-02,  ...,  4.7607e-02,
          -4.8584e-02, -2.1118e-02],
         [-4.8438e-01,  6.2256e-02,  3.7500e-01,  ..., -1.2329e-02,
           6.8054e-03,  4.7461e-01],
         [ 2.6367e-01, -3.1641e-01,  3.7109e-01,  ..., -4.0039e-02,
           4.7461e-01,  4.8242e-01],
         ...,
         [ 1.5332e-01,  4.4531e-01,  6.9531e-01,  ..., -6.4844e-01,
          -5.8984e-01, -6.6406e-02],
         [ 2.7930e-01,  2.0752e-02, -3.1836e-01,  ..., -1.6602e-01,
           6.7578e-01, -1.7773e-01],
         [-2.0215e-01,  4.3555e-01,  8.1543e-02,  ...,  1.6699e-01,
          -3.8477e-01,  2.6367e-01]],

        [[-1.0437e-02,  4.6730e-04,  1.5869e-02,  ...,  4.7607e-02,
          -4.8584e-02, -2.1118e-02],
         [-4.8438e-01,  6.2256e-02,  3.7500e-01,  ..., -1.2329e-02,
           6.8054e-03,  4.7461e-01],
         [ 2.6367e-01, -3.1641e-01,  3.7109e-01,  ..., -4.0039e-02,
           4.7461e-01,  4.8242e-01],
         ...,
         [ 1.5332e-01,  4.4531e-01,  6.9531e-01,  ..., -6.4844e-01,
          -5.8984e-01, -6.6406e-02],
         [ 2.7930e-01,  2.0752e-02, -3.1836e-01,  ..., -1.6602e-01,
           6.7578e-01, -1.7773e-01],
         [-2.0215e-01,  4.3555e-01,  8.1543e-02,  ...,  1.6699e-01,
          -3.8477e-01,  2.6367e-01]],

        ...,

        [[-1.8799e-02,  1.3245e-02, -1.6602e-02,  ..., -1.1292e-02,
           1.3489e-02,  2.1820e-03],
         [ 2.6758e-01, -1.4062e-01, -9.5215e-02,  ..., -1.2500e-01,
           4.9219e-01,  2.7930e-01],
         [-3.2422e-01,  1.4453e-01, -4.1211e-01,  ...,  3.0078e-01,
           3.6328e-01, -1.3672e-01],
         ...,
         [ 2.0801e-01,  2.7734e-01,  3.7598e-02,  ..., -2.9883e-01,
           4.1602e-01, -7.1094e-01],
         [ 1.3965e-01, -5.3125e-01, -5.4297e-01,  ...,  1.8066e-01,
          -1.9434e-01, -7.2266e-02],
         [-6.0120e-03, -2.2656e-01, -4.0234e-01,  ..., -3.0396e-02,
          -3.7695e-01,  2.3145e-01]],

        [[-1.8799e-02,  1.3245e-02, -1.6602e-02,  ..., -1.1292e-02,
           1.3489e-02,  2.1820e-03],
         [ 2.6758e-01, -1.4062e-01, -9.5215e-02,  ..., -1.2500e-01,
           4.9219e-01,  2.7930e-01],
         [-3.2422e-01,  1.4453e-01, -4.1211e-01,  ...,  3.0078e-01,
           3.6328e-01, -1.3672e-01],
         ...,
         [ 2.0801e-01,  2.7734e-01,  3.7598e-02,  ..., -2.9883e-01,
           4.1602e-01, -7.1094e-01],
         [ 1.3965e-01, -5.3125e-01, -5.4297e-01,  ...,  1.8066e-01,
          -1.9434e-01, -7.2266e-02],
         [-6.0120e-03, -2.2656e-01, -4.0234e-01,  ..., -3.0396e-02,
          -3.7695e-01,  2.3145e-01]],

        [[-1.8799e-02,  1.3245e-02, -1.6602e-02,  ..., -1.1292e-02,
           1.3489e-02,  2.1820e-03],
         [ 2.6758e-01, -1.4062e-01, -9.5215e-02,  ..., -1.2500e-01,
           4.9219e-01,  2.7930e-01],
         [-3.2422e-01,  1.4453e-01, -4.1211e-01,  ...,  3.0078e-01,
           3.6328e-01, -1.3672e-01],
         ...,
         [ 2.0801e-01,  2.7734e-01,  3.7598e-02,  ..., -2.9883e-01,
           4.1602e-01, -7.1094e-01],
         [ 1.3965e-01, -5.3125e-01, -5.4297e-01,  ...,  1.8066e-01,
          -1.9434e-01, -7.2266e-02],
         [-6.0120e-03, -2.2656e-01, -4.0234e-01,  ..., -3.0396e-02,
          -3.7695e-01,  2.3145e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-2.4567e-03, -9.7046e-03,  2.9449e-03,  ..., -2.8711e-01,
          -1.3770e-01, -1.1230e-01],
         [-1.0078e+00,  2.8711e-01, -1.0469e+00,  ..., -1.1562e+00,
           3.9688e+00, -3.5156e-01],
         [ 3.0469e-01,  6.7578e-01, -3.7109e-01,  ...,  1.1230e-01,
           7.0703e-01,  2.5625e+00],
         ...,
         [ 2.2461e-01, -1.7656e+00,  1.3281e-01,  ..., -5.0000e-01,
          -1.4922e+00, -1.3516e+00],
         [ 7.4219e-01,  2.6953e-01,  7.0312e-01,  ..., -1.1797e+00,
           9.0625e-01,  1.6895e-01],
         [ 1.7422e+00, -3.1250e-02,  2.4219e-01,  ..., -3.7812e+00,
           1.9766e+00,  1.7734e+00]],

        [[-2.4567e-03, -9.7046e-03,  2.9449e-03,  ..., -2.8711e-01,
          -1.3770e-01, -1.1230e-01],
         [-1.0078e+00,  2.8711e-01, -1.0469e+00,  ..., -1.1562e+00,
           3.9688e+00, -3.5156e-01],
         [ 3.0469e-01,  6.7578e-01, -3.7109e-01,  ...,  1.1230e-01,
           7.0703e-01,  2.5625e+00],
         ...,
         [ 2.2461e-01, -1.7656e+00,  1.3281e-01,  ..., -5.0000e-01,
          -1.4922e+00, -1.3516e+00],
         [ 7.4219e-01,  2.6953e-01,  7.0312e-01,  ..., -1.1797e+00,
           9.0625e-01,  1.6895e-01],
         [ 1.7422e+00, -3.1250e-02,  2.4219e-01,  ..., -3.7812e+00,
           1.9766e+00,  1.7734e+00]],

        [[-2.4567e-03, -9.7046e-03,  2.9449e-03,  ..., -2.8711e-01,
          -1.3770e-01, -1.1230e-01],
         [-1.0078e+00,  2.8711e-01, -1.0469e+00,  ..., -1.1562e+00,
           3.9688e+00, -3.5156e-01],
         [ 3.0469e-01,  6.7578e-01, -3.7109e-01,  ...,  1.1230e-01,
           7.0703e-01,  2.5625e+00],
         ...,
         [ 2.2461e-01, -1.7656e+00,  1.3281e-01,  ..., -5.0000e-01,
          -1.4922e+00, -1.3516e+00],
         [ 7.4219e-01,  2.6953e-01,  7.0312e-01,  ..., -1.1797e+00,
           9.0625e-01,  1.6895e-01],
         [ 1.7422e+00, -3.1250e-02,  2.4219e-01,  ..., -3.7812e+00,
           1.9766e+00,  1.7734e+00]],

        ...,

        [[ 3.6163e-03, -1.3794e-02,  1.1597e-02,  ...,  3.6914e-01,
           7.5684e-02,  1.3184e-01],
         [ 7.3047e-01,  1.1328e+00,  1.2109e+00,  ..., -3.8672e-01,
          -5.5469e-01, -7.9688e-01],
         [ 6.7578e-01, -7.4707e-02,  8.4766e-01,  ...,  3.9844e+00,
          -1.8125e+00,  1.0000e+00],
         ...,
         [ 1.2109e+00, -1.1719e-01,  1.1250e+00,  ..., -1.9688e+00,
           2.1719e+00,  1.1641e+00],
         [-4.9219e-01, -1.7676e-01, -9.9121e-02,  ..., -2.5781e+00,
           9.0234e-01,  1.3047e+00],
         [-5.8203e-01, -9.1797e-01,  1.6875e+00,  ..., -1.1719e+00,
          -8.6914e-02, -1.3984e+00]],

        [[ 3.6163e-03, -1.3794e-02,  1.1597e-02,  ...,  3.6914e-01,
           7.5684e-02,  1.3184e-01],
         [ 7.3047e-01,  1.1328e+00,  1.2109e+00,  ..., -3.8672e-01,
          -5.5469e-01, -7.9688e-01],
         [ 6.7578e-01, -7.4707e-02,  8.4766e-01,  ...,  3.9844e+00,
          -1.8125e+00,  1.0000e+00],
         ...,
         [ 1.2109e+00, -1.1719e-01,  1.1250e+00,  ..., -1.9688e+00,
           2.1719e+00,  1.1641e+00],
         [-4.9219e-01, -1.7676e-01, -9.9121e-02,  ..., -2.5781e+00,
           9.0234e-01,  1.3047e+00],
         [-5.8203e-01, -9.1797e-01,  1.6875e+00,  ..., -1.1719e+00,
          -8.6914e-02, -1.3984e+00]],

        [[ 3.6163e-03, -1.3794e-02,  1.1597e-02,  ...,  3.6914e-01,
           7.5684e-02,  1.3184e-01],
         [ 7.3047e-01,  1.1328e+00,  1.2109e+00,  ..., -3.8672e-01,
          -5.5469e-01, -7.9688e-01],
         [ 6.7578e-01, -7.4707e-02,  8.4766e-01,  ...,  3.9844e+00,
          -1.8125e+00,  1.0000e+00],
         ...,
         [ 1.2109e+00, -1.1719e-01,  1.1250e+00,  ..., -1.9688e+00,
           2.1719e+00,  1.1641e+00],
         [-4.9219e-01, -1.7676e-01, -9.9121e-02,  ..., -2.5781e+00,
           9.0234e-01,  1.3047e+00],
         [-5.8203e-01, -9.1797e-01,  1.6875e+00,  ..., -1.1719e+00,
          -8.6914e-02, -1.3984e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.9775e-02,  3.9551e-02, -2.7344e-02,  ..., -4.4434e-02,
          -7.5684e-02,  1.7944e-02],
         [ 2.5781e-01, -4.1797e-01,  1.6406e-01,  ...,  6.4453e-01,
          -3.4766e-01, -4.9414e-01],
         [-5.0781e-01,  4.7070e-01, -2.5781e-01,  ...,  7.9297e-01,
          -6.0547e-01,  4.2969e-02],
         ...,
         [-6.9922e-01,  1.0889e-01, -1.0156e-01,  ..., -5.6641e-01,
          -3.1836e-01, -6.0156e-01],
         [-8.1250e-01, -4.5508e-01, -2.4902e-01,  ..., -3.4375e-01,
          -2.1191e-01, -1.1406e+00],
         [-3.5547e-01,  1.3574e-01, -6.2891e-01,  ...,  1.0681e-02,
          -1.2695e-01,  1.5234e-01]],

        [[-1.9775e-02,  3.9551e-02, -2.7344e-02,  ..., -4.4434e-02,
          -7.5684e-02,  1.7944e-02],
         [ 2.5781e-01, -4.1797e-01,  1.6406e-01,  ...,  6.4453e-01,
          -3.4766e-01, -4.9414e-01],
         [-5.0781e-01,  4.7070e-01, -2.5781e-01,  ...,  7.9297e-01,
          -6.0547e-01,  4.2969e-02],
         ...,
         [-6.9922e-01,  1.0889e-01, -1.0156e-01,  ..., -5.6641e-01,
          -3.1836e-01, -6.0156e-01],
         [-8.1250e-01, -4.5508e-01, -2.4902e-01,  ..., -3.4375e-01,
          -2.1191e-01, -1.1406e+00],
         [-3.5547e-01,  1.3574e-01, -6.2891e-01,  ...,  1.0681e-02,
          -1.2695e-01,  1.5234e-01]],

        [[-1.9775e-02,  3.9551e-02, -2.7344e-02,  ..., -4.4434e-02,
          -7.5684e-02,  1.7944e-02],
         [ 2.5781e-01, -4.1797e-01,  1.6406e-01,  ...,  6.4453e-01,
          -3.4766e-01, -4.9414e-01],
         [-5.0781e-01,  4.7070e-01, -2.5781e-01,  ...,  7.9297e-01,
          -6.0547e-01,  4.2969e-02],
         ...,
         [-6.9922e-01,  1.0889e-01, -1.0156e-01,  ..., -5.6641e-01,
          -3.1836e-01, -6.0156e-01],
         [-8.1250e-01, -4.5508e-01, -2.4902e-01,  ..., -3.4375e-01,
          -2.1191e-01, -1.1406e+00],
         [-3.5547e-01,  1.3574e-01, -6.2891e-01,  ...,  1.0681e-02,
          -1.2695e-01,  1.5234e-01]],

        ...,

        [[ 4.2725e-03, -1.0498e-02, -1.6357e-02,  ..., -2.4261e-03,
          -2.6489e-02,  8.6975e-04],
         [-3.2031e-01, -3.2617e-01,  8.8379e-02,  ..., -8.1250e-01,
           1.6504e-01,  4.5508e-01],
         [-2.8711e-01, -3.9453e-01,  1.5137e-01,  ...,  5.1514e-02,
           6.3965e-02,  1.4746e-01],
         ...,
         [-4.1992e-01,  5.3223e-02,  2.7344e-01,  ...,  1.0234e+00,
           5.7031e-01,  2.1680e-01],
         [-3.6914e-01, -1.1084e-01,  5.3906e-01,  ...,  4.9072e-02,
           7.4609e-01,  6.0156e-01],
         [ 3.7695e-01, -1.3867e-01, -1.8848e-01,  ...,  2.7344e-01,
           5.7031e-01, -4.0527e-02]],

        [[ 4.2725e-03, -1.0498e-02, -1.6357e-02,  ..., -2.4261e-03,
          -2.6489e-02,  8.6975e-04],
         [-3.2031e-01, -3.2617e-01,  8.8379e-02,  ..., -8.1250e-01,
           1.6504e-01,  4.5508e-01],
         [-2.8711e-01, -3.9453e-01,  1.5137e-01,  ...,  5.1514e-02,
           6.3965e-02,  1.4746e-01],
         ...,
         [-4.1992e-01,  5.3223e-02,  2.7344e-01,  ...,  1.0234e+00,
           5.7031e-01,  2.1680e-01],
         [-3.6914e-01, -1.1084e-01,  5.3906e-01,  ...,  4.9072e-02,
           7.4609e-01,  6.0156e-01],
         [ 3.7695e-01, -1.3867e-01, -1.8848e-01,  ...,  2.7344e-01,
           5.7031e-01, -4.0527e-02]],

        [[ 4.2725e-03, -1.0498e-02, -1.6357e-02,  ..., -2.4261e-03,
          -2.6489e-02,  8.6975e-04],
         [-3.2031e-01, -3.2617e-01,  8.8379e-02,  ..., -8.1250e-01,
           1.6504e-01,  4.5508e-01],
         [-2.8711e-01, -3.9453e-01,  1.5137e-01,  ...,  5.1514e-02,
           6.3965e-02,  1.4746e-01],
         ...,
         [-4.1992e-01,  5.3223e-02,  2.7344e-01,  ...,  1.0234e+00,
           5.7031e-01,  2.1680e-01],
         [-3.6914e-01, -1.1084e-01,  5.3906e-01,  ...,  4.9072e-02,
           7.4609e-01,  6.0156e-01],
         [ 3.7695e-01, -1.3867e-01, -1.8848e-01,  ...,  2.7344e-01,
           5.7031e-01, -4.0527e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.0504e-04, -5.9509e-04, -8.3618e-03,  ...,  3.5706e-03,
           7.7637e-02, -1.4258e-01],
         [-1.5547e+00, -6.5234e-01, -7.3047e-01,  ...,  2.0312e+00,
          -6.8750e-01,  6.7578e-01],
         [ 6.4844e-01, -6.2500e-01, -6.1328e-01,  ..., -7.3047e-01,
          -1.5078e+00, -6.0547e-01],
         ...,
         [ 8.9062e-01, -2.5781e-01, -3.6914e-01,  ...,  9.8828e-01,
           2.0156e+00,  2.5781e+00],
         [ 7.8125e-01,  4.2383e-01, -4.6094e-01,  ...,  7.5195e-02,
           6.2500e-01,  1.4375e+00],
         [ 1.6484e+00,  8.2812e-01, -1.4141e+00,  ...,  9.5703e-01,
          -1.1641e+00,  6.7188e-01]],

        [[ 2.0504e-04, -5.9509e-04, -8.3618e-03,  ...,  3.5706e-03,
           7.7637e-02, -1.4258e-01],
         [-1.5547e+00, -6.5234e-01, -7.3047e-01,  ...,  2.0312e+00,
          -6.8750e-01,  6.7578e-01],
         [ 6.4844e-01, -6.2500e-01, -6.1328e-01,  ..., -7.3047e-01,
          -1.5078e+00, -6.0547e-01],
         ...,
         [ 8.9062e-01, -2.5781e-01, -3.6914e-01,  ...,  9.8828e-01,
           2.0156e+00,  2.5781e+00],
         [ 7.8125e-01,  4.2383e-01, -4.6094e-01,  ...,  7.5195e-02,
           6.2500e-01,  1.4375e+00],
         [ 1.6484e+00,  8.2812e-01, -1.4141e+00,  ...,  9.5703e-01,
          -1.1641e+00,  6.7188e-01]],

        [[ 2.0504e-04, -5.9509e-04, -8.3618e-03,  ...,  3.5706e-03,
           7.7637e-02, -1.4258e-01],
         [-1.5547e+00, -6.5234e-01, -7.3047e-01,  ...,  2.0312e+00,
          -6.8750e-01,  6.7578e-01],
         [ 6.4844e-01, -6.2500e-01, -6.1328e-01,  ..., -7.3047e-01,
          -1.5078e+00, -6.0547e-01],
         ...,
         [ 8.9062e-01, -2.5781e-01, -3.6914e-01,  ...,  9.8828e-01,
           2.0156e+00,  2.5781e+00],
         [ 7.8125e-01,  4.2383e-01, -4.6094e-01,  ...,  7.5195e-02,
           6.2500e-01,  1.4375e+00],
         [ 1.6484e+00,  8.2812e-01, -1.4141e+00,  ...,  9.5703e-01,
          -1.1641e+00,  6.7188e-01]],

        ...,

        [[-8.8882e-04, -6.9275e-03, -3.6926e-03,  ..., -2.1094e-01,
          -4.8584e-02, -6.6406e-02],
         [ 7.1875e-01,  4.1016e-01, -6.3281e-01,  ...,  2.8438e+00,
          -3.8906e+00, -7.0703e-01],
         [ 1.3086e-01, -2.4805e-01, -8.0859e-01,  ...,  2.0000e+00,
          -4.7812e+00, -3.4531e+00],
         ...,
         [ 2.8906e-01,  1.1377e-01,  2.5781e-01,  ..., -3.8594e+00,
          -1.2891e+00,  5.7373e-03],
         [ 2.7148e-01, -1.9336e-01, -2.1289e-01,  ..., -3.6406e+00,
           2.2656e+00, -1.6484e+00],
         [-2.3828e-01, -6.2500e-02,  5.2002e-02,  ...,  3.7500e+00,
           1.3438e+00,  6.1768e-02]],

        [[-8.8882e-04, -6.9275e-03, -3.6926e-03,  ..., -2.1094e-01,
          -4.8584e-02, -6.6406e-02],
         [ 7.1875e-01,  4.1016e-01, -6.3281e-01,  ...,  2.8438e+00,
          -3.8906e+00, -7.0703e-01],
         [ 1.3086e-01, -2.4805e-01, -8.0859e-01,  ...,  2.0000e+00,
          -4.7812e+00, -3.4531e+00],
         ...,
         [ 2.8906e-01,  1.1377e-01,  2.5781e-01,  ..., -3.8594e+00,
          -1.2891e+00,  5.7373e-03],
         [ 2.7148e-01, -1.9336e-01, -2.1289e-01,  ..., -3.6406e+00,
           2.2656e+00, -1.6484e+00],
         [-2.3828e-01, -6.2500e-02,  5.2002e-02,  ...,  3.7500e+00,
           1.3438e+00,  6.1768e-02]],

        [[-8.8882e-04, -6.9275e-03, -3.6926e-03,  ..., -2.1094e-01,
          -4.8584e-02, -6.6406e-02],
         [ 7.1875e-01,  4.1016e-01, -6.3281e-01,  ...,  2.8438e+00,
          -3.8906e+00, -7.0703e-01],
         [ 1.3086e-01, -2.4805e-01, -8.0859e-01,  ...,  2.0000e+00,
          -4.7812e+00, -3.4531e+00],
         ...,
         [ 2.8906e-01,  1.1377e-01,  2.5781e-01,  ..., -3.8594e+00,
          -1.2891e+00,  5.7373e-03],
         [ 2.7148e-01, -1.9336e-01, -2.1289e-01,  ..., -3.6406e+00,
           2.2656e+00, -1.6484e+00],
         [-2.3828e-01, -6.2500e-02,  5.2002e-02,  ...,  3.7500e+00,
           1.3438e+00,  6.1768e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-4.9210e-04, -2.6398e-03,  5.1270e-03,  ...,  1.7929e-03,
           4.4556e-03,  9.5825e-03],
         [-7.7637e-02, -2.2949e-01, -3.5742e-01,  ...,  1.1597e-02,
          -8.4766e-01, -5.3906e-01],
         [-3.6719e-01,  3.7500e-01,  4.3164e-01,  ..., -1.7188e-01,
          -4.4922e-01,  4.8438e-01],
         ...,
         [-3.7891e-01,  2.6172e-01, -1.5332e-01,  ...,  5.3906e-01,
           4.8242e-01, -4.9414e-01],
         [-2.4023e-01, -6.6406e-02,  2.0801e-01,  ...,  7.9297e-01,
          -8.7402e-02, -2.8320e-01],
         [-7.9297e-01,  2.0312e-01,  1.6992e-01,  ..., -4.4922e-01,
           5.3516e-01,  3.1445e-01]],

        [[-4.9210e-04, -2.6398e-03,  5.1270e-03,  ...,  1.7929e-03,
           4.4556e-03,  9.5825e-03],
         [-7.7637e-02, -2.2949e-01, -3.5742e-01,  ...,  1.1597e-02,
          -8.4766e-01, -5.3906e-01],
         [-3.6719e-01,  3.7500e-01,  4.3164e-01,  ..., -1.7188e-01,
          -4.4922e-01,  4.8438e-01],
         ...,
         [-3.7891e-01,  2.6172e-01, -1.5332e-01,  ...,  5.3906e-01,
           4.8242e-01, -4.9414e-01],
         [-2.4023e-01, -6.6406e-02,  2.0801e-01,  ...,  7.9297e-01,
          -8.7402e-02, -2.8320e-01],
         [-7.9297e-01,  2.0312e-01,  1.6992e-01,  ..., -4.4922e-01,
           5.3516e-01,  3.1445e-01]],

        [[-4.9210e-04, -2.6398e-03,  5.1270e-03,  ...,  1.7929e-03,
           4.4556e-03,  9.5825e-03],
         [-7.7637e-02, -2.2949e-01, -3.5742e-01,  ...,  1.1597e-02,
          -8.4766e-01, -5.3906e-01],
         [-3.6719e-01,  3.7500e-01,  4.3164e-01,  ..., -1.7188e-01,
          -4.4922e-01,  4.8438e-01],
         ...,
         [-3.7891e-01,  2.6172e-01, -1.5332e-01,  ...,  5.3906e-01,
           4.8242e-01, -4.9414e-01],
         [-2.4023e-01, -6.6406e-02,  2.0801e-01,  ...,  7.9297e-01,
          -8.7402e-02, -2.8320e-01],
         [-7.9297e-01,  2.0312e-01,  1.6992e-01,  ..., -4.4922e-01,
           5.3516e-01,  3.1445e-01]],

        ...,

        [[-1.7456e-02,  5.1514e-02, -1.7262e-04,  ...,  2.8968e-05,
           2.8687e-03,  1.8066e-02],
         [-1.3125e+00,  4.2969e-02,  1.1719e-01,  ...,  7.4219e-01,
          -8.2812e-01, -4.1602e-01],
         [-8.3984e-01, -2.1406e+00, -1.8828e+00,  ..., -1.0312e+00,
          -6.6797e-01,  1.4355e-01],
         ...,
         [ 6.5234e-01, -1.3125e+00, -4.6680e-01,  ..., -9.5215e-02,
           1.4648e-01,  5.7031e-01],
         [ 2.6562e-01, -2.0215e-01,  1.2793e-01,  ...,  2.0020e-01,
           7.5391e-01,  2.7588e-02],
         [-4.5508e-01, -4.7852e-01, -9.8047e-01,  ...,  6.3281e-01,
          -5.2734e-01, -1.3672e-01]],

        [[-1.7456e-02,  5.1514e-02, -1.7262e-04,  ...,  2.8968e-05,
           2.8687e-03,  1.8066e-02],
         [-1.3125e+00,  4.2969e-02,  1.1719e-01,  ...,  7.4219e-01,
          -8.2812e-01, -4.1602e-01],
         [-8.3984e-01, -2.1406e+00, -1.8828e+00,  ..., -1.0312e+00,
          -6.6797e-01,  1.4355e-01],
         ...,
         [ 6.5234e-01, -1.3125e+00, -4.6680e-01,  ..., -9.5215e-02,
           1.4648e-01,  5.7031e-01],
         [ 2.6562e-01, -2.0215e-01,  1.2793e-01,  ...,  2.0020e-01,
           7.5391e-01,  2.7588e-02],
         [-4.5508e-01, -4.7852e-01, -9.8047e-01,  ...,  6.3281e-01,
          -5.2734e-01, -1.3672e-01]],

        [[-1.7456e-02,  5.1514e-02, -1.7262e-04,  ...,  2.8968e-05,
           2.8687e-03,  1.8066e-02],
         [-1.3125e+00,  4.2969e-02,  1.1719e-01,  ...,  7.4219e-01,
          -8.2812e-01, -4.1602e-01],
         [-8.3984e-01, -2.1406e+00, -1.8828e+00,  ..., -1.0312e+00,
          -6.6797e-01,  1.4355e-01],
         ...,
         [ 6.5234e-01, -1.3125e+00, -4.6680e-01,  ..., -9.5215e-02,
           1.4648e-01,  5.7031e-01],
         [ 2.6562e-01, -2.0215e-01,  1.2793e-01,  ...,  2.0020e-01,
           7.5391e-01,  2.7588e-02],
         [-4.5508e-01, -4.7852e-01, -9.8047e-01,  ...,  6.3281e-01,
          -5.2734e-01, -1.3672e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 6.1646e-03, -5.6763e-03,  5.2490e-03,  ..., -9.8633e-02,
          -2.2656e-01, -2.7954e-02],
         [-1.6406e+00, -1.4844e-01,  2.5391e-01,  ..., -2.9688e-01,
          -2.1562e+00,  4.9072e-02],
         [-2.0625e+00,  9.0234e-01, -6.2109e-01,  ...,  1.6875e+00,
          -1.6250e+00, -1.6562e+00],
         ...,
         [-3.9258e-01, -9.1016e-01,  1.0000e+00,  ..., -2.5625e+00,
           1.7266e+00, -1.3516e+00],
         [ 1.8848e-01, -1.0625e+00, -4.5898e-01,  ..., -8.3203e-01,
           7.1094e-01, -2.9844e+00],
         [-1.4531e+00, -3.9453e-01,  1.3203e+00,  ..., -1.2266e+00,
          -3.2031e+00,  3.0156e+00]],

        [[ 6.1646e-03, -5.6763e-03,  5.2490e-03,  ..., -9.8633e-02,
          -2.2656e-01, -2.7954e-02],
         [-1.6406e+00, -1.4844e-01,  2.5391e-01,  ..., -2.9688e-01,
          -2.1562e+00,  4.9072e-02],
         [-2.0625e+00,  9.0234e-01, -6.2109e-01,  ...,  1.6875e+00,
          -1.6250e+00, -1.6562e+00],
         ...,
         [-3.9258e-01, -9.1016e-01,  1.0000e+00,  ..., -2.5625e+00,
           1.7266e+00, -1.3516e+00],
         [ 1.8848e-01, -1.0625e+00, -4.5898e-01,  ..., -8.3203e-01,
           7.1094e-01, -2.9844e+00],
         [-1.4531e+00, -3.9453e-01,  1.3203e+00,  ..., -1.2266e+00,
          -3.2031e+00,  3.0156e+00]],

        [[ 6.1646e-03, -5.6763e-03,  5.2490e-03,  ..., -9.8633e-02,
          -2.2656e-01, -2.7954e-02],
         [-1.6406e+00, -1.4844e-01,  2.5391e-01,  ..., -2.9688e-01,
          -2.1562e+00,  4.9072e-02],
         [-2.0625e+00,  9.0234e-01, -6.2109e-01,  ...,  1.6875e+00,
          -1.6250e+00, -1.6562e+00],
         ...,
         [-3.9258e-01, -9.1016e-01,  1.0000e+00,  ..., -2.5625e+00,
           1.7266e+00, -1.3516e+00],
         [ 1.8848e-01, -1.0625e+00, -4.5898e-01,  ..., -8.3203e-01,
           7.1094e-01, -2.9844e+00],
         [-1.4531e+00, -3.9453e-01,  1.3203e+00,  ..., -1.2266e+00,
          -3.2031e+00,  3.0156e+00]],

        ...,

        [[ 2.3193e-03,  1.8921e-02,  1.1902e-02,  ...,  3.8574e-02,
           1.0620e-02,  1.2207e-01],
         [ 9.7656e-01, -4.1016e-02,  1.2031e+00,  ...,  4.9805e-01,
           4.1211e-01, -8.9453e-01],
         [ 9.9609e-01,  1.0938e+00,  2.0781e+00,  ...,  2.4062e+00,
          -1.6357e-02, -6.0156e-01],
         ...,
         [ 1.5156e+00, -8.8672e-01,  5.6641e-02,  ..., -8.6328e-01,
           1.6406e+00, -2.1875e+00],
         [ 3.9062e-01, -9.1406e-01, -2.4805e-01,  ...,  1.1484e+00,
           1.7344e+00, -1.8594e+00],
         [ 6.8750e-01, -1.5000e+00,  1.0156e+00,  ...,  1.6094e+00,
           5.6641e-01, -2.8516e-01]],

        [[ 2.3193e-03,  1.8921e-02,  1.1902e-02,  ...,  3.8574e-02,
           1.0620e-02,  1.2207e-01],
         [ 9.7656e-01, -4.1016e-02,  1.2031e+00,  ...,  4.9805e-01,
           4.1211e-01, -8.9453e-01],
         [ 9.9609e-01,  1.0938e+00,  2.0781e+00,  ...,  2.4062e+00,
          -1.6357e-02, -6.0156e-01],
         ...,
         [ 1.5156e+00, -8.8672e-01,  5.6641e-02,  ..., -8.6328e-01,
           1.6406e+00, -2.1875e+00],
         [ 3.9062e-01, -9.1406e-01, -2.4805e-01,  ...,  1.1484e+00,
           1.7344e+00, -1.8594e+00],
         [ 6.8750e-01, -1.5000e+00,  1.0156e+00,  ...,  1.6094e+00,
           5.6641e-01, -2.8516e-01]],

        [[ 2.3193e-03,  1.8921e-02,  1.1902e-02,  ...,  3.8574e-02,
           1.0620e-02,  1.2207e-01],
         [ 9.7656e-01, -4.1016e-02,  1.2031e+00,  ...,  4.9805e-01,
           4.1211e-01, -8.9453e-01],
         [ 9.9609e-01,  1.0938e+00,  2.0781e+00,  ...,  2.4062e+00,
          -1.6357e-02, -6.0156e-01],
         ...,
         [ 1.5156e+00, -8.8672e-01,  5.6641e-02,  ..., -8.6328e-01,
           1.6406e+00, -2.1875e+00],
         [ 3.9062e-01, -9.1406e-01, -2.4805e-01,  ...,  1.1484e+00,
           1.7344e+00, -1.8594e+00],
         [ 6.8750e-01, -1.5000e+00,  1.0156e+00,  ...,  1.6094e+00,
           5.6641e-01, -2.8516e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-0.0067, -0.0065, -0.0085,  ..., -0.0104, -0.0156, -0.0349],
         [-0.1035,  0.9180,  0.8984,  ..., -0.1216,  0.4629,  0.6289],
         [-0.7266,  1.1094,  0.8711,  ...,  0.2344, -0.6406,  1.3047],
         ...,
         [ 1.1484,  0.2832, -0.2598,  ...,  0.4062,  0.6953, -1.1328],
         [-0.3594, -0.0439, -0.4395,  ...,  0.5234,  0.3008,  0.2461],
         [ 0.5234, -1.4375,  0.5469,  ..., -0.3008,  1.1250,  0.1299]],

        [[-0.0067, -0.0065, -0.0085,  ..., -0.0104, -0.0156, -0.0349],
         [-0.1035,  0.9180,  0.8984,  ..., -0.1216,  0.4629,  0.6289],
         [-0.7266,  1.1094,  0.8711,  ...,  0.2344, -0.6406,  1.3047],
         ...,
         [ 1.1484,  0.2832, -0.2598,  ...,  0.4062,  0.6953, -1.1328],
         [-0.3594, -0.0439, -0.4395,  ...,  0.5234,  0.3008,  0.2461],
         [ 0.5234, -1.4375,  0.5469,  ..., -0.3008,  1.1250,  0.1299]],

        [[-0.0067, -0.0065, -0.0085,  ..., -0.0104, -0.0156, -0.0349],
         [-0.1035,  0.9180,  0.8984,  ..., -0.1216,  0.4629,  0.6289],
         [-0.7266,  1.1094,  0.8711,  ...,  0.2344, -0.6406,  1.3047],
         ...,
         [ 1.1484,  0.2832, -0.2598,  ...,  0.4062,  0.6953, -1.1328],
         [-0.3594, -0.0439, -0.4395,  ...,  0.5234,  0.3008,  0.2461],
         [ 0.5234, -1.4375,  0.5469,  ..., -0.3008,  1.1250,  0.1299]],

        ...,

        [[ 0.0139, -0.0273,  0.0087,  ..., -0.0417,  0.0271,  0.0422],
         [ 0.2275,  0.4062, -0.0045,  ...,  0.7266, -0.3867, -0.6211],
         [-0.2422, -0.1992,  1.3438,  ...,  0.9805, -0.5234, -0.2324],
         ...,
         [ 0.6875, -1.4297, -0.3418,  ...,  0.1069, -0.1187,  0.1299],
         [ 0.5156, -0.5547, -0.6094,  ...,  0.0503,  0.1641,  0.1436],
         [-0.2793,  0.5352, -0.6367,  ..., -0.7109, -1.2656, -0.4180]],

        [[ 0.0139, -0.0273,  0.0087,  ..., -0.0417,  0.0271,  0.0422],
         [ 0.2275,  0.4062, -0.0045,  ...,  0.7266, -0.3867, -0.6211],
         [-0.2422, -0.1992,  1.3438,  ...,  0.9805, -0.5234, -0.2324],
         ...,
         [ 0.6875, -1.4297, -0.3418,  ...,  0.1069, -0.1187,  0.1299],
         [ 0.5156, -0.5547, -0.6094,  ...,  0.0503,  0.1641,  0.1436],
         [-0.2793,  0.5352, -0.6367,  ..., -0.7109, -1.2656, -0.4180]],

        [[ 0.0139, -0.0273,  0.0087,  ..., -0.0417,  0.0271,  0.0422],
         [ 0.2275,  0.4062, -0.0045,  ...,  0.7266, -0.3867, -0.6211],
         [-0.2422, -0.1992,  1.3438,  ...,  0.9805, -0.5234, -0.2324],
         ...,
         [ 0.6875, -1.4297, -0.3418,  ...,  0.1069, -0.1187,  0.1299],
         [ 0.5156, -0.5547, -0.6094,  ...,  0.0503,  0.1641,  0.1436],
         [-0.2793,  0.5352, -0.6367,  ..., -0.7109, -1.2656, -0.4180]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[ 6.0425e-03, -9.7656e-03,  6.5002e-03,  ...,  1.8164e-01,
          -2.0938e+00, -4.0039e-02],
         [ 2.4023e-01,  2.0605e-01,  1.4297e+00,  ..., -9.3359e-01,
           6.4688e+00, -7.0312e-01],
         [-8.1543e-02,  7.2266e-01,  9.5312e-01,  ..., -9.1406e-01,
           7.5000e+00, -3.5938e-01],
         ...,
         [ 8.6328e-01,  1.2793e-01,  4.8438e-01,  ...,  5.8203e-01,
           4.9062e+00,  1.4609e+00],
         [ 2.7344e-02, -3.1006e-02,  4.3359e-01,  ..., -5.8594e-01,
           4.8125e+00,  2.2812e+00],
         [ 8.0078e-01, -9.1797e-01,  1.2344e+00,  ...,  1.5391e+00,
           5.9375e+00, -3.1875e+00]],

        [[ 6.0425e-03, -9.7656e-03,  6.5002e-03,  ...,  1.8164e-01,
          -2.0938e+00, -4.0039e-02],
         [ 2.4023e-01,  2.0605e-01,  1.4297e+00,  ..., -9.3359e-01,
           6.4688e+00, -7.0312e-01],
         [-8.1543e-02,  7.2266e-01,  9.5312e-01,  ..., -9.1406e-01,
           7.5000e+00, -3.5938e-01],
         ...,
         [ 8.6328e-01,  1.2793e-01,  4.8438e-01,  ...,  5.8203e-01,
           4.9062e+00,  1.4609e+00],
         [ 2.7344e-02, -3.1006e-02,  4.3359e-01,  ..., -5.8594e-01,
           4.8125e+00,  2.2812e+00],
         [ 8.0078e-01, -9.1797e-01,  1.2344e+00,  ...,  1.5391e+00,
           5.9375e+00, -3.1875e+00]],

        [[ 6.0425e-03, -9.7656e-03,  6.5002e-03,  ...,  1.8164e-01,
          -2.0938e+00, -4.0039e-02],
         [ 2.4023e-01,  2.0605e-01,  1.4297e+00,  ..., -9.3359e-01,
           6.4688e+00, -7.0312e-01],
         [-8.1543e-02,  7.2266e-01,  9.5312e-01,  ..., -9.1406e-01,
           7.5000e+00, -3.5938e-01],
         ...,
         [ 8.6328e-01,  1.2793e-01,  4.8438e-01,  ...,  5.8203e-01,
           4.9062e+00,  1.4609e+00],
         [ 2.7344e-02, -3.1006e-02,  4.3359e-01,  ..., -5.8594e-01,
           4.8125e+00,  2.2812e+00],
         [ 8.0078e-01, -9.1797e-01,  1.2344e+00,  ...,  1.5391e+00,
           5.9375e+00, -3.1875e+00]],

        ...,

        [[-4.9744e-03, -7.0190e-04, -9.7656e-03,  ...,  3.0859e-01,
          -4.3750e-01, -6.5430e-02],
         [ 1.8906e+00, -1.2500e+00,  9.7656e-01,  ..., -7.7344e-01,
           1.6250e+00, -5.1172e-01],
         [-1.9531e-01,  5.0781e-01, -6.7188e-01,  ..., -1.9531e+00,
           1.1719e+00,  3.0859e-01],
         ...,
         [ 7.8906e-01,  1.1484e+00,  4.3359e-01,  ..., -1.6250e+00,
           2.2754e-01,  1.5625e-01],
         [-5.8594e-01,  4.2578e-01,  6.7969e-01,  ..., -2.0000e+00,
          -5.1562e-01,  1.9165e-02],
         [-1.6016e+00,  4.2188e-01,  1.0859e+00,  ..., -1.5859e+00,
           1.0938e+00,  7.9102e-02]],

        [[-4.9744e-03, -7.0190e-04, -9.7656e-03,  ...,  3.0859e-01,
          -4.3750e-01, -6.5430e-02],
         [ 1.8906e+00, -1.2500e+00,  9.7656e-01,  ..., -7.7344e-01,
           1.6250e+00, -5.1172e-01],
         [-1.9531e-01,  5.0781e-01, -6.7188e-01,  ..., -1.9531e+00,
           1.1719e+00,  3.0859e-01],
         ...,
         [ 7.8906e-01,  1.1484e+00,  4.3359e-01,  ..., -1.6250e+00,
           2.2754e-01,  1.5625e-01],
         [-5.8594e-01,  4.2578e-01,  6.7969e-01,  ..., -2.0000e+00,
          -5.1562e-01,  1.9165e-02],
         [-1.6016e+00,  4.2188e-01,  1.0859e+00,  ..., -1.5859e+00,
           1.0938e+00,  7.9102e-02]],

        [[-4.9744e-03, -7.0190e-04, -9.7656e-03,  ...,  3.0859e-01,
          -4.3750e-01, -6.5430e-02],
         [ 1.8906e+00, -1.2500e+00,  9.7656e-01,  ..., -7.7344e-01,
           1.6250e+00, -5.1172e-01],
         [-1.9531e-01,  5.0781e-01, -6.7188e-01,  ..., -1.9531e+00,
           1.1719e+00,  3.0859e-01],
         ...,
         [ 7.8906e-01,  1.1484e+00,  4.3359e-01,  ..., -1.6250e+00,
           2.2754e-01,  1.5625e-01],
         [-5.8594e-01,  4.2578e-01,  6.7969e-01,  ..., -2.0000e+00,
          -5.1562e-01,  1.9165e-02],
         [-1.6016e+00,  4.2188e-01,  1.0859e+00,  ..., -1.5859e+00,
           1.0938e+00,  7.9102e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 0.0245,  0.0625, -0.0684,  ..., -0.0576, -0.0496,  0.0322],
         [-0.6133,  0.3223, -0.2158,  ...,  0.2930,  0.2852, -0.3613],
         [ 0.2656,  0.1973,  0.1157,  ..., -0.3672, -0.1279,  0.3633],
         ...,
         [ 0.4922, -0.7812,  0.2363,  ...,  0.1992,  0.3848, -0.3848],
         [-0.3398, -1.0859, -0.5117,  ...,  0.4961, -0.0102, -0.2246],
         [ 0.2129,  0.6172, -0.6797,  ...,  0.3301, -0.1914,  0.5352]],

        [[ 0.0245,  0.0625, -0.0684,  ..., -0.0576, -0.0496,  0.0322],
         [-0.6133,  0.3223, -0.2158,  ...,  0.2930,  0.2852, -0.3613],
         [ 0.2656,  0.1973,  0.1157,  ..., -0.3672, -0.1279,  0.3633],
         ...,
         [ 0.4922, -0.7812,  0.2363,  ...,  0.1992,  0.3848, -0.3848],
         [-0.3398, -1.0859, -0.5117,  ...,  0.4961, -0.0102, -0.2246],
         [ 0.2129,  0.6172, -0.6797,  ...,  0.3301, -0.1914,  0.5352]],

        [[ 0.0245,  0.0625, -0.0684,  ..., -0.0576, -0.0496,  0.0322],
         [-0.6133,  0.3223, -0.2158,  ...,  0.2930,  0.2852, -0.3613],
         [ 0.2656,  0.1973,  0.1157,  ..., -0.3672, -0.1279,  0.3633],
         ...,
         [ 0.4922, -0.7812,  0.2363,  ...,  0.1992,  0.3848, -0.3848],
         [-0.3398, -1.0859, -0.5117,  ...,  0.4961, -0.0102, -0.2246],
         [ 0.2129,  0.6172, -0.6797,  ...,  0.3301, -0.1914,  0.5352]],

        ...,

        [[ 0.0066, -0.0017,  0.0037,  ...,  0.0084,  0.0132,  0.0036],
         [-0.4062, -1.0234,  0.8516,  ...,  0.2617,  1.2031,  0.4395],
         [-0.5898,  0.8008,  0.1118,  ...,  0.6875,  1.5156, -0.1387],
         ...,
         [ 0.7734,  1.6328,  0.4062,  ...,  0.0635, -1.4531, -0.7656],
         [ 0.9453, -0.4082,  0.7617,  ...,  1.6797,  0.2002,  0.5781],
         [-0.2031,  0.1699,  0.1191,  ..., -0.1377,  0.6641,  0.8008]],

        [[ 0.0066, -0.0017,  0.0037,  ...,  0.0084,  0.0132,  0.0036],
         [-0.4062, -1.0234,  0.8516,  ...,  0.2617,  1.2031,  0.4395],
         [-0.5898,  0.8008,  0.1118,  ...,  0.6875,  1.5156, -0.1387],
         ...,
         [ 0.7734,  1.6328,  0.4062,  ...,  0.0635, -1.4531, -0.7656],
         [ 0.9453, -0.4082,  0.7617,  ...,  1.6797,  0.2002,  0.5781],
         [-0.2031,  0.1699,  0.1191,  ..., -0.1377,  0.6641,  0.8008]],

        [[ 0.0066, -0.0017,  0.0037,  ...,  0.0084,  0.0132,  0.0036],
         [-0.4062, -1.0234,  0.8516,  ...,  0.2617,  1.2031,  0.4395],
         [-0.5898,  0.8008,  0.1118,  ...,  0.6875,  1.5156, -0.1387],
         ...,
         [ 0.7734,  1.6328,  0.4062,  ...,  0.0635, -1.4531, -0.7656],
         [ 0.9453, -0.4082,  0.7617,  ...,  1.6797,  0.2002,  0.5781],
         [-0.2031,  0.1699,  0.1191,  ..., -0.1377,  0.6641,  0.8008]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[-3.7537e-03, -2.6245e-03,  1.2329e-02,  ...,  1.1768e-01,
          -4.6387e-02,  2.0508e-01],
         [-1.5781e+00, -1.7578e-02, -8.6719e-01,  ..., -1.2500e+00,
           7.1875e-01, -3.4961e-01],
         [-7.1875e-01, -1.1182e-01,  2.5195e-01,  ..., -1.3438e+00,
           3.3594e+00, -1.1172e+00],
         ...,
         [-3.6719e-01,  1.2500e+00, -5.3125e-01,  ..., -1.8359e+00,
          -5.3516e-01,  2.1875e-01],
         [ 4.0625e-01,  1.9824e-01, -3.6719e-01,  ..., -5.3125e-01,
          -1.8848e-01, -9.7656e-03],
         [ 1.7188e-01,  7.0703e-01, -8.2422e-01,  ...,  1.3047e+00,
           2.5625e+00, -3.1406e+00]],

        [[-3.7537e-03, -2.6245e-03,  1.2329e-02,  ...,  1.1768e-01,
          -4.6387e-02,  2.0508e-01],
         [-1.5781e+00, -1.7578e-02, -8.6719e-01,  ..., -1.2500e+00,
           7.1875e-01, -3.4961e-01],
         [-7.1875e-01, -1.1182e-01,  2.5195e-01,  ..., -1.3438e+00,
           3.3594e+00, -1.1172e+00],
         ...,
         [-3.6719e-01,  1.2500e+00, -5.3125e-01,  ..., -1.8359e+00,
          -5.3516e-01,  2.1875e-01],
         [ 4.0625e-01,  1.9824e-01, -3.6719e-01,  ..., -5.3125e-01,
          -1.8848e-01, -9.7656e-03],
         [ 1.7188e-01,  7.0703e-01, -8.2422e-01,  ...,  1.3047e+00,
           2.5625e+00, -3.1406e+00]],

        [[-3.7537e-03, -2.6245e-03,  1.2329e-02,  ...,  1.1768e-01,
          -4.6387e-02,  2.0508e-01],
         [-1.5781e+00, -1.7578e-02, -8.6719e-01,  ..., -1.2500e+00,
           7.1875e-01, -3.4961e-01],
         [-7.1875e-01, -1.1182e-01,  2.5195e-01,  ..., -1.3438e+00,
           3.3594e+00, -1.1172e+00],
         ...,
         [-3.6719e-01,  1.2500e+00, -5.3125e-01,  ..., -1.8359e+00,
          -5.3516e-01,  2.1875e-01],
         [ 4.0625e-01,  1.9824e-01, -3.6719e-01,  ..., -5.3125e-01,
          -1.8848e-01, -9.7656e-03],
         [ 1.7188e-01,  7.0703e-01, -8.2422e-01,  ...,  1.3047e+00,
           2.5625e+00, -3.1406e+00]],

        ...,

        [[-1.0315e-02,  6.5804e-05,  1.0681e-02,  ...,  5.1172e-01,
           1.9043e-02, -2.9907e-02],
         [ 9.0625e-01, -7.8125e-02, -1.0156e+00,  ..., -3.0000e+00,
           4.2578e-01, -9.8047e-01],
         [ 6.7188e-01,  3.4180e-01,  4.0625e-01,  ..., -1.3594e+00,
           1.4844e-01, -1.0312e+00],
         ...,
         [-3.3203e-01,  3.3594e-01, -1.1865e-01,  ..., -1.2031e+00,
          -2.9844e+00, -7.1484e-01],
         [-3.5938e-01, -5.0293e-02,  2.6245e-02,  ..., -4.6631e-02,
          -1.1016e+00,  5.2344e-01],
         [-1.6875e+00, -2.2266e-01, -1.6406e+00,  ...,  4.3164e-01,
           2.1606e-02,  7.4609e-01]],

        [[-1.0315e-02,  6.5804e-05,  1.0681e-02,  ...,  5.1172e-01,
           1.9043e-02, -2.9907e-02],
         [ 9.0625e-01, -7.8125e-02, -1.0156e+00,  ..., -3.0000e+00,
           4.2578e-01, -9.8047e-01],
         [ 6.7188e-01,  3.4180e-01,  4.0625e-01,  ..., -1.3594e+00,
           1.4844e-01, -1.0312e+00],
         ...,
         [-3.3203e-01,  3.3594e-01, -1.1865e-01,  ..., -1.2031e+00,
          -2.9844e+00, -7.1484e-01],
         [-3.5938e-01, -5.0293e-02,  2.6245e-02,  ..., -4.6631e-02,
          -1.1016e+00,  5.2344e-01],
         [-1.6875e+00, -2.2266e-01, -1.6406e+00,  ...,  4.3164e-01,
           2.1606e-02,  7.4609e-01]],

        [[-1.0315e-02,  6.5804e-05,  1.0681e-02,  ...,  5.1172e-01,
           1.9043e-02, -2.9907e-02],
         [ 9.0625e-01, -7.8125e-02, -1.0156e+00,  ..., -3.0000e+00,
           4.2578e-01, -9.8047e-01],
         [ 6.7188e-01,  3.4180e-01,  4.0625e-01,  ..., -1.3594e+00,
           1.4844e-01, -1.0312e+00],
         ...,
         [-3.3203e-01,  3.3594e-01, -1.1865e-01,  ..., -1.2031e+00,
          -2.9844e+00, -7.1484e-01],
         [-3.5938e-01, -5.0293e-02,  2.6245e-02,  ..., -4.6631e-02,
          -1.1016e+00,  5.2344e-01],
         [-1.6875e+00, -2.2266e-01, -1.6406e+00,  ...,  4.3164e-01,
           2.1606e-02,  7.4609e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-2.2339e-02, -2.2583e-02, -1.3672e-01,  ...,  2.0386e-02,
           5.4016e-03, -2.3193e-02],
         [-7.5000e-01, -3.5547e-01,  6.2500e-01,  ...,  3.8281e-01,
           3.4961e-01,  3.8477e-01],
         [ 2.2266e-01, -3.2617e-01,  4.4727e-01,  ..., -1.6504e-01,
           5.8203e-01,  2.3047e-01],
         ...,
         [-3.8086e-02, -1.5723e-01, -9.1797e-01,  ...,  4.5117e-01,
           3.3203e-01,  1.0156e+00],
         [-8.5547e-01,  1.8438e+00,  9.1797e-02,  ...,  1.1406e+00,
           5.5859e-01,  8.1641e-01],
         [ 5.8594e-01, -1.0059e-01, -5.7031e-01,  ..., -5.2734e-02,
           2.9883e-01,  5.0391e-01]],

        [[-2.2339e-02, -2.2583e-02, -1.3672e-01,  ...,  2.0386e-02,
           5.4016e-03, -2.3193e-02],
         [-7.5000e-01, -3.5547e-01,  6.2500e-01,  ...,  3.8281e-01,
           3.4961e-01,  3.8477e-01],
         [ 2.2266e-01, -3.2617e-01,  4.4727e-01,  ..., -1.6504e-01,
           5.8203e-01,  2.3047e-01],
         ...,
         [-3.8086e-02, -1.5723e-01, -9.1797e-01,  ...,  4.5117e-01,
           3.3203e-01,  1.0156e+00],
         [-8.5547e-01,  1.8438e+00,  9.1797e-02,  ...,  1.1406e+00,
           5.5859e-01,  8.1641e-01],
         [ 5.8594e-01, -1.0059e-01, -5.7031e-01,  ..., -5.2734e-02,
           2.9883e-01,  5.0391e-01]],

        [[-2.2339e-02, -2.2583e-02, -1.3672e-01,  ...,  2.0386e-02,
           5.4016e-03, -2.3193e-02],
         [-7.5000e-01, -3.5547e-01,  6.2500e-01,  ...,  3.8281e-01,
           3.4961e-01,  3.8477e-01],
         [ 2.2266e-01, -3.2617e-01,  4.4727e-01,  ..., -1.6504e-01,
           5.8203e-01,  2.3047e-01],
         ...,
         [-3.8086e-02, -1.5723e-01, -9.1797e-01,  ...,  4.5117e-01,
           3.3203e-01,  1.0156e+00],
         [-8.5547e-01,  1.8438e+00,  9.1797e-02,  ...,  1.1406e+00,
           5.5859e-01,  8.1641e-01],
         [ 5.8594e-01, -1.0059e-01, -5.7031e-01,  ..., -5.2734e-02,
           2.9883e-01,  5.0391e-01]],

        ...,

        [[-1.1963e-02,  1.4801e-03, -6.3171e-03,  ...,  3.9062e-03,
          -2.3315e-02, -1.6327e-03],
         [-4.7852e-01, -4.9219e-01,  1.7285e-01,  ..., -3.8086e-01,
           7.8516e-01,  5.4297e-01],
         [-1.1953e+00, -3.8281e-01, -4.9609e-01,  ..., -6.0938e-01,
           5.3516e-01,  4.7266e-01],
         ...,
         [-1.1279e-01,  9.8633e-02,  7.9688e-01,  ...,  6.2109e-01,
          -5.8984e-01, -8.0859e-01],
         [-1.5625e-01,  4.9805e-02,  2.5781e-01,  ...,  9.7656e-01,
           2.1973e-01,  2.3633e-01],
         [-2.4805e-01,  3.6914e-01,  1.1016e+00,  ...,  7.0312e-01,
           2.5391e-01, -3.4912e-02]],

        [[-1.1963e-02,  1.4801e-03, -6.3171e-03,  ...,  3.9062e-03,
          -2.3315e-02, -1.6327e-03],
         [-4.7852e-01, -4.9219e-01,  1.7285e-01,  ..., -3.8086e-01,
           7.8516e-01,  5.4297e-01],
         [-1.1953e+00, -3.8281e-01, -4.9609e-01,  ..., -6.0938e-01,
           5.3516e-01,  4.7266e-01],
         ...,
         [-1.1279e-01,  9.8633e-02,  7.9688e-01,  ...,  6.2109e-01,
          -5.8984e-01, -8.0859e-01],
         [-1.5625e-01,  4.9805e-02,  2.5781e-01,  ...,  9.7656e-01,
           2.1973e-01,  2.3633e-01],
         [-2.4805e-01,  3.6914e-01,  1.1016e+00,  ...,  7.0312e-01,
           2.5391e-01, -3.4912e-02]],

        [[-1.1963e-02,  1.4801e-03, -6.3171e-03,  ...,  3.9062e-03,
          -2.3315e-02, -1.6327e-03],
         [-4.7852e-01, -4.9219e-01,  1.7285e-01,  ..., -3.8086e-01,
           7.8516e-01,  5.4297e-01],
         [-1.1953e+00, -3.8281e-01, -4.9609e-01,  ..., -6.0938e-01,
           5.3516e-01,  4.7266e-01],
         ...,
         [-1.1279e-01,  9.8633e-02,  7.9688e-01,  ...,  6.2109e-01,
          -5.8984e-01, -8.0859e-01],
         [-1.5625e-01,  4.9805e-02,  2.5781e-01,  ...,  9.7656e-01,
           2.1973e-01,  2.3633e-01],
         [-2.4805e-01,  3.6914e-01,  1.1016e+00,  ...,  7.0312e-01,
           2.5391e-01, -3.4912e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 7.9956e-03,  4.7302e-03,  9.1553e-05,  ...,  4.0234e-01,
          -7.4707e-02, -5.3516e-01],
         [ 1.8750e-01, -2.0312e+00,  1.7109e+00,  ..., -2.7344e+00,
           9.1406e-01,  3.5938e-01],
         [-6.9922e-01,  2.0508e-01,  7.5781e-01,  ..., -2.6875e+00,
           4.7070e-01, -4.1504e-02],
         ...,
         [-1.1797e+00,  2.8906e-01, -1.2031e+00,  ..., -3.5625e+00,
          -1.6797e-01,  1.8281e+00],
         [-5.8203e-01,  8.5938e-01,  9.7656e-02,  ..., -3.3750e+00,
          -2.4375e+00,  3.0781e+00],
         [-1.3672e+00,  1.1094e+00,  8.6328e-01,  ..., -4.0938e+00,
           2.7188e+00,  1.2891e+00]],

        [[ 7.9956e-03,  4.7302e-03,  9.1553e-05,  ...,  4.0234e-01,
          -7.4707e-02, -5.3516e-01],
         [ 1.8750e-01, -2.0312e+00,  1.7109e+00,  ..., -2.7344e+00,
           9.1406e-01,  3.5938e-01],
         [-6.9922e-01,  2.0508e-01,  7.5781e-01,  ..., -2.6875e+00,
           4.7070e-01, -4.1504e-02],
         ...,
         [-1.1797e+00,  2.8906e-01, -1.2031e+00,  ..., -3.5625e+00,
          -1.6797e-01,  1.8281e+00],
         [-5.8203e-01,  8.5938e-01,  9.7656e-02,  ..., -3.3750e+00,
          -2.4375e+00,  3.0781e+00],
         [-1.3672e+00,  1.1094e+00,  8.6328e-01,  ..., -4.0938e+00,
           2.7188e+00,  1.2891e+00]],

        [[ 7.9956e-03,  4.7302e-03,  9.1553e-05,  ...,  4.0234e-01,
          -7.4707e-02, -5.3516e-01],
         [ 1.8750e-01, -2.0312e+00,  1.7109e+00,  ..., -2.7344e+00,
           9.1406e-01,  3.5938e-01],
         [-6.9922e-01,  2.0508e-01,  7.5781e-01,  ..., -2.6875e+00,
           4.7070e-01, -4.1504e-02],
         ...,
         [-1.1797e+00,  2.8906e-01, -1.2031e+00,  ..., -3.5625e+00,
          -1.6797e-01,  1.8281e+00],
         [-5.8203e-01,  8.5938e-01,  9.7656e-02,  ..., -3.3750e+00,
          -2.4375e+00,  3.0781e+00],
         [-1.3672e+00,  1.1094e+00,  8.6328e-01,  ..., -4.0938e+00,
           2.7188e+00,  1.2891e+00]],

        ...,

        [[ 3.9368e-03,  7.5378e-03, -1.5137e-02,  ...,  9.0820e-02,
          -4.9805e-01, -1.0469e+00],
         [ 1.6719e+00, -2.3633e-01,  4.3359e-01,  ..., -6.9531e-01,
           4.8438e+00,  3.9531e+00],
         [ 2.9297e-01,  6.9922e-01,  4.2383e-01,  ..., -1.3359e+00,
           3.0156e+00,  5.8125e+00],
         ...,
         [ 4.7461e-01,  1.2344e+00, -1.4062e+00,  ..., -1.7212e-02,
          -5.6641e-01,  3.5625e+00],
         [-3.2812e-01,  3.0078e-01,  1.4062e-01,  ...,  1.2969e+00,
          -1.8359e-01,  4.1562e+00],
         [-1.3984e+00,  7.1875e-01,  9.4531e-01,  ..., -3.4375e+00,
           2.9844e+00,  5.0938e+00]],

        [[ 3.9368e-03,  7.5378e-03, -1.5137e-02,  ...,  9.0820e-02,
          -4.9805e-01, -1.0469e+00],
         [ 1.6719e+00, -2.3633e-01,  4.3359e-01,  ..., -6.9531e-01,
           4.8438e+00,  3.9531e+00],
         [ 2.9297e-01,  6.9922e-01,  4.2383e-01,  ..., -1.3359e+00,
           3.0156e+00,  5.8125e+00],
         ...,
         [ 4.7461e-01,  1.2344e+00, -1.4062e+00,  ..., -1.7212e-02,
          -5.6641e-01,  3.5625e+00],
         [-3.2812e-01,  3.0078e-01,  1.4062e-01,  ...,  1.2969e+00,
          -1.8359e-01,  4.1562e+00],
         [-1.3984e+00,  7.1875e-01,  9.4531e-01,  ..., -3.4375e+00,
           2.9844e+00,  5.0938e+00]],

        [[ 3.9368e-03,  7.5378e-03, -1.5137e-02,  ...,  9.0820e-02,
          -4.9805e-01, -1.0469e+00],
         [ 1.6719e+00, -2.3633e-01,  4.3359e-01,  ..., -6.9531e-01,
           4.8438e+00,  3.9531e+00],
         [ 2.9297e-01,  6.9922e-01,  4.2383e-01,  ..., -1.3359e+00,
           3.0156e+00,  5.8125e+00],
         ...,
         [ 4.7461e-01,  1.2344e+00, -1.4062e+00,  ..., -1.7212e-02,
          -5.6641e-01,  3.5625e+00],
         [-3.2812e-01,  3.0078e-01,  1.4062e-01,  ...,  1.2969e+00,
          -1.8359e-01,  4.1562e+00],
         [-1.3984e+00,  7.1875e-01,  9.4531e-01,  ..., -3.4375e+00,
           2.9844e+00,  5.0938e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-7.3547e-03, -4.1504e-03, -1.7334e-02,  ...,  1.0132e-02,
          -1.4725e-03,  4.8828e-03],
         [-9.4531e-01,  1.7188e-01,  2.5781e+00,  ..., -4.9219e-01,
           1.5234e+00,  1.3594e+00],
         [-2.6562e-01, -1.1484e+00,  1.6406e+00,  ..., -1.5078e+00,
           9.1016e-01,  1.6406e+00],
         ...,
         [-6.5625e-01,  3.2959e-02, -3.8281e-01,  ...,  1.9141e-01,
          -4.0039e-01, -2.3804e-02],
         [-1.8262e-01, -4.1016e-01, -5.8984e-01,  ...,  1.1572e-01,
          -3.8867e-01,  8.2812e-01],
         [-1.5938e+00, -4.6289e-01,  6.7188e-01,  ..., -7.8125e-02,
           5.4932e-02, -9.1797e-02]],

        [[-7.3547e-03, -4.1504e-03, -1.7334e-02,  ...,  1.0132e-02,
          -1.4725e-03,  4.8828e-03],
         [-9.4531e-01,  1.7188e-01,  2.5781e+00,  ..., -4.9219e-01,
           1.5234e+00,  1.3594e+00],
         [-2.6562e-01, -1.1484e+00,  1.6406e+00,  ..., -1.5078e+00,
           9.1016e-01,  1.6406e+00],
         ...,
         [-6.5625e-01,  3.2959e-02, -3.8281e-01,  ...,  1.9141e-01,
          -4.0039e-01, -2.3804e-02],
         [-1.8262e-01, -4.1016e-01, -5.8984e-01,  ...,  1.1572e-01,
          -3.8867e-01,  8.2812e-01],
         [-1.5938e+00, -4.6289e-01,  6.7188e-01,  ..., -7.8125e-02,
           5.4932e-02, -9.1797e-02]],

        [[-7.3547e-03, -4.1504e-03, -1.7334e-02,  ...,  1.0132e-02,
          -1.4725e-03,  4.8828e-03],
         [-9.4531e-01,  1.7188e-01,  2.5781e+00,  ..., -4.9219e-01,
           1.5234e+00,  1.3594e+00],
         [-2.6562e-01, -1.1484e+00,  1.6406e+00,  ..., -1.5078e+00,
           9.1016e-01,  1.6406e+00],
         ...,
         [-6.5625e-01,  3.2959e-02, -3.8281e-01,  ...,  1.9141e-01,
          -4.0039e-01, -2.3804e-02],
         [-1.8262e-01, -4.1016e-01, -5.8984e-01,  ...,  1.1572e-01,
          -3.8867e-01,  8.2812e-01],
         [-1.5938e+00, -4.6289e-01,  6.7188e-01,  ..., -7.8125e-02,
           5.4932e-02, -9.1797e-02]],

        ...,

        [[-7.8735e-03,  1.6357e-02,  8.4686e-04,  ..., -9.6436e-03,
          -1.5747e-02,  1.2512e-02],
         [-1.0840e-01, -1.2031e+00, -2.6953e-01,  ..., -1.9653e-02,
          -2.6953e-01,  1.5991e-02],
         [ 5.1172e-01,  1.1328e+00, -1.2402e-01,  ...,  1.3281e-01,
           4.2773e-01, -2.9297e-01],
         ...,
         [-8.1641e-01, -6.1328e-01, -5.0781e-02,  ...,  6.1719e-01,
          -1.1250e+00, -1.1328e-01],
         [ 3.2471e-02, -1.7383e-01,  3.6328e-01,  ..., -6.7188e-01,
           6.5625e-01, -1.2793e-01],
         [-1.1875e+00,  2.0625e+00,  1.7188e+00,  ..., -5.2344e-01,
           9.9219e-01, -6.4453e-01]],

        [[-7.8735e-03,  1.6357e-02,  8.4686e-04,  ..., -9.6436e-03,
          -1.5747e-02,  1.2512e-02],
         [-1.0840e-01, -1.2031e+00, -2.6953e-01,  ..., -1.9653e-02,
          -2.6953e-01,  1.5991e-02],
         [ 5.1172e-01,  1.1328e+00, -1.2402e-01,  ...,  1.3281e-01,
           4.2773e-01, -2.9297e-01],
         ...,
         [-8.1641e-01, -6.1328e-01, -5.0781e-02,  ...,  6.1719e-01,
          -1.1250e+00, -1.1328e-01],
         [ 3.2471e-02, -1.7383e-01,  3.6328e-01,  ..., -6.7188e-01,
           6.5625e-01, -1.2793e-01],
         [-1.1875e+00,  2.0625e+00,  1.7188e+00,  ..., -5.2344e-01,
           9.9219e-01, -6.4453e-01]],

        [[-7.8735e-03,  1.6357e-02,  8.4686e-04,  ..., -9.6436e-03,
          -1.5747e-02,  1.2512e-02],
         [-1.0840e-01, -1.2031e+00, -2.6953e-01,  ..., -1.9653e-02,
          -2.6953e-01,  1.5991e-02],
         [ 5.1172e-01,  1.1328e+00, -1.2402e-01,  ...,  1.3281e-01,
           4.2773e-01, -2.9297e-01],
         ...,
         [-8.1641e-01, -6.1328e-01, -5.0781e-02,  ...,  6.1719e-01,
          -1.1250e+00, -1.1328e-01],
         [ 3.2471e-02, -1.7383e-01,  3.6328e-01,  ..., -6.7188e-01,
           6.5625e-01, -1.2793e-01],
         [-1.1875e+00,  2.0625e+00,  1.7188e+00,  ..., -5.2344e-01,
           9.9219e-01, -6.4453e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 3.4332e-04,  1.2634e-02, -1.8433e-02,  ...,  3.8330e-02,
           1.4941e-01, -1.8828e+00],
         [ 9.5312e-01, -1.9219e+00, -1.5820e-01,  ..., -1.0391e+00,
          -6.2891e-01,  7.4688e+00],
         [-3.3594e-01, -6.1719e-01, -3.4180e-01,  ...,  6.3281e-01,
           2.0156e+00,  8.3750e+00],
         ...,
         [-4.4922e-01,  3.6914e-01, -8.1641e-01,  ..., -2.8594e+00,
          -3.4688e+00,  5.1562e+00],
         [ 2.5391e-01, -2.1484e-01,  1.2988e-01,  ..., -2.7500e+00,
          -1.4766e+00,  5.7500e+00],
         [-1.0234e+00,  1.0625e+00, -1.3750e+00,  ..., -8.0859e-01,
           2.2188e+00,  7.1875e+00]],

        [[ 3.4332e-04,  1.2634e-02, -1.8433e-02,  ...,  3.8330e-02,
           1.4941e-01, -1.8828e+00],
         [ 9.5312e-01, -1.9219e+00, -1.5820e-01,  ..., -1.0391e+00,
          -6.2891e-01,  7.4688e+00],
         [-3.3594e-01, -6.1719e-01, -3.4180e-01,  ...,  6.3281e-01,
           2.0156e+00,  8.3750e+00],
         ...,
         [-4.4922e-01,  3.6914e-01, -8.1641e-01,  ..., -2.8594e+00,
          -3.4688e+00,  5.1562e+00],
         [ 2.5391e-01, -2.1484e-01,  1.2988e-01,  ..., -2.7500e+00,
          -1.4766e+00,  5.7500e+00],
         [-1.0234e+00,  1.0625e+00, -1.3750e+00,  ..., -8.0859e-01,
           2.2188e+00,  7.1875e+00]],

        [[ 3.4332e-04,  1.2634e-02, -1.8433e-02,  ...,  3.8330e-02,
           1.4941e-01, -1.8828e+00],
         [ 9.5312e-01, -1.9219e+00, -1.5820e-01,  ..., -1.0391e+00,
          -6.2891e-01,  7.4688e+00],
         [-3.3594e-01, -6.1719e-01, -3.4180e-01,  ...,  6.3281e-01,
           2.0156e+00,  8.3750e+00],
         ...,
         [-4.4922e-01,  3.6914e-01, -8.1641e-01,  ..., -2.8594e+00,
          -3.4688e+00,  5.1562e+00],
         [ 2.5391e-01, -2.1484e-01,  1.2988e-01,  ..., -2.7500e+00,
          -1.4766e+00,  5.7500e+00],
         [-1.0234e+00,  1.0625e+00, -1.3750e+00,  ..., -8.0859e-01,
           2.2188e+00,  7.1875e+00]],

        ...,

        [[ 4.5776e-03, -9.3994e-03, -8.3618e-03,  ..., -1.5156e+00,
          -4.8828e-01,  8.8379e-02],
         [-2.8438e+00,  2.7188e+00,  1.3594e+00,  ...,  8.3750e+00,
          -1.0234e+00, -5.2344e-01],
         [-7.5000e-01,  4.7266e-01,  3.7305e-01,  ...,  9.1875e+00,
          -2.5000e+00, -2.6406e+00],
         ...,
         [-1.3672e+00, -8.7500e-01, -1.0391e+00,  ...,  7.7188e+00,
          -2.6367e-01,  9.9609e-01],
         [-1.4688e+00, -8.3203e-01, -6.3281e-01,  ...,  6.9062e+00,
           1.3672e+00,  5.8203e-01],
         [-3.4570e-01, -1.9844e+00, -1.2109e+00,  ...,  7.4062e+00,
          -4.2480e-02, -2.3750e+00]],

        [[ 4.5776e-03, -9.3994e-03, -8.3618e-03,  ..., -1.5156e+00,
          -4.8828e-01,  8.8379e-02],
         [-2.8438e+00,  2.7188e+00,  1.3594e+00,  ...,  8.3750e+00,
          -1.0234e+00, -5.2344e-01],
         [-7.5000e-01,  4.7266e-01,  3.7305e-01,  ...,  9.1875e+00,
          -2.5000e+00, -2.6406e+00],
         ...,
         [-1.3672e+00, -8.7500e-01, -1.0391e+00,  ...,  7.7188e+00,
          -2.6367e-01,  9.9609e-01],
         [-1.4688e+00, -8.3203e-01, -6.3281e-01,  ...,  6.9062e+00,
           1.3672e+00,  5.8203e-01],
         [-3.4570e-01, -1.9844e+00, -1.2109e+00,  ...,  7.4062e+00,
          -4.2480e-02, -2.3750e+00]],

        [[ 4.5776e-03, -9.3994e-03, -8.3618e-03,  ..., -1.5156e+00,
          -4.8828e-01,  8.8379e-02],
         [-2.8438e+00,  2.7188e+00,  1.3594e+00,  ...,  8.3750e+00,
          -1.0234e+00, -5.2344e-01],
         [-7.5000e-01,  4.7266e-01,  3.7305e-01,  ...,  9.1875e+00,
          -2.5000e+00, -2.6406e+00],
         ...,
         [-1.3672e+00, -8.7500e-01, -1.0391e+00,  ...,  7.7188e+00,
          -2.6367e-01,  9.9609e-01],
         [-1.4688e+00, -8.3203e-01, -6.3281e-01,  ...,  6.9062e+00,
           1.3672e+00,  5.8203e-01],
         [-3.4570e-01, -1.9844e+00, -1.2109e+00,  ...,  7.4062e+00,
          -4.2480e-02, -2.3750e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 8.1543e-02,  4.4556e-03,  5.4199e-02,  ..., -3.0212e-03,
          -8.1787e-03, -3.4668e-02],
         [-3.1641e-01, -2.2461e-01, -2.6758e-01,  ...,  4.6289e-01,
          -2.1191e-01, -3.3789e-01],
         [-8.7402e-02,  3.1250e-01,  4.8828e-01,  ...,  7.9297e-01,
           9.8633e-02,  1.7090e-01],
         ...,
         [ 8.5156e-01,  1.0703e+00,  3.1982e-02,  ...,  1.0498e-01,
           2.8125e-01, -5.4688e-01],
         [-3.8330e-02,  3.7500e-01,  3.3594e-01,  ...,  9.4922e-01,
           7.2656e-01, -8.1641e-01],
         [ 2.1387e-01,  3.5547e-01, -1.5234e-01,  ...,  4.3750e-01,
           6.2891e-01, -4.7852e-02]],

        [[ 8.1543e-02,  4.4556e-03,  5.4199e-02,  ..., -3.0212e-03,
          -8.1787e-03, -3.4668e-02],
         [-3.1641e-01, -2.2461e-01, -2.6758e-01,  ...,  4.6289e-01,
          -2.1191e-01, -3.3789e-01],
         [-8.7402e-02,  3.1250e-01,  4.8828e-01,  ...,  7.9297e-01,
           9.8633e-02,  1.7090e-01],
         ...,
         [ 8.5156e-01,  1.0703e+00,  3.1982e-02,  ...,  1.0498e-01,
           2.8125e-01, -5.4688e-01],
         [-3.8330e-02,  3.7500e-01,  3.3594e-01,  ...,  9.4922e-01,
           7.2656e-01, -8.1641e-01],
         [ 2.1387e-01,  3.5547e-01, -1.5234e-01,  ...,  4.3750e-01,
           6.2891e-01, -4.7852e-02]],

        [[ 8.1543e-02,  4.4556e-03,  5.4199e-02,  ..., -3.0212e-03,
          -8.1787e-03, -3.4668e-02],
         [-3.1641e-01, -2.2461e-01, -2.6758e-01,  ...,  4.6289e-01,
          -2.1191e-01, -3.3789e-01],
         [-8.7402e-02,  3.1250e-01,  4.8828e-01,  ...,  7.9297e-01,
           9.8633e-02,  1.7090e-01],
         ...,
         [ 8.5156e-01,  1.0703e+00,  3.1982e-02,  ...,  1.0498e-01,
           2.8125e-01, -5.4688e-01],
         [-3.8330e-02,  3.7500e-01,  3.3594e-01,  ...,  9.4922e-01,
           7.2656e-01, -8.1641e-01],
         [ 2.1387e-01,  3.5547e-01, -1.5234e-01,  ...,  4.3750e-01,
           6.2891e-01, -4.7852e-02]],

        ...,

        [[ 5.6763e-03, -2.8534e-03,  6.3705e-04,  ...,  3.8300e-03,
          -9.3994e-03,  3.9673e-03],
         [-2.5938e+00,  3.1406e+00, -4.9023e-01,  ..., -5.6250e-01,
          -1.3281e+00,  2.0142e-02],
         [-2.4844e+00,  1.4297e+00, -1.9922e+00,  ...,  5.3516e-01,
          -4.1797e-01, -2.4609e-01],
         ...,
         [ 7.5391e-01, -5.7031e-01,  7.5391e-01,  ...,  8.8672e-01,
          -7.6172e-01,  4.5117e-01],
         [ 4.8633e-01, -8.3008e-02,  4.8438e-01,  ..., -6.4941e-02,
          -2.4316e-01,  6.4941e-02],
         [-8.1055e-02,  5.4443e-02, -5.2734e-01,  ...,  3.4180e-01,
           2.0625e+00, -9.8828e-01]],

        [[ 5.6763e-03, -2.8534e-03,  6.3705e-04,  ...,  3.8300e-03,
          -9.3994e-03,  3.9673e-03],
         [-2.5938e+00,  3.1406e+00, -4.9023e-01,  ..., -5.6250e-01,
          -1.3281e+00,  2.0142e-02],
         [-2.4844e+00,  1.4297e+00, -1.9922e+00,  ...,  5.3516e-01,
          -4.1797e-01, -2.4609e-01],
         ...,
         [ 7.5391e-01, -5.7031e-01,  7.5391e-01,  ...,  8.8672e-01,
          -7.6172e-01,  4.5117e-01],
         [ 4.8633e-01, -8.3008e-02,  4.8438e-01,  ..., -6.4941e-02,
          -2.4316e-01,  6.4941e-02],
         [-8.1055e-02,  5.4443e-02, -5.2734e-01,  ...,  3.4180e-01,
           2.0625e+00, -9.8828e-01]],

        [[ 5.6763e-03, -2.8534e-03,  6.3705e-04,  ...,  3.8300e-03,
          -9.3994e-03,  3.9673e-03],
         [-2.5938e+00,  3.1406e+00, -4.9023e-01,  ..., -5.6250e-01,
          -1.3281e+00,  2.0142e-02],
         [-2.4844e+00,  1.4297e+00, -1.9922e+00,  ...,  5.3516e-01,
          -4.1797e-01, -2.4609e-01],
         ...,
         [ 7.5391e-01, -5.7031e-01,  7.5391e-01,  ...,  8.8672e-01,
          -7.6172e-01,  4.5117e-01],
         [ 4.8633e-01, -8.3008e-02,  4.8438e-01,  ..., -6.4941e-02,
          -2.4316e-01,  6.4941e-02],
         [-8.1055e-02,  5.4443e-02, -5.2734e-01,  ...,  3.4180e-01,
           2.0625e+00, -9.8828e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-4.6921e-04,  1.2268e-02,  1.6113e-02,  ...,  2.0703e-01,
           7.5684e-02,  7.1716e-03],
         [-1.4531e+00, -7.7344e-01, -3.1055e-01,  ..., -5.7812e-01,
          -1.6968e-02,  5.7068e-03],
         [-9.0234e-01,  3.5938e-01, -5.3125e-01,  ..., -1.9922e+00,
           2.1250e+00, -2.1406e+00],
         ...,
         [-4.9219e-01, -5.6250e-01,  4.9805e-01,  ...,  1.3359e+00,
          -3.9531e+00, -1.8672e+00],
         [-1.2402e-01,  1.0156e-01,  2.5195e-01,  ...,  7.1875e-01,
          -5.1953e-01, -2.2812e+00],
         [-7.7148e-02,  2.7734e-01, -1.0938e-01,  ..., -9.9609e-01,
           5.5078e-01, -7.0312e-01]],

        [[-4.6921e-04,  1.2268e-02,  1.6113e-02,  ...,  2.0703e-01,
           7.5684e-02,  7.1716e-03],
         [-1.4531e+00, -7.7344e-01, -3.1055e-01,  ..., -5.7812e-01,
          -1.6968e-02,  5.7068e-03],
         [-9.0234e-01,  3.5938e-01, -5.3125e-01,  ..., -1.9922e+00,
           2.1250e+00, -2.1406e+00],
         ...,
         [-4.9219e-01, -5.6250e-01,  4.9805e-01,  ...,  1.3359e+00,
          -3.9531e+00, -1.8672e+00],
         [-1.2402e-01,  1.0156e-01,  2.5195e-01,  ...,  7.1875e-01,
          -5.1953e-01, -2.2812e+00],
         [-7.7148e-02,  2.7734e-01, -1.0938e-01,  ..., -9.9609e-01,
           5.5078e-01, -7.0312e-01]],

        [[-4.6921e-04,  1.2268e-02,  1.6113e-02,  ...,  2.0703e-01,
           7.5684e-02,  7.1716e-03],
         [-1.4531e+00, -7.7344e-01, -3.1055e-01,  ..., -5.7812e-01,
          -1.6968e-02,  5.7068e-03],
         [-9.0234e-01,  3.5938e-01, -5.3125e-01,  ..., -1.9922e+00,
           2.1250e+00, -2.1406e+00],
         ...,
         [-4.9219e-01, -5.6250e-01,  4.9805e-01,  ...,  1.3359e+00,
          -3.9531e+00, -1.8672e+00],
         [-1.2402e-01,  1.0156e-01,  2.5195e-01,  ...,  7.1875e-01,
          -5.1953e-01, -2.2812e+00],
         [-7.7148e-02,  2.7734e-01, -1.0938e-01,  ..., -9.9609e-01,
           5.5078e-01, -7.0312e-01]],

        ...,

        [[-1.4221e-02,  1.0681e-02,  2.6398e-03,  ...,  5.9082e-02,
          -9.4727e-02, -1.1328e-01],
         [ 9.3359e-01, -9.9219e-01,  3.9062e-01,  ..., -4.4727e-01,
          -3.0625e+00, -1.9727e-01],
         [-1.7969e-01, -1.2266e+00,  8.1641e-01,  ...,  1.0234e+00,
          -3.5938e-01,  8.9062e-01],
         ...,
         [-3.9062e-01, -1.0938e+00, -1.1562e+00,  ...,  1.4688e+00,
           4.1797e-01,  1.3516e+00],
         [-6.3281e-01, -8.7500e-01, -1.5312e+00,  ...,  6.4844e-01,
           5.7031e-01,  1.2656e+00],
         [-1.0938e+00, -4.2969e-01, -1.4062e+00,  ...,  1.5312e+00,
           9.1406e-01, -4.4062e+00]],

        [[-1.4221e-02,  1.0681e-02,  2.6398e-03,  ...,  5.9082e-02,
          -9.4727e-02, -1.1328e-01],
         [ 9.3359e-01, -9.9219e-01,  3.9062e-01,  ..., -4.4727e-01,
          -3.0625e+00, -1.9727e-01],
         [-1.7969e-01, -1.2266e+00,  8.1641e-01,  ...,  1.0234e+00,
          -3.5938e-01,  8.9062e-01],
         ...,
         [-3.9062e-01, -1.0938e+00, -1.1562e+00,  ...,  1.4688e+00,
           4.1797e-01,  1.3516e+00],
         [-6.3281e-01, -8.7500e-01, -1.5312e+00,  ...,  6.4844e-01,
           5.7031e-01,  1.2656e+00],
         [-1.0938e+00, -4.2969e-01, -1.4062e+00,  ...,  1.5312e+00,
           9.1406e-01, -4.4062e+00]],

        [[-1.4221e-02,  1.0681e-02,  2.6398e-03,  ...,  5.9082e-02,
          -9.4727e-02, -1.1328e-01],
         [ 9.3359e-01, -9.9219e-01,  3.9062e-01,  ..., -4.4727e-01,
          -3.0625e+00, -1.9727e-01],
         [-1.7969e-01, -1.2266e+00,  8.1641e-01,  ...,  1.0234e+00,
          -3.5938e-01,  8.9062e-01],
         ...,
         [-3.9062e-01, -1.0938e+00, -1.1562e+00,  ...,  1.4688e+00,
           4.1797e-01,  1.3516e+00],
         [-6.3281e-01, -8.7500e-01, -1.5312e+00,  ...,  6.4844e-01,
           5.7031e-01,  1.2656e+00],
         [-1.0938e+00, -4.2969e-01, -1.4062e+00,  ...,  1.5312e+00,
           9.1406e-01, -4.4062e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 9.8877e-03,  2.3193e-03,  7.3547e-03,  ..., -1.6357e-02,
           8.9722e-03, -5.0781e-02],
         [ 2.2461e-01, -3.2422e-01, -6.4062e-01,  ...,  3.5938e-01,
          -6.1328e-01, -2.9053e-02],
         [ 1.3086e-01, -2.3633e-01,  6.6406e-01,  ...,  6.1719e-01,
          -1.0625e+00,  4.8438e-01],
         ...,
         [ 1.2188e+00,  1.3281e-01, -2.0020e-01,  ...,  4.7656e-01,
          -6.7188e-01,  1.3770e-01],
         [ 1.0000e+00,  4.0625e-01, -5.2344e-01,  ...,  2.0605e-01,
          -5.0000e-01,  1.3203e+00],
         [ 3.6133e-01,  5.6250e-01, -5.2344e-01,  ..., -1.4062e-01,
          -6.6797e-01,  1.4844e+00]],

        [[ 9.8877e-03,  2.3193e-03,  7.3547e-03,  ..., -1.6357e-02,
           8.9722e-03, -5.0781e-02],
         [ 2.2461e-01, -3.2422e-01, -6.4062e-01,  ...,  3.5938e-01,
          -6.1328e-01, -2.9053e-02],
         [ 1.3086e-01, -2.3633e-01,  6.6406e-01,  ...,  6.1719e-01,
          -1.0625e+00,  4.8438e-01],
         ...,
         [ 1.2188e+00,  1.3281e-01, -2.0020e-01,  ...,  4.7656e-01,
          -6.7188e-01,  1.3770e-01],
         [ 1.0000e+00,  4.0625e-01, -5.2344e-01,  ...,  2.0605e-01,
          -5.0000e-01,  1.3203e+00],
         [ 3.6133e-01,  5.6250e-01, -5.2344e-01,  ..., -1.4062e-01,
          -6.6797e-01,  1.4844e+00]],

        [[ 9.8877e-03,  2.3193e-03,  7.3547e-03,  ..., -1.6357e-02,
           8.9722e-03, -5.0781e-02],
         [ 2.2461e-01, -3.2422e-01, -6.4062e-01,  ...,  3.5938e-01,
          -6.1328e-01, -2.9053e-02],
         [ 1.3086e-01, -2.3633e-01,  6.6406e-01,  ...,  6.1719e-01,
          -1.0625e+00,  4.8438e-01],
         ...,
         [ 1.2188e+00,  1.3281e-01, -2.0020e-01,  ...,  4.7656e-01,
          -6.7188e-01,  1.3770e-01],
         [ 1.0000e+00,  4.0625e-01, -5.2344e-01,  ...,  2.0605e-01,
          -5.0000e-01,  1.3203e+00],
         [ 3.6133e-01,  5.6250e-01, -5.2344e-01,  ..., -1.4062e-01,
          -6.6797e-01,  1.4844e+00]],

        ...,

        [[ 3.3569e-03, -6.7139e-03, -8.9264e-04,  ...,  1.3046e-03,
           2.6703e-03,  5.0659e-03],
         [-2.7812e+00, -3.8867e-01, -9.0234e-01,  ..., -2.7969e+00,
          -6.2109e-01, -2.7500e+00],
         [-1.2578e+00,  3.4961e-01, -1.9531e-01,  ..., -3.0078e-01,
           8.3984e-01, -1.6484e+00],
         ...,
         [ 1.6797e-01, -1.6602e-01, -1.6699e-01,  ...,  2.8516e-01,
           2.9883e-01, -4.2578e-01],
         [-1.9043e-02,  1.1426e-01, -1.1035e-01,  ..., -8.4961e-02,
           1.2109e-01, -6.0938e-01],
         [-1.0000e+00, -1.6406e-01,  1.4297e+00,  ...,  8.9844e-01,
           1.5703e+00,  4.9805e-01]],

        [[ 3.3569e-03, -6.7139e-03, -8.9264e-04,  ...,  1.3046e-03,
           2.6703e-03,  5.0659e-03],
         [-2.7812e+00, -3.8867e-01, -9.0234e-01,  ..., -2.7969e+00,
          -6.2109e-01, -2.7500e+00],
         [-1.2578e+00,  3.4961e-01, -1.9531e-01,  ..., -3.0078e-01,
           8.3984e-01, -1.6484e+00],
         ...,
         [ 1.6797e-01, -1.6602e-01, -1.6699e-01,  ...,  2.8516e-01,
           2.9883e-01, -4.2578e-01],
         [-1.9043e-02,  1.1426e-01, -1.1035e-01,  ..., -8.4961e-02,
           1.2109e-01, -6.0938e-01],
         [-1.0000e+00, -1.6406e-01,  1.4297e+00,  ...,  8.9844e-01,
           1.5703e+00,  4.9805e-01]],

        [[ 3.3569e-03, -6.7139e-03, -8.9264e-04,  ...,  1.3046e-03,
           2.6703e-03,  5.0659e-03],
         [-2.7812e+00, -3.8867e-01, -9.0234e-01,  ..., -2.7969e+00,
          -6.2109e-01, -2.7500e+00],
         [-1.2578e+00,  3.4961e-01, -1.9531e-01,  ..., -3.0078e-01,
           8.3984e-01, -1.6484e+00],
         ...,
         [ 1.6797e-01, -1.6602e-01, -1.6699e-01,  ...,  2.8516e-01,
           2.9883e-01, -4.2578e-01],
         [-1.9043e-02,  1.1426e-01, -1.1035e-01,  ..., -8.4961e-02,
           1.2109e-01, -6.0938e-01],
         [-1.0000e+00, -1.6406e-01,  1.4297e+00,  ...,  8.9844e-01,
           1.5703e+00,  4.9805e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.6724e-02, -4.3030e-03,  4.7607e-03,  ...,  1.0352e-01,
           8.0566e-02,  3.9062e-01],
         [-4.0625e-01, -7.0312e-01, -8.1250e-01,  ..., -2.0781e+00,
          -1.4922e+00, -6.7812e+00],
         [-1.4688e+00, -1.0312e+00, -9.8828e-01,  ..., -9.4531e-01,
          -2.5469e+00, -6.0625e+00],
         ...,
         [-5.3906e-01, -1.0938e+00,  1.0781e+00,  ..., -1.3984e+00,
          -2.7500e+00, -3.6562e+00],
         [-3.7891e-01, -4.7266e-01,  1.1953e+00,  ..., -1.5938e+00,
          -2.7344e+00, -5.3438e+00],
         [-2.6953e-01,  4.3359e-01,  4.2773e-01,  ..., -1.1172e+00,
          -3.2500e+00, -5.4062e+00]],

        [[-1.6724e-02, -4.3030e-03,  4.7607e-03,  ...,  1.0352e-01,
           8.0566e-02,  3.9062e-01],
         [-4.0625e-01, -7.0312e-01, -8.1250e-01,  ..., -2.0781e+00,
          -1.4922e+00, -6.7812e+00],
         [-1.4688e+00, -1.0312e+00, -9.8828e-01,  ..., -9.4531e-01,
          -2.5469e+00, -6.0625e+00],
         ...,
         [-5.3906e-01, -1.0938e+00,  1.0781e+00,  ..., -1.3984e+00,
          -2.7500e+00, -3.6562e+00],
         [-3.7891e-01, -4.7266e-01,  1.1953e+00,  ..., -1.5938e+00,
          -2.7344e+00, -5.3438e+00],
         [-2.6953e-01,  4.3359e-01,  4.2773e-01,  ..., -1.1172e+00,
          -3.2500e+00, -5.4062e+00]],

        [[-1.6724e-02, -4.3030e-03,  4.7607e-03,  ...,  1.0352e-01,
           8.0566e-02,  3.9062e-01],
         [-4.0625e-01, -7.0312e-01, -8.1250e-01,  ..., -2.0781e+00,
          -1.4922e+00, -6.7812e+00],
         [-1.4688e+00, -1.0312e+00, -9.8828e-01,  ..., -9.4531e-01,
          -2.5469e+00, -6.0625e+00],
         ...,
         [-5.3906e-01, -1.0938e+00,  1.0781e+00,  ..., -1.3984e+00,
          -2.7500e+00, -3.6562e+00],
         [-3.7891e-01, -4.7266e-01,  1.1953e+00,  ..., -1.5938e+00,
          -2.7344e+00, -5.3438e+00],
         [-2.6953e-01,  4.3359e-01,  4.2773e-01,  ..., -1.1172e+00,
          -3.2500e+00, -5.4062e+00]],

        ...,

        [[-5.0354e-04,  4.9744e-03,  1.7471e-03,  ...,  4.1504e-02,
          -3.3203e-02,  4.9316e-02],
         [ 9.1016e-01, -8.8672e-01, -8.7891e-01,  ..., -1.8828e+00,
           1.9297e+00,  2.3047e-01],
         [-1.1406e+00,  1.6895e-01, -6.6797e-01,  ..., -1.1094e+00,
           4.7852e-01,  2.0312e+00],
         ...,
         [ 4.3750e-01, -1.3281e+00,  1.2422e+00,  ...,  8.4375e-01,
           1.1797e+00,  9.6191e-02],
         [-2.6367e-02, -8.2031e-01,  0.0000e+00,  ..., -6.2891e-01,
           9.3750e-01, -8.6328e-01],
         [-2.2500e+00,  7.5391e-01, -7.5781e-01,  ..., -2.4902e-02,
           1.8750e+00,  2.4062e+00]],

        [[-5.0354e-04,  4.9744e-03,  1.7471e-03,  ...,  4.1504e-02,
          -3.3203e-02,  4.9316e-02],
         [ 9.1016e-01, -8.8672e-01, -8.7891e-01,  ..., -1.8828e+00,
           1.9297e+00,  2.3047e-01],
         [-1.1406e+00,  1.6895e-01, -6.6797e-01,  ..., -1.1094e+00,
           4.7852e-01,  2.0312e+00],
         ...,
         [ 4.3750e-01, -1.3281e+00,  1.2422e+00,  ...,  8.4375e-01,
           1.1797e+00,  9.6191e-02],
         [-2.6367e-02, -8.2031e-01,  0.0000e+00,  ..., -6.2891e-01,
           9.3750e-01, -8.6328e-01],
         [-2.2500e+00,  7.5391e-01, -7.5781e-01,  ..., -2.4902e-02,
           1.8750e+00,  2.4062e+00]],

        [[-5.0354e-04,  4.9744e-03,  1.7471e-03,  ...,  4.1504e-02,
          -3.3203e-02,  4.9316e-02],
         [ 9.1016e-01, -8.8672e-01, -8.7891e-01,  ..., -1.8828e+00,
           1.9297e+00,  2.3047e-01],
         [-1.1406e+00,  1.6895e-01, -6.6797e-01,  ..., -1.1094e+00,
           4.7852e-01,  2.0312e+00],
         ...,
         [ 4.3750e-01, -1.3281e+00,  1.2422e+00,  ...,  8.4375e-01,
           1.1797e+00,  9.6191e-02],
         [-2.6367e-02, -8.2031e-01,  0.0000e+00,  ..., -6.2891e-01,
           9.3750e-01, -8.6328e-01],
         [-2.2500e+00,  7.5391e-01, -7.5781e-01,  ..., -2.4902e-02,
           1.8750e+00,  2.4062e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 1.1902e-03, -6.4392e-03,  8.5449e-03,  ...,  5.4626e-03,
           7.1716e-03,  7.1106e-03],
         [ 5.1562e-01,  1.0625e+00, -7.3438e-01,  ...,  1.2266e+00,
          -1.7422e+00,  6.5625e-01],
         [ 3.5156e-01, -6.4844e-01, -9.4531e-01,  ...,  4.4727e-01,
          -1.4531e+00,  6.0938e-01],
         ...,
         [ 7.3242e-02,  1.6484e+00, -2.2363e-01,  ..., -1.1094e+00,
          -1.3906e+00,  6.1719e-01],
         [ 3.8867e-01,  3.4570e-01,  1.3477e-01,  ...,  2.0215e-01,
          -1.8164e-01,  4.9414e-01],
         [ 1.0156e+00,  4.2578e-01, -1.3438e+00,  ..., -5.2490e-02,
          -1.3281e-01,  7.1777e-02]],

        [[ 1.1902e-03, -6.4392e-03,  8.5449e-03,  ...,  5.4626e-03,
           7.1716e-03,  7.1106e-03],
         [ 5.1562e-01,  1.0625e+00, -7.3438e-01,  ...,  1.2266e+00,
          -1.7422e+00,  6.5625e-01],
         [ 3.5156e-01, -6.4844e-01, -9.4531e-01,  ...,  4.4727e-01,
          -1.4531e+00,  6.0938e-01],
         ...,
         [ 7.3242e-02,  1.6484e+00, -2.2363e-01,  ..., -1.1094e+00,
          -1.3906e+00,  6.1719e-01],
         [ 3.8867e-01,  3.4570e-01,  1.3477e-01,  ...,  2.0215e-01,
          -1.8164e-01,  4.9414e-01],
         [ 1.0156e+00,  4.2578e-01, -1.3438e+00,  ..., -5.2490e-02,
          -1.3281e-01,  7.1777e-02]],

        [[ 1.1902e-03, -6.4392e-03,  8.5449e-03,  ...,  5.4626e-03,
           7.1716e-03,  7.1106e-03],
         [ 5.1562e-01,  1.0625e+00, -7.3438e-01,  ...,  1.2266e+00,
          -1.7422e+00,  6.5625e-01],
         [ 3.5156e-01, -6.4844e-01, -9.4531e-01,  ...,  4.4727e-01,
          -1.4531e+00,  6.0938e-01],
         ...,
         [ 7.3242e-02,  1.6484e+00, -2.2363e-01,  ..., -1.1094e+00,
          -1.3906e+00,  6.1719e-01],
         [ 3.8867e-01,  3.4570e-01,  1.3477e-01,  ...,  2.0215e-01,
          -1.8164e-01,  4.9414e-01],
         [ 1.0156e+00,  4.2578e-01, -1.3438e+00,  ..., -5.2490e-02,
          -1.3281e-01,  7.1777e-02]],

        ...,

        [[-5.5847e-03, -3.2501e-03,  4.8523e-03,  ...,  1.0010e-02,
          -5.5542e-03,  7.4768e-03],
         [ 5.9375e-01, -1.4160e-01, -1.7383e-01,  ...,  2.6758e-01,
          -4.1992e-01,  2.2949e-01],
         [-2.5781e-01,  1.6562e+00,  1.6172e+00,  ...,  1.5781e+00,
           7.8516e-01,  3.4375e-01],
         ...,
         [-1.2812e+00,  8.8672e-01, -1.3281e+00,  ..., -8.7109e-01,
           1.0078e+00, -7.8125e-01],
         [-3.3594e-01,  2.1191e-01, -7.7148e-02,  ...,  1.1768e-01,
           8.4766e-01,  3.8477e-01],
         [-4.4336e-01, -1.6484e+00,  3.8477e-01,  ..., -1.3203e+00,
           3.1445e-01, -3.1641e-01]],

        [[-5.5847e-03, -3.2501e-03,  4.8523e-03,  ...,  1.0010e-02,
          -5.5542e-03,  7.4768e-03],
         [ 5.9375e-01, -1.4160e-01, -1.7383e-01,  ...,  2.6758e-01,
          -4.1992e-01,  2.2949e-01],
         [-2.5781e-01,  1.6562e+00,  1.6172e+00,  ...,  1.5781e+00,
           7.8516e-01,  3.4375e-01],
         ...,
         [-1.2812e+00,  8.8672e-01, -1.3281e+00,  ..., -8.7109e-01,
           1.0078e+00, -7.8125e-01],
         [-3.3594e-01,  2.1191e-01, -7.7148e-02,  ...,  1.1768e-01,
           8.4766e-01,  3.8477e-01],
         [-4.4336e-01, -1.6484e+00,  3.8477e-01,  ..., -1.3203e+00,
           3.1445e-01, -3.1641e-01]],

        [[-5.5847e-03, -3.2501e-03,  4.8523e-03,  ...,  1.0010e-02,
          -5.5542e-03,  7.4768e-03],
         [ 5.9375e-01, -1.4160e-01, -1.7383e-01,  ...,  2.6758e-01,
          -4.1992e-01,  2.2949e-01],
         [-2.5781e-01,  1.6562e+00,  1.6172e+00,  ...,  1.5781e+00,
           7.8516e-01,  3.4375e-01],
         ...,
         [-1.2812e+00,  8.8672e-01, -1.3281e+00,  ..., -8.7109e-01,
           1.0078e+00, -7.8125e-01],
         [-3.3594e-01,  2.1191e-01, -7.7148e-02,  ...,  1.1768e-01,
           8.4766e-01,  3.8477e-01],
         [-4.4336e-01, -1.6484e+00,  3.8477e-01,  ..., -1.3203e+00,
           3.1445e-01, -3.1641e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.4414e-03,  9.9487e-03,  9.2773e-03,  ..., -2.6367e-01,
           4.9316e-02,  1.9043e-02],
         [-8.3594e-01, -7.3828e-01,  3.5352e-01,  ...,  4.2812e+00,
          -3.0469e-01, -2.1875e+00],
         [ 6.0156e-01,  1.6406e-01,  8.0078e-02,  ...,  6.3125e+00,
          -2.9883e-01,  1.1658e-02],
         ...,
         [ 2.5391e-01, -1.4453e-01, -2.3535e-01,  ...,  2.1562e+00,
           1.7285e-01,  7.7734e-01],
         [ 1.3125e+00,  3.8281e-01, -6.7969e-01,  ...,  3.4062e+00,
           2.0020e-01,  8.7891e-03],
         [ 9.6094e-01, -1.1328e+00, -2.9102e-01,  ...,  3.6875e+00,
          -1.0078e+00, -1.9043e-01]],

        [[ 2.4414e-03,  9.9487e-03,  9.2773e-03,  ..., -2.6367e-01,
           4.9316e-02,  1.9043e-02],
         [-8.3594e-01, -7.3828e-01,  3.5352e-01,  ...,  4.2812e+00,
          -3.0469e-01, -2.1875e+00],
         [ 6.0156e-01,  1.6406e-01,  8.0078e-02,  ...,  6.3125e+00,
          -2.9883e-01,  1.1658e-02],
         ...,
         [ 2.5391e-01, -1.4453e-01, -2.3535e-01,  ...,  2.1562e+00,
           1.7285e-01,  7.7734e-01],
         [ 1.3125e+00,  3.8281e-01, -6.7969e-01,  ...,  3.4062e+00,
           2.0020e-01,  8.7891e-03],
         [ 9.6094e-01, -1.1328e+00, -2.9102e-01,  ...,  3.6875e+00,
          -1.0078e+00, -1.9043e-01]],

        [[ 2.4414e-03,  9.9487e-03,  9.2773e-03,  ..., -2.6367e-01,
           4.9316e-02,  1.9043e-02],
         [-8.3594e-01, -7.3828e-01,  3.5352e-01,  ...,  4.2812e+00,
          -3.0469e-01, -2.1875e+00],
         [ 6.0156e-01,  1.6406e-01,  8.0078e-02,  ...,  6.3125e+00,
          -2.9883e-01,  1.1658e-02],
         ...,
         [ 2.5391e-01, -1.4453e-01, -2.3535e-01,  ...,  2.1562e+00,
           1.7285e-01,  7.7734e-01],
         [ 1.3125e+00,  3.8281e-01, -6.7969e-01,  ...,  3.4062e+00,
           2.0020e-01,  8.7891e-03],
         [ 9.6094e-01, -1.1328e+00, -2.9102e-01,  ...,  3.6875e+00,
          -1.0078e+00, -1.9043e-01]],

        ...,

        [[-6.4850e-04,  4.4250e-03, -1.9684e-03,  ...,  9.6680e-02,
          -5.8984e-01, -5.2246e-02],
         [-2.4609e-01, -6.4844e-01, -1.1133e-01,  ...,  2.4062e+00,
           8.1250e+00,  3.2188e+00],
         [-8.2031e-02,  6.2500e-01, -2.9492e-01,  ..., -3.3750e+00,
           1.0375e+01,  4.1406e-01],
         ...,
         [ 3.3398e-01, -3.7305e-01,  6.9531e-01,  ...,  3.3789e-01,
           8.4375e+00,  5.9375e+00],
         [-5.8594e-02,  1.0254e-01,  2.1094e-01,  ..., -1.3281e+00,
           7.4375e+00,  1.9531e+00],
         [-3.3789e-01,  1.7578e-02, -2.4170e-02,  ...,  9.6484e-01,
           9.0625e+00, -5.2812e+00]],

        [[-6.4850e-04,  4.4250e-03, -1.9684e-03,  ...,  9.6680e-02,
          -5.8984e-01, -5.2246e-02],
         [-2.4609e-01, -6.4844e-01, -1.1133e-01,  ...,  2.4062e+00,
           8.1250e+00,  3.2188e+00],
         [-8.2031e-02,  6.2500e-01, -2.9492e-01,  ..., -3.3750e+00,
           1.0375e+01,  4.1406e-01],
         ...,
         [ 3.3398e-01, -3.7305e-01,  6.9531e-01,  ...,  3.3789e-01,
           8.4375e+00,  5.9375e+00],
         [-5.8594e-02,  1.0254e-01,  2.1094e-01,  ..., -1.3281e+00,
           7.4375e+00,  1.9531e+00],
         [-3.3789e-01,  1.7578e-02, -2.4170e-02,  ...,  9.6484e-01,
           9.0625e+00, -5.2812e+00]],

        [[-6.4850e-04,  4.4250e-03, -1.9684e-03,  ...,  9.6680e-02,
          -5.8984e-01, -5.2246e-02],
         [-2.4609e-01, -6.4844e-01, -1.1133e-01,  ...,  2.4062e+00,
           8.1250e+00,  3.2188e+00],
         [-8.2031e-02,  6.2500e-01, -2.9492e-01,  ..., -3.3750e+00,
           1.0375e+01,  4.1406e-01],
         ...,
         [ 3.3398e-01, -3.7305e-01,  6.9531e-01,  ...,  3.3789e-01,
           8.4375e+00,  5.9375e+00],
         [-5.8594e-02,  1.0254e-01,  2.1094e-01,  ..., -1.3281e+00,
           7.4375e+00,  1.9531e+00],
         [-3.3789e-01,  1.7578e-02, -2.4170e-02,  ...,  9.6484e-01,
           9.0625e+00, -5.2812e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-4.2114e-03,  1.7700e-03,  8.1787e-03,  ...,  2.0599e-03,
          -3.2616e-04, -2.4414e-03],
         [ 2.2031e+00, -1.1377e-01, -2.0156e+00,  ...,  1.6719e+00,
          -4.4375e+00, -2.3594e+00],
         [ 5.7812e-01, -7.9297e-01, -7.2266e-01,  ...,  6.5430e-02,
           8.1543e-02, -1.7891e+00],
         ...,
         [ 8.0859e-01, -2.1250e+00, -3.8477e-01,  ..., -9.8145e-02,
          -1.8945e-01, -3.4375e-01],
         [ 1.1016e+00,  3.6719e-01,  1.8750e-01,  ..., -1.4062e-01,
           7.4707e-02, -4.2383e-01],
         [ 4.7461e-01,  1.8359e+00,  2.2969e+00,  ...,  1.2812e+00,
          -8.9355e-02, -1.1406e+00]],

        [[-4.2114e-03,  1.7700e-03,  8.1787e-03,  ...,  2.0599e-03,
          -3.2616e-04, -2.4414e-03],
         [ 2.2031e+00, -1.1377e-01, -2.0156e+00,  ...,  1.6719e+00,
          -4.4375e+00, -2.3594e+00],
         [ 5.7812e-01, -7.9297e-01, -7.2266e-01,  ...,  6.5430e-02,
           8.1543e-02, -1.7891e+00],
         ...,
         [ 8.0859e-01, -2.1250e+00, -3.8477e-01,  ..., -9.8145e-02,
          -1.8945e-01, -3.4375e-01],
         [ 1.1016e+00,  3.6719e-01,  1.8750e-01,  ..., -1.4062e-01,
           7.4707e-02, -4.2383e-01],
         [ 4.7461e-01,  1.8359e+00,  2.2969e+00,  ...,  1.2812e+00,
          -8.9355e-02, -1.1406e+00]],

        [[-4.2114e-03,  1.7700e-03,  8.1787e-03,  ...,  2.0599e-03,
          -3.2616e-04, -2.4414e-03],
         [ 2.2031e+00, -1.1377e-01, -2.0156e+00,  ...,  1.6719e+00,
          -4.4375e+00, -2.3594e+00],
         [ 5.7812e-01, -7.9297e-01, -7.2266e-01,  ...,  6.5430e-02,
           8.1543e-02, -1.7891e+00],
         ...,
         [ 8.0859e-01, -2.1250e+00, -3.8477e-01,  ..., -9.8145e-02,
          -1.8945e-01, -3.4375e-01],
         [ 1.1016e+00,  3.6719e-01,  1.8750e-01,  ..., -1.4062e-01,
           7.4707e-02, -4.2383e-01],
         [ 4.7461e-01,  1.8359e+00,  2.2969e+00,  ...,  1.2812e+00,
          -8.9355e-02, -1.1406e+00]],

        ...,

        [[ 1.8311e-02, -2.3499e-03,  1.9455e-04,  ..., -1.7334e-02,
           4.1389e-04,  4.0283e-03],
         [ 2.8125e-01,  3.6914e-01, -3.3398e-01,  ...,  1.3672e+00,
           4.8047e-01,  6.4062e-01],
         [-1.0391e+00, -1.2695e-01,  4.0039e-01,  ...,  5.3467e-02,
           1.9043e-01,  2.7148e-01],
         ...,
         [ 4.8828e-01,  3.5742e-01,  5.7422e-01,  ..., -5.3516e-01,
          -3.8867e-01,  4.6484e-01],
         [-7.7637e-02,  4.4336e-01,  3.5400e-02,  ...,  3.3008e-01,
           2.5757e-02, -7.4707e-02],
         [-9.7656e-01,  1.4062e+00, -1.0303e-01,  ..., -1.0938e+00,
           1.2891e+00,  2.2461e-01]],

        [[ 1.8311e-02, -2.3499e-03,  1.9455e-04,  ..., -1.7334e-02,
           4.1389e-04,  4.0283e-03],
         [ 2.8125e-01,  3.6914e-01, -3.3398e-01,  ...,  1.3672e+00,
           4.8047e-01,  6.4062e-01],
         [-1.0391e+00, -1.2695e-01,  4.0039e-01,  ...,  5.3467e-02,
           1.9043e-01,  2.7148e-01],
         ...,
         [ 4.8828e-01,  3.5742e-01,  5.7422e-01,  ..., -5.3516e-01,
          -3.8867e-01,  4.6484e-01],
         [-7.7637e-02,  4.4336e-01,  3.5400e-02,  ...,  3.3008e-01,
           2.5757e-02, -7.4707e-02],
         [-9.7656e-01,  1.4062e+00, -1.0303e-01,  ..., -1.0938e+00,
           1.2891e+00,  2.2461e-01]],

        [[ 1.8311e-02, -2.3499e-03,  1.9455e-04,  ..., -1.7334e-02,
           4.1389e-04,  4.0283e-03],
         [ 2.8125e-01,  3.6914e-01, -3.3398e-01,  ...,  1.3672e+00,
           4.8047e-01,  6.4062e-01],
         [-1.0391e+00, -1.2695e-01,  4.0039e-01,  ...,  5.3467e-02,
           1.9043e-01,  2.7148e-01],
         ...,
         [ 4.8828e-01,  3.5742e-01,  5.7422e-01,  ..., -5.3516e-01,
          -3.8867e-01,  4.6484e-01],
         [-7.7637e-02,  4.4336e-01,  3.5400e-02,  ...,  3.3008e-01,
           2.5757e-02, -7.4707e-02],
         [-9.7656e-01,  1.4062e+00, -1.0303e-01,  ..., -1.0938e+00,
           1.2891e+00,  2.2461e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-3.4943e-03,  2.4414e-02,  5.6839e-04,  ...,  5.4932e-02,
          -1.2354e-01, -2.4023e-01],
         [-1.8359e-01, -8.9844e-02, -1.6953e+00,  ..., -1.0156e+00,
           3.7188e+00,  4.8047e-01],
         [-1.4062e+00, -1.9219e+00, -7.4609e-01,  ...,  1.8672e+00,
          -7.1716e-03,  3.0625e+00],
         ...,
         [ 2.2559e-01, -4.1797e-01,  3.4766e-01,  ..., -1.9141e+00,
           8.8672e-01,  1.4609e+00],
         [-6.6895e-02, -5.7812e-01, -7.1289e-02,  ..., -7.0312e-01,
           8.4375e-01,  2.1094e+00],
         [-4.0039e-01,  4.0625e-01, -2.4219e-01,  ..., -3.5938e-01,
           8.2031e-01,  2.0703e-01]],

        [[-3.4943e-03,  2.4414e-02,  5.6839e-04,  ...,  5.4932e-02,
          -1.2354e-01, -2.4023e-01],
         [-1.8359e-01, -8.9844e-02, -1.6953e+00,  ..., -1.0156e+00,
           3.7188e+00,  4.8047e-01],
         [-1.4062e+00, -1.9219e+00, -7.4609e-01,  ...,  1.8672e+00,
          -7.1716e-03,  3.0625e+00],
         ...,
         [ 2.2559e-01, -4.1797e-01,  3.4766e-01,  ..., -1.9141e+00,
           8.8672e-01,  1.4609e+00],
         [-6.6895e-02, -5.7812e-01, -7.1289e-02,  ..., -7.0312e-01,
           8.4375e-01,  2.1094e+00],
         [-4.0039e-01,  4.0625e-01, -2.4219e-01,  ..., -3.5938e-01,
           8.2031e-01,  2.0703e-01]],

        [[-3.4943e-03,  2.4414e-02,  5.6839e-04,  ...,  5.4932e-02,
          -1.2354e-01, -2.4023e-01],
         [-1.8359e-01, -8.9844e-02, -1.6953e+00,  ..., -1.0156e+00,
           3.7188e+00,  4.8047e-01],
         [-1.4062e+00, -1.9219e+00, -7.4609e-01,  ...,  1.8672e+00,
          -7.1716e-03,  3.0625e+00],
         ...,
         [ 2.2559e-01, -4.1797e-01,  3.4766e-01,  ..., -1.9141e+00,
           8.8672e-01,  1.4609e+00],
         [-6.6895e-02, -5.7812e-01, -7.1289e-02,  ..., -7.0312e-01,
           8.4375e-01,  2.1094e+00],
         [-4.0039e-01,  4.0625e-01, -2.4219e-01,  ..., -3.5938e-01,
           8.2031e-01,  2.0703e-01]],

        ...,

        [[ 1.0729e-04,  8.3008e-03, -1.4099e-02,  ...,  5.6763e-03,
          -5.2979e-02,  1.0864e-02],
         [-3.2031e-01, -1.0498e-01,  6.7188e-01,  ..., -9.3750e-01,
           4.4688e+00, -1.2344e+00],
         [-5.8203e-01, -7.8906e-01,  1.0156e+00,  ...,  1.8516e+00,
           2.5312e+00, -1.0547e+00],
         ...,
         [-2.6562e-01, -3.9844e-01, -1.0391e+00,  ...,  1.5918e-01,
           8.0859e-01, -2.3125e+00],
         [-1.4453e-01,  5.5469e-01, -1.4746e-01,  ...,  5.8984e-01,
           1.8906e+00, -2.2461e-01],
         [-4.3164e-01, -1.8750e-01,  6.9531e-01,  ..., -6.6016e-01,
          -1.9727e-01, -6.8750e-01]],

        [[ 1.0729e-04,  8.3008e-03, -1.4099e-02,  ...,  5.6763e-03,
          -5.2979e-02,  1.0864e-02],
         [-3.2031e-01, -1.0498e-01,  6.7188e-01,  ..., -9.3750e-01,
           4.4688e+00, -1.2344e+00],
         [-5.8203e-01, -7.8906e-01,  1.0156e+00,  ...,  1.8516e+00,
           2.5312e+00, -1.0547e+00],
         ...,
         [-2.6562e-01, -3.9844e-01, -1.0391e+00,  ...,  1.5918e-01,
           8.0859e-01, -2.3125e+00],
         [-1.4453e-01,  5.5469e-01, -1.4746e-01,  ...,  5.8984e-01,
           1.8906e+00, -2.2461e-01],
         [-4.3164e-01, -1.8750e-01,  6.9531e-01,  ..., -6.6016e-01,
          -1.9727e-01, -6.8750e-01]],

        [[ 1.0729e-04,  8.3008e-03, -1.4099e-02,  ...,  5.6763e-03,
          -5.2979e-02,  1.0864e-02],
         [-3.2031e-01, -1.0498e-01,  6.7188e-01,  ..., -9.3750e-01,
           4.4688e+00, -1.2344e+00],
         [-5.8203e-01, -7.8906e-01,  1.0156e+00,  ...,  1.8516e+00,
           2.5312e+00, -1.0547e+00],
         ...,
         [-2.6562e-01, -3.9844e-01, -1.0391e+00,  ...,  1.5918e-01,
           8.0859e-01, -2.3125e+00],
         [-1.4453e-01,  5.5469e-01, -1.4746e-01,  ...,  5.8984e-01,
           1.8906e+00, -2.2461e-01],
         [-4.3164e-01, -1.8750e-01,  6.9531e-01,  ..., -6.6016e-01,
          -1.9727e-01, -6.8750e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 4.5776e-03,  1.1414e-02,  6.9275e-03,  ...,  3.7994e-03,
          -2.7161e-03, -2.9755e-03],
         [ 1.0469e+00, -4.7812e+00, -1.7734e+00,  ..., -2.7188e+00,
           1.1641e+00,  6.6016e-01],
         [ 6.2891e-01, -1.8359e+00, -9.1797e-01,  ...,  1.0010e-02,
           1.4453e+00,  1.7266e+00],
         ...,
         [-4.3945e-01, -1.0547e+00, -9.0625e-01,  ...,  1.2656e+00,
           1.0078e+00,  1.0469e+00],
         [ 7.3730e-02,  3.5156e-02, -8.3203e-01,  ...,  6.7969e-01,
           6.2891e-01,  4.6875e-01],
         [-6.0156e-01,  3.8867e-01,  1.1328e+00,  ..., -5.0391e-01,
           3.0664e-01,  8.8672e-01]],

        [[ 4.5776e-03,  1.1414e-02,  6.9275e-03,  ...,  3.7994e-03,
          -2.7161e-03, -2.9755e-03],
         [ 1.0469e+00, -4.7812e+00, -1.7734e+00,  ..., -2.7188e+00,
           1.1641e+00,  6.6016e-01],
         [ 6.2891e-01, -1.8359e+00, -9.1797e-01,  ...,  1.0010e-02,
           1.4453e+00,  1.7266e+00],
         ...,
         [-4.3945e-01, -1.0547e+00, -9.0625e-01,  ...,  1.2656e+00,
           1.0078e+00,  1.0469e+00],
         [ 7.3730e-02,  3.5156e-02, -8.3203e-01,  ...,  6.7969e-01,
           6.2891e-01,  4.6875e-01],
         [-6.0156e-01,  3.8867e-01,  1.1328e+00,  ..., -5.0391e-01,
           3.0664e-01,  8.8672e-01]],

        [[ 4.5776e-03,  1.1414e-02,  6.9275e-03,  ...,  3.7994e-03,
          -2.7161e-03, -2.9755e-03],
         [ 1.0469e+00, -4.7812e+00, -1.7734e+00,  ..., -2.7188e+00,
           1.1641e+00,  6.6016e-01],
         [ 6.2891e-01, -1.8359e+00, -9.1797e-01,  ...,  1.0010e-02,
           1.4453e+00,  1.7266e+00],
         ...,
         [-4.3945e-01, -1.0547e+00, -9.0625e-01,  ...,  1.2656e+00,
           1.0078e+00,  1.0469e+00],
         [ 7.3730e-02,  3.5156e-02, -8.3203e-01,  ...,  6.7969e-01,
           6.2891e-01,  4.6875e-01],
         [-6.0156e-01,  3.8867e-01,  1.1328e+00,  ..., -5.0391e-01,
           3.0664e-01,  8.8672e-01]],

        ...,

        [[-3.8757e-03,  1.6327e-03,  7.3242e-03,  ..., -1.7822e-02,
          -6.9336e-02,  4.8447e-04],
         [-1.7344e+00,  4.6094e-01, -3.1875e+00,  ..., -2.4609e-01,
           2.3125e+00, -3.9688e+00],
         [ 8.3984e-01,  1.3281e+00, -1.0938e+00,  ...,  2.6953e-01,
          -2.3145e-01, -3.3203e-01],
         ...,
         [ 6.9922e-01, -7.2266e-01,  5.9375e-01,  ..., -9.2188e-01,
          -6.6016e-01,  8.5938e-01],
         [ 8.8281e-01,  3.8477e-01, -1.2793e-01,  ...,  4.5117e-01,
           1.9844e+00,  4.0234e-01],
         [ 1.0625e+00,  2.6758e-01, -1.9375e+00,  ...,  1.5000e+00,
           1.0078e+00,  1.0391e+00]],

        [[-3.8757e-03,  1.6327e-03,  7.3242e-03,  ..., -1.7822e-02,
          -6.9336e-02,  4.8447e-04],
         [-1.7344e+00,  4.6094e-01, -3.1875e+00,  ..., -2.4609e-01,
           2.3125e+00, -3.9688e+00],
         [ 8.3984e-01,  1.3281e+00, -1.0938e+00,  ...,  2.6953e-01,
          -2.3145e-01, -3.3203e-01],
         ...,
         [ 6.9922e-01, -7.2266e-01,  5.9375e-01,  ..., -9.2188e-01,
          -6.6016e-01,  8.5938e-01],
         [ 8.8281e-01,  3.8477e-01, -1.2793e-01,  ...,  4.5117e-01,
           1.9844e+00,  4.0234e-01],
         [ 1.0625e+00,  2.6758e-01, -1.9375e+00,  ...,  1.5000e+00,
           1.0078e+00,  1.0391e+00]],

        [[-3.8757e-03,  1.6327e-03,  7.3242e-03,  ..., -1.7822e-02,
          -6.9336e-02,  4.8447e-04],
         [-1.7344e+00,  4.6094e-01, -3.1875e+00,  ..., -2.4609e-01,
           2.3125e+00, -3.9688e+00],
         [ 8.3984e-01,  1.3281e+00, -1.0938e+00,  ...,  2.6953e-01,
          -2.3145e-01, -3.3203e-01],
         ...,
         [ 6.9922e-01, -7.2266e-01,  5.9375e-01,  ..., -9.2188e-01,
          -6.6016e-01,  8.5938e-01],
         [ 8.8281e-01,  3.8477e-01, -1.2793e-01,  ...,  4.5117e-01,
           1.9844e+00,  4.0234e-01],
         [ 1.0625e+00,  2.6758e-01, -1.9375e+00,  ...,  1.5000e+00,
           1.0078e+00,  1.0391e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-0.0067,  0.0189, -0.0097,  ...,  0.1387,  0.0334,  0.0076],
         [-0.2676, -0.6953,  1.0625,  ..., -3.4531, -0.6523,  1.7344],
         [-0.5508,  1.1797,  0.9883,  ..., -4.2188,  0.6836,  0.8125],
         ...,
         [-0.1157,  0.5312, -0.1660,  ..., -4.0312, -1.5781, -1.5156],
         [-0.1172,  0.1138,  0.5273,  ..., -3.8750, -1.6719, -0.4551],
         [ 0.5469, -0.3672,  1.1875,  ..., -3.5625, -0.6055,  0.0591]],

        [[-0.0067,  0.0189, -0.0097,  ...,  0.1387,  0.0334,  0.0076],
         [-0.2676, -0.6953,  1.0625,  ..., -3.4531, -0.6523,  1.7344],
         [-0.5508,  1.1797,  0.9883,  ..., -4.2188,  0.6836,  0.8125],
         ...,
         [-0.1157,  0.5312, -0.1660,  ..., -4.0312, -1.5781, -1.5156],
         [-0.1172,  0.1138,  0.5273,  ..., -3.8750, -1.6719, -0.4551],
         [ 0.5469, -0.3672,  1.1875,  ..., -3.5625, -0.6055,  0.0591]],

        [[-0.0067,  0.0189, -0.0097,  ...,  0.1387,  0.0334,  0.0076],
         [-0.2676, -0.6953,  1.0625,  ..., -3.4531, -0.6523,  1.7344],
         [-0.5508,  1.1797,  0.9883,  ..., -4.2188,  0.6836,  0.8125],
         ...,
         [-0.1157,  0.5312, -0.1660,  ..., -4.0312, -1.5781, -1.5156],
         [-0.1172,  0.1138,  0.5273,  ..., -3.8750, -1.6719, -0.4551],
         [ 0.5469, -0.3672,  1.1875,  ..., -3.5625, -0.6055,  0.0591]],

        ...,

        [[-0.0256,  0.0177,  0.0053,  ...,  0.0048, -0.0129, -0.0820],
         [ 1.0859, -0.6836, -0.5078,  ...,  0.0684, -0.6016,  1.6953],
         [-1.2891, -1.2500,  0.5430,  ...,  1.9766,  0.4082,  2.3125],
         ...,
         [ 0.7500,  0.3965, -1.0234,  ..., -0.2041, -0.9102,  0.6211],
         [-1.3516,  0.5820, -0.5195,  ..., -0.0172, -0.9375,  0.9922],
         [-2.9062,  1.6953, -0.3223,  ..., -1.8125, -0.1445,  0.2734]],

        [[-0.0256,  0.0177,  0.0053,  ...,  0.0048, -0.0129, -0.0820],
         [ 1.0859, -0.6836, -0.5078,  ...,  0.0684, -0.6016,  1.6953],
         [-1.2891, -1.2500,  0.5430,  ...,  1.9766,  0.4082,  2.3125],
         ...,
         [ 0.7500,  0.3965, -1.0234,  ..., -0.2041, -0.9102,  0.6211],
         [-1.3516,  0.5820, -0.5195,  ..., -0.0172, -0.9375,  0.9922],
         [-2.9062,  1.6953, -0.3223,  ..., -1.8125, -0.1445,  0.2734]],

        [[-0.0256,  0.0177,  0.0053,  ...,  0.0048, -0.0129, -0.0820],
         [ 1.0859, -0.6836, -0.5078,  ...,  0.0684, -0.6016,  1.6953],
         [-1.2891, -1.2500,  0.5430,  ...,  1.9766,  0.4082,  2.3125],
         ...,
         [ 0.7500,  0.3965, -1.0234,  ..., -0.2041, -0.9102,  0.6211],
         [-1.3516,  0.5820, -0.5195,  ..., -0.0172, -0.9375,  0.9922],
         [-2.9062,  1.6953, -0.3223,  ..., -1.8125, -0.1445,  0.2734]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[-5.5542e-03, -1.3123e-03, -4.8218e-03,  ...,  5.2185e-03,
          -5.2185e-03, -7.4768e-04],
         [-1.1562e+00,  5.5469e-01,  4.4922e-01,  ..., -1.0078e+00,
          -3.0859e-01,  2.7734e-01],
         [ 5.1172e-01, -2.8711e-01,  3.6865e-02,  ...,  5.1953e-01,
           8.7109e-01, -9.5312e-01],
         ...,
         [ 3.4180e-01,  3.7109e-01, -3.6377e-02,  ..., -1.5723e-01,
          -1.6641e+00, -7.6562e-01],
         [ 7.1289e-02,  1.4221e-02,  1.0742e-01,  ..., -3.0859e-01,
          -1.1953e+00, -7.9688e-01],
         [ 5.3516e-01,  8.4766e-01, -3.1055e-01,  ...,  3.7305e-01,
          -4.6680e-01,  2.9883e-01]],

        [[-5.5542e-03, -1.3123e-03, -4.8218e-03,  ...,  5.2185e-03,
          -5.2185e-03, -7.4768e-04],
         [-1.1562e+00,  5.5469e-01,  4.4922e-01,  ..., -1.0078e+00,
          -3.0859e-01,  2.7734e-01],
         [ 5.1172e-01, -2.8711e-01,  3.6865e-02,  ...,  5.1953e-01,
           8.7109e-01, -9.5312e-01],
         ...,
         [ 3.4180e-01,  3.7109e-01, -3.6377e-02,  ..., -1.5723e-01,
          -1.6641e+00, -7.6562e-01],
         [ 7.1289e-02,  1.4221e-02,  1.0742e-01,  ..., -3.0859e-01,
          -1.1953e+00, -7.9688e-01],
         [ 5.3516e-01,  8.4766e-01, -3.1055e-01,  ...,  3.7305e-01,
          -4.6680e-01,  2.9883e-01]],

        [[-5.5542e-03, -1.3123e-03, -4.8218e-03,  ...,  5.2185e-03,
          -5.2185e-03, -7.4768e-04],
         [-1.1562e+00,  5.5469e-01,  4.4922e-01,  ..., -1.0078e+00,
          -3.0859e-01,  2.7734e-01],
         [ 5.1172e-01, -2.8711e-01,  3.6865e-02,  ...,  5.1953e-01,
           8.7109e-01, -9.5312e-01],
         ...,
         [ 3.4180e-01,  3.7109e-01, -3.6377e-02,  ..., -1.5723e-01,
          -1.6641e+00, -7.6562e-01],
         [ 7.1289e-02,  1.4221e-02,  1.0742e-01,  ..., -3.0859e-01,
          -1.1953e+00, -7.9688e-01],
         [ 5.3516e-01,  8.4766e-01, -3.1055e-01,  ...,  3.7305e-01,
          -4.6680e-01,  2.9883e-01]],

        ...,

        [[-7.5378e-03, -4.8161e-05,  1.0132e-02,  ..., -6.2256e-03,
          -5.2185e-03, -4.3640e-03],
         [ 1.0156e+00, -2.6367e-01, -4.8242e-01,  ..., -1.2969e+00,
           5.2734e-01,  2.0996e-01],
         [-4.5898e-02,  1.1719e+00,  9.1797e-01,  ..., -1.0000e+00,
           8.6719e-01, -4.1211e-01],
         ...,
         [-1.8066e-01, -1.7344e+00, -4.3359e-01,  ..., -8.9844e-01,
          -8.0469e-01,  9.5215e-02],
         [ 3.6133e-01, -2.1875e-01, -1.9434e-01,  ..., -1.3965e-01,
          -1.6641e+00, -4.4922e-01],
         [-1.1328e+00, -5.2734e-01,  3.2031e-01,  ..., -2.3071e-02,
           7.3438e-01,  4.8828e-02]],

        [[-7.5378e-03, -4.8161e-05,  1.0132e-02,  ..., -6.2256e-03,
          -5.2185e-03, -4.3640e-03],
         [ 1.0156e+00, -2.6367e-01, -4.8242e-01,  ..., -1.2969e+00,
           5.2734e-01,  2.0996e-01],
         [-4.5898e-02,  1.1719e+00,  9.1797e-01,  ..., -1.0000e+00,
           8.6719e-01, -4.1211e-01],
         ...,
         [-1.8066e-01, -1.7344e+00, -4.3359e-01,  ..., -8.9844e-01,
          -8.0469e-01,  9.5215e-02],
         [ 3.6133e-01, -2.1875e-01, -1.9434e-01,  ..., -1.3965e-01,
          -1.6641e+00, -4.4922e-01],
         [-1.1328e+00, -5.2734e-01,  3.2031e-01,  ..., -2.3071e-02,
           7.3438e-01,  4.8828e-02]],

        [[-7.5378e-03, -4.8161e-05,  1.0132e-02,  ..., -6.2256e-03,
          -5.2185e-03, -4.3640e-03],
         [ 1.0156e+00, -2.6367e-01, -4.8242e-01,  ..., -1.2969e+00,
           5.2734e-01,  2.0996e-01],
         [-4.5898e-02,  1.1719e+00,  9.1797e-01,  ..., -1.0000e+00,
           8.6719e-01, -4.1211e-01],
         ...,
         [-1.8066e-01, -1.7344e+00, -4.3359e-01,  ..., -8.9844e-01,
          -8.0469e-01,  9.5215e-02],
         [ 3.6133e-01, -2.1875e-01, -1.9434e-01,  ..., -1.3965e-01,
          -1.6641e+00, -4.4922e-01],
         [-1.1328e+00, -5.2734e-01,  3.2031e-01,  ..., -2.3071e-02,
           7.3438e-01,  4.8828e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.5198e-02,  4.3030e-03,  1.2131e-03,  ...,  3.6865e-02,
          -2.1387e-01, -1.9287e-02],
         [-1.3828e+00,  2.1875e-01, -4.2480e-02,  ...,  1.6562e+00,
           4.1406e-01,  3.2344e+00],
         [-9.0234e-01,  2.6001e-02, -2.6172e-01,  ...,  2.2656e+00,
           1.5625e+00,  2.1719e+00],
         ...,
         [-8.6914e-02, -6.6797e-01,  8.8281e-01,  ..., -3.2471e-02,
          -2.3125e+00,  2.3906e+00],
         [ 2.8711e-01, -7.0312e-02,  1.1914e-01,  ..., -2.1875e-01,
          -1.6484e+00, -8.7891e-01],
         [-1.6211e-01, -3.6523e-01, -3.8477e-01,  ..., -3.7031e+00,
          -2.7031e+00,  2.1719e+00]],

        [[ 1.5198e-02,  4.3030e-03,  1.2131e-03,  ...,  3.6865e-02,
          -2.1387e-01, -1.9287e-02],
         [-1.3828e+00,  2.1875e-01, -4.2480e-02,  ...,  1.6562e+00,
           4.1406e-01,  3.2344e+00],
         [-9.0234e-01,  2.6001e-02, -2.6172e-01,  ...,  2.2656e+00,
           1.5625e+00,  2.1719e+00],
         ...,
         [-8.6914e-02, -6.6797e-01,  8.8281e-01,  ..., -3.2471e-02,
          -2.3125e+00,  2.3906e+00],
         [ 2.8711e-01, -7.0312e-02,  1.1914e-01,  ..., -2.1875e-01,
          -1.6484e+00, -8.7891e-01],
         [-1.6211e-01, -3.6523e-01, -3.8477e-01,  ..., -3.7031e+00,
          -2.7031e+00,  2.1719e+00]],

        [[ 1.5198e-02,  4.3030e-03,  1.2131e-03,  ...,  3.6865e-02,
          -2.1387e-01, -1.9287e-02],
         [-1.3828e+00,  2.1875e-01, -4.2480e-02,  ...,  1.6562e+00,
           4.1406e-01,  3.2344e+00],
         [-9.0234e-01,  2.6001e-02, -2.6172e-01,  ...,  2.2656e+00,
           1.5625e+00,  2.1719e+00],
         ...,
         [-8.6914e-02, -6.6797e-01,  8.8281e-01,  ..., -3.2471e-02,
          -2.3125e+00,  2.3906e+00],
         [ 2.8711e-01, -7.0312e-02,  1.1914e-01,  ..., -2.1875e-01,
          -1.6484e+00, -8.7891e-01],
         [-1.6211e-01, -3.6523e-01, -3.8477e-01,  ..., -3.7031e+00,
          -2.7031e+00,  2.1719e+00]],

        ...,

        [[-1.1536e-02, -6.7139e-03,  3.7003e-04,  ...,  1.7944e-02,
           6.1279e-02,  1.0938e+00],
         [-4.4922e-02, -1.0078e+00,  9.3750e-02,  ..., -1.9609e+00,
           3.4219e+00, -8.7500e+00],
         [-1.6602e-01, -6.3672e-01,  5.1270e-03,  ...,  7.3828e-01,
          -2.9375e+00, -1.0000e+01],
         ...,
         [-3.8281e-01, -9.3359e-01,  3.3594e-01,  ...,  2.8125e+00,
           6.5625e-01, -6.7812e+00],
         [-4.1016e-02,  1.3477e-01, -2.2852e-01,  ..., -5.8203e-01,
          -1.6172e+00, -7.9062e+00],
         [ 4.8242e-01,  6.7578e-01, -4.8242e-01,  ..., -6.4062e-01,
          -2.4531e+00, -7.1562e+00]],

        [[-1.1536e-02, -6.7139e-03,  3.7003e-04,  ...,  1.7944e-02,
           6.1279e-02,  1.0938e+00],
         [-4.4922e-02, -1.0078e+00,  9.3750e-02,  ..., -1.9609e+00,
           3.4219e+00, -8.7500e+00],
         [-1.6602e-01, -6.3672e-01,  5.1270e-03,  ...,  7.3828e-01,
          -2.9375e+00, -1.0000e+01],
         ...,
         [-3.8281e-01, -9.3359e-01,  3.3594e-01,  ...,  2.8125e+00,
           6.5625e-01, -6.7812e+00],
         [-4.1016e-02,  1.3477e-01, -2.2852e-01,  ..., -5.8203e-01,
          -1.6172e+00, -7.9062e+00],
         [ 4.8242e-01,  6.7578e-01, -4.8242e-01,  ..., -6.4062e-01,
          -2.4531e+00, -7.1562e+00]],

        [[-1.1536e-02, -6.7139e-03,  3.7003e-04,  ...,  1.7944e-02,
           6.1279e-02,  1.0938e+00],
         [-4.4922e-02, -1.0078e+00,  9.3750e-02,  ..., -1.9609e+00,
           3.4219e+00, -8.7500e+00],
         [-1.6602e-01, -6.3672e-01,  5.1270e-03,  ...,  7.3828e-01,
          -2.9375e+00, -1.0000e+01],
         ...,
         [-3.8281e-01, -9.3359e-01,  3.3594e-01,  ...,  2.8125e+00,
           6.5625e-01, -6.7812e+00],
         [-4.1016e-02,  1.3477e-01, -2.2852e-01,  ..., -5.8203e-01,
          -1.6172e+00, -7.9062e+00],
         [ 4.8242e-01,  6.7578e-01, -4.8242e-01,  ..., -6.4062e-01,
          -2.4531e+00, -7.1562e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 2.9755e-04,  3.8757e-03,  6.9427e-04,  ...,  5.5237e-03,
           1.3828e-04, -6.5918e-03],
         [ 1.3594e+00,  5.7422e-01, -1.0938e+00,  ..., -1.8828e+00,
           5.3516e-01,  5.5859e-01],
         [ 2.3315e-02, -2.5938e+00,  8.9844e-01,  ..., -9.7266e-01,
           3.0664e-01,  1.2266e+00],
         ...,
         [ 1.4844e+00, -5.8984e-01,  1.9629e-01,  ..., -6.0156e-01,
          -1.9238e-01, -2.1777e-01],
         [ 1.1406e+00, -1.0938e+00, -3.2422e-01,  ..., -1.1719e-01,
          -3.1055e-01, -2.3438e-02],
         [ 5.8203e-01, -1.5747e-02, -6.3281e-01,  ...,  1.0547e+00,
          -1.0645e-01,  5.0537e-02]],

        [[ 2.9755e-04,  3.8757e-03,  6.9427e-04,  ...,  5.5237e-03,
           1.3828e-04, -6.5918e-03],
         [ 1.3594e+00,  5.7422e-01, -1.0938e+00,  ..., -1.8828e+00,
           5.3516e-01,  5.5859e-01],
         [ 2.3315e-02, -2.5938e+00,  8.9844e-01,  ..., -9.7266e-01,
           3.0664e-01,  1.2266e+00],
         ...,
         [ 1.4844e+00, -5.8984e-01,  1.9629e-01,  ..., -6.0156e-01,
          -1.9238e-01, -2.1777e-01],
         [ 1.1406e+00, -1.0938e+00, -3.2422e-01,  ..., -1.1719e-01,
          -3.1055e-01, -2.3438e-02],
         [ 5.8203e-01, -1.5747e-02, -6.3281e-01,  ...,  1.0547e+00,
          -1.0645e-01,  5.0537e-02]],

        [[ 2.9755e-04,  3.8757e-03,  6.9427e-04,  ...,  5.5237e-03,
           1.3828e-04, -6.5918e-03],
         [ 1.3594e+00,  5.7422e-01, -1.0938e+00,  ..., -1.8828e+00,
           5.3516e-01,  5.5859e-01],
         [ 2.3315e-02, -2.5938e+00,  8.9844e-01,  ..., -9.7266e-01,
           3.0664e-01,  1.2266e+00],
         ...,
         [ 1.4844e+00, -5.8984e-01,  1.9629e-01,  ..., -6.0156e-01,
          -1.9238e-01, -2.1777e-01],
         [ 1.1406e+00, -1.0938e+00, -3.2422e-01,  ..., -1.1719e-01,
          -3.1055e-01, -2.3438e-02],
         [ 5.8203e-01, -1.5747e-02, -6.3281e-01,  ...,  1.0547e+00,
          -1.0645e-01,  5.0537e-02]],

        ...,

        [[ 2.6703e-03, -6.5918e-03, -9.8267e-03,  ..., -7.3242e-03,
          -2.3460e-04, -4.3945e-03],
         [-1.8906e+00,  7.5391e-01, -9.6191e-02,  ...,  2.6367e-01,
           1.5156e+00, -1.2031e+00],
         [ 1.4531e+00, -1.3359e+00,  9.0820e-02,  ..., -2.6406e+00,
           4.2188e+00,  3.1875e+00],
         ...,
         [ 6.2109e-01,  9.7266e-01, -8.7891e-01,  ...,  1.5625e-01,
          -3.2617e-01, -9.0234e-01],
         [-3.6719e-01, -1.1475e-01, -5.4297e-01,  ...,  2.7734e-01,
           1.1328e-01,  5.5176e-02],
         [ 5.9766e-01,  3.8086e-01,  2.1562e+00,  ...,  1.7656e+00,
          -3.3008e-01,  5.6152e-02]],

        [[ 2.6703e-03, -6.5918e-03, -9.8267e-03,  ..., -7.3242e-03,
          -2.3460e-04, -4.3945e-03],
         [-1.8906e+00,  7.5391e-01, -9.6191e-02,  ...,  2.6367e-01,
           1.5156e+00, -1.2031e+00],
         [ 1.4531e+00, -1.3359e+00,  9.0820e-02,  ..., -2.6406e+00,
           4.2188e+00,  3.1875e+00],
         ...,
         [ 6.2109e-01,  9.7266e-01, -8.7891e-01,  ...,  1.5625e-01,
          -3.2617e-01, -9.0234e-01],
         [-3.6719e-01, -1.1475e-01, -5.4297e-01,  ...,  2.7734e-01,
           1.1328e-01,  5.5176e-02],
         [ 5.9766e-01,  3.8086e-01,  2.1562e+00,  ...,  1.7656e+00,
          -3.3008e-01,  5.6152e-02]],

        [[ 2.6703e-03, -6.5918e-03, -9.8267e-03,  ..., -7.3242e-03,
          -2.3460e-04, -4.3945e-03],
         [-1.8906e+00,  7.5391e-01, -9.6191e-02,  ...,  2.6367e-01,
           1.5156e+00, -1.2031e+00],
         [ 1.4531e+00, -1.3359e+00,  9.0820e-02,  ..., -2.6406e+00,
           4.2188e+00,  3.1875e+00],
         ...,
         [ 6.2109e-01,  9.7266e-01, -8.7891e-01,  ...,  1.5625e-01,
          -3.2617e-01, -9.0234e-01],
         [-3.6719e-01, -1.1475e-01, -5.4297e-01,  ...,  2.7734e-01,
           1.1328e-01,  5.5176e-02],
         [ 5.9766e-01,  3.8086e-01,  2.1562e+00,  ...,  1.7656e+00,
          -3.3008e-01,  5.6152e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-3.6926e-03,  3.5858e-03, -1.4191e-03,  ...,  2.9907e-02,
          -6.6406e-02,  5.3516e-01],
         [ 2.0938e+00,  1.7031e+00, -1.5156e+00,  ..., -2.1875e+00,
           1.8594e+00, -5.1875e+00],
         [ 7.6172e-01,  1.2109e+00, -1.4062e-01,  ..., -1.0938e+00,
          -1.8047e+00, -3.8281e+00],
         ...,
         [ 4.6875e-01, -4.2383e-01,  2.3828e-01,  ..., -2.9102e-01,
           4.6875e+00, -5.3125e+00],
         [-4.4922e-01, -3.2812e-01,  1.7188e-01,  ...,  8.3594e-01,
           6.0938e-01, -4.4688e+00],
         [-2.3750e+00, -2.3125e+00, -1.1562e+00,  ..., -1.7188e+00,
          -5.0391e-01, -5.0000e+00]],

        [[-3.6926e-03,  3.5858e-03, -1.4191e-03,  ...,  2.9907e-02,
          -6.6406e-02,  5.3516e-01],
         [ 2.0938e+00,  1.7031e+00, -1.5156e+00,  ..., -2.1875e+00,
           1.8594e+00, -5.1875e+00],
         [ 7.6172e-01,  1.2109e+00, -1.4062e-01,  ..., -1.0938e+00,
          -1.8047e+00, -3.8281e+00],
         ...,
         [ 4.6875e-01, -4.2383e-01,  2.3828e-01,  ..., -2.9102e-01,
           4.6875e+00, -5.3125e+00],
         [-4.4922e-01, -3.2812e-01,  1.7188e-01,  ...,  8.3594e-01,
           6.0938e-01, -4.4688e+00],
         [-2.3750e+00, -2.3125e+00, -1.1562e+00,  ..., -1.7188e+00,
          -5.0391e-01, -5.0000e+00]],

        [[-3.6926e-03,  3.5858e-03, -1.4191e-03,  ...,  2.9907e-02,
          -6.6406e-02,  5.3516e-01],
         [ 2.0938e+00,  1.7031e+00, -1.5156e+00,  ..., -2.1875e+00,
           1.8594e+00, -5.1875e+00],
         [ 7.6172e-01,  1.2109e+00, -1.4062e-01,  ..., -1.0938e+00,
          -1.8047e+00, -3.8281e+00],
         ...,
         [ 4.6875e-01, -4.2383e-01,  2.3828e-01,  ..., -2.9102e-01,
           4.6875e+00, -5.3125e+00],
         [-4.4922e-01, -3.2812e-01,  1.7188e-01,  ...,  8.3594e-01,
           6.0938e-01, -4.4688e+00],
         [-2.3750e+00, -2.3125e+00, -1.1562e+00,  ..., -1.7188e+00,
          -5.0391e-01, -5.0000e+00]],

        ...,

        [[-1.0834e-03,  1.1597e-02,  6.6223e-03,  ...,  1.1426e-01,
           5.9082e-02,  7.5195e-02],
         [ 1.6719e+00, -5.2344e-01, -8.1641e-01,  ..., -6.9922e-01,
          -5.8594e-01, -1.0781e+00],
         [ 1.2344e+00, -1.4453e+00, -9.9219e-01,  ..., -6.4844e-01,
          -4.3945e-01, -6.7969e-01],
         ...,
         [ 6.7969e-01, -6.4453e-01,  9.9219e-01,  ..., -1.2344e+00,
          -1.3984e+00, -1.0000e+00],
         [-4.7070e-01,  3.3984e-01,  5.8984e-01,  ..., -4.1992e-01,
          -1.4258e-01,  6.3965e-02],
         [-1.0000e+00,  1.5781e+00,  5.8203e-01,  ..., -9.4141e-01,
           4.6631e-02, -1.3984e+00]],

        [[-1.0834e-03,  1.1597e-02,  6.6223e-03,  ...,  1.1426e-01,
           5.9082e-02,  7.5195e-02],
         [ 1.6719e+00, -5.2344e-01, -8.1641e-01,  ..., -6.9922e-01,
          -5.8594e-01, -1.0781e+00],
         [ 1.2344e+00, -1.4453e+00, -9.9219e-01,  ..., -6.4844e-01,
          -4.3945e-01, -6.7969e-01],
         ...,
         [ 6.7969e-01, -6.4453e-01,  9.9219e-01,  ..., -1.2344e+00,
          -1.3984e+00, -1.0000e+00],
         [-4.7070e-01,  3.3984e-01,  5.8984e-01,  ..., -4.1992e-01,
          -1.4258e-01,  6.3965e-02],
         [-1.0000e+00,  1.5781e+00,  5.8203e-01,  ..., -9.4141e-01,
           4.6631e-02, -1.3984e+00]],

        [[-1.0834e-03,  1.1597e-02,  6.6223e-03,  ...,  1.1426e-01,
           5.9082e-02,  7.5195e-02],
         [ 1.6719e+00, -5.2344e-01, -8.1641e-01,  ..., -6.9922e-01,
          -5.8594e-01, -1.0781e+00],
         [ 1.2344e+00, -1.4453e+00, -9.9219e-01,  ..., -6.4844e-01,
          -4.3945e-01, -6.7969e-01],
         ...,
         [ 6.7969e-01, -6.4453e-01,  9.9219e-01,  ..., -1.2344e+00,
          -1.3984e+00, -1.0000e+00],
         [-4.7070e-01,  3.3984e-01,  5.8984e-01,  ..., -4.1992e-01,
          -1.4258e-01,  6.3965e-02],
         [-1.0000e+00,  1.5781e+00,  5.8203e-01,  ..., -9.4141e-01,
           4.6631e-02, -1.3984e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.1215e-03, -1.3184e-02,  2.7313e-03,  ...,  1.4160e-02,
           2.8038e-04,  7.0496e-03],
         [ 6.5625e-01,  3.0078e-01,  3.0664e-01,  ..., -2.2754e-01,
          -8.6719e-01, -2.0703e-01],
         [ 3.8672e-01, -2.0020e-01,  3.7305e-01,  ...,  1.3477e-01,
           7.9297e-01,  3.1250e-01],
         ...,
         [-1.7109e+00,  2.7344e-01,  9.6680e-02,  ...,  1.7773e-01,
           3.5938e-01, -5.0781e-01],
         [-1.5703e+00, -2.3633e-01,  2.7539e-01,  ...,  4.3945e-01,
          -2.5177e-03, -4.2578e-01],
         [-5.3906e-01, -4.6289e-01,  4.3359e-01,  ...,  4.4531e-01,
          -3.0859e-01, -9.4141e-01]],

        [[-1.1215e-03, -1.3184e-02,  2.7313e-03,  ...,  1.4160e-02,
           2.8038e-04,  7.0496e-03],
         [ 6.5625e-01,  3.0078e-01,  3.0664e-01,  ..., -2.2754e-01,
          -8.6719e-01, -2.0703e-01],
         [ 3.8672e-01, -2.0020e-01,  3.7305e-01,  ...,  1.3477e-01,
           7.9297e-01,  3.1250e-01],
         ...,
         [-1.7109e+00,  2.7344e-01,  9.6680e-02,  ...,  1.7773e-01,
           3.5938e-01, -5.0781e-01],
         [-1.5703e+00, -2.3633e-01,  2.7539e-01,  ...,  4.3945e-01,
          -2.5177e-03, -4.2578e-01],
         [-5.3906e-01, -4.6289e-01,  4.3359e-01,  ...,  4.4531e-01,
          -3.0859e-01, -9.4141e-01]],

        [[-1.1215e-03, -1.3184e-02,  2.7313e-03,  ...,  1.4160e-02,
           2.8038e-04,  7.0496e-03],
         [ 6.5625e-01,  3.0078e-01,  3.0664e-01,  ..., -2.2754e-01,
          -8.6719e-01, -2.0703e-01],
         [ 3.8672e-01, -2.0020e-01,  3.7305e-01,  ...,  1.3477e-01,
           7.9297e-01,  3.1250e-01],
         ...,
         [-1.7109e+00,  2.7344e-01,  9.6680e-02,  ...,  1.7773e-01,
           3.5938e-01, -5.0781e-01],
         [-1.5703e+00, -2.3633e-01,  2.7539e-01,  ...,  4.3945e-01,
          -2.5177e-03, -4.2578e-01],
         [-5.3906e-01, -4.6289e-01,  4.3359e-01,  ...,  4.4531e-01,
          -3.0859e-01, -9.4141e-01]],

        ...,

        [[-9.1553e-03, -9.7656e-03, -6.1951e-03,  ...,  5.8899e-03,
           6.4697e-03, -7.2098e-04],
         [-5.0781e-01,  1.3281e+00, -1.4766e+00,  ..., -7.6172e-01,
          -1.5430e-01, -1.7656e+00],
         [-4.9414e-01,  2.2949e-01,  1.8203e+00,  ..., -3.0273e-01,
          -2.4805e-01, -6.5625e-01],
         ...,
         [-3.7842e-02,  9.4727e-02,  1.5703e+00,  ...,  1.1094e+00,
          -1.2344e+00,  1.5723e-01],
         [-1.6895e-01,  4.3750e-01,  7.5000e-01,  ...,  1.6504e-01,
          -1.5332e-01,  8.2397e-03],
         [-1.0000e+00, -7.9297e-01,  8.7109e-01,  ...,  1.4688e+00,
           1.1875e+00, -5.3906e-01]],

        [[-9.1553e-03, -9.7656e-03, -6.1951e-03,  ...,  5.8899e-03,
           6.4697e-03, -7.2098e-04],
         [-5.0781e-01,  1.3281e+00, -1.4766e+00,  ..., -7.6172e-01,
          -1.5430e-01, -1.7656e+00],
         [-4.9414e-01,  2.2949e-01,  1.8203e+00,  ..., -3.0273e-01,
          -2.4805e-01, -6.5625e-01],
         ...,
         [-3.7842e-02,  9.4727e-02,  1.5703e+00,  ...,  1.1094e+00,
          -1.2344e+00,  1.5723e-01],
         [-1.6895e-01,  4.3750e-01,  7.5000e-01,  ...,  1.6504e-01,
          -1.5332e-01,  8.2397e-03],
         [-1.0000e+00, -7.9297e-01,  8.7109e-01,  ...,  1.4688e+00,
           1.1875e+00, -5.3906e-01]],

        [[-9.1553e-03, -9.7656e-03, -6.1951e-03,  ...,  5.8899e-03,
           6.4697e-03, -7.2098e-04],
         [-5.0781e-01,  1.3281e+00, -1.4766e+00,  ..., -7.6172e-01,
          -1.5430e-01, -1.7656e+00],
         [-4.9414e-01,  2.2949e-01,  1.8203e+00,  ..., -3.0273e-01,
          -2.4805e-01, -6.5625e-01],
         ...,
         [-3.7842e-02,  9.4727e-02,  1.5703e+00,  ...,  1.1094e+00,
          -1.2344e+00,  1.5723e-01],
         [-1.6895e-01,  4.3750e-01,  7.5000e-01,  ...,  1.6504e-01,
          -1.5332e-01,  8.2397e-03],
         [-1.0000e+00, -7.9297e-01,  8.7109e-01,  ...,  1.4688e+00,
           1.1875e+00, -5.3906e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-2.0386e-02,  2.8801e-04,  1.3504e-03,  ..., -6.3281e-01,
          -2.7222e-02, -1.0645e-01],
         [-3.9062e-02,  2.2969e+00,  1.6562e+00,  ...,  1.0688e+01,
          -2.3750e+00,  5.7031e-01],
         [-1.7656e+00,  8.3594e-01,  4.0234e-01,  ...,  1.1438e+01,
          -7.1094e-01,  8.7109e-01],
         ...,
         [-2.5195e-01, -8.2031e-02,  7.3828e-01,  ...,  9.6875e+00,
           1.2500e+00,  1.0078e+00],
         [-4.1211e-01, -1.0781e+00,  6.8750e-01,  ...,  8.6250e+00,
           7.2656e-01,  5.3906e-01],
         [-1.0781e+00, -1.0000e+00,  9.1406e-01,  ...,  9.2500e+00,
           3.4961e-01,  3.0762e-02]],

        [[-2.0386e-02,  2.8801e-04,  1.3504e-03,  ..., -6.3281e-01,
          -2.7222e-02, -1.0645e-01],
         [-3.9062e-02,  2.2969e+00,  1.6562e+00,  ...,  1.0688e+01,
          -2.3750e+00,  5.7031e-01],
         [-1.7656e+00,  8.3594e-01,  4.0234e-01,  ...,  1.1438e+01,
          -7.1094e-01,  8.7109e-01],
         ...,
         [-2.5195e-01, -8.2031e-02,  7.3828e-01,  ...,  9.6875e+00,
           1.2500e+00,  1.0078e+00],
         [-4.1211e-01, -1.0781e+00,  6.8750e-01,  ...,  8.6250e+00,
           7.2656e-01,  5.3906e-01],
         [-1.0781e+00, -1.0000e+00,  9.1406e-01,  ...,  9.2500e+00,
           3.4961e-01,  3.0762e-02]],

        [[-2.0386e-02,  2.8801e-04,  1.3504e-03,  ..., -6.3281e-01,
          -2.7222e-02, -1.0645e-01],
         [-3.9062e-02,  2.2969e+00,  1.6562e+00,  ...,  1.0688e+01,
          -2.3750e+00,  5.7031e-01],
         [-1.7656e+00,  8.3594e-01,  4.0234e-01,  ...,  1.1438e+01,
          -7.1094e-01,  8.7109e-01],
         ...,
         [-2.5195e-01, -8.2031e-02,  7.3828e-01,  ...,  9.6875e+00,
           1.2500e+00,  1.0078e+00],
         [-4.1211e-01, -1.0781e+00,  6.8750e-01,  ...,  8.6250e+00,
           7.2656e-01,  5.3906e-01],
         [-1.0781e+00, -1.0000e+00,  9.1406e-01,  ...,  9.2500e+00,
           3.4961e-01,  3.0762e-02]],

        ...,

        [[ 4.9744e-03,  5.5237e-03,  7.2937e-03,  ..., -2.2949e-02,
           8.1787e-03,  1.5723e-01],
         [-1.3750e+00, -8.7891e-01, -8.5156e-01,  ...,  1.4453e+00,
          -2.7734e-01, -3.7812e+00],
         [-8.2812e-01, -5.9375e-01,  1.8652e-01,  ...,  3.1875e+00,
           2.5977e-01,  1.8555e-01],
         ...,
         [-1.0156e+00,  6.3281e-01, -1.9727e-01,  ..., -1.7500e+00,
          -2.6250e+00, -3.7656e+00],
         [ 2.1680e-01,  4.5703e-01, -6.1719e-01,  ...,  5.9375e-01,
           5.2344e-01, -6.9531e-01],
         [ 1.2969e+00,  8.9062e-01, -1.2188e+00,  ...,  2.5312e+00,
          -1.3516e+00, -2.0469e+00]],

        [[ 4.9744e-03,  5.5237e-03,  7.2937e-03,  ..., -2.2949e-02,
           8.1787e-03,  1.5723e-01],
         [-1.3750e+00, -8.7891e-01, -8.5156e-01,  ...,  1.4453e+00,
          -2.7734e-01, -3.7812e+00],
         [-8.2812e-01, -5.9375e-01,  1.8652e-01,  ...,  3.1875e+00,
           2.5977e-01,  1.8555e-01],
         ...,
         [-1.0156e+00,  6.3281e-01, -1.9727e-01,  ..., -1.7500e+00,
          -2.6250e+00, -3.7656e+00],
         [ 2.1680e-01,  4.5703e-01, -6.1719e-01,  ...,  5.9375e-01,
           5.2344e-01, -6.9531e-01],
         [ 1.2969e+00,  8.9062e-01, -1.2188e+00,  ...,  2.5312e+00,
          -1.3516e+00, -2.0469e+00]],

        [[ 4.9744e-03,  5.5237e-03,  7.2937e-03,  ..., -2.2949e-02,
           8.1787e-03,  1.5723e-01],
         [-1.3750e+00, -8.7891e-01, -8.5156e-01,  ...,  1.4453e+00,
          -2.7734e-01, -3.7812e+00],
         [-8.2812e-01, -5.9375e-01,  1.8652e-01,  ...,  3.1875e+00,
           2.5977e-01,  1.8555e-01],
         ...,
         [-1.0156e+00,  6.3281e-01, -1.9727e-01,  ..., -1.7500e+00,
          -2.6250e+00, -3.7656e+00],
         [ 2.1680e-01,  4.5703e-01, -6.1719e-01,  ...,  5.9375e-01,
           5.2344e-01, -6.9531e-01],
         [ 1.2969e+00,  8.9062e-01, -1.2188e+00,  ...,  2.5312e+00,
          -1.3516e+00, -2.0469e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-0.0063, -0.0132, -0.0074,  ...,  0.0058, -0.0055, -0.0167],
         [-1.8672, -0.9336, -1.5859,  ..., -1.7891, -0.7305, -0.5938],
         [-1.5156,  0.0366,  0.4688,  ..., -0.1069,  0.0542,  0.1040],
         ...,
         [ 0.1250,  0.2812, -1.1172,  ...,  0.2754, -0.7305,  0.0072],
         [ 0.8047, -0.1641,  0.0850,  ..., -0.5703, -0.4941,  0.3301],
         [-1.2109,  0.2578,  1.3516,  ...,  0.1367,  0.1660,  0.7227]],

        [[-0.0063, -0.0132, -0.0074,  ...,  0.0058, -0.0055, -0.0167],
         [-1.8672, -0.9336, -1.5859,  ..., -1.7891, -0.7305, -0.5938],
         [-1.5156,  0.0366,  0.4688,  ..., -0.1069,  0.0542,  0.1040],
         ...,
         [ 0.1250,  0.2812, -1.1172,  ...,  0.2754, -0.7305,  0.0072],
         [ 0.8047, -0.1641,  0.0850,  ..., -0.5703, -0.4941,  0.3301],
         [-1.2109,  0.2578,  1.3516,  ...,  0.1367,  0.1660,  0.7227]],

        [[-0.0063, -0.0132, -0.0074,  ...,  0.0058, -0.0055, -0.0167],
         [-1.8672, -0.9336, -1.5859,  ..., -1.7891, -0.7305, -0.5938],
         [-1.5156,  0.0366,  0.4688,  ..., -0.1069,  0.0542,  0.1040],
         ...,
         [ 0.1250,  0.2812, -1.1172,  ...,  0.2754, -0.7305,  0.0072],
         [ 0.8047, -0.1641,  0.0850,  ..., -0.5703, -0.4941,  0.3301],
         [-1.2109,  0.2578,  1.3516,  ...,  0.1367,  0.1660,  0.7227]],

        ...,

        [[ 0.0068,  0.0091,  0.0122,  ..., -0.0236, -0.0041,  0.0050],
         [ 0.5156,  0.7461,  1.4219,  ...,  1.0859,  1.1172,  0.5859],
         [-0.3027,  0.3945,  1.0859,  ...,  0.7852,  0.3652, -0.1992],
         ...,
         [-0.3750, -0.2832, -0.0659,  ...,  0.2617, -0.7578,  0.3301],
         [ 0.1924,  0.1963, -0.1709,  ...,  0.2207,  0.0713,  0.5234],
         [-0.9453, -0.0046,  0.9141,  ...,  0.2100, -0.6445,  0.6133]],

        [[ 0.0068,  0.0091,  0.0122,  ..., -0.0236, -0.0041,  0.0050],
         [ 0.5156,  0.7461,  1.4219,  ...,  1.0859,  1.1172,  0.5859],
         [-0.3027,  0.3945,  1.0859,  ...,  0.7852,  0.3652, -0.1992],
         ...,
         [-0.3750, -0.2832, -0.0659,  ...,  0.2617, -0.7578,  0.3301],
         [ 0.1924,  0.1963, -0.1709,  ...,  0.2207,  0.0713,  0.5234],
         [-0.9453, -0.0046,  0.9141,  ...,  0.2100, -0.6445,  0.6133]],

        [[ 0.0068,  0.0091,  0.0122,  ..., -0.0236, -0.0041,  0.0050],
         [ 0.5156,  0.7461,  1.4219,  ...,  1.0859,  1.1172,  0.5859],
         [-0.3027,  0.3945,  1.0859,  ...,  0.7852,  0.3652, -0.1992],
         ...,
         [-0.3750, -0.2832, -0.0659,  ...,  0.2617, -0.7578,  0.3301],
         [ 0.1924,  0.1963, -0.1709,  ...,  0.2207,  0.0713,  0.5234],
         [-0.9453, -0.0046,  0.9141,  ...,  0.2100, -0.6445,  0.6133]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.4261e-03,  2.3804e-03,  2.8076e-03,  ..., -3.6377e-02,
          -2.4219e-01,  8.2812e-01],
         [-4.5312e-01, -1.1094e+00,  3.9844e-01,  ...,  6.1328e-01,
          -1.2354e-01, -7.6875e+00],
         [-8.2422e-01, -1.3516e+00,  9.1406e-01,  ..., -1.0859e+00,
           2.6875e+00, -9.0625e+00],
         ...,
         [ 3.4375e-01,  1.1562e+00, -3.6719e-01,  ...,  3.4844e+00,
          -5.7500e+00, -7.5000e+00],
         [-1.6309e-01,  1.9336e-01, -9.5703e-01,  ...,  5.5078e-01,
          -2.3594e+00, -7.0625e+00],
         [-1.6016e+00,  1.0156e+00, -9.3359e-01,  ..., -8.0566e-02,
           1.5938e+00, -8.0625e+00]],

        [[ 2.4261e-03,  2.3804e-03,  2.8076e-03,  ..., -3.6377e-02,
          -2.4219e-01,  8.2812e-01],
         [-4.5312e-01, -1.1094e+00,  3.9844e-01,  ...,  6.1328e-01,
          -1.2354e-01, -7.6875e+00],
         [-8.2422e-01, -1.3516e+00,  9.1406e-01,  ..., -1.0859e+00,
           2.6875e+00, -9.0625e+00],
         ...,
         [ 3.4375e-01,  1.1562e+00, -3.6719e-01,  ...,  3.4844e+00,
          -5.7500e+00, -7.5000e+00],
         [-1.6309e-01,  1.9336e-01, -9.5703e-01,  ...,  5.5078e-01,
          -2.3594e+00, -7.0625e+00],
         [-1.6016e+00,  1.0156e+00, -9.3359e-01,  ..., -8.0566e-02,
           1.5938e+00, -8.0625e+00]],

        [[ 2.4261e-03,  2.3804e-03,  2.8076e-03,  ..., -3.6377e-02,
          -2.4219e-01,  8.2812e-01],
         [-4.5312e-01, -1.1094e+00,  3.9844e-01,  ...,  6.1328e-01,
          -1.2354e-01, -7.6875e+00],
         [-8.2422e-01, -1.3516e+00,  9.1406e-01,  ..., -1.0859e+00,
           2.6875e+00, -9.0625e+00],
         ...,
         [ 3.4375e-01,  1.1562e+00, -3.6719e-01,  ...,  3.4844e+00,
          -5.7500e+00, -7.5000e+00],
         [-1.6309e-01,  1.9336e-01, -9.5703e-01,  ...,  5.5078e-01,
          -2.3594e+00, -7.0625e+00],
         [-1.6016e+00,  1.0156e+00, -9.3359e-01,  ..., -8.0566e-02,
           1.5938e+00, -8.0625e+00]],

        ...,

        [[ 1.5747e-02, -5.4016e-03,  8.7891e-03,  ..., -1.3184e-02,
          -5.9326e-02,  3.2812e-01],
         [-2.4688e+00,  2.1875e-01, -9.1797e-01,  ..., -9.3750e-01,
           1.6250e+00, -2.6875e+00],
         [-2.8125e-01,  2.2461e-01, -4.1406e-01,  ..., -2.9844e+00,
           1.8555e-01, -2.4531e+00],
         ...,
         [-1.2266e+00,  4.6680e-01,  9.8047e-01,  ..., -1.6172e+00,
          -1.7871e-01, -3.3281e+00],
         [-4.3750e-01,  2.4121e-01,  1.2891e-01,  ..., -3.5000e+00,
          -9.5312e-01, -9.8047e-01],
         [ 7.6172e-02, -5.3125e-01, -3.3398e-01,  ..., -3.4688e+00,
           2.4844e+00, -1.8047e+00]],

        [[ 1.5747e-02, -5.4016e-03,  8.7891e-03,  ..., -1.3184e-02,
          -5.9326e-02,  3.2812e-01],
         [-2.4688e+00,  2.1875e-01, -9.1797e-01,  ..., -9.3750e-01,
           1.6250e+00, -2.6875e+00],
         [-2.8125e-01,  2.2461e-01, -4.1406e-01,  ..., -2.9844e+00,
           1.8555e-01, -2.4531e+00],
         ...,
         [-1.2266e+00,  4.6680e-01,  9.8047e-01,  ..., -1.6172e+00,
          -1.7871e-01, -3.3281e+00],
         [-4.3750e-01,  2.4121e-01,  1.2891e-01,  ..., -3.5000e+00,
          -9.5312e-01, -9.8047e-01],
         [ 7.6172e-02, -5.3125e-01, -3.3398e-01,  ..., -3.4688e+00,
           2.4844e+00, -1.8047e+00]],

        [[ 1.5747e-02, -5.4016e-03,  8.7891e-03,  ..., -1.3184e-02,
          -5.9326e-02,  3.2812e-01],
         [-2.4688e+00,  2.1875e-01, -9.1797e-01,  ..., -9.3750e-01,
           1.6250e+00, -2.6875e+00],
         [-2.8125e-01,  2.2461e-01, -4.1406e-01,  ..., -2.9844e+00,
           1.8555e-01, -2.4531e+00],
         ...,
         [-1.2266e+00,  4.6680e-01,  9.8047e-01,  ..., -1.6172e+00,
          -1.7871e-01, -3.3281e+00],
         [-4.3750e-01,  2.4121e-01,  1.2891e-01,  ..., -3.5000e+00,
          -9.5312e-01, -9.8047e-01],
         [ 7.6172e-02, -5.3125e-01, -3.3398e-01,  ..., -3.4688e+00,
           2.4844e+00, -1.8047e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-5.2795e-03,  2.1648e-04, -1.0498e-02,  ..., -1.2054e-03,
           1.2665e-03,  1.0254e-02],
         [ 1.1016e+00, -6.9922e-01,  2.5391e-01,  ..., -3.6914e-01,
           2.3438e-01, -1.0791e-01],
         [ 1.4551e-01, -2.8198e-02,  4.4727e-01,  ..., -1.1719e+00,
           6.7188e-01, -6.7188e-01],
         ...,
         [ 1.6719e+00,  1.6797e+00,  2.3438e+00,  ...,  2.0469e+00,
          -4.2188e+00,  3.6875e+00],
         [ 1.1953e+00,  7.4609e-01,  1.7578e-01,  ...,  4.6387e-02,
          -6.5234e-01,  1.0781e+00],
         [-3.6914e-01,  2.9883e-01,  6.5625e-01,  ..., -9.6484e-01,
           5.4688e-02, -3.2227e-01]],

        [[-5.2795e-03,  2.1648e-04, -1.0498e-02,  ..., -1.2054e-03,
           1.2665e-03,  1.0254e-02],
         [ 1.1016e+00, -6.9922e-01,  2.5391e-01,  ..., -3.6914e-01,
           2.3438e-01, -1.0791e-01],
         [ 1.4551e-01, -2.8198e-02,  4.4727e-01,  ..., -1.1719e+00,
           6.7188e-01, -6.7188e-01],
         ...,
         [ 1.6719e+00,  1.6797e+00,  2.3438e+00,  ...,  2.0469e+00,
          -4.2188e+00,  3.6875e+00],
         [ 1.1953e+00,  7.4609e-01,  1.7578e-01,  ...,  4.6387e-02,
          -6.5234e-01,  1.0781e+00],
         [-3.6914e-01,  2.9883e-01,  6.5625e-01,  ..., -9.6484e-01,
           5.4688e-02, -3.2227e-01]],

        [[-5.2795e-03,  2.1648e-04, -1.0498e-02,  ..., -1.2054e-03,
           1.2665e-03,  1.0254e-02],
         [ 1.1016e+00, -6.9922e-01,  2.5391e-01,  ..., -3.6914e-01,
           2.3438e-01, -1.0791e-01],
         [ 1.4551e-01, -2.8198e-02,  4.4727e-01,  ..., -1.1719e+00,
           6.7188e-01, -6.7188e-01],
         ...,
         [ 1.6719e+00,  1.6797e+00,  2.3438e+00,  ...,  2.0469e+00,
          -4.2188e+00,  3.6875e+00],
         [ 1.1953e+00,  7.4609e-01,  1.7578e-01,  ...,  4.6387e-02,
          -6.5234e-01,  1.0781e+00],
         [-3.6914e-01,  2.9883e-01,  6.5625e-01,  ..., -9.6484e-01,
           5.4688e-02, -3.2227e-01]],

        ...,

        [[-1.2054e-03,  1.3611e-02, -3.5553e-03,  ..., -1.2817e-03,
          -1.3123e-03,  2.8839e-03],
         [-1.1377e-01, -1.1875e+00,  1.1719e+00,  ..., -1.6250e+00,
           4.6680e-01, -6.9531e-01],
         [ 1.9453e+00,  1.0840e-01,  1.4038e-02,  ...,  8.9844e-02,
          -5.5078e-01,  4.6484e-01],
         ...,
         [-3.7500e-01, -8.2031e-01,  7.1484e-01,  ..., -9.7656e-02,
           2.6172e-01, -9.3750e-01],
         [ 1.0156e-01, -3.9648e-01,  3.5547e-01,  ...,  7.2656e-01,
           5.0000e-01, -8.2031e-01],
         [-3.0273e-01,  2.1094e-01, -3.9453e-01,  ...,  1.2109e+00,
           2.5586e-01,  1.6113e-01]],

        [[-1.2054e-03,  1.3611e-02, -3.5553e-03,  ..., -1.2817e-03,
          -1.3123e-03,  2.8839e-03],
         [-1.1377e-01, -1.1875e+00,  1.1719e+00,  ..., -1.6250e+00,
           4.6680e-01, -6.9531e-01],
         [ 1.9453e+00,  1.0840e-01,  1.4038e-02,  ...,  8.9844e-02,
          -5.5078e-01,  4.6484e-01],
         ...,
         [-3.7500e-01, -8.2031e-01,  7.1484e-01,  ..., -9.7656e-02,
           2.6172e-01, -9.3750e-01],
         [ 1.0156e-01, -3.9648e-01,  3.5547e-01,  ...,  7.2656e-01,
           5.0000e-01, -8.2031e-01],
         [-3.0273e-01,  2.1094e-01, -3.9453e-01,  ...,  1.2109e+00,
           2.5586e-01,  1.6113e-01]],

        [[-1.2054e-03,  1.3611e-02, -3.5553e-03,  ..., -1.2817e-03,
          -1.3123e-03,  2.8839e-03],
         [-1.1377e-01, -1.1875e+00,  1.1719e+00,  ..., -1.6250e+00,
           4.6680e-01, -6.9531e-01],
         [ 1.9453e+00,  1.0840e-01,  1.4038e-02,  ...,  8.9844e-02,
          -5.5078e-01,  4.6484e-01],
         ...,
         [-3.7500e-01, -8.2031e-01,  7.1484e-01,  ..., -9.7656e-02,
           2.6172e-01, -9.3750e-01],
         [ 1.0156e-01, -3.9648e-01,  3.5547e-01,  ...,  7.2656e-01,
           5.0000e-01, -8.2031e-01],
         [-3.0273e-01,  2.1094e-01, -3.9453e-01,  ...,  1.2109e+00,
           2.5586e-01,  1.6113e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.3977e-02, -1.9302e-03,  5.9509e-03,  ..., -5.0537e-02,
          -9.7656e-02, -5.2734e-01],
         [-1.3203e+00, -7.7344e-01, -8.3984e-01,  ..., -8.1641e-01,
          -6.2109e-01,  8.3750e+00],
         [ 3.7500e-01,  5.3125e-01, -1.0938e+00,  ..., -1.4062e-01,
           9.2285e-02,  8.3750e+00],
         ...,
         [-1.5156e+00,  1.6484e+00,  1.5938e+00,  ...,  2.2500e+00,
           6.4062e-01,  7.3750e+00],
         [ 4.4141e-01,  4.9805e-01, -2.7344e-02,  ...,  1.3906e+00,
           1.1719e+00,  7.0938e+00],
         [ 1.9531e+00, -1.7578e+00,  5.5078e-01,  ...,  1.3750e+00,
           1.6562e+00,  7.6875e+00]],

        [[ 1.3977e-02, -1.9302e-03,  5.9509e-03,  ..., -5.0537e-02,
          -9.7656e-02, -5.2734e-01],
         [-1.3203e+00, -7.7344e-01, -8.3984e-01,  ..., -8.1641e-01,
          -6.2109e-01,  8.3750e+00],
         [ 3.7500e-01,  5.3125e-01, -1.0938e+00,  ..., -1.4062e-01,
           9.2285e-02,  8.3750e+00],
         ...,
         [-1.5156e+00,  1.6484e+00,  1.5938e+00,  ...,  2.2500e+00,
           6.4062e-01,  7.3750e+00],
         [ 4.4141e-01,  4.9805e-01, -2.7344e-02,  ...,  1.3906e+00,
           1.1719e+00,  7.0938e+00],
         [ 1.9531e+00, -1.7578e+00,  5.5078e-01,  ...,  1.3750e+00,
           1.6562e+00,  7.6875e+00]],

        [[ 1.3977e-02, -1.9302e-03,  5.9509e-03,  ..., -5.0537e-02,
          -9.7656e-02, -5.2734e-01],
         [-1.3203e+00, -7.7344e-01, -8.3984e-01,  ..., -8.1641e-01,
          -6.2109e-01,  8.3750e+00],
         [ 3.7500e-01,  5.3125e-01, -1.0938e+00,  ..., -1.4062e-01,
           9.2285e-02,  8.3750e+00],
         ...,
         [-1.5156e+00,  1.6484e+00,  1.5938e+00,  ...,  2.2500e+00,
           6.4062e-01,  7.3750e+00],
         [ 4.4141e-01,  4.9805e-01, -2.7344e-02,  ...,  1.3906e+00,
           1.1719e+00,  7.0938e+00],
         [ 1.9531e+00, -1.7578e+00,  5.5078e-01,  ...,  1.3750e+00,
           1.6562e+00,  7.6875e+00]],

        ...,

        [[ 3.8719e-04, -5.5237e-03,  4.5013e-04,  ...,  4.7852e-02,
           4.2114e-03,  6.7383e-02],
         [-2.9062e+00,  1.9297e+00,  1.4219e+00,  ...,  3.7305e-01,
          -9.2578e-01,  1.9766e+00],
         [-1.4062e+00,  7.8906e-01,  4.8047e-01,  ..., -1.0840e-01,
           1.9922e+00,  8.5938e-01],
         ...,
         [-1.0312e+00, -8.5156e-01, -5.5469e-01,  ...,  9.4531e-01,
           5.7422e-01, -3.4570e-01],
         [-1.9531e-02, -4.3555e-01,  7.1289e-02,  ...,  1.6113e-01,
           1.8125e+00, -1.6484e+00],
         [ 2.8516e-01, -3.0664e-01,  1.0938e+00,  ..., -2.6367e-01,
          -1.0547e+00, -5.6250e-01]],

        [[ 3.8719e-04, -5.5237e-03,  4.5013e-04,  ...,  4.7852e-02,
           4.2114e-03,  6.7383e-02],
         [-2.9062e+00,  1.9297e+00,  1.4219e+00,  ...,  3.7305e-01,
          -9.2578e-01,  1.9766e+00],
         [-1.4062e+00,  7.8906e-01,  4.8047e-01,  ..., -1.0840e-01,
           1.9922e+00,  8.5938e-01],
         ...,
         [-1.0312e+00, -8.5156e-01, -5.5469e-01,  ...,  9.4531e-01,
           5.7422e-01, -3.4570e-01],
         [-1.9531e-02, -4.3555e-01,  7.1289e-02,  ...,  1.6113e-01,
           1.8125e+00, -1.6484e+00],
         [ 2.8516e-01, -3.0664e-01,  1.0938e+00,  ..., -2.6367e-01,
          -1.0547e+00, -5.6250e-01]],

        [[ 3.8719e-04, -5.5237e-03,  4.5013e-04,  ...,  4.7852e-02,
           4.2114e-03,  6.7383e-02],
         [-2.9062e+00,  1.9297e+00,  1.4219e+00,  ...,  3.7305e-01,
          -9.2578e-01,  1.9766e+00],
         [-1.4062e+00,  7.8906e-01,  4.8047e-01,  ..., -1.0840e-01,
           1.9922e+00,  8.5938e-01],
         ...,
         [-1.0312e+00, -8.5156e-01, -5.5469e-01,  ...,  9.4531e-01,
           5.7422e-01, -3.4570e-01],
         [-1.9531e-02, -4.3555e-01,  7.1289e-02,  ...,  1.6113e-01,
           1.8125e+00, -1.6484e+00],
         [ 2.8516e-01, -3.0664e-01,  1.0938e+00,  ..., -2.6367e-01,
          -1.0547e+00, -5.6250e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 6.1035e-03,  4.6997e-03,  3.7384e-04,  ..., -1.2817e-03,
           8.5449e-03,  3.8757e-03],
         [-1.0234e+00, -6.3672e-01, -7.3047e-01,  ...,  4.5312e-01,
          -1.2598e-01,  1.3750e+00],
         [ 2.8516e-01, -2.5312e+00, -1.6309e-01,  ...,  1.1953e+00,
           8.6328e-01,  5.8594e-01],
         ...,
         [ 2.3281e+00,  3.5742e-01, -1.3516e+00,  ..., -4.5312e-01,
           1.7578e-01, -5.1562e-01],
         [ 1.1016e+00, -2.5781e-01, -1.2656e+00,  ..., -2.7539e-01,
           1.0391e+00, -6.2891e-01],
         [ 4.8633e-01, -1.0625e+00, -8.7500e-01,  ..., -9.7656e-01,
           1.9824e-01, -6.6797e-01]],

        [[ 6.1035e-03,  4.6997e-03,  3.7384e-04,  ..., -1.2817e-03,
           8.5449e-03,  3.8757e-03],
         [-1.0234e+00, -6.3672e-01, -7.3047e-01,  ...,  4.5312e-01,
          -1.2598e-01,  1.3750e+00],
         [ 2.8516e-01, -2.5312e+00, -1.6309e-01,  ...,  1.1953e+00,
           8.6328e-01,  5.8594e-01],
         ...,
         [ 2.3281e+00,  3.5742e-01, -1.3516e+00,  ..., -4.5312e-01,
           1.7578e-01, -5.1562e-01],
         [ 1.1016e+00, -2.5781e-01, -1.2656e+00,  ..., -2.7539e-01,
           1.0391e+00, -6.2891e-01],
         [ 4.8633e-01, -1.0625e+00, -8.7500e-01,  ..., -9.7656e-01,
           1.9824e-01, -6.6797e-01]],

        [[ 6.1035e-03,  4.6997e-03,  3.7384e-04,  ..., -1.2817e-03,
           8.5449e-03,  3.8757e-03],
         [-1.0234e+00, -6.3672e-01, -7.3047e-01,  ...,  4.5312e-01,
          -1.2598e-01,  1.3750e+00],
         [ 2.8516e-01, -2.5312e+00, -1.6309e-01,  ...,  1.1953e+00,
           8.6328e-01,  5.8594e-01],
         ...,
         [ 2.3281e+00,  3.5742e-01, -1.3516e+00,  ..., -4.5312e-01,
           1.7578e-01, -5.1562e-01],
         [ 1.1016e+00, -2.5781e-01, -1.2656e+00,  ..., -2.7539e-01,
           1.0391e+00, -6.2891e-01],
         [ 4.8633e-01, -1.0625e+00, -8.7500e-01,  ..., -9.7656e-01,
           1.9824e-01, -6.6797e-01]],

        ...,

        [[-3.0136e-04, -1.9379e-03, -3.3569e-03,  ..., -4.6692e-03,
          -7.5989e-03, -9.1934e-04],
         [-2.0117e-01, -1.7969e+00,  1.0791e-01,  ...,  2.6406e+00,
          -6.2500e-01, -4.7656e-01],
         [-1.2734e+00,  4.7070e-01, -9.9609e-01,  ..., -2.3438e+00,
           1.2344e+00,  9.4141e-01],
         ...,
         [-3.1445e-01, -6.9141e-01, -4.0039e-02,  ..., -4.6680e-01,
           1.4258e-01, -2.6367e-01],
         [-5.5859e-01,  1.3984e+00, -3.5742e-01,  ..., -1.3047e+00,
           1.1641e+00,  4.7070e-01],
         [-3.7305e-01,  7.2266e-02,  7.2754e-02,  ..., -2.1094e-01,
           8.2812e-01, -2.4316e-01]],

        [[-3.0136e-04, -1.9379e-03, -3.3569e-03,  ..., -4.6692e-03,
          -7.5989e-03, -9.1934e-04],
         [-2.0117e-01, -1.7969e+00,  1.0791e-01,  ...,  2.6406e+00,
          -6.2500e-01, -4.7656e-01],
         [-1.2734e+00,  4.7070e-01, -9.9609e-01,  ..., -2.3438e+00,
           1.2344e+00,  9.4141e-01],
         ...,
         [-3.1445e-01, -6.9141e-01, -4.0039e-02,  ..., -4.6680e-01,
           1.4258e-01, -2.6367e-01],
         [-5.5859e-01,  1.3984e+00, -3.5742e-01,  ..., -1.3047e+00,
           1.1641e+00,  4.7070e-01],
         [-3.7305e-01,  7.2266e-02,  7.2754e-02,  ..., -2.1094e-01,
           8.2812e-01, -2.4316e-01]],

        [[-3.0136e-04, -1.9379e-03, -3.3569e-03,  ..., -4.6692e-03,
          -7.5989e-03, -9.1934e-04],
         [-2.0117e-01, -1.7969e+00,  1.0791e-01,  ...,  2.6406e+00,
          -6.2500e-01, -4.7656e-01],
         [-1.2734e+00,  4.7070e-01, -9.9609e-01,  ..., -2.3438e+00,
           1.2344e+00,  9.4141e-01],
         ...,
         [-3.1445e-01, -6.9141e-01, -4.0039e-02,  ..., -4.6680e-01,
           1.4258e-01, -2.6367e-01],
         [-5.5859e-01,  1.3984e+00, -3.5742e-01,  ..., -1.3047e+00,
           1.1641e+00,  4.7070e-01],
         [-3.7305e-01,  7.2266e-02,  7.2754e-02,  ..., -2.1094e-01,
           8.2812e-01, -2.4316e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.1215e-03,  1.2283e-03,  4.0588e-03,  ...,  6.5430e-02,
           7.5684e-03, -4.1748e-02],
         [-5.1953e-01, -8.5938e-01, -1.2812e+00,  ..., -8.3203e-01,
           1.4141e+00,  1.0859e+00],
         [-1.6719e+00, -1.0156e+00, -2.2500e+00,  ..., -1.5859e+00,
          -1.4297e+00,  2.0781e+00],
         ...,
         [-1.0059e-01, -7.3828e-01,  8.1250e-01,  ..., -2.7500e+00,
           2.2278e-03,  2.7500e+00],
         [-9.7656e-04,  1.2061e-01,  6.2891e-01,  ..., -4.0000e+00,
           6.6797e-01,  2.4531e+00],
         [-1.3750e+00,  1.5625e+00,  3.8281e-01,  ..., -2.5156e+00,
          -1.7969e-01,  2.0625e+00]],

        [[-1.1215e-03,  1.2283e-03,  4.0588e-03,  ...,  6.5430e-02,
           7.5684e-03, -4.1748e-02],
         [-5.1953e-01, -8.5938e-01, -1.2812e+00,  ..., -8.3203e-01,
           1.4141e+00,  1.0859e+00],
         [-1.6719e+00, -1.0156e+00, -2.2500e+00,  ..., -1.5859e+00,
          -1.4297e+00,  2.0781e+00],
         ...,
         [-1.0059e-01, -7.3828e-01,  8.1250e-01,  ..., -2.7500e+00,
           2.2278e-03,  2.7500e+00],
         [-9.7656e-04,  1.2061e-01,  6.2891e-01,  ..., -4.0000e+00,
           6.6797e-01,  2.4531e+00],
         [-1.3750e+00,  1.5625e+00,  3.8281e-01,  ..., -2.5156e+00,
          -1.7969e-01,  2.0625e+00]],

        [[-1.1215e-03,  1.2283e-03,  4.0588e-03,  ...,  6.5430e-02,
           7.5684e-03, -4.1748e-02],
         [-5.1953e-01, -8.5938e-01, -1.2812e+00,  ..., -8.3203e-01,
           1.4141e+00,  1.0859e+00],
         [-1.6719e+00, -1.0156e+00, -2.2500e+00,  ..., -1.5859e+00,
          -1.4297e+00,  2.0781e+00],
         ...,
         [-1.0059e-01, -7.3828e-01,  8.1250e-01,  ..., -2.7500e+00,
           2.2278e-03,  2.7500e+00],
         [-9.7656e-04,  1.2061e-01,  6.2891e-01,  ..., -4.0000e+00,
           6.6797e-01,  2.4531e+00],
         [-1.3750e+00,  1.5625e+00,  3.8281e-01,  ..., -2.5156e+00,
          -1.7969e-01,  2.0625e+00]],

        ...,

        [[-2.6245e-03,  7.7820e-03, -2.7924e-03,  ...,  9.4604e-03,
           4.3640e-03, -3.2227e-01],
         [-5.1172e-01, -1.6953e+00, -1.0156e+00,  ..., -1.5156e+00,
           2.1250e+00,  1.0250e+01],
         [-1.2812e+00, -1.2109e+00, -8.0078e-01,  ..., -3.9648e-01,
           2.2559e-01,  1.0125e+01],
         ...,
         [-2.9297e-01, -2.2852e-01,  3.7305e-01,  ..., -1.7383e-01,
           2.5938e+00,  7.5938e+00],
         [-8.0859e-01,  6.4453e-01,  1.6406e-01,  ..., -1.0781e+00,
           9.4141e-01,  7.9375e+00],
         [-9.7266e-01,  1.1406e+00, -6.1328e-01,  ..., -1.6953e+00,
           4.5117e-01,  8.5000e+00]],

        [[-2.6245e-03,  7.7820e-03, -2.7924e-03,  ...,  9.4604e-03,
           4.3640e-03, -3.2227e-01],
         [-5.1172e-01, -1.6953e+00, -1.0156e+00,  ..., -1.5156e+00,
           2.1250e+00,  1.0250e+01],
         [-1.2812e+00, -1.2109e+00, -8.0078e-01,  ..., -3.9648e-01,
           2.2559e-01,  1.0125e+01],
         ...,
         [-2.9297e-01, -2.2852e-01,  3.7305e-01,  ..., -1.7383e-01,
           2.5938e+00,  7.5938e+00],
         [-8.0859e-01,  6.4453e-01,  1.6406e-01,  ..., -1.0781e+00,
           9.4141e-01,  7.9375e+00],
         [-9.7266e-01,  1.1406e+00, -6.1328e-01,  ..., -1.6953e+00,
           4.5117e-01,  8.5000e+00]],

        [[-2.6245e-03,  7.7820e-03, -2.7924e-03,  ...,  9.4604e-03,
           4.3640e-03, -3.2227e-01],
         [-5.1172e-01, -1.6953e+00, -1.0156e+00,  ..., -1.5156e+00,
           2.1250e+00,  1.0250e+01],
         [-1.2812e+00, -1.2109e+00, -8.0078e-01,  ..., -3.9648e-01,
           2.2559e-01,  1.0125e+01],
         ...,
         [-2.9297e-01, -2.2852e-01,  3.7305e-01,  ..., -1.7383e-01,
           2.5938e+00,  7.5938e+00],
         [-8.0859e-01,  6.4453e-01,  1.6406e-01,  ..., -1.0781e+00,
           9.4141e-01,  7.9375e+00],
         [-9.7266e-01,  1.1406e+00, -6.1328e-01,  ..., -1.6953e+00,
           4.5117e-01,  8.5000e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 5.8289e-03, -2.0905e-03, -2.3041e-03,  ...,  6.0654e-04,
           6.7139e-03, -1.5717e-03],
         [ 5.4688e-01,  5.8899e-03,  2.3145e-01,  ...,  1.2695e-01,
           4.5508e-01,  6.8750e-01],
         [ 9.8828e-01, -4.2969e-01,  9.2969e-01,  ..., -1.8828e+00,
           1.0547e+00, -1.2578e+00],
         ...,
         [ 1.1875e+00, -5.3906e-01, -8.0078e-02,  ..., -6.2891e-01,
           9.9609e-02,  1.0205e-01],
         [ 4.2773e-01,  7.8125e-01,  2.7100e-02,  ..., -6.7578e-01,
          -9.2969e-01,  8.0469e-01],
         [ 1.7285e-01,  2.3047e-01,  6.9824e-02,  ...,  1.3672e-01,
           1.0547e+00, -6.7188e-01]],

        [[ 5.8289e-03, -2.0905e-03, -2.3041e-03,  ...,  6.0654e-04,
           6.7139e-03, -1.5717e-03],
         [ 5.4688e-01,  5.8899e-03,  2.3145e-01,  ...,  1.2695e-01,
           4.5508e-01,  6.8750e-01],
         [ 9.8828e-01, -4.2969e-01,  9.2969e-01,  ..., -1.8828e+00,
           1.0547e+00, -1.2578e+00],
         ...,
         [ 1.1875e+00, -5.3906e-01, -8.0078e-02,  ..., -6.2891e-01,
           9.9609e-02,  1.0205e-01],
         [ 4.2773e-01,  7.8125e-01,  2.7100e-02,  ..., -6.7578e-01,
          -9.2969e-01,  8.0469e-01],
         [ 1.7285e-01,  2.3047e-01,  6.9824e-02,  ...,  1.3672e-01,
           1.0547e+00, -6.7188e-01]],

        [[ 5.8289e-03, -2.0905e-03, -2.3041e-03,  ...,  6.0654e-04,
           6.7139e-03, -1.5717e-03],
         [ 5.4688e-01,  5.8899e-03,  2.3145e-01,  ...,  1.2695e-01,
           4.5508e-01,  6.8750e-01],
         [ 9.8828e-01, -4.2969e-01,  9.2969e-01,  ..., -1.8828e+00,
           1.0547e+00, -1.2578e+00],
         ...,
         [ 1.1875e+00, -5.3906e-01, -8.0078e-02,  ..., -6.2891e-01,
           9.9609e-02,  1.0205e-01],
         [ 4.2773e-01,  7.8125e-01,  2.7100e-02,  ..., -6.7578e-01,
          -9.2969e-01,  8.0469e-01],
         [ 1.7285e-01,  2.3047e-01,  6.9824e-02,  ...,  1.3672e-01,
           1.0547e+00, -6.7188e-01]],

        ...,

        [[-3.1128e-03, -7.8583e-04,  5.3406e-03,  ..., -7.8125e-03,
           1.1963e-02, -1.4420e-03],
         [ 8.1250e-01, -4.7461e-01, -8.2812e-01,  ...,  4.3555e-01,
          -1.8125e+00, -8.7891e-01],
         [-1.0391e+00, -8.7500e-01, -4.8242e-01,  ..., -1.1963e-01,
          -1.3984e+00, -1.8359e-01],
         ...,
         [ 2.9883e-01,  3.4180e-01,  3.6914e-01,  ...,  9.6094e-01,
          -6.0156e-01, -1.7676e-01],
         [ 8.3203e-01, -1.9434e-01,  5.6250e-01,  ...,  5.4297e-01,
          -1.6797e-01,  4.9805e-01],
         [ 9.6484e-01, -6.8359e-01,  7.0312e-01,  ..., -4.8047e-01,
           1.8262e-01, -3.4180e-01]],

        [[-3.1128e-03, -7.8583e-04,  5.3406e-03,  ..., -7.8125e-03,
           1.1963e-02, -1.4420e-03],
         [ 8.1250e-01, -4.7461e-01, -8.2812e-01,  ...,  4.3555e-01,
          -1.8125e+00, -8.7891e-01],
         [-1.0391e+00, -8.7500e-01, -4.8242e-01,  ..., -1.1963e-01,
          -1.3984e+00, -1.8359e-01],
         ...,
         [ 2.9883e-01,  3.4180e-01,  3.6914e-01,  ...,  9.6094e-01,
          -6.0156e-01, -1.7676e-01],
         [ 8.3203e-01, -1.9434e-01,  5.6250e-01,  ...,  5.4297e-01,
          -1.6797e-01,  4.9805e-01],
         [ 9.6484e-01, -6.8359e-01,  7.0312e-01,  ..., -4.8047e-01,
           1.8262e-01, -3.4180e-01]],

        [[-3.1128e-03, -7.8583e-04,  5.3406e-03,  ..., -7.8125e-03,
           1.1963e-02, -1.4420e-03],
         [ 8.1250e-01, -4.7461e-01, -8.2812e-01,  ...,  4.3555e-01,
          -1.8125e+00, -8.7891e-01],
         [-1.0391e+00, -8.7500e-01, -4.8242e-01,  ..., -1.1963e-01,
          -1.3984e+00, -1.8359e-01],
         ...,
         [ 2.9883e-01,  3.4180e-01,  3.6914e-01,  ...,  9.6094e-01,
          -6.0156e-01, -1.7676e-01],
         [ 8.3203e-01, -1.9434e-01,  5.6250e-01,  ...,  5.4297e-01,
          -1.6797e-01,  4.9805e-01],
         [ 9.6484e-01, -6.8359e-01,  7.0312e-01,  ..., -4.8047e-01,
           1.8262e-01, -3.4180e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-2.0508e-02,  6.1340e-03, -8.1787e-03,  ..., -1.7285e-01,
          -1.0010e-01,  9.5703e-02],
         [ 2.2656e+00, -1.7500e+00,  1.6562e+00,  ...,  1.0625e+00,
           8.9453e-01, -1.7969e+00],
         [ 1.6172e+00, -1.5234e+00,  1.7031e+00,  ...,  7.7344e-01,
           8.0078e-02,  5.8984e-01],
         ...,
         [ 1.1094e+00, -5.0391e-01, -7.0703e-01,  ...,  1.2578e+00,
           1.0703e+00, -1.4922e+00],
         [ 3.9062e-01,  2.0264e-02, -1.5234e-01,  ...,  1.1719e+00,
           1.2266e+00, -4.7852e-01],
         [ 1.0193e-02,  3.9062e-01, -2.1484e-01,  ...,  1.2422e+00,
           3.1094e+00, -9.8828e-01]],

        [[-2.0508e-02,  6.1340e-03, -8.1787e-03,  ..., -1.7285e-01,
          -1.0010e-01,  9.5703e-02],
         [ 2.2656e+00, -1.7500e+00,  1.6562e+00,  ...,  1.0625e+00,
           8.9453e-01, -1.7969e+00],
         [ 1.6172e+00, -1.5234e+00,  1.7031e+00,  ...,  7.7344e-01,
           8.0078e-02,  5.8984e-01],
         ...,
         [ 1.1094e+00, -5.0391e-01, -7.0703e-01,  ...,  1.2578e+00,
           1.0703e+00, -1.4922e+00],
         [ 3.9062e-01,  2.0264e-02, -1.5234e-01,  ...,  1.1719e+00,
           1.2266e+00, -4.7852e-01],
         [ 1.0193e-02,  3.9062e-01, -2.1484e-01,  ...,  1.2422e+00,
           3.1094e+00, -9.8828e-01]],

        [[-2.0508e-02,  6.1340e-03, -8.1787e-03,  ..., -1.7285e-01,
          -1.0010e-01,  9.5703e-02],
         [ 2.2656e+00, -1.7500e+00,  1.6562e+00,  ...,  1.0625e+00,
           8.9453e-01, -1.7969e+00],
         [ 1.6172e+00, -1.5234e+00,  1.7031e+00,  ...,  7.7344e-01,
           8.0078e-02,  5.8984e-01],
         ...,
         [ 1.1094e+00, -5.0391e-01, -7.0703e-01,  ...,  1.2578e+00,
           1.0703e+00, -1.4922e+00],
         [ 3.9062e-01,  2.0264e-02, -1.5234e-01,  ...,  1.1719e+00,
           1.2266e+00, -4.7852e-01],
         [ 1.0193e-02,  3.9062e-01, -2.1484e-01,  ...,  1.2422e+00,
           3.1094e+00, -9.8828e-01]],

        ...,

        [[-7.2021e-03, -4.3335e-03,  4.3335e-03,  ..., -1.3245e-02,
          -7.4609e-01,  1.4941e-01],
         [-1.0156e+00,  2.8125e+00,  1.6016e-01,  ...,  1.0391e+00,
           7.6562e+00, -8.5156e-01],
         [-2.3750e+00,  1.4844e+00,  1.1094e+00,  ...,  5.0781e-01,
           8.9375e+00, -2.2812e+00],
         ...,
         [-1.7969e-01, -7.8125e-02, -1.3184e-01,  ...,  6.5234e-01,
           7.0938e+00, -4.5312e-01],
         [-1.3438e+00, -1.0625e+00, -1.1094e+00,  ..., -4.7852e-01,
           7.4688e+00, -7.8516e-01],
         [-1.3516e+00, -1.6953e+00, -1.4688e+00,  ...,  1.2734e+00,
           7.4688e+00, -8.6328e-01]],

        [[-7.2021e-03, -4.3335e-03,  4.3335e-03,  ..., -1.3245e-02,
          -7.4609e-01,  1.4941e-01],
         [-1.0156e+00,  2.8125e+00,  1.6016e-01,  ...,  1.0391e+00,
           7.6562e+00, -8.5156e-01],
         [-2.3750e+00,  1.4844e+00,  1.1094e+00,  ...,  5.0781e-01,
           8.9375e+00, -2.2812e+00],
         ...,
         [-1.7969e-01, -7.8125e-02, -1.3184e-01,  ...,  6.5234e-01,
           7.0938e+00, -4.5312e-01],
         [-1.3438e+00, -1.0625e+00, -1.1094e+00,  ..., -4.7852e-01,
           7.4688e+00, -7.8516e-01],
         [-1.3516e+00, -1.6953e+00, -1.4688e+00,  ...,  1.2734e+00,
           7.4688e+00, -8.6328e-01]],

        [[-7.2021e-03, -4.3335e-03,  4.3335e-03,  ..., -1.3245e-02,
          -7.4609e-01,  1.4941e-01],
         [-1.0156e+00,  2.8125e+00,  1.6016e-01,  ...,  1.0391e+00,
           7.6562e+00, -8.5156e-01],
         [-2.3750e+00,  1.4844e+00,  1.1094e+00,  ...,  5.0781e-01,
           8.9375e+00, -2.2812e+00],
         ...,
         [-1.7969e-01, -7.8125e-02, -1.3184e-01,  ...,  6.5234e-01,
           7.0938e+00, -4.5312e-01],
         [-1.3438e+00, -1.0625e+00, -1.1094e+00,  ..., -4.7852e-01,
           7.4688e+00, -7.8516e-01],
         [-1.3516e+00, -1.6953e+00, -1.4688e+00,  ...,  1.2734e+00,
           7.4688e+00, -8.6328e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 2.0752e-03, -2.1210e-03, -6.5002e-03,  ...,  5.2185e-03,
          -5.4626e-03, -2.9144e-03],
         [ 8.2397e-03, -2.9102e-01,  2.8516e-01,  ..., -9.3750e-01,
          -5.7812e-01,  6.2109e-01],
         [ 3.7500e-01, -6.1328e-01,  1.0547e-01,  ..., -1.8438e+00,
           4.2969e-01, -1.8281e+00],
         ...,
         [ 4.7070e-01, -1.4141e+00, -3.8086e-01,  ..., -7.5684e-02,
          -1.0625e+00,  1.6309e-01],
         [ 3.5742e-01, -9.2578e-01, -9.1016e-01,  ...,  8.3984e-01,
          -1.1016e+00, -1.3281e+00],
         [ 1.0156e+00, -9.0234e-01, -1.9434e-01,  ...,  4.3555e-01,
          -1.3672e+00, -9.7266e-01]],

        [[ 2.0752e-03, -2.1210e-03, -6.5002e-03,  ...,  5.2185e-03,
          -5.4626e-03, -2.9144e-03],
         [ 8.2397e-03, -2.9102e-01,  2.8516e-01,  ..., -9.3750e-01,
          -5.7812e-01,  6.2109e-01],
         [ 3.7500e-01, -6.1328e-01,  1.0547e-01,  ..., -1.8438e+00,
           4.2969e-01, -1.8281e+00],
         ...,
         [ 4.7070e-01, -1.4141e+00, -3.8086e-01,  ..., -7.5684e-02,
          -1.0625e+00,  1.6309e-01],
         [ 3.5742e-01, -9.2578e-01, -9.1016e-01,  ...,  8.3984e-01,
          -1.1016e+00, -1.3281e+00],
         [ 1.0156e+00, -9.0234e-01, -1.9434e-01,  ...,  4.3555e-01,
          -1.3672e+00, -9.7266e-01]],

        [[ 2.0752e-03, -2.1210e-03, -6.5002e-03,  ...,  5.2185e-03,
          -5.4626e-03, -2.9144e-03],
         [ 8.2397e-03, -2.9102e-01,  2.8516e-01,  ..., -9.3750e-01,
          -5.7812e-01,  6.2109e-01],
         [ 3.7500e-01, -6.1328e-01,  1.0547e-01,  ..., -1.8438e+00,
           4.2969e-01, -1.8281e+00],
         ...,
         [ 4.7070e-01, -1.4141e+00, -3.8086e-01,  ..., -7.5684e-02,
          -1.0625e+00,  1.6309e-01],
         [ 3.5742e-01, -9.2578e-01, -9.1016e-01,  ...,  8.3984e-01,
          -1.1016e+00, -1.3281e+00],
         [ 1.0156e+00, -9.0234e-01, -1.9434e-01,  ...,  4.3555e-01,
          -1.3672e+00, -9.7266e-01]],

        ...,

        [[ 4.7302e-03,  9.9182e-04, -4.4556e-03,  ...,  1.8921e-03,
          -6.1340e-03,  2.1820e-03],
         [ 5.3906e-01,  8.6426e-02, -3.4424e-02,  ...,  9.8438e-01,
           8.1250e-01,  7.6172e-01],
         [-3.0859e-01,  3.6719e-01,  1.8906e+00,  ...,  1.7871e-01,
           6.7871e-02, -2.1562e+00],
         ...,
         [-5.3516e-01,  1.5015e-02, -2.0630e-02,  ..., -5.3467e-02,
           4.6387e-02, -8.9844e-02],
         [ 2.6562e-01,  8.6719e-01, -6.9922e-01,  ..., -1.2656e+00,
          -3.1641e-01, -8.0469e-01],
         [-2.8516e-01, -5.2490e-02,  8.6719e-01,  ..., -4.3555e-01,
          -8.6426e-02,  1.1768e-01]],

        [[ 4.7302e-03,  9.9182e-04, -4.4556e-03,  ...,  1.8921e-03,
          -6.1340e-03,  2.1820e-03],
         [ 5.3906e-01,  8.6426e-02, -3.4424e-02,  ...,  9.8438e-01,
           8.1250e-01,  7.6172e-01],
         [-3.0859e-01,  3.6719e-01,  1.8906e+00,  ...,  1.7871e-01,
           6.7871e-02, -2.1562e+00],
         ...,
         [-5.3516e-01,  1.5015e-02, -2.0630e-02,  ..., -5.3467e-02,
           4.6387e-02, -8.9844e-02],
         [ 2.6562e-01,  8.6719e-01, -6.9922e-01,  ..., -1.2656e+00,
          -3.1641e-01, -8.0469e-01],
         [-2.8516e-01, -5.2490e-02,  8.6719e-01,  ..., -4.3555e-01,
          -8.6426e-02,  1.1768e-01]],

        [[ 4.7302e-03,  9.9182e-04, -4.4556e-03,  ...,  1.8921e-03,
          -6.1340e-03,  2.1820e-03],
         [ 5.3906e-01,  8.6426e-02, -3.4424e-02,  ...,  9.8438e-01,
           8.1250e-01,  7.6172e-01],
         [-3.0859e-01,  3.6719e-01,  1.8906e+00,  ...,  1.7871e-01,
           6.7871e-02, -2.1562e+00],
         ...,
         [-5.3516e-01,  1.5015e-02, -2.0630e-02,  ..., -5.3467e-02,
           4.6387e-02, -8.9844e-02],
         [ 2.6562e-01,  8.6719e-01, -6.9922e-01,  ..., -1.2656e+00,
          -3.1641e-01, -8.0469e-01],
         [-2.8516e-01, -5.2490e-02,  8.6719e-01,  ..., -4.3555e-01,
          -8.6426e-02,  1.1768e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 3.7994e-03, -2.5146e-02, -1.2360e-03,  ..., -2.3633e-01,
          -9.6680e-02,  1.0059e-01],
         [ 2.5000e-01, -3.6328e-01,  3.5156e-01,  ..., -8.7402e-02,
          -9.6484e-01, -1.3516e+00],
         [ 1.1797e+00, -1.0859e+00, -7.4609e-01,  ..., -5.4062e+00,
          -4.2500e+00, -3.3750e+00],
         ...,
         [ 9.9609e-02,  9.0234e-01,  2.5781e-01,  ..., -5.1250e+00,
          -7.9297e-01, -1.4375e+00],
         [ 8.1641e-01,  2.0703e-01,  5.8203e-01,  ..., -4.9688e+00,
           1.2207e-01,  1.0234e+00],
         [ 1.4922e+00,  1.2305e-01,  2.4121e-01,  ..., -3.5156e+00,
          -3.0469e+00, -1.1250e+00]],

        [[ 3.7994e-03, -2.5146e-02, -1.2360e-03,  ..., -2.3633e-01,
          -9.6680e-02,  1.0059e-01],
         [ 2.5000e-01, -3.6328e-01,  3.5156e-01,  ..., -8.7402e-02,
          -9.6484e-01, -1.3516e+00],
         [ 1.1797e+00, -1.0859e+00, -7.4609e-01,  ..., -5.4062e+00,
          -4.2500e+00, -3.3750e+00],
         ...,
         [ 9.9609e-02,  9.0234e-01,  2.5781e-01,  ..., -5.1250e+00,
          -7.9297e-01, -1.4375e+00],
         [ 8.1641e-01,  2.0703e-01,  5.8203e-01,  ..., -4.9688e+00,
           1.2207e-01,  1.0234e+00],
         [ 1.4922e+00,  1.2305e-01,  2.4121e-01,  ..., -3.5156e+00,
          -3.0469e+00, -1.1250e+00]],

        [[ 3.7994e-03, -2.5146e-02, -1.2360e-03,  ..., -2.3633e-01,
          -9.6680e-02,  1.0059e-01],
         [ 2.5000e-01, -3.6328e-01,  3.5156e-01,  ..., -8.7402e-02,
          -9.6484e-01, -1.3516e+00],
         [ 1.1797e+00, -1.0859e+00, -7.4609e-01,  ..., -5.4062e+00,
          -4.2500e+00, -3.3750e+00],
         ...,
         [ 9.9609e-02,  9.0234e-01,  2.5781e-01,  ..., -5.1250e+00,
          -7.9297e-01, -1.4375e+00],
         [ 8.1641e-01,  2.0703e-01,  5.8203e-01,  ..., -4.9688e+00,
           1.2207e-01,  1.0234e+00],
         [ 1.4922e+00,  1.2305e-01,  2.4121e-01,  ..., -3.5156e+00,
          -3.0469e+00, -1.1250e+00]],

        ...,

        [[-8.1177e-03, -2.3041e-03, -8.8501e-04,  ...,  4.3213e-02,
           1.1426e-01,  1.2305e-01],
         [ 1.2969e+00,  1.7734e+00,  1.0703e+00,  ...,  2.6406e+00,
           9.8828e-01,  1.5723e-01],
         [ 1.7188e-01,  6.2500e-01,  2.2344e+00,  ...,  1.0625e+00,
          -7.8906e-01,  2.1484e-01],
         ...,
         [-8.8867e-02,  3.2617e-01,  2.4219e-01,  ...,  4.5703e-01,
           4.7656e-01, -2.8281e+00],
         [ 1.9141e-01,  4.1016e-01,  6.0156e-01,  ...,  2.6172e-01,
           9.1309e-02, -3.1055e-01],
         [ 1.2266e+00, -9.2188e-01,  2.0781e+00,  ...,  2.1719e+00,
           9.9219e-01, -1.4297e+00]],

        [[-8.1177e-03, -2.3041e-03, -8.8501e-04,  ...,  4.3213e-02,
           1.1426e-01,  1.2305e-01],
         [ 1.2969e+00,  1.7734e+00,  1.0703e+00,  ...,  2.6406e+00,
           9.8828e-01,  1.5723e-01],
         [ 1.7188e-01,  6.2500e-01,  2.2344e+00,  ...,  1.0625e+00,
          -7.8906e-01,  2.1484e-01],
         ...,
         [-8.8867e-02,  3.2617e-01,  2.4219e-01,  ...,  4.5703e-01,
           4.7656e-01, -2.8281e+00],
         [ 1.9141e-01,  4.1016e-01,  6.0156e-01,  ...,  2.6172e-01,
           9.1309e-02, -3.1055e-01],
         [ 1.2266e+00, -9.2188e-01,  2.0781e+00,  ...,  2.1719e+00,
           9.9219e-01, -1.4297e+00]],

        [[-8.1177e-03, -2.3041e-03, -8.8501e-04,  ...,  4.3213e-02,
           1.1426e-01,  1.2305e-01],
         [ 1.2969e+00,  1.7734e+00,  1.0703e+00,  ...,  2.6406e+00,
           9.8828e-01,  1.5723e-01],
         [ 1.7188e-01,  6.2500e-01,  2.2344e+00,  ...,  1.0625e+00,
          -7.8906e-01,  2.1484e-01],
         ...,
         [-8.8867e-02,  3.2617e-01,  2.4219e-01,  ...,  4.5703e-01,
           4.7656e-01, -2.8281e+00],
         [ 1.9141e-01,  4.1016e-01,  6.0156e-01,  ...,  2.6172e-01,
           9.1309e-02, -3.1055e-01],
         [ 1.2266e+00, -9.2188e-01,  2.0781e+00,  ...,  2.1719e+00,
           9.9219e-01, -1.4297e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 7.2327e-03, -4.8218e-03, -5.1575e-03,  ...,  1.7090e-03,
           1.6479e-03,  5.5847e-03],
         [-1.0781e+00, -9.1406e-01,  2.9688e+00,  ..., -1.7031e+00,
          -5.6641e-01, -1.1016e+00],
         [-1.9219e+00, -4.1562e+00,  4.4688e+00,  ..., -7.5781e-01,
          -1.3359e+00, -9.6094e-01],
         ...,
         [-1.5625e+00, -6.8750e-01, -4.9072e-02,  ...,  1.3906e+00,
           5.8984e-01,  1.4922e+00],
         [-3.9062e-01, -1.6953e+00,  1.7500e+00,  ..., -6.4844e-01,
           4.9609e-01,  1.6328e+00],
         [-1.2031e+00, -2.3594e+00,  1.7812e+00,  ..., -2.5195e-01,
          -3.6133e-01,  2.5586e-01]],

        [[ 7.2327e-03, -4.8218e-03, -5.1575e-03,  ...,  1.7090e-03,
           1.6479e-03,  5.5847e-03],
         [-1.0781e+00, -9.1406e-01,  2.9688e+00,  ..., -1.7031e+00,
          -5.6641e-01, -1.1016e+00],
         [-1.9219e+00, -4.1562e+00,  4.4688e+00,  ..., -7.5781e-01,
          -1.3359e+00, -9.6094e-01],
         ...,
         [-1.5625e+00, -6.8750e-01, -4.9072e-02,  ...,  1.3906e+00,
           5.8984e-01,  1.4922e+00],
         [-3.9062e-01, -1.6953e+00,  1.7500e+00,  ..., -6.4844e-01,
           4.9609e-01,  1.6328e+00],
         [-1.2031e+00, -2.3594e+00,  1.7812e+00,  ..., -2.5195e-01,
          -3.6133e-01,  2.5586e-01]],

        [[ 7.2327e-03, -4.8218e-03, -5.1575e-03,  ...,  1.7090e-03,
           1.6479e-03,  5.5847e-03],
         [-1.0781e+00, -9.1406e-01,  2.9688e+00,  ..., -1.7031e+00,
          -5.6641e-01, -1.1016e+00],
         [-1.9219e+00, -4.1562e+00,  4.4688e+00,  ..., -7.5781e-01,
          -1.3359e+00, -9.6094e-01],
         ...,
         [-1.5625e+00, -6.8750e-01, -4.9072e-02,  ...,  1.3906e+00,
           5.8984e-01,  1.4922e+00],
         [-3.9062e-01, -1.6953e+00,  1.7500e+00,  ..., -6.4844e-01,
           4.9609e-01,  1.6328e+00],
         [-1.2031e+00, -2.3594e+00,  1.7812e+00,  ..., -2.5195e-01,
          -3.6133e-01,  2.5586e-01]],

        ...,

        [[ 5.9509e-04,  1.7090e-02, -3.9368e-03,  ...,  3.6011e-03,
           3.3569e-03, -2.1362e-03],
         [-1.0547e-01,  1.1963e-01,  1.7266e+00,  ..., -1.3359e+00,
          -6.9141e-01, -1.7422e+00],
         [-4.1809e-03, -8.7402e-02,  5.8594e-01,  ..., -8.2422e-01,
          -4.3701e-02, -3.7598e-02],
         ...,
         [-1.4453e+00, -1.0596e-01,  4.8242e-01,  ..., -1.8750e+00,
          -2.7344e+00, -1.2344e+00],
         [-8.1250e-01, -5.6250e-01, -2.3828e-01,  ..., -1.0859e+00,
          -1.8359e+00, -1.8750e+00],
         [-1.5078e+00,  7.9297e-01,  8.1250e-01,  ..., -1.9531e+00,
          -5.8203e-01,  1.2500e+00]],

        [[ 5.9509e-04,  1.7090e-02, -3.9368e-03,  ...,  3.6011e-03,
           3.3569e-03, -2.1362e-03],
         [-1.0547e-01,  1.1963e-01,  1.7266e+00,  ..., -1.3359e+00,
          -6.9141e-01, -1.7422e+00],
         [-4.1809e-03, -8.7402e-02,  5.8594e-01,  ..., -8.2422e-01,
          -4.3701e-02, -3.7598e-02],
         ...,
         [-1.4453e+00, -1.0596e-01,  4.8242e-01,  ..., -1.8750e+00,
          -2.7344e+00, -1.2344e+00],
         [-8.1250e-01, -5.6250e-01, -2.3828e-01,  ..., -1.0859e+00,
          -1.8359e+00, -1.8750e+00],
         [-1.5078e+00,  7.9297e-01,  8.1250e-01,  ..., -1.9531e+00,
          -5.8203e-01,  1.2500e+00]],

        [[ 5.9509e-04,  1.7090e-02, -3.9368e-03,  ...,  3.6011e-03,
           3.3569e-03, -2.1362e-03],
         [-1.0547e-01,  1.1963e-01,  1.7266e+00,  ..., -1.3359e+00,
          -6.9141e-01, -1.7422e+00],
         [-4.1809e-03, -8.7402e-02,  5.8594e-01,  ..., -8.2422e-01,
          -4.3701e-02, -3.7598e-02],
         ...,
         [-1.4453e+00, -1.0596e-01,  4.8242e-01,  ..., -1.8750e+00,
          -2.7344e+00, -1.2344e+00],
         [-8.1250e-01, -5.6250e-01, -2.3828e-01,  ..., -1.0859e+00,
          -1.8359e+00, -1.8750e+00],
         [-1.5078e+00,  7.9297e-01,  8.1250e-01,  ..., -1.9531e+00,
          -5.8203e-01,  1.2500e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.2024e-02,  5.1270e-03,  5.8289e-03,  ..., -1.3594e+00,
           1.7090e-03,  2.0215e-01],
         [ 3.5547e-01,  1.9609e+00, -6.4844e-01,  ...,  9.5625e+00,
           1.6113e-01, -2.9062e+00],
         [-9.1016e-01, -1.1621e-01,  2.7734e-01,  ...,  1.0000e+01,
           1.7188e+00, -6.8359e-01],
         ...,
         [ 1.7090e-01, -5.3516e-01, -1.1377e-01,  ...,  9.2500e+00,
           1.3477e-01, -1.3984e+00],
         [-1.1406e+00, -1.6328e+00, -5.9766e-01,  ...,  8.8125e+00,
           3.7891e-01, -1.2812e+00],
         [-1.8281e+00, -3.5156e-01, -1.2812e+00,  ...,  8.8125e+00,
          -8.1250e-01, -2.5781e+00]],

        [[-1.2024e-02,  5.1270e-03,  5.8289e-03,  ..., -1.3594e+00,
           1.7090e-03,  2.0215e-01],
         [ 3.5547e-01,  1.9609e+00, -6.4844e-01,  ...,  9.5625e+00,
           1.6113e-01, -2.9062e+00],
         [-9.1016e-01, -1.1621e-01,  2.7734e-01,  ...,  1.0000e+01,
           1.7188e+00, -6.8359e-01],
         ...,
         [ 1.7090e-01, -5.3516e-01, -1.1377e-01,  ...,  9.2500e+00,
           1.3477e-01, -1.3984e+00],
         [-1.1406e+00, -1.6328e+00, -5.9766e-01,  ...,  8.8125e+00,
           3.7891e-01, -1.2812e+00],
         [-1.8281e+00, -3.5156e-01, -1.2812e+00,  ...,  8.8125e+00,
          -8.1250e-01, -2.5781e+00]],

        [[-1.2024e-02,  5.1270e-03,  5.8289e-03,  ..., -1.3594e+00,
           1.7090e-03,  2.0215e-01],
         [ 3.5547e-01,  1.9609e+00, -6.4844e-01,  ...,  9.5625e+00,
           1.6113e-01, -2.9062e+00],
         [-9.1016e-01, -1.1621e-01,  2.7734e-01,  ...,  1.0000e+01,
           1.7188e+00, -6.8359e-01],
         ...,
         [ 1.7090e-01, -5.3516e-01, -1.1377e-01,  ...,  9.2500e+00,
           1.3477e-01, -1.3984e+00],
         [-1.1406e+00, -1.6328e+00, -5.9766e-01,  ...,  8.8125e+00,
           3.7891e-01, -1.2812e+00],
         [-1.8281e+00, -3.5156e-01, -1.2812e+00,  ...,  8.8125e+00,
          -8.1250e-01, -2.5781e+00]],

        ...,

        [[ 9.3384e-03,  1.5198e-02, -1.2146e-02,  ...,  7.6660e-02,
          -3.2422e-01,  1.1523e-01],
         [-1.5078e+00,  2.6406e+00,  1.2656e+00,  ...,  2.8711e-01,
           2.0469e+00,  1.1523e-01],
         [-7.0312e-01,  9.1309e-02,  5.9375e-01,  ...,  4.6875e+00,
          -4.1562e+00,  1.5234e+00],
         ...,
         [-1.0781e+00,  3.7500e-01, -1.4648e-02,  ...,  1.6406e+00,
          -3.6719e+00,  2.9688e+00],
         [ 5.0781e-01, -5.5664e-02,  2.1387e-01,  ...,  7.9688e-01,
          -3.6562e+00,  3.0781e+00],
         [ 2.4531e+00, -9.6875e-01,  5.7031e-01,  ...,  1.5703e+00,
          -1.7500e+00,  1.7344e+00]],

        [[ 9.3384e-03,  1.5198e-02, -1.2146e-02,  ...,  7.6660e-02,
          -3.2422e-01,  1.1523e-01],
         [-1.5078e+00,  2.6406e+00,  1.2656e+00,  ...,  2.8711e-01,
           2.0469e+00,  1.1523e-01],
         [-7.0312e-01,  9.1309e-02,  5.9375e-01,  ...,  4.6875e+00,
          -4.1562e+00,  1.5234e+00],
         ...,
         [-1.0781e+00,  3.7500e-01, -1.4648e-02,  ...,  1.6406e+00,
          -3.6719e+00,  2.9688e+00],
         [ 5.0781e-01, -5.5664e-02,  2.1387e-01,  ...,  7.9688e-01,
          -3.6562e+00,  3.0781e+00],
         [ 2.4531e+00, -9.6875e-01,  5.7031e-01,  ...,  1.5703e+00,
          -1.7500e+00,  1.7344e+00]],

        [[ 9.3384e-03,  1.5198e-02, -1.2146e-02,  ...,  7.6660e-02,
          -3.2422e-01,  1.1523e-01],
         [-1.5078e+00,  2.6406e+00,  1.2656e+00,  ...,  2.8711e-01,
           2.0469e+00,  1.1523e-01],
         [-7.0312e-01,  9.1309e-02,  5.9375e-01,  ...,  4.6875e+00,
          -4.1562e+00,  1.5234e+00],
         ...,
         [-1.0781e+00,  3.7500e-01, -1.4648e-02,  ...,  1.6406e+00,
          -3.6719e+00,  2.9688e+00],
         [ 5.0781e-01, -5.5664e-02,  2.1387e-01,  ...,  7.9688e-01,
          -3.6562e+00,  3.0781e+00],
         [ 2.4531e+00, -9.6875e-01,  5.7031e-01,  ...,  1.5703e+00,
          -1.7500e+00,  1.7344e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-9.0942e-03,  5.8289e-03, -4.0894e-03,  ..., -1.3184e-02,
          -4.6997e-03,  2.1172e-04],
         [ 1.4062e-01, -7.3047e-01, -4.3164e-01,  ..., -9.0625e-01,
           6.6797e-01, -9.2969e-01],
         [-9.3359e-01, -2.1719e+00,  3.2959e-02,  ...,  9.1406e-01,
           5.8984e-01,  5.2734e-01],
         ...,
         [-4.8047e-01,  1.2793e-01,  8.3203e-01,  ...,  5.4688e-01,
           7.3828e-01,  2.0410e-01],
         [-7.6172e-01,  1.3125e+00, -8.0469e-01,  ...,  1.3203e+00,
           9.9609e-01,  9.3750e-02],
         [ 4.5703e-01,  1.8047e+00,  7.8516e-01,  ...,  1.1094e+00,
          -1.2207e-01, -3.5742e-01]],

        [[-9.0942e-03,  5.8289e-03, -4.0894e-03,  ..., -1.3184e-02,
          -4.6997e-03,  2.1172e-04],
         [ 1.4062e-01, -7.3047e-01, -4.3164e-01,  ..., -9.0625e-01,
           6.6797e-01, -9.2969e-01],
         [-9.3359e-01, -2.1719e+00,  3.2959e-02,  ...,  9.1406e-01,
           5.8984e-01,  5.2734e-01],
         ...,
         [-4.8047e-01,  1.2793e-01,  8.3203e-01,  ...,  5.4688e-01,
           7.3828e-01,  2.0410e-01],
         [-7.6172e-01,  1.3125e+00, -8.0469e-01,  ...,  1.3203e+00,
           9.9609e-01,  9.3750e-02],
         [ 4.5703e-01,  1.8047e+00,  7.8516e-01,  ...,  1.1094e+00,
          -1.2207e-01, -3.5742e-01]],

        [[-9.0942e-03,  5.8289e-03, -4.0894e-03,  ..., -1.3184e-02,
          -4.6997e-03,  2.1172e-04],
         [ 1.4062e-01, -7.3047e-01, -4.3164e-01,  ..., -9.0625e-01,
           6.6797e-01, -9.2969e-01],
         [-9.3359e-01, -2.1719e+00,  3.2959e-02,  ...,  9.1406e-01,
           5.8984e-01,  5.2734e-01],
         ...,
         [-4.8047e-01,  1.2793e-01,  8.3203e-01,  ...,  5.4688e-01,
           7.3828e-01,  2.0410e-01],
         [-7.6172e-01,  1.3125e+00, -8.0469e-01,  ...,  1.3203e+00,
           9.9609e-01,  9.3750e-02],
         [ 4.5703e-01,  1.8047e+00,  7.8516e-01,  ...,  1.1094e+00,
          -1.2207e-01, -3.5742e-01]],

        ...,

        [[ 3.6955e-05, -7.8735e-03,  3.1853e-04,  ...,  1.9531e-02,
           2.4986e-04, -6.5308e-03],
         [-1.0449e-01, -9.7656e-02,  5.8984e-01,  ...,  8.9844e-01,
          -1.6094e+00,  1.9434e-01],
         [-1.0781e+00,  3.4219e+00,  2.5000e+00,  ..., -6.2500e-01,
          -1.0859e+00,  1.2812e+00],
         ...,
         [-2.5195e-01, -3.1445e-01,  7.3438e-01,  ..., -4.7266e-01,
          -2.1094e+00,  3.3984e-01],
         [-1.0234e+00,  1.3672e+00,  1.4922e+00,  ..., -9.8438e-01,
          -3.2969e+00, -2.1191e-01],
         [-1.6875e+00,  1.4688e+00,  1.3594e+00,  ..., -3.1055e-01,
          -2.5156e+00,  1.2422e+00]],

        [[ 3.6955e-05, -7.8735e-03,  3.1853e-04,  ...,  1.9531e-02,
           2.4986e-04, -6.5308e-03],
         [-1.0449e-01, -9.7656e-02,  5.8984e-01,  ...,  8.9844e-01,
          -1.6094e+00,  1.9434e-01],
         [-1.0781e+00,  3.4219e+00,  2.5000e+00,  ..., -6.2500e-01,
          -1.0859e+00,  1.2812e+00],
         ...,
         [-2.5195e-01, -3.1445e-01,  7.3438e-01,  ..., -4.7266e-01,
          -2.1094e+00,  3.3984e-01],
         [-1.0234e+00,  1.3672e+00,  1.4922e+00,  ..., -9.8438e-01,
          -3.2969e+00, -2.1191e-01],
         [-1.6875e+00,  1.4688e+00,  1.3594e+00,  ..., -3.1055e-01,
          -2.5156e+00,  1.2422e+00]],

        [[ 3.6955e-05, -7.8735e-03,  3.1853e-04,  ...,  1.9531e-02,
           2.4986e-04, -6.5308e-03],
         [-1.0449e-01, -9.7656e-02,  5.8984e-01,  ...,  8.9844e-01,
          -1.6094e+00,  1.9434e-01],
         [-1.0781e+00,  3.4219e+00,  2.5000e+00,  ..., -6.2500e-01,
          -1.0859e+00,  1.2812e+00],
         ...,
         [-2.5195e-01, -3.1445e-01,  7.3438e-01,  ..., -4.7266e-01,
          -2.1094e+00,  3.3984e-01],
         [-1.0234e+00,  1.3672e+00,  1.4922e+00,  ..., -9.8438e-01,
          -3.2969e+00, -2.1191e-01],
         [-1.6875e+00,  1.4688e+00,  1.3594e+00,  ..., -3.1055e-01,
          -2.5156e+00,  1.2422e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.8076e-03, -3.0212e-03,  3.6774e-03,  ...,  4.7266e-01,
           5.4443e-02, -6.8848e-02],
         [-4.5625e+00,  8.1250e-01,  1.7500e+00,  ..., -5.6641e-01,
          -1.9141e-01,  8.0078e-02],
         [-1.1719e+00,  2.5938e+00,  1.4844e+00,  ..., -6.6406e-01,
           4.4727e-01,  4.9609e-01],
         ...,
         [-2.5312e+00,  8.7500e-01, -2.4062e+00,  ...,  9.1406e-01,
           4.8438e-01,  1.2969e+00],
         [-1.0938e+00,  3.2617e-01, -9.8438e-01,  ...,  5.5469e-01,
          -2.2754e-01, -7.5000e-01],
         [ 2.1094e+00, -2.6875e+00, -4.4727e-01,  ...,  1.5391e+00,
          -2.5781e+00, -1.7578e+00]],

        [[ 2.8076e-03, -3.0212e-03,  3.6774e-03,  ...,  4.7266e-01,
           5.4443e-02, -6.8848e-02],
         [-4.5625e+00,  8.1250e-01,  1.7500e+00,  ..., -5.6641e-01,
          -1.9141e-01,  8.0078e-02],
         [-1.1719e+00,  2.5938e+00,  1.4844e+00,  ..., -6.6406e-01,
           4.4727e-01,  4.9609e-01],
         ...,
         [-2.5312e+00,  8.7500e-01, -2.4062e+00,  ...,  9.1406e-01,
           4.8438e-01,  1.2969e+00],
         [-1.0938e+00,  3.2617e-01, -9.8438e-01,  ...,  5.5469e-01,
          -2.2754e-01, -7.5000e-01],
         [ 2.1094e+00, -2.6875e+00, -4.4727e-01,  ...,  1.5391e+00,
          -2.5781e+00, -1.7578e+00]],

        [[ 2.8076e-03, -3.0212e-03,  3.6774e-03,  ...,  4.7266e-01,
           5.4443e-02, -6.8848e-02],
         [-4.5625e+00,  8.1250e-01,  1.7500e+00,  ..., -5.6641e-01,
          -1.9141e-01,  8.0078e-02],
         [-1.1719e+00,  2.5938e+00,  1.4844e+00,  ..., -6.6406e-01,
           4.4727e-01,  4.9609e-01],
         ...,
         [-2.5312e+00,  8.7500e-01, -2.4062e+00,  ...,  9.1406e-01,
           4.8438e-01,  1.2969e+00],
         [-1.0938e+00,  3.2617e-01, -9.8438e-01,  ...,  5.5469e-01,
          -2.2754e-01, -7.5000e-01],
         [ 2.1094e+00, -2.6875e+00, -4.4727e-01,  ...,  1.5391e+00,
          -2.5781e+00, -1.7578e+00]],

        ...,

        [[ 1.2390e-02,  3.0670e-03, -8.9111e-03,  ...,  2.3438e-01,
          -5.9326e-02,  6.6895e-02],
         [-1.9043e-01, -1.5625e-02, -2.5000e-01,  ..., -2.9688e-01,
           6.7188e-01, -3.9258e-01],
         [-3.3789e-01,  2.4316e-01,  9.7168e-02,  ..., -1.0703e+00,
           5.2734e-01, -1.5000e+00],
         ...,
         [ 3.4961e-01,  6.8359e-01, -5.0000e-01,  ...,  5.1953e-01,
          -7.3438e-01, -6.1328e-01],
         [ 7.3242e-02,  9.3750e-02, -2.8516e-01,  ...,  1.1484e+00,
          -1.0107e-01,  5.0049e-02],
         [ 7.6953e-01, -7.6172e-01, -1.0469e+00,  ...,  1.2891e+00,
           7.8906e-01,  2.9907e-02]],

        [[ 1.2390e-02,  3.0670e-03, -8.9111e-03,  ...,  2.3438e-01,
          -5.9326e-02,  6.6895e-02],
         [-1.9043e-01, -1.5625e-02, -2.5000e-01,  ..., -2.9688e-01,
           6.7188e-01, -3.9258e-01],
         [-3.3789e-01,  2.4316e-01,  9.7168e-02,  ..., -1.0703e+00,
           5.2734e-01, -1.5000e+00],
         ...,
         [ 3.4961e-01,  6.8359e-01, -5.0000e-01,  ...,  5.1953e-01,
          -7.3438e-01, -6.1328e-01],
         [ 7.3242e-02,  9.3750e-02, -2.8516e-01,  ...,  1.1484e+00,
          -1.0107e-01,  5.0049e-02],
         [ 7.6953e-01, -7.6172e-01, -1.0469e+00,  ...,  1.2891e+00,
           7.8906e-01,  2.9907e-02]],

        [[ 1.2390e-02,  3.0670e-03, -8.9111e-03,  ...,  2.3438e-01,
          -5.9326e-02,  6.6895e-02],
         [-1.9043e-01, -1.5625e-02, -2.5000e-01,  ..., -2.9688e-01,
           6.7188e-01, -3.9258e-01],
         [-3.3789e-01,  2.4316e-01,  9.7168e-02,  ..., -1.0703e+00,
           5.2734e-01, -1.5000e+00],
         ...,
         [ 3.4961e-01,  6.8359e-01, -5.0000e-01,  ...,  5.1953e-01,
          -7.3438e-01, -6.1328e-01],
         [ 7.3242e-02,  9.3750e-02, -2.8516e-01,  ...,  1.1484e+00,
          -1.0107e-01,  5.0049e-02],
         [ 7.6953e-01, -7.6172e-01, -1.0469e+00,  ...,  1.2891e+00,
           7.8906e-01,  2.9907e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-3.1891e-03,  2.6703e-03, -3.2196e-03,  ..., -9.7656e-03,
          -9.3994e-03,  2.1667e-03],
         [-3.2422e-01,  1.2891e-01,  5.2344e-01,  ...,  4.4922e-01,
           1.2598e-01, -9.6484e-01],
         [ 8.7500e-01, -7.2656e-01,  8.9355e-02,  ...,  5.7812e-01,
          -7.8516e-01, -1.2793e-01],
         ...,
         [-4.2969e-01, -1.3047e+00, -1.1377e-01,  ...,  1.5625e+00,
           1.0781e+00, -1.0840e-01],
         [ 1.0000e+00, -3.8867e-01, -1.0547e+00,  ...,  3.3984e-01,
          -4.8633e-01, -2.3242e-01],
         [-1.3906e+00, -1.2578e+00, -5.5469e-01,  ...,  2.1094e+00,
           2.3438e+00,  1.1953e+00]],

        [[-3.1891e-03,  2.6703e-03, -3.2196e-03,  ..., -9.7656e-03,
          -9.3994e-03,  2.1667e-03],
         [-3.2422e-01,  1.2891e-01,  5.2344e-01,  ...,  4.4922e-01,
           1.2598e-01, -9.6484e-01],
         [ 8.7500e-01, -7.2656e-01,  8.9355e-02,  ...,  5.7812e-01,
          -7.8516e-01, -1.2793e-01],
         ...,
         [-4.2969e-01, -1.3047e+00, -1.1377e-01,  ...,  1.5625e+00,
           1.0781e+00, -1.0840e-01],
         [ 1.0000e+00, -3.8867e-01, -1.0547e+00,  ...,  3.3984e-01,
          -4.8633e-01, -2.3242e-01],
         [-1.3906e+00, -1.2578e+00, -5.5469e-01,  ...,  2.1094e+00,
           2.3438e+00,  1.1953e+00]],

        [[-3.1891e-03,  2.6703e-03, -3.2196e-03,  ..., -9.7656e-03,
          -9.3994e-03,  2.1667e-03],
         [-3.2422e-01,  1.2891e-01,  5.2344e-01,  ...,  4.4922e-01,
           1.2598e-01, -9.6484e-01],
         [ 8.7500e-01, -7.2656e-01,  8.9355e-02,  ...,  5.7812e-01,
          -7.8516e-01, -1.2793e-01],
         ...,
         [-4.2969e-01, -1.3047e+00, -1.1377e-01,  ...,  1.5625e+00,
           1.0781e+00, -1.0840e-01],
         [ 1.0000e+00, -3.8867e-01, -1.0547e+00,  ...,  3.3984e-01,
          -4.8633e-01, -2.3242e-01],
         [-1.3906e+00, -1.2578e+00, -5.5469e-01,  ...,  2.1094e+00,
           2.3438e+00,  1.1953e+00]],

        ...,

        [[-4.9133e-03, -5.0049e-03, -8.3008e-03,  ..., -8.8501e-04,
          -2.7618e-03,  1.4587e-02],
         [-5.5469e-01, -1.0859e+00, -3.4570e-01,  ..., -1.1641e+00,
           5.1172e-01,  1.9629e-01],
         [ 2.1875e-01, -9.4531e-01,  2.5781e-01,  ..., -7.8516e-01,
           5.6641e-01,  1.0000e+00],
         ...,
         [ 1.1353e-02,  6.7188e-01,  1.7969e+00,  ..., -1.8594e+00,
           8.3594e-01, -3.3398e-01],
         [ 1.1641e+00,  8.3984e-02,  4.1406e-01,  ..., -1.0000e+00,
          -1.2500e+00, -2.9688e-01],
         [-2.5330e-03, -2.1191e-01,  1.6016e+00,  ..., -9.2578e-01,
           1.8516e+00, -1.0000e+00]],

        [[-4.9133e-03, -5.0049e-03, -8.3008e-03,  ..., -8.8501e-04,
          -2.7618e-03,  1.4587e-02],
         [-5.5469e-01, -1.0859e+00, -3.4570e-01,  ..., -1.1641e+00,
           5.1172e-01,  1.9629e-01],
         [ 2.1875e-01, -9.4531e-01,  2.5781e-01,  ..., -7.8516e-01,
           5.6641e-01,  1.0000e+00],
         ...,
         [ 1.1353e-02,  6.7188e-01,  1.7969e+00,  ..., -1.8594e+00,
           8.3594e-01, -3.3398e-01],
         [ 1.1641e+00,  8.3984e-02,  4.1406e-01,  ..., -1.0000e+00,
          -1.2500e+00, -2.9688e-01],
         [-2.5330e-03, -2.1191e-01,  1.6016e+00,  ..., -9.2578e-01,
           1.8516e+00, -1.0000e+00]],

        [[-4.9133e-03, -5.0049e-03, -8.3008e-03,  ..., -8.8501e-04,
          -2.7618e-03,  1.4587e-02],
         [-5.5469e-01, -1.0859e+00, -3.4570e-01,  ..., -1.1641e+00,
           5.1172e-01,  1.9629e-01],
         [ 2.1875e-01, -9.4531e-01,  2.5781e-01,  ..., -7.8516e-01,
           5.6641e-01,  1.0000e+00],
         ...,
         [ 1.1353e-02,  6.7188e-01,  1.7969e+00,  ..., -1.8594e+00,
           8.3594e-01, -3.3398e-01],
         [ 1.1641e+00,  8.3984e-02,  4.1406e-01,  ..., -1.0000e+00,
          -1.2500e+00, -2.9688e-01],
         [-2.5330e-03, -2.1191e-01,  1.6016e+00,  ..., -9.2578e-01,
           1.8516e+00, -1.0000e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.6724e-02, -4.4861e-03, -2.8076e-03,  ...,  1.1523e-01,
           1.3574e-01, -1.5747e-02],
         [ 4.4141e-01,  3.1250e-01, -9.4141e-01,  ...,  7.8906e-01,
           4.4531e-01,  4.8438e-01],
         [-1.5469e+00,  1.9727e-01, -2.2188e+00,  ...,  6.8359e-01,
          -1.1426e-01,  2.2500e+00],
         ...,
         [-3.3594e-01,  6.7188e-01,  6.0156e-01,  ...,  7.7148e-02,
          -1.5625e+00, -1.3965e-01],
         [-1.4922e+00, -2.9883e-01,  5.7031e-01,  ...,  5.3125e-01,
          -4.9414e-01,  7.4219e-02],
         [-1.6875e+00, -1.0625e+00,  3.0859e-01,  ..., -1.7734e+00,
          -4.6094e-01,  9.8828e-01]],

        [[-1.6724e-02, -4.4861e-03, -2.8076e-03,  ...,  1.1523e-01,
           1.3574e-01, -1.5747e-02],
         [ 4.4141e-01,  3.1250e-01, -9.4141e-01,  ...,  7.8906e-01,
           4.4531e-01,  4.8438e-01],
         [-1.5469e+00,  1.9727e-01, -2.2188e+00,  ...,  6.8359e-01,
          -1.1426e-01,  2.2500e+00],
         ...,
         [-3.3594e-01,  6.7188e-01,  6.0156e-01,  ...,  7.7148e-02,
          -1.5625e+00, -1.3965e-01],
         [-1.4922e+00, -2.9883e-01,  5.7031e-01,  ...,  5.3125e-01,
          -4.9414e-01,  7.4219e-02],
         [-1.6875e+00, -1.0625e+00,  3.0859e-01,  ..., -1.7734e+00,
          -4.6094e-01,  9.8828e-01]],

        [[-1.6724e-02, -4.4861e-03, -2.8076e-03,  ...,  1.1523e-01,
           1.3574e-01, -1.5747e-02],
         [ 4.4141e-01,  3.1250e-01, -9.4141e-01,  ...,  7.8906e-01,
           4.4531e-01,  4.8438e-01],
         [-1.5469e+00,  1.9727e-01, -2.2188e+00,  ...,  6.8359e-01,
          -1.1426e-01,  2.2500e+00],
         ...,
         [-3.3594e-01,  6.7188e-01,  6.0156e-01,  ...,  7.7148e-02,
          -1.5625e+00, -1.3965e-01],
         [-1.4922e+00, -2.9883e-01,  5.7031e-01,  ...,  5.3125e-01,
          -4.9414e-01,  7.4219e-02],
         [-1.6875e+00, -1.0625e+00,  3.0859e-01,  ..., -1.7734e+00,
          -4.6094e-01,  9.8828e-01]],

        ...,

        [[ 2.2125e-03,  1.2573e-02,  1.5335e-03,  ...,  9.7168e-02,
          -4.8523e-03, -1.6504e-01],
         [ 1.8203e+00,  8.9062e-01, -2.1484e-01,  ..., -2.0801e-01,
          -4.8242e-01,  1.1562e+00],
         [ 1.4375e+00,  2.5156e+00, -8.4375e-01,  ...,  2.4805e-01,
          -2.2949e-01,  2.6406e+00],
         ...,
         [-3.7500e-01,  1.3984e+00,  5.2344e-01,  ..., -1.4219e+00,
          -1.3477e-01,  3.4180e-01],
         [ 1.0625e+00,  1.3594e+00, -2.8906e-01,  ...,  1.2500e+00,
           3.9258e-01,  1.4141e+00],
         [ 2.1250e+00,  1.5625e-01, -8.6719e-01,  ..., -6.8359e-01,
          -1.3086e-01,  2.7969e+00]],

        [[ 2.2125e-03,  1.2573e-02,  1.5335e-03,  ...,  9.7168e-02,
          -4.8523e-03, -1.6504e-01],
         [ 1.8203e+00,  8.9062e-01, -2.1484e-01,  ..., -2.0801e-01,
          -4.8242e-01,  1.1562e+00],
         [ 1.4375e+00,  2.5156e+00, -8.4375e-01,  ...,  2.4805e-01,
          -2.2949e-01,  2.6406e+00],
         ...,
         [-3.7500e-01,  1.3984e+00,  5.2344e-01,  ..., -1.4219e+00,
          -1.3477e-01,  3.4180e-01],
         [ 1.0625e+00,  1.3594e+00, -2.8906e-01,  ...,  1.2500e+00,
           3.9258e-01,  1.4141e+00],
         [ 2.1250e+00,  1.5625e-01, -8.6719e-01,  ..., -6.8359e-01,
          -1.3086e-01,  2.7969e+00]],

        [[ 2.2125e-03,  1.2573e-02,  1.5335e-03,  ...,  9.7168e-02,
          -4.8523e-03, -1.6504e-01],
         [ 1.8203e+00,  8.9062e-01, -2.1484e-01,  ..., -2.0801e-01,
          -4.8242e-01,  1.1562e+00],
         [ 1.4375e+00,  2.5156e+00, -8.4375e-01,  ...,  2.4805e-01,
          -2.2949e-01,  2.6406e+00],
         ...,
         [-3.7500e-01,  1.3984e+00,  5.2344e-01,  ..., -1.4219e+00,
          -1.3477e-01,  3.4180e-01],
         [ 1.0625e+00,  1.3594e+00, -2.8906e-01,  ...,  1.2500e+00,
           3.9258e-01,  1.4141e+00],
         [ 2.1250e+00,  1.5625e-01, -8.6719e-01,  ..., -6.8359e-01,
          -1.3086e-01,  2.7969e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 0.0064,  0.0118,  0.0106,  ...,  0.0147,  0.0085, -0.0025],
         [-1.1484, -1.8438,  0.8281,  ...,  1.4219, -1.5938, -0.3262],
         [-0.4883, -0.5859, -0.4316,  ...,  0.0128, -0.0199,  1.2422],
         ...,
         [-0.0113,  0.6250, -0.1709,  ..., -0.4199,  0.6523,  0.4590],
         [ 0.2715,  0.1885, -0.7188,  ...,  0.4238,  0.5508,  0.9023],
         [-0.2246, -1.1953, -0.4316,  ...,  0.0273,  1.0234,  0.8789]],

        [[ 0.0064,  0.0118,  0.0106,  ...,  0.0147,  0.0085, -0.0025],
         [-1.1484, -1.8438,  0.8281,  ...,  1.4219, -1.5938, -0.3262],
         [-0.4883, -0.5859, -0.4316,  ...,  0.0128, -0.0199,  1.2422],
         ...,
         [-0.0113,  0.6250, -0.1709,  ..., -0.4199,  0.6523,  0.4590],
         [ 0.2715,  0.1885, -0.7188,  ...,  0.4238,  0.5508,  0.9023],
         [-0.2246, -1.1953, -0.4316,  ...,  0.0273,  1.0234,  0.8789]],

        [[ 0.0064,  0.0118,  0.0106,  ...,  0.0147,  0.0085, -0.0025],
         [-1.1484, -1.8438,  0.8281,  ...,  1.4219, -1.5938, -0.3262],
         [-0.4883, -0.5859, -0.4316,  ...,  0.0128, -0.0199,  1.2422],
         ...,
         [-0.0113,  0.6250, -0.1709,  ..., -0.4199,  0.6523,  0.4590],
         [ 0.2715,  0.1885, -0.7188,  ...,  0.4238,  0.5508,  0.9023],
         [-0.2246, -1.1953, -0.4316,  ...,  0.0273,  1.0234,  0.8789]],

        ...,

        [[-0.0048,  0.0029, -0.0053,  ..., -0.0062, -0.0024,  0.0045],
         [ 0.4160, -0.6328,  0.5195,  ...,  0.7461, -0.2754,  0.5273],
         [-1.4375,  1.0703,  0.4629,  ..., -0.9336, -1.2734, -0.4043],
         ...,
         [-1.0938, -0.0304,  0.2852,  ..., -0.2266, -0.4199, -1.0000],
         [-2.0781, -0.8477, -1.4531,  ..., -0.8672, -1.1094, -0.0728],
         [ 1.1094,  0.2715, -0.0287,  ...,  0.6367,  0.2637, -0.0728]],

        [[-0.0048,  0.0029, -0.0053,  ..., -0.0062, -0.0024,  0.0045],
         [ 0.4160, -0.6328,  0.5195,  ...,  0.7461, -0.2754,  0.5273],
         [-1.4375,  1.0703,  0.4629,  ..., -0.9336, -1.2734, -0.4043],
         ...,
         [-1.0938, -0.0304,  0.2852,  ..., -0.2266, -0.4199, -1.0000],
         [-2.0781, -0.8477, -1.4531,  ..., -0.8672, -1.1094, -0.0728],
         [ 1.1094,  0.2715, -0.0287,  ...,  0.6367,  0.2637, -0.0728]],

        [[-0.0048,  0.0029, -0.0053,  ..., -0.0062, -0.0024,  0.0045],
         [ 0.4160, -0.6328,  0.5195,  ...,  0.7461, -0.2754,  0.5273],
         [-1.4375,  1.0703,  0.4629,  ..., -0.9336, -1.2734, -0.4043],
         ...,
         [-1.0938, -0.0304,  0.2852,  ..., -0.2266, -0.4199, -1.0000],
         [-2.0781, -0.8477, -1.4531,  ..., -0.8672, -1.1094, -0.0728],
         [ 1.1094,  0.2715, -0.0287,  ...,  0.6367,  0.2637, -0.0728]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[-7.9346e-04,  2.2430e-03,  1.5717e-03,  ..., -7.9956e-03,
          -1.6406e+00,  8.7402e-02],
         [ 3.9453e-01,  1.5625e-01, -1.2266e+00,  ..., -1.7285e-01,
           6.5938e+00, -1.5527e-01],
         [ 4.2188e-01, -3.2812e-01,  1.5430e-01,  ..., -2.4375e+00,
           9.6875e+00,  1.8281e+00],
         ...,
         [ 3.8672e-01,  1.8555e-01, -7.6172e-01,  ...,  4.9062e+00,
           7.7188e+00, -4.0000e+00],
         [-3.9062e-03,  1.2695e-01, -2.9492e-01,  ...,  1.1797e+00,
           5.7188e+00, -2.1094e+00],
         [-1.0059e-01, -1.6602e-01, -9.6875e-01,  ..., -2.2969e+00,
           8.8125e+00, -1.2500e+00]],

        [[-7.9346e-04,  2.2430e-03,  1.5717e-03,  ..., -7.9956e-03,
          -1.6406e+00,  8.7402e-02],
         [ 3.9453e-01,  1.5625e-01, -1.2266e+00,  ..., -1.7285e-01,
           6.5938e+00, -1.5527e-01],
         [ 4.2188e-01, -3.2812e-01,  1.5430e-01,  ..., -2.4375e+00,
           9.6875e+00,  1.8281e+00],
         ...,
         [ 3.8672e-01,  1.8555e-01, -7.6172e-01,  ...,  4.9062e+00,
           7.7188e+00, -4.0000e+00],
         [-3.9062e-03,  1.2695e-01, -2.9492e-01,  ...,  1.1797e+00,
           5.7188e+00, -2.1094e+00],
         [-1.0059e-01, -1.6602e-01, -9.6875e-01,  ..., -2.2969e+00,
           8.8125e+00, -1.2500e+00]],

        [[-7.9346e-04,  2.2430e-03,  1.5717e-03,  ..., -7.9956e-03,
          -1.6406e+00,  8.7402e-02],
         [ 3.9453e-01,  1.5625e-01, -1.2266e+00,  ..., -1.7285e-01,
           6.5938e+00, -1.5527e-01],
         [ 4.2188e-01, -3.2812e-01,  1.5430e-01,  ..., -2.4375e+00,
           9.6875e+00,  1.8281e+00],
         ...,
         [ 3.8672e-01,  1.8555e-01, -7.6172e-01,  ...,  4.9062e+00,
           7.7188e+00, -4.0000e+00],
         [-3.9062e-03,  1.2695e-01, -2.9492e-01,  ...,  1.1797e+00,
           5.7188e+00, -2.1094e+00],
         [-1.0059e-01, -1.6602e-01, -9.6875e-01,  ..., -2.2969e+00,
           8.8125e+00, -1.2500e+00]],

        ...,

        [[-3.6774e-03, -1.3000e-02, -1.9989e-03,  ..., -5.5908e-02,
           2.0386e-02, -7.2754e-02],
         [ 1.8594e+00, -1.2344e+00, -7.7344e-01,  ...,  5.4297e-01,
          -1.0703e+00,  1.8359e+00],
         [ 1.0000e+00, -7.2656e-01, -4.6875e-01,  ...,  7.3438e-01,
           3.0273e-01, -2.5781e-01],
         ...,
         [-6.2500e-01,  9.7266e-01,  1.9062e+00,  ..., -1.0938e+00,
           2.7656e+00,  6.9141e-01],
         [ 2.8320e-01, -1.2256e-01,  1.0625e+00,  ..., -1.0625e+00,
           5.5469e-01, -1.7344e+00],
         [ 1.8594e+00,  1.4219e+00,  1.4844e+00,  ..., -1.1016e+00,
           1.4453e+00,  7.5391e-01]],

        [[-3.6774e-03, -1.3000e-02, -1.9989e-03,  ..., -5.5908e-02,
           2.0386e-02, -7.2754e-02],
         [ 1.8594e+00, -1.2344e+00, -7.7344e-01,  ...,  5.4297e-01,
          -1.0703e+00,  1.8359e+00],
         [ 1.0000e+00, -7.2656e-01, -4.6875e-01,  ...,  7.3438e-01,
           3.0273e-01, -2.5781e-01],
         ...,
         [-6.2500e-01,  9.7266e-01,  1.9062e+00,  ..., -1.0938e+00,
           2.7656e+00,  6.9141e-01],
         [ 2.8320e-01, -1.2256e-01,  1.0625e+00,  ..., -1.0625e+00,
           5.5469e-01, -1.7344e+00],
         [ 1.8594e+00,  1.4219e+00,  1.4844e+00,  ..., -1.1016e+00,
           1.4453e+00,  7.5391e-01]],

        [[-3.6774e-03, -1.3000e-02, -1.9989e-03,  ..., -5.5908e-02,
           2.0386e-02, -7.2754e-02],
         [ 1.8594e+00, -1.2344e+00, -7.7344e-01,  ...,  5.4297e-01,
          -1.0703e+00,  1.8359e+00],
         [ 1.0000e+00, -7.2656e-01, -4.6875e-01,  ...,  7.3438e-01,
           3.0273e-01, -2.5781e-01],
         ...,
         [-6.2500e-01,  9.7266e-01,  1.9062e+00,  ..., -1.0938e+00,
           2.7656e+00,  6.9141e-01],
         [ 2.8320e-01, -1.2256e-01,  1.0625e+00,  ..., -1.0625e+00,
           5.5469e-01, -1.7344e+00],
         [ 1.8594e+00,  1.4219e+00,  1.4844e+00,  ..., -1.1016e+00,
           1.4453e+00,  7.5391e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 4.6692e-03,  2.2095e-02, -3.2043e-03,  ..., -1.2268e-02,
           1.4160e-02,  3.9307e-02],
         [ 5.5078e-01, -2.9102e-01, -6.2500e-01,  ..., -8.1635e-04,
          -1.5078e+00, -1.0781e+00],
         [-2.6367e-02,  6.0156e-01,  2.3047e-01,  ..., -8.0469e-01,
           4.0820e-01, -1.7578e-01],
         ...,
         [-1.0078e+00, -3.1445e-01, -9.9609e-01,  ...,  7.6172e-01,
          -7.9590e-02,  1.6016e+00],
         [-8.0469e-01, -7.3730e-02, -8.0859e-01,  ..., -4.1797e-01,
           5.3125e-01,  1.4297e+00],
         [ 1.0078e+00, -1.0547e+00, -1.7188e+00,  ..., -1.2109e+00,
          -2.2188e+00,  8.7891e-01]],

        [[ 4.6692e-03,  2.2095e-02, -3.2043e-03,  ..., -1.2268e-02,
           1.4160e-02,  3.9307e-02],
         [ 5.5078e-01, -2.9102e-01, -6.2500e-01,  ..., -8.1635e-04,
          -1.5078e+00, -1.0781e+00],
         [-2.6367e-02,  6.0156e-01,  2.3047e-01,  ..., -8.0469e-01,
           4.0820e-01, -1.7578e-01],
         ...,
         [-1.0078e+00, -3.1445e-01, -9.9609e-01,  ...,  7.6172e-01,
          -7.9590e-02,  1.6016e+00],
         [-8.0469e-01, -7.3730e-02, -8.0859e-01,  ..., -4.1797e-01,
           5.3125e-01,  1.4297e+00],
         [ 1.0078e+00, -1.0547e+00, -1.7188e+00,  ..., -1.2109e+00,
          -2.2188e+00,  8.7891e-01]],

        [[ 4.6692e-03,  2.2095e-02, -3.2043e-03,  ..., -1.2268e-02,
           1.4160e-02,  3.9307e-02],
         [ 5.5078e-01, -2.9102e-01, -6.2500e-01,  ..., -8.1635e-04,
          -1.5078e+00, -1.0781e+00],
         [-2.6367e-02,  6.0156e-01,  2.3047e-01,  ..., -8.0469e-01,
           4.0820e-01, -1.7578e-01],
         ...,
         [-1.0078e+00, -3.1445e-01, -9.9609e-01,  ...,  7.6172e-01,
          -7.9590e-02,  1.6016e+00],
         [-8.0469e-01, -7.3730e-02, -8.0859e-01,  ..., -4.1797e-01,
           5.3125e-01,  1.4297e+00],
         [ 1.0078e+00, -1.0547e+00, -1.7188e+00,  ..., -1.2109e+00,
          -2.2188e+00,  8.7891e-01]],

        ...,

        [[ 1.7319e-03,  1.0872e-04, -7.7209e-03,  ..., -1.3489e-02,
          -8.0566e-03,  6.9580e-03],
         [-7.5000e-01,  3.8086e-01, -8.7500e-01,  ..., -4.7070e-01,
          -1.2578e+00,  4.1602e-01],
         [-2.5000e-01,  2.1289e-01,  5.5469e-01,  ...,  8.9844e-01,
          -1.0469e+00,  1.2266e+00],
         ...,
         [-3.7842e-02,  6.1951e-03,  5.5469e-01,  ..., -1.7285e-01,
          -3.0078e-01,  8.3984e-02],
         [ 3.1250e-01, -4.4922e-02,  5.0000e-01,  ...,  1.0234e+00,
           1.3125e+00, -1.8750e-01],
         [ 1.8848e-01,  1.1875e+00,  5.0781e-01,  ...,  3.3203e-01,
           2.2949e-01,  2.6758e-01]],

        [[ 1.7319e-03,  1.0872e-04, -7.7209e-03,  ..., -1.3489e-02,
          -8.0566e-03,  6.9580e-03],
         [-7.5000e-01,  3.8086e-01, -8.7500e-01,  ..., -4.7070e-01,
          -1.2578e+00,  4.1602e-01],
         [-2.5000e-01,  2.1289e-01,  5.5469e-01,  ...,  8.9844e-01,
          -1.0469e+00,  1.2266e+00],
         ...,
         [-3.7842e-02,  6.1951e-03,  5.5469e-01,  ..., -1.7285e-01,
          -3.0078e-01,  8.3984e-02],
         [ 3.1250e-01, -4.4922e-02,  5.0000e-01,  ...,  1.0234e+00,
           1.3125e+00, -1.8750e-01],
         [ 1.8848e-01,  1.1875e+00,  5.0781e-01,  ...,  3.3203e-01,
           2.2949e-01,  2.6758e-01]],

        [[ 1.7319e-03,  1.0872e-04, -7.7209e-03,  ..., -1.3489e-02,
          -8.0566e-03,  6.9580e-03],
         [-7.5000e-01,  3.8086e-01, -8.7500e-01,  ..., -4.7070e-01,
          -1.2578e+00,  4.1602e-01],
         [-2.5000e-01,  2.1289e-01,  5.5469e-01,  ...,  8.9844e-01,
          -1.0469e+00,  1.2266e+00],
         ...,
         [-3.7842e-02,  6.1951e-03,  5.5469e-01,  ..., -1.7285e-01,
          -3.0078e-01,  8.3984e-02],
         [ 3.1250e-01, -4.4922e-02,  5.0000e-01,  ...,  1.0234e+00,
           1.3125e+00, -1.8750e-01],
         [ 1.8848e-01,  1.1875e+00,  5.0781e-01,  ...,  3.3203e-01,
           2.2949e-01,  2.6758e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.2695e-02,  1.6098e-03,  1.1658e-02,  ...,  2.1680e-01,
           1.3672e-01, -1.2812e+00],
         [ 7.8516e-01,  4.4531e-01,  3.4375e-01,  ...,  1.6406e+00,
           8.2031e-01,  1.7891e+00],
         [ 2.5781e-01, -4.2969e-01, -2.9785e-02,  ...,  1.4766e+00,
          -2.0781e+00,  2.2656e+00],
         ...,
         [-8.7891e-03, -3.9062e-01,  1.9922e-01,  ..., -3.2500e+00,
           7.1484e-01,  3.8594e+00],
         [ 3.5156e-01, -5.8594e-01, -1.5430e-01,  ..., -1.0859e+00,
          -1.1953e+00,  2.6719e+00],
         [-1.2793e-01, -2.0508e-01,  3.0664e-01,  ...,  8.1641e-01,
          -1.0469e+00,  4.5312e+00]],

        [[ 1.2695e-02,  1.6098e-03,  1.1658e-02,  ...,  2.1680e-01,
           1.3672e-01, -1.2812e+00],
         [ 7.8516e-01,  4.4531e-01,  3.4375e-01,  ...,  1.6406e+00,
           8.2031e-01,  1.7891e+00],
         [ 2.5781e-01, -4.2969e-01, -2.9785e-02,  ...,  1.4766e+00,
          -2.0781e+00,  2.2656e+00],
         ...,
         [-8.7891e-03, -3.9062e-01,  1.9922e-01,  ..., -3.2500e+00,
           7.1484e-01,  3.8594e+00],
         [ 3.5156e-01, -5.8594e-01, -1.5430e-01,  ..., -1.0859e+00,
          -1.1953e+00,  2.6719e+00],
         [-1.2793e-01, -2.0508e-01,  3.0664e-01,  ...,  8.1641e-01,
          -1.0469e+00,  4.5312e+00]],

        [[ 1.2695e-02,  1.6098e-03,  1.1658e-02,  ...,  2.1680e-01,
           1.3672e-01, -1.2812e+00],
         [ 7.8516e-01,  4.4531e-01,  3.4375e-01,  ...,  1.6406e+00,
           8.2031e-01,  1.7891e+00],
         [ 2.5781e-01, -4.2969e-01, -2.9785e-02,  ...,  1.4766e+00,
          -2.0781e+00,  2.2656e+00],
         ...,
         [-8.7891e-03, -3.9062e-01,  1.9922e-01,  ..., -3.2500e+00,
           7.1484e-01,  3.8594e+00],
         [ 3.5156e-01, -5.8594e-01, -1.5430e-01,  ..., -1.0859e+00,
          -1.1953e+00,  2.6719e+00],
         [-1.2793e-01, -2.0508e-01,  3.0664e-01,  ...,  8.1641e-01,
          -1.0469e+00,  4.5312e+00]],

        ...,

        [[ 9.3384e-03, -2.3438e-02, -4.6692e-03,  ...,  5.3516e-01,
          -2.9297e-01, -8.1543e-02],
         [ 4.1211e-01,  9.6875e-01, -4.1992e-02,  ..., -1.7188e+00,
           9.1016e-01, -3.4062e+00],
         [ 2.2852e-01,  7.6904e-03,  3.0469e-01,  ..., -2.0703e-01,
           8.8281e-01, -1.5000e+00],
         ...,
         [ 6.4453e-01, -2.1289e-01, -7.6172e-01,  ..., -1.6484e+00,
          -1.7734e+00, -2.7969e+00],
         [ 5.1758e-02, -1.7188e-01, -2.4414e-01,  ...,  1.6484e+00,
          -8.3594e-01,  4.2188e-01],
         [ 3.4961e-01, -7.7734e-01, -4.8828e-02,  ...,  1.1562e+00,
           1.0391e+00, -6.5312e+00]],

        [[ 9.3384e-03, -2.3438e-02, -4.6692e-03,  ...,  5.3516e-01,
          -2.9297e-01, -8.1543e-02],
         [ 4.1211e-01,  9.6875e-01, -4.1992e-02,  ..., -1.7188e+00,
           9.1016e-01, -3.4062e+00],
         [ 2.2852e-01,  7.6904e-03,  3.0469e-01,  ..., -2.0703e-01,
           8.8281e-01, -1.5000e+00],
         ...,
         [ 6.4453e-01, -2.1289e-01, -7.6172e-01,  ..., -1.6484e+00,
          -1.7734e+00, -2.7969e+00],
         [ 5.1758e-02, -1.7188e-01, -2.4414e-01,  ...,  1.6484e+00,
          -8.3594e-01,  4.2188e-01],
         [ 3.4961e-01, -7.7734e-01, -4.8828e-02,  ...,  1.1562e+00,
           1.0391e+00, -6.5312e+00]],

        [[ 9.3384e-03, -2.3438e-02, -4.6692e-03,  ...,  5.3516e-01,
          -2.9297e-01, -8.1543e-02],
         [ 4.1211e-01,  9.6875e-01, -4.1992e-02,  ..., -1.7188e+00,
           9.1016e-01, -3.4062e+00],
         [ 2.2852e-01,  7.6904e-03,  3.0469e-01,  ..., -2.0703e-01,
           8.8281e-01, -1.5000e+00],
         ...,
         [ 6.4453e-01, -2.1289e-01, -7.6172e-01,  ..., -1.6484e+00,
          -1.7734e+00, -2.7969e+00],
         [ 5.1758e-02, -1.7188e-01, -2.4414e-01,  ...,  1.6484e+00,
          -8.3594e-01,  4.2188e-01],
         [ 3.4961e-01, -7.7734e-01, -4.8828e-02,  ...,  1.1562e+00,
           1.0391e+00, -6.5312e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-5.3406e-03,  1.2398e-04,  1.0864e-02,  ..., -7.3547e-03,
           1.1475e-02,  7.3547e-03],
         [ 2.5156e+00, -1.6172e+00, -5.4688e-01,  ...,  1.7773e-01,
           7.1094e-01,  2.1406e+00],
         [ 2.7930e-01, -1.3125e+00, -1.1172e+00,  ...,  2.2754e-01,
           3.7109e-01,  7.1094e-01],
         ...,
         [-1.3770e-01, -5.5078e-01,  5.2344e-01,  ...,  3.8477e-01,
          -8.5449e-02,  9.4141e-01],
         [-1.4941e-01,  9.8145e-02, -5.4297e-01,  ...,  5.1172e-01,
           4.0283e-02,  7.4609e-01],
         [ 1.7109e+00, -6.5234e-01,  1.4941e-01,  ...,  1.8672e+00,
          -1.6406e+00,  1.0156e+00]],

        [[-5.3406e-03,  1.2398e-04,  1.0864e-02,  ..., -7.3547e-03,
           1.1475e-02,  7.3547e-03],
         [ 2.5156e+00, -1.6172e+00, -5.4688e-01,  ...,  1.7773e-01,
           7.1094e-01,  2.1406e+00],
         [ 2.7930e-01, -1.3125e+00, -1.1172e+00,  ...,  2.2754e-01,
           3.7109e-01,  7.1094e-01],
         ...,
         [-1.3770e-01, -5.5078e-01,  5.2344e-01,  ...,  3.8477e-01,
          -8.5449e-02,  9.4141e-01],
         [-1.4941e-01,  9.8145e-02, -5.4297e-01,  ...,  5.1172e-01,
           4.0283e-02,  7.4609e-01],
         [ 1.7109e+00, -6.5234e-01,  1.4941e-01,  ...,  1.8672e+00,
          -1.6406e+00,  1.0156e+00]],

        [[-5.3406e-03,  1.2398e-04,  1.0864e-02,  ..., -7.3547e-03,
           1.1475e-02,  7.3547e-03],
         [ 2.5156e+00, -1.6172e+00, -5.4688e-01,  ...,  1.7773e-01,
           7.1094e-01,  2.1406e+00],
         [ 2.7930e-01, -1.3125e+00, -1.1172e+00,  ...,  2.2754e-01,
           3.7109e-01,  7.1094e-01],
         ...,
         [-1.3770e-01, -5.5078e-01,  5.2344e-01,  ...,  3.8477e-01,
          -8.5449e-02,  9.4141e-01],
         [-1.4941e-01,  9.8145e-02, -5.4297e-01,  ...,  5.1172e-01,
           4.0283e-02,  7.4609e-01],
         [ 1.7109e+00, -6.5234e-01,  1.4941e-01,  ...,  1.8672e+00,
          -1.6406e+00,  1.0156e+00]],

        ...,

        [[ 5.8289e-03,  1.7822e-02, -2.7222e-02,  ..., -1.2360e-03,
           1.0071e-02,  9.7656e-04],
         [-6.9531e-01, -1.4062e+00,  4.2773e-01,  ...,  6.6016e-01,
           1.1719e-01,  2.8438e+00],
         [-3.5352e-01, -1.9141e+00,  2.0996e-01,  ..., -3.6328e-01,
           1.7383e-01, -1.1562e+00],
         ...,
         [-3.9648e-01, -6.9922e-01, -3.2227e-01,  ...,  4.1602e-01,
          -2.1680e-01, -2.9297e-01],
         [-1.0859e+00, -3.5352e-01, -3.5547e-01,  ...,  3.8672e-01,
           1.0547e+00,  5.9326e-02],
         [ 1.5859e+00,  1.7969e-01,  1.9297e+00,  ...,  1.9629e-01,
           9.7656e-01, -1.1406e+00]],

        [[ 5.8289e-03,  1.7822e-02, -2.7222e-02,  ..., -1.2360e-03,
           1.0071e-02,  9.7656e-04],
         [-6.9531e-01, -1.4062e+00,  4.2773e-01,  ...,  6.6016e-01,
           1.1719e-01,  2.8438e+00],
         [-3.5352e-01, -1.9141e+00,  2.0996e-01,  ..., -3.6328e-01,
           1.7383e-01, -1.1562e+00],
         ...,
         [-3.9648e-01, -6.9922e-01, -3.2227e-01,  ...,  4.1602e-01,
          -2.1680e-01, -2.9297e-01],
         [-1.0859e+00, -3.5352e-01, -3.5547e-01,  ...,  3.8672e-01,
           1.0547e+00,  5.9326e-02],
         [ 1.5859e+00,  1.7969e-01,  1.9297e+00,  ...,  1.9629e-01,
           9.7656e-01, -1.1406e+00]],

        [[ 5.8289e-03,  1.7822e-02, -2.7222e-02,  ..., -1.2360e-03,
           1.0071e-02,  9.7656e-04],
         [-6.9531e-01, -1.4062e+00,  4.2773e-01,  ...,  6.6016e-01,
           1.1719e-01,  2.8438e+00],
         [-3.5352e-01, -1.9141e+00,  2.0996e-01,  ..., -3.6328e-01,
           1.7383e-01, -1.1562e+00],
         ...,
         [-3.9648e-01, -6.9922e-01, -3.2227e-01,  ...,  4.1602e-01,
          -2.1680e-01, -2.9297e-01],
         [-1.0859e+00, -3.5352e-01, -3.5547e-01,  ...,  3.8672e-01,
           1.0547e+00,  5.9326e-02],
         [ 1.5859e+00,  1.7969e-01,  1.9297e+00,  ...,  1.9629e-01,
           9.7656e-01, -1.1406e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 8.9722e-03,  1.1139e-03, -4.6997e-03,  ...,  2.6611e-02,
          -1.4531e+00,  2.0020e-01],
         [-3.1875e+00,  1.5625e-01, -1.3438e+00,  ..., -1.1797e+00,
           7.4688e+00, -3.3750e+00],
         [-7.5000e-01,  7.6172e-01, -5.3125e-01,  ...,  7.8516e-01,
           8.7500e+00, -2.4844e+00],
         ...,
         [-3.9844e-01,  1.3965e-01,  5.4688e-01,  ..., -5.5859e-01,
           5.4688e+00, -4.1562e+00],
         [-4.7656e-01, -1.4062e-01,  3.7891e-01,  ...,  6.3965e-02,
           6.6875e+00, -2.1875e+00],
         [ 7.3047e-01, -1.1250e+00, -1.0312e+00,  ..., -5.3516e-01,
           7.5000e+00, -1.1562e+00]],

        [[ 8.9722e-03,  1.1139e-03, -4.6997e-03,  ...,  2.6611e-02,
          -1.4531e+00,  2.0020e-01],
         [-3.1875e+00,  1.5625e-01, -1.3438e+00,  ..., -1.1797e+00,
           7.4688e+00, -3.3750e+00],
         [-7.5000e-01,  7.6172e-01, -5.3125e-01,  ...,  7.8516e-01,
           8.7500e+00, -2.4844e+00],
         ...,
         [-3.9844e-01,  1.3965e-01,  5.4688e-01,  ..., -5.5859e-01,
           5.4688e+00, -4.1562e+00],
         [-4.7656e-01, -1.4062e-01,  3.7891e-01,  ...,  6.3965e-02,
           6.6875e+00, -2.1875e+00],
         [ 7.3047e-01, -1.1250e+00, -1.0312e+00,  ..., -5.3516e-01,
           7.5000e+00, -1.1562e+00]],

        [[ 8.9722e-03,  1.1139e-03, -4.6997e-03,  ...,  2.6611e-02,
          -1.4531e+00,  2.0020e-01],
         [-3.1875e+00,  1.5625e-01, -1.3438e+00,  ..., -1.1797e+00,
           7.4688e+00, -3.3750e+00],
         [-7.5000e-01,  7.6172e-01, -5.3125e-01,  ...,  7.8516e-01,
           8.7500e+00, -2.4844e+00],
         ...,
         [-3.9844e-01,  1.3965e-01,  5.4688e-01,  ..., -5.5859e-01,
           5.4688e+00, -4.1562e+00],
         [-4.7656e-01, -1.4062e-01,  3.7891e-01,  ...,  6.3965e-02,
           6.6875e+00, -2.1875e+00],
         [ 7.3047e-01, -1.1250e+00, -1.0312e+00,  ..., -5.3516e-01,
           7.5000e+00, -1.1562e+00]],

        ...,

        [[-6.3782e-03, -8.7280e-03, -9.7046e-03,  ...,  8.0078e-02,
           8.1055e-02, -2.4707e-01],
         [ 3.1562e+00,  3.8672e-01,  2.2266e-01,  ..., -4.9805e-01,
          -1.1484e+00,  1.0645e-01],
         [ 9.3359e-01,  1.5391e+00,  1.1328e-01,  ...,  1.3203e+00,
           2.0938e+00,  7.9688e-01],
         ...,
         [ 9.6484e-01,  4.7266e-01, -1.6699e-01,  ..., -2.5000e+00,
           6.3281e-01,  1.9922e+00],
         [ 6.7969e-01,  1.0156e+00,  3.4375e-01,  ..., -2.1250e+00,
           2.3594e+00,  7.5000e-01],
         [-1.0938e+00,  5.5859e-01,  2.3633e-01,  ...,  3.1445e-01,
           6.7969e-01, -1.6016e-01]],

        [[-6.3782e-03, -8.7280e-03, -9.7046e-03,  ...,  8.0078e-02,
           8.1055e-02, -2.4707e-01],
         [ 3.1562e+00,  3.8672e-01,  2.2266e-01,  ..., -4.9805e-01,
          -1.1484e+00,  1.0645e-01],
         [ 9.3359e-01,  1.5391e+00,  1.1328e-01,  ...,  1.3203e+00,
           2.0938e+00,  7.9688e-01],
         ...,
         [ 9.6484e-01,  4.7266e-01, -1.6699e-01,  ..., -2.5000e+00,
           6.3281e-01,  1.9922e+00],
         [ 6.7969e-01,  1.0156e+00,  3.4375e-01,  ..., -2.1250e+00,
           2.3594e+00,  7.5000e-01],
         [-1.0938e+00,  5.5859e-01,  2.3633e-01,  ...,  3.1445e-01,
           6.7969e-01, -1.6016e-01]],

        [[-6.3782e-03, -8.7280e-03, -9.7046e-03,  ...,  8.0078e-02,
           8.1055e-02, -2.4707e-01],
         [ 3.1562e+00,  3.8672e-01,  2.2266e-01,  ..., -4.9805e-01,
          -1.1484e+00,  1.0645e-01],
         [ 9.3359e-01,  1.5391e+00,  1.1328e-01,  ...,  1.3203e+00,
           2.0938e+00,  7.9688e-01],
         ...,
         [ 9.6484e-01,  4.7266e-01, -1.6699e-01,  ..., -2.5000e+00,
           6.3281e-01,  1.9922e+00],
         [ 6.7969e-01,  1.0156e+00,  3.4375e-01,  ..., -2.1250e+00,
           2.3594e+00,  7.5000e-01],
         [-1.0938e+00,  5.5859e-01,  2.3633e-01,  ...,  3.1445e-01,
           6.7969e-01, -1.6016e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 7.6294e-03, -6.6528e-03, -5.7678e-03,  ...,  6.5613e-03,
          -5.0354e-03,  2.7313e-03],
         [ 9.2969e-01,  1.1641e+00, -3.2812e-01,  ...,  1.3281e+00,
           2.6367e-01,  2.0625e+00],
         [-4.9609e-01, -8.0078e-01,  4.6680e-01,  ...,  5.9766e-01,
          -2.7734e-01,  8.7500e-01],
         ...,
         [ 1.0469e+00, -6.2109e-01, -1.0625e+00,  ...,  4.8047e-01,
           1.5000e+00, -7.9688e-01],
         [ 5.7812e-01, -6.2891e-01,  1.5039e-01,  ...,  6.8359e-01,
           7.3047e-01, -2.0703e-01],
         [ 1.3594e+00,  9.7168e-02,  2.0781e+00,  ..., -9.9609e-01,
          -1.0781e+00,  1.6016e+00]],

        [[ 7.6294e-03, -6.6528e-03, -5.7678e-03,  ...,  6.5613e-03,
          -5.0354e-03,  2.7313e-03],
         [ 9.2969e-01,  1.1641e+00, -3.2812e-01,  ...,  1.3281e+00,
           2.6367e-01,  2.0625e+00],
         [-4.9609e-01, -8.0078e-01,  4.6680e-01,  ...,  5.9766e-01,
          -2.7734e-01,  8.7500e-01],
         ...,
         [ 1.0469e+00, -6.2109e-01, -1.0625e+00,  ...,  4.8047e-01,
           1.5000e+00, -7.9688e-01],
         [ 5.7812e-01, -6.2891e-01,  1.5039e-01,  ...,  6.8359e-01,
           7.3047e-01, -2.0703e-01],
         [ 1.3594e+00,  9.7168e-02,  2.0781e+00,  ..., -9.9609e-01,
          -1.0781e+00,  1.6016e+00]],

        [[ 7.6294e-03, -6.6528e-03, -5.7678e-03,  ...,  6.5613e-03,
          -5.0354e-03,  2.7313e-03],
         [ 9.2969e-01,  1.1641e+00, -3.2812e-01,  ...,  1.3281e+00,
           2.6367e-01,  2.0625e+00],
         [-4.9609e-01, -8.0078e-01,  4.6680e-01,  ...,  5.9766e-01,
          -2.7734e-01,  8.7500e-01],
         ...,
         [ 1.0469e+00, -6.2109e-01, -1.0625e+00,  ...,  4.8047e-01,
           1.5000e+00, -7.9688e-01],
         [ 5.7812e-01, -6.2891e-01,  1.5039e-01,  ...,  6.8359e-01,
           7.3047e-01, -2.0703e-01],
         [ 1.3594e+00,  9.7168e-02,  2.0781e+00,  ..., -9.9609e-01,
          -1.0781e+00,  1.6016e+00]],

        ...,

        [[-6.7234e-05,  2.8381e-03, -5.2795e-03,  ...,  2.2736e-03,
          -6.7444e-03,  6.7749e-03],
         [-7.3242e-02,  5.4688e-01,  8.6328e-01,  ...,  7.6953e-01,
           4.8633e-01,  3.1055e-01],
         [-6.3672e-01,  3.1445e-01,  2.5977e-01,  ..., -3.6865e-02,
           1.9165e-02, -8.3594e-01],
         ...,
         [-1.0303e-01,  9.1016e-01,  4.2383e-01,  ..., -8.9453e-01,
           5.4297e-01,  7.2266e-01],
         [-2.6367e-01, -5.5859e-01,  1.3184e-01,  ...,  5.0391e-01,
           1.7734e+00, -1.7773e-01],
         [ 1.0312e+00, -1.5625e-01,  7.3438e-01,  ..., -1.5332e-01,
           6.8359e-01,  9.1016e-01]],

        [[-6.7234e-05,  2.8381e-03, -5.2795e-03,  ...,  2.2736e-03,
          -6.7444e-03,  6.7749e-03],
         [-7.3242e-02,  5.4688e-01,  8.6328e-01,  ...,  7.6953e-01,
           4.8633e-01,  3.1055e-01],
         [-6.3672e-01,  3.1445e-01,  2.5977e-01,  ..., -3.6865e-02,
           1.9165e-02, -8.3594e-01],
         ...,
         [-1.0303e-01,  9.1016e-01,  4.2383e-01,  ..., -8.9453e-01,
           5.4297e-01,  7.2266e-01],
         [-2.6367e-01, -5.5859e-01,  1.3184e-01,  ...,  5.0391e-01,
           1.7734e+00, -1.7773e-01],
         [ 1.0312e+00, -1.5625e-01,  7.3438e-01,  ..., -1.5332e-01,
           6.8359e-01,  9.1016e-01]],

        [[-6.7234e-05,  2.8381e-03, -5.2795e-03,  ...,  2.2736e-03,
          -6.7444e-03,  6.7749e-03],
         [-7.3242e-02,  5.4688e-01,  8.6328e-01,  ...,  7.6953e-01,
           4.8633e-01,  3.1055e-01],
         [-6.3672e-01,  3.1445e-01,  2.5977e-01,  ..., -3.6865e-02,
           1.9165e-02, -8.3594e-01],
         ...,
         [-1.0303e-01,  9.1016e-01,  4.2383e-01,  ..., -8.9453e-01,
           5.4297e-01,  7.2266e-01],
         [-2.6367e-01, -5.5859e-01,  1.3184e-01,  ...,  5.0391e-01,
           1.7734e+00, -1.7773e-01],
         [ 1.0312e+00, -1.5625e-01,  7.3438e-01,  ..., -1.5332e-01,
           6.8359e-01,  9.1016e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.3245e-02, -6.1035e-04,  2.7618e-03,  ..., -1.0303e-01,
          -1.4941e-01,  1.9727e-01],
         [-1.1875e+00, -4.7070e-01,  2.4531e+00,  ...,  9.4141e-01,
           2.5195e-01,  1.8125e+00],
         [ 1.2451e-01,  2.1387e-01,  2.1973e-01,  ...,  1.1719e+00,
          -3.3398e-01, -2.8125e+00],
         ...,
         [ 5.8594e-03,  4.7266e-01,  6.2500e-01,  ...,  5.7031e-01,
           1.3203e+00, -1.0156e+00],
         [ 1.2695e-01, -2.5586e-01,  8.5938e-01,  ..., -3.5156e-01,
           1.4160e-01, -1.8906e+00],
         [ 2.7344e-01,  7.3438e-01,  1.5391e+00,  ...,  1.2578e+00,
           9.4922e-01,  7.7344e-01]],

        [[ 1.3245e-02, -6.1035e-04,  2.7618e-03,  ..., -1.0303e-01,
          -1.4941e-01,  1.9727e-01],
         [-1.1875e+00, -4.7070e-01,  2.4531e+00,  ...,  9.4141e-01,
           2.5195e-01,  1.8125e+00],
         [ 1.2451e-01,  2.1387e-01,  2.1973e-01,  ...,  1.1719e+00,
          -3.3398e-01, -2.8125e+00],
         ...,
         [ 5.8594e-03,  4.7266e-01,  6.2500e-01,  ...,  5.7031e-01,
           1.3203e+00, -1.0156e+00],
         [ 1.2695e-01, -2.5586e-01,  8.5938e-01,  ..., -3.5156e-01,
           1.4160e-01, -1.8906e+00],
         [ 2.7344e-01,  7.3438e-01,  1.5391e+00,  ...,  1.2578e+00,
           9.4922e-01,  7.7344e-01]],

        [[ 1.3245e-02, -6.1035e-04,  2.7618e-03,  ..., -1.0303e-01,
          -1.4941e-01,  1.9727e-01],
         [-1.1875e+00, -4.7070e-01,  2.4531e+00,  ...,  9.4141e-01,
           2.5195e-01,  1.8125e+00],
         [ 1.2451e-01,  2.1387e-01,  2.1973e-01,  ...,  1.1719e+00,
          -3.3398e-01, -2.8125e+00],
         ...,
         [ 5.8594e-03,  4.7266e-01,  6.2500e-01,  ...,  5.7031e-01,
           1.3203e+00, -1.0156e+00],
         [ 1.2695e-01, -2.5586e-01,  8.5938e-01,  ..., -3.5156e-01,
           1.4160e-01, -1.8906e+00],
         [ 2.7344e-01,  7.3438e-01,  1.5391e+00,  ...,  1.2578e+00,
           9.4922e-01,  7.7344e-01]],

        ...,

        [[ 7.3242e-03,  2.6398e-03,  3.4027e-03,  ..., -1.0352e-01,
           8.5938e-01, -3.6914e-01],
         [-1.3281e-01,  6.6016e-01,  1.0059e-01,  ...,  2.9531e+00,
          -3.9062e+00, -5.7031e-01],
         [-2.3828e-01,  6.2012e-02, -2.3535e-01,  ...,  1.9609e+00,
          -5.2500e+00,  1.8516e+00],
         ...,
         [-7.7344e-01, -4.3750e-01,  2.8125e-01,  ..., -1.9531e+00,
           6.1768e-02,  1.1406e+00],
         [-6.1523e-02, -5.1562e-01,  4.5117e-01,  ..., -8.5156e-01,
          -2.5156e+00,  1.6250e+00],
         [-2.0898e-01, -3.9453e-01,  8.7109e-01,  ...,  2.0000e+00,
          -3.0000e+00,  1.9629e-01]],

        [[ 7.3242e-03,  2.6398e-03,  3.4027e-03,  ..., -1.0352e-01,
           8.5938e-01, -3.6914e-01],
         [-1.3281e-01,  6.6016e-01,  1.0059e-01,  ...,  2.9531e+00,
          -3.9062e+00, -5.7031e-01],
         [-2.3828e-01,  6.2012e-02, -2.3535e-01,  ...,  1.9609e+00,
          -5.2500e+00,  1.8516e+00],
         ...,
         [-7.7344e-01, -4.3750e-01,  2.8125e-01,  ..., -1.9531e+00,
           6.1768e-02,  1.1406e+00],
         [-6.1523e-02, -5.1562e-01,  4.5117e-01,  ..., -8.5156e-01,
          -2.5156e+00,  1.6250e+00],
         [-2.0898e-01, -3.9453e-01,  8.7109e-01,  ...,  2.0000e+00,
          -3.0000e+00,  1.9629e-01]],

        [[ 7.3242e-03,  2.6398e-03,  3.4027e-03,  ..., -1.0352e-01,
           8.5938e-01, -3.6914e-01],
         [-1.3281e-01,  6.6016e-01,  1.0059e-01,  ...,  2.9531e+00,
          -3.9062e+00, -5.7031e-01],
         [-2.3828e-01,  6.2012e-02, -2.3535e-01,  ...,  1.9609e+00,
          -5.2500e+00,  1.8516e+00],
         ...,
         [-7.7344e-01, -4.3750e-01,  2.8125e-01,  ..., -1.9531e+00,
           6.1768e-02,  1.1406e+00],
         [-6.1523e-02, -5.1562e-01,  4.5117e-01,  ..., -8.5156e-01,
          -2.5156e+00,  1.6250e+00],
         [-2.0898e-01, -3.9453e-01,  8.7109e-01,  ...,  2.0000e+00,
          -3.0000e+00,  1.9629e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-7.5378e-03,  4.1504e-03,  5.4932e-03,  ..., -3.3569e-03,
           1.2024e-02,  3.7384e-03],
         [-1.2656e+00, -1.1250e+00, -8.9453e-01,  ...,  3.3008e-01,
          -6.0156e-01, -7.5391e-01],
         [-2.0156e+00, -9.3750e-01, -2.6245e-03,  ..., -3.0000e+00,
           9.2578e-01, -1.0234e+00],
         ...,
         [ 7.5391e-01, -1.2734e+00, -9.1406e-01,  ..., -3.1562e+00,
           1.1719e+00,  5.8203e-01],
         [ 8.2812e-01, -5.7031e-01, -1.7500e+00,  ..., -5.6875e+00,
           2.3594e+00, -1.2344e+00],
         [-3.9258e-01,  1.4801e-03, -1.1016e+00,  ..., -2.4219e+00,
          -6.8750e-01, -4.8047e-01]],

        [[-7.5378e-03,  4.1504e-03,  5.4932e-03,  ..., -3.3569e-03,
           1.2024e-02,  3.7384e-03],
         [-1.2656e+00, -1.1250e+00, -8.9453e-01,  ...,  3.3008e-01,
          -6.0156e-01, -7.5391e-01],
         [-2.0156e+00, -9.3750e-01, -2.6245e-03,  ..., -3.0000e+00,
           9.2578e-01, -1.0234e+00],
         ...,
         [ 7.5391e-01, -1.2734e+00, -9.1406e-01,  ..., -3.1562e+00,
           1.1719e+00,  5.8203e-01],
         [ 8.2812e-01, -5.7031e-01, -1.7500e+00,  ..., -5.6875e+00,
           2.3594e+00, -1.2344e+00],
         [-3.9258e-01,  1.4801e-03, -1.1016e+00,  ..., -2.4219e+00,
          -6.8750e-01, -4.8047e-01]],

        [[-7.5378e-03,  4.1504e-03,  5.4932e-03,  ..., -3.3569e-03,
           1.2024e-02,  3.7384e-03],
         [-1.2656e+00, -1.1250e+00, -8.9453e-01,  ...,  3.3008e-01,
          -6.0156e-01, -7.5391e-01],
         [-2.0156e+00, -9.3750e-01, -2.6245e-03,  ..., -3.0000e+00,
           9.2578e-01, -1.0234e+00],
         ...,
         [ 7.5391e-01, -1.2734e+00, -9.1406e-01,  ..., -3.1562e+00,
           1.1719e+00,  5.8203e-01],
         [ 8.2812e-01, -5.7031e-01, -1.7500e+00,  ..., -5.6875e+00,
           2.3594e+00, -1.2344e+00],
         [-3.9258e-01,  1.4801e-03, -1.1016e+00,  ..., -2.4219e+00,
          -6.8750e-01, -4.8047e-01]],

        ...,

        [[ 7.0190e-03, -1.6556e-03, -3.8605e-03,  ..., -3.7384e-03,
           2.1973e-03,  1.5564e-03],
         [-8.9844e-01, -1.7422e+00,  4.0312e+00,  ..., -1.5469e+00,
           9.5312e-01,  2.0625e+00],
         [ 1.4954e-03, -8.6719e-01,  2.4375e+00,  ..., -1.3672e+00,
          -6.4062e-01, -1.4453e-01],
         ...,
         [-2.2363e-01,  1.2344e+00, -1.8750e+00,  ...,  1.4038e-02,
           2.9297e-01,  1.0312e+00],
         [-5.4016e-03,  6.6797e-01, -5.9375e-01,  ...,  1.4453e-01,
           1.8262e-01,  8.2031e-01],
         [-1.1641e+00, -7.0312e-01,  1.0938e+00,  ..., -1.9531e+00,
           7.4609e-01,  4.9023e-01]],

        [[ 7.0190e-03, -1.6556e-03, -3.8605e-03,  ..., -3.7384e-03,
           2.1973e-03,  1.5564e-03],
         [-8.9844e-01, -1.7422e+00,  4.0312e+00,  ..., -1.5469e+00,
           9.5312e-01,  2.0625e+00],
         [ 1.4954e-03, -8.6719e-01,  2.4375e+00,  ..., -1.3672e+00,
          -6.4062e-01, -1.4453e-01],
         ...,
         [-2.2363e-01,  1.2344e+00, -1.8750e+00,  ...,  1.4038e-02,
           2.9297e-01,  1.0312e+00],
         [-5.4016e-03,  6.6797e-01, -5.9375e-01,  ...,  1.4453e-01,
           1.8262e-01,  8.2031e-01],
         [-1.1641e+00, -7.0312e-01,  1.0938e+00,  ..., -1.9531e+00,
           7.4609e-01,  4.9023e-01]],

        [[ 7.0190e-03, -1.6556e-03, -3.8605e-03,  ..., -3.7384e-03,
           2.1973e-03,  1.5564e-03],
         [-8.9844e-01, -1.7422e+00,  4.0312e+00,  ..., -1.5469e+00,
           9.5312e-01,  2.0625e+00],
         [ 1.4954e-03, -8.6719e-01,  2.4375e+00,  ..., -1.3672e+00,
          -6.4062e-01, -1.4453e-01],
         ...,
         [-2.2363e-01,  1.2344e+00, -1.8750e+00,  ...,  1.4038e-02,
           2.9297e-01,  1.0312e+00],
         [-5.4016e-03,  6.6797e-01, -5.9375e-01,  ...,  1.4453e-01,
           1.8262e-01,  8.2031e-01],
         [-1.1641e+00, -7.0312e-01,  1.0938e+00,  ..., -1.9531e+00,
           7.4609e-01,  4.9023e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 8.5449e-03,  1.2741e-03, -8.4229e-03,  ..., -1.3379e-01,
          -6.2866e-03,  7.2327e-03],
         [-1.0781e+00, -6.3281e-01, -2.5391e-01,  ...,  5.6641e-01,
          -1.0625e+00,  2.5156e+00],
         [ 1.4258e-01,  2.4609e-01, -9.9219e-01,  ...,  1.7344e+00,
           2.8125e-01,  3.3906e+00],
         ...,
         [-6.8750e-01,  7.4219e-01,  8.4766e-01,  ..., -2.7500e+00,
           8.9355e-02,  5.7031e-01],
         [-3.4180e-01,  1.0078e+00, -3.3203e-02,  ...,  2.0156e+00,
          -1.0234e+00,  1.7500e+00],
         [ 6.6406e-01,  1.1328e+00,  3.1836e-01,  ..., -5.0391e-01,
          -5.0781e-01,  3.5469e+00]],

        [[ 8.5449e-03,  1.2741e-03, -8.4229e-03,  ..., -1.3379e-01,
          -6.2866e-03,  7.2327e-03],
         [-1.0781e+00, -6.3281e-01, -2.5391e-01,  ...,  5.6641e-01,
          -1.0625e+00,  2.5156e+00],
         [ 1.4258e-01,  2.4609e-01, -9.9219e-01,  ...,  1.7344e+00,
           2.8125e-01,  3.3906e+00],
         ...,
         [-6.8750e-01,  7.4219e-01,  8.4766e-01,  ..., -2.7500e+00,
           8.9355e-02,  5.7031e-01],
         [-3.4180e-01,  1.0078e+00, -3.3203e-02,  ...,  2.0156e+00,
          -1.0234e+00,  1.7500e+00],
         [ 6.6406e-01,  1.1328e+00,  3.1836e-01,  ..., -5.0391e-01,
          -5.0781e-01,  3.5469e+00]],

        [[ 8.5449e-03,  1.2741e-03, -8.4229e-03,  ..., -1.3379e-01,
          -6.2866e-03,  7.2327e-03],
         [-1.0781e+00, -6.3281e-01, -2.5391e-01,  ...,  5.6641e-01,
          -1.0625e+00,  2.5156e+00],
         [ 1.4258e-01,  2.4609e-01, -9.9219e-01,  ...,  1.7344e+00,
           2.8125e-01,  3.3906e+00],
         ...,
         [-6.8750e-01,  7.4219e-01,  8.4766e-01,  ..., -2.7500e+00,
           8.9355e-02,  5.7031e-01],
         [-3.4180e-01,  1.0078e+00, -3.3203e-02,  ...,  2.0156e+00,
          -1.0234e+00,  1.7500e+00],
         [ 6.6406e-01,  1.1328e+00,  3.1836e-01,  ..., -5.0391e-01,
          -5.0781e-01,  3.5469e+00]],

        ...,

        [[ 4.4861e-03, -7.5073e-03, -7.8735e-03,  ..., -7.3242e-02,
          -5.5908e-02,  5.2490e-02],
         [ 1.3125e+00, -1.4844e-01,  1.0781e+00,  ..., -9.2969e-01,
           2.9883e-01,  5.6641e-01],
         [-1.1484e+00,  1.1875e+00,  1.0840e-01,  ...,  8.4766e-01,
           3.3789e-01,  1.4219e+00],
         ...,
         [ 1.2695e-02,  1.3750e+00, -9.3359e-01,  ..., -2.7344e+00,
          -1.7656e+00, -3.8281e+00],
         [-1.6016e-01,  6.6406e-01, -2.6367e-01,  ..., -2.9375e+00,
          -2.0000e+00, -2.0469e+00],
         [-1.6250e+00, -4.2969e-01,  3.0859e-01,  ..., -4.7070e-01,
           3.0469e-01,  2.5312e+00]],

        [[ 4.4861e-03, -7.5073e-03, -7.8735e-03,  ..., -7.3242e-02,
          -5.5908e-02,  5.2490e-02],
         [ 1.3125e+00, -1.4844e-01,  1.0781e+00,  ..., -9.2969e-01,
           2.9883e-01,  5.6641e-01],
         [-1.1484e+00,  1.1875e+00,  1.0840e-01,  ...,  8.4766e-01,
           3.3789e-01,  1.4219e+00],
         ...,
         [ 1.2695e-02,  1.3750e+00, -9.3359e-01,  ..., -2.7344e+00,
          -1.7656e+00, -3.8281e+00],
         [-1.6016e-01,  6.6406e-01, -2.6367e-01,  ..., -2.9375e+00,
          -2.0000e+00, -2.0469e+00],
         [-1.6250e+00, -4.2969e-01,  3.0859e-01,  ..., -4.7070e-01,
           3.0469e-01,  2.5312e+00]],

        [[ 4.4861e-03, -7.5073e-03, -7.8735e-03,  ..., -7.3242e-02,
          -5.5908e-02,  5.2490e-02],
         [ 1.3125e+00, -1.4844e-01,  1.0781e+00,  ..., -9.2969e-01,
           2.9883e-01,  5.6641e-01],
         [-1.1484e+00,  1.1875e+00,  1.0840e-01,  ...,  8.4766e-01,
           3.3789e-01,  1.4219e+00],
         ...,
         [ 1.2695e-02,  1.3750e+00, -9.3359e-01,  ..., -2.7344e+00,
          -1.7656e+00, -3.8281e+00],
         [-1.6016e-01,  6.6406e-01, -2.6367e-01,  ..., -2.9375e+00,
          -2.0000e+00, -2.0469e+00],
         [-1.6250e+00, -4.2969e-01,  3.0859e-01,  ..., -4.7070e-01,
           3.0469e-01,  2.5312e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 5.7068e-03,  9.1171e-04, -8.8501e-03,  ..., -4.3869e-04,
          -9.8877e-03, -9.2316e-04],
         [-5.3906e-01, -3.3008e-01,  1.1035e-01,  ...,  1.0791e-01,
          -5.3125e-01, -3.2422e-01],
         [-1.3245e-02, -2.9492e-01, -7.5000e-01,  ..., -4.2773e-01,
          -9.1016e-01,  1.9824e-01],
         ...,
         [ 7.1094e-01,  9.6680e-02, -9.3359e-01,  ...,  2.2266e-01,
          -1.2031e+00, -1.0156e+00],
         [-5.8105e-02,  4.6484e-01, -3.0859e-01,  ..., -5.5664e-02,
          -9.3359e-01,  7.7734e-01],
         [-6.2109e-01,  1.9922e+00, -4.1797e-01,  ..., -6.7969e-01,
           2.6367e-01,  7.7637e-02]],

        [[ 5.7068e-03,  9.1171e-04, -8.8501e-03,  ..., -4.3869e-04,
          -9.8877e-03, -9.2316e-04],
         [-5.3906e-01, -3.3008e-01,  1.1035e-01,  ...,  1.0791e-01,
          -5.3125e-01, -3.2422e-01],
         [-1.3245e-02, -2.9492e-01, -7.5000e-01,  ..., -4.2773e-01,
          -9.1016e-01,  1.9824e-01],
         ...,
         [ 7.1094e-01,  9.6680e-02, -9.3359e-01,  ...,  2.2266e-01,
          -1.2031e+00, -1.0156e+00],
         [-5.8105e-02,  4.6484e-01, -3.0859e-01,  ..., -5.5664e-02,
          -9.3359e-01,  7.7734e-01],
         [-6.2109e-01,  1.9922e+00, -4.1797e-01,  ..., -6.7969e-01,
           2.6367e-01,  7.7637e-02]],

        [[ 5.7068e-03,  9.1171e-04, -8.8501e-03,  ..., -4.3869e-04,
          -9.8877e-03, -9.2316e-04],
         [-5.3906e-01, -3.3008e-01,  1.1035e-01,  ...,  1.0791e-01,
          -5.3125e-01, -3.2422e-01],
         [-1.3245e-02, -2.9492e-01, -7.5000e-01,  ..., -4.2773e-01,
          -9.1016e-01,  1.9824e-01],
         ...,
         [ 7.1094e-01,  9.6680e-02, -9.3359e-01,  ...,  2.2266e-01,
          -1.2031e+00, -1.0156e+00],
         [-5.8105e-02,  4.6484e-01, -3.0859e-01,  ..., -5.5664e-02,
          -9.3359e-01,  7.7734e-01],
         [-6.2109e-01,  1.9922e+00, -4.1797e-01,  ..., -6.7969e-01,
           2.6367e-01,  7.7637e-02]],

        ...,

        [[ 3.2227e-02,  2.3682e-02,  8.5449e-02,  ...,  9.9487e-03,
          -4.2236e-02,  3.6377e-02],
         [-3.1836e-01, -1.2695e-01,  9.5312e-01,  ..., -5.3516e-01,
           2.8711e-01, -3.7305e-01],
         [ 2.5000e-01, -4.2725e-02,  1.6797e-01,  ...,  1.0254e-01,
          -1.0078e+00,  5.4199e-02],
         ...,
         [ 2.0874e-02, -1.2734e+00,  1.9453e+00,  ...,  5.7422e-01,
           1.8188e-02, -9.0625e-01],
         [-4.6289e-01, -8.9844e-01,  4.6875e-01,  ..., -9.7656e-02,
           2.9492e-01, -1.2109e+00],
         [ 4.5898e-01, -5.0781e-01,  2.1191e-01,  ...,  1.0078e+00,
           1.1328e+00,  3.3789e-01]],

        [[ 3.2227e-02,  2.3682e-02,  8.5449e-02,  ...,  9.9487e-03,
          -4.2236e-02,  3.6377e-02],
         [-3.1836e-01, -1.2695e-01,  9.5312e-01,  ..., -5.3516e-01,
           2.8711e-01, -3.7305e-01],
         [ 2.5000e-01, -4.2725e-02,  1.6797e-01,  ...,  1.0254e-01,
          -1.0078e+00,  5.4199e-02],
         ...,
         [ 2.0874e-02, -1.2734e+00,  1.9453e+00,  ...,  5.7422e-01,
           1.8188e-02, -9.0625e-01],
         [-4.6289e-01, -8.9844e-01,  4.6875e-01,  ..., -9.7656e-02,
           2.9492e-01, -1.2109e+00],
         [ 4.5898e-01, -5.0781e-01,  2.1191e-01,  ...,  1.0078e+00,
           1.1328e+00,  3.3789e-01]],

        [[ 3.2227e-02,  2.3682e-02,  8.5449e-02,  ...,  9.9487e-03,
          -4.2236e-02,  3.6377e-02],
         [-3.1836e-01, -1.2695e-01,  9.5312e-01,  ..., -5.3516e-01,
           2.8711e-01, -3.7305e-01],
         [ 2.5000e-01, -4.2725e-02,  1.6797e-01,  ...,  1.0254e-01,
          -1.0078e+00,  5.4199e-02],
         ...,
         [ 2.0874e-02, -1.2734e+00,  1.9453e+00,  ...,  5.7422e-01,
           1.8188e-02, -9.0625e-01],
         [-4.6289e-01, -8.9844e-01,  4.6875e-01,  ..., -9.7656e-02,
           2.9492e-01, -1.2109e+00],
         [ 4.5898e-01, -5.0781e-01,  2.1191e-01,  ...,  1.0078e+00,
           1.1328e+00,  3.3789e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.4893e-02,  5.1270e-03,  1.1780e-02,  ...,  1.9375e+00,
          -9.0820e-02,  3.8086e-02],
         [ 2.3438e-02,  2.0508e-01,  1.1572e-01,  ..., -6.1250e+00,
          -3.9648e-01,  3.4766e-01],
         [ 6.9336e-02, -4.5508e-01, -1.8555e-02,  ..., -7.1875e+00,
           5.3516e-01,  2.1406e+00],
         ...,
         [ 5.8594e-01, -3.9648e-01, -5.8350e-02,  ..., -4.1875e+00,
           1.4375e+00, -4.0938e+00],
         [ 1.3477e-01, -3.2812e-01,  5.7422e-01,  ..., -5.7812e+00,
           1.1797e+00, -5.6250e-01],
         [-1.3477e-01,  8.2520e-02,  3.6621e-02,  ..., -6.3125e+00,
          -5.8984e-01,  3.1836e-01]],

        [[ 1.4893e-02,  5.1270e-03,  1.1780e-02,  ...,  1.9375e+00,
          -9.0820e-02,  3.8086e-02],
         [ 2.3438e-02,  2.0508e-01,  1.1572e-01,  ..., -6.1250e+00,
          -3.9648e-01,  3.4766e-01],
         [ 6.9336e-02, -4.5508e-01, -1.8555e-02,  ..., -7.1875e+00,
           5.3516e-01,  2.1406e+00],
         ...,
         [ 5.8594e-01, -3.9648e-01, -5.8350e-02,  ..., -4.1875e+00,
           1.4375e+00, -4.0938e+00],
         [ 1.3477e-01, -3.2812e-01,  5.7422e-01,  ..., -5.7812e+00,
           1.1797e+00, -5.6250e-01],
         [-1.3477e-01,  8.2520e-02,  3.6621e-02,  ..., -6.3125e+00,
          -5.8984e-01,  3.1836e-01]],

        [[ 1.4893e-02,  5.1270e-03,  1.1780e-02,  ...,  1.9375e+00,
          -9.0820e-02,  3.8086e-02],
         [ 2.3438e-02,  2.0508e-01,  1.1572e-01,  ..., -6.1250e+00,
          -3.9648e-01,  3.4766e-01],
         [ 6.9336e-02, -4.5508e-01, -1.8555e-02,  ..., -7.1875e+00,
           5.3516e-01,  2.1406e+00],
         ...,
         [ 5.8594e-01, -3.9648e-01, -5.8350e-02,  ..., -4.1875e+00,
           1.4375e+00, -4.0938e+00],
         [ 1.3477e-01, -3.2812e-01,  5.7422e-01,  ..., -5.7812e+00,
           1.1797e+00, -5.6250e-01],
         [-1.3477e-01,  8.2520e-02,  3.6621e-02,  ..., -6.3125e+00,
          -5.8984e-01,  3.1836e-01]],

        ...,

        [[ 2.5787e-03,  3.2501e-03,  1.0620e-02,  ...,  6.8359e-02,
           9.8047e-01, -1.3281e-01],
         [-3.9062e-03,  4.6484e-01, -1.1250e+00,  ...,  4.6631e-02,
          -1.1953e+00, -3.4766e-01],
         [-4.2188e-01, -9.4531e-01, -8.1641e-01,  ...,  1.5332e-01,
          -4.5938e+00,  2.9688e-01],
         ...,
         [-3.3008e-01, -1.1797e+00,  1.0391e+00,  ...,  7.0312e-02,
          -2.2500e+00, -2.3560e-02],
         [-4.2578e-01,  9.8145e-02,  2.2461e-02,  ..., -3.5889e-02,
          -2.8125e+00, -2.8198e-02],
         [-1.7422e+00,  7.0312e-02, -1.5625e-01,  ..., -8.1641e-01,
          -2.9844e+00,  3.8086e-01]],

        [[ 2.5787e-03,  3.2501e-03,  1.0620e-02,  ...,  6.8359e-02,
           9.8047e-01, -1.3281e-01],
         [-3.9062e-03,  4.6484e-01, -1.1250e+00,  ...,  4.6631e-02,
          -1.1953e+00, -3.4766e-01],
         [-4.2188e-01, -9.4531e-01, -8.1641e-01,  ...,  1.5332e-01,
          -4.5938e+00,  2.9688e-01],
         ...,
         [-3.3008e-01, -1.1797e+00,  1.0391e+00,  ...,  7.0312e-02,
          -2.2500e+00, -2.3560e-02],
         [-4.2578e-01,  9.8145e-02,  2.2461e-02,  ..., -3.5889e-02,
          -2.8125e+00, -2.8198e-02],
         [-1.7422e+00,  7.0312e-02, -1.5625e-01,  ..., -8.1641e-01,
          -2.9844e+00,  3.8086e-01]],

        [[ 2.5787e-03,  3.2501e-03,  1.0620e-02,  ...,  6.8359e-02,
           9.8047e-01, -1.3281e-01],
         [-3.9062e-03,  4.6484e-01, -1.1250e+00,  ...,  4.6631e-02,
          -1.1953e+00, -3.4766e-01],
         [-4.2188e-01, -9.4531e-01, -8.1641e-01,  ...,  1.5332e-01,
          -4.5938e+00,  2.9688e-01],
         ...,
         [-3.3008e-01, -1.1797e+00,  1.0391e+00,  ...,  7.0312e-02,
          -2.2500e+00, -2.3560e-02],
         [-4.2578e-01,  9.8145e-02,  2.2461e-02,  ..., -3.5889e-02,
          -2.8125e+00, -2.8198e-02],
         [-1.7422e+00,  7.0312e-02, -1.5625e-01,  ..., -8.1641e-01,
          -2.9844e+00,  3.8086e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-0.0053,  0.0099, -0.0089,  ...,  0.0029, -0.0044, -0.0042],
         [ 2.7031, -0.5781,  0.9375,  ...,  2.6094, -1.6562, -0.2637],
         [ 0.5664, -1.2891,  0.1299,  ...,  0.7930, -0.0786, -1.3047],
         ...,
         [-1.0547, -0.8203, -0.3613,  ...,  0.1406,  0.5547, -0.5547],
         [-1.5312,  0.5430, -0.8750,  ..., -0.2227, -0.8008, -0.6992],
         [ 1.0469,  0.3750, -0.0276,  ..., -0.4746,  1.1250,  0.3066]],

        [[-0.0053,  0.0099, -0.0089,  ...,  0.0029, -0.0044, -0.0042],
         [ 2.7031, -0.5781,  0.9375,  ...,  2.6094, -1.6562, -0.2637],
         [ 0.5664, -1.2891,  0.1299,  ...,  0.7930, -0.0786, -1.3047],
         ...,
         [-1.0547, -0.8203, -0.3613,  ...,  0.1406,  0.5547, -0.5547],
         [-1.5312,  0.5430, -0.8750,  ..., -0.2227, -0.8008, -0.6992],
         [ 1.0469,  0.3750, -0.0276,  ..., -0.4746,  1.1250,  0.3066]],

        [[-0.0053,  0.0099, -0.0089,  ...,  0.0029, -0.0044, -0.0042],
         [ 2.7031, -0.5781,  0.9375,  ...,  2.6094, -1.6562, -0.2637],
         [ 0.5664, -1.2891,  0.1299,  ...,  0.7930, -0.0786, -1.3047],
         ...,
         [-1.0547, -0.8203, -0.3613,  ...,  0.1406,  0.5547, -0.5547],
         [-1.5312,  0.5430, -0.8750,  ..., -0.2227, -0.8008, -0.6992],
         [ 1.0469,  0.3750, -0.0276,  ..., -0.4746,  1.1250,  0.3066]],

        ...,

        [[-0.0347,  0.0099,  0.0165,  ..., -0.0503, -0.0281,  0.0081],
         [ 0.5117,  0.2988,  0.5977,  ...,  0.2715,  0.8438, -0.1855],
         [ 1.2891,  0.1729,  0.1484,  ...,  0.2363,  0.6758, -0.1934],
         ...,
         [ 0.9688,  0.0054,  0.2578,  ...,  0.7305, -0.4180, -0.0801],
         [ 1.8203,  0.5039,  0.0349,  ..., -0.1875, -0.0165, -0.0610],
         [ 0.6172,  0.1592, -0.6055,  ..., -0.0170, -0.2832, -0.6602]],

        [[-0.0347,  0.0099,  0.0165,  ..., -0.0503, -0.0281,  0.0081],
         [ 0.5117,  0.2988,  0.5977,  ...,  0.2715,  0.8438, -0.1855],
         [ 1.2891,  0.1729,  0.1484,  ...,  0.2363,  0.6758, -0.1934],
         ...,
         [ 0.9688,  0.0054,  0.2578,  ...,  0.7305, -0.4180, -0.0801],
         [ 1.8203,  0.5039,  0.0349,  ..., -0.1875, -0.0165, -0.0610],
         [ 0.6172,  0.1592, -0.6055,  ..., -0.0170, -0.2832, -0.6602]],

        [[-0.0347,  0.0099,  0.0165,  ..., -0.0503, -0.0281,  0.0081],
         [ 0.5117,  0.2988,  0.5977,  ...,  0.2715,  0.8438, -0.1855],
         [ 1.2891,  0.1729,  0.1484,  ...,  0.2363,  0.6758, -0.1934],
         ...,
         [ 0.9688,  0.0054,  0.2578,  ...,  0.7305, -0.4180, -0.0801],
         [ 1.8203,  0.5039,  0.0349,  ..., -0.1875, -0.0165, -0.0610],
         [ 0.6172,  0.1592, -0.6055,  ..., -0.0170, -0.2832, -0.6602]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.6113e-02,  5.1498e-04, -1.2268e-02,  ...,  2.3047e-01,
           6.2500e-02,  7.8906e-01],
         [ 1.0156e-01, -8.4766e-01,  2.4219e-01,  ...,  5.6641e-01,
          -2.0781e+00, -3.1562e+00],
         [ 1.3672e-01,  1.2793e-01,  9.2188e-01,  ..., -3.5938e-01,
          -1.2109e+00, -4.1875e+00],
         ...,
         [-4.8047e-01, -2.7344e-01, -4.8828e-01,  ..., -5.5859e-01,
          -4.6484e-01,  9.4727e-02],
         [ 2.6562e-01, -7.6172e-01,  2.9297e-01,  ..., -6.7969e-01,
           5.7812e-01, -1.2812e+00],
         [ 3.6328e-01,  2.2461e-01,  3.4375e-01,  ..., -6.7578e-01,
           1.6328e+00, -2.6875e+00]],

        [[ 1.6113e-02,  5.1498e-04, -1.2268e-02,  ...,  2.3047e-01,
           6.2500e-02,  7.8906e-01],
         [ 1.0156e-01, -8.4766e-01,  2.4219e-01,  ...,  5.6641e-01,
          -2.0781e+00, -3.1562e+00],
         [ 1.3672e-01,  1.2793e-01,  9.2188e-01,  ..., -3.5938e-01,
          -1.2109e+00, -4.1875e+00],
         ...,
         [-4.8047e-01, -2.7344e-01, -4.8828e-01,  ..., -5.5859e-01,
          -4.6484e-01,  9.4727e-02],
         [ 2.6562e-01, -7.6172e-01,  2.9297e-01,  ..., -6.7969e-01,
           5.7812e-01, -1.2812e+00],
         [ 3.6328e-01,  2.2461e-01,  3.4375e-01,  ..., -6.7578e-01,
           1.6328e+00, -2.6875e+00]],

        [[ 1.6113e-02,  5.1498e-04, -1.2268e-02,  ...,  2.3047e-01,
           6.2500e-02,  7.8906e-01],
         [ 1.0156e-01, -8.4766e-01,  2.4219e-01,  ...,  5.6641e-01,
          -2.0781e+00, -3.1562e+00],
         [ 1.3672e-01,  1.2793e-01,  9.2188e-01,  ..., -3.5938e-01,
          -1.2109e+00, -4.1875e+00],
         ...,
         [-4.8047e-01, -2.7344e-01, -4.8828e-01,  ..., -5.5859e-01,
          -4.6484e-01,  9.4727e-02],
         [ 2.6562e-01, -7.6172e-01,  2.9297e-01,  ..., -6.7969e-01,
           5.7812e-01, -1.2812e+00],
         [ 3.6328e-01,  2.2461e-01,  3.4375e-01,  ..., -6.7578e-01,
           1.6328e+00, -2.6875e+00]],

        ...,

        [[-6.3171e-03,  1.1536e-02, -4.2152e-04,  ..., -9.0820e-02,
          -2.0605e-01,  7.3730e-02],
         [ 5.0781e-02,  6.1523e-02,  7.6172e-02,  ...,  2.1562e+00,
           1.2988e-01, -1.7812e+00],
         [-2.1484e-01,  3.0273e-01,  4.9219e-01,  ...,  3.0938e+00,
          -1.2266e+00, -3.8594e+00],
         ...,
         [-9.6094e-01,  3.2227e-02,  2.9297e-02,  ..., -1.8906e+00,
          -1.9141e+00, -7.1777e-02],
         [-1.7578e-01, -2.8711e-01, -6.5430e-02,  ..., -3.0884e-02,
           4.5703e-01, -2.5000e-01],
         [ 1.0234e+00,  7.0801e-02, -8.7500e-01,  ...,  2.2188e+00,
           6.3281e-01,  8.6328e-01]],

        [[-6.3171e-03,  1.1536e-02, -4.2152e-04,  ..., -9.0820e-02,
          -2.0605e-01,  7.3730e-02],
         [ 5.0781e-02,  6.1523e-02,  7.6172e-02,  ...,  2.1562e+00,
           1.2988e-01, -1.7812e+00],
         [-2.1484e-01,  3.0273e-01,  4.9219e-01,  ...,  3.0938e+00,
          -1.2266e+00, -3.8594e+00],
         ...,
         [-9.6094e-01,  3.2227e-02,  2.9297e-02,  ..., -1.8906e+00,
          -1.9141e+00, -7.1777e-02],
         [-1.7578e-01, -2.8711e-01, -6.5430e-02,  ..., -3.0884e-02,
           4.5703e-01, -2.5000e-01],
         [ 1.0234e+00,  7.0801e-02, -8.7500e-01,  ...,  2.2188e+00,
           6.3281e-01,  8.6328e-01]],

        [[-6.3171e-03,  1.1536e-02, -4.2152e-04,  ..., -9.0820e-02,
          -2.0605e-01,  7.3730e-02],
         [ 5.0781e-02,  6.1523e-02,  7.6172e-02,  ...,  2.1562e+00,
           1.2988e-01, -1.7812e+00],
         [-2.1484e-01,  3.0273e-01,  4.9219e-01,  ...,  3.0938e+00,
          -1.2266e+00, -3.8594e+00],
         ...,
         [-9.6094e-01,  3.2227e-02,  2.9297e-02,  ..., -1.8906e+00,
          -1.9141e+00, -7.1777e-02],
         [-1.7578e-01, -2.8711e-01, -6.5430e-02,  ..., -3.0884e-02,
           4.5703e-01, -2.5000e-01],
         [ 1.0234e+00,  7.0801e-02, -8.7500e-01,  ...,  2.2188e+00,
           6.3281e-01,  8.6328e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 6.7444e-03, -1.3672e-02,  2.2278e-03,  ...,  3.4180e-03,
          -8.3008e-03, -3.9062e-03],
         [ 9.1309e-02,  9.2773e-02, -8.5156e-01,  ..., -1.7734e+00,
          -1.3203e+00,  1.1094e+00],
         [ 8.8281e-01, -3.5352e-01, -1.4258e-01,  ..., -5.0391e-01,
           6.4941e-02,  2.2656e-01],
         ...,
         [ 7.9102e-02, -4.5703e-01, -6.7578e-01,  ..., -4.8242e-01,
           9.5312e-01,  2.5977e-01],
         [-1.0469e+00, -1.8555e-02,  4.5703e-01,  ...,  4.4336e-01,
          -6.0547e-01, -3.1055e-01],
         [-4.3555e-01, -1.3047e+00,  2.1582e-01,  ...,  5.9326e-02,
           1.3516e+00,  3.2617e-01]],

        [[ 6.7444e-03, -1.3672e-02,  2.2278e-03,  ...,  3.4180e-03,
          -8.3008e-03, -3.9062e-03],
         [ 9.1309e-02,  9.2773e-02, -8.5156e-01,  ..., -1.7734e+00,
          -1.3203e+00,  1.1094e+00],
         [ 8.8281e-01, -3.5352e-01, -1.4258e-01,  ..., -5.0391e-01,
           6.4941e-02,  2.2656e-01],
         ...,
         [ 7.9102e-02, -4.5703e-01, -6.7578e-01,  ..., -4.8242e-01,
           9.5312e-01,  2.5977e-01],
         [-1.0469e+00, -1.8555e-02,  4.5703e-01,  ...,  4.4336e-01,
          -6.0547e-01, -3.1055e-01],
         [-4.3555e-01, -1.3047e+00,  2.1582e-01,  ...,  5.9326e-02,
           1.3516e+00,  3.2617e-01]],

        [[ 6.7444e-03, -1.3672e-02,  2.2278e-03,  ...,  3.4180e-03,
          -8.3008e-03, -3.9062e-03],
         [ 9.1309e-02,  9.2773e-02, -8.5156e-01,  ..., -1.7734e+00,
          -1.3203e+00,  1.1094e+00],
         [ 8.8281e-01, -3.5352e-01, -1.4258e-01,  ..., -5.0391e-01,
           6.4941e-02,  2.2656e-01],
         ...,
         [ 7.9102e-02, -4.5703e-01, -6.7578e-01,  ..., -4.8242e-01,
           9.5312e-01,  2.5977e-01],
         [-1.0469e+00, -1.8555e-02,  4.5703e-01,  ...,  4.4336e-01,
          -6.0547e-01, -3.1055e-01],
         [-4.3555e-01, -1.3047e+00,  2.1582e-01,  ...,  5.9326e-02,
           1.3516e+00,  3.2617e-01]],

        ...,

        [[-2.0142e-02,  5.4626e-03,  2.3499e-03,  ..., -1.0147e-03,
          -1.6113e-02,  8.1177e-03],
         [ 1.0078e+00,  2.0801e-01, -1.6250e+00,  ...,  2.9297e-01,
           7.2656e-01,  2.6094e+00],
         [ 8.9453e-01,  3.0156e+00, -7.5000e-01,  ...,  6.4453e-01,
           1.6602e-01,  2.5625e+00],
         ...,
         [-1.6895e-01,  5.7812e-01, -1.6016e-01,  ..., -1.2598e-01,
          -5.0781e-01,  1.9141e-01],
         [-5.3516e-01,  7.4219e-01, -9.5215e-02,  ...,  6.4453e-01,
           5.2734e-01, -1.5332e-01],
         [-9.6875e-01, -1.0234e+00, -1.8203e+00,  ..., -7.6953e-01,
           8.1250e-01, -1.8281e+00]],

        [[-2.0142e-02,  5.4626e-03,  2.3499e-03,  ..., -1.0147e-03,
          -1.6113e-02,  8.1177e-03],
         [ 1.0078e+00,  2.0801e-01, -1.6250e+00,  ...,  2.9297e-01,
           7.2656e-01,  2.6094e+00],
         [ 8.9453e-01,  3.0156e+00, -7.5000e-01,  ...,  6.4453e-01,
           1.6602e-01,  2.5625e+00],
         ...,
         [-1.6895e-01,  5.7812e-01, -1.6016e-01,  ..., -1.2598e-01,
          -5.0781e-01,  1.9141e-01],
         [-5.3516e-01,  7.4219e-01, -9.5215e-02,  ...,  6.4453e-01,
           5.2734e-01, -1.5332e-01],
         [-9.6875e-01, -1.0234e+00, -1.8203e+00,  ..., -7.6953e-01,
           8.1250e-01, -1.8281e+00]],

        [[-2.0142e-02,  5.4626e-03,  2.3499e-03,  ..., -1.0147e-03,
          -1.6113e-02,  8.1177e-03],
         [ 1.0078e+00,  2.0801e-01, -1.6250e+00,  ...,  2.9297e-01,
           7.2656e-01,  2.6094e+00],
         [ 8.9453e-01,  3.0156e+00, -7.5000e-01,  ...,  6.4453e-01,
           1.6602e-01,  2.5625e+00],
         ...,
         [-1.6895e-01,  5.7812e-01, -1.6016e-01,  ..., -1.2598e-01,
          -5.0781e-01,  1.9141e-01],
         [-5.3516e-01,  7.4219e-01, -9.5215e-02,  ...,  6.4453e-01,
           5.2734e-01, -1.5332e-01],
         [-9.6875e-01, -1.0234e+00, -1.8203e+00,  ..., -7.6953e-01,
           8.1250e-01, -1.8281e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.2583e-03,  1.4648e-03, -5.1270e-03,  ...,  8.6426e-02,
          -2.5586e-01, -1.3062e-02],
         [-2.4531e+00, -2.8516e-01, -9.7266e-01,  ..., -4.1797e-01,
          -6.9141e-01, -7.8516e-01],
         [-2.7031e+00, -9.8828e-01, -1.5234e+00,  ...,  1.9141e+00,
          -1.3672e+00,  1.1797e+00],
         ...,
         [-1.3750e+00, -7.4219e-01, -8.9062e-01,  ..., -1.7266e+00,
          -5.1562e-01, -7.1094e-01],
         [-1.0391e+00, -6.3672e-01, -9.1016e-01,  ..., -1.8438e+00,
          -1.3828e+00,  3.2617e-01],
         [-2.3730e-01, -3.7891e-01, -1.7500e+00,  ..., -1.2656e+00,
          -1.1250e+00, -2.7656e+00]],

        [[ 2.2583e-03,  1.4648e-03, -5.1270e-03,  ...,  8.6426e-02,
          -2.5586e-01, -1.3062e-02],
         [-2.4531e+00, -2.8516e-01, -9.7266e-01,  ..., -4.1797e-01,
          -6.9141e-01, -7.8516e-01],
         [-2.7031e+00, -9.8828e-01, -1.5234e+00,  ...,  1.9141e+00,
          -1.3672e+00,  1.1797e+00],
         ...,
         [-1.3750e+00, -7.4219e-01, -8.9062e-01,  ..., -1.7266e+00,
          -5.1562e-01, -7.1094e-01],
         [-1.0391e+00, -6.3672e-01, -9.1016e-01,  ..., -1.8438e+00,
          -1.3828e+00,  3.2617e-01],
         [-2.3730e-01, -3.7891e-01, -1.7500e+00,  ..., -1.2656e+00,
          -1.1250e+00, -2.7656e+00]],

        [[ 2.2583e-03,  1.4648e-03, -5.1270e-03,  ...,  8.6426e-02,
          -2.5586e-01, -1.3062e-02],
         [-2.4531e+00, -2.8516e-01, -9.7266e-01,  ..., -4.1797e-01,
          -6.9141e-01, -7.8516e-01],
         [-2.7031e+00, -9.8828e-01, -1.5234e+00,  ...,  1.9141e+00,
          -1.3672e+00,  1.1797e+00],
         ...,
         [-1.3750e+00, -7.4219e-01, -8.9062e-01,  ..., -1.7266e+00,
          -5.1562e-01, -7.1094e-01],
         [-1.0391e+00, -6.3672e-01, -9.1016e-01,  ..., -1.8438e+00,
          -1.3828e+00,  3.2617e-01],
         [-2.3730e-01, -3.7891e-01, -1.7500e+00,  ..., -1.2656e+00,
          -1.1250e+00, -2.7656e+00]],

        ...,

        [[-3.2349e-03,  5.8594e-03,  7.5378e-03,  ..., -8.3008e-02,
           6.6406e-02, -8.7891e-02],
         [-3.1250e+00,  9.4141e-01, -2.5781e-01,  ...,  6.3281e-01,
          -1.2969e+00,  9.1797e-02],
         [ 1.6406e-01, -5.3906e-01, -2.2188e+00,  ..., -1.6602e-01,
          -2.3906e+00,  1.1328e-01],
         ...,
         [-1.6875e+00,  6.4062e-01,  1.8906e+00,  ..., -8.6426e-02,
           7.7344e-01, -1.6406e+00],
         [-3.6719e-01,  1.5430e-01,  1.8438e+00,  ..., -1.1182e-01,
           2.2949e-02, -1.5469e+00],
         [ 1.8359e+00,  1.3828e+00,  5.3906e-01,  ..., -6.1719e-01,
          -5.9766e-01,  1.7456e-02]],

        [[-3.2349e-03,  5.8594e-03,  7.5378e-03,  ..., -8.3008e-02,
           6.6406e-02, -8.7891e-02],
         [-3.1250e+00,  9.4141e-01, -2.5781e-01,  ...,  6.3281e-01,
          -1.2969e+00,  9.1797e-02],
         [ 1.6406e-01, -5.3906e-01, -2.2188e+00,  ..., -1.6602e-01,
          -2.3906e+00,  1.1328e-01],
         ...,
         [-1.6875e+00,  6.4062e-01,  1.8906e+00,  ..., -8.6426e-02,
           7.7344e-01, -1.6406e+00],
         [-3.6719e-01,  1.5430e-01,  1.8438e+00,  ..., -1.1182e-01,
           2.2949e-02, -1.5469e+00],
         [ 1.8359e+00,  1.3828e+00,  5.3906e-01,  ..., -6.1719e-01,
          -5.9766e-01,  1.7456e-02]],

        [[-3.2349e-03,  5.8594e-03,  7.5378e-03,  ..., -8.3008e-02,
           6.6406e-02, -8.7891e-02],
         [-3.1250e+00,  9.4141e-01, -2.5781e-01,  ...,  6.3281e-01,
          -1.2969e+00,  9.1797e-02],
         [ 1.6406e-01, -5.3906e-01, -2.2188e+00,  ..., -1.6602e-01,
          -2.3906e+00,  1.1328e-01],
         ...,
         [-1.6875e+00,  6.4062e-01,  1.8906e+00,  ..., -8.6426e-02,
           7.7344e-01, -1.6406e+00],
         [-3.6719e-01,  1.5430e-01,  1.8438e+00,  ..., -1.1182e-01,
           2.2949e-02, -1.5469e+00],
         [ 1.8359e+00,  1.3828e+00,  5.3906e-01,  ..., -6.1719e-01,
          -5.9766e-01,  1.7456e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 5.5847e-03, -3.1738e-03,  1.5137e-02,  ...,  1.8692e-03,
           1.3733e-02,  1.8433e-02],
         [ 1.6992e-01, -1.4609e+00,  1.2012e-01,  ..., -3.7305e-01,
           1.3359e+00, -2.1191e-01],
         [ 3.2617e-01, -1.7090e-01,  1.1172e+00,  ...,  1.2656e+00,
          -4.8633e-01, -4.1602e-01],
         ...,
         [-1.3672e-01,  2.1777e-01,  7.2266e-01,  ..., -1.3062e-02,
          -1.6113e-01, -3.6914e-01],
         [ 6.2891e-01,  5.0000e-01, -7.6172e-02,  ...,  3.5352e-01,
           7.7734e-01, -3.6523e-01],
         [-8.6719e-01,  4.4922e-01, -5.7031e-01,  ..., -2.5781e-01,
          -6.9141e-01, -1.5469e+00]],

        [[ 5.5847e-03, -3.1738e-03,  1.5137e-02,  ...,  1.8692e-03,
           1.3733e-02,  1.8433e-02],
         [ 1.6992e-01, -1.4609e+00,  1.2012e-01,  ..., -3.7305e-01,
           1.3359e+00, -2.1191e-01],
         [ 3.2617e-01, -1.7090e-01,  1.1172e+00,  ...,  1.2656e+00,
          -4.8633e-01, -4.1602e-01],
         ...,
         [-1.3672e-01,  2.1777e-01,  7.2266e-01,  ..., -1.3062e-02,
          -1.6113e-01, -3.6914e-01],
         [ 6.2891e-01,  5.0000e-01, -7.6172e-02,  ...,  3.5352e-01,
           7.7734e-01, -3.6523e-01],
         [-8.6719e-01,  4.4922e-01, -5.7031e-01,  ..., -2.5781e-01,
          -6.9141e-01, -1.5469e+00]],

        [[ 5.5847e-03, -3.1738e-03,  1.5137e-02,  ...,  1.8692e-03,
           1.3733e-02,  1.8433e-02],
         [ 1.6992e-01, -1.4609e+00,  1.2012e-01,  ..., -3.7305e-01,
           1.3359e+00, -2.1191e-01],
         [ 3.2617e-01, -1.7090e-01,  1.1172e+00,  ...,  1.2656e+00,
          -4.8633e-01, -4.1602e-01],
         ...,
         [-1.3672e-01,  2.1777e-01,  7.2266e-01,  ..., -1.3062e-02,
          -1.6113e-01, -3.6914e-01],
         [ 6.2891e-01,  5.0000e-01, -7.6172e-02,  ...,  3.5352e-01,
           7.7734e-01, -3.6523e-01],
         [-8.6719e-01,  4.4922e-01, -5.7031e-01,  ..., -2.5781e-01,
          -6.9141e-01, -1.5469e+00]],

        ...,

        [[ 3.8910e-03, -1.4572e-03, -4.1809e-03,  ...,  1.4282e-02,
          -5.4321e-03,  6.4392e-03],
         [ 1.1250e+00, -6.4453e-01, -1.4551e-01,  ..., -6.5430e-02,
           6.2891e-01, -6.4453e-01],
         [-3.3398e-01, -1.8457e-01,  5.4688e-01,  ...,  8.7500e-01,
           2.1191e-01,  1.0781e+00],
         ...,
         [-2.0215e-01, -4.0625e-01,  1.2578e+00,  ...,  4.1797e-01,
           1.7969e+00, -8.7891e-02],
         [ 3.3203e-01,  8.5938e-01, -5.9766e-01,  ..., -7.1484e-01,
          -9.8438e-01, -5.3906e-01],
         [ 1.4922e+00, -2.6367e-01,  9.3262e-02,  ...,  8.6719e-01,
          -4.4336e-01,  7.1875e-01]],

        [[ 3.8910e-03, -1.4572e-03, -4.1809e-03,  ...,  1.4282e-02,
          -5.4321e-03,  6.4392e-03],
         [ 1.1250e+00, -6.4453e-01, -1.4551e-01,  ..., -6.5430e-02,
           6.2891e-01, -6.4453e-01],
         [-3.3398e-01, -1.8457e-01,  5.4688e-01,  ...,  8.7500e-01,
           2.1191e-01,  1.0781e+00],
         ...,
         [-2.0215e-01, -4.0625e-01,  1.2578e+00,  ...,  4.1797e-01,
           1.7969e+00, -8.7891e-02],
         [ 3.3203e-01,  8.5938e-01, -5.9766e-01,  ..., -7.1484e-01,
          -9.8438e-01, -5.3906e-01],
         [ 1.4922e+00, -2.6367e-01,  9.3262e-02,  ...,  8.6719e-01,
          -4.4336e-01,  7.1875e-01]],

        [[ 3.8910e-03, -1.4572e-03, -4.1809e-03,  ...,  1.4282e-02,
          -5.4321e-03,  6.4392e-03],
         [ 1.1250e+00, -6.4453e-01, -1.4551e-01,  ..., -6.5430e-02,
           6.2891e-01, -6.4453e-01],
         [-3.3398e-01, -1.8457e-01,  5.4688e-01,  ...,  8.7500e-01,
           2.1191e-01,  1.0781e+00],
         ...,
         [-2.0215e-01, -4.0625e-01,  1.2578e+00,  ...,  4.1797e-01,
           1.7969e+00, -8.7891e-02],
         [ 3.3203e-01,  8.5938e-01, -5.9766e-01,  ..., -7.1484e-01,
          -9.8438e-01, -5.3906e-01],
         [ 1.4922e+00, -2.6367e-01,  9.3262e-02,  ...,  8.6719e-01,
          -4.4336e-01,  7.1875e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 9.3994e-03, -1.4465e-02, -9.5825e-03,  ..., -1.4551e-01,
           5.5469e-01,  2.6562e+00],
         [ 9.4922e-01, -3.2812e-01,  2.4414e-02,  ..., -9.5312e-01,
          -3.2422e-01, -6.4062e+00],
         [-4.3750e-01, -2.2070e-01,  1.6016e-01,  ...,  1.9453e+00,
          -2.6978e-02, -9.3125e+00],
         ...,
         [-6.6895e-02, -3.8477e-01, -2.7344e-01,  ...,  3.0273e-01,
          -1.2988e-01, -6.1875e+00],
         [ 1.6016e-01, -1.3867e-01, -2.9883e-01,  ...,  2.9102e-01,
          -1.1016e+00, -4.7500e+00],
         [ 2.0801e-01, -3.7109e-01,  8.1250e-01,  ...,  2.9844e+00,
           4.0771e-02, -7.2500e+00]],

        [[ 9.3994e-03, -1.4465e-02, -9.5825e-03,  ..., -1.4551e-01,
           5.5469e-01,  2.6562e+00],
         [ 9.4922e-01, -3.2812e-01,  2.4414e-02,  ..., -9.5312e-01,
          -3.2422e-01, -6.4062e+00],
         [-4.3750e-01, -2.2070e-01,  1.6016e-01,  ...,  1.9453e+00,
          -2.6978e-02, -9.3125e+00],
         ...,
         [-6.6895e-02, -3.8477e-01, -2.7344e-01,  ...,  3.0273e-01,
          -1.2988e-01, -6.1875e+00],
         [ 1.6016e-01, -1.3867e-01, -2.9883e-01,  ...,  2.9102e-01,
          -1.1016e+00, -4.7500e+00],
         [ 2.0801e-01, -3.7109e-01,  8.1250e-01,  ...,  2.9844e+00,
           4.0771e-02, -7.2500e+00]],

        [[ 9.3994e-03, -1.4465e-02, -9.5825e-03,  ..., -1.4551e-01,
           5.5469e-01,  2.6562e+00],
         [ 9.4922e-01, -3.2812e-01,  2.4414e-02,  ..., -9.5312e-01,
          -3.2422e-01, -6.4062e+00],
         [-4.3750e-01, -2.2070e-01,  1.6016e-01,  ...,  1.9453e+00,
          -2.6978e-02, -9.3125e+00],
         ...,
         [-6.6895e-02, -3.8477e-01, -2.7344e-01,  ...,  3.0273e-01,
          -1.2988e-01, -6.1875e+00],
         [ 1.6016e-01, -1.3867e-01, -2.9883e-01,  ...,  2.9102e-01,
          -1.1016e+00, -4.7500e+00],
         [ 2.0801e-01, -3.7109e-01,  8.1250e-01,  ...,  2.9844e+00,
           4.0771e-02, -7.2500e+00]],

        ...,

        [[-2.3926e-02, -3.6240e-04, -4.7607e-03,  ...,  2.4048e-02,
           2.1875e+00,  3.8086e-01],
         [ 1.6250e+00, -7.9297e-01, -4.3359e-01,  ..., -3.8086e-01,
          -4.6250e+00,  2.2852e-01],
         [ 7.5195e-02, -5.9766e-01,  7.3828e-01,  ...,  9.8047e-01,
          -3.7500e+00,  1.2256e-01],
         ...,
         [-4.1211e-01, -7.4219e-01,  1.6113e-01,  ...,  2.2344e+00,
          -3.8125e+00,  9.4922e-01],
         [-3.1250e-01, -2.9492e-01, -5.8105e-02,  ...,  9.8438e-01,
          -4.7188e+00,  7.6172e-01],
         [-1.5234e+00, -9.0625e-01,  2.8711e-01,  ...,  2.0469e+00,
          -3.7812e+00,  1.7500e+00]],

        [[-2.3926e-02, -3.6240e-04, -4.7607e-03,  ...,  2.4048e-02,
           2.1875e+00,  3.8086e-01],
         [ 1.6250e+00, -7.9297e-01, -4.3359e-01,  ..., -3.8086e-01,
          -4.6250e+00,  2.2852e-01],
         [ 7.5195e-02, -5.9766e-01,  7.3828e-01,  ...,  9.8047e-01,
          -3.7500e+00,  1.2256e-01],
         ...,
         [-4.1211e-01, -7.4219e-01,  1.6113e-01,  ...,  2.2344e+00,
          -3.8125e+00,  9.4922e-01],
         [-3.1250e-01, -2.9492e-01, -5.8105e-02,  ...,  9.8438e-01,
          -4.7188e+00,  7.6172e-01],
         [-1.5234e+00, -9.0625e-01,  2.8711e-01,  ...,  2.0469e+00,
          -3.7812e+00,  1.7500e+00]],

        [[-2.3926e-02, -3.6240e-04, -4.7607e-03,  ...,  2.4048e-02,
           2.1875e+00,  3.8086e-01],
         [ 1.6250e+00, -7.9297e-01, -4.3359e-01,  ..., -3.8086e-01,
          -4.6250e+00,  2.2852e-01],
         [ 7.5195e-02, -5.9766e-01,  7.3828e-01,  ...,  9.8047e-01,
          -3.7500e+00,  1.2256e-01],
         ...,
         [-4.1211e-01, -7.4219e-01,  1.6113e-01,  ...,  2.2344e+00,
          -3.8125e+00,  9.4922e-01],
         [-3.1250e-01, -2.9492e-01, -5.8105e-02,  ...,  9.8438e-01,
          -4.7188e+00,  7.6172e-01],
         [-1.5234e+00, -9.0625e-01,  2.8711e-01,  ...,  2.0469e+00,
          -3.7812e+00,  1.7500e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 1.2268e-02, -1.1047e-02, -2.2949e-02,  ...,  1.1963e-02,
          -4.4861e-03,  2.0508e-02],
         [-3.2031e-01, -8.4766e-01,  5.0391e-01,  ..., -1.3184e-01,
           2.0801e-01, -1.2188e+00],
         [-5.3711e-02,  1.8281e+00,  2.7656e+00,  ..., -1.2256e-01,
           2.0000e+00,  1.0234e+00],
         ...,
         [-1.8203e+00,  1.4297e+00,  3.7305e-01,  ...,  1.6328e+00,
           1.0703e+00,  2.3906e+00],
         [-3.9844e-01,  4.4531e-01,  9.3750e-01,  ...,  1.8750e-01,
           1.4141e+00,  2.0625e+00],
         [-9.2188e-01,  2.4414e-01, -8.1641e-01,  ..., -8.8281e-01,
          -8.3008e-02, -7.6562e-01]],

        [[ 1.2268e-02, -1.1047e-02, -2.2949e-02,  ...,  1.1963e-02,
          -4.4861e-03,  2.0508e-02],
         [-3.2031e-01, -8.4766e-01,  5.0391e-01,  ..., -1.3184e-01,
           2.0801e-01, -1.2188e+00],
         [-5.3711e-02,  1.8281e+00,  2.7656e+00,  ..., -1.2256e-01,
           2.0000e+00,  1.0234e+00],
         ...,
         [-1.8203e+00,  1.4297e+00,  3.7305e-01,  ...,  1.6328e+00,
           1.0703e+00,  2.3906e+00],
         [-3.9844e-01,  4.4531e-01,  9.3750e-01,  ...,  1.8750e-01,
           1.4141e+00,  2.0625e+00],
         [-9.2188e-01,  2.4414e-01, -8.1641e-01,  ..., -8.8281e-01,
          -8.3008e-02, -7.6562e-01]],

        [[ 1.2268e-02, -1.1047e-02, -2.2949e-02,  ...,  1.1963e-02,
          -4.4861e-03,  2.0508e-02],
         [-3.2031e-01, -8.4766e-01,  5.0391e-01,  ..., -1.3184e-01,
           2.0801e-01, -1.2188e+00],
         [-5.3711e-02,  1.8281e+00,  2.7656e+00,  ..., -1.2256e-01,
           2.0000e+00,  1.0234e+00],
         ...,
         [-1.8203e+00,  1.4297e+00,  3.7305e-01,  ...,  1.6328e+00,
           1.0703e+00,  2.3906e+00],
         [-3.9844e-01,  4.4531e-01,  9.3750e-01,  ...,  1.8750e-01,
           1.4141e+00,  2.0625e+00],
         [-9.2188e-01,  2.4414e-01, -8.1641e-01,  ..., -8.8281e-01,
          -8.3008e-02, -7.6562e-01]],

        ...,

        [[ 4.6082e-03, -1.6022e-04, -2.8687e-02,  ..., -4.7607e-03,
          -8.6060e-03, -6.3705e-04],
         [-1.5938e+00, -1.2305e-01,  2.7148e-01,  ..., -1.8555e-02,
          -4.1016e-01,  4.1406e-01],
         [-1.4062e+00,  6.1328e-01, -2.0605e-01,  ..., -3.7305e-01,
           5.9375e-01,  4.0625e-01],
         ...,
         [-3.4375e-01,  1.2969e+00,  1.0625e+00,  ..., -6.4941e-02,
          -5.7031e-01, -1.8906e+00],
         [ 3.0859e-01,  2.1484e-01,  7.1094e-01,  ..., -7.7734e-01,
           8.3203e-01,  7.1777e-02],
         [ 8.0859e-01,  2.4512e-01,  6.4844e-01,  ..., -1.0000e+00,
          -1.9629e-01, -1.4297e+00]],

        [[ 4.6082e-03, -1.6022e-04, -2.8687e-02,  ..., -4.7607e-03,
          -8.6060e-03, -6.3705e-04],
         [-1.5938e+00, -1.2305e-01,  2.7148e-01,  ..., -1.8555e-02,
          -4.1016e-01,  4.1406e-01],
         [-1.4062e+00,  6.1328e-01, -2.0605e-01,  ..., -3.7305e-01,
           5.9375e-01,  4.0625e-01],
         ...,
         [-3.4375e-01,  1.2969e+00,  1.0625e+00,  ..., -6.4941e-02,
          -5.7031e-01, -1.8906e+00],
         [ 3.0859e-01,  2.1484e-01,  7.1094e-01,  ..., -7.7734e-01,
           8.3203e-01,  7.1777e-02],
         [ 8.0859e-01,  2.4512e-01,  6.4844e-01,  ..., -1.0000e+00,
          -1.9629e-01, -1.4297e+00]],

        [[ 4.6082e-03, -1.6022e-04, -2.8687e-02,  ..., -4.7607e-03,
          -8.6060e-03, -6.3705e-04],
         [-1.5938e+00, -1.2305e-01,  2.7148e-01,  ..., -1.8555e-02,
          -4.1016e-01,  4.1406e-01],
         [-1.4062e+00,  6.1328e-01, -2.0605e-01,  ..., -3.7305e-01,
           5.9375e-01,  4.0625e-01],
         ...,
         [-3.4375e-01,  1.2969e+00,  1.0625e+00,  ..., -6.4941e-02,
          -5.7031e-01, -1.8906e+00],
         [ 3.0859e-01,  2.1484e-01,  7.1094e-01,  ..., -7.7734e-01,
           8.3203e-01,  7.1777e-02],
         [ 8.0859e-01,  2.4512e-01,  6.4844e-01,  ..., -1.0000e+00,
          -1.9629e-01, -1.4297e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 7.6904e-03, -4.5166e-03,  4.1809e-03,  ...,  7.6599e-03,
          -2.2754e-01,  7.9102e-02],
         [-2.2031e+00,  9.3750e-02, -2.4512e-01,  ...,  1.4609e+00,
           7.2266e-01, -8.1641e-01],
         [-1.9062e+00, -5.0781e-01,  7.1289e-02,  ...,  1.2031e+00,
           1.2344e+00, -8.0078e-01],
         ...,
         [-8.5547e-01,  3.1250e-02, -8.2031e-01,  ...,  3.7969e+00,
          -1.9727e-01,  1.7656e+00],
         [-1.0625e+00, -1.1406e+00, -5.3516e-01,  ...,  4.0312e+00,
          -9.2188e-01,  4.9805e-01],
         [-1.1953e+00, -3.3984e-01, -1.0234e+00,  ...,  1.4297e+00,
           2.4062e+00,  2.4023e-01]],

        [[ 7.6904e-03, -4.5166e-03,  4.1809e-03,  ...,  7.6599e-03,
          -2.2754e-01,  7.9102e-02],
         [-2.2031e+00,  9.3750e-02, -2.4512e-01,  ...,  1.4609e+00,
           7.2266e-01, -8.1641e-01],
         [-1.9062e+00, -5.0781e-01,  7.1289e-02,  ...,  1.2031e+00,
           1.2344e+00, -8.0078e-01],
         ...,
         [-8.5547e-01,  3.1250e-02, -8.2031e-01,  ...,  3.7969e+00,
          -1.9727e-01,  1.7656e+00],
         [-1.0625e+00, -1.1406e+00, -5.3516e-01,  ...,  4.0312e+00,
          -9.2188e-01,  4.9805e-01],
         [-1.1953e+00, -3.3984e-01, -1.0234e+00,  ...,  1.4297e+00,
           2.4062e+00,  2.4023e-01]],

        [[ 7.6904e-03, -4.5166e-03,  4.1809e-03,  ...,  7.6599e-03,
          -2.2754e-01,  7.9102e-02],
         [-2.2031e+00,  9.3750e-02, -2.4512e-01,  ...,  1.4609e+00,
           7.2266e-01, -8.1641e-01],
         [-1.9062e+00, -5.0781e-01,  7.1289e-02,  ...,  1.2031e+00,
           1.2344e+00, -8.0078e-01],
         ...,
         [-8.5547e-01,  3.1250e-02, -8.2031e-01,  ...,  3.7969e+00,
          -1.9727e-01,  1.7656e+00],
         [-1.0625e+00, -1.1406e+00, -5.3516e-01,  ...,  4.0312e+00,
          -9.2188e-01,  4.9805e-01],
         [-1.1953e+00, -3.3984e-01, -1.0234e+00,  ...,  1.4297e+00,
           2.4062e+00,  2.4023e-01]],

        ...,

        [[ 4.9744e-03, -6.3171e-03,  2.1076e-04,  ..., -7.6172e-02,
           2.6562e-01,  1.2354e-01],
         [ 1.2969e+00,  1.5625e+00,  1.3906e+00,  ...,  1.0703e+00,
           5.5078e-01,  1.7812e+00],
         [ 4.7266e-01, -8.3984e-01,  1.8750e+00,  ..., -9.1406e-01,
           1.4453e+00,  4.1992e-01],
         ...,
         [ 9.0625e-01, -9.6875e-01, -6.8848e-02,  ..., -1.6406e-01,
           2.1094e+00, -1.0000e+00],
         [-2.7832e-02, -2.7539e-01, -1.5137e-01,  ..., -2.0625e+00,
           1.6641e+00, -2.5156e+00],
         [-6.0547e-01, -4.1797e-01,  1.6953e+00,  ..., -3.7891e-01,
           1.7031e+00, -2.8906e-01]],

        [[ 4.9744e-03, -6.3171e-03,  2.1076e-04,  ..., -7.6172e-02,
           2.6562e-01,  1.2354e-01],
         [ 1.2969e+00,  1.5625e+00,  1.3906e+00,  ...,  1.0703e+00,
           5.5078e-01,  1.7812e+00],
         [ 4.7266e-01, -8.3984e-01,  1.8750e+00,  ..., -9.1406e-01,
           1.4453e+00,  4.1992e-01],
         ...,
         [ 9.0625e-01, -9.6875e-01, -6.8848e-02,  ..., -1.6406e-01,
           2.1094e+00, -1.0000e+00],
         [-2.7832e-02, -2.7539e-01, -1.5137e-01,  ..., -2.0625e+00,
           1.6641e+00, -2.5156e+00],
         [-6.0547e-01, -4.1797e-01,  1.6953e+00,  ..., -3.7891e-01,
           1.7031e+00, -2.8906e-01]],

        [[ 4.9744e-03, -6.3171e-03,  2.1076e-04,  ..., -7.6172e-02,
           2.6562e-01,  1.2354e-01],
         [ 1.2969e+00,  1.5625e+00,  1.3906e+00,  ...,  1.0703e+00,
           5.5078e-01,  1.7812e+00],
         [ 4.7266e-01, -8.3984e-01,  1.8750e+00,  ..., -9.1406e-01,
           1.4453e+00,  4.1992e-01],
         ...,
         [ 9.0625e-01, -9.6875e-01, -6.8848e-02,  ..., -1.6406e-01,
           2.1094e+00, -1.0000e+00],
         [-2.7832e-02, -2.7539e-01, -1.5137e-01,  ..., -2.0625e+00,
           1.6641e+00, -2.5156e+00],
         [-6.0547e-01, -4.1797e-01,  1.6953e+00,  ..., -3.7891e-01,
           1.7031e+00, -2.8906e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-4.2725e-03,  5.8594e-03,  1.6937e-03,  ...,  5.0049e-03,
          -1.9775e-02, -1.0620e-02],
         [ 3.3203e-01,  6.7969e-01,  8.1641e-01,  ...,  3.4375e-01,
          -1.2207e-01, -5.8984e-01],
         [-2.0215e-01,  7.8516e-01, -8.2812e-01,  ...,  1.2969e+00,
          -4.8047e-01, -6.5234e-01],
         ...,
         [-1.4531e+00,  1.0703e+00,  7.7148e-02,  ...,  2.0156e+00,
           3.0938e+00, -8.3203e-01],
         [-1.5938e+00,  6.0156e-01,  1.6094e+00,  ...,  9.5312e-01,
           2.8125e-01, -1.8438e+00],
         [-3.6719e-01, -1.6992e-01, -8.1250e-01,  ..., -1.0469e+00,
          -1.6895e-01, -4.2188e-01]],

        [[-4.2725e-03,  5.8594e-03,  1.6937e-03,  ...,  5.0049e-03,
          -1.9775e-02, -1.0620e-02],
         [ 3.3203e-01,  6.7969e-01,  8.1641e-01,  ...,  3.4375e-01,
          -1.2207e-01, -5.8984e-01],
         [-2.0215e-01,  7.8516e-01, -8.2812e-01,  ...,  1.2969e+00,
          -4.8047e-01, -6.5234e-01],
         ...,
         [-1.4531e+00,  1.0703e+00,  7.7148e-02,  ...,  2.0156e+00,
           3.0938e+00, -8.3203e-01],
         [-1.5938e+00,  6.0156e-01,  1.6094e+00,  ...,  9.5312e-01,
           2.8125e-01, -1.8438e+00],
         [-3.6719e-01, -1.6992e-01, -8.1250e-01,  ..., -1.0469e+00,
          -1.6895e-01, -4.2188e-01]],

        [[-4.2725e-03,  5.8594e-03,  1.6937e-03,  ...,  5.0049e-03,
          -1.9775e-02, -1.0620e-02],
         [ 3.3203e-01,  6.7969e-01,  8.1641e-01,  ...,  3.4375e-01,
          -1.2207e-01, -5.8984e-01],
         [-2.0215e-01,  7.8516e-01, -8.2812e-01,  ...,  1.2969e+00,
          -4.8047e-01, -6.5234e-01],
         ...,
         [-1.4531e+00,  1.0703e+00,  7.7148e-02,  ...,  2.0156e+00,
           3.0938e+00, -8.3203e-01],
         [-1.5938e+00,  6.0156e-01,  1.6094e+00,  ...,  9.5312e-01,
           2.8125e-01, -1.8438e+00],
         [-3.6719e-01, -1.6992e-01, -8.1250e-01,  ..., -1.0469e+00,
          -1.6895e-01, -4.2188e-01]],

        ...,

        [[ 2.7588e-02, -9.9487e-03, -1.7700e-02,  ..., -6.4087e-03,
           1.4038e-03,  5.3711e-03],
         [ 2.7930e-01, -2.5977e-01,  7.3047e-01,  ...,  8.4961e-02,
           3.4570e-01, -1.5332e-01],
         [-8.1177e-03,  1.7578e+00, -4.1016e-01,  ...,  3.9062e-01,
           3.3398e-01,  4.3750e-01],
         ...,
         [-2.8125e-01, -3.3789e-01,  4.7266e-01,  ...,  3.8477e-01,
          -9.4141e-01,  2.2949e-01],
         [ 9.4531e-01, -1.5312e+00,  1.1797e+00,  ...,  3.3906e+00,
           1.2266e+00, -4.6484e-01],
         [ 4.1211e-01, -7.2656e-01,  4.3359e-01,  ...,  8.6060e-03,
          -8.3594e-01,  2.0020e-01]],

        [[ 2.7588e-02, -9.9487e-03, -1.7700e-02,  ..., -6.4087e-03,
           1.4038e-03,  5.3711e-03],
         [ 2.7930e-01, -2.5977e-01,  7.3047e-01,  ...,  8.4961e-02,
           3.4570e-01, -1.5332e-01],
         [-8.1177e-03,  1.7578e+00, -4.1016e-01,  ...,  3.9062e-01,
           3.3398e-01,  4.3750e-01],
         ...,
         [-2.8125e-01, -3.3789e-01,  4.7266e-01,  ...,  3.8477e-01,
          -9.4141e-01,  2.2949e-01],
         [ 9.4531e-01, -1.5312e+00,  1.1797e+00,  ...,  3.3906e+00,
           1.2266e+00, -4.6484e-01],
         [ 4.1211e-01, -7.2656e-01,  4.3359e-01,  ...,  8.6060e-03,
          -8.3594e-01,  2.0020e-01]],

        [[ 2.7588e-02, -9.9487e-03, -1.7700e-02,  ..., -6.4087e-03,
           1.4038e-03,  5.3711e-03],
         [ 2.7930e-01, -2.5977e-01,  7.3047e-01,  ...,  8.4961e-02,
           3.4570e-01, -1.5332e-01],
         [-8.1177e-03,  1.7578e+00, -4.1016e-01,  ...,  3.9062e-01,
           3.3398e-01,  4.3750e-01],
         ...,
         [-2.8125e-01, -3.3789e-01,  4.7266e-01,  ...,  3.8477e-01,
          -9.4141e-01,  2.2949e-01],
         [ 9.4531e-01, -1.5312e+00,  1.1797e+00,  ...,  3.3906e+00,
           1.2266e+00, -4.6484e-01],
         [ 4.1211e-01, -7.2656e-01,  4.3359e-01,  ...,  8.6060e-03,
          -8.3594e-01,  2.0020e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-2.3346e-03,  9.5215e-03, -9.2163e-03,  ..., -1.6602e-01,
          -1.3672e-01,  2.4414e-01],
         [ 2.0312e+00, -6.0156e-01, -2.6953e-01,  ...,  1.7734e+00,
           7.8516e-01, -8.5547e-01],
         [ 1.6016e+00, -2.7734e-01, -2.1289e-01,  ...,  2.1406e+00,
           1.4551e-01, -5.0000e-01],
         ...,
         [ 9.1406e-01, -3.8672e-01,  7.9297e-01,  ...,  1.3281e-01,
          -1.9434e-01,  3.5547e-01],
         [ 5.8594e-01, -3.4766e-01,  7.6172e-02,  ...,  1.5820e-01,
           1.6875e+00,  1.0312e+00],
         [ 1.1875e+00,  9.2969e-01,  1.3867e-01,  ...,  7.6172e-01,
           8.8672e-01, -2.7344e-01]],

        [[-2.3346e-03,  9.5215e-03, -9.2163e-03,  ..., -1.6602e-01,
          -1.3672e-01,  2.4414e-01],
         [ 2.0312e+00, -6.0156e-01, -2.6953e-01,  ...,  1.7734e+00,
           7.8516e-01, -8.5547e-01],
         [ 1.6016e+00, -2.7734e-01, -2.1289e-01,  ...,  2.1406e+00,
           1.4551e-01, -5.0000e-01],
         ...,
         [ 9.1406e-01, -3.8672e-01,  7.9297e-01,  ...,  1.3281e-01,
          -1.9434e-01,  3.5547e-01],
         [ 5.8594e-01, -3.4766e-01,  7.6172e-02,  ...,  1.5820e-01,
           1.6875e+00,  1.0312e+00],
         [ 1.1875e+00,  9.2969e-01,  1.3867e-01,  ...,  7.6172e-01,
           8.8672e-01, -2.7344e-01]],

        [[-2.3346e-03,  9.5215e-03, -9.2163e-03,  ..., -1.6602e-01,
          -1.3672e-01,  2.4414e-01],
         [ 2.0312e+00, -6.0156e-01, -2.6953e-01,  ...,  1.7734e+00,
           7.8516e-01, -8.5547e-01],
         [ 1.6016e+00, -2.7734e-01, -2.1289e-01,  ...,  2.1406e+00,
           1.4551e-01, -5.0000e-01],
         ...,
         [ 9.1406e-01, -3.8672e-01,  7.9297e-01,  ...,  1.3281e-01,
          -1.9434e-01,  3.5547e-01],
         [ 5.8594e-01, -3.4766e-01,  7.6172e-02,  ...,  1.5820e-01,
           1.6875e+00,  1.0312e+00],
         [ 1.1875e+00,  9.2969e-01,  1.3867e-01,  ...,  7.6172e-01,
           8.8672e-01, -2.7344e-01]],

        ...,

        [[ 1.5137e-02, -7.3547e-03, -4.0894e-03,  ..., -2.2852e-01,
           8.4375e-01, -1.9141e-01],
         [-1.2422e+00, -1.0703e+00,  6.9922e-01,  ..., -3.3008e-01,
           1.2266e+00,  3.3125e+00],
         [-9.8438e-01, -7.5391e-01,  7.5000e-01,  ..., -1.3184e-01,
           2.2266e-01, -2.5938e+00],
         ...,
         [-1.0312e+00,  3.0664e-01,  1.0625e+00,  ...,  3.0000e+00,
          -5.0391e-01,  7.0312e-01],
         [-8.0469e-01,  1.0781e+00,  5.2734e-01,  ...,  3.8281e-01,
          -6.5430e-02,  5.4297e-01],
         [ 1.9434e-01,  1.3750e+00, -8.7891e-02,  ..., -9.5703e-01,
          -6.0547e-01, -1.4141e+00]],

        [[ 1.5137e-02, -7.3547e-03, -4.0894e-03,  ..., -2.2852e-01,
           8.4375e-01, -1.9141e-01],
         [-1.2422e+00, -1.0703e+00,  6.9922e-01,  ..., -3.3008e-01,
           1.2266e+00,  3.3125e+00],
         [-9.8438e-01, -7.5391e-01,  7.5000e-01,  ..., -1.3184e-01,
           2.2266e-01, -2.5938e+00],
         ...,
         [-1.0312e+00,  3.0664e-01,  1.0625e+00,  ...,  3.0000e+00,
          -5.0391e-01,  7.0312e-01],
         [-8.0469e-01,  1.0781e+00,  5.2734e-01,  ...,  3.8281e-01,
          -6.5430e-02,  5.4297e-01],
         [ 1.9434e-01,  1.3750e+00, -8.7891e-02,  ..., -9.5703e-01,
          -6.0547e-01, -1.4141e+00]],

        [[ 1.5137e-02, -7.3547e-03, -4.0894e-03,  ..., -2.2852e-01,
           8.4375e-01, -1.9141e-01],
         [-1.2422e+00, -1.0703e+00,  6.9922e-01,  ..., -3.3008e-01,
           1.2266e+00,  3.3125e+00],
         [-9.8438e-01, -7.5391e-01,  7.5000e-01,  ..., -1.3184e-01,
           2.2266e-01, -2.5938e+00],
         ...,
         [-1.0312e+00,  3.0664e-01,  1.0625e+00,  ...,  3.0000e+00,
          -5.0391e-01,  7.0312e-01],
         [-8.0469e-01,  1.0781e+00,  5.2734e-01,  ...,  3.8281e-01,
          -6.5430e-02,  5.4297e-01],
         [ 1.9434e-01,  1.3750e+00, -8.7891e-02,  ..., -9.5703e-01,
          -6.0547e-01, -1.4141e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-7.1106e-03, -1.7944e-02,  7.7820e-03,  ...,  1.2756e-02,
          -1.6724e-02,  4.2725e-03],
         [-1.6797e-01, -6.9922e-01,  4.2969e-02,  ...,  7.4707e-02,
          -1.8066e-01, -3.6133e-02],
         [ 2.7539e-01,  1.1562e+00, -1.1875e+00,  ..., -6.7383e-02,
          -3.1055e-01,  2.6611e-02],
         ...,
         [ 2.0703e-01,  2.5391e-01, -4.3701e-02,  ...,  2.9883e-01,
          -5.7812e-01,  3.8867e-01],
         [-5.6250e-01,  4.2773e-01,  7.8125e-01,  ..., -9.9219e-01,
          -5.0000e-01, -5.6641e-01],
         [ 8.3008e-02,  1.2969e+00,  1.1094e+00,  ...,  1.0469e+00,
           1.1484e+00,  2.0312e+00]],

        [[-7.1106e-03, -1.7944e-02,  7.7820e-03,  ...,  1.2756e-02,
          -1.6724e-02,  4.2725e-03],
         [-1.6797e-01, -6.9922e-01,  4.2969e-02,  ...,  7.4707e-02,
          -1.8066e-01, -3.6133e-02],
         [ 2.7539e-01,  1.1562e+00, -1.1875e+00,  ..., -6.7383e-02,
          -3.1055e-01,  2.6611e-02],
         ...,
         [ 2.0703e-01,  2.5391e-01, -4.3701e-02,  ...,  2.9883e-01,
          -5.7812e-01,  3.8867e-01],
         [-5.6250e-01,  4.2773e-01,  7.8125e-01,  ..., -9.9219e-01,
          -5.0000e-01, -5.6641e-01],
         [ 8.3008e-02,  1.2969e+00,  1.1094e+00,  ...,  1.0469e+00,
           1.1484e+00,  2.0312e+00]],

        [[-7.1106e-03, -1.7944e-02,  7.7820e-03,  ...,  1.2756e-02,
          -1.6724e-02,  4.2725e-03],
         [-1.6797e-01, -6.9922e-01,  4.2969e-02,  ...,  7.4707e-02,
          -1.8066e-01, -3.6133e-02],
         [ 2.7539e-01,  1.1562e+00, -1.1875e+00,  ..., -6.7383e-02,
          -3.1055e-01,  2.6611e-02],
         ...,
         [ 2.0703e-01,  2.5391e-01, -4.3701e-02,  ...,  2.9883e-01,
          -5.7812e-01,  3.8867e-01],
         [-5.6250e-01,  4.2773e-01,  7.8125e-01,  ..., -9.9219e-01,
          -5.0000e-01, -5.6641e-01],
         [ 8.3008e-02,  1.2969e+00,  1.1094e+00,  ...,  1.0469e+00,
           1.1484e+00,  2.0312e+00]],

        ...,

        [[-3.6469e-03, -8.3618e-03,  3.2349e-03,  ...,  2.0504e-04,
          -3.1128e-03,  1.9836e-03],
         [-2.4844e+00,  9.6484e-01, -8.0469e-01,  ...,  1.8047e+00,
           2.6562e+00, -2.4688e+00],
         [-8.5938e-01, -1.7344e+00, -9.7266e-01,  ...,  6.6797e-01,
           5.0781e-01, -2.2500e+00],
         ...,
         [ 1.3281e+00, -1.3438e+00, -9.9609e-01,  ...,  7.6562e-01,
          -1.2188e+00,  5.0000e-01],
         [-5.9375e-01, -3.8086e-01, -1.7734e+00,  ...,  1.4922e+00,
           5.1758e-02, -1.1016e+00],
         [-2.2656e+00,  2.6367e-01, -3.0156e+00,  ...,  1.0391e+00,
          -1.9688e+00, -7.5000e-01]],

        [[-3.6469e-03, -8.3618e-03,  3.2349e-03,  ...,  2.0504e-04,
          -3.1128e-03,  1.9836e-03],
         [-2.4844e+00,  9.6484e-01, -8.0469e-01,  ...,  1.8047e+00,
           2.6562e+00, -2.4688e+00],
         [-8.5938e-01, -1.7344e+00, -9.7266e-01,  ...,  6.6797e-01,
           5.0781e-01, -2.2500e+00],
         ...,
         [ 1.3281e+00, -1.3438e+00, -9.9609e-01,  ...,  7.6562e-01,
          -1.2188e+00,  5.0000e-01],
         [-5.9375e-01, -3.8086e-01, -1.7734e+00,  ...,  1.4922e+00,
           5.1758e-02, -1.1016e+00],
         [-2.2656e+00,  2.6367e-01, -3.0156e+00,  ...,  1.0391e+00,
          -1.9688e+00, -7.5000e-01]],

        [[-3.6469e-03, -8.3618e-03,  3.2349e-03,  ...,  2.0504e-04,
          -3.1128e-03,  1.9836e-03],
         [-2.4844e+00,  9.6484e-01, -8.0469e-01,  ...,  1.8047e+00,
           2.6562e+00, -2.4688e+00],
         [-8.5938e-01, -1.7344e+00, -9.7266e-01,  ...,  6.6797e-01,
           5.0781e-01, -2.2500e+00],
         ...,
         [ 1.3281e+00, -1.3438e+00, -9.9609e-01,  ...,  7.6562e-01,
          -1.2188e+00,  5.0000e-01],
         [-5.9375e-01, -3.8086e-01, -1.7734e+00,  ...,  1.4922e+00,
           5.1758e-02, -1.1016e+00],
         [-2.2656e+00,  2.6367e-01, -3.0156e+00,  ...,  1.0391e+00,
          -1.9688e+00, -7.5000e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 6.3324e-04,  3.0518e-03, -2.5635e-03,  ..., -1.9043e-01,
          -1.2695e-01,  1.2598e-01],
         [ 0.0000e+00,  9.1406e-01, -2.5000e+00,  ...,  1.4062e-01,
           6.0938e-01, -8.7891e-01],
         [-1.0078e+00,  1.2500e+00, -2.9492e-01,  ..., -2.1680e-01,
          -4.3555e-01,  2.7734e-01],
         ...,
         [-1.2695e-01,  1.9336e-01, -1.1475e-01,  ..., -6.9336e-02,
          -7.3828e-01, -6.5234e-01],
         [-2.9492e-01,  2.8906e-01,  7.0312e-01,  ..., -1.9453e+00,
          -6.0938e-01,  8.4375e-01],
         [-1.2266e+00, -3.7891e-01,  4.9219e-01,  ..., -2.3535e-01,
          -1.6797e+00, -1.3906e+00]],

        [[ 6.3324e-04,  3.0518e-03, -2.5635e-03,  ..., -1.9043e-01,
          -1.2695e-01,  1.2598e-01],
         [ 0.0000e+00,  9.1406e-01, -2.5000e+00,  ...,  1.4062e-01,
           6.0938e-01, -8.7891e-01],
         [-1.0078e+00,  1.2500e+00, -2.9492e-01,  ..., -2.1680e-01,
          -4.3555e-01,  2.7734e-01],
         ...,
         [-1.2695e-01,  1.9336e-01, -1.1475e-01,  ..., -6.9336e-02,
          -7.3828e-01, -6.5234e-01],
         [-2.9492e-01,  2.8906e-01,  7.0312e-01,  ..., -1.9453e+00,
          -6.0938e-01,  8.4375e-01],
         [-1.2266e+00, -3.7891e-01,  4.9219e-01,  ..., -2.3535e-01,
          -1.6797e+00, -1.3906e+00]],

        [[ 6.3324e-04,  3.0518e-03, -2.5635e-03,  ..., -1.9043e-01,
          -1.2695e-01,  1.2598e-01],
         [ 0.0000e+00,  9.1406e-01, -2.5000e+00,  ...,  1.4062e-01,
           6.0938e-01, -8.7891e-01],
         [-1.0078e+00,  1.2500e+00, -2.9492e-01,  ..., -2.1680e-01,
          -4.3555e-01,  2.7734e-01],
         ...,
         [-1.2695e-01,  1.9336e-01, -1.1475e-01,  ..., -6.9336e-02,
          -7.3828e-01, -6.5234e-01],
         [-2.9492e-01,  2.8906e-01,  7.0312e-01,  ..., -1.9453e+00,
          -6.0938e-01,  8.4375e-01],
         [-1.2266e+00, -3.7891e-01,  4.9219e-01,  ..., -2.3535e-01,
          -1.6797e+00, -1.3906e+00]],

        ...,

        [[-1.1978e-03, -1.1841e-02, -8.0566e-03,  ...,  1.3965e-01,
           5.9082e-02,  1.1865e-01],
         [-2.5781e+00, -1.0078e+00,  1.7578e+00,  ..., -2.5156e+00,
           2.7734e-01, -6.7188e-01],
         [-1.5938e+00, -1.5312e+00,  1.4609e+00,  ..., -3.6719e-01,
           1.8125e+00, -1.7188e+00],
         ...,
         [-1.1094e+00,  1.3359e+00,  7.1875e-01,  ...,  7.3047e-01,
          -2.3906e+00,  2.0312e+00],
         [-1.2031e+00,  1.4297e+00,  8.4766e-01,  ...,  1.7812e+00,
          -5.7422e-01,  2.2969e+00],
         [-5.5078e-01,  2.2500e+00,  1.7266e+00,  ..., -5.8203e-01,
          -1.6211e-01,  9.9609e-02]],

        [[-1.1978e-03, -1.1841e-02, -8.0566e-03,  ...,  1.3965e-01,
           5.9082e-02,  1.1865e-01],
         [-2.5781e+00, -1.0078e+00,  1.7578e+00,  ..., -2.5156e+00,
           2.7734e-01, -6.7188e-01],
         [-1.5938e+00, -1.5312e+00,  1.4609e+00,  ..., -3.6719e-01,
           1.8125e+00, -1.7188e+00],
         ...,
         [-1.1094e+00,  1.3359e+00,  7.1875e-01,  ...,  7.3047e-01,
          -2.3906e+00,  2.0312e+00],
         [-1.2031e+00,  1.4297e+00,  8.4766e-01,  ...,  1.7812e+00,
          -5.7422e-01,  2.2969e+00],
         [-5.5078e-01,  2.2500e+00,  1.7266e+00,  ..., -5.8203e-01,
          -1.6211e-01,  9.9609e-02]],

        [[-1.1978e-03, -1.1841e-02, -8.0566e-03,  ...,  1.3965e-01,
           5.9082e-02,  1.1865e-01],
         [-2.5781e+00, -1.0078e+00,  1.7578e+00,  ..., -2.5156e+00,
           2.7734e-01, -6.7188e-01],
         [-1.5938e+00, -1.5312e+00,  1.4609e+00,  ..., -3.6719e-01,
           1.8125e+00, -1.7188e+00],
         ...,
         [-1.1094e+00,  1.3359e+00,  7.1875e-01,  ...,  7.3047e-01,
          -2.3906e+00,  2.0312e+00],
         [-1.2031e+00,  1.4297e+00,  8.4766e-01,  ...,  1.7812e+00,
          -5.7422e-01,  2.2969e+00],
         [-5.5078e-01,  2.2500e+00,  1.7266e+00,  ..., -5.8203e-01,
          -1.6211e-01,  9.9609e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 2.4658e-02, -2.5558e-04, -4.5166e-03,  ..., -9.0942e-03,
           1.7822e-02, -9.8877e-03],
         [-3.9648e-01,  2.2656e-01,  2.7344e-01,  ..., -8.3984e-01,
           4.1602e-01, -3.8086e-01],
         [-1.7383e-01,  4.6094e-01, -1.5859e+00,  ...,  3.0469e-01,
          -1.4375e+00, -8.2812e-01],
         ...,
         [ 7.5781e-01,  2.2344e+00,  2.4062e+00,  ..., -2.7222e-02,
           1.5000e+00,  7.8516e-01],
         [ 1.0234e+00,  1.5312e+00,  8.7891e-03,  ..., -4.0234e-01,
           1.0078e+00, -3.8330e-02],
         [-9.5703e-01,  5.8594e-01, -5.7422e-01,  ...,  4.8828e-01,
           4.1016e-01, -6.6016e-01]],

        [[ 2.4658e-02, -2.5558e-04, -4.5166e-03,  ..., -9.0942e-03,
           1.7822e-02, -9.8877e-03],
         [-3.9648e-01,  2.2656e-01,  2.7344e-01,  ..., -8.3984e-01,
           4.1602e-01, -3.8086e-01],
         [-1.7383e-01,  4.6094e-01, -1.5859e+00,  ...,  3.0469e-01,
          -1.4375e+00, -8.2812e-01],
         ...,
         [ 7.5781e-01,  2.2344e+00,  2.4062e+00,  ..., -2.7222e-02,
           1.5000e+00,  7.8516e-01],
         [ 1.0234e+00,  1.5312e+00,  8.7891e-03,  ..., -4.0234e-01,
           1.0078e+00, -3.8330e-02],
         [-9.5703e-01,  5.8594e-01, -5.7422e-01,  ...,  4.8828e-01,
           4.1016e-01, -6.6016e-01]],

        [[ 2.4658e-02, -2.5558e-04, -4.5166e-03,  ..., -9.0942e-03,
           1.7822e-02, -9.8877e-03],
         [-3.9648e-01,  2.2656e-01,  2.7344e-01,  ..., -8.3984e-01,
           4.1602e-01, -3.8086e-01],
         [-1.7383e-01,  4.6094e-01, -1.5859e+00,  ...,  3.0469e-01,
          -1.4375e+00, -8.2812e-01],
         ...,
         [ 7.5781e-01,  2.2344e+00,  2.4062e+00,  ..., -2.7222e-02,
           1.5000e+00,  7.8516e-01],
         [ 1.0234e+00,  1.5312e+00,  8.7891e-03,  ..., -4.0234e-01,
           1.0078e+00, -3.8330e-02],
         [-9.5703e-01,  5.8594e-01, -5.7422e-01,  ...,  4.8828e-01,
           4.1016e-01, -6.6016e-01]],

        ...,

        [[-1.2085e-02, -7.4463e-03, -1.9836e-03,  ..., -1.0315e-02,
          -3.5858e-03, -1.2756e-02],
         [-3.5742e-01, -1.9336e-01, -2.2656e-01,  ..., -8.3984e-01,
           5.3906e-01, -1.9141e+00],
         [-9.7656e-03, -3.3594e-01,  3.7109e-02,  ...,  8.7109e-01,
          -7.7734e-01, -7.7734e-01],
         ...,
         [ 6.2109e-01,  8.6328e-01, -1.0625e+00,  ..., -5.8203e-01,
           1.2891e+00,  4.4678e-02],
         [ 7.6660e-02, -1.6699e-01, -1.4062e+00,  ...,  5.7422e-01,
           8.7891e-01,  2.0215e-01],
         [-3.9062e-01, -1.0156e+00, -8.7500e-01,  ..., -6.6406e-01,
           9.0234e-01, -2.3071e-02]],

        [[-1.2085e-02, -7.4463e-03, -1.9836e-03,  ..., -1.0315e-02,
          -3.5858e-03, -1.2756e-02],
         [-3.5742e-01, -1.9336e-01, -2.2656e-01,  ..., -8.3984e-01,
           5.3906e-01, -1.9141e+00],
         [-9.7656e-03, -3.3594e-01,  3.7109e-02,  ...,  8.7109e-01,
          -7.7734e-01, -7.7734e-01],
         ...,
         [ 6.2109e-01,  8.6328e-01, -1.0625e+00,  ..., -5.8203e-01,
           1.2891e+00,  4.4678e-02],
         [ 7.6660e-02, -1.6699e-01, -1.4062e+00,  ...,  5.7422e-01,
           8.7891e-01,  2.0215e-01],
         [-3.9062e-01, -1.0156e+00, -8.7500e-01,  ..., -6.6406e-01,
           9.0234e-01, -2.3071e-02]],

        [[-1.2085e-02, -7.4463e-03, -1.9836e-03,  ..., -1.0315e-02,
          -3.5858e-03, -1.2756e-02],
         [-3.5742e-01, -1.9336e-01, -2.2656e-01,  ..., -8.3984e-01,
           5.3906e-01, -1.9141e+00],
         [-9.7656e-03, -3.3594e-01,  3.7109e-02,  ...,  8.7109e-01,
          -7.7734e-01, -7.7734e-01],
         ...,
         [ 6.2109e-01,  8.6328e-01, -1.0625e+00,  ..., -5.8203e-01,
           1.2891e+00,  4.4678e-02],
         [ 7.6660e-02, -1.6699e-01, -1.4062e+00,  ...,  5.7422e-01,
           8.7891e-01,  2.0215e-01],
         [-3.9062e-01, -1.0156e+00, -8.7500e-01,  ..., -6.6406e-01,
           9.0234e-01, -2.3071e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-3.1433e-03,  4.5776e-03,  1.3306e-02,  ..., -1.6016e-01,
           2.3535e-01,  3.0273e-01],
         [-3.1719e+00, -1.2656e+00, -1.2812e+00,  ...,  2.9883e-01,
          -2.4805e-01, -5.2734e-01],
         [-1.9141e-01, -8.1250e-01, -4.2969e-01,  ..., -3.5889e-02,
          -1.2031e+00, -1.0234e+00],
         ...,
         [-1.8594e+00,  1.0703e+00, -1.7188e-01,  ...,  1.1484e+00,
          -7.7344e-01, -9.6484e-01],
         [-6.2891e-01,  2.5781e+00, -1.2500e+00,  ...,  1.0312e+00,
          -3.5938e-01, -2.7344e-01],
         [ 2.9219e+00,  1.4531e+00, -1.7578e+00,  ...,  1.1406e+00,
           1.0234e+00, -1.8672e+00]],

        [[-3.1433e-03,  4.5776e-03,  1.3306e-02,  ..., -1.6016e-01,
           2.3535e-01,  3.0273e-01],
         [-3.1719e+00, -1.2656e+00, -1.2812e+00,  ...,  2.9883e-01,
          -2.4805e-01, -5.2734e-01],
         [-1.9141e-01, -8.1250e-01, -4.2969e-01,  ..., -3.5889e-02,
          -1.2031e+00, -1.0234e+00],
         ...,
         [-1.8594e+00,  1.0703e+00, -1.7188e-01,  ...,  1.1484e+00,
          -7.7344e-01, -9.6484e-01],
         [-6.2891e-01,  2.5781e+00, -1.2500e+00,  ...,  1.0312e+00,
          -3.5938e-01, -2.7344e-01],
         [ 2.9219e+00,  1.4531e+00, -1.7578e+00,  ...,  1.1406e+00,
           1.0234e+00, -1.8672e+00]],

        [[-3.1433e-03,  4.5776e-03,  1.3306e-02,  ..., -1.6016e-01,
           2.3535e-01,  3.0273e-01],
         [-3.1719e+00, -1.2656e+00, -1.2812e+00,  ...,  2.9883e-01,
          -2.4805e-01, -5.2734e-01],
         [-1.9141e-01, -8.1250e-01, -4.2969e-01,  ..., -3.5889e-02,
          -1.2031e+00, -1.0234e+00],
         ...,
         [-1.8594e+00,  1.0703e+00, -1.7188e-01,  ...,  1.1484e+00,
          -7.7344e-01, -9.6484e-01],
         [-6.2891e-01,  2.5781e+00, -1.2500e+00,  ...,  1.0312e+00,
          -3.5938e-01, -2.7344e-01],
         [ 2.9219e+00,  1.4531e+00, -1.7578e+00,  ...,  1.1406e+00,
           1.0234e+00, -1.8672e+00]],

        ...,

        [[ 6.1798e-04,  9.0027e-04, -2.0996e-02,  ...,  1.0681e-03,
          -4.8828e-02, -1.9336e-01],
         [-3.7031e+00, -1.8047e+00, -6.6406e-02,  ...,  2.6562e-01,
           1.8594e+00,  1.1914e-01],
         [-7.6953e-01, -1.3828e+00,  2.4414e-02,  ...,  5.9375e-01,
          -1.5391e+00, -1.1484e+00],
         ...,
         [-2.1719e+00,  1.4062e+00, -5.5469e-01,  ...,  7.3047e-01,
           8.5938e-01, -3.2031e-01],
         [-6.4062e-01,  1.2500e+00, -9.1406e-01,  ...,  3.3984e-01,
          -2.2188e+00, -9.2969e-01],
         [ 2.3750e+00,  1.8594e+00, -1.6797e+00,  ...,  1.4609e+00,
           2.5391e-01, -1.1406e+00]],

        [[ 6.1798e-04,  9.0027e-04, -2.0996e-02,  ...,  1.0681e-03,
          -4.8828e-02, -1.9336e-01],
         [-3.7031e+00, -1.8047e+00, -6.6406e-02,  ...,  2.6562e-01,
           1.8594e+00,  1.1914e-01],
         [-7.6953e-01, -1.3828e+00,  2.4414e-02,  ...,  5.9375e-01,
          -1.5391e+00, -1.1484e+00],
         ...,
         [-2.1719e+00,  1.4062e+00, -5.5469e-01,  ...,  7.3047e-01,
           8.5938e-01, -3.2031e-01],
         [-6.4062e-01,  1.2500e+00, -9.1406e-01,  ...,  3.3984e-01,
          -2.2188e+00, -9.2969e-01],
         [ 2.3750e+00,  1.8594e+00, -1.6797e+00,  ...,  1.4609e+00,
           2.5391e-01, -1.1406e+00]],

        [[ 6.1798e-04,  9.0027e-04, -2.0996e-02,  ...,  1.0681e-03,
          -4.8828e-02, -1.9336e-01],
         [-3.7031e+00, -1.8047e+00, -6.6406e-02,  ...,  2.6562e-01,
           1.8594e+00,  1.1914e-01],
         [-7.6953e-01, -1.3828e+00,  2.4414e-02,  ...,  5.9375e-01,
          -1.5391e+00, -1.1484e+00],
         ...,
         [-2.1719e+00,  1.4062e+00, -5.5469e-01,  ...,  7.3047e-01,
           8.5938e-01, -3.2031e-01],
         [-6.4062e-01,  1.2500e+00, -9.1406e-01,  ...,  3.3984e-01,
          -2.2188e+00, -9.2969e-01],
         [ 2.3750e+00,  1.8594e+00, -1.6797e+00,  ...,  1.4609e+00,
           2.5391e-01, -1.1406e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 1.0925e-02, -2.8229e-03, -1.0071e-02,  ..., -1.0803e-02,
           2.3193e-02, -4.3030e-03],
         [-3.6328e-01,  2.4512e-01, -1.2031e+00,  ..., -1.0498e-02,
           2.5781e-01, -9.3359e-01],
         [-7.7637e-02,  4.9219e-01, -2.6367e-01,  ...,  1.1016e+00,
          -1.1562e+00, -8.4473e-02],
         ...,
         [-1.0625e+00,  1.0703e+00, -4.9023e-01,  ...,  4.3164e-01,
          -7.1094e-01, -4.4141e-01],
         [-2.4902e-01, -3.9648e-01, -8.9355e-02,  ...,  9.7266e-01,
          -1.2812e+00,  1.2266e+00],
         [-6.9531e-01,  2.7539e-01, -6.8359e-01,  ...,  9.1797e-01,
          -1.5156e+00,  8.3594e-01]],

        [[ 1.0925e-02, -2.8229e-03, -1.0071e-02,  ..., -1.0803e-02,
           2.3193e-02, -4.3030e-03],
         [-3.6328e-01,  2.4512e-01, -1.2031e+00,  ..., -1.0498e-02,
           2.5781e-01, -9.3359e-01],
         [-7.7637e-02,  4.9219e-01, -2.6367e-01,  ...,  1.1016e+00,
          -1.1562e+00, -8.4473e-02],
         ...,
         [-1.0625e+00,  1.0703e+00, -4.9023e-01,  ...,  4.3164e-01,
          -7.1094e-01, -4.4141e-01],
         [-2.4902e-01, -3.9648e-01, -8.9355e-02,  ...,  9.7266e-01,
          -1.2812e+00,  1.2266e+00],
         [-6.9531e-01,  2.7539e-01, -6.8359e-01,  ...,  9.1797e-01,
          -1.5156e+00,  8.3594e-01]],

        [[ 1.0925e-02, -2.8229e-03, -1.0071e-02,  ..., -1.0803e-02,
           2.3193e-02, -4.3030e-03],
         [-3.6328e-01,  2.4512e-01, -1.2031e+00,  ..., -1.0498e-02,
           2.5781e-01, -9.3359e-01],
         [-7.7637e-02,  4.9219e-01, -2.6367e-01,  ...,  1.1016e+00,
          -1.1562e+00, -8.4473e-02],
         ...,
         [-1.0625e+00,  1.0703e+00, -4.9023e-01,  ...,  4.3164e-01,
          -7.1094e-01, -4.4141e-01],
         [-2.4902e-01, -3.9648e-01, -8.9355e-02,  ...,  9.7266e-01,
          -1.2812e+00,  1.2266e+00],
         [-6.9531e-01,  2.7539e-01, -6.8359e-01,  ...,  9.1797e-01,
          -1.5156e+00,  8.3594e-01]],

        ...,

        [[-6.8283e-04, -3.7079e-03,  7.7820e-04,  ...,  2.1118e-02,
           6.7444e-03,  5.3406e-03],
         [-2.0508e-01, -2.8516e-01,  4.9609e-01,  ...,  1.5938e+00,
           1.1719e+00,  5.5078e-01],
         [-2.8125e-01, -1.6992e-01, -5.9766e-01,  ...,  4.9414e-01,
           1.4160e-02,  4.6289e-01],
         ...,
         [ 1.2695e-01, -1.2891e+00,  1.0986e-01,  ...,  1.1250e+00,
          -2.4219e+00, -1.2793e-01],
         [ 1.2012e-01,  8.4375e-01,  1.5625e+00,  ..., -9.3750e-02,
          -8.6719e-01,  1.2812e+00],
         [-8.2812e-01, -3.7305e-01,  6.7871e-02,  ..., -4.4336e-01,
           1.5527e-01, -7.6562e-01]],

        [[-6.8283e-04, -3.7079e-03,  7.7820e-04,  ...,  2.1118e-02,
           6.7444e-03,  5.3406e-03],
         [-2.0508e-01, -2.8516e-01,  4.9609e-01,  ...,  1.5938e+00,
           1.1719e+00,  5.5078e-01],
         [-2.8125e-01, -1.6992e-01, -5.9766e-01,  ...,  4.9414e-01,
           1.4160e-02,  4.6289e-01],
         ...,
         [ 1.2695e-01, -1.2891e+00,  1.0986e-01,  ...,  1.1250e+00,
          -2.4219e+00, -1.2793e-01],
         [ 1.2012e-01,  8.4375e-01,  1.5625e+00,  ..., -9.3750e-02,
          -8.6719e-01,  1.2812e+00],
         [-8.2812e-01, -3.7305e-01,  6.7871e-02,  ..., -4.4336e-01,
           1.5527e-01, -7.6562e-01]],

        [[-6.8283e-04, -3.7079e-03,  7.7820e-04,  ...,  2.1118e-02,
           6.7444e-03,  5.3406e-03],
         [-2.0508e-01, -2.8516e-01,  4.9609e-01,  ...,  1.5938e+00,
           1.1719e+00,  5.5078e-01],
         [-2.8125e-01, -1.6992e-01, -5.9766e-01,  ...,  4.9414e-01,
           1.4160e-02,  4.6289e-01],
         ...,
         [ 1.2695e-01, -1.2891e+00,  1.0986e-01,  ...,  1.1250e+00,
          -2.4219e+00, -1.2793e-01],
         [ 1.2012e-01,  8.4375e-01,  1.5625e+00,  ..., -9.3750e-02,
          -8.6719e-01,  1.2812e+00],
         [-8.2812e-01, -3.7305e-01,  6.7871e-02,  ..., -4.4336e-01,
           1.5527e-01, -7.6562e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.0437e-02, -6.3782e-03, -7.9346e-03,  ..., -1.8750e+00,
           8.4473e-02,  7.5781e-01],
         [-4.1016e-01,  6.1719e-01,  5.3906e-01,  ...,  1.1016e+00,
           1.3281e-01, -5.6562e+00],
         [ 4.3945e-01, -1.9336e-01,  2.6562e-01,  ...,  4.3438e+00,
          -3.1250e+00, -2.8594e+00],
         ...,
         [-3.6328e-01, -4.7852e-02, -5.0391e-01,  ...,  3.5781e+00,
          -9.4531e-01, -7.7344e-01],
         [ 2.4609e-01,  2.9688e-01, -1.9922e-01,  ...,  4.0000e+00,
           1.8652e-01, -2.3594e+00],
         [-2.8125e-01, -3.0078e-01, -2.6367e-01,  ...,  3.5312e+00,
           6.6406e-01, -2.9531e+00]],

        [[ 1.0437e-02, -6.3782e-03, -7.9346e-03,  ..., -1.8750e+00,
           8.4473e-02,  7.5781e-01],
         [-4.1016e-01,  6.1719e-01,  5.3906e-01,  ...,  1.1016e+00,
           1.3281e-01, -5.6562e+00],
         [ 4.3945e-01, -1.9336e-01,  2.6562e-01,  ...,  4.3438e+00,
          -3.1250e+00, -2.8594e+00],
         ...,
         [-3.6328e-01, -4.7852e-02, -5.0391e-01,  ...,  3.5781e+00,
          -9.4531e-01, -7.7344e-01],
         [ 2.4609e-01,  2.9688e-01, -1.9922e-01,  ...,  4.0000e+00,
           1.8652e-01, -2.3594e+00],
         [-2.8125e-01, -3.0078e-01, -2.6367e-01,  ...,  3.5312e+00,
           6.6406e-01, -2.9531e+00]],

        [[ 1.0437e-02, -6.3782e-03, -7.9346e-03,  ..., -1.8750e+00,
           8.4473e-02,  7.5781e-01],
         [-4.1016e-01,  6.1719e-01,  5.3906e-01,  ...,  1.1016e+00,
           1.3281e-01, -5.6562e+00],
         [ 4.3945e-01, -1.9336e-01,  2.6562e-01,  ...,  4.3438e+00,
          -3.1250e+00, -2.8594e+00],
         ...,
         [-3.6328e-01, -4.7852e-02, -5.0391e-01,  ...,  3.5781e+00,
          -9.4531e-01, -7.7344e-01],
         [ 2.4609e-01,  2.9688e-01, -1.9922e-01,  ...,  4.0000e+00,
           1.8652e-01, -2.3594e+00],
         [-2.8125e-01, -3.0078e-01, -2.6367e-01,  ...,  3.5312e+00,
           6.6406e-01, -2.9531e+00]],

        ...,

        [[-6.0730e-03, -9.9487e-03, -2.1057e-03,  ...,  1.2158e-01,
           2.1484e-01,  6.1646e-03],
         [-1.2656e+00, -9.0625e-01,  1.5625e+00,  ...,  1.6641e+00,
           8.0469e-01, -7.3438e-01],
         [-1.6211e-01,  7.6172e-02,  1.0312e+00,  ...,  5.5625e+00,
           3.5312e+00,  3.0664e-01],
         ...,
         [-8.9062e-01, -7.3438e-01, -8.0469e-01,  ...,  3.4062e+00,
           2.8750e+00,  3.3594e+00],
         [-1.5039e-01,  3.1445e-01,  5.5078e-01,  ...,  4.2500e+00,
           3.8594e+00,  7.0312e+00],
         [-3.2227e-01,  4.5312e-01,  6.8750e-01,  ...,  4.5312e+00,
           1.7969e+00,  3.2812e+00]],

        [[-6.0730e-03, -9.9487e-03, -2.1057e-03,  ...,  1.2158e-01,
           2.1484e-01,  6.1646e-03],
         [-1.2656e+00, -9.0625e-01,  1.5625e+00,  ...,  1.6641e+00,
           8.0469e-01, -7.3438e-01],
         [-1.6211e-01,  7.6172e-02,  1.0312e+00,  ...,  5.5625e+00,
           3.5312e+00,  3.0664e-01],
         ...,
         [-8.9062e-01, -7.3438e-01, -8.0469e-01,  ...,  3.4062e+00,
           2.8750e+00,  3.3594e+00],
         [-1.5039e-01,  3.1445e-01,  5.5078e-01,  ...,  4.2500e+00,
           3.8594e+00,  7.0312e+00],
         [-3.2227e-01,  4.5312e-01,  6.8750e-01,  ...,  4.5312e+00,
           1.7969e+00,  3.2812e+00]],

        [[-6.0730e-03, -9.9487e-03, -2.1057e-03,  ...,  1.2158e-01,
           2.1484e-01,  6.1646e-03],
         [-1.2656e+00, -9.0625e-01,  1.5625e+00,  ...,  1.6641e+00,
           8.0469e-01, -7.3438e-01],
         [-1.6211e-01,  7.6172e-02,  1.0312e+00,  ...,  5.5625e+00,
           3.5312e+00,  3.0664e-01],
         ...,
         [-8.9062e-01, -7.3438e-01, -8.0469e-01,  ...,  3.4062e+00,
           2.8750e+00,  3.3594e+00],
         [-1.5039e-01,  3.1445e-01,  5.5078e-01,  ...,  4.2500e+00,
           3.8594e+00,  7.0312e+00],
         [-3.2227e-01,  4.5312e-01,  6.8750e-01,  ...,  4.5312e+00,
           1.7969e+00,  3.2812e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.1658e-02,  3.4332e-03,  1.4099e-02,  ...,  2.7954e-02,
           8.4229e-03, -1.0010e-02],
         [-8.8672e-01,  2.0156e+00, -1.7500e+00,  ..., -9.3750e-01,
           2.1719e+00,  2.9531e+00],
         [ 1.2969e+00, -3.3750e+00,  9.1406e-01,  ...,  8.5156e-01,
           4.4727e-01,  2.2031e+00],
         ...,
         [ 7.6172e-01,  1.3906e+00, -6.0938e-01,  ..., -8.8672e-01,
           6.7578e-01,  3.3691e-02],
         [ 4.3555e-01, -5.9766e-01, -2.5391e-02,  ...,  5.8984e-01,
          -2.7148e-01,  5.9375e-01],
         [ 3.0312e+00, -8.5156e-01,  1.4844e+00,  ...,  1.1875e+00,
          -7.4609e-01, -8.9844e-02]],

        [[-1.1658e-02,  3.4332e-03,  1.4099e-02,  ...,  2.7954e-02,
           8.4229e-03, -1.0010e-02],
         [-8.8672e-01,  2.0156e+00, -1.7500e+00,  ..., -9.3750e-01,
           2.1719e+00,  2.9531e+00],
         [ 1.2969e+00, -3.3750e+00,  9.1406e-01,  ...,  8.5156e-01,
           4.4727e-01,  2.2031e+00],
         ...,
         [ 7.6172e-01,  1.3906e+00, -6.0938e-01,  ..., -8.8672e-01,
           6.7578e-01,  3.3691e-02],
         [ 4.3555e-01, -5.9766e-01, -2.5391e-02,  ...,  5.8984e-01,
          -2.7148e-01,  5.9375e-01],
         [ 3.0312e+00, -8.5156e-01,  1.4844e+00,  ...,  1.1875e+00,
          -7.4609e-01, -8.9844e-02]],

        [[-1.1658e-02,  3.4332e-03,  1.4099e-02,  ...,  2.7954e-02,
           8.4229e-03, -1.0010e-02],
         [-8.8672e-01,  2.0156e+00, -1.7500e+00,  ..., -9.3750e-01,
           2.1719e+00,  2.9531e+00],
         [ 1.2969e+00, -3.3750e+00,  9.1406e-01,  ...,  8.5156e-01,
           4.4727e-01,  2.2031e+00],
         ...,
         [ 7.6172e-01,  1.3906e+00, -6.0938e-01,  ..., -8.8672e-01,
           6.7578e-01,  3.3691e-02],
         [ 4.3555e-01, -5.9766e-01, -2.5391e-02,  ...,  5.8984e-01,
          -2.7148e-01,  5.9375e-01],
         [ 3.0312e+00, -8.5156e-01,  1.4844e+00,  ...,  1.1875e+00,
          -7.4609e-01, -8.9844e-02]],

        ...,

        [[ 2.0599e-03,  1.2146e-02, -2.2095e-02,  ...,  1.0620e-02,
          -2.1057e-03,  3.3264e-03],
         [ 2.0000e+00,  6.7578e-01,  7.5391e-01,  ...,  1.5469e+00,
           4.0527e-02, -1.3438e+00],
         [ 4.1797e-01,  1.6406e+00,  1.8672e+00,  ...,  1.5234e+00,
          -2.3281e+00, -2.5625e+00],
         ...,
         [ 4.9062e+00,  3.6250e+00,  2.4531e+00,  ...,  1.8828e+00,
          -5.5000e+00, -2.2500e+00],
         [ 5.8438e+00,  3.8438e+00,  3.3281e+00,  ...,  3.5156e+00,
          -5.9375e+00, -3.8438e+00],
         [ 3.6875e+00,  2.4062e+00,  1.6172e+00,  ...,  3.0469e+00,
          -4.8438e+00, -2.1562e+00]],

        [[ 2.0599e-03,  1.2146e-02, -2.2095e-02,  ...,  1.0620e-02,
          -2.1057e-03,  3.3264e-03],
         [ 2.0000e+00,  6.7578e-01,  7.5391e-01,  ...,  1.5469e+00,
           4.0527e-02, -1.3438e+00],
         [ 4.1797e-01,  1.6406e+00,  1.8672e+00,  ...,  1.5234e+00,
          -2.3281e+00, -2.5625e+00],
         ...,
         [ 4.9062e+00,  3.6250e+00,  2.4531e+00,  ...,  1.8828e+00,
          -5.5000e+00, -2.2500e+00],
         [ 5.8438e+00,  3.8438e+00,  3.3281e+00,  ...,  3.5156e+00,
          -5.9375e+00, -3.8438e+00],
         [ 3.6875e+00,  2.4062e+00,  1.6172e+00,  ...,  3.0469e+00,
          -4.8438e+00, -2.1562e+00]],

        [[ 2.0599e-03,  1.2146e-02, -2.2095e-02,  ...,  1.0620e-02,
          -2.1057e-03,  3.3264e-03],
         [ 2.0000e+00,  6.7578e-01,  7.5391e-01,  ...,  1.5469e+00,
           4.0527e-02, -1.3438e+00],
         [ 4.1797e-01,  1.6406e+00,  1.8672e+00,  ...,  1.5234e+00,
          -2.3281e+00, -2.5625e+00],
         ...,
         [ 4.9062e+00,  3.6250e+00,  2.4531e+00,  ...,  1.8828e+00,
          -5.5000e+00, -2.2500e+00],
         [ 5.8438e+00,  3.8438e+00,  3.3281e+00,  ...,  3.5156e+00,
          -5.9375e+00, -3.8438e+00],
         [ 3.6875e+00,  2.4062e+00,  1.6172e+00,  ...,  3.0469e+00,
          -4.8438e+00, -2.1562e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 8.5831e-04, -4.5166e-03,  2.7008e-03,  ...,  4.5166e-02,
           3.1055e-01, -8.4961e-02],
         [ 3.1250e+00,  2.2656e+00, -6.4062e-01,  ...,  6.3281e-01,
          -5.9375e-01,  1.2500e+00],
         [ 1.9375e+00,  1.9688e+00,  1.0625e+00,  ...,  5.1953e-01,
           5.8105e-02,  9.5312e-01],
         ...,
         [ 1.2969e+00, -5.9375e-01,  1.4062e+00,  ..., -2.9541e-02,
           8.9844e-01,  2.5195e-01],
         [ 2.3438e+00, -1.2969e+00, -2.4219e-01,  ...,  3.6914e-01,
          -3.2959e-02, -3.6328e-01],
         [ 9.8047e-01, -1.1484e+00, -1.8750e-01,  ...,  7.5000e-01,
          -1.0625e+00, -4.8047e-01]],

        [[ 8.5831e-04, -4.5166e-03,  2.7008e-03,  ...,  4.5166e-02,
           3.1055e-01, -8.4961e-02],
         [ 3.1250e+00,  2.2656e+00, -6.4062e-01,  ...,  6.3281e-01,
          -5.9375e-01,  1.2500e+00],
         [ 1.9375e+00,  1.9688e+00,  1.0625e+00,  ...,  5.1953e-01,
           5.8105e-02,  9.5312e-01],
         ...,
         [ 1.2969e+00, -5.9375e-01,  1.4062e+00,  ..., -2.9541e-02,
           8.9844e-01,  2.5195e-01],
         [ 2.3438e+00, -1.2969e+00, -2.4219e-01,  ...,  3.6914e-01,
          -3.2959e-02, -3.6328e-01],
         [ 9.8047e-01, -1.1484e+00, -1.8750e-01,  ...,  7.5000e-01,
          -1.0625e+00, -4.8047e-01]],

        [[ 8.5831e-04, -4.5166e-03,  2.7008e-03,  ...,  4.5166e-02,
           3.1055e-01, -8.4961e-02],
         [ 3.1250e+00,  2.2656e+00, -6.4062e-01,  ...,  6.3281e-01,
          -5.9375e-01,  1.2500e+00],
         [ 1.9375e+00,  1.9688e+00,  1.0625e+00,  ...,  5.1953e-01,
           5.8105e-02,  9.5312e-01],
         ...,
         [ 1.2969e+00, -5.9375e-01,  1.4062e+00,  ..., -2.9541e-02,
           8.9844e-01,  2.5195e-01],
         [ 2.3438e+00, -1.2969e+00, -2.4219e-01,  ...,  3.6914e-01,
          -3.2959e-02, -3.6328e-01],
         [ 9.8047e-01, -1.1484e+00, -1.8750e-01,  ...,  7.5000e-01,
          -1.0625e+00, -4.8047e-01]],

        ...,

        [[ 3.2959e-03, -2.9755e-03, -6.4392e-03,  ...,  2.6562e-01,
           4.3945e-01,  1.4551e-01],
         [-3.2031e+00,  1.4766e+00, -1.5938e+00,  ..., -3.1641e-01,
          -1.6484e+00, -4.7852e-01],
         [ 2.7344e-02,  1.8262e-01, -1.2031e+00,  ..., -1.3359e+00,
          -2.0781e+00,  1.5938e+00],
         ...,
         [-2.4062e+00, -1.2812e+00, -9.6484e-01,  ..., -9.5703e-01,
          -1.3906e+00,  5.3125e-01],
         [-2.4219e-01, -1.3047e+00, -9.2578e-01,  ...,  8.5938e-01,
          -2.3438e+00,  1.2891e+00],
         [ 2.6719e+00, -1.6016e+00, -2.4219e+00,  ..., -1.0156e+00,
          -2.2188e+00, -1.3867e-01]],

        [[ 3.2959e-03, -2.9755e-03, -6.4392e-03,  ...,  2.6562e-01,
           4.3945e-01,  1.4551e-01],
         [-3.2031e+00,  1.4766e+00, -1.5938e+00,  ..., -3.1641e-01,
          -1.6484e+00, -4.7852e-01],
         [ 2.7344e-02,  1.8262e-01, -1.2031e+00,  ..., -1.3359e+00,
          -2.0781e+00,  1.5938e+00],
         ...,
         [-2.4062e+00, -1.2812e+00, -9.6484e-01,  ..., -9.5703e-01,
          -1.3906e+00,  5.3125e-01],
         [-2.4219e-01, -1.3047e+00, -9.2578e-01,  ...,  8.5938e-01,
          -2.3438e+00,  1.2891e+00],
         [ 2.6719e+00, -1.6016e+00, -2.4219e+00,  ..., -1.0156e+00,
          -2.2188e+00, -1.3867e-01]],

        [[ 3.2959e-03, -2.9755e-03, -6.4392e-03,  ...,  2.6562e-01,
           4.3945e-01,  1.4551e-01],
         [-3.2031e+00,  1.4766e+00, -1.5938e+00,  ..., -3.1641e-01,
          -1.6484e+00, -4.7852e-01],
         [ 2.7344e-02,  1.8262e-01, -1.2031e+00,  ..., -1.3359e+00,
          -2.0781e+00,  1.5938e+00],
         ...,
         [-2.4062e+00, -1.2812e+00, -9.6484e-01,  ..., -9.5703e-01,
          -1.3906e+00,  5.3125e-01],
         [-2.4219e-01, -1.3047e+00, -9.2578e-01,  ...,  8.5938e-01,
          -2.3438e+00,  1.2891e+00],
         [ 2.6719e+00, -1.6016e+00, -2.4219e+00,  ..., -1.0156e+00,
          -2.2188e+00, -1.3867e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-2.0264e-02,  7.4005e-04, -1.5869e-02,  ..., -1.4893e-02,
           4.4250e-03, -2.7771e-03],
         [ 5.3125e-01,  7.6172e-01,  1.8203e+00,  ..., -3.2812e-01,
          -6.6016e-01, -5.2734e-01],
         [ 1.6328e+00,  1.3438e+00,  4.5898e-01,  ...,  7.7344e-01,
          -7.0703e-01,  9.9219e-01],
         ...,
         [ 1.8457e-01, -9.2969e-01,  1.8594e+00,  ...,  2.0996e-01,
           9.4922e-01,  7.8516e-01],
         [-3.9844e-01, -9.1406e-01, -7.1094e-01,  ..., -5.8984e-01,
           7.3438e-01,  8.3496e-02],
         [ 4.0283e-03,  5.5469e-01, -7.4219e-01,  ..., -5.6250e-01,
           1.1406e+00, -1.2578e+00]],

        [[-2.0264e-02,  7.4005e-04, -1.5869e-02,  ..., -1.4893e-02,
           4.4250e-03, -2.7771e-03],
         [ 5.3125e-01,  7.6172e-01,  1.8203e+00,  ..., -3.2812e-01,
          -6.6016e-01, -5.2734e-01],
         [ 1.6328e+00,  1.3438e+00,  4.5898e-01,  ...,  7.7344e-01,
          -7.0703e-01,  9.9219e-01],
         ...,
         [ 1.8457e-01, -9.2969e-01,  1.8594e+00,  ...,  2.0996e-01,
           9.4922e-01,  7.8516e-01],
         [-3.9844e-01, -9.1406e-01, -7.1094e-01,  ..., -5.8984e-01,
           7.3438e-01,  8.3496e-02],
         [ 4.0283e-03,  5.5469e-01, -7.4219e-01,  ..., -5.6250e-01,
           1.1406e+00, -1.2578e+00]],

        [[-2.0264e-02,  7.4005e-04, -1.5869e-02,  ..., -1.4893e-02,
           4.4250e-03, -2.7771e-03],
         [ 5.3125e-01,  7.6172e-01,  1.8203e+00,  ..., -3.2812e-01,
          -6.6016e-01, -5.2734e-01],
         [ 1.6328e+00,  1.3438e+00,  4.5898e-01,  ...,  7.7344e-01,
          -7.0703e-01,  9.9219e-01],
         ...,
         [ 1.8457e-01, -9.2969e-01,  1.8594e+00,  ...,  2.0996e-01,
           9.4922e-01,  7.8516e-01],
         [-3.9844e-01, -9.1406e-01, -7.1094e-01,  ..., -5.8984e-01,
           7.3438e-01,  8.3496e-02],
         [ 4.0283e-03,  5.5469e-01, -7.4219e-01,  ..., -5.6250e-01,
           1.1406e+00, -1.2578e+00]],

        ...,

        [[ 6.8665e-04, -4.1809e-03, -1.6968e-02,  ...,  1.6235e-02,
           5.7068e-03,  1.3062e-02],
         [ 2.1191e-01, -4.6289e-01,  7.0703e-01,  ..., -3.9648e-01,
          -3.3984e-01,  4.9805e-01],
         [-1.2793e-01, -1.2500e-01,  1.2734e+00,  ...,  2.5391e-01,
           2.8906e-01, -1.5469e+00],
         ...,
         [-1.0234e+00,  1.1250e+00,  1.3047e+00,  ...,  9.9609e-01,
          -1.0938e+00, -8.7891e-01],
         [ 5.3516e-01, -2.3047e-01,  1.5312e+00,  ..., -4.1875e+00,
           3.9062e-01, -1.0625e+00],
         [ 6.4844e-01,  6.7969e-01,  5.7812e-01,  ..., -3.5156e-01,
          -7.2266e-01, -8.9062e-01]],

        [[ 6.8665e-04, -4.1809e-03, -1.6968e-02,  ...,  1.6235e-02,
           5.7068e-03,  1.3062e-02],
         [ 2.1191e-01, -4.6289e-01,  7.0703e-01,  ..., -3.9648e-01,
          -3.3984e-01,  4.9805e-01],
         [-1.2793e-01, -1.2500e-01,  1.2734e+00,  ...,  2.5391e-01,
           2.8906e-01, -1.5469e+00],
         ...,
         [-1.0234e+00,  1.1250e+00,  1.3047e+00,  ...,  9.9609e-01,
          -1.0938e+00, -8.7891e-01],
         [ 5.3516e-01, -2.3047e-01,  1.5312e+00,  ..., -4.1875e+00,
           3.9062e-01, -1.0625e+00],
         [ 6.4844e-01,  6.7969e-01,  5.7812e-01,  ..., -3.5156e-01,
          -7.2266e-01, -8.9062e-01]],

        [[ 6.8665e-04, -4.1809e-03, -1.6968e-02,  ...,  1.6235e-02,
           5.7068e-03,  1.3062e-02],
         [ 2.1191e-01, -4.6289e-01,  7.0703e-01,  ..., -3.9648e-01,
          -3.3984e-01,  4.9805e-01],
         [-1.2793e-01, -1.2500e-01,  1.2734e+00,  ...,  2.5391e-01,
           2.8906e-01, -1.5469e+00],
         ...,
         [-1.0234e+00,  1.1250e+00,  1.3047e+00,  ...,  9.9609e-01,
          -1.0938e+00, -8.7891e-01],
         [ 5.3516e-01, -2.3047e-01,  1.5312e+00,  ..., -4.1875e+00,
           3.9062e-01, -1.0625e+00],
         [ 6.4844e-01,  6.7969e-01,  5.7812e-01,  ..., -3.5156e-01,
          -7.2266e-01, -8.9062e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 6.1340e-03, -1.6724e-02,  5.8365e-04,  ...,  3.6328e-01,
          -4.5898e-02, -1.3086e-01],
         [ 1.1641e+00,  6.7969e-01, -1.1875e+00,  ..., -9.2188e-01,
          -5.7373e-02,  9.6094e-01],
         [ 1.6875e+00,  1.3438e+00, -6.3281e-01,  ..., -1.5078e+00,
          -4.3164e-01,  1.9375e+00],
         ...,
         [-6.6406e-02, -8.2422e-01,  1.5000e+00,  ..., -3.3203e-01,
          -1.2812e+00,  1.2031e+00],
         [ 7.6562e-01, -9.0625e-01, -5.5859e-01,  ..., -8.7109e-01,
           1.0352e-01,  1.8984e+00],
         [ 2.4062e+00, -6.9141e-01,  2.1484e-01,  ..., -1.1250e+00,
          -8.1641e-01,  1.0312e+00]],

        [[ 6.1340e-03, -1.6724e-02,  5.8365e-04,  ...,  3.6328e-01,
          -4.5898e-02, -1.3086e-01],
         [ 1.1641e+00,  6.7969e-01, -1.1875e+00,  ..., -9.2188e-01,
          -5.7373e-02,  9.6094e-01],
         [ 1.6875e+00,  1.3438e+00, -6.3281e-01,  ..., -1.5078e+00,
          -4.3164e-01,  1.9375e+00],
         ...,
         [-6.6406e-02, -8.2422e-01,  1.5000e+00,  ..., -3.3203e-01,
          -1.2812e+00,  1.2031e+00],
         [ 7.6562e-01, -9.0625e-01, -5.5859e-01,  ..., -8.7109e-01,
           1.0352e-01,  1.8984e+00],
         [ 2.4062e+00, -6.9141e-01,  2.1484e-01,  ..., -1.1250e+00,
          -8.1641e-01,  1.0312e+00]],

        [[ 6.1340e-03, -1.6724e-02,  5.8365e-04,  ...,  3.6328e-01,
          -4.5898e-02, -1.3086e-01],
         [ 1.1641e+00,  6.7969e-01, -1.1875e+00,  ..., -9.2188e-01,
          -5.7373e-02,  9.6094e-01],
         [ 1.6875e+00,  1.3438e+00, -6.3281e-01,  ..., -1.5078e+00,
          -4.3164e-01,  1.9375e+00],
         ...,
         [-6.6406e-02, -8.2422e-01,  1.5000e+00,  ..., -3.3203e-01,
          -1.2812e+00,  1.2031e+00],
         [ 7.6562e-01, -9.0625e-01, -5.5859e-01,  ..., -8.7109e-01,
           1.0352e-01,  1.8984e+00],
         [ 2.4062e+00, -6.9141e-01,  2.1484e-01,  ..., -1.1250e+00,
          -8.1641e-01,  1.0312e+00]],

        ...,

        [[ 1.5076e-02,  1.1780e-02,  4.5471e-03,  ..., -1.2012e-01,
          -1.2024e-02, -1.5391e+00],
         [ 3.3125e+00, -1.0156e-01, -9.9219e-01,  ..., -3.9062e-01,
           5.7812e-01,  6.6250e+00],
         [ 2.0781e+00,  1.8438e+00, -9.1406e-01,  ..., -1.8359e+00,
           4.6875e-02,  8.0625e+00],
         ...,
         [ 2.0781e+00,  1.6406e+00,  1.6875e+00,  ...,  6.8750e-01,
          -1.1953e+00,  7.0312e+00],
         [ 2.3438e+00,  1.4844e+00,  1.4062e+00,  ...,  1.2578e+00,
          -3.7109e-01,  7.4688e+00],
         [ 1.1250e+00, -1.1406e+00,  7.2656e-01,  ...,  7.4219e-01,
          -3.7695e-01,  7.5625e+00]],

        [[ 1.5076e-02,  1.1780e-02,  4.5471e-03,  ..., -1.2012e-01,
          -1.2024e-02, -1.5391e+00],
         [ 3.3125e+00, -1.0156e-01, -9.9219e-01,  ..., -3.9062e-01,
           5.7812e-01,  6.6250e+00],
         [ 2.0781e+00,  1.8438e+00, -9.1406e-01,  ..., -1.8359e+00,
           4.6875e-02,  8.0625e+00],
         ...,
         [ 2.0781e+00,  1.6406e+00,  1.6875e+00,  ...,  6.8750e-01,
          -1.1953e+00,  7.0312e+00],
         [ 2.3438e+00,  1.4844e+00,  1.4062e+00,  ...,  1.2578e+00,
          -3.7109e-01,  7.4688e+00],
         [ 1.1250e+00, -1.1406e+00,  7.2656e-01,  ...,  7.4219e-01,
          -3.7695e-01,  7.5625e+00]],

        [[ 1.5076e-02,  1.1780e-02,  4.5471e-03,  ..., -1.2012e-01,
          -1.2024e-02, -1.5391e+00],
         [ 3.3125e+00, -1.0156e-01, -9.9219e-01,  ..., -3.9062e-01,
           5.7812e-01,  6.6250e+00],
         [ 2.0781e+00,  1.8438e+00, -9.1406e-01,  ..., -1.8359e+00,
           4.6875e-02,  8.0625e+00],
         ...,
         [ 2.0781e+00,  1.6406e+00,  1.6875e+00,  ...,  6.8750e-01,
          -1.1953e+00,  7.0312e+00],
         [ 2.3438e+00,  1.4844e+00,  1.4062e+00,  ...,  1.2578e+00,
          -3.7109e-01,  7.4688e+00],
         [ 1.1250e+00, -1.1406e+00,  7.2656e-01,  ...,  7.4219e-01,
          -3.7695e-01,  7.5625e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 4.6387e-03,  8.9722e-03, -1.2268e-02,  ...,  2.8687e-02,
           2.5024e-03,  4.5776e-03],
         [-2.1094e-01,  3.4180e-02,  6.0938e-01,  ...,  7.0312e-01,
           3.2031e-01,  3.0078e-01],
         [ 7.0312e-01,  9.6094e-01, -2.3730e-01,  ..., -2.8125e-01,
           9.0625e-01, -6.3281e-01],
         ...,
         [-1.3750e+00,  1.5781e+00, -6.0303e-02,  ...,  2.2188e+00,
           5.4297e-01,  4.2236e-02],
         [ 5.4688e-01, -1.0938e+00,  1.6719e+00,  ..., -1.5469e+00,
          -2.9883e-01, -6.0547e-01],
         [ 2.5195e-01, -1.4219e+00,  7.6172e-01,  ..., -2.3906e+00,
           9.2578e-01,  7.5000e-01]],

        [[ 4.6387e-03,  8.9722e-03, -1.2268e-02,  ...,  2.8687e-02,
           2.5024e-03,  4.5776e-03],
         [-2.1094e-01,  3.4180e-02,  6.0938e-01,  ...,  7.0312e-01,
           3.2031e-01,  3.0078e-01],
         [ 7.0312e-01,  9.6094e-01, -2.3730e-01,  ..., -2.8125e-01,
           9.0625e-01, -6.3281e-01],
         ...,
         [-1.3750e+00,  1.5781e+00, -6.0303e-02,  ...,  2.2188e+00,
           5.4297e-01,  4.2236e-02],
         [ 5.4688e-01, -1.0938e+00,  1.6719e+00,  ..., -1.5469e+00,
          -2.9883e-01, -6.0547e-01],
         [ 2.5195e-01, -1.4219e+00,  7.6172e-01,  ..., -2.3906e+00,
           9.2578e-01,  7.5000e-01]],

        [[ 4.6387e-03,  8.9722e-03, -1.2268e-02,  ...,  2.8687e-02,
           2.5024e-03,  4.5776e-03],
         [-2.1094e-01,  3.4180e-02,  6.0938e-01,  ...,  7.0312e-01,
           3.2031e-01,  3.0078e-01],
         [ 7.0312e-01,  9.6094e-01, -2.3730e-01,  ..., -2.8125e-01,
           9.0625e-01, -6.3281e-01],
         ...,
         [-1.3750e+00,  1.5781e+00, -6.0303e-02,  ...,  2.2188e+00,
           5.4297e-01,  4.2236e-02],
         [ 5.4688e-01, -1.0938e+00,  1.6719e+00,  ..., -1.5469e+00,
          -2.9883e-01, -6.0547e-01],
         [ 2.5195e-01, -1.4219e+00,  7.6172e-01,  ..., -2.3906e+00,
           9.2578e-01,  7.5000e-01]],

        ...,

        [[-7.5684e-03,  1.5869e-02, -2.8198e-02,  ...,  1.9531e-02,
           6.4697e-03, -8.6975e-04],
         [ 1.1914e-01,  4.0283e-02,  1.3438e+00,  ..., -8.8672e-01,
           1.7031e+00, -2.0996e-01],
         [ 5.4297e-01, -9.4922e-01,  1.1484e+00,  ..., -2.2656e+00,
          -1.2266e+00, -2.0781e+00],
         ...,
         [ 2.4375e+00,  8.4375e-01,  4.7656e-01,  ..., -1.0703e+00,
           4.1602e-01, -4.1250e+00],
         [ 1.1562e+00, -7.8516e-01, -1.8359e-01,  ..., -1.0234e+00,
          -1.5391e+00, -1.7891e+00],
         [ 2.4844e+00, -3.0469e+00, -7.0801e-02,  ..., -2.3438e+00,
           7.8125e-01, -1.6250e+00]],

        [[-7.5684e-03,  1.5869e-02, -2.8198e-02,  ...,  1.9531e-02,
           6.4697e-03, -8.6975e-04],
         [ 1.1914e-01,  4.0283e-02,  1.3438e+00,  ..., -8.8672e-01,
           1.7031e+00, -2.0996e-01],
         [ 5.4297e-01, -9.4922e-01,  1.1484e+00,  ..., -2.2656e+00,
          -1.2266e+00, -2.0781e+00],
         ...,
         [ 2.4375e+00,  8.4375e-01,  4.7656e-01,  ..., -1.0703e+00,
           4.1602e-01, -4.1250e+00],
         [ 1.1562e+00, -7.8516e-01, -1.8359e-01,  ..., -1.0234e+00,
          -1.5391e+00, -1.7891e+00],
         [ 2.4844e+00, -3.0469e+00, -7.0801e-02,  ..., -2.3438e+00,
           7.8125e-01, -1.6250e+00]],

        [[-7.5684e-03,  1.5869e-02, -2.8198e-02,  ...,  1.9531e-02,
           6.4697e-03, -8.6975e-04],
         [ 1.1914e-01,  4.0283e-02,  1.3438e+00,  ..., -8.8672e-01,
           1.7031e+00, -2.0996e-01],
         [ 5.4297e-01, -9.4922e-01,  1.1484e+00,  ..., -2.2656e+00,
          -1.2266e+00, -2.0781e+00],
         ...,
         [ 2.4375e+00,  8.4375e-01,  4.7656e-01,  ..., -1.0703e+00,
           4.1602e-01, -4.1250e+00],
         [ 1.1562e+00, -7.8516e-01, -1.8359e-01,  ..., -1.0234e+00,
          -1.5391e+00, -1.7891e+00],
         [ 2.4844e+00, -3.0469e+00, -7.0801e-02,  ..., -2.3438e+00,
           7.8125e-01, -1.6250e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.7090e-02,  5.7983e-03,  1.4267e-03,  ..., -4.4189e-02,
          -2.8076e-03, -2.6001e-02],
         [ 2.1250e+00, -6.4844e-01, -5.2734e-01,  ...,  1.3359e+00,
          -8.1641e-01,  7.1484e-01],
         [-1.7383e-01,  3.6328e-01,  2.1406e+00,  ...,  1.6953e+00,
          -2.8750e+00, -4.3555e-01],
         ...,
         [ 5.0781e-01, -1.5312e+00,  1.0781e+00,  ..., -3.1562e+00,
           5.9375e-01,  4.2969e-01],
         [-4.1211e-01, -4.0234e-01,  1.1562e+00,  ..., -2.0156e+00,
          -1.7188e+00,  1.5391e+00],
         [-1.6172e+00, -1.4160e-01,  1.8750e+00,  ..., -1.6016e+00,
          -6.8750e-01, -1.5723e-01]],

        [[-1.7090e-02,  5.7983e-03,  1.4267e-03,  ..., -4.4189e-02,
          -2.8076e-03, -2.6001e-02],
         [ 2.1250e+00, -6.4844e-01, -5.2734e-01,  ...,  1.3359e+00,
          -8.1641e-01,  7.1484e-01],
         [-1.7383e-01,  3.6328e-01,  2.1406e+00,  ...,  1.6953e+00,
          -2.8750e+00, -4.3555e-01],
         ...,
         [ 5.0781e-01, -1.5312e+00,  1.0781e+00,  ..., -3.1562e+00,
           5.9375e-01,  4.2969e-01],
         [-4.1211e-01, -4.0234e-01,  1.1562e+00,  ..., -2.0156e+00,
          -1.7188e+00,  1.5391e+00],
         [-1.6172e+00, -1.4160e-01,  1.8750e+00,  ..., -1.6016e+00,
          -6.8750e-01, -1.5723e-01]],

        [[-1.7090e-02,  5.7983e-03,  1.4267e-03,  ..., -4.4189e-02,
          -2.8076e-03, -2.6001e-02],
         [ 2.1250e+00, -6.4844e-01, -5.2734e-01,  ...,  1.3359e+00,
          -8.1641e-01,  7.1484e-01],
         [-1.7383e-01,  3.6328e-01,  2.1406e+00,  ...,  1.6953e+00,
          -2.8750e+00, -4.3555e-01],
         ...,
         [ 5.0781e-01, -1.5312e+00,  1.0781e+00,  ..., -3.1562e+00,
           5.9375e-01,  4.2969e-01],
         [-4.1211e-01, -4.0234e-01,  1.1562e+00,  ..., -2.0156e+00,
          -1.7188e+00,  1.5391e+00],
         [-1.6172e+00, -1.4160e-01,  1.8750e+00,  ..., -1.6016e+00,
          -6.8750e-01, -1.5723e-01]],

        ...,

        [[-1.4221e-02,  2.1240e-02, -1.3855e-02,  ...,  1.4258e-01,
           8.3984e-02, -1.8906e+00],
         [ 2.9297e-02, -2.5391e-01, -9.0625e-01,  ...,  4.1016e-01,
          -1.7344e+00,  5.9375e+00],
         [ 4.1406e-01,  3.7842e-02,  1.7578e-01,  ..., -5.4297e-01,
          -2.0469e+00,  6.5312e+00],
         ...,
         [ 1.0156e+00,  8.3594e-01, -3.8477e-01,  ..., -8.2422e-01,
           1.6016e+00,  5.2188e+00],
         [ 1.0234e+00,  4.8633e-01,  1.6602e-01,  ..., -2.6562e+00,
          -6.7188e-01,  5.3438e+00],
         [ 1.0391e+00,  2.2266e-01, -4.4922e-01,  ..., -2.5156e+00,
          -1.1484e+00,  4.9688e+00]],

        [[-1.4221e-02,  2.1240e-02, -1.3855e-02,  ...,  1.4258e-01,
           8.3984e-02, -1.8906e+00],
         [ 2.9297e-02, -2.5391e-01, -9.0625e-01,  ...,  4.1016e-01,
          -1.7344e+00,  5.9375e+00],
         [ 4.1406e-01,  3.7842e-02,  1.7578e-01,  ..., -5.4297e-01,
          -2.0469e+00,  6.5312e+00],
         ...,
         [ 1.0156e+00,  8.3594e-01, -3.8477e-01,  ..., -8.2422e-01,
           1.6016e+00,  5.2188e+00],
         [ 1.0234e+00,  4.8633e-01,  1.6602e-01,  ..., -2.6562e+00,
          -6.7188e-01,  5.3438e+00],
         [ 1.0391e+00,  2.2266e-01, -4.4922e-01,  ..., -2.5156e+00,
          -1.1484e+00,  4.9688e+00]],

        [[-1.4221e-02,  2.1240e-02, -1.3855e-02,  ...,  1.4258e-01,
           8.3984e-02, -1.8906e+00],
         [ 2.9297e-02, -2.5391e-01, -9.0625e-01,  ...,  4.1016e-01,
          -1.7344e+00,  5.9375e+00],
         [ 4.1406e-01,  3.7842e-02,  1.7578e-01,  ..., -5.4297e-01,
          -2.0469e+00,  6.5312e+00],
         ...,
         [ 1.0156e+00,  8.3594e-01, -3.8477e-01,  ..., -8.2422e-01,
           1.6016e+00,  5.2188e+00],
         [ 1.0234e+00,  4.8633e-01,  1.6602e-01,  ..., -2.6562e+00,
          -6.7188e-01,  5.3438e+00],
         [ 1.0391e+00,  2.2266e-01, -4.4922e-01,  ..., -2.5156e+00,
          -1.1484e+00,  4.9688e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-4.0039e-02, -1.9409e-02, -2.3804e-02,  ...,  3.9795e-02,
           6.1035e-02, -2.9785e-02],
         [-1.0938e+00,  4.3750e-01, -2.5781e-01,  ..., -3.9258e-01,
           7.9590e-02,  1.3828e+00],
         [-7.3828e-01, -2.5391e-01,  6.8359e-01,  ..., -4.2578e-01,
           6.7969e-01,  4.2578e-01],
         ...,
         [-7.7344e-01,  1.6172e+00, -2.0469e+00,  ...,  4.6875e-01,
           1.2500e+00,  7.4219e-01],
         [ 3.7109e-02, -2.1875e-01, -3.4766e-01,  ..., -7.3828e-01,
          -6.5625e-01, -2.0625e+00],
         [ 7.8906e-01, -8.4375e-01, -3.8281e+00,  ...,  9.5312e-01,
           6.2891e-01, -2.9375e+00]],

        [[-4.0039e-02, -1.9409e-02, -2.3804e-02,  ...,  3.9795e-02,
           6.1035e-02, -2.9785e-02],
         [-1.0938e+00,  4.3750e-01, -2.5781e-01,  ..., -3.9258e-01,
           7.9590e-02,  1.3828e+00],
         [-7.3828e-01, -2.5391e-01,  6.8359e-01,  ..., -4.2578e-01,
           6.7969e-01,  4.2578e-01],
         ...,
         [-7.7344e-01,  1.6172e+00, -2.0469e+00,  ...,  4.6875e-01,
           1.2500e+00,  7.4219e-01],
         [ 3.7109e-02, -2.1875e-01, -3.4766e-01,  ..., -7.3828e-01,
          -6.5625e-01, -2.0625e+00],
         [ 7.8906e-01, -8.4375e-01, -3.8281e+00,  ...,  9.5312e-01,
           6.2891e-01, -2.9375e+00]],

        [[-4.0039e-02, -1.9409e-02, -2.3804e-02,  ...,  3.9795e-02,
           6.1035e-02, -2.9785e-02],
         [-1.0938e+00,  4.3750e-01, -2.5781e-01,  ..., -3.9258e-01,
           7.9590e-02,  1.3828e+00],
         [-7.3828e-01, -2.5391e-01,  6.8359e-01,  ..., -4.2578e-01,
           6.7969e-01,  4.2578e-01],
         ...,
         [-7.7344e-01,  1.6172e+00, -2.0469e+00,  ...,  4.6875e-01,
           1.2500e+00,  7.4219e-01],
         [ 3.7109e-02, -2.1875e-01, -3.4766e-01,  ..., -7.3828e-01,
          -6.5625e-01, -2.0625e+00],
         [ 7.8906e-01, -8.4375e-01, -3.8281e+00,  ...,  9.5312e-01,
           6.2891e-01, -2.9375e+00]],

        ...,

        [[ 6.0730e-03, -1.4343e-02, -1.2146e-02,  ...,  1.8692e-03,
          -4.9744e-03,  3.8757e-03],
         [-4.6875e-01,  2.4414e-01, -9.5312e-01,  ...,  5.5859e-01,
          -7.8125e-01,  4.4727e-01],
         [-8.5547e-01, -5.8984e-01,  4.5898e-01,  ..., -2.5625e+00,
          -5.1270e-02,  3.3008e-01],
         ...,
         [-3.9453e-01, -3.5352e-01,  7.9688e-01,  ..., -4.6680e-01,
          -1.5869e-02,  6.4844e-01],
         [ 2.4902e-02, -4.5898e-01, -9.2969e-01,  ..., -1.0078e+00,
          -5.7678e-03,  3.3594e-01],
         [-1.9062e+00, -7.9688e-01,  1.7822e-02,  ..., -1.8438e+00,
           6.4453e-02, -4.3750e-01]],

        [[ 6.0730e-03, -1.4343e-02, -1.2146e-02,  ...,  1.8692e-03,
          -4.9744e-03,  3.8757e-03],
         [-4.6875e-01,  2.4414e-01, -9.5312e-01,  ...,  5.5859e-01,
          -7.8125e-01,  4.4727e-01],
         [-8.5547e-01, -5.8984e-01,  4.5898e-01,  ..., -2.5625e+00,
          -5.1270e-02,  3.3008e-01],
         ...,
         [-3.9453e-01, -3.5352e-01,  7.9688e-01,  ..., -4.6680e-01,
          -1.5869e-02,  6.4844e-01],
         [ 2.4902e-02, -4.5898e-01, -9.2969e-01,  ..., -1.0078e+00,
          -5.7678e-03,  3.3594e-01],
         [-1.9062e+00, -7.9688e-01,  1.7822e-02,  ..., -1.8438e+00,
           6.4453e-02, -4.3750e-01]],

        [[ 6.0730e-03, -1.4343e-02, -1.2146e-02,  ...,  1.8692e-03,
          -4.9744e-03,  3.8757e-03],
         [-4.6875e-01,  2.4414e-01, -9.5312e-01,  ...,  5.5859e-01,
          -7.8125e-01,  4.4727e-01],
         [-8.5547e-01, -5.8984e-01,  4.5898e-01,  ..., -2.5625e+00,
          -5.1270e-02,  3.3008e-01],
         ...,
         [-3.9453e-01, -3.5352e-01,  7.9688e-01,  ..., -4.6680e-01,
          -1.5869e-02,  6.4844e-01],
         [ 2.4902e-02, -4.5898e-01, -9.2969e-01,  ..., -1.0078e+00,
          -5.7678e-03,  3.3594e-01],
         [-1.9062e+00, -7.9688e-01,  1.7822e-02,  ..., -1.8438e+00,
           6.4453e-02, -4.3750e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-0.0069,  0.0071, -0.0229,  ..., -0.2656, -0.3203,  0.5664],
         [ 1.9844, -0.9141,  1.1562,  ..., -2.0938, -0.4727,  0.7070],
         [ 1.4766, -0.3887,  0.7656,  ..., -2.5625, -0.1885, -0.8750],
         ...,
         [ 0.9492,  0.5312, -1.0234,  ..., -1.1406,  0.0277, -2.6562],
         [ 0.4082, -0.0649, -0.3672,  ..., -2.3750,  1.6641, -3.0000],
         [ 0.7344,  0.7656, -0.0039,  ..., -2.3750, -1.9297, -2.0469]],

        [[-0.0069,  0.0071, -0.0229,  ..., -0.2656, -0.3203,  0.5664],
         [ 1.9844, -0.9141,  1.1562,  ..., -2.0938, -0.4727,  0.7070],
         [ 1.4766, -0.3887,  0.7656,  ..., -2.5625, -0.1885, -0.8750],
         ...,
         [ 0.9492,  0.5312, -1.0234,  ..., -1.1406,  0.0277, -2.6562],
         [ 0.4082, -0.0649, -0.3672,  ..., -2.3750,  1.6641, -3.0000],
         [ 0.7344,  0.7656, -0.0039,  ..., -2.3750, -1.9297, -2.0469]],

        [[-0.0069,  0.0071, -0.0229,  ..., -0.2656, -0.3203,  0.5664],
         [ 1.9844, -0.9141,  1.1562,  ..., -2.0938, -0.4727,  0.7070],
         [ 1.4766, -0.3887,  0.7656,  ..., -2.5625, -0.1885, -0.8750],
         ...,
         [ 0.9492,  0.5312, -1.0234,  ..., -1.1406,  0.0277, -2.6562],
         [ 0.4082, -0.0649, -0.3672,  ..., -2.3750,  1.6641, -3.0000],
         [ 0.7344,  0.7656, -0.0039,  ..., -2.3750, -1.9297, -2.0469]],

        ...,

        [[-0.0102, -0.0106,  0.0152,  ..., -0.0306,  0.2354,  0.2734],
         [ 0.5859,  0.5820,  1.4844,  ...,  0.5117, -0.1738, -0.3418],
         [-0.3145,  2.4375,  0.7930,  ..., -0.0059,  0.1729, -0.4980],
         ...,
         [ 0.3984,  0.6250, -1.8047,  ..., -0.6641,  0.0165, -0.4453],
         [-0.3789,  0.5547,  0.1738,  ...,  1.2578, -1.5703,  1.8906],
         [-2.2188, -0.7695,  0.3887,  ..., -0.2637, -0.2061,  1.6406]],

        [[-0.0102, -0.0106,  0.0152,  ..., -0.0306,  0.2354,  0.2734],
         [ 0.5859,  0.5820,  1.4844,  ...,  0.5117, -0.1738, -0.3418],
         [-0.3145,  2.4375,  0.7930,  ..., -0.0059,  0.1729, -0.4980],
         ...,
         [ 0.3984,  0.6250, -1.8047,  ..., -0.6641,  0.0165, -0.4453],
         [-0.3789,  0.5547,  0.1738,  ...,  1.2578, -1.5703,  1.8906],
         [-2.2188, -0.7695,  0.3887,  ..., -0.2637, -0.2061,  1.6406]],

        [[-0.0102, -0.0106,  0.0152,  ..., -0.0306,  0.2354,  0.2734],
         [ 0.5859,  0.5820,  1.4844,  ...,  0.5117, -0.1738, -0.3418],
         [-0.3145,  2.4375,  0.7930,  ..., -0.0059,  0.1729, -0.4980],
         ...,
         [ 0.3984,  0.6250, -1.8047,  ..., -0.6641,  0.0165, -0.4453],
         [-0.3789,  0.5547,  0.1738,  ...,  1.2578, -1.5703,  1.8906],
         [-2.2188, -0.7695,  0.3887,  ..., -0.2637, -0.2061,  1.6406]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[ 0.0066,  0.0254,  0.0081,  ...,  0.0171,  0.0244, -0.0117],
         [ 0.2949, -0.6914,  0.5703,  ...,  0.0679,  0.4258, -0.1797],
         [-0.5859, -1.0625, -0.5781,  ...,  0.0036,  0.9961, -1.3984],
         ...,
         [-0.2295,  0.6641, -0.1377,  ..., -1.1875,  0.7969,  0.6602],
         [-0.0615,  0.4805,  1.2031,  ..., -1.3047,  0.3047,  0.3750],
         [-0.4727,  0.3672,  1.0938,  ...,  0.9219, -0.0688, -0.8086]],

        [[ 0.0066,  0.0254,  0.0081,  ...,  0.0171,  0.0244, -0.0117],
         [ 0.2949, -0.6914,  0.5703,  ...,  0.0679,  0.4258, -0.1797],
         [-0.5859, -1.0625, -0.5781,  ...,  0.0036,  0.9961, -1.3984],
         ...,
         [-0.2295,  0.6641, -0.1377,  ..., -1.1875,  0.7969,  0.6602],
         [-0.0615,  0.4805,  1.2031,  ..., -1.3047,  0.3047,  0.3750],
         [-0.4727,  0.3672,  1.0938,  ...,  0.9219, -0.0688, -0.8086]],

        [[ 0.0066,  0.0254,  0.0081,  ...,  0.0171,  0.0244, -0.0117],
         [ 0.2949, -0.6914,  0.5703,  ...,  0.0679,  0.4258, -0.1797],
         [-0.5859, -1.0625, -0.5781,  ...,  0.0036,  0.9961, -1.3984],
         ...,
         [-0.2295,  0.6641, -0.1377,  ..., -1.1875,  0.7969,  0.6602],
         [-0.0615,  0.4805,  1.2031,  ..., -1.3047,  0.3047,  0.3750],
         [-0.4727,  0.3672,  1.0938,  ...,  0.9219, -0.0688, -0.8086]],

        ...,

        [[-0.0140, -0.0067,  0.0165,  ..., -0.0031, -0.0269,  0.0166],
         [ 0.7578,  0.6719, -0.7812,  ..., -0.4512, -1.5156, -0.4316],
         [ 1.1406,  1.6641, -0.1426,  ..., -0.7422,  0.7422, -2.5156],
         ...,
         [-1.5625, -1.3594,  0.0552,  ..., -0.0302, -0.5117,  0.4004],
         [-1.6406, -0.3691, -0.2852,  ...,  0.9336, -0.0918, -1.4609],
         [-1.5547, -0.0308, -1.0156,  ..., -0.6406, -0.0928,  0.2949]],

        [[-0.0140, -0.0067,  0.0165,  ..., -0.0031, -0.0269,  0.0166],
         [ 0.7578,  0.6719, -0.7812,  ..., -0.4512, -1.5156, -0.4316],
         [ 1.1406,  1.6641, -0.1426,  ..., -0.7422,  0.7422, -2.5156],
         ...,
         [-1.5625, -1.3594,  0.0552,  ..., -0.0302, -0.5117,  0.4004],
         [-1.6406, -0.3691, -0.2852,  ...,  0.9336, -0.0918, -1.4609],
         [-1.5547, -0.0308, -1.0156,  ..., -0.6406, -0.0928,  0.2949]],

        [[-0.0140, -0.0067,  0.0165,  ..., -0.0031, -0.0269,  0.0166],
         [ 0.7578,  0.6719, -0.7812,  ..., -0.4512, -1.5156, -0.4316],
         [ 1.1406,  1.6641, -0.1426,  ..., -0.7422,  0.7422, -2.5156],
         ...,
         [-1.5625, -1.3594,  0.0552,  ..., -0.0302, -0.5117,  0.4004],
         [-1.6406, -0.3691, -0.2852,  ...,  0.9336, -0.0918, -1.4609],
         [-1.5547, -0.0308, -1.0156,  ..., -0.6406, -0.0928,  0.2949]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[ 4.0894e-03,  4.7607e-03, -2.6245e-03,  ..., -6.8848e-02,
           4.7852e-02,  1.8262e-01],
         [-3.7500e+00,  1.0625e+00, -1.1094e+00,  ..., -1.0303e-01,
           3.3203e-01, -1.0547e+00],
         [-1.4766e+00,  3.1562e+00, -1.0625e+00,  ..., -1.7578e-02,
          -9.4531e-01, -2.1250e+00],
         ...,
         [-2.8438e+00,  1.2266e+00,  1.4688e+00,  ...,  6.7578e-01,
           1.4141e+00, -8.2812e-01],
         [-4.9609e-01,  8.3984e-01,  6.4062e-01,  ...,  6.5234e-01,
          -1.3828e+00,  1.9062e+00],
         [ 1.1094e+00, -3.2812e-01, -3.3008e-01,  ...,  3.6865e-02,
          -7.6953e-01,  1.6406e-01]],

        [[ 4.0894e-03,  4.7607e-03, -2.6245e-03,  ..., -6.8848e-02,
           4.7852e-02,  1.8262e-01],
         [-3.7500e+00,  1.0625e+00, -1.1094e+00,  ..., -1.0303e-01,
           3.3203e-01, -1.0547e+00],
         [-1.4766e+00,  3.1562e+00, -1.0625e+00,  ..., -1.7578e-02,
          -9.4531e-01, -2.1250e+00],
         ...,
         [-2.8438e+00,  1.2266e+00,  1.4688e+00,  ...,  6.7578e-01,
           1.4141e+00, -8.2812e-01],
         [-4.9609e-01,  8.3984e-01,  6.4062e-01,  ...,  6.5234e-01,
          -1.3828e+00,  1.9062e+00],
         [ 1.1094e+00, -3.2812e-01, -3.3008e-01,  ...,  3.6865e-02,
          -7.6953e-01,  1.6406e-01]],

        [[ 4.0894e-03,  4.7607e-03, -2.6245e-03,  ..., -6.8848e-02,
           4.7852e-02,  1.8262e-01],
         [-3.7500e+00,  1.0625e+00, -1.1094e+00,  ..., -1.0303e-01,
           3.3203e-01, -1.0547e+00],
         [-1.4766e+00,  3.1562e+00, -1.0625e+00,  ..., -1.7578e-02,
          -9.4531e-01, -2.1250e+00],
         ...,
         [-2.8438e+00,  1.2266e+00,  1.4688e+00,  ...,  6.7578e-01,
           1.4141e+00, -8.2812e-01],
         [-4.9609e-01,  8.3984e-01,  6.4062e-01,  ...,  6.5234e-01,
          -1.3828e+00,  1.9062e+00],
         [ 1.1094e+00, -3.2812e-01, -3.3008e-01,  ...,  3.6865e-02,
          -7.6953e-01,  1.6406e-01]],

        ...,

        [[ 8.1787e-03,  1.9165e-02, -1.0620e-02,  ...,  1.6699e-01,
          -1.3965e-01, -1.6016e-01],
         [-1.8906e+00,  1.1406e+00,  4.0039e-01,  ...,  1.2305e-01,
           4.7852e-01, -2.8906e-01],
         [ 5.7031e-01, -1.0469e+00, -5.1172e-01,  ..., -6.7578e-01,
          -2.7734e-01, -1.0781e+00],
         ...,
         [-1.7773e-01, -3.9453e-01, -1.3906e+00,  ...,  5.5078e-01,
          -2.9688e-01, -6.3965e-02],
         [-1.6016e-01, -5.8203e-01, -4.2773e-01,  ...,  9.1016e-01,
          -7.6172e-01,  9.6875e-01],
         [ 1.6484e+00, -1.0312e+00, -1.4922e+00,  ...,  3.8281e-01,
           1.3828e+00, -8.8379e-02]],

        [[ 8.1787e-03,  1.9165e-02, -1.0620e-02,  ...,  1.6699e-01,
          -1.3965e-01, -1.6016e-01],
         [-1.8906e+00,  1.1406e+00,  4.0039e-01,  ...,  1.2305e-01,
           4.7852e-01, -2.8906e-01],
         [ 5.7031e-01, -1.0469e+00, -5.1172e-01,  ..., -6.7578e-01,
          -2.7734e-01, -1.0781e+00],
         ...,
         [-1.7773e-01, -3.9453e-01, -1.3906e+00,  ...,  5.5078e-01,
          -2.9688e-01, -6.3965e-02],
         [-1.6016e-01, -5.8203e-01, -4.2773e-01,  ...,  9.1016e-01,
          -7.6172e-01,  9.6875e-01],
         [ 1.6484e+00, -1.0312e+00, -1.4922e+00,  ...,  3.8281e-01,
           1.3828e+00, -8.8379e-02]],

        [[ 8.1787e-03,  1.9165e-02, -1.0620e-02,  ...,  1.6699e-01,
          -1.3965e-01, -1.6016e-01],
         [-1.8906e+00,  1.1406e+00,  4.0039e-01,  ...,  1.2305e-01,
           4.7852e-01, -2.8906e-01],
         [ 5.7031e-01, -1.0469e+00, -5.1172e-01,  ..., -6.7578e-01,
          -2.7734e-01, -1.0781e+00],
         ...,
         [-1.7773e-01, -3.9453e-01, -1.3906e+00,  ...,  5.5078e-01,
          -2.9688e-01, -6.3965e-02],
         [-1.6016e-01, -5.8203e-01, -4.2773e-01,  ...,  9.1016e-01,
          -7.6172e-01,  9.6875e-01],
         [ 1.6484e+00, -1.0312e+00, -1.4922e+00,  ...,  3.8281e-01,
           1.3828e+00, -8.8379e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-0.0254, -0.0189,  0.0183,  ..., -0.0042,  0.1099, -0.0179],
         [-0.3301,  0.7344,  1.5312,  ..., -0.8516, -0.1631,  0.8555],
         [ 0.0043,  0.4824, -0.2041,  ..., -0.1660,  0.5938,  0.3242],
         ...,
         [-0.1543,  0.1572,  0.0144,  ..., -1.6328,  0.8789,  1.7969],
         [-0.2715, -0.2930,  0.0708,  ..., -1.1562, -1.2656,  1.4141],
         [ 0.0659,  0.4883,  0.2578,  ...,  0.1475, -0.5508, -0.7031]],

        [[-0.0254, -0.0189,  0.0183,  ..., -0.0042,  0.1099, -0.0179],
         [-0.3301,  0.7344,  1.5312,  ..., -0.8516, -0.1631,  0.8555],
         [ 0.0043,  0.4824, -0.2041,  ..., -0.1660,  0.5938,  0.3242],
         ...,
         [-0.1543,  0.1572,  0.0144,  ..., -1.6328,  0.8789,  1.7969],
         [-0.2715, -0.2930,  0.0708,  ..., -1.1562, -1.2656,  1.4141],
         [ 0.0659,  0.4883,  0.2578,  ...,  0.1475, -0.5508, -0.7031]],

        [[-0.0254, -0.0189,  0.0183,  ..., -0.0042,  0.1099, -0.0179],
         [-0.3301,  0.7344,  1.5312,  ..., -0.8516, -0.1631,  0.8555],
         [ 0.0043,  0.4824, -0.2041,  ..., -0.1660,  0.5938,  0.3242],
         ...,
         [-0.1543,  0.1572,  0.0144,  ..., -1.6328,  0.8789,  1.7969],
         [-0.2715, -0.2930,  0.0708,  ..., -1.1562, -1.2656,  1.4141],
         [ 0.0659,  0.4883,  0.2578,  ...,  0.1475, -0.5508, -0.7031]],

        ...,

        [[-0.0245,  0.0131,  0.0088,  ...,  0.0496, -0.0391, -0.0024],
         [-0.2373, -0.2256,  0.3340,  ..., -0.0225, -0.1021,  0.9844],
         [-0.2637, -1.0156, -0.0444,  ..., -0.5508, -0.2002, -0.1064],
         ...,
         [ 0.4883,  0.1514,  1.5000,  ..., -2.3125,  0.5117,  0.1934],
         [ 0.4023, -0.1167, -0.0220,  ..., -1.4219,  0.4004,  0.7227],
         [ 0.4785,  0.1875, -0.4512,  ...,  0.1953, -0.0601, -0.0166]],

        [[-0.0245,  0.0131,  0.0088,  ...,  0.0496, -0.0391, -0.0024],
         [-0.2373, -0.2256,  0.3340,  ..., -0.0225, -0.1021,  0.9844],
         [-0.2637, -1.0156, -0.0444,  ..., -0.5508, -0.2002, -0.1064],
         ...,
         [ 0.4883,  0.1514,  1.5000,  ..., -2.3125,  0.5117,  0.1934],
         [ 0.4023, -0.1167, -0.0220,  ..., -1.4219,  0.4004,  0.7227],
         [ 0.4785,  0.1875, -0.4512,  ...,  0.1953, -0.0601, -0.0166]],

        [[-0.0245,  0.0131,  0.0088,  ...,  0.0496, -0.0391, -0.0024],
         [-0.2373, -0.2256,  0.3340,  ..., -0.0225, -0.1021,  0.9844],
         [-0.2637, -1.0156, -0.0444,  ..., -0.5508, -0.2002, -0.1064],
         ...,
         [ 0.4883,  0.1514,  1.5000,  ..., -2.3125,  0.5117,  0.1934],
         [ 0.4023, -0.1167, -0.0220,  ..., -1.4219,  0.4004,  0.7227],
         [ 0.4785,  0.1875, -0.4512,  ...,  0.1953, -0.0601, -0.0166]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.4048e-02,  1.5381e-02, -2.6489e-02,  ...,  1.5442e-02,
           1.6094e+00, -2.5586e-01],
         [-1.9297e+00,  1.5312e+00,  3.8281e-01,  ..., -1.4844e+00,
          -4.0000e+00,  7.5000e-01],
         [-4.2188e-01,  1.6562e+00,  1.6953e+00,  ...,  1.9453e+00,
          -5.9062e+00,  1.8555e-01],
         ...,
         [-7.4219e-01,  1.7188e-01, -1.2188e+00,  ..., -3.7031e+00,
          -5.1875e+00,  6.4062e-01],
         [-1.7188e-01, -8.3984e-01, -1.4062e+00,  ..., -1.6328e+00,
          -6.4688e+00,  1.0625e+00],
         [ 1.9922e+00, -1.2422e+00, -1.0391e+00,  ..., -1.9141e+00,
          -4.5312e+00,  3.3984e-01]],

        [[ 2.4048e-02,  1.5381e-02, -2.6489e-02,  ...,  1.5442e-02,
           1.6094e+00, -2.5586e-01],
         [-1.9297e+00,  1.5312e+00,  3.8281e-01,  ..., -1.4844e+00,
          -4.0000e+00,  7.5000e-01],
         [-4.2188e-01,  1.6562e+00,  1.6953e+00,  ...,  1.9453e+00,
          -5.9062e+00,  1.8555e-01],
         ...,
         [-7.4219e-01,  1.7188e-01, -1.2188e+00,  ..., -3.7031e+00,
          -5.1875e+00,  6.4062e-01],
         [-1.7188e-01, -8.3984e-01, -1.4062e+00,  ..., -1.6328e+00,
          -6.4688e+00,  1.0625e+00],
         [ 1.9922e+00, -1.2422e+00, -1.0391e+00,  ..., -1.9141e+00,
          -4.5312e+00,  3.3984e-01]],

        [[ 2.4048e-02,  1.5381e-02, -2.6489e-02,  ...,  1.5442e-02,
           1.6094e+00, -2.5586e-01],
         [-1.9297e+00,  1.5312e+00,  3.8281e-01,  ..., -1.4844e+00,
          -4.0000e+00,  7.5000e-01],
         [-4.2188e-01,  1.6562e+00,  1.6953e+00,  ...,  1.9453e+00,
          -5.9062e+00,  1.8555e-01],
         ...,
         [-7.4219e-01,  1.7188e-01, -1.2188e+00,  ..., -3.7031e+00,
          -5.1875e+00,  6.4062e-01],
         [-1.7188e-01, -8.3984e-01, -1.4062e+00,  ..., -1.6328e+00,
          -6.4688e+00,  1.0625e+00],
         [ 1.9922e+00, -1.2422e+00, -1.0391e+00,  ..., -1.9141e+00,
          -4.5312e+00,  3.3984e-01]],

        ...,

        [[-1.1475e-02,  4.6997e-03, -4.4861e-03,  ...,  4.9438e-03,
           5.2002e-02, -1.9844e+00],
         [-9.8047e-01, -8.5547e-01,  8.7500e-01,  ...,  9.5312e-01,
           1.0681e-02,  5.5000e+00],
         [-2.0312e+00, -2.3594e+00, -6.3672e-01,  ..., -8.6328e-01,
          -3.6719e-01,  7.7500e+00],
         ...,
         [-1.4922e+00, -1.3125e+00, -1.1328e+00,  ...,  7.7344e-01,
           7.3828e-01,  4.9375e+00],
         [-4.8633e-01, -8.9453e-01, -1.8359e-01,  ...,  1.5859e+00,
           1.5234e+00,  6.5938e+00],
         [-4.5117e-01, -6.8359e-01, -9.2578e-01,  ...,  2.7031e+00,
           3.2812e-01,  5.8438e+00]],

        [[-1.1475e-02,  4.6997e-03, -4.4861e-03,  ...,  4.9438e-03,
           5.2002e-02, -1.9844e+00],
         [-9.8047e-01, -8.5547e-01,  8.7500e-01,  ...,  9.5312e-01,
           1.0681e-02,  5.5000e+00],
         [-2.0312e+00, -2.3594e+00, -6.3672e-01,  ..., -8.6328e-01,
          -3.6719e-01,  7.7500e+00],
         ...,
         [-1.4922e+00, -1.3125e+00, -1.1328e+00,  ...,  7.7344e-01,
           7.3828e-01,  4.9375e+00],
         [-4.8633e-01, -8.9453e-01, -1.8359e-01,  ...,  1.5859e+00,
           1.5234e+00,  6.5938e+00],
         [-4.5117e-01, -6.8359e-01, -9.2578e-01,  ...,  2.7031e+00,
           3.2812e-01,  5.8438e+00]],

        [[-1.1475e-02,  4.6997e-03, -4.4861e-03,  ...,  4.9438e-03,
           5.2002e-02, -1.9844e+00],
         [-9.8047e-01, -8.5547e-01,  8.7500e-01,  ...,  9.5312e-01,
           1.0681e-02,  5.5000e+00],
         [-2.0312e+00, -2.3594e+00, -6.3672e-01,  ..., -8.6328e-01,
          -3.6719e-01,  7.7500e+00],
         ...,
         [-1.4922e+00, -1.3125e+00, -1.1328e+00,  ...,  7.7344e-01,
           7.3828e-01,  4.9375e+00],
         [-4.8633e-01, -8.9453e-01, -1.8359e-01,  ...,  1.5859e+00,
           1.5234e+00,  6.5938e+00],
         [-4.5117e-01, -6.8359e-01, -9.2578e-01,  ...,  2.7031e+00,
           3.2812e-01,  5.8438e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 9.2163e-03,  1.6724e-02,  7.9346e-03,  ..., -2.8442e-02,
          -1.2573e-02,  1.4038e-02],
         [ 3.3984e-01,  3.1836e-01,  3.7305e-01,  ...,  4.7852e-01,
          -5.5078e-01,  4.5703e-01],
         [-2.9883e-01, -6.1328e-01,  1.1094e+00,  ..., -9.4531e-01,
           6.4844e-01, -4.6680e-01],
         ...,
         [ 4.6680e-01,  5.5078e-01, -1.4922e+00,  ...,  6.7969e-01,
           1.6406e+00,  3.8867e-01],
         [ 1.1797e+00,  1.1250e+00, -2.2168e-01,  ...,  1.5234e+00,
           3.6719e+00, -3.1445e-01],
         [ 1.0781e+00,  1.3047e+00,  8.6719e-01,  ...,  1.2344e+00,
           1.0469e+00, -8.6719e-01]],

        [[ 9.2163e-03,  1.6724e-02,  7.9346e-03,  ..., -2.8442e-02,
          -1.2573e-02,  1.4038e-02],
         [ 3.3984e-01,  3.1836e-01,  3.7305e-01,  ...,  4.7852e-01,
          -5.5078e-01,  4.5703e-01],
         [-2.9883e-01, -6.1328e-01,  1.1094e+00,  ..., -9.4531e-01,
           6.4844e-01, -4.6680e-01],
         ...,
         [ 4.6680e-01,  5.5078e-01, -1.4922e+00,  ...,  6.7969e-01,
           1.6406e+00,  3.8867e-01],
         [ 1.1797e+00,  1.1250e+00, -2.2168e-01,  ...,  1.5234e+00,
           3.6719e+00, -3.1445e-01],
         [ 1.0781e+00,  1.3047e+00,  8.6719e-01,  ...,  1.2344e+00,
           1.0469e+00, -8.6719e-01]],

        [[ 9.2163e-03,  1.6724e-02,  7.9346e-03,  ..., -2.8442e-02,
          -1.2573e-02,  1.4038e-02],
         [ 3.3984e-01,  3.1836e-01,  3.7305e-01,  ...,  4.7852e-01,
          -5.5078e-01,  4.5703e-01],
         [-2.9883e-01, -6.1328e-01,  1.1094e+00,  ..., -9.4531e-01,
           6.4844e-01, -4.6680e-01],
         ...,
         [ 4.6680e-01,  5.5078e-01, -1.4922e+00,  ...,  6.7969e-01,
           1.6406e+00,  3.8867e-01],
         [ 1.1797e+00,  1.1250e+00, -2.2168e-01,  ...,  1.5234e+00,
           3.6719e+00, -3.1445e-01],
         [ 1.0781e+00,  1.3047e+00,  8.6719e-01,  ...,  1.2344e+00,
           1.0469e+00, -8.6719e-01]],

        ...,

        [[ 1.2817e-02,  1.9897e-02, -2.3682e-02,  ...,  7.8201e-05,
          -3.9795e-02,  1.3306e-02],
         [ 1.4922e+00, -5.1172e-01, -7.1094e-01,  ...,  1.6016e+00,
          -3.3594e-01,  2.4707e-01],
         [ 2.9297e-01, -3.4062e+00, -3.9648e-01,  ..., -7.1094e-01,
          -1.3828e+00, -2.6875e+00],
         ...,
         [-4.9414e-01, -1.2305e-01, -1.0791e-01,  ...,  3.7695e-01,
          -1.5078e+00,  1.8594e+00],
         [-6.6016e-01, -1.2266e+00,  8.8281e-01,  ..., -1.4609e+00,
           3.7598e-02, -1.5234e+00],
         [ 1.1250e+00,  1.8359e+00,  9.0625e-01,  ..., -1.4922e+00,
          -1.9922e-01,  7.0312e-01]],

        [[ 1.2817e-02,  1.9897e-02, -2.3682e-02,  ...,  7.8201e-05,
          -3.9795e-02,  1.3306e-02],
         [ 1.4922e+00, -5.1172e-01, -7.1094e-01,  ...,  1.6016e+00,
          -3.3594e-01,  2.4707e-01],
         [ 2.9297e-01, -3.4062e+00, -3.9648e-01,  ..., -7.1094e-01,
          -1.3828e+00, -2.6875e+00],
         ...,
         [-4.9414e-01, -1.2305e-01, -1.0791e-01,  ...,  3.7695e-01,
          -1.5078e+00,  1.8594e+00],
         [-6.6016e-01, -1.2266e+00,  8.8281e-01,  ..., -1.4609e+00,
           3.7598e-02, -1.5234e+00],
         [ 1.1250e+00,  1.8359e+00,  9.0625e-01,  ..., -1.4922e+00,
          -1.9922e-01,  7.0312e-01]],

        [[ 1.2817e-02,  1.9897e-02, -2.3682e-02,  ...,  7.8201e-05,
          -3.9795e-02,  1.3306e-02],
         [ 1.4922e+00, -5.1172e-01, -7.1094e-01,  ...,  1.6016e+00,
          -3.3594e-01,  2.4707e-01],
         [ 2.9297e-01, -3.4062e+00, -3.9648e-01,  ..., -7.1094e-01,
          -1.3828e+00, -2.6875e+00],
         ...,
         [-4.9414e-01, -1.2305e-01, -1.0791e-01,  ...,  3.7695e-01,
          -1.5078e+00,  1.8594e+00],
         [-6.6016e-01, -1.2266e+00,  8.8281e-01,  ..., -1.4609e+00,
           3.7598e-02, -1.5234e+00],
         [ 1.1250e+00,  1.8359e+00,  9.0625e-01,  ..., -1.4922e+00,
          -1.9922e-01,  7.0312e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-0.2656, -0.6289, -0.9727,  ...,  0.3301,  0.3516, -1.3516],
         [ 0.3125,  0.0537,  0.1123,  ...,  0.2930,  0.2480,  0.2793],
         [-0.2490, -0.2793, -0.0850,  ...,  0.4082, -2.8750, -1.3750],
         ...,
         [ 0.5781,  1.2500,  1.0391,  ...,  4.0625, -3.5156, -2.3125],
         [ 0.6719,  0.7148,  0.9531,  ...,  1.0469, -1.4766, -0.2793],
         [-0.4727, -0.0156,  0.3750,  ...,  2.2344, -2.3281, -0.0830]],

        [[-0.2656, -0.6289, -0.9727,  ...,  0.3301,  0.3516, -1.3516],
         [ 0.3125,  0.0537,  0.1123,  ...,  0.2930,  0.2480,  0.2793],
         [-0.2490, -0.2793, -0.0850,  ...,  0.4082, -2.8750, -1.3750],
         ...,
         [ 0.5781,  1.2500,  1.0391,  ...,  4.0625, -3.5156, -2.3125],
         [ 0.6719,  0.7148,  0.9531,  ...,  1.0469, -1.4766, -0.2793],
         [-0.4727, -0.0156,  0.3750,  ...,  2.2344, -2.3281, -0.0830]],

        [[-0.2656, -0.6289, -0.9727,  ...,  0.3301,  0.3516, -1.3516],
         [ 0.3125,  0.0537,  0.1123,  ...,  0.2930,  0.2480,  0.2793],
         [-0.2490, -0.2793, -0.0850,  ...,  0.4082, -2.8750, -1.3750],
         ...,
         [ 0.5781,  1.2500,  1.0391,  ...,  4.0625, -3.5156, -2.3125],
         [ 0.6719,  0.7148,  0.9531,  ...,  1.0469, -1.4766, -0.2793],
         [-0.4727, -0.0156,  0.3750,  ...,  2.2344, -2.3281, -0.0830]],

        ...,

        [[-0.0134,  0.9180, -0.1533,  ...,  2.3906, -3.7188,  1.6250],
         [ 0.9141, -0.7188, -0.4258,  ..., -0.6445, -0.3828,  1.3047],
         [-0.3672,  0.5781,  0.1060,  ...,  0.2412,  0.2207, -0.4941],
         ...,
         [ 0.4062,  0.6875,  1.2500,  ..., -3.4531,  0.4219, -2.7656],
         [-0.8125,  0.6562,  1.0156,  ..., -2.9688, -4.0938, -0.3340],
         [-1.1797,  0.5312,  0.7070,  ..., -3.2500,  0.8203, -2.3281]],

        [[-0.0134,  0.9180, -0.1533,  ...,  2.3906, -3.7188,  1.6250],
         [ 0.9141, -0.7188, -0.4258,  ..., -0.6445, -0.3828,  1.3047],
         [-0.3672,  0.5781,  0.1060,  ...,  0.2412,  0.2207, -0.4941],
         ...,
         [ 0.4062,  0.6875,  1.2500,  ..., -3.4531,  0.4219, -2.7656],
         [-0.8125,  0.6562,  1.0156,  ..., -2.9688, -4.0938, -0.3340],
         [-1.1797,  0.5312,  0.7070,  ..., -3.2500,  0.8203, -2.3281]],

        [[-0.0134,  0.9180, -0.1533,  ...,  2.3906, -3.7188,  1.6250],
         [ 0.9141, -0.7188, -0.4258,  ..., -0.6445, -0.3828,  1.3047],
         [-0.3672,  0.5781,  0.1060,  ...,  0.2412,  0.2207, -0.4941],
         ...,
         [ 0.4062,  0.6875,  1.2500,  ..., -3.4531,  0.4219, -2.7656],
         [-0.8125,  0.6562,  1.0156,  ..., -2.9688, -4.0938, -0.3340],
         [-1.1797,  0.5312,  0.7070,  ..., -3.2500,  0.8203, -2.3281]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[ 0.4805, -0.1099, -0.3379,  ..., -0.0300,  0.0986,  0.2188],
         [-0.0181, -0.2637, -0.2461,  ..., -0.2080, -0.1777, -0.1348],
         [-0.4180, -0.0654,  0.0275,  ...,  0.0275, -0.0159,  0.0723],
         ...,
         [ 0.2021,  0.0786,  0.4375,  ..., -0.6445,  0.6992, -0.8477],
         [ 0.2148,  0.3418, -0.0889,  ..., -0.0840,  0.3711, -0.3125],
         [ 0.4238,  0.2969,  0.4492,  ...,  0.4297,  0.6289, -0.4531]],

        [[ 0.4805, -0.1099, -0.3379,  ..., -0.0300,  0.0986,  0.2188],
         [-0.0181, -0.2637, -0.2461,  ..., -0.2080, -0.1777, -0.1348],
         [-0.4180, -0.0654,  0.0275,  ...,  0.0275, -0.0159,  0.0723],
         ...,
         [ 0.2021,  0.0786,  0.4375,  ..., -0.6445,  0.6992, -0.8477],
         [ 0.2148,  0.3418, -0.0889,  ..., -0.0840,  0.3711, -0.3125],
         [ 0.4238,  0.2969,  0.4492,  ...,  0.4297,  0.6289, -0.4531]],

        [[ 0.4805, -0.1099, -0.3379,  ..., -0.0300,  0.0986,  0.2188],
         [-0.0181, -0.2637, -0.2461,  ..., -0.2080, -0.1777, -0.1348],
         [-0.4180, -0.0654,  0.0275,  ...,  0.0275, -0.0159,  0.0723],
         ...,
         [ 0.2021,  0.0786,  0.4375,  ..., -0.6445,  0.6992, -0.8477],
         [ 0.2148,  0.3418, -0.0889,  ..., -0.0840,  0.3711, -0.3125],
         [ 0.4238,  0.2969,  0.4492,  ...,  0.4297,  0.6289, -0.4531]],

        ...,

        [[ 0.3242, -0.4062, -0.1709,  ...,  0.2695, -0.5391, -0.3926],
         [ 0.2695, -0.2197,  0.2227,  ..., -0.2734, -0.5898,  0.0435],
         [ 0.8477, -0.6094, -0.0427,  ..., -0.1914, -0.5508, -0.7031],
         ...,
         [-1.0234,  1.2969, -0.0723,  ...,  0.1797,  0.4980,  0.5195],
         [-1.3984,  1.1719, -0.3809,  ...,  0.3809,  0.0420,  0.3145],
         [-0.1367, -0.0142, -0.7539,  ...,  0.3477, -0.4102,  0.3594]],

        [[ 0.3242, -0.4062, -0.1709,  ...,  0.2695, -0.5391, -0.3926],
         [ 0.2695, -0.2197,  0.2227,  ..., -0.2734, -0.5898,  0.0435],
         [ 0.8477, -0.6094, -0.0427,  ..., -0.1914, -0.5508, -0.7031],
         ...,
         [-1.0234,  1.2969, -0.0723,  ...,  0.1797,  0.4980,  0.5195],
         [-1.3984,  1.1719, -0.3809,  ...,  0.3809,  0.0420,  0.3145],
         [-0.1367, -0.0142, -0.7539,  ...,  0.3477, -0.4102,  0.3594]],

        [[ 0.3242, -0.4062, -0.1709,  ...,  0.2695, -0.5391, -0.3926],
         [ 0.2695, -0.2197,  0.2227,  ..., -0.2734, -0.5898,  0.0435],
         [ 0.8477, -0.6094, -0.0427,  ..., -0.1914, -0.5508, -0.7031],
         ...,
         [-1.0234,  1.2969, -0.0723,  ...,  0.1797,  0.4980,  0.5195],
         [-1.3984,  1.1719, -0.3809,  ...,  0.3809,  0.0420,  0.3145],
         [-0.1367, -0.0142, -0.7539,  ...,  0.3477, -0.4102,  0.3594]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)))}, logits=tensor([[[ -8.7500, -10.6250, -11.9375,  ..., -10.8125, -12.0000,  -9.5000],
         [ -9.1250, -10.5000, -12.1250,  ...,  -8.1250,  -9.4375,  -7.1562],
         [-16.1250, -22.2500, -24.0000,  ..., -19.1250, -19.5000, -18.6250],
         ...,
         [-14.1250, -16.0000, -20.2500,  ..., -12.3750, -17.7500, -11.3125],
         [-13.1875, -14.3125, -18.8750,  ..., -15.3125, -17.6250, -14.6250],
         [-13.5000, -14.4375, -18.0000,  ..., -12.5000, -16.5000,  -8.8750]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), past_key_values=((tensor([[[ 0.5547, -1.4062, -0.3340,  ...,  0.1514, -0.9297, -1.0312],
         [-0.4785, -1.4062, -0.4316,  ...,  0.0583, -1.4062,  0.1992],
         [-2.2500, -2.3281,  0.4707,  ..., -1.1562, -0.1973,  0.9883],
         ...,
         [-0.5508,  1.2656,  0.2754,  ...,  0.0527, -1.4062,  0.1992],
         [-0.6953,  1.1406, -0.3242,  ..., -0.9414, -0.6211, -1.1953],
         [-1.4141,  1.2656,  0.2314,  ..., -0.1182, -0.2041, -0.2812]],

        [[ 0.5547, -1.4062, -0.3340,  ...,  0.1514, -0.9297, -1.0312],
         [-0.4785, -1.4062, -0.4316,  ...,  0.0583, -1.4062,  0.1992],
         [-2.2500, -2.3281,  0.4707,  ..., -1.1562, -0.1973,  0.9883],
         ...,
         [-0.5508,  1.2656,  0.2754,  ...,  0.0527, -1.4062,  0.1992],
         [-0.6953,  1.1406, -0.3242,  ..., -0.9414, -0.6211, -1.1953],
         [-1.4141,  1.2656,  0.2314,  ..., -0.1182, -0.2041, -0.2812]],

        [[ 0.5547, -1.4062, -0.3340,  ...,  0.1514, -0.9297, -1.0312],
         [-0.4785, -1.4062, -0.4316,  ...,  0.0583, -1.4062,  0.1992],
         [-2.2500, -2.3281,  0.4707,  ..., -1.1562, -0.1973,  0.9883],
         ...,
         [-0.5508,  1.2656,  0.2754,  ...,  0.0527, -1.4062,  0.1992],
         [-0.6953,  1.1406, -0.3242,  ..., -0.9414, -0.6211, -1.1953],
         [-1.4141,  1.2656,  0.2314,  ..., -0.1182, -0.2041, -0.2812]],

        ...,

        [[ 0.4551, -0.1377, -0.2383,  ...,  0.9648, -1.7422,  0.1562],
         [ 0.2422, -1.1875, -0.4629,  ..., -0.1816, -0.8867, -0.7070],
         [-1.3672, -1.3750, -1.0781,  ...,  0.2119,  1.6172,  0.6758],
         ...,
         [ 0.0078, -0.5391,  0.4355,  ..., -0.1787, -0.8945, -0.7070],
         [-0.1738,  0.0957,  0.4004,  ...,  0.7578, -1.8203,  1.6328],
         [-0.4863,  1.2500, -2.1094,  ...,  0.9961,  1.2109,  1.5938]],

        [[ 0.4551, -0.1377, -0.2383,  ...,  0.9648, -1.7422,  0.1562],
         [ 0.2422, -1.1875, -0.4629,  ..., -0.1816, -0.8867, -0.7070],
         [-1.3672, -1.3750, -1.0781,  ...,  0.2119,  1.6172,  0.6758],
         ...,
         [ 0.0078, -0.5391,  0.4355,  ..., -0.1787, -0.8945, -0.7070],
         [-0.1738,  0.0957,  0.4004,  ...,  0.7578, -1.8203,  1.6328],
         [-0.4863,  1.2500, -2.1094,  ...,  0.9961,  1.2109,  1.5938]],

        [[ 0.4551, -0.1377, -0.2383,  ...,  0.9648, -1.7422,  0.1562],
         [ 0.2422, -1.1875, -0.4629,  ..., -0.1816, -0.8867, -0.7070],
         [-1.3672, -1.3750, -1.0781,  ...,  0.2119,  1.6172,  0.6758],
         ...,
         [ 0.0078, -0.5391,  0.4355,  ..., -0.1787, -0.8945, -0.7070],
         [-0.1738,  0.0957,  0.4004,  ...,  0.7578, -1.8203,  1.6328],
         [-0.4863,  1.2500, -2.1094,  ...,  0.9961,  1.2109,  1.5938]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[-0.0036,  0.0417,  0.0364,  ..., -0.0087, -0.0391, -0.0474],
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [-0.0591, -0.0118, -0.1279,  ..., -0.1592,  0.1338,  0.0255],
         ...,
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [ 0.0087,  0.0205,  0.0557,  ...,  0.0085,  0.0039, -0.0454],
         [-0.0659,  0.0427,  0.1006,  ...,  0.1055, -0.0527, -0.0339]],

        [[-0.0036,  0.0417,  0.0364,  ..., -0.0087, -0.0391, -0.0474],
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [-0.0591, -0.0118, -0.1279,  ..., -0.1592,  0.1338,  0.0255],
         ...,
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [ 0.0087,  0.0205,  0.0557,  ...,  0.0085,  0.0039, -0.0454],
         [-0.0659,  0.0427,  0.1006,  ...,  0.1055, -0.0527, -0.0339]],

        [[-0.0036,  0.0417,  0.0364,  ..., -0.0087, -0.0391, -0.0474],
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [-0.0591, -0.0118, -0.1279,  ..., -0.1592,  0.1338,  0.0255],
         ...,
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [ 0.0087,  0.0205,  0.0557,  ...,  0.0085,  0.0039, -0.0454],
         [-0.0659,  0.0427,  0.1006,  ...,  0.1055, -0.0527, -0.0339]],

        ...,

        [[-0.0593, -0.0640, -0.0276,  ..., -0.0116, -0.0459, -0.0016],
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0708, -0.0918,  0.2285,  ..., -0.0635,  0.1396,  0.0603],
         ...,
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0923,  0.0505,  0.0068,  ...,  0.0239, -0.0119,  0.0031],
         [-0.2773, -0.3125, -0.3086,  ...,  0.0464,  0.1826, -0.2871]],

        [[-0.0593, -0.0640, -0.0276,  ..., -0.0116, -0.0459, -0.0016],
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0708, -0.0918,  0.2285,  ..., -0.0635,  0.1396,  0.0603],
         ...,
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0923,  0.0505,  0.0068,  ...,  0.0239, -0.0119,  0.0031],
         [-0.2773, -0.3125, -0.3086,  ...,  0.0464,  0.1826, -0.2871]],

        [[-0.0593, -0.0640, -0.0276,  ..., -0.0116, -0.0459, -0.0016],
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0708, -0.0918,  0.2285,  ..., -0.0635,  0.1396,  0.0603],
         ...,
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0923,  0.0505,  0.0068,  ...,  0.0239, -0.0119,  0.0031],
         [-0.2773, -0.3125, -0.3086,  ...,  0.0464,  0.1826, -0.2871]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.4531,  1.6719, -4.0938,  ..., -1.4219,  0.8477,  1.2422],
         [ 7.1250,  2.4375, -2.5938,  ..., -1.1406,  0.6445,  1.5938],
         [ 4.6875,  4.4375, -2.6875,  ..., -0.1709,  0.5898,  3.0156],
         ...,
         [ 7.1875,  3.5156,  0.3594,  ..., -2.2969,  0.7539,  3.3750],
         [ 2.5312,  0.1777, -1.4219,  ...,  0.1572,  2.0312, -0.9766],
         [-3.5000, -2.9531, -4.5312,  ..., -1.9609,  1.1016,  1.3672]],

        [[ 2.4531,  1.6719, -4.0938,  ..., -1.4219,  0.8477,  1.2422],
         [ 7.1250,  2.4375, -2.5938,  ..., -1.1406,  0.6445,  1.5938],
         [ 4.6875,  4.4375, -2.6875,  ..., -0.1709,  0.5898,  3.0156],
         ...,
         [ 7.1875,  3.5156,  0.3594,  ..., -2.2969,  0.7539,  3.3750],
         [ 2.5312,  0.1777, -1.4219,  ...,  0.1572,  2.0312, -0.9766],
         [-3.5000, -2.9531, -4.5312,  ..., -1.9609,  1.1016,  1.3672]],

        [[ 2.4531,  1.6719, -4.0938,  ..., -1.4219,  0.8477,  1.2422],
         [ 7.1250,  2.4375, -2.5938,  ..., -1.1406,  0.6445,  1.5938],
         [ 4.6875,  4.4375, -2.6875,  ..., -0.1709,  0.5898,  3.0156],
         ...,
         [ 7.1875,  3.5156,  0.3594,  ..., -2.2969,  0.7539,  3.3750],
         [ 2.5312,  0.1777, -1.4219,  ...,  0.1572,  2.0312, -0.9766],
         [-3.5000, -2.9531, -4.5312,  ..., -1.9609,  1.1016,  1.3672]],

        ...,

        [[ 0.1846,  0.7070, -2.5938,  ..., -1.6562,  3.1719, -1.0938],
         [ 3.0312,  1.2500, -1.8672,  ..., -2.9219,  2.3594, -0.6680],
         [ 5.0625,  0.7070, -0.4258,  ..., -4.0938,  3.5781, -3.2344],
         ...,
         [ 3.1094, -0.4531, -1.0938,  ..., -2.7969,  4.0625, -3.1562],
         [ 0.1914, -0.6250, -0.1875,  ..., -4.5000,  2.4844, -2.0000],
         [-2.2344,  3.1250, -4.0625,  ..., -2.6406,  3.4844, -0.3887]],

        [[ 0.1846,  0.7070, -2.5938,  ..., -1.6562,  3.1719, -1.0938],
         [ 3.0312,  1.2500, -1.8672,  ..., -2.9219,  2.3594, -0.6680],
         [ 5.0625,  0.7070, -0.4258,  ..., -4.0938,  3.5781, -3.2344],
         ...,
         [ 3.1094, -0.4531, -1.0938,  ..., -2.7969,  4.0625, -3.1562],
         [ 0.1914, -0.6250, -0.1875,  ..., -4.5000,  2.4844, -2.0000],
         [-2.2344,  3.1250, -4.0625,  ..., -2.6406,  3.4844, -0.3887]],

        [[ 0.1846,  0.7070, -2.5938,  ..., -1.6562,  3.1719, -1.0938],
         [ 3.0312,  1.2500, -1.8672,  ..., -2.9219,  2.3594, -0.6680],
         [ 5.0625,  0.7070, -0.4258,  ..., -4.0938,  3.5781, -3.2344],
         ...,
         [ 3.1094, -0.4531, -1.0938,  ..., -2.7969,  4.0625, -3.1562],
         [ 0.1914, -0.6250, -0.1875,  ..., -4.5000,  2.4844, -2.0000],
         [-2.2344,  3.1250, -4.0625,  ..., -2.6406,  3.4844, -0.3887]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[ 0.0366,  0.0337,  0.0234,  ...,  0.1299,  0.2031, -0.1338],
         [-0.0269, -0.0304,  0.0752,  ..., -0.1118, -0.0603,  0.0413],
         [-0.0747,  0.0601,  0.0187,  ..., -0.0109, -0.2598, -0.1670],
         ...,
         [ 0.2061, -0.0820,  0.0376,  ..., -0.0986,  0.1738,  0.1660],
         [-0.0284, -0.0187,  0.0200,  ...,  0.0508, -0.0062, -0.0474],
         [-0.1807,  0.1826,  0.0069,  ...,  0.1045, -0.3145, -0.1138]],

        [[ 0.0366,  0.0337,  0.0234,  ...,  0.1299,  0.2031, -0.1338],
         [-0.0269, -0.0304,  0.0752,  ..., -0.1118, -0.0603,  0.0413],
         [-0.0747,  0.0601,  0.0187,  ..., -0.0109, -0.2598, -0.1670],
         ...,
         [ 0.2061, -0.0820,  0.0376,  ..., -0.0986,  0.1738,  0.1660],
         [-0.0284, -0.0187,  0.0200,  ...,  0.0508, -0.0062, -0.0474],
         [-0.1807,  0.1826,  0.0069,  ...,  0.1045, -0.3145, -0.1138]],

        [[ 0.0366,  0.0337,  0.0234,  ...,  0.1299,  0.2031, -0.1338],
         [-0.0269, -0.0304,  0.0752,  ..., -0.1118, -0.0603,  0.0413],
         [-0.0747,  0.0601,  0.0187,  ..., -0.0109, -0.2598, -0.1670],
         ...,
         [ 0.2061, -0.0820,  0.0376,  ..., -0.0986,  0.1738,  0.1660],
         [-0.0284, -0.0187,  0.0200,  ...,  0.0508, -0.0062, -0.0474],
         [-0.1807,  0.1826,  0.0069,  ...,  0.1045, -0.3145, -0.1138]],

        ...,

        [[ 0.0688,  0.0198,  0.0096,  ..., -0.0469, -0.0825, -0.0283],
         [-0.0437,  0.0231, -0.0981,  ...,  0.0354, -0.0835, -0.0356],
         [-0.1641,  0.0330, -0.0334,  ...,  0.0674, -0.1543,  0.1328],
         ...,
         [ 0.0854, -0.0124, -0.1245,  ...,  0.0864, -0.0591, -0.0588],
         [-0.0104, -0.0232,  0.0012,  ...,  0.0289,  0.0244,  0.0532],
         [ 0.0466,  0.1074,  0.2637,  ..., -0.0938,  0.0044,  0.0801]],

        [[ 0.0688,  0.0198,  0.0096,  ..., -0.0469, -0.0825, -0.0283],
         [-0.0437,  0.0231, -0.0981,  ...,  0.0354, -0.0835, -0.0356],
         [-0.1641,  0.0330, -0.0334,  ...,  0.0674, -0.1543,  0.1328],
         ...,
         [ 0.0854, -0.0124, -0.1245,  ...,  0.0864, -0.0591, -0.0588],
         [-0.0104, -0.0232,  0.0012,  ...,  0.0289,  0.0244,  0.0532],
         [ 0.0466,  0.1074,  0.2637,  ..., -0.0938,  0.0044,  0.0801]],

        [[ 0.0688,  0.0198,  0.0096,  ..., -0.0469, -0.0825, -0.0283],
         [-0.0437,  0.0231, -0.0981,  ...,  0.0354, -0.0835, -0.0356],
         [-0.1641,  0.0330, -0.0334,  ...,  0.0674, -0.1543,  0.1328],
         ...,
         [ 0.0854, -0.0124, -0.1245,  ...,  0.0864, -0.0591, -0.0588],
         [-0.0104, -0.0232,  0.0012,  ...,  0.0289,  0.0244,  0.0532],
         [ 0.0466,  0.1074,  0.2637,  ..., -0.0938,  0.0044,  0.0801]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[-2.1820e-03, -5.0964e-03,  8.5831e-04,  ..., -9.5312e-01,
           6.8970e-03,  4.1406e-01],
         [ 1.4219e+00,  1.4141e+00,  2.2188e+00,  ...,  6.9062e+00,
           1.8164e-01, -1.7188e+00],
         [ 1.3672e+00,  7.8613e-02,  1.4531e+00,  ...,  6.3438e+00,
          -1.0547e+00, -2.1562e+00],
         ...,
         [ 2.2031e+00, -3.1250e-01, -2.8516e-01,  ...,  7.3438e+00,
          -2.1406e+00, -2.9688e+00],
         [ 1.5234e-01, -2.6367e-01,  1.6699e-01,  ...,  6.3125e+00,
          -8.0469e-01, -1.4844e+00],
         [ 1.2031e+00, -1.0469e+00,  1.1016e+00,  ...,  6.2812e+00,
           1.1250e+00, -1.1328e+00]],

        [[-2.1820e-03, -5.0964e-03,  8.5831e-04,  ..., -9.5312e-01,
           6.8970e-03,  4.1406e-01],
         [ 1.4219e+00,  1.4141e+00,  2.2188e+00,  ...,  6.9062e+00,
           1.8164e-01, -1.7188e+00],
         [ 1.3672e+00,  7.8613e-02,  1.4531e+00,  ...,  6.3438e+00,
          -1.0547e+00, -2.1562e+00],
         ...,
         [ 2.2031e+00, -3.1250e-01, -2.8516e-01,  ...,  7.3438e+00,
          -2.1406e+00, -2.9688e+00],
         [ 1.5234e-01, -2.6367e-01,  1.6699e-01,  ...,  6.3125e+00,
          -8.0469e-01, -1.4844e+00],
         [ 1.2031e+00, -1.0469e+00,  1.1016e+00,  ...,  6.2812e+00,
           1.1250e+00, -1.1328e+00]],

        [[-2.1820e-03, -5.0964e-03,  8.5831e-04,  ..., -9.5312e-01,
           6.8970e-03,  4.1406e-01],
         [ 1.4219e+00,  1.4141e+00,  2.2188e+00,  ...,  6.9062e+00,
           1.8164e-01, -1.7188e+00],
         [ 1.3672e+00,  7.8613e-02,  1.4531e+00,  ...,  6.3438e+00,
          -1.0547e+00, -2.1562e+00],
         ...,
         [ 2.2031e+00, -3.1250e-01, -2.8516e-01,  ...,  7.3438e+00,
          -2.1406e+00, -2.9688e+00],
         [ 1.5234e-01, -2.6367e-01,  1.6699e-01,  ...,  6.3125e+00,
          -8.0469e-01, -1.4844e+00],
         [ 1.2031e+00, -1.0469e+00,  1.1016e+00,  ...,  6.2812e+00,
           1.1250e+00, -1.1328e+00]],

        ...,

        [[ 1.0315e-02,  7.4005e-04, -5.4626e-03,  ..., -4.5654e-02,
           9.8145e-02,  1.5015e-02],
         [ 4.0000e+00,  1.5156e+00,  1.7188e+00,  ...,  8.3008e-02,
          -2.4844e+00,  1.7188e+00],
         [ 2.1875e+00, -4.4434e-02,  1.3672e+00,  ...,  1.4609e+00,
          -1.6250e+00,  1.2578e+00],
         ...,
         [ 3.4844e+00,  9.6484e-01,  5.4297e-01,  ...,  1.0625e+00,
          -1.5938e+00,  7.6172e-01],
         [ 8.1250e-01,  7.3242e-04,  7.5000e-01,  ...,  1.5156e+00,
           4.2773e-01,  2.3594e+00],
         [-5.3516e-01,  1.0469e+00,  2.1719e+00,  ...,  6.0547e-01,
          -1.0391e+00,  2.4805e-01]],

        [[ 1.0315e-02,  7.4005e-04, -5.4626e-03,  ..., -4.5654e-02,
           9.8145e-02,  1.5015e-02],
         [ 4.0000e+00,  1.5156e+00,  1.7188e+00,  ...,  8.3008e-02,
          -2.4844e+00,  1.7188e+00],
         [ 2.1875e+00, -4.4434e-02,  1.3672e+00,  ...,  1.4609e+00,
          -1.6250e+00,  1.2578e+00],
         ...,
         [ 3.4844e+00,  9.6484e-01,  5.4297e-01,  ...,  1.0625e+00,
          -1.5938e+00,  7.6172e-01],
         [ 8.1250e-01,  7.3242e-04,  7.5000e-01,  ...,  1.5156e+00,
           4.2773e-01,  2.3594e+00],
         [-5.3516e-01,  1.0469e+00,  2.1719e+00,  ...,  6.0547e-01,
          -1.0391e+00,  2.4805e-01]],

        [[ 1.0315e-02,  7.4005e-04, -5.4626e-03,  ..., -4.5654e-02,
           9.8145e-02,  1.5015e-02],
         [ 4.0000e+00,  1.5156e+00,  1.7188e+00,  ...,  8.3008e-02,
          -2.4844e+00,  1.7188e+00],
         [ 2.1875e+00, -4.4434e-02,  1.3672e+00,  ...,  1.4609e+00,
          -1.6250e+00,  1.2578e+00],
         ...,
         [ 3.4844e+00,  9.6484e-01,  5.4297e-01,  ...,  1.0625e+00,
          -1.5938e+00,  7.6172e-01],
         [ 8.1250e-01,  7.3242e-04,  7.5000e-01,  ...,  1.5156e+00,
           4.2773e-01,  2.3594e+00],
         [-5.3516e-01,  1.0469e+00,  2.1719e+00,  ...,  6.0547e-01,
          -1.0391e+00,  2.4805e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.8997e-03, -1.0864e-02,  2.7313e-03,  ..., -5.3711e-03,
           7.7820e-04,  3.0975e-03],
         [ 1.2793e-01,  3.1445e-01,  5.7812e-01,  ...,  5.1953e-01,
          -1.4771e-02, -3.0078e-01],
         [-2.1973e-01,  5.4443e-02, -1.6699e-01,  ..., -5.3516e-01,
           2.7930e-01,  1.0205e-01],
         ...,
         [-1.7480e-01,  7.2266e-02,  2.2949e-01,  ...,  2.8809e-02,
           2.4512e-01, -1.9775e-02],
         [ 3.5742e-01, -2.8198e-02,  5.3955e-02,  ...,  3.4766e-01,
          -1.0400e-01, -1.5820e-01],
         [ 4.0234e-01,  1.1719e+00, -1.2812e+00,  ...,  1.0156e+00,
          -5.1172e-01,  5.5469e-01]],

        [[-1.8997e-03, -1.0864e-02,  2.7313e-03,  ..., -5.3711e-03,
           7.7820e-04,  3.0975e-03],
         [ 1.2793e-01,  3.1445e-01,  5.7812e-01,  ...,  5.1953e-01,
          -1.4771e-02, -3.0078e-01],
         [-2.1973e-01,  5.4443e-02, -1.6699e-01,  ..., -5.3516e-01,
           2.7930e-01,  1.0205e-01],
         ...,
         [-1.7480e-01,  7.2266e-02,  2.2949e-01,  ...,  2.8809e-02,
           2.4512e-01, -1.9775e-02],
         [ 3.5742e-01, -2.8198e-02,  5.3955e-02,  ...,  3.4766e-01,
          -1.0400e-01, -1.5820e-01],
         [ 4.0234e-01,  1.1719e+00, -1.2812e+00,  ...,  1.0156e+00,
          -5.1172e-01,  5.5469e-01]],

        [[-1.8997e-03, -1.0864e-02,  2.7313e-03,  ..., -5.3711e-03,
           7.7820e-04,  3.0975e-03],
         [ 1.2793e-01,  3.1445e-01,  5.7812e-01,  ...,  5.1953e-01,
          -1.4771e-02, -3.0078e-01],
         [-2.1973e-01,  5.4443e-02, -1.6699e-01,  ..., -5.3516e-01,
           2.7930e-01,  1.0205e-01],
         ...,
         [-1.7480e-01,  7.2266e-02,  2.2949e-01,  ...,  2.8809e-02,
           2.4512e-01, -1.9775e-02],
         [ 3.5742e-01, -2.8198e-02,  5.3955e-02,  ...,  3.4766e-01,
          -1.0400e-01, -1.5820e-01],
         [ 4.0234e-01,  1.1719e+00, -1.2812e+00,  ...,  1.0156e+00,
          -5.1172e-01,  5.5469e-01]],

        ...,

        [[ 1.1658e-02, -7.6599e-03, -4.5967e-04,  ..., -3.1128e-03,
           3.2349e-03,  9.1934e-04],
         [ 2.6953e-01,  2.1582e-01, -2.3633e-01,  ...,  2.2095e-02,
           2.5195e-01,  5.7373e-02],
         [-3.2031e-01,  1.0315e-02,  5.2979e-02,  ...,  4.0820e-01,
           1.6895e-01, -2.3047e-01],
         ...,
         [-2.4048e-02,  1.3184e-01, -2.1289e-01,  ...,  2.3340e-01,
           2.2949e-01,  3.2617e-01],
         [-3.1250e-01,  1.7090e-01, -1.5918e-01,  ...,  2.1289e-01,
          -2.3926e-01, -2.6245e-02],
         [-1.4062e-01, -1.6504e-01,  1.4282e-02,  ..., -2.7344e-01,
          -2.6758e-01, -4.4727e-01]],

        [[ 1.1658e-02, -7.6599e-03, -4.5967e-04,  ..., -3.1128e-03,
           3.2349e-03,  9.1934e-04],
         [ 2.6953e-01,  2.1582e-01, -2.3633e-01,  ...,  2.2095e-02,
           2.5195e-01,  5.7373e-02],
         [-3.2031e-01,  1.0315e-02,  5.2979e-02,  ...,  4.0820e-01,
           1.6895e-01, -2.3047e-01],
         ...,
         [-2.4048e-02,  1.3184e-01, -2.1289e-01,  ...,  2.3340e-01,
           2.2949e-01,  3.2617e-01],
         [-3.1250e-01,  1.7090e-01, -1.5918e-01,  ...,  2.1289e-01,
          -2.3926e-01, -2.6245e-02],
         [-1.4062e-01, -1.6504e-01,  1.4282e-02,  ..., -2.7344e-01,
          -2.6758e-01, -4.4727e-01]],

        [[ 1.1658e-02, -7.6599e-03, -4.5967e-04,  ..., -3.1128e-03,
           3.2349e-03,  9.1934e-04],
         [ 2.6953e-01,  2.1582e-01, -2.3633e-01,  ...,  2.2095e-02,
           2.5195e-01,  5.7373e-02],
         [-3.2031e-01,  1.0315e-02,  5.2979e-02,  ...,  4.0820e-01,
           1.6895e-01, -2.3047e-01],
         ...,
         [-2.4048e-02,  1.3184e-01, -2.1289e-01,  ...,  2.3340e-01,
           2.2949e-01,  3.2617e-01],
         [-3.1250e-01,  1.7090e-01, -1.5918e-01,  ...,  2.1289e-01,
          -2.3926e-01, -2.6245e-02],
         [-1.4062e-01, -1.6504e-01,  1.4282e-02,  ..., -2.7344e-01,
          -2.6758e-01, -4.4727e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 7.3242e-03,  4.6387e-03,  8.1787e-03,  ...,  6.4844e-01,
           5.0000e-01, -3.8477e-01],
         [-1.9688e+00,  1.5000e+00, -7.1484e-01,  ..., -3.2031e+00,
          -2.9375e+00,  2.9883e-01],
         [-5.7031e-01,  5.5078e-01, -1.0781e+00,  ..., -3.3281e+00,
          -3.3906e+00,  1.0469e+00],
         ...,
         [-6.6406e-01,  3.0078e-01,  8.3984e-01,  ..., -4.9375e+00,
          -2.9844e+00,  8.1543e-02],
         [-3.3203e-01, -7.6172e-02,  1.9531e-01,  ..., -4.9062e+00,
          -3.5625e+00,  6.4453e-01],
         [ 4.2188e-01, -8.7891e-01,  1.4355e-01,  ..., -5.2188e+00,
          -2.3281e+00,  1.0625e+00]],

        [[ 7.3242e-03,  4.6387e-03,  8.1787e-03,  ...,  6.4844e-01,
           5.0000e-01, -3.8477e-01],
         [-1.9688e+00,  1.5000e+00, -7.1484e-01,  ..., -3.2031e+00,
          -2.9375e+00,  2.9883e-01],
         [-5.7031e-01,  5.5078e-01, -1.0781e+00,  ..., -3.3281e+00,
          -3.3906e+00,  1.0469e+00],
         ...,
         [-6.6406e-01,  3.0078e-01,  8.3984e-01,  ..., -4.9375e+00,
          -2.9844e+00,  8.1543e-02],
         [-3.3203e-01, -7.6172e-02,  1.9531e-01,  ..., -4.9062e+00,
          -3.5625e+00,  6.4453e-01],
         [ 4.2188e-01, -8.7891e-01,  1.4355e-01,  ..., -5.2188e+00,
          -2.3281e+00,  1.0625e+00]],

        [[ 7.3242e-03,  4.6387e-03,  8.1787e-03,  ...,  6.4844e-01,
           5.0000e-01, -3.8477e-01],
         [-1.9688e+00,  1.5000e+00, -7.1484e-01,  ..., -3.2031e+00,
          -2.9375e+00,  2.9883e-01],
         [-5.7031e-01,  5.5078e-01, -1.0781e+00,  ..., -3.3281e+00,
          -3.3906e+00,  1.0469e+00],
         ...,
         [-6.6406e-01,  3.0078e-01,  8.3984e-01,  ..., -4.9375e+00,
          -2.9844e+00,  8.1543e-02],
         [-3.3203e-01, -7.6172e-02,  1.9531e-01,  ..., -4.9062e+00,
          -3.5625e+00,  6.4453e-01],
         [ 4.2188e-01, -8.7891e-01,  1.4355e-01,  ..., -5.2188e+00,
          -2.3281e+00,  1.0625e+00]],

        ...,

        [[ 2.6398e-03, -5.1575e-03,  3.1586e-03,  ...,  8.0078e-02,
          -5.4321e-03, -2.4219e-01],
         [-5.1875e+00, -8.9062e-01,  7.8516e-01,  ..., -6.6406e-01,
          -1.3281e+00, -6.8750e-01],
         [-3.2188e+00, -5.5664e-02,  1.1719e-02,  ...,  5.1172e-01,
           3.5352e-01, -9.7266e-01],
         ...,
         [-4.7188e+00,  2.7344e+00,  1.2500e+00,  ..., -6.9531e-01,
           4.5508e-01, -7.4219e-01],
         [-2.1406e+00,  2.5312e+00,  1.1875e+00,  ...,  6.6406e-01,
           7.3828e-01, -7.6953e-01],
         [ 3.0469e+00,  1.2812e+00,  2.6562e+00,  ..., -1.1094e+00,
           2.1875e-01, -2.4292e-02]],

        [[ 2.6398e-03, -5.1575e-03,  3.1586e-03,  ...,  8.0078e-02,
          -5.4321e-03, -2.4219e-01],
         [-5.1875e+00, -8.9062e-01,  7.8516e-01,  ..., -6.6406e-01,
          -1.3281e+00, -6.8750e-01],
         [-3.2188e+00, -5.5664e-02,  1.1719e-02,  ...,  5.1172e-01,
           3.5352e-01, -9.7266e-01],
         ...,
         [-4.7188e+00,  2.7344e+00,  1.2500e+00,  ..., -6.9531e-01,
           4.5508e-01, -7.4219e-01],
         [-2.1406e+00,  2.5312e+00,  1.1875e+00,  ...,  6.6406e-01,
           7.3828e-01, -7.6953e-01],
         [ 3.0469e+00,  1.2812e+00,  2.6562e+00,  ..., -1.1094e+00,
           2.1875e-01, -2.4292e-02]],

        [[ 2.6398e-03, -5.1575e-03,  3.1586e-03,  ...,  8.0078e-02,
          -5.4321e-03, -2.4219e-01],
         [-5.1875e+00, -8.9062e-01,  7.8516e-01,  ..., -6.6406e-01,
          -1.3281e+00, -6.8750e-01],
         [-3.2188e+00, -5.5664e-02,  1.1719e-02,  ...,  5.1172e-01,
           3.5352e-01, -9.7266e-01],
         ...,
         [-4.7188e+00,  2.7344e+00,  1.2500e+00,  ..., -6.9531e-01,
           4.5508e-01, -7.4219e-01],
         [-2.1406e+00,  2.5312e+00,  1.1875e+00,  ...,  6.6406e-01,
           7.3828e-01, -7.6953e-01],
         [ 3.0469e+00,  1.2812e+00,  2.6562e+00,  ..., -1.1094e+00,
           2.1875e-01, -2.4292e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 5.3406e-04,  6.9580e-03, -4.0283e-03,  ...,  1.3428e-03,
          -3.2349e-03, -1.6174e-03],
         [ 1.1963e-01, -6.2109e-01,  2.1680e-01,  ..., -2.0874e-02,
           1.0193e-02,  2.7930e-01],
         [-3.1055e-01,  4.1016e-01,  8.4375e-01,  ...,  1.8457e-01,
          -7.3730e-02, -3.7305e-01],
         ...,
         [ 1.3672e-01,  4.7656e-01,  3.8477e-01,  ...,  1.4258e-01,
           3.2422e-01,  1.6479e-02],
         [ 7.4609e-01, -6.8359e-01,  1.6211e-01,  ..., -2.5781e-01,
           4.6875e-01,  9.9121e-02],
         [-1.3867e-01,  1.6699e-01,  1.1279e-01,  ...,  4.2969e-01,
          -8.2397e-03, -2.7539e-01]],

        [[ 5.3406e-04,  6.9580e-03, -4.0283e-03,  ...,  1.3428e-03,
          -3.2349e-03, -1.6174e-03],
         [ 1.1963e-01, -6.2109e-01,  2.1680e-01,  ..., -2.0874e-02,
           1.0193e-02,  2.7930e-01],
         [-3.1055e-01,  4.1016e-01,  8.4375e-01,  ...,  1.8457e-01,
          -7.3730e-02, -3.7305e-01],
         ...,
         [ 1.3672e-01,  4.7656e-01,  3.8477e-01,  ...,  1.4258e-01,
           3.2422e-01,  1.6479e-02],
         [ 7.4609e-01, -6.8359e-01,  1.6211e-01,  ..., -2.5781e-01,
           4.6875e-01,  9.9121e-02],
         [-1.3867e-01,  1.6699e-01,  1.1279e-01,  ...,  4.2969e-01,
          -8.2397e-03, -2.7539e-01]],

        [[ 5.3406e-04,  6.9580e-03, -4.0283e-03,  ...,  1.3428e-03,
          -3.2349e-03, -1.6174e-03],
         [ 1.1963e-01, -6.2109e-01,  2.1680e-01,  ..., -2.0874e-02,
           1.0193e-02,  2.7930e-01],
         [-3.1055e-01,  4.1016e-01,  8.4375e-01,  ...,  1.8457e-01,
          -7.3730e-02, -3.7305e-01],
         ...,
         [ 1.3672e-01,  4.7656e-01,  3.8477e-01,  ...,  1.4258e-01,
           3.2422e-01,  1.6479e-02],
         [ 7.4609e-01, -6.8359e-01,  1.6211e-01,  ..., -2.5781e-01,
           4.6875e-01,  9.9121e-02],
         [-1.3867e-01,  1.6699e-01,  1.1279e-01,  ...,  4.2969e-01,
          -8.2397e-03, -2.7539e-01]],

        ...,

        [[-3.9368e-03, -2.0752e-03,  3.1433e-03,  ...,  3.3569e-04,
           3.0212e-03, -7.6675e-04],
         [-3.5547e-01,  2.2339e-02, -1.4648e-01,  ..., -1.3770e-01,
          -2.1973e-01,  1.2024e-02],
         [ 8.7891e-02, -1.6699e-01,  6.2500e-02,  ...,  4.0039e-01,
          -2.6953e-01,  1.7383e-01],
         ...,
         [ 7.7148e-02,  2.5586e-01, -2.6172e-01,  ...,  1.7578e-01,
           9.2773e-02, -1.6968e-02],
         [ 1.5039e-01,  4.4678e-02,  1.2061e-01,  ...,  1.4648e-02,
           2.7539e-01, -1.4453e-01],
         [-1.5234e-01, -3.2617e-01, -4.0625e-01,  ...,  7.1716e-03,
          -9.5215e-02,  6.0791e-02]],

        [[-3.9368e-03, -2.0752e-03,  3.1433e-03,  ...,  3.3569e-04,
           3.0212e-03, -7.6675e-04],
         [-3.5547e-01,  2.2339e-02, -1.4648e-01,  ..., -1.3770e-01,
          -2.1973e-01,  1.2024e-02],
         [ 8.7891e-02, -1.6699e-01,  6.2500e-02,  ...,  4.0039e-01,
          -2.6953e-01,  1.7383e-01],
         ...,
         [ 7.7148e-02,  2.5586e-01, -2.6172e-01,  ...,  1.7578e-01,
           9.2773e-02, -1.6968e-02],
         [ 1.5039e-01,  4.4678e-02,  1.2061e-01,  ...,  1.4648e-02,
           2.7539e-01, -1.4453e-01],
         [-1.5234e-01, -3.2617e-01, -4.0625e-01,  ...,  7.1716e-03,
          -9.5215e-02,  6.0791e-02]],

        [[-3.9368e-03, -2.0752e-03,  3.1433e-03,  ...,  3.3569e-04,
           3.0212e-03, -7.6675e-04],
         [-3.5547e-01,  2.2339e-02, -1.4648e-01,  ..., -1.3770e-01,
          -2.1973e-01,  1.2024e-02],
         [ 8.7891e-02, -1.6699e-01,  6.2500e-02,  ...,  4.0039e-01,
          -2.6953e-01,  1.7383e-01],
         ...,
         [ 7.7148e-02,  2.5586e-01, -2.6172e-01,  ...,  1.7578e-01,
           9.2773e-02, -1.6968e-02],
         [ 1.5039e-01,  4.4678e-02,  1.2061e-01,  ...,  1.4648e-02,
           2.7539e-01, -1.4453e-01],
         [-1.5234e-01, -3.2617e-01, -4.0625e-01,  ...,  7.1716e-03,
          -9.5215e-02,  6.0791e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 0.0146,  0.0317, -0.0143,  ..., -0.2393, -0.6523, -0.1162],
         [ 1.6094, -1.1406, -0.0879,  ...,  0.4395,  2.6562, -0.4863],
         [ 2.5781, -0.8555, -0.3262,  ...,  0.2451,  3.0469, -0.1484],
         ...,
         [ 1.9609,  0.9023, -1.1562,  ...,  2.4219,  3.5000,  1.3750],
         [-0.1348, -0.6406, -0.3516,  ...,  1.4609,  0.8047,  0.5312],
         [ 0.5703,  0.3828, -0.3555,  ...,  1.1250,  2.1094, -2.1875]],

        [[ 0.0146,  0.0317, -0.0143,  ..., -0.2393, -0.6523, -0.1162],
         [ 1.6094, -1.1406, -0.0879,  ...,  0.4395,  2.6562, -0.4863],
         [ 2.5781, -0.8555, -0.3262,  ...,  0.2451,  3.0469, -0.1484],
         ...,
         [ 1.9609,  0.9023, -1.1562,  ...,  2.4219,  3.5000,  1.3750],
         [-0.1348, -0.6406, -0.3516,  ...,  1.4609,  0.8047,  0.5312],
         [ 0.5703,  0.3828, -0.3555,  ...,  1.1250,  2.1094, -2.1875]],

        [[ 0.0146,  0.0317, -0.0143,  ..., -0.2393, -0.6523, -0.1162],
         [ 1.6094, -1.1406, -0.0879,  ...,  0.4395,  2.6562, -0.4863],
         [ 2.5781, -0.8555, -0.3262,  ...,  0.2451,  3.0469, -0.1484],
         ...,
         [ 1.9609,  0.9023, -1.1562,  ...,  2.4219,  3.5000,  1.3750],
         [-0.1348, -0.6406, -0.3516,  ...,  1.4609,  0.8047,  0.5312],
         [ 0.5703,  0.3828, -0.3555,  ...,  1.1250,  2.1094, -2.1875]],

        ...,

        [[-0.0132, -0.0184,  0.0106,  ...,  0.0322, -0.2773,  0.1357],
         [ 2.0625, -1.0625, -1.7812,  ..., -0.1167, -0.0369, -0.5156],
         [ 0.3672, -0.5586, -0.3867,  ...,  0.9570, -0.5742, -0.1230],
         ...,
         [ 1.5938, -0.4102,  0.0332,  ...,  1.7266,  1.3203, -1.1406],
         [ 0.0674,  0.0227,  0.0391,  ...,  0.0649, -1.4062,  1.2812],
         [-0.1367,  0.7305, -0.2539,  ...,  1.7578,  0.3047, -4.2812]],

        [[-0.0132, -0.0184,  0.0106,  ...,  0.0322, -0.2773,  0.1357],
         [ 2.0625, -1.0625, -1.7812,  ..., -0.1167, -0.0369, -0.5156],
         [ 0.3672, -0.5586, -0.3867,  ...,  0.9570, -0.5742, -0.1230],
         ...,
         [ 1.5938, -0.4102,  0.0332,  ...,  1.7266,  1.3203, -1.1406],
         [ 0.0674,  0.0227,  0.0391,  ...,  0.0649, -1.4062,  1.2812],
         [-0.1367,  0.7305, -0.2539,  ...,  1.7578,  0.3047, -4.2812]],

        [[-0.0132, -0.0184,  0.0106,  ...,  0.0322, -0.2773,  0.1357],
         [ 2.0625, -1.0625, -1.7812,  ..., -0.1167, -0.0369, -0.5156],
         [ 0.3672, -0.5586, -0.3867,  ...,  0.9570, -0.5742, -0.1230],
         ...,
         [ 1.5938, -0.4102,  0.0332,  ...,  1.7266,  1.3203, -1.1406],
         [ 0.0674,  0.0227,  0.0391,  ...,  0.0649, -1.4062,  1.2812],
         [-0.1367,  0.7305, -0.2539,  ...,  1.7578,  0.3047, -4.2812]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[-9.8267e-03,  3.4027e-03, -1.1963e-02,  ...,  4.8523e-03,
          -4.0894e-03, -9.4604e-03],
         [-1.2988e-01,  6.7578e-01, -2.5586e-01,  ..., -2.1191e-01,
          -2.3828e-01,  6.7188e-01],
         [ 3.7109e-01, -2.2363e-01, -2.2559e-01,  ...,  5.8594e-01,
           6.9531e-01, -2.0117e-01],
         ...,
         [ 2.3193e-02,  1.3867e-01,  7.3828e-01,  ...,  5.7031e-01,
           1.9434e-01,  1.4648e-01],
         [ 8.2422e-01, -5.4688e-01,  6.0938e-01,  ..., -4.4141e-01,
          -1.9434e-01,  5.6641e-01],
         [-1.5332e-01,  3.4912e-02, -2.3535e-01,  ...,  3.2715e-02,
           2.0508e-01, -1.7285e-01]],

        [[-9.8267e-03,  3.4027e-03, -1.1963e-02,  ...,  4.8523e-03,
          -4.0894e-03, -9.4604e-03],
         [-1.2988e-01,  6.7578e-01, -2.5586e-01,  ..., -2.1191e-01,
          -2.3828e-01,  6.7188e-01],
         [ 3.7109e-01, -2.2363e-01, -2.2559e-01,  ...,  5.8594e-01,
           6.9531e-01, -2.0117e-01],
         ...,
         [ 2.3193e-02,  1.3867e-01,  7.3828e-01,  ...,  5.7031e-01,
           1.9434e-01,  1.4648e-01],
         [ 8.2422e-01, -5.4688e-01,  6.0938e-01,  ..., -4.4141e-01,
          -1.9434e-01,  5.6641e-01],
         [-1.5332e-01,  3.4912e-02, -2.3535e-01,  ...,  3.2715e-02,
           2.0508e-01, -1.7285e-01]],

        [[-9.8267e-03,  3.4027e-03, -1.1963e-02,  ...,  4.8523e-03,
          -4.0894e-03, -9.4604e-03],
         [-1.2988e-01,  6.7578e-01, -2.5586e-01,  ..., -2.1191e-01,
          -2.3828e-01,  6.7188e-01],
         [ 3.7109e-01, -2.2363e-01, -2.2559e-01,  ...,  5.8594e-01,
           6.9531e-01, -2.0117e-01],
         ...,
         [ 2.3193e-02,  1.3867e-01,  7.3828e-01,  ...,  5.7031e-01,
           1.9434e-01,  1.4648e-01],
         [ 8.2422e-01, -5.4688e-01,  6.0938e-01,  ..., -4.4141e-01,
          -1.9434e-01,  5.6641e-01],
         [-1.5332e-01,  3.4912e-02, -2.3535e-01,  ...,  3.2715e-02,
           2.0508e-01, -1.7285e-01]],

        ...,

        [[-4.5967e-04, -4.8523e-03, -3.3447e-02,  ...,  6.7139e-03,
           7.4768e-03,  5.1880e-03],
         [-1.9629e-01, -2.9541e-02,  9.2163e-03,  ..., -6.4844e-01,
          -5.3516e-01,  3.8574e-02],
         [ 4.1602e-01, -6.7578e-01,  1.9531e-01,  ..., -4.4531e-01,
           3.5156e-02,  4.7070e-01],
         ...,
         [ 4.4141e-01,  5.0391e-01, -5.2002e-02,  ..., -3.4375e-01,
           6.8848e-02,  2.1973e-01],
         [ 2.0996e-01,  4.8828e-01,  4.5508e-01,  ..., -2.5195e-01,
          -1.0547e-01,  3.1836e-01],
         [-1.2500e-01, -6.4844e-01, -1.1816e-01,  ..., -1.4648e-01,
           8.2016e-04, -3.0859e-01]],

        [[-4.5967e-04, -4.8523e-03, -3.3447e-02,  ...,  6.7139e-03,
           7.4768e-03,  5.1880e-03],
         [-1.9629e-01, -2.9541e-02,  9.2163e-03,  ..., -6.4844e-01,
          -5.3516e-01,  3.8574e-02],
         [ 4.1602e-01, -6.7578e-01,  1.9531e-01,  ..., -4.4531e-01,
           3.5156e-02,  4.7070e-01],
         ...,
         [ 4.4141e-01,  5.0391e-01, -5.2002e-02,  ..., -3.4375e-01,
           6.8848e-02,  2.1973e-01],
         [ 2.0996e-01,  4.8828e-01,  4.5508e-01,  ..., -2.5195e-01,
          -1.0547e-01,  3.1836e-01],
         [-1.2500e-01, -6.4844e-01, -1.1816e-01,  ..., -1.4648e-01,
           8.2016e-04, -3.0859e-01]],

        [[-4.5967e-04, -4.8523e-03, -3.3447e-02,  ...,  6.7139e-03,
           7.4768e-03,  5.1880e-03],
         [-1.9629e-01, -2.9541e-02,  9.2163e-03,  ..., -6.4844e-01,
          -5.3516e-01,  3.8574e-02],
         [ 4.1602e-01, -6.7578e-01,  1.9531e-01,  ..., -4.4531e-01,
           3.5156e-02,  4.7070e-01],
         ...,
         [ 4.4141e-01,  5.0391e-01, -5.2002e-02,  ..., -3.4375e-01,
           6.8848e-02,  2.1973e-01],
         [ 2.0996e-01,  4.8828e-01,  4.5508e-01,  ..., -2.5195e-01,
          -1.0547e-01,  3.1836e-01],
         [-1.2500e-01, -6.4844e-01, -1.1816e-01,  ..., -1.4648e-01,
           8.2016e-04, -3.0859e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-5.8289e-03, -5.2185e-03, -3.2043e-03,  ...,  1.9141e-01,
          -9.6191e-02, -6.4062e-01],
         [-3.3750e+00, -1.9297e+00,  2.3828e-01,  ..., -1.8359e-01,
           7.5391e-01,  2.9375e+00],
         [ 5.9375e-01, -2.5312e+00, -1.2266e+00,  ..., -3.8574e-02,
           8.2812e-01,  1.7812e+00],
         ...,
         [-1.8828e+00, -1.8125e+00,  1.3438e+00,  ..., -1.3125e+00,
          -4.3359e-01,  2.8594e+00],
         [ 1.2188e+00, -1.3867e-01,  1.0703e+00,  ..., -2.2812e+00,
          -1.9824e-01,  3.6406e+00],
         [ 3.5156e+00,  1.7188e+00,  1.7188e+00,  ...,  6.7188e-01,
          -2.3340e-01,  2.0469e+00]],

        [[-5.8289e-03, -5.2185e-03, -3.2043e-03,  ...,  1.9141e-01,
          -9.6191e-02, -6.4062e-01],
         [-3.3750e+00, -1.9297e+00,  2.3828e-01,  ..., -1.8359e-01,
           7.5391e-01,  2.9375e+00],
         [ 5.9375e-01, -2.5312e+00, -1.2266e+00,  ..., -3.8574e-02,
           8.2812e-01,  1.7812e+00],
         ...,
         [-1.8828e+00, -1.8125e+00,  1.3438e+00,  ..., -1.3125e+00,
          -4.3359e-01,  2.8594e+00],
         [ 1.2188e+00, -1.3867e-01,  1.0703e+00,  ..., -2.2812e+00,
          -1.9824e-01,  3.6406e+00],
         [ 3.5156e+00,  1.7188e+00,  1.7188e+00,  ...,  6.7188e-01,
          -2.3340e-01,  2.0469e+00]],

        [[-5.8289e-03, -5.2185e-03, -3.2043e-03,  ...,  1.9141e-01,
          -9.6191e-02, -6.4062e-01],
         [-3.3750e+00, -1.9297e+00,  2.3828e-01,  ..., -1.8359e-01,
           7.5391e-01,  2.9375e+00],
         [ 5.9375e-01, -2.5312e+00, -1.2266e+00,  ..., -3.8574e-02,
           8.2812e-01,  1.7812e+00],
         ...,
         [-1.8828e+00, -1.8125e+00,  1.3438e+00,  ..., -1.3125e+00,
          -4.3359e-01,  2.8594e+00],
         [ 1.2188e+00, -1.3867e-01,  1.0703e+00,  ..., -2.2812e+00,
          -1.9824e-01,  3.6406e+00],
         [ 3.5156e+00,  1.7188e+00,  1.7188e+00,  ...,  6.7188e-01,
          -2.3340e-01,  2.0469e+00]],

        ...,

        [[-2.0599e-03, -3.8338e-04,  5.0049e-03,  ...,  1.0437e-02,
           1.6235e-02,  2.6758e-01],
         [-6.3672e-01,  5.5469e-01, -4.7656e-01,  ..., -2.0625e+00,
           9.9609e-01, -1.2969e+00],
         [ 1.2266e+00,  1.9531e+00,  3.3984e-01,  ..., -7.0703e-01,
          -2.5781e+00,  2.0605e-01],
         ...,
         [ 3.9062e-01,  1.8906e+00,  3.1250e-02,  ...,  6.8359e-01,
           2.2852e-01, -3.0664e-01],
         [ 2.9688e-01, -9.6094e-01,  9.8828e-01,  ...,  1.2578e+00,
          -2.0000e+00, -1.3516e+00],
         [ 1.7734e+00, -2.6953e-01, -1.8906e+00,  ...,  1.6328e+00,
          -9.9609e-01, -6.7188e-01]],

        [[-2.0599e-03, -3.8338e-04,  5.0049e-03,  ...,  1.0437e-02,
           1.6235e-02,  2.6758e-01],
         [-6.3672e-01,  5.5469e-01, -4.7656e-01,  ..., -2.0625e+00,
           9.9609e-01, -1.2969e+00],
         [ 1.2266e+00,  1.9531e+00,  3.3984e-01,  ..., -7.0703e-01,
          -2.5781e+00,  2.0605e-01],
         ...,
         [ 3.9062e-01,  1.8906e+00,  3.1250e-02,  ...,  6.8359e-01,
           2.2852e-01, -3.0664e-01],
         [ 2.9688e-01, -9.6094e-01,  9.8828e-01,  ...,  1.2578e+00,
          -2.0000e+00, -1.3516e+00],
         [ 1.7734e+00, -2.6953e-01, -1.8906e+00,  ...,  1.6328e+00,
          -9.9609e-01, -6.7188e-01]],

        [[-2.0599e-03, -3.8338e-04,  5.0049e-03,  ...,  1.0437e-02,
           1.6235e-02,  2.6758e-01],
         [-6.3672e-01,  5.5469e-01, -4.7656e-01,  ..., -2.0625e+00,
           9.9609e-01, -1.2969e+00],
         [ 1.2266e+00,  1.9531e+00,  3.3984e-01,  ..., -7.0703e-01,
          -2.5781e+00,  2.0605e-01],
         ...,
         [ 3.9062e-01,  1.8906e+00,  3.1250e-02,  ...,  6.8359e-01,
           2.2852e-01, -3.0664e-01],
         [ 2.9688e-01, -9.6094e-01,  9.8828e-01,  ...,  1.2578e+00,
          -2.0000e+00, -1.3516e+00],
         [ 1.7734e+00, -2.6953e-01, -1.8906e+00,  ...,  1.6328e+00,
          -9.9609e-01, -6.7188e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-6.6833e-03,  9.1553e-03, -9.5215e-03,  ...,  9.5825e-03,
          -1.5747e-02,  8.4229e-03],
         [ 1.8164e-01,  5.7861e-02,  8.6719e-01,  ..., -2.4414e-01,
           3.2812e-01,  5.7422e-01],
         [ 1.6113e-01, -7.6953e-01,  4.9414e-01,  ..., -5.0391e-01,
          -2.1191e-01, -4.2578e-01],
         ...,
         [ 1.7188e-01, -1.1572e-01, -2.1484e-02,  ..., -2.3828e-01,
           6.3281e-01,  5.4297e-01],
         [ 8.6914e-02, -3.3398e-01, -7.8125e-02,  ..., -3.1445e-01,
           9.1309e-02, -2.5781e-01],
         [-3.2812e-01,  2.1191e-01,  2.6172e-01,  ..., -1.0205e-01,
           1.3281e+00, -2.5586e-01]],

        [[-6.6833e-03,  9.1553e-03, -9.5215e-03,  ...,  9.5825e-03,
          -1.5747e-02,  8.4229e-03],
         [ 1.8164e-01,  5.7861e-02,  8.6719e-01,  ..., -2.4414e-01,
           3.2812e-01,  5.7422e-01],
         [ 1.6113e-01, -7.6953e-01,  4.9414e-01,  ..., -5.0391e-01,
          -2.1191e-01, -4.2578e-01],
         ...,
         [ 1.7188e-01, -1.1572e-01, -2.1484e-02,  ..., -2.3828e-01,
           6.3281e-01,  5.4297e-01],
         [ 8.6914e-02, -3.3398e-01, -7.8125e-02,  ..., -3.1445e-01,
           9.1309e-02, -2.5781e-01],
         [-3.2812e-01,  2.1191e-01,  2.6172e-01,  ..., -1.0205e-01,
           1.3281e+00, -2.5586e-01]],

        [[-6.6833e-03,  9.1553e-03, -9.5215e-03,  ...,  9.5825e-03,
          -1.5747e-02,  8.4229e-03],
         [ 1.8164e-01,  5.7861e-02,  8.6719e-01,  ..., -2.4414e-01,
           3.2812e-01,  5.7422e-01],
         [ 1.6113e-01, -7.6953e-01,  4.9414e-01,  ..., -5.0391e-01,
          -2.1191e-01, -4.2578e-01],
         ...,
         [ 1.7188e-01, -1.1572e-01, -2.1484e-02,  ..., -2.3828e-01,
           6.3281e-01,  5.4297e-01],
         [ 8.6914e-02, -3.3398e-01, -7.8125e-02,  ..., -3.1445e-01,
           9.1309e-02, -2.5781e-01],
         [-3.2812e-01,  2.1191e-01,  2.6172e-01,  ..., -1.0205e-01,
           1.3281e+00, -2.5586e-01]],

        ...,

        [[ 7.4463e-03,  4.1504e-03, -7.1716e-03,  ...,  4.1199e-03,
           1.5625e-02, -1.1841e-02],
         [ 2.1289e-01,  2.9102e-01, -2.2363e-01,  ...,  1.0010e-01,
           2.1973e-01,  3.2812e-01],
         [-1.3672e-01,  5.5859e-01,  4.2534e-04,  ..., -3.8672e-01,
          -6.1328e-01, -5.4688e-01],
         ...,
         [ 1.4551e-01,  4.1406e-01, -5.8594e-01,  ..., -7.6953e-01,
          -8.2520e-02,  3.2422e-01],
         [-1.6699e-01, -3.3203e-02, -3.7842e-02,  ..., -4.4141e-01,
          -2.6855e-02,  2.8906e-01],
         [ 4.4141e-01,  1.0234e+00,  8.4961e-02,  ...,  3.7305e-01,
          -2.3828e-01, -9.7656e-02]],

        [[ 7.4463e-03,  4.1504e-03, -7.1716e-03,  ...,  4.1199e-03,
           1.5625e-02, -1.1841e-02],
         [ 2.1289e-01,  2.9102e-01, -2.2363e-01,  ...,  1.0010e-01,
           2.1973e-01,  3.2812e-01],
         [-1.3672e-01,  5.5859e-01,  4.2534e-04,  ..., -3.8672e-01,
          -6.1328e-01, -5.4688e-01],
         ...,
         [ 1.4551e-01,  4.1406e-01, -5.8594e-01,  ..., -7.6953e-01,
          -8.2520e-02,  3.2422e-01],
         [-1.6699e-01, -3.3203e-02, -3.7842e-02,  ..., -4.4141e-01,
          -2.6855e-02,  2.8906e-01],
         [ 4.4141e-01,  1.0234e+00,  8.4961e-02,  ...,  3.7305e-01,
          -2.3828e-01, -9.7656e-02]],

        [[ 7.4463e-03,  4.1504e-03, -7.1716e-03,  ...,  4.1199e-03,
           1.5625e-02, -1.1841e-02],
         [ 2.1289e-01,  2.9102e-01, -2.2363e-01,  ...,  1.0010e-01,
           2.1973e-01,  3.2812e-01],
         [-1.3672e-01,  5.5859e-01,  4.2534e-04,  ..., -3.8672e-01,
          -6.1328e-01, -5.4688e-01],
         ...,
         [ 1.4551e-01,  4.1406e-01, -5.8594e-01,  ..., -7.6953e-01,
          -8.2520e-02,  3.2422e-01],
         [-1.6699e-01, -3.3203e-02, -3.7842e-02,  ..., -4.4141e-01,
          -2.6855e-02,  2.8906e-01],
         [ 4.4141e-01,  1.0234e+00,  8.4961e-02,  ...,  3.7305e-01,
          -2.3828e-01, -9.7656e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 0.0128,  0.0156, -0.0175,  ...,  0.0996, -0.0767, -0.1533],
         [-1.1406, -0.2266, -0.6562,  ...,  1.5859,  0.4961,  0.7188],
         [-1.5938,  0.2578,  0.3047,  ...,  0.1641,  1.9609, -1.6484],
         ...,
         [-2.0469, -1.1094, -1.8203,  ...,  0.1680,  1.1094, -2.9062],
         [ 0.0063, -0.4082,  0.0391,  ..., -2.1719,  2.5312,  1.8047],
         [-1.1875, -0.7344, -0.9141,  ..., -0.6055, -0.3867, -0.0410]],

        [[ 0.0128,  0.0156, -0.0175,  ...,  0.0996, -0.0767, -0.1533],
         [-1.1406, -0.2266, -0.6562,  ...,  1.5859,  0.4961,  0.7188],
         [-1.5938,  0.2578,  0.3047,  ...,  0.1641,  1.9609, -1.6484],
         ...,
         [-2.0469, -1.1094, -1.8203,  ...,  0.1680,  1.1094, -2.9062],
         [ 0.0063, -0.4082,  0.0391,  ..., -2.1719,  2.5312,  1.8047],
         [-1.1875, -0.7344, -0.9141,  ..., -0.6055, -0.3867, -0.0410]],

        [[ 0.0128,  0.0156, -0.0175,  ...,  0.0996, -0.0767, -0.1533],
         [-1.1406, -0.2266, -0.6562,  ...,  1.5859,  0.4961,  0.7188],
         [-1.5938,  0.2578,  0.3047,  ...,  0.1641,  1.9609, -1.6484],
         ...,
         [-2.0469, -1.1094, -1.8203,  ...,  0.1680,  1.1094, -2.9062],
         [ 0.0063, -0.4082,  0.0391,  ..., -2.1719,  2.5312,  1.8047],
         [-1.1875, -0.7344, -0.9141,  ..., -0.6055, -0.3867, -0.0410]],

        ...,

        [[-0.0070, -0.0035, -0.0055,  ..., -0.0588, -0.0309,  0.0552],
         [ 1.9453,  1.6719,  0.0781,  ...,  1.7734, -0.5156,  1.0938],
         [ 2.8125,  1.0000,  0.5859,  ...,  1.2578,  0.5547,  1.7344],
         ...,
         [ 2.8750,  0.9023, -0.5430,  ...,  0.7539,  0.1641, -0.8438],
         [ 0.4609, -0.8828,  0.0352,  ...,  0.4473,  1.8203, -1.0000],
         [ 1.2812, -0.7305, -1.2734,  ..., -0.7500, -0.3809, -2.3281]],

        [[-0.0070, -0.0035, -0.0055,  ..., -0.0588, -0.0309,  0.0552],
         [ 1.9453,  1.6719,  0.0781,  ...,  1.7734, -0.5156,  1.0938],
         [ 2.8125,  1.0000,  0.5859,  ...,  1.2578,  0.5547,  1.7344],
         ...,
         [ 2.8750,  0.9023, -0.5430,  ...,  0.7539,  0.1641, -0.8438],
         [ 0.4609, -0.8828,  0.0352,  ...,  0.4473,  1.8203, -1.0000],
         [ 1.2812, -0.7305, -1.2734,  ..., -0.7500, -0.3809, -2.3281]],

        [[-0.0070, -0.0035, -0.0055,  ..., -0.0588, -0.0309,  0.0552],
         [ 1.9453,  1.6719,  0.0781,  ...,  1.7734, -0.5156,  1.0938],
         [ 2.8125,  1.0000,  0.5859,  ...,  1.2578,  0.5547,  1.7344],
         ...,
         [ 2.8750,  0.9023, -0.5430,  ...,  0.7539,  0.1641, -0.8438],
         [ 0.4609, -0.8828,  0.0352,  ...,  0.4473,  1.8203, -1.0000],
         [ 1.2812, -0.7305, -1.2734,  ..., -0.7500, -0.3809, -2.3281]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[-0.0125,  0.0048, -0.0069,  ...,  0.0050, -0.0082, -0.0070],
         [ 0.1729, -0.1914,  0.0452,  ...,  0.5156,  0.0698,  0.3691],
         [-1.0156,  0.7773,  0.3066,  ..., -0.1240,  0.8320,  0.8203],
         ...,
         [-0.3867,  0.8281, -0.7305,  ...,  0.2080, -0.0381, -0.7500],
         [-0.6133, -0.3047,  0.1768,  ...,  0.2129, -0.0342, -0.4980],
         [-0.0039,  0.4922, -0.2188,  ...,  0.1338, -0.0232, -0.2285]],

        [[-0.0125,  0.0048, -0.0069,  ...,  0.0050, -0.0082, -0.0070],
         [ 0.1729, -0.1914,  0.0452,  ...,  0.5156,  0.0698,  0.3691],
         [-1.0156,  0.7773,  0.3066,  ..., -0.1240,  0.8320,  0.8203],
         ...,
         [-0.3867,  0.8281, -0.7305,  ...,  0.2080, -0.0381, -0.7500],
         [-0.6133, -0.3047,  0.1768,  ...,  0.2129, -0.0342, -0.4980],
         [-0.0039,  0.4922, -0.2188,  ...,  0.1338, -0.0232, -0.2285]],

        [[-0.0125,  0.0048, -0.0069,  ...,  0.0050, -0.0082, -0.0070],
         [ 0.1729, -0.1914,  0.0452,  ...,  0.5156,  0.0698,  0.3691],
         [-1.0156,  0.7773,  0.3066,  ..., -0.1240,  0.8320,  0.8203],
         ...,
         [-0.3867,  0.8281, -0.7305,  ...,  0.2080, -0.0381, -0.7500],
         [-0.6133, -0.3047,  0.1768,  ...,  0.2129, -0.0342, -0.4980],
         [-0.0039,  0.4922, -0.2188,  ...,  0.1338, -0.0232, -0.2285]],

        ...,

        [[-0.0022, -0.0028, -0.0028,  ..., -0.0060,  0.0087, -0.0074],
         [ 0.0129,  0.1001,  0.2002,  ..., -0.2734,  0.4023,  0.2168],
         [ 0.3848,  0.5156,  0.6914,  ..., -0.5273,  0.2832,  0.9570],
         ...,
         [ 0.3203,  0.0518, -0.5703,  ...,  0.3145, -0.5781, -0.1191],
         [-0.0708,  0.5273, -0.2002,  ..., -0.3906, -0.0850,  0.4043],
         [ 0.3457,  0.4980,  0.2041,  ...,  0.0576,  0.2148,  0.0630]],

        [[-0.0022, -0.0028, -0.0028,  ..., -0.0060,  0.0087, -0.0074],
         [ 0.0129,  0.1001,  0.2002,  ..., -0.2734,  0.4023,  0.2168],
         [ 0.3848,  0.5156,  0.6914,  ..., -0.5273,  0.2832,  0.9570],
         ...,
         [ 0.3203,  0.0518, -0.5703,  ...,  0.3145, -0.5781, -0.1191],
         [-0.0708,  0.5273, -0.2002,  ..., -0.3906, -0.0850,  0.4043],
         [ 0.3457,  0.4980,  0.2041,  ...,  0.0576,  0.2148,  0.0630]],

        [[-0.0022, -0.0028, -0.0028,  ..., -0.0060,  0.0087, -0.0074],
         [ 0.0129,  0.1001,  0.2002,  ..., -0.2734,  0.4023,  0.2168],
         [ 0.3848,  0.5156,  0.6914,  ..., -0.5273,  0.2832,  0.9570],
         ...,
         [ 0.3203,  0.0518, -0.5703,  ...,  0.3145, -0.5781, -0.1191],
         [-0.0708,  0.5273, -0.2002,  ..., -0.3906, -0.0850,  0.4043],
         [ 0.3457,  0.4980,  0.2041,  ...,  0.0576,  0.2148,  0.0630]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[-1.5015e-02, -4.2725e-03, -8.5831e-04,  ..., -5.9375e-01,
          -1.5723e-01,  2.8076e-02],
         [-1.6641e+00,  2.8125e-01, -1.2188e+00,  ...,  2.2031e+00,
           6.6797e-01,  4.2188e-01],
         [-1.6406e-01,  1.0391e+00, -1.2695e-01,  ..., -1.5442e-02,
           2.4531e+00,  1.2109e+00],
         ...,
         [-7.9297e-01,  5.0000e-01, -4.8828e-01,  ...,  1.3516e+00,
           2.1719e+00, -9.2188e-01],
         [ 4.7266e-01, -1.6357e-02, -4.2578e-01,  ...,  8.0078e-01,
           1.0938e+00,  2.0898e-01],
         [ 6.9922e-01,  4.7266e-01, -1.2578e+00,  ...,  1.9531e+00,
           3.2617e-01,  7.8516e-01]],

        [[-1.5015e-02, -4.2725e-03, -8.5831e-04,  ..., -5.9375e-01,
          -1.5723e-01,  2.8076e-02],
         [-1.6641e+00,  2.8125e-01, -1.2188e+00,  ...,  2.2031e+00,
           6.6797e-01,  4.2188e-01],
         [-1.6406e-01,  1.0391e+00, -1.2695e-01,  ..., -1.5442e-02,
           2.4531e+00,  1.2109e+00],
         ...,
         [-7.9297e-01,  5.0000e-01, -4.8828e-01,  ...,  1.3516e+00,
           2.1719e+00, -9.2188e-01],
         [ 4.7266e-01, -1.6357e-02, -4.2578e-01,  ...,  8.0078e-01,
           1.0938e+00,  2.0898e-01],
         [ 6.9922e-01,  4.7266e-01, -1.2578e+00,  ...,  1.9531e+00,
           3.2617e-01,  7.8516e-01]],

        [[-1.5015e-02, -4.2725e-03, -8.5831e-04,  ..., -5.9375e-01,
          -1.5723e-01,  2.8076e-02],
         [-1.6641e+00,  2.8125e-01, -1.2188e+00,  ...,  2.2031e+00,
           6.6797e-01,  4.2188e-01],
         [-1.6406e-01,  1.0391e+00, -1.2695e-01,  ..., -1.5442e-02,
           2.4531e+00,  1.2109e+00],
         ...,
         [-7.9297e-01,  5.0000e-01, -4.8828e-01,  ...,  1.3516e+00,
           2.1719e+00, -9.2188e-01],
         [ 4.7266e-01, -1.6357e-02, -4.2578e-01,  ...,  8.0078e-01,
           1.0938e+00,  2.0898e-01],
         [ 6.9922e-01,  4.7266e-01, -1.2578e+00,  ...,  1.9531e+00,
           3.2617e-01,  7.8516e-01]],

        ...,

        [[ 3.1433e-03,  8.4839e-03,  4.5967e-04,  ...,  1.1914e-01,
           5.0781e-02, -9.8267e-03],
         [ 1.4688e+00, -8.5547e-01, -1.2734e+00,  ..., -4.9805e-01,
          -4.1504e-02,  1.1016e+00],
         [-1.2422e+00, -9.2969e-01, -9.9219e-01,  ...,  5.4297e-01,
          -1.6328e+00,  5.5078e-01],
         ...,
         [-3.9062e-02, -6.8750e-01, -1.0391e+00,  ...,  3.6523e-01,
           3.0469e-01,  1.6172e+00],
         [-9.2188e-01,  9.9609e-02, -2.0898e-01,  ...,  7.7344e-01,
          -9.6875e-01,  2.2188e+00],
         [-9.2188e-01, -6.5625e-01, -1.4688e+00,  ..., -6.6016e-01,
          -2.5391e-01,  9.0234e-01]],

        [[ 3.1433e-03,  8.4839e-03,  4.5967e-04,  ...,  1.1914e-01,
           5.0781e-02, -9.8267e-03],
         [ 1.4688e+00, -8.5547e-01, -1.2734e+00,  ..., -4.9805e-01,
          -4.1504e-02,  1.1016e+00],
         [-1.2422e+00, -9.2969e-01, -9.9219e-01,  ...,  5.4297e-01,
          -1.6328e+00,  5.5078e-01],
         ...,
         [-3.9062e-02, -6.8750e-01, -1.0391e+00,  ...,  3.6523e-01,
           3.0469e-01,  1.6172e+00],
         [-9.2188e-01,  9.9609e-02, -2.0898e-01,  ...,  7.7344e-01,
          -9.6875e-01,  2.2188e+00],
         [-9.2188e-01, -6.5625e-01, -1.4688e+00,  ..., -6.6016e-01,
          -2.5391e-01,  9.0234e-01]],

        [[ 3.1433e-03,  8.4839e-03,  4.5967e-04,  ...,  1.1914e-01,
           5.0781e-02, -9.8267e-03],
         [ 1.4688e+00, -8.5547e-01, -1.2734e+00,  ..., -4.9805e-01,
          -4.1504e-02,  1.1016e+00],
         [-1.2422e+00, -9.2969e-01, -9.9219e-01,  ...,  5.4297e-01,
          -1.6328e+00,  5.5078e-01],
         ...,
         [-3.9062e-02, -6.8750e-01, -1.0391e+00,  ...,  3.6523e-01,
           3.0469e-01,  1.6172e+00],
         [-9.2188e-01,  9.9609e-02, -2.0898e-01,  ...,  7.7344e-01,
          -9.6875e-01,  2.2188e+00],
         [-9.2188e-01, -6.5625e-01, -1.4688e+00,  ..., -6.6016e-01,
          -2.5391e-01,  9.0234e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-2.3651e-03, -5.9204e-03,  4.3030e-03,  ...,  5.5313e-04,
           9.9182e-04, -5.1880e-03],
         [ 4.3750e-01,  5.5859e-01,  6.6016e-01,  ...,  1.4160e-01,
          -8.3594e-01, -5.1514e-02],
         [ 5.1953e-01, -4.2725e-02,  1.4062e-01,  ..., -4.8828e-01,
           2.5586e-01,  2.6367e-01],
         ...,
         [-2.7344e-01,  3.1445e-01, -3.1982e-02,  ...,  1.6797e-01,
          -4.8047e-01,  9.5312e-01],
         [ 3.2617e-01,  4.5703e-01, -5.0391e-01,  ..., -3.3398e-01,
           1.3379e-01,  5.8203e-01],
         [-5.2734e-02,  7.9688e-01, -9.8145e-02,  ..., -4.6875e-02,
          -8.3203e-01,  1.7676e-01]],

        [[-2.3651e-03, -5.9204e-03,  4.3030e-03,  ...,  5.5313e-04,
           9.9182e-04, -5.1880e-03],
         [ 4.3750e-01,  5.5859e-01,  6.6016e-01,  ...,  1.4160e-01,
          -8.3594e-01, -5.1514e-02],
         [ 5.1953e-01, -4.2725e-02,  1.4062e-01,  ..., -4.8828e-01,
           2.5586e-01,  2.6367e-01],
         ...,
         [-2.7344e-01,  3.1445e-01, -3.1982e-02,  ...,  1.6797e-01,
          -4.8047e-01,  9.5312e-01],
         [ 3.2617e-01,  4.5703e-01, -5.0391e-01,  ..., -3.3398e-01,
           1.3379e-01,  5.8203e-01],
         [-5.2734e-02,  7.9688e-01, -9.8145e-02,  ..., -4.6875e-02,
          -8.3203e-01,  1.7676e-01]],

        [[-2.3651e-03, -5.9204e-03,  4.3030e-03,  ...,  5.5313e-04,
           9.9182e-04, -5.1880e-03],
         [ 4.3750e-01,  5.5859e-01,  6.6016e-01,  ...,  1.4160e-01,
          -8.3594e-01, -5.1514e-02],
         [ 5.1953e-01, -4.2725e-02,  1.4062e-01,  ..., -4.8828e-01,
           2.5586e-01,  2.6367e-01],
         ...,
         [-2.7344e-01,  3.1445e-01, -3.1982e-02,  ...,  1.6797e-01,
          -4.8047e-01,  9.5312e-01],
         [ 3.2617e-01,  4.5703e-01, -5.0391e-01,  ..., -3.3398e-01,
           1.3379e-01,  5.8203e-01],
         [-5.2734e-02,  7.9688e-01, -9.8145e-02,  ..., -4.6875e-02,
          -8.3203e-01,  1.7676e-01]],

        ...,

        [[-4.8065e-04,  9.7275e-04,  5.5176e-02,  ..., -3.0518e-02,
          -1.1108e-02,  2.5635e-02],
         [-2.1680e-01, -1.3750e+00, -3.3984e-01,  ...,  4.8438e-01,
          -1.5039e-01,  7.4219e-01],
         [-1.3965e-01, -8.9453e-01, -3.2617e-01,  ...,  3.6914e-01,
          -7.4609e-01,  1.6699e-01],
         ...,
         [-8.2031e-01, -1.6602e-01,  1.9043e-01,  ...,  1.2656e+00,
          -1.7188e-01,  1.3906e+00],
         [-8.9844e-01,  3.6328e-01,  1.3379e-01,  ...,  3.0469e-01,
          -3.5156e-01,  5.1172e-01],
         [ 1.1094e+00, -8.5156e-01, -4.6680e-01,  ..., -1.9629e-01,
           1.2598e-01, -5.4688e-01]],

        [[-4.8065e-04,  9.7275e-04,  5.5176e-02,  ..., -3.0518e-02,
          -1.1108e-02,  2.5635e-02],
         [-2.1680e-01, -1.3750e+00, -3.3984e-01,  ...,  4.8438e-01,
          -1.5039e-01,  7.4219e-01],
         [-1.3965e-01, -8.9453e-01, -3.2617e-01,  ...,  3.6914e-01,
          -7.4609e-01,  1.6699e-01],
         ...,
         [-8.2031e-01, -1.6602e-01,  1.9043e-01,  ...,  1.2656e+00,
          -1.7188e-01,  1.3906e+00],
         [-8.9844e-01,  3.6328e-01,  1.3379e-01,  ...,  3.0469e-01,
          -3.5156e-01,  5.1172e-01],
         [ 1.1094e+00, -8.5156e-01, -4.6680e-01,  ..., -1.9629e-01,
           1.2598e-01, -5.4688e-01]],

        [[-4.8065e-04,  9.7275e-04,  5.5176e-02,  ..., -3.0518e-02,
          -1.1108e-02,  2.5635e-02],
         [-2.1680e-01, -1.3750e+00, -3.3984e-01,  ...,  4.8438e-01,
          -1.5039e-01,  7.4219e-01],
         [-1.3965e-01, -8.9453e-01, -3.2617e-01,  ...,  3.6914e-01,
          -7.4609e-01,  1.6699e-01],
         ...,
         [-8.2031e-01, -1.6602e-01,  1.9043e-01,  ...,  1.2656e+00,
          -1.7188e-01,  1.3906e+00],
         [-8.9844e-01,  3.6328e-01,  1.3379e-01,  ...,  3.0469e-01,
          -3.5156e-01,  5.1172e-01],
         [ 1.1094e+00, -8.5156e-01, -4.6680e-01,  ..., -1.9629e-01,
           1.2598e-01, -5.4688e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.1780e-02, -3.5858e-03, -3.2196e-03,  ..., -3.6328e-01,
          -1.3733e-02, -4.6875e-02],
         [-1.5234e-01,  1.4062e-01, -9.8438e-01,  ..., -8.7500e-01,
           1.2969e+00,  1.2969e+00],
         [ 8.9062e-01, -1.4062e+00, -6.4062e-01,  ...,  1.8828e+00,
           2.2656e+00,  1.3438e+00],
         ...,
         [ 1.4062e+00, -1.9531e-02,  2.8125e-01,  ..., -1.6895e-01,
          -1.0156e+00, -1.1182e-01],
         [ 2.1484e-02,  2.8711e-01, -2.1875e-01,  ...,  1.9609e+00,
          -1.7031e+00, -1.3750e+00],
         [ 3.4766e-01,  8.7500e-01, -1.2266e+00,  ...,  1.3203e+00,
           2.4688e+00,  1.8516e+00]],

        [[ 1.1780e-02, -3.5858e-03, -3.2196e-03,  ..., -3.6328e-01,
          -1.3733e-02, -4.6875e-02],
         [-1.5234e-01,  1.4062e-01, -9.8438e-01,  ..., -8.7500e-01,
           1.2969e+00,  1.2969e+00],
         [ 8.9062e-01, -1.4062e+00, -6.4062e-01,  ...,  1.8828e+00,
           2.2656e+00,  1.3438e+00],
         ...,
         [ 1.4062e+00, -1.9531e-02,  2.8125e-01,  ..., -1.6895e-01,
          -1.0156e+00, -1.1182e-01],
         [ 2.1484e-02,  2.8711e-01, -2.1875e-01,  ...,  1.9609e+00,
          -1.7031e+00, -1.3750e+00],
         [ 3.4766e-01,  8.7500e-01, -1.2266e+00,  ...,  1.3203e+00,
           2.4688e+00,  1.8516e+00]],

        [[ 1.1780e-02, -3.5858e-03, -3.2196e-03,  ..., -3.6328e-01,
          -1.3733e-02, -4.6875e-02],
         [-1.5234e-01,  1.4062e-01, -9.8438e-01,  ..., -8.7500e-01,
           1.2969e+00,  1.2969e+00],
         [ 8.9062e-01, -1.4062e+00, -6.4062e-01,  ...,  1.8828e+00,
           2.2656e+00,  1.3438e+00],
         ...,
         [ 1.4062e+00, -1.9531e-02,  2.8125e-01,  ..., -1.6895e-01,
          -1.0156e+00, -1.1182e-01],
         [ 2.1484e-02,  2.8711e-01, -2.1875e-01,  ...,  1.9609e+00,
          -1.7031e+00, -1.3750e+00],
         [ 3.4766e-01,  8.7500e-01, -1.2266e+00,  ...,  1.3203e+00,
           2.4688e+00,  1.8516e+00]],

        ...,

        [[-6.5613e-04,  1.4099e-02, -2.3956e-03,  ..., -2.5977e-01,
          -1.6406e-01, -4.7607e-03],
         [-1.0156e+00,  3.3203e-01,  2.2656e+00,  ...,  1.3516e+00,
          -9.7266e-01, -3.9844e-01],
         [-2.8750e+00,  5.5469e-01,  2.0625e+00,  ...,  1.0391e+00,
          -9.7168e-02, -1.7109e+00],
         ...,
         [-2.4375e+00,  1.5469e+00,  2.6172e-01,  ..., -9.8438e-01,
           8.1641e-01, -4.6680e-01],
         [-5.3906e-01,  5.1172e-01,  2.4219e-01,  ..., -1.3672e+00,
          -6.3672e-01,  6.7188e-01],
         [-8.7109e-01,  7.7344e-01,  8.2812e-01,  ..., -1.8652e-01,
           4.8633e-01, -2.4023e-01]],

        [[-6.5613e-04,  1.4099e-02, -2.3956e-03,  ..., -2.5977e-01,
          -1.6406e-01, -4.7607e-03],
         [-1.0156e+00,  3.3203e-01,  2.2656e+00,  ...,  1.3516e+00,
          -9.7266e-01, -3.9844e-01],
         [-2.8750e+00,  5.5469e-01,  2.0625e+00,  ...,  1.0391e+00,
          -9.7168e-02, -1.7109e+00],
         ...,
         [-2.4375e+00,  1.5469e+00,  2.6172e-01,  ..., -9.8438e-01,
           8.1641e-01, -4.6680e-01],
         [-5.3906e-01,  5.1172e-01,  2.4219e-01,  ..., -1.3672e+00,
          -6.3672e-01,  6.7188e-01],
         [-8.7109e-01,  7.7344e-01,  8.2812e-01,  ..., -1.8652e-01,
           4.8633e-01, -2.4023e-01]],

        [[-6.5613e-04,  1.4099e-02, -2.3956e-03,  ..., -2.5977e-01,
          -1.6406e-01, -4.7607e-03],
         [-1.0156e+00,  3.3203e-01,  2.2656e+00,  ...,  1.3516e+00,
          -9.7266e-01, -3.9844e-01],
         [-2.8750e+00,  5.5469e-01,  2.0625e+00,  ...,  1.0391e+00,
          -9.7168e-02, -1.7109e+00],
         ...,
         [-2.4375e+00,  1.5469e+00,  2.6172e-01,  ..., -9.8438e-01,
           8.1641e-01, -4.6680e-01],
         [-5.3906e-01,  5.1172e-01,  2.4219e-01,  ..., -1.3672e+00,
          -6.3672e-01,  6.7188e-01],
         [-8.7109e-01,  7.7344e-01,  8.2812e-01,  ..., -1.8652e-01,
           4.8633e-01, -2.4023e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 3.7842e-02, -7.5073e-03, -7.4768e-04,  ...,  8.5449e-02,
           3.0762e-02,  1.2695e-02],
         [-6.9141e-01,  4.3359e-01, -2.8320e-02,  ...,  3.4570e-01,
          -2.2949e-01, -1.6699e-01],
         [-4.7070e-01,  7.6172e-02, -4.6631e-02,  ...,  1.8750e-01,
          -2.4219e-01, -5.5859e-01],
         ...,
         [-2.9053e-02,  5.7422e-01,  4.2578e-01,  ..., -4.7461e-01,
          -7.3828e-01, -3.1641e-01],
         [-7.1875e-01,  4.4141e-01,  6.3672e-01,  ..., -1.6797e-01,
          -6.6406e-01,  3.8867e-01],
         [ 4.3945e-01,  5.2734e-01, -9.7656e-02,  ...,  3.7305e-01,
          -2.3340e-01, -9.1406e-01]],

        [[ 3.7842e-02, -7.5073e-03, -7.4768e-04,  ...,  8.5449e-02,
           3.0762e-02,  1.2695e-02],
         [-6.9141e-01,  4.3359e-01, -2.8320e-02,  ...,  3.4570e-01,
          -2.2949e-01, -1.6699e-01],
         [-4.7070e-01,  7.6172e-02, -4.6631e-02,  ...,  1.8750e-01,
          -2.4219e-01, -5.5859e-01],
         ...,
         [-2.9053e-02,  5.7422e-01,  4.2578e-01,  ..., -4.7461e-01,
          -7.3828e-01, -3.1641e-01],
         [-7.1875e-01,  4.4141e-01,  6.3672e-01,  ..., -1.6797e-01,
          -6.6406e-01,  3.8867e-01],
         [ 4.3945e-01,  5.2734e-01, -9.7656e-02,  ...,  3.7305e-01,
          -2.3340e-01, -9.1406e-01]],

        [[ 3.7842e-02, -7.5073e-03, -7.4768e-04,  ...,  8.5449e-02,
           3.0762e-02,  1.2695e-02],
         [-6.9141e-01,  4.3359e-01, -2.8320e-02,  ...,  3.4570e-01,
          -2.2949e-01, -1.6699e-01],
         [-4.7070e-01,  7.6172e-02, -4.6631e-02,  ...,  1.8750e-01,
          -2.4219e-01, -5.5859e-01],
         ...,
         [-2.9053e-02,  5.7422e-01,  4.2578e-01,  ..., -4.7461e-01,
          -7.3828e-01, -3.1641e-01],
         [-7.1875e-01,  4.4141e-01,  6.3672e-01,  ..., -1.6797e-01,
          -6.6406e-01,  3.8867e-01],
         [ 4.3945e-01,  5.2734e-01, -9.7656e-02,  ...,  3.7305e-01,
          -2.3340e-01, -9.1406e-01]],

        ...,

        [[ 1.3580e-03, -1.8082e-03, -3.6812e-04,  ...,  1.6861e-03,
          -4.3945e-03,  9.2163e-03],
         [ 6.5625e-01, -1.4954e-03, -7.1484e-01,  ...,  2.5000e-01,
           8.4375e-01,  1.0781e+00],
         [-1.8555e-01, -9.1309e-02, -5.0000e-01,  ..., -4.0430e-01,
           1.7188e-01,  5.8203e-01],
         ...,
         [ 3.4961e-01,  1.2158e-01,  7.4609e-01,  ..., -4.6875e-02,
           2.7930e-01, -2.1289e-01],
         [-7.7637e-02, -1.1963e-01,  9.2578e-01,  ...,  5.3906e-01,
          -6.4844e-01, -1.7285e-01],
         [-9.8828e-01, -4.9414e-01,  4.3945e-01,  ...,  2.1210e-03,
           6.1719e-01, -1.0205e-01]],

        [[ 1.3580e-03, -1.8082e-03, -3.6812e-04,  ...,  1.6861e-03,
          -4.3945e-03,  9.2163e-03],
         [ 6.5625e-01, -1.4954e-03, -7.1484e-01,  ...,  2.5000e-01,
           8.4375e-01,  1.0781e+00],
         [-1.8555e-01, -9.1309e-02, -5.0000e-01,  ..., -4.0430e-01,
           1.7188e-01,  5.8203e-01],
         ...,
         [ 3.4961e-01,  1.2158e-01,  7.4609e-01,  ..., -4.6875e-02,
           2.7930e-01, -2.1289e-01],
         [-7.7637e-02, -1.1963e-01,  9.2578e-01,  ...,  5.3906e-01,
          -6.4844e-01, -1.7285e-01],
         [-9.8828e-01, -4.9414e-01,  4.3945e-01,  ...,  2.1210e-03,
           6.1719e-01, -1.0205e-01]],

        [[ 1.3580e-03, -1.8082e-03, -3.6812e-04,  ...,  1.6861e-03,
          -4.3945e-03,  9.2163e-03],
         [ 6.5625e-01, -1.4954e-03, -7.1484e-01,  ...,  2.5000e-01,
           8.4375e-01,  1.0781e+00],
         [-1.8555e-01, -9.1309e-02, -5.0000e-01,  ..., -4.0430e-01,
           1.7188e-01,  5.8203e-01],
         ...,
         [ 3.4961e-01,  1.2158e-01,  7.4609e-01,  ..., -4.6875e-02,
           2.7930e-01, -2.1289e-01],
         [-7.7637e-02, -1.1963e-01,  9.2578e-01,  ...,  5.3906e-01,
          -6.4844e-01, -1.7285e-01],
         [-9.8828e-01, -4.9414e-01,  4.3945e-01,  ...,  2.1210e-03,
           6.1719e-01, -1.0205e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.1719e-02,  1.8799e-02,  3.8605e-03,  ..., -1.9629e-01,
          -1.7090e-01, -1.7578e-01],
         [ 3.2344e+00, -3.0781e+00,  3.3203e-01,  ...,  5.3516e-01,
          -2.3125e+00,  9.1016e-01],
         [ 3.6914e-01,  1.4160e-01,  1.7031e+00,  ...,  9.1016e-01,
           3.7500e-01, -1.6602e-01],
         ...,
         [ 1.1328e-01, -4.3555e-01,  1.9922e+00,  ..., -1.1016e+00,
          -9.2969e-01,  2.3281e+00],
         [-1.2812e+00, -5.1953e-01,  6.0938e-01,  ...,  9.6191e-02,
          -8.7891e-01,  1.8594e+00],
         [-9.6094e-01,  1.2969e+00,  1.3203e+00,  ...,  1.0078e+00,
           2.4375e+00, -1.6719e+00]],

        [[-1.1719e-02,  1.8799e-02,  3.8605e-03,  ..., -1.9629e-01,
          -1.7090e-01, -1.7578e-01],
         [ 3.2344e+00, -3.0781e+00,  3.3203e-01,  ...,  5.3516e-01,
          -2.3125e+00,  9.1016e-01],
         [ 3.6914e-01,  1.4160e-01,  1.7031e+00,  ...,  9.1016e-01,
           3.7500e-01, -1.6602e-01],
         ...,
         [ 1.1328e-01, -4.3555e-01,  1.9922e+00,  ..., -1.1016e+00,
          -9.2969e-01,  2.3281e+00],
         [-1.2812e+00, -5.1953e-01,  6.0938e-01,  ...,  9.6191e-02,
          -8.7891e-01,  1.8594e+00],
         [-9.6094e-01,  1.2969e+00,  1.3203e+00,  ...,  1.0078e+00,
           2.4375e+00, -1.6719e+00]],

        [[-1.1719e-02,  1.8799e-02,  3.8605e-03,  ..., -1.9629e-01,
          -1.7090e-01, -1.7578e-01],
         [ 3.2344e+00, -3.0781e+00,  3.3203e-01,  ...,  5.3516e-01,
          -2.3125e+00,  9.1016e-01],
         [ 3.6914e-01,  1.4160e-01,  1.7031e+00,  ...,  9.1016e-01,
           3.7500e-01, -1.6602e-01],
         ...,
         [ 1.1328e-01, -4.3555e-01,  1.9922e+00,  ..., -1.1016e+00,
          -9.2969e-01,  2.3281e+00],
         [-1.2812e+00, -5.1953e-01,  6.0938e-01,  ...,  9.6191e-02,
          -8.7891e-01,  1.8594e+00],
         [-9.6094e-01,  1.2969e+00,  1.3203e+00,  ...,  1.0078e+00,
           2.4375e+00, -1.6719e+00]],

        ...,

        [[ 6.1798e-04,  1.2024e-02, -1.5076e-02,  ...,  1.9141e-01,
           4.7656e-01,  4.5703e-01],
         [-3.6406e+00, -1.2109e-01,  1.8203e+00,  ...,  1.6797e+00,
          -2.1719e+00,  1.7109e+00],
         [-1.5938e+00,  1.7383e-01,  1.0625e+00,  ...,  1.1250e+00,
          -4.1875e+00,  7.7637e-02],
         ...,
         [-2.0469e+00,  1.5859e+00, -5.2344e-01,  ..., -7.0703e-01,
          -3.7812e+00,  5.1514e-02],
         [ 3.9062e-01,  1.1719e+00,  5.7812e-01,  ..., -4.3750e-01,
          -3.6250e+00, -1.8047e+00],
         [-4.4727e-01,  1.0156e+00,  1.9141e-01,  ...,  2.3281e+00,
          -1.2969e+00,  8.4375e-01]],

        [[ 6.1798e-04,  1.2024e-02, -1.5076e-02,  ...,  1.9141e-01,
           4.7656e-01,  4.5703e-01],
         [-3.6406e+00, -1.2109e-01,  1.8203e+00,  ...,  1.6797e+00,
          -2.1719e+00,  1.7109e+00],
         [-1.5938e+00,  1.7383e-01,  1.0625e+00,  ...,  1.1250e+00,
          -4.1875e+00,  7.7637e-02],
         ...,
         [-2.0469e+00,  1.5859e+00, -5.2344e-01,  ..., -7.0703e-01,
          -3.7812e+00,  5.1514e-02],
         [ 3.9062e-01,  1.1719e+00,  5.7812e-01,  ..., -4.3750e-01,
          -3.6250e+00, -1.8047e+00],
         [-4.4727e-01,  1.0156e+00,  1.9141e-01,  ...,  2.3281e+00,
          -1.2969e+00,  8.4375e-01]],

        [[ 6.1798e-04,  1.2024e-02, -1.5076e-02,  ...,  1.9141e-01,
           4.7656e-01,  4.5703e-01],
         [-3.6406e+00, -1.2109e-01,  1.8203e+00,  ...,  1.6797e+00,
          -2.1719e+00,  1.7109e+00],
         [-1.5938e+00,  1.7383e-01,  1.0625e+00,  ...,  1.1250e+00,
          -4.1875e+00,  7.7637e-02],
         ...,
         [-2.0469e+00,  1.5859e+00, -5.2344e-01,  ..., -7.0703e-01,
          -3.7812e+00,  5.1514e-02],
         [ 3.9062e-01,  1.1719e+00,  5.7812e-01,  ..., -4.3750e-01,
          -3.6250e+00, -1.8047e+00],
         [-4.4727e-01,  1.0156e+00,  1.9141e-01,  ...,  2.3281e+00,
          -1.2969e+00,  8.4375e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-3.8300e-03,  1.1749e-03, -5.0964e-03,  ...,  1.0193e-02,
           1.3489e-02, -1.4221e-02],
         [ 2.2559e-01, -1.7212e-02, -2.6562e-01,  ...,  2.5586e-01,
           1.8457e-01,  7.3828e-01],
         [-4.7656e-01,  3.0664e-01,  2.2559e-01,  ..., -1.1084e-01,
          -3.3594e-01, -2.4023e-01],
         ...,
         [ 7.6562e-01,  3.4766e-01, -1.2031e+00,  ..., -3.7891e-01,
          -2.8564e-02, -1.4453e-01],
         [ 2.3535e-01,  3.3789e-01, -4.4922e-01,  ..., -8.3008e-02,
          -3.1641e-01, -2.4707e-01],
         [ 1.1816e-01, -5.2344e-01,  7.5391e-01,  ..., -2.3047e-01,
           4.6680e-01,  3.7305e-01]],

        [[-3.8300e-03,  1.1749e-03, -5.0964e-03,  ...,  1.0193e-02,
           1.3489e-02, -1.4221e-02],
         [ 2.2559e-01, -1.7212e-02, -2.6562e-01,  ...,  2.5586e-01,
           1.8457e-01,  7.3828e-01],
         [-4.7656e-01,  3.0664e-01,  2.2559e-01,  ..., -1.1084e-01,
          -3.3594e-01, -2.4023e-01],
         ...,
         [ 7.6562e-01,  3.4766e-01, -1.2031e+00,  ..., -3.7891e-01,
          -2.8564e-02, -1.4453e-01],
         [ 2.3535e-01,  3.3789e-01, -4.4922e-01,  ..., -8.3008e-02,
          -3.1641e-01, -2.4707e-01],
         [ 1.1816e-01, -5.2344e-01,  7.5391e-01,  ..., -2.3047e-01,
           4.6680e-01,  3.7305e-01]],

        [[-3.8300e-03,  1.1749e-03, -5.0964e-03,  ...,  1.0193e-02,
           1.3489e-02, -1.4221e-02],
         [ 2.2559e-01, -1.7212e-02, -2.6562e-01,  ...,  2.5586e-01,
           1.8457e-01,  7.3828e-01],
         [-4.7656e-01,  3.0664e-01,  2.2559e-01,  ..., -1.1084e-01,
          -3.3594e-01, -2.4023e-01],
         ...,
         [ 7.6562e-01,  3.4766e-01, -1.2031e+00,  ..., -3.7891e-01,
          -2.8564e-02, -1.4453e-01],
         [ 2.3535e-01,  3.3789e-01, -4.4922e-01,  ..., -8.3008e-02,
          -3.1641e-01, -2.4707e-01],
         [ 1.1816e-01, -5.2344e-01,  7.5391e-01,  ..., -2.3047e-01,
           4.6680e-01,  3.7305e-01]],

        ...,

        [[-8.8501e-03, -3.8452e-03, -2.5940e-03,  ...,  2.5635e-03,
           3.6469e-03, -3.3875e-03],
         [ 1.6504e-01, -4.8242e-01, -6.4453e-02,  ..., -2.0508e-01,
          -2.4902e-01, -2.7344e-01],
         [-7.1484e-01,  8.3203e-01,  2.2656e-01,  ..., -7.3242e-02,
          -9.1797e-01,  1.1016e+00],
         ...,
         [-3.3984e-01,  3.8867e-01, -4.2188e-01,  ..., -2.5000e-01,
          -3.8867e-01, -3.2422e-01],
         [-1.3965e-01,  4.3750e-01, -3.8281e-01,  ...,  2.6758e-01,
           6.6895e-02, -2.2168e-01],
         [ 1.0156e+00, -1.1328e-01, -8.1250e-01,  ..., -5.2979e-02,
           1.0205e-01,  6.6406e-01]],

        [[-8.8501e-03, -3.8452e-03, -2.5940e-03,  ...,  2.5635e-03,
           3.6469e-03, -3.3875e-03],
         [ 1.6504e-01, -4.8242e-01, -6.4453e-02,  ..., -2.0508e-01,
          -2.4902e-01, -2.7344e-01],
         [-7.1484e-01,  8.3203e-01,  2.2656e-01,  ..., -7.3242e-02,
          -9.1797e-01,  1.1016e+00],
         ...,
         [-3.3984e-01,  3.8867e-01, -4.2188e-01,  ..., -2.5000e-01,
          -3.8867e-01, -3.2422e-01],
         [-1.3965e-01,  4.3750e-01, -3.8281e-01,  ...,  2.6758e-01,
           6.6895e-02, -2.2168e-01],
         [ 1.0156e+00, -1.1328e-01, -8.1250e-01,  ..., -5.2979e-02,
           1.0205e-01,  6.6406e-01]],

        [[-8.8501e-03, -3.8452e-03, -2.5940e-03,  ...,  2.5635e-03,
           3.6469e-03, -3.3875e-03],
         [ 1.6504e-01, -4.8242e-01, -6.4453e-02,  ..., -2.0508e-01,
          -2.4902e-01, -2.7344e-01],
         [-7.1484e-01,  8.3203e-01,  2.2656e-01,  ..., -7.3242e-02,
          -9.1797e-01,  1.1016e+00],
         ...,
         [-3.3984e-01,  3.8867e-01, -4.2188e-01,  ..., -2.5000e-01,
          -3.8867e-01, -3.2422e-01],
         [-1.3965e-01,  4.3750e-01, -3.8281e-01,  ...,  2.6758e-01,
           6.6895e-02, -2.2168e-01],
         [ 1.0156e+00, -1.1328e-01, -8.1250e-01,  ..., -5.2979e-02,
           1.0205e-01,  6.6406e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 3.0060e-03,  8.0872e-04, -1.0437e-02,  ..., -2.7344e-02,
          -3.0518e-03, -2.2461e-01],
         [-1.8594e+00,  2.4688e+00,  2.5000e+00,  ...,  1.6328e+00,
          -5.8350e-02,  1.6328e+00],
         [ 1.2500e+00, -1.9336e-01,  1.2656e+00,  ...,  1.1562e+00,
          -5.8203e-01,  1.1562e+00],
         ...,
         [-4.9609e-01, -1.7188e+00, -6.1328e-01,  ...,  2.2656e+00,
          -6.4844e-01, -3.7500e-01],
         [-3.4570e-01, -4.0820e-01, -6.8750e-01,  ...,  2.1094e+00,
           9.8145e-02, -2.5000e+00],
         [ 1.5625e+00, -4.1797e-01,  9.7266e-01,  ...,  3.0469e-01,
          -7.7344e-01, -5.6641e-01]],

        [[ 3.0060e-03,  8.0872e-04, -1.0437e-02,  ..., -2.7344e-02,
          -3.0518e-03, -2.2461e-01],
         [-1.8594e+00,  2.4688e+00,  2.5000e+00,  ...,  1.6328e+00,
          -5.8350e-02,  1.6328e+00],
         [ 1.2500e+00, -1.9336e-01,  1.2656e+00,  ...,  1.1562e+00,
          -5.8203e-01,  1.1562e+00],
         ...,
         [-4.9609e-01, -1.7188e+00, -6.1328e-01,  ...,  2.2656e+00,
          -6.4844e-01, -3.7500e-01],
         [-3.4570e-01, -4.0820e-01, -6.8750e-01,  ...,  2.1094e+00,
           9.8145e-02, -2.5000e+00],
         [ 1.5625e+00, -4.1797e-01,  9.7266e-01,  ...,  3.0469e-01,
          -7.7344e-01, -5.6641e-01]],

        [[ 3.0060e-03,  8.0872e-04, -1.0437e-02,  ..., -2.7344e-02,
          -3.0518e-03, -2.2461e-01],
         [-1.8594e+00,  2.4688e+00,  2.5000e+00,  ...,  1.6328e+00,
          -5.8350e-02,  1.6328e+00],
         [ 1.2500e+00, -1.9336e-01,  1.2656e+00,  ...,  1.1562e+00,
          -5.8203e-01,  1.1562e+00],
         ...,
         [-4.9609e-01, -1.7188e+00, -6.1328e-01,  ...,  2.2656e+00,
          -6.4844e-01, -3.7500e-01],
         [-3.4570e-01, -4.0820e-01, -6.8750e-01,  ...,  2.1094e+00,
           9.8145e-02, -2.5000e+00],
         [ 1.5625e+00, -4.1797e-01,  9.7266e-01,  ...,  3.0469e-01,
          -7.7344e-01, -5.6641e-01]],

        ...,

        [[-8.2779e-04,  5.0964e-03, -1.3123e-03,  ...,  2.4609e-01,
           7.2656e-01,  5.5078e-01],
         [-1.8203e+00,  7.3047e-01, -2.6875e+00,  ...,  4.5312e-01,
          -3.4219e+00,  3.5156e-01],
         [ 1.9824e-01,  3.9453e-01, -7.2266e-01,  ...,  5.5859e-01,
          -6.7500e+00,  4.4141e-01],
         ...,
         [-1.0469e+00, -7.5391e-01, -6.4062e-01,  ...,  2.3438e+00,
          -4.1875e+00,  2.3438e+00],
         [ 9.7656e-02,  8.2812e-01, -1.1572e-01,  ..., -1.5332e-01,
          -6.5312e+00,  3.3594e+00],
         [ 8.3984e-01,  3.0469e-01, -1.4531e+00,  ..., -1.3359e+00,
          -2.6719e+00,  3.0000e+00]],

        [[-8.2779e-04,  5.0964e-03, -1.3123e-03,  ...,  2.4609e-01,
           7.2656e-01,  5.5078e-01],
         [-1.8203e+00,  7.3047e-01, -2.6875e+00,  ...,  4.5312e-01,
          -3.4219e+00,  3.5156e-01],
         [ 1.9824e-01,  3.9453e-01, -7.2266e-01,  ...,  5.5859e-01,
          -6.7500e+00,  4.4141e-01],
         ...,
         [-1.0469e+00, -7.5391e-01, -6.4062e-01,  ...,  2.3438e+00,
          -4.1875e+00,  2.3438e+00],
         [ 9.7656e-02,  8.2812e-01, -1.1572e-01,  ..., -1.5332e-01,
          -6.5312e+00,  3.3594e+00],
         [ 8.3984e-01,  3.0469e-01, -1.4531e+00,  ..., -1.3359e+00,
          -2.6719e+00,  3.0000e+00]],

        [[-8.2779e-04,  5.0964e-03, -1.3123e-03,  ...,  2.4609e-01,
           7.2656e-01,  5.5078e-01],
         [-1.8203e+00,  7.3047e-01, -2.6875e+00,  ...,  4.5312e-01,
          -3.4219e+00,  3.5156e-01],
         [ 1.9824e-01,  3.9453e-01, -7.2266e-01,  ...,  5.5859e-01,
          -6.7500e+00,  4.4141e-01],
         ...,
         [-1.0469e+00, -7.5391e-01, -6.4062e-01,  ...,  2.3438e+00,
          -4.1875e+00,  2.3438e+00],
         [ 9.7656e-02,  8.2812e-01, -1.1572e-01,  ..., -1.5332e-01,
          -6.5312e+00,  3.3594e+00],
         [ 8.3984e-01,  3.0469e-01, -1.4531e+00,  ..., -1.3359e+00,
          -2.6719e+00,  3.0000e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.3046e-03,  9.9487e-03, -1.2756e-02,  ..., -1.3000e-02,
           3.1281e-03, -3.1891e-03],
         [ 3.8672e-01, -2.0215e-01, -6.1035e-02,  ...,  8.5449e-02,
           4.7656e-01,  5.9570e-02],
         [ 1.0391e+00, -6.9336e-02, -2.5391e-01,  ..., -3.4375e-01,
          -4.6387e-02, -2.8906e-01],
         ...,
         [-6.8359e-01,  2.4902e-01,  1.6895e-01,  ..., -1.6016e-01,
          -2.2559e-01, -3.7305e-01],
         [-1.4844e-01,  6.0547e-01, -4.9805e-01,  ...,  9.2969e-01,
           1.5918e-01,  2.5977e-01],
         [ 1.1426e-01, -6.3965e-02, -1.1963e-01,  ..., -1.8359e-01,
          -4.8633e-01, -4.6680e-01]],

        [[-1.3046e-03,  9.9487e-03, -1.2756e-02,  ..., -1.3000e-02,
           3.1281e-03, -3.1891e-03],
         [ 3.8672e-01, -2.0215e-01, -6.1035e-02,  ...,  8.5449e-02,
           4.7656e-01,  5.9570e-02],
         [ 1.0391e+00, -6.9336e-02, -2.5391e-01,  ..., -3.4375e-01,
          -4.6387e-02, -2.8906e-01],
         ...,
         [-6.8359e-01,  2.4902e-01,  1.6895e-01,  ..., -1.6016e-01,
          -2.2559e-01, -3.7305e-01],
         [-1.4844e-01,  6.0547e-01, -4.9805e-01,  ...,  9.2969e-01,
           1.5918e-01,  2.5977e-01],
         [ 1.1426e-01, -6.3965e-02, -1.1963e-01,  ..., -1.8359e-01,
          -4.8633e-01, -4.6680e-01]],

        [[-1.3046e-03,  9.9487e-03, -1.2756e-02,  ..., -1.3000e-02,
           3.1281e-03, -3.1891e-03],
         [ 3.8672e-01, -2.0215e-01, -6.1035e-02,  ...,  8.5449e-02,
           4.7656e-01,  5.9570e-02],
         [ 1.0391e+00, -6.9336e-02, -2.5391e-01,  ..., -3.4375e-01,
          -4.6387e-02, -2.8906e-01],
         ...,
         [-6.8359e-01,  2.4902e-01,  1.6895e-01,  ..., -1.6016e-01,
          -2.2559e-01, -3.7305e-01],
         [-1.4844e-01,  6.0547e-01, -4.9805e-01,  ...,  9.2969e-01,
           1.5918e-01,  2.5977e-01],
         [ 1.1426e-01, -6.3965e-02, -1.1963e-01,  ..., -1.8359e-01,
          -4.8633e-01, -4.6680e-01]],

        ...,

        [[ 2.9785e-02, -4.5166e-02,  2.2949e-02,  ...,  2.6001e-02,
           3.2715e-02,  8.3618e-03],
         [-6.6016e-01, -8.2520e-02, -2.0410e-01,  ...,  9.5825e-03,
          -3.3008e-01, -7.3730e-02],
         [-8.3203e-01,  3.6719e-01, -1.0391e+00,  ..., -5.0781e-01,
           1.6211e-01, -2.3242e-01],
         ...,
         [-6.3672e-01,  3.8867e-01, -8.5547e-01,  ...,  1.1328e+00,
          -6.0547e-01,  4.8438e-01],
         [-2.8906e-01,  3.8818e-02, -1.3828e+00,  ...,  3.4766e-01,
          -3.5938e-01, -7.3242e-02],
         [-6.9141e-01, -2.5781e-01, -4.4727e-01,  ..., -5.3125e-01,
          -1.4746e-01,  2.6245e-02]],

        [[ 2.9785e-02, -4.5166e-02,  2.2949e-02,  ...,  2.6001e-02,
           3.2715e-02,  8.3618e-03],
         [-6.6016e-01, -8.2520e-02, -2.0410e-01,  ...,  9.5825e-03,
          -3.3008e-01, -7.3730e-02],
         [-8.3203e-01,  3.6719e-01, -1.0391e+00,  ..., -5.0781e-01,
           1.6211e-01, -2.3242e-01],
         ...,
         [-6.3672e-01,  3.8867e-01, -8.5547e-01,  ...,  1.1328e+00,
          -6.0547e-01,  4.8438e-01],
         [-2.8906e-01,  3.8818e-02, -1.3828e+00,  ...,  3.4766e-01,
          -3.5938e-01, -7.3242e-02],
         [-6.9141e-01, -2.5781e-01, -4.4727e-01,  ..., -5.3125e-01,
          -1.4746e-01,  2.6245e-02]],

        [[ 2.9785e-02, -4.5166e-02,  2.2949e-02,  ...,  2.6001e-02,
           3.2715e-02,  8.3618e-03],
         [-6.6016e-01, -8.2520e-02, -2.0410e-01,  ...,  9.5825e-03,
          -3.3008e-01, -7.3730e-02],
         [-8.3203e-01,  3.6719e-01, -1.0391e+00,  ..., -5.0781e-01,
           1.6211e-01, -2.3242e-01],
         ...,
         [-6.3672e-01,  3.8867e-01, -8.5547e-01,  ...,  1.1328e+00,
          -6.0547e-01,  4.8438e-01],
         [-2.8906e-01,  3.8818e-02, -1.3828e+00,  ...,  3.4766e-01,
          -3.5938e-01, -7.3242e-02],
         [-6.9141e-01, -2.5781e-01, -4.4727e-01,  ..., -5.3125e-01,
          -1.4746e-01,  2.6245e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.6602e-02,  7.8735e-03, -1.5869e-02,  ..., -4.2969e-01,
           3.4570e-01,  9.5215e-02],
         [-2.8281e+00, -1.3750e+00,  1.6641e+00,  ...,  1.5625e+00,
           3.0273e-02,  7.7637e-02],
         [ 5.7812e-01, -7.7344e-01,  1.5781e+00,  ...,  2.2969e+00,
           1.6641e+00, -5.8984e-01],
         ...,
         [ 1.9922e-01, -2.4844e+00,  3.3594e-01,  ...,  1.4453e+00,
          -1.4688e+00,  6.1719e-01],
         [ 2.2461e-02, -4.3359e-01, -5.1172e-01,  ...,  1.2734e+00,
          -8.7891e-01,  7.1094e-01],
         [ 1.3516e+00, -3.9844e-01,  9.4531e-01,  ...,  8.2031e-01,
           2.1289e-01, -2.2188e+00]],

        [[-1.6602e-02,  7.8735e-03, -1.5869e-02,  ..., -4.2969e-01,
           3.4570e-01,  9.5215e-02],
         [-2.8281e+00, -1.3750e+00,  1.6641e+00,  ...,  1.5625e+00,
           3.0273e-02,  7.7637e-02],
         [ 5.7812e-01, -7.7344e-01,  1.5781e+00,  ...,  2.2969e+00,
           1.6641e+00, -5.8984e-01],
         ...,
         [ 1.9922e-01, -2.4844e+00,  3.3594e-01,  ...,  1.4453e+00,
          -1.4688e+00,  6.1719e-01],
         [ 2.2461e-02, -4.3359e-01, -5.1172e-01,  ...,  1.2734e+00,
          -8.7891e-01,  7.1094e-01],
         [ 1.3516e+00, -3.9844e-01,  9.4531e-01,  ...,  8.2031e-01,
           2.1289e-01, -2.2188e+00]],

        [[-1.6602e-02,  7.8735e-03, -1.5869e-02,  ..., -4.2969e-01,
           3.4570e-01,  9.5215e-02],
         [-2.8281e+00, -1.3750e+00,  1.6641e+00,  ...,  1.5625e+00,
           3.0273e-02,  7.7637e-02],
         [ 5.7812e-01, -7.7344e-01,  1.5781e+00,  ...,  2.2969e+00,
           1.6641e+00, -5.8984e-01],
         ...,
         [ 1.9922e-01, -2.4844e+00,  3.3594e-01,  ...,  1.4453e+00,
          -1.4688e+00,  6.1719e-01],
         [ 2.2461e-02, -4.3359e-01, -5.1172e-01,  ...,  1.2734e+00,
          -8.7891e-01,  7.1094e-01],
         [ 1.3516e+00, -3.9844e-01,  9.4531e-01,  ...,  8.2031e-01,
           2.1289e-01, -2.2188e+00]],

        ...,

        [[-2.3315e-02, -1.9043e-02, -2.8687e-03,  ...,  7.2754e-02,
          -2.9219e+00, -4.1602e-01],
         [ 3.3203e-01, -5.1953e-01, -1.9922e-01,  ...,  5.7422e-01,
           8.7500e+00,  5.1875e+00],
         [-2.8516e-01, -3.2422e-01, -1.7578e-01,  ...,  8.8281e-01,
           8.1875e+00,  5.2500e+00],
         ...,
         [-1.0645e-01, -2.2461e-01,  6.9141e-01,  ...,  6.5625e-01,
           8.3125e+00, -1.0547e+00],
         [ 8.3008e-02,  2.9688e-01,  1.5430e-01,  ...,  9.1797e-01,
           5.9375e+00, -2.1094e+00],
         [-7.9688e-01,  5.8594e-01, -2.3340e-01,  ..., -1.5938e+00,
           7.2188e+00,  2.4688e+00]],

        [[-2.3315e-02, -1.9043e-02, -2.8687e-03,  ...,  7.2754e-02,
          -2.9219e+00, -4.1602e-01],
         [ 3.3203e-01, -5.1953e-01, -1.9922e-01,  ...,  5.7422e-01,
           8.7500e+00,  5.1875e+00],
         [-2.8516e-01, -3.2422e-01, -1.7578e-01,  ...,  8.8281e-01,
           8.1875e+00,  5.2500e+00],
         ...,
         [-1.0645e-01, -2.2461e-01,  6.9141e-01,  ...,  6.5625e-01,
           8.3125e+00, -1.0547e+00],
         [ 8.3008e-02,  2.9688e-01,  1.5430e-01,  ...,  9.1797e-01,
           5.9375e+00, -2.1094e+00],
         [-7.9688e-01,  5.8594e-01, -2.3340e-01,  ..., -1.5938e+00,
           7.2188e+00,  2.4688e+00]],

        [[-2.3315e-02, -1.9043e-02, -2.8687e-03,  ...,  7.2754e-02,
          -2.9219e+00, -4.1602e-01],
         [ 3.3203e-01, -5.1953e-01, -1.9922e-01,  ...,  5.7422e-01,
           8.7500e+00,  5.1875e+00],
         [-2.8516e-01, -3.2422e-01, -1.7578e-01,  ...,  8.8281e-01,
           8.1875e+00,  5.2500e+00],
         ...,
         [-1.0645e-01, -2.2461e-01,  6.9141e-01,  ...,  6.5625e-01,
           8.3125e+00, -1.0547e+00],
         [ 8.3008e-02,  2.9688e-01,  1.5430e-01,  ...,  9.1797e-01,
           5.9375e+00, -2.1094e+00],
         [-7.9688e-01,  5.8594e-01, -2.3340e-01,  ..., -1.5938e+00,
           7.2188e+00,  2.4688e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-2.7832e-02, -1.5869e-02, -4.3701e-02,  ...,  5.1270e-03,
           9.5215e-03,  1.0620e-02],
         [ 5.9375e-01, -4.9414e-01,  4.2969e-01,  ...,  1.3086e-01,
          -1.9434e-01,  3.2812e-01],
         [ 2.7734e-01, -1.6504e-01,  7.1484e-01,  ...,  1.1797e+00,
           2.8906e-01,  2.6245e-02],
         ...,
         [ 2.3633e-01, -1.5312e+00, -3.9258e-01,  ...,  2.0508e-01,
           4.6875e-01,  9.9121e-02],
         [ 1.9629e-01, -1.7500e+00, -7.6562e-01,  ...,  1.1875e+00,
           3.7305e-01,  8.2031e-01],
         [ 5.9766e-01, -5.2979e-02,  1.0791e-01,  ...,  8.1543e-02,
           7.0703e-01, -9.5215e-02]],

        [[-2.7832e-02, -1.5869e-02, -4.3701e-02,  ...,  5.1270e-03,
           9.5215e-03,  1.0620e-02],
         [ 5.9375e-01, -4.9414e-01,  4.2969e-01,  ...,  1.3086e-01,
          -1.9434e-01,  3.2812e-01],
         [ 2.7734e-01, -1.6504e-01,  7.1484e-01,  ...,  1.1797e+00,
           2.8906e-01,  2.6245e-02],
         ...,
         [ 2.3633e-01, -1.5312e+00, -3.9258e-01,  ...,  2.0508e-01,
           4.6875e-01,  9.9121e-02],
         [ 1.9629e-01, -1.7500e+00, -7.6562e-01,  ...,  1.1875e+00,
           3.7305e-01,  8.2031e-01],
         [ 5.9766e-01, -5.2979e-02,  1.0791e-01,  ...,  8.1543e-02,
           7.0703e-01, -9.5215e-02]],

        [[-2.7832e-02, -1.5869e-02, -4.3701e-02,  ...,  5.1270e-03,
           9.5215e-03,  1.0620e-02],
         [ 5.9375e-01, -4.9414e-01,  4.2969e-01,  ...,  1.3086e-01,
          -1.9434e-01,  3.2812e-01],
         [ 2.7734e-01, -1.6504e-01,  7.1484e-01,  ...,  1.1797e+00,
           2.8906e-01,  2.6245e-02],
         ...,
         [ 2.3633e-01, -1.5312e+00, -3.9258e-01,  ...,  2.0508e-01,
           4.6875e-01,  9.9121e-02],
         [ 1.9629e-01, -1.7500e+00, -7.6562e-01,  ...,  1.1875e+00,
           3.7305e-01,  8.2031e-01],
         [ 5.9766e-01, -5.2979e-02,  1.0791e-01,  ...,  8.1543e-02,
           7.0703e-01, -9.5215e-02]],

        ...,

        [[-2.3438e-02,  2.6978e-02,  1.3062e-02,  ...,  7.7438e-04,
           1.4954e-02,  1.2283e-03],
         [ 6.6406e-01, -8.2031e-01, -3.6523e-01,  ..., -5.2795e-03,
          -1.8359e-01,  2.6367e-01],
         [-1.0498e-01, -6.3281e-01, -3.3789e-01,  ..., -1.1250e+00,
          -7.9688e-01,  8.9844e-02],
         ...,
         [ 1.9688e+00, -5.5469e-01,  3.4375e-01,  ..., -3.0469e-01,
          -1.8359e+00, -1.1250e+00],
         [ 3.9648e-01,  2.4707e-01,  3.3203e-01,  ...,  4.6387e-02,
          -1.0703e+00, -6.7578e-01],
         [-2.1387e-01, -6.0938e-01, -9.7266e-01,  ..., -2.6758e-01,
          -9.7656e-02,  5.2734e-01]],

        [[-2.3438e-02,  2.6978e-02,  1.3062e-02,  ...,  7.7438e-04,
           1.4954e-02,  1.2283e-03],
         [ 6.6406e-01, -8.2031e-01, -3.6523e-01,  ..., -5.2795e-03,
          -1.8359e-01,  2.6367e-01],
         [-1.0498e-01, -6.3281e-01, -3.3789e-01,  ..., -1.1250e+00,
          -7.9688e-01,  8.9844e-02],
         ...,
         [ 1.9688e+00, -5.5469e-01,  3.4375e-01,  ..., -3.0469e-01,
          -1.8359e+00, -1.1250e+00],
         [ 3.9648e-01,  2.4707e-01,  3.3203e-01,  ...,  4.6387e-02,
          -1.0703e+00, -6.7578e-01],
         [-2.1387e-01, -6.0938e-01, -9.7266e-01,  ..., -2.6758e-01,
          -9.7656e-02,  5.2734e-01]],

        [[-2.3438e-02,  2.6978e-02,  1.3062e-02,  ...,  7.7438e-04,
           1.4954e-02,  1.2283e-03],
         [ 6.6406e-01, -8.2031e-01, -3.6523e-01,  ..., -5.2795e-03,
          -1.8359e-01,  2.6367e-01],
         [-1.0498e-01, -6.3281e-01, -3.3789e-01,  ..., -1.1250e+00,
          -7.9688e-01,  8.9844e-02],
         ...,
         [ 1.9688e+00, -5.5469e-01,  3.4375e-01,  ..., -3.0469e-01,
          -1.8359e+00, -1.1250e+00],
         [ 3.9648e-01,  2.4707e-01,  3.3203e-01,  ...,  4.6387e-02,
          -1.0703e+00, -6.7578e-01],
         [-2.1387e-01, -6.0938e-01, -9.7266e-01,  ..., -2.6758e-01,
          -9.7656e-02,  5.2734e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 7.9727e-04,  1.0925e-02, -3.7842e-03,  ..., -1.3184e-01,
          -2.9297e-01, -9.7168e-02],
         [ 2.0000e+00,  7.6172e-01, -2.5938e+00,  ...,  9.4238e-02,
           6.7188e-01, -9.1406e-01],
         [ 6.0156e-01, -1.2344e+00, -1.1250e+00,  ...,  8.0078e-02,
           1.8672e+00, -4.7461e-01],
         ...,
         [ 0.0000e+00, -3.4375e-01,  2.4023e-01,  ...,  1.5703e+00,
          -5.7031e-01, -5.5469e-01],
         [ 7.8125e-03,  3.3398e-01, -1.8555e-01,  ...,  4.0430e-01,
          -5.6250e-01, -9.7656e-01],
         [-2.0781e+00,  5.2734e-01, -1.5938e+00,  ..., -2.3594e+00,
          -2.4062e+00, -1.7734e+00]],

        [[ 7.9727e-04,  1.0925e-02, -3.7842e-03,  ..., -1.3184e-01,
          -2.9297e-01, -9.7168e-02],
         [ 2.0000e+00,  7.6172e-01, -2.5938e+00,  ...,  9.4238e-02,
           6.7188e-01, -9.1406e-01],
         [ 6.0156e-01, -1.2344e+00, -1.1250e+00,  ...,  8.0078e-02,
           1.8672e+00, -4.7461e-01],
         ...,
         [ 0.0000e+00, -3.4375e-01,  2.4023e-01,  ...,  1.5703e+00,
          -5.7031e-01, -5.5469e-01],
         [ 7.8125e-03,  3.3398e-01, -1.8555e-01,  ...,  4.0430e-01,
          -5.6250e-01, -9.7656e-01],
         [-2.0781e+00,  5.2734e-01, -1.5938e+00,  ..., -2.3594e+00,
          -2.4062e+00, -1.7734e+00]],

        [[ 7.9727e-04,  1.0925e-02, -3.7842e-03,  ..., -1.3184e-01,
          -2.9297e-01, -9.7168e-02],
         [ 2.0000e+00,  7.6172e-01, -2.5938e+00,  ...,  9.4238e-02,
           6.7188e-01, -9.1406e-01],
         [ 6.0156e-01, -1.2344e+00, -1.1250e+00,  ...,  8.0078e-02,
           1.8672e+00, -4.7461e-01],
         ...,
         [ 0.0000e+00, -3.4375e-01,  2.4023e-01,  ...,  1.5703e+00,
          -5.7031e-01, -5.5469e-01],
         [ 7.8125e-03,  3.3398e-01, -1.8555e-01,  ...,  4.0430e-01,
          -5.6250e-01, -9.7656e-01],
         [-2.0781e+00,  5.2734e-01, -1.5938e+00,  ..., -2.3594e+00,
          -2.4062e+00, -1.7734e+00]],

        ...,

        [[ 3.6469e-03,  1.0742e-02,  3.5706e-03,  ..., -2.5156e+00,
          -1.3672e-02,  1.0498e-01],
         [ 1.4219e+00,  1.2500e-01,  5.7812e-01,  ...,  8.2500e+00,
          -4.3945e-01, -1.3047e+00],
         [-5.0781e-01,  6.7969e-01, -1.4766e+00,  ...,  8.4375e+00,
           1.2422e+00, -1.2578e+00],
         ...,
         [-1.0703e+00,  1.0312e+00,  1.5430e-01,  ...,  7.1875e+00,
          -1.1094e+00,  7.3828e-01],
         [-1.3086e-01, -1.7383e-01,  3.3789e-01,  ...,  5.3750e+00,
          -2.3047e-01,  7.4609e-01],
         [-4.8242e-01, -2.8906e-01,  1.1250e+00,  ...,  6.9375e+00,
          -2.0625e+00, -1.1328e+00]],

        [[ 3.6469e-03,  1.0742e-02,  3.5706e-03,  ..., -2.5156e+00,
          -1.3672e-02,  1.0498e-01],
         [ 1.4219e+00,  1.2500e-01,  5.7812e-01,  ...,  8.2500e+00,
          -4.3945e-01, -1.3047e+00],
         [-5.0781e-01,  6.7969e-01, -1.4766e+00,  ...,  8.4375e+00,
           1.2422e+00, -1.2578e+00],
         ...,
         [-1.0703e+00,  1.0312e+00,  1.5430e-01,  ...,  7.1875e+00,
          -1.1094e+00,  7.3828e-01],
         [-1.3086e-01, -1.7383e-01,  3.3789e-01,  ...,  5.3750e+00,
          -2.3047e-01,  7.4609e-01],
         [-4.8242e-01, -2.8906e-01,  1.1250e+00,  ...,  6.9375e+00,
          -2.0625e+00, -1.1328e+00]],

        [[ 3.6469e-03,  1.0742e-02,  3.5706e-03,  ..., -2.5156e+00,
          -1.3672e-02,  1.0498e-01],
         [ 1.4219e+00,  1.2500e-01,  5.7812e-01,  ...,  8.2500e+00,
          -4.3945e-01, -1.3047e+00],
         [-5.0781e-01,  6.7969e-01, -1.4766e+00,  ...,  8.4375e+00,
           1.2422e+00, -1.2578e+00],
         ...,
         [-1.0703e+00,  1.0312e+00,  1.5430e-01,  ...,  7.1875e+00,
          -1.1094e+00,  7.3828e-01],
         [-1.3086e-01, -1.7383e-01,  3.3789e-01,  ...,  5.3750e+00,
          -2.3047e-01,  7.4609e-01],
         [-4.8242e-01, -2.8906e-01,  1.1250e+00,  ...,  6.9375e+00,
          -2.0625e+00, -1.1328e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.9287e-02,  2.2461e-02, -3.5645e-02,  ..., -2.5391e-02,
          -1.1841e-02, -1.5564e-02],
         [ 4.3359e-01,  2.4805e-01, -4.0234e-01,  ...,  2.8125e-01,
          -6.6016e-01, -6.8750e-01],
         [-2.2949e-02, -7.6660e-02, -1.0547e-01,  ..., -5.8594e-01,
          -3.2617e-01,  8.6426e-02],
         ...,
         [-3.8867e-01, -6.7969e-01,  6.2500e-01,  ..., -7.9102e-02,
          -1.2031e+00, -1.2207e-01],
         [-1.3770e-01, -7.2656e-01, -1.7456e-02,  ...,  1.7676e-01,
          -6.8359e-01,  4.4336e-01],
         [ 4.5117e-01,  1.2266e+00,  1.1328e-01,  ...,  8.0078e-01,
          -2.0410e-01, -1.1475e-01]],

        [[-1.9287e-02,  2.2461e-02, -3.5645e-02,  ..., -2.5391e-02,
          -1.1841e-02, -1.5564e-02],
         [ 4.3359e-01,  2.4805e-01, -4.0234e-01,  ...,  2.8125e-01,
          -6.6016e-01, -6.8750e-01],
         [-2.2949e-02, -7.6660e-02, -1.0547e-01,  ..., -5.8594e-01,
          -3.2617e-01,  8.6426e-02],
         ...,
         [-3.8867e-01, -6.7969e-01,  6.2500e-01,  ..., -7.9102e-02,
          -1.2031e+00, -1.2207e-01],
         [-1.3770e-01, -7.2656e-01, -1.7456e-02,  ...,  1.7676e-01,
          -6.8359e-01,  4.4336e-01],
         [ 4.5117e-01,  1.2266e+00,  1.1328e-01,  ...,  8.0078e-01,
          -2.0410e-01, -1.1475e-01]],

        [[-1.9287e-02,  2.2461e-02, -3.5645e-02,  ..., -2.5391e-02,
          -1.1841e-02, -1.5564e-02],
         [ 4.3359e-01,  2.4805e-01, -4.0234e-01,  ...,  2.8125e-01,
          -6.6016e-01, -6.8750e-01],
         [-2.2949e-02, -7.6660e-02, -1.0547e-01,  ..., -5.8594e-01,
          -3.2617e-01,  8.6426e-02],
         ...,
         [-3.8867e-01, -6.7969e-01,  6.2500e-01,  ..., -7.9102e-02,
          -1.2031e+00, -1.2207e-01],
         [-1.3770e-01, -7.2656e-01, -1.7456e-02,  ...,  1.7676e-01,
          -6.8359e-01,  4.4336e-01],
         [ 4.5117e-01,  1.2266e+00,  1.1328e-01,  ...,  8.0078e-01,
          -2.0410e-01, -1.1475e-01]],

        ...,

        [[ 1.9653e-02,  5.0049e-03, -5.9509e-03,  ...,  3.7354e-02,
          -1.7090e-02,  1.6022e-03],
         [-8.8379e-02, -1.9238e-01,  2.8320e-01,  ...,  3.4570e-01,
          -5.9082e-02,  5.5908e-02],
         [ 3.1250e-01, -8.3594e-01,  3.4375e-01,  ...,  5.9326e-02,
           3.1250e-01, -3.3008e-01],
         ...,
         [-7.9688e-01,  5.7422e-01,  6.7871e-02,  ..., -2.0117e-01,
          -5.0391e-01,  8.3008e-02],
         [-5.7617e-02, -2.6733e-02, -6.6376e-04,  ..., -1.5332e-01,
          -2.6367e-01,  3.0859e-01],
         [-2.1875e-01,  3.9844e-01, -2.5000e-01,  ..., -8.3984e-01,
          -1.0156e+00,  1.5723e-01]],

        [[ 1.9653e-02,  5.0049e-03, -5.9509e-03,  ...,  3.7354e-02,
          -1.7090e-02,  1.6022e-03],
         [-8.8379e-02, -1.9238e-01,  2.8320e-01,  ...,  3.4570e-01,
          -5.9082e-02,  5.5908e-02],
         [ 3.1250e-01, -8.3594e-01,  3.4375e-01,  ...,  5.9326e-02,
           3.1250e-01, -3.3008e-01],
         ...,
         [-7.9688e-01,  5.7422e-01,  6.7871e-02,  ..., -2.0117e-01,
          -5.0391e-01,  8.3008e-02],
         [-5.7617e-02, -2.6733e-02, -6.6376e-04,  ..., -1.5332e-01,
          -2.6367e-01,  3.0859e-01],
         [-2.1875e-01,  3.9844e-01, -2.5000e-01,  ..., -8.3984e-01,
          -1.0156e+00,  1.5723e-01]],

        [[ 1.9653e-02,  5.0049e-03, -5.9509e-03,  ...,  3.7354e-02,
          -1.7090e-02,  1.6022e-03],
         [-8.8379e-02, -1.9238e-01,  2.8320e-01,  ...,  3.4570e-01,
          -5.9082e-02,  5.5908e-02],
         [ 3.1250e-01, -8.3594e-01,  3.4375e-01,  ...,  5.9326e-02,
           3.1250e-01, -3.3008e-01],
         ...,
         [-7.9688e-01,  5.7422e-01,  6.7871e-02,  ..., -2.0117e-01,
          -5.0391e-01,  8.3008e-02],
         [-5.7617e-02, -2.6733e-02, -6.6376e-04,  ..., -1.5332e-01,
          -2.6367e-01,  3.0859e-01],
         [-2.1875e-01,  3.9844e-01, -2.5000e-01,  ..., -8.3984e-01,
          -1.0156e+00,  1.5723e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.7700e-02, -1.2329e-02,  1.7456e-02,  ...,  7.7637e-02,
           5.0537e-02, -1.0107e-01],
         [ 2.0000e+00,  1.4551e-01,  2.0469e+00,  ..., -4.6289e-01,
          -6.8359e-01, -7.5391e-01],
         [ 1.6562e+00,  1.2734e+00, -2.5977e-01,  ...,  2.4512e-01,
          -2.8594e+00,  1.6562e+00],
         ...,
         [ 3.1445e-01,  6.0156e-01, -3.8086e-02,  ..., -1.6953e+00,
           5.4375e+00, -1.1172e+00],
         [-8.0469e-01, -4.6680e-01,  3.3398e-01,  ..., -3.8672e-01,
           5.1172e-01, -2.6250e+00],
         [-5.6641e-01, -4.4531e-01,  1.4219e+00,  ..., -6.8359e-01,
           9.9121e-02,  1.1562e+00]],

        [[-1.7700e-02, -1.2329e-02,  1.7456e-02,  ...,  7.7637e-02,
           5.0537e-02, -1.0107e-01],
         [ 2.0000e+00,  1.4551e-01,  2.0469e+00,  ..., -4.6289e-01,
          -6.8359e-01, -7.5391e-01],
         [ 1.6562e+00,  1.2734e+00, -2.5977e-01,  ...,  2.4512e-01,
          -2.8594e+00,  1.6562e+00],
         ...,
         [ 3.1445e-01,  6.0156e-01, -3.8086e-02,  ..., -1.6953e+00,
           5.4375e+00, -1.1172e+00],
         [-8.0469e-01, -4.6680e-01,  3.3398e-01,  ..., -3.8672e-01,
           5.1172e-01, -2.6250e+00],
         [-5.6641e-01, -4.4531e-01,  1.4219e+00,  ..., -6.8359e-01,
           9.9121e-02,  1.1562e+00]],

        [[-1.7700e-02, -1.2329e-02,  1.7456e-02,  ...,  7.7637e-02,
           5.0537e-02, -1.0107e-01],
         [ 2.0000e+00,  1.4551e-01,  2.0469e+00,  ..., -4.6289e-01,
          -6.8359e-01, -7.5391e-01],
         [ 1.6562e+00,  1.2734e+00, -2.5977e-01,  ...,  2.4512e-01,
          -2.8594e+00,  1.6562e+00],
         ...,
         [ 3.1445e-01,  6.0156e-01, -3.8086e-02,  ..., -1.6953e+00,
           5.4375e+00, -1.1172e+00],
         [-8.0469e-01, -4.6680e-01,  3.3398e-01,  ..., -3.8672e-01,
           5.1172e-01, -2.6250e+00],
         [-5.6641e-01, -4.4531e-01,  1.4219e+00,  ..., -6.8359e-01,
           9.9121e-02,  1.1562e+00]],

        ...,

        [[-3.5706e-03,  6.9275e-03,  1.4221e-02,  ..., -1.7285e-01,
           1.7773e-01, -1.6968e-02],
         [ 2.4062e+00,  1.9062e+00,  2.1777e-01,  ..., -1.5156e+00,
          -1.3750e+00, -9.2578e-01],
         [ 9.4922e-01,  1.1484e+00,  3.0273e-01,  ..., -3.0938e+00,
          -2.1250e+00, -2.0938e+00],
         ...,
         [ 1.9531e+00, -1.5820e-01, -4.6875e-01,  ...,  7.3438e-01,
           7.0703e-01,  1.8984e+00],
         [-2.0508e-01,  1.2500e-01, -6.8359e-01,  ..., -7.4707e-02,
           5.0391e-01,  2.5469e+00],
         [-1.3125e+00, -2.1406e+00, -1.0391e+00,  ..., -6.2500e-01,
          -1.4453e+00,  1.4453e+00]],

        [[-3.5706e-03,  6.9275e-03,  1.4221e-02,  ..., -1.7285e-01,
           1.7773e-01, -1.6968e-02],
         [ 2.4062e+00,  1.9062e+00,  2.1777e-01,  ..., -1.5156e+00,
          -1.3750e+00, -9.2578e-01],
         [ 9.4922e-01,  1.1484e+00,  3.0273e-01,  ..., -3.0938e+00,
          -2.1250e+00, -2.0938e+00],
         ...,
         [ 1.9531e+00, -1.5820e-01, -4.6875e-01,  ...,  7.3438e-01,
           7.0703e-01,  1.8984e+00],
         [-2.0508e-01,  1.2500e-01, -6.8359e-01,  ..., -7.4707e-02,
           5.0391e-01,  2.5469e+00],
         [-1.3125e+00, -2.1406e+00, -1.0391e+00,  ..., -6.2500e-01,
          -1.4453e+00,  1.4453e+00]],

        [[-3.5706e-03,  6.9275e-03,  1.4221e-02,  ..., -1.7285e-01,
           1.7773e-01, -1.6968e-02],
         [ 2.4062e+00,  1.9062e+00,  2.1777e-01,  ..., -1.5156e+00,
          -1.3750e+00, -9.2578e-01],
         [ 9.4922e-01,  1.1484e+00,  3.0273e-01,  ..., -3.0938e+00,
          -2.1250e+00, -2.0938e+00],
         ...,
         [ 1.9531e+00, -1.5820e-01, -4.6875e-01,  ...,  7.3438e-01,
           7.0703e-01,  1.8984e+00],
         [-2.0508e-01,  1.2500e-01, -6.8359e-01,  ..., -7.4707e-02,
           5.0391e-01,  2.5469e+00],
         [-1.3125e+00, -2.1406e+00, -1.0391e+00,  ..., -6.2500e-01,
          -1.4453e+00,  1.4453e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 0.0537, -0.0067, -0.0131,  ...,  0.0227, -0.0175,  0.0188],
         [-0.4023, -0.3379, -0.4727,  ..., -0.3730, -0.2930,  0.3906],
         [-0.2852, -0.7734,  0.0374,  ..., -0.4922, -0.1074,  0.3262],
         ...,
         [ 0.9141,  0.0201,  0.0767,  ...,  0.3379,  0.7227, -0.1226],
         [ 0.4629, -0.3047, -1.0156,  ..., -0.3867,  0.9883, -0.2891],
         [-0.3574, -0.0312,  0.0036,  ...,  0.0957,  0.3242,  0.3711]],

        [[ 0.0537, -0.0067, -0.0131,  ...,  0.0227, -0.0175,  0.0188],
         [-0.4023, -0.3379, -0.4727,  ..., -0.3730, -0.2930,  0.3906],
         [-0.2852, -0.7734,  0.0374,  ..., -0.4922, -0.1074,  0.3262],
         ...,
         [ 0.9141,  0.0201,  0.0767,  ...,  0.3379,  0.7227, -0.1226],
         [ 0.4629, -0.3047, -1.0156,  ..., -0.3867,  0.9883, -0.2891],
         [-0.3574, -0.0312,  0.0036,  ...,  0.0957,  0.3242,  0.3711]],

        [[ 0.0537, -0.0067, -0.0131,  ...,  0.0227, -0.0175,  0.0188],
         [-0.4023, -0.3379, -0.4727,  ..., -0.3730, -0.2930,  0.3906],
         [-0.2852, -0.7734,  0.0374,  ..., -0.4922, -0.1074,  0.3262],
         ...,
         [ 0.9141,  0.0201,  0.0767,  ...,  0.3379,  0.7227, -0.1226],
         [ 0.4629, -0.3047, -1.0156,  ..., -0.3867,  0.9883, -0.2891],
         [-0.3574, -0.0312,  0.0036,  ...,  0.0957,  0.3242,  0.3711]],

        ...,

        [[-0.0300, -0.0415, -0.0488,  ..., -0.0352,  0.0204,  0.0442],
         [-0.2051,  0.1680,  0.2500,  ..., -0.3320, -0.2949,  0.1670],
         [ 0.0203, -0.3398, -0.0371,  ..., -0.2354, -0.0206,  0.0171],
         ...,
         [ 0.3613, -0.0474,  0.5508,  ...,  0.7266,  0.5312,  0.2363],
         [ 0.0061, -0.2520,  0.2520,  ...,  0.5820, -0.9219, -0.2422],
         [ 0.3750,  0.4922,  0.8398,  ...,  0.4512,  0.2139,  0.4316]],

        [[-0.0300, -0.0415, -0.0488,  ..., -0.0352,  0.0204,  0.0442],
         [-0.2051,  0.1680,  0.2500,  ..., -0.3320, -0.2949,  0.1670],
         [ 0.0203, -0.3398, -0.0371,  ..., -0.2354, -0.0206,  0.0171],
         ...,
         [ 0.3613, -0.0474,  0.5508,  ...,  0.7266,  0.5312,  0.2363],
         [ 0.0061, -0.2520,  0.2520,  ...,  0.5820, -0.9219, -0.2422],
         [ 0.3750,  0.4922,  0.8398,  ...,  0.4512,  0.2139,  0.4316]],

        [[-0.0300, -0.0415, -0.0488,  ..., -0.0352,  0.0204,  0.0442],
         [-0.2051,  0.1680,  0.2500,  ..., -0.3320, -0.2949,  0.1670],
         [ 0.0203, -0.3398, -0.0371,  ..., -0.2354, -0.0206,  0.0171],
         ...,
         [ 0.3613, -0.0474,  0.5508,  ...,  0.7266,  0.5312,  0.2363],
         [ 0.0061, -0.2520,  0.2520,  ...,  0.5820, -0.9219, -0.2422],
         [ 0.3750,  0.4922,  0.8398,  ...,  0.4512,  0.2139,  0.4316]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[-7.9956e-03, -3.5553e-03, -3.8147e-03,  ..., -1.2500e+00,
           1.5747e-02,  6.2012e-02],
         [ 7.0312e-01,  3.6719e-01,  1.5820e-01,  ...,  2.6172e-01,
          -5.8594e-01, -1.5625e-01],
         [-6.3281e-01,  7.7734e-01, -8.4229e-03,  ...,  1.1719e+00,
          -1.7109e+00, -2.7500e+00],
         ...,
         [-1.0312e+00,  1.3086e-01,  2.3828e-01,  ..., -1.7109e+00,
          -2.5156e+00,  7.3438e-01],
         [-3.6523e-01, -8.7891e-02, -4.9414e-01,  ...,  6.7871e-02,
          -1.5156e+00, -7.3438e-01],
         [-5.7422e-01, -1.4062e+00, -1.4766e+00,  ...,  6.8750e-01,
           1.4609e+00, -1.2188e+00]],

        [[-7.9956e-03, -3.5553e-03, -3.8147e-03,  ..., -1.2500e+00,
           1.5747e-02,  6.2012e-02],
         [ 7.0312e-01,  3.6719e-01,  1.5820e-01,  ...,  2.6172e-01,
          -5.8594e-01, -1.5625e-01],
         [-6.3281e-01,  7.7734e-01, -8.4229e-03,  ...,  1.1719e+00,
          -1.7109e+00, -2.7500e+00],
         ...,
         [-1.0312e+00,  1.3086e-01,  2.3828e-01,  ..., -1.7109e+00,
          -2.5156e+00,  7.3438e-01],
         [-3.6523e-01, -8.7891e-02, -4.9414e-01,  ...,  6.7871e-02,
          -1.5156e+00, -7.3438e-01],
         [-5.7422e-01, -1.4062e+00, -1.4766e+00,  ...,  6.8750e-01,
           1.4609e+00, -1.2188e+00]],

        [[-7.9956e-03, -3.5553e-03, -3.8147e-03,  ..., -1.2500e+00,
           1.5747e-02,  6.2012e-02],
         [ 7.0312e-01,  3.6719e-01,  1.5820e-01,  ...,  2.6172e-01,
          -5.8594e-01, -1.5625e-01],
         [-6.3281e-01,  7.7734e-01, -8.4229e-03,  ...,  1.1719e+00,
          -1.7109e+00, -2.7500e+00],
         ...,
         [-1.0312e+00,  1.3086e-01,  2.3828e-01,  ..., -1.7109e+00,
          -2.5156e+00,  7.3438e-01],
         [-3.6523e-01, -8.7891e-02, -4.9414e-01,  ...,  6.7871e-02,
          -1.5156e+00, -7.3438e-01],
         [-5.7422e-01, -1.4062e+00, -1.4766e+00,  ...,  6.8750e-01,
           1.4609e+00, -1.2188e+00]],

        ...,

        [[-1.8692e-03,  1.0300e-03, -7.0496e-03,  ...,  1.4375e+00,
          -2.7969e+00, -3.5156e-01],
         [ 1.7188e-01, -2.3281e+00, -9.9219e-01,  ..., -1.9219e+00,
           6.9375e+00,  1.6641e+00],
         [ 1.5527e-01, -3.7109e-01, -6.2500e-01,  ..., -2.3906e+00,
           7.8438e+00,  4.6562e+00],
         ...,
         [ 9.3750e-01,  9.6484e-01,  1.3770e-01,  ..., -1.2812e+00,
           6.0625e+00, -4.2188e+00],
         [-4.9609e-01,  5.3125e-01,  1.7285e-01,  ..., -8.5938e-01,
           4.5625e+00, -1.1719e+00],
         [ 1.4258e-01,  1.9375e+00,  3.6523e-01,  ..., -4.2188e-01,
           5.3438e+00,  2.2969e+00]],

        [[-1.8692e-03,  1.0300e-03, -7.0496e-03,  ...,  1.4375e+00,
          -2.7969e+00, -3.5156e-01],
         [ 1.7188e-01, -2.3281e+00, -9.9219e-01,  ..., -1.9219e+00,
           6.9375e+00,  1.6641e+00],
         [ 1.5527e-01, -3.7109e-01, -6.2500e-01,  ..., -2.3906e+00,
           7.8438e+00,  4.6562e+00],
         ...,
         [ 9.3750e-01,  9.6484e-01,  1.3770e-01,  ..., -1.2812e+00,
           6.0625e+00, -4.2188e+00],
         [-4.9609e-01,  5.3125e-01,  1.7285e-01,  ..., -8.5938e-01,
           4.5625e+00, -1.1719e+00],
         [ 1.4258e-01,  1.9375e+00,  3.6523e-01,  ..., -4.2188e-01,
           5.3438e+00,  2.2969e+00]],

        [[-1.8692e-03,  1.0300e-03, -7.0496e-03,  ...,  1.4375e+00,
          -2.7969e+00, -3.5156e-01],
         [ 1.7188e-01, -2.3281e+00, -9.9219e-01,  ..., -1.9219e+00,
           6.9375e+00,  1.6641e+00],
         [ 1.5527e-01, -3.7109e-01, -6.2500e-01,  ..., -2.3906e+00,
           7.8438e+00,  4.6562e+00],
         ...,
         [ 9.3750e-01,  9.6484e-01,  1.3770e-01,  ..., -1.2812e+00,
           6.0625e+00, -4.2188e+00],
         [-4.9609e-01,  5.3125e-01,  1.7285e-01,  ..., -8.5938e-01,
           4.5625e+00, -1.1719e+00],
         [ 1.4258e-01,  1.9375e+00,  3.6523e-01,  ..., -4.2188e-01,
           5.3438e+00,  2.2969e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 0.0623,  0.0913, -0.0081,  ...,  0.0413, -0.0123,  0.0043],
         [-0.5742,  0.0535, -0.3262,  ...,  0.3320,  0.4824,  0.2266],
         [-0.7031, -0.4531,  0.2227,  ..., -0.0579,  0.4180,  0.1084],
         ...,
         [ 0.5781, -0.4688,  0.4727,  ...,  0.2832, -0.0732,  0.8008],
         [ 0.0183, -0.4453,  0.8086,  ...,  0.1299,  0.0757,  0.7695],
         [ 0.2207, -0.5938, -1.3984,  ...,  0.3223,  0.6367, -0.5078]],

        [[ 0.0623,  0.0913, -0.0081,  ...,  0.0413, -0.0123,  0.0043],
         [-0.5742,  0.0535, -0.3262,  ...,  0.3320,  0.4824,  0.2266],
         [-0.7031, -0.4531,  0.2227,  ..., -0.0579,  0.4180,  0.1084],
         ...,
         [ 0.5781, -0.4688,  0.4727,  ...,  0.2832, -0.0732,  0.8008],
         [ 0.0183, -0.4453,  0.8086,  ...,  0.1299,  0.0757,  0.7695],
         [ 0.2207, -0.5938, -1.3984,  ...,  0.3223,  0.6367, -0.5078]],

        [[ 0.0623,  0.0913, -0.0081,  ...,  0.0413, -0.0123,  0.0043],
         [-0.5742,  0.0535, -0.3262,  ...,  0.3320,  0.4824,  0.2266],
         [-0.7031, -0.4531,  0.2227,  ..., -0.0579,  0.4180,  0.1084],
         ...,
         [ 0.5781, -0.4688,  0.4727,  ...,  0.2832, -0.0732,  0.8008],
         [ 0.0183, -0.4453,  0.8086,  ...,  0.1299,  0.0757,  0.7695],
         [ 0.2207, -0.5938, -1.3984,  ...,  0.3223,  0.6367, -0.5078]],

        ...,

        [[-0.0090,  0.0481,  0.0194,  ...,  0.0170,  0.0413,  0.0540],
         [-0.3809, -0.5859,  0.0077,  ...,  0.2715,  0.6562, -0.1631],
         [-0.8516, -0.4805, -0.6523,  ..., -0.4141,  0.1099, -0.4941],
         ...,
         [ 0.9414, -0.0767, -0.1104,  ..., -0.5312,  0.1113, -0.4844],
         [ 1.3125,  0.2100, -0.3555,  ..., -0.5938,  0.0104, -0.3984],
         [-0.3887, -0.5547, -0.7422,  ...,  0.2266, -0.4219, -0.0659]],

        [[-0.0090,  0.0481,  0.0194,  ...,  0.0170,  0.0413,  0.0540],
         [-0.3809, -0.5859,  0.0077,  ...,  0.2715,  0.6562, -0.1631],
         [-0.8516, -0.4805, -0.6523,  ..., -0.4141,  0.1099, -0.4941],
         ...,
         [ 0.9414, -0.0767, -0.1104,  ..., -0.5312,  0.1113, -0.4844],
         [ 1.3125,  0.2100, -0.3555,  ..., -0.5938,  0.0104, -0.3984],
         [-0.3887, -0.5547, -0.7422,  ...,  0.2266, -0.4219, -0.0659]],

        [[-0.0090,  0.0481,  0.0194,  ...,  0.0170,  0.0413,  0.0540],
         [-0.3809, -0.5859,  0.0077,  ...,  0.2715,  0.6562, -0.1631],
         [-0.8516, -0.4805, -0.6523,  ..., -0.4141,  0.1099, -0.4941],
         ...,
         [ 0.9414, -0.0767, -0.1104,  ..., -0.5312,  0.1113, -0.4844],
         [ 1.3125,  0.2100, -0.3555,  ..., -0.5938,  0.0104, -0.3984],
         [-0.3887, -0.5547, -0.7422,  ...,  0.2266, -0.4219, -0.0659]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[-7.8125e-03,  3.9673e-03, -4.2419e-03,  ..., -6.7188e-01,
          -1.7090e-01,  9.5215e-02],
         [ 1.6406e+00,  9.9609e-01,  1.7422e+00,  ..., -1.0781e+00,
           1.0469e+00,  1.4141e+00],
         [ 9.2578e-01,  9.8047e-01,  1.8828e+00,  ..., -2.4062e+00,
          -1.1328e+00,  2.3750e+00],
         ...,
         [ 1.0781e+00,  7.4219e-01,  1.1328e+00,  ..., -3.3447e-02,
           4.1797e-01,  8.0078e-01],
         [-4.8047e-01,  2.4512e-01,  8.0469e-01,  ..., -1.1084e-01,
          -1.2891e+00,  1.4375e+00],
         [ 2.4414e-01, -1.0156e+00,  2.1406e+00,  ..., -1.4746e-01,
           1.3281e+00,  3.9551e-02]],

        [[-7.8125e-03,  3.9673e-03, -4.2419e-03,  ..., -6.7188e-01,
          -1.7090e-01,  9.5215e-02],
         [ 1.6406e+00,  9.9609e-01,  1.7422e+00,  ..., -1.0781e+00,
           1.0469e+00,  1.4141e+00],
         [ 9.2578e-01,  9.8047e-01,  1.8828e+00,  ..., -2.4062e+00,
          -1.1328e+00,  2.3750e+00],
         ...,
         [ 1.0781e+00,  7.4219e-01,  1.1328e+00,  ..., -3.3447e-02,
           4.1797e-01,  8.0078e-01],
         [-4.8047e-01,  2.4512e-01,  8.0469e-01,  ..., -1.1084e-01,
          -1.2891e+00,  1.4375e+00],
         [ 2.4414e-01, -1.0156e+00,  2.1406e+00,  ..., -1.4746e-01,
           1.3281e+00,  3.9551e-02]],

        [[-7.8125e-03,  3.9673e-03, -4.2419e-03,  ..., -6.7188e-01,
          -1.7090e-01,  9.5215e-02],
         [ 1.6406e+00,  9.9609e-01,  1.7422e+00,  ..., -1.0781e+00,
           1.0469e+00,  1.4141e+00],
         [ 9.2578e-01,  9.8047e-01,  1.8828e+00,  ..., -2.4062e+00,
          -1.1328e+00,  2.3750e+00],
         ...,
         [ 1.0781e+00,  7.4219e-01,  1.1328e+00,  ..., -3.3447e-02,
           4.1797e-01,  8.0078e-01],
         [-4.8047e-01,  2.4512e-01,  8.0469e-01,  ..., -1.1084e-01,
          -1.2891e+00,  1.4375e+00],
         [ 2.4414e-01, -1.0156e+00,  2.1406e+00,  ..., -1.4746e-01,
           1.3281e+00,  3.9551e-02]],

        ...,

        [[ 1.5198e-02,  1.0864e-02, -6.3477e-03,  ...,  1.2146e-02,
           1.2451e-02,  2.3281e+00],
         [ 8.3984e-01,  2.2344e+00,  1.0391e+00,  ...,  3.3398e-01,
          -3.8438e+00, -7.8438e+00],
         [-2.1484e-01,  1.6562e+00,  1.2695e-01,  ...,  5.5859e-01,
          -1.9219e+00, -1.0062e+01],
         ...,
         [-8.4375e-01, -1.0156e+00, -1.2969e+00,  ...,  8.9062e-01,
           3.1055e-01, -4.8438e+00],
         [-4.6094e-01, -5.3125e-01, -7.4609e-01,  ...,  1.6797e+00,
          -6.4062e-01, -5.3438e+00],
         [-1.4609e+00, -1.6719e+00,  1.9531e-03,  ...,  4.9609e-01,
          -2.7188e+00, -5.4688e+00]],

        [[ 1.5198e-02,  1.0864e-02, -6.3477e-03,  ...,  1.2146e-02,
           1.2451e-02,  2.3281e+00],
         [ 8.3984e-01,  2.2344e+00,  1.0391e+00,  ...,  3.3398e-01,
          -3.8438e+00, -7.8438e+00],
         [-2.1484e-01,  1.6562e+00,  1.2695e-01,  ...,  5.5859e-01,
          -1.9219e+00, -1.0062e+01],
         ...,
         [-8.4375e-01, -1.0156e+00, -1.2969e+00,  ...,  8.9062e-01,
           3.1055e-01, -4.8438e+00],
         [-4.6094e-01, -5.3125e-01, -7.4609e-01,  ...,  1.6797e+00,
          -6.4062e-01, -5.3438e+00],
         [-1.4609e+00, -1.6719e+00,  1.9531e-03,  ...,  4.9609e-01,
          -2.7188e+00, -5.4688e+00]],

        [[ 1.5198e-02,  1.0864e-02, -6.3477e-03,  ...,  1.2146e-02,
           1.2451e-02,  2.3281e+00],
         [ 8.3984e-01,  2.2344e+00,  1.0391e+00,  ...,  3.3398e-01,
          -3.8438e+00, -7.8438e+00],
         [-2.1484e-01,  1.6562e+00,  1.2695e-01,  ...,  5.5859e-01,
          -1.9219e+00, -1.0062e+01],
         ...,
         [-8.4375e-01, -1.0156e+00, -1.2969e+00,  ...,  8.9062e-01,
           3.1055e-01, -4.8438e+00],
         [-4.6094e-01, -5.3125e-01, -7.4609e-01,  ...,  1.6797e+00,
          -6.4062e-01, -5.3438e+00],
         [-1.4609e+00, -1.6719e+00,  1.9531e-03,  ...,  4.9609e-01,
          -2.7188e+00, -5.4688e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.0437e-02,  4.6730e-04,  1.5869e-02,  ...,  4.7607e-02,
          -4.8584e-02, -2.1118e-02],
         [-4.8438e-01,  6.2256e-02,  3.7500e-01,  ..., -1.2329e-02,
           6.8054e-03,  4.7461e-01],
         [ 2.6367e-01, -3.1641e-01,  3.7109e-01,  ..., -4.0039e-02,
           4.7461e-01,  4.8242e-01],
         ...,
         [ 1.5332e-01,  4.4531e-01,  6.9531e-01,  ..., -6.4844e-01,
          -5.8984e-01, -6.6406e-02],
         [ 2.7930e-01,  2.0752e-02, -3.1836e-01,  ..., -1.6602e-01,
           6.7578e-01, -1.7773e-01],
         [-2.0215e-01,  4.3555e-01,  8.1543e-02,  ...,  1.6699e-01,
          -3.8477e-01,  2.6367e-01]],

        [[-1.0437e-02,  4.6730e-04,  1.5869e-02,  ...,  4.7607e-02,
          -4.8584e-02, -2.1118e-02],
         [-4.8438e-01,  6.2256e-02,  3.7500e-01,  ..., -1.2329e-02,
           6.8054e-03,  4.7461e-01],
         [ 2.6367e-01, -3.1641e-01,  3.7109e-01,  ..., -4.0039e-02,
           4.7461e-01,  4.8242e-01],
         ...,
         [ 1.5332e-01,  4.4531e-01,  6.9531e-01,  ..., -6.4844e-01,
          -5.8984e-01, -6.6406e-02],
         [ 2.7930e-01,  2.0752e-02, -3.1836e-01,  ..., -1.6602e-01,
           6.7578e-01, -1.7773e-01],
         [-2.0215e-01,  4.3555e-01,  8.1543e-02,  ...,  1.6699e-01,
          -3.8477e-01,  2.6367e-01]],

        [[-1.0437e-02,  4.6730e-04,  1.5869e-02,  ...,  4.7607e-02,
          -4.8584e-02, -2.1118e-02],
         [-4.8438e-01,  6.2256e-02,  3.7500e-01,  ..., -1.2329e-02,
           6.8054e-03,  4.7461e-01],
         [ 2.6367e-01, -3.1641e-01,  3.7109e-01,  ..., -4.0039e-02,
           4.7461e-01,  4.8242e-01],
         ...,
         [ 1.5332e-01,  4.4531e-01,  6.9531e-01,  ..., -6.4844e-01,
          -5.8984e-01, -6.6406e-02],
         [ 2.7930e-01,  2.0752e-02, -3.1836e-01,  ..., -1.6602e-01,
           6.7578e-01, -1.7773e-01],
         [-2.0215e-01,  4.3555e-01,  8.1543e-02,  ...,  1.6699e-01,
          -3.8477e-01,  2.6367e-01]],

        ...,

        [[-1.8799e-02,  1.3245e-02, -1.6602e-02,  ..., -1.1292e-02,
           1.3489e-02,  2.1820e-03],
         [ 2.6758e-01, -1.4062e-01, -9.5215e-02,  ..., -1.2500e-01,
           4.9219e-01,  2.7930e-01],
         [-3.2422e-01,  1.4453e-01, -4.1211e-01,  ...,  3.0078e-01,
           3.6328e-01, -1.3672e-01],
         ...,
         [ 2.0801e-01,  2.7734e-01,  3.7598e-02,  ..., -2.9883e-01,
           4.1602e-01, -7.1094e-01],
         [ 1.3965e-01, -5.3125e-01, -5.4297e-01,  ...,  1.8066e-01,
          -1.9434e-01, -7.2266e-02],
         [-6.0120e-03, -2.2656e-01, -4.0234e-01,  ..., -3.0396e-02,
          -3.7695e-01,  2.3145e-01]],

        [[-1.8799e-02,  1.3245e-02, -1.6602e-02,  ..., -1.1292e-02,
           1.3489e-02,  2.1820e-03],
         [ 2.6758e-01, -1.4062e-01, -9.5215e-02,  ..., -1.2500e-01,
           4.9219e-01,  2.7930e-01],
         [-3.2422e-01,  1.4453e-01, -4.1211e-01,  ...,  3.0078e-01,
           3.6328e-01, -1.3672e-01],
         ...,
         [ 2.0801e-01,  2.7734e-01,  3.7598e-02,  ..., -2.9883e-01,
           4.1602e-01, -7.1094e-01],
         [ 1.3965e-01, -5.3125e-01, -5.4297e-01,  ...,  1.8066e-01,
          -1.9434e-01, -7.2266e-02],
         [-6.0120e-03, -2.2656e-01, -4.0234e-01,  ..., -3.0396e-02,
          -3.7695e-01,  2.3145e-01]],

        [[-1.8799e-02,  1.3245e-02, -1.6602e-02,  ..., -1.1292e-02,
           1.3489e-02,  2.1820e-03],
         [ 2.6758e-01, -1.4062e-01, -9.5215e-02,  ..., -1.2500e-01,
           4.9219e-01,  2.7930e-01],
         [-3.2422e-01,  1.4453e-01, -4.1211e-01,  ...,  3.0078e-01,
           3.6328e-01, -1.3672e-01],
         ...,
         [ 2.0801e-01,  2.7734e-01,  3.7598e-02,  ..., -2.9883e-01,
           4.1602e-01, -7.1094e-01],
         [ 1.3965e-01, -5.3125e-01, -5.4297e-01,  ...,  1.8066e-01,
          -1.9434e-01, -7.2266e-02],
         [-6.0120e-03, -2.2656e-01, -4.0234e-01,  ..., -3.0396e-02,
          -3.7695e-01,  2.3145e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-2.4567e-03, -9.7046e-03,  2.9449e-03,  ..., -2.8711e-01,
          -1.3770e-01, -1.1230e-01],
         [-1.0078e+00,  2.8711e-01, -1.0469e+00,  ..., -1.1562e+00,
           3.9688e+00, -3.5156e-01],
         [ 3.0469e-01,  6.7578e-01, -3.7109e-01,  ...,  1.1230e-01,
           7.0703e-01,  2.5625e+00],
         ...,
         [ 2.2461e-01, -1.7656e+00,  1.3281e-01,  ..., -5.0000e-01,
          -1.4922e+00, -1.3516e+00],
         [ 7.4219e-01,  2.6953e-01,  7.0312e-01,  ..., -1.1797e+00,
           9.0625e-01,  1.6895e-01],
         [ 1.7422e+00, -3.1250e-02,  2.4219e-01,  ..., -3.7812e+00,
           1.9766e+00,  1.7734e+00]],

        [[-2.4567e-03, -9.7046e-03,  2.9449e-03,  ..., -2.8711e-01,
          -1.3770e-01, -1.1230e-01],
         [-1.0078e+00,  2.8711e-01, -1.0469e+00,  ..., -1.1562e+00,
           3.9688e+00, -3.5156e-01],
         [ 3.0469e-01,  6.7578e-01, -3.7109e-01,  ...,  1.1230e-01,
           7.0703e-01,  2.5625e+00],
         ...,
         [ 2.2461e-01, -1.7656e+00,  1.3281e-01,  ..., -5.0000e-01,
          -1.4922e+00, -1.3516e+00],
         [ 7.4219e-01,  2.6953e-01,  7.0312e-01,  ..., -1.1797e+00,
           9.0625e-01,  1.6895e-01],
         [ 1.7422e+00, -3.1250e-02,  2.4219e-01,  ..., -3.7812e+00,
           1.9766e+00,  1.7734e+00]],

        [[-2.4567e-03, -9.7046e-03,  2.9449e-03,  ..., -2.8711e-01,
          -1.3770e-01, -1.1230e-01],
         [-1.0078e+00,  2.8711e-01, -1.0469e+00,  ..., -1.1562e+00,
           3.9688e+00, -3.5156e-01],
         [ 3.0469e-01,  6.7578e-01, -3.7109e-01,  ...,  1.1230e-01,
           7.0703e-01,  2.5625e+00],
         ...,
         [ 2.2461e-01, -1.7656e+00,  1.3281e-01,  ..., -5.0000e-01,
          -1.4922e+00, -1.3516e+00],
         [ 7.4219e-01,  2.6953e-01,  7.0312e-01,  ..., -1.1797e+00,
           9.0625e-01,  1.6895e-01],
         [ 1.7422e+00, -3.1250e-02,  2.4219e-01,  ..., -3.7812e+00,
           1.9766e+00,  1.7734e+00]],

        ...,

        [[ 3.6163e-03, -1.3794e-02,  1.1597e-02,  ...,  3.6914e-01,
           7.5684e-02,  1.3184e-01],
         [ 7.3047e-01,  1.1328e+00,  1.2109e+00,  ..., -3.8672e-01,
          -5.5469e-01, -7.9688e-01],
         [ 6.7578e-01, -7.4707e-02,  8.4766e-01,  ...,  3.9844e+00,
          -1.8125e+00,  1.0000e+00],
         ...,
         [ 1.2109e+00, -1.1719e-01,  1.1250e+00,  ..., -1.9688e+00,
           2.1719e+00,  1.1641e+00],
         [-4.9219e-01, -1.7676e-01, -9.9121e-02,  ..., -2.5781e+00,
           9.0234e-01,  1.3047e+00],
         [-5.8203e-01, -9.1797e-01,  1.6875e+00,  ..., -1.1719e+00,
          -8.6914e-02, -1.3984e+00]],

        [[ 3.6163e-03, -1.3794e-02,  1.1597e-02,  ...,  3.6914e-01,
           7.5684e-02,  1.3184e-01],
         [ 7.3047e-01,  1.1328e+00,  1.2109e+00,  ..., -3.8672e-01,
          -5.5469e-01, -7.9688e-01],
         [ 6.7578e-01, -7.4707e-02,  8.4766e-01,  ...,  3.9844e+00,
          -1.8125e+00,  1.0000e+00],
         ...,
         [ 1.2109e+00, -1.1719e-01,  1.1250e+00,  ..., -1.9688e+00,
           2.1719e+00,  1.1641e+00],
         [-4.9219e-01, -1.7676e-01, -9.9121e-02,  ..., -2.5781e+00,
           9.0234e-01,  1.3047e+00],
         [-5.8203e-01, -9.1797e-01,  1.6875e+00,  ..., -1.1719e+00,
          -8.6914e-02, -1.3984e+00]],

        [[ 3.6163e-03, -1.3794e-02,  1.1597e-02,  ...,  3.6914e-01,
           7.5684e-02,  1.3184e-01],
         [ 7.3047e-01,  1.1328e+00,  1.2109e+00,  ..., -3.8672e-01,
          -5.5469e-01, -7.9688e-01],
         [ 6.7578e-01, -7.4707e-02,  8.4766e-01,  ...,  3.9844e+00,
          -1.8125e+00,  1.0000e+00],
         ...,
         [ 1.2109e+00, -1.1719e-01,  1.1250e+00,  ..., -1.9688e+00,
           2.1719e+00,  1.1641e+00],
         [-4.9219e-01, -1.7676e-01, -9.9121e-02,  ..., -2.5781e+00,
           9.0234e-01,  1.3047e+00],
         [-5.8203e-01, -9.1797e-01,  1.6875e+00,  ..., -1.1719e+00,
          -8.6914e-02, -1.3984e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.9775e-02,  3.9551e-02, -2.7344e-02,  ..., -4.4434e-02,
          -7.5684e-02,  1.7944e-02],
         [ 2.5781e-01, -4.1797e-01,  1.6406e-01,  ...,  6.4453e-01,
          -3.4766e-01, -4.9414e-01],
         [-5.0781e-01,  4.7070e-01, -2.5781e-01,  ...,  7.9297e-01,
          -6.0547e-01,  4.2969e-02],
         ...,
         [-6.9922e-01,  1.0889e-01, -1.0156e-01,  ..., -5.6641e-01,
          -3.1836e-01, -6.0156e-01],
         [-8.1250e-01, -4.5508e-01, -2.4902e-01,  ..., -3.4375e-01,
          -2.1191e-01, -1.1406e+00],
         [-3.5547e-01,  1.3574e-01, -6.2891e-01,  ...,  1.0681e-02,
          -1.2695e-01,  1.5234e-01]],

        [[-1.9775e-02,  3.9551e-02, -2.7344e-02,  ..., -4.4434e-02,
          -7.5684e-02,  1.7944e-02],
         [ 2.5781e-01, -4.1797e-01,  1.6406e-01,  ...,  6.4453e-01,
          -3.4766e-01, -4.9414e-01],
         [-5.0781e-01,  4.7070e-01, -2.5781e-01,  ...,  7.9297e-01,
          -6.0547e-01,  4.2969e-02],
         ...,
         [-6.9922e-01,  1.0889e-01, -1.0156e-01,  ..., -5.6641e-01,
          -3.1836e-01, -6.0156e-01],
         [-8.1250e-01, -4.5508e-01, -2.4902e-01,  ..., -3.4375e-01,
          -2.1191e-01, -1.1406e+00],
         [-3.5547e-01,  1.3574e-01, -6.2891e-01,  ...,  1.0681e-02,
          -1.2695e-01,  1.5234e-01]],

        [[-1.9775e-02,  3.9551e-02, -2.7344e-02,  ..., -4.4434e-02,
          -7.5684e-02,  1.7944e-02],
         [ 2.5781e-01, -4.1797e-01,  1.6406e-01,  ...,  6.4453e-01,
          -3.4766e-01, -4.9414e-01],
         [-5.0781e-01,  4.7070e-01, -2.5781e-01,  ...,  7.9297e-01,
          -6.0547e-01,  4.2969e-02],
         ...,
         [-6.9922e-01,  1.0889e-01, -1.0156e-01,  ..., -5.6641e-01,
          -3.1836e-01, -6.0156e-01],
         [-8.1250e-01, -4.5508e-01, -2.4902e-01,  ..., -3.4375e-01,
          -2.1191e-01, -1.1406e+00],
         [-3.5547e-01,  1.3574e-01, -6.2891e-01,  ...,  1.0681e-02,
          -1.2695e-01,  1.5234e-01]],

        ...,

        [[ 4.2725e-03, -1.0498e-02, -1.6357e-02,  ..., -2.4261e-03,
          -2.6489e-02,  8.6975e-04],
         [-3.2031e-01, -3.2617e-01,  8.8379e-02,  ..., -8.1250e-01,
           1.6504e-01,  4.5508e-01],
         [-2.8711e-01, -3.9453e-01,  1.5137e-01,  ...,  5.1514e-02,
           6.3965e-02,  1.4746e-01],
         ...,
         [-4.1992e-01,  5.3223e-02,  2.7344e-01,  ...,  1.0234e+00,
           5.7031e-01,  2.1680e-01],
         [-3.6914e-01, -1.1084e-01,  5.3906e-01,  ...,  4.9072e-02,
           7.4609e-01,  6.0156e-01],
         [ 3.7695e-01, -1.3867e-01, -1.8848e-01,  ...,  2.7344e-01,
           5.7031e-01, -4.0527e-02]],

        [[ 4.2725e-03, -1.0498e-02, -1.6357e-02,  ..., -2.4261e-03,
          -2.6489e-02,  8.6975e-04],
         [-3.2031e-01, -3.2617e-01,  8.8379e-02,  ..., -8.1250e-01,
           1.6504e-01,  4.5508e-01],
         [-2.8711e-01, -3.9453e-01,  1.5137e-01,  ...,  5.1514e-02,
           6.3965e-02,  1.4746e-01],
         ...,
         [-4.1992e-01,  5.3223e-02,  2.7344e-01,  ...,  1.0234e+00,
           5.7031e-01,  2.1680e-01],
         [-3.6914e-01, -1.1084e-01,  5.3906e-01,  ...,  4.9072e-02,
           7.4609e-01,  6.0156e-01],
         [ 3.7695e-01, -1.3867e-01, -1.8848e-01,  ...,  2.7344e-01,
           5.7031e-01, -4.0527e-02]],

        [[ 4.2725e-03, -1.0498e-02, -1.6357e-02,  ..., -2.4261e-03,
          -2.6489e-02,  8.6975e-04],
         [-3.2031e-01, -3.2617e-01,  8.8379e-02,  ..., -8.1250e-01,
           1.6504e-01,  4.5508e-01],
         [-2.8711e-01, -3.9453e-01,  1.5137e-01,  ...,  5.1514e-02,
           6.3965e-02,  1.4746e-01],
         ...,
         [-4.1992e-01,  5.3223e-02,  2.7344e-01,  ...,  1.0234e+00,
           5.7031e-01,  2.1680e-01],
         [-3.6914e-01, -1.1084e-01,  5.3906e-01,  ...,  4.9072e-02,
           7.4609e-01,  6.0156e-01],
         [ 3.7695e-01, -1.3867e-01, -1.8848e-01,  ...,  2.7344e-01,
           5.7031e-01, -4.0527e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.0504e-04, -5.9509e-04, -8.3618e-03,  ...,  3.5706e-03,
           7.7637e-02, -1.4258e-01],
         [-1.5547e+00, -6.5234e-01, -7.3047e-01,  ...,  2.0312e+00,
          -6.8750e-01,  6.7578e-01],
         [ 6.4844e-01, -6.2500e-01, -6.1328e-01,  ..., -7.3047e-01,
          -1.5078e+00, -6.0547e-01],
         ...,
         [ 8.9062e-01, -2.5781e-01, -3.6914e-01,  ...,  9.8828e-01,
           2.0156e+00,  2.5781e+00],
         [ 7.8125e-01,  4.2383e-01, -4.6094e-01,  ...,  7.5195e-02,
           6.2500e-01,  1.4375e+00],
         [ 1.6484e+00,  8.2812e-01, -1.4141e+00,  ...,  9.5703e-01,
          -1.1641e+00,  6.7188e-01]],

        [[ 2.0504e-04, -5.9509e-04, -8.3618e-03,  ...,  3.5706e-03,
           7.7637e-02, -1.4258e-01],
         [-1.5547e+00, -6.5234e-01, -7.3047e-01,  ...,  2.0312e+00,
          -6.8750e-01,  6.7578e-01],
         [ 6.4844e-01, -6.2500e-01, -6.1328e-01,  ..., -7.3047e-01,
          -1.5078e+00, -6.0547e-01],
         ...,
         [ 8.9062e-01, -2.5781e-01, -3.6914e-01,  ...,  9.8828e-01,
           2.0156e+00,  2.5781e+00],
         [ 7.8125e-01,  4.2383e-01, -4.6094e-01,  ...,  7.5195e-02,
           6.2500e-01,  1.4375e+00],
         [ 1.6484e+00,  8.2812e-01, -1.4141e+00,  ...,  9.5703e-01,
          -1.1641e+00,  6.7188e-01]],

        [[ 2.0504e-04, -5.9509e-04, -8.3618e-03,  ...,  3.5706e-03,
           7.7637e-02, -1.4258e-01],
         [-1.5547e+00, -6.5234e-01, -7.3047e-01,  ...,  2.0312e+00,
          -6.8750e-01,  6.7578e-01],
         [ 6.4844e-01, -6.2500e-01, -6.1328e-01,  ..., -7.3047e-01,
          -1.5078e+00, -6.0547e-01],
         ...,
         [ 8.9062e-01, -2.5781e-01, -3.6914e-01,  ...,  9.8828e-01,
           2.0156e+00,  2.5781e+00],
         [ 7.8125e-01,  4.2383e-01, -4.6094e-01,  ...,  7.5195e-02,
           6.2500e-01,  1.4375e+00],
         [ 1.6484e+00,  8.2812e-01, -1.4141e+00,  ...,  9.5703e-01,
          -1.1641e+00,  6.7188e-01]],

        ...,

        [[-8.8882e-04, -6.9275e-03, -3.6926e-03,  ..., -2.1094e-01,
          -4.8584e-02, -6.6406e-02],
         [ 7.1875e-01,  4.1016e-01, -6.3281e-01,  ...,  2.8438e+00,
          -3.8906e+00, -7.0703e-01],
         [ 1.3086e-01, -2.4805e-01, -8.0859e-01,  ...,  2.0000e+00,
          -4.7812e+00, -3.4531e+00],
         ...,
         [ 2.8906e-01,  1.1377e-01,  2.5781e-01,  ..., -3.8594e+00,
          -1.2891e+00,  5.7373e-03],
         [ 2.7148e-01, -1.9336e-01, -2.1289e-01,  ..., -3.6406e+00,
           2.2656e+00, -1.6484e+00],
         [-2.3828e-01, -6.2500e-02,  5.2002e-02,  ...,  3.7500e+00,
           1.3438e+00,  6.1768e-02]],

        [[-8.8882e-04, -6.9275e-03, -3.6926e-03,  ..., -2.1094e-01,
          -4.8584e-02, -6.6406e-02],
         [ 7.1875e-01,  4.1016e-01, -6.3281e-01,  ...,  2.8438e+00,
          -3.8906e+00, -7.0703e-01],
         [ 1.3086e-01, -2.4805e-01, -8.0859e-01,  ...,  2.0000e+00,
          -4.7812e+00, -3.4531e+00],
         ...,
         [ 2.8906e-01,  1.1377e-01,  2.5781e-01,  ..., -3.8594e+00,
          -1.2891e+00,  5.7373e-03],
         [ 2.7148e-01, -1.9336e-01, -2.1289e-01,  ..., -3.6406e+00,
           2.2656e+00, -1.6484e+00],
         [-2.3828e-01, -6.2500e-02,  5.2002e-02,  ...,  3.7500e+00,
           1.3438e+00,  6.1768e-02]],

        [[-8.8882e-04, -6.9275e-03, -3.6926e-03,  ..., -2.1094e-01,
          -4.8584e-02, -6.6406e-02],
         [ 7.1875e-01,  4.1016e-01, -6.3281e-01,  ...,  2.8438e+00,
          -3.8906e+00, -7.0703e-01],
         [ 1.3086e-01, -2.4805e-01, -8.0859e-01,  ...,  2.0000e+00,
          -4.7812e+00, -3.4531e+00],
         ...,
         [ 2.8906e-01,  1.1377e-01,  2.5781e-01,  ..., -3.8594e+00,
          -1.2891e+00,  5.7373e-03],
         [ 2.7148e-01, -1.9336e-01, -2.1289e-01,  ..., -3.6406e+00,
           2.2656e+00, -1.6484e+00],
         [-2.3828e-01, -6.2500e-02,  5.2002e-02,  ...,  3.7500e+00,
           1.3438e+00,  6.1768e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-4.9210e-04, -2.6398e-03,  5.1270e-03,  ...,  1.7929e-03,
           4.4556e-03,  9.5825e-03],
         [-7.7637e-02, -2.2949e-01, -3.5742e-01,  ...,  1.1597e-02,
          -8.4766e-01, -5.3906e-01],
         [-3.6719e-01,  3.7500e-01,  4.3164e-01,  ..., -1.7188e-01,
          -4.4922e-01,  4.8438e-01],
         ...,
         [-3.7891e-01,  2.6172e-01, -1.5332e-01,  ...,  5.3906e-01,
           4.8242e-01, -4.9414e-01],
         [-2.4023e-01, -6.6406e-02,  2.0801e-01,  ...,  7.9297e-01,
          -8.7402e-02, -2.8320e-01],
         [-7.9297e-01,  2.0312e-01,  1.6992e-01,  ..., -4.4922e-01,
           5.3516e-01,  3.1445e-01]],

        [[-4.9210e-04, -2.6398e-03,  5.1270e-03,  ...,  1.7929e-03,
           4.4556e-03,  9.5825e-03],
         [-7.7637e-02, -2.2949e-01, -3.5742e-01,  ...,  1.1597e-02,
          -8.4766e-01, -5.3906e-01],
         [-3.6719e-01,  3.7500e-01,  4.3164e-01,  ..., -1.7188e-01,
          -4.4922e-01,  4.8438e-01],
         ...,
         [-3.7891e-01,  2.6172e-01, -1.5332e-01,  ...,  5.3906e-01,
           4.8242e-01, -4.9414e-01],
         [-2.4023e-01, -6.6406e-02,  2.0801e-01,  ...,  7.9297e-01,
          -8.7402e-02, -2.8320e-01],
         [-7.9297e-01,  2.0312e-01,  1.6992e-01,  ..., -4.4922e-01,
           5.3516e-01,  3.1445e-01]],

        [[-4.9210e-04, -2.6398e-03,  5.1270e-03,  ...,  1.7929e-03,
           4.4556e-03,  9.5825e-03],
         [-7.7637e-02, -2.2949e-01, -3.5742e-01,  ...,  1.1597e-02,
          -8.4766e-01, -5.3906e-01],
         [-3.6719e-01,  3.7500e-01,  4.3164e-01,  ..., -1.7188e-01,
          -4.4922e-01,  4.8438e-01],
         ...,
         [-3.7891e-01,  2.6172e-01, -1.5332e-01,  ...,  5.3906e-01,
           4.8242e-01, -4.9414e-01],
         [-2.4023e-01, -6.6406e-02,  2.0801e-01,  ...,  7.9297e-01,
          -8.7402e-02, -2.8320e-01],
         [-7.9297e-01,  2.0312e-01,  1.6992e-01,  ..., -4.4922e-01,
           5.3516e-01,  3.1445e-01]],

        ...,

        [[-1.7456e-02,  5.1514e-02, -1.7262e-04,  ...,  2.8968e-05,
           2.8687e-03,  1.8066e-02],
         [-1.3125e+00,  4.2969e-02,  1.1719e-01,  ...,  7.4219e-01,
          -8.2812e-01, -4.1602e-01],
         [-8.3984e-01, -2.1406e+00, -1.8828e+00,  ..., -1.0312e+00,
          -6.6797e-01,  1.4355e-01],
         ...,
         [ 6.5234e-01, -1.3125e+00, -4.6680e-01,  ..., -9.5215e-02,
           1.4648e-01,  5.7031e-01],
         [ 2.6562e-01, -2.0215e-01,  1.2793e-01,  ...,  2.0020e-01,
           7.5391e-01,  2.7588e-02],
         [-4.5508e-01, -4.7852e-01, -9.8047e-01,  ...,  6.3281e-01,
          -5.2734e-01, -1.3672e-01]],

        [[-1.7456e-02,  5.1514e-02, -1.7262e-04,  ...,  2.8968e-05,
           2.8687e-03,  1.8066e-02],
         [-1.3125e+00,  4.2969e-02,  1.1719e-01,  ...,  7.4219e-01,
          -8.2812e-01, -4.1602e-01],
         [-8.3984e-01, -2.1406e+00, -1.8828e+00,  ..., -1.0312e+00,
          -6.6797e-01,  1.4355e-01],
         ...,
         [ 6.5234e-01, -1.3125e+00, -4.6680e-01,  ..., -9.5215e-02,
           1.4648e-01,  5.7031e-01],
         [ 2.6562e-01, -2.0215e-01,  1.2793e-01,  ...,  2.0020e-01,
           7.5391e-01,  2.7588e-02],
         [-4.5508e-01, -4.7852e-01, -9.8047e-01,  ...,  6.3281e-01,
          -5.2734e-01, -1.3672e-01]],

        [[-1.7456e-02,  5.1514e-02, -1.7262e-04,  ...,  2.8968e-05,
           2.8687e-03,  1.8066e-02],
         [-1.3125e+00,  4.2969e-02,  1.1719e-01,  ...,  7.4219e-01,
          -8.2812e-01, -4.1602e-01],
         [-8.3984e-01, -2.1406e+00, -1.8828e+00,  ..., -1.0312e+00,
          -6.6797e-01,  1.4355e-01],
         ...,
         [ 6.5234e-01, -1.3125e+00, -4.6680e-01,  ..., -9.5215e-02,
           1.4648e-01,  5.7031e-01],
         [ 2.6562e-01, -2.0215e-01,  1.2793e-01,  ...,  2.0020e-01,
           7.5391e-01,  2.7588e-02],
         [-4.5508e-01, -4.7852e-01, -9.8047e-01,  ...,  6.3281e-01,
          -5.2734e-01, -1.3672e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 6.1646e-03, -5.6763e-03,  5.2490e-03,  ..., -9.8633e-02,
          -2.2656e-01, -2.7954e-02],
         [-1.6406e+00, -1.4844e-01,  2.5391e-01,  ..., -2.9688e-01,
          -2.1562e+00,  4.9072e-02],
         [-2.0625e+00,  9.0234e-01, -6.2109e-01,  ...,  1.6875e+00,
          -1.6250e+00, -1.6562e+00],
         ...,
         [-3.9258e-01, -9.1016e-01,  1.0000e+00,  ..., -2.5625e+00,
           1.7266e+00, -1.3516e+00],
         [ 1.8848e-01, -1.0625e+00, -4.5898e-01,  ..., -8.3203e-01,
           7.1094e-01, -2.9844e+00],
         [-1.4531e+00, -3.9453e-01,  1.3203e+00,  ..., -1.2266e+00,
          -3.2031e+00,  3.0156e+00]],

        [[ 6.1646e-03, -5.6763e-03,  5.2490e-03,  ..., -9.8633e-02,
          -2.2656e-01, -2.7954e-02],
         [-1.6406e+00, -1.4844e-01,  2.5391e-01,  ..., -2.9688e-01,
          -2.1562e+00,  4.9072e-02],
         [-2.0625e+00,  9.0234e-01, -6.2109e-01,  ...,  1.6875e+00,
          -1.6250e+00, -1.6562e+00],
         ...,
         [-3.9258e-01, -9.1016e-01,  1.0000e+00,  ..., -2.5625e+00,
           1.7266e+00, -1.3516e+00],
         [ 1.8848e-01, -1.0625e+00, -4.5898e-01,  ..., -8.3203e-01,
           7.1094e-01, -2.9844e+00],
         [-1.4531e+00, -3.9453e-01,  1.3203e+00,  ..., -1.2266e+00,
          -3.2031e+00,  3.0156e+00]],

        [[ 6.1646e-03, -5.6763e-03,  5.2490e-03,  ..., -9.8633e-02,
          -2.2656e-01, -2.7954e-02],
         [-1.6406e+00, -1.4844e-01,  2.5391e-01,  ..., -2.9688e-01,
          -2.1562e+00,  4.9072e-02],
         [-2.0625e+00,  9.0234e-01, -6.2109e-01,  ...,  1.6875e+00,
          -1.6250e+00, -1.6562e+00],
         ...,
         [-3.9258e-01, -9.1016e-01,  1.0000e+00,  ..., -2.5625e+00,
           1.7266e+00, -1.3516e+00],
         [ 1.8848e-01, -1.0625e+00, -4.5898e-01,  ..., -8.3203e-01,
           7.1094e-01, -2.9844e+00],
         [-1.4531e+00, -3.9453e-01,  1.3203e+00,  ..., -1.2266e+00,
          -3.2031e+00,  3.0156e+00]],

        ...,

        [[ 2.3193e-03,  1.8921e-02,  1.1902e-02,  ...,  3.8574e-02,
           1.0620e-02,  1.2207e-01],
         [ 9.7656e-01, -4.1016e-02,  1.2031e+00,  ...,  4.9805e-01,
           4.1211e-01, -8.9453e-01],
         [ 9.9609e-01,  1.0938e+00,  2.0781e+00,  ...,  2.4062e+00,
          -1.6357e-02, -6.0156e-01],
         ...,
         [ 1.5156e+00, -8.8672e-01,  5.6641e-02,  ..., -8.6328e-01,
           1.6406e+00, -2.1875e+00],
         [ 3.9062e-01, -9.1406e-01, -2.4805e-01,  ...,  1.1484e+00,
           1.7344e+00, -1.8594e+00],
         [ 6.8750e-01, -1.5000e+00,  1.0156e+00,  ...,  1.6094e+00,
           5.6641e-01, -2.8516e-01]],

        [[ 2.3193e-03,  1.8921e-02,  1.1902e-02,  ...,  3.8574e-02,
           1.0620e-02,  1.2207e-01],
         [ 9.7656e-01, -4.1016e-02,  1.2031e+00,  ...,  4.9805e-01,
           4.1211e-01, -8.9453e-01],
         [ 9.9609e-01,  1.0938e+00,  2.0781e+00,  ...,  2.4062e+00,
          -1.6357e-02, -6.0156e-01],
         ...,
         [ 1.5156e+00, -8.8672e-01,  5.6641e-02,  ..., -8.6328e-01,
           1.6406e+00, -2.1875e+00],
         [ 3.9062e-01, -9.1406e-01, -2.4805e-01,  ...,  1.1484e+00,
           1.7344e+00, -1.8594e+00],
         [ 6.8750e-01, -1.5000e+00,  1.0156e+00,  ...,  1.6094e+00,
           5.6641e-01, -2.8516e-01]],

        [[ 2.3193e-03,  1.8921e-02,  1.1902e-02,  ...,  3.8574e-02,
           1.0620e-02,  1.2207e-01],
         [ 9.7656e-01, -4.1016e-02,  1.2031e+00,  ...,  4.9805e-01,
           4.1211e-01, -8.9453e-01],
         [ 9.9609e-01,  1.0938e+00,  2.0781e+00,  ...,  2.4062e+00,
          -1.6357e-02, -6.0156e-01],
         ...,
         [ 1.5156e+00, -8.8672e-01,  5.6641e-02,  ..., -8.6328e-01,
           1.6406e+00, -2.1875e+00],
         [ 3.9062e-01, -9.1406e-01, -2.4805e-01,  ...,  1.1484e+00,
           1.7344e+00, -1.8594e+00],
         [ 6.8750e-01, -1.5000e+00,  1.0156e+00,  ...,  1.6094e+00,
           5.6641e-01, -2.8516e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-0.0067, -0.0065, -0.0085,  ..., -0.0104, -0.0156, -0.0349],
         [-0.1035,  0.9180,  0.8984,  ..., -0.1216,  0.4629,  0.6289],
         [-0.7266,  1.1094,  0.8711,  ...,  0.2344, -0.6406,  1.3047],
         ...,
         [ 1.1484,  0.2832, -0.2598,  ...,  0.4062,  0.6953, -1.1328],
         [-0.3594, -0.0439, -0.4395,  ...,  0.5234,  0.3008,  0.2461],
         [ 0.5234, -1.4375,  0.5469,  ..., -0.3008,  1.1250,  0.1299]],

        [[-0.0067, -0.0065, -0.0085,  ..., -0.0104, -0.0156, -0.0349],
         [-0.1035,  0.9180,  0.8984,  ..., -0.1216,  0.4629,  0.6289],
         [-0.7266,  1.1094,  0.8711,  ...,  0.2344, -0.6406,  1.3047],
         ...,
         [ 1.1484,  0.2832, -0.2598,  ...,  0.4062,  0.6953, -1.1328],
         [-0.3594, -0.0439, -0.4395,  ...,  0.5234,  0.3008,  0.2461],
         [ 0.5234, -1.4375,  0.5469,  ..., -0.3008,  1.1250,  0.1299]],

        [[-0.0067, -0.0065, -0.0085,  ..., -0.0104, -0.0156, -0.0349],
         [-0.1035,  0.9180,  0.8984,  ..., -0.1216,  0.4629,  0.6289],
         [-0.7266,  1.1094,  0.8711,  ...,  0.2344, -0.6406,  1.3047],
         ...,
         [ 1.1484,  0.2832, -0.2598,  ...,  0.4062,  0.6953, -1.1328],
         [-0.3594, -0.0439, -0.4395,  ...,  0.5234,  0.3008,  0.2461],
         [ 0.5234, -1.4375,  0.5469,  ..., -0.3008,  1.1250,  0.1299]],

        ...,

        [[ 0.0139, -0.0273,  0.0087,  ..., -0.0417,  0.0271,  0.0422],
         [ 0.2275,  0.4062, -0.0045,  ...,  0.7266, -0.3867, -0.6211],
         [-0.2422, -0.1992,  1.3438,  ...,  0.9805, -0.5234, -0.2324],
         ...,
         [ 0.6875, -1.4297, -0.3418,  ...,  0.1069, -0.1187,  0.1299],
         [ 0.5156, -0.5547, -0.6094,  ...,  0.0503,  0.1641,  0.1436],
         [-0.2793,  0.5352, -0.6367,  ..., -0.7109, -1.2656, -0.4180]],

        [[ 0.0139, -0.0273,  0.0087,  ..., -0.0417,  0.0271,  0.0422],
         [ 0.2275,  0.4062, -0.0045,  ...,  0.7266, -0.3867, -0.6211],
         [-0.2422, -0.1992,  1.3438,  ...,  0.9805, -0.5234, -0.2324],
         ...,
         [ 0.6875, -1.4297, -0.3418,  ...,  0.1069, -0.1187,  0.1299],
         [ 0.5156, -0.5547, -0.6094,  ...,  0.0503,  0.1641,  0.1436],
         [-0.2793,  0.5352, -0.6367,  ..., -0.7109, -1.2656, -0.4180]],

        [[ 0.0139, -0.0273,  0.0087,  ..., -0.0417,  0.0271,  0.0422],
         [ 0.2275,  0.4062, -0.0045,  ...,  0.7266, -0.3867, -0.6211],
         [-0.2422, -0.1992,  1.3438,  ...,  0.9805, -0.5234, -0.2324],
         ...,
         [ 0.6875, -1.4297, -0.3418,  ...,  0.1069, -0.1187,  0.1299],
         [ 0.5156, -0.5547, -0.6094,  ...,  0.0503,  0.1641,  0.1436],
         [-0.2793,  0.5352, -0.6367,  ..., -0.7109, -1.2656, -0.4180]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[ 6.0425e-03, -9.7656e-03,  6.5002e-03,  ...,  1.8164e-01,
          -2.0938e+00, -4.0039e-02],
         [ 2.4023e-01,  2.0605e-01,  1.4297e+00,  ..., -9.3359e-01,
           6.4688e+00, -7.0312e-01],
         [-8.1543e-02,  7.2266e-01,  9.5312e-01,  ..., -9.1406e-01,
           7.5000e+00, -3.5938e-01],
         ...,
         [ 8.6328e-01,  1.2793e-01,  4.8438e-01,  ...,  5.8203e-01,
           4.9062e+00,  1.4609e+00],
         [ 2.7344e-02, -3.1006e-02,  4.3359e-01,  ..., -5.8594e-01,
           4.8125e+00,  2.2812e+00],
         [ 8.0078e-01, -9.1797e-01,  1.2344e+00,  ...,  1.5391e+00,
           5.9375e+00, -3.1875e+00]],

        [[ 6.0425e-03, -9.7656e-03,  6.5002e-03,  ...,  1.8164e-01,
          -2.0938e+00, -4.0039e-02],
         [ 2.4023e-01,  2.0605e-01,  1.4297e+00,  ..., -9.3359e-01,
           6.4688e+00, -7.0312e-01],
         [-8.1543e-02,  7.2266e-01,  9.5312e-01,  ..., -9.1406e-01,
           7.5000e+00, -3.5938e-01],
         ...,
         [ 8.6328e-01,  1.2793e-01,  4.8438e-01,  ...,  5.8203e-01,
           4.9062e+00,  1.4609e+00],
         [ 2.7344e-02, -3.1006e-02,  4.3359e-01,  ..., -5.8594e-01,
           4.8125e+00,  2.2812e+00],
         [ 8.0078e-01, -9.1797e-01,  1.2344e+00,  ...,  1.5391e+00,
           5.9375e+00, -3.1875e+00]],

        [[ 6.0425e-03, -9.7656e-03,  6.5002e-03,  ...,  1.8164e-01,
          -2.0938e+00, -4.0039e-02],
         [ 2.4023e-01,  2.0605e-01,  1.4297e+00,  ..., -9.3359e-01,
           6.4688e+00, -7.0312e-01],
         [-8.1543e-02,  7.2266e-01,  9.5312e-01,  ..., -9.1406e-01,
           7.5000e+00, -3.5938e-01],
         ...,
         [ 8.6328e-01,  1.2793e-01,  4.8438e-01,  ...,  5.8203e-01,
           4.9062e+00,  1.4609e+00],
         [ 2.7344e-02, -3.1006e-02,  4.3359e-01,  ..., -5.8594e-01,
           4.8125e+00,  2.2812e+00],
         [ 8.0078e-01, -9.1797e-01,  1.2344e+00,  ...,  1.5391e+00,
           5.9375e+00, -3.1875e+00]],

        ...,

        [[-4.9744e-03, -7.0190e-04, -9.7656e-03,  ...,  3.0859e-01,
          -4.3750e-01, -6.5430e-02],
         [ 1.8906e+00, -1.2500e+00,  9.7656e-01,  ..., -7.7344e-01,
           1.6250e+00, -5.1172e-01],
         [-1.9531e-01,  5.0781e-01, -6.7188e-01,  ..., -1.9531e+00,
           1.1719e+00,  3.0859e-01],
         ...,
         [ 7.8906e-01,  1.1484e+00,  4.3359e-01,  ..., -1.6250e+00,
           2.2754e-01,  1.5625e-01],
         [-5.8594e-01,  4.2578e-01,  6.7969e-01,  ..., -2.0000e+00,
          -5.1562e-01,  1.9165e-02],
         [-1.6016e+00,  4.2188e-01,  1.0859e+00,  ..., -1.5859e+00,
           1.0938e+00,  7.9102e-02]],

        [[-4.9744e-03, -7.0190e-04, -9.7656e-03,  ...,  3.0859e-01,
          -4.3750e-01, -6.5430e-02],
         [ 1.8906e+00, -1.2500e+00,  9.7656e-01,  ..., -7.7344e-01,
           1.6250e+00, -5.1172e-01],
         [-1.9531e-01,  5.0781e-01, -6.7188e-01,  ..., -1.9531e+00,
           1.1719e+00,  3.0859e-01],
         ...,
         [ 7.8906e-01,  1.1484e+00,  4.3359e-01,  ..., -1.6250e+00,
           2.2754e-01,  1.5625e-01],
         [-5.8594e-01,  4.2578e-01,  6.7969e-01,  ..., -2.0000e+00,
          -5.1562e-01,  1.9165e-02],
         [-1.6016e+00,  4.2188e-01,  1.0859e+00,  ..., -1.5859e+00,
           1.0938e+00,  7.9102e-02]],

        [[-4.9744e-03, -7.0190e-04, -9.7656e-03,  ...,  3.0859e-01,
          -4.3750e-01, -6.5430e-02],
         [ 1.8906e+00, -1.2500e+00,  9.7656e-01,  ..., -7.7344e-01,
           1.6250e+00, -5.1172e-01],
         [-1.9531e-01,  5.0781e-01, -6.7188e-01,  ..., -1.9531e+00,
           1.1719e+00,  3.0859e-01],
         ...,
         [ 7.8906e-01,  1.1484e+00,  4.3359e-01,  ..., -1.6250e+00,
           2.2754e-01,  1.5625e-01],
         [-5.8594e-01,  4.2578e-01,  6.7969e-01,  ..., -2.0000e+00,
          -5.1562e-01,  1.9165e-02],
         [-1.6016e+00,  4.2188e-01,  1.0859e+00,  ..., -1.5859e+00,
           1.0938e+00,  7.9102e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 0.0245,  0.0625, -0.0684,  ..., -0.0576, -0.0496,  0.0322],
         [-0.6133,  0.3223, -0.2158,  ...,  0.2930,  0.2852, -0.3613],
         [ 0.2656,  0.1973,  0.1157,  ..., -0.3672, -0.1279,  0.3633],
         ...,
         [ 0.4922, -0.7812,  0.2363,  ...,  0.1992,  0.3848, -0.3848],
         [-0.3398, -1.0859, -0.5117,  ...,  0.4961, -0.0102, -0.2246],
         [ 0.2129,  0.6172, -0.6797,  ...,  0.3301, -0.1914,  0.5352]],

        [[ 0.0245,  0.0625, -0.0684,  ..., -0.0576, -0.0496,  0.0322],
         [-0.6133,  0.3223, -0.2158,  ...,  0.2930,  0.2852, -0.3613],
         [ 0.2656,  0.1973,  0.1157,  ..., -0.3672, -0.1279,  0.3633],
         ...,
         [ 0.4922, -0.7812,  0.2363,  ...,  0.1992,  0.3848, -0.3848],
         [-0.3398, -1.0859, -0.5117,  ...,  0.4961, -0.0102, -0.2246],
         [ 0.2129,  0.6172, -0.6797,  ...,  0.3301, -0.1914,  0.5352]],

        [[ 0.0245,  0.0625, -0.0684,  ..., -0.0576, -0.0496,  0.0322],
         [-0.6133,  0.3223, -0.2158,  ...,  0.2930,  0.2852, -0.3613],
         [ 0.2656,  0.1973,  0.1157,  ..., -0.3672, -0.1279,  0.3633],
         ...,
         [ 0.4922, -0.7812,  0.2363,  ...,  0.1992,  0.3848, -0.3848],
         [-0.3398, -1.0859, -0.5117,  ...,  0.4961, -0.0102, -0.2246],
         [ 0.2129,  0.6172, -0.6797,  ...,  0.3301, -0.1914,  0.5352]],

        ...,

        [[ 0.0066, -0.0017,  0.0037,  ...,  0.0084,  0.0132,  0.0036],
         [-0.4062, -1.0234,  0.8516,  ...,  0.2617,  1.2031,  0.4395],
         [-0.5898,  0.8008,  0.1118,  ...,  0.6875,  1.5156, -0.1387],
         ...,
         [ 0.7734,  1.6328,  0.4062,  ...,  0.0635, -1.4531, -0.7656],
         [ 0.9453, -0.4082,  0.7617,  ...,  1.6797,  0.2002,  0.5781],
         [-0.2031,  0.1699,  0.1191,  ..., -0.1377,  0.6641,  0.8008]],

        [[ 0.0066, -0.0017,  0.0037,  ...,  0.0084,  0.0132,  0.0036],
         [-0.4062, -1.0234,  0.8516,  ...,  0.2617,  1.2031,  0.4395],
         [-0.5898,  0.8008,  0.1118,  ...,  0.6875,  1.5156, -0.1387],
         ...,
         [ 0.7734,  1.6328,  0.4062,  ...,  0.0635, -1.4531, -0.7656],
         [ 0.9453, -0.4082,  0.7617,  ...,  1.6797,  0.2002,  0.5781],
         [-0.2031,  0.1699,  0.1191,  ..., -0.1377,  0.6641,  0.8008]],

        [[ 0.0066, -0.0017,  0.0037,  ...,  0.0084,  0.0132,  0.0036],
         [-0.4062, -1.0234,  0.8516,  ...,  0.2617,  1.2031,  0.4395],
         [-0.5898,  0.8008,  0.1118,  ...,  0.6875,  1.5156, -0.1387],
         ...,
         [ 0.7734,  1.6328,  0.4062,  ...,  0.0635, -1.4531, -0.7656],
         [ 0.9453, -0.4082,  0.7617,  ...,  1.6797,  0.2002,  0.5781],
         [-0.2031,  0.1699,  0.1191,  ..., -0.1377,  0.6641,  0.8008]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[-3.7537e-03, -2.6245e-03,  1.2329e-02,  ...,  1.1768e-01,
          -4.6387e-02,  2.0508e-01],
         [-1.5781e+00, -1.7578e-02, -8.6719e-01,  ..., -1.2500e+00,
           7.1875e-01, -3.4961e-01],
         [-7.1875e-01, -1.1182e-01,  2.5195e-01,  ..., -1.3438e+00,
           3.3594e+00, -1.1172e+00],
         ...,
         [-3.6719e-01,  1.2500e+00, -5.3125e-01,  ..., -1.8359e+00,
          -5.3516e-01,  2.1875e-01],
         [ 4.0625e-01,  1.9824e-01, -3.6719e-01,  ..., -5.3125e-01,
          -1.8848e-01, -9.7656e-03],
         [ 1.7188e-01,  7.0703e-01, -8.2422e-01,  ...,  1.3047e+00,
           2.5625e+00, -3.1406e+00]],

        [[-3.7537e-03, -2.6245e-03,  1.2329e-02,  ...,  1.1768e-01,
          -4.6387e-02,  2.0508e-01],
         [-1.5781e+00, -1.7578e-02, -8.6719e-01,  ..., -1.2500e+00,
           7.1875e-01, -3.4961e-01],
         [-7.1875e-01, -1.1182e-01,  2.5195e-01,  ..., -1.3438e+00,
           3.3594e+00, -1.1172e+00],
         ...,
         [-3.6719e-01,  1.2500e+00, -5.3125e-01,  ..., -1.8359e+00,
          -5.3516e-01,  2.1875e-01],
         [ 4.0625e-01,  1.9824e-01, -3.6719e-01,  ..., -5.3125e-01,
          -1.8848e-01, -9.7656e-03],
         [ 1.7188e-01,  7.0703e-01, -8.2422e-01,  ...,  1.3047e+00,
           2.5625e+00, -3.1406e+00]],

        [[-3.7537e-03, -2.6245e-03,  1.2329e-02,  ...,  1.1768e-01,
          -4.6387e-02,  2.0508e-01],
         [-1.5781e+00, -1.7578e-02, -8.6719e-01,  ..., -1.2500e+00,
           7.1875e-01, -3.4961e-01],
         [-7.1875e-01, -1.1182e-01,  2.5195e-01,  ..., -1.3438e+00,
           3.3594e+00, -1.1172e+00],
         ...,
         [-3.6719e-01,  1.2500e+00, -5.3125e-01,  ..., -1.8359e+00,
          -5.3516e-01,  2.1875e-01],
         [ 4.0625e-01,  1.9824e-01, -3.6719e-01,  ..., -5.3125e-01,
          -1.8848e-01, -9.7656e-03],
         [ 1.7188e-01,  7.0703e-01, -8.2422e-01,  ...,  1.3047e+00,
           2.5625e+00, -3.1406e+00]],

        ...,

        [[-1.0315e-02,  6.5804e-05,  1.0681e-02,  ...,  5.1172e-01,
           1.9043e-02, -2.9907e-02],
         [ 9.0625e-01, -7.8125e-02, -1.0156e+00,  ..., -3.0000e+00,
           4.2578e-01, -9.8047e-01],
         [ 6.7188e-01,  3.4180e-01,  4.0625e-01,  ..., -1.3594e+00,
           1.4844e-01, -1.0312e+00],
         ...,
         [-3.3203e-01,  3.3594e-01, -1.1865e-01,  ..., -1.2031e+00,
          -2.9844e+00, -7.1484e-01],
         [-3.5938e-01, -5.0293e-02,  2.6245e-02,  ..., -4.6631e-02,
          -1.1016e+00,  5.2344e-01],
         [-1.6875e+00, -2.2266e-01, -1.6406e+00,  ...,  4.3164e-01,
           2.1606e-02,  7.4609e-01]],

        [[-1.0315e-02,  6.5804e-05,  1.0681e-02,  ...,  5.1172e-01,
           1.9043e-02, -2.9907e-02],
         [ 9.0625e-01, -7.8125e-02, -1.0156e+00,  ..., -3.0000e+00,
           4.2578e-01, -9.8047e-01],
         [ 6.7188e-01,  3.4180e-01,  4.0625e-01,  ..., -1.3594e+00,
           1.4844e-01, -1.0312e+00],
         ...,
         [-3.3203e-01,  3.3594e-01, -1.1865e-01,  ..., -1.2031e+00,
          -2.9844e+00, -7.1484e-01],
         [-3.5938e-01, -5.0293e-02,  2.6245e-02,  ..., -4.6631e-02,
          -1.1016e+00,  5.2344e-01],
         [-1.6875e+00, -2.2266e-01, -1.6406e+00,  ...,  4.3164e-01,
           2.1606e-02,  7.4609e-01]],

        [[-1.0315e-02,  6.5804e-05,  1.0681e-02,  ...,  5.1172e-01,
           1.9043e-02, -2.9907e-02],
         [ 9.0625e-01, -7.8125e-02, -1.0156e+00,  ..., -3.0000e+00,
           4.2578e-01, -9.8047e-01],
         [ 6.7188e-01,  3.4180e-01,  4.0625e-01,  ..., -1.3594e+00,
           1.4844e-01, -1.0312e+00],
         ...,
         [-3.3203e-01,  3.3594e-01, -1.1865e-01,  ..., -1.2031e+00,
          -2.9844e+00, -7.1484e-01],
         [-3.5938e-01, -5.0293e-02,  2.6245e-02,  ..., -4.6631e-02,
          -1.1016e+00,  5.2344e-01],
         [-1.6875e+00, -2.2266e-01, -1.6406e+00,  ...,  4.3164e-01,
           2.1606e-02,  7.4609e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-2.2339e-02, -2.2583e-02, -1.3672e-01,  ...,  2.0386e-02,
           5.4016e-03, -2.3193e-02],
         [-7.5000e-01, -3.5547e-01,  6.2500e-01,  ...,  3.8281e-01,
           3.4961e-01,  3.8477e-01],
         [ 2.2266e-01, -3.2617e-01,  4.4727e-01,  ..., -1.6504e-01,
           5.8203e-01,  2.3047e-01],
         ...,
         [-3.8086e-02, -1.5723e-01, -9.1797e-01,  ...,  4.5117e-01,
           3.3203e-01,  1.0156e+00],
         [-8.5547e-01,  1.8438e+00,  9.1797e-02,  ...,  1.1406e+00,
           5.5859e-01,  8.1641e-01],
         [ 5.8594e-01, -1.0059e-01, -5.7031e-01,  ..., -5.2734e-02,
           2.9883e-01,  5.0391e-01]],

        [[-2.2339e-02, -2.2583e-02, -1.3672e-01,  ...,  2.0386e-02,
           5.4016e-03, -2.3193e-02],
         [-7.5000e-01, -3.5547e-01,  6.2500e-01,  ...,  3.8281e-01,
           3.4961e-01,  3.8477e-01],
         [ 2.2266e-01, -3.2617e-01,  4.4727e-01,  ..., -1.6504e-01,
           5.8203e-01,  2.3047e-01],
         ...,
         [-3.8086e-02, -1.5723e-01, -9.1797e-01,  ...,  4.5117e-01,
           3.3203e-01,  1.0156e+00],
         [-8.5547e-01,  1.8438e+00,  9.1797e-02,  ...,  1.1406e+00,
           5.5859e-01,  8.1641e-01],
         [ 5.8594e-01, -1.0059e-01, -5.7031e-01,  ..., -5.2734e-02,
           2.9883e-01,  5.0391e-01]],

        [[-2.2339e-02, -2.2583e-02, -1.3672e-01,  ...,  2.0386e-02,
           5.4016e-03, -2.3193e-02],
         [-7.5000e-01, -3.5547e-01,  6.2500e-01,  ...,  3.8281e-01,
           3.4961e-01,  3.8477e-01],
         [ 2.2266e-01, -3.2617e-01,  4.4727e-01,  ..., -1.6504e-01,
           5.8203e-01,  2.3047e-01],
         ...,
         [-3.8086e-02, -1.5723e-01, -9.1797e-01,  ...,  4.5117e-01,
           3.3203e-01,  1.0156e+00],
         [-8.5547e-01,  1.8438e+00,  9.1797e-02,  ...,  1.1406e+00,
           5.5859e-01,  8.1641e-01],
         [ 5.8594e-01, -1.0059e-01, -5.7031e-01,  ..., -5.2734e-02,
           2.9883e-01,  5.0391e-01]],

        ...,

        [[-1.1963e-02,  1.4801e-03, -6.3171e-03,  ...,  3.9062e-03,
          -2.3315e-02, -1.6327e-03],
         [-4.7852e-01, -4.9219e-01,  1.7285e-01,  ..., -3.8086e-01,
           7.8516e-01,  5.4297e-01],
         [-1.1953e+00, -3.8281e-01, -4.9609e-01,  ..., -6.0938e-01,
           5.3516e-01,  4.7266e-01],
         ...,
         [-1.1279e-01,  9.8633e-02,  7.9688e-01,  ...,  6.2109e-01,
          -5.8984e-01, -8.0859e-01],
         [-1.5625e-01,  4.9805e-02,  2.5781e-01,  ...,  9.7656e-01,
           2.1973e-01,  2.3633e-01],
         [-2.4805e-01,  3.6914e-01,  1.1016e+00,  ...,  7.0312e-01,
           2.5391e-01, -3.4912e-02]],

        [[-1.1963e-02,  1.4801e-03, -6.3171e-03,  ...,  3.9062e-03,
          -2.3315e-02, -1.6327e-03],
         [-4.7852e-01, -4.9219e-01,  1.7285e-01,  ..., -3.8086e-01,
           7.8516e-01,  5.4297e-01],
         [-1.1953e+00, -3.8281e-01, -4.9609e-01,  ..., -6.0938e-01,
           5.3516e-01,  4.7266e-01],
         ...,
         [-1.1279e-01,  9.8633e-02,  7.9688e-01,  ...,  6.2109e-01,
          -5.8984e-01, -8.0859e-01],
         [-1.5625e-01,  4.9805e-02,  2.5781e-01,  ...,  9.7656e-01,
           2.1973e-01,  2.3633e-01],
         [-2.4805e-01,  3.6914e-01,  1.1016e+00,  ...,  7.0312e-01,
           2.5391e-01, -3.4912e-02]],

        [[-1.1963e-02,  1.4801e-03, -6.3171e-03,  ...,  3.9062e-03,
          -2.3315e-02, -1.6327e-03],
         [-4.7852e-01, -4.9219e-01,  1.7285e-01,  ..., -3.8086e-01,
           7.8516e-01,  5.4297e-01],
         [-1.1953e+00, -3.8281e-01, -4.9609e-01,  ..., -6.0938e-01,
           5.3516e-01,  4.7266e-01],
         ...,
         [-1.1279e-01,  9.8633e-02,  7.9688e-01,  ...,  6.2109e-01,
          -5.8984e-01, -8.0859e-01],
         [-1.5625e-01,  4.9805e-02,  2.5781e-01,  ...,  9.7656e-01,
           2.1973e-01,  2.3633e-01],
         [-2.4805e-01,  3.6914e-01,  1.1016e+00,  ...,  7.0312e-01,
           2.5391e-01, -3.4912e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 7.9956e-03,  4.7302e-03,  9.1553e-05,  ...,  4.0234e-01,
          -7.4707e-02, -5.3516e-01],
         [ 1.8750e-01, -2.0312e+00,  1.7109e+00,  ..., -2.7344e+00,
           9.1406e-01,  3.5938e-01],
         [-6.9922e-01,  2.0508e-01,  7.5781e-01,  ..., -2.6875e+00,
           4.7070e-01, -4.1504e-02],
         ...,
         [-1.1797e+00,  2.8906e-01, -1.2031e+00,  ..., -3.5625e+00,
          -1.6797e-01,  1.8281e+00],
         [-5.8203e-01,  8.5938e-01,  9.7656e-02,  ..., -3.3750e+00,
          -2.4375e+00,  3.0781e+00],
         [-1.3672e+00,  1.1094e+00,  8.6328e-01,  ..., -4.0938e+00,
           2.7188e+00,  1.2891e+00]],

        [[ 7.9956e-03,  4.7302e-03,  9.1553e-05,  ...,  4.0234e-01,
          -7.4707e-02, -5.3516e-01],
         [ 1.8750e-01, -2.0312e+00,  1.7109e+00,  ..., -2.7344e+00,
           9.1406e-01,  3.5938e-01],
         [-6.9922e-01,  2.0508e-01,  7.5781e-01,  ..., -2.6875e+00,
           4.7070e-01, -4.1504e-02],
         ...,
         [-1.1797e+00,  2.8906e-01, -1.2031e+00,  ..., -3.5625e+00,
          -1.6797e-01,  1.8281e+00],
         [-5.8203e-01,  8.5938e-01,  9.7656e-02,  ..., -3.3750e+00,
          -2.4375e+00,  3.0781e+00],
         [-1.3672e+00,  1.1094e+00,  8.6328e-01,  ..., -4.0938e+00,
           2.7188e+00,  1.2891e+00]],

        [[ 7.9956e-03,  4.7302e-03,  9.1553e-05,  ...,  4.0234e-01,
          -7.4707e-02, -5.3516e-01],
         [ 1.8750e-01, -2.0312e+00,  1.7109e+00,  ..., -2.7344e+00,
           9.1406e-01,  3.5938e-01],
         [-6.9922e-01,  2.0508e-01,  7.5781e-01,  ..., -2.6875e+00,
           4.7070e-01, -4.1504e-02],
         ...,
         [-1.1797e+00,  2.8906e-01, -1.2031e+00,  ..., -3.5625e+00,
          -1.6797e-01,  1.8281e+00],
         [-5.8203e-01,  8.5938e-01,  9.7656e-02,  ..., -3.3750e+00,
          -2.4375e+00,  3.0781e+00],
         [-1.3672e+00,  1.1094e+00,  8.6328e-01,  ..., -4.0938e+00,
           2.7188e+00,  1.2891e+00]],

        ...,

        [[ 3.9368e-03,  7.5378e-03, -1.5137e-02,  ...,  9.0820e-02,
          -4.9805e-01, -1.0469e+00],
         [ 1.6719e+00, -2.3633e-01,  4.3359e-01,  ..., -6.9531e-01,
           4.8438e+00,  3.9531e+00],
         [ 2.9297e-01,  6.9922e-01,  4.2383e-01,  ..., -1.3359e+00,
           3.0156e+00,  5.8125e+00],
         ...,
         [ 4.7461e-01,  1.2344e+00, -1.4062e+00,  ..., -1.7212e-02,
          -5.6641e-01,  3.5625e+00],
         [-3.2812e-01,  3.0078e-01,  1.4062e-01,  ...,  1.2969e+00,
          -1.8359e-01,  4.1562e+00],
         [-1.3984e+00,  7.1875e-01,  9.4531e-01,  ..., -3.4375e+00,
           2.9844e+00,  5.0938e+00]],

        [[ 3.9368e-03,  7.5378e-03, -1.5137e-02,  ...,  9.0820e-02,
          -4.9805e-01, -1.0469e+00],
         [ 1.6719e+00, -2.3633e-01,  4.3359e-01,  ..., -6.9531e-01,
           4.8438e+00,  3.9531e+00],
         [ 2.9297e-01,  6.9922e-01,  4.2383e-01,  ..., -1.3359e+00,
           3.0156e+00,  5.8125e+00],
         ...,
         [ 4.7461e-01,  1.2344e+00, -1.4062e+00,  ..., -1.7212e-02,
          -5.6641e-01,  3.5625e+00],
         [-3.2812e-01,  3.0078e-01,  1.4062e-01,  ...,  1.2969e+00,
          -1.8359e-01,  4.1562e+00],
         [-1.3984e+00,  7.1875e-01,  9.4531e-01,  ..., -3.4375e+00,
           2.9844e+00,  5.0938e+00]],

        [[ 3.9368e-03,  7.5378e-03, -1.5137e-02,  ...,  9.0820e-02,
          -4.9805e-01, -1.0469e+00],
         [ 1.6719e+00, -2.3633e-01,  4.3359e-01,  ..., -6.9531e-01,
           4.8438e+00,  3.9531e+00],
         [ 2.9297e-01,  6.9922e-01,  4.2383e-01,  ..., -1.3359e+00,
           3.0156e+00,  5.8125e+00],
         ...,
         [ 4.7461e-01,  1.2344e+00, -1.4062e+00,  ..., -1.7212e-02,
          -5.6641e-01,  3.5625e+00],
         [-3.2812e-01,  3.0078e-01,  1.4062e-01,  ...,  1.2969e+00,
          -1.8359e-01,  4.1562e+00],
         [-1.3984e+00,  7.1875e-01,  9.4531e-01,  ..., -3.4375e+00,
           2.9844e+00,  5.0938e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-7.3547e-03, -4.1504e-03, -1.7334e-02,  ...,  1.0132e-02,
          -1.4725e-03,  4.8828e-03],
         [-9.4531e-01,  1.7188e-01,  2.5781e+00,  ..., -4.9219e-01,
           1.5234e+00,  1.3594e+00],
         [-2.6562e-01, -1.1484e+00,  1.6406e+00,  ..., -1.5078e+00,
           9.1016e-01,  1.6406e+00],
         ...,
         [-6.5625e-01,  3.2959e-02, -3.8281e-01,  ...,  1.9141e-01,
          -4.0039e-01, -2.3804e-02],
         [-1.8262e-01, -4.1016e-01, -5.8984e-01,  ...,  1.1572e-01,
          -3.8867e-01,  8.2812e-01],
         [-1.5938e+00, -4.6289e-01,  6.7188e-01,  ..., -7.8125e-02,
           5.4932e-02, -9.1797e-02]],

        [[-7.3547e-03, -4.1504e-03, -1.7334e-02,  ...,  1.0132e-02,
          -1.4725e-03,  4.8828e-03],
         [-9.4531e-01,  1.7188e-01,  2.5781e+00,  ..., -4.9219e-01,
           1.5234e+00,  1.3594e+00],
         [-2.6562e-01, -1.1484e+00,  1.6406e+00,  ..., -1.5078e+00,
           9.1016e-01,  1.6406e+00],
         ...,
         [-6.5625e-01,  3.2959e-02, -3.8281e-01,  ...,  1.9141e-01,
          -4.0039e-01, -2.3804e-02],
         [-1.8262e-01, -4.1016e-01, -5.8984e-01,  ...,  1.1572e-01,
          -3.8867e-01,  8.2812e-01],
         [-1.5938e+00, -4.6289e-01,  6.7188e-01,  ..., -7.8125e-02,
           5.4932e-02, -9.1797e-02]],

        [[-7.3547e-03, -4.1504e-03, -1.7334e-02,  ...,  1.0132e-02,
          -1.4725e-03,  4.8828e-03],
         [-9.4531e-01,  1.7188e-01,  2.5781e+00,  ..., -4.9219e-01,
           1.5234e+00,  1.3594e+00],
         [-2.6562e-01, -1.1484e+00,  1.6406e+00,  ..., -1.5078e+00,
           9.1016e-01,  1.6406e+00],
         ...,
         [-6.5625e-01,  3.2959e-02, -3.8281e-01,  ...,  1.9141e-01,
          -4.0039e-01, -2.3804e-02],
         [-1.8262e-01, -4.1016e-01, -5.8984e-01,  ...,  1.1572e-01,
          -3.8867e-01,  8.2812e-01],
         [-1.5938e+00, -4.6289e-01,  6.7188e-01,  ..., -7.8125e-02,
           5.4932e-02, -9.1797e-02]],

        ...,

        [[-7.8735e-03,  1.6357e-02,  8.4686e-04,  ..., -9.6436e-03,
          -1.5747e-02,  1.2512e-02],
         [-1.0840e-01, -1.2031e+00, -2.6953e-01,  ..., -1.9653e-02,
          -2.6953e-01,  1.5991e-02],
         [ 5.1172e-01,  1.1328e+00, -1.2402e-01,  ...,  1.3281e-01,
           4.2773e-01, -2.9297e-01],
         ...,
         [-8.1641e-01, -6.1328e-01, -5.0781e-02,  ...,  6.1719e-01,
          -1.1250e+00, -1.1328e-01],
         [ 3.2471e-02, -1.7383e-01,  3.6328e-01,  ..., -6.7188e-01,
           6.5625e-01, -1.2793e-01],
         [-1.1875e+00,  2.0625e+00,  1.7188e+00,  ..., -5.2344e-01,
           9.9219e-01, -6.4453e-01]],

        [[-7.8735e-03,  1.6357e-02,  8.4686e-04,  ..., -9.6436e-03,
          -1.5747e-02,  1.2512e-02],
         [-1.0840e-01, -1.2031e+00, -2.6953e-01,  ..., -1.9653e-02,
          -2.6953e-01,  1.5991e-02],
         [ 5.1172e-01,  1.1328e+00, -1.2402e-01,  ...,  1.3281e-01,
           4.2773e-01, -2.9297e-01],
         ...,
         [-8.1641e-01, -6.1328e-01, -5.0781e-02,  ...,  6.1719e-01,
          -1.1250e+00, -1.1328e-01],
         [ 3.2471e-02, -1.7383e-01,  3.6328e-01,  ..., -6.7188e-01,
           6.5625e-01, -1.2793e-01],
         [-1.1875e+00,  2.0625e+00,  1.7188e+00,  ..., -5.2344e-01,
           9.9219e-01, -6.4453e-01]],

        [[-7.8735e-03,  1.6357e-02,  8.4686e-04,  ..., -9.6436e-03,
          -1.5747e-02,  1.2512e-02],
         [-1.0840e-01, -1.2031e+00, -2.6953e-01,  ..., -1.9653e-02,
          -2.6953e-01,  1.5991e-02],
         [ 5.1172e-01,  1.1328e+00, -1.2402e-01,  ...,  1.3281e-01,
           4.2773e-01, -2.9297e-01],
         ...,
         [-8.1641e-01, -6.1328e-01, -5.0781e-02,  ...,  6.1719e-01,
          -1.1250e+00, -1.1328e-01],
         [ 3.2471e-02, -1.7383e-01,  3.6328e-01,  ..., -6.7188e-01,
           6.5625e-01, -1.2793e-01],
         [-1.1875e+00,  2.0625e+00,  1.7188e+00,  ..., -5.2344e-01,
           9.9219e-01, -6.4453e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 3.4332e-04,  1.2634e-02, -1.8433e-02,  ...,  3.8330e-02,
           1.4941e-01, -1.8828e+00],
         [ 9.5312e-01, -1.9219e+00, -1.5820e-01,  ..., -1.0391e+00,
          -6.2891e-01,  7.4688e+00],
         [-3.3594e-01, -6.1719e-01, -3.4180e-01,  ...,  6.3281e-01,
           2.0156e+00,  8.3750e+00],
         ...,
         [-4.4922e-01,  3.6914e-01, -8.1641e-01,  ..., -2.8594e+00,
          -3.4688e+00,  5.1562e+00],
         [ 2.5391e-01, -2.1484e-01,  1.2988e-01,  ..., -2.7500e+00,
          -1.4766e+00,  5.7500e+00],
         [-1.0234e+00,  1.0625e+00, -1.3750e+00,  ..., -8.0859e-01,
           2.2188e+00,  7.1875e+00]],

        [[ 3.4332e-04,  1.2634e-02, -1.8433e-02,  ...,  3.8330e-02,
           1.4941e-01, -1.8828e+00],
         [ 9.5312e-01, -1.9219e+00, -1.5820e-01,  ..., -1.0391e+00,
          -6.2891e-01,  7.4688e+00],
         [-3.3594e-01, -6.1719e-01, -3.4180e-01,  ...,  6.3281e-01,
           2.0156e+00,  8.3750e+00],
         ...,
         [-4.4922e-01,  3.6914e-01, -8.1641e-01,  ..., -2.8594e+00,
          -3.4688e+00,  5.1562e+00],
         [ 2.5391e-01, -2.1484e-01,  1.2988e-01,  ..., -2.7500e+00,
          -1.4766e+00,  5.7500e+00],
         [-1.0234e+00,  1.0625e+00, -1.3750e+00,  ..., -8.0859e-01,
           2.2188e+00,  7.1875e+00]],

        [[ 3.4332e-04,  1.2634e-02, -1.8433e-02,  ...,  3.8330e-02,
           1.4941e-01, -1.8828e+00],
         [ 9.5312e-01, -1.9219e+00, -1.5820e-01,  ..., -1.0391e+00,
          -6.2891e-01,  7.4688e+00],
         [-3.3594e-01, -6.1719e-01, -3.4180e-01,  ...,  6.3281e-01,
           2.0156e+00,  8.3750e+00],
         ...,
         [-4.4922e-01,  3.6914e-01, -8.1641e-01,  ..., -2.8594e+00,
          -3.4688e+00,  5.1562e+00],
         [ 2.5391e-01, -2.1484e-01,  1.2988e-01,  ..., -2.7500e+00,
          -1.4766e+00,  5.7500e+00],
         [-1.0234e+00,  1.0625e+00, -1.3750e+00,  ..., -8.0859e-01,
           2.2188e+00,  7.1875e+00]],

        ...,

        [[ 4.5776e-03, -9.3994e-03, -8.3618e-03,  ..., -1.5156e+00,
          -4.8828e-01,  8.8379e-02],
         [-2.8438e+00,  2.7188e+00,  1.3594e+00,  ...,  8.3750e+00,
          -1.0234e+00, -5.2344e-01],
         [-7.5000e-01,  4.7266e-01,  3.7305e-01,  ...,  9.1875e+00,
          -2.5000e+00, -2.6406e+00],
         ...,
         [-1.3672e+00, -8.7500e-01, -1.0391e+00,  ...,  7.7188e+00,
          -2.6367e-01,  9.9609e-01],
         [-1.4688e+00, -8.3203e-01, -6.3281e-01,  ...,  6.9062e+00,
           1.3672e+00,  5.8203e-01],
         [-3.4570e-01, -1.9844e+00, -1.2109e+00,  ...,  7.4062e+00,
          -4.2480e-02, -2.3750e+00]],

        [[ 4.5776e-03, -9.3994e-03, -8.3618e-03,  ..., -1.5156e+00,
          -4.8828e-01,  8.8379e-02],
         [-2.8438e+00,  2.7188e+00,  1.3594e+00,  ...,  8.3750e+00,
          -1.0234e+00, -5.2344e-01],
         [-7.5000e-01,  4.7266e-01,  3.7305e-01,  ...,  9.1875e+00,
          -2.5000e+00, -2.6406e+00],
         ...,
         [-1.3672e+00, -8.7500e-01, -1.0391e+00,  ...,  7.7188e+00,
          -2.6367e-01,  9.9609e-01],
         [-1.4688e+00, -8.3203e-01, -6.3281e-01,  ...,  6.9062e+00,
           1.3672e+00,  5.8203e-01],
         [-3.4570e-01, -1.9844e+00, -1.2109e+00,  ...,  7.4062e+00,
          -4.2480e-02, -2.3750e+00]],

        [[ 4.5776e-03, -9.3994e-03, -8.3618e-03,  ..., -1.5156e+00,
          -4.8828e-01,  8.8379e-02],
         [-2.8438e+00,  2.7188e+00,  1.3594e+00,  ...,  8.3750e+00,
          -1.0234e+00, -5.2344e-01],
         [-7.5000e-01,  4.7266e-01,  3.7305e-01,  ...,  9.1875e+00,
          -2.5000e+00, -2.6406e+00],
         ...,
         [-1.3672e+00, -8.7500e-01, -1.0391e+00,  ...,  7.7188e+00,
          -2.6367e-01,  9.9609e-01],
         [-1.4688e+00, -8.3203e-01, -6.3281e-01,  ...,  6.9062e+00,
           1.3672e+00,  5.8203e-01],
         [-3.4570e-01, -1.9844e+00, -1.2109e+00,  ...,  7.4062e+00,
          -4.2480e-02, -2.3750e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 8.1543e-02,  4.4556e-03,  5.4199e-02,  ..., -3.0212e-03,
          -8.1787e-03, -3.4668e-02],
         [-3.1641e-01, -2.2461e-01, -2.6758e-01,  ...,  4.6289e-01,
          -2.1191e-01, -3.3789e-01],
         [-8.7402e-02,  3.1250e-01,  4.8828e-01,  ...,  7.9297e-01,
           9.8633e-02,  1.7090e-01],
         ...,
         [ 8.5156e-01,  1.0703e+00,  3.1982e-02,  ...,  1.0498e-01,
           2.8125e-01, -5.4688e-01],
         [-3.8330e-02,  3.7500e-01,  3.3594e-01,  ...,  9.4922e-01,
           7.2656e-01, -8.1641e-01],
         [ 2.1387e-01,  3.5547e-01, -1.5234e-01,  ...,  4.3750e-01,
           6.2891e-01, -4.7852e-02]],

        [[ 8.1543e-02,  4.4556e-03,  5.4199e-02,  ..., -3.0212e-03,
          -8.1787e-03, -3.4668e-02],
         [-3.1641e-01, -2.2461e-01, -2.6758e-01,  ...,  4.6289e-01,
          -2.1191e-01, -3.3789e-01],
         [-8.7402e-02,  3.1250e-01,  4.8828e-01,  ...,  7.9297e-01,
           9.8633e-02,  1.7090e-01],
         ...,
         [ 8.5156e-01,  1.0703e+00,  3.1982e-02,  ...,  1.0498e-01,
           2.8125e-01, -5.4688e-01],
         [-3.8330e-02,  3.7500e-01,  3.3594e-01,  ...,  9.4922e-01,
           7.2656e-01, -8.1641e-01],
         [ 2.1387e-01,  3.5547e-01, -1.5234e-01,  ...,  4.3750e-01,
           6.2891e-01, -4.7852e-02]],

        [[ 8.1543e-02,  4.4556e-03,  5.4199e-02,  ..., -3.0212e-03,
          -8.1787e-03, -3.4668e-02],
         [-3.1641e-01, -2.2461e-01, -2.6758e-01,  ...,  4.6289e-01,
          -2.1191e-01, -3.3789e-01],
         [-8.7402e-02,  3.1250e-01,  4.8828e-01,  ...,  7.9297e-01,
           9.8633e-02,  1.7090e-01],
         ...,
         [ 8.5156e-01,  1.0703e+00,  3.1982e-02,  ...,  1.0498e-01,
           2.8125e-01, -5.4688e-01],
         [-3.8330e-02,  3.7500e-01,  3.3594e-01,  ...,  9.4922e-01,
           7.2656e-01, -8.1641e-01],
         [ 2.1387e-01,  3.5547e-01, -1.5234e-01,  ...,  4.3750e-01,
           6.2891e-01, -4.7852e-02]],

        ...,

        [[ 5.6763e-03, -2.8534e-03,  6.3705e-04,  ...,  3.8300e-03,
          -9.3994e-03,  3.9673e-03],
         [-2.5938e+00,  3.1406e+00, -4.9023e-01,  ..., -5.6250e-01,
          -1.3281e+00,  2.0142e-02],
         [-2.4844e+00,  1.4297e+00, -1.9922e+00,  ...,  5.3516e-01,
          -4.1797e-01, -2.4609e-01],
         ...,
         [ 7.5391e-01, -5.7031e-01,  7.5391e-01,  ...,  8.8672e-01,
          -7.6172e-01,  4.5117e-01],
         [ 4.8633e-01, -8.3008e-02,  4.8438e-01,  ..., -6.4941e-02,
          -2.4316e-01,  6.4941e-02],
         [-8.1055e-02,  5.4443e-02, -5.2734e-01,  ...,  3.4180e-01,
           2.0625e+00, -9.8828e-01]],

        [[ 5.6763e-03, -2.8534e-03,  6.3705e-04,  ...,  3.8300e-03,
          -9.3994e-03,  3.9673e-03],
         [-2.5938e+00,  3.1406e+00, -4.9023e-01,  ..., -5.6250e-01,
          -1.3281e+00,  2.0142e-02],
         [-2.4844e+00,  1.4297e+00, -1.9922e+00,  ...,  5.3516e-01,
          -4.1797e-01, -2.4609e-01],
         ...,
         [ 7.5391e-01, -5.7031e-01,  7.5391e-01,  ...,  8.8672e-01,
          -7.6172e-01,  4.5117e-01],
         [ 4.8633e-01, -8.3008e-02,  4.8438e-01,  ..., -6.4941e-02,
          -2.4316e-01,  6.4941e-02],
         [-8.1055e-02,  5.4443e-02, -5.2734e-01,  ...,  3.4180e-01,
           2.0625e+00, -9.8828e-01]],

        [[ 5.6763e-03, -2.8534e-03,  6.3705e-04,  ...,  3.8300e-03,
          -9.3994e-03,  3.9673e-03],
         [-2.5938e+00,  3.1406e+00, -4.9023e-01,  ..., -5.6250e-01,
          -1.3281e+00,  2.0142e-02],
         [-2.4844e+00,  1.4297e+00, -1.9922e+00,  ...,  5.3516e-01,
          -4.1797e-01, -2.4609e-01],
         ...,
         [ 7.5391e-01, -5.7031e-01,  7.5391e-01,  ...,  8.8672e-01,
          -7.6172e-01,  4.5117e-01],
         [ 4.8633e-01, -8.3008e-02,  4.8438e-01,  ..., -6.4941e-02,
          -2.4316e-01,  6.4941e-02],
         [-8.1055e-02,  5.4443e-02, -5.2734e-01,  ...,  3.4180e-01,
           2.0625e+00, -9.8828e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-4.6921e-04,  1.2268e-02,  1.6113e-02,  ...,  2.0703e-01,
           7.5684e-02,  7.1716e-03],
         [-1.4531e+00, -7.7344e-01, -3.1055e-01,  ..., -5.7812e-01,
          -1.6968e-02,  5.7068e-03],
         [-9.0234e-01,  3.5938e-01, -5.3125e-01,  ..., -1.9922e+00,
           2.1250e+00, -2.1406e+00],
         ...,
         [-4.9219e-01, -5.6250e-01,  4.9805e-01,  ...,  1.3359e+00,
          -3.9531e+00, -1.8672e+00],
         [-1.2402e-01,  1.0156e-01,  2.5195e-01,  ...,  7.1875e-01,
          -5.1953e-01, -2.2812e+00],
         [-7.7148e-02,  2.7734e-01, -1.0938e-01,  ..., -9.9609e-01,
           5.5078e-01, -7.0312e-01]],

        [[-4.6921e-04,  1.2268e-02,  1.6113e-02,  ...,  2.0703e-01,
           7.5684e-02,  7.1716e-03],
         [-1.4531e+00, -7.7344e-01, -3.1055e-01,  ..., -5.7812e-01,
          -1.6968e-02,  5.7068e-03],
         [-9.0234e-01,  3.5938e-01, -5.3125e-01,  ..., -1.9922e+00,
           2.1250e+00, -2.1406e+00],
         ...,
         [-4.9219e-01, -5.6250e-01,  4.9805e-01,  ...,  1.3359e+00,
          -3.9531e+00, -1.8672e+00],
         [-1.2402e-01,  1.0156e-01,  2.5195e-01,  ...,  7.1875e-01,
          -5.1953e-01, -2.2812e+00],
         [-7.7148e-02,  2.7734e-01, -1.0938e-01,  ..., -9.9609e-01,
           5.5078e-01, -7.0312e-01]],

        [[-4.6921e-04,  1.2268e-02,  1.6113e-02,  ...,  2.0703e-01,
           7.5684e-02,  7.1716e-03],
         [-1.4531e+00, -7.7344e-01, -3.1055e-01,  ..., -5.7812e-01,
          -1.6968e-02,  5.7068e-03],
         [-9.0234e-01,  3.5938e-01, -5.3125e-01,  ..., -1.9922e+00,
           2.1250e+00, -2.1406e+00],
         ...,
         [-4.9219e-01, -5.6250e-01,  4.9805e-01,  ...,  1.3359e+00,
          -3.9531e+00, -1.8672e+00],
         [-1.2402e-01,  1.0156e-01,  2.5195e-01,  ...,  7.1875e-01,
          -5.1953e-01, -2.2812e+00],
         [-7.7148e-02,  2.7734e-01, -1.0938e-01,  ..., -9.9609e-01,
           5.5078e-01, -7.0312e-01]],

        ...,

        [[-1.4221e-02,  1.0681e-02,  2.6398e-03,  ...,  5.9082e-02,
          -9.4727e-02, -1.1328e-01],
         [ 9.3359e-01, -9.9219e-01,  3.9062e-01,  ..., -4.4727e-01,
          -3.0625e+00, -1.9727e-01],
         [-1.7969e-01, -1.2266e+00,  8.1641e-01,  ...,  1.0234e+00,
          -3.5938e-01,  8.9062e-01],
         ...,
         [-3.9062e-01, -1.0938e+00, -1.1562e+00,  ...,  1.4688e+00,
           4.1797e-01,  1.3516e+00],
         [-6.3281e-01, -8.7500e-01, -1.5312e+00,  ...,  6.4844e-01,
           5.7031e-01,  1.2656e+00],
         [-1.0938e+00, -4.2969e-01, -1.4062e+00,  ...,  1.5312e+00,
           9.1406e-01, -4.4062e+00]],

        [[-1.4221e-02,  1.0681e-02,  2.6398e-03,  ...,  5.9082e-02,
          -9.4727e-02, -1.1328e-01],
         [ 9.3359e-01, -9.9219e-01,  3.9062e-01,  ..., -4.4727e-01,
          -3.0625e+00, -1.9727e-01],
         [-1.7969e-01, -1.2266e+00,  8.1641e-01,  ...,  1.0234e+00,
          -3.5938e-01,  8.9062e-01],
         ...,
         [-3.9062e-01, -1.0938e+00, -1.1562e+00,  ...,  1.4688e+00,
           4.1797e-01,  1.3516e+00],
         [-6.3281e-01, -8.7500e-01, -1.5312e+00,  ...,  6.4844e-01,
           5.7031e-01,  1.2656e+00],
         [-1.0938e+00, -4.2969e-01, -1.4062e+00,  ...,  1.5312e+00,
           9.1406e-01, -4.4062e+00]],

        [[-1.4221e-02,  1.0681e-02,  2.6398e-03,  ...,  5.9082e-02,
          -9.4727e-02, -1.1328e-01],
         [ 9.3359e-01, -9.9219e-01,  3.9062e-01,  ..., -4.4727e-01,
          -3.0625e+00, -1.9727e-01],
         [-1.7969e-01, -1.2266e+00,  8.1641e-01,  ...,  1.0234e+00,
          -3.5938e-01,  8.9062e-01],
         ...,
         [-3.9062e-01, -1.0938e+00, -1.1562e+00,  ...,  1.4688e+00,
           4.1797e-01,  1.3516e+00],
         [-6.3281e-01, -8.7500e-01, -1.5312e+00,  ...,  6.4844e-01,
           5.7031e-01,  1.2656e+00],
         [-1.0938e+00, -4.2969e-01, -1.4062e+00,  ...,  1.5312e+00,
           9.1406e-01, -4.4062e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 9.8877e-03,  2.3193e-03,  7.3547e-03,  ..., -1.6357e-02,
           8.9722e-03, -5.0781e-02],
         [ 2.2461e-01, -3.2422e-01, -6.4062e-01,  ...,  3.5938e-01,
          -6.1328e-01, -2.9053e-02],
         [ 1.3086e-01, -2.3633e-01,  6.6406e-01,  ...,  6.1719e-01,
          -1.0625e+00,  4.8438e-01],
         ...,
         [ 1.2188e+00,  1.3281e-01, -2.0020e-01,  ...,  4.7656e-01,
          -6.7188e-01,  1.3770e-01],
         [ 1.0000e+00,  4.0625e-01, -5.2344e-01,  ...,  2.0605e-01,
          -5.0000e-01,  1.3203e+00],
         [ 3.6133e-01,  5.6250e-01, -5.2344e-01,  ..., -1.4062e-01,
          -6.6797e-01,  1.4844e+00]],

        [[ 9.8877e-03,  2.3193e-03,  7.3547e-03,  ..., -1.6357e-02,
           8.9722e-03, -5.0781e-02],
         [ 2.2461e-01, -3.2422e-01, -6.4062e-01,  ...,  3.5938e-01,
          -6.1328e-01, -2.9053e-02],
         [ 1.3086e-01, -2.3633e-01,  6.6406e-01,  ...,  6.1719e-01,
          -1.0625e+00,  4.8438e-01],
         ...,
         [ 1.2188e+00,  1.3281e-01, -2.0020e-01,  ...,  4.7656e-01,
          -6.7188e-01,  1.3770e-01],
         [ 1.0000e+00,  4.0625e-01, -5.2344e-01,  ...,  2.0605e-01,
          -5.0000e-01,  1.3203e+00],
         [ 3.6133e-01,  5.6250e-01, -5.2344e-01,  ..., -1.4062e-01,
          -6.6797e-01,  1.4844e+00]],

        [[ 9.8877e-03,  2.3193e-03,  7.3547e-03,  ..., -1.6357e-02,
           8.9722e-03, -5.0781e-02],
         [ 2.2461e-01, -3.2422e-01, -6.4062e-01,  ...,  3.5938e-01,
          -6.1328e-01, -2.9053e-02],
         [ 1.3086e-01, -2.3633e-01,  6.6406e-01,  ...,  6.1719e-01,
          -1.0625e+00,  4.8438e-01],
         ...,
         [ 1.2188e+00,  1.3281e-01, -2.0020e-01,  ...,  4.7656e-01,
          -6.7188e-01,  1.3770e-01],
         [ 1.0000e+00,  4.0625e-01, -5.2344e-01,  ...,  2.0605e-01,
          -5.0000e-01,  1.3203e+00],
         [ 3.6133e-01,  5.6250e-01, -5.2344e-01,  ..., -1.4062e-01,
          -6.6797e-01,  1.4844e+00]],

        ...,

        [[ 3.3569e-03, -6.7139e-03, -8.9264e-04,  ...,  1.3046e-03,
           2.6703e-03,  5.0659e-03],
         [-2.7812e+00, -3.8867e-01, -9.0234e-01,  ..., -2.7969e+00,
          -6.2109e-01, -2.7500e+00],
         [-1.2578e+00,  3.4961e-01, -1.9531e-01,  ..., -3.0078e-01,
           8.3984e-01, -1.6484e+00],
         ...,
         [ 1.6797e-01, -1.6602e-01, -1.6699e-01,  ...,  2.8516e-01,
           2.9883e-01, -4.2578e-01],
         [-1.9043e-02,  1.1426e-01, -1.1035e-01,  ..., -8.4961e-02,
           1.2109e-01, -6.0938e-01],
         [-1.0000e+00, -1.6406e-01,  1.4297e+00,  ...,  8.9844e-01,
           1.5703e+00,  4.9805e-01]],

        [[ 3.3569e-03, -6.7139e-03, -8.9264e-04,  ...,  1.3046e-03,
           2.6703e-03,  5.0659e-03],
         [-2.7812e+00, -3.8867e-01, -9.0234e-01,  ..., -2.7969e+00,
          -6.2109e-01, -2.7500e+00],
         [-1.2578e+00,  3.4961e-01, -1.9531e-01,  ..., -3.0078e-01,
           8.3984e-01, -1.6484e+00],
         ...,
         [ 1.6797e-01, -1.6602e-01, -1.6699e-01,  ...,  2.8516e-01,
           2.9883e-01, -4.2578e-01],
         [-1.9043e-02,  1.1426e-01, -1.1035e-01,  ..., -8.4961e-02,
           1.2109e-01, -6.0938e-01],
         [-1.0000e+00, -1.6406e-01,  1.4297e+00,  ...,  8.9844e-01,
           1.5703e+00,  4.9805e-01]],

        [[ 3.3569e-03, -6.7139e-03, -8.9264e-04,  ...,  1.3046e-03,
           2.6703e-03,  5.0659e-03],
         [-2.7812e+00, -3.8867e-01, -9.0234e-01,  ..., -2.7969e+00,
          -6.2109e-01, -2.7500e+00],
         [-1.2578e+00,  3.4961e-01, -1.9531e-01,  ..., -3.0078e-01,
           8.3984e-01, -1.6484e+00],
         ...,
         [ 1.6797e-01, -1.6602e-01, -1.6699e-01,  ...,  2.8516e-01,
           2.9883e-01, -4.2578e-01],
         [-1.9043e-02,  1.1426e-01, -1.1035e-01,  ..., -8.4961e-02,
           1.2109e-01, -6.0938e-01],
         [-1.0000e+00, -1.6406e-01,  1.4297e+00,  ...,  8.9844e-01,
           1.5703e+00,  4.9805e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.6724e-02, -4.3030e-03,  4.7607e-03,  ...,  1.0352e-01,
           8.0566e-02,  3.9062e-01],
         [-4.0625e-01, -7.0312e-01, -8.1250e-01,  ..., -2.0781e+00,
          -1.4922e+00, -6.7812e+00],
         [-1.4688e+00, -1.0312e+00, -9.8828e-01,  ..., -9.4531e-01,
          -2.5469e+00, -6.0625e+00],
         ...,
         [-5.3906e-01, -1.0938e+00,  1.0781e+00,  ..., -1.3984e+00,
          -2.7500e+00, -3.6562e+00],
         [-3.7891e-01, -4.7266e-01,  1.1953e+00,  ..., -1.5938e+00,
          -2.7344e+00, -5.3438e+00],
         [-2.6953e-01,  4.3359e-01,  4.2773e-01,  ..., -1.1172e+00,
          -3.2500e+00, -5.4062e+00]],

        [[-1.6724e-02, -4.3030e-03,  4.7607e-03,  ...,  1.0352e-01,
           8.0566e-02,  3.9062e-01],
         [-4.0625e-01, -7.0312e-01, -8.1250e-01,  ..., -2.0781e+00,
          -1.4922e+00, -6.7812e+00],
         [-1.4688e+00, -1.0312e+00, -9.8828e-01,  ..., -9.4531e-01,
          -2.5469e+00, -6.0625e+00],
         ...,
         [-5.3906e-01, -1.0938e+00,  1.0781e+00,  ..., -1.3984e+00,
          -2.7500e+00, -3.6562e+00],
         [-3.7891e-01, -4.7266e-01,  1.1953e+00,  ..., -1.5938e+00,
          -2.7344e+00, -5.3438e+00],
         [-2.6953e-01,  4.3359e-01,  4.2773e-01,  ..., -1.1172e+00,
          -3.2500e+00, -5.4062e+00]],

        [[-1.6724e-02, -4.3030e-03,  4.7607e-03,  ...,  1.0352e-01,
           8.0566e-02,  3.9062e-01],
         [-4.0625e-01, -7.0312e-01, -8.1250e-01,  ..., -2.0781e+00,
          -1.4922e+00, -6.7812e+00],
         [-1.4688e+00, -1.0312e+00, -9.8828e-01,  ..., -9.4531e-01,
          -2.5469e+00, -6.0625e+00],
         ...,
         [-5.3906e-01, -1.0938e+00,  1.0781e+00,  ..., -1.3984e+00,
          -2.7500e+00, -3.6562e+00],
         [-3.7891e-01, -4.7266e-01,  1.1953e+00,  ..., -1.5938e+00,
          -2.7344e+00, -5.3438e+00],
         [-2.6953e-01,  4.3359e-01,  4.2773e-01,  ..., -1.1172e+00,
          -3.2500e+00, -5.4062e+00]],

        ...,

        [[-5.0354e-04,  4.9744e-03,  1.7471e-03,  ...,  4.1504e-02,
          -3.3203e-02,  4.9316e-02],
         [ 9.1016e-01, -8.8672e-01, -8.7891e-01,  ..., -1.8828e+00,
           1.9297e+00,  2.3047e-01],
         [-1.1406e+00,  1.6895e-01, -6.6797e-01,  ..., -1.1094e+00,
           4.7852e-01,  2.0312e+00],
         ...,
         [ 4.3750e-01, -1.3281e+00,  1.2422e+00,  ...,  8.4375e-01,
           1.1797e+00,  9.6191e-02],
         [-2.6367e-02, -8.2031e-01,  0.0000e+00,  ..., -6.2891e-01,
           9.3750e-01, -8.6328e-01],
         [-2.2500e+00,  7.5391e-01, -7.5781e-01,  ..., -2.4902e-02,
           1.8750e+00,  2.4062e+00]],

        [[-5.0354e-04,  4.9744e-03,  1.7471e-03,  ...,  4.1504e-02,
          -3.3203e-02,  4.9316e-02],
         [ 9.1016e-01, -8.8672e-01, -8.7891e-01,  ..., -1.8828e+00,
           1.9297e+00,  2.3047e-01],
         [-1.1406e+00,  1.6895e-01, -6.6797e-01,  ..., -1.1094e+00,
           4.7852e-01,  2.0312e+00],
         ...,
         [ 4.3750e-01, -1.3281e+00,  1.2422e+00,  ...,  8.4375e-01,
           1.1797e+00,  9.6191e-02],
         [-2.6367e-02, -8.2031e-01,  0.0000e+00,  ..., -6.2891e-01,
           9.3750e-01, -8.6328e-01],
         [-2.2500e+00,  7.5391e-01, -7.5781e-01,  ..., -2.4902e-02,
           1.8750e+00,  2.4062e+00]],

        [[-5.0354e-04,  4.9744e-03,  1.7471e-03,  ...,  4.1504e-02,
          -3.3203e-02,  4.9316e-02],
         [ 9.1016e-01, -8.8672e-01, -8.7891e-01,  ..., -1.8828e+00,
           1.9297e+00,  2.3047e-01],
         [-1.1406e+00,  1.6895e-01, -6.6797e-01,  ..., -1.1094e+00,
           4.7852e-01,  2.0312e+00],
         ...,
         [ 4.3750e-01, -1.3281e+00,  1.2422e+00,  ...,  8.4375e-01,
           1.1797e+00,  9.6191e-02],
         [-2.6367e-02, -8.2031e-01,  0.0000e+00,  ..., -6.2891e-01,
           9.3750e-01, -8.6328e-01],
         [-2.2500e+00,  7.5391e-01, -7.5781e-01,  ..., -2.4902e-02,
           1.8750e+00,  2.4062e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 1.1902e-03, -6.4392e-03,  8.5449e-03,  ...,  5.4626e-03,
           7.1716e-03,  7.1106e-03],
         [ 5.1562e-01,  1.0625e+00, -7.3438e-01,  ...,  1.2266e+00,
          -1.7422e+00,  6.5625e-01],
         [ 3.5156e-01, -6.4844e-01, -9.4531e-01,  ...,  4.4727e-01,
          -1.4531e+00,  6.0938e-01],
         ...,
         [ 7.3242e-02,  1.6484e+00, -2.2363e-01,  ..., -1.1094e+00,
          -1.3906e+00,  6.1719e-01],
         [ 3.8867e-01,  3.4570e-01,  1.3477e-01,  ...,  2.0215e-01,
          -1.8164e-01,  4.9414e-01],
         [ 1.0156e+00,  4.2578e-01, -1.3438e+00,  ..., -5.2490e-02,
          -1.3281e-01,  7.1777e-02]],

        [[ 1.1902e-03, -6.4392e-03,  8.5449e-03,  ...,  5.4626e-03,
           7.1716e-03,  7.1106e-03],
         [ 5.1562e-01,  1.0625e+00, -7.3438e-01,  ...,  1.2266e+00,
          -1.7422e+00,  6.5625e-01],
         [ 3.5156e-01, -6.4844e-01, -9.4531e-01,  ...,  4.4727e-01,
          -1.4531e+00,  6.0938e-01],
         ...,
         [ 7.3242e-02,  1.6484e+00, -2.2363e-01,  ..., -1.1094e+00,
          -1.3906e+00,  6.1719e-01],
         [ 3.8867e-01,  3.4570e-01,  1.3477e-01,  ...,  2.0215e-01,
          -1.8164e-01,  4.9414e-01],
         [ 1.0156e+00,  4.2578e-01, -1.3438e+00,  ..., -5.2490e-02,
          -1.3281e-01,  7.1777e-02]],

        [[ 1.1902e-03, -6.4392e-03,  8.5449e-03,  ...,  5.4626e-03,
           7.1716e-03,  7.1106e-03],
         [ 5.1562e-01,  1.0625e+00, -7.3438e-01,  ...,  1.2266e+00,
          -1.7422e+00,  6.5625e-01],
         [ 3.5156e-01, -6.4844e-01, -9.4531e-01,  ...,  4.4727e-01,
          -1.4531e+00,  6.0938e-01],
         ...,
         [ 7.3242e-02,  1.6484e+00, -2.2363e-01,  ..., -1.1094e+00,
          -1.3906e+00,  6.1719e-01],
         [ 3.8867e-01,  3.4570e-01,  1.3477e-01,  ...,  2.0215e-01,
          -1.8164e-01,  4.9414e-01],
         [ 1.0156e+00,  4.2578e-01, -1.3438e+00,  ..., -5.2490e-02,
          -1.3281e-01,  7.1777e-02]],

        ...,

        [[-5.5847e-03, -3.2501e-03,  4.8523e-03,  ...,  1.0010e-02,
          -5.5542e-03,  7.4768e-03],
         [ 5.9375e-01, -1.4160e-01, -1.7383e-01,  ...,  2.6758e-01,
          -4.1992e-01,  2.2949e-01],
         [-2.5781e-01,  1.6562e+00,  1.6172e+00,  ...,  1.5781e+00,
           7.8516e-01,  3.4375e-01],
         ...,
         [-1.2812e+00,  8.8672e-01, -1.3281e+00,  ..., -8.7109e-01,
           1.0078e+00, -7.8125e-01],
         [-3.3594e-01,  2.1191e-01, -7.7148e-02,  ...,  1.1768e-01,
           8.4766e-01,  3.8477e-01],
         [-4.4336e-01, -1.6484e+00,  3.8477e-01,  ..., -1.3203e+00,
           3.1445e-01, -3.1641e-01]],

        [[-5.5847e-03, -3.2501e-03,  4.8523e-03,  ...,  1.0010e-02,
          -5.5542e-03,  7.4768e-03],
         [ 5.9375e-01, -1.4160e-01, -1.7383e-01,  ...,  2.6758e-01,
          -4.1992e-01,  2.2949e-01],
         [-2.5781e-01,  1.6562e+00,  1.6172e+00,  ...,  1.5781e+00,
           7.8516e-01,  3.4375e-01],
         ...,
         [-1.2812e+00,  8.8672e-01, -1.3281e+00,  ..., -8.7109e-01,
           1.0078e+00, -7.8125e-01],
         [-3.3594e-01,  2.1191e-01, -7.7148e-02,  ...,  1.1768e-01,
           8.4766e-01,  3.8477e-01],
         [-4.4336e-01, -1.6484e+00,  3.8477e-01,  ..., -1.3203e+00,
           3.1445e-01, -3.1641e-01]],

        [[-5.5847e-03, -3.2501e-03,  4.8523e-03,  ...,  1.0010e-02,
          -5.5542e-03,  7.4768e-03],
         [ 5.9375e-01, -1.4160e-01, -1.7383e-01,  ...,  2.6758e-01,
          -4.1992e-01,  2.2949e-01],
         [-2.5781e-01,  1.6562e+00,  1.6172e+00,  ...,  1.5781e+00,
           7.8516e-01,  3.4375e-01],
         ...,
         [-1.2812e+00,  8.8672e-01, -1.3281e+00,  ..., -8.7109e-01,
           1.0078e+00, -7.8125e-01],
         [-3.3594e-01,  2.1191e-01, -7.7148e-02,  ...,  1.1768e-01,
           8.4766e-01,  3.8477e-01],
         [-4.4336e-01, -1.6484e+00,  3.8477e-01,  ..., -1.3203e+00,
           3.1445e-01, -3.1641e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.4414e-03,  9.9487e-03,  9.2773e-03,  ..., -2.6367e-01,
           4.9316e-02,  1.9043e-02],
         [-8.3594e-01, -7.3828e-01,  3.5352e-01,  ...,  4.2812e+00,
          -3.0469e-01, -2.1875e+00],
         [ 6.0156e-01,  1.6406e-01,  8.0078e-02,  ...,  6.3125e+00,
          -2.9883e-01,  1.1658e-02],
         ...,
         [ 2.5391e-01, -1.4453e-01, -2.3535e-01,  ...,  2.1562e+00,
           1.7285e-01,  7.7734e-01],
         [ 1.3125e+00,  3.8281e-01, -6.7969e-01,  ...,  3.4062e+00,
           2.0020e-01,  8.7891e-03],
         [ 9.6094e-01, -1.1328e+00, -2.9102e-01,  ...,  3.6875e+00,
          -1.0078e+00, -1.9043e-01]],

        [[ 2.4414e-03,  9.9487e-03,  9.2773e-03,  ..., -2.6367e-01,
           4.9316e-02,  1.9043e-02],
         [-8.3594e-01, -7.3828e-01,  3.5352e-01,  ...,  4.2812e+00,
          -3.0469e-01, -2.1875e+00],
         [ 6.0156e-01,  1.6406e-01,  8.0078e-02,  ...,  6.3125e+00,
          -2.9883e-01,  1.1658e-02],
         ...,
         [ 2.5391e-01, -1.4453e-01, -2.3535e-01,  ...,  2.1562e+00,
           1.7285e-01,  7.7734e-01],
         [ 1.3125e+00,  3.8281e-01, -6.7969e-01,  ...,  3.4062e+00,
           2.0020e-01,  8.7891e-03],
         [ 9.6094e-01, -1.1328e+00, -2.9102e-01,  ...,  3.6875e+00,
          -1.0078e+00, -1.9043e-01]],

        [[ 2.4414e-03,  9.9487e-03,  9.2773e-03,  ..., -2.6367e-01,
           4.9316e-02,  1.9043e-02],
         [-8.3594e-01, -7.3828e-01,  3.5352e-01,  ...,  4.2812e+00,
          -3.0469e-01, -2.1875e+00],
         [ 6.0156e-01,  1.6406e-01,  8.0078e-02,  ...,  6.3125e+00,
          -2.9883e-01,  1.1658e-02],
         ...,
         [ 2.5391e-01, -1.4453e-01, -2.3535e-01,  ...,  2.1562e+00,
           1.7285e-01,  7.7734e-01],
         [ 1.3125e+00,  3.8281e-01, -6.7969e-01,  ...,  3.4062e+00,
           2.0020e-01,  8.7891e-03],
         [ 9.6094e-01, -1.1328e+00, -2.9102e-01,  ...,  3.6875e+00,
          -1.0078e+00, -1.9043e-01]],

        ...,

        [[-6.4850e-04,  4.4250e-03, -1.9684e-03,  ...,  9.6680e-02,
          -5.8984e-01, -5.2246e-02],
         [-2.4609e-01, -6.4844e-01, -1.1133e-01,  ...,  2.4062e+00,
           8.1250e+00,  3.2188e+00],
         [-8.2031e-02,  6.2500e-01, -2.9492e-01,  ..., -3.3750e+00,
           1.0375e+01,  4.1406e-01],
         ...,
         [ 3.3398e-01, -3.7305e-01,  6.9531e-01,  ...,  3.3789e-01,
           8.4375e+00,  5.9375e+00],
         [-5.8594e-02,  1.0254e-01,  2.1094e-01,  ..., -1.3281e+00,
           7.4375e+00,  1.9531e+00],
         [-3.3789e-01,  1.7578e-02, -2.4170e-02,  ...,  9.6484e-01,
           9.0625e+00, -5.2812e+00]],

        [[-6.4850e-04,  4.4250e-03, -1.9684e-03,  ...,  9.6680e-02,
          -5.8984e-01, -5.2246e-02],
         [-2.4609e-01, -6.4844e-01, -1.1133e-01,  ...,  2.4062e+00,
           8.1250e+00,  3.2188e+00],
         [-8.2031e-02,  6.2500e-01, -2.9492e-01,  ..., -3.3750e+00,
           1.0375e+01,  4.1406e-01],
         ...,
         [ 3.3398e-01, -3.7305e-01,  6.9531e-01,  ...,  3.3789e-01,
           8.4375e+00,  5.9375e+00],
         [-5.8594e-02,  1.0254e-01,  2.1094e-01,  ..., -1.3281e+00,
           7.4375e+00,  1.9531e+00],
         [-3.3789e-01,  1.7578e-02, -2.4170e-02,  ...,  9.6484e-01,
           9.0625e+00, -5.2812e+00]],

        [[-6.4850e-04,  4.4250e-03, -1.9684e-03,  ...,  9.6680e-02,
          -5.8984e-01, -5.2246e-02],
         [-2.4609e-01, -6.4844e-01, -1.1133e-01,  ...,  2.4062e+00,
           8.1250e+00,  3.2188e+00],
         [-8.2031e-02,  6.2500e-01, -2.9492e-01,  ..., -3.3750e+00,
           1.0375e+01,  4.1406e-01],
         ...,
         [ 3.3398e-01, -3.7305e-01,  6.9531e-01,  ...,  3.3789e-01,
           8.4375e+00,  5.9375e+00],
         [-5.8594e-02,  1.0254e-01,  2.1094e-01,  ..., -1.3281e+00,
           7.4375e+00,  1.9531e+00],
         [-3.3789e-01,  1.7578e-02, -2.4170e-02,  ...,  9.6484e-01,
           9.0625e+00, -5.2812e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-4.2114e-03,  1.7700e-03,  8.1787e-03,  ...,  2.0599e-03,
          -3.2616e-04, -2.4414e-03],
         [ 2.2031e+00, -1.1377e-01, -2.0156e+00,  ...,  1.6719e+00,
          -4.4375e+00, -2.3594e+00],
         [ 5.7812e-01, -7.9297e-01, -7.2266e-01,  ...,  6.5430e-02,
           8.1543e-02, -1.7891e+00],
         ...,
         [ 8.0859e-01, -2.1250e+00, -3.8477e-01,  ..., -9.8145e-02,
          -1.8945e-01, -3.4375e-01],
         [ 1.1016e+00,  3.6719e-01,  1.8750e-01,  ..., -1.4062e-01,
           7.4707e-02, -4.2383e-01],
         [ 4.7461e-01,  1.8359e+00,  2.2969e+00,  ...,  1.2812e+00,
          -8.9355e-02, -1.1406e+00]],

        [[-4.2114e-03,  1.7700e-03,  8.1787e-03,  ...,  2.0599e-03,
          -3.2616e-04, -2.4414e-03],
         [ 2.2031e+00, -1.1377e-01, -2.0156e+00,  ...,  1.6719e+00,
          -4.4375e+00, -2.3594e+00],
         [ 5.7812e-01, -7.9297e-01, -7.2266e-01,  ...,  6.5430e-02,
           8.1543e-02, -1.7891e+00],
         ...,
         [ 8.0859e-01, -2.1250e+00, -3.8477e-01,  ..., -9.8145e-02,
          -1.8945e-01, -3.4375e-01],
         [ 1.1016e+00,  3.6719e-01,  1.8750e-01,  ..., -1.4062e-01,
           7.4707e-02, -4.2383e-01],
         [ 4.7461e-01,  1.8359e+00,  2.2969e+00,  ...,  1.2812e+00,
          -8.9355e-02, -1.1406e+00]],

        [[-4.2114e-03,  1.7700e-03,  8.1787e-03,  ...,  2.0599e-03,
          -3.2616e-04, -2.4414e-03],
         [ 2.2031e+00, -1.1377e-01, -2.0156e+00,  ...,  1.6719e+00,
          -4.4375e+00, -2.3594e+00],
         [ 5.7812e-01, -7.9297e-01, -7.2266e-01,  ...,  6.5430e-02,
           8.1543e-02, -1.7891e+00],
         ...,
         [ 8.0859e-01, -2.1250e+00, -3.8477e-01,  ..., -9.8145e-02,
          -1.8945e-01, -3.4375e-01],
         [ 1.1016e+00,  3.6719e-01,  1.8750e-01,  ..., -1.4062e-01,
           7.4707e-02, -4.2383e-01],
         [ 4.7461e-01,  1.8359e+00,  2.2969e+00,  ...,  1.2812e+00,
          -8.9355e-02, -1.1406e+00]],

        ...,

        [[ 1.8311e-02, -2.3499e-03,  1.9455e-04,  ..., -1.7334e-02,
           4.1389e-04,  4.0283e-03],
         [ 2.8125e-01,  3.6914e-01, -3.3398e-01,  ...,  1.3672e+00,
           4.8047e-01,  6.4062e-01],
         [-1.0391e+00, -1.2695e-01,  4.0039e-01,  ...,  5.3467e-02,
           1.9043e-01,  2.7148e-01],
         ...,
         [ 4.8828e-01,  3.5742e-01,  5.7422e-01,  ..., -5.3516e-01,
          -3.8867e-01,  4.6484e-01],
         [-7.7637e-02,  4.4336e-01,  3.5400e-02,  ...,  3.3008e-01,
           2.5757e-02, -7.4707e-02],
         [-9.7656e-01,  1.4062e+00, -1.0303e-01,  ..., -1.0938e+00,
           1.2891e+00,  2.2461e-01]],

        [[ 1.8311e-02, -2.3499e-03,  1.9455e-04,  ..., -1.7334e-02,
           4.1389e-04,  4.0283e-03],
         [ 2.8125e-01,  3.6914e-01, -3.3398e-01,  ...,  1.3672e+00,
           4.8047e-01,  6.4062e-01],
         [-1.0391e+00, -1.2695e-01,  4.0039e-01,  ...,  5.3467e-02,
           1.9043e-01,  2.7148e-01],
         ...,
         [ 4.8828e-01,  3.5742e-01,  5.7422e-01,  ..., -5.3516e-01,
          -3.8867e-01,  4.6484e-01],
         [-7.7637e-02,  4.4336e-01,  3.5400e-02,  ...,  3.3008e-01,
           2.5757e-02, -7.4707e-02],
         [-9.7656e-01,  1.4062e+00, -1.0303e-01,  ..., -1.0938e+00,
           1.2891e+00,  2.2461e-01]],

        [[ 1.8311e-02, -2.3499e-03,  1.9455e-04,  ..., -1.7334e-02,
           4.1389e-04,  4.0283e-03],
         [ 2.8125e-01,  3.6914e-01, -3.3398e-01,  ...,  1.3672e+00,
           4.8047e-01,  6.4062e-01],
         [-1.0391e+00, -1.2695e-01,  4.0039e-01,  ...,  5.3467e-02,
           1.9043e-01,  2.7148e-01],
         ...,
         [ 4.8828e-01,  3.5742e-01,  5.7422e-01,  ..., -5.3516e-01,
          -3.8867e-01,  4.6484e-01],
         [-7.7637e-02,  4.4336e-01,  3.5400e-02,  ...,  3.3008e-01,
           2.5757e-02, -7.4707e-02],
         [-9.7656e-01,  1.4062e+00, -1.0303e-01,  ..., -1.0938e+00,
           1.2891e+00,  2.2461e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-3.4943e-03,  2.4414e-02,  5.6839e-04,  ...,  5.4932e-02,
          -1.2354e-01, -2.4023e-01],
         [-1.8359e-01, -8.9844e-02, -1.6953e+00,  ..., -1.0156e+00,
           3.7188e+00,  4.8047e-01],
         [-1.4062e+00, -1.9219e+00, -7.4609e-01,  ...,  1.8672e+00,
          -7.1716e-03,  3.0625e+00],
         ...,
         [ 2.2559e-01, -4.1797e-01,  3.4766e-01,  ..., -1.9141e+00,
           8.8672e-01,  1.4609e+00],
         [-6.6895e-02, -5.7812e-01, -7.1289e-02,  ..., -7.0312e-01,
           8.4375e-01,  2.1094e+00],
         [-4.0039e-01,  4.0625e-01, -2.4219e-01,  ..., -3.5938e-01,
           8.2031e-01,  2.0703e-01]],

        [[-3.4943e-03,  2.4414e-02,  5.6839e-04,  ...,  5.4932e-02,
          -1.2354e-01, -2.4023e-01],
         [-1.8359e-01, -8.9844e-02, -1.6953e+00,  ..., -1.0156e+00,
           3.7188e+00,  4.8047e-01],
         [-1.4062e+00, -1.9219e+00, -7.4609e-01,  ...,  1.8672e+00,
          -7.1716e-03,  3.0625e+00],
         ...,
         [ 2.2559e-01, -4.1797e-01,  3.4766e-01,  ..., -1.9141e+00,
           8.8672e-01,  1.4609e+00],
         [-6.6895e-02, -5.7812e-01, -7.1289e-02,  ..., -7.0312e-01,
           8.4375e-01,  2.1094e+00],
         [-4.0039e-01,  4.0625e-01, -2.4219e-01,  ..., -3.5938e-01,
           8.2031e-01,  2.0703e-01]],

        [[-3.4943e-03,  2.4414e-02,  5.6839e-04,  ...,  5.4932e-02,
          -1.2354e-01, -2.4023e-01],
         [-1.8359e-01, -8.9844e-02, -1.6953e+00,  ..., -1.0156e+00,
           3.7188e+00,  4.8047e-01],
         [-1.4062e+00, -1.9219e+00, -7.4609e-01,  ...,  1.8672e+00,
          -7.1716e-03,  3.0625e+00],
         ...,
         [ 2.2559e-01, -4.1797e-01,  3.4766e-01,  ..., -1.9141e+00,
           8.8672e-01,  1.4609e+00],
         [-6.6895e-02, -5.7812e-01, -7.1289e-02,  ..., -7.0312e-01,
           8.4375e-01,  2.1094e+00],
         [-4.0039e-01,  4.0625e-01, -2.4219e-01,  ..., -3.5938e-01,
           8.2031e-01,  2.0703e-01]],

        ...,

        [[ 1.0729e-04,  8.3008e-03, -1.4099e-02,  ...,  5.6763e-03,
          -5.2979e-02,  1.0864e-02],
         [-3.2031e-01, -1.0498e-01,  6.7188e-01,  ..., -9.3750e-01,
           4.4688e+00, -1.2344e+00],
         [-5.8203e-01, -7.8906e-01,  1.0156e+00,  ...,  1.8516e+00,
           2.5312e+00, -1.0547e+00],
         ...,
         [-2.6562e-01, -3.9844e-01, -1.0391e+00,  ...,  1.5918e-01,
           8.0859e-01, -2.3125e+00],
         [-1.4453e-01,  5.5469e-01, -1.4746e-01,  ...,  5.8984e-01,
           1.8906e+00, -2.2461e-01],
         [-4.3164e-01, -1.8750e-01,  6.9531e-01,  ..., -6.6016e-01,
          -1.9727e-01, -6.8750e-01]],

        [[ 1.0729e-04,  8.3008e-03, -1.4099e-02,  ...,  5.6763e-03,
          -5.2979e-02,  1.0864e-02],
         [-3.2031e-01, -1.0498e-01,  6.7188e-01,  ..., -9.3750e-01,
           4.4688e+00, -1.2344e+00],
         [-5.8203e-01, -7.8906e-01,  1.0156e+00,  ...,  1.8516e+00,
           2.5312e+00, -1.0547e+00],
         ...,
         [-2.6562e-01, -3.9844e-01, -1.0391e+00,  ...,  1.5918e-01,
           8.0859e-01, -2.3125e+00],
         [-1.4453e-01,  5.5469e-01, -1.4746e-01,  ...,  5.8984e-01,
           1.8906e+00, -2.2461e-01],
         [-4.3164e-01, -1.8750e-01,  6.9531e-01,  ..., -6.6016e-01,
          -1.9727e-01, -6.8750e-01]],

        [[ 1.0729e-04,  8.3008e-03, -1.4099e-02,  ...,  5.6763e-03,
          -5.2979e-02,  1.0864e-02],
         [-3.2031e-01, -1.0498e-01,  6.7188e-01,  ..., -9.3750e-01,
           4.4688e+00, -1.2344e+00],
         [-5.8203e-01, -7.8906e-01,  1.0156e+00,  ...,  1.8516e+00,
           2.5312e+00, -1.0547e+00],
         ...,
         [-2.6562e-01, -3.9844e-01, -1.0391e+00,  ...,  1.5918e-01,
           8.0859e-01, -2.3125e+00],
         [-1.4453e-01,  5.5469e-01, -1.4746e-01,  ...,  5.8984e-01,
           1.8906e+00, -2.2461e-01],
         [-4.3164e-01, -1.8750e-01,  6.9531e-01,  ..., -6.6016e-01,
          -1.9727e-01, -6.8750e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 4.5776e-03,  1.1414e-02,  6.9275e-03,  ...,  3.7994e-03,
          -2.7161e-03, -2.9755e-03],
         [ 1.0469e+00, -4.7812e+00, -1.7734e+00,  ..., -2.7188e+00,
           1.1641e+00,  6.6016e-01],
         [ 6.2891e-01, -1.8359e+00, -9.1797e-01,  ...,  1.0010e-02,
           1.4453e+00,  1.7266e+00],
         ...,
         [-4.3945e-01, -1.0547e+00, -9.0625e-01,  ...,  1.2656e+00,
           1.0078e+00,  1.0469e+00],
         [ 7.3730e-02,  3.5156e-02, -8.3203e-01,  ...,  6.7969e-01,
           6.2891e-01,  4.6875e-01],
         [-6.0156e-01,  3.8867e-01,  1.1328e+00,  ..., -5.0391e-01,
           3.0664e-01,  8.8672e-01]],

        [[ 4.5776e-03,  1.1414e-02,  6.9275e-03,  ...,  3.7994e-03,
          -2.7161e-03, -2.9755e-03],
         [ 1.0469e+00, -4.7812e+00, -1.7734e+00,  ..., -2.7188e+00,
           1.1641e+00,  6.6016e-01],
         [ 6.2891e-01, -1.8359e+00, -9.1797e-01,  ...,  1.0010e-02,
           1.4453e+00,  1.7266e+00],
         ...,
         [-4.3945e-01, -1.0547e+00, -9.0625e-01,  ...,  1.2656e+00,
           1.0078e+00,  1.0469e+00],
         [ 7.3730e-02,  3.5156e-02, -8.3203e-01,  ...,  6.7969e-01,
           6.2891e-01,  4.6875e-01],
         [-6.0156e-01,  3.8867e-01,  1.1328e+00,  ..., -5.0391e-01,
           3.0664e-01,  8.8672e-01]],

        [[ 4.5776e-03,  1.1414e-02,  6.9275e-03,  ...,  3.7994e-03,
          -2.7161e-03, -2.9755e-03],
         [ 1.0469e+00, -4.7812e+00, -1.7734e+00,  ..., -2.7188e+00,
           1.1641e+00,  6.6016e-01],
         [ 6.2891e-01, -1.8359e+00, -9.1797e-01,  ...,  1.0010e-02,
           1.4453e+00,  1.7266e+00],
         ...,
         [-4.3945e-01, -1.0547e+00, -9.0625e-01,  ...,  1.2656e+00,
           1.0078e+00,  1.0469e+00],
         [ 7.3730e-02,  3.5156e-02, -8.3203e-01,  ...,  6.7969e-01,
           6.2891e-01,  4.6875e-01],
         [-6.0156e-01,  3.8867e-01,  1.1328e+00,  ..., -5.0391e-01,
           3.0664e-01,  8.8672e-01]],

        ...,

        [[-3.8757e-03,  1.6327e-03,  7.3242e-03,  ..., -1.7822e-02,
          -6.9336e-02,  4.8447e-04],
         [-1.7344e+00,  4.6094e-01, -3.1875e+00,  ..., -2.4609e-01,
           2.3125e+00, -3.9688e+00],
         [ 8.3984e-01,  1.3281e+00, -1.0938e+00,  ...,  2.6953e-01,
          -2.3145e-01, -3.3203e-01],
         ...,
         [ 6.9922e-01, -7.2266e-01,  5.9375e-01,  ..., -9.2188e-01,
          -6.6016e-01,  8.5938e-01],
         [ 8.8281e-01,  3.8477e-01, -1.2793e-01,  ...,  4.5117e-01,
           1.9844e+00,  4.0234e-01],
         [ 1.0625e+00,  2.6758e-01, -1.9375e+00,  ...,  1.5000e+00,
           1.0078e+00,  1.0391e+00]],

        [[-3.8757e-03,  1.6327e-03,  7.3242e-03,  ..., -1.7822e-02,
          -6.9336e-02,  4.8447e-04],
         [-1.7344e+00,  4.6094e-01, -3.1875e+00,  ..., -2.4609e-01,
           2.3125e+00, -3.9688e+00],
         [ 8.3984e-01,  1.3281e+00, -1.0938e+00,  ...,  2.6953e-01,
          -2.3145e-01, -3.3203e-01],
         ...,
         [ 6.9922e-01, -7.2266e-01,  5.9375e-01,  ..., -9.2188e-01,
          -6.6016e-01,  8.5938e-01],
         [ 8.8281e-01,  3.8477e-01, -1.2793e-01,  ...,  4.5117e-01,
           1.9844e+00,  4.0234e-01],
         [ 1.0625e+00,  2.6758e-01, -1.9375e+00,  ...,  1.5000e+00,
           1.0078e+00,  1.0391e+00]],

        [[-3.8757e-03,  1.6327e-03,  7.3242e-03,  ..., -1.7822e-02,
          -6.9336e-02,  4.8447e-04],
         [-1.7344e+00,  4.6094e-01, -3.1875e+00,  ..., -2.4609e-01,
           2.3125e+00, -3.9688e+00],
         [ 8.3984e-01,  1.3281e+00, -1.0938e+00,  ...,  2.6953e-01,
          -2.3145e-01, -3.3203e-01],
         ...,
         [ 6.9922e-01, -7.2266e-01,  5.9375e-01,  ..., -9.2188e-01,
          -6.6016e-01,  8.5938e-01],
         [ 8.8281e-01,  3.8477e-01, -1.2793e-01,  ...,  4.5117e-01,
           1.9844e+00,  4.0234e-01],
         [ 1.0625e+00,  2.6758e-01, -1.9375e+00,  ...,  1.5000e+00,
           1.0078e+00,  1.0391e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-0.0067,  0.0189, -0.0097,  ...,  0.1387,  0.0334,  0.0076],
         [-0.2676, -0.6953,  1.0625,  ..., -3.4531, -0.6523,  1.7344],
         [-0.5508,  1.1797,  0.9883,  ..., -4.2188,  0.6836,  0.8125],
         ...,
         [-0.1157,  0.5312, -0.1660,  ..., -4.0312, -1.5781, -1.5156],
         [-0.1172,  0.1138,  0.5273,  ..., -3.8750, -1.6719, -0.4551],
         [ 0.5469, -0.3672,  1.1875,  ..., -3.5625, -0.6055,  0.0591]],

        [[-0.0067,  0.0189, -0.0097,  ...,  0.1387,  0.0334,  0.0076],
         [-0.2676, -0.6953,  1.0625,  ..., -3.4531, -0.6523,  1.7344],
         [-0.5508,  1.1797,  0.9883,  ..., -4.2188,  0.6836,  0.8125],
         ...,
         [-0.1157,  0.5312, -0.1660,  ..., -4.0312, -1.5781, -1.5156],
         [-0.1172,  0.1138,  0.5273,  ..., -3.8750, -1.6719, -0.4551],
         [ 0.5469, -0.3672,  1.1875,  ..., -3.5625, -0.6055,  0.0591]],

        [[-0.0067,  0.0189, -0.0097,  ...,  0.1387,  0.0334,  0.0076],
         [-0.2676, -0.6953,  1.0625,  ..., -3.4531, -0.6523,  1.7344],
         [-0.5508,  1.1797,  0.9883,  ..., -4.2188,  0.6836,  0.8125],
         ...,
         [-0.1157,  0.5312, -0.1660,  ..., -4.0312, -1.5781, -1.5156],
         [-0.1172,  0.1138,  0.5273,  ..., -3.8750, -1.6719, -0.4551],
         [ 0.5469, -0.3672,  1.1875,  ..., -3.5625, -0.6055,  0.0591]],

        ...,

        [[-0.0256,  0.0177,  0.0053,  ...,  0.0048, -0.0129, -0.0820],
         [ 1.0859, -0.6836, -0.5078,  ...,  0.0684, -0.6016,  1.6953],
         [-1.2891, -1.2500,  0.5430,  ...,  1.9766,  0.4082,  2.3125],
         ...,
         [ 0.7500,  0.3965, -1.0234,  ..., -0.2041, -0.9102,  0.6211],
         [-1.3516,  0.5820, -0.5195,  ..., -0.0172, -0.9375,  0.9922],
         [-2.9062,  1.6953, -0.3223,  ..., -1.8125, -0.1445,  0.2734]],

        [[-0.0256,  0.0177,  0.0053,  ...,  0.0048, -0.0129, -0.0820],
         [ 1.0859, -0.6836, -0.5078,  ...,  0.0684, -0.6016,  1.6953],
         [-1.2891, -1.2500,  0.5430,  ...,  1.9766,  0.4082,  2.3125],
         ...,
         [ 0.7500,  0.3965, -1.0234,  ..., -0.2041, -0.9102,  0.6211],
         [-1.3516,  0.5820, -0.5195,  ..., -0.0172, -0.9375,  0.9922],
         [-2.9062,  1.6953, -0.3223,  ..., -1.8125, -0.1445,  0.2734]],

        [[-0.0256,  0.0177,  0.0053,  ...,  0.0048, -0.0129, -0.0820],
         [ 1.0859, -0.6836, -0.5078,  ...,  0.0684, -0.6016,  1.6953],
         [-1.2891, -1.2500,  0.5430,  ...,  1.9766,  0.4082,  2.3125],
         ...,
         [ 0.7500,  0.3965, -1.0234,  ..., -0.2041, -0.9102,  0.6211],
         [-1.3516,  0.5820, -0.5195,  ..., -0.0172, -0.9375,  0.9922],
         [-2.9062,  1.6953, -0.3223,  ..., -1.8125, -0.1445,  0.2734]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[-5.5542e-03, -1.3123e-03, -4.8218e-03,  ...,  5.2185e-03,
          -5.2185e-03, -7.4768e-04],
         [-1.1562e+00,  5.5469e-01,  4.4922e-01,  ..., -1.0078e+00,
          -3.0859e-01,  2.7734e-01],
         [ 5.1172e-01, -2.8711e-01,  3.6865e-02,  ...,  5.1953e-01,
           8.7109e-01, -9.5312e-01],
         ...,
         [ 3.4180e-01,  3.7109e-01, -3.6377e-02,  ..., -1.5723e-01,
          -1.6641e+00, -7.6562e-01],
         [ 7.1289e-02,  1.4221e-02,  1.0742e-01,  ..., -3.0859e-01,
          -1.1953e+00, -7.9688e-01],
         [ 5.3516e-01,  8.4766e-01, -3.1055e-01,  ...,  3.7305e-01,
          -4.6680e-01,  2.9883e-01]],

        [[-5.5542e-03, -1.3123e-03, -4.8218e-03,  ...,  5.2185e-03,
          -5.2185e-03, -7.4768e-04],
         [-1.1562e+00,  5.5469e-01,  4.4922e-01,  ..., -1.0078e+00,
          -3.0859e-01,  2.7734e-01],
         [ 5.1172e-01, -2.8711e-01,  3.6865e-02,  ...,  5.1953e-01,
           8.7109e-01, -9.5312e-01],
         ...,
         [ 3.4180e-01,  3.7109e-01, -3.6377e-02,  ..., -1.5723e-01,
          -1.6641e+00, -7.6562e-01],
         [ 7.1289e-02,  1.4221e-02,  1.0742e-01,  ..., -3.0859e-01,
          -1.1953e+00, -7.9688e-01],
         [ 5.3516e-01,  8.4766e-01, -3.1055e-01,  ...,  3.7305e-01,
          -4.6680e-01,  2.9883e-01]],

        [[-5.5542e-03, -1.3123e-03, -4.8218e-03,  ...,  5.2185e-03,
          -5.2185e-03, -7.4768e-04],
         [-1.1562e+00,  5.5469e-01,  4.4922e-01,  ..., -1.0078e+00,
          -3.0859e-01,  2.7734e-01],
         [ 5.1172e-01, -2.8711e-01,  3.6865e-02,  ...,  5.1953e-01,
           8.7109e-01, -9.5312e-01],
         ...,
         [ 3.4180e-01,  3.7109e-01, -3.6377e-02,  ..., -1.5723e-01,
          -1.6641e+00, -7.6562e-01],
         [ 7.1289e-02,  1.4221e-02,  1.0742e-01,  ..., -3.0859e-01,
          -1.1953e+00, -7.9688e-01],
         [ 5.3516e-01,  8.4766e-01, -3.1055e-01,  ...,  3.7305e-01,
          -4.6680e-01,  2.9883e-01]],

        ...,

        [[-7.5378e-03, -4.8161e-05,  1.0132e-02,  ..., -6.2256e-03,
          -5.2185e-03, -4.3640e-03],
         [ 1.0156e+00, -2.6367e-01, -4.8242e-01,  ..., -1.2969e+00,
           5.2734e-01,  2.0996e-01],
         [-4.5898e-02,  1.1719e+00,  9.1797e-01,  ..., -1.0000e+00,
           8.6719e-01, -4.1211e-01],
         ...,
         [-1.8066e-01, -1.7344e+00, -4.3359e-01,  ..., -8.9844e-01,
          -8.0469e-01,  9.5215e-02],
         [ 3.6133e-01, -2.1875e-01, -1.9434e-01,  ..., -1.3965e-01,
          -1.6641e+00, -4.4922e-01],
         [-1.1328e+00, -5.2734e-01,  3.2031e-01,  ..., -2.3071e-02,
           7.3438e-01,  4.8828e-02]],

        [[-7.5378e-03, -4.8161e-05,  1.0132e-02,  ..., -6.2256e-03,
          -5.2185e-03, -4.3640e-03],
         [ 1.0156e+00, -2.6367e-01, -4.8242e-01,  ..., -1.2969e+00,
           5.2734e-01,  2.0996e-01],
         [-4.5898e-02,  1.1719e+00,  9.1797e-01,  ..., -1.0000e+00,
           8.6719e-01, -4.1211e-01],
         ...,
         [-1.8066e-01, -1.7344e+00, -4.3359e-01,  ..., -8.9844e-01,
          -8.0469e-01,  9.5215e-02],
         [ 3.6133e-01, -2.1875e-01, -1.9434e-01,  ..., -1.3965e-01,
          -1.6641e+00, -4.4922e-01],
         [-1.1328e+00, -5.2734e-01,  3.2031e-01,  ..., -2.3071e-02,
           7.3438e-01,  4.8828e-02]],

        [[-7.5378e-03, -4.8161e-05,  1.0132e-02,  ..., -6.2256e-03,
          -5.2185e-03, -4.3640e-03],
         [ 1.0156e+00, -2.6367e-01, -4.8242e-01,  ..., -1.2969e+00,
           5.2734e-01,  2.0996e-01],
         [-4.5898e-02,  1.1719e+00,  9.1797e-01,  ..., -1.0000e+00,
           8.6719e-01, -4.1211e-01],
         ...,
         [-1.8066e-01, -1.7344e+00, -4.3359e-01,  ..., -8.9844e-01,
          -8.0469e-01,  9.5215e-02],
         [ 3.6133e-01, -2.1875e-01, -1.9434e-01,  ..., -1.3965e-01,
          -1.6641e+00, -4.4922e-01],
         [-1.1328e+00, -5.2734e-01,  3.2031e-01,  ..., -2.3071e-02,
           7.3438e-01,  4.8828e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.5198e-02,  4.3030e-03,  1.2131e-03,  ...,  3.6865e-02,
          -2.1387e-01, -1.9287e-02],
         [-1.3828e+00,  2.1875e-01, -4.2480e-02,  ...,  1.6562e+00,
           4.1406e-01,  3.2344e+00],
         [-9.0234e-01,  2.6001e-02, -2.6172e-01,  ...,  2.2656e+00,
           1.5625e+00,  2.1719e+00],
         ...,
         [-8.6914e-02, -6.6797e-01,  8.8281e-01,  ..., -3.2471e-02,
          -2.3125e+00,  2.3906e+00],
         [ 2.8711e-01, -7.0312e-02,  1.1914e-01,  ..., -2.1875e-01,
          -1.6484e+00, -8.7891e-01],
         [-1.6211e-01, -3.6523e-01, -3.8477e-01,  ..., -3.7031e+00,
          -2.7031e+00,  2.1719e+00]],

        [[ 1.5198e-02,  4.3030e-03,  1.2131e-03,  ...,  3.6865e-02,
          -2.1387e-01, -1.9287e-02],
         [-1.3828e+00,  2.1875e-01, -4.2480e-02,  ...,  1.6562e+00,
           4.1406e-01,  3.2344e+00],
         [-9.0234e-01,  2.6001e-02, -2.6172e-01,  ...,  2.2656e+00,
           1.5625e+00,  2.1719e+00],
         ...,
         [-8.6914e-02, -6.6797e-01,  8.8281e-01,  ..., -3.2471e-02,
          -2.3125e+00,  2.3906e+00],
         [ 2.8711e-01, -7.0312e-02,  1.1914e-01,  ..., -2.1875e-01,
          -1.6484e+00, -8.7891e-01],
         [-1.6211e-01, -3.6523e-01, -3.8477e-01,  ..., -3.7031e+00,
          -2.7031e+00,  2.1719e+00]],

        [[ 1.5198e-02,  4.3030e-03,  1.2131e-03,  ...,  3.6865e-02,
          -2.1387e-01, -1.9287e-02],
         [-1.3828e+00,  2.1875e-01, -4.2480e-02,  ...,  1.6562e+00,
           4.1406e-01,  3.2344e+00],
         [-9.0234e-01,  2.6001e-02, -2.6172e-01,  ...,  2.2656e+00,
           1.5625e+00,  2.1719e+00],
         ...,
         [-8.6914e-02, -6.6797e-01,  8.8281e-01,  ..., -3.2471e-02,
          -2.3125e+00,  2.3906e+00],
         [ 2.8711e-01, -7.0312e-02,  1.1914e-01,  ..., -2.1875e-01,
          -1.6484e+00, -8.7891e-01],
         [-1.6211e-01, -3.6523e-01, -3.8477e-01,  ..., -3.7031e+00,
          -2.7031e+00,  2.1719e+00]],

        ...,

        [[-1.1536e-02, -6.7139e-03,  3.7003e-04,  ...,  1.7944e-02,
           6.1279e-02,  1.0938e+00],
         [-4.4922e-02, -1.0078e+00,  9.3750e-02,  ..., -1.9609e+00,
           3.4219e+00, -8.7500e+00],
         [-1.6602e-01, -6.3672e-01,  5.1270e-03,  ...,  7.3828e-01,
          -2.9375e+00, -1.0000e+01],
         ...,
         [-3.8281e-01, -9.3359e-01,  3.3594e-01,  ...,  2.8125e+00,
           6.5625e-01, -6.7812e+00],
         [-4.1016e-02,  1.3477e-01, -2.2852e-01,  ..., -5.8203e-01,
          -1.6172e+00, -7.9062e+00],
         [ 4.8242e-01,  6.7578e-01, -4.8242e-01,  ..., -6.4062e-01,
          -2.4531e+00, -7.1562e+00]],

        [[-1.1536e-02, -6.7139e-03,  3.7003e-04,  ...,  1.7944e-02,
           6.1279e-02,  1.0938e+00],
         [-4.4922e-02, -1.0078e+00,  9.3750e-02,  ..., -1.9609e+00,
           3.4219e+00, -8.7500e+00],
         [-1.6602e-01, -6.3672e-01,  5.1270e-03,  ...,  7.3828e-01,
          -2.9375e+00, -1.0000e+01],
         ...,
         [-3.8281e-01, -9.3359e-01,  3.3594e-01,  ...,  2.8125e+00,
           6.5625e-01, -6.7812e+00],
         [-4.1016e-02,  1.3477e-01, -2.2852e-01,  ..., -5.8203e-01,
          -1.6172e+00, -7.9062e+00],
         [ 4.8242e-01,  6.7578e-01, -4.8242e-01,  ..., -6.4062e-01,
          -2.4531e+00, -7.1562e+00]],

        [[-1.1536e-02, -6.7139e-03,  3.7003e-04,  ...,  1.7944e-02,
           6.1279e-02,  1.0938e+00],
         [-4.4922e-02, -1.0078e+00,  9.3750e-02,  ..., -1.9609e+00,
           3.4219e+00, -8.7500e+00],
         [-1.6602e-01, -6.3672e-01,  5.1270e-03,  ...,  7.3828e-01,
          -2.9375e+00, -1.0000e+01],
         ...,
         [-3.8281e-01, -9.3359e-01,  3.3594e-01,  ...,  2.8125e+00,
           6.5625e-01, -6.7812e+00],
         [-4.1016e-02,  1.3477e-01, -2.2852e-01,  ..., -5.8203e-01,
          -1.6172e+00, -7.9062e+00],
         [ 4.8242e-01,  6.7578e-01, -4.8242e-01,  ..., -6.4062e-01,
          -2.4531e+00, -7.1562e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 2.9755e-04,  3.8757e-03,  6.9427e-04,  ...,  5.5237e-03,
           1.3828e-04, -6.5918e-03],
         [ 1.3594e+00,  5.7422e-01, -1.0938e+00,  ..., -1.8828e+00,
           5.3516e-01,  5.5859e-01],
         [ 2.3315e-02, -2.5938e+00,  8.9844e-01,  ..., -9.7266e-01,
           3.0664e-01,  1.2266e+00],
         ...,
         [ 1.4844e+00, -5.8984e-01,  1.9629e-01,  ..., -6.0156e-01,
          -1.9238e-01, -2.1777e-01],
         [ 1.1406e+00, -1.0938e+00, -3.2422e-01,  ..., -1.1719e-01,
          -3.1055e-01, -2.3438e-02],
         [ 5.8203e-01, -1.5747e-02, -6.3281e-01,  ...,  1.0547e+00,
          -1.0645e-01,  5.0537e-02]],

        [[ 2.9755e-04,  3.8757e-03,  6.9427e-04,  ...,  5.5237e-03,
           1.3828e-04, -6.5918e-03],
         [ 1.3594e+00,  5.7422e-01, -1.0938e+00,  ..., -1.8828e+00,
           5.3516e-01,  5.5859e-01],
         [ 2.3315e-02, -2.5938e+00,  8.9844e-01,  ..., -9.7266e-01,
           3.0664e-01,  1.2266e+00],
         ...,
         [ 1.4844e+00, -5.8984e-01,  1.9629e-01,  ..., -6.0156e-01,
          -1.9238e-01, -2.1777e-01],
         [ 1.1406e+00, -1.0938e+00, -3.2422e-01,  ..., -1.1719e-01,
          -3.1055e-01, -2.3438e-02],
         [ 5.8203e-01, -1.5747e-02, -6.3281e-01,  ...,  1.0547e+00,
          -1.0645e-01,  5.0537e-02]],

        [[ 2.9755e-04,  3.8757e-03,  6.9427e-04,  ...,  5.5237e-03,
           1.3828e-04, -6.5918e-03],
         [ 1.3594e+00,  5.7422e-01, -1.0938e+00,  ..., -1.8828e+00,
           5.3516e-01,  5.5859e-01],
         [ 2.3315e-02, -2.5938e+00,  8.9844e-01,  ..., -9.7266e-01,
           3.0664e-01,  1.2266e+00],
         ...,
         [ 1.4844e+00, -5.8984e-01,  1.9629e-01,  ..., -6.0156e-01,
          -1.9238e-01, -2.1777e-01],
         [ 1.1406e+00, -1.0938e+00, -3.2422e-01,  ..., -1.1719e-01,
          -3.1055e-01, -2.3438e-02],
         [ 5.8203e-01, -1.5747e-02, -6.3281e-01,  ...,  1.0547e+00,
          -1.0645e-01,  5.0537e-02]],

        ...,

        [[ 2.6703e-03, -6.5918e-03, -9.8267e-03,  ..., -7.3242e-03,
          -2.3460e-04, -4.3945e-03],
         [-1.8906e+00,  7.5391e-01, -9.6191e-02,  ...,  2.6367e-01,
           1.5156e+00, -1.2031e+00],
         [ 1.4531e+00, -1.3359e+00,  9.0820e-02,  ..., -2.6406e+00,
           4.2188e+00,  3.1875e+00],
         ...,
         [ 6.2109e-01,  9.7266e-01, -8.7891e-01,  ...,  1.5625e-01,
          -3.2617e-01, -9.0234e-01],
         [-3.6719e-01, -1.1475e-01, -5.4297e-01,  ...,  2.7734e-01,
           1.1328e-01,  5.5176e-02],
         [ 5.9766e-01,  3.8086e-01,  2.1562e+00,  ...,  1.7656e+00,
          -3.3008e-01,  5.6152e-02]],

        [[ 2.6703e-03, -6.5918e-03, -9.8267e-03,  ..., -7.3242e-03,
          -2.3460e-04, -4.3945e-03],
         [-1.8906e+00,  7.5391e-01, -9.6191e-02,  ...,  2.6367e-01,
           1.5156e+00, -1.2031e+00],
         [ 1.4531e+00, -1.3359e+00,  9.0820e-02,  ..., -2.6406e+00,
           4.2188e+00,  3.1875e+00],
         ...,
         [ 6.2109e-01,  9.7266e-01, -8.7891e-01,  ...,  1.5625e-01,
          -3.2617e-01, -9.0234e-01],
         [-3.6719e-01, -1.1475e-01, -5.4297e-01,  ...,  2.7734e-01,
           1.1328e-01,  5.5176e-02],
         [ 5.9766e-01,  3.8086e-01,  2.1562e+00,  ...,  1.7656e+00,
          -3.3008e-01,  5.6152e-02]],

        [[ 2.6703e-03, -6.5918e-03, -9.8267e-03,  ..., -7.3242e-03,
          -2.3460e-04, -4.3945e-03],
         [-1.8906e+00,  7.5391e-01, -9.6191e-02,  ...,  2.6367e-01,
           1.5156e+00, -1.2031e+00],
         [ 1.4531e+00, -1.3359e+00,  9.0820e-02,  ..., -2.6406e+00,
           4.2188e+00,  3.1875e+00],
         ...,
         [ 6.2109e-01,  9.7266e-01, -8.7891e-01,  ...,  1.5625e-01,
          -3.2617e-01, -9.0234e-01],
         [-3.6719e-01, -1.1475e-01, -5.4297e-01,  ...,  2.7734e-01,
           1.1328e-01,  5.5176e-02],
         [ 5.9766e-01,  3.8086e-01,  2.1562e+00,  ...,  1.7656e+00,
          -3.3008e-01,  5.6152e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-3.6926e-03,  3.5858e-03, -1.4191e-03,  ...,  2.9907e-02,
          -6.6406e-02,  5.3516e-01],
         [ 2.0938e+00,  1.7031e+00, -1.5156e+00,  ..., -2.1875e+00,
           1.8594e+00, -5.1875e+00],
         [ 7.6172e-01,  1.2109e+00, -1.4062e-01,  ..., -1.0938e+00,
          -1.8047e+00, -3.8281e+00],
         ...,
         [ 4.6875e-01, -4.2383e-01,  2.3828e-01,  ..., -2.9102e-01,
           4.6875e+00, -5.3125e+00],
         [-4.4922e-01, -3.2812e-01,  1.7188e-01,  ...,  8.3594e-01,
           6.0938e-01, -4.4688e+00],
         [-2.3750e+00, -2.3125e+00, -1.1562e+00,  ..., -1.7188e+00,
          -5.0391e-01, -5.0000e+00]],

        [[-3.6926e-03,  3.5858e-03, -1.4191e-03,  ...,  2.9907e-02,
          -6.6406e-02,  5.3516e-01],
         [ 2.0938e+00,  1.7031e+00, -1.5156e+00,  ..., -2.1875e+00,
           1.8594e+00, -5.1875e+00],
         [ 7.6172e-01,  1.2109e+00, -1.4062e-01,  ..., -1.0938e+00,
          -1.8047e+00, -3.8281e+00],
         ...,
         [ 4.6875e-01, -4.2383e-01,  2.3828e-01,  ..., -2.9102e-01,
           4.6875e+00, -5.3125e+00],
         [-4.4922e-01, -3.2812e-01,  1.7188e-01,  ...,  8.3594e-01,
           6.0938e-01, -4.4688e+00],
         [-2.3750e+00, -2.3125e+00, -1.1562e+00,  ..., -1.7188e+00,
          -5.0391e-01, -5.0000e+00]],

        [[-3.6926e-03,  3.5858e-03, -1.4191e-03,  ...,  2.9907e-02,
          -6.6406e-02,  5.3516e-01],
         [ 2.0938e+00,  1.7031e+00, -1.5156e+00,  ..., -2.1875e+00,
           1.8594e+00, -5.1875e+00],
         [ 7.6172e-01,  1.2109e+00, -1.4062e-01,  ..., -1.0938e+00,
          -1.8047e+00, -3.8281e+00],
         ...,
         [ 4.6875e-01, -4.2383e-01,  2.3828e-01,  ..., -2.9102e-01,
           4.6875e+00, -5.3125e+00],
         [-4.4922e-01, -3.2812e-01,  1.7188e-01,  ...,  8.3594e-01,
           6.0938e-01, -4.4688e+00],
         [-2.3750e+00, -2.3125e+00, -1.1562e+00,  ..., -1.7188e+00,
          -5.0391e-01, -5.0000e+00]],

        ...,

        [[-1.0834e-03,  1.1597e-02,  6.6223e-03,  ...,  1.1426e-01,
           5.9082e-02,  7.5195e-02],
         [ 1.6719e+00, -5.2344e-01, -8.1641e-01,  ..., -6.9922e-01,
          -5.8594e-01, -1.0781e+00],
         [ 1.2344e+00, -1.4453e+00, -9.9219e-01,  ..., -6.4844e-01,
          -4.3945e-01, -6.7969e-01],
         ...,
         [ 6.7969e-01, -6.4453e-01,  9.9219e-01,  ..., -1.2344e+00,
          -1.3984e+00, -1.0000e+00],
         [-4.7070e-01,  3.3984e-01,  5.8984e-01,  ..., -4.1992e-01,
          -1.4258e-01,  6.3965e-02],
         [-1.0000e+00,  1.5781e+00,  5.8203e-01,  ..., -9.4141e-01,
           4.6631e-02, -1.3984e+00]],

        [[-1.0834e-03,  1.1597e-02,  6.6223e-03,  ...,  1.1426e-01,
           5.9082e-02,  7.5195e-02],
         [ 1.6719e+00, -5.2344e-01, -8.1641e-01,  ..., -6.9922e-01,
          -5.8594e-01, -1.0781e+00],
         [ 1.2344e+00, -1.4453e+00, -9.9219e-01,  ..., -6.4844e-01,
          -4.3945e-01, -6.7969e-01],
         ...,
         [ 6.7969e-01, -6.4453e-01,  9.9219e-01,  ..., -1.2344e+00,
          -1.3984e+00, -1.0000e+00],
         [-4.7070e-01,  3.3984e-01,  5.8984e-01,  ..., -4.1992e-01,
          -1.4258e-01,  6.3965e-02],
         [-1.0000e+00,  1.5781e+00,  5.8203e-01,  ..., -9.4141e-01,
           4.6631e-02, -1.3984e+00]],

        [[-1.0834e-03,  1.1597e-02,  6.6223e-03,  ...,  1.1426e-01,
           5.9082e-02,  7.5195e-02],
         [ 1.6719e+00, -5.2344e-01, -8.1641e-01,  ..., -6.9922e-01,
          -5.8594e-01, -1.0781e+00],
         [ 1.2344e+00, -1.4453e+00, -9.9219e-01,  ..., -6.4844e-01,
          -4.3945e-01, -6.7969e-01],
         ...,
         [ 6.7969e-01, -6.4453e-01,  9.9219e-01,  ..., -1.2344e+00,
          -1.3984e+00, -1.0000e+00],
         [-4.7070e-01,  3.3984e-01,  5.8984e-01,  ..., -4.1992e-01,
          -1.4258e-01,  6.3965e-02],
         [-1.0000e+00,  1.5781e+00,  5.8203e-01,  ..., -9.4141e-01,
           4.6631e-02, -1.3984e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.1215e-03, -1.3184e-02,  2.7313e-03,  ...,  1.4160e-02,
           2.8038e-04,  7.0496e-03],
         [ 6.5625e-01,  3.0078e-01,  3.0664e-01,  ..., -2.2754e-01,
          -8.6719e-01, -2.0703e-01],
         [ 3.8672e-01, -2.0020e-01,  3.7305e-01,  ...,  1.3477e-01,
           7.9297e-01,  3.1250e-01],
         ...,
         [-1.7109e+00,  2.7344e-01,  9.6680e-02,  ...,  1.7773e-01,
           3.5938e-01, -5.0781e-01],
         [-1.5703e+00, -2.3633e-01,  2.7539e-01,  ...,  4.3945e-01,
          -2.5177e-03, -4.2578e-01],
         [-5.3906e-01, -4.6289e-01,  4.3359e-01,  ...,  4.4531e-01,
          -3.0859e-01, -9.4141e-01]],

        [[-1.1215e-03, -1.3184e-02,  2.7313e-03,  ...,  1.4160e-02,
           2.8038e-04,  7.0496e-03],
         [ 6.5625e-01,  3.0078e-01,  3.0664e-01,  ..., -2.2754e-01,
          -8.6719e-01, -2.0703e-01],
         [ 3.8672e-01, -2.0020e-01,  3.7305e-01,  ...,  1.3477e-01,
           7.9297e-01,  3.1250e-01],
         ...,
         [-1.7109e+00,  2.7344e-01,  9.6680e-02,  ...,  1.7773e-01,
           3.5938e-01, -5.0781e-01],
         [-1.5703e+00, -2.3633e-01,  2.7539e-01,  ...,  4.3945e-01,
          -2.5177e-03, -4.2578e-01],
         [-5.3906e-01, -4.6289e-01,  4.3359e-01,  ...,  4.4531e-01,
          -3.0859e-01, -9.4141e-01]],

        [[-1.1215e-03, -1.3184e-02,  2.7313e-03,  ...,  1.4160e-02,
           2.8038e-04,  7.0496e-03],
         [ 6.5625e-01,  3.0078e-01,  3.0664e-01,  ..., -2.2754e-01,
          -8.6719e-01, -2.0703e-01],
         [ 3.8672e-01, -2.0020e-01,  3.7305e-01,  ...,  1.3477e-01,
           7.9297e-01,  3.1250e-01],
         ...,
         [-1.7109e+00,  2.7344e-01,  9.6680e-02,  ...,  1.7773e-01,
           3.5938e-01, -5.0781e-01],
         [-1.5703e+00, -2.3633e-01,  2.7539e-01,  ...,  4.3945e-01,
          -2.5177e-03, -4.2578e-01],
         [-5.3906e-01, -4.6289e-01,  4.3359e-01,  ...,  4.4531e-01,
          -3.0859e-01, -9.4141e-01]],

        ...,

        [[-9.1553e-03, -9.7656e-03, -6.1951e-03,  ...,  5.8899e-03,
           6.4697e-03, -7.2098e-04],
         [-5.0781e-01,  1.3281e+00, -1.4766e+00,  ..., -7.6172e-01,
          -1.5430e-01, -1.7656e+00],
         [-4.9414e-01,  2.2949e-01,  1.8203e+00,  ..., -3.0273e-01,
          -2.4805e-01, -6.5625e-01],
         ...,
         [-3.7842e-02,  9.4727e-02,  1.5703e+00,  ...,  1.1094e+00,
          -1.2344e+00,  1.5723e-01],
         [-1.6895e-01,  4.3750e-01,  7.5000e-01,  ...,  1.6504e-01,
          -1.5332e-01,  8.2397e-03],
         [-1.0000e+00, -7.9297e-01,  8.7109e-01,  ...,  1.4688e+00,
           1.1875e+00, -5.3906e-01]],

        [[-9.1553e-03, -9.7656e-03, -6.1951e-03,  ...,  5.8899e-03,
           6.4697e-03, -7.2098e-04],
         [-5.0781e-01,  1.3281e+00, -1.4766e+00,  ..., -7.6172e-01,
          -1.5430e-01, -1.7656e+00],
         [-4.9414e-01,  2.2949e-01,  1.8203e+00,  ..., -3.0273e-01,
          -2.4805e-01, -6.5625e-01],
         ...,
         [-3.7842e-02,  9.4727e-02,  1.5703e+00,  ...,  1.1094e+00,
          -1.2344e+00,  1.5723e-01],
         [-1.6895e-01,  4.3750e-01,  7.5000e-01,  ...,  1.6504e-01,
          -1.5332e-01,  8.2397e-03],
         [-1.0000e+00, -7.9297e-01,  8.7109e-01,  ...,  1.4688e+00,
           1.1875e+00, -5.3906e-01]],

        [[-9.1553e-03, -9.7656e-03, -6.1951e-03,  ...,  5.8899e-03,
           6.4697e-03, -7.2098e-04],
         [-5.0781e-01,  1.3281e+00, -1.4766e+00,  ..., -7.6172e-01,
          -1.5430e-01, -1.7656e+00],
         [-4.9414e-01,  2.2949e-01,  1.8203e+00,  ..., -3.0273e-01,
          -2.4805e-01, -6.5625e-01],
         ...,
         [-3.7842e-02,  9.4727e-02,  1.5703e+00,  ...,  1.1094e+00,
          -1.2344e+00,  1.5723e-01],
         [-1.6895e-01,  4.3750e-01,  7.5000e-01,  ...,  1.6504e-01,
          -1.5332e-01,  8.2397e-03],
         [-1.0000e+00, -7.9297e-01,  8.7109e-01,  ...,  1.4688e+00,
           1.1875e+00, -5.3906e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-2.0386e-02,  2.8801e-04,  1.3504e-03,  ..., -6.3281e-01,
          -2.7222e-02, -1.0645e-01],
         [-3.9062e-02,  2.2969e+00,  1.6562e+00,  ...,  1.0688e+01,
          -2.3750e+00,  5.7031e-01],
         [-1.7656e+00,  8.3594e-01,  4.0234e-01,  ...,  1.1438e+01,
          -7.1094e-01,  8.7109e-01],
         ...,
         [-2.5195e-01, -8.2031e-02,  7.3828e-01,  ...,  9.6875e+00,
           1.2500e+00,  1.0078e+00],
         [-4.1211e-01, -1.0781e+00,  6.8750e-01,  ...,  8.6250e+00,
           7.2656e-01,  5.3906e-01],
         [-1.0781e+00, -1.0000e+00,  9.1406e-01,  ...,  9.2500e+00,
           3.4961e-01,  3.0762e-02]],

        [[-2.0386e-02,  2.8801e-04,  1.3504e-03,  ..., -6.3281e-01,
          -2.7222e-02, -1.0645e-01],
         [-3.9062e-02,  2.2969e+00,  1.6562e+00,  ...,  1.0688e+01,
          -2.3750e+00,  5.7031e-01],
         [-1.7656e+00,  8.3594e-01,  4.0234e-01,  ...,  1.1438e+01,
          -7.1094e-01,  8.7109e-01],
         ...,
         [-2.5195e-01, -8.2031e-02,  7.3828e-01,  ...,  9.6875e+00,
           1.2500e+00,  1.0078e+00],
         [-4.1211e-01, -1.0781e+00,  6.8750e-01,  ...,  8.6250e+00,
           7.2656e-01,  5.3906e-01],
         [-1.0781e+00, -1.0000e+00,  9.1406e-01,  ...,  9.2500e+00,
           3.4961e-01,  3.0762e-02]],

        [[-2.0386e-02,  2.8801e-04,  1.3504e-03,  ..., -6.3281e-01,
          -2.7222e-02, -1.0645e-01],
         [-3.9062e-02,  2.2969e+00,  1.6562e+00,  ...,  1.0688e+01,
          -2.3750e+00,  5.7031e-01],
         [-1.7656e+00,  8.3594e-01,  4.0234e-01,  ...,  1.1438e+01,
          -7.1094e-01,  8.7109e-01],
         ...,
         [-2.5195e-01, -8.2031e-02,  7.3828e-01,  ...,  9.6875e+00,
           1.2500e+00,  1.0078e+00],
         [-4.1211e-01, -1.0781e+00,  6.8750e-01,  ...,  8.6250e+00,
           7.2656e-01,  5.3906e-01],
         [-1.0781e+00, -1.0000e+00,  9.1406e-01,  ...,  9.2500e+00,
           3.4961e-01,  3.0762e-02]],

        ...,

        [[ 4.9744e-03,  5.5237e-03,  7.2937e-03,  ..., -2.2949e-02,
           8.1787e-03,  1.5723e-01],
         [-1.3750e+00, -8.7891e-01, -8.5156e-01,  ...,  1.4453e+00,
          -2.7734e-01, -3.7812e+00],
         [-8.2812e-01, -5.9375e-01,  1.8652e-01,  ...,  3.1875e+00,
           2.5977e-01,  1.8555e-01],
         ...,
         [-1.0156e+00,  6.3281e-01, -1.9727e-01,  ..., -1.7500e+00,
          -2.6250e+00, -3.7656e+00],
         [ 2.1680e-01,  4.5703e-01, -6.1719e-01,  ...,  5.9375e-01,
           5.2344e-01, -6.9531e-01],
         [ 1.2969e+00,  8.9062e-01, -1.2188e+00,  ...,  2.5312e+00,
          -1.3516e+00, -2.0469e+00]],

        [[ 4.9744e-03,  5.5237e-03,  7.2937e-03,  ..., -2.2949e-02,
           8.1787e-03,  1.5723e-01],
         [-1.3750e+00, -8.7891e-01, -8.5156e-01,  ...,  1.4453e+00,
          -2.7734e-01, -3.7812e+00],
         [-8.2812e-01, -5.9375e-01,  1.8652e-01,  ...,  3.1875e+00,
           2.5977e-01,  1.8555e-01],
         ...,
         [-1.0156e+00,  6.3281e-01, -1.9727e-01,  ..., -1.7500e+00,
          -2.6250e+00, -3.7656e+00],
         [ 2.1680e-01,  4.5703e-01, -6.1719e-01,  ...,  5.9375e-01,
           5.2344e-01, -6.9531e-01],
         [ 1.2969e+00,  8.9062e-01, -1.2188e+00,  ...,  2.5312e+00,
          -1.3516e+00, -2.0469e+00]],

        [[ 4.9744e-03,  5.5237e-03,  7.2937e-03,  ..., -2.2949e-02,
           8.1787e-03,  1.5723e-01],
         [-1.3750e+00, -8.7891e-01, -8.5156e-01,  ...,  1.4453e+00,
          -2.7734e-01, -3.7812e+00],
         [-8.2812e-01, -5.9375e-01,  1.8652e-01,  ...,  3.1875e+00,
           2.5977e-01,  1.8555e-01],
         ...,
         [-1.0156e+00,  6.3281e-01, -1.9727e-01,  ..., -1.7500e+00,
          -2.6250e+00, -3.7656e+00],
         [ 2.1680e-01,  4.5703e-01, -6.1719e-01,  ...,  5.9375e-01,
           5.2344e-01, -6.9531e-01],
         [ 1.2969e+00,  8.9062e-01, -1.2188e+00,  ...,  2.5312e+00,
          -1.3516e+00, -2.0469e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-0.0063, -0.0132, -0.0074,  ...,  0.0058, -0.0055, -0.0167],
         [-1.8672, -0.9336, -1.5859,  ..., -1.7891, -0.7305, -0.5938],
         [-1.5156,  0.0366,  0.4688,  ..., -0.1069,  0.0542,  0.1040],
         ...,
         [ 0.1250,  0.2812, -1.1172,  ...,  0.2754, -0.7305,  0.0072],
         [ 0.8047, -0.1641,  0.0850,  ..., -0.5703, -0.4941,  0.3301],
         [-1.2109,  0.2578,  1.3516,  ...,  0.1367,  0.1660,  0.7227]],

        [[-0.0063, -0.0132, -0.0074,  ...,  0.0058, -0.0055, -0.0167],
         [-1.8672, -0.9336, -1.5859,  ..., -1.7891, -0.7305, -0.5938],
         [-1.5156,  0.0366,  0.4688,  ..., -0.1069,  0.0542,  0.1040],
         ...,
         [ 0.1250,  0.2812, -1.1172,  ...,  0.2754, -0.7305,  0.0072],
         [ 0.8047, -0.1641,  0.0850,  ..., -0.5703, -0.4941,  0.3301],
         [-1.2109,  0.2578,  1.3516,  ...,  0.1367,  0.1660,  0.7227]],

        [[-0.0063, -0.0132, -0.0074,  ...,  0.0058, -0.0055, -0.0167],
         [-1.8672, -0.9336, -1.5859,  ..., -1.7891, -0.7305, -0.5938],
         [-1.5156,  0.0366,  0.4688,  ..., -0.1069,  0.0542,  0.1040],
         ...,
         [ 0.1250,  0.2812, -1.1172,  ...,  0.2754, -0.7305,  0.0072],
         [ 0.8047, -0.1641,  0.0850,  ..., -0.5703, -0.4941,  0.3301],
         [-1.2109,  0.2578,  1.3516,  ...,  0.1367,  0.1660,  0.7227]],

        ...,

        [[ 0.0068,  0.0091,  0.0122,  ..., -0.0236, -0.0041,  0.0050],
         [ 0.5156,  0.7461,  1.4219,  ...,  1.0859,  1.1172,  0.5859],
         [-0.3027,  0.3945,  1.0859,  ...,  0.7852,  0.3652, -0.1992],
         ...,
         [-0.3750, -0.2832, -0.0659,  ...,  0.2617, -0.7578,  0.3301],
         [ 0.1924,  0.1963, -0.1709,  ...,  0.2207,  0.0713,  0.5234],
         [-0.9453, -0.0046,  0.9141,  ...,  0.2100, -0.6445,  0.6133]],

        [[ 0.0068,  0.0091,  0.0122,  ..., -0.0236, -0.0041,  0.0050],
         [ 0.5156,  0.7461,  1.4219,  ...,  1.0859,  1.1172,  0.5859],
         [-0.3027,  0.3945,  1.0859,  ...,  0.7852,  0.3652, -0.1992],
         ...,
         [-0.3750, -0.2832, -0.0659,  ...,  0.2617, -0.7578,  0.3301],
         [ 0.1924,  0.1963, -0.1709,  ...,  0.2207,  0.0713,  0.5234],
         [-0.9453, -0.0046,  0.9141,  ...,  0.2100, -0.6445,  0.6133]],

        [[ 0.0068,  0.0091,  0.0122,  ..., -0.0236, -0.0041,  0.0050],
         [ 0.5156,  0.7461,  1.4219,  ...,  1.0859,  1.1172,  0.5859],
         [-0.3027,  0.3945,  1.0859,  ...,  0.7852,  0.3652, -0.1992],
         ...,
         [-0.3750, -0.2832, -0.0659,  ...,  0.2617, -0.7578,  0.3301],
         [ 0.1924,  0.1963, -0.1709,  ...,  0.2207,  0.0713,  0.5234],
         [-0.9453, -0.0046,  0.9141,  ...,  0.2100, -0.6445,  0.6133]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.4261e-03,  2.3804e-03,  2.8076e-03,  ..., -3.6377e-02,
          -2.4219e-01,  8.2812e-01],
         [-4.5312e-01, -1.1094e+00,  3.9844e-01,  ...,  6.1328e-01,
          -1.2354e-01, -7.6875e+00],
         [-8.2422e-01, -1.3516e+00,  9.1406e-01,  ..., -1.0859e+00,
           2.6875e+00, -9.0625e+00],
         ...,
         [ 3.4375e-01,  1.1562e+00, -3.6719e-01,  ...,  3.4844e+00,
          -5.7500e+00, -7.5000e+00],
         [-1.6309e-01,  1.9336e-01, -9.5703e-01,  ...,  5.5078e-01,
          -2.3594e+00, -7.0625e+00],
         [-1.6016e+00,  1.0156e+00, -9.3359e-01,  ..., -8.0566e-02,
           1.5938e+00, -8.0625e+00]],

        [[ 2.4261e-03,  2.3804e-03,  2.8076e-03,  ..., -3.6377e-02,
          -2.4219e-01,  8.2812e-01],
         [-4.5312e-01, -1.1094e+00,  3.9844e-01,  ...,  6.1328e-01,
          -1.2354e-01, -7.6875e+00],
         [-8.2422e-01, -1.3516e+00,  9.1406e-01,  ..., -1.0859e+00,
           2.6875e+00, -9.0625e+00],
         ...,
         [ 3.4375e-01,  1.1562e+00, -3.6719e-01,  ...,  3.4844e+00,
          -5.7500e+00, -7.5000e+00],
         [-1.6309e-01,  1.9336e-01, -9.5703e-01,  ...,  5.5078e-01,
          -2.3594e+00, -7.0625e+00],
         [-1.6016e+00,  1.0156e+00, -9.3359e-01,  ..., -8.0566e-02,
           1.5938e+00, -8.0625e+00]],

        [[ 2.4261e-03,  2.3804e-03,  2.8076e-03,  ..., -3.6377e-02,
          -2.4219e-01,  8.2812e-01],
         [-4.5312e-01, -1.1094e+00,  3.9844e-01,  ...,  6.1328e-01,
          -1.2354e-01, -7.6875e+00],
         [-8.2422e-01, -1.3516e+00,  9.1406e-01,  ..., -1.0859e+00,
           2.6875e+00, -9.0625e+00],
         ...,
         [ 3.4375e-01,  1.1562e+00, -3.6719e-01,  ...,  3.4844e+00,
          -5.7500e+00, -7.5000e+00],
         [-1.6309e-01,  1.9336e-01, -9.5703e-01,  ...,  5.5078e-01,
          -2.3594e+00, -7.0625e+00],
         [-1.6016e+00,  1.0156e+00, -9.3359e-01,  ..., -8.0566e-02,
           1.5938e+00, -8.0625e+00]],

        ...,

        [[ 1.5747e-02, -5.4016e-03,  8.7891e-03,  ..., -1.3184e-02,
          -5.9326e-02,  3.2812e-01],
         [-2.4688e+00,  2.1875e-01, -9.1797e-01,  ..., -9.3750e-01,
           1.6250e+00, -2.6875e+00],
         [-2.8125e-01,  2.2461e-01, -4.1406e-01,  ..., -2.9844e+00,
           1.8555e-01, -2.4531e+00],
         ...,
         [-1.2266e+00,  4.6680e-01,  9.8047e-01,  ..., -1.6172e+00,
          -1.7871e-01, -3.3281e+00],
         [-4.3750e-01,  2.4121e-01,  1.2891e-01,  ..., -3.5000e+00,
          -9.5312e-01, -9.8047e-01],
         [ 7.6172e-02, -5.3125e-01, -3.3398e-01,  ..., -3.4688e+00,
           2.4844e+00, -1.8047e+00]],

        [[ 1.5747e-02, -5.4016e-03,  8.7891e-03,  ..., -1.3184e-02,
          -5.9326e-02,  3.2812e-01],
         [-2.4688e+00,  2.1875e-01, -9.1797e-01,  ..., -9.3750e-01,
           1.6250e+00, -2.6875e+00],
         [-2.8125e-01,  2.2461e-01, -4.1406e-01,  ..., -2.9844e+00,
           1.8555e-01, -2.4531e+00],
         ...,
         [-1.2266e+00,  4.6680e-01,  9.8047e-01,  ..., -1.6172e+00,
          -1.7871e-01, -3.3281e+00],
         [-4.3750e-01,  2.4121e-01,  1.2891e-01,  ..., -3.5000e+00,
          -9.5312e-01, -9.8047e-01],
         [ 7.6172e-02, -5.3125e-01, -3.3398e-01,  ..., -3.4688e+00,
           2.4844e+00, -1.8047e+00]],

        [[ 1.5747e-02, -5.4016e-03,  8.7891e-03,  ..., -1.3184e-02,
          -5.9326e-02,  3.2812e-01],
         [-2.4688e+00,  2.1875e-01, -9.1797e-01,  ..., -9.3750e-01,
           1.6250e+00, -2.6875e+00],
         [-2.8125e-01,  2.2461e-01, -4.1406e-01,  ..., -2.9844e+00,
           1.8555e-01, -2.4531e+00],
         ...,
         [-1.2266e+00,  4.6680e-01,  9.8047e-01,  ..., -1.6172e+00,
          -1.7871e-01, -3.3281e+00],
         [-4.3750e-01,  2.4121e-01,  1.2891e-01,  ..., -3.5000e+00,
          -9.5312e-01, -9.8047e-01],
         [ 7.6172e-02, -5.3125e-01, -3.3398e-01,  ..., -3.4688e+00,
           2.4844e+00, -1.8047e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-5.2795e-03,  2.1648e-04, -1.0498e-02,  ..., -1.2054e-03,
           1.2665e-03,  1.0254e-02],
         [ 1.1016e+00, -6.9922e-01,  2.5391e-01,  ..., -3.6914e-01,
           2.3438e-01, -1.0791e-01],
         [ 1.4551e-01, -2.8198e-02,  4.4727e-01,  ..., -1.1719e+00,
           6.7188e-01, -6.7188e-01],
         ...,
         [ 1.6719e+00,  1.6797e+00,  2.3438e+00,  ...,  2.0469e+00,
          -4.2188e+00,  3.6875e+00],
         [ 1.1953e+00,  7.4609e-01,  1.7578e-01,  ...,  4.6387e-02,
          -6.5234e-01,  1.0781e+00],
         [-3.6914e-01,  2.9883e-01,  6.5625e-01,  ..., -9.6484e-01,
           5.4688e-02, -3.2227e-01]],

        [[-5.2795e-03,  2.1648e-04, -1.0498e-02,  ..., -1.2054e-03,
           1.2665e-03,  1.0254e-02],
         [ 1.1016e+00, -6.9922e-01,  2.5391e-01,  ..., -3.6914e-01,
           2.3438e-01, -1.0791e-01],
         [ 1.4551e-01, -2.8198e-02,  4.4727e-01,  ..., -1.1719e+00,
           6.7188e-01, -6.7188e-01],
         ...,
         [ 1.6719e+00,  1.6797e+00,  2.3438e+00,  ...,  2.0469e+00,
          -4.2188e+00,  3.6875e+00],
         [ 1.1953e+00,  7.4609e-01,  1.7578e-01,  ...,  4.6387e-02,
          -6.5234e-01,  1.0781e+00],
         [-3.6914e-01,  2.9883e-01,  6.5625e-01,  ..., -9.6484e-01,
           5.4688e-02, -3.2227e-01]],

        [[-5.2795e-03,  2.1648e-04, -1.0498e-02,  ..., -1.2054e-03,
           1.2665e-03,  1.0254e-02],
         [ 1.1016e+00, -6.9922e-01,  2.5391e-01,  ..., -3.6914e-01,
           2.3438e-01, -1.0791e-01],
         [ 1.4551e-01, -2.8198e-02,  4.4727e-01,  ..., -1.1719e+00,
           6.7188e-01, -6.7188e-01],
         ...,
         [ 1.6719e+00,  1.6797e+00,  2.3438e+00,  ...,  2.0469e+00,
          -4.2188e+00,  3.6875e+00],
         [ 1.1953e+00,  7.4609e-01,  1.7578e-01,  ...,  4.6387e-02,
          -6.5234e-01,  1.0781e+00],
         [-3.6914e-01,  2.9883e-01,  6.5625e-01,  ..., -9.6484e-01,
           5.4688e-02, -3.2227e-01]],

        ...,

        [[-1.2054e-03,  1.3611e-02, -3.5553e-03,  ..., -1.2817e-03,
          -1.3123e-03,  2.8839e-03],
         [-1.1377e-01, -1.1875e+00,  1.1719e+00,  ..., -1.6250e+00,
           4.6680e-01, -6.9531e-01],
         [ 1.9453e+00,  1.0840e-01,  1.4038e-02,  ...,  8.9844e-02,
          -5.5078e-01,  4.6484e-01],
         ...,
         [-3.7500e-01, -8.2031e-01,  7.1484e-01,  ..., -9.7656e-02,
           2.6172e-01, -9.3750e-01],
         [ 1.0156e-01, -3.9648e-01,  3.5547e-01,  ...,  7.2656e-01,
           5.0000e-01, -8.2031e-01],
         [-3.0273e-01,  2.1094e-01, -3.9453e-01,  ...,  1.2109e+00,
           2.5586e-01,  1.6113e-01]],

        [[-1.2054e-03,  1.3611e-02, -3.5553e-03,  ..., -1.2817e-03,
          -1.3123e-03,  2.8839e-03],
         [-1.1377e-01, -1.1875e+00,  1.1719e+00,  ..., -1.6250e+00,
           4.6680e-01, -6.9531e-01],
         [ 1.9453e+00,  1.0840e-01,  1.4038e-02,  ...,  8.9844e-02,
          -5.5078e-01,  4.6484e-01],
         ...,
         [-3.7500e-01, -8.2031e-01,  7.1484e-01,  ..., -9.7656e-02,
           2.6172e-01, -9.3750e-01],
         [ 1.0156e-01, -3.9648e-01,  3.5547e-01,  ...,  7.2656e-01,
           5.0000e-01, -8.2031e-01],
         [-3.0273e-01,  2.1094e-01, -3.9453e-01,  ...,  1.2109e+00,
           2.5586e-01,  1.6113e-01]],

        [[-1.2054e-03,  1.3611e-02, -3.5553e-03,  ..., -1.2817e-03,
          -1.3123e-03,  2.8839e-03],
         [-1.1377e-01, -1.1875e+00,  1.1719e+00,  ..., -1.6250e+00,
           4.6680e-01, -6.9531e-01],
         [ 1.9453e+00,  1.0840e-01,  1.4038e-02,  ...,  8.9844e-02,
          -5.5078e-01,  4.6484e-01],
         ...,
         [-3.7500e-01, -8.2031e-01,  7.1484e-01,  ..., -9.7656e-02,
           2.6172e-01, -9.3750e-01],
         [ 1.0156e-01, -3.9648e-01,  3.5547e-01,  ...,  7.2656e-01,
           5.0000e-01, -8.2031e-01],
         [-3.0273e-01,  2.1094e-01, -3.9453e-01,  ...,  1.2109e+00,
           2.5586e-01,  1.6113e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.3977e-02, -1.9302e-03,  5.9509e-03,  ..., -5.0537e-02,
          -9.7656e-02, -5.2734e-01],
         [-1.3203e+00, -7.7344e-01, -8.3984e-01,  ..., -8.1641e-01,
          -6.2109e-01,  8.3750e+00],
         [ 3.7500e-01,  5.3125e-01, -1.0938e+00,  ..., -1.4062e-01,
           9.2285e-02,  8.3750e+00],
         ...,
         [-1.5156e+00,  1.6484e+00,  1.5938e+00,  ...,  2.2500e+00,
           6.4062e-01,  7.3750e+00],
         [ 4.4141e-01,  4.9805e-01, -2.7344e-02,  ...,  1.3906e+00,
           1.1719e+00,  7.0938e+00],
         [ 1.9531e+00, -1.7578e+00,  5.5078e-01,  ...,  1.3750e+00,
           1.6562e+00,  7.6875e+00]],

        [[ 1.3977e-02, -1.9302e-03,  5.9509e-03,  ..., -5.0537e-02,
          -9.7656e-02, -5.2734e-01],
         [-1.3203e+00, -7.7344e-01, -8.3984e-01,  ..., -8.1641e-01,
          -6.2109e-01,  8.3750e+00],
         [ 3.7500e-01,  5.3125e-01, -1.0938e+00,  ..., -1.4062e-01,
           9.2285e-02,  8.3750e+00],
         ...,
         [-1.5156e+00,  1.6484e+00,  1.5938e+00,  ...,  2.2500e+00,
           6.4062e-01,  7.3750e+00],
         [ 4.4141e-01,  4.9805e-01, -2.7344e-02,  ...,  1.3906e+00,
           1.1719e+00,  7.0938e+00],
         [ 1.9531e+00, -1.7578e+00,  5.5078e-01,  ...,  1.3750e+00,
           1.6562e+00,  7.6875e+00]],

        [[ 1.3977e-02, -1.9302e-03,  5.9509e-03,  ..., -5.0537e-02,
          -9.7656e-02, -5.2734e-01],
         [-1.3203e+00, -7.7344e-01, -8.3984e-01,  ..., -8.1641e-01,
          -6.2109e-01,  8.3750e+00],
         [ 3.7500e-01,  5.3125e-01, -1.0938e+00,  ..., -1.4062e-01,
           9.2285e-02,  8.3750e+00],
         ...,
         [-1.5156e+00,  1.6484e+00,  1.5938e+00,  ...,  2.2500e+00,
           6.4062e-01,  7.3750e+00],
         [ 4.4141e-01,  4.9805e-01, -2.7344e-02,  ...,  1.3906e+00,
           1.1719e+00,  7.0938e+00],
         [ 1.9531e+00, -1.7578e+00,  5.5078e-01,  ...,  1.3750e+00,
           1.6562e+00,  7.6875e+00]],

        ...,

        [[ 3.8719e-04, -5.5237e-03,  4.5013e-04,  ...,  4.7852e-02,
           4.2114e-03,  6.7383e-02],
         [-2.9062e+00,  1.9297e+00,  1.4219e+00,  ...,  3.7305e-01,
          -9.2578e-01,  1.9766e+00],
         [-1.4062e+00,  7.8906e-01,  4.8047e-01,  ..., -1.0840e-01,
           1.9922e+00,  8.5938e-01],
         ...,
         [-1.0312e+00, -8.5156e-01, -5.5469e-01,  ...,  9.4531e-01,
           5.7422e-01, -3.4570e-01],
         [-1.9531e-02, -4.3555e-01,  7.1289e-02,  ...,  1.6113e-01,
           1.8125e+00, -1.6484e+00],
         [ 2.8516e-01, -3.0664e-01,  1.0938e+00,  ..., -2.6367e-01,
          -1.0547e+00, -5.6250e-01]],

        [[ 3.8719e-04, -5.5237e-03,  4.5013e-04,  ...,  4.7852e-02,
           4.2114e-03,  6.7383e-02],
         [-2.9062e+00,  1.9297e+00,  1.4219e+00,  ...,  3.7305e-01,
          -9.2578e-01,  1.9766e+00],
         [-1.4062e+00,  7.8906e-01,  4.8047e-01,  ..., -1.0840e-01,
           1.9922e+00,  8.5938e-01],
         ...,
         [-1.0312e+00, -8.5156e-01, -5.5469e-01,  ...,  9.4531e-01,
           5.7422e-01, -3.4570e-01],
         [-1.9531e-02, -4.3555e-01,  7.1289e-02,  ...,  1.6113e-01,
           1.8125e+00, -1.6484e+00],
         [ 2.8516e-01, -3.0664e-01,  1.0938e+00,  ..., -2.6367e-01,
          -1.0547e+00, -5.6250e-01]],

        [[ 3.8719e-04, -5.5237e-03,  4.5013e-04,  ...,  4.7852e-02,
           4.2114e-03,  6.7383e-02],
         [-2.9062e+00,  1.9297e+00,  1.4219e+00,  ...,  3.7305e-01,
          -9.2578e-01,  1.9766e+00],
         [-1.4062e+00,  7.8906e-01,  4.8047e-01,  ..., -1.0840e-01,
           1.9922e+00,  8.5938e-01],
         ...,
         [-1.0312e+00, -8.5156e-01, -5.5469e-01,  ...,  9.4531e-01,
           5.7422e-01, -3.4570e-01],
         [-1.9531e-02, -4.3555e-01,  7.1289e-02,  ...,  1.6113e-01,
           1.8125e+00, -1.6484e+00],
         [ 2.8516e-01, -3.0664e-01,  1.0938e+00,  ..., -2.6367e-01,
          -1.0547e+00, -5.6250e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 6.1035e-03,  4.6997e-03,  3.7384e-04,  ..., -1.2817e-03,
           8.5449e-03,  3.8757e-03],
         [-1.0234e+00, -6.3672e-01, -7.3047e-01,  ...,  4.5312e-01,
          -1.2598e-01,  1.3750e+00],
         [ 2.8516e-01, -2.5312e+00, -1.6309e-01,  ...,  1.1953e+00,
           8.6328e-01,  5.8594e-01],
         ...,
         [ 2.3281e+00,  3.5742e-01, -1.3516e+00,  ..., -4.5312e-01,
           1.7578e-01, -5.1562e-01],
         [ 1.1016e+00, -2.5781e-01, -1.2656e+00,  ..., -2.7539e-01,
           1.0391e+00, -6.2891e-01],
         [ 4.8633e-01, -1.0625e+00, -8.7500e-01,  ..., -9.7656e-01,
           1.9824e-01, -6.6797e-01]],

        [[ 6.1035e-03,  4.6997e-03,  3.7384e-04,  ..., -1.2817e-03,
           8.5449e-03,  3.8757e-03],
         [-1.0234e+00, -6.3672e-01, -7.3047e-01,  ...,  4.5312e-01,
          -1.2598e-01,  1.3750e+00],
         [ 2.8516e-01, -2.5312e+00, -1.6309e-01,  ...,  1.1953e+00,
           8.6328e-01,  5.8594e-01],
         ...,
         [ 2.3281e+00,  3.5742e-01, -1.3516e+00,  ..., -4.5312e-01,
           1.7578e-01, -5.1562e-01],
         [ 1.1016e+00, -2.5781e-01, -1.2656e+00,  ..., -2.7539e-01,
           1.0391e+00, -6.2891e-01],
         [ 4.8633e-01, -1.0625e+00, -8.7500e-01,  ..., -9.7656e-01,
           1.9824e-01, -6.6797e-01]],

        [[ 6.1035e-03,  4.6997e-03,  3.7384e-04,  ..., -1.2817e-03,
           8.5449e-03,  3.8757e-03],
         [-1.0234e+00, -6.3672e-01, -7.3047e-01,  ...,  4.5312e-01,
          -1.2598e-01,  1.3750e+00],
         [ 2.8516e-01, -2.5312e+00, -1.6309e-01,  ...,  1.1953e+00,
           8.6328e-01,  5.8594e-01],
         ...,
         [ 2.3281e+00,  3.5742e-01, -1.3516e+00,  ..., -4.5312e-01,
           1.7578e-01, -5.1562e-01],
         [ 1.1016e+00, -2.5781e-01, -1.2656e+00,  ..., -2.7539e-01,
           1.0391e+00, -6.2891e-01],
         [ 4.8633e-01, -1.0625e+00, -8.7500e-01,  ..., -9.7656e-01,
           1.9824e-01, -6.6797e-01]],

        ...,

        [[-3.0136e-04, -1.9379e-03, -3.3569e-03,  ..., -4.6692e-03,
          -7.5989e-03, -9.1934e-04],
         [-2.0117e-01, -1.7969e+00,  1.0791e-01,  ...,  2.6406e+00,
          -6.2500e-01, -4.7656e-01],
         [-1.2734e+00,  4.7070e-01, -9.9609e-01,  ..., -2.3438e+00,
           1.2344e+00,  9.4141e-01],
         ...,
         [-3.1445e-01, -6.9141e-01, -4.0039e-02,  ..., -4.6680e-01,
           1.4258e-01, -2.6367e-01],
         [-5.5859e-01,  1.3984e+00, -3.5742e-01,  ..., -1.3047e+00,
           1.1641e+00,  4.7070e-01],
         [-3.7305e-01,  7.2266e-02,  7.2754e-02,  ..., -2.1094e-01,
           8.2812e-01, -2.4316e-01]],

        [[-3.0136e-04, -1.9379e-03, -3.3569e-03,  ..., -4.6692e-03,
          -7.5989e-03, -9.1934e-04],
         [-2.0117e-01, -1.7969e+00,  1.0791e-01,  ...,  2.6406e+00,
          -6.2500e-01, -4.7656e-01],
         [-1.2734e+00,  4.7070e-01, -9.9609e-01,  ..., -2.3438e+00,
           1.2344e+00,  9.4141e-01],
         ...,
         [-3.1445e-01, -6.9141e-01, -4.0039e-02,  ..., -4.6680e-01,
           1.4258e-01, -2.6367e-01],
         [-5.5859e-01,  1.3984e+00, -3.5742e-01,  ..., -1.3047e+00,
           1.1641e+00,  4.7070e-01],
         [-3.7305e-01,  7.2266e-02,  7.2754e-02,  ..., -2.1094e-01,
           8.2812e-01, -2.4316e-01]],

        [[-3.0136e-04, -1.9379e-03, -3.3569e-03,  ..., -4.6692e-03,
          -7.5989e-03, -9.1934e-04],
         [-2.0117e-01, -1.7969e+00,  1.0791e-01,  ...,  2.6406e+00,
          -6.2500e-01, -4.7656e-01],
         [-1.2734e+00,  4.7070e-01, -9.9609e-01,  ..., -2.3438e+00,
           1.2344e+00,  9.4141e-01],
         ...,
         [-3.1445e-01, -6.9141e-01, -4.0039e-02,  ..., -4.6680e-01,
           1.4258e-01, -2.6367e-01],
         [-5.5859e-01,  1.3984e+00, -3.5742e-01,  ..., -1.3047e+00,
           1.1641e+00,  4.7070e-01],
         [-3.7305e-01,  7.2266e-02,  7.2754e-02,  ..., -2.1094e-01,
           8.2812e-01, -2.4316e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.1215e-03,  1.2283e-03,  4.0588e-03,  ...,  6.5430e-02,
           7.5684e-03, -4.1748e-02],
         [-5.1953e-01, -8.5938e-01, -1.2812e+00,  ..., -8.3203e-01,
           1.4141e+00,  1.0859e+00],
         [-1.6719e+00, -1.0156e+00, -2.2500e+00,  ..., -1.5859e+00,
          -1.4297e+00,  2.0781e+00],
         ...,
         [-1.0059e-01, -7.3828e-01,  8.1250e-01,  ..., -2.7500e+00,
           2.2278e-03,  2.7500e+00],
         [-9.7656e-04,  1.2061e-01,  6.2891e-01,  ..., -4.0000e+00,
           6.6797e-01,  2.4531e+00],
         [-1.3750e+00,  1.5625e+00,  3.8281e-01,  ..., -2.5156e+00,
          -1.7969e-01,  2.0625e+00]],

        [[-1.1215e-03,  1.2283e-03,  4.0588e-03,  ...,  6.5430e-02,
           7.5684e-03, -4.1748e-02],
         [-5.1953e-01, -8.5938e-01, -1.2812e+00,  ..., -8.3203e-01,
           1.4141e+00,  1.0859e+00],
         [-1.6719e+00, -1.0156e+00, -2.2500e+00,  ..., -1.5859e+00,
          -1.4297e+00,  2.0781e+00],
         ...,
         [-1.0059e-01, -7.3828e-01,  8.1250e-01,  ..., -2.7500e+00,
           2.2278e-03,  2.7500e+00],
         [-9.7656e-04,  1.2061e-01,  6.2891e-01,  ..., -4.0000e+00,
           6.6797e-01,  2.4531e+00],
         [-1.3750e+00,  1.5625e+00,  3.8281e-01,  ..., -2.5156e+00,
          -1.7969e-01,  2.0625e+00]],

        [[-1.1215e-03,  1.2283e-03,  4.0588e-03,  ...,  6.5430e-02,
           7.5684e-03, -4.1748e-02],
         [-5.1953e-01, -8.5938e-01, -1.2812e+00,  ..., -8.3203e-01,
           1.4141e+00,  1.0859e+00],
         [-1.6719e+00, -1.0156e+00, -2.2500e+00,  ..., -1.5859e+00,
          -1.4297e+00,  2.0781e+00],
         ...,
         [-1.0059e-01, -7.3828e-01,  8.1250e-01,  ..., -2.7500e+00,
           2.2278e-03,  2.7500e+00],
         [-9.7656e-04,  1.2061e-01,  6.2891e-01,  ..., -4.0000e+00,
           6.6797e-01,  2.4531e+00],
         [-1.3750e+00,  1.5625e+00,  3.8281e-01,  ..., -2.5156e+00,
          -1.7969e-01,  2.0625e+00]],

        ...,

        [[-2.6245e-03,  7.7820e-03, -2.7924e-03,  ...,  9.4604e-03,
           4.3640e-03, -3.2227e-01],
         [-5.1172e-01, -1.6953e+00, -1.0156e+00,  ..., -1.5156e+00,
           2.1250e+00,  1.0250e+01],
         [-1.2812e+00, -1.2109e+00, -8.0078e-01,  ..., -3.9648e-01,
           2.2559e-01,  1.0125e+01],
         ...,
         [-2.9297e-01, -2.2852e-01,  3.7305e-01,  ..., -1.7383e-01,
           2.5938e+00,  7.5938e+00],
         [-8.0859e-01,  6.4453e-01,  1.6406e-01,  ..., -1.0781e+00,
           9.4141e-01,  7.9375e+00],
         [-9.7266e-01,  1.1406e+00, -6.1328e-01,  ..., -1.6953e+00,
           4.5117e-01,  8.5000e+00]],

        [[-2.6245e-03,  7.7820e-03, -2.7924e-03,  ...,  9.4604e-03,
           4.3640e-03, -3.2227e-01],
         [-5.1172e-01, -1.6953e+00, -1.0156e+00,  ..., -1.5156e+00,
           2.1250e+00,  1.0250e+01],
         [-1.2812e+00, -1.2109e+00, -8.0078e-01,  ..., -3.9648e-01,
           2.2559e-01,  1.0125e+01],
         ...,
         [-2.9297e-01, -2.2852e-01,  3.7305e-01,  ..., -1.7383e-01,
           2.5938e+00,  7.5938e+00],
         [-8.0859e-01,  6.4453e-01,  1.6406e-01,  ..., -1.0781e+00,
           9.4141e-01,  7.9375e+00],
         [-9.7266e-01,  1.1406e+00, -6.1328e-01,  ..., -1.6953e+00,
           4.5117e-01,  8.5000e+00]],

        [[-2.6245e-03,  7.7820e-03, -2.7924e-03,  ...,  9.4604e-03,
           4.3640e-03, -3.2227e-01],
         [-5.1172e-01, -1.6953e+00, -1.0156e+00,  ..., -1.5156e+00,
           2.1250e+00,  1.0250e+01],
         [-1.2812e+00, -1.2109e+00, -8.0078e-01,  ..., -3.9648e-01,
           2.2559e-01,  1.0125e+01],
         ...,
         [-2.9297e-01, -2.2852e-01,  3.7305e-01,  ..., -1.7383e-01,
           2.5938e+00,  7.5938e+00],
         [-8.0859e-01,  6.4453e-01,  1.6406e-01,  ..., -1.0781e+00,
           9.4141e-01,  7.9375e+00],
         [-9.7266e-01,  1.1406e+00, -6.1328e-01,  ..., -1.6953e+00,
           4.5117e-01,  8.5000e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 5.8289e-03, -2.0905e-03, -2.3041e-03,  ...,  6.0654e-04,
           6.7139e-03, -1.5717e-03],
         [ 5.4688e-01,  5.8899e-03,  2.3145e-01,  ...,  1.2695e-01,
           4.5508e-01,  6.8750e-01],
         [ 9.8828e-01, -4.2969e-01,  9.2969e-01,  ..., -1.8828e+00,
           1.0547e+00, -1.2578e+00],
         ...,
         [ 1.1875e+00, -5.3906e-01, -8.0078e-02,  ..., -6.2891e-01,
           9.9609e-02,  1.0205e-01],
         [ 4.2773e-01,  7.8125e-01,  2.7100e-02,  ..., -6.7578e-01,
          -9.2969e-01,  8.0469e-01],
         [ 1.7285e-01,  2.3047e-01,  6.9824e-02,  ...,  1.3672e-01,
           1.0547e+00, -6.7188e-01]],

        [[ 5.8289e-03, -2.0905e-03, -2.3041e-03,  ...,  6.0654e-04,
           6.7139e-03, -1.5717e-03],
         [ 5.4688e-01,  5.8899e-03,  2.3145e-01,  ...,  1.2695e-01,
           4.5508e-01,  6.8750e-01],
         [ 9.8828e-01, -4.2969e-01,  9.2969e-01,  ..., -1.8828e+00,
           1.0547e+00, -1.2578e+00],
         ...,
         [ 1.1875e+00, -5.3906e-01, -8.0078e-02,  ..., -6.2891e-01,
           9.9609e-02,  1.0205e-01],
         [ 4.2773e-01,  7.8125e-01,  2.7100e-02,  ..., -6.7578e-01,
          -9.2969e-01,  8.0469e-01],
         [ 1.7285e-01,  2.3047e-01,  6.9824e-02,  ...,  1.3672e-01,
           1.0547e+00, -6.7188e-01]],

        [[ 5.8289e-03, -2.0905e-03, -2.3041e-03,  ...,  6.0654e-04,
           6.7139e-03, -1.5717e-03],
         [ 5.4688e-01,  5.8899e-03,  2.3145e-01,  ...,  1.2695e-01,
           4.5508e-01,  6.8750e-01],
         [ 9.8828e-01, -4.2969e-01,  9.2969e-01,  ..., -1.8828e+00,
           1.0547e+00, -1.2578e+00],
         ...,
         [ 1.1875e+00, -5.3906e-01, -8.0078e-02,  ..., -6.2891e-01,
           9.9609e-02,  1.0205e-01],
         [ 4.2773e-01,  7.8125e-01,  2.7100e-02,  ..., -6.7578e-01,
          -9.2969e-01,  8.0469e-01],
         [ 1.7285e-01,  2.3047e-01,  6.9824e-02,  ...,  1.3672e-01,
           1.0547e+00, -6.7188e-01]],

        ...,

        [[-3.1128e-03, -7.8583e-04,  5.3406e-03,  ..., -7.8125e-03,
           1.1963e-02, -1.4420e-03],
         [ 8.1250e-01, -4.7461e-01, -8.2812e-01,  ...,  4.3555e-01,
          -1.8125e+00, -8.7891e-01],
         [-1.0391e+00, -8.7500e-01, -4.8242e-01,  ..., -1.1963e-01,
          -1.3984e+00, -1.8359e-01],
         ...,
         [ 2.9883e-01,  3.4180e-01,  3.6914e-01,  ...,  9.6094e-01,
          -6.0156e-01, -1.7676e-01],
         [ 8.3203e-01, -1.9434e-01,  5.6250e-01,  ...,  5.4297e-01,
          -1.6797e-01,  4.9805e-01],
         [ 9.6484e-01, -6.8359e-01,  7.0312e-01,  ..., -4.8047e-01,
           1.8262e-01, -3.4180e-01]],

        [[-3.1128e-03, -7.8583e-04,  5.3406e-03,  ..., -7.8125e-03,
           1.1963e-02, -1.4420e-03],
         [ 8.1250e-01, -4.7461e-01, -8.2812e-01,  ...,  4.3555e-01,
          -1.8125e+00, -8.7891e-01],
         [-1.0391e+00, -8.7500e-01, -4.8242e-01,  ..., -1.1963e-01,
          -1.3984e+00, -1.8359e-01],
         ...,
         [ 2.9883e-01,  3.4180e-01,  3.6914e-01,  ...,  9.6094e-01,
          -6.0156e-01, -1.7676e-01],
         [ 8.3203e-01, -1.9434e-01,  5.6250e-01,  ...,  5.4297e-01,
          -1.6797e-01,  4.9805e-01],
         [ 9.6484e-01, -6.8359e-01,  7.0312e-01,  ..., -4.8047e-01,
           1.8262e-01, -3.4180e-01]],

        [[-3.1128e-03, -7.8583e-04,  5.3406e-03,  ..., -7.8125e-03,
           1.1963e-02, -1.4420e-03],
         [ 8.1250e-01, -4.7461e-01, -8.2812e-01,  ...,  4.3555e-01,
          -1.8125e+00, -8.7891e-01],
         [-1.0391e+00, -8.7500e-01, -4.8242e-01,  ..., -1.1963e-01,
          -1.3984e+00, -1.8359e-01],
         ...,
         [ 2.9883e-01,  3.4180e-01,  3.6914e-01,  ...,  9.6094e-01,
          -6.0156e-01, -1.7676e-01],
         [ 8.3203e-01, -1.9434e-01,  5.6250e-01,  ...,  5.4297e-01,
          -1.6797e-01,  4.9805e-01],
         [ 9.6484e-01, -6.8359e-01,  7.0312e-01,  ..., -4.8047e-01,
           1.8262e-01, -3.4180e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-2.0508e-02,  6.1340e-03, -8.1787e-03,  ..., -1.7285e-01,
          -1.0010e-01,  9.5703e-02],
         [ 2.2656e+00, -1.7500e+00,  1.6562e+00,  ...,  1.0625e+00,
           8.9453e-01, -1.7969e+00],
         [ 1.6172e+00, -1.5234e+00,  1.7031e+00,  ...,  7.7344e-01,
           8.0078e-02,  5.8984e-01],
         ...,
         [ 1.1094e+00, -5.0391e-01, -7.0703e-01,  ...,  1.2578e+00,
           1.0703e+00, -1.4922e+00],
         [ 3.9062e-01,  2.0264e-02, -1.5234e-01,  ...,  1.1719e+00,
           1.2266e+00, -4.7852e-01],
         [ 1.0193e-02,  3.9062e-01, -2.1484e-01,  ...,  1.2422e+00,
           3.1094e+00, -9.8828e-01]],

        [[-2.0508e-02,  6.1340e-03, -8.1787e-03,  ..., -1.7285e-01,
          -1.0010e-01,  9.5703e-02],
         [ 2.2656e+00, -1.7500e+00,  1.6562e+00,  ...,  1.0625e+00,
           8.9453e-01, -1.7969e+00],
         [ 1.6172e+00, -1.5234e+00,  1.7031e+00,  ...,  7.7344e-01,
           8.0078e-02,  5.8984e-01],
         ...,
         [ 1.1094e+00, -5.0391e-01, -7.0703e-01,  ...,  1.2578e+00,
           1.0703e+00, -1.4922e+00],
         [ 3.9062e-01,  2.0264e-02, -1.5234e-01,  ...,  1.1719e+00,
           1.2266e+00, -4.7852e-01],
         [ 1.0193e-02,  3.9062e-01, -2.1484e-01,  ...,  1.2422e+00,
           3.1094e+00, -9.8828e-01]],

        [[-2.0508e-02,  6.1340e-03, -8.1787e-03,  ..., -1.7285e-01,
          -1.0010e-01,  9.5703e-02],
         [ 2.2656e+00, -1.7500e+00,  1.6562e+00,  ...,  1.0625e+00,
           8.9453e-01, -1.7969e+00],
         [ 1.6172e+00, -1.5234e+00,  1.7031e+00,  ...,  7.7344e-01,
           8.0078e-02,  5.8984e-01],
         ...,
         [ 1.1094e+00, -5.0391e-01, -7.0703e-01,  ...,  1.2578e+00,
           1.0703e+00, -1.4922e+00],
         [ 3.9062e-01,  2.0264e-02, -1.5234e-01,  ...,  1.1719e+00,
           1.2266e+00, -4.7852e-01],
         [ 1.0193e-02,  3.9062e-01, -2.1484e-01,  ...,  1.2422e+00,
           3.1094e+00, -9.8828e-01]],

        ...,

        [[-7.2021e-03, -4.3335e-03,  4.3335e-03,  ..., -1.3245e-02,
          -7.4609e-01,  1.4941e-01],
         [-1.0156e+00,  2.8125e+00,  1.6016e-01,  ...,  1.0391e+00,
           7.6562e+00, -8.5156e-01],
         [-2.3750e+00,  1.4844e+00,  1.1094e+00,  ...,  5.0781e-01,
           8.9375e+00, -2.2812e+00],
         ...,
         [-1.7969e-01, -7.8125e-02, -1.3184e-01,  ...,  6.5234e-01,
           7.0938e+00, -4.5312e-01],
         [-1.3438e+00, -1.0625e+00, -1.1094e+00,  ..., -4.7852e-01,
           7.4688e+00, -7.8516e-01],
         [-1.3516e+00, -1.6953e+00, -1.4688e+00,  ...,  1.2734e+00,
           7.4688e+00, -8.6328e-01]],

        [[-7.2021e-03, -4.3335e-03,  4.3335e-03,  ..., -1.3245e-02,
          -7.4609e-01,  1.4941e-01],
         [-1.0156e+00,  2.8125e+00,  1.6016e-01,  ...,  1.0391e+00,
           7.6562e+00, -8.5156e-01],
         [-2.3750e+00,  1.4844e+00,  1.1094e+00,  ...,  5.0781e-01,
           8.9375e+00, -2.2812e+00],
         ...,
         [-1.7969e-01, -7.8125e-02, -1.3184e-01,  ...,  6.5234e-01,
           7.0938e+00, -4.5312e-01],
         [-1.3438e+00, -1.0625e+00, -1.1094e+00,  ..., -4.7852e-01,
           7.4688e+00, -7.8516e-01],
         [-1.3516e+00, -1.6953e+00, -1.4688e+00,  ...,  1.2734e+00,
           7.4688e+00, -8.6328e-01]],

        [[-7.2021e-03, -4.3335e-03,  4.3335e-03,  ..., -1.3245e-02,
          -7.4609e-01,  1.4941e-01],
         [-1.0156e+00,  2.8125e+00,  1.6016e-01,  ...,  1.0391e+00,
           7.6562e+00, -8.5156e-01],
         [-2.3750e+00,  1.4844e+00,  1.1094e+00,  ...,  5.0781e-01,
           8.9375e+00, -2.2812e+00],
         ...,
         [-1.7969e-01, -7.8125e-02, -1.3184e-01,  ...,  6.5234e-01,
           7.0938e+00, -4.5312e-01],
         [-1.3438e+00, -1.0625e+00, -1.1094e+00,  ..., -4.7852e-01,
           7.4688e+00, -7.8516e-01],
         [-1.3516e+00, -1.6953e+00, -1.4688e+00,  ...,  1.2734e+00,
           7.4688e+00, -8.6328e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 2.0752e-03, -2.1210e-03, -6.5002e-03,  ...,  5.2185e-03,
          -5.4626e-03, -2.9144e-03],
         [ 8.2397e-03, -2.9102e-01,  2.8516e-01,  ..., -9.3750e-01,
          -5.7812e-01,  6.2109e-01],
         [ 3.7500e-01, -6.1328e-01,  1.0547e-01,  ..., -1.8438e+00,
           4.2969e-01, -1.8281e+00],
         ...,
         [ 4.7070e-01, -1.4141e+00, -3.8086e-01,  ..., -7.5684e-02,
          -1.0625e+00,  1.6309e-01],
         [ 3.5742e-01, -9.2578e-01, -9.1016e-01,  ...,  8.3984e-01,
          -1.1016e+00, -1.3281e+00],
         [ 1.0156e+00, -9.0234e-01, -1.9434e-01,  ...,  4.3555e-01,
          -1.3672e+00, -9.7266e-01]],

        [[ 2.0752e-03, -2.1210e-03, -6.5002e-03,  ...,  5.2185e-03,
          -5.4626e-03, -2.9144e-03],
         [ 8.2397e-03, -2.9102e-01,  2.8516e-01,  ..., -9.3750e-01,
          -5.7812e-01,  6.2109e-01],
         [ 3.7500e-01, -6.1328e-01,  1.0547e-01,  ..., -1.8438e+00,
           4.2969e-01, -1.8281e+00],
         ...,
         [ 4.7070e-01, -1.4141e+00, -3.8086e-01,  ..., -7.5684e-02,
          -1.0625e+00,  1.6309e-01],
         [ 3.5742e-01, -9.2578e-01, -9.1016e-01,  ...,  8.3984e-01,
          -1.1016e+00, -1.3281e+00],
         [ 1.0156e+00, -9.0234e-01, -1.9434e-01,  ...,  4.3555e-01,
          -1.3672e+00, -9.7266e-01]],

        [[ 2.0752e-03, -2.1210e-03, -6.5002e-03,  ...,  5.2185e-03,
          -5.4626e-03, -2.9144e-03],
         [ 8.2397e-03, -2.9102e-01,  2.8516e-01,  ..., -9.3750e-01,
          -5.7812e-01,  6.2109e-01],
         [ 3.7500e-01, -6.1328e-01,  1.0547e-01,  ..., -1.8438e+00,
           4.2969e-01, -1.8281e+00],
         ...,
         [ 4.7070e-01, -1.4141e+00, -3.8086e-01,  ..., -7.5684e-02,
          -1.0625e+00,  1.6309e-01],
         [ 3.5742e-01, -9.2578e-01, -9.1016e-01,  ...,  8.3984e-01,
          -1.1016e+00, -1.3281e+00],
         [ 1.0156e+00, -9.0234e-01, -1.9434e-01,  ...,  4.3555e-01,
          -1.3672e+00, -9.7266e-01]],

        ...,

        [[ 4.7302e-03,  9.9182e-04, -4.4556e-03,  ...,  1.8921e-03,
          -6.1340e-03,  2.1820e-03],
         [ 5.3906e-01,  8.6426e-02, -3.4424e-02,  ...,  9.8438e-01,
           8.1250e-01,  7.6172e-01],
         [-3.0859e-01,  3.6719e-01,  1.8906e+00,  ...,  1.7871e-01,
           6.7871e-02, -2.1562e+00],
         ...,
         [-5.3516e-01,  1.5015e-02, -2.0630e-02,  ..., -5.3467e-02,
           4.6387e-02, -8.9844e-02],
         [ 2.6562e-01,  8.6719e-01, -6.9922e-01,  ..., -1.2656e+00,
          -3.1641e-01, -8.0469e-01],
         [-2.8516e-01, -5.2490e-02,  8.6719e-01,  ..., -4.3555e-01,
          -8.6426e-02,  1.1768e-01]],

        [[ 4.7302e-03,  9.9182e-04, -4.4556e-03,  ...,  1.8921e-03,
          -6.1340e-03,  2.1820e-03],
         [ 5.3906e-01,  8.6426e-02, -3.4424e-02,  ...,  9.8438e-01,
           8.1250e-01,  7.6172e-01],
         [-3.0859e-01,  3.6719e-01,  1.8906e+00,  ...,  1.7871e-01,
           6.7871e-02, -2.1562e+00],
         ...,
         [-5.3516e-01,  1.5015e-02, -2.0630e-02,  ..., -5.3467e-02,
           4.6387e-02, -8.9844e-02],
         [ 2.6562e-01,  8.6719e-01, -6.9922e-01,  ..., -1.2656e+00,
          -3.1641e-01, -8.0469e-01],
         [-2.8516e-01, -5.2490e-02,  8.6719e-01,  ..., -4.3555e-01,
          -8.6426e-02,  1.1768e-01]],

        [[ 4.7302e-03,  9.9182e-04, -4.4556e-03,  ...,  1.8921e-03,
          -6.1340e-03,  2.1820e-03],
         [ 5.3906e-01,  8.6426e-02, -3.4424e-02,  ...,  9.8438e-01,
           8.1250e-01,  7.6172e-01],
         [-3.0859e-01,  3.6719e-01,  1.8906e+00,  ...,  1.7871e-01,
           6.7871e-02, -2.1562e+00],
         ...,
         [-5.3516e-01,  1.5015e-02, -2.0630e-02,  ..., -5.3467e-02,
           4.6387e-02, -8.9844e-02],
         [ 2.6562e-01,  8.6719e-01, -6.9922e-01,  ..., -1.2656e+00,
          -3.1641e-01, -8.0469e-01],
         [-2.8516e-01, -5.2490e-02,  8.6719e-01,  ..., -4.3555e-01,
          -8.6426e-02,  1.1768e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 3.7994e-03, -2.5146e-02, -1.2360e-03,  ..., -2.3633e-01,
          -9.6680e-02,  1.0059e-01],
         [ 2.5000e-01, -3.6328e-01,  3.5156e-01,  ..., -8.7402e-02,
          -9.6484e-01, -1.3516e+00],
         [ 1.1797e+00, -1.0859e+00, -7.4609e-01,  ..., -5.4062e+00,
          -4.2500e+00, -3.3750e+00],
         ...,
         [ 9.9609e-02,  9.0234e-01,  2.5781e-01,  ..., -5.1250e+00,
          -7.9297e-01, -1.4375e+00],
         [ 8.1641e-01,  2.0703e-01,  5.8203e-01,  ..., -4.9688e+00,
           1.2207e-01,  1.0234e+00],
         [ 1.4922e+00,  1.2305e-01,  2.4121e-01,  ..., -3.5156e+00,
          -3.0469e+00, -1.1250e+00]],

        [[ 3.7994e-03, -2.5146e-02, -1.2360e-03,  ..., -2.3633e-01,
          -9.6680e-02,  1.0059e-01],
         [ 2.5000e-01, -3.6328e-01,  3.5156e-01,  ..., -8.7402e-02,
          -9.6484e-01, -1.3516e+00],
         [ 1.1797e+00, -1.0859e+00, -7.4609e-01,  ..., -5.4062e+00,
          -4.2500e+00, -3.3750e+00],
         ...,
         [ 9.9609e-02,  9.0234e-01,  2.5781e-01,  ..., -5.1250e+00,
          -7.9297e-01, -1.4375e+00],
         [ 8.1641e-01,  2.0703e-01,  5.8203e-01,  ..., -4.9688e+00,
           1.2207e-01,  1.0234e+00],
         [ 1.4922e+00,  1.2305e-01,  2.4121e-01,  ..., -3.5156e+00,
          -3.0469e+00, -1.1250e+00]],

        [[ 3.7994e-03, -2.5146e-02, -1.2360e-03,  ..., -2.3633e-01,
          -9.6680e-02,  1.0059e-01],
         [ 2.5000e-01, -3.6328e-01,  3.5156e-01,  ..., -8.7402e-02,
          -9.6484e-01, -1.3516e+00],
         [ 1.1797e+00, -1.0859e+00, -7.4609e-01,  ..., -5.4062e+00,
          -4.2500e+00, -3.3750e+00],
         ...,
         [ 9.9609e-02,  9.0234e-01,  2.5781e-01,  ..., -5.1250e+00,
          -7.9297e-01, -1.4375e+00],
         [ 8.1641e-01,  2.0703e-01,  5.8203e-01,  ..., -4.9688e+00,
           1.2207e-01,  1.0234e+00],
         [ 1.4922e+00,  1.2305e-01,  2.4121e-01,  ..., -3.5156e+00,
          -3.0469e+00, -1.1250e+00]],

        ...,

        [[-8.1177e-03, -2.3041e-03, -8.8501e-04,  ...,  4.3213e-02,
           1.1426e-01,  1.2305e-01],
         [ 1.2969e+00,  1.7734e+00,  1.0703e+00,  ...,  2.6406e+00,
           9.8828e-01,  1.5723e-01],
         [ 1.7188e-01,  6.2500e-01,  2.2344e+00,  ...,  1.0625e+00,
          -7.8906e-01,  2.1484e-01],
         ...,
         [-8.8867e-02,  3.2617e-01,  2.4219e-01,  ...,  4.5703e-01,
           4.7656e-01, -2.8281e+00],
         [ 1.9141e-01,  4.1016e-01,  6.0156e-01,  ...,  2.6172e-01,
           9.1309e-02, -3.1055e-01],
         [ 1.2266e+00, -9.2188e-01,  2.0781e+00,  ...,  2.1719e+00,
           9.9219e-01, -1.4297e+00]],

        [[-8.1177e-03, -2.3041e-03, -8.8501e-04,  ...,  4.3213e-02,
           1.1426e-01,  1.2305e-01],
         [ 1.2969e+00,  1.7734e+00,  1.0703e+00,  ...,  2.6406e+00,
           9.8828e-01,  1.5723e-01],
         [ 1.7188e-01,  6.2500e-01,  2.2344e+00,  ...,  1.0625e+00,
          -7.8906e-01,  2.1484e-01],
         ...,
         [-8.8867e-02,  3.2617e-01,  2.4219e-01,  ...,  4.5703e-01,
           4.7656e-01, -2.8281e+00],
         [ 1.9141e-01,  4.1016e-01,  6.0156e-01,  ...,  2.6172e-01,
           9.1309e-02, -3.1055e-01],
         [ 1.2266e+00, -9.2188e-01,  2.0781e+00,  ...,  2.1719e+00,
           9.9219e-01, -1.4297e+00]],

        [[-8.1177e-03, -2.3041e-03, -8.8501e-04,  ...,  4.3213e-02,
           1.1426e-01,  1.2305e-01],
         [ 1.2969e+00,  1.7734e+00,  1.0703e+00,  ...,  2.6406e+00,
           9.8828e-01,  1.5723e-01],
         [ 1.7188e-01,  6.2500e-01,  2.2344e+00,  ...,  1.0625e+00,
          -7.8906e-01,  2.1484e-01],
         ...,
         [-8.8867e-02,  3.2617e-01,  2.4219e-01,  ...,  4.5703e-01,
           4.7656e-01, -2.8281e+00],
         [ 1.9141e-01,  4.1016e-01,  6.0156e-01,  ...,  2.6172e-01,
           9.1309e-02, -3.1055e-01],
         [ 1.2266e+00, -9.2188e-01,  2.0781e+00,  ...,  2.1719e+00,
           9.9219e-01, -1.4297e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 7.2327e-03, -4.8218e-03, -5.1575e-03,  ...,  1.7090e-03,
           1.6479e-03,  5.5847e-03],
         [-1.0781e+00, -9.1406e-01,  2.9688e+00,  ..., -1.7031e+00,
          -5.6641e-01, -1.1016e+00],
         [-1.9219e+00, -4.1562e+00,  4.4688e+00,  ..., -7.5781e-01,
          -1.3359e+00, -9.6094e-01],
         ...,
         [-1.5625e+00, -6.8750e-01, -4.9072e-02,  ...,  1.3906e+00,
           5.8984e-01,  1.4922e+00],
         [-3.9062e-01, -1.6953e+00,  1.7500e+00,  ..., -6.4844e-01,
           4.9609e-01,  1.6328e+00],
         [-1.2031e+00, -2.3594e+00,  1.7812e+00,  ..., -2.5195e-01,
          -3.6133e-01,  2.5586e-01]],

        [[ 7.2327e-03, -4.8218e-03, -5.1575e-03,  ...,  1.7090e-03,
           1.6479e-03,  5.5847e-03],
         [-1.0781e+00, -9.1406e-01,  2.9688e+00,  ..., -1.7031e+00,
          -5.6641e-01, -1.1016e+00],
         [-1.9219e+00, -4.1562e+00,  4.4688e+00,  ..., -7.5781e-01,
          -1.3359e+00, -9.6094e-01],
         ...,
         [-1.5625e+00, -6.8750e-01, -4.9072e-02,  ...,  1.3906e+00,
           5.8984e-01,  1.4922e+00],
         [-3.9062e-01, -1.6953e+00,  1.7500e+00,  ..., -6.4844e-01,
           4.9609e-01,  1.6328e+00],
         [-1.2031e+00, -2.3594e+00,  1.7812e+00,  ..., -2.5195e-01,
          -3.6133e-01,  2.5586e-01]],

        [[ 7.2327e-03, -4.8218e-03, -5.1575e-03,  ...,  1.7090e-03,
           1.6479e-03,  5.5847e-03],
         [-1.0781e+00, -9.1406e-01,  2.9688e+00,  ..., -1.7031e+00,
          -5.6641e-01, -1.1016e+00],
         [-1.9219e+00, -4.1562e+00,  4.4688e+00,  ..., -7.5781e-01,
          -1.3359e+00, -9.6094e-01],
         ...,
         [-1.5625e+00, -6.8750e-01, -4.9072e-02,  ...,  1.3906e+00,
           5.8984e-01,  1.4922e+00],
         [-3.9062e-01, -1.6953e+00,  1.7500e+00,  ..., -6.4844e-01,
           4.9609e-01,  1.6328e+00],
         [-1.2031e+00, -2.3594e+00,  1.7812e+00,  ..., -2.5195e-01,
          -3.6133e-01,  2.5586e-01]],

        ...,

        [[ 5.9509e-04,  1.7090e-02, -3.9368e-03,  ...,  3.6011e-03,
           3.3569e-03, -2.1362e-03],
         [-1.0547e-01,  1.1963e-01,  1.7266e+00,  ..., -1.3359e+00,
          -6.9141e-01, -1.7422e+00],
         [-4.1809e-03, -8.7402e-02,  5.8594e-01,  ..., -8.2422e-01,
          -4.3701e-02, -3.7598e-02],
         ...,
         [-1.4453e+00, -1.0596e-01,  4.8242e-01,  ..., -1.8750e+00,
          -2.7344e+00, -1.2344e+00],
         [-8.1250e-01, -5.6250e-01, -2.3828e-01,  ..., -1.0859e+00,
          -1.8359e+00, -1.8750e+00],
         [-1.5078e+00,  7.9297e-01,  8.1250e-01,  ..., -1.9531e+00,
          -5.8203e-01,  1.2500e+00]],

        [[ 5.9509e-04,  1.7090e-02, -3.9368e-03,  ...,  3.6011e-03,
           3.3569e-03, -2.1362e-03],
         [-1.0547e-01,  1.1963e-01,  1.7266e+00,  ..., -1.3359e+00,
          -6.9141e-01, -1.7422e+00],
         [-4.1809e-03, -8.7402e-02,  5.8594e-01,  ..., -8.2422e-01,
          -4.3701e-02, -3.7598e-02],
         ...,
         [-1.4453e+00, -1.0596e-01,  4.8242e-01,  ..., -1.8750e+00,
          -2.7344e+00, -1.2344e+00],
         [-8.1250e-01, -5.6250e-01, -2.3828e-01,  ..., -1.0859e+00,
          -1.8359e+00, -1.8750e+00],
         [-1.5078e+00,  7.9297e-01,  8.1250e-01,  ..., -1.9531e+00,
          -5.8203e-01,  1.2500e+00]],

        [[ 5.9509e-04,  1.7090e-02, -3.9368e-03,  ...,  3.6011e-03,
           3.3569e-03, -2.1362e-03],
         [-1.0547e-01,  1.1963e-01,  1.7266e+00,  ..., -1.3359e+00,
          -6.9141e-01, -1.7422e+00],
         [-4.1809e-03, -8.7402e-02,  5.8594e-01,  ..., -8.2422e-01,
          -4.3701e-02, -3.7598e-02],
         ...,
         [-1.4453e+00, -1.0596e-01,  4.8242e-01,  ..., -1.8750e+00,
          -2.7344e+00, -1.2344e+00],
         [-8.1250e-01, -5.6250e-01, -2.3828e-01,  ..., -1.0859e+00,
          -1.8359e+00, -1.8750e+00],
         [-1.5078e+00,  7.9297e-01,  8.1250e-01,  ..., -1.9531e+00,
          -5.8203e-01,  1.2500e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.2024e-02,  5.1270e-03,  5.8289e-03,  ..., -1.3594e+00,
           1.7090e-03,  2.0215e-01],
         [ 3.5547e-01,  1.9609e+00, -6.4844e-01,  ...,  9.5625e+00,
           1.6113e-01, -2.9062e+00],
         [-9.1016e-01, -1.1621e-01,  2.7734e-01,  ...,  1.0000e+01,
           1.7188e+00, -6.8359e-01],
         ...,
         [ 1.7090e-01, -5.3516e-01, -1.1377e-01,  ...,  9.2500e+00,
           1.3477e-01, -1.3984e+00],
         [-1.1406e+00, -1.6328e+00, -5.9766e-01,  ...,  8.8125e+00,
           3.7891e-01, -1.2812e+00],
         [-1.8281e+00, -3.5156e-01, -1.2812e+00,  ...,  8.8125e+00,
          -8.1250e-01, -2.5781e+00]],

        [[-1.2024e-02,  5.1270e-03,  5.8289e-03,  ..., -1.3594e+00,
           1.7090e-03,  2.0215e-01],
         [ 3.5547e-01,  1.9609e+00, -6.4844e-01,  ...,  9.5625e+00,
           1.6113e-01, -2.9062e+00],
         [-9.1016e-01, -1.1621e-01,  2.7734e-01,  ...,  1.0000e+01,
           1.7188e+00, -6.8359e-01],
         ...,
         [ 1.7090e-01, -5.3516e-01, -1.1377e-01,  ...,  9.2500e+00,
           1.3477e-01, -1.3984e+00],
         [-1.1406e+00, -1.6328e+00, -5.9766e-01,  ...,  8.8125e+00,
           3.7891e-01, -1.2812e+00],
         [-1.8281e+00, -3.5156e-01, -1.2812e+00,  ...,  8.8125e+00,
          -8.1250e-01, -2.5781e+00]],

        [[-1.2024e-02,  5.1270e-03,  5.8289e-03,  ..., -1.3594e+00,
           1.7090e-03,  2.0215e-01],
         [ 3.5547e-01,  1.9609e+00, -6.4844e-01,  ...,  9.5625e+00,
           1.6113e-01, -2.9062e+00],
         [-9.1016e-01, -1.1621e-01,  2.7734e-01,  ...,  1.0000e+01,
           1.7188e+00, -6.8359e-01],
         ...,
         [ 1.7090e-01, -5.3516e-01, -1.1377e-01,  ...,  9.2500e+00,
           1.3477e-01, -1.3984e+00],
         [-1.1406e+00, -1.6328e+00, -5.9766e-01,  ...,  8.8125e+00,
           3.7891e-01, -1.2812e+00],
         [-1.8281e+00, -3.5156e-01, -1.2812e+00,  ...,  8.8125e+00,
          -8.1250e-01, -2.5781e+00]],

        ...,

        [[ 9.3384e-03,  1.5198e-02, -1.2146e-02,  ...,  7.6660e-02,
          -3.2422e-01,  1.1523e-01],
         [-1.5078e+00,  2.6406e+00,  1.2656e+00,  ...,  2.8711e-01,
           2.0469e+00,  1.1523e-01],
         [-7.0312e-01,  9.1309e-02,  5.9375e-01,  ...,  4.6875e+00,
          -4.1562e+00,  1.5234e+00],
         ...,
         [-1.0781e+00,  3.7500e-01, -1.4648e-02,  ...,  1.6406e+00,
          -3.6719e+00,  2.9688e+00],
         [ 5.0781e-01, -5.5664e-02,  2.1387e-01,  ...,  7.9688e-01,
          -3.6562e+00,  3.0781e+00],
         [ 2.4531e+00, -9.6875e-01,  5.7031e-01,  ...,  1.5703e+00,
          -1.7500e+00,  1.7344e+00]],

        [[ 9.3384e-03,  1.5198e-02, -1.2146e-02,  ...,  7.6660e-02,
          -3.2422e-01,  1.1523e-01],
         [-1.5078e+00,  2.6406e+00,  1.2656e+00,  ...,  2.8711e-01,
           2.0469e+00,  1.1523e-01],
         [-7.0312e-01,  9.1309e-02,  5.9375e-01,  ...,  4.6875e+00,
          -4.1562e+00,  1.5234e+00],
         ...,
         [-1.0781e+00,  3.7500e-01, -1.4648e-02,  ...,  1.6406e+00,
          -3.6719e+00,  2.9688e+00],
         [ 5.0781e-01, -5.5664e-02,  2.1387e-01,  ...,  7.9688e-01,
          -3.6562e+00,  3.0781e+00],
         [ 2.4531e+00, -9.6875e-01,  5.7031e-01,  ...,  1.5703e+00,
          -1.7500e+00,  1.7344e+00]],

        [[ 9.3384e-03,  1.5198e-02, -1.2146e-02,  ...,  7.6660e-02,
          -3.2422e-01,  1.1523e-01],
         [-1.5078e+00,  2.6406e+00,  1.2656e+00,  ...,  2.8711e-01,
           2.0469e+00,  1.1523e-01],
         [-7.0312e-01,  9.1309e-02,  5.9375e-01,  ...,  4.6875e+00,
          -4.1562e+00,  1.5234e+00],
         ...,
         [-1.0781e+00,  3.7500e-01, -1.4648e-02,  ...,  1.6406e+00,
          -3.6719e+00,  2.9688e+00],
         [ 5.0781e-01, -5.5664e-02,  2.1387e-01,  ...,  7.9688e-01,
          -3.6562e+00,  3.0781e+00],
         [ 2.4531e+00, -9.6875e-01,  5.7031e-01,  ...,  1.5703e+00,
          -1.7500e+00,  1.7344e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-9.0942e-03,  5.8289e-03, -4.0894e-03,  ..., -1.3184e-02,
          -4.6997e-03,  2.1172e-04],
         [ 1.4062e-01, -7.3047e-01, -4.3164e-01,  ..., -9.0625e-01,
           6.6797e-01, -9.2969e-01],
         [-9.3359e-01, -2.1719e+00,  3.2959e-02,  ...,  9.1406e-01,
           5.8984e-01,  5.2734e-01],
         ...,
         [-4.8047e-01,  1.2793e-01,  8.3203e-01,  ...,  5.4688e-01,
           7.3828e-01,  2.0410e-01],
         [-7.6172e-01,  1.3125e+00, -8.0469e-01,  ...,  1.3203e+00,
           9.9609e-01,  9.3750e-02],
         [ 4.5703e-01,  1.8047e+00,  7.8516e-01,  ...,  1.1094e+00,
          -1.2207e-01, -3.5742e-01]],

        [[-9.0942e-03,  5.8289e-03, -4.0894e-03,  ..., -1.3184e-02,
          -4.6997e-03,  2.1172e-04],
         [ 1.4062e-01, -7.3047e-01, -4.3164e-01,  ..., -9.0625e-01,
           6.6797e-01, -9.2969e-01],
         [-9.3359e-01, -2.1719e+00,  3.2959e-02,  ...,  9.1406e-01,
           5.8984e-01,  5.2734e-01],
         ...,
         [-4.8047e-01,  1.2793e-01,  8.3203e-01,  ...,  5.4688e-01,
           7.3828e-01,  2.0410e-01],
         [-7.6172e-01,  1.3125e+00, -8.0469e-01,  ...,  1.3203e+00,
           9.9609e-01,  9.3750e-02],
         [ 4.5703e-01,  1.8047e+00,  7.8516e-01,  ...,  1.1094e+00,
          -1.2207e-01, -3.5742e-01]],

        [[-9.0942e-03,  5.8289e-03, -4.0894e-03,  ..., -1.3184e-02,
          -4.6997e-03,  2.1172e-04],
         [ 1.4062e-01, -7.3047e-01, -4.3164e-01,  ..., -9.0625e-01,
           6.6797e-01, -9.2969e-01],
         [-9.3359e-01, -2.1719e+00,  3.2959e-02,  ...,  9.1406e-01,
           5.8984e-01,  5.2734e-01],
         ...,
         [-4.8047e-01,  1.2793e-01,  8.3203e-01,  ...,  5.4688e-01,
           7.3828e-01,  2.0410e-01],
         [-7.6172e-01,  1.3125e+00, -8.0469e-01,  ...,  1.3203e+00,
           9.9609e-01,  9.3750e-02],
         [ 4.5703e-01,  1.8047e+00,  7.8516e-01,  ...,  1.1094e+00,
          -1.2207e-01, -3.5742e-01]],

        ...,

        [[ 3.6955e-05, -7.8735e-03,  3.1853e-04,  ...,  1.9531e-02,
           2.4986e-04, -6.5308e-03],
         [-1.0449e-01, -9.7656e-02,  5.8984e-01,  ...,  8.9844e-01,
          -1.6094e+00,  1.9434e-01],
         [-1.0781e+00,  3.4219e+00,  2.5000e+00,  ..., -6.2500e-01,
          -1.0859e+00,  1.2812e+00],
         ...,
         [-2.5195e-01, -3.1445e-01,  7.3438e-01,  ..., -4.7266e-01,
          -2.1094e+00,  3.3984e-01],
         [-1.0234e+00,  1.3672e+00,  1.4922e+00,  ..., -9.8438e-01,
          -3.2969e+00, -2.1191e-01],
         [-1.6875e+00,  1.4688e+00,  1.3594e+00,  ..., -3.1055e-01,
          -2.5156e+00,  1.2422e+00]],

        [[ 3.6955e-05, -7.8735e-03,  3.1853e-04,  ...,  1.9531e-02,
           2.4986e-04, -6.5308e-03],
         [-1.0449e-01, -9.7656e-02,  5.8984e-01,  ...,  8.9844e-01,
          -1.6094e+00,  1.9434e-01],
         [-1.0781e+00,  3.4219e+00,  2.5000e+00,  ..., -6.2500e-01,
          -1.0859e+00,  1.2812e+00],
         ...,
         [-2.5195e-01, -3.1445e-01,  7.3438e-01,  ..., -4.7266e-01,
          -2.1094e+00,  3.3984e-01],
         [-1.0234e+00,  1.3672e+00,  1.4922e+00,  ..., -9.8438e-01,
          -3.2969e+00, -2.1191e-01],
         [-1.6875e+00,  1.4688e+00,  1.3594e+00,  ..., -3.1055e-01,
          -2.5156e+00,  1.2422e+00]],

        [[ 3.6955e-05, -7.8735e-03,  3.1853e-04,  ...,  1.9531e-02,
           2.4986e-04, -6.5308e-03],
         [-1.0449e-01, -9.7656e-02,  5.8984e-01,  ...,  8.9844e-01,
          -1.6094e+00,  1.9434e-01],
         [-1.0781e+00,  3.4219e+00,  2.5000e+00,  ..., -6.2500e-01,
          -1.0859e+00,  1.2812e+00],
         ...,
         [-2.5195e-01, -3.1445e-01,  7.3438e-01,  ..., -4.7266e-01,
          -2.1094e+00,  3.3984e-01],
         [-1.0234e+00,  1.3672e+00,  1.4922e+00,  ..., -9.8438e-01,
          -3.2969e+00, -2.1191e-01],
         [-1.6875e+00,  1.4688e+00,  1.3594e+00,  ..., -3.1055e-01,
          -2.5156e+00,  1.2422e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.8076e-03, -3.0212e-03,  3.6774e-03,  ...,  4.7266e-01,
           5.4443e-02, -6.8848e-02],
         [-4.5625e+00,  8.1250e-01,  1.7500e+00,  ..., -5.6641e-01,
          -1.9141e-01,  8.0078e-02],
         [-1.1719e+00,  2.5938e+00,  1.4844e+00,  ..., -6.6406e-01,
           4.4727e-01,  4.9609e-01],
         ...,
         [-2.5312e+00,  8.7500e-01, -2.4062e+00,  ...,  9.1406e-01,
           4.8438e-01,  1.2969e+00],
         [-1.0938e+00,  3.2617e-01, -9.8438e-01,  ...,  5.5469e-01,
          -2.2754e-01, -7.5000e-01],
         [ 2.1094e+00, -2.6875e+00, -4.4727e-01,  ...,  1.5391e+00,
          -2.5781e+00, -1.7578e+00]],

        [[ 2.8076e-03, -3.0212e-03,  3.6774e-03,  ...,  4.7266e-01,
           5.4443e-02, -6.8848e-02],
         [-4.5625e+00,  8.1250e-01,  1.7500e+00,  ..., -5.6641e-01,
          -1.9141e-01,  8.0078e-02],
         [-1.1719e+00,  2.5938e+00,  1.4844e+00,  ..., -6.6406e-01,
           4.4727e-01,  4.9609e-01],
         ...,
         [-2.5312e+00,  8.7500e-01, -2.4062e+00,  ...,  9.1406e-01,
           4.8438e-01,  1.2969e+00],
         [-1.0938e+00,  3.2617e-01, -9.8438e-01,  ...,  5.5469e-01,
          -2.2754e-01, -7.5000e-01],
         [ 2.1094e+00, -2.6875e+00, -4.4727e-01,  ...,  1.5391e+00,
          -2.5781e+00, -1.7578e+00]],

        [[ 2.8076e-03, -3.0212e-03,  3.6774e-03,  ...,  4.7266e-01,
           5.4443e-02, -6.8848e-02],
         [-4.5625e+00,  8.1250e-01,  1.7500e+00,  ..., -5.6641e-01,
          -1.9141e-01,  8.0078e-02],
         [-1.1719e+00,  2.5938e+00,  1.4844e+00,  ..., -6.6406e-01,
           4.4727e-01,  4.9609e-01],
         ...,
         [-2.5312e+00,  8.7500e-01, -2.4062e+00,  ...,  9.1406e-01,
           4.8438e-01,  1.2969e+00],
         [-1.0938e+00,  3.2617e-01, -9.8438e-01,  ...,  5.5469e-01,
          -2.2754e-01, -7.5000e-01],
         [ 2.1094e+00, -2.6875e+00, -4.4727e-01,  ...,  1.5391e+00,
          -2.5781e+00, -1.7578e+00]],

        ...,

        [[ 1.2390e-02,  3.0670e-03, -8.9111e-03,  ...,  2.3438e-01,
          -5.9326e-02,  6.6895e-02],
         [-1.9043e-01, -1.5625e-02, -2.5000e-01,  ..., -2.9688e-01,
           6.7188e-01, -3.9258e-01],
         [-3.3789e-01,  2.4316e-01,  9.7168e-02,  ..., -1.0703e+00,
           5.2734e-01, -1.5000e+00],
         ...,
         [ 3.4961e-01,  6.8359e-01, -5.0000e-01,  ...,  5.1953e-01,
          -7.3438e-01, -6.1328e-01],
         [ 7.3242e-02,  9.3750e-02, -2.8516e-01,  ...,  1.1484e+00,
          -1.0107e-01,  5.0049e-02],
         [ 7.6953e-01, -7.6172e-01, -1.0469e+00,  ...,  1.2891e+00,
           7.8906e-01,  2.9907e-02]],

        [[ 1.2390e-02,  3.0670e-03, -8.9111e-03,  ...,  2.3438e-01,
          -5.9326e-02,  6.6895e-02],
         [-1.9043e-01, -1.5625e-02, -2.5000e-01,  ..., -2.9688e-01,
           6.7188e-01, -3.9258e-01],
         [-3.3789e-01,  2.4316e-01,  9.7168e-02,  ..., -1.0703e+00,
           5.2734e-01, -1.5000e+00],
         ...,
         [ 3.4961e-01,  6.8359e-01, -5.0000e-01,  ...,  5.1953e-01,
          -7.3438e-01, -6.1328e-01],
         [ 7.3242e-02,  9.3750e-02, -2.8516e-01,  ...,  1.1484e+00,
          -1.0107e-01,  5.0049e-02],
         [ 7.6953e-01, -7.6172e-01, -1.0469e+00,  ...,  1.2891e+00,
           7.8906e-01,  2.9907e-02]],

        [[ 1.2390e-02,  3.0670e-03, -8.9111e-03,  ...,  2.3438e-01,
          -5.9326e-02,  6.6895e-02],
         [-1.9043e-01, -1.5625e-02, -2.5000e-01,  ..., -2.9688e-01,
           6.7188e-01, -3.9258e-01],
         [-3.3789e-01,  2.4316e-01,  9.7168e-02,  ..., -1.0703e+00,
           5.2734e-01, -1.5000e+00],
         ...,
         [ 3.4961e-01,  6.8359e-01, -5.0000e-01,  ...,  5.1953e-01,
          -7.3438e-01, -6.1328e-01],
         [ 7.3242e-02,  9.3750e-02, -2.8516e-01,  ...,  1.1484e+00,
          -1.0107e-01,  5.0049e-02],
         [ 7.6953e-01, -7.6172e-01, -1.0469e+00,  ...,  1.2891e+00,
           7.8906e-01,  2.9907e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-3.1891e-03,  2.6703e-03, -3.2196e-03,  ..., -9.7656e-03,
          -9.3994e-03,  2.1667e-03],
         [-3.2422e-01,  1.2891e-01,  5.2344e-01,  ...,  4.4922e-01,
           1.2598e-01, -9.6484e-01],
         [ 8.7500e-01, -7.2656e-01,  8.9355e-02,  ...,  5.7812e-01,
          -7.8516e-01, -1.2793e-01],
         ...,
         [-4.2969e-01, -1.3047e+00, -1.1377e-01,  ...,  1.5625e+00,
           1.0781e+00, -1.0840e-01],
         [ 1.0000e+00, -3.8867e-01, -1.0547e+00,  ...,  3.3984e-01,
          -4.8633e-01, -2.3242e-01],
         [-1.3906e+00, -1.2578e+00, -5.5469e-01,  ...,  2.1094e+00,
           2.3438e+00,  1.1953e+00]],

        [[-3.1891e-03,  2.6703e-03, -3.2196e-03,  ..., -9.7656e-03,
          -9.3994e-03,  2.1667e-03],
         [-3.2422e-01,  1.2891e-01,  5.2344e-01,  ...,  4.4922e-01,
           1.2598e-01, -9.6484e-01],
         [ 8.7500e-01, -7.2656e-01,  8.9355e-02,  ...,  5.7812e-01,
          -7.8516e-01, -1.2793e-01],
         ...,
         [-4.2969e-01, -1.3047e+00, -1.1377e-01,  ...,  1.5625e+00,
           1.0781e+00, -1.0840e-01],
         [ 1.0000e+00, -3.8867e-01, -1.0547e+00,  ...,  3.3984e-01,
          -4.8633e-01, -2.3242e-01],
         [-1.3906e+00, -1.2578e+00, -5.5469e-01,  ...,  2.1094e+00,
           2.3438e+00,  1.1953e+00]],

        [[-3.1891e-03,  2.6703e-03, -3.2196e-03,  ..., -9.7656e-03,
          -9.3994e-03,  2.1667e-03],
         [-3.2422e-01,  1.2891e-01,  5.2344e-01,  ...,  4.4922e-01,
           1.2598e-01, -9.6484e-01],
         [ 8.7500e-01, -7.2656e-01,  8.9355e-02,  ...,  5.7812e-01,
          -7.8516e-01, -1.2793e-01],
         ...,
         [-4.2969e-01, -1.3047e+00, -1.1377e-01,  ...,  1.5625e+00,
           1.0781e+00, -1.0840e-01],
         [ 1.0000e+00, -3.8867e-01, -1.0547e+00,  ...,  3.3984e-01,
          -4.8633e-01, -2.3242e-01],
         [-1.3906e+00, -1.2578e+00, -5.5469e-01,  ...,  2.1094e+00,
           2.3438e+00,  1.1953e+00]],

        ...,

        [[-4.9133e-03, -5.0049e-03, -8.3008e-03,  ..., -8.8501e-04,
          -2.7618e-03,  1.4587e-02],
         [-5.5469e-01, -1.0859e+00, -3.4570e-01,  ..., -1.1641e+00,
           5.1172e-01,  1.9629e-01],
         [ 2.1875e-01, -9.4531e-01,  2.5781e-01,  ..., -7.8516e-01,
           5.6641e-01,  1.0000e+00],
         ...,
         [ 1.1353e-02,  6.7188e-01,  1.7969e+00,  ..., -1.8594e+00,
           8.3594e-01, -3.3398e-01],
         [ 1.1641e+00,  8.3984e-02,  4.1406e-01,  ..., -1.0000e+00,
          -1.2500e+00, -2.9688e-01],
         [-2.5330e-03, -2.1191e-01,  1.6016e+00,  ..., -9.2578e-01,
           1.8516e+00, -1.0000e+00]],

        [[-4.9133e-03, -5.0049e-03, -8.3008e-03,  ..., -8.8501e-04,
          -2.7618e-03,  1.4587e-02],
         [-5.5469e-01, -1.0859e+00, -3.4570e-01,  ..., -1.1641e+00,
           5.1172e-01,  1.9629e-01],
         [ 2.1875e-01, -9.4531e-01,  2.5781e-01,  ..., -7.8516e-01,
           5.6641e-01,  1.0000e+00],
         ...,
         [ 1.1353e-02,  6.7188e-01,  1.7969e+00,  ..., -1.8594e+00,
           8.3594e-01, -3.3398e-01],
         [ 1.1641e+00,  8.3984e-02,  4.1406e-01,  ..., -1.0000e+00,
          -1.2500e+00, -2.9688e-01],
         [-2.5330e-03, -2.1191e-01,  1.6016e+00,  ..., -9.2578e-01,
           1.8516e+00, -1.0000e+00]],

        [[-4.9133e-03, -5.0049e-03, -8.3008e-03,  ..., -8.8501e-04,
          -2.7618e-03,  1.4587e-02],
         [-5.5469e-01, -1.0859e+00, -3.4570e-01,  ..., -1.1641e+00,
           5.1172e-01,  1.9629e-01],
         [ 2.1875e-01, -9.4531e-01,  2.5781e-01,  ..., -7.8516e-01,
           5.6641e-01,  1.0000e+00],
         ...,
         [ 1.1353e-02,  6.7188e-01,  1.7969e+00,  ..., -1.8594e+00,
           8.3594e-01, -3.3398e-01],
         [ 1.1641e+00,  8.3984e-02,  4.1406e-01,  ..., -1.0000e+00,
          -1.2500e+00, -2.9688e-01],
         [-2.5330e-03, -2.1191e-01,  1.6016e+00,  ..., -9.2578e-01,
           1.8516e+00, -1.0000e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.6724e-02, -4.4861e-03, -2.8076e-03,  ...,  1.1523e-01,
           1.3574e-01, -1.5747e-02],
         [ 4.4141e-01,  3.1250e-01, -9.4141e-01,  ...,  7.8906e-01,
           4.4531e-01,  4.8438e-01],
         [-1.5469e+00,  1.9727e-01, -2.2188e+00,  ...,  6.8359e-01,
          -1.1426e-01,  2.2500e+00],
         ...,
         [-3.3594e-01,  6.7188e-01,  6.0156e-01,  ...,  7.7148e-02,
          -1.5625e+00, -1.3965e-01],
         [-1.4922e+00, -2.9883e-01,  5.7031e-01,  ...,  5.3125e-01,
          -4.9414e-01,  7.4219e-02],
         [-1.6875e+00, -1.0625e+00,  3.0859e-01,  ..., -1.7734e+00,
          -4.6094e-01,  9.8828e-01]],

        [[-1.6724e-02, -4.4861e-03, -2.8076e-03,  ...,  1.1523e-01,
           1.3574e-01, -1.5747e-02],
         [ 4.4141e-01,  3.1250e-01, -9.4141e-01,  ...,  7.8906e-01,
           4.4531e-01,  4.8438e-01],
         [-1.5469e+00,  1.9727e-01, -2.2188e+00,  ...,  6.8359e-01,
          -1.1426e-01,  2.2500e+00],
         ...,
         [-3.3594e-01,  6.7188e-01,  6.0156e-01,  ...,  7.7148e-02,
          -1.5625e+00, -1.3965e-01],
         [-1.4922e+00, -2.9883e-01,  5.7031e-01,  ...,  5.3125e-01,
          -4.9414e-01,  7.4219e-02],
         [-1.6875e+00, -1.0625e+00,  3.0859e-01,  ..., -1.7734e+00,
          -4.6094e-01,  9.8828e-01]],

        [[-1.6724e-02, -4.4861e-03, -2.8076e-03,  ...,  1.1523e-01,
           1.3574e-01, -1.5747e-02],
         [ 4.4141e-01,  3.1250e-01, -9.4141e-01,  ...,  7.8906e-01,
           4.4531e-01,  4.8438e-01],
         [-1.5469e+00,  1.9727e-01, -2.2188e+00,  ...,  6.8359e-01,
          -1.1426e-01,  2.2500e+00],
         ...,
         [-3.3594e-01,  6.7188e-01,  6.0156e-01,  ...,  7.7148e-02,
          -1.5625e+00, -1.3965e-01],
         [-1.4922e+00, -2.9883e-01,  5.7031e-01,  ...,  5.3125e-01,
          -4.9414e-01,  7.4219e-02],
         [-1.6875e+00, -1.0625e+00,  3.0859e-01,  ..., -1.7734e+00,
          -4.6094e-01,  9.8828e-01]],

        ...,

        [[ 2.2125e-03,  1.2573e-02,  1.5335e-03,  ...,  9.7168e-02,
          -4.8523e-03, -1.6504e-01],
         [ 1.8203e+00,  8.9062e-01, -2.1484e-01,  ..., -2.0801e-01,
          -4.8242e-01,  1.1562e+00],
         [ 1.4375e+00,  2.5156e+00, -8.4375e-01,  ...,  2.4805e-01,
          -2.2949e-01,  2.6406e+00],
         ...,
         [-3.7500e-01,  1.3984e+00,  5.2344e-01,  ..., -1.4219e+00,
          -1.3477e-01,  3.4180e-01],
         [ 1.0625e+00,  1.3594e+00, -2.8906e-01,  ...,  1.2500e+00,
           3.9258e-01,  1.4141e+00],
         [ 2.1250e+00,  1.5625e-01, -8.6719e-01,  ..., -6.8359e-01,
          -1.3086e-01,  2.7969e+00]],

        [[ 2.2125e-03,  1.2573e-02,  1.5335e-03,  ...,  9.7168e-02,
          -4.8523e-03, -1.6504e-01],
         [ 1.8203e+00,  8.9062e-01, -2.1484e-01,  ..., -2.0801e-01,
          -4.8242e-01,  1.1562e+00],
         [ 1.4375e+00,  2.5156e+00, -8.4375e-01,  ...,  2.4805e-01,
          -2.2949e-01,  2.6406e+00],
         ...,
         [-3.7500e-01,  1.3984e+00,  5.2344e-01,  ..., -1.4219e+00,
          -1.3477e-01,  3.4180e-01],
         [ 1.0625e+00,  1.3594e+00, -2.8906e-01,  ...,  1.2500e+00,
           3.9258e-01,  1.4141e+00],
         [ 2.1250e+00,  1.5625e-01, -8.6719e-01,  ..., -6.8359e-01,
          -1.3086e-01,  2.7969e+00]],

        [[ 2.2125e-03,  1.2573e-02,  1.5335e-03,  ...,  9.7168e-02,
          -4.8523e-03, -1.6504e-01],
         [ 1.8203e+00,  8.9062e-01, -2.1484e-01,  ..., -2.0801e-01,
          -4.8242e-01,  1.1562e+00],
         [ 1.4375e+00,  2.5156e+00, -8.4375e-01,  ...,  2.4805e-01,
          -2.2949e-01,  2.6406e+00],
         ...,
         [-3.7500e-01,  1.3984e+00,  5.2344e-01,  ..., -1.4219e+00,
          -1.3477e-01,  3.4180e-01],
         [ 1.0625e+00,  1.3594e+00, -2.8906e-01,  ...,  1.2500e+00,
           3.9258e-01,  1.4141e+00],
         [ 2.1250e+00,  1.5625e-01, -8.6719e-01,  ..., -6.8359e-01,
          -1.3086e-01,  2.7969e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 0.0064,  0.0118,  0.0106,  ...,  0.0147,  0.0085, -0.0025],
         [-1.1484, -1.8438,  0.8281,  ...,  1.4219, -1.5938, -0.3262],
         [-0.4883, -0.5859, -0.4316,  ...,  0.0128, -0.0199,  1.2422],
         ...,
         [-0.0113,  0.6250, -0.1709,  ..., -0.4199,  0.6523,  0.4590],
         [ 0.2715,  0.1885, -0.7188,  ...,  0.4238,  0.5508,  0.9023],
         [-0.2246, -1.1953, -0.4316,  ...,  0.0273,  1.0234,  0.8789]],

        [[ 0.0064,  0.0118,  0.0106,  ...,  0.0147,  0.0085, -0.0025],
         [-1.1484, -1.8438,  0.8281,  ...,  1.4219, -1.5938, -0.3262],
         [-0.4883, -0.5859, -0.4316,  ...,  0.0128, -0.0199,  1.2422],
         ...,
         [-0.0113,  0.6250, -0.1709,  ..., -0.4199,  0.6523,  0.4590],
         [ 0.2715,  0.1885, -0.7188,  ...,  0.4238,  0.5508,  0.9023],
         [-0.2246, -1.1953, -0.4316,  ...,  0.0273,  1.0234,  0.8789]],

        [[ 0.0064,  0.0118,  0.0106,  ...,  0.0147,  0.0085, -0.0025],
         [-1.1484, -1.8438,  0.8281,  ...,  1.4219, -1.5938, -0.3262],
         [-0.4883, -0.5859, -0.4316,  ...,  0.0128, -0.0199,  1.2422],
         ...,
         [-0.0113,  0.6250, -0.1709,  ..., -0.4199,  0.6523,  0.4590],
         [ 0.2715,  0.1885, -0.7188,  ...,  0.4238,  0.5508,  0.9023],
         [-0.2246, -1.1953, -0.4316,  ...,  0.0273,  1.0234,  0.8789]],

        ...,

        [[-0.0048,  0.0029, -0.0053,  ..., -0.0062, -0.0024,  0.0045],
         [ 0.4160, -0.6328,  0.5195,  ...,  0.7461, -0.2754,  0.5273],
         [-1.4375,  1.0703,  0.4629,  ..., -0.9336, -1.2734, -0.4043],
         ...,
         [-1.0938, -0.0304,  0.2852,  ..., -0.2266, -0.4199, -1.0000],
         [-2.0781, -0.8477, -1.4531,  ..., -0.8672, -1.1094, -0.0728],
         [ 1.1094,  0.2715, -0.0287,  ...,  0.6367,  0.2637, -0.0728]],

        [[-0.0048,  0.0029, -0.0053,  ..., -0.0062, -0.0024,  0.0045],
         [ 0.4160, -0.6328,  0.5195,  ...,  0.7461, -0.2754,  0.5273],
         [-1.4375,  1.0703,  0.4629,  ..., -0.9336, -1.2734, -0.4043],
         ...,
         [-1.0938, -0.0304,  0.2852,  ..., -0.2266, -0.4199, -1.0000],
         [-2.0781, -0.8477, -1.4531,  ..., -0.8672, -1.1094, -0.0728],
         [ 1.1094,  0.2715, -0.0287,  ...,  0.6367,  0.2637, -0.0728]],

        [[-0.0048,  0.0029, -0.0053,  ..., -0.0062, -0.0024,  0.0045],
         [ 0.4160, -0.6328,  0.5195,  ...,  0.7461, -0.2754,  0.5273],
         [-1.4375,  1.0703,  0.4629,  ..., -0.9336, -1.2734, -0.4043],
         ...,
         [-1.0938, -0.0304,  0.2852,  ..., -0.2266, -0.4199, -1.0000],
         [-2.0781, -0.8477, -1.4531,  ..., -0.8672, -1.1094, -0.0728],
         [ 1.1094,  0.2715, -0.0287,  ...,  0.6367,  0.2637, -0.0728]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[-7.9346e-04,  2.2430e-03,  1.5717e-03,  ..., -7.9956e-03,
          -1.6406e+00,  8.7402e-02],
         [ 3.9453e-01,  1.5625e-01, -1.2266e+00,  ..., -1.7285e-01,
           6.5938e+00, -1.5527e-01],
         [ 4.2188e-01, -3.2812e-01,  1.5430e-01,  ..., -2.4375e+00,
           9.6875e+00,  1.8281e+00],
         ...,
         [ 3.8672e-01,  1.8555e-01, -7.6172e-01,  ...,  4.9062e+00,
           7.7188e+00, -4.0000e+00],
         [-3.9062e-03,  1.2695e-01, -2.9492e-01,  ...,  1.1797e+00,
           5.7188e+00, -2.1094e+00],
         [-1.0059e-01, -1.6602e-01, -9.6875e-01,  ..., -2.2969e+00,
           8.8125e+00, -1.2500e+00]],

        [[-7.9346e-04,  2.2430e-03,  1.5717e-03,  ..., -7.9956e-03,
          -1.6406e+00,  8.7402e-02],
         [ 3.9453e-01,  1.5625e-01, -1.2266e+00,  ..., -1.7285e-01,
           6.5938e+00, -1.5527e-01],
         [ 4.2188e-01, -3.2812e-01,  1.5430e-01,  ..., -2.4375e+00,
           9.6875e+00,  1.8281e+00],
         ...,
         [ 3.8672e-01,  1.8555e-01, -7.6172e-01,  ...,  4.9062e+00,
           7.7188e+00, -4.0000e+00],
         [-3.9062e-03,  1.2695e-01, -2.9492e-01,  ...,  1.1797e+00,
           5.7188e+00, -2.1094e+00],
         [-1.0059e-01, -1.6602e-01, -9.6875e-01,  ..., -2.2969e+00,
           8.8125e+00, -1.2500e+00]],

        [[-7.9346e-04,  2.2430e-03,  1.5717e-03,  ..., -7.9956e-03,
          -1.6406e+00,  8.7402e-02],
         [ 3.9453e-01,  1.5625e-01, -1.2266e+00,  ..., -1.7285e-01,
           6.5938e+00, -1.5527e-01],
         [ 4.2188e-01, -3.2812e-01,  1.5430e-01,  ..., -2.4375e+00,
           9.6875e+00,  1.8281e+00],
         ...,
         [ 3.8672e-01,  1.8555e-01, -7.6172e-01,  ...,  4.9062e+00,
           7.7188e+00, -4.0000e+00],
         [-3.9062e-03,  1.2695e-01, -2.9492e-01,  ...,  1.1797e+00,
           5.7188e+00, -2.1094e+00],
         [-1.0059e-01, -1.6602e-01, -9.6875e-01,  ..., -2.2969e+00,
           8.8125e+00, -1.2500e+00]],

        ...,

        [[-3.6774e-03, -1.3000e-02, -1.9989e-03,  ..., -5.5908e-02,
           2.0386e-02, -7.2754e-02],
         [ 1.8594e+00, -1.2344e+00, -7.7344e-01,  ...,  5.4297e-01,
          -1.0703e+00,  1.8359e+00],
         [ 1.0000e+00, -7.2656e-01, -4.6875e-01,  ...,  7.3438e-01,
           3.0273e-01, -2.5781e-01],
         ...,
         [-6.2500e-01,  9.7266e-01,  1.9062e+00,  ..., -1.0938e+00,
           2.7656e+00,  6.9141e-01],
         [ 2.8320e-01, -1.2256e-01,  1.0625e+00,  ..., -1.0625e+00,
           5.5469e-01, -1.7344e+00],
         [ 1.8594e+00,  1.4219e+00,  1.4844e+00,  ..., -1.1016e+00,
           1.4453e+00,  7.5391e-01]],

        [[-3.6774e-03, -1.3000e-02, -1.9989e-03,  ..., -5.5908e-02,
           2.0386e-02, -7.2754e-02],
         [ 1.8594e+00, -1.2344e+00, -7.7344e-01,  ...,  5.4297e-01,
          -1.0703e+00,  1.8359e+00],
         [ 1.0000e+00, -7.2656e-01, -4.6875e-01,  ...,  7.3438e-01,
           3.0273e-01, -2.5781e-01],
         ...,
         [-6.2500e-01,  9.7266e-01,  1.9062e+00,  ..., -1.0938e+00,
           2.7656e+00,  6.9141e-01],
         [ 2.8320e-01, -1.2256e-01,  1.0625e+00,  ..., -1.0625e+00,
           5.5469e-01, -1.7344e+00],
         [ 1.8594e+00,  1.4219e+00,  1.4844e+00,  ..., -1.1016e+00,
           1.4453e+00,  7.5391e-01]],

        [[-3.6774e-03, -1.3000e-02, -1.9989e-03,  ..., -5.5908e-02,
           2.0386e-02, -7.2754e-02],
         [ 1.8594e+00, -1.2344e+00, -7.7344e-01,  ...,  5.4297e-01,
          -1.0703e+00,  1.8359e+00],
         [ 1.0000e+00, -7.2656e-01, -4.6875e-01,  ...,  7.3438e-01,
           3.0273e-01, -2.5781e-01],
         ...,
         [-6.2500e-01,  9.7266e-01,  1.9062e+00,  ..., -1.0938e+00,
           2.7656e+00,  6.9141e-01],
         [ 2.8320e-01, -1.2256e-01,  1.0625e+00,  ..., -1.0625e+00,
           5.5469e-01, -1.7344e+00],
         [ 1.8594e+00,  1.4219e+00,  1.4844e+00,  ..., -1.1016e+00,
           1.4453e+00,  7.5391e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 4.6692e-03,  2.2095e-02, -3.2043e-03,  ..., -1.2268e-02,
           1.4160e-02,  3.9307e-02],
         [ 5.5078e-01, -2.9102e-01, -6.2500e-01,  ..., -8.1635e-04,
          -1.5078e+00, -1.0781e+00],
         [-2.6367e-02,  6.0156e-01,  2.3047e-01,  ..., -8.0469e-01,
           4.0820e-01, -1.7578e-01],
         ...,
         [-1.0078e+00, -3.1445e-01, -9.9609e-01,  ...,  7.6172e-01,
          -7.9590e-02,  1.6016e+00],
         [-8.0469e-01, -7.3730e-02, -8.0859e-01,  ..., -4.1797e-01,
           5.3125e-01,  1.4297e+00],
         [ 1.0078e+00, -1.0547e+00, -1.7188e+00,  ..., -1.2109e+00,
          -2.2188e+00,  8.7891e-01]],

        [[ 4.6692e-03,  2.2095e-02, -3.2043e-03,  ..., -1.2268e-02,
           1.4160e-02,  3.9307e-02],
         [ 5.5078e-01, -2.9102e-01, -6.2500e-01,  ..., -8.1635e-04,
          -1.5078e+00, -1.0781e+00],
         [-2.6367e-02,  6.0156e-01,  2.3047e-01,  ..., -8.0469e-01,
           4.0820e-01, -1.7578e-01],
         ...,
         [-1.0078e+00, -3.1445e-01, -9.9609e-01,  ...,  7.6172e-01,
          -7.9590e-02,  1.6016e+00],
         [-8.0469e-01, -7.3730e-02, -8.0859e-01,  ..., -4.1797e-01,
           5.3125e-01,  1.4297e+00],
         [ 1.0078e+00, -1.0547e+00, -1.7188e+00,  ..., -1.2109e+00,
          -2.2188e+00,  8.7891e-01]],

        [[ 4.6692e-03,  2.2095e-02, -3.2043e-03,  ..., -1.2268e-02,
           1.4160e-02,  3.9307e-02],
         [ 5.5078e-01, -2.9102e-01, -6.2500e-01,  ..., -8.1635e-04,
          -1.5078e+00, -1.0781e+00],
         [-2.6367e-02,  6.0156e-01,  2.3047e-01,  ..., -8.0469e-01,
           4.0820e-01, -1.7578e-01],
         ...,
         [-1.0078e+00, -3.1445e-01, -9.9609e-01,  ...,  7.6172e-01,
          -7.9590e-02,  1.6016e+00],
         [-8.0469e-01, -7.3730e-02, -8.0859e-01,  ..., -4.1797e-01,
           5.3125e-01,  1.4297e+00],
         [ 1.0078e+00, -1.0547e+00, -1.7188e+00,  ..., -1.2109e+00,
          -2.2188e+00,  8.7891e-01]],

        ...,

        [[ 1.7319e-03,  1.0872e-04, -7.7209e-03,  ..., -1.3489e-02,
          -8.0566e-03,  6.9580e-03],
         [-7.5000e-01,  3.8086e-01, -8.7500e-01,  ..., -4.7070e-01,
          -1.2578e+00,  4.1602e-01],
         [-2.5000e-01,  2.1289e-01,  5.5469e-01,  ...,  8.9844e-01,
          -1.0469e+00,  1.2266e+00],
         ...,
         [-3.7842e-02,  6.1951e-03,  5.5469e-01,  ..., -1.7285e-01,
          -3.0078e-01,  8.3984e-02],
         [ 3.1250e-01, -4.4922e-02,  5.0000e-01,  ...,  1.0234e+00,
           1.3125e+00, -1.8750e-01],
         [ 1.8848e-01,  1.1875e+00,  5.0781e-01,  ...,  3.3203e-01,
           2.2949e-01,  2.6758e-01]],

        [[ 1.7319e-03,  1.0872e-04, -7.7209e-03,  ..., -1.3489e-02,
          -8.0566e-03,  6.9580e-03],
         [-7.5000e-01,  3.8086e-01, -8.7500e-01,  ..., -4.7070e-01,
          -1.2578e+00,  4.1602e-01],
         [-2.5000e-01,  2.1289e-01,  5.5469e-01,  ...,  8.9844e-01,
          -1.0469e+00,  1.2266e+00],
         ...,
         [-3.7842e-02,  6.1951e-03,  5.5469e-01,  ..., -1.7285e-01,
          -3.0078e-01,  8.3984e-02],
         [ 3.1250e-01, -4.4922e-02,  5.0000e-01,  ...,  1.0234e+00,
           1.3125e+00, -1.8750e-01],
         [ 1.8848e-01,  1.1875e+00,  5.0781e-01,  ...,  3.3203e-01,
           2.2949e-01,  2.6758e-01]],

        [[ 1.7319e-03,  1.0872e-04, -7.7209e-03,  ..., -1.3489e-02,
          -8.0566e-03,  6.9580e-03],
         [-7.5000e-01,  3.8086e-01, -8.7500e-01,  ..., -4.7070e-01,
          -1.2578e+00,  4.1602e-01],
         [-2.5000e-01,  2.1289e-01,  5.5469e-01,  ...,  8.9844e-01,
          -1.0469e+00,  1.2266e+00],
         ...,
         [-3.7842e-02,  6.1951e-03,  5.5469e-01,  ..., -1.7285e-01,
          -3.0078e-01,  8.3984e-02],
         [ 3.1250e-01, -4.4922e-02,  5.0000e-01,  ...,  1.0234e+00,
           1.3125e+00, -1.8750e-01],
         [ 1.8848e-01,  1.1875e+00,  5.0781e-01,  ...,  3.3203e-01,
           2.2949e-01,  2.6758e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.2695e-02,  1.6098e-03,  1.1658e-02,  ...,  2.1680e-01,
           1.3672e-01, -1.2812e+00],
         [ 7.8516e-01,  4.4531e-01,  3.4375e-01,  ...,  1.6406e+00,
           8.2031e-01,  1.7891e+00],
         [ 2.5781e-01, -4.2969e-01, -2.9785e-02,  ...,  1.4766e+00,
          -2.0781e+00,  2.2656e+00],
         ...,
         [-8.7891e-03, -3.9062e-01,  1.9922e-01,  ..., -3.2500e+00,
           7.1484e-01,  3.8594e+00],
         [ 3.5156e-01, -5.8594e-01, -1.5430e-01,  ..., -1.0859e+00,
          -1.1953e+00,  2.6719e+00],
         [-1.2793e-01, -2.0508e-01,  3.0664e-01,  ...,  8.1641e-01,
          -1.0469e+00,  4.5312e+00]],

        [[ 1.2695e-02,  1.6098e-03,  1.1658e-02,  ...,  2.1680e-01,
           1.3672e-01, -1.2812e+00],
         [ 7.8516e-01,  4.4531e-01,  3.4375e-01,  ...,  1.6406e+00,
           8.2031e-01,  1.7891e+00],
         [ 2.5781e-01, -4.2969e-01, -2.9785e-02,  ...,  1.4766e+00,
          -2.0781e+00,  2.2656e+00],
         ...,
         [-8.7891e-03, -3.9062e-01,  1.9922e-01,  ..., -3.2500e+00,
           7.1484e-01,  3.8594e+00],
         [ 3.5156e-01, -5.8594e-01, -1.5430e-01,  ..., -1.0859e+00,
          -1.1953e+00,  2.6719e+00],
         [-1.2793e-01, -2.0508e-01,  3.0664e-01,  ...,  8.1641e-01,
          -1.0469e+00,  4.5312e+00]],

        [[ 1.2695e-02,  1.6098e-03,  1.1658e-02,  ...,  2.1680e-01,
           1.3672e-01, -1.2812e+00],
         [ 7.8516e-01,  4.4531e-01,  3.4375e-01,  ...,  1.6406e+00,
           8.2031e-01,  1.7891e+00],
         [ 2.5781e-01, -4.2969e-01, -2.9785e-02,  ...,  1.4766e+00,
          -2.0781e+00,  2.2656e+00],
         ...,
         [-8.7891e-03, -3.9062e-01,  1.9922e-01,  ..., -3.2500e+00,
           7.1484e-01,  3.8594e+00],
         [ 3.5156e-01, -5.8594e-01, -1.5430e-01,  ..., -1.0859e+00,
          -1.1953e+00,  2.6719e+00],
         [-1.2793e-01, -2.0508e-01,  3.0664e-01,  ...,  8.1641e-01,
          -1.0469e+00,  4.5312e+00]],

        ...,

        [[ 9.3384e-03, -2.3438e-02, -4.6692e-03,  ...,  5.3516e-01,
          -2.9297e-01, -8.1543e-02],
         [ 4.1211e-01,  9.6875e-01, -4.1992e-02,  ..., -1.7188e+00,
           9.1016e-01, -3.4062e+00],
         [ 2.2852e-01,  7.6904e-03,  3.0469e-01,  ..., -2.0703e-01,
           8.8281e-01, -1.5000e+00],
         ...,
         [ 6.4453e-01, -2.1289e-01, -7.6172e-01,  ..., -1.6484e+00,
          -1.7734e+00, -2.7969e+00],
         [ 5.1758e-02, -1.7188e-01, -2.4414e-01,  ...,  1.6484e+00,
          -8.3594e-01,  4.2188e-01],
         [ 3.4961e-01, -7.7734e-01, -4.8828e-02,  ...,  1.1562e+00,
           1.0391e+00, -6.5312e+00]],

        [[ 9.3384e-03, -2.3438e-02, -4.6692e-03,  ...,  5.3516e-01,
          -2.9297e-01, -8.1543e-02],
         [ 4.1211e-01,  9.6875e-01, -4.1992e-02,  ..., -1.7188e+00,
           9.1016e-01, -3.4062e+00],
         [ 2.2852e-01,  7.6904e-03,  3.0469e-01,  ..., -2.0703e-01,
           8.8281e-01, -1.5000e+00],
         ...,
         [ 6.4453e-01, -2.1289e-01, -7.6172e-01,  ..., -1.6484e+00,
          -1.7734e+00, -2.7969e+00],
         [ 5.1758e-02, -1.7188e-01, -2.4414e-01,  ...,  1.6484e+00,
          -8.3594e-01,  4.2188e-01],
         [ 3.4961e-01, -7.7734e-01, -4.8828e-02,  ...,  1.1562e+00,
           1.0391e+00, -6.5312e+00]],

        [[ 9.3384e-03, -2.3438e-02, -4.6692e-03,  ...,  5.3516e-01,
          -2.9297e-01, -8.1543e-02],
         [ 4.1211e-01,  9.6875e-01, -4.1992e-02,  ..., -1.7188e+00,
           9.1016e-01, -3.4062e+00],
         [ 2.2852e-01,  7.6904e-03,  3.0469e-01,  ..., -2.0703e-01,
           8.8281e-01, -1.5000e+00],
         ...,
         [ 6.4453e-01, -2.1289e-01, -7.6172e-01,  ..., -1.6484e+00,
          -1.7734e+00, -2.7969e+00],
         [ 5.1758e-02, -1.7188e-01, -2.4414e-01,  ...,  1.6484e+00,
          -8.3594e-01,  4.2188e-01],
         [ 3.4961e-01, -7.7734e-01, -4.8828e-02,  ...,  1.1562e+00,
           1.0391e+00, -6.5312e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-5.3406e-03,  1.2398e-04,  1.0864e-02,  ..., -7.3547e-03,
           1.1475e-02,  7.3547e-03],
         [ 2.5156e+00, -1.6172e+00, -5.4688e-01,  ...,  1.7773e-01,
           7.1094e-01,  2.1406e+00],
         [ 2.7930e-01, -1.3125e+00, -1.1172e+00,  ...,  2.2754e-01,
           3.7109e-01,  7.1094e-01],
         ...,
         [-1.3770e-01, -5.5078e-01,  5.2344e-01,  ...,  3.8477e-01,
          -8.5449e-02,  9.4141e-01],
         [-1.4941e-01,  9.8145e-02, -5.4297e-01,  ...,  5.1172e-01,
           4.0283e-02,  7.4609e-01],
         [ 1.7109e+00, -6.5234e-01,  1.4941e-01,  ...,  1.8672e+00,
          -1.6406e+00,  1.0156e+00]],

        [[-5.3406e-03,  1.2398e-04,  1.0864e-02,  ..., -7.3547e-03,
           1.1475e-02,  7.3547e-03],
         [ 2.5156e+00, -1.6172e+00, -5.4688e-01,  ...,  1.7773e-01,
           7.1094e-01,  2.1406e+00],
         [ 2.7930e-01, -1.3125e+00, -1.1172e+00,  ...,  2.2754e-01,
           3.7109e-01,  7.1094e-01],
         ...,
         [-1.3770e-01, -5.5078e-01,  5.2344e-01,  ...,  3.8477e-01,
          -8.5449e-02,  9.4141e-01],
         [-1.4941e-01,  9.8145e-02, -5.4297e-01,  ...,  5.1172e-01,
           4.0283e-02,  7.4609e-01],
         [ 1.7109e+00, -6.5234e-01,  1.4941e-01,  ...,  1.8672e+00,
          -1.6406e+00,  1.0156e+00]],

        [[-5.3406e-03,  1.2398e-04,  1.0864e-02,  ..., -7.3547e-03,
           1.1475e-02,  7.3547e-03],
         [ 2.5156e+00, -1.6172e+00, -5.4688e-01,  ...,  1.7773e-01,
           7.1094e-01,  2.1406e+00],
         [ 2.7930e-01, -1.3125e+00, -1.1172e+00,  ...,  2.2754e-01,
           3.7109e-01,  7.1094e-01],
         ...,
         [-1.3770e-01, -5.5078e-01,  5.2344e-01,  ...,  3.8477e-01,
          -8.5449e-02,  9.4141e-01],
         [-1.4941e-01,  9.8145e-02, -5.4297e-01,  ...,  5.1172e-01,
           4.0283e-02,  7.4609e-01],
         [ 1.7109e+00, -6.5234e-01,  1.4941e-01,  ...,  1.8672e+00,
          -1.6406e+00,  1.0156e+00]],

        ...,

        [[ 5.8289e-03,  1.7822e-02, -2.7222e-02,  ..., -1.2360e-03,
           1.0071e-02,  9.7656e-04],
         [-6.9531e-01, -1.4062e+00,  4.2773e-01,  ...,  6.6016e-01,
           1.1719e-01,  2.8438e+00],
         [-3.5352e-01, -1.9141e+00,  2.0996e-01,  ..., -3.6328e-01,
           1.7383e-01, -1.1562e+00],
         ...,
         [-3.9648e-01, -6.9922e-01, -3.2227e-01,  ...,  4.1602e-01,
          -2.1680e-01, -2.9297e-01],
         [-1.0859e+00, -3.5352e-01, -3.5547e-01,  ...,  3.8672e-01,
           1.0547e+00,  5.9326e-02],
         [ 1.5859e+00,  1.7969e-01,  1.9297e+00,  ...,  1.9629e-01,
           9.7656e-01, -1.1406e+00]],

        [[ 5.8289e-03,  1.7822e-02, -2.7222e-02,  ..., -1.2360e-03,
           1.0071e-02,  9.7656e-04],
         [-6.9531e-01, -1.4062e+00,  4.2773e-01,  ...,  6.6016e-01,
           1.1719e-01,  2.8438e+00],
         [-3.5352e-01, -1.9141e+00,  2.0996e-01,  ..., -3.6328e-01,
           1.7383e-01, -1.1562e+00],
         ...,
         [-3.9648e-01, -6.9922e-01, -3.2227e-01,  ...,  4.1602e-01,
          -2.1680e-01, -2.9297e-01],
         [-1.0859e+00, -3.5352e-01, -3.5547e-01,  ...,  3.8672e-01,
           1.0547e+00,  5.9326e-02],
         [ 1.5859e+00,  1.7969e-01,  1.9297e+00,  ...,  1.9629e-01,
           9.7656e-01, -1.1406e+00]],

        [[ 5.8289e-03,  1.7822e-02, -2.7222e-02,  ..., -1.2360e-03,
           1.0071e-02,  9.7656e-04],
         [-6.9531e-01, -1.4062e+00,  4.2773e-01,  ...,  6.6016e-01,
           1.1719e-01,  2.8438e+00],
         [-3.5352e-01, -1.9141e+00,  2.0996e-01,  ..., -3.6328e-01,
           1.7383e-01, -1.1562e+00],
         ...,
         [-3.9648e-01, -6.9922e-01, -3.2227e-01,  ...,  4.1602e-01,
          -2.1680e-01, -2.9297e-01],
         [-1.0859e+00, -3.5352e-01, -3.5547e-01,  ...,  3.8672e-01,
           1.0547e+00,  5.9326e-02],
         [ 1.5859e+00,  1.7969e-01,  1.9297e+00,  ...,  1.9629e-01,
           9.7656e-01, -1.1406e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 8.9722e-03,  1.1139e-03, -4.6997e-03,  ...,  2.6611e-02,
          -1.4531e+00,  2.0020e-01],
         [-3.1875e+00,  1.5625e-01, -1.3438e+00,  ..., -1.1797e+00,
           7.4688e+00, -3.3750e+00],
         [-7.5000e-01,  7.6172e-01, -5.3125e-01,  ...,  7.8516e-01,
           8.7500e+00, -2.4844e+00],
         ...,
         [-3.9844e-01,  1.3965e-01,  5.4688e-01,  ..., -5.5859e-01,
           5.4688e+00, -4.1562e+00],
         [-4.7656e-01, -1.4062e-01,  3.7891e-01,  ...,  6.3965e-02,
           6.6875e+00, -2.1875e+00],
         [ 7.3047e-01, -1.1250e+00, -1.0312e+00,  ..., -5.3516e-01,
           7.5000e+00, -1.1562e+00]],

        [[ 8.9722e-03,  1.1139e-03, -4.6997e-03,  ...,  2.6611e-02,
          -1.4531e+00,  2.0020e-01],
         [-3.1875e+00,  1.5625e-01, -1.3438e+00,  ..., -1.1797e+00,
           7.4688e+00, -3.3750e+00],
         [-7.5000e-01,  7.6172e-01, -5.3125e-01,  ...,  7.8516e-01,
           8.7500e+00, -2.4844e+00],
         ...,
         [-3.9844e-01,  1.3965e-01,  5.4688e-01,  ..., -5.5859e-01,
           5.4688e+00, -4.1562e+00],
         [-4.7656e-01, -1.4062e-01,  3.7891e-01,  ...,  6.3965e-02,
           6.6875e+00, -2.1875e+00],
         [ 7.3047e-01, -1.1250e+00, -1.0312e+00,  ..., -5.3516e-01,
           7.5000e+00, -1.1562e+00]],

        [[ 8.9722e-03,  1.1139e-03, -4.6997e-03,  ...,  2.6611e-02,
          -1.4531e+00,  2.0020e-01],
         [-3.1875e+00,  1.5625e-01, -1.3438e+00,  ..., -1.1797e+00,
           7.4688e+00, -3.3750e+00],
         [-7.5000e-01,  7.6172e-01, -5.3125e-01,  ...,  7.8516e-01,
           8.7500e+00, -2.4844e+00],
         ...,
         [-3.9844e-01,  1.3965e-01,  5.4688e-01,  ..., -5.5859e-01,
           5.4688e+00, -4.1562e+00],
         [-4.7656e-01, -1.4062e-01,  3.7891e-01,  ...,  6.3965e-02,
           6.6875e+00, -2.1875e+00],
         [ 7.3047e-01, -1.1250e+00, -1.0312e+00,  ..., -5.3516e-01,
           7.5000e+00, -1.1562e+00]],

        ...,

        [[-6.3782e-03, -8.7280e-03, -9.7046e-03,  ...,  8.0078e-02,
           8.1055e-02, -2.4707e-01],
         [ 3.1562e+00,  3.8672e-01,  2.2266e-01,  ..., -4.9805e-01,
          -1.1484e+00,  1.0645e-01],
         [ 9.3359e-01,  1.5391e+00,  1.1328e-01,  ...,  1.3203e+00,
           2.0938e+00,  7.9688e-01],
         ...,
         [ 9.6484e-01,  4.7266e-01, -1.6699e-01,  ..., -2.5000e+00,
           6.3281e-01,  1.9922e+00],
         [ 6.7969e-01,  1.0156e+00,  3.4375e-01,  ..., -2.1250e+00,
           2.3594e+00,  7.5000e-01],
         [-1.0938e+00,  5.5859e-01,  2.3633e-01,  ...,  3.1445e-01,
           6.7969e-01, -1.6016e-01]],

        [[-6.3782e-03, -8.7280e-03, -9.7046e-03,  ...,  8.0078e-02,
           8.1055e-02, -2.4707e-01],
         [ 3.1562e+00,  3.8672e-01,  2.2266e-01,  ..., -4.9805e-01,
          -1.1484e+00,  1.0645e-01],
         [ 9.3359e-01,  1.5391e+00,  1.1328e-01,  ...,  1.3203e+00,
           2.0938e+00,  7.9688e-01],
         ...,
         [ 9.6484e-01,  4.7266e-01, -1.6699e-01,  ..., -2.5000e+00,
           6.3281e-01,  1.9922e+00],
         [ 6.7969e-01,  1.0156e+00,  3.4375e-01,  ..., -2.1250e+00,
           2.3594e+00,  7.5000e-01],
         [-1.0938e+00,  5.5859e-01,  2.3633e-01,  ...,  3.1445e-01,
           6.7969e-01, -1.6016e-01]],

        [[-6.3782e-03, -8.7280e-03, -9.7046e-03,  ...,  8.0078e-02,
           8.1055e-02, -2.4707e-01],
         [ 3.1562e+00,  3.8672e-01,  2.2266e-01,  ..., -4.9805e-01,
          -1.1484e+00,  1.0645e-01],
         [ 9.3359e-01,  1.5391e+00,  1.1328e-01,  ...,  1.3203e+00,
           2.0938e+00,  7.9688e-01],
         ...,
         [ 9.6484e-01,  4.7266e-01, -1.6699e-01,  ..., -2.5000e+00,
           6.3281e-01,  1.9922e+00],
         [ 6.7969e-01,  1.0156e+00,  3.4375e-01,  ..., -2.1250e+00,
           2.3594e+00,  7.5000e-01],
         [-1.0938e+00,  5.5859e-01,  2.3633e-01,  ...,  3.1445e-01,
           6.7969e-01, -1.6016e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 7.6294e-03, -6.6528e-03, -5.7678e-03,  ...,  6.5613e-03,
          -5.0354e-03,  2.7313e-03],
         [ 9.2969e-01,  1.1641e+00, -3.2812e-01,  ...,  1.3281e+00,
           2.6367e-01,  2.0625e+00],
         [-4.9609e-01, -8.0078e-01,  4.6680e-01,  ...,  5.9766e-01,
          -2.7734e-01,  8.7500e-01],
         ...,
         [ 1.0469e+00, -6.2109e-01, -1.0625e+00,  ...,  4.8047e-01,
           1.5000e+00, -7.9688e-01],
         [ 5.7812e-01, -6.2891e-01,  1.5039e-01,  ...,  6.8359e-01,
           7.3047e-01, -2.0703e-01],
         [ 1.3594e+00,  9.7168e-02,  2.0781e+00,  ..., -9.9609e-01,
          -1.0781e+00,  1.6016e+00]],

        [[ 7.6294e-03, -6.6528e-03, -5.7678e-03,  ...,  6.5613e-03,
          -5.0354e-03,  2.7313e-03],
         [ 9.2969e-01,  1.1641e+00, -3.2812e-01,  ...,  1.3281e+00,
           2.6367e-01,  2.0625e+00],
         [-4.9609e-01, -8.0078e-01,  4.6680e-01,  ...,  5.9766e-01,
          -2.7734e-01,  8.7500e-01],
         ...,
         [ 1.0469e+00, -6.2109e-01, -1.0625e+00,  ...,  4.8047e-01,
           1.5000e+00, -7.9688e-01],
         [ 5.7812e-01, -6.2891e-01,  1.5039e-01,  ...,  6.8359e-01,
           7.3047e-01, -2.0703e-01],
         [ 1.3594e+00,  9.7168e-02,  2.0781e+00,  ..., -9.9609e-01,
          -1.0781e+00,  1.6016e+00]],

        [[ 7.6294e-03, -6.6528e-03, -5.7678e-03,  ...,  6.5613e-03,
          -5.0354e-03,  2.7313e-03],
         [ 9.2969e-01,  1.1641e+00, -3.2812e-01,  ...,  1.3281e+00,
           2.6367e-01,  2.0625e+00],
         [-4.9609e-01, -8.0078e-01,  4.6680e-01,  ...,  5.9766e-01,
          -2.7734e-01,  8.7500e-01],
         ...,
         [ 1.0469e+00, -6.2109e-01, -1.0625e+00,  ...,  4.8047e-01,
           1.5000e+00, -7.9688e-01],
         [ 5.7812e-01, -6.2891e-01,  1.5039e-01,  ...,  6.8359e-01,
           7.3047e-01, -2.0703e-01],
         [ 1.3594e+00,  9.7168e-02,  2.0781e+00,  ..., -9.9609e-01,
          -1.0781e+00,  1.6016e+00]],

        ...,

        [[-6.7234e-05,  2.8381e-03, -5.2795e-03,  ...,  2.2736e-03,
          -6.7444e-03,  6.7749e-03],
         [-7.3242e-02,  5.4688e-01,  8.6328e-01,  ...,  7.6953e-01,
           4.8633e-01,  3.1055e-01],
         [-6.3672e-01,  3.1445e-01,  2.5977e-01,  ..., -3.6865e-02,
           1.9165e-02, -8.3594e-01],
         ...,
         [-1.0303e-01,  9.1016e-01,  4.2383e-01,  ..., -8.9453e-01,
           5.4297e-01,  7.2266e-01],
         [-2.6367e-01, -5.5859e-01,  1.3184e-01,  ...,  5.0391e-01,
           1.7734e+00, -1.7773e-01],
         [ 1.0312e+00, -1.5625e-01,  7.3438e-01,  ..., -1.5332e-01,
           6.8359e-01,  9.1016e-01]],

        [[-6.7234e-05,  2.8381e-03, -5.2795e-03,  ...,  2.2736e-03,
          -6.7444e-03,  6.7749e-03],
         [-7.3242e-02,  5.4688e-01,  8.6328e-01,  ...,  7.6953e-01,
           4.8633e-01,  3.1055e-01],
         [-6.3672e-01,  3.1445e-01,  2.5977e-01,  ..., -3.6865e-02,
           1.9165e-02, -8.3594e-01],
         ...,
         [-1.0303e-01,  9.1016e-01,  4.2383e-01,  ..., -8.9453e-01,
           5.4297e-01,  7.2266e-01],
         [-2.6367e-01, -5.5859e-01,  1.3184e-01,  ...,  5.0391e-01,
           1.7734e+00, -1.7773e-01],
         [ 1.0312e+00, -1.5625e-01,  7.3438e-01,  ..., -1.5332e-01,
           6.8359e-01,  9.1016e-01]],

        [[-6.7234e-05,  2.8381e-03, -5.2795e-03,  ...,  2.2736e-03,
          -6.7444e-03,  6.7749e-03],
         [-7.3242e-02,  5.4688e-01,  8.6328e-01,  ...,  7.6953e-01,
           4.8633e-01,  3.1055e-01],
         [-6.3672e-01,  3.1445e-01,  2.5977e-01,  ..., -3.6865e-02,
           1.9165e-02, -8.3594e-01],
         ...,
         [-1.0303e-01,  9.1016e-01,  4.2383e-01,  ..., -8.9453e-01,
           5.4297e-01,  7.2266e-01],
         [-2.6367e-01, -5.5859e-01,  1.3184e-01,  ...,  5.0391e-01,
           1.7734e+00, -1.7773e-01],
         [ 1.0312e+00, -1.5625e-01,  7.3438e-01,  ..., -1.5332e-01,
           6.8359e-01,  9.1016e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.3245e-02, -6.1035e-04,  2.7618e-03,  ..., -1.0303e-01,
          -1.4941e-01,  1.9727e-01],
         [-1.1875e+00, -4.7070e-01,  2.4531e+00,  ...,  9.4141e-01,
           2.5195e-01,  1.8125e+00],
         [ 1.2451e-01,  2.1387e-01,  2.1973e-01,  ...,  1.1719e+00,
          -3.3398e-01, -2.8125e+00],
         ...,
         [ 5.8594e-03,  4.7266e-01,  6.2500e-01,  ...,  5.7031e-01,
           1.3203e+00, -1.0156e+00],
         [ 1.2695e-01, -2.5586e-01,  8.5938e-01,  ..., -3.5156e-01,
           1.4160e-01, -1.8906e+00],
         [ 2.7344e-01,  7.3438e-01,  1.5391e+00,  ...,  1.2578e+00,
           9.4922e-01,  7.7344e-01]],

        [[ 1.3245e-02, -6.1035e-04,  2.7618e-03,  ..., -1.0303e-01,
          -1.4941e-01,  1.9727e-01],
         [-1.1875e+00, -4.7070e-01,  2.4531e+00,  ...,  9.4141e-01,
           2.5195e-01,  1.8125e+00],
         [ 1.2451e-01,  2.1387e-01,  2.1973e-01,  ...,  1.1719e+00,
          -3.3398e-01, -2.8125e+00],
         ...,
         [ 5.8594e-03,  4.7266e-01,  6.2500e-01,  ...,  5.7031e-01,
           1.3203e+00, -1.0156e+00],
         [ 1.2695e-01, -2.5586e-01,  8.5938e-01,  ..., -3.5156e-01,
           1.4160e-01, -1.8906e+00],
         [ 2.7344e-01,  7.3438e-01,  1.5391e+00,  ...,  1.2578e+00,
           9.4922e-01,  7.7344e-01]],

        [[ 1.3245e-02, -6.1035e-04,  2.7618e-03,  ..., -1.0303e-01,
          -1.4941e-01,  1.9727e-01],
         [-1.1875e+00, -4.7070e-01,  2.4531e+00,  ...,  9.4141e-01,
           2.5195e-01,  1.8125e+00],
         [ 1.2451e-01,  2.1387e-01,  2.1973e-01,  ...,  1.1719e+00,
          -3.3398e-01, -2.8125e+00],
         ...,
         [ 5.8594e-03,  4.7266e-01,  6.2500e-01,  ...,  5.7031e-01,
           1.3203e+00, -1.0156e+00],
         [ 1.2695e-01, -2.5586e-01,  8.5938e-01,  ..., -3.5156e-01,
           1.4160e-01, -1.8906e+00],
         [ 2.7344e-01,  7.3438e-01,  1.5391e+00,  ...,  1.2578e+00,
           9.4922e-01,  7.7344e-01]],

        ...,

        [[ 7.3242e-03,  2.6398e-03,  3.4027e-03,  ..., -1.0352e-01,
           8.5938e-01, -3.6914e-01],
         [-1.3281e-01,  6.6016e-01,  1.0059e-01,  ...,  2.9531e+00,
          -3.9062e+00, -5.7031e-01],
         [-2.3828e-01,  6.2012e-02, -2.3535e-01,  ...,  1.9609e+00,
          -5.2500e+00,  1.8516e+00],
         ...,
         [-7.7344e-01, -4.3750e-01,  2.8125e-01,  ..., -1.9531e+00,
           6.1768e-02,  1.1406e+00],
         [-6.1523e-02, -5.1562e-01,  4.5117e-01,  ..., -8.5156e-01,
          -2.5156e+00,  1.6250e+00],
         [-2.0898e-01, -3.9453e-01,  8.7109e-01,  ...,  2.0000e+00,
          -3.0000e+00,  1.9629e-01]],

        [[ 7.3242e-03,  2.6398e-03,  3.4027e-03,  ..., -1.0352e-01,
           8.5938e-01, -3.6914e-01],
         [-1.3281e-01,  6.6016e-01,  1.0059e-01,  ...,  2.9531e+00,
          -3.9062e+00, -5.7031e-01],
         [-2.3828e-01,  6.2012e-02, -2.3535e-01,  ...,  1.9609e+00,
          -5.2500e+00,  1.8516e+00],
         ...,
         [-7.7344e-01, -4.3750e-01,  2.8125e-01,  ..., -1.9531e+00,
           6.1768e-02,  1.1406e+00],
         [-6.1523e-02, -5.1562e-01,  4.5117e-01,  ..., -8.5156e-01,
          -2.5156e+00,  1.6250e+00],
         [-2.0898e-01, -3.9453e-01,  8.7109e-01,  ...,  2.0000e+00,
          -3.0000e+00,  1.9629e-01]],

        [[ 7.3242e-03,  2.6398e-03,  3.4027e-03,  ..., -1.0352e-01,
           8.5938e-01, -3.6914e-01],
         [-1.3281e-01,  6.6016e-01,  1.0059e-01,  ...,  2.9531e+00,
          -3.9062e+00, -5.7031e-01],
         [-2.3828e-01,  6.2012e-02, -2.3535e-01,  ...,  1.9609e+00,
          -5.2500e+00,  1.8516e+00],
         ...,
         [-7.7344e-01, -4.3750e-01,  2.8125e-01,  ..., -1.9531e+00,
           6.1768e-02,  1.1406e+00],
         [-6.1523e-02, -5.1562e-01,  4.5117e-01,  ..., -8.5156e-01,
          -2.5156e+00,  1.6250e+00],
         [-2.0898e-01, -3.9453e-01,  8.7109e-01,  ...,  2.0000e+00,
          -3.0000e+00,  1.9629e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-7.5378e-03,  4.1504e-03,  5.4932e-03,  ..., -3.3569e-03,
           1.2024e-02,  3.7384e-03],
         [-1.2656e+00, -1.1250e+00, -8.9453e-01,  ...,  3.3008e-01,
          -6.0156e-01, -7.5391e-01],
         [-2.0156e+00, -9.3750e-01, -2.6245e-03,  ..., -3.0000e+00,
           9.2578e-01, -1.0234e+00],
         ...,
         [ 7.5391e-01, -1.2734e+00, -9.1406e-01,  ..., -3.1562e+00,
           1.1719e+00,  5.8203e-01],
         [ 8.2812e-01, -5.7031e-01, -1.7500e+00,  ..., -5.6875e+00,
           2.3594e+00, -1.2344e+00],
         [-3.9258e-01,  1.4801e-03, -1.1016e+00,  ..., -2.4219e+00,
          -6.8750e-01, -4.8047e-01]],

        [[-7.5378e-03,  4.1504e-03,  5.4932e-03,  ..., -3.3569e-03,
           1.2024e-02,  3.7384e-03],
         [-1.2656e+00, -1.1250e+00, -8.9453e-01,  ...,  3.3008e-01,
          -6.0156e-01, -7.5391e-01],
         [-2.0156e+00, -9.3750e-01, -2.6245e-03,  ..., -3.0000e+00,
           9.2578e-01, -1.0234e+00],
         ...,
         [ 7.5391e-01, -1.2734e+00, -9.1406e-01,  ..., -3.1562e+00,
           1.1719e+00,  5.8203e-01],
         [ 8.2812e-01, -5.7031e-01, -1.7500e+00,  ..., -5.6875e+00,
           2.3594e+00, -1.2344e+00],
         [-3.9258e-01,  1.4801e-03, -1.1016e+00,  ..., -2.4219e+00,
          -6.8750e-01, -4.8047e-01]],

        [[-7.5378e-03,  4.1504e-03,  5.4932e-03,  ..., -3.3569e-03,
           1.2024e-02,  3.7384e-03],
         [-1.2656e+00, -1.1250e+00, -8.9453e-01,  ...,  3.3008e-01,
          -6.0156e-01, -7.5391e-01],
         [-2.0156e+00, -9.3750e-01, -2.6245e-03,  ..., -3.0000e+00,
           9.2578e-01, -1.0234e+00],
         ...,
         [ 7.5391e-01, -1.2734e+00, -9.1406e-01,  ..., -3.1562e+00,
           1.1719e+00,  5.8203e-01],
         [ 8.2812e-01, -5.7031e-01, -1.7500e+00,  ..., -5.6875e+00,
           2.3594e+00, -1.2344e+00],
         [-3.9258e-01,  1.4801e-03, -1.1016e+00,  ..., -2.4219e+00,
          -6.8750e-01, -4.8047e-01]],

        ...,

        [[ 7.0190e-03, -1.6556e-03, -3.8605e-03,  ..., -3.7384e-03,
           2.1973e-03,  1.5564e-03],
         [-8.9844e-01, -1.7422e+00,  4.0312e+00,  ..., -1.5469e+00,
           9.5312e-01,  2.0625e+00],
         [ 1.4954e-03, -8.6719e-01,  2.4375e+00,  ..., -1.3672e+00,
          -6.4062e-01, -1.4453e-01],
         ...,
         [-2.2363e-01,  1.2344e+00, -1.8750e+00,  ...,  1.4038e-02,
           2.9297e-01,  1.0312e+00],
         [-5.4016e-03,  6.6797e-01, -5.9375e-01,  ...,  1.4453e-01,
           1.8262e-01,  8.2031e-01],
         [-1.1641e+00, -7.0312e-01,  1.0938e+00,  ..., -1.9531e+00,
           7.4609e-01,  4.9023e-01]],

        [[ 7.0190e-03, -1.6556e-03, -3.8605e-03,  ..., -3.7384e-03,
           2.1973e-03,  1.5564e-03],
         [-8.9844e-01, -1.7422e+00,  4.0312e+00,  ..., -1.5469e+00,
           9.5312e-01,  2.0625e+00],
         [ 1.4954e-03, -8.6719e-01,  2.4375e+00,  ..., -1.3672e+00,
          -6.4062e-01, -1.4453e-01],
         ...,
         [-2.2363e-01,  1.2344e+00, -1.8750e+00,  ...,  1.4038e-02,
           2.9297e-01,  1.0312e+00],
         [-5.4016e-03,  6.6797e-01, -5.9375e-01,  ...,  1.4453e-01,
           1.8262e-01,  8.2031e-01],
         [-1.1641e+00, -7.0312e-01,  1.0938e+00,  ..., -1.9531e+00,
           7.4609e-01,  4.9023e-01]],

        [[ 7.0190e-03, -1.6556e-03, -3.8605e-03,  ..., -3.7384e-03,
           2.1973e-03,  1.5564e-03],
         [-8.9844e-01, -1.7422e+00,  4.0312e+00,  ..., -1.5469e+00,
           9.5312e-01,  2.0625e+00],
         [ 1.4954e-03, -8.6719e-01,  2.4375e+00,  ..., -1.3672e+00,
          -6.4062e-01, -1.4453e-01],
         ...,
         [-2.2363e-01,  1.2344e+00, -1.8750e+00,  ...,  1.4038e-02,
           2.9297e-01,  1.0312e+00],
         [-5.4016e-03,  6.6797e-01, -5.9375e-01,  ...,  1.4453e-01,
           1.8262e-01,  8.2031e-01],
         [-1.1641e+00, -7.0312e-01,  1.0938e+00,  ..., -1.9531e+00,
           7.4609e-01,  4.9023e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 8.5449e-03,  1.2741e-03, -8.4229e-03,  ..., -1.3379e-01,
          -6.2866e-03,  7.2327e-03],
         [-1.0781e+00, -6.3281e-01, -2.5391e-01,  ...,  5.6641e-01,
          -1.0625e+00,  2.5156e+00],
         [ 1.4258e-01,  2.4609e-01, -9.9219e-01,  ...,  1.7344e+00,
           2.8125e-01,  3.3906e+00],
         ...,
         [-6.8750e-01,  7.4219e-01,  8.4766e-01,  ..., -2.7500e+00,
           8.9355e-02,  5.7031e-01],
         [-3.4180e-01,  1.0078e+00, -3.3203e-02,  ...,  2.0156e+00,
          -1.0234e+00,  1.7500e+00],
         [ 6.6406e-01,  1.1328e+00,  3.1836e-01,  ..., -5.0391e-01,
          -5.0781e-01,  3.5469e+00]],

        [[ 8.5449e-03,  1.2741e-03, -8.4229e-03,  ..., -1.3379e-01,
          -6.2866e-03,  7.2327e-03],
         [-1.0781e+00, -6.3281e-01, -2.5391e-01,  ...,  5.6641e-01,
          -1.0625e+00,  2.5156e+00],
         [ 1.4258e-01,  2.4609e-01, -9.9219e-01,  ...,  1.7344e+00,
           2.8125e-01,  3.3906e+00],
         ...,
         [-6.8750e-01,  7.4219e-01,  8.4766e-01,  ..., -2.7500e+00,
           8.9355e-02,  5.7031e-01],
         [-3.4180e-01,  1.0078e+00, -3.3203e-02,  ...,  2.0156e+00,
          -1.0234e+00,  1.7500e+00],
         [ 6.6406e-01,  1.1328e+00,  3.1836e-01,  ..., -5.0391e-01,
          -5.0781e-01,  3.5469e+00]],

        [[ 8.5449e-03,  1.2741e-03, -8.4229e-03,  ..., -1.3379e-01,
          -6.2866e-03,  7.2327e-03],
         [-1.0781e+00, -6.3281e-01, -2.5391e-01,  ...,  5.6641e-01,
          -1.0625e+00,  2.5156e+00],
         [ 1.4258e-01,  2.4609e-01, -9.9219e-01,  ...,  1.7344e+00,
           2.8125e-01,  3.3906e+00],
         ...,
         [-6.8750e-01,  7.4219e-01,  8.4766e-01,  ..., -2.7500e+00,
           8.9355e-02,  5.7031e-01],
         [-3.4180e-01,  1.0078e+00, -3.3203e-02,  ...,  2.0156e+00,
          -1.0234e+00,  1.7500e+00],
         [ 6.6406e-01,  1.1328e+00,  3.1836e-01,  ..., -5.0391e-01,
          -5.0781e-01,  3.5469e+00]],

        ...,

        [[ 4.4861e-03, -7.5073e-03, -7.8735e-03,  ..., -7.3242e-02,
          -5.5908e-02,  5.2490e-02],
         [ 1.3125e+00, -1.4844e-01,  1.0781e+00,  ..., -9.2969e-01,
           2.9883e-01,  5.6641e-01],
         [-1.1484e+00,  1.1875e+00,  1.0840e-01,  ...,  8.4766e-01,
           3.3789e-01,  1.4219e+00],
         ...,
         [ 1.2695e-02,  1.3750e+00, -9.3359e-01,  ..., -2.7344e+00,
          -1.7656e+00, -3.8281e+00],
         [-1.6016e-01,  6.6406e-01, -2.6367e-01,  ..., -2.9375e+00,
          -2.0000e+00, -2.0469e+00],
         [-1.6250e+00, -4.2969e-01,  3.0859e-01,  ..., -4.7070e-01,
           3.0469e-01,  2.5312e+00]],

        [[ 4.4861e-03, -7.5073e-03, -7.8735e-03,  ..., -7.3242e-02,
          -5.5908e-02,  5.2490e-02],
         [ 1.3125e+00, -1.4844e-01,  1.0781e+00,  ..., -9.2969e-01,
           2.9883e-01,  5.6641e-01],
         [-1.1484e+00,  1.1875e+00,  1.0840e-01,  ...,  8.4766e-01,
           3.3789e-01,  1.4219e+00],
         ...,
         [ 1.2695e-02,  1.3750e+00, -9.3359e-01,  ..., -2.7344e+00,
          -1.7656e+00, -3.8281e+00],
         [-1.6016e-01,  6.6406e-01, -2.6367e-01,  ..., -2.9375e+00,
          -2.0000e+00, -2.0469e+00],
         [-1.6250e+00, -4.2969e-01,  3.0859e-01,  ..., -4.7070e-01,
           3.0469e-01,  2.5312e+00]],

        [[ 4.4861e-03, -7.5073e-03, -7.8735e-03,  ..., -7.3242e-02,
          -5.5908e-02,  5.2490e-02],
         [ 1.3125e+00, -1.4844e-01,  1.0781e+00,  ..., -9.2969e-01,
           2.9883e-01,  5.6641e-01],
         [-1.1484e+00,  1.1875e+00,  1.0840e-01,  ...,  8.4766e-01,
           3.3789e-01,  1.4219e+00],
         ...,
         [ 1.2695e-02,  1.3750e+00, -9.3359e-01,  ..., -2.7344e+00,
          -1.7656e+00, -3.8281e+00],
         [-1.6016e-01,  6.6406e-01, -2.6367e-01,  ..., -2.9375e+00,
          -2.0000e+00, -2.0469e+00],
         [-1.6250e+00, -4.2969e-01,  3.0859e-01,  ..., -4.7070e-01,
           3.0469e-01,  2.5312e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 5.7068e-03,  9.1171e-04, -8.8501e-03,  ..., -4.3869e-04,
          -9.8877e-03, -9.2316e-04],
         [-5.3906e-01, -3.3008e-01,  1.1035e-01,  ...,  1.0791e-01,
          -5.3125e-01, -3.2422e-01],
         [-1.3245e-02, -2.9492e-01, -7.5000e-01,  ..., -4.2773e-01,
          -9.1016e-01,  1.9824e-01],
         ...,
         [ 7.1094e-01,  9.6680e-02, -9.3359e-01,  ...,  2.2266e-01,
          -1.2031e+00, -1.0156e+00],
         [-5.8105e-02,  4.6484e-01, -3.0859e-01,  ..., -5.5664e-02,
          -9.3359e-01,  7.7734e-01],
         [-6.2109e-01,  1.9922e+00, -4.1797e-01,  ..., -6.7969e-01,
           2.6367e-01,  7.7637e-02]],

        [[ 5.7068e-03,  9.1171e-04, -8.8501e-03,  ..., -4.3869e-04,
          -9.8877e-03, -9.2316e-04],
         [-5.3906e-01, -3.3008e-01,  1.1035e-01,  ...,  1.0791e-01,
          -5.3125e-01, -3.2422e-01],
         [-1.3245e-02, -2.9492e-01, -7.5000e-01,  ..., -4.2773e-01,
          -9.1016e-01,  1.9824e-01],
         ...,
         [ 7.1094e-01,  9.6680e-02, -9.3359e-01,  ...,  2.2266e-01,
          -1.2031e+00, -1.0156e+00],
         [-5.8105e-02,  4.6484e-01, -3.0859e-01,  ..., -5.5664e-02,
          -9.3359e-01,  7.7734e-01],
         [-6.2109e-01,  1.9922e+00, -4.1797e-01,  ..., -6.7969e-01,
           2.6367e-01,  7.7637e-02]],

        [[ 5.7068e-03,  9.1171e-04, -8.8501e-03,  ..., -4.3869e-04,
          -9.8877e-03, -9.2316e-04],
         [-5.3906e-01, -3.3008e-01,  1.1035e-01,  ...,  1.0791e-01,
          -5.3125e-01, -3.2422e-01],
         [-1.3245e-02, -2.9492e-01, -7.5000e-01,  ..., -4.2773e-01,
          -9.1016e-01,  1.9824e-01],
         ...,
         [ 7.1094e-01,  9.6680e-02, -9.3359e-01,  ...,  2.2266e-01,
          -1.2031e+00, -1.0156e+00],
         [-5.8105e-02,  4.6484e-01, -3.0859e-01,  ..., -5.5664e-02,
          -9.3359e-01,  7.7734e-01],
         [-6.2109e-01,  1.9922e+00, -4.1797e-01,  ..., -6.7969e-01,
           2.6367e-01,  7.7637e-02]],

        ...,

        [[ 3.2227e-02,  2.3682e-02,  8.5449e-02,  ...,  9.9487e-03,
          -4.2236e-02,  3.6377e-02],
         [-3.1836e-01, -1.2695e-01,  9.5312e-01,  ..., -5.3516e-01,
           2.8711e-01, -3.7305e-01],
         [ 2.5000e-01, -4.2725e-02,  1.6797e-01,  ...,  1.0254e-01,
          -1.0078e+00,  5.4199e-02],
         ...,
         [ 2.0874e-02, -1.2734e+00,  1.9453e+00,  ...,  5.7422e-01,
           1.8188e-02, -9.0625e-01],
         [-4.6289e-01, -8.9844e-01,  4.6875e-01,  ..., -9.7656e-02,
           2.9492e-01, -1.2109e+00],
         [ 4.5898e-01, -5.0781e-01,  2.1191e-01,  ...,  1.0078e+00,
           1.1328e+00,  3.3789e-01]],

        [[ 3.2227e-02,  2.3682e-02,  8.5449e-02,  ...,  9.9487e-03,
          -4.2236e-02,  3.6377e-02],
         [-3.1836e-01, -1.2695e-01,  9.5312e-01,  ..., -5.3516e-01,
           2.8711e-01, -3.7305e-01],
         [ 2.5000e-01, -4.2725e-02,  1.6797e-01,  ...,  1.0254e-01,
          -1.0078e+00,  5.4199e-02],
         ...,
         [ 2.0874e-02, -1.2734e+00,  1.9453e+00,  ...,  5.7422e-01,
           1.8188e-02, -9.0625e-01],
         [-4.6289e-01, -8.9844e-01,  4.6875e-01,  ..., -9.7656e-02,
           2.9492e-01, -1.2109e+00],
         [ 4.5898e-01, -5.0781e-01,  2.1191e-01,  ...,  1.0078e+00,
           1.1328e+00,  3.3789e-01]],

        [[ 3.2227e-02,  2.3682e-02,  8.5449e-02,  ...,  9.9487e-03,
          -4.2236e-02,  3.6377e-02],
         [-3.1836e-01, -1.2695e-01,  9.5312e-01,  ..., -5.3516e-01,
           2.8711e-01, -3.7305e-01],
         [ 2.5000e-01, -4.2725e-02,  1.6797e-01,  ...,  1.0254e-01,
          -1.0078e+00,  5.4199e-02],
         ...,
         [ 2.0874e-02, -1.2734e+00,  1.9453e+00,  ...,  5.7422e-01,
           1.8188e-02, -9.0625e-01],
         [-4.6289e-01, -8.9844e-01,  4.6875e-01,  ..., -9.7656e-02,
           2.9492e-01, -1.2109e+00],
         [ 4.5898e-01, -5.0781e-01,  2.1191e-01,  ...,  1.0078e+00,
           1.1328e+00,  3.3789e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.4893e-02,  5.1270e-03,  1.1780e-02,  ...,  1.9375e+00,
          -9.0820e-02,  3.8086e-02],
         [ 2.3438e-02,  2.0508e-01,  1.1572e-01,  ..., -6.1250e+00,
          -3.9648e-01,  3.4766e-01],
         [ 6.9336e-02, -4.5508e-01, -1.8555e-02,  ..., -7.1875e+00,
           5.3516e-01,  2.1406e+00],
         ...,
         [ 5.8594e-01, -3.9648e-01, -5.8350e-02,  ..., -4.1875e+00,
           1.4375e+00, -4.0938e+00],
         [ 1.3477e-01, -3.2812e-01,  5.7422e-01,  ..., -5.7812e+00,
           1.1797e+00, -5.6250e-01],
         [-1.3477e-01,  8.2520e-02,  3.6621e-02,  ..., -6.3125e+00,
          -5.8984e-01,  3.1836e-01]],

        [[ 1.4893e-02,  5.1270e-03,  1.1780e-02,  ...,  1.9375e+00,
          -9.0820e-02,  3.8086e-02],
         [ 2.3438e-02,  2.0508e-01,  1.1572e-01,  ..., -6.1250e+00,
          -3.9648e-01,  3.4766e-01],
         [ 6.9336e-02, -4.5508e-01, -1.8555e-02,  ..., -7.1875e+00,
           5.3516e-01,  2.1406e+00],
         ...,
         [ 5.8594e-01, -3.9648e-01, -5.8350e-02,  ..., -4.1875e+00,
           1.4375e+00, -4.0938e+00],
         [ 1.3477e-01, -3.2812e-01,  5.7422e-01,  ..., -5.7812e+00,
           1.1797e+00, -5.6250e-01],
         [-1.3477e-01,  8.2520e-02,  3.6621e-02,  ..., -6.3125e+00,
          -5.8984e-01,  3.1836e-01]],

        [[ 1.4893e-02,  5.1270e-03,  1.1780e-02,  ...,  1.9375e+00,
          -9.0820e-02,  3.8086e-02],
         [ 2.3438e-02,  2.0508e-01,  1.1572e-01,  ..., -6.1250e+00,
          -3.9648e-01,  3.4766e-01],
         [ 6.9336e-02, -4.5508e-01, -1.8555e-02,  ..., -7.1875e+00,
           5.3516e-01,  2.1406e+00],
         ...,
         [ 5.8594e-01, -3.9648e-01, -5.8350e-02,  ..., -4.1875e+00,
           1.4375e+00, -4.0938e+00],
         [ 1.3477e-01, -3.2812e-01,  5.7422e-01,  ..., -5.7812e+00,
           1.1797e+00, -5.6250e-01],
         [-1.3477e-01,  8.2520e-02,  3.6621e-02,  ..., -6.3125e+00,
          -5.8984e-01,  3.1836e-01]],

        ...,

        [[ 2.5787e-03,  3.2501e-03,  1.0620e-02,  ...,  6.8359e-02,
           9.8047e-01, -1.3281e-01],
         [-3.9062e-03,  4.6484e-01, -1.1250e+00,  ...,  4.6631e-02,
          -1.1953e+00, -3.4766e-01],
         [-4.2188e-01, -9.4531e-01, -8.1641e-01,  ...,  1.5332e-01,
          -4.5938e+00,  2.9688e-01],
         ...,
         [-3.3008e-01, -1.1797e+00,  1.0391e+00,  ...,  7.0312e-02,
          -2.2500e+00, -2.3560e-02],
         [-4.2578e-01,  9.8145e-02,  2.2461e-02,  ..., -3.5889e-02,
          -2.8125e+00, -2.8198e-02],
         [-1.7422e+00,  7.0312e-02, -1.5625e-01,  ..., -8.1641e-01,
          -2.9844e+00,  3.8086e-01]],

        [[ 2.5787e-03,  3.2501e-03,  1.0620e-02,  ...,  6.8359e-02,
           9.8047e-01, -1.3281e-01],
         [-3.9062e-03,  4.6484e-01, -1.1250e+00,  ...,  4.6631e-02,
          -1.1953e+00, -3.4766e-01],
         [-4.2188e-01, -9.4531e-01, -8.1641e-01,  ...,  1.5332e-01,
          -4.5938e+00,  2.9688e-01],
         ...,
         [-3.3008e-01, -1.1797e+00,  1.0391e+00,  ...,  7.0312e-02,
          -2.2500e+00, -2.3560e-02],
         [-4.2578e-01,  9.8145e-02,  2.2461e-02,  ..., -3.5889e-02,
          -2.8125e+00, -2.8198e-02],
         [-1.7422e+00,  7.0312e-02, -1.5625e-01,  ..., -8.1641e-01,
          -2.9844e+00,  3.8086e-01]],

        [[ 2.5787e-03,  3.2501e-03,  1.0620e-02,  ...,  6.8359e-02,
           9.8047e-01, -1.3281e-01],
         [-3.9062e-03,  4.6484e-01, -1.1250e+00,  ...,  4.6631e-02,
          -1.1953e+00, -3.4766e-01],
         [-4.2188e-01, -9.4531e-01, -8.1641e-01,  ...,  1.5332e-01,
          -4.5938e+00,  2.9688e-01],
         ...,
         [-3.3008e-01, -1.1797e+00,  1.0391e+00,  ...,  7.0312e-02,
          -2.2500e+00, -2.3560e-02],
         [-4.2578e-01,  9.8145e-02,  2.2461e-02,  ..., -3.5889e-02,
          -2.8125e+00, -2.8198e-02],
         [-1.7422e+00,  7.0312e-02, -1.5625e-01,  ..., -8.1641e-01,
          -2.9844e+00,  3.8086e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-0.0053,  0.0099, -0.0089,  ...,  0.0029, -0.0044, -0.0042],
         [ 2.7031, -0.5781,  0.9375,  ...,  2.6094, -1.6562, -0.2637],
         [ 0.5664, -1.2891,  0.1299,  ...,  0.7930, -0.0786, -1.3047],
         ...,
         [-1.0547, -0.8203, -0.3613,  ...,  0.1406,  0.5547, -0.5547],
         [-1.5312,  0.5430, -0.8750,  ..., -0.2227, -0.8008, -0.6992],
         [ 1.0469,  0.3750, -0.0276,  ..., -0.4746,  1.1250,  0.3066]],

        [[-0.0053,  0.0099, -0.0089,  ...,  0.0029, -0.0044, -0.0042],
         [ 2.7031, -0.5781,  0.9375,  ...,  2.6094, -1.6562, -0.2637],
         [ 0.5664, -1.2891,  0.1299,  ...,  0.7930, -0.0786, -1.3047],
         ...,
         [-1.0547, -0.8203, -0.3613,  ...,  0.1406,  0.5547, -0.5547],
         [-1.5312,  0.5430, -0.8750,  ..., -0.2227, -0.8008, -0.6992],
         [ 1.0469,  0.3750, -0.0276,  ..., -0.4746,  1.1250,  0.3066]],

        [[-0.0053,  0.0099, -0.0089,  ...,  0.0029, -0.0044, -0.0042],
         [ 2.7031, -0.5781,  0.9375,  ...,  2.6094, -1.6562, -0.2637],
         [ 0.5664, -1.2891,  0.1299,  ...,  0.7930, -0.0786, -1.3047],
         ...,
         [-1.0547, -0.8203, -0.3613,  ...,  0.1406,  0.5547, -0.5547],
         [-1.5312,  0.5430, -0.8750,  ..., -0.2227, -0.8008, -0.6992],
         [ 1.0469,  0.3750, -0.0276,  ..., -0.4746,  1.1250,  0.3066]],

        ...,

        [[-0.0347,  0.0099,  0.0165,  ..., -0.0503, -0.0281,  0.0081],
         [ 0.5117,  0.2988,  0.5977,  ...,  0.2715,  0.8438, -0.1855],
         [ 1.2891,  0.1729,  0.1484,  ...,  0.2363,  0.6758, -0.1934],
         ...,
         [ 0.9688,  0.0054,  0.2578,  ...,  0.7305, -0.4180, -0.0801],
         [ 1.8203,  0.5039,  0.0349,  ..., -0.1875, -0.0165, -0.0610],
         [ 0.6172,  0.1592, -0.6055,  ..., -0.0170, -0.2832, -0.6602]],

        [[-0.0347,  0.0099,  0.0165,  ..., -0.0503, -0.0281,  0.0081],
         [ 0.5117,  0.2988,  0.5977,  ...,  0.2715,  0.8438, -0.1855],
         [ 1.2891,  0.1729,  0.1484,  ...,  0.2363,  0.6758, -0.1934],
         ...,
         [ 0.9688,  0.0054,  0.2578,  ...,  0.7305, -0.4180, -0.0801],
         [ 1.8203,  0.5039,  0.0349,  ..., -0.1875, -0.0165, -0.0610],
         [ 0.6172,  0.1592, -0.6055,  ..., -0.0170, -0.2832, -0.6602]],

        [[-0.0347,  0.0099,  0.0165,  ..., -0.0503, -0.0281,  0.0081],
         [ 0.5117,  0.2988,  0.5977,  ...,  0.2715,  0.8438, -0.1855],
         [ 1.2891,  0.1729,  0.1484,  ...,  0.2363,  0.6758, -0.1934],
         ...,
         [ 0.9688,  0.0054,  0.2578,  ...,  0.7305, -0.4180, -0.0801],
         [ 1.8203,  0.5039,  0.0349,  ..., -0.1875, -0.0165, -0.0610],
         [ 0.6172,  0.1592, -0.6055,  ..., -0.0170, -0.2832, -0.6602]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.6113e-02,  5.1498e-04, -1.2268e-02,  ...,  2.3047e-01,
           6.2500e-02,  7.8906e-01],
         [ 1.0156e-01, -8.4766e-01,  2.4219e-01,  ...,  5.6641e-01,
          -2.0781e+00, -3.1562e+00],
         [ 1.3672e-01,  1.2793e-01,  9.2188e-01,  ..., -3.5938e-01,
          -1.2109e+00, -4.1875e+00],
         ...,
         [-4.8047e-01, -2.7344e-01, -4.8828e-01,  ..., -5.5859e-01,
          -4.6484e-01,  9.4727e-02],
         [ 2.6562e-01, -7.6172e-01,  2.9297e-01,  ..., -6.7969e-01,
           5.7812e-01, -1.2812e+00],
         [ 3.6328e-01,  2.2461e-01,  3.4375e-01,  ..., -6.7578e-01,
           1.6328e+00, -2.6875e+00]],

        [[ 1.6113e-02,  5.1498e-04, -1.2268e-02,  ...,  2.3047e-01,
           6.2500e-02,  7.8906e-01],
         [ 1.0156e-01, -8.4766e-01,  2.4219e-01,  ...,  5.6641e-01,
          -2.0781e+00, -3.1562e+00],
         [ 1.3672e-01,  1.2793e-01,  9.2188e-01,  ..., -3.5938e-01,
          -1.2109e+00, -4.1875e+00],
         ...,
         [-4.8047e-01, -2.7344e-01, -4.8828e-01,  ..., -5.5859e-01,
          -4.6484e-01,  9.4727e-02],
         [ 2.6562e-01, -7.6172e-01,  2.9297e-01,  ..., -6.7969e-01,
           5.7812e-01, -1.2812e+00],
         [ 3.6328e-01,  2.2461e-01,  3.4375e-01,  ..., -6.7578e-01,
           1.6328e+00, -2.6875e+00]],

        [[ 1.6113e-02,  5.1498e-04, -1.2268e-02,  ...,  2.3047e-01,
           6.2500e-02,  7.8906e-01],
         [ 1.0156e-01, -8.4766e-01,  2.4219e-01,  ...,  5.6641e-01,
          -2.0781e+00, -3.1562e+00],
         [ 1.3672e-01,  1.2793e-01,  9.2188e-01,  ..., -3.5938e-01,
          -1.2109e+00, -4.1875e+00],
         ...,
         [-4.8047e-01, -2.7344e-01, -4.8828e-01,  ..., -5.5859e-01,
          -4.6484e-01,  9.4727e-02],
         [ 2.6562e-01, -7.6172e-01,  2.9297e-01,  ..., -6.7969e-01,
           5.7812e-01, -1.2812e+00],
         [ 3.6328e-01,  2.2461e-01,  3.4375e-01,  ..., -6.7578e-01,
           1.6328e+00, -2.6875e+00]],

        ...,

        [[-6.3171e-03,  1.1536e-02, -4.2152e-04,  ..., -9.0820e-02,
          -2.0605e-01,  7.3730e-02],
         [ 5.0781e-02,  6.1523e-02,  7.6172e-02,  ...,  2.1562e+00,
           1.2988e-01, -1.7812e+00],
         [-2.1484e-01,  3.0273e-01,  4.9219e-01,  ...,  3.0938e+00,
          -1.2266e+00, -3.8594e+00],
         ...,
         [-9.6094e-01,  3.2227e-02,  2.9297e-02,  ..., -1.8906e+00,
          -1.9141e+00, -7.1777e-02],
         [-1.7578e-01, -2.8711e-01, -6.5430e-02,  ..., -3.0884e-02,
           4.5703e-01, -2.5000e-01],
         [ 1.0234e+00,  7.0801e-02, -8.7500e-01,  ...,  2.2188e+00,
           6.3281e-01,  8.6328e-01]],

        [[-6.3171e-03,  1.1536e-02, -4.2152e-04,  ..., -9.0820e-02,
          -2.0605e-01,  7.3730e-02],
         [ 5.0781e-02,  6.1523e-02,  7.6172e-02,  ...,  2.1562e+00,
           1.2988e-01, -1.7812e+00],
         [-2.1484e-01,  3.0273e-01,  4.9219e-01,  ...,  3.0938e+00,
          -1.2266e+00, -3.8594e+00],
         ...,
         [-9.6094e-01,  3.2227e-02,  2.9297e-02,  ..., -1.8906e+00,
          -1.9141e+00, -7.1777e-02],
         [-1.7578e-01, -2.8711e-01, -6.5430e-02,  ..., -3.0884e-02,
           4.5703e-01, -2.5000e-01],
         [ 1.0234e+00,  7.0801e-02, -8.7500e-01,  ...,  2.2188e+00,
           6.3281e-01,  8.6328e-01]],

        [[-6.3171e-03,  1.1536e-02, -4.2152e-04,  ..., -9.0820e-02,
          -2.0605e-01,  7.3730e-02],
         [ 5.0781e-02,  6.1523e-02,  7.6172e-02,  ...,  2.1562e+00,
           1.2988e-01, -1.7812e+00],
         [-2.1484e-01,  3.0273e-01,  4.9219e-01,  ...,  3.0938e+00,
          -1.2266e+00, -3.8594e+00],
         ...,
         [-9.6094e-01,  3.2227e-02,  2.9297e-02,  ..., -1.8906e+00,
          -1.9141e+00, -7.1777e-02],
         [-1.7578e-01, -2.8711e-01, -6.5430e-02,  ..., -3.0884e-02,
           4.5703e-01, -2.5000e-01],
         [ 1.0234e+00,  7.0801e-02, -8.7500e-01,  ...,  2.2188e+00,
           6.3281e-01,  8.6328e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 6.7444e-03, -1.3672e-02,  2.2278e-03,  ...,  3.4180e-03,
          -8.3008e-03, -3.9062e-03],
         [ 9.1309e-02,  9.2773e-02, -8.5156e-01,  ..., -1.7734e+00,
          -1.3203e+00,  1.1094e+00],
         [ 8.8281e-01, -3.5352e-01, -1.4258e-01,  ..., -5.0391e-01,
           6.4941e-02,  2.2656e-01],
         ...,
         [ 7.9102e-02, -4.5703e-01, -6.7578e-01,  ..., -4.8242e-01,
           9.5312e-01,  2.5977e-01],
         [-1.0469e+00, -1.8555e-02,  4.5703e-01,  ...,  4.4336e-01,
          -6.0547e-01, -3.1055e-01],
         [-4.3555e-01, -1.3047e+00,  2.1582e-01,  ...,  5.9326e-02,
           1.3516e+00,  3.2617e-01]],

        [[ 6.7444e-03, -1.3672e-02,  2.2278e-03,  ...,  3.4180e-03,
          -8.3008e-03, -3.9062e-03],
         [ 9.1309e-02,  9.2773e-02, -8.5156e-01,  ..., -1.7734e+00,
          -1.3203e+00,  1.1094e+00],
         [ 8.8281e-01, -3.5352e-01, -1.4258e-01,  ..., -5.0391e-01,
           6.4941e-02,  2.2656e-01],
         ...,
         [ 7.9102e-02, -4.5703e-01, -6.7578e-01,  ..., -4.8242e-01,
           9.5312e-01,  2.5977e-01],
         [-1.0469e+00, -1.8555e-02,  4.5703e-01,  ...,  4.4336e-01,
          -6.0547e-01, -3.1055e-01],
         [-4.3555e-01, -1.3047e+00,  2.1582e-01,  ...,  5.9326e-02,
           1.3516e+00,  3.2617e-01]],

        [[ 6.7444e-03, -1.3672e-02,  2.2278e-03,  ...,  3.4180e-03,
          -8.3008e-03, -3.9062e-03],
         [ 9.1309e-02,  9.2773e-02, -8.5156e-01,  ..., -1.7734e+00,
          -1.3203e+00,  1.1094e+00],
         [ 8.8281e-01, -3.5352e-01, -1.4258e-01,  ..., -5.0391e-01,
           6.4941e-02,  2.2656e-01],
         ...,
         [ 7.9102e-02, -4.5703e-01, -6.7578e-01,  ..., -4.8242e-01,
           9.5312e-01,  2.5977e-01],
         [-1.0469e+00, -1.8555e-02,  4.5703e-01,  ...,  4.4336e-01,
          -6.0547e-01, -3.1055e-01],
         [-4.3555e-01, -1.3047e+00,  2.1582e-01,  ...,  5.9326e-02,
           1.3516e+00,  3.2617e-01]],

        ...,

        [[-2.0142e-02,  5.4626e-03,  2.3499e-03,  ..., -1.0147e-03,
          -1.6113e-02,  8.1177e-03],
         [ 1.0078e+00,  2.0801e-01, -1.6250e+00,  ...,  2.9297e-01,
           7.2656e-01,  2.6094e+00],
         [ 8.9453e-01,  3.0156e+00, -7.5000e-01,  ...,  6.4453e-01,
           1.6602e-01,  2.5625e+00],
         ...,
         [-1.6895e-01,  5.7812e-01, -1.6016e-01,  ..., -1.2598e-01,
          -5.0781e-01,  1.9141e-01],
         [-5.3516e-01,  7.4219e-01, -9.5215e-02,  ...,  6.4453e-01,
           5.2734e-01, -1.5332e-01],
         [-9.6875e-01, -1.0234e+00, -1.8203e+00,  ..., -7.6953e-01,
           8.1250e-01, -1.8281e+00]],

        [[-2.0142e-02,  5.4626e-03,  2.3499e-03,  ..., -1.0147e-03,
          -1.6113e-02,  8.1177e-03],
         [ 1.0078e+00,  2.0801e-01, -1.6250e+00,  ...,  2.9297e-01,
           7.2656e-01,  2.6094e+00],
         [ 8.9453e-01,  3.0156e+00, -7.5000e-01,  ...,  6.4453e-01,
           1.6602e-01,  2.5625e+00],
         ...,
         [-1.6895e-01,  5.7812e-01, -1.6016e-01,  ..., -1.2598e-01,
          -5.0781e-01,  1.9141e-01],
         [-5.3516e-01,  7.4219e-01, -9.5215e-02,  ...,  6.4453e-01,
           5.2734e-01, -1.5332e-01],
         [-9.6875e-01, -1.0234e+00, -1.8203e+00,  ..., -7.6953e-01,
           8.1250e-01, -1.8281e+00]],

        [[-2.0142e-02,  5.4626e-03,  2.3499e-03,  ..., -1.0147e-03,
          -1.6113e-02,  8.1177e-03],
         [ 1.0078e+00,  2.0801e-01, -1.6250e+00,  ...,  2.9297e-01,
           7.2656e-01,  2.6094e+00],
         [ 8.9453e-01,  3.0156e+00, -7.5000e-01,  ...,  6.4453e-01,
           1.6602e-01,  2.5625e+00],
         ...,
         [-1.6895e-01,  5.7812e-01, -1.6016e-01,  ..., -1.2598e-01,
          -5.0781e-01,  1.9141e-01],
         [-5.3516e-01,  7.4219e-01, -9.5215e-02,  ...,  6.4453e-01,
           5.2734e-01, -1.5332e-01],
         [-9.6875e-01, -1.0234e+00, -1.8203e+00,  ..., -7.6953e-01,
           8.1250e-01, -1.8281e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.2583e-03,  1.4648e-03, -5.1270e-03,  ...,  8.6426e-02,
          -2.5586e-01, -1.3062e-02],
         [-2.4531e+00, -2.8516e-01, -9.7266e-01,  ..., -4.1797e-01,
          -6.9141e-01, -7.8516e-01],
         [-2.7031e+00, -9.8828e-01, -1.5234e+00,  ...,  1.9141e+00,
          -1.3672e+00,  1.1797e+00],
         ...,
         [-1.3750e+00, -7.4219e-01, -8.9062e-01,  ..., -1.7266e+00,
          -5.1562e-01, -7.1094e-01],
         [-1.0391e+00, -6.3672e-01, -9.1016e-01,  ..., -1.8438e+00,
          -1.3828e+00,  3.2617e-01],
         [-2.3730e-01, -3.7891e-01, -1.7500e+00,  ..., -1.2656e+00,
          -1.1250e+00, -2.7656e+00]],

        [[ 2.2583e-03,  1.4648e-03, -5.1270e-03,  ...,  8.6426e-02,
          -2.5586e-01, -1.3062e-02],
         [-2.4531e+00, -2.8516e-01, -9.7266e-01,  ..., -4.1797e-01,
          -6.9141e-01, -7.8516e-01],
         [-2.7031e+00, -9.8828e-01, -1.5234e+00,  ...,  1.9141e+00,
          -1.3672e+00,  1.1797e+00],
         ...,
         [-1.3750e+00, -7.4219e-01, -8.9062e-01,  ..., -1.7266e+00,
          -5.1562e-01, -7.1094e-01],
         [-1.0391e+00, -6.3672e-01, -9.1016e-01,  ..., -1.8438e+00,
          -1.3828e+00,  3.2617e-01],
         [-2.3730e-01, -3.7891e-01, -1.7500e+00,  ..., -1.2656e+00,
          -1.1250e+00, -2.7656e+00]],

        [[ 2.2583e-03,  1.4648e-03, -5.1270e-03,  ...,  8.6426e-02,
          -2.5586e-01, -1.3062e-02],
         [-2.4531e+00, -2.8516e-01, -9.7266e-01,  ..., -4.1797e-01,
          -6.9141e-01, -7.8516e-01],
         [-2.7031e+00, -9.8828e-01, -1.5234e+00,  ...,  1.9141e+00,
          -1.3672e+00,  1.1797e+00],
         ...,
         [-1.3750e+00, -7.4219e-01, -8.9062e-01,  ..., -1.7266e+00,
          -5.1562e-01, -7.1094e-01],
         [-1.0391e+00, -6.3672e-01, -9.1016e-01,  ..., -1.8438e+00,
          -1.3828e+00,  3.2617e-01],
         [-2.3730e-01, -3.7891e-01, -1.7500e+00,  ..., -1.2656e+00,
          -1.1250e+00, -2.7656e+00]],

        ...,

        [[-3.2349e-03,  5.8594e-03,  7.5378e-03,  ..., -8.3008e-02,
           6.6406e-02, -8.7891e-02],
         [-3.1250e+00,  9.4141e-01, -2.5781e-01,  ...,  6.3281e-01,
          -1.2969e+00,  9.1797e-02],
         [ 1.6406e-01, -5.3906e-01, -2.2188e+00,  ..., -1.6602e-01,
          -2.3906e+00,  1.1328e-01],
         ...,
         [-1.6875e+00,  6.4062e-01,  1.8906e+00,  ..., -8.6426e-02,
           7.7344e-01, -1.6406e+00],
         [-3.6719e-01,  1.5430e-01,  1.8438e+00,  ..., -1.1182e-01,
           2.2949e-02, -1.5469e+00],
         [ 1.8359e+00,  1.3828e+00,  5.3906e-01,  ..., -6.1719e-01,
          -5.9766e-01,  1.7456e-02]],

        [[-3.2349e-03,  5.8594e-03,  7.5378e-03,  ..., -8.3008e-02,
           6.6406e-02, -8.7891e-02],
         [-3.1250e+00,  9.4141e-01, -2.5781e-01,  ...,  6.3281e-01,
          -1.2969e+00,  9.1797e-02],
         [ 1.6406e-01, -5.3906e-01, -2.2188e+00,  ..., -1.6602e-01,
          -2.3906e+00,  1.1328e-01],
         ...,
         [-1.6875e+00,  6.4062e-01,  1.8906e+00,  ..., -8.6426e-02,
           7.7344e-01, -1.6406e+00],
         [-3.6719e-01,  1.5430e-01,  1.8438e+00,  ..., -1.1182e-01,
           2.2949e-02, -1.5469e+00],
         [ 1.8359e+00,  1.3828e+00,  5.3906e-01,  ..., -6.1719e-01,
          -5.9766e-01,  1.7456e-02]],

        [[-3.2349e-03,  5.8594e-03,  7.5378e-03,  ..., -8.3008e-02,
           6.6406e-02, -8.7891e-02],
         [-3.1250e+00,  9.4141e-01, -2.5781e-01,  ...,  6.3281e-01,
          -1.2969e+00,  9.1797e-02],
         [ 1.6406e-01, -5.3906e-01, -2.2188e+00,  ..., -1.6602e-01,
          -2.3906e+00,  1.1328e-01],
         ...,
         [-1.6875e+00,  6.4062e-01,  1.8906e+00,  ..., -8.6426e-02,
           7.7344e-01, -1.6406e+00],
         [-3.6719e-01,  1.5430e-01,  1.8438e+00,  ..., -1.1182e-01,
           2.2949e-02, -1.5469e+00],
         [ 1.8359e+00,  1.3828e+00,  5.3906e-01,  ..., -6.1719e-01,
          -5.9766e-01,  1.7456e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 5.5847e-03, -3.1738e-03,  1.5137e-02,  ...,  1.8692e-03,
           1.3733e-02,  1.8433e-02],
         [ 1.6992e-01, -1.4609e+00,  1.2012e-01,  ..., -3.7305e-01,
           1.3359e+00, -2.1191e-01],
         [ 3.2617e-01, -1.7090e-01,  1.1172e+00,  ...,  1.2656e+00,
          -4.8633e-01, -4.1602e-01],
         ...,
         [-1.3672e-01,  2.1777e-01,  7.2266e-01,  ..., -1.3062e-02,
          -1.6113e-01, -3.6914e-01],
         [ 6.2891e-01,  5.0000e-01, -7.6172e-02,  ...,  3.5352e-01,
           7.7734e-01, -3.6523e-01],
         [-8.6719e-01,  4.4922e-01, -5.7031e-01,  ..., -2.5781e-01,
          -6.9141e-01, -1.5469e+00]],

        [[ 5.5847e-03, -3.1738e-03,  1.5137e-02,  ...,  1.8692e-03,
           1.3733e-02,  1.8433e-02],
         [ 1.6992e-01, -1.4609e+00,  1.2012e-01,  ..., -3.7305e-01,
           1.3359e+00, -2.1191e-01],
         [ 3.2617e-01, -1.7090e-01,  1.1172e+00,  ...,  1.2656e+00,
          -4.8633e-01, -4.1602e-01],
         ...,
         [-1.3672e-01,  2.1777e-01,  7.2266e-01,  ..., -1.3062e-02,
          -1.6113e-01, -3.6914e-01],
         [ 6.2891e-01,  5.0000e-01, -7.6172e-02,  ...,  3.5352e-01,
           7.7734e-01, -3.6523e-01],
         [-8.6719e-01,  4.4922e-01, -5.7031e-01,  ..., -2.5781e-01,
          -6.9141e-01, -1.5469e+00]],

        [[ 5.5847e-03, -3.1738e-03,  1.5137e-02,  ...,  1.8692e-03,
           1.3733e-02,  1.8433e-02],
         [ 1.6992e-01, -1.4609e+00,  1.2012e-01,  ..., -3.7305e-01,
           1.3359e+00, -2.1191e-01],
         [ 3.2617e-01, -1.7090e-01,  1.1172e+00,  ...,  1.2656e+00,
          -4.8633e-01, -4.1602e-01],
         ...,
         [-1.3672e-01,  2.1777e-01,  7.2266e-01,  ..., -1.3062e-02,
          -1.6113e-01, -3.6914e-01],
         [ 6.2891e-01,  5.0000e-01, -7.6172e-02,  ...,  3.5352e-01,
           7.7734e-01, -3.6523e-01],
         [-8.6719e-01,  4.4922e-01, -5.7031e-01,  ..., -2.5781e-01,
          -6.9141e-01, -1.5469e+00]],

        ...,

        [[ 3.8910e-03, -1.4572e-03, -4.1809e-03,  ...,  1.4282e-02,
          -5.4321e-03,  6.4392e-03],
         [ 1.1250e+00, -6.4453e-01, -1.4551e-01,  ..., -6.5430e-02,
           6.2891e-01, -6.4453e-01],
         [-3.3398e-01, -1.8457e-01,  5.4688e-01,  ...,  8.7500e-01,
           2.1191e-01,  1.0781e+00],
         ...,
         [-2.0215e-01, -4.0625e-01,  1.2578e+00,  ...,  4.1797e-01,
           1.7969e+00, -8.7891e-02],
         [ 3.3203e-01,  8.5938e-01, -5.9766e-01,  ..., -7.1484e-01,
          -9.8438e-01, -5.3906e-01],
         [ 1.4922e+00, -2.6367e-01,  9.3262e-02,  ...,  8.6719e-01,
          -4.4336e-01,  7.1875e-01]],

        [[ 3.8910e-03, -1.4572e-03, -4.1809e-03,  ...,  1.4282e-02,
          -5.4321e-03,  6.4392e-03],
         [ 1.1250e+00, -6.4453e-01, -1.4551e-01,  ..., -6.5430e-02,
           6.2891e-01, -6.4453e-01],
         [-3.3398e-01, -1.8457e-01,  5.4688e-01,  ...,  8.7500e-01,
           2.1191e-01,  1.0781e+00],
         ...,
         [-2.0215e-01, -4.0625e-01,  1.2578e+00,  ...,  4.1797e-01,
           1.7969e+00, -8.7891e-02],
         [ 3.3203e-01,  8.5938e-01, -5.9766e-01,  ..., -7.1484e-01,
          -9.8438e-01, -5.3906e-01],
         [ 1.4922e+00, -2.6367e-01,  9.3262e-02,  ...,  8.6719e-01,
          -4.4336e-01,  7.1875e-01]],

        [[ 3.8910e-03, -1.4572e-03, -4.1809e-03,  ...,  1.4282e-02,
          -5.4321e-03,  6.4392e-03],
         [ 1.1250e+00, -6.4453e-01, -1.4551e-01,  ..., -6.5430e-02,
           6.2891e-01, -6.4453e-01],
         [-3.3398e-01, -1.8457e-01,  5.4688e-01,  ...,  8.7500e-01,
           2.1191e-01,  1.0781e+00],
         ...,
         [-2.0215e-01, -4.0625e-01,  1.2578e+00,  ...,  4.1797e-01,
           1.7969e+00, -8.7891e-02],
         [ 3.3203e-01,  8.5938e-01, -5.9766e-01,  ..., -7.1484e-01,
          -9.8438e-01, -5.3906e-01],
         [ 1.4922e+00, -2.6367e-01,  9.3262e-02,  ...,  8.6719e-01,
          -4.4336e-01,  7.1875e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 9.3994e-03, -1.4465e-02, -9.5825e-03,  ..., -1.4551e-01,
           5.5469e-01,  2.6562e+00],
         [ 9.4922e-01, -3.2812e-01,  2.4414e-02,  ..., -9.5312e-01,
          -3.2422e-01, -6.4062e+00],
         [-4.3750e-01, -2.2070e-01,  1.6016e-01,  ...,  1.9453e+00,
          -2.6978e-02, -9.3125e+00],
         ...,
         [-6.6895e-02, -3.8477e-01, -2.7344e-01,  ...,  3.0273e-01,
          -1.2988e-01, -6.1875e+00],
         [ 1.6016e-01, -1.3867e-01, -2.9883e-01,  ...,  2.9102e-01,
          -1.1016e+00, -4.7500e+00],
         [ 2.0801e-01, -3.7109e-01,  8.1250e-01,  ...,  2.9844e+00,
           4.0771e-02, -7.2500e+00]],

        [[ 9.3994e-03, -1.4465e-02, -9.5825e-03,  ..., -1.4551e-01,
           5.5469e-01,  2.6562e+00],
         [ 9.4922e-01, -3.2812e-01,  2.4414e-02,  ..., -9.5312e-01,
          -3.2422e-01, -6.4062e+00],
         [-4.3750e-01, -2.2070e-01,  1.6016e-01,  ...,  1.9453e+00,
          -2.6978e-02, -9.3125e+00],
         ...,
         [-6.6895e-02, -3.8477e-01, -2.7344e-01,  ...,  3.0273e-01,
          -1.2988e-01, -6.1875e+00],
         [ 1.6016e-01, -1.3867e-01, -2.9883e-01,  ...,  2.9102e-01,
          -1.1016e+00, -4.7500e+00],
         [ 2.0801e-01, -3.7109e-01,  8.1250e-01,  ...,  2.9844e+00,
           4.0771e-02, -7.2500e+00]],

        [[ 9.3994e-03, -1.4465e-02, -9.5825e-03,  ..., -1.4551e-01,
           5.5469e-01,  2.6562e+00],
         [ 9.4922e-01, -3.2812e-01,  2.4414e-02,  ..., -9.5312e-01,
          -3.2422e-01, -6.4062e+00],
         [-4.3750e-01, -2.2070e-01,  1.6016e-01,  ...,  1.9453e+00,
          -2.6978e-02, -9.3125e+00],
         ...,
         [-6.6895e-02, -3.8477e-01, -2.7344e-01,  ...,  3.0273e-01,
          -1.2988e-01, -6.1875e+00],
         [ 1.6016e-01, -1.3867e-01, -2.9883e-01,  ...,  2.9102e-01,
          -1.1016e+00, -4.7500e+00],
         [ 2.0801e-01, -3.7109e-01,  8.1250e-01,  ...,  2.9844e+00,
           4.0771e-02, -7.2500e+00]],

        ...,

        [[-2.3926e-02, -3.6240e-04, -4.7607e-03,  ...,  2.4048e-02,
           2.1875e+00,  3.8086e-01],
         [ 1.6250e+00, -7.9297e-01, -4.3359e-01,  ..., -3.8086e-01,
          -4.6250e+00,  2.2852e-01],
         [ 7.5195e-02, -5.9766e-01,  7.3828e-01,  ...,  9.8047e-01,
          -3.7500e+00,  1.2256e-01],
         ...,
         [-4.1211e-01, -7.4219e-01,  1.6113e-01,  ...,  2.2344e+00,
          -3.8125e+00,  9.4922e-01],
         [-3.1250e-01, -2.9492e-01, -5.8105e-02,  ...,  9.8438e-01,
          -4.7188e+00,  7.6172e-01],
         [-1.5234e+00, -9.0625e-01,  2.8711e-01,  ...,  2.0469e+00,
          -3.7812e+00,  1.7500e+00]],

        [[-2.3926e-02, -3.6240e-04, -4.7607e-03,  ...,  2.4048e-02,
           2.1875e+00,  3.8086e-01],
         [ 1.6250e+00, -7.9297e-01, -4.3359e-01,  ..., -3.8086e-01,
          -4.6250e+00,  2.2852e-01],
         [ 7.5195e-02, -5.9766e-01,  7.3828e-01,  ...,  9.8047e-01,
          -3.7500e+00,  1.2256e-01],
         ...,
         [-4.1211e-01, -7.4219e-01,  1.6113e-01,  ...,  2.2344e+00,
          -3.8125e+00,  9.4922e-01],
         [-3.1250e-01, -2.9492e-01, -5.8105e-02,  ...,  9.8438e-01,
          -4.7188e+00,  7.6172e-01],
         [-1.5234e+00, -9.0625e-01,  2.8711e-01,  ...,  2.0469e+00,
          -3.7812e+00,  1.7500e+00]],

        [[-2.3926e-02, -3.6240e-04, -4.7607e-03,  ...,  2.4048e-02,
           2.1875e+00,  3.8086e-01],
         [ 1.6250e+00, -7.9297e-01, -4.3359e-01,  ..., -3.8086e-01,
          -4.6250e+00,  2.2852e-01],
         [ 7.5195e-02, -5.9766e-01,  7.3828e-01,  ...,  9.8047e-01,
          -3.7500e+00,  1.2256e-01],
         ...,
         [-4.1211e-01, -7.4219e-01,  1.6113e-01,  ...,  2.2344e+00,
          -3.8125e+00,  9.4922e-01],
         [-3.1250e-01, -2.9492e-01, -5.8105e-02,  ...,  9.8438e-01,
          -4.7188e+00,  7.6172e-01],
         [-1.5234e+00, -9.0625e-01,  2.8711e-01,  ...,  2.0469e+00,
          -3.7812e+00,  1.7500e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 1.2268e-02, -1.1047e-02, -2.2949e-02,  ...,  1.1963e-02,
          -4.4861e-03,  2.0508e-02],
         [-3.2031e-01, -8.4766e-01,  5.0391e-01,  ..., -1.3184e-01,
           2.0801e-01, -1.2188e+00],
         [-5.3711e-02,  1.8281e+00,  2.7656e+00,  ..., -1.2256e-01,
           2.0000e+00,  1.0234e+00],
         ...,
         [-1.8203e+00,  1.4297e+00,  3.7305e-01,  ...,  1.6328e+00,
           1.0703e+00,  2.3906e+00],
         [-3.9844e-01,  4.4531e-01,  9.3750e-01,  ...,  1.8750e-01,
           1.4141e+00,  2.0625e+00],
         [-9.2188e-01,  2.4414e-01, -8.1641e-01,  ..., -8.8281e-01,
          -8.3008e-02, -7.6562e-01]],

        [[ 1.2268e-02, -1.1047e-02, -2.2949e-02,  ...,  1.1963e-02,
          -4.4861e-03,  2.0508e-02],
         [-3.2031e-01, -8.4766e-01,  5.0391e-01,  ..., -1.3184e-01,
           2.0801e-01, -1.2188e+00],
         [-5.3711e-02,  1.8281e+00,  2.7656e+00,  ..., -1.2256e-01,
           2.0000e+00,  1.0234e+00],
         ...,
         [-1.8203e+00,  1.4297e+00,  3.7305e-01,  ...,  1.6328e+00,
           1.0703e+00,  2.3906e+00],
         [-3.9844e-01,  4.4531e-01,  9.3750e-01,  ...,  1.8750e-01,
           1.4141e+00,  2.0625e+00],
         [-9.2188e-01,  2.4414e-01, -8.1641e-01,  ..., -8.8281e-01,
          -8.3008e-02, -7.6562e-01]],

        [[ 1.2268e-02, -1.1047e-02, -2.2949e-02,  ...,  1.1963e-02,
          -4.4861e-03,  2.0508e-02],
         [-3.2031e-01, -8.4766e-01,  5.0391e-01,  ..., -1.3184e-01,
           2.0801e-01, -1.2188e+00],
         [-5.3711e-02,  1.8281e+00,  2.7656e+00,  ..., -1.2256e-01,
           2.0000e+00,  1.0234e+00],
         ...,
         [-1.8203e+00,  1.4297e+00,  3.7305e-01,  ...,  1.6328e+00,
           1.0703e+00,  2.3906e+00],
         [-3.9844e-01,  4.4531e-01,  9.3750e-01,  ...,  1.8750e-01,
           1.4141e+00,  2.0625e+00],
         [-9.2188e-01,  2.4414e-01, -8.1641e-01,  ..., -8.8281e-01,
          -8.3008e-02, -7.6562e-01]],

        ...,

        [[ 4.6082e-03, -1.6022e-04, -2.8687e-02,  ..., -4.7607e-03,
          -8.6060e-03, -6.3705e-04],
         [-1.5938e+00, -1.2305e-01,  2.7148e-01,  ..., -1.8555e-02,
          -4.1016e-01,  4.1406e-01],
         [-1.4062e+00,  6.1328e-01, -2.0605e-01,  ..., -3.7305e-01,
           5.9375e-01,  4.0625e-01],
         ...,
         [-3.4375e-01,  1.2969e+00,  1.0625e+00,  ..., -6.4941e-02,
          -5.7031e-01, -1.8906e+00],
         [ 3.0859e-01,  2.1484e-01,  7.1094e-01,  ..., -7.7734e-01,
           8.3203e-01,  7.1777e-02],
         [ 8.0859e-01,  2.4512e-01,  6.4844e-01,  ..., -1.0000e+00,
          -1.9629e-01, -1.4297e+00]],

        [[ 4.6082e-03, -1.6022e-04, -2.8687e-02,  ..., -4.7607e-03,
          -8.6060e-03, -6.3705e-04],
         [-1.5938e+00, -1.2305e-01,  2.7148e-01,  ..., -1.8555e-02,
          -4.1016e-01,  4.1406e-01],
         [-1.4062e+00,  6.1328e-01, -2.0605e-01,  ..., -3.7305e-01,
           5.9375e-01,  4.0625e-01],
         ...,
         [-3.4375e-01,  1.2969e+00,  1.0625e+00,  ..., -6.4941e-02,
          -5.7031e-01, -1.8906e+00],
         [ 3.0859e-01,  2.1484e-01,  7.1094e-01,  ..., -7.7734e-01,
           8.3203e-01,  7.1777e-02],
         [ 8.0859e-01,  2.4512e-01,  6.4844e-01,  ..., -1.0000e+00,
          -1.9629e-01, -1.4297e+00]],

        [[ 4.6082e-03, -1.6022e-04, -2.8687e-02,  ..., -4.7607e-03,
          -8.6060e-03, -6.3705e-04],
         [-1.5938e+00, -1.2305e-01,  2.7148e-01,  ..., -1.8555e-02,
          -4.1016e-01,  4.1406e-01],
         [-1.4062e+00,  6.1328e-01, -2.0605e-01,  ..., -3.7305e-01,
           5.9375e-01,  4.0625e-01],
         ...,
         [-3.4375e-01,  1.2969e+00,  1.0625e+00,  ..., -6.4941e-02,
          -5.7031e-01, -1.8906e+00],
         [ 3.0859e-01,  2.1484e-01,  7.1094e-01,  ..., -7.7734e-01,
           8.3203e-01,  7.1777e-02],
         [ 8.0859e-01,  2.4512e-01,  6.4844e-01,  ..., -1.0000e+00,
          -1.9629e-01, -1.4297e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 7.6904e-03, -4.5166e-03,  4.1809e-03,  ...,  7.6599e-03,
          -2.2754e-01,  7.9102e-02],
         [-2.2031e+00,  9.3750e-02, -2.4512e-01,  ...,  1.4609e+00,
           7.2266e-01, -8.1641e-01],
         [-1.9062e+00, -5.0781e-01,  7.1289e-02,  ...,  1.2031e+00,
           1.2344e+00, -8.0078e-01],
         ...,
         [-8.5547e-01,  3.1250e-02, -8.2031e-01,  ...,  3.7969e+00,
          -1.9727e-01,  1.7656e+00],
         [-1.0625e+00, -1.1406e+00, -5.3516e-01,  ...,  4.0312e+00,
          -9.2188e-01,  4.9805e-01],
         [-1.1953e+00, -3.3984e-01, -1.0234e+00,  ...,  1.4297e+00,
           2.4062e+00,  2.4023e-01]],

        [[ 7.6904e-03, -4.5166e-03,  4.1809e-03,  ...,  7.6599e-03,
          -2.2754e-01,  7.9102e-02],
         [-2.2031e+00,  9.3750e-02, -2.4512e-01,  ...,  1.4609e+00,
           7.2266e-01, -8.1641e-01],
         [-1.9062e+00, -5.0781e-01,  7.1289e-02,  ...,  1.2031e+00,
           1.2344e+00, -8.0078e-01],
         ...,
         [-8.5547e-01,  3.1250e-02, -8.2031e-01,  ...,  3.7969e+00,
          -1.9727e-01,  1.7656e+00],
         [-1.0625e+00, -1.1406e+00, -5.3516e-01,  ...,  4.0312e+00,
          -9.2188e-01,  4.9805e-01],
         [-1.1953e+00, -3.3984e-01, -1.0234e+00,  ...,  1.4297e+00,
           2.4062e+00,  2.4023e-01]],

        [[ 7.6904e-03, -4.5166e-03,  4.1809e-03,  ...,  7.6599e-03,
          -2.2754e-01,  7.9102e-02],
         [-2.2031e+00,  9.3750e-02, -2.4512e-01,  ...,  1.4609e+00,
           7.2266e-01, -8.1641e-01],
         [-1.9062e+00, -5.0781e-01,  7.1289e-02,  ...,  1.2031e+00,
           1.2344e+00, -8.0078e-01],
         ...,
         [-8.5547e-01,  3.1250e-02, -8.2031e-01,  ...,  3.7969e+00,
          -1.9727e-01,  1.7656e+00],
         [-1.0625e+00, -1.1406e+00, -5.3516e-01,  ...,  4.0312e+00,
          -9.2188e-01,  4.9805e-01],
         [-1.1953e+00, -3.3984e-01, -1.0234e+00,  ...,  1.4297e+00,
           2.4062e+00,  2.4023e-01]],

        ...,

        [[ 4.9744e-03, -6.3171e-03,  2.1076e-04,  ..., -7.6172e-02,
           2.6562e-01,  1.2354e-01],
         [ 1.2969e+00,  1.5625e+00,  1.3906e+00,  ...,  1.0703e+00,
           5.5078e-01,  1.7812e+00],
         [ 4.7266e-01, -8.3984e-01,  1.8750e+00,  ..., -9.1406e-01,
           1.4453e+00,  4.1992e-01],
         ...,
         [ 9.0625e-01, -9.6875e-01, -6.8848e-02,  ..., -1.6406e-01,
           2.1094e+00, -1.0000e+00],
         [-2.7832e-02, -2.7539e-01, -1.5137e-01,  ..., -2.0625e+00,
           1.6641e+00, -2.5156e+00],
         [-6.0547e-01, -4.1797e-01,  1.6953e+00,  ..., -3.7891e-01,
           1.7031e+00, -2.8906e-01]],

        [[ 4.9744e-03, -6.3171e-03,  2.1076e-04,  ..., -7.6172e-02,
           2.6562e-01,  1.2354e-01],
         [ 1.2969e+00,  1.5625e+00,  1.3906e+00,  ...,  1.0703e+00,
           5.5078e-01,  1.7812e+00],
         [ 4.7266e-01, -8.3984e-01,  1.8750e+00,  ..., -9.1406e-01,
           1.4453e+00,  4.1992e-01],
         ...,
         [ 9.0625e-01, -9.6875e-01, -6.8848e-02,  ..., -1.6406e-01,
           2.1094e+00, -1.0000e+00],
         [-2.7832e-02, -2.7539e-01, -1.5137e-01,  ..., -2.0625e+00,
           1.6641e+00, -2.5156e+00],
         [-6.0547e-01, -4.1797e-01,  1.6953e+00,  ..., -3.7891e-01,
           1.7031e+00, -2.8906e-01]],

        [[ 4.9744e-03, -6.3171e-03,  2.1076e-04,  ..., -7.6172e-02,
           2.6562e-01,  1.2354e-01],
         [ 1.2969e+00,  1.5625e+00,  1.3906e+00,  ...,  1.0703e+00,
           5.5078e-01,  1.7812e+00],
         [ 4.7266e-01, -8.3984e-01,  1.8750e+00,  ..., -9.1406e-01,
           1.4453e+00,  4.1992e-01],
         ...,
         [ 9.0625e-01, -9.6875e-01, -6.8848e-02,  ..., -1.6406e-01,
           2.1094e+00, -1.0000e+00],
         [-2.7832e-02, -2.7539e-01, -1.5137e-01,  ..., -2.0625e+00,
           1.6641e+00, -2.5156e+00],
         [-6.0547e-01, -4.1797e-01,  1.6953e+00,  ..., -3.7891e-01,
           1.7031e+00, -2.8906e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-4.2725e-03,  5.8594e-03,  1.6937e-03,  ...,  5.0049e-03,
          -1.9775e-02, -1.0620e-02],
         [ 3.3203e-01,  6.7969e-01,  8.1641e-01,  ...,  3.4375e-01,
          -1.2207e-01, -5.8984e-01],
         [-2.0215e-01,  7.8516e-01, -8.2812e-01,  ...,  1.2969e+00,
          -4.8047e-01, -6.5234e-01],
         ...,
         [-1.4531e+00,  1.0703e+00,  7.7148e-02,  ...,  2.0156e+00,
           3.0938e+00, -8.3203e-01],
         [-1.5938e+00,  6.0156e-01,  1.6094e+00,  ...,  9.5312e-01,
           2.8125e-01, -1.8438e+00],
         [-3.6719e-01, -1.6992e-01, -8.1250e-01,  ..., -1.0469e+00,
          -1.6895e-01, -4.2188e-01]],

        [[-4.2725e-03,  5.8594e-03,  1.6937e-03,  ...,  5.0049e-03,
          -1.9775e-02, -1.0620e-02],
         [ 3.3203e-01,  6.7969e-01,  8.1641e-01,  ...,  3.4375e-01,
          -1.2207e-01, -5.8984e-01],
         [-2.0215e-01,  7.8516e-01, -8.2812e-01,  ...,  1.2969e+00,
          -4.8047e-01, -6.5234e-01],
         ...,
         [-1.4531e+00,  1.0703e+00,  7.7148e-02,  ...,  2.0156e+00,
           3.0938e+00, -8.3203e-01],
         [-1.5938e+00,  6.0156e-01,  1.6094e+00,  ...,  9.5312e-01,
           2.8125e-01, -1.8438e+00],
         [-3.6719e-01, -1.6992e-01, -8.1250e-01,  ..., -1.0469e+00,
          -1.6895e-01, -4.2188e-01]],

        [[-4.2725e-03,  5.8594e-03,  1.6937e-03,  ...,  5.0049e-03,
          -1.9775e-02, -1.0620e-02],
         [ 3.3203e-01,  6.7969e-01,  8.1641e-01,  ...,  3.4375e-01,
          -1.2207e-01, -5.8984e-01],
         [-2.0215e-01,  7.8516e-01, -8.2812e-01,  ...,  1.2969e+00,
          -4.8047e-01, -6.5234e-01],
         ...,
         [-1.4531e+00,  1.0703e+00,  7.7148e-02,  ...,  2.0156e+00,
           3.0938e+00, -8.3203e-01],
         [-1.5938e+00,  6.0156e-01,  1.6094e+00,  ...,  9.5312e-01,
           2.8125e-01, -1.8438e+00],
         [-3.6719e-01, -1.6992e-01, -8.1250e-01,  ..., -1.0469e+00,
          -1.6895e-01, -4.2188e-01]],

        ...,

        [[ 2.7588e-02, -9.9487e-03, -1.7700e-02,  ..., -6.4087e-03,
           1.4038e-03,  5.3711e-03],
         [ 2.7930e-01, -2.5977e-01,  7.3047e-01,  ...,  8.4961e-02,
           3.4570e-01, -1.5332e-01],
         [-8.1177e-03,  1.7578e+00, -4.1016e-01,  ...,  3.9062e-01,
           3.3398e-01,  4.3750e-01],
         ...,
         [-2.8125e-01, -3.3789e-01,  4.7266e-01,  ...,  3.8477e-01,
          -9.4141e-01,  2.2949e-01],
         [ 9.4531e-01, -1.5312e+00,  1.1797e+00,  ...,  3.3906e+00,
           1.2266e+00, -4.6484e-01],
         [ 4.1211e-01, -7.2656e-01,  4.3359e-01,  ...,  8.6060e-03,
          -8.3594e-01,  2.0020e-01]],

        [[ 2.7588e-02, -9.9487e-03, -1.7700e-02,  ..., -6.4087e-03,
           1.4038e-03,  5.3711e-03],
         [ 2.7930e-01, -2.5977e-01,  7.3047e-01,  ...,  8.4961e-02,
           3.4570e-01, -1.5332e-01],
         [-8.1177e-03,  1.7578e+00, -4.1016e-01,  ...,  3.9062e-01,
           3.3398e-01,  4.3750e-01],
         ...,
         [-2.8125e-01, -3.3789e-01,  4.7266e-01,  ...,  3.8477e-01,
          -9.4141e-01,  2.2949e-01],
         [ 9.4531e-01, -1.5312e+00,  1.1797e+00,  ...,  3.3906e+00,
           1.2266e+00, -4.6484e-01],
         [ 4.1211e-01, -7.2656e-01,  4.3359e-01,  ...,  8.6060e-03,
          -8.3594e-01,  2.0020e-01]],

        [[ 2.7588e-02, -9.9487e-03, -1.7700e-02,  ..., -6.4087e-03,
           1.4038e-03,  5.3711e-03],
         [ 2.7930e-01, -2.5977e-01,  7.3047e-01,  ...,  8.4961e-02,
           3.4570e-01, -1.5332e-01],
         [-8.1177e-03,  1.7578e+00, -4.1016e-01,  ...,  3.9062e-01,
           3.3398e-01,  4.3750e-01],
         ...,
         [-2.8125e-01, -3.3789e-01,  4.7266e-01,  ...,  3.8477e-01,
          -9.4141e-01,  2.2949e-01],
         [ 9.4531e-01, -1.5312e+00,  1.1797e+00,  ...,  3.3906e+00,
           1.2266e+00, -4.6484e-01],
         [ 4.1211e-01, -7.2656e-01,  4.3359e-01,  ...,  8.6060e-03,
          -8.3594e-01,  2.0020e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-2.3346e-03,  9.5215e-03, -9.2163e-03,  ..., -1.6602e-01,
          -1.3672e-01,  2.4414e-01],
         [ 2.0312e+00, -6.0156e-01, -2.6953e-01,  ...,  1.7734e+00,
           7.8516e-01, -8.5547e-01],
         [ 1.6016e+00, -2.7734e-01, -2.1289e-01,  ...,  2.1406e+00,
           1.4551e-01, -5.0000e-01],
         ...,
         [ 9.1406e-01, -3.8672e-01,  7.9297e-01,  ...,  1.3281e-01,
          -1.9434e-01,  3.5547e-01],
         [ 5.8594e-01, -3.4766e-01,  7.6172e-02,  ...,  1.5820e-01,
           1.6875e+00,  1.0312e+00],
         [ 1.1875e+00,  9.2969e-01,  1.3867e-01,  ...,  7.6172e-01,
           8.8672e-01, -2.7344e-01]],

        [[-2.3346e-03,  9.5215e-03, -9.2163e-03,  ..., -1.6602e-01,
          -1.3672e-01,  2.4414e-01],
         [ 2.0312e+00, -6.0156e-01, -2.6953e-01,  ...,  1.7734e+00,
           7.8516e-01, -8.5547e-01],
         [ 1.6016e+00, -2.7734e-01, -2.1289e-01,  ...,  2.1406e+00,
           1.4551e-01, -5.0000e-01],
         ...,
         [ 9.1406e-01, -3.8672e-01,  7.9297e-01,  ...,  1.3281e-01,
          -1.9434e-01,  3.5547e-01],
         [ 5.8594e-01, -3.4766e-01,  7.6172e-02,  ...,  1.5820e-01,
           1.6875e+00,  1.0312e+00],
         [ 1.1875e+00,  9.2969e-01,  1.3867e-01,  ...,  7.6172e-01,
           8.8672e-01, -2.7344e-01]],

        [[-2.3346e-03,  9.5215e-03, -9.2163e-03,  ..., -1.6602e-01,
          -1.3672e-01,  2.4414e-01],
         [ 2.0312e+00, -6.0156e-01, -2.6953e-01,  ...,  1.7734e+00,
           7.8516e-01, -8.5547e-01],
         [ 1.6016e+00, -2.7734e-01, -2.1289e-01,  ...,  2.1406e+00,
           1.4551e-01, -5.0000e-01],
         ...,
         [ 9.1406e-01, -3.8672e-01,  7.9297e-01,  ...,  1.3281e-01,
          -1.9434e-01,  3.5547e-01],
         [ 5.8594e-01, -3.4766e-01,  7.6172e-02,  ...,  1.5820e-01,
           1.6875e+00,  1.0312e+00],
         [ 1.1875e+00,  9.2969e-01,  1.3867e-01,  ...,  7.6172e-01,
           8.8672e-01, -2.7344e-01]],

        ...,

        [[ 1.5137e-02, -7.3547e-03, -4.0894e-03,  ..., -2.2852e-01,
           8.4375e-01, -1.9141e-01],
         [-1.2422e+00, -1.0703e+00,  6.9922e-01,  ..., -3.3008e-01,
           1.2266e+00,  3.3125e+00],
         [-9.8438e-01, -7.5391e-01,  7.5000e-01,  ..., -1.3184e-01,
           2.2266e-01, -2.5938e+00],
         ...,
         [-1.0312e+00,  3.0664e-01,  1.0625e+00,  ...,  3.0000e+00,
          -5.0391e-01,  7.0312e-01],
         [-8.0469e-01,  1.0781e+00,  5.2734e-01,  ...,  3.8281e-01,
          -6.5430e-02,  5.4297e-01],
         [ 1.9434e-01,  1.3750e+00, -8.7891e-02,  ..., -9.5703e-01,
          -6.0547e-01, -1.4141e+00]],

        [[ 1.5137e-02, -7.3547e-03, -4.0894e-03,  ..., -2.2852e-01,
           8.4375e-01, -1.9141e-01],
         [-1.2422e+00, -1.0703e+00,  6.9922e-01,  ..., -3.3008e-01,
           1.2266e+00,  3.3125e+00],
         [-9.8438e-01, -7.5391e-01,  7.5000e-01,  ..., -1.3184e-01,
           2.2266e-01, -2.5938e+00],
         ...,
         [-1.0312e+00,  3.0664e-01,  1.0625e+00,  ...,  3.0000e+00,
          -5.0391e-01,  7.0312e-01],
         [-8.0469e-01,  1.0781e+00,  5.2734e-01,  ...,  3.8281e-01,
          -6.5430e-02,  5.4297e-01],
         [ 1.9434e-01,  1.3750e+00, -8.7891e-02,  ..., -9.5703e-01,
          -6.0547e-01, -1.4141e+00]],

        [[ 1.5137e-02, -7.3547e-03, -4.0894e-03,  ..., -2.2852e-01,
           8.4375e-01, -1.9141e-01],
         [-1.2422e+00, -1.0703e+00,  6.9922e-01,  ..., -3.3008e-01,
           1.2266e+00,  3.3125e+00],
         [-9.8438e-01, -7.5391e-01,  7.5000e-01,  ..., -1.3184e-01,
           2.2266e-01, -2.5938e+00],
         ...,
         [-1.0312e+00,  3.0664e-01,  1.0625e+00,  ...,  3.0000e+00,
          -5.0391e-01,  7.0312e-01],
         [-8.0469e-01,  1.0781e+00,  5.2734e-01,  ...,  3.8281e-01,
          -6.5430e-02,  5.4297e-01],
         [ 1.9434e-01,  1.3750e+00, -8.7891e-02,  ..., -9.5703e-01,
          -6.0547e-01, -1.4141e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-7.1106e-03, -1.7944e-02,  7.7820e-03,  ...,  1.2756e-02,
          -1.6724e-02,  4.2725e-03],
         [-1.6797e-01, -6.9922e-01,  4.2969e-02,  ...,  7.4707e-02,
          -1.8066e-01, -3.6133e-02],
         [ 2.7539e-01,  1.1562e+00, -1.1875e+00,  ..., -6.7383e-02,
          -3.1055e-01,  2.6611e-02],
         ...,
         [ 2.0703e-01,  2.5391e-01, -4.3701e-02,  ...,  2.9883e-01,
          -5.7812e-01,  3.8867e-01],
         [-5.6250e-01,  4.2773e-01,  7.8125e-01,  ..., -9.9219e-01,
          -5.0000e-01, -5.6641e-01],
         [ 8.3008e-02,  1.2969e+00,  1.1094e+00,  ...,  1.0469e+00,
           1.1484e+00,  2.0312e+00]],

        [[-7.1106e-03, -1.7944e-02,  7.7820e-03,  ...,  1.2756e-02,
          -1.6724e-02,  4.2725e-03],
         [-1.6797e-01, -6.9922e-01,  4.2969e-02,  ...,  7.4707e-02,
          -1.8066e-01, -3.6133e-02],
         [ 2.7539e-01,  1.1562e+00, -1.1875e+00,  ..., -6.7383e-02,
          -3.1055e-01,  2.6611e-02],
         ...,
         [ 2.0703e-01,  2.5391e-01, -4.3701e-02,  ...,  2.9883e-01,
          -5.7812e-01,  3.8867e-01],
         [-5.6250e-01,  4.2773e-01,  7.8125e-01,  ..., -9.9219e-01,
          -5.0000e-01, -5.6641e-01],
         [ 8.3008e-02,  1.2969e+00,  1.1094e+00,  ...,  1.0469e+00,
           1.1484e+00,  2.0312e+00]],

        [[-7.1106e-03, -1.7944e-02,  7.7820e-03,  ...,  1.2756e-02,
          -1.6724e-02,  4.2725e-03],
         [-1.6797e-01, -6.9922e-01,  4.2969e-02,  ...,  7.4707e-02,
          -1.8066e-01, -3.6133e-02],
         [ 2.7539e-01,  1.1562e+00, -1.1875e+00,  ..., -6.7383e-02,
          -3.1055e-01,  2.6611e-02],
         ...,
         [ 2.0703e-01,  2.5391e-01, -4.3701e-02,  ...,  2.9883e-01,
          -5.7812e-01,  3.8867e-01],
         [-5.6250e-01,  4.2773e-01,  7.8125e-01,  ..., -9.9219e-01,
          -5.0000e-01, -5.6641e-01],
         [ 8.3008e-02,  1.2969e+00,  1.1094e+00,  ...,  1.0469e+00,
           1.1484e+00,  2.0312e+00]],

        ...,

        [[-3.6469e-03, -8.3618e-03,  3.2349e-03,  ...,  2.0504e-04,
          -3.1128e-03,  1.9836e-03],
         [-2.4844e+00,  9.6484e-01, -8.0469e-01,  ...,  1.8047e+00,
           2.6562e+00, -2.4688e+00],
         [-8.5938e-01, -1.7344e+00, -9.7266e-01,  ...,  6.6797e-01,
           5.0781e-01, -2.2500e+00],
         ...,
         [ 1.3281e+00, -1.3438e+00, -9.9609e-01,  ...,  7.6562e-01,
          -1.2188e+00,  5.0000e-01],
         [-5.9375e-01, -3.8086e-01, -1.7734e+00,  ...,  1.4922e+00,
           5.1758e-02, -1.1016e+00],
         [-2.2656e+00,  2.6367e-01, -3.0156e+00,  ...,  1.0391e+00,
          -1.9688e+00, -7.5000e-01]],

        [[-3.6469e-03, -8.3618e-03,  3.2349e-03,  ...,  2.0504e-04,
          -3.1128e-03,  1.9836e-03],
         [-2.4844e+00,  9.6484e-01, -8.0469e-01,  ...,  1.8047e+00,
           2.6562e+00, -2.4688e+00],
         [-8.5938e-01, -1.7344e+00, -9.7266e-01,  ...,  6.6797e-01,
           5.0781e-01, -2.2500e+00],
         ...,
         [ 1.3281e+00, -1.3438e+00, -9.9609e-01,  ...,  7.6562e-01,
          -1.2188e+00,  5.0000e-01],
         [-5.9375e-01, -3.8086e-01, -1.7734e+00,  ...,  1.4922e+00,
           5.1758e-02, -1.1016e+00],
         [-2.2656e+00,  2.6367e-01, -3.0156e+00,  ...,  1.0391e+00,
          -1.9688e+00, -7.5000e-01]],

        [[-3.6469e-03, -8.3618e-03,  3.2349e-03,  ...,  2.0504e-04,
          -3.1128e-03,  1.9836e-03],
         [-2.4844e+00,  9.6484e-01, -8.0469e-01,  ...,  1.8047e+00,
           2.6562e+00, -2.4688e+00],
         [-8.5938e-01, -1.7344e+00, -9.7266e-01,  ...,  6.6797e-01,
           5.0781e-01, -2.2500e+00],
         ...,
         [ 1.3281e+00, -1.3438e+00, -9.9609e-01,  ...,  7.6562e-01,
          -1.2188e+00,  5.0000e-01],
         [-5.9375e-01, -3.8086e-01, -1.7734e+00,  ...,  1.4922e+00,
           5.1758e-02, -1.1016e+00],
         [-2.2656e+00,  2.6367e-01, -3.0156e+00,  ...,  1.0391e+00,
          -1.9688e+00, -7.5000e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 6.3324e-04,  3.0518e-03, -2.5635e-03,  ..., -1.9043e-01,
          -1.2695e-01,  1.2598e-01],
         [ 0.0000e+00,  9.1406e-01, -2.5000e+00,  ...,  1.4062e-01,
           6.0938e-01, -8.7891e-01],
         [-1.0078e+00,  1.2500e+00, -2.9492e-01,  ..., -2.1680e-01,
          -4.3555e-01,  2.7734e-01],
         ...,
         [-1.2695e-01,  1.9336e-01, -1.1475e-01,  ..., -6.9336e-02,
          -7.3828e-01, -6.5234e-01],
         [-2.9492e-01,  2.8906e-01,  7.0312e-01,  ..., -1.9453e+00,
          -6.0938e-01,  8.4375e-01],
         [-1.2266e+00, -3.7891e-01,  4.9219e-01,  ..., -2.3535e-01,
          -1.6797e+00, -1.3906e+00]],

        [[ 6.3324e-04,  3.0518e-03, -2.5635e-03,  ..., -1.9043e-01,
          -1.2695e-01,  1.2598e-01],
         [ 0.0000e+00,  9.1406e-01, -2.5000e+00,  ...,  1.4062e-01,
           6.0938e-01, -8.7891e-01],
         [-1.0078e+00,  1.2500e+00, -2.9492e-01,  ..., -2.1680e-01,
          -4.3555e-01,  2.7734e-01],
         ...,
         [-1.2695e-01,  1.9336e-01, -1.1475e-01,  ..., -6.9336e-02,
          -7.3828e-01, -6.5234e-01],
         [-2.9492e-01,  2.8906e-01,  7.0312e-01,  ..., -1.9453e+00,
          -6.0938e-01,  8.4375e-01],
         [-1.2266e+00, -3.7891e-01,  4.9219e-01,  ..., -2.3535e-01,
          -1.6797e+00, -1.3906e+00]],

        [[ 6.3324e-04,  3.0518e-03, -2.5635e-03,  ..., -1.9043e-01,
          -1.2695e-01,  1.2598e-01],
         [ 0.0000e+00,  9.1406e-01, -2.5000e+00,  ...,  1.4062e-01,
           6.0938e-01, -8.7891e-01],
         [-1.0078e+00,  1.2500e+00, -2.9492e-01,  ..., -2.1680e-01,
          -4.3555e-01,  2.7734e-01],
         ...,
         [-1.2695e-01,  1.9336e-01, -1.1475e-01,  ..., -6.9336e-02,
          -7.3828e-01, -6.5234e-01],
         [-2.9492e-01,  2.8906e-01,  7.0312e-01,  ..., -1.9453e+00,
          -6.0938e-01,  8.4375e-01],
         [-1.2266e+00, -3.7891e-01,  4.9219e-01,  ..., -2.3535e-01,
          -1.6797e+00, -1.3906e+00]],

        ...,

        [[-1.1978e-03, -1.1841e-02, -8.0566e-03,  ...,  1.3965e-01,
           5.9082e-02,  1.1865e-01],
         [-2.5781e+00, -1.0078e+00,  1.7578e+00,  ..., -2.5156e+00,
           2.7734e-01, -6.7188e-01],
         [-1.5938e+00, -1.5312e+00,  1.4609e+00,  ..., -3.6719e-01,
           1.8125e+00, -1.7188e+00],
         ...,
         [-1.1094e+00,  1.3359e+00,  7.1875e-01,  ...,  7.3047e-01,
          -2.3906e+00,  2.0312e+00],
         [-1.2031e+00,  1.4297e+00,  8.4766e-01,  ...,  1.7812e+00,
          -5.7422e-01,  2.2969e+00],
         [-5.5078e-01,  2.2500e+00,  1.7266e+00,  ..., -5.8203e-01,
          -1.6211e-01,  9.9609e-02]],

        [[-1.1978e-03, -1.1841e-02, -8.0566e-03,  ...,  1.3965e-01,
           5.9082e-02,  1.1865e-01],
         [-2.5781e+00, -1.0078e+00,  1.7578e+00,  ..., -2.5156e+00,
           2.7734e-01, -6.7188e-01],
         [-1.5938e+00, -1.5312e+00,  1.4609e+00,  ..., -3.6719e-01,
           1.8125e+00, -1.7188e+00],
         ...,
         [-1.1094e+00,  1.3359e+00,  7.1875e-01,  ...,  7.3047e-01,
          -2.3906e+00,  2.0312e+00],
         [-1.2031e+00,  1.4297e+00,  8.4766e-01,  ...,  1.7812e+00,
          -5.7422e-01,  2.2969e+00],
         [-5.5078e-01,  2.2500e+00,  1.7266e+00,  ..., -5.8203e-01,
          -1.6211e-01,  9.9609e-02]],

        [[-1.1978e-03, -1.1841e-02, -8.0566e-03,  ...,  1.3965e-01,
           5.9082e-02,  1.1865e-01],
         [-2.5781e+00, -1.0078e+00,  1.7578e+00,  ..., -2.5156e+00,
           2.7734e-01, -6.7188e-01],
         [-1.5938e+00, -1.5312e+00,  1.4609e+00,  ..., -3.6719e-01,
           1.8125e+00, -1.7188e+00],
         ...,
         [-1.1094e+00,  1.3359e+00,  7.1875e-01,  ...,  7.3047e-01,
          -2.3906e+00,  2.0312e+00],
         [-1.2031e+00,  1.4297e+00,  8.4766e-01,  ...,  1.7812e+00,
          -5.7422e-01,  2.2969e+00],
         [-5.5078e-01,  2.2500e+00,  1.7266e+00,  ..., -5.8203e-01,
          -1.6211e-01,  9.9609e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 2.4658e-02, -2.5558e-04, -4.5166e-03,  ..., -9.0942e-03,
           1.7822e-02, -9.8877e-03],
         [-3.9648e-01,  2.2656e-01,  2.7344e-01,  ..., -8.3984e-01,
           4.1602e-01, -3.8086e-01],
         [-1.7383e-01,  4.6094e-01, -1.5859e+00,  ...,  3.0469e-01,
          -1.4375e+00, -8.2812e-01],
         ...,
         [ 7.5781e-01,  2.2344e+00,  2.4062e+00,  ..., -2.7222e-02,
           1.5000e+00,  7.8516e-01],
         [ 1.0234e+00,  1.5312e+00,  8.7891e-03,  ..., -4.0234e-01,
           1.0078e+00, -3.8330e-02],
         [-9.5703e-01,  5.8594e-01, -5.7422e-01,  ...,  4.8828e-01,
           4.1016e-01, -6.6016e-01]],

        [[ 2.4658e-02, -2.5558e-04, -4.5166e-03,  ..., -9.0942e-03,
           1.7822e-02, -9.8877e-03],
         [-3.9648e-01,  2.2656e-01,  2.7344e-01,  ..., -8.3984e-01,
           4.1602e-01, -3.8086e-01],
         [-1.7383e-01,  4.6094e-01, -1.5859e+00,  ...,  3.0469e-01,
          -1.4375e+00, -8.2812e-01],
         ...,
         [ 7.5781e-01,  2.2344e+00,  2.4062e+00,  ..., -2.7222e-02,
           1.5000e+00,  7.8516e-01],
         [ 1.0234e+00,  1.5312e+00,  8.7891e-03,  ..., -4.0234e-01,
           1.0078e+00, -3.8330e-02],
         [-9.5703e-01,  5.8594e-01, -5.7422e-01,  ...,  4.8828e-01,
           4.1016e-01, -6.6016e-01]],

        [[ 2.4658e-02, -2.5558e-04, -4.5166e-03,  ..., -9.0942e-03,
           1.7822e-02, -9.8877e-03],
         [-3.9648e-01,  2.2656e-01,  2.7344e-01,  ..., -8.3984e-01,
           4.1602e-01, -3.8086e-01],
         [-1.7383e-01,  4.6094e-01, -1.5859e+00,  ...,  3.0469e-01,
          -1.4375e+00, -8.2812e-01],
         ...,
         [ 7.5781e-01,  2.2344e+00,  2.4062e+00,  ..., -2.7222e-02,
           1.5000e+00,  7.8516e-01],
         [ 1.0234e+00,  1.5312e+00,  8.7891e-03,  ..., -4.0234e-01,
           1.0078e+00, -3.8330e-02],
         [-9.5703e-01,  5.8594e-01, -5.7422e-01,  ...,  4.8828e-01,
           4.1016e-01, -6.6016e-01]],

        ...,

        [[-1.2085e-02, -7.4463e-03, -1.9836e-03,  ..., -1.0315e-02,
          -3.5858e-03, -1.2756e-02],
         [-3.5742e-01, -1.9336e-01, -2.2656e-01,  ..., -8.3984e-01,
           5.3906e-01, -1.9141e+00],
         [-9.7656e-03, -3.3594e-01,  3.7109e-02,  ...,  8.7109e-01,
          -7.7734e-01, -7.7734e-01],
         ...,
         [ 6.2109e-01,  8.6328e-01, -1.0625e+00,  ..., -5.8203e-01,
           1.2891e+00,  4.4678e-02],
         [ 7.6660e-02, -1.6699e-01, -1.4062e+00,  ...,  5.7422e-01,
           8.7891e-01,  2.0215e-01],
         [-3.9062e-01, -1.0156e+00, -8.7500e-01,  ..., -6.6406e-01,
           9.0234e-01, -2.3071e-02]],

        [[-1.2085e-02, -7.4463e-03, -1.9836e-03,  ..., -1.0315e-02,
          -3.5858e-03, -1.2756e-02],
         [-3.5742e-01, -1.9336e-01, -2.2656e-01,  ..., -8.3984e-01,
           5.3906e-01, -1.9141e+00],
         [-9.7656e-03, -3.3594e-01,  3.7109e-02,  ...,  8.7109e-01,
          -7.7734e-01, -7.7734e-01],
         ...,
         [ 6.2109e-01,  8.6328e-01, -1.0625e+00,  ..., -5.8203e-01,
           1.2891e+00,  4.4678e-02],
         [ 7.6660e-02, -1.6699e-01, -1.4062e+00,  ...,  5.7422e-01,
           8.7891e-01,  2.0215e-01],
         [-3.9062e-01, -1.0156e+00, -8.7500e-01,  ..., -6.6406e-01,
           9.0234e-01, -2.3071e-02]],

        [[-1.2085e-02, -7.4463e-03, -1.9836e-03,  ..., -1.0315e-02,
          -3.5858e-03, -1.2756e-02],
         [-3.5742e-01, -1.9336e-01, -2.2656e-01,  ..., -8.3984e-01,
           5.3906e-01, -1.9141e+00],
         [-9.7656e-03, -3.3594e-01,  3.7109e-02,  ...,  8.7109e-01,
          -7.7734e-01, -7.7734e-01],
         ...,
         [ 6.2109e-01,  8.6328e-01, -1.0625e+00,  ..., -5.8203e-01,
           1.2891e+00,  4.4678e-02],
         [ 7.6660e-02, -1.6699e-01, -1.4062e+00,  ...,  5.7422e-01,
           8.7891e-01,  2.0215e-01],
         [-3.9062e-01, -1.0156e+00, -8.7500e-01,  ..., -6.6406e-01,
           9.0234e-01, -2.3071e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-3.1433e-03,  4.5776e-03,  1.3306e-02,  ..., -1.6016e-01,
           2.3535e-01,  3.0273e-01],
         [-3.1719e+00, -1.2656e+00, -1.2812e+00,  ...,  2.9883e-01,
          -2.4805e-01, -5.2734e-01],
         [-1.9141e-01, -8.1250e-01, -4.2969e-01,  ..., -3.5889e-02,
          -1.2031e+00, -1.0234e+00],
         ...,
         [-1.8594e+00,  1.0703e+00, -1.7188e-01,  ...,  1.1484e+00,
          -7.7344e-01, -9.6484e-01],
         [-6.2891e-01,  2.5781e+00, -1.2500e+00,  ...,  1.0312e+00,
          -3.5938e-01, -2.7344e-01],
         [ 2.9219e+00,  1.4531e+00, -1.7578e+00,  ...,  1.1406e+00,
           1.0234e+00, -1.8672e+00]],

        [[-3.1433e-03,  4.5776e-03,  1.3306e-02,  ..., -1.6016e-01,
           2.3535e-01,  3.0273e-01],
         [-3.1719e+00, -1.2656e+00, -1.2812e+00,  ...,  2.9883e-01,
          -2.4805e-01, -5.2734e-01],
         [-1.9141e-01, -8.1250e-01, -4.2969e-01,  ..., -3.5889e-02,
          -1.2031e+00, -1.0234e+00],
         ...,
         [-1.8594e+00,  1.0703e+00, -1.7188e-01,  ...,  1.1484e+00,
          -7.7344e-01, -9.6484e-01],
         [-6.2891e-01,  2.5781e+00, -1.2500e+00,  ...,  1.0312e+00,
          -3.5938e-01, -2.7344e-01],
         [ 2.9219e+00,  1.4531e+00, -1.7578e+00,  ...,  1.1406e+00,
           1.0234e+00, -1.8672e+00]],

        [[-3.1433e-03,  4.5776e-03,  1.3306e-02,  ..., -1.6016e-01,
           2.3535e-01,  3.0273e-01],
         [-3.1719e+00, -1.2656e+00, -1.2812e+00,  ...,  2.9883e-01,
          -2.4805e-01, -5.2734e-01],
         [-1.9141e-01, -8.1250e-01, -4.2969e-01,  ..., -3.5889e-02,
          -1.2031e+00, -1.0234e+00],
         ...,
         [-1.8594e+00,  1.0703e+00, -1.7188e-01,  ...,  1.1484e+00,
          -7.7344e-01, -9.6484e-01],
         [-6.2891e-01,  2.5781e+00, -1.2500e+00,  ...,  1.0312e+00,
          -3.5938e-01, -2.7344e-01],
         [ 2.9219e+00,  1.4531e+00, -1.7578e+00,  ...,  1.1406e+00,
           1.0234e+00, -1.8672e+00]],

        ...,

        [[ 6.1798e-04,  9.0027e-04, -2.0996e-02,  ...,  1.0681e-03,
          -4.8828e-02, -1.9336e-01],
         [-3.7031e+00, -1.8047e+00, -6.6406e-02,  ...,  2.6562e-01,
           1.8594e+00,  1.1914e-01],
         [-7.6953e-01, -1.3828e+00,  2.4414e-02,  ...,  5.9375e-01,
          -1.5391e+00, -1.1484e+00],
         ...,
         [-2.1719e+00,  1.4062e+00, -5.5469e-01,  ...,  7.3047e-01,
           8.5938e-01, -3.2031e-01],
         [-6.4062e-01,  1.2500e+00, -9.1406e-01,  ...,  3.3984e-01,
          -2.2188e+00, -9.2969e-01],
         [ 2.3750e+00,  1.8594e+00, -1.6797e+00,  ...,  1.4609e+00,
           2.5391e-01, -1.1406e+00]],

        [[ 6.1798e-04,  9.0027e-04, -2.0996e-02,  ...,  1.0681e-03,
          -4.8828e-02, -1.9336e-01],
         [-3.7031e+00, -1.8047e+00, -6.6406e-02,  ...,  2.6562e-01,
           1.8594e+00,  1.1914e-01],
         [-7.6953e-01, -1.3828e+00,  2.4414e-02,  ...,  5.9375e-01,
          -1.5391e+00, -1.1484e+00],
         ...,
         [-2.1719e+00,  1.4062e+00, -5.5469e-01,  ...,  7.3047e-01,
           8.5938e-01, -3.2031e-01],
         [-6.4062e-01,  1.2500e+00, -9.1406e-01,  ...,  3.3984e-01,
          -2.2188e+00, -9.2969e-01],
         [ 2.3750e+00,  1.8594e+00, -1.6797e+00,  ...,  1.4609e+00,
           2.5391e-01, -1.1406e+00]],

        [[ 6.1798e-04,  9.0027e-04, -2.0996e-02,  ...,  1.0681e-03,
          -4.8828e-02, -1.9336e-01],
         [-3.7031e+00, -1.8047e+00, -6.6406e-02,  ...,  2.6562e-01,
           1.8594e+00,  1.1914e-01],
         [-7.6953e-01, -1.3828e+00,  2.4414e-02,  ...,  5.9375e-01,
          -1.5391e+00, -1.1484e+00],
         ...,
         [-2.1719e+00,  1.4062e+00, -5.5469e-01,  ...,  7.3047e-01,
           8.5938e-01, -3.2031e-01],
         [-6.4062e-01,  1.2500e+00, -9.1406e-01,  ...,  3.3984e-01,
          -2.2188e+00, -9.2969e-01],
         [ 2.3750e+00,  1.8594e+00, -1.6797e+00,  ...,  1.4609e+00,
           2.5391e-01, -1.1406e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 1.0925e-02, -2.8229e-03, -1.0071e-02,  ..., -1.0803e-02,
           2.3193e-02, -4.3030e-03],
         [-3.6328e-01,  2.4512e-01, -1.2031e+00,  ..., -1.0498e-02,
           2.5781e-01, -9.3359e-01],
         [-7.7637e-02,  4.9219e-01, -2.6367e-01,  ...,  1.1016e+00,
          -1.1562e+00, -8.4473e-02],
         ...,
         [-1.0625e+00,  1.0703e+00, -4.9023e-01,  ...,  4.3164e-01,
          -7.1094e-01, -4.4141e-01],
         [-2.4902e-01, -3.9648e-01, -8.9355e-02,  ...,  9.7266e-01,
          -1.2812e+00,  1.2266e+00],
         [-6.9531e-01,  2.7539e-01, -6.8359e-01,  ...,  9.1797e-01,
          -1.5156e+00,  8.3594e-01]],

        [[ 1.0925e-02, -2.8229e-03, -1.0071e-02,  ..., -1.0803e-02,
           2.3193e-02, -4.3030e-03],
         [-3.6328e-01,  2.4512e-01, -1.2031e+00,  ..., -1.0498e-02,
           2.5781e-01, -9.3359e-01],
         [-7.7637e-02,  4.9219e-01, -2.6367e-01,  ...,  1.1016e+00,
          -1.1562e+00, -8.4473e-02],
         ...,
         [-1.0625e+00,  1.0703e+00, -4.9023e-01,  ...,  4.3164e-01,
          -7.1094e-01, -4.4141e-01],
         [-2.4902e-01, -3.9648e-01, -8.9355e-02,  ...,  9.7266e-01,
          -1.2812e+00,  1.2266e+00],
         [-6.9531e-01,  2.7539e-01, -6.8359e-01,  ...,  9.1797e-01,
          -1.5156e+00,  8.3594e-01]],

        [[ 1.0925e-02, -2.8229e-03, -1.0071e-02,  ..., -1.0803e-02,
           2.3193e-02, -4.3030e-03],
         [-3.6328e-01,  2.4512e-01, -1.2031e+00,  ..., -1.0498e-02,
           2.5781e-01, -9.3359e-01],
         [-7.7637e-02,  4.9219e-01, -2.6367e-01,  ...,  1.1016e+00,
          -1.1562e+00, -8.4473e-02],
         ...,
         [-1.0625e+00,  1.0703e+00, -4.9023e-01,  ...,  4.3164e-01,
          -7.1094e-01, -4.4141e-01],
         [-2.4902e-01, -3.9648e-01, -8.9355e-02,  ...,  9.7266e-01,
          -1.2812e+00,  1.2266e+00],
         [-6.9531e-01,  2.7539e-01, -6.8359e-01,  ...,  9.1797e-01,
          -1.5156e+00,  8.3594e-01]],

        ...,

        [[-6.8283e-04, -3.7079e-03,  7.7820e-04,  ...,  2.1118e-02,
           6.7444e-03,  5.3406e-03],
         [-2.0508e-01, -2.8516e-01,  4.9609e-01,  ...,  1.5938e+00,
           1.1719e+00,  5.5078e-01],
         [-2.8125e-01, -1.6992e-01, -5.9766e-01,  ...,  4.9414e-01,
           1.4160e-02,  4.6289e-01],
         ...,
         [ 1.2695e-01, -1.2891e+00,  1.0986e-01,  ...,  1.1250e+00,
          -2.4219e+00, -1.2793e-01],
         [ 1.2012e-01,  8.4375e-01,  1.5625e+00,  ..., -9.3750e-02,
          -8.6719e-01,  1.2812e+00],
         [-8.2812e-01, -3.7305e-01,  6.7871e-02,  ..., -4.4336e-01,
           1.5527e-01, -7.6562e-01]],

        [[-6.8283e-04, -3.7079e-03,  7.7820e-04,  ...,  2.1118e-02,
           6.7444e-03,  5.3406e-03],
         [-2.0508e-01, -2.8516e-01,  4.9609e-01,  ...,  1.5938e+00,
           1.1719e+00,  5.5078e-01],
         [-2.8125e-01, -1.6992e-01, -5.9766e-01,  ...,  4.9414e-01,
           1.4160e-02,  4.6289e-01],
         ...,
         [ 1.2695e-01, -1.2891e+00,  1.0986e-01,  ...,  1.1250e+00,
          -2.4219e+00, -1.2793e-01],
         [ 1.2012e-01,  8.4375e-01,  1.5625e+00,  ..., -9.3750e-02,
          -8.6719e-01,  1.2812e+00],
         [-8.2812e-01, -3.7305e-01,  6.7871e-02,  ..., -4.4336e-01,
           1.5527e-01, -7.6562e-01]],

        [[-6.8283e-04, -3.7079e-03,  7.7820e-04,  ...,  2.1118e-02,
           6.7444e-03,  5.3406e-03],
         [-2.0508e-01, -2.8516e-01,  4.9609e-01,  ...,  1.5938e+00,
           1.1719e+00,  5.5078e-01],
         [-2.8125e-01, -1.6992e-01, -5.9766e-01,  ...,  4.9414e-01,
           1.4160e-02,  4.6289e-01],
         ...,
         [ 1.2695e-01, -1.2891e+00,  1.0986e-01,  ...,  1.1250e+00,
          -2.4219e+00, -1.2793e-01],
         [ 1.2012e-01,  8.4375e-01,  1.5625e+00,  ..., -9.3750e-02,
          -8.6719e-01,  1.2812e+00],
         [-8.2812e-01, -3.7305e-01,  6.7871e-02,  ..., -4.4336e-01,
           1.5527e-01, -7.6562e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 1.0437e-02, -6.3782e-03, -7.9346e-03,  ..., -1.8750e+00,
           8.4473e-02,  7.5781e-01],
         [-4.1016e-01,  6.1719e-01,  5.3906e-01,  ...,  1.1016e+00,
           1.3281e-01, -5.6562e+00],
         [ 4.3945e-01, -1.9336e-01,  2.6562e-01,  ...,  4.3438e+00,
          -3.1250e+00, -2.8594e+00],
         ...,
         [-3.6328e-01, -4.7852e-02, -5.0391e-01,  ...,  3.5781e+00,
          -9.4531e-01, -7.7344e-01],
         [ 2.4609e-01,  2.9688e-01, -1.9922e-01,  ...,  4.0000e+00,
           1.8652e-01, -2.3594e+00],
         [-2.8125e-01, -3.0078e-01, -2.6367e-01,  ...,  3.5312e+00,
           6.6406e-01, -2.9531e+00]],

        [[ 1.0437e-02, -6.3782e-03, -7.9346e-03,  ..., -1.8750e+00,
           8.4473e-02,  7.5781e-01],
         [-4.1016e-01,  6.1719e-01,  5.3906e-01,  ...,  1.1016e+00,
           1.3281e-01, -5.6562e+00],
         [ 4.3945e-01, -1.9336e-01,  2.6562e-01,  ...,  4.3438e+00,
          -3.1250e+00, -2.8594e+00],
         ...,
         [-3.6328e-01, -4.7852e-02, -5.0391e-01,  ...,  3.5781e+00,
          -9.4531e-01, -7.7344e-01],
         [ 2.4609e-01,  2.9688e-01, -1.9922e-01,  ...,  4.0000e+00,
           1.8652e-01, -2.3594e+00],
         [-2.8125e-01, -3.0078e-01, -2.6367e-01,  ...,  3.5312e+00,
           6.6406e-01, -2.9531e+00]],

        [[ 1.0437e-02, -6.3782e-03, -7.9346e-03,  ..., -1.8750e+00,
           8.4473e-02,  7.5781e-01],
         [-4.1016e-01,  6.1719e-01,  5.3906e-01,  ...,  1.1016e+00,
           1.3281e-01, -5.6562e+00],
         [ 4.3945e-01, -1.9336e-01,  2.6562e-01,  ...,  4.3438e+00,
          -3.1250e+00, -2.8594e+00],
         ...,
         [-3.6328e-01, -4.7852e-02, -5.0391e-01,  ...,  3.5781e+00,
          -9.4531e-01, -7.7344e-01],
         [ 2.4609e-01,  2.9688e-01, -1.9922e-01,  ...,  4.0000e+00,
           1.8652e-01, -2.3594e+00],
         [-2.8125e-01, -3.0078e-01, -2.6367e-01,  ...,  3.5312e+00,
           6.6406e-01, -2.9531e+00]],

        ...,

        [[-6.0730e-03, -9.9487e-03, -2.1057e-03,  ...,  1.2158e-01,
           2.1484e-01,  6.1646e-03],
         [-1.2656e+00, -9.0625e-01,  1.5625e+00,  ...,  1.6641e+00,
           8.0469e-01, -7.3438e-01],
         [-1.6211e-01,  7.6172e-02,  1.0312e+00,  ...,  5.5625e+00,
           3.5312e+00,  3.0664e-01],
         ...,
         [-8.9062e-01, -7.3438e-01, -8.0469e-01,  ...,  3.4062e+00,
           2.8750e+00,  3.3594e+00],
         [-1.5039e-01,  3.1445e-01,  5.5078e-01,  ...,  4.2500e+00,
           3.8594e+00,  7.0312e+00],
         [-3.2227e-01,  4.5312e-01,  6.8750e-01,  ...,  4.5312e+00,
           1.7969e+00,  3.2812e+00]],

        [[-6.0730e-03, -9.9487e-03, -2.1057e-03,  ...,  1.2158e-01,
           2.1484e-01,  6.1646e-03],
         [-1.2656e+00, -9.0625e-01,  1.5625e+00,  ...,  1.6641e+00,
           8.0469e-01, -7.3438e-01],
         [-1.6211e-01,  7.6172e-02,  1.0312e+00,  ...,  5.5625e+00,
           3.5312e+00,  3.0664e-01],
         ...,
         [-8.9062e-01, -7.3438e-01, -8.0469e-01,  ...,  3.4062e+00,
           2.8750e+00,  3.3594e+00],
         [-1.5039e-01,  3.1445e-01,  5.5078e-01,  ...,  4.2500e+00,
           3.8594e+00,  7.0312e+00],
         [-3.2227e-01,  4.5312e-01,  6.8750e-01,  ...,  4.5312e+00,
           1.7969e+00,  3.2812e+00]],

        [[-6.0730e-03, -9.9487e-03, -2.1057e-03,  ...,  1.2158e-01,
           2.1484e-01,  6.1646e-03],
         [-1.2656e+00, -9.0625e-01,  1.5625e+00,  ...,  1.6641e+00,
           8.0469e-01, -7.3438e-01],
         [-1.6211e-01,  7.6172e-02,  1.0312e+00,  ...,  5.5625e+00,
           3.5312e+00,  3.0664e-01],
         ...,
         [-8.9062e-01, -7.3438e-01, -8.0469e-01,  ...,  3.4062e+00,
           2.8750e+00,  3.3594e+00],
         [-1.5039e-01,  3.1445e-01,  5.5078e-01,  ...,  4.2500e+00,
           3.8594e+00,  7.0312e+00],
         [-3.2227e-01,  4.5312e-01,  6.8750e-01,  ...,  4.5312e+00,
           1.7969e+00,  3.2812e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-1.1658e-02,  3.4332e-03,  1.4099e-02,  ...,  2.7954e-02,
           8.4229e-03, -1.0010e-02],
         [-8.8672e-01,  2.0156e+00, -1.7500e+00,  ..., -9.3750e-01,
           2.1719e+00,  2.9531e+00],
         [ 1.2969e+00, -3.3750e+00,  9.1406e-01,  ...,  8.5156e-01,
           4.4727e-01,  2.2031e+00],
         ...,
         [ 7.6172e-01,  1.3906e+00, -6.0938e-01,  ..., -8.8672e-01,
           6.7578e-01,  3.3691e-02],
         [ 4.3555e-01, -5.9766e-01, -2.5391e-02,  ...,  5.8984e-01,
          -2.7148e-01,  5.9375e-01],
         [ 3.0312e+00, -8.5156e-01,  1.4844e+00,  ...,  1.1875e+00,
          -7.4609e-01, -8.9844e-02]],

        [[-1.1658e-02,  3.4332e-03,  1.4099e-02,  ...,  2.7954e-02,
           8.4229e-03, -1.0010e-02],
         [-8.8672e-01,  2.0156e+00, -1.7500e+00,  ..., -9.3750e-01,
           2.1719e+00,  2.9531e+00],
         [ 1.2969e+00, -3.3750e+00,  9.1406e-01,  ...,  8.5156e-01,
           4.4727e-01,  2.2031e+00],
         ...,
         [ 7.6172e-01,  1.3906e+00, -6.0938e-01,  ..., -8.8672e-01,
           6.7578e-01,  3.3691e-02],
         [ 4.3555e-01, -5.9766e-01, -2.5391e-02,  ...,  5.8984e-01,
          -2.7148e-01,  5.9375e-01],
         [ 3.0312e+00, -8.5156e-01,  1.4844e+00,  ...,  1.1875e+00,
          -7.4609e-01, -8.9844e-02]],

        [[-1.1658e-02,  3.4332e-03,  1.4099e-02,  ...,  2.7954e-02,
           8.4229e-03, -1.0010e-02],
         [-8.8672e-01,  2.0156e+00, -1.7500e+00,  ..., -9.3750e-01,
           2.1719e+00,  2.9531e+00],
         [ 1.2969e+00, -3.3750e+00,  9.1406e-01,  ...,  8.5156e-01,
           4.4727e-01,  2.2031e+00],
         ...,
         [ 7.6172e-01,  1.3906e+00, -6.0938e-01,  ..., -8.8672e-01,
           6.7578e-01,  3.3691e-02],
         [ 4.3555e-01, -5.9766e-01, -2.5391e-02,  ...,  5.8984e-01,
          -2.7148e-01,  5.9375e-01],
         [ 3.0312e+00, -8.5156e-01,  1.4844e+00,  ...,  1.1875e+00,
          -7.4609e-01, -8.9844e-02]],

        ...,

        [[ 2.0599e-03,  1.2146e-02, -2.2095e-02,  ...,  1.0620e-02,
          -2.1057e-03,  3.3264e-03],
         [ 2.0000e+00,  6.7578e-01,  7.5391e-01,  ...,  1.5469e+00,
           4.0527e-02, -1.3438e+00],
         [ 4.1797e-01,  1.6406e+00,  1.8672e+00,  ...,  1.5234e+00,
          -2.3281e+00, -2.5625e+00],
         ...,
         [ 4.9062e+00,  3.6250e+00,  2.4531e+00,  ...,  1.8828e+00,
          -5.5000e+00, -2.2500e+00],
         [ 5.8438e+00,  3.8438e+00,  3.3281e+00,  ...,  3.5156e+00,
          -5.9375e+00, -3.8438e+00],
         [ 3.6875e+00,  2.4062e+00,  1.6172e+00,  ...,  3.0469e+00,
          -4.8438e+00, -2.1562e+00]],

        [[ 2.0599e-03,  1.2146e-02, -2.2095e-02,  ...,  1.0620e-02,
          -2.1057e-03,  3.3264e-03],
         [ 2.0000e+00,  6.7578e-01,  7.5391e-01,  ...,  1.5469e+00,
           4.0527e-02, -1.3438e+00],
         [ 4.1797e-01,  1.6406e+00,  1.8672e+00,  ...,  1.5234e+00,
          -2.3281e+00, -2.5625e+00],
         ...,
         [ 4.9062e+00,  3.6250e+00,  2.4531e+00,  ...,  1.8828e+00,
          -5.5000e+00, -2.2500e+00],
         [ 5.8438e+00,  3.8438e+00,  3.3281e+00,  ...,  3.5156e+00,
          -5.9375e+00, -3.8438e+00],
         [ 3.6875e+00,  2.4062e+00,  1.6172e+00,  ...,  3.0469e+00,
          -4.8438e+00, -2.1562e+00]],

        [[ 2.0599e-03,  1.2146e-02, -2.2095e-02,  ...,  1.0620e-02,
          -2.1057e-03,  3.3264e-03],
         [ 2.0000e+00,  6.7578e-01,  7.5391e-01,  ...,  1.5469e+00,
           4.0527e-02, -1.3438e+00],
         [ 4.1797e-01,  1.6406e+00,  1.8672e+00,  ...,  1.5234e+00,
          -2.3281e+00, -2.5625e+00],
         ...,
         [ 4.9062e+00,  3.6250e+00,  2.4531e+00,  ...,  1.8828e+00,
          -5.5000e+00, -2.2500e+00],
         [ 5.8438e+00,  3.8438e+00,  3.3281e+00,  ...,  3.5156e+00,
          -5.9375e+00, -3.8438e+00],
         [ 3.6875e+00,  2.4062e+00,  1.6172e+00,  ...,  3.0469e+00,
          -4.8438e+00, -2.1562e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 8.5831e-04, -4.5166e-03,  2.7008e-03,  ...,  4.5166e-02,
           3.1055e-01, -8.4961e-02],
         [ 3.1250e+00,  2.2656e+00, -6.4062e-01,  ...,  6.3281e-01,
          -5.9375e-01,  1.2500e+00],
         [ 1.9375e+00,  1.9688e+00,  1.0625e+00,  ...,  5.1953e-01,
           5.8105e-02,  9.5312e-01],
         ...,
         [ 1.2969e+00, -5.9375e-01,  1.4062e+00,  ..., -2.9541e-02,
           8.9844e-01,  2.5195e-01],
         [ 2.3438e+00, -1.2969e+00, -2.4219e-01,  ...,  3.6914e-01,
          -3.2959e-02, -3.6328e-01],
         [ 9.8047e-01, -1.1484e+00, -1.8750e-01,  ...,  7.5000e-01,
          -1.0625e+00, -4.8047e-01]],

        [[ 8.5831e-04, -4.5166e-03,  2.7008e-03,  ...,  4.5166e-02,
           3.1055e-01, -8.4961e-02],
         [ 3.1250e+00,  2.2656e+00, -6.4062e-01,  ...,  6.3281e-01,
          -5.9375e-01,  1.2500e+00],
         [ 1.9375e+00,  1.9688e+00,  1.0625e+00,  ...,  5.1953e-01,
           5.8105e-02,  9.5312e-01],
         ...,
         [ 1.2969e+00, -5.9375e-01,  1.4062e+00,  ..., -2.9541e-02,
           8.9844e-01,  2.5195e-01],
         [ 2.3438e+00, -1.2969e+00, -2.4219e-01,  ...,  3.6914e-01,
          -3.2959e-02, -3.6328e-01],
         [ 9.8047e-01, -1.1484e+00, -1.8750e-01,  ...,  7.5000e-01,
          -1.0625e+00, -4.8047e-01]],

        [[ 8.5831e-04, -4.5166e-03,  2.7008e-03,  ...,  4.5166e-02,
           3.1055e-01, -8.4961e-02],
         [ 3.1250e+00,  2.2656e+00, -6.4062e-01,  ...,  6.3281e-01,
          -5.9375e-01,  1.2500e+00],
         [ 1.9375e+00,  1.9688e+00,  1.0625e+00,  ...,  5.1953e-01,
           5.8105e-02,  9.5312e-01],
         ...,
         [ 1.2969e+00, -5.9375e-01,  1.4062e+00,  ..., -2.9541e-02,
           8.9844e-01,  2.5195e-01],
         [ 2.3438e+00, -1.2969e+00, -2.4219e-01,  ...,  3.6914e-01,
          -3.2959e-02, -3.6328e-01],
         [ 9.8047e-01, -1.1484e+00, -1.8750e-01,  ...,  7.5000e-01,
          -1.0625e+00, -4.8047e-01]],

        ...,

        [[ 3.2959e-03, -2.9755e-03, -6.4392e-03,  ...,  2.6562e-01,
           4.3945e-01,  1.4551e-01],
         [-3.2031e+00,  1.4766e+00, -1.5938e+00,  ..., -3.1641e-01,
          -1.6484e+00, -4.7852e-01],
         [ 2.7344e-02,  1.8262e-01, -1.2031e+00,  ..., -1.3359e+00,
          -2.0781e+00,  1.5938e+00],
         ...,
         [-2.4062e+00, -1.2812e+00, -9.6484e-01,  ..., -9.5703e-01,
          -1.3906e+00,  5.3125e-01],
         [-2.4219e-01, -1.3047e+00, -9.2578e-01,  ...,  8.5938e-01,
          -2.3438e+00,  1.2891e+00],
         [ 2.6719e+00, -1.6016e+00, -2.4219e+00,  ..., -1.0156e+00,
          -2.2188e+00, -1.3867e-01]],

        [[ 3.2959e-03, -2.9755e-03, -6.4392e-03,  ...,  2.6562e-01,
           4.3945e-01,  1.4551e-01],
         [-3.2031e+00,  1.4766e+00, -1.5938e+00,  ..., -3.1641e-01,
          -1.6484e+00, -4.7852e-01],
         [ 2.7344e-02,  1.8262e-01, -1.2031e+00,  ..., -1.3359e+00,
          -2.0781e+00,  1.5938e+00],
         ...,
         [-2.4062e+00, -1.2812e+00, -9.6484e-01,  ..., -9.5703e-01,
          -1.3906e+00,  5.3125e-01],
         [-2.4219e-01, -1.3047e+00, -9.2578e-01,  ...,  8.5938e-01,
          -2.3438e+00,  1.2891e+00],
         [ 2.6719e+00, -1.6016e+00, -2.4219e+00,  ..., -1.0156e+00,
          -2.2188e+00, -1.3867e-01]],

        [[ 3.2959e-03, -2.9755e-03, -6.4392e-03,  ...,  2.6562e-01,
           4.3945e-01,  1.4551e-01],
         [-3.2031e+00,  1.4766e+00, -1.5938e+00,  ..., -3.1641e-01,
          -1.6484e+00, -4.7852e-01],
         [ 2.7344e-02,  1.8262e-01, -1.2031e+00,  ..., -1.3359e+00,
          -2.0781e+00,  1.5938e+00],
         ...,
         [-2.4062e+00, -1.2812e+00, -9.6484e-01,  ..., -9.5703e-01,
          -1.3906e+00,  5.3125e-01],
         [-2.4219e-01, -1.3047e+00, -9.2578e-01,  ...,  8.5938e-01,
          -2.3438e+00,  1.2891e+00],
         [ 2.6719e+00, -1.6016e+00, -2.4219e+00,  ..., -1.0156e+00,
          -2.2188e+00, -1.3867e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-2.0264e-02,  7.4005e-04, -1.5869e-02,  ..., -1.4893e-02,
           4.4250e-03, -2.7771e-03],
         [ 5.3125e-01,  7.6172e-01,  1.8203e+00,  ..., -3.2812e-01,
          -6.6016e-01, -5.2734e-01],
         [ 1.6328e+00,  1.3438e+00,  4.5898e-01,  ...,  7.7344e-01,
          -7.0703e-01,  9.9219e-01],
         ...,
         [ 1.8457e-01, -9.2969e-01,  1.8594e+00,  ...,  2.0996e-01,
           9.4922e-01,  7.8516e-01],
         [-3.9844e-01, -9.1406e-01, -7.1094e-01,  ..., -5.8984e-01,
           7.3438e-01,  8.3496e-02],
         [ 4.0283e-03,  5.5469e-01, -7.4219e-01,  ..., -5.6250e-01,
           1.1406e+00, -1.2578e+00]],

        [[-2.0264e-02,  7.4005e-04, -1.5869e-02,  ..., -1.4893e-02,
           4.4250e-03, -2.7771e-03],
         [ 5.3125e-01,  7.6172e-01,  1.8203e+00,  ..., -3.2812e-01,
          -6.6016e-01, -5.2734e-01],
         [ 1.6328e+00,  1.3438e+00,  4.5898e-01,  ...,  7.7344e-01,
          -7.0703e-01,  9.9219e-01],
         ...,
         [ 1.8457e-01, -9.2969e-01,  1.8594e+00,  ...,  2.0996e-01,
           9.4922e-01,  7.8516e-01],
         [-3.9844e-01, -9.1406e-01, -7.1094e-01,  ..., -5.8984e-01,
           7.3438e-01,  8.3496e-02],
         [ 4.0283e-03,  5.5469e-01, -7.4219e-01,  ..., -5.6250e-01,
           1.1406e+00, -1.2578e+00]],

        [[-2.0264e-02,  7.4005e-04, -1.5869e-02,  ..., -1.4893e-02,
           4.4250e-03, -2.7771e-03],
         [ 5.3125e-01,  7.6172e-01,  1.8203e+00,  ..., -3.2812e-01,
          -6.6016e-01, -5.2734e-01],
         [ 1.6328e+00,  1.3438e+00,  4.5898e-01,  ...,  7.7344e-01,
          -7.0703e-01,  9.9219e-01],
         ...,
         [ 1.8457e-01, -9.2969e-01,  1.8594e+00,  ...,  2.0996e-01,
           9.4922e-01,  7.8516e-01],
         [-3.9844e-01, -9.1406e-01, -7.1094e-01,  ..., -5.8984e-01,
           7.3438e-01,  8.3496e-02],
         [ 4.0283e-03,  5.5469e-01, -7.4219e-01,  ..., -5.6250e-01,
           1.1406e+00, -1.2578e+00]],

        ...,

        [[ 6.8665e-04, -4.1809e-03, -1.6968e-02,  ...,  1.6235e-02,
           5.7068e-03,  1.3062e-02],
         [ 2.1191e-01, -4.6289e-01,  7.0703e-01,  ..., -3.9648e-01,
          -3.3984e-01,  4.9805e-01],
         [-1.2793e-01, -1.2500e-01,  1.2734e+00,  ...,  2.5391e-01,
           2.8906e-01, -1.5469e+00],
         ...,
         [-1.0234e+00,  1.1250e+00,  1.3047e+00,  ...,  9.9609e-01,
          -1.0938e+00, -8.7891e-01],
         [ 5.3516e-01, -2.3047e-01,  1.5312e+00,  ..., -4.1875e+00,
           3.9062e-01, -1.0625e+00],
         [ 6.4844e-01,  6.7969e-01,  5.7812e-01,  ..., -3.5156e-01,
          -7.2266e-01, -8.9062e-01]],

        [[ 6.8665e-04, -4.1809e-03, -1.6968e-02,  ...,  1.6235e-02,
           5.7068e-03,  1.3062e-02],
         [ 2.1191e-01, -4.6289e-01,  7.0703e-01,  ..., -3.9648e-01,
          -3.3984e-01,  4.9805e-01],
         [-1.2793e-01, -1.2500e-01,  1.2734e+00,  ...,  2.5391e-01,
           2.8906e-01, -1.5469e+00],
         ...,
         [-1.0234e+00,  1.1250e+00,  1.3047e+00,  ...,  9.9609e-01,
          -1.0938e+00, -8.7891e-01],
         [ 5.3516e-01, -2.3047e-01,  1.5312e+00,  ..., -4.1875e+00,
           3.9062e-01, -1.0625e+00],
         [ 6.4844e-01,  6.7969e-01,  5.7812e-01,  ..., -3.5156e-01,
          -7.2266e-01, -8.9062e-01]],

        [[ 6.8665e-04, -4.1809e-03, -1.6968e-02,  ...,  1.6235e-02,
           5.7068e-03,  1.3062e-02],
         [ 2.1191e-01, -4.6289e-01,  7.0703e-01,  ..., -3.9648e-01,
          -3.3984e-01,  4.9805e-01],
         [-1.2793e-01, -1.2500e-01,  1.2734e+00,  ...,  2.5391e-01,
           2.8906e-01, -1.5469e+00],
         ...,
         [-1.0234e+00,  1.1250e+00,  1.3047e+00,  ...,  9.9609e-01,
          -1.0938e+00, -8.7891e-01],
         [ 5.3516e-01, -2.3047e-01,  1.5312e+00,  ..., -4.1875e+00,
           3.9062e-01, -1.0625e+00],
         [ 6.4844e-01,  6.7969e-01,  5.7812e-01,  ..., -3.5156e-01,
          -7.2266e-01, -8.9062e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[ 6.1340e-03, -1.6724e-02,  5.8365e-04,  ...,  3.6328e-01,
          -4.5898e-02, -1.3086e-01],
         [ 1.1641e+00,  6.7969e-01, -1.1875e+00,  ..., -9.2188e-01,
          -5.7373e-02,  9.6094e-01],
         [ 1.6875e+00,  1.3438e+00, -6.3281e-01,  ..., -1.5078e+00,
          -4.3164e-01,  1.9375e+00],
         ...,
         [-6.6406e-02, -8.2422e-01,  1.5000e+00,  ..., -3.3203e-01,
          -1.2812e+00,  1.2031e+00],
         [ 7.6562e-01, -9.0625e-01, -5.5859e-01,  ..., -8.7109e-01,
           1.0352e-01,  1.8984e+00],
         [ 2.4062e+00, -6.9141e-01,  2.1484e-01,  ..., -1.1250e+00,
          -8.1641e-01,  1.0312e+00]],

        [[ 6.1340e-03, -1.6724e-02,  5.8365e-04,  ...,  3.6328e-01,
          -4.5898e-02, -1.3086e-01],
         [ 1.1641e+00,  6.7969e-01, -1.1875e+00,  ..., -9.2188e-01,
          -5.7373e-02,  9.6094e-01],
         [ 1.6875e+00,  1.3438e+00, -6.3281e-01,  ..., -1.5078e+00,
          -4.3164e-01,  1.9375e+00],
         ...,
         [-6.6406e-02, -8.2422e-01,  1.5000e+00,  ..., -3.3203e-01,
          -1.2812e+00,  1.2031e+00],
         [ 7.6562e-01, -9.0625e-01, -5.5859e-01,  ..., -8.7109e-01,
           1.0352e-01,  1.8984e+00],
         [ 2.4062e+00, -6.9141e-01,  2.1484e-01,  ..., -1.1250e+00,
          -8.1641e-01,  1.0312e+00]],

        [[ 6.1340e-03, -1.6724e-02,  5.8365e-04,  ...,  3.6328e-01,
          -4.5898e-02, -1.3086e-01],
         [ 1.1641e+00,  6.7969e-01, -1.1875e+00,  ..., -9.2188e-01,
          -5.7373e-02,  9.6094e-01],
         [ 1.6875e+00,  1.3438e+00, -6.3281e-01,  ..., -1.5078e+00,
          -4.3164e-01,  1.9375e+00],
         ...,
         [-6.6406e-02, -8.2422e-01,  1.5000e+00,  ..., -3.3203e-01,
          -1.2812e+00,  1.2031e+00],
         [ 7.6562e-01, -9.0625e-01, -5.5859e-01,  ..., -8.7109e-01,
           1.0352e-01,  1.8984e+00],
         [ 2.4062e+00, -6.9141e-01,  2.1484e-01,  ..., -1.1250e+00,
          -8.1641e-01,  1.0312e+00]],

        ...,

        [[ 1.5076e-02,  1.1780e-02,  4.5471e-03,  ..., -1.2012e-01,
          -1.2024e-02, -1.5391e+00],
         [ 3.3125e+00, -1.0156e-01, -9.9219e-01,  ..., -3.9062e-01,
           5.7812e-01,  6.6250e+00],
         [ 2.0781e+00,  1.8438e+00, -9.1406e-01,  ..., -1.8359e+00,
           4.6875e-02,  8.0625e+00],
         ...,
         [ 2.0781e+00,  1.6406e+00,  1.6875e+00,  ...,  6.8750e-01,
          -1.1953e+00,  7.0312e+00],
         [ 2.3438e+00,  1.4844e+00,  1.4062e+00,  ...,  1.2578e+00,
          -3.7109e-01,  7.4688e+00],
         [ 1.1250e+00, -1.1406e+00,  7.2656e-01,  ...,  7.4219e-01,
          -3.7695e-01,  7.5625e+00]],

        [[ 1.5076e-02,  1.1780e-02,  4.5471e-03,  ..., -1.2012e-01,
          -1.2024e-02, -1.5391e+00],
         [ 3.3125e+00, -1.0156e-01, -9.9219e-01,  ..., -3.9062e-01,
           5.7812e-01,  6.6250e+00],
         [ 2.0781e+00,  1.8438e+00, -9.1406e-01,  ..., -1.8359e+00,
           4.6875e-02,  8.0625e+00],
         ...,
         [ 2.0781e+00,  1.6406e+00,  1.6875e+00,  ...,  6.8750e-01,
          -1.1953e+00,  7.0312e+00],
         [ 2.3438e+00,  1.4844e+00,  1.4062e+00,  ...,  1.2578e+00,
          -3.7109e-01,  7.4688e+00],
         [ 1.1250e+00, -1.1406e+00,  7.2656e-01,  ...,  7.4219e-01,
          -3.7695e-01,  7.5625e+00]],

        [[ 1.5076e-02,  1.1780e-02,  4.5471e-03,  ..., -1.2012e-01,
          -1.2024e-02, -1.5391e+00],
         [ 3.3125e+00, -1.0156e-01, -9.9219e-01,  ..., -3.9062e-01,
           5.7812e-01,  6.6250e+00],
         [ 2.0781e+00,  1.8438e+00, -9.1406e-01,  ..., -1.8359e+00,
           4.6875e-02,  8.0625e+00],
         ...,
         [ 2.0781e+00,  1.6406e+00,  1.6875e+00,  ...,  6.8750e-01,
          -1.1953e+00,  7.0312e+00],
         [ 2.3438e+00,  1.4844e+00,  1.4062e+00,  ...,  1.2578e+00,
          -3.7109e-01,  7.4688e+00],
         [ 1.1250e+00, -1.1406e+00,  7.2656e-01,  ...,  7.4219e-01,
          -3.7695e-01,  7.5625e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 4.6387e-03,  8.9722e-03, -1.2268e-02,  ...,  2.8687e-02,
           2.5024e-03,  4.5776e-03],
         [-2.1094e-01,  3.4180e-02,  6.0938e-01,  ...,  7.0312e-01,
           3.2031e-01,  3.0078e-01],
         [ 7.0312e-01,  9.6094e-01, -2.3730e-01,  ..., -2.8125e-01,
           9.0625e-01, -6.3281e-01],
         ...,
         [-1.3750e+00,  1.5781e+00, -6.0303e-02,  ...,  2.2188e+00,
           5.4297e-01,  4.2236e-02],
         [ 5.4688e-01, -1.0938e+00,  1.6719e+00,  ..., -1.5469e+00,
          -2.9883e-01, -6.0547e-01],
         [ 2.5195e-01, -1.4219e+00,  7.6172e-01,  ..., -2.3906e+00,
           9.2578e-01,  7.5000e-01]],

        [[ 4.6387e-03,  8.9722e-03, -1.2268e-02,  ...,  2.8687e-02,
           2.5024e-03,  4.5776e-03],
         [-2.1094e-01,  3.4180e-02,  6.0938e-01,  ...,  7.0312e-01,
           3.2031e-01,  3.0078e-01],
         [ 7.0312e-01,  9.6094e-01, -2.3730e-01,  ..., -2.8125e-01,
           9.0625e-01, -6.3281e-01],
         ...,
         [-1.3750e+00,  1.5781e+00, -6.0303e-02,  ...,  2.2188e+00,
           5.4297e-01,  4.2236e-02],
         [ 5.4688e-01, -1.0938e+00,  1.6719e+00,  ..., -1.5469e+00,
          -2.9883e-01, -6.0547e-01],
         [ 2.5195e-01, -1.4219e+00,  7.6172e-01,  ..., -2.3906e+00,
           9.2578e-01,  7.5000e-01]],

        [[ 4.6387e-03,  8.9722e-03, -1.2268e-02,  ...,  2.8687e-02,
           2.5024e-03,  4.5776e-03],
         [-2.1094e-01,  3.4180e-02,  6.0938e-01,  ...,  7.0312e-01,
           3.2031e-01,  3.0078e-01],
         [ 7.0312e-01,  9.6094e-01, -2.3730e-01,  ..., -2.8125e-01,
           9.0625e-01, -6.3281e-01],
         ...,
         [-1.3750e+00,  1.5781e+00, -6.0303e-02,  ...,  2.2188e+00,
           5.4297e-01,  4.2236e-02],
         [ 5.4688e-01, -1.0938e+00,  1.6719e+00,  ..., -1.5469e+00,
          -2.9883e-01, -6.0547e-01],
         [ 2.5195e-01, -1.4219e+00,  7.6172e-01,  ..., -2.3906e+00,
           9.2578e-01,  7.5000e-01]],

        ...,

        [[-7.5684e-03,  1.5869e-02, -2.8198e-02,  ...,  1.9531e-02,
           6.4697e-03, -8.6975e-04],
         [ 1.1914e-01,  4.0283e-02,  1.3438e+00,  ..., -8.8672e-01,
           1.7031e+00, -2.0996e-01],
         [ 5.4297e-01, -9.4922e-01,  1.1484e+00,  ..., -2.2656e+00,
          -1.2266e+00, -2.0781e+00],
         ...,
         [ 2.4375e+00,  8.4375e-01,  4.7656e-01,  ..., -1.0703e+00,
           4.1602e-01, -4.1250e+00],
         [ 1.1562e+00, -7.8516e-01, -1.8359e-01,  ..., -1.0234e+00,
          -1.5391e+00, -1.7891e+00],
         [ 2.4844e+00, -3.0469e+00, -7.0801e-02,  ..., -2.3438e+00,
           7.8125e-01, -1.6250e+00]],

        [[-7.5684e-03,  1.5869e-02, -2.8198e-02,  ...,  1.9531e-02,
           6.4697e-03, -8.6975e-04],
         [ 1.1914e-01,  4.0283e-02,  1.3438e+00,  ..., -8.8672e-01,
           1.7031e+00, -2.0996e-01],
         [ 5.4297e-01, -9.4922e-01,  1.1484e+00,  ..., -2.2656e+00,
          -1.2266e+00, -2.0781e+00],
         ...,
         [ 2.4375e+00,  8.4375e-01,  4.7656e-01,  ..., -1.0703e+00,
           4.1602e-01, -4.1250e+00],
         [ 1.1562e+00, -7.8516e-01, -1.8359e-01,  ..., -1.0234e+00,
          -1.5391e+00, -1.7891e+00],
         [ 2.4844e+00, -3.0469e+00, -7.0801e-02,  ..., -2.3438e+00,
           7.8125e-01, -1.6250e+00]],

        [[-7.5684e-03,  1.5869e-02, -2.8198e-02,  ...,  1.9531e-02,
           6.4697e-03, -8.6975e-04],
         [ 1.1914e-01,  4.0283e-02,  1.3438e+00,  ..., -8.8672e-01,
           1.7031e+00, -2.0996e-01],
         [ 5.4297e-01, -9.4922e-01,  1.1484e+00,  ..., -2.2656e+00,
          -1.2266e+00, -2.0781e+00],
         ...,
         [ 2.4375e+00,  8.4375e-01,  4.7656e-01,  ..., -1.0703e+00,
           4.1602e-01, -4.1250e+00],
         [ 1.1562e+00, -7.8516e-01, -1.8359e-01,  ..., -1.0234e+00,
          -1.5391e+00, -1.7891e+00],
         [ 2.4844e+00, -3.0469e+00, -7.0801e-02,  ..., -2.3438e+00,
           7.8125e-01, -1.6250e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-1.7090e-02,  5.7983e-03,  1.4267e-03,  ..., -4.4189e-02,
          -2.8076e-03, -2.6001e-02],
         [ 2.1250e+00, -6.4844e-01, -5.2734e-01,  ...,  1.3359e+00,
          -8.1641e-01,  7.1484e-01],
         [-1.7383e-01,  3.6328e-01,  2.1406e+00,  ...,  1.6953e+00,
          -2.8750e+00, -4.3555e-01],
         ...,
         [ 5.0781e-01, -1.5312e+00,  1.0781e+00,  ..., -3.1562e+00,
           5.9375e-01,  4.2969e-01],
         [-4.1211e-01, -4.0234e-01,  1.1562e+00,  ..., -2.0156e+00,
          -1.7188e+00,  1.5391e+00],
         [-1.6172e+00, -1.4160e-01,  1.8750e+00,  ..., -1.6016e+00,
          -6.8750e-01, -1.5723e-01]],

        [[-1.7090e-02,  5.7983e-03,  1.4267e-03,  ..., -4.4189e-02,
          -2.8076e-03, -2.6001e-02],
         [ 2.1250e+00, -6.4844e-01, -5.2734e-01,  ...,  1.3359e+00,
          -8.1641e-01,  7.1484e-01],
         [-1.7383e-01,  3.6328e-01,  2.1406e+00,  ...,  1.6953e+00,
          -2.8750e+00, -4.3555e-01],
         ...,
         [ 5.0781e-01, -1.5312e+00,  1.0781e+00,  ..., -3.1562e+00,
           5.9375e-01,  4.2969e-01],
         [-4.1211e-01, -4.0234e-01,  1.1562e+00,  ..., -2.0156e+00,
          -1.7188e+00,  1.5391e+00],
         [-1.6172e+00, -1.4160e-01,  1.8750e+00,  ..., -1.6016e+00,
          -6.8750e-01, -1.5723e-01]],

        [[-1.7090e-02,  5.7983e-03,  1.4267e-03,  ..., -4.4189e-02,
          -2.8076e-03, -2.6001e-02],
         [ 2.1250e+00, -6.4844e-01, -5.2734e-01,  ...,  1.3359e+00,
          -8.1641e-01,  7.1484e-01],
         [-1.7383e-01,  3.6328e-01,  2.1406e+00,  ...,  1.6953e+00,
          -2.8750e+00, -4.3555e-01],
         ...,
         [ 5.0781e-01, -1.5312e+00,  1.0781e+00,  ..., -3.1562e+00,
           5.9375e-01,  4.2969e-01],
         [-4.1211e-01, -4.0234e-01,  1.1562e+00,  ..., -2.0156e+00,
          -1.7188e+00,  1.5391e+00],
         [-1.6172e+00, -1.4160e-01,  1.8750e+00,  ..., -1.6016e+00,
          -6.8750e-01, -1.5723e-01]],

        ...,

        [[-1.4221e-02,  2.1240e-02, -1.3855e-02,  ...,  1.4258e-01,
           8.3984e-02, -1.8906e+00],
         [ 2.9297e-02, -2.5391e-01, -9.0625e-01,  ...,  4.1016e-01,
          -1.7344e+00,  5.9375e+00],
         [ 4.1406e-01,  3.7842e-02,  1.7578e-01,  ..., -5.4297e-01,
          -2.0469e+00,  6.5312e+00],
         ...,
         [ 1.0156e+00,  8.3594e-01, -3.8477e-01,  ..., -8.2422e-01,
           1.6016e+00,  5.2188e+00],
         [ 1.0234e+00,  4.8633e-01,  1.6602e-01,  ..., -2.6562e+00,
          -6.7188e-01,  5.3438e+00],
         [ 1.0391e+00,  2.2266e-01, -4.4922e-01,  ..., -2.5156e+00,
          -1.1484e+00,  4.9688e+00]],

        [[-1.4221e-02,  2.1240e-02, -1.3855e-02,  ...,  1.4258e-01,
           8.3984e-02, -1.8906e+00],
         [ 2.9297e-02, -2.5391e-01, -9.0625e-01,  ...,  4.1016e-01,
          -1.7344e+00,  5.9375e+00],
         [ 4.1406e-01,  3.7842e-02,  1.7578e-01,  ..., -5.4297e-01,
          -2.0469e+00,  6.5312e+00],
         ...,
         [ 1.0156e+00,  8.3594e-01, -3.8477e-01,  ..., -8.2422e-01,
           1.6016e+00,  5.2188e+00],
         [ 1.0234e+00,  4.8633e-01,  1.6602e-01,  ..., -2.6562e+00,
          -6.7188e-01,  5.3438e+00],
         [ 1.0391e+00,  2.2266e-01, -4.4922e-01,  ..., -2.5156e+00,
          -1.1484e+00,  4.9688e+00]],

        [[-1.4221e-02,  2.1240e-02, -1.3855e-02,  ...,  1.4258e-01,
           8.3984e-02, -1.8906e+00],
         [ 2.9297e-02, -2.5391e-01, -9.0625e-01,  ...,  4.1016e-01,
          -1.7344e+00,  5.9375e+00],
         [ 4.1406e-01,  3.7842e-02,  1.7578e-01,  ..., -5.4297e-01,
          -2.0469e+00,  6.5312e+00],
         ...,
         [ 1.0156e+00,  8.3594e-01, -3.8477e-01,  ..., -8.2422e-01,
           1.6016e+00,  5.2188e+00],
         [ 1.0234e+00,  4.8633e-01,  1.6602e-01,  ..., -2.6562e+00,
          -6.7188e-01,  5.3438e+00],
         [ 1.0391e+00,  2.2266e-01, -4.4922e-01,  ..., -2.5156e+00,
          -1.1484e+00,  4.9688e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-4.0039e-02, -1.9409e-02, -2.3804e-02,  ...,  3.9795e-02,
           6.1035e-02, -2.9785e-02],
         [-1.0938e+00,  4.3750e-01, -2.5781e-01,  ..., -3.9258e-01,
           7.9590e-02,  1.3828e+00],
         [-7.3828e-01, -2.5391e-01,  6.8359e-01,  ..., -4.2578e-01,
           6.7969e-01,  4.2578e-01],
         ...,
         [-7.7344e-01,  1.6172e+00, -2.0469e+00,  ...,  4.6875e-01,
           1.2500e+00,  7.4219e-01],
         [ 3.7109e-02, -2.1875e-01, -3.4766e-01,  ..., -7.3828e-01,
          -6.5625e-01, -2.0625e+00],
         [ 7.8906e-01, -8.4375e-01, -3.8281e+00,  ...,  9.5312e-01,
           6.2891e-01, -2.9375e+00]],

        [[-4.0039e-02, -1.9409e-02, -2.3804e-02,  ...,  3.9795e-02,
           6.1035e-02, -2.9785e-02],
         [-1.0938e+00,  4.3750e-01, -2.5781e-01,  ..., -3.9258e-01,
           7.9590e-02,  1.3828e+00],
         [-7.3828e-01, -2.5391e-01,  6.8359e-01,  ..., -4.2578e-01,
           6.7969e-01,  4.2578e-01],
         ...,
         [-7.7344e-01,  1.6172e+00, -2.0469e+00,  ...,  4.6875e-01,
           1.2500e+00,  7.4219e-01],
         [ 3.7109e-02, -2.1875e-01, -3.4766e-01,  ..., -7.3828e-01,
          -6.5625e-01, -2.0625e+00],
         [ 7.8906e-01, -8.4375e-01, -3.8281e+00,  ...,  9.5312e-01,
           6.2891e-01, -2.9375e+00]],

        [[-4.0039e-02, -1.9409e-02, -2.3804e-02,  ...,  3.9795e-02,
           6.1035e-02, -2.9785e-02],
         [-1.0938e+00,  4.3750e-01, -2.5781e-01,  ..., -3.9258e-01,
           7.9590e-02,  1.3828e+00],
         [-7.3828e-01, -2.5391e-01,  6.8359e-01,  ..., -4.2578e-01,
           6.7969e-01,  4.2578e-01],
         ...,
         [-7.7344e-01,  1.6172e+00, -2.0469e+00,  ...,  4.6875e-01,
           1.2500e+00,  7.4219e-01],
         [ 3.7109e-02, -2.1875e-01, -3.4766e-01,  ..., -7.3828e-01,
          -6.5625e-01, -2.0625e+00],
         [ 7.8906e-01, -8.4375e-01, -3.8281e+00,  ...,  9.5312e-01,
           6.2891e-01, -2.9375e+00]],

        ...,

        [[ 6.0730e-03, -1.4343e-02, -1.2146e-02,  ...,  1.8692e-03,
          -4.9744e-03,  3.8757e-03],
         [-4.6875e-01,  2.4414e-01, -9.5312e-01,  ...,  5.5859e-01,
          -7.8125e-01,  4.4727e-01],
         [-8.5547e-01, -5.8984e-01,  4.5898e-01,  ..., -2.5625e+00,
          -5.1270e-02,  3.3008e-01],
         ...,
         [-3.9453e-01, -3.5352e-01,  7.9688e-01,  ..., -4.6680e-01,
          -1.5869e-02,  6.4844e-01],
         [ 2.4902e-02, -4.5898e-01, -9.2969e-01,  ..., -1.0078e+00,
          -5.7678e-03,  3.3594e-01],
         [-1.9062e+00, -7.9688e-01,  1.7822e-02,  ..., -1.8438e+00,
           6.4453e-02, -4.3750e-01]],

        [[ 6.0730e-03, -1.4343e-02, -1.2146e-02,  ...,  1.8692e-03,
          -4.9744e-03,  3.8757e-03],
         [-4.6875e-01,  2.4414e-01, -9.5312e-01,  ...,  5.5859e-01,
          -7.8125e-01,  4.4727e-01],
         [-8.5547e-01, -5.8984e-01,  4.5898e-01,  ..., -2.5625e+00,
          -5.1270e-02,  3.3008e-01],
         ...,
         [-3.9453e-01, -3.5352e-01,  7.9688e-01,  ..., -4.6680e-01,
          -1.5869e-02,  6.4844e-01],
         [ 2.4902e-02, -4.5898e-01, -9.2969e-01,  ..., -1.0078e+00,
          -5.7678e-03,  3.3594e-01],
         [-1.9062e+00, -7.9688e-01,  1.7822e-02,  ..., -1.8438e+00,
           6.4453e-02, -4.3750e-01]],

        [[ 6.0730e-03, -1.4343e-02, -1.2146e-02,  ...,  1.8692e-03,
          -4.9744e-03,  3.8757e-03],
         [-4.6875e-01,  2.4414e-01, -9.5312e-01,  ...,  5.5859e-01,
          -7.8125e-01,  4.4727e-01],
         [-8.5547e-01, -5.8984e-01,  4.5898e-01,  ..., -2.5625e+00,
          -5.1270e-02,  3.3008e-01],
         ...,
         [-3.9453e-01, -3.5352e-01,  7.9688e-01,  ..., -4.6680e-01,
          -1.5869e-02,  6.4844e-01],
         [ 2.4902e-02, -4.5898e-01, -9.2969e-01,  ..., -1.0078e+00,
          -5.7678e-03,  3.3594e-01],
         [-1.9062e+00, -7.9688e-01,  1.7822e-02,  ..., -1.8438e+00,
           6.4453e-02, -4.3750e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-0.0069,  0.0071, -0.0229,  ..., -0.2656, -0.3203,  0.5664],
         [ 1.9844, -0.9141,  1.1562,  ..., -2.0938, -0.4727,  0.7070],
         [ 1.4766, -0.3887,  0.7656,  ..., -2.5625, -0.1885, -0.8750],
         ...,
         [ 0.9492,  0.5312, -1.0234,  ..., -1.1406,  0.0277, -2.6562],
         [ 0.4082, -0.0649, -0.3672,  ..., -2.3750,  1.6641, -3.0000],
         [ 0.7344,  0.7656, -0.0039,  ..., -2.3750, -1.9297, -2.0469]],

        [[-0.0069,  0.0071, -0.0229,  ..., -0.2656, -0.3203,  0.5664],
         [ 1.9844, -0.9141,  1.1562,  ..., -2.0938, -0.4727,  0.7070],
         [ 1.4766, -0.3887,  0.7656,  ..., -2.5625, -0.1885, -0.8750],
         ...,
         [ 0.9492,  0.5312, -1.0234,  ..., -1.1406,  0.0277, -2.6562],
         [ 0.4082, -0.0649, -0.3672,  ..., -2.3750,  1.6641, -3.0000],
         [ 0.7344,  0.7656, -0.0039,  ..., -2.3750, -1.9297, -2.0469]],

        [[-0.0069,  0.0071, -0.0229,  ..., -0.2656, -0.3203,  0.5664],
         [ 1.9844, -0.9141,  1.1562,  ..., -2.0938, -0.4727,  0.7070],
         [ 1.4766, -0.3887,  0.7656,  ..., -2.5625, -0.1885, -0.8750],
         ...,
         [ 0.9492,  0.5312, -1.0234,  ..., -1.1406,  0.0277, -2.6562],
         [ 0.4082, -0.0649, -0.3672,  ..., -2.3750,  1.6641, -3.0000],
         [ 0.7344,  0.7656, -0.0039,  ..., -2.3750, -1.9297, -2.0469]],

        ...,

        [[-0.0102, -0.0106,  0.0152,  ..., -0.0306,  0.2354,  0.2734],
         [ 0.5859,  0.5820,  1.4844,  ...,  0.5117, -0.1738, -0.3418],
         [-0.3145,  2.4375,  0.7930,  ..., -0.0059,  0.1729, -0.4980],
         ...,
         [ 0.3984,  0.6250, -1.8047,  ..., -0.6641,  0.0165, -0.4453],
         [-0.3789,  0.5547,  0.1738,  ...,  1.2578, -1.5703,  1.8906],
         [-2.2188, -0.7695,  0.3887,  ..., -0.2637, -0.2061,  1.6406]],

        [[-0.0102, -0.0106,  0.0152,  ..., -0.0306,  0.2354,  0.2734],
         [ 0.5859,  0.5820,  1.4844,  ...,  0.5117, -0.1738, -0.3418],
         [-0.3145,  2.4375,  0.7930,  ..., -0.0059,  0.1729, -0.4980],
         ...,
         [ 0.3984,  0.6250, -1.8047,  ..., -0.6641,  0.0165, -0.4453],
         [-0.3789,  0.5547,  0.1738,  ...,  1.2578, -1.5703,  1.8906],
         [-2.2188, -0.7695,  0.3887,  ..., -0.2637, -0.2061,  1.6406]],

        [[-0.0102, -0.0106,  0.0152,  ..., -0.0306,  0.2354,  0.2734],
         [ 0.5859,  0.5820,  1.4844,  ...,  0.5117, -0.1738, -0.3418],
         [-0.3145,  2.4375,  0.7930,  ..., -0.0059,  0.1729, -0.4980],
         ...,
         [ 0.3984,  0.6250, -1.8047,  ..., -0.6641,  0.0165, -0.4453],
         [-0.3789,  0.5547,  0.1738,  ...,  1.2578, -1.5703,  1.8906],
         [-2.2188, -0.7695,  0.3887,  ..., -0.2637, -0.2061,  1.6406]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[ 0.0066,  0.0254,  0.0081,  ...,  0.0171,  0.0244, -0.0117],
         [ 0.2949, -0.6914,  0.5703,  ...,  0.0679,  0.4258, -0.1797],
         [-0.5859, -1.0625, -0.5781,  ...,  0.0036,  0.9961, -1.3984],
         ...,
         [-0.2295,  0.6641, -0.1377,  ..., -1.1875,  0.7969,  0.6602],
         [-0.0615,  0.4805,  1.2031,  ..., -1.3047,  0.3047,  0.3750],
         [-0.4727,  0.3672,  1.0938,  ...,  0.9219, -0.0688, -0.8086]],

        [[ 0.0066,  0.0254,  0.0081,  ...,  0.0171,  0.0244, -0.0117],
         [ 0.2949, -0.6914,  0.5703,  ...,  0.0679,  0.4258, -0.1797],
         [-0.5859, -1.0625, -0.5781,  ...,  0.0036,  0.9961, -1.3984],
         ...,
         [-0.2295,  0.6641, -0.1377,  ..., -1.1875,  0.7969,  0.6602],
         [-0.0615,  0.4805,  1.2031,  ..., -1.3047,  0.3047,  0.3750],
         [-0.4727,  0.3672,  1.0938,  ...,  0.9219, -0.0688, -0.8086]],

        [[ 0.0066,  0.0254,  0.0081,  ...,  0.0171,  0.0244, -0.0117],
         [ 0.2949, -0.6914,  0.5703,  ...,  0.0679,  0.4258, -0.1797],
         [-0.5859, -1.0625, -0.5781,  ...,  0.0036,  0.9961, -1.3984],
         ...,
         [-0.2295,  0.6641, -0.1377,  ..., -1.1875,  0.7969,  0.6602],
         [-0.0615,  0.4805,  1.2031,  ..., -1.3047,  0.3047,  0.3750],
         [-0.4727,  0.3672,  1.0938,  ...,  0.9219, -0.0688, -0.8086]],

        ...,

        [[-0.0140, -0.0067,  0.0165,  ..., -0.0031, -0.0269,  0.0166],
         [ 0.7578,  0.6719, -0.7812,  ..., -0.4512, -1.5156, -0.4316],
         [ 1.1406,  1.6641, -0.1426,  ..., -0.7422,  0.7422, -2.5156],
         ...,
         [-1.5625, -1.3594,  0.0552,  ..., -0.0302, -0.5117,  0.4004],
         [-1.6406, -0.3691, -0.2852,  ...,  0.9336, -0.0918, -1.4609],
         [-1.5547, -0.0308, -1.0156,  ..., -0.6406, -0.0928,  0.2949]],

        [[-0.0140, -0.0067,  0.0165,  ..., -0.0031, -0.0269,  0.0166],
         [ 0.7578,  0.6719, -0.7812,  ..., -0.4512, -1.5156, -0.4316],
         [ 1.1406,  1.6641, -0.1426,  ..., -0.7422,  0.7422, -2.5156],
         ...,
         [-1.5625, -1.3594,  0.0552,  ..., -0.0302, -0.5117,  0.4004],
         [-1.6406, -0.3691, -0.2852,  ...,  0.9336, -0.0918, -1.4609],
         [-1.5547, -0.0308, -1.0156,  ..., -0.6406, -0.0928,  0.2949]],

        [[-0.0140, -0.0067,  0.0165,  ..., -0.0031, -0.0269,  0.0166],
         [ 0.7578,  0.6719, -0.7812,  ..., -0.4512, -1.5156, -0.4316],
         [ 1.1406,  1.6641, -0.1426,  ..., -0.7422,  0.7422, -2.5156],
         ...,
         [-1.5625, -1.3594,  0.0552,  ..., -0.0302, -0.5117,  0.4004],
         [-1.6406, -0.3691, -0.2852,  ...,  0.9336, -0.0918, -1.4609],
         [-1.5547, -0.0308, -1.0156,  ..., -0.6406, -0.0928,  0.2949]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[ 4.0894e-03,  4.7607e-03, -2.6245e-03,  ..., -6.8848e-02,
           4.7852e-02,  1.8262e-01],
         [-3.7500e+00,  1.0625e+00, -1.1094e+00,  ..., -1.0303e-01,
           3.3203e-01, -1.0547e+00],
         [-1.4766e+00,  3.1562e+00, -1.0625e+00,  ..., -1.7578e-02,
          -9.4531e-01, -2.1250e+00],
         ...,
         [-2.8438e+00,  1.2266e+00,  1.4688e+00,  ...,  6.7578e-01,
           1.4141e+00, -8.2812e-01],
         [-4.9609e-01,  8.3984e-01,  6.4062e-01,  ...,  6.5234e-01,
          -1.3828e+00,  1.9062e+00],
         [ 1.1094e+00, -3.2812e-01, -3.3008e-01,  ...,  3.6865e-02,
          -7.6953e-01,  1.6406e-01]],

        [[ 4.0894e-03,  4.7607e-03, -2.6245e-03,  ..., -6.8848e-02,
           4.7852e-02,  1.8262e-01],
         [-3.7500e+00,  1.0625e+00, -1.1094e+00,  ..., -1.0303e-01,
           3.3203e-01, -1.0547e+00],
         [-1.4766e+00,  3.1562e+00, -1.0625e+00,  ..., -1.7578e-02,
          -9.4531e-01, -2.1250e+00],
         ...,
         [-2.8438e+00,  1.2266e+00,  1.4688e+00,  ...,  6.7578e-01,
           1.4141e+00, -8.2812e-01],
         [-4.9609e-01,  8.3984e-01,  6.4062e-01,  ...,  6.5234e-01,
          -1.3828e+00,  1.9062e+00],
         [ 1.1094e+00, -3.2812e-01, -3.3008e-01,  ...,  3.6865e-02,
          -7.6953e-01,  1.6406e-01]],

        [[ 4.0894e-03,  4.7607e-03, -2.6245e-03,  ..., -6.8848e-02,
           4.7852e-02,  1.8262e-01],
         [-3.7500e+00,  1.0625e+00, -1.1094e+00,  ..., -1.0303e-01,
           3.3203e-01, -1.0547e+00],
         [-1.4766e+00,  3.1562e+00, -1.0625e+00,  ..., -1.7578e-02,
          -9.4531e-01, -2.1250e+00],
         ...,
         [-2.8438e+00,  1.2266e+00,  1.4688e+00,  ...,  6.7578e-01,
           1.4141e+00, -8.2812e-01],
         [-4.9609e-01,  8.3984e-01,  6.4062e-01,  ...,  6.5234e-01,
          -1.3828e+00,  1.9062e+00],
         [ 1.1094e+00, -3.2812e-01, -3.3008e-01,  ...,  3.6865e-02,
          -7.6953e-01,  1.6406e-01]],

        ...,

        [[ 8.1787e-03,  1.9165e-02, -1.0620e-02,  ...,  1.6699e-01,
          -1.3965e-01, -1.6016e-01],
         [-1.8906e+00,  1.1406e+00,  4.0039e-01,  ...,  1.2305e-01,
           4.7852e-01, -2.8906e-01],
         [ 5.7031e-01, -1.0469e+00, -5.1172e-01,  ..., -6.7578e-01,
          -2.7734e-01, -1.0781e+00],
         ...,
         [-1.7773e-01, -3.9453e-01, -1.3906e+00,  ...,  5.5078e-01,
          -2.9688e-01, -6.3965e-02],
         [-1.6016e-01, -5.8203e-01, -4.2773e-01,  ...,  9.1016e-01,
          -7.6172e-01,  9.6875e-01],
         [ 1.6484e+00, -1.0312e+00, -1.4922e+00,  ...,  3.8281e-01,
           1.3828e+00, -8.8379e-02]],

        [[ 8.1787e-03,  1.9165e-02, -1.0620e-02,  ...,  1.6699e-01,
          -1.3965e-01, -1.6016e-01],
         [-1.8906e+00,  1.1406e+00,  4.0039e-01,  ...,  1.2305e-01,
           4.7852e-01, -2.8906e-01],
         [ 5.7031e-01, -1.0469e+00, -5.1172e-01,  ..., -6.7578e-01,
          -2.7734e-01, -1.0781e+00],
         ...,
         [-1.7773e-01, -3.9453e-01, -1.3906e+00,  ...,  5.5078e-01,
          -2.9688e-01, -6.3965e-02],
         [-1.6016e-01, -5.8203e-01, -4.2773e-01,  ...,  9.1016e-01,
          -7.6172e-01,  9.6875e-01],
         [ 1.6484e+00, -1.0312e+00, -1.4922e+00,  ...,  3.8281e-01,
           1.3828e+00, -8.8379e-02]],

        [[ 8.1787e-03,  1.9165e-02, -1.0620e-02,  ...,  1.6699e-01,
          -1.3965e-01, -1.6016e-01],
         [-1.8906e+00,  1.1406e+00,  4.0039e-01,  ...,  1.2305e-01,
           4.7852e-01, -2.8906e-01],
         [ 5.7031e-01, -1.0469e+00, -5.1172e-01,  ..., -6.7578e-01,
          -2.7734e-01, -1.0781e+00],
         ...,
         [-1.7773e-01, -3.9453e-01, -1.3906e+00,  ...,  5.5078e-01,
          -2.9688e-01, -6.3965e-02],
         [-1.6016e-01, -5.8203e-01, -4.2773e-01,  ...,  9.1016e-01,
          -7.6172e-01,  9.6875e-01],
         [ 1.6484e+00, -1.0312e+00, -1.4922e+00,  ...,  3.8281e-01,
           1.3828e+00, -8.8379e-02]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[-0.0254, -0.0189,  0.0183,  ..., -0.0042,  0.1099, -0.0179],
         [-0.3301,  0.7344,  1.5312,  ..., -0.8516, -0.1631,  0.8555],
         [ 0.0043,  0.4824, -0.2041,  ..., -0.1660,  0.5938,  0.3242],
         ...,
         [-0.1543,  0.1572,  0.0144,  ..., -1.6328,  0.8789,  1.7969],
         [-0.2715, -0.2930,  0.0708,  ..., -1.1562, -1.2656,  1.4141],
         [ 0.0659,  0.4883,  0.2578,  ...,  0.1475, -0.5508, -0.7031]],

        [[-0.0254, -0.0189,  0.0183,  ..., -0.0042,  0.1099, -0.0179],
         [-0.3301,  0.7344,  1.5312,  ..., -0.8516, -0.1631,  0.8555],
         [ 0.0043,  0.4824, -0.2041,  ..., -0.1660,  0.5938,  0.3242],
         ...,
         [-0.1543,  0.1572,  0.0144,  ..., -1.6328,  0.8789,  1.7969],
         [-0.2715, -0.2930,  0.0708,  ..., -1.1562, -1.2656,  1.4141],
         [ 0.0659,  0.4883,  0.2578,  ...,  0.1475, -0.5508, -0.7031]],

        [[-0.0254, -0.0189,  0.0183,  ..., -0.0042,  0.1099, -0.0179],
         [-0.3301,  0.7344,  1.5312,  ..., -0.8516, -0.1631,  0.8555],
         [ 0.0043,  0.4824, -0.2041,  ..., -0.1660,  0.5938,  0.3242],
         ...,
         [-0.1543,  0.1572,  0.0144,  ..., -1.6328,  0.8789,  1.7969],
         [-0.2715, -0.2930,  0.0708,  ..., -1.1562, -1.2656,  1.4141],
         [ 0.0659,  0.4883,  0.2578,  ...,  0.1475, -0.5508, -0.7031]],

        ...,

        [[-0.0245,  0.0131,  0.0088,  ...,  0.0496, -0.0391, -0.0024],
         [-0.2373, -0.2256,  0.3340,  ..., -0.0225, -0.1021,  0.9844],
         [-0.2637, -1.0156, -0.0444,  ..., -0.5508, -0.2002, -0.1064],
         ...,
         [ 0.4883,  0.1514,  1.5000,  ..., -2.3125,  0.5117,  0.1934],
         [ 0.4023, -0.1167, -0.0220,  ..., -1.4219,  0.4004,  0.7227],
         [ 0.4785,  0.1875, -0.4512,  ...,  0.1953, -0.0601, -0.0166]],

        [[-0.0245,  0.0131,  0.0088,  ...,  0.0496, -0.0391, -0.0024],
         [-0.2373, -0.2256,  0.3340,  ..., -0.0225, -0.1021,  0.9844],
         [-0.2637, -1.0156, -0.0444,  ..., -0.5508, -0.2002, -0.1064],
         ...,
         [ 0.4883,  0.1514,  1.5000,  ..., -2.3125,  0.5117,  0.1934],
         [ 0.4023, -0.1167, -0.0220,  ..., -1.4219,  0.4004,  0.7227],
         [ 0.4785,  0.1875, -0.4512,  ...,  0.1953, -0.0601, -0.0166]],

        [[-0.0245,  0.0131,  0.0088,  ...,  0.0496, -0.0391, -0.0024],
         [-0.2373, -0.2256,  0.3340,  ..., -0.0225, -0.1021,  0.9844],
         [-0.2637, -1.0156, -0.0444,  ..., -0.5508, -0.2002, -0.1064],
         ...,
         [ 0.4883,  0.1514,  1.5000,  ..., -2.3125,  0.5117,  0.1934],
         [ 0.4023, -0.1167, -0.0220,  ..., -1.4219,  0.4004,  0.7227],
         [ 0.4785,  0.1875, -0.4512,  ...,  0.1953, -0.0601, -0.0166]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)), (tensor([[[ 2.4048e-02,  1.5381e-02, -2.6489e-02,  ...,  1.5442e-02,
           1.6094e+00, -2.5586e-01],
         [-1.9297e+00,  1.5312e+00,  3.8281e-01,  ..., -1.4844e+00,
          -4.0000e+00,  7.5000e-01],
         [-4.2188e-01,  1.6562e+00,  1.6953e+00,  ...,  1.9453e+00,
          -5.9062e+00,  1.8555e-01],
         ...,
         [-7.4219e-01,  1.7188e-01, -1.2188e+00,  ..., -3.7031e+00,
          -5.1875e+00,  6.4062e-01],
         [-1.7188e-01, -8.3984e-01, -1.4062e+00,  ..., -1.6328e+00,
          -6.4688e+00,  1.0625e+00],
         [ 1.9922e+00, -1.2422e+00, -1.0391e+00,  ..., -1.9141e+00,
          -4.5312e+00,  3.3984e-01]],

        [[ 2.4048e-02,  1.5381e-02, -2.6489e-02,  ...,  1.5442e-02,
           1.6094e+00, -2.5586e-01],
         [-1.9297e+00,  1.5312e+00,  3.8281e-01,  ..., -1.4844e+00,
          -4.0000e+00,  7.5000e-01],
         [-4.2188e-01,  1.6562e+00,  1.6953e+00,  ...,  1.9453e+00,
          -5.9062e+00,  1.8555e-01],
         ...,
         [-7.4219e-01,  1.7188e-01, -1.2188e+00,  ..., -3.7031e+00,
          -5.1875e+00,  6.4062e-01],
         [-1.7188e-01, -8.3984e-01, -1.4062e+00,  ..., -1.6328e+00,
          -6.4688e+00,  1.0625e+00],
         [ 1.9922e+00, -1.2422e+00, -1.0391e+00,  ..., -1.9141e+00,
          -4.5312e+00,  3.3984e-01]],

        [[ 2.4048e-02,  1.5381e-02, -2.6489e-02,  ...,  1.5442e-02,
           1.6094e+00, -2.5586e-01],
         [-1.9297e+00,  1.5312e+00,  3.8281e-01,  ..., -1.4844e+00,
          -4.0000e+00,  7.5000e-01],
         [-4.2188e-01,  1.6562e+00,  1.6953e+00,  ...,  1.9453e+00,
          -5.9062e+00,  1.8555e-01],
         ...,
         [-7.4219e-01,  1.7188e-01, -1.2188e+00,  ..., -3.7031e+00,
          -5.1875e+00,  6.4062e-01],
         [-1.7188e-01, -8.3984e-01, -1.4062e+00,  ..., -1.6328e+00,
          -6.4688e+00,  1.0625e+00],
         [ 1.9922e+00, -1.2422e+00, -1.0391e+00,  ..., -1.9141e+00,
          -4.5312e+00,  3.3984e-01]],

        ...,

        [[-1.1475e-02,  4.6997e-03, -4.4861e-03,  ...,  4.9438e-03,
           5.2002e-02, -1.9844e+00],
         [-9.8047e-01, -8.5547e-01,  8.7500e-01,  ...,  9.5312e-01,
           1.0681e-02,  5.5000e+00],
         [-2.0312e+00, -2.3594e+00, -6.3672e-01,  ..., -8.6328e-01,
          -3.6719e-01,  7.7500e+00],
         ...,
         [-1.4922e+00, -1.3125e+00, -1.1328e+00,  ...,  7.7344e-01,
           7.3828e-01,  4.9375e+00],
         [-4.8633e-01, -8.9453e-01, -1.8359e-01,  ...,  1.5859e+00,
           1.5234e+00,  6.5938e+00],
         [-4.5117e-01, -6.8359e-01, -9.2578e-01,  ...,  2.7031e+00,
           3.2812e-01,  5.8438e+00]],

        [[-1.1475e-02,  4.6997e-03, -4.4861e-03,  ...,  4.9438e-03,
           5.2002e-02, -1.9844e+00],
         [-9.8047e-01, -8.5547e-01,  8.7500e-01,  ...,  9.5312e-01,
           1.0681e-02,  5.5000e+00],
         [-2.0312e+00, -2.3594e+00, -6.3672e-01,  ..., -8.6328e-01,
          -3.6719e-01,  7.7500e+00],
         ...,
         [-1.4922e+00, -1.3125e+00, -1.1328e+00,  ...,  7.7344e-01,
           7.3828e-01,  4.9375e+00],
         [-4.8633e-01, -8.9453e-01, -1.8359e-01,  ...,  1.5859e+00,
           1.5234e+00,  6.5938e+00],
         [-4.5117e-01, -6.8359e-01, -9.2578e-01,  ...,  2.7031e+00,
           3.2812e-01,  5.8438e+00]],

        [[-1.1475e-02,  4.6997e-03, -4.4861e-03,  ...,  4.9438e-03,
           5.2002e-02, -1.9844e+00],
         [-9.8047e-01, -8.5547e-01,  8.7500e-01,  ...,  9.5312e-01,
           1.0681e-02,  5.5000e+00],
         [-2.0312e+00, -2.3594e+00, -6.3672e-01,  ..., -8.6328e-01,
          -3.6719e-01,  7.7500e+00],
         ...,
         [-1.4922e+00, -1.3125e+00, -1.1328e+00,  ...,  7.7344e-01,
           7.3828e-01,  4.9375e+00],
         [-4.8633e-01, -8.9453e-01, -1.8359e-01,  ...,  1.5859e+00,
           1.5234e+00,  6.5938e+00],
         [-4.5117e-01, -6.8359e-01, -9.2578e-01,  ...,  2.7031e+00,
           3.2812e-01,  5.8438e+00]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>), tensor([[[ 9.2163e-03,  1.6724e-02,  7.9346e-03,  ..., -2.8442e-02,
          -1.2573e-02,  1.4038e-02],
         [ 3.3984e-01,  3.1836e-01,  3.7305e-01,  ...,  4.7852e-01,
          -5.5078e-01,  4.5703e-01],
         [-2.9883e-01, -6.1328e-01,  1.1094e+00,  ..., -9.4531e-01,
           6.4844e-01, -4.6680e-01],
         ...,
         [ 4.6680e-01,  5.5078e-01, -1.4922e+00,  ...,  6.7969e-01,
           1.6406e+00,  3.8867e-01],
         [ 1.1797e+00,  1.1250e+00, -2.2168e-01,  ...,  1.5234e+00,
           3.6719e+00, -3.1445e-01],
         [ 1.0781e+00,  1.3047e+00,  8.6719e-01,  ...,  1.2344e+00,
           1.0469e+00, -8.6719e-01]],

        [[ 9.2163e-03,  1.6724e-02,  7.9346e-03,  ..., -2.8442e-02,
          -1.2573e-02,  1.4038e-02],
         [ 3.3984e-01,  3.1836e-01,  3.7305e-01,  ...,  4.7852e-01,
          -5.5078e-01,  4.5703e-01],
         [-2.9883e-01, -6.1328e-01,  1.1094e+00,  ..., -9.4531e-01,
           6.4844e-01, -4.6680e-01],
         ...,
         [ 4.6680e-01,  5.5078e-01, -1.4922e+00,  ...,  6.7969e-01,
           1.6406e+00,  3.8867e-01],
         [ 1.1797e+00,  1.1250e+00, -2.2168e-01,  ...,  1.5234e+00,
           3.6719e+00, -3.1445e-01],
         [ 1.0781e+00,  1.3047e+00,  8.6719e-01,  ...,  1.2344e+00,
           1.0469e+00, -8.6719e-01]],

        [[ 9.2163e-03,  1.6724e-02,  7.9346e-03,  ..., -2.8442e-02,
          -1.2573e-02,  1.4038e-02],
         [ 3.3984e-01,  3.1836e-01,  3.7305e-01,  ...,  4.7852e-01,
          -5.5078e-01,  4.5703e-01],
         [-2.9883e-01, -6.1328e-01,  1.1094e+00,  ..., -9.4531e-01,
           6.4844e-01, -4.6680e-01],
         ...,
         [ 4.6680e-01,  5.5078e-01, -1.4922e+00,  ...,  6.7969e-01,
           1.6406e+00,  3.8867e-01],
         [ 1.1797e+00,  1.1250e+00, -2.2168e-01,  ...,  1.5234e+00,
           3.6719e+00, -3.1445e-01],
         [ 1.0781e+00,  1.3047e+00,  8.6719e-01,  ...,  1.2344e+00,
           1.0469e+00, -8.6719e-01]],

        ...,

        [[ 1.2817e-02,  1.9897e-02, -2.3682e-02,  ...,  7.8201e-05,
          -3.9795e-02,  1.3306e-02],
         [ 1.4922e+00, -5.1172e-01, -7.1094e-01,  ...,  1.6016e+00,
          -3.3594e-01,  2.4707e-01],
         [ 2.9297e-01, -3.4062e+00, -3.9648e-01,  ..., -7.1094e-01,
          -1.3828e+00, -2.6875e+00],
         ...,
         [-4.9414e-01, -1.2305e-01, -1.0791e-01,  ...,  3.7695e-01,
          -1.5078e+00,  1.8594e+00],
         [-6.6016e-01, -1.2266e+00,  8.8281e-01,  ..., -1.4609e+00,
           3.7598e-02, -1.5234e+00],
         [ 1.1250e+00,  1.8359e+00,  9.0625e-01,  ..., -1.4922e+00,
          -1.9922e-01,  7.0312e-01]],

        [[ 1.2817e-02,  1.9897e-02, -2.3682e-02,  ...,  7.8201e-05,
          -3.9795e-02,  1.3306e-02],
         [ 1.4922e+00, -5.1172e-01, -7.1094e-01,  ...,  1.6016e+00,
          -3.3594e-01,  2.4707e-01],
         [ 2.9297e-01, -3.4062e+00, -3.9648e-01,  ..., -7.1094e-01,
          -1.3828e+00, -2.6875e+00],
         ...,
         [-4.9414e-01, -1.2305e-01, -1.0791e-01,  ...,  3.7695e-01,
          -1.5078e+00,  1.8594e+00],
         [-6.6016e-01, -1.2266e+00,  8.8281e-01,  ..., -1.4609e+00,
           3.7598e-02, -1.5234e+00],
         [ 1.1250e+00,  1.8359e+00,  9.0625e-01,  ..., -1.4922e+00,
          -1.9922e-01,  7.0312e-01]],

        [[ 1.2817e-02,  1.9897e-02, -2.3682e-02,  ...,  7.8201e-05,
          -3.9795e-02,  1.3306e-02],
         [ 1.4922e+00, -5.1172e-01, -7.1094e-01,  ...,  1.6016e+00,
          -3.3594e-01,  2.4707e-01],
         [ 2.9297e-01, -3.4062e+00, -3.9648e-01,  ..., -7.1094e-01,
          -1.3828e+00, -2.6875e+00],
         ...,
         [-4.9414e-01, -1.2305e-01, -1.0791e-01,  ...,  3.7695e-01,
          -1.5078e+00,  1.8594e+00],
         [-6.6016e-01, -1.2266e+00,  8.8281e-01,  ..., -1.4609e+00,
           3.7598e-02, -1.5234e+00],
         [ 1.1250e+00,  1.8359e+00,  9.0625e-01,  ..., -1.4922e+00,
          -1.9922e-01,  7.0312e-01]]], dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)), (tensor([[[-0.2656, -0.6289, -0.9727,  ...,  0.3301,  0.3516, -1.3516],
         [ 0.3125,  0.0537,  0.1123,  ...,  0.2930,  0.2480,  0.2793],
         [-0.2490, -0.2793, -0.0850,  ...,  0.4082, -2.8750, -1.3750],
         ...,
         [ 0.5781,  1.2500,  1.0391,  ...,  4.0625, -3.5156, -2.3125],
         [ 0.6719,  0.7148,  0.9531,  ...,  1.0469, -1.4766, -0.2793],
         [-0.4727, -0.0156,  0.3750,  ...,  2.2344, -2.3281, -0.0830]],

        [[-0.2656, -0.6289, -0.9727,  ...,  0.3301,  0.3516, -1.3516],
         [ 0.3125,  0.0537,  0.1123,  ...,  0.2930,  0.2480,  0.2793],
         [-0.2490, -0.2793, -0.0850,  ...,  0.4082, -2.8750, -1.3750],
         ...,
         [ 0.5781,  1.2500,  1.0391,  ...,  4.0625, -3.5156, -2.3125],
         [ 0.6719,  0.7148,  0.9531,  ...,  1.0469, -1.4766, -0.2793],
         [-0.4727, -0.0156,  0.3750,  ...,  2.2344, -2.3281, -0.0830]],

        [[-0.2656, -0.6289, -0.9727,  ...,  0.3301,  0.3516, -1.3516],
         [ 0.3125,  0.0537,  0.1123,  ...,  0.2930,  0.2480,  0.2793],
         [-0.2490, -0.2793, -0.0850,  ...,  0.4082, -2.8750, -1.3750],
         ...,
         [ 0.5781,  1.2500,  1.0391,  ...,  4.0625, -3.5156, -2.3125],
         [ 0.6719,  0.7148,  0.9531,  ...,  1.0469, -1.4766, -0.2793],
         [-0.4727, -0.0156,  0.3750,  ...,  2.2344, -2.3281, -0.0830]],

        ...,

        [[-0.0134,  0.9180, -0.1533,  ...,  2.3906, -3.7188,  1.6250],
         [ 0.9141, -0.7188, -0.4258,  ..., -0.6445, -0.3828,  1.3047],
         [-0.3672,  0.5781,  0.1060,  ...,  0.2412,  0.2207, -0.4941],
         ...,
         [ 0.4062,  0.6875,  1.2500,  ..., -3.4531,  0.4219, -2.7656],
         [-0.8125,  0.6562,  1.0156,  ..., -2.9688, -4.0938, -0.3340],
         [-1.1797,  0.5312,  0.7070,  ..., -3.2500,  0.8203, -2.3281]],

        [[-0.0134,  0.9180, -0.1533,  ...,  2.3906, -3.7188,  1.6250],
         [ 0.9141, -0.7188, -0.4258,  ..., -0.6445, -0.3828,  1.3047],
         [-0.3672,  0.5781,  0.1060,  ...,  0.2412,  0.2207, -0.4941],
         ...,
         [ 0.4062,  0.6875,  1.2500,  ..., -3.4531,  0.4219, -2.7656],
         [-0.8125,  0.6562,  1.0156,  ..., -2.9688, -4.0938, -0.3340],
         [-1.1797,  0.5312,  0.7070,  ..., -3.2500,  0.8203, -2.3281]],

        [[-0.0134,  0.9180, -0.1533,  ...,  2.3906, -3.7188,  1.6250],
         [ 0.9141, -0.7188, -0.4258,  ..., -0.6445, -0.3828,  1.3047],
         [-0.3672,  0.5781,  0.1060,  ...,  0.2412,  0.2207, -0.4941],
         ...,
         [ 0.4062,  0.6875,  1.2500,  ..., -3.4531,  0.4219, -2.7656],
         [-0.8125,  0.6562,  1.0156,  ..., -2.9688, -4.0938, -0.3340],
         [-1.1797,  0.5312,  0.7070,  ..., -3.2500,  0.8203, -2.3281]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), tensor([[[ 0.4805, -0.1099, -0.3379,  ..., -0.0300,  0.0986,  0.2188],
         [-0.0181, -0.2637, -0.2461,  ..., -0.2080, -0.1777, -0.1348],
         [-0.4180, -0.0654,  0.0275,  ...,  0.0275, -0.0159,  0.0723],
         ...,
         [ 0.2021,  0.0786,  0.4375,  ..., -0.6445,  0.6992, -0.8477],
         [ 0.2148,  0.3418, -0.0889,  ..., -0.0840,  0.3711, -0.3125],
         [ 0.4238,  0.2969,  0.4492,  ...,  0.4297,  0.6289, -0.4531]],

        [[ 0.4805, -0.1099, -0.3379,  ..., -0.0300,  0.0986,  0.2188],
         [-0.0181, -0.2637, -0.2461,  ..., -0.2080, -0.1777, -0.1348],
         [-0.4180, -0.0654,  0.0275,  ...,  0.0275, -0.0159,  0.0723],
         ...,
         [ 0.2021,  0.0786,  0.4375,  ..., -0.6445,  0.6992, -0.8477],
         [ 0.2148,  0.3418, -0.0889,  ..., -0.0840,  0.3711, -0.3125],
         [ 0.4238,  0.2969,  0.4492,  ...,  0.4297,  0.6289, -0.4531]],

        [[ 0.4805, -0.1099, -0.3379,  ..., -0.0300,  0.0986,  0.2188],
         [-0.0181, -0.2637, -0.2461,  ..., -0.2080, -0.1777, -0.1348],
         [-0.4180, -0.0654,  0.0275,  ...,  0.0275, -0.0159,  0.0723],
         ...,
         [ 0.2021,  0.0786,  0.4375,  ..., -0.6445,  0.6992, -0.8477],
         [ 0.2148,  0.3418, -0.0889,  ..., -0.0840,  0.3711, -0.3125],
         [ 0.4238,  0.2969,  0.4492,  ...,  0.4297,  0.6289, -0.4531]],

        ...,

        [[ 0.3242, -0.4062, -0.1709,  ...,  0.2695, -0.5391, -0.3926],
         [ 0.2695, -0.2197,  0.2227,  ..., -0.2734, -0.5898,  0.0435],
         [ 0.8477, -0.6094, -0.0427,  ..., -0.1914, -0.5508, -0.7031],
         ...,
         [-1.0234,  1.2969, -0.0723,  ...,  0.1797,  0.4980,  0.5195],
         [-1.3984,  1.1719, -0.3809,  ...,  0.3809,  0.0420,  0.3145],
         [-0.1367, -0.0142, -0.7539,  ...,  0.3477, -0.4102,  0.3594]],

        [[ 0.3242, -0.4062, -0.1709,  ...,  0.2695, -0.5391, -0.3926],
         [ 0.2695, -0.2197,  0.2227,  ..., -0.2734, -0.5898,  0.0435],
         [ 0.8477, -0.6094, -0.0427,  ..., -0.1914, -0.5508, -0.7031],
         ...,
         [-1.0234,  1.2969, -0.0723,  ...,  0.1797,  0.4980,  0.5195],
         [-1.3984,  1.1719, -0.3809,  ...,  0.3809,  0.0420,  0.3145],
         [-0.1367, -0.0142, -0.7539,  ...,  0.3477, -0.4102,  0.3594]],

        [[ 0.3242, -0.4062, -0.1709,  ...,  0.2695, -0.5391, -0.3926],
         [ 0.2695, -0.2197,  0.2227,  ..., -0.2734, -0.5898,  0.0435],
         [ 0.8477, -0.6094, -0.0427,  ..., -0.1914, -0.5508, -0.7031],
         ...,
         [-1.0234,  1.2969, -0.0723,  ...,  0.1797,  0.4980,  0.5195],
         [-1.3984,  1.1719, -0.3809,  ...,  0.3809,  0.0420,  0.3145],
         [-0.1367, -0.0142, -0.7539,  ...,  0.3477, -0.4102,  0.3594]]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>))), hidden_states=None, attentions=None, cross_attentions=None)
top = torch.topk(torch.softmax(output[0][0][-1], 0), k=20)

for ix, token_id in enumerate(list(top.indices)):
    print('%s %0.5f' % (tokenizer.decode(token_id), top.values[ix].item()))
 His 0.23145
 Pedro 0.07520
zy 0.04590
 Mar 0.04102
 Donald 0.02307
 i 0.02026
 David 0.01550
 Mate 0.01361
 his 0.01361
 And 0.01318
 Franc 0.01221
owi 0.01215
, 0.01166
 Be 0.01160
 K 0.01147
 J 0.01105
 Jose 0.01086
 José 0.01074
 z 0.00995
 Raj 0.00854
tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize("Tłumaczenie na język angielski zdania 'Jem kanapkę z serem' to: I eat a"))
with torch.no_grad():
    output = pipeline.model(torch.tensor([tokens]))
top = torch.topk(torch.softmax(output[0][0][-1], 0), k=20)

for ix, token_id in enumerate(list(top.indices)):
    print('%s %0.5f' % (tokenizer.decode(token_id), top.values[ix].item()))
 cheese 0.47461
 sandwich 0.46094
 ham 0.00879
 chees 0.00854
 sand 0.00781
 slice 0.00433
 toast 0.00337
 bread 0.00336
  0.00278
 snack 0.00224
 piece 0.00218
 toasted 0.00187
 grilled 0.00174
 s 0.00117
 cheddar 0.00092
 che 0.00087
 sandwiches 0.00071
 cheesy 0.00054
 ser 0.00048
 roll 0.00048
output
CausalLMOutputWithCrossAttentions(loss={'logits': tensor([[[ -8.7500, -10.6250, -11.9375,  ..., -10.8125, -12.0000,  -9.5000],
         [ -9.1250, -10.5000, -12.1250,  ...,  -8.1250,  -9.4375,  -7.1562],
         [-16.1250, -22.2500, -24.0000,  ..., -19.1250, -19.5000, -18.6250],
         ...,
         [-13.8750, -16.2500, -20.3750,  ..., -12.1250, -17.5000, -11.3750],
         [-13.4375, -14.7500, -19.0000,  ..., -15.5000, -17.7500, -14.8125],
         [-13.3750, -15.8125, -17.7500,  ..., -11.9375, -16.0000,  -9.5000]]],
       dtype=torch.bfloat16), 'past_key_values': ((tensor([[[ 0.5547, -1.4062, -0.3340,  ...,  0.1514, -0.9297, -1.0312],
         [-0.4785, -1.4062, -0.4316,  ...,  0.0583, -1.4062,  0.1992],
         [-2.2500, -2.3281,  0.4707,  ..., -1.1562, -0.1973,  0.9883],
         ...,
         [-0.1641,  1.3672, -0.2100,  ...,  0.0522, -1.4062,  0.1992],
         [-0.1484, -0.2168, -0.0801,  ..., -0.9414, -0.6211, -1.1953],
         [ 1.5781, -1.5859,  0.1396,  ..., -0.1191, -0.2041, -0.2812]],

        [[ 0.5547, -1.4062, -0.3340,  ...,  0.1514, -0.9297, -1.0312],
         [-0.4785, -1.4062, -0.4316,  ...,  0.0583, -1.4062,  0.1992],
         [-2.2500, -2.3281,  0.4707,  ..., -1.1562, -0.1973,  0.9883],
         ...,
         [-0.1641,  1.3672, -0.2100,  ...,  0.0522, -1.4062,  0.1992],
         [-0.1484, -0.2168, -0.0801,  ..., -0.9414, -0.6211, -1.1953],
         [ 1.5781, -1.5859,  0.1396,  ..., -0.1191, -0.2041, -0.2812]],

        [[ 0.5547, -1.4062, -0.3340,  ...,  0.1514, -0.9297, -1.0312],
         [-0.4785, -1.4062, -0.4316,  ...,  0.0583, -1.4062,  0.1992],
         [-2.2500, -2.3281,  0.4707,  ..., -1.1562, -0.1973,  0.9883],
         ...,
         [-0.1641,  1.3672, -0.2100,  ...,  0.0522, -1.4062,  0.1992],
         [-0.1484, -0.2168, -0.0801,  ..., -0.9414, -0.6211, -1.1953],
         [ 1.5781, -1.5859,  0.1396,  ..., -0.1191, -0.2041, -0.2812]],

        ...,

        [[ 0.4551, -0.1377, -0.2383,  ...,  0.9648, -1.7422,  0.1562],
         [ 0.2422, -1.1875, -0.4629,  ..., -0.1816, -0.8867, -0.7070],
         [-1.3672, -1.3750, -1.0781,  ...,  0.2119,  1.6172,  0.6758],
         ...,
         [-1.4375,  1.2188, -0.1250,  ..., -0.1787, -0.8945, -0.7070],
         [-0.1367,  0.2949, -0.0059,  ...,  0.7617, -1.8203,  1.6328],
         [ 1.4922,  0.6953, -1.4297,  ...,  0.9961,  1.2109,  1.5938]],

        [[ 0.4551, -0.1377, -0.2383,  ...,  0.9648, -1.7422,  0.1562],
         [ 0.2422, -1.1875, -0.4629,  ..., -0.1816, -0.8867, -0.7070],
         [-1.3672, -1.3750, -1.0781,  ...,  0.2119,  1.6172,  0.6758],
         ...,
         [-1.4375,  1.2188, -0.1250,  ..., -0.1787, -0.8945, -0.7070],
         [-0.1367,  0.2949, -0.0059,  ...,  0.7617, -1.8203,  1.6328],
         [ 1.4922,  0.6953, -1.4297,  ...,  0.9961,  1.2109,  1.5938]],

        [[ 0.4551, -0.1377, -0.2383,  ...,  0.9648, -1.7422,  0.1562],
         [ 0.2422, -1.1875, -0.4629,  ..., -0.1816, -0.8867, -0.7070],
         [-1.3672, -1.3750, -1.0781,  ...,  0.2119,  1.6172,  0.6758],
         ...,
         [-1.4375,  1.2188, -0.1250,  ..., -0.1787, -0.8945, -0.7070],
         [-0.1367,  0.2949, -0.0059,  ...,  0.7617, -1.8203,  1.6328],
         [ 1.4922,  0.6953, -1.4297,  ...,  0.9961,  1.2109,  1.5938]]],
       dtype=torch.bfloat16), tensor([[[-0.0036,  0.0417,  0.0364,  ..., -0.0087, -0.0391, -0.0474],
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [-0.0591, -0.0118, -0.1279,  ..., -0.1592,  0.1338,  0.0255],
         ...,
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [ 0.0087,  0.0205,  0.0557,  ...,  0.0085,  0.0039, -0.0454],
         [-0.0659,  0.0427,  0.1006,  ...,  0.1055, -0.0527, -0.0339]],

        [[-0.0036,  0.0417,  0.0364,  ..., -0.0087, -0.0391, -0.0474],
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [-0.0591, -0.0118, -0.1279,  ..., -0.1592,  0.1338,  0.0255],
         ...,
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [ 0.0087,  0.0205,  0.0557,  ...,  0.0085,  0.0039, -0.0454],
         [-0.0659,  0.0427,  0.1006,  ...,  0.1055, -0.0527, -0.0339]],

        [[-0.0036,  0.0417,  0.0364,  ..., -0.0087, -0.0391, -0.0474],
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [-0.0591, -0.0118, -0.1279,  ..., -0.1592,  0.1338,  0.0255],
         ...,
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [ 0.0087,  0.0205,  0.0557,  ...,  0.0085,  0.0039, -0.0454],
         [-0.0659,  0.0427,  0.1006,  ...,  0.1055, -0.0527, -0.0339]],

        ...,

        [[-0.0593, -0.0640, -0.0276,  ..., -0.0116, -0.0459, -0.0016],
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0708, -0.0918,  0.2285,  ..., -0.0635,  0.1396,  0.0603],
         ...,
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0923,  0.0505,  0.0068,  ...,  0.0239, -0.0119,  0.0031],
         [-0.2773, -0.3125, -0.3086,  ...,  0.0464,  0.1826, -0.2871]],

        [[-0.0593, -0.0640, -0.0276,  ..., -0.0116, -0.0459, -0.0016],
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0708, -0.0918,  0.2285,  ..., -0.0635,  0.1396,  0.0603],
         ...,
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0923,  0.0505,  0.0068,  ...,  0.0239, -0.0119,  0.0031],
         [-0.2773, -0.3125, -0.3086,  ...,  0.0464,  0.1826, -0.2871]],

        [[-0.0593, -0.0640, -0.0276,  ..., -0.0116, -0.0459, -0.0016],
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0708, -0.0918,  0.2285,  ..., -0.0635,  0.1396,  0.0603],
         ...,
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0923,  0.0505,  0.0068,  ...,  0.0239, -0.0119,  0.0031],
         [-0.2773, -0.3125, -0.3086,  ...,  0.0464,  0.1826, -0.2871]]],
       dtype=torch.bfloat16)), (tensor([[[ 2.4531,  1.6719, -4.0938,  ..., -1.4219,  0.8477,  1.2422],
         [ 7.1250,  2.4375, -2.5938,  ..., -1.1406,  0.6445,  1.5938],
         [ 4.6875,  4.4375, -2.6875,  ..., -0.1709,  0.5898,  3.0156],
         ...,
         [-3.0156, -2.6562, -2.6562,  ..., -2.2656,  0.7383,  3.3906],
         [-2.3125, -2.8281, -2.1562,  ...,  0.1621,  2.0000, -1.0938],
         [-5.6250, -4.4062, -2.9062,  ..., -1.9922,  1.0938,  1.3516]],

        [[ 2.4531,  1.6719, -4.0938,  ..., -1.4219,  0.8477,  1.2422],
         [ 7.1250,  2.4375, -2.5938,  ..., -1.1406,  0.6445,  1.5938],
         [ 4.6875,  4.4375, -2.6875,  ..., -0.1709,  0.5898,  3.0156],
         ...,
         [-3.0156, -2.6562, -2.6562,  ..., -2.2656,  0.7383,  3.3906],
         [-2.3125, -2.8281, -2.1562,  ...,  0.1621,  2.0000, -1.0938],
         [-5.6250, -4.4062, -2.9062,  ..., -1.9922,  1.0938,  1.3516]],

        [[ 2.4531,  1.6719, -4.0938,  ..., -1.4219,  0.8477,  1.2422],
         [ 7.1250,  2.4375, -2.5938,  ..., -1.1406,  0.6445,  1.5938],
         [ 4.6875,  4.4375, -2.6875,  ..., -0.1709,  0.5898,  3.0156],
         ...,
         [-3.0156, -2.6562, -2.6562,  ..., -2.2656,  0.7383,  3.3906],
         [-2.3125, -2.8281, -2.1562,  ...,  0.1621,  2.0000, -1.0938],
         [-5.6250, -4.4062, -2.9062,  ..., -1.9922,  1.0938,  1.3516]],

        ...,

        [[ 0.1846,  0.7070, -2.5938,  ..., -1.6562,  3.1719, -1.0938],
         [ 3.0312,  1.2500, -1.8672,  ..., -2.9219,  2.3594, -0.6680],
         [ 5.0625,  0.7070, -0.4258,  ..., -4.0938,  3.5781, -3.2344],
         ...,
         [-0.9375, -2.8438, -1.4531,  ..., -2.8594,  4.0312, -3.2188],
         [ 0.1758, -0.5859, -0.0127,  ..., -4.4688,  2.4375, -1.9375],
         [-5.9688, -1.0625, -4.2500,  ..., -2.6406,  3.4844, -0.4336]],

        [[ 0.1846,  0.7070, -2.5938,  ..., -1.6562,  3.1719, -1.0938],
         [ 3.0312,  1.2500, -1.8672,  ..., -2.9219,  2.3594, -0.6680],
         [ 5.0625,  0.7070, -0.4258,  ..., -4.0938,  3.5781, -3.2344],
         ...,
         [-0.9375, -2.8438, -1.4531,  ..., -2.8594,  4.0312, -3.2188],
         [ 0.1758, -0.5859, -0.0127,  ..., -4.4688,  2.4375, -1.9375],
         [-5.9688, -1.0625, -4.2500,  ..., -2.6406,  3.4844, -0.4336]],

        [[ 0.1846,  0.7070, -2.5938,  ..., -1.6562,  3.1719, -1.0938],
         [ 3.0312,  1.2500, -1.8672,  ..., -2.9219,  2.3594, -0.6680],
         [ 5.0625,  0.7070, -0.4258,  ..., -4.0938,  3.5781, -3.2344],
         ...,
         [-0.9375, -2.8438, -1.4531,  ..., -2.8594,  4.0312, -3.2188],
         [ 0.1758, -0.5859, -0.0127,  ..., -4.4688,  2.4375, -1.9375],
         [-5.9688, -1.0625, -4.2500,  ..., -2.6406,  3.4844, -0.4336]]],
       dtype=torch.bfloat16), tensor([[[ 0.0366,  0.0337,  0.0234,  ...,  0.1299,  0.2031, -0.1338],
         [-0.0269, -0.0304,  0.0752,  ..., -0.1118, -0.0603,  0.0413],
         [-0.0747,  0.0601,  0.0187,  ..., -0.0109, -0.2598, -0.1670],
         ...,
         [ 0.1973, -0.0928,  0.0315,  ..., -0.1162,  0.1777,  0.1699],
         [-0.0259, -0.0040,  0.0400,  ...,  0.0557, -0.0082, -0.0427],
         [-0.1729,  0.1797,  0.0096,  ...,  0.0938, -0.3145, -0.1206]],

        [[ 0.0366,  0.0337,  0.0234,  ...,  0.1299,  0.2031, -0.1338],
         [-0.0269, -0.0304,  0.0752,  ..., -0.1118, -0.0603,  0.0413],
         [-0.0747,  0.0601,  0.0187,  ..., -0.0109, -0.2598, -0.1670],
         ...,
         [ 0.1973, -0.0928,  0.0315,  ..., -0.1162,  0.1777,  0.1699],
         [-0.0259, -0.0040,  0.0400,  ...,  0.0557, -0.0082, -0.0427],
         [-0.1729,  0.1797,  0.0096,  ...,  0.0938, -0.3145, -0.1206]],

        [[ 0.0366,  0.0337,  0.0234,  ...,  0.1299,  0.2031, -0.1338],
         [-0.0269, -0.0304,  0.0752,  ..., -0.1118, -0.0603,  0.0413],
         [-0.0747,  0.0601,  0.0187,  ..., -0.0109, -0.2598, -0.1670],
         ...,
         [ 0.1973, -0.0928,  0.0315,  ..., -0.1162,  0.1777,  0.1699],
         [-0.0259, -0.0040,  0.0400,  ...,  0.0557, -0.0082, -0.0427],
         [-0.1729,  0.1797,  0.0096,  ...,  0.0938, -0.3145, -0.1206]],

        ...,

        [[ 0.0688,  0.0198,  0.0096,  ..., -0.0469, -0.0825, -0.0283],
         [-0.0437,  0.0231, -0.0981,  ...,  0.0354, -0.0835, -0.0356],
         [-0.1641,  0.0330, -0.0334,  ...,  0.0674, -0.1543,  0.1328],
         ...,
         [ 0.0874, -0.0222, -0.1260,  ...,  0.0718, -0.0569, -0.0508],
         [-0.0073, -0.0215,  0.0016,  ...,  0.0200,  0.0398,  0.0527],
         [ 0.0408,  0.1011,  0.2637,  ..., -0.0903,  0.0087,  0.0859]],

        [[ 0.0688,  0.0198,  0.0096,  ..., -0.0469, -0.0825, -0.0283],
         [-0.0437,  0.0231, -0.0981,  ...,  0.0354, -0.0835, -0.0356],
         [-0.1641,  0.0330, -0.0334,  ...,  0.0674, -0.1543,  0.1328],
         ...,
         [ 0.0874, -0.0222, -0.1260,  ...,  0.0718, -0.0569, -0.0508],
         [-0.0073, -0.0215,  0.0016,  ...,  0.0200,  0.0398,  0.0527],
         [ 0.0408,  0.1011,  0.2637,  ..., -0.0903,  0.0087,  0.0859]],

        [[ 0.0688,  0.0198,  0.0096,  ..., -0.0469, -0.0825, -0.0283],
         [-0.0437,  0.0231, -0.0981,  ...,  0.0354, -0.0835, -0.0356],
         [-0.1641,  0.0330, -0.0334,  ...,  0.0674, -0.1543,  0.1328],
         ...,
         [ 0.0874, -0.0222, -0.1260,  ...,  0.0718, -0.0569, -0.0508],
         [-0.0073, -0.0215,  0.0016,  ...,  0.0200,  0.0398,  0.0527],
         [ 0.0408,  0.1011,  0.2637,  ..., -0.0903,  0.0087,  0.0859]]],
       dtype=torch.bfloat16)), (tensor([[[-2.1820e-03, -5.0964e-03,  8.5831e-04,  ..., -9.5312e-01,
           6.8970e-03,  4.1406e-01],
         [ 1.4219e+00,  1.4141e+00,  2.2188e+00,  ...,  6.9062e+00,
           1.8164e-01, -1.7188e+00],
         [ 1.3672e+00,  7.8613e-02,  1.4531e+00,  ...,  6.3438e+00,
          -1.0547e+00, -2.1562e+00],
         ...,
         [ 8.2422e-01, -7.6953e-01,  1.9297e+00,  ...,  7.3125e+00,
          -2.0625e+00, -2.8750e+00],
         [ 7.0312e-02,  1.0352e-01,  2.8125e-01,  ...,  6.2500e+00,
          -7.6953e-01, -1.5312e+00],
         [-1.4062e+00, -1.1562e+00,  6.0547e-01,  ...,  6.2188e+00,
           1.2031e+00, -1.1016e+00]],

        [[-2.1820e-03, -5.0964e-03,  8.5831e-04,  ..., -9.5312e-01,
           6.8970e-03,  4.1406e-01],
         [ 1.4219e+00,  1.4141e+00,  2.2188e+00,  ...,  6.9062e+00,
           1.8164e-01, -1.7188e+00],
         [ 1.3672e+00,  7.8613e-02,  1.4531e+00,  ...,  6.3438e+00,
          -1.0547e+00, -2.1562e+00],
         ...,
         [ 8.2422e-01, -7.6953e-01,  1.9297e+00,  ...,  7.3125e+00,
          -2.0625e+00, -2.8750e+00],
         [ 7.0312e-02,  1.0352e-01,  2.8125e-01,  ...,  6.2500e+00,
          -7.6953e-01, -1.5312e+00],
         [-1.4062e+00, -1.1562e+00,  6.0547e-01,  ...,  6.2188e+00,
           1.2031e+00, -1.1016e+00]],

        [[-2.1820e-03, -5.0964e-03,  8.5831e-04,  ..., -9.5312e-01,
           6.8970e-03,  4.1406e-01],
         [ 1.4219e+00,  1.4141e+00,  2.2188e+00,  ...,  6.9062e+00,
           1.8164e-01, -1.7188e+00],
         [ 1.3672e+00,  7.8613e-02,  1.4531e+00,  ...,  6.3438e+00,
          -1.0547e+00, -2.1562e+00],
         ...,
         [ 8.2422e-01, -7.6953e-01,  1.9297e+00,  ...,  7.3125e+00,
          -2.0625e+00, -2.8750e+00],
         [ 7.0312e-02,  1.0352e-01,  2.8125e-01,  ...,  6.2500e+00,
          -7.6953e-01, -1.5312e+00],
         [-1.4062e+00, -1.1562e+00,  6.0547e-01,  ...,  6.2188e+00,
           1.2031e+00, -1.1016e+00]],

        ...,

        [[ 1.0315e-02,  7.4005e-04, -5.4626e-03,  ..., -4.5654e-02,
           9.8145e-02,  1.5015e-02],
         [ 4.0000e+00,  1.5156e+00,  1.7188e+00,  ...,  8.3008e-02,
          -2.4844e+00,  1.7188e+00],
         [ 2.1875e+00, -4.4434e-02,  1.3672e+00,  ...,  1.4609e+00,
          -1.6250e+00,  1.2578e+00],
         ...,
         [-2.4062e+00, -4.2383e-01,  2.0156e+00,  ...,  1.1406e+00,
          -1.5859e+00,  8.0469e-01],
         [-1.1875e+00, -2.2363e-01,  1.2422e+00,  ...,  1.6016e+00,
           4.3164e-01,  2.5000e+00],
         [-2.5781e+00, -9.8438e-01,  7.4219e-01,  ...,  6.5234e-01,
          -1.0391e+00,  1.8359e-01]],

        [[ 1.0315e-02,  7.4005e-04, -5.4626e-03,  ..., -4.5654e-02,
           9.8145e-02,  1.5015e-02],
         [ 4.0000e+00,  1.5156e+00,  1.7188e+00,  ...,  8.3008e-02,
          -2.4844e+00,  1.7188e+00],
         [ 2.1875e+00, -4.4434e-02,  1.3672e+00,  ...,  1.4609e+00,
          -1.6250e+00,  1.2578e+00],
         ...,
         [-2.4062e+00, -4.2383e-01,  2.0156e+00,  ...,  1.1406e+00,
          -1.5859e+00,  8.0469e-01],
         [-1.1875e+00, -2.2363e-01,  1.2422e+00,  ...,  1.6016e+00,
           4.3164e-01,  2.5000e+00],
         [-2.5781e+00, -9.8438e-01,  7.4219e-01,  ...,  6.5234e-01,
          -1.0391e+00,  1.8359e-01]],

        [[ 1.0315e-02,  7.4005e-04, -5.4626e-03,  ..., -4.5654e-02,
           9.8145e-02,  1.5015e-02],
         [ 4.0000e+00,  1.5156e+00,  1.7188e+00,  ...,  8.3008e-02,
          -2.4844e+00,  1.7188e+00],
         [ 2.1875e+00, -4.4434e-02,  1.3672e+00,  ...,  1.4609e+00,
          -1.6250e+00,  1.2578e+00],
         ...,
         [-2.4062e+00, -4.2383e-01,  2.0156e+00,  ...,  1.1406e+00,
          -1.5859e+00,  8.0469e-01],
         [-1.1875e+00, -2.2363e-01,  1.2422e+00,  ...,  1.6016e+00,
           4.3164e-01,  2.5000e+00],
         [-2.5781e+00, -9.8438e-01,  7.4219e-01,  ...,  6.5234e-01,
          -1.0391e+00,  1.8359e-01]]], dtype=torch.bfloat16), tensor([[[-1.8997e-03, -1.0864e-02,  2.7313e-03,  ..., -5.3711e-03,
           7.7820e-04,  3.0975e-03],
         [ 1.2793e-01,  3.1445e-01,  5.7812e-01,  ...,  5.1953e-01,
          -1.4771e-02, -3.0078e-01],
         [-2.1973e-01,  5.4443e-02, -1.6699e-01,  ..., -5.3516e-01,
           2.7930e-01,  1.0205e-01],
         ...,
         [-1.8262e-01,  3.0273e-02,  2.2461e-01,  ..., -1.0193e-02,
           2.6953e-01,  5.9891e-04],
         [ 3.1250e-01, -3.7354e-02,  4.6631e-02,  ...,  4.3945e-01,
          -9.2285e-02, -1.6895e-01],
         [ 3.7500e-01,  1.1406e+00, -1.2969e+00,  ...,  1.0156e+00,
          -5.2344e-01,  5.2734e-01]],

        [[-1.8997e-03, -1.0864e-02,  2.7313e-03,  ..., -5.3711e-03,
           7.7820e-04,  3.0975e-03],
         [ 1.2793e-01,  3.1445e-01,  5.7812e-01,  ...,  5.1953e-01,
          -1.4771e-02, -3.0078e-01],
         [-2.1973e-01,  5.4443e-02, -1.6699e-01,  ..., -5.3516e-01,
           2.7930e-01,  1.0205e-01],
         ...,
         [-1.8262e-01,  3.0273e-02,  2.2461e-01,  ..., -1.0193e-02,
           2.6953e-01,  5.9891e-04],
         [ 3.1250e-01, -3.7354e-02,  4.6631e-02,  ...,  4.3945e-01,
          -9.2285e-02, -1.6895e-01],
         [ 3.7500e-01,  1.1406e+00, -1.2969e+00,  ...,  1.0156e+00,
          -5.2344e-01,  5.2734e-01]],

        [[-1.8997e-03, -1.0864e-02,  2.7313e-03,  ..., -5.3711e-03,
           7.7820e-04,  3.0975e-03],
         [ 1.2793e-01,  3.1445e-01,  5.7812e-01,  ...,  5.1953e-01,
          -1.4771e-02, -3.0078e-01],
         [-2.1973e-01,  5.4443e-02, -1.6699e-01,  ..., -5.3516e-01,
           2.7930e-01,  1.0205e-01],
         ...,
         [-1.8262e-01,  3.0273e-02,  2.2461e-01,  ..., -1.0193e-02,
           2.6953e-01,  5.9891e-04],
         [ 3.1250e-01, -3.7354e-02,  4.6631e-02,  ...,  4.3945e-01,
          -9.2285e-02, -1.6895e-01],
         [ 3.7500e-01,  1.1406e+00, -1.2969e+00,  ...,  1.0156e+00,
          -5.2344e-01,  5.2734e-01]],

        ...,

        [[ 1.1658e-02, -7.6599e-03, -4.5967e-04,  ..., -3.1128e-03,
           3.2349e-03,  9.1934e-04],
         [ 2.6953e-01,  2.1582e-01, -2.3633e-01,  ...,  2.2095e-02,
           2.5195e-01,  5.7373e-02],
         [-3.2031e-01,  1.0315e-02,  5.2979e-02,  ...,  4.0820e-01,
           1.6895e-01, -2.3047e-01],
         ...,
         [-2.9785e-02,  1.4258e-01, -1.9824e-01,  ...,  2.0312e-01,
           1.8848e-01,  2.8711e-01],
         [-2.9688e-01,  1.5234e-01, -1.6797e-01,  ...,  1.9824e-01,
          -2.8711e-01, -4.9072e-02],
         [-1.3184e-01, -1.4160e-01,  2.1973e-02,  ..., -2.5781e-01,
          -2.5977e-01, -4.6680e-01]],

        [[ 1.1658e-02, -7.6599e-03, -4.5967e-04,  ..., -3.1128e-03,
           3.2349e-03,  9.1934e-04],
         [ 2.6953e-01,  2.1582e-01, -2.3633e-01,  ...,  2.2095e-02,
           2.5195e-01,  5.7373e-02],
         [-3.2031e-01,  1.0315e-02,  5.2979e-02,  ...,  4.0820e-01,
           1.6895e-01, -2.3047e-01],
         ...,
         [-2.9785e-02,  1.4258e-01, -1.9824e-01,  ...,  2.0312e-01,
           1.8848e-01,  2.8711e-01],
         [-2.9688e-01,  1.5234e-01, -1.6797e-01,  ...,  1.9824e-01,
          -2.8711e-01, -4.9072e-02],
         [-1.3184e-01, -1.4160e-01,  2.1973e-02,  ..., -2.5781e-01,
          -2.5977e-01, -4.6680e-01]],

        [[ 1.1658e-02, -7.6599e-03, -4.5967e-04,  ..., -3.1128e-03,
           3.2349e-03,  9.1934e-04],
         [ 2.6953e-01,  2.1582e-01, -2.3633e-01,  ...,  2.2095e-02,
           2.5195e-01,  5.7373e-02],
         [-3.2031e-01,  1.0315e-02,  5.2979e-02,  ...,  4.0820e-01,
           1.6895e-01, -2.3047e-01],
         ...,
         [-2.9785e-02,  1.4258e-01, -1.9824e-01,  ...,  2.0312e-01,
           1.8848e-01,  2.8711e-01],
         [-2.9688e-01,  1.5234e-01, -1.6797e-01,  ...,  1.9824e-01,
          -2.8711e-01, -4.9072e-02],
         [-1.3184e-01, -1.4160e-01,  2.1973e-02,  ..., -2.5781e-01,
          -2.5977e-01, -4.6680e-01]]], dtype=torch.bfloat16)), (tensor([[[ 7.3242e-03,  4.6387e-03,  8.1787e-03,  ...,  6.4844e-01,
           5.0000e-01, -3.8477e-01],
         [-1.9688e+00,  1.5000e+00, -7.1484e-01,  ..., -3.2031e+00,
          -2.9375e+00,  2.9883e-01],
         [-5.7031e-01,  5.5078e-01, -1.0781e+00,  ..., -3.3281e+00,
          -3.3906e+00,  1.0469e+00],
         ...,
         [ 5.1172e-01, -8.0469e-01,  1.9434e-01,  ..., -4.8438e+00,
          -3.0312e+00,  8.8867e-02],
         [ 1.2500e-01, -3.7305e-01, -3.4570e-01,  ..., -4.9375e+00,
          -3.5000e+00,  5.1953e-01],
         [ 1.6309e-01,  1.2891e-01,  3.4961e-01,  ..., -5.2500e+00,
          -2.3750e+00,  1.0078e+00]],

        [[ 7.3242e-03,  4.6387e-03,  8.1787e-03,  ...,  6.4844e-01,
           5.0000e-01, -3.8477e-01],
         [-1.9688e+00,  1.5000e+00, -7.1484e-01,  ..., -3.2031e+00,
          -2.9375e+00,  2.9883e-01],
         [-5.7031e-01,  5.5078e-01, -1.0781e+00,  ..., -3.3281e+00,
          -3.3906e+00,  1.0469e+00],
         ...,
         [ 5.1172e-01, -8.0469e-01,  1.9434e-01,  ..., -4.8438e+00,
          -3.0312e+00,  8.8867e-02],
         [ 1.2500e-01, -3.7305e-01, -3.4570e-01,  ..., -4.9375e+00,
          -3.5000e+00,  5.1953e-01],
         [ 1.6309e-01,  1.2891e-01,  3.4961e-01,  ..., -5.2500e+00,
          -2.3750e+00,  1.0078e+00]],

        [[ 7.3242e-03,  4.6387e-03,  8.1787e-03,  ...,  6.4844e-01,
           5.0000e-01, -3.8477e-01],
         [-1.9688e+00,  1.5000e+00, -7.1484e-01,  ..., -3.2031e+00,
          -2.9375e+00,  2.9883e-01],
         [-5.7031e-01,  5.5078e-01, -1.0781e+00,  ..., -3.3281e+00,
          -3.3906e+00,  1.0469e+00],
         ...,
         [ 5.1172e-01, -8.0469e-01,  1.9434e-01,  ..., -4.8438e+00,
          -3.0312e+00,  8.8867e-02],
         [ 1.2500e-01, -3.7305e-01, -3.4570e-01,  ..., -4.9375e+00,
          -3.5000e+00,  5.1953e-01],
         [ 1.6309e-01,  1.2891e-01,  3.4961e-01,  ..., -5.2500e+00,
          -2.3750e+00,  1.0078e+00]],

        ...,

        [[ 2.6398e-03, -5.1575e-03,  3.1586e-03,  ...,  8.0078e-02,
          -5.4321e-03, -2.4219e-01],
         [-5.1875e+00, -8.9062e-01,  7.8516e-01,  ..., -6.6406e-01,
          -1.3281e+00, -6.8750e-01],
         [-3.2188e+00, -5.5664e-02,  1.1719e-02,  ...,  5.1172e-01,
           3.5352e-01, -9.7266e-01],
         ...,
         [ 2.1562e+00,  1.8438e+00,  1.5703e+00,  ..., -6.1328e-01,
           5.1953e-01, -7.0312e-01],
         [ 3.2500e+00, -3.3789e-01,  1.6328e+00,  ...,  7.4219e-01,
           8.8672e-01, -7.8125e-01],
         [ 2.7812e+00, -2.0312e+00, -7.5781e-01,  ..., -1.0391e+00,
           2.1875e-01, -1.5039e-01]],

        [[ 2.6398e-03, -5.1575e-03,  3.1586e-03,  ...,  8.0078e-02,
          -5.4321e-03, -2.4219e-01],
         [-5.1875e+00, -8.9062e-01,  7.8516e-01,  ..., -6.6406e-01,
          -1.3281e+00, -6.8750e-01],
         [-3.2188e+00, -5.5664e-02,  1.1719e-02,  ...,  5.1172e-01,
           3.5352e-01, -9.7266e-01],
         ...,
         [ 2.1562e+00,  1.8438e+00,  1.5703e+00,  ..., -6.1328e-01,
           5.1953e-01, -7.0312e-01],
         [ 3.2500e+00, -3.3789e-01,  1.6328e+00,  ...,  7.4219e-01,
           8.8672e-01, -7.8125e-01],
         [ 2.7812e+00, -2.0312e+00, -7.5781e-01,  ..., -1.0391e+00,
           2.1875e-01, -1.5039e-01]],

        [[ 2.6398e-03, -5.1575e-03,  3.1586e-03,  ...,  8.0078e-02,
          -5.4321e-03, -2.4219e-01],
         [-5.1875e+00, -8.9062e-01,  7.8516e-01,  ..., -6.6406e-01,
          -1.3281e+00, -6.8750e-01],
         [-3.2188e+00, -5.5664e-02,  1.1719e-02,  ...,  5.1172e-01,
           3.5352e-01, -9.7266e-01],
         ...,
         [ 2.1562e+00,  1.8438e+00,  1.5703e+00,  ..., -6.1328e-01,
           5.1953e-01, -7.0312e-01],
         [ 3.2500e+00, -3.3789e-01,  1.6328e+00,  ...,  7.4219e-01,
           8.8672e-01, -7.8125e-01],
         [ 2.7812e+00, -2.0312e+00, -7.5781e-01,  ..., -1.0391e+00,
           2.1875e-01, -1.5039e-01]]], dtype=torch.bfloat16), tensor([[[ 5.3406e-04,  6.9580e-03, -4.0283e-03,  ...,  1.3428e-03,
          -3.2349e-03, -1.6174e-03],
         [ 1.1963e-01, -6.2109e-01,  2.1680e-01,  ..., -2.0874e-02,
           1.0193e-02,  2.7930e-01],
         [-3.1055e-01,  4.1016e-01,  8.4375e-01,  ...,  1.8457e-01,
          -7.3730e-02, -3.7305e-01],
         ...,
         [ 1.2891e-01,  4.1992e-01,  4.2188e-01,  ...,  1.4453e-01,
           3.4961e-01, -2.1118e-02],
         [ 7.2266e-01, -7.5000e-01,  1.8750e-01,  ..., -2.8516e-01,
           5.1172e-01,  2.8198e-02],
         [-1.3184e-01,  1.5332e-01,  1.0693e-01,  ...,  4.2578e-01,
          -6.1646e-03, -3.0859e-01]],

        [[ 5.3406e-04,  6.9580e-03, -4.0283e-03,  ...,  1.3428e-03,
          -3.2349e-03, -1.6174e-03],
         [ 1.1963e-01, -6.2109e-01,  2.1680e-01,  ..., -2.0874e-02,
           1.0193e-02,  2.7930e-01],
         [-3.1055e-01,  4.1016e-01,  8.4375e-01,  ...,  1.8457e-01,
          -7.3730e-02, -3.7305e-01],
         ...,
         [ 1.2891e-01,  4.1992e-01,  4.2188e-01,  ...,  1.4453e-01,
           3.4961e-01, -2.1118e-02],
         [ 7.2266e-01, -7.5000e-01,  1.8750e-01,  ..., -2.8516e-01,
           5.1172e-01,  2.8198e-02],
         [-1.3184e-01,  1.5332e-01,  1.0693e-01,  ...,  4.2578e-01,
          -6.1646e-03, -3.0859e-01]],

        [[ 5.3406e-04,  6.9580e-03, -4.0283e-03,  ...,  1.3428e-03,
          -3.2349e-03, -1.6174e-03],
         [ 1.1963e-01, -6.2109e-01,  2.1680e-01,  ..., -2.0874e-02,
           1.0193e-02,  2.7930e-01],
         [-3.1055e-01,  4.1016e-01,  8.4375e-01,  ...,  1.8457e-01,
          -7.3730e-02, -3.7305e-01],
         ...,
         [ 1.2891e-01,  4.1992e-01,  4.2188e-01,  ...,  1.4453e-01,
           3.4961e-01, -2.1118e-02],
         [ 7.2266e-01, -7.5000e-01,  1.8750e-01,  ..., -2.8516e-01,
           5.1172e-01,  2.8198e-02],
         [-1.3184e-01,  1.5332e-01,  1.0693e-01,  ...,  4.2578e-01,
          -6.1646e-03, -3.0859e-01]],

        ...,

        [[-3.9368e-03, -2.0752e-03,  3.1433e-03,  ...,  3.3569e-04,
           3.0212e-03, -7.6675e-04],
         [-3.5547e-01,  2.2339e-02, -1.4648e-01,  ..., -1.3770e-01,
          -2.1973e-01,  1.2024e-02],
         [ 8.7891e-02, -1.6699e-01,  6.2500e-02,  ...,  4.0039e-01,
          -2.6953e-01,  1.7383e-01],
         ...,
         [ 2.7222e-02,  2.2070e-01, -2.7930e-01,  ...,  1.5527e-01,
           8.2031e-02, -3.4790e-03],
         [ 1.3867e-01,  2.8198e-02,  8.3008e-02,  ..., -6.5308e-03,
           2.6367e-01, -1.5137e-01],
         [-1.4062e-01, -2.9297e-01, -4.2578e-01,  ...,  3.3691e-02,
          -9.7656e-02,  6.5918e-02]],

        [[-3.9368e-03, -2.0752e-03,  3.1433e-03,  ...,  3.3569e-04,
           3.0212e-03, -7.6675e-04],
         [-3.5547e-01,  2.2339e-02, -1.4648e-01,  ..., -1.3770e-01,
          -2.1973e-01,  1.2024e-02],
         [ 8.7891e-02, -1.6699e-01,  6.2500e-02,  ...,  4.0039e-01,
          -2.6953e-01,  1.7383e-01],
         ...,
         [ 2.7222e-02,  2.2070e-01, -2.7930e-01,  ...,  1.5527e-01,
           8.2031e-02, -3.4790e-03],
         [ 1.3867e-01,  2.8198e-02,  8.3008e-02,  ..., -6.5308e-03,
           2.6367e-01, -1.5137e-01],
         [-1.4062e-01, -2.9297e-01, -4.2578e-01,  ...,  3.3691e-02,
          -9.7656e-02,  6.5918e-02]],

        [[-3.9368e-03, -2.0752e-03,  3.1433e-03,  ...,  3.3569e-04,
           3.0212e-03, -7.6675e-04],
         [-3.5547e-01,  2.2339e-02, -1.4648e-01,  ..., -1.3770e-01,
          -2.1973e-01,  1.2024e-02],
         [ 8.7891e-02, -1.6699e-01,  6.2500e-02,  ...,  4.0039e-01,
          -2.6953e-01,  1.7383e-01],
         ...,
         [ 2.7222e-02,  2.2070e-01, -2.7930e-01,  ...,  1.5527e-01,
           8.2031e-02, -3.4790e-03],
         [ 1.3867e-01,  2.8198e-02,  8.3008e-02,  ..., -6.5308e-03,
           2.6367e-01, -1.5137e-01],
         [-1.4062e-01, -2.9297e-01, -4.2578e-01,  ...,  3.3691e-02,
          -9.7656e-02,  6.5918e-02]]], dtype=torch.bfloat16)), (tensor([[[ 0.0146,  0.0317, -0.0143,  ..., -0.2393, -0.6523, -0.1162],
         [ 1.6094, -1.1406, -0.0879,  ...,  0.4395,  2.6562, -0.4863],
         [ 2.5781, -0.8555, -0.3262,  ...,  0.2451,  3.0469, -0.1484],
         ...,
         [-0.6953,  0.5469, -1.6641,  ...,  2.4375,  3.4062,  1.3047],
         [-0.4199,  0.4336,  0.5078,  ...,  1.5391,  0.7188,  0.7188],
         [-0.5078,  2.0938, -0.6367,  ...,  1.0078,  2.1406, -2.1406]],

        [[ 0.0146,  0.0317, -0.0143,  ..., -0.2393, -0.6523, -0.1162],
         [ 1.6094, -1.1406, -0.0879,  ...,  0.4395,  2.6562, -0.4863],
         [ 2.5781, -0.8555, -0.3262,  ...,  0.2451,  3.0469, -0.1484],
         ...,
         [-0.6953,  0.5469, -1.6641,  ...,  2.4375,  3.4062,  1.3047],
         [-0.4199,  0.4336,  0.5078,  ...,  1.5391,  0.7188,  0.7188],
         [-0.5078,  2.0938, -0.6367,  ...,  1.0078,  2.1406, -2.1406]],

        [[ 0.0146,  0.0317, -0.0143,  ..., -0.2393, -0.6523, -0.1162],
         [ 1.6094, -1.1406, -0.0879,  ...,  0.4395,  2.6562, -0.4863],
         [ 2.5781, -0.8555, -0.3262,  ...,  0.2451,  3.0469, -0.1484],
         ...,
         [-0.6953,  0.5469, -1.6641,  ...,  2.4375,  3.4062,  1.3047],
         [-0.4199,  0.4336,  0.5078,  ...,  1.5391,  0.7188,  0.7188],
         [-0.5078,  2.0938, -0.6367,  ...,  1.0078,  2.1406, -2.1406]],

        ...,

        [[-0.0132, -0.0184,  0.0106,  ...,  0.0322, -0.2773,  0.1357],
         [ 2.0625, -1.0625, -1.7812,  ..., -0.1167, -0.0369, -0.5156],
         [ 0.3672, -0.5586, -0.3867,  ...,  0.9570, -0.5742, -0.1230],
         ...,
         [-1.0391,  0.9609, -1.2891,  ...,  1.7422,  1.3750, -1.1016],
         [-0.5234,  0.6836, -0.7266,  ...,  0.2891, -1.3828,  1.3281],
         [-0.7422,  0.3984, -0.5547,  ...,  1.7734,  0.4082, -4.3750]],

        [[-0.0132, -0.0184,  0.0106,  ...,  0.0322, -0.2773,  0.1357],
         [ 2.0625, -1.0625, -1.7812,  ..., -0.1167, -0.0369, -0.5156],
         [ 0.3672, -0.5586, -0.3867,  ...,  0.9570, -0.5742, -0.1230],
         ...,
         [-1.0391,  0.9609, -1.2891,  ...,  1.7422,  1.3750, -1.1016],
         [-0.5234,  0.6836, -0.7266,  ...,  0.2891, -1.3828,  1.3281],
         [-0.7422,  0.3984, -0.5547,  ...,  1.7734,  0.4082, -4.3750]],

        [[-0.0132, -0.0184,  0.0106,  ...,  0.0322, -0.2773,  0.1357],
         [ 2.0625, -1.0625, -1.7812,  ..., -0.1167, -0.0369, -0.5156],
         [ 0.3672, -0.5586, -0.3867,  ...,  0.9570, -0.5742, -0.1230],
         ...,
         [-1.0391,  0.9609, -1.2891,  ...,  1.7422,  1.3750, -1.1016],
         [-0.5234,  0.6836, -0.7266,  ...,  0.2891, -1.3828,  1.3281],
         [-0.7422,  0.3984, -0.5547,  ...,  1.7734,  0.4082, -4.3750]]],
       dtype=torch.bfloat16), tensor([[[-9.8267e-03,  3.4027e-03, -1.1963e-02,  ...,  4.8523e-03,
          -4.0894e-03, -9.4604e-03],
         [-1.2988e-01,  6.7578e-01, -2.5586e-01,  ..., -2.1191e-01,
          -2.3828e-01,  6.7188e-01],
         [ 3.7109e-01, -2.2363e-01, -2.2559e-01,  ...,  5.8594e-01,
           6.9531e-01, -2.0117e-01],
         ...,
         [ 5.6641e-02,  1.3867e-01,  7.0312e-01,  ...,  5.5469e-01,
           2.1191e-01,  1.2793e-01],
         [ 8.5938e-01, -5.3906e-01,  6.1328e-01,  ..., -4.2383e-01,
          -2.2168e-01,  5.3516e-01],
         [-1.4160e-01,  1.6357e-02, -2.3633e-01,  ...,  4.9316e-02,
           2.5391e-01, -2.0410e-01]],

        [[-9.8267e-03,  3.4027e-03, -1.1963e-02,  ...,  4.8523e-03,
          -4.0894e-03, -9.4604e-03],
         [-1.2988e-01,  6.7578e-01, -2.5586e-01,  ..., -2.1191e-01,
          -2.3828e-01,  6.7188e-01],
         [ 3.7109e-01, -2.2363e-01, -2.2559e-01,  ...,  5.8594e-01,
           6.9531e-01, -2.0117e-01],
         ...,
         [ 5.6641e-02,  1.3867e-01,  7.0312e-01,  ...,  5.5469e-01,
           2.1191e-01,  1.2793e-01],
         [ 8.5938e-01, -5.3906e-01,  6.1328e-01,  ..., -4.2383e-01,
          -2.2168e-01,  5.3516e-01],
         [-1.4160e-01,  1.6357e-02, -2.3633e-01,  ...,  4.9316e-02,
           2.5391e-01, -2.0410e-01]],

        [[-9.8267e-03,  3.4027e-03, -1.1963e-02,  ...,  4.8523e-03,
          -4.0894e-03, -9.4604e-03],
         [-1.2988e-01,  6.7578e-01, -2.5586e-01,  ..., -2.1191e-01,
          -2.3828e-01,  6.7188e-01],
         [ 3.7109e-01, -2.2363e-01, -2.2559e-01,  ...,  5.8594e-01,
           6.9531e-01, -2.0117e-01],
         ...,
         [ 5.6641e-02,  1.3867e-01,  7.0312e-01,  ...,  5.5469e-01,
           2.1191e-01,  1.2793e-01],
         [ 8.5938e-01, -5.3906e-01,  6.1328e-01,  ..., -4.2383e-01,
          -2.2168e-01,  5.3516e-01],
         [-1.4160e-01,  1.6357e-02, -2.3633e-01,  ...,  4.9316e-02,
           2.5391e-01, -2.0410e-01]],

        ...,

        [[-4.5967e-04, -4.8523e-03, -3.3447e-02,  ...,  6.7139e-03,
           7.4768e-03,  5.1880e-03],
         [-1.9629e-01, -2.9541e-02,  9.2163e-03,  ..., -6.4844e-01,
          -5.3516e-01,  3.8574e-02],
         [ 4.1602e-01, -6.7578e-01,  1.9531e-01,  ..., -4.4531e-01,
           3.5156e-02,  4.7070e-01],
         ...,
         [ 4.1602e-01,  5.5859e-01,  8.3923e-04,  ..., -2.9297e-01,
           7.3730e-02,  2.3340e-01],
         [ 2.1973e-01,  5.0391e-01,  4.7070e-01,  ..., -2.5391e-01,
          -7.0312e-02,  3.0859e-01],
         [-1.0547e-01, -6.4453e-01, -1.1670e-01,  ..., -1.3281e-01,
           3.1738e-02, -3.0469e-01]],

        [[-4.5967e-04, -4.8523e-03, -3.3447e-02,  ...,  6.7139e-03,
           7.4768e-03,  5.1880e-03],
         [-1.9629e-01, -2.9541e-02,  9.2163e-03,  ..., -6.4844e-01,
          -5.3516e-01,  3.8574e-02],
         [ 4.1602e-01, -6.7578e-01,  1.9531e-01,  ..., -4.4531e-01,
           3.5156e-02,  4.7070e-01],
         ...,
         [ 4.1602e-01,  5.5859e-01,  8.3923e-04,  ..., -2.9297e-01,
           7.3730e-02,  2.3340e-01],
         [ 2.1973e-01,  5.0391e-01,  4.7070e-01,  ..., -2.5391e-01,
          -7.0312e-02,  3.0859e-01],
         [-1.0547e-01, -6.4453e-01, -1.1670e-01,  ..., -1.3281e-01,
           3.1738e-02, -3.0469e-01]],

        [[-4.5967e-04, -4.8523e-03, -3.3447e-02,  ...,  6.7139e-03,
           7.4768e-03,  5.1880e-03],
         [-1.9629e-01, -2.9541e-02,  9.2163e-03,  ..., -6.4844e-01,
          -5.3516e-01,  3.8574e-02],
         [ 4.1602e-01, -6.7578e-01,  1.9531e-01,  ..., -4.4531e-01,
           3.5156e-02,  4.7070e-01],
         ...,
         [ 4.1602e-01,  5.5859e-01,  8.3923e-04,  ..., -2.9297e-01,
           7.3730e-02,  2.3340e-01],
         [ 2.1973e-01,  5.0391e-01,  4.7070e-01,  ..., -2.5391e-01,
          -7.0312e-02,  3.0859e-01],
         [-1.0547e-01, -6.4453e-01, -1.1670e-01,  ..., -1.3281e-01,
           3.1738e-02, -3.0469e-01]]], dtype=torch.bfloat16)), (tensor([[[-5.8289e-03, -5.2185e-03, -3.2043e-03,  ...,  1.9141e-01,
          -9.6191e-02, -6.4062e-01],
         [-3.3750e+00, -1.9297e+00,  2.3828e-01,  ..., -1.8359e-01,
           7.5391e-01,  2.9375e+00],
         [ 5.9375e-01, -2.5312e+00, -1.2266e+00,  ..., -3.8574e-02,
           8.2812e-01,  1.7812e+00],
         ...,
         [ 4.4375e+00,  2.1875e+00,  6.8359e-01,  ..., -1.2500e+00,
          -4.7070e-01,  2.8438e+00],
         [ 2.8906e+00,  1.8906e+00, -2.6367e-01,  ..., -2.2188e+00,
          -2.5977e-01,  3.6562e+00],
         [-2.9688e-01,  1.1016e+00, -2.8906e-01,  ...,  7.5781e-01,
          -2.6953e-01,  2.0938e+00]],

        [[-5.8289e-03, -5.2185e-03, -3.2043e-03,  ...,  1.9141e-01,
          -9.6191e-02, -6.4062e-01],
         [-3.3750e+00, -1.9297e+00,  2.3828e-01,  ..., -1.8359e-01,
           7.5391e-01,  2.9375e+00],
         [ 5.9375e-01, -2.5312e+00, -1.2266e+00,  ..., -3.8574e-02,
           8.2812e-01,  1.7812e+00],
         ...,
         [ 4.4375e+00,  2.1875e+00,  6.8359e-01,  ..., -1.2500e+00,
          -4.7070e-01,  2.8438e+00],
         [ 2.8906e+00,  1.8906e+00, -2.6367e-01,  ..., -2.2188e+00,
          -2.5977e-01,  3.6562e+00],
         [-2.9688e-01,  1.1016e+00, -2.8906e-01,  ...,  7.5781e-01,
          -2.6953e-01,  2.0938e+00]],

        [[-5.8289e-03, -5.2185e-03, -3.2043e-03,  ...,  1.9141e-01,
          -9.6191e-02, -6.4062e-01],
         [-3.3750e+00, -1.9297e+00,  2.3828e-01,  ..., -1.8359e-01,
           7.5391e-01,  2.9375e+00],
         [ 5.9375e-01, -2.5312e+00, -1.2266e+00,  ..., -3.8574e-02,
           8.2812e-01,  1.7812e+00],
         ...,
         [ 4.4375e+00,  2.1875e+00,  6.8359e-01,  ..., -1.2500e+00,
          -4.7070e-01,  2.8438e+00],
         [ 2.8906e+00,  1.8906e+00, -2.6367e-01,  ..., -2.2188e+00,
          -2.5977e-01,  3.6562e+00],
         [-2.9688e-01,  1.1016e+00, -2.8906e-01,  ...,  7.5781e-01,
          -2.6953e-01,  2.0938e+00]],

        ...,

        [[-2.0599e-03, -3.8338e-04,  5.0049e-03,  ...,  1.0437e-02,
           1.6235e-02,  2.6758e-01],
         [-6.3672e-01,  5.5469e-01, -4.7656e-01,  ..., -2.0625e+00,
           9.9609e-01, -1.2969e+00],
         [ 1.2266e+00,  1.9531e+00,  3.3984e-01,  ..., -7.0703e-01,
          -2.5781e+00,  2.0605e-01],
         ...,
         [ 2.7656e+00, -1.6562e+00, -4.3555e-01,  ...,  5.7031e-01,
           2.2266e-01, -4.2969e-01],
         [ 4.8340e-02, -6.9531e-01, -6.7969e-01,  ...,  1.0234e+00,
          -2.1719e+00, -1.3828e+00],
         [-1.5938e+00, -1.5469e+00,  3.7109e-01,  ...,  1.5469e+00,
          -1.0234e+00, -6.2891e-01]],

        [[-2.0599e-03, -3.8338e-04,  5.0049e-03,  ...,  1.0437e-02,
           1.6235e-02,  2.6758e-01],
         [-6.3672e-01,  5.5469e-01, -4.7656e-01,  ..., -2.0625e+00,
           9.9609e-01, -1.2969e+00],
         [ 1.2266e+00,  1.9531e+00,  3.3984e-01,  ..., -7.0703e-01,
          -2.5781e+00,  2.0605e-01],
         ...,
         [ 2.7656e+00, -1.6562e+00, -4.3555e-01,  ...,  5.7031e-01,
           2.2266e-01, -4.2969e-01],
         [ 4.8340e-02, -6.9531e-01, -6.7969e-01,  ...,  1.0234e+00,
          -2.1719e+00, -1.3828e+00],
         [-1.5938e+00, -1.5469e+00,  3.7109e-01,  ...,  1.5469e+00,
          -1.0234e+00, -6.2891e-01]],

        [[-2.0599e-03, -3.8338e-04,  5.0049e-03,  ...,  1.0437e-02,
           1.6235e-02,  2.6758e-01],
         [-6.3672e-01,  5.5469e-01, -4.7656e-01,  ..., -2.0625e+00,
           9.9609e-01, -1.2969e+00],
         [ 1.2266e+00,  1.9531e+00,  3.3984e-01,  ..., -7.0703e-01,
          -2.5781e+00,  2.0605e-01],
         ...,
         [ 2.7656e+00, -1.6562e+00, -4.3555e-01,  ...,  5.7031e-01,
           2.2266e-01, -4.2969e-01],
         [ 4.8340e-02, -6.9531e-01, -6.7969e-01,  ...,  1.0234e+00,
          -2.1719e+00, -1.3828e+00],
         [-1.5938e+00, -1.5469e+00,  3.7109e-01,  ...,  1.5469e+00,
          -1.0234e+00, -6.2891e-01]]], dtype=torch.bfloat16), tensor([[[-6.6833e-03,  9.1553e-03, -9.5215e-03,  ...,  9.5825e-03,
          -1.5747e-02,  8.4229e-03],
         [ 1.8164e-01,  5.7861e-02,  8.6719e-01,  ..., -2.4414e-01,
           3.2812e-01,  5.7422e-01],
         [ 1.6113e-01, -7.6953e-01,  4.9414e-01,  ..., -5.0391e-01,
          -2.1191e-01, -4.2578e-01],
         ...,
         [ 2.1191e-01, -1.1182e-01,  1.4160e-02,  ..., -2.5781e-01,
           6.4062e-01,  5.7422e-01],
         [ 1.1621e-01, -3.2031e-01, -2.4048e-02,  ..., -3.6133e-01,
           7.3730e-02, -2.4219e-01],
         [-2.8320e-01,  2.3047e-01,  2.9883e-01,  ..., -1.0254e-01,
           1.3594e+00, -2.2559e-01]],

        [[-6.6833e-03,  9.1553e-03, -9.5215e-03,  ...,  9.5825e-03,
          -1.5747e-02,  8.4229e-03],
         [ 1.8164e-01,  5.7861e-02,  8.6719e-01,  ..., -2.4414e-01,
           3.2812e-01,  5.7422e-01],
         [ 1.6113e-01, -7.6953e-01,  4.9414e-01,  ..., -5.0391e-01,
          -2.1191e-01, -4.2578e-01],
         ...,
         [ 2.1191e-01, -1.1182e-01,  1.4160e-02,  ..., -2.5781e-01,
           6.4062e-01,  5.7422e-01],
         [ 1.1621e-01, -3.2031e-01, -2.4048e-02,  ..., -3.6133e-01,
           7.3730e-02, -2.4219e-01],
         [-2.8320e-01,  2.3047e-01,  2.9883e-01,  ..., -1.0254e-01,
           1.3594e+00, -2.2559e-01]],

        [[-6.6833e-03,  9.1553e-03, -9.5215e-03,  ...,  9.5825e-03,
          -1.5747e-02,  8.4229e-03],
         [ 1.8164e-01,  5.7861e-02,  8.6719e-01,  ..., -2.4414e-01,
           3.2812e-01,  5.7422e-01],
         [ 1.6113e-01, -7.6953e-01,  4.9414e-01,  ..., -5.0391e-01,
          -2.1191e-01, -4.2578e-01],
         ...,
         [ 2.1191e-01, -1.1182e-01,  1.4160e-02,  ..., -2.5781e-01,
           6.4062e-01,  5.7422e-01],
         [ 1.1621e-01, -3.2031e-01, -2.4048e-02,  ..., -3.6133e-01,
           7.3730e-02, -2.4219e-01],
         [-2.8320e-01,  2.3047e-01,  2.9883e-01,  ..., -1.0254e-01,
           1.3594e+00, -2.2559e-01]],

        ...,

        [[ 7.4463e-03,  4.1504e-03, -7.1716e-03,  ...,  4.1199e-03,
           1.5625e-02, -1.1841e-02],
         [ 2.1289e-01,  2.9102e-01, -2.2363e-01,  ...,  1.0010e-01,
           2.1973e-01,  3.2812e-01],
         [-1.3672e-01,  5.5859e-01,  4.2534e-04,  ..., -3.8672e-01,
          -6.1328e-01, -5.4688e-01],
         ...,
         [ 1.4355e-01,  3.9844e-01, -5.7031e-01,  ..., -7.0703e-01,
          -1.1230e-01,  2.7930e-01],
         [-2.4512e-01, -9.1797e-02, -1.2402e-01,  ..., -5.3516e-01,
          -8.7402e-02,  2.3926e-01],
         [ 3.9844e-01,  9.9219e-01,  7.5195e-02,  ...,  3.6914e-01,
          -2.3535e-01, -8.0078e-02]],

        [[ 7.4463e-03,  4.1504e-03, -7.1716e-03,  ...,  4.1199e-03,
           1.5625e-02, -1.1841e-02],
         [ 2.1289e-01,  2.9102e-01, -2.2363e-01,  ...,  1.0010e-01,
           2.1973e-01,  3.2812e-01],
         [-1.3672e-01,  5.5859e-01,  4.2534e-04,  ..., -3.8672e-01,
          -6.1328e-01, -5.4688e-01],
         ...,
         [ 1.4355e-01,  3.9844e-01, -5.7031e-01,  ..., -7.0703e-01,
          -1.1230e-01,  2.7930e-01],
         [-2.4512e-01, -9.1797e-02, -1.2402e-01,  ..., -5.3516e-01,
          -8.7402e-02,  2.3926e-01],
         [ 3.9844e-01,  9.9219e-01,  7.5195e-02,  ...,  3.6914e-01,
          -2.3535e-01, -8.0078e-02]],

        [[ 7.4463e-03,  4.1504e-03, -7.1716e-03,  ...,  4.1199e-03,
           1.5625e-02, -1.1841e-02],
         [ 2.1289e-01,  2.9102e-01, -2.2363e-01,  ...,  1.0010e-01,
           2.1973e-01,  3.2812e-01],
         [-1.3672e-01,  5.5859e-01,  4.2534e-04,  ..., -3.8672e-01,
          -6.1328e-01, -5.4688e-01],
         ...,
         [ 1.4355e-01,  3.9844e-01, -5.7031e-01,  ..., -7.0703e-01,
          -1.1230e-01,  2.7930e-01],
         [-2.4512e-01, -9.1797e-02, -1.2402e-01,  ..., -5.3516e-01,
          -8.7402e-02,  2.3926e-01],
         [ 3.9844e-01,  9.9219e-01,  7.5195e-02,  ...,  3.6914e-01,
          -2.3535e-01, -8.0078e-02]]], dtype=torch.bfloat16)), (tensor([[[ 0.0128,  0.0156, -0.0175,  ...,  0.0996, -0.0767, -0.1533],
         [-1.1406, -0.2266, -0.6562,  ...,  1.5859,  0.4961,  0.7188],
         [-1.5938,  0.2578,  0.3047,  ...,  0.1641,  1.9609, -1.6484],
         ...,
         [ 0.5195, -0.9219, -0.9688,  ...,  0.1719,  1.0391, -2.8281],
         [ 0.2344, -0.1104,  0.5625,  ..., -2.2656,  2.6875,  1.8594],
         [ 2.0938,  1.2656,  0.6406,  ..., -0.5234, -0.3047, -0.0488]],

        [[ 0.0128,  0.0156, -0.0175,  ...,  0.0996, -0.0767, -0.1533],
         [-1.1406, -0.2266, -0.6562,  ...,  1.5859,  0.4961,  0.7188],
         [-1.5938,  0.2578,  0.3047,  ...,  0.1641,  1.9609, -1.6484],
         ...,
         [ 0.5195, -0.9219, -0.9688,  ...,  0.1719,  1.0391, -2.8281],
         [ 0.2344, -0.1104,  0.5625,  ..., -2.2656,  2.6875,  1.8594],
         [ 2.0938,  1.2656,  0.6406,  ..., -0.5234, -0.3047, -0.0488]],

        [[ 0.0128,  0.0156, -0.0175,  ...,  0.0996, -0.0767, -0.1533],
         [-1.1406, -0.2266, -0.6562,  ...,  1.5859,  0.4961,  0.7188],
         [-1.5938,  0.2578,  0.3047,  ...,  0.1641,  1.9609, -1.6484],
         ...,
         [ 0.5195, -0.9219, -0.9688,  ...,  0.1719,  1.0391, -2.8281],
         [ 0.2344, -0.1104,  0.5625,  ..., -2.2656,  2.6875,  1.8594],
         [ 2.0938,  1.2656,  0.6406,  ..., -0.5234, -0.3047, -0.0488]],

        ...,

        [[-0.0070, -0.0035, -0.0055,  ..., -0.0588, -0.0309,  0.0552],
         [ 1.9453,  1.6719,  0.0781,  ...,  1.7734, -0.5156,  1.0938],
         [ 2.8125,  1.0000,  0.5859,  ...,  1.2578,  0.5547,  1.7344],
         ...,
         [-0.3086, -1.2266, -0.3398,  ...,  0.7422,  0.0845, -0.8867],
         [-0.1670, -1.1250, -0.3242,  ...,  0.3008,  1.4531, -0.9492],
         [-1.3906, -0.5273, -1.0625,  ..., -0.7695, -0.3340, -2.1250]],

        [[-0.0070, -0.0035, -0.0055,  ..., -0.0588, -0.0309,  0.0552],
         [ 1.9453,  1.6719,  0.0781,  ...,  1.7734, -0.5156,  1.0938],
         [ 2.8125,  1.0000,  0.5859,  ...,  1.2578,  0.5547,  1.7344],
         ...,
         [-0.3086, -1.2266, -0.3398,  ...,  0.7422,  0.0845, -0.8867],
         [-0.1670, -1.1250, -0.3242,  ...,  0.3008,  1.4531, -0.9492],
         [-1.3906, -0.5273, -1.0625,  ..., -0.7695, -0.3340, -2.1250]],

        [[-0.0070, -0.0035, -0.0055,  ..., -0.0588, -0.0309,  0.0552],
         [ 1.9453,  1.6719,  0.0781,  ...,  1.7734, -0.5156,  1.0938],
         [ 2.8125,  1.0000,  0.5859,  ...,  1.2578,  0.5547,  1.7344],
         ...,
         [-0.3086, -1.2266, -0.3398,  ...,  0.7422,  0.0845, -0.8867],
         [-0.1670, -1.1250, -0.3242,  ...,  0.3008,  1.4531, -0.9492],
         [-1.3906, -0.5273, -1.0625,  ..., -0.7695, -0.3340, -2.1250]]],
       dtype=torch.bfloat16), tensor([[[-0.0125,  0.0048, -0.0069,  ...,  0.0050, -0.0082, -0.0070],
         [ 0.1729, -0.1914,  0.0452,  ...,  0.5156,  0.0698,  0.3691],
         [-1.0156,  0.7773,  0.3066,  ..., -0.1240,  0.8320,  0.8203],
         ...,
         [-0.4062,  0.7812, -0.6992,  ...,  0.2148, -0.0635, -0.7734],
         [-0.4746, -0.3242,  0.3184,  ...,  0.2715, -0.0131, -0.3809],
         [ 0.0299,  0.4707, -0.2832,  ...,  0.1592, -0.0796, -0.3047]],

        [[-0.0125,  0.0048, -0.0069,  ...,  0.0050, -0.0082, -0.0070],
         [ 0.1729, -0.1914,  0.0452,  ...,  0.5156,  0.0698,  0.3691],
         [-1.0156,  0.7773,  0.3066,  ..., -0.1240,  0.8320,  0.8203],
         ...,
         [-0.4062,  0.7812, -0.6992,  ...,  0.2148, -0.0635, -0.7734],
         [-0.4746, -0.3242,  0.3184,  ...,  0.2715, -0.0131, -0.3809],
         [ 0.0299,  0.4707, -0.2832,  ...,  0.1592, -0.0796, -0.3047]],

        [[-0.0125,  0.0048, -0.0069,  ...,  0.0050, -0.0082, -0.0070],
         [ 0.1729, -0.1914,  0.0452,  ...,  0.5156,  0.0698,  0.3691],
         [-1.0156,  0.7773,  0.3066,  ..., -0.1240,  0.8320,  0.8203],
         ...,
         [-0.4062,  0.7812, -0.6992,  ...,  0.2148, -0.0635, -0.7734],
         [-0.4746, -0.3242,  0.3184,  ...,  0.2715, -0.0131, -0.3809],
         [ 0.0299,  0.4707, -0.2832,  ...,  0.1592, -0.0796, -0.3047]],

        ...,

        [[-0.0022, -0.0028, -0.0028,  ..., -0.0060,  0.0087, -0.0074],
         [ 0.0129,  0.1001,  0.2002,  ..., -0.2734,  0.4023,  0.2168],
         [ 0.3848,  0.5156,  0.6914,  ..., -0.5273,  0.2832,  0.9570],
         ...,
         [ 0.2891,  0.0508, -0.5742,  ...,  0.3242, -0.5742, -0.2012],
         [-0.1445,  0.5586, -0.2139,  ..., -0.3965, -0.1055,  0.3711],
         [ 0.3828,  0.5156,  0.1934,  ...,  0.0635,  0.2578,  0.0045]],

        [[-0.0022, -0.0028, -0.0028,  ..., -0.0060,  0.0087, -0.0074],
         [ 0.0129,  0.1001,  0.2002,  ..., -0.2734,  0.4023,  0.2168],
         [ 0.3848,  0.5156,  0.6914,  ..., -0.5273,  0.2832,  0.9570],
         ...,
         [ 0.2891,  0.0508, -0.5742,  ...,  0.3242, -0.5742, -0.2012],
         [-0.1445,  0.5586, -0.2139,  ..., -0.3965, -0.1055,  0.3711],
         [ 0.3828,  0.5156,  0.1934,  ...,  0.0635,  0.2578,  0.0045]],

        [[-0.0022, -0.0028, -0.0028,  ..., -0.0060,  0.0087, -0.0074],
         [ 0.0129,  0.1001,  0.2002,  ..., -0.2734,  0.4023,  0.2168],
         [ 0.3848,  0.5156,  0.6914,  ..., -0.5273,  0.2832,  0.9570],
         ...,
         [ 0.2891,  0.0508, -0.5742,  ...,  0.3242, -0.5742, -0.2012],
         [-0.1445,  0.5586, -0.2139,  ..., -0.3965, -0.1055,  0.3711],
         [ 0.3828,  0.5156,  0.1934,  ...,  0.0635,  0.2578,  0.0045]]],
       dtype=torch.bfloat16)), (tensor([[[-1.5015e-02, -4.2725e-03, -8.5831e-04,  ..., -5.9375e-01,
          -1.5723e-01,  2.8076e-02],
         [-1.6641e+00,  2.8125e-01, -1.2188e+00,  ...,  2.2031e+00,
           6.6797e-01,  4.2188e-01],
         [-1.6406e-01,  1.0391e+00, -1.2695e-01,  ..., -1.5442e-02,
           2.4531e+00,  1.2109e+00],
         ...,
         [ 1.2656e+00, -9.4531e-01, -1.6562e+00,  ...,  1.2891e+00,
           2.1562e+00, -1.0234e+00],
         [ 5.5664e-02, -4.8828e-01, -8.7500e-01,  ...,  9.2578e-01,
           1.2500e+00, -6.9336e-02],
         [ 5.7812e-01, -1.1406e+00, -1.0000e+00,  ...,  2.0156e+00,
           3.9258e-01,  6.5625e-01]],

        [[-1.5015e-02, -4.2725e-03, -8.5831e-04,  ..., -5.9375e-01,
          -1.5723e-01,  2.8076e-02],
         [-1.6641e+00,  2.8125e-01, -1.2188e+00,  ...,  2.2031e+00,
           6.6797e-01,  4.2188e-01],
         [-1.6406e-01,  1.0391e+00, -1.2695e-01,  ..., -1.5442e-02,
           2.4531e+00,  1.2109e+00],
         ...,
         [ 1.2656e+00, -9.4531e-01, -1.6562e+00,  ...,  1.2891e+00,
           2.1562e+00, -1.0234e+00],
         [ 5.5664e-02, -4.8828e-01, -8.7500e-01,  ...,  9.2578e-01,
           1.2500e+00, -6.9336e-02],
         [ 5.7812e-01, -1.1406e+00, -1.0000e+00,  ...,  2.0156e+00,
           3.9258e-01,  6.5625e-01]],

        [[-1.5015e-02, -4.2725e-03, -8.5831e-04,  ..., -5.9375e-01,
          -1.5723e-01,  2.8076e-02],
         [-1.6641e+00,  2.8125e-01, -1.2188e+00,  ...,  2.2031e+00,
           6.6797e-01,  4.2188e-01],
         [-1.6406e-01,  1.0391e+00, -1.2695e-01,  ..., -1.5442e-02,
           2.4531e+00,  1.2109e+00],
         ...,
         [ 1.2656e+00, -9.4531e-01, -1.6562e+00,  ...,  1.2891e+00,
           2.1562e+00, -1.0234e+00],
         [ 5.5664e-02, -4.8828e-01, -8.7500e-01,  ...,  9.2578e-01,
           1.2500e+00, -6.9336e-02],
         [ 5.7812e-01, -1.1406e+00, -1.0000e+00,  ...,  2.0156e+00,
           3.9258e-01,  6.5625e-01]],

        ...,

        [[ 3.1433e-03,  8.4839e-03,  4.5967e-04,  ...,  1.1914e-01,
           5.0781e-02, -9.8267e-03],
         [ 1.4688e+00, -8.5547e-01, -1.2734e+00,  ..., -4.9805e-01,
          -4.1504e-02,  1.1016e+00],
         [-1.2422e+00, -9.2969e-01, -9.9219e-01,  ...,  5.4297e-01,
          -1.6328e+00,  5.5078e-01],
         ...,
         [-2.5156e+00,  3.3398e-01, -1.7422e+00,  ...,  3.5156e-01,
           3.1641e-01,  1.6797e+00],
         [ 2.0508e-01,  9.4531e-01, -4.4531e-01,  ...,  1.0859e+00,
          -6.9531e-01,  2.2031e+00],
         [-1.0547e-01,  7.7344e-01, -8.3203e-01,  ..., -4.9023e-01,
          -5.7373e-02,  1.0234e+00]],

        [[ 3.1433e-03,  8.4839e-03,  4.5967e-04,  ...,  1.1914e-01,
           5.0781e-02, -9.8267e-03],
         [ 1.4688e+00, -8.5547e-01, -1.2734e+00,  ..., -4.9805e-01,
          -4.1504e-02,  1.1016e+00],
         [-1.2422e+00, -9.2969e-01, -9.9219e-01,  ...,  5.4297e-01,
          -1.6328e+00,  5.5078e-01],
         ...,
         [-2.5156e+00,  3.3398e-01, -1.7422e+00,  ...,  3.5156e-01,
           3.1641e-01,  1.6797e+00],
         [ 2.0508e-01,  9.4531e-01, -4.4531e-01,  ...,  1.0859e+00,
          -6.9531e-01,  2.2031e+00],
         [-1.0547e-01,  7.7344e-01, -8.3203e-01,  ..., -4.9023e-01,
          -5.7373e-02,  1.0234e+00]],

        [[ 3.1433e-03,  8.4839e-03,  4.5967e-04,  ...,  1.1914e-01,
           5.0781e-02, -9.8267e-03],
         [ 1.4688e+00, -8.5547e-01, -1.2734e+00,  ..., -4.9805e-01,
          -4.1504e-02,  1.1016e+00],
         [-1.2422e+00, -9.2969e-01, -9.9219e-01,  ...,  5.4297e-01,
          -1.6328e+00,  5.5078e-01],
         ...,
         [-2.5156e+00,  3.3398e-01, -1.7422e+00,  ...,  3.5156e-01,
           3.1641e-01,  1.6797e+00],
         [ 2.0508e-01,  9.4531e-01, -4.4531e-01,  ...,  1.0859e+00,
          -6.9531e-01,  2.2031e+00],
         [-1.0547e-01,  7.7344e-01, -8.3203e-01,  ..., -4.9023e-01,
          -5.7373e-02,  1.0234e+00]]], dtype=torch.bfloat16), tensor([[[-2.3651e-03, -5.9204e-03,  4.3030e-03,  ...,  5.5313e-04,
           9.9182e-04, -5.1880e-03],
         [ 4.3750e-01,  5.5859e-01,  6.6016e-01,  ...,  1.4160e-01,
          -8.3594e-01, -5.1514e-02],
         [ 5.1953e-01, -4.2725e-02,  1.4062e-01,  ..., -4.8828e-01,
           2.5586e-01,  2.6367e-01],
         ...,
         [-3.1641e-01,  2.7734e-01, -3.7354e-02,  ...,  1.3281e-01,
          -4.4336e-01,  1.0156e+00],
         [ 1.6895e-01,  3.0469e-01, -4.8633e-01,  ..., -3.6914e-01,
           8.3008e-02,  7.0312e-01],
         [-8.0566e-02,  7.8516e-01, -3.4668e-02,  ..., -4.6143e-02,
          -8.0469e-01,  1.6504e-01]],

        [[-2.3651e-03, -5.9204e-03,  4.3030e-03,  ...,  5.5313e-04,
           9.9182e-04, -5.1880e-03],
         [ 4.3750e-01,  5.5859e-01,  6.6016e-01,  ...,  1.4160e-01,
          -8.3594e-01, -5.1514e-02],
         [ 5.1953e-01, -4.2725e-02,  1.4062e-01,  ..., -4.8828e-01,
           2.5586e-01,  2.6367e-01],
         ...,
         [-3.1641e-01,  2.7734e-01, -3.7354e-02,  ...,  1.3281e-01,
          -4.4336e-01,  1.0156e+00],
         [ 1.6895e-01,  3.0469e-01, -4.8633e-01,  ..., -3.6914e-01,
           8.3008e-02,  7.0312e-01],
         [-8.0566e-02,  7.8516e-01, -3.4668e-02,  ..., -4.6143e-02,
          -8.0469e-01,  1.6504e-01]],

        [[-2.3651e-03, -5.9204e-03,  4.3030e-03,  ...,  5.5313e-04,
           9.9182e-04, -5.1880e-03],
         [ 4.3750e-01,  5.5859e-01,  6.6016e-01,  ...,  1.4160e-01,
          -8.3594e-01, -5.1514e-02],
         [ 5.1953e-01, -4.2725e-02,  1.4062e-01,  ..., -4.8828e-01,
           2.5586e-01,  2.6367e-01],
         ...,
         [-3.1641e-01,  2.7734e-01, -3.7354e-02,  ...,  1.3281e-01,
          -4.4336e-01,  1.0156e+00],
         [ 1.6895e-01,  3.0469e-01, -4.8633e-01,  ..., -3.6914e-01,
           8.3008e-02,  7.0312e-01],
         [-8.0566e-02,  7.8516e-01, -3.4668e-02,  ..., -4.6143e-02,
          -8.0469e-01,  1.6504e-01]],

        ...,

        [[-4.8065e-04,  9.7275e-04,  5.5176e-02,  ..., -3.0518e-02,
          -1.1108e-02,  2.5635e-02],
         [-2.1680e-01, -1.3750e+00, -3.3984e-01,  ...,  4.8438e-01,
          -1.5039e-01,  7.4219e-01],
         [-1.3965e-01, -8.9453e-01, -3.2617e-01,  ...,  3.6914e-01,
          -7.4609e-01,  1.6699e-01],
         ...,
         [-8.6328e-01, -1.7383e-01,  1.3965e-01,  ...,  1.3125e+00,
          -2.1289e-01,  1.3672e+00],
         [-9.9219e-01,  2.3633e-01,  2.4609e-01,  ...,  6.0156e-01,
          -4.3750e-01,  4.5117e-01],
         [ 1.0234e+00, -7.7344e-01, -5.0781e-01,  ..., -1.6504e-01,
           1.6309e-01, -5.8203e-01]],

        [[-4.8065e-04,  9.7275e-04,  5.5176e-02,  ..., -3.0518e-02,
          -1.1108e-02,  2.5635e-02],
         [-2.1680e-01, -1.3750e+00, -3.3984e-01,  ...,  4.8438e-01,
          -1.5039e-01,  7.4219e-01],
         [-1.3965e-01, -8.9453e-01, -3.2617e-01,  ...,  3.6914e-01,
          -7.4609e-01,  1.6699e-01],
         ...,
         [-8.6328e-01, -1.7383e-01,  1.3965e-01,  ...,  1.3125e+00,
          -2.1289e-01,  1.3672e+00],
         [-9.9219e-01,  2.3633e-01,  2.4609e-01,  ...,  6.0156e-01,
          -4.3750e-01,  4.5117e-01],
         [ 1.0234e+00, -7.7344e-01, -5.0781e-01,  ..., -1.6504e-01,
           1.6309e-01, -5.8203e-01]],

        [[-4.8065e-04,  9.7275e-04,  5.5176e-02,  ..., -3.0518e-02,
          -1.1108e-02,  2.5635e-02],
         [-2.1680e-01, -1.3750e+00, -3.3984e-01,  ...,  4.8438e-01,
          -1.5039e-01,  7.4219e-01],
         [-1.3965e-01, -8.9453e-01, -3.2617e-01,  ...,  3.6914e-01,
          -7.4609e-01,  1.6699e-01],
         ...,
         [-8.6328e-01, -1.7383e-01,  1.3965e-01,  ...,  1.3125e+00,
          -2.1289e-01,  1.3672e+00],
         [-9.9219e-01,  2.3633e-01,  2.4609e-01,  ...,  6.0156e-01,
          -4.3750e-01,  4.5117e-01],
         [ 1.0234e+00, -7.7344e-01, -5.0781e-01,  ..., -1.6504e-01,
           1.6309e-01, -5.8203e-01]]], dtype=torch.bfloat16)), (tensor([[[ 1.1780e-02, -3.5858e-03, -3.2196e-03,  ..., -3.6328e-01,
          -1.3733e-02, -4.6875e-02],
         [-1.5234e-01,  1.4062e-01, -9.8438e-01,  ..., -8.7500e-01,
           1.2969e+00,  1.2969e+00],
         [ 8.9062e-01, -1.4062e+00, -6.4062e-01,  ...,  1.8828e+00,
           2.2656e+00,  1.3438e+00],
         ...,
         [-3.5400e-02,  8.9844e-01, -4.5312e-01,  ..., -8.9844e-02,
          -9.1016e-01,  2.4780e-02],
         [-6.1328e-01,  4.3359e-01, -4.3750e-01,  ...,  2.2344e+00,
          -2.4375e+00, -1.1562e+00],
         [-1.3047e+00,  1.6484e+00, -1.5312e+00,  ...,  1.7266e+00,
           2.4375e+00,  2.0625e+00]],

        [[ 1.1780e-02, -3.5858e-03, -3.2196e-03,  ..., -3.6328e-01,
          -1.3733e-02, -4.6875e-02],
         [-1.5234e-01,  1.4062e-01, -9.8438e-01,  ..., -8.7500e-01,
           1.2969e+00,  1.2969e+00],
         [ 8.9062e-01, -1.4062e+00, -6.4062e-01,  ...,  1.8828e+00,
           2.2656e+00,  1.3438e+00],
         ...,
         [-3.5400e-02,  8.9844e-01, -4.5312e-01,  ..., -8.9844e-02,
          -9.1016e-01,  2.4780e-02],
         [-6.1328e-01,  4.3359e-01, -4.3750e-01,  ...,  2.2344e+00,
          -2.4375e+00, -1.1562e+00],
         [-1.3047e+00,  1.6484e+00, -1.5312e+00,  ...,  1.7266e+00,
           2.4375e+00,  2.0625e+00]],

        [[ 1.1780e-02, -3.5858e-03, -3.2196e-03,  ..., -3.6328e-01,
          -1.3733e-02, -4.6875e-02],
         [-1.5234e-01,  1.4062e-01, -9.8438e-01,  ..., -8.7500e-01,
           1.2969e+00,  1.2969e+00],
         [ 8.9062e-01, -1.4062e+00, -6.4062e-01,  ...,  1.8828e+00,
           2.2656e+00,  1.3438e+00],
         ...,
         [-3.5400e-02,  8.9844e-01, -4.5312e-01,  ..., -8.9844e-02,
          -9.1016e-01,  2.4780e-02],
         [-6.1328e-01,  4.3359e-01, -4.3750e-01,  ...,  2.2344e+00,
          -2.4375e+00, -1.1562e+00],
         [-1.3047e+00,  1.6484e+00, -1.5312e+00,  ...,  1.7266e+00,
           2.4375e+00,  2.0625e+00]],

        ...,

        [[-6.5613e-04,  1.4099e-02, -2.3956e-03,  ..., -2.5977e-01,
          -1.6406e-01, -4.7607e-03],
         [-1.0156e+00,  3.3203e-01,  2.2656e+00,  ...,  1.3516e+00,
          -9.7266e-01, -3.9844e-01],
         [-2.8750e+00,  5.5469e-01,  2.0625e+00,  ...,  1.0391e+00,
          -9.7168e-02, -1.7109e+00],
         ...,
         [-1.2695e-01,  1.6406e+00,  1.6016e+00,  ..., -1.0703e+00,
           7.0312e-01, -4.1797e-01],
         [ 2.7344e-01,  2.4414e-02,  3.6719e-01,  ..., -1.3906e+00,
          -7.1094e-01,  5.7812e-01],
         [ 1.9219e+00, -8.5938e-01,  1.3750e+00,  ..., -3.5352e-01,
           4.5898e-01, -1.3477e-01]],

        [[-6.5613e-04,  1.4099e-02, -2.3956e-03,  ..., -2.5977e-01,
          -1.6406e-01, -4.7607e-03],
         [-1.0156e+00,  3.3203e-01,  2.2656e+00,  ...,  1.3516e+00,
          -9.7266e-01, -3.9844e-01],
         [-2.8750e+00,  5.5469e-01,  2.0625e+00,  ...,  1.0391e+00,
          -9.7168e-02, -1.7109e+00],
         ...,
         [-1.2695e-01,  1.6406e+00,  1.6016e+00,  ..., -1.0703e+00,
           7.0312e-01, -4.1797e-01],
         [ 2.7344e-01,  2.4414e-02,  3.6719e-01,  ..., -1.3906e+00,
          -7.1094e-01,  5.7812e-01],
         [ 1.9219e+00, -8.5938e-01,  1.3750e+00,  ..., -3.5352e-01,
           4.5898e-01, -1.3477e-01]],

        [[-6.5613e-04,  1.4099e-02, -2.3956e-03,  ..., -2.5977e-01,
          -1.6406e-01, -4.7607e-03],
         [-1.0156e+00,  3.3203e-01,  2.2656e+00,  ...,  1.3516e+00,
          -9.7266e-01, -3.9844e-01],
         [-2.8750e+00,  5.5469e-01,  2.0625e+00,  ...,  1.0391e+00,
          -9.7168e-02, -1.7109e+00],
         ...,
         [-1.2695e-01,  1.6406e+00,  1.6016e+00,  ..., -1.0703e+00,
           7.0312e-01, -4.1797e-01],
         [ 2.7344e-01,  2.4414e-02,  3.6719e-01,  ..., -1.3906e+00,
          -7.1094e-01,  5.7812e-01],
         [ 1.9219e+00, -8.5938e-01,  1.3750e+00,  ..., -3.5352e-01,
           4.5898e-01, -1.3477e-01]]], dtype=torch.bfloat16), tensor([[[ 3.7842e-02, -7.5073e-03, -7.4768e-04,  ...,  8.5449e-02,
           3.0762e-02,  1.2695e-02],
         [-6.9141e-01,  4.3359e-01, -2.8320e-02,  ...,  3.4570e-01,
          -2.2949e-01, -1.6699e-01],
         [-4.7070e-01,  7.6172e-02, -4.6631e-02,  ...,  1.8750e-01,
          -2.4219e-01, -5.5859e-01],
         ...,
         [-1.9775e-02,  4.9609e-01,  4.1992e-01,  ..., -4.6289e-01,
          -6.9531e-01, -3.0469e-01],
         [-6.0156e-01,  3.9844e-01,  6.2891e-01,  ..., -2.0605e-01,
          -7.5000e-01,  3.0664e-01],
         [ 3.6133e-01,  5.5469e-01, -7.0801e-02,  ...,  3.1641e-01,
          -4.0820e-01, -9.3359e-01]],

        [[ 3.7842e-02, -7.5073e-03, -7.4768e-04,  ...,  8.5449e-02,
           3.0762e-02,  1.2695e-02],
         [-6.9141e-01,  4.3359e-01, -2.8320e-02,  ...,  3.4570e-01,
          -2.2949e-01, -1.6699e-01],
         [-4.7070e-01,  7.6172e-02, -4.6631e-02,  ...,  1.8750e-01,
          -2.4219e-01, -5.5859e-01],
         ...,
         [-1.9775e-02,  4.9609e-01,  4.1992e-01,  ..., -4.6289e-01,
          -6.9531e-01, -3.0469e-01],
         [-6.0156e-01,  3.9844e-01,  6.2891e-01,  ..., -2.0605e-01,
          -7.5000e-01,  3.0664e-01],
         [ 3.6133e-01,  5.5469e-01, -7.0801e-02,  ...,  3.1641e-01,
          -4.0820e-01, -9.3359e-01]],

        [[ 3.7842e-02, -7.5073e-03, -7.4768e-04,  ...,  8.5449e-02,
           3.0762e-02,  1.2695e-02],
         [-6.9141e-01,  4.3359e-01, -2.8320e-02,  ...,  3.4570e-01,
          -2.2949e-01, -1.6699e-01],
         [-4.7070e-01,  7.6172e-02, -4.6631e-02,  ...,  1.8750e-01,
          -2.4219e-01, -5.5859e-01],
         ...,
         [-1.9775e-02,  4.9609e-01,  4.1992e-01,  ..., -4.6289e-01,
          -6.9531e-01, -3.0469e-01],
         [-6.0156e-01,  3.9844e-01,  6.2891e-01,  ..., -2.0605e-01,
          -7.5000e-01,  3.0664e-01],
         [ 3.6133e-01,  5.5469e-01, -7.0801e-02,  ...,  3.1641e-01,
          -4.0820e-01, -9.3359e-01]],

        ...,

        [[ 1.3580e-03, -1.8082e-03, -3.6812e-04,  ...,  1.6861e-03,
          -4.3945e-03,  9.2163e-03],
         [ 6.5625e-01, -1.4954e-03, -7.1484e-01,  ...,  2.5000e-01,
           8.4375e-01,  1.0781e+00],
         [-1.8555e-01, -9.1309e-02, -5.0000e-01,  ..., -4.0430e-01,
           1.7188e-01,  5.8203e-01],
         ...,
         [ 2.4512e-01,  2.2168e-01,  6.9922e-01,  ...,  2.1973e-02,
           2.6172e-01, -1.5039e-01],
         [-4.6631e-02, -2.8687e-02,  1.0078e+00,  ...,  5.0781e-01,
          -8.2422e-01, -2.4805e-01],
         [-1.1172e+00, -4.8242e-01,  5.8203e-01,  ..., -6.2988e-02,
           5.7031e-01, -1.7676e-01]],

        [[ 1.3580e-03, -1.8082e-03, -3.6812e-04,  ...,  1.6861e-03,
          -4.3945e-03,  9.2163e-03],
         [ 6.5625e-01, -1.4954e-03, -7.1484e-01,  ...,  2.5000e-01,
           8.4375e-01,  1.0781e+00],
         [-1.8555e-01, -9.1309e-02, -5.0000e-01,  ..., -4.0430e-01,
           1.7188e-01,  5.8203e-01],
         ...,
         [ 2.4512e-01,  2.2168e-01,  6.9922e-01,  ...,  2.1973e-02,
           2.6172e-01, -1.5039e-01],
         [-4.6631e-02, -2.8687e-02,  1.0078e+00,  ...,  5.0781e-01,
          -8.2422e-01, -2.4805e-01],
         [-1.1172e+00, -4.8242e-01,  5.8203e-01,  ..., -6.2988e-02,
           5.7031e-01, -1.7676e-01]],

        [[ 1.3580e-03, -1.8082e-03, -3.6812e-04,  ...,  1.6861e-03,
          -4.3945e-03,  9.2163e-03],
         [ 6.5625e-01, -1.4954e-03, -7.1484e-01,  ...,  2.5000e-01,
           8.4375e-01,  1.0781e+00],
         [-1.8555e-01, -9.1309e-02, -5.0000e-01,  ..., -4.0430e-01,
           1.7188e-01,  5.8203e-01],
         ...,
         [ 2.4512e-01,  2.2168e-01,  6.9922e-01,  ...,  2.1973e-02,
           2.6172e-01, -1.5039e-01],
         [-4.6631e-02, -2.8687e-02,  1.0078e+00,  ...,  5.0781e-01,
          -8.2422e-01, -2.4805e-01],
         [-1.1172e+00, -4.8242e-01,  5.8203e-01,  ..., -6.2988e-02,
           5.7031e-01, -1.7676e-01]]], dtype=torch.bfloat16)), (tensor([[[-1.1719e-02,  1.8799e-02,  3.8605e-03,  ..., -1.9629e-01,
          -1.7090e-01, -1.7578e-01],
         [ 3.2344e+00, -3.0781e+00,  3.3203e-01,  ...,  5.3516e-01,
          -2.3125e+00,  9.1016e-01],
         [ 3.6914e-01,  1.4160e-01,  1.7031e+00,  ...,  9.1016e-01,
           3.7500e-01, -1.6602e-01],
         ...,
         [-2.1406e+00,  1.2793e-01,  1.1641e+00,  ..., -9.4531e-01,
          -1.1328e+00,  2.2656e+00],
         [ 2.6953e-01,  3.5352e-01, -1.2188e+00,  ...,  1.6504e-01,
          -9.4531e-01,  1.9297e+00],
         [ 0.0000e+00,  1.1484e+00, -3.9062e-03,  ...,  8.5156e-01,
           2.5156e+00, -1.6406e+00]],

        [[-1.1719e-02,  1.8799e-02,  3.8605e-03,  ..., -1.9629e-01,
          -1.7090e-01, -1.7578e-01],
         [ 3.2344e+00, -3.0781e+00,  3.3203e-01,  ...,  5.3516e-01,
          -2.3125e+00,  9.1016e-01],
         [ 3.6914e-01,  1.4160e-01,  1.7031e+00,  ...,  9.1016e-01,
           3.7500e-01, -1.6602e-01],
         ...,
         [-2.1406e+00,  1.2793e-01,  1.1641e+00,  ..., -9.4531e-01,
          -1.1328e+00,  2.2656e+00],
         [ 2.6953e-01,  3.5352e-01, -1.2188e+00,  ...,  1.6504e-01,
          -9.4531e-01,  1.9297e+00],
         [ 0.0000e+00,  1.1484e+00, -3.9062e-03,  ...,  8.5156e-01,
           2.5156e+00, -1.6406e+00]],

        [[-1.1719e-02,  1.8799e-02,  3.8605e-03,  ..., -1.9629e-01,
          -1.7090e-01, -1.7578e-01],
         [ 3.2344e+00, -3.0781e+00,  3.3203e-01,  ...,  5.3516e-01,
          -2.3125e+00,  9.1016e-01],
         [ 3.6914e-01,  1.4160e-01,  1.7031e+00,  ...,  9.1016e-01,
           3.7500e-01, -1.6602e-01],
         ...,
         [-2.1406e+00,  1.2793e-01,  1.1641e+00,  ..., -9.4531e-01,
          -1.1328e+00,  2.2656e+00],
         [ 2.6953e-01,  3.5352e-01, -1.2188e+00,  ...,  1.6504e-01,
          -9.4531e-01,  1.9297e+00],
         [ 0.0000e+00,  1.1484e+00, -3.9062e-03,  ...,  8.5156e-01,
           2.5156e+00, -1.6406e+00]],

        ...,

        [[ 6.1798e-04,  1.2024e-02, -1.5076e-02,  ...,  1.9141e-01,
           4.7656e-01,  4.5703e-01],
         [-3.6406e+00, -1.2109e-01,  1.8203e+00,  ...,  1.6797e+00,
          -2.1719e+00,  1.7109e+00],
         [-1.5938e+00,  1.7383e-01,  1.0625e+00,  ...,  1.1250e+00,
          -4.1875e+00,  7.7637e-02],
         ...,
         [ 1.2188e+00,  1.4297e+00, -6.4844e-01,  ..., -4.9609e-01,
          -3.7031e+00,  1.6016e-01],
         [ 7.7344e-01,  7.6172e-01, -4.4531e-01,  ..., -7.7148e-02,
          -3.6094e+00, -1.8906e+00],
         [ 1.7344e+00, -2.1094e+00,  1.2578e+00,  ...,  2.2031e+00,
          -1.5312e+00,  9.9609e-01]],

        [[ 6.1798e-04,  1.2024e-02, -1.5076e-02,  ...,  1.9141e-01,
           4.7656e-01,  4.5703e-01],
         [-3.6406e+00, -1.2109e-01,  1.8203e+00,  ...,  1.6797e+00,
          -2.1719e+00,  1.7109e+00],
         [-1.5938e+00,  1.7383e-01,  1.0625e+00,  ...,  1.1250e+00,
          -4.1875e+00,  7.7637e-02],
         ...,
         [ 1.2188e+00,  1.4297e+00, -6.4844e-01,  ..., -4.9609e-01,
          -3.7031e+00,  1.6016e-01],
         [ 7.7344e-01,  7.6172e-01, -4.4531e-01,  ..., -7.7148e-02,
          -3.6094e+00, -1.8906e+00],
         [ 1.7344e+00, -2.1094e+00,  1.2578e+00,  ...,  2.2031e+00,
          -1.5312e+00,  9.9609e-01]],

        [[ 6.1798e-04,  1.2024e-02, -1.5076e-02,  ...,  1.9141e-01,
           4.7656e-01,  4.5703e-01],
         [-3.6406e+00, -1.2109e-01,  1.8203e+00,  ...,  1.6797e+00,
          -2.1719e+00,  1.7109e+00],
         [-1.5938e+00,  1.7383e-01,  1.0625e+00,  ...,  1.1250e+00,
          -4.1875e+00,  7.7637e-02],
         ...,
         [ 1.2188e+00,  1.4297e+00, -6.4844e-01,  ..., -4.9609e-01,
          -3.7031e+00,  1.6016e-01],
         [ 7.7344e-01,  7.6172e-01, -4.4531e-01,  ..., -7.7148e-02,
          -3.6094e+00, -1.8906e+00],
         [ 1.7344e+00, -2.1094e+00,  1.2578e+00,  ...,  2.2031e+00,
          -1.5312e+00,  9.9609e-01]]], dtype=torch.bfloat16), tensor([[[-3.8300e-03,  1.1749e-03, -5.0964e-03,  ...,  1.0193e-02,
           1.3489e-02, -1.4221e-02],
         [ 2.2559e-01, -1.7212e-02, -2.6562e-01,  ...,  2.5586e-01,
           1.8457e-01,  7.3828e-01],
         [-4.7656e-01,  3.0664e-01,  2.2559e-01,  ..., -1.1084e-01,
          -3.3594e-01, -2.4023e-01],
         ...,
         [ 8.0859e-01,  3.3203e-01, -1.1953e+00,  ..., -3.5547e-01,
           8.0078e-02, -1.6992e-01],
         [ 2.3242e-01,  1.1377e-01, -5.1953e-01,  ..., -7.8613e-02,
          -1.3086e-01, -4.4922e-01],
         [ 9.6680e-02, -5.1562e-01,  6.2109e-01,  ..., -1.8359e-01,
           3.6133e-01,  4.0820e-01]],

        [[-3.8300e-03,  1.1749e-03, -5.0964e-03,  ...,  1.0193e-02,
           1.3489e-02, -1.4221e-02],
         [ 2.2559e-01, -1.7212e-02, -2.6562e-01,  ...,  2.5586e-01,
           1.8457e-01,  7.3828e-01],
         [-4.7656e-01,  3.0664e-01,  2.2559e-01,  ..., -1.1084e-01,
          -3.3594e-01, -2.4023e-01],
         ...,
         [ 8.0859e-01,  3.3203e-01, -1.1953e+00,  ..., -3.5547e-01,
           8.0078e-02, -1.6992e-01],
         [ 2.3242e-01,  1.1377e-01, -5.1953e-01,  ..., -7.8613e-02,
          -1.3086e-01, -4.4922e-01],
         [ 9.6680e-02, -5.1562e-01,  6.2109e-01,  ..., -1.8359e-01,
           3.6133e-01,  4.0820e-01]],

        [[-3.8300e-03,  1.1749e-03, -5.0964e-03,  ...,  1.0193e-02,
           1.3489e-02, -1.4221e-02],
         [ 2.2559e-01, -1.7212e-02, -2.6562e-01,  ...,  2.5586e-01,
           1.8457e-01,  7.3828e-01],
         [-4.7656e-01,  3.0664e-01,  2.2559e-01,  ..., -1.1084e-01,
          -3.3594e-01, -2.4023e-01],
         ...,
         [ 8.0859e-01,  3.3203e-01, -1.1953e+00,  ..., -3.5547e-01,
           8.0078e-02, -1.6992e-01],
         [ 2.3242e-01,  1.1377e-01, -5.1953e-01,  ..., -7.8613e-02,
          -1.3086e-01, -4.4922e-01],
         [ 9.6680e-02, -5.1562e-01,  6.2109e-01,  ..., -1.8359e-01,
           3.6133e-01,  4.0820e-01]],

        ...,

        [[-8.8501e-03, -3.8452e-03, -2.5940e-03,  ...,  2.5635e-03,
           3.6469e-03, -3.3875e-03],
         [ 1.6504e-01, -4.8242e-01, -6.4453e-02,  ..., -2.0508e-01,
          -2.4902e-01, -2.7344e-01],
         [-7.1484e-01,  8.3203e-01,  2.2656e-01,  ..., -7.3242e-02,
          -9.1797e-01,  1.1016e+00],
         ...,
         [-3.3594e-01,  2.6367e-01, -4.7852e-01,  ..., -1.7578e-01,
          -3.4766e-01, -3.2422e-01],
         [-1.2891e-01,  4.1992e-01, -3.9258e-01,  ...,  3.5742e-01,
          -5.8594e-02, -1.7090e-01],
         [ 9.7656e-01, -3.3203e-02, -8.0078e-01,  ..., -7.3730e-02,
           1.7383e-01,  6.6406e-01]],

        [[-8.8501e-03, -3.8452e-03, -2.5940e-03,  ...,  2.5635e-03,
           3.6469e-03, -3.3875e-03],
         [ 1.6504e-01, -4.8242e-01, -6.4453e-02,  ..., -2.0508e-01,
          -2.4902e-01, -2.7344e-01],
         [-7.1484e-01,  8.3203e-01,  2.2656e-01,  ..., -7.3242e-02,
          -9.1797e-01,  1.1016e+00],
         ...,
         [-3.3594e-01,  2.6367e-01, -4.7852e-01,  ..., -1.7578e-01,
          -3.4766e-01, -3.2422e-01],
         [-1.2891e-01,  4.1992e-01, -3.9258e-01,  ...,  3.5742e-01,
          -5.8594e-02, -1.7090e-01],
         [ 9.7656e-01, -3.3203e-02, -8.0078e-01,  ..., -7.3730e-02,
           1.7383e-01,  6.6406e-01]],

        [[-8.8501e-03, -3.8452e-03, -2.5940e-03,  ...,  2.5635e-03,
           3.6469e-03, -3.3875e-03],
         [ 1.6504e-01, -4.8242e-01, -6.4453e-02,  ..., -2.0508e-01,
          -2.4902e-01, -2.7344e-01],
         [-7.1484e-01,  8.3203e-01,  2.2656e-01,  ..., -7.3242e-02,
          -9.1797e-01,  1.1016e+00],
         ...,
         [-3.3594e-01,  2.6367e-01, -4.7852e-01,  ..., -1.7578e-01,
          -3.4766e-01, -3.2422e-01],
         [-1.2891e-01,  4.1992e-01, -3.9258e-01,  ...,  3.5742e-01,
          -5.8594e-02, -1.7090e-01],
         [ 9.7656e-01, -3.3203e-02, -8.0078e-01,  ..., -7.3730e-02,
           1.7383e-01,  6.6406e-01]]], dtype=torch.bfloat16)), (tensor([[[ 3.0060e-03,  8.0872e-04, -1.0437e-02,  ..., -2.7344e-02,
          -3.0518e-03, -2.2461e-01],
         [-1.8594e+00,  2.4688e+00,  2.5000e+00,  ...,  1.6328e+00,
          -5.8350e-02,  1.6328e+00],
         [ 1.2500e+00, -1.9336e-01,  1.2656e+00,  ...,  1.1562e+00,
          -5.8203e-01,  1.1562e+00],
         ...,
         [ 1.4219e+00,  2.2266e-01,  3.5938e-01,  ...,  2.0781e+00,
          -4.6875e-01, -3.9258e-01],
         [ 2.9883e-01, -2.2461e-02,  1.1963e-01,  ...,  2.1406e+00,
           4.1504e-02, -2.5781e+00],
         [-1.2812e+00,  1.1797e+00,  1.4844e+00,  ...,  8.4766e-01,
          -5.6250e-01, -8.7109e-01]],

        [[ 3.0060e-03,  8.0872e-04, -1.0437e-02,  ..., -2.7344e-02,
          -3.0518e-03, -2.2461e-01],
         [-1.8594e+00,  2.4688e+00,  2.5000e+00,  ...,  1.6328e+00,
          -5.8350e-02,  1.6328e+00],
         [ 1.2500e+00, -1.9336e-01,  1.2656e+00,  ...,  1.1562e+00,
          -5.8203e-01,  1.1562e+00],
         ...,
         [ 1.4219e+00,  2.2266e-01,  3.5938e-01,  ...,  2.0781e+00,
          -4.6875e-01, -3.9258e-01],
         [ 2.9883e-01, -2.2461e-02,  1.1963e-01,  ...,  2.1406e+00,
           4.1504e-02, -2.5781e+00],
         [-1.2812e+00,  1.1797e+00,  1.4844e+00,  ...,  8.4766e-01,
          -5.6250e-01, -8.7109e-01]],

        [[ 3.0060e-03,  8.0872e-04, -1.0437e-02,  ..., -2.7344e-02,
          -3.0518e-03, -2.2461e-01],
         [-1.8594e+00,  2.4688e+00,  2.5000e+00,  ...,  1.6328e+00,
          -5.8350e-02,  1.6328e+00],
         [ 1.2500e+00, -1.9336e-01,  1.2656e+00,  ...,  1.1562e+00,
          -5.8203e-01,  1.1562e+00],
         ...,
         [ 1.4219e+00,  2.2266e-01,  3.5938e-01,  ...,  2.0781e+00,
          -4.6875e-01, -3.9258e-01],
         [ 2.9883e-01, -2.2461e-02,  1.1963e-01,  ...,  2.1406e+00,
           4.1504e-02, -2.5781e+00],
         [-1.2812e+00,  1.1797e+00,  1.4844e+00,  ...,  8.4766e-01,
          -5.6250e-01, -8.7109e-01]],

        ...,

        [[-8.2779e-04,  5.0964e-03, -1.3123e-03,  ...,  2.4609e-01,
           7.2656e-01,  5.5078e-01],
         [-1.8203e+00,  7.3047e-01, -2.6875e+00,  ...,  4.5312e-01,
          -3.4219e+00,  3.5156e-01],
         [ 1.9824e-01,  3.9453e-01, -7.2266e-01,  ...,  5.5859e-01,
          -6.7500e+00,  4.4141e-01],
         ...,
         [ 5.8984e-01,  9.4922e-01, -6.9141e-01,  ...,  1.9922e+00,
          -4.2812e+00,  2.5625e+00],
         [ 4.2188e-01,  6.3281e-01,  2.0508e-02,  ...,  9.2773e-02,
          -6.5312e+00,  3.1562e+00],
         [ 5.8594e-01,  5.3516e-01, -6.2500e-01,  ..., -8.4375e-01,
          -2.8594e+00,  2.3906e+00]],

        [[-8.2779e-04,  5.0964e-03, -1.3123e-03,  ...,  2.4609e-01,
           7.2656e-01,  5.5078e-01],
         [-1.8203e+00,  7.3047e-01, -2.6875e+00,  ...,  4.5312e-01,
          -3.4219e+00,  3.5156e-01],
         [ 1.9824e-01,  3.9453e-01, -7.2266e-01,  ...,  5.5859e-01,
          -6.7500e+00,  4.4141e-01],
         ...,
         [ 5.8984e-01,  9.4922e-01, -6.9141e-01,  ...,  1.9922e+00,
          -4.2812e+00,  2.5625e+00],
         [ 4.2188e-01,  6.3281e-01,  2.0508e-02,  ...,  9.2773e-02,
          -6.5312e+00,  3.1562e+00],
         [ 5.8594e-01,  5.3516e-01, -6.2500e-01,  ..., -8.4375e-01,
          -2.8594e+00,  2.3906e+00]],

        [[-8.2779e-04,  5.0964e-03, -1.3123e-03,  ...,  2.4609e-01,
           7.2656e-01,  5.5078e-01],
         [-1.8203e+00,  7.3047e-01, -2.6875e+00,  ...,  4.5312e-01,
          -3.4219e+00,  3.5156e-01],
         [ 1.9824e-01,  3.9453e-01, -7.2266e-01,  ...,  5.5859e-01,
          -6.7500e+00,  4.4141e-01],
         ...,
         [ 5.8984e-01,  9.4922e-01, -6.9141e-01,  ...,  1.9922e+00,
          -4.2812e+00,  2.5625e+00],
         [ 4.2188e-01,  6.3281e-01,  2.0508e-02,  ...,  9.2773e-02,
          -6.5312e+00,  3.1562e+00],
         [ 5.8594e-01,  5.3516e-01, -6.2500e-01,  ..., -8.4375e-01,
          -2.8594e+00,  2.3906e+00]]], dtype=torch.bfloat16), tensor([[[-1.3046e-03,  9.9487e-03, -1.2756e-02,  ..., -1.3000e-02,
           3.1281e-03, -3.1891e-03],
         [ 3.8672e-01, -2.0215e-01, -6.1035e-02,  ...,  8.5449e-02,
           4.7656e-01,  5.9570e-02],
         [ 1.0391e+00, -6.9336e-02, -2.5391e-01,  ..., -3.4375e-01,
          -4.6387e-02, -2.8906e-01],
         ...,
         [-6.8359e-01,  3.2617e-01,  2.4512e-01,  ..., -1.0742e-01,
          -2.1484e-01, -4.5117e-01],
         [-1.0791e-01,  4.7656e-01, -4.2773e-01,  ...,  8.7891e-01,
           1.3379e-01,  7.1777e-02],
         [ 7.0801e-02,  3.0029e-02, -8.9355e-02,  ..., -6.2500e-02,
          -4.1211e-01, -4.1016e-01]],

        [[-1.3046e-03,  9.9487e-03, -1.2756e-02,  ..., -1.3000e-02,
           3.1281e-03, -3.1891e-03],
         [ 3.8672e-01, -2.0215e-01, -6.1035e-02,  ...,  8.5449e-02,
           4.7656e-01,  5.9570e-02],
         [ 1.0391e+00, -6.9336e-02, -2.5391e-01,  ..., -3.4375e-01,
          -4.6387e-02, -2.8906e-01],
         ...,
         [-6.8359e-01,  3.2617e-01,  2.4512e-01,  ..., -1.0742e-01,
          -2.1484e-01, -4.5117e-01],
         [-1.0791e-01,  4.7656e-01, -4.2773e-01,  ...,  8.7891e-01,
           1.3379e-01,  7.1777e-02],
         [ 7.0801e-02,  3.0029e-02, -8.9355e-02,  ..., -6.2500e-02,
          -4.1211e-01, -4.1016e-01]],

        [[-1.3046e-03,  9.9487e-03, -1.2756e-02,  ..., -1.3000e-02,
           3.1281e-03, -3.1891e-03],
         [ 3.8672e-01, -2.0215e-01, -6.1035e-02,  ...,  8.5449e-02,
           4.7656e-01,  5.9570e-02],
         [ 1.0391e+00, -6.9336e-02, -2.5391e-01,  ..., -3.4375e-01,
          -4.6387e-02, -2.8906e-01],
         ...,
         [-6.8359e-01,  3.2617e-01,  2.4512e-01,  ..., -1.0742e-01,
          -2.1484e-01, -4.5117e-01],
         [-1.0791e-01,  4.7656e-01, -4.2773e-01,  ...,  8.7891e-01,
           1.3379e-01,  7.1777e-02],
         [ 7.0801e-02,  3.0029e-02, -8.9355e-02,  ..., -6.2500e-02,
          -4.1211e-01, -4.1016e-01]],

        ...,

        [[ 2.9785e-02, -4.5166e-02,  2.2949e-02,  ...,  2.6001e-02,
           3.2715e-02,  8.3618e-03],
         [-6.6016e-01, -8.2520e-02, -2.0410e-01,  ...,  9.5825e-03,
          -3.3008e-01, -7.3730e-02],
         [-8.3203e-01,  3.6719e-01, -1.0391e+00,  ..., -5.0781e-01,
           1.6211e-01, -2.3242e-01],
         ...,
         [-5.4688e-01,  3.5742e-01, -8.1250e-01,  ...,  1.0234e+00,
          -6.2109e-01,  5.7031e-01],
         [-2.0215e-01, -7.9102e-02, -1.5781e+00,  ...,  3.5938e-01,
          -2.9883e-01, -2.0142e-02],
         [-6.3281e-01, -3.1836e-01, -4.6484e-01,  ..., -2.8320e-01,
          -1.5430e-01, -1.0742e-01]],

        [[ 2.9785e-02, -4.5166e-02,  2.2949e-02,  ...,  2.6001e-02,
           3.2715e-02,  8.3618e-03],
         [-6.6016e-01, -8.2520e-02, -2.0410e-01,  ...,  9.5825e-03,
          -3.3008e-01, -7.3730e-02],
         [-8.3203e-01,  3.6719e-01, -1.0391e+00,  ..., -5.0781e-01,
           1.6211e-01, -2.3242e-01],
         ...,
         [-5.4688e-01,  3.5742e-01, -8.1250e-01,  ...,  1.0234e+00,
          -6.2109e-01,  5.7031e-01],
         [-2.0215e-01, -7.9102e-02, -1.5781e+00,  ...,  3.5938e-01,
          -2.9883e-01, -2.0142e-02],
         [-6.3281e-01, -3.1836e-01, -4.6484e-01,  ..., -2.8320e-01,
          -1.5430e-01, -1.0742e-01]],

        [[ 2.9785e-02, -4.5166e-02,  2.2949e-02,  ...,  2.6001e-02,
           3.2715e-02,  8.3618e-03],
         [-6.6016e-01, -8.2520e-02, -2.0410e-01,  ...,  9.5825e-03,
          -3.3008e-01, -7.3730e-02],
         [-8.3203e-01,  3.6719e-01, -1.0391e+00,  ..., -5.0781e-01,
           1.6211e-01, -2.3242e-01],
         ...,
         [-5.4688e-01,  3.5742e-01, -8.1250e-01,  ...,  1.0234e+00,
          -6.2109e-01,  5.7031e-01],
         [-2.0215e-01, -7.9102e-02, -1.5781e+00,  ...,  3.5938e-01,
          -2.9883e-01, -2.0142e-02],
         [-6.3281e-01, -3.1836e-01, -4.6484e-01,  ..., -2.8320e-01,
          -1.5430e-01, -1.0742e-01]]], dtype=torch.bfloat16)), (tensor([[[-1.6602e-02,  7.8735e-03, -1.5869e-02,  ..., -4.2969e-01,
           3.4570e-01,  9.5215e-02],
         [-2.8281e+00, -1.3750e+00,  1.6641e+00,  ...,  1.5625e+00,
           3.0273e-02,  7.7637e-02],
         [ 5.7812e-01, -7.7344e-01,  1.5781e+00,  ...,  2.2969e+00,
           1.6641e+00, -5.8984e-01],
         ...,
         [ 1.8906e+00,  5.3906e-01,  5.9766e-01,  ...,  1.4453e+00,
          -1.3672e+00,  5.4297e-01],
         [ 2.9492e-01,  1.1875e+00, -7.7148e-02,  ...,  1.1797e+00,
          -1.1172e+00,  9.2188e-01],
         [-5.3906e-01,  2.0625e+00,  2.3750e+00,  ...,  7.0312e-01,
           5.9082e-02, -1.9297e+00]],

        [[-1.6602e-02,  7.8735e-03, -1.5869e-02,  ..., -4.2969e-01,
           3.4570e-01,  9.5215e-02],
         [-2.8281e+00, -1.3750e+00,  1.6641e+00,  ...,  1.5625e+00,
           3.0273e-02,  7.7637e-02],
         [ 5.7812e-01, -7.7344e-01,  1.5781e+00,  ...,  2.2969e+00,
           1.6641e+00, -5.8984e-01],
         ...,
         [ 1.8906e+00,  5.3906e-01,  5.9766e-01,  ...,  1.4453e+00,
          -1.3672e+00,  5.4297e-01],
         [ 2.9492e-01,  1.1875e+00, -7.7148e-02,  ...,  1.1797e+00,
          -1.1172e+00,  9.2188e-01],
         [-5.3906e-01,  2.0625e+00,  2.3750e+00,  ...,  7.0312e-01,
           5.9082e-02, -1.9297e+00]],

        [[-1.6602e-02,  7.8735e-03, -1.5869e-02,  ..., -4.2969e-01,
           3.4570e-01,  9.5215e-02],
         [-2.8281e+00, -1.3750e+00,  1.6641e+00,  ...,  1.5625e+00,
           3.0273e-02,  7.7637e-02],
         [ 5.7812e-01, -7.7344e-01,  1.5781e+00,  ...,  2.2969e+00,
           1.6641e+00, -5.8984e-01],
         ...,
         [ 1.8906e+00,  5.3906e-01,  5.9766e-01,  ...,  1.4453e+00,
          -1.3672e+00,  5.4297e-01],
         [ 2.9492e-01,  1.1875e+00, -7.7148e-02,  ...,  1.1797e+00,
          -1.1172e+00,  9.2188e-01],
         [-5.3906e-01,  2.0625e+00,  2.3750e+00,  ...,  7.0312e-01,
           5.9082e-02, -1.9297e+00]],

        ...,

        [[-2.3315e-02, -1.9043e-02, -2.8687e-03,  ...,  7.2754e-02,
          -2.9219e+00, -4.1602e-01],
         [ 3.3203e-01, -5.1953e-01, -1.9922e-01,  ...,  5.7422e-01,
           8.7500e+00,  5.1875e+00],
         [-2.8516e-01, -3.2422e-01, -1.7578e-01,  ...,  8.8281e-01,
           8.1875e+00,  5.2500e+00],
         ...,
         [-3.2617e-01,  5.9375e-01,  4.4727e-01,  ...,  4.8828e-01,
           8.1875e+00, -9.4922e-01],
         [ 3.9648e-01, -1.2988e-01, -4.2480e-02,  ...,  1.1562e+00,
           5.9375e+00, -2.1562e+00],
         [ 3.2227e-01,  1.4551e-01, -6.5625e-01,  ..., -1.8047e+00,
           7.4062e+00,  2.4844e+00]],

        [[-2.3315e-02, -1.9043e-02, -2.8687e-03,  ...,  7.2754e-02,
          -2.9219e+00, -4.1602e-01],
         [ 3.3203e-01, -5.1953e-01, -1.9922e-01,  ...,  5.7422e-01,
           8.7500e+00,  5.1875e+00],
         [-2.8516e-01, -3.2422e-01, -1.7578e-01,  ...,  8.8281e-01,
           8.1875e+00,  5.2500e+00],
         ...,
         [-3.2617e-01,  5.9375e-01,  4.4727e-01,  ...,  4.8828e-01,
           8.1875e+00, -9.4922e-01],
         [ 3.9648e-01, -1.2988e-01, -4.2480e-02,  ...,  1.1562e+00,
           5.9375e+00, -2.1562e+00],
         [ 3.2227e-01,  1.4551e-01, -6.5625e-01,  ..., -1.8047e+00,
           7.4062e+00,  2.4844e+00]],

        [[-2.3315e-02, -1.9043e-02, -2.8687e-03,  ...,  7.2754e-02,
          -2.9219e+00, -4.1602e-01],
         [ 3.3203e-01, -5.1953e-01, -1.9922e-01,  ...,  5.7422e-01,
           8.7500e+00,  5.1875e+00],
         [-2.8516e-01, -3.2422e-01, -1.7578e-01,  ...,  8.8281e-01,
           8.1875e+00,  5.2500e+00],
         ...,
         [-3.2617e-01,  5.9375e-01,  4.4727e-01,  ...,  4.8828e-01,
           8.1875e+00, -9.4922e-01],
         [ 3.9648e-01, -1.2988e-01, -4.2480e-02,  ...,  1.1562e+00,
           5.9375e+00, -2.1562e+00],
         [ 3.2227e-01,  1.4551e-01, -6.5625e-01,  ..., -1.8047e+00,
           7.4062e+00,  2.4844e+00]]], dtype=torch.bfloat16), tensor([[[-2.7832e-02, -1.5869e-02, -4.3701e-02,  ...,  5.1270e-03,
           9.5215e-03,  1.0620e-02],
         [ 5.9375e-01, -4.9414e-01,  4.2969e-01,  ...,  1.3086e-01,
          -1.9434e-01,  3.2812e-01],
         [ 2.7734e-01, -1.6504e-01,  7.1484e-01,  ...,  1.1797e+00,
           2.8906e-01,  2.6245e-02],
         ...,
         [ 2.7148e-01, -1.4531e+00, -3.6133e-01,  ...,  2.7148e-01,
           5.0000e-01,  3.4424e-02],
         [ 2.4512e-01, -1.9297e+00, -9.1016e-01,  ...,  1.0781e+00,
           3.9844e-01,  8.2031e-01],
         [ 5.4297e-01,  4.6387e-02,  3.7842e-02,  ..., -1.7090e-01,
           7.4609e-01,  1.2109e-01]],

        [[-2.7832e-02, -1.5869e-02, -4.3701e-02,  ...,  5.1270e-03,
           9.5215e-03,  1.0620e-02],
         [ 5.9375e-01, -4.9414e-01,  4.2969e-01,  ...,  1.3086e-01,
          -1.9434e-01,  3.2812e-01],
         [ 2.7734e-01, -1.6504e-01,  7.1484e-01,  ...,  1.1797e+00,
           2.8906e-01,  2.6245e-02],
         ...,
         [ 2.7148e-01, -1.4531e+00, -3.6133e-01,  ...,  2.7148e-01,
           5.0000e-01,  3.4424e-02],
         [ 2.4512e-01, -1.9297e+00, -9.1016e-01,  ...,  1.0781e+00,
           3.9844e-01,  8.2031e-01],
         [ 5.4297e-01,  4.6387e-02,  3.7842e-02,  ..., -1.7090e-01,
           7.4609e-01,  1.2109e-01]],

        [[-2.7832e-02, -1.5869e-02, -4.3701e-02,  ...,  5.1270e-03,
           9.5215e-03,  1.0620e-02],
         [ 5.9375e-01, -4.9414e-01,  4.2969e-01,  ...,  1.3086e-01,
          -1.9434e-01,  3.2812e-01],
         [ 2.7734e-01, -1.6504e-01,  7.1484e-01,  ...,  1.1797e+00,
           2.8906e-01,  2.6245e-02],
         ...,
         [ 2.7148e-01, -1.4531e+00, -3.6133e-01,  ...,  2.7148e-01,
           5.0000e-01,  3.4424e-02],
         [ 2.4512e-01, -1.9297e+00, -9.1016e-01,  ...,  1.0781e+00,
           3.9844e-01,  8.2031e-01],
         [ 5.4297e-01,  4.6387e-02,  3.7842e-02,  ..., -1.7090e-01,
           7.4609e-01,  1.2109e-01]],

        ...,

        [[-2.3438e-02,  2.6978e-02,  1.3062e-02,  ...,  7.7438e-04,
           1.4954e-02,  1.2283e-03],
         [ 6.6406e-01, -8.2031e-01, -3.6523e-01,  ..., -5.2795e-03,
          -1.8359e-01,  2.6367e-01],
         [-1.0498e-01, -6.3281e-01, -3.3789e-01,  ..., -1.1250e+00,
          -7.9688e-01,  8.9844e-02],
         ...,
         [ 2.0625e+00, -4.9609e-01,  2.4805e-01,  ..., -3.1445e-01,
          -1.8516e+00, -1.2188e+00],
         [ 6.5234e-01,  1.9238e-01,  3.2422e-01,  ..., -2.5757e-02,
          -1.3359e+00, -8.3984e-01],
         [-2.1289e-01, -6.4062e-01, -8.8672e-01,  ..., -2.0703e-01,
           4.8340e-02,  7.3828e-01]],

        [[-2.3438e-02,  2.6978e-02,  1.3062e-02,  ...,  7.7438e-04,
           1.4954e-02,  1.2283e-03],
         [ 6.6406e-01, -8.2031e-01, -3.6523e-01,  ..., -5.2795e-03,
          -1.8359e-01,  2.6367e-01],
         [-1.0498e-01, -6.3281e-01, -3.3789e-01,  ..., -1.1250e+00,
          -7.9688e-01,  8.9844e-02],
         ...,
         [ 2.0625e+00, -4.9609e-01,  2.4805e-01,  ..., -3.1445e-01,
          -1.8516e+00, -1.2188e+00],
         [ 6.5234e-01,  1.9238e-01,  3.2422e-01,  ..., -2.5757e-02,
          -1.3359e+00, -8.3984e-01],
         [-2.1289e-01, -6.4062e-01, -8.8672e-01,  ..., -2.0703e-01,
           4.8340e-02,  7.3828e-01]],

        [[-2.3438e-02,  2.6978e-02,  1.3062e-02,  ...,  7.7438e-04,
           1.4954e-02,  1.2283e-03],
         [ 6.6406e-01, -8.2031e-01, -3.6523e-01,  ..., -5.2795e-03,
          -1.8359e-01,  2.6367e-01],
         [-1.0498e-01, -6.3281e-01, -3.3789e-01,  ..., -1.1250e+00,
          -7.9688e-01,  8.9844e-02],
         ...,
         [ 2.0625e+00, -4.9609e-01,  2.4805e-01,  ..., -3.1445e-01,
          -1.8516e+00, -1.2188e+00],
         [ 6.5234e-01,  1.9238e-01,  3.2422e-01,  ..., -2.5757e-02,
          -1.3359e+00, -8.3984e-01],
         [-2.1289e-01, -6.4062e-01, -8.8672e-01,  ..., -2.0703e-01,
           4.8340e-02,  7.3828e-01]]], dtype=torch.bfloat16)), (tensor([[[ 7.9727e-04,  1.0925e-02, -3.7842e-03,  ..., -1.3184e-01,
          -2.9297e-01, -9.7168e-02],
         [ 2.0000e+00,  7.6172e-01, -2.5938e+00,  ...,  9.4238e-02,
           6.7188e-01, -9.1406e-01],
         [ 6.0156e-01, -1.2344e+00, -1.1250e+00,  ...,  8.0078e-02,
           1.8672e+00, -4.7461e-01],
         ...,
         [-8.2422e-01,  7.7344e-01, -9.0820e-02,  ...,  1.4688e+00,
          -6.8750e-01, -7.1094e-01],
         [ 4.2969e-01,  3.3203e-01,  4.6875e-01,  ...,  3.3594e-01,
          -3.2617e-01, -1.0469e+00],
         [ 2.6562e-01,  1.7344e+00, -1.6250e+00,  ..., -2.1562e+00,
          -2.4219e+00, -2.5625e+00]],

        [[ 7.9727e-04,  1.0925e-02, -3.7842e-03,  ..., -1.3184e-01,
          -2.9297e-01, -9.7168e-02],
         [ 2.0000e+00,  7.6172e-01, -2.5938e+00,  ...,  9.4238e-02,
           6.7188e-01, -9.1406e-01],
         [ 6.0156e-01, -1.2344e+00, -1.1250e+00,  ...,  8.0078e-02,
           1.8672e+00, -4.7461e-01],
         ...,
         [-8.2422e-01,  7.7344e-01, -9.0820e-02,  ...,  1.4688e+00,
          -6.8750e-01, -7.1094e-01],
         [ 4.2969e-01,  3.3203e-01,  4.6875e-01,  ...,  3.3594e-01,
          -3.2617e-01, -1.0469e+00],
         [ 2.6562e-01,  1.7344e+00, -1.6250e+00,  ..., -2.1562e+00,
          -2.4219e+00, -2.5625e+00]],

        [[ 7.9727e-04,  1.0925e-02, -3.7842e-03,  ..., -1.3184e-01,
          -2.9297e-01, -9.7168e-02],
         [ 2.0000e+00,  7.6172e-01, -2.5938e+00,  ...,  9.4238e-02,
           6.7188e-01, -9.1406e-01],
         [ 6.0156e-01, -1.2344e+00, -1.1250e+00,  ...,  8.0078e-02,
           1.8672e+00, -4.7461e-01],
         ...,
         [-8.2422e-01,  7.7344e-01, -9.0820e-02,  ...,  1.4688e+00,
          -6.8750e-01, -7.1094e-01],
         [ 4.2969e-01,  3.3203e-01,  4.6875e-01,  ...,  3.3594e-01,
          -3.2617e-01, -1.0469e+00],
         [ 2.6562e-01,  1.7344e+00, -1.6250e+00,  ..., -2.1562e+00,
          -2.4219e+00, -2.5625e+00]],

        ...,

        [[ 3.6469e-03,  1.0742e-02,  3.5706e-03,  ..., -2.5156e+00,
          -1.3672e-02,  1.0498e-01],
         [ 1.4219e+00,  1.2500e-01,  5.7812e-01,  ...,  8.2500e+00,
          -4.3945e-01, -1.3047e+00],
         [-5.0781e-01,  6.7969e-01, -1.4766e+00,  ...,  8.4375e+00,
           1.2422e+00, -1.2578e+00],
         ...,
         [-2.9492e-01, -2.5977e-01,  3.3203e-01,  ...,  7.0938e+00,
          -1.1875e+00,  8.1250e-01],
         [ 3.6719e-01, -1.1562e+00,  1.0400e-01,  ...,  5.2812e+00,
          -8.9355e-02,  8.0469e-01],
         [ 1.1172e+00, -4.6875e-01,  1.1719e-02,  ...,  7.2812e+00,
          -2.0781e+00, -9.6484e-01]],

        [[ 3.6469e-03,  1.0742e-02,  3.5706e-03,  ..., -2.5156e+00,
          -1.3672e-02,  1.0498e-01],
         [ 1.4219e+00,  1.2500e-01,  5.7812e-01,  ...,  8.2500e+00,
          -4.3945e-01, -1.3047e+00],
         [-5.0781e-01,  6.7969e-01, -1.4766e+00,  ...,  8.4375e+00,
           1.2422e+00, -1.2578e+00],
         ...,
         [-2.9492e-01, -2.5977e-01,  3.3203e-01,  ...,  7.0938e+00,
          -1.1875e+00,  8.1250e-01],
         [ 3.6719e-01, -1.1562e+00,  1.0400e-01,  ...,  5.2812e+00,
          -8.9355e-02,  8.0469e-01],
         [ 1.1172e+00, -4.6875e-01,  1.1719e-02,  ...,  7.2812e+00,
          -2.0781e+00, -9.6484e-01]],

        [[ 3.6469e-03,  1.0742e-02,  3.5706e-03,  ..., -2.5156e+00,
          -1.3672e-02,  1.0498e-01],
         [ 1.4219e+00,  1.2500e-01,  5.7812e-01,  ...,  8.2500e+00,
          -4.3945e-01, -1.3047e+00],
         [-5.0781e-01,  6.7969e-01, -1.4766e+00,  ...,  8.4375e+00,
           1.2422e+00, -1.2578e+00],
         ...,
         [-2.9492e-01, -2.5977e-01,  3.3203e-01,  ...,  7.0938e+00,
          -1.1875e+00,  8.1250e-01],
         [ 3.6719e-01, -1.1562e+00,  1.0400e-01,  ...,  5.2812e+00,
          -8.9355e-02,  8.0469e-01],
         [ 1.1172e+00, -4.6875e-01,  1.1719e-02,  ...,  7.2812e+00,
          -2.0781e+00, -9.6484e-01]]], dtype=torch.bfloat16), tensor([[[-0.0193,  0.0225, -0.0356,  ..., -0.0254, -0.0118, -0.0156],
         [ 0.4336,  0.2480, -0.4023,  ...,  0.2812, -0.6602, -0.6875],
         [-0.0229, -0.0767, -0.1055,  ..., -0.5859, -0.3262,  0.0864],
         ...,
         [-0.3730, -0.6250,  0.5508,  ..., -0.0092, -1.2188, -0.0742],
         [ 0.1465, -0.7773, -0.0559,  ...,  0.0938, -0.7266,  0.3613],
         [ 0.7227,  1.0469, -0.1816,  ...,  0.7930,  0.0243, -0.2773]],

        [[-0.0193,  0.0225, -0.0356,  ..., -0.0254, -0.0118, -0.0156],
         [ 0.4336,  0.2480, -0.4023,  ...,  0.2812, -0.6602, -0.6875],
         [-0.0229, -0.0767, -0.1055,  ..., -0.5859, -0.3262,  0.0864],
         ...,
         [-0.3730, -0.6250,  0.5508,  ..., -0.0092, -1.2188, -0.0742],
         [ 0.1465, -0.7773, -0.0559,  ...,  0.0938, -0.7266,  0.3613],
         [ 0.7227,  1.0469, -0.1816,  ...,  0.7930,  0.0243, -0.2773]],

        [[-0.0193,  0.0225, -0.0356,  ..., -0.0254, -0.0118, -0.0156],
         [ 0.4336,  0.2480, -0.4023,  ...,  0.2812, -0.6602, -0.6875],
         [-0.0229, -0.0767, -0.1055,  ..., -0.5859, -0.3262,  0.0864],
         ...,
         [-0.3730, -0.6250,  0.5508,  ..., -0.0092, -1.2188, -0.0742],
         [ 0.1465, -0.7773, -0.0559,  ...,  0.0938, -0.7266,  0.3613],
         [ 0.7227,  1.0469, -0.1816,  ...,  0.7930,  0.0243, -0.2773]],

        ...,

        [[ 0.0197,  0.0050, -0.0060,  ...,  0.0374, -0.0171,  0.0016],
         [-0.0884, -0.1924,  0.2832,  ...,  0.3457, -0.0591,  0.0559],
         [ 0.3125, -0.8359,  0.3438,  ...,  0.0593,  0.3125, -0.3301],
         ...,
         [-0.7734,  0.5352,  0.1973,  ..., -0.1650, -0.4180,  0.0596],
         [-0.0215,  0.0084,  0.1279,  ..., -0.2109, -0.2266,  0.3047],
         [-0.1123,  0.0981, -0.0510,  ..., -0.6797, -1.0938,  0.0583]],

        [[ 0.0197,  0.0050, -0.0060,  ...,  0.0374, -0.0171,  0.0016],
         [-0.0884, -0.1924,  0.2832,  ...,  0.3457, -0.0591,  0.0559],
         [ 0.3125, -0.8359,  0.3438,  ...,  0.0593,  0.3125, -0.3301],
         ...,
         [-0.7734,  0.5352,  0.1973,  ..., -0.1650, -0.4180,  0.0596],
         [-0.0215,  0.0084,  0.1279,  ..., -0.2109, -0.2266,  0.3047],
         [-0.1123,  0.0981, -0.0510,  ..., -0.6797, -1.0938,  0.0583]],

        [[ 0.0197,  0.0050, -0.0060,  ...,  0.0374, -0.0171,  0.0016],
         [-0.0884, -0.1924,  0.2832,  ...,  0.3457, -0.0591,  0.0559],
         [ 0.3125, -0.8359,  0.3438,  ...,  0.0593,  0.3125, -0.3301],
         ...,
         [-0.7734,  0.5352,  0.1973,  ..., -0.1650, -0.4180,  0.0596],
         [-0.0215,  0.0084,  0.1279,  ..., -0.2109, -0.2266,  0.3047],
         [-0.1123,  0.0981, -0.0510,  ..., -0.6797, -1.0938,  0.0583]]],
       dtype=torch.bfloat16)), (tensor([[[-1.7700e-02, -1.2329e-02,  1.7456e-02,  ...,  7.7637e-02,
           5.0537e-02, -1.0107e-01],
         [ 2.0000e+00,  1.4551e-01,  2.0469e+00,  ..., -4.6289e-01,
          -6.8359e-01, -7.5391e-01],
         [ 1.6562e+00,  1.2734e+00, -2.5977e-01,  ...,  2.4512e-01,
          -2.8594e+00,  1.6562e+00],
         ...,
         [-6.8359e-01, -1.5156e+00, -2.2852e-01,  ..., -1.6484e+00,
           5.1562e+00, -8.9453e-01],
         [ 2.0117e-01, -6.7578e-01, -4.4922e-02,  ...,  7.9102e-02,
           1.9434e-01, -2.6719e+00],
         [-1.0938e+00, -1.5469e+00,  1.4844e-01,  ..., -5.3906e-01,
          -9.2163e-03,  1.0703e+00]],

        [[-1.7700e-02, -1.2329e-02,  1.7456e-02,  ...,  7.7637e-02,
           5.0537e-02, -1.0107e-01],
         [ 2.0000e+00,  1.4551e-01,  2.0469e+00,  ..., -4.6289e-01,
          -6.8359e-01, -7.5391e-01],
         [ 1.6562e+00,  1.2734e+00, -2.5977e-01,  ...,  2.4512e-01,
          -2.8594e+00,  1.6562e+00],
         ...,
         [-6.8359e-01, -1.5156e+00, -2.2852e-01,  ..., -1.6484e+00,
           5.1562e+00, -8.9453e-01],
         [ 2.0117e-01, -6.7578e-01, -4.4922e-02,  ...,  7.9102e-02,
           1.9434e-01, -2.6719e+00],
         [-1.0938e+00, -1.5469e+00,  1.4844e-01,  ..., -5.3906e-01,
          -9.2163e-03,  1.0703e+00]],

        [[-1.7700e-02, -1.2329e-02,  1.7456e-02,  ...,  7.7637e-02,
           5.0537e-02, -1.0107e-01],
         [ 2.0000e+00,  1.4551e-01,  2.0469e+00,  ..., -4.6289e-01,
          -6.8359e-01, -7.5391e-01],
         [ 1.6562e+00,  1.2734e+00, -2.5977e-01,  ...,  2.4512e-01,
          -2.8594e+00,  1.6562e+00],
         ...,
         [-6.8359e-01, -1.5156e+00, -2.2852e-01,  ..., -1.6484e+00,
           5.1562e+00, -8.9453e-01],
         [ 2.0117e-01, -6.7578e-01, -4.4922e-02,  ...,  7.9102e-02,
           1.9434e-01, -2.6719e+00],
         [-1.0938e+00, -1.5469e+00,  1.4844e-01,  ..., -5.3906e-01,
          -9.2163e-03,  1.0703e+00]],

        ...,

        [[-3.5706e-03,  6.9275e-03,  1.4221e-02,  ..., -1.7285e-01,
           1.7773e-01, -1.6968e-02],
         [ 2.4062e+00,  1.9062e+00,  2.1777e-01,  ..., -1.5156e+00,
          -1.3750e+00, -9.2578e-01],
         [ 9.4922e-01,  1.1484e+00,  3.0273e-01,  ..., -3.0938e+00,
          -2.1250e+00, -2.0938e+00],
         ...,
         [-2.1406e+00, -1.0547e+00, -2.0625e+00,  ...,  5.8984e-01,
           5.7422e-01,  1.9766e+00],
         [-3.4570e-01, -3.5156e-01, -7.7734e-01,  ...,  9.7656e-02,
           7.9688e-01,  2.7031e+00],
         [-6.9531e-01,  2.8516e-01, -1.9336e-01,  ..., -5.1953e-01,
          -1.6484e+00,  1.3047e+00]],

        [[-3.5706e-03,  6.9275e-03,  1.4221e-02,  ..., -1.7285e-01,
           1.7773e-01, -1.6968e-02],
         [ 2.4062e+00,  1.9062e+00,  2.1777e-01,  ..., -1.5156e+00,
          -1.3750e+00, -9.2578e-01],
         [ 9.4922e-01,  1.1484e+00,  3.0273e-01,  ..., -3.0938e+00,
          -2.1250e+00, -2.0938e+00],
         ...,
         [-2.1406e+00, -1.0547e+00, -2.0625e+00,  ...,  5.8984e-01,
           5.7422e-01,  1.9766e+00],
         [-3.4570e-01, -3.5156e-01, -7.7734e-01,  ...,  9.7656e-02,
           7.9688e-01,  2.7031e+00],
         [-6.9531e-01,  2.8516e-01, -1.9336e-01,  ..., -5.1953e-01,
          -1.6484e+00,  1.3047e+00]],

        [[-3.5706e-03,  6.9275e-03,  1.4221e-02,  ..., -1.7285e-01,
           1.7773e-01, -1.6968e-02],
         [ 2.4062e+00,  1.9062e+00,  2.1777e-01,  ..., -1.5156e+00,
          -1.3750e+00, -9.2578e-01],
         [ 9.4922e-01,  1.1484e+00,  3.0273e-01,  ..., -3.0938e+00,
          -2.1250e+00, -2.0938e+00],
         ...,
         [-2.1406e+00, -1.0547e+00, -2.0625e+00,  ...,  5.8984e-01,
           5.7422e-01,  1.9766e+00],
         [-3.4570e-01, -3.5156e-01, -7.7734e-01,  ...,  9.7656e-02,
           7.9688e-01,  2.7031e+00],
         [-6.9531e-01,  2.8516e-01, -1.9336e-01,  ..., -5.1953e-01,
          -1.6484e+00,  1.3047e+00]]], dtype=torch.bfloat16), tensor([[[ 0.0537, -0.0067, -0.0131,  ...,  0.0227, -0.0175,  0.0188],
         [-0.4023, -0.3379, -0.4727,  ..., -0.3730, -0.2930,  0.3906],
         [-0.2852, -0.7734,  0.0374,  ..., -0.4922, -0.1074,  0.3262],
         ...,
         [ 0.7656, -0.0801, -0.0025,  ...,  0.2656,  0.7109, -0.1611],
         [ 0.2539, -0.4785, -0.9531,  ..., -0.4551,  0.8828, -0.2119],
         [-0.3438, -0.1406, -0.0776,  ...,  0.0947,  0.0938,  0.3672]],

        [[ 0.0537, -0.0067, -0.0131,  ...,  0.0227, -0.0175,  0.0188],
         [-0.4023, -0.3379, -0.4727,  ..., -0.3730, -0.2930,  0.3906],
         [-0.2852, -0.7734,  0.0374,  ..., -0.4922, -0.1074,  0.3262],
         ...,
         [ 0.7656, -0.0801, -0.0025,  ...,  0.2656,  0.7109, -0.1611],
         [ 0.2539, -0.4785, -0.9531,  ..., -0.4551,  0.8828, -0.2119],
         [-0.3438, -0.1406, -0.0776,  ...,  0.0947,  0.0938,  0.3672]],

        [[ 0.0537, -0.0067, -0.0131,  ...,  0.0227, -0.0175,  0.0188],
         [-0.4023, -0.3379, -0.4727,  ..., -0.3730, -0.2930,  0.3906],
         [-0.2852, -0.7734,  0.0374,  ..., -0.4922, -0.1074,  0.3262],
         ...,
         [ 0.7656, -0.0801, -0.0025,  ...,  0.2656,  0.7109, -0.1611],
         [ 0.2539, -0.4785, -0.9531,  ..., -0.4551,  0.8828, -0.2119],
         [-0.3438, -0.1406, -0.0776,  ...,  0.0947,  0.0938,  0.3672]],

        ...,

        [[-0.0300, -0.0415, -0.0488,  ..., -0.0352,  0.0204,  0.0442],
         [-0.2051,  0.1680,  0.2500,  ..., -0.3320, -0.2949,  0.1670],
         [ 0.0203, -0.3398, -0.0371,  ..., -0.2354, -0.0206,  0.0171],
         ...,
         [ 0.3496, -0.0040,  0.5586,  ...,  0.6992,  0.4531,  0.1025],
         [ 0.0669, -0.3203,  0.1123,  ...,  0.5859, -1.0703, -0.2539],
         [ 0.2539,  0.2891,  0.6719,  ...,  0.6758,  0.2256,  0.3047]],

        [[-0.0300, -0.0415, -0.0488,  ..., -0.0352,  0.0204,  0.0442],
         [-0.2051,  0.1680,  0.2500,  ..., -0.3320, -0.2949,  0.1670],
         [ 0.0203, -0.3398, -0.0371,  ..., -0.2354, -0.0206,  0.0171],
         ...,
         [ 0.3496, -0.0040,  0.5586,  ...,  0.6992,  0.4531,  0.1025],
         [ 0.0669, -0.3203,  0.1123,  ...,  0.5859, -1.0703, -0.2539],
         [ 0.2539,  0.2891,  0.6719,  ...,  0.6758,  0.2256,  0.3047]],

        [[-0.0300, -0.0415, -0.0488,  ..., -0.0352,  0.0204,  0.0442],
         [-0.2051,  0.1680,  0.2500,  ..., -0.3320, -0.2949,  0.1670],
         [ 0.0203, -0.3398, -0.0371,  ..., -0.2354, -0.0206,  0.0171],
         ...,
         [ 0.3496, -0.0040,  0.5586,  ...,  0.6992,  0.4531,  0.1025],
         [ 0.0669, -0.3203,  0.1123,  ...,  0.5859, -1.0703, -0.2539],
         [ 0.2539,  0.2891,  0.6719,  ...,  0.6758,  0.2256,  0.3047]]],
       dtype=torch.bfloat16)), (tensor([[[-7.9956e-03, -3.5553e-03, -3.8147e-03,  ..., -1.2500e+00,
           1.5747e-02,  6.2012e-02],
         [ 7.0312e-01,  3.6719e-01,  1.5820e-01,  ...,  2.6172e-01,
          -5.8594e-01, -1.5625e-01],
         [-6.3281e-01,  7.7734e-01, -8.4229e-03,  ...,  1.1719e+00,
          -1.7109e+00, -2.7500e+00],
         ...,
         [-8.8867e-02, -3.4570e-01, -1.0859e+00,  ..., -1.3984e+00,
          -2.6719e+00,  7.9297e-01],
         [ 5.1172e-01, -3.7109e-01, -7.5781e-01,  ...,  4.8438e-01,
          -1.8438e+00, -8.3203e-01],
         [ 9.4141e-01, -8.5938e-01, -6.0938e-01,  ...,  2.9297e-01,
           1.5938e+00, -1.7578e+00]],

        [[-7.9956e-03, -3.5553e-03, -3.8147e-03,  ..., -1.2500e+00,
           1.5747e-02,  6.2012e-02],
         [ 7.0312e-01,  3.6719e-01,  1.5820e-01,  ...,  2.6172e-01,
          -5.8594e-01, -1.5625e-01],
         [-6.3281e-01,  7.7734e-01, -8.4229e-03,  ...,  1.1719e+00,
          -1.7109e+00, -2.7500e+00],
         ...,
         [-8.8867e-02, -3.4570e-01, -1.0859e+00,  ..., -1.3984e+00,
          -2.6719e+00,  7.9297e-01],
         [ 5.1172e-01, -3.7109e-01, -7.5781e-01,  ...,  4.8438e-01,
          -1.8438e+00, -8.3203e-01],
         [ 9.4141e-01, -8.5938e-01, -6.0938e-01,  ...,  2.9297e-01,
           1.5938e+00, -1.7578e+00]],

        [[-7.9956e-03, -3.5553e-03, -3.8147e-03,  ..., -1.2500e+00,
           1.5747e-02,  6.2012e-02],
         [ 7.0312e-01,  3.6719e-01,  1.5820e-01,  ...,  2.6172e-01,
          -5.8594e-01, -1.5625e-01],
         [-6.3281e-01,  7.7734e-01, -8.4229e-03,  ...,  1.1719e+00,
          -1.7109e+00, -2.7500e+00],
         ...,
         [-8.8867e-02, -3.4570e-01, -1.0859e+00,  ..., -1.3984e+00,
          -2.6719e+00,  7.9297e-01],
         [ 5.1172e-01, -3.7109e-01, -7.5781e-01,  ...,  4.8438e-01,
          -1.8438e+00, -8.3203e-01],
         [ 9.4141e-01, -8.5938e-01, -6.0938e-01,  ...,  2.9297e-01,
           1.5938e+00, -1.7578e+00]],

        ...,

        [[-1.8692e-03,  1.0300e-03, -7.0496e-03,  ...,  1.4375e+00,
          -2.7969e+00, -3.5156e-01],
         [ 1.7188e-01, -2.3281e+00, -9.9219e-01,  ..., -1.9219e+00,
           6.9375e+00,  1.6641e+00],
         [ 1.5527e-01, -3.7109e-01, -6.2500e-01,  ..., -2.3906e+00,
           7.8438e+00,  4.6562e+00],
         ...,
         [-1.1641e+00,  8.4375e-01, -1.6211e-01,  ..., -1.2891e+00,
           6.0938e+00, -4.3438e+00],
         [ 6.0547e-02, -1.7676e-01, -4.6875e-01,  ..., -9.0234e-01,
           4.4688e+00, -1.2344e+00],
         [-1.2266e+00,  1.5625e-02, -9.3750e-01,  ..., -3.4180e-01,
           5.8438e+00,  2.3281e+00]],

        [[-1.8692e-03,  1.0300e-03, -7.0496e-03,  ...,  1.4375e+00,
          -2.7969e+00, -3.5156e-01],
         [ 1.7188e-01, -2.3281e+00, -9.9219e-01,  ..., -1.9219e+00,
           6.9375e+00,  1.6641e+00],
         [ 1.5527e-01, -3.7109e-01, -6.2500e-01,  ..., -2.3906e+00,
           7.8438e+00,  4.6562e+00],
         ...,
         [-1.1641e+00,  8.4375e-01, -1.6211e-01,  ..., -1.2891e+00,
           6.0938e+00, -4.3438e+00],
         [ 6.0547e-02, -1.7676e-01, -4.6875e-01,  ..., -9.0234e-01,
           4.4688e+00, -1.2344e+00],
         [-1.2266e+00,  1.5625e-02, -9.3750e-01,  ..., -3.4180e-01,
           5.8438e+00,  2.3281e+00]],

        [[-1.8692e-03,  1.0300e-03, -7.0496e-03,  ...,  1.4375e+00,
          -2.7969e+00, -3.5156e-01],
         [ 1.7188e-01, -2.3281e+00, -9.9219e-01,  ..., -1.9219e+00,
           6.9375e+00,  1.6641e+00],
         [ 1.5527e-01, -3.7109e-01, -6.2500e-01,  ..., -2.3906e+00,
           7.8438e+00,  4.6562e+00],
         ...,
         [-1.1641e+00,  8.4375e-01, -1.6211e-01,  ..., -1.2891e+00,
           6.0938e+00, -4.3438e+00],
         [ 6.0547e-02, -1.7676e-01, -4.6875e-01,  ..., -9.0234e-01,
           4.4688e+00, -1.2344e+00],
         [-1.2266e+00,  1.5625e-02, -9.3750e-01,  ..., -3.4180e-01,
           5.8438e+00,  2.3281e+00]]], dtype=torch.bfloat16), tensor([[[ 0.0623,  0.0913, -0.0081,  ...,  0.0413, -0.0123,  0.0043],
         [-0.5742,  0.0535, -0.3262,  ...,  0.3320,  0.4824,  0.2266],
         [-0.7031, -0.4531,  0.2227,  ..., -0.0579,  0.4180,  0.1084],
         ...,
         [ 0.4590, -0.3926,  0.3984,  ...,  0.2559, -0.1855,  0.7461],
         [-0.0698, -0.4395,  0.8242,  ..., -0.0913,  0.0625,  0.7383],
         [ 0.2676, -0.3828, -1.4062,  ...,  0.6211,  0.8164, -0.4180]],

        [[ 0.0623,  0.0913, -0.0081,  ...,  0.0413, -0.0123,  0.0043],
         [-0.5742,  0.0535, -0.3262,  ...,  0.3320,  0.4824,  0.2266],
         [-0.7031, -0.4531,  0.2227,  ..., -0.0579,  0.4180,  0.1084],
         ...,
         [ 0.4590, -0.3926,  0.3984,  ...,  0.2559, -0.1855,  0.7461],
         [-0.0698, -0.4395,  0.8242,  ..., -0.0913,  0.0625,  0.7383],
         [ 0.2676, -0.3828, -1.4062,  ...,  0.6211,  0.8164, -0.4180]],

        [[ 0.0623,  0.0913, -0.0081,  ...,  0.0413, -0.0123,  0.0043],
         [-0.5742,  0.0535, -0.3262,  ...,  0.3320,  0.4824,  0.2266],
         [-0.7031, -0.4531,  0.2227,  ..., -0.0579,  0.4180,  0.1084],
         ...,
         [ 0.4590, -0.3926,  0.3984,  ...,  0.2559, -0.1855,  0.7461],
         [-0.0698, -0.4395,  0.8242,  ..., -0.0913,  0.0625,  0.7383],
         [ 0.2676, -0.3828, -1.4062,  ...,  0.6211,  0.8164, -0.4180]],

        ...,

        [[-0.0090,  0.0481,  0.0194,  ...,  0.0170,  0.0413,  0.0540],
         [-0.3809, -0.5859,  0.0077,  ...,  0.2715,  0.6562, -0.1631],
         [-0.8516, -0.4805, -0.6523,  ..., -0.4141,  0.1099, -0.4941],
         ...,
         [ 0.7930, -0.0747, -0.0684,  ..., -0.5898,  0.0684, -0.4922],
         [ 1.3281,  0.2334, -0.2139,  ..., -0.4824,  0.0820, -0.4980],
         [-0.2598, -0.6133, -0.6797,  ...,  0.3457, -0.5742,  0.2832]],

        [[-0.0090,  0.0481,  0.0194,  ...,  0.0170,  0.0413,  0.0540],
         [-0.3809, -0.5859,  0.0077,  ...,  0.2715,  0.6562, -0.1631],
         [-0.8516, -0.4805, -0.6523,  ..., -0.4141,  0.1099, -0.4941],
         ...,
         [ 0.7930, -0.0747, -0.0684,  ..., -0.5898,  0.0684, -0.4922],
         [ 1.3281,  0.2334, -0.2139,  ..., -0.4824,  0.0820, -0.4980],
         [-0.2598, -0.6133, -0.6797,  ...,  0.3457, -0.5742,  0.2832]],

        [[-0.0090,  0.0481,  0.0194,  ...,  0.0170,  0.0413,  0.0540],
         [-0.3809, -0.5859,  0.0077,  ...,  0.2715,  0.6562, -0.1631],
         [-0.8516, -0.4805, -0.6523,  ..., -0.4141,  0.1099, -0.4941],
         ...,
         [ 0.7930, -0.0747, -0.0684,  ..., -0.5898,  0.0684, -0.4922],
         [ 1.3281,  0.2334, -0.2139,  ..., -0.4824,  0.0820, -0.4980],
         [-0.2598, -0.6133, -0.6797,  ...,  0.3457, -0.5742,  0.2832]]],
       dtype=torch.bfloat16)), (tensor([[[-7.8125e-03,  3.9673e-03, -4.2419e-03,  ..., -6.7188e-01,
          -1.7090e-01,  9.5215e-02],
         [ 1.6406e+00,  9.9609e-01,  1.7422e+00,  ..., -1.0781e+00,
           1.0469e+00,  1.4141e+00],
         [ 9.2578e-01,  9.8047e-01,  1.8828e+00,  ..., -2.4062e+00,
          -1.1328e+00,  2.3750e+00],
         ...,
         [-5.9375e-01,  8.7891e-02,  1.4219e+00,  ...,  4.4189e-02,
           5.5469e-01,  8.5938e-01],
         [-2.6562e-01, -4.3555e-01,  4.1406e-01,  ..., -3.9453e-01,
          -1.3359e+00,  1.3359e+00],
         [-1.7500e+00, -2.9492e-01,  1.0000e+00,  ..., -1.2158e-01,
           1.3516e+00,  7.1484e-01]],

        [[-7.8125e-03,  3.9673e-03, -4.2419e-03,  ..., -6.7188e-01,
          -1.7090e-01,  9.5215e-02],
         [ 1.6406e+00,  9.9609e-01,  1.7422e+00,  ..., -1.0781e+00,
           1.0469e+00,  1.4141e+00],
         [ 9.2578e-01,  9.8047e-01,  1.8828e+00,  ..., -2.4062e+00,
          -1.1328e+00,  2.3750e+00],
         ...,
         [-5.9375e-01,  8.7891e-02,  1.4219e+00,  ...,  4.4189e-02,
           5.5469e-01,  8.5938e-01],
         [-2.6562e-01, -4.3555e-01,  4.1406e-01,  ..., -3.9453e-01,
          -1.3359e+00,  1.3359e+00],
         [-1.7500e+00, -2.9492e-01,  1.0000e+00,  ..., -1.2158e-01,
           1.3516e+00,  7.1484e-01]],

        [[-7.8125e-03,  3.9673e-03, -4.2419e-03,  ..., -6.7188e-01,
          -1.7090e-01,  9.5215e-02],
         [ 1.6406e+00,  9.9609e-01,  1.7422e+00,  ..., -1.0781e+00,
           1.0469e+00,  1.4141e+00],
         [ 9.2578e-01,  9.8047e-01,  1.8828e+00,  ..., -2.4062e+00,
          -1.1328e+00,  2.3750e+00],
         ...,
         [-5.9375e-01,  8.7891e-02,  1.4219e+00,  ...,  4.4189e-02,
           5.5469e-01,  8.5938e-01],
         [-2.6562e-01, -4.3555e-01,  4.1406e-01,  ..., -3.9453e-01,
          -1.3359e+00,  1.3359e+00],
         [-1.7500e+00, -2.9492e-01,  1.0000e+00,  ..., -1.2158e-01,
           1.3516e+00,  7.1484e-01]],

        ...,

        [[ 1.5198e-02,  1.0864e-02, -6.3477e-03,  ...,  1.2146e-02,
           1.2451e-02,  2.3281e+00],
         [ 8.3984e-01,  2.2344e+00,  1.0391e+00,  ...,  3.3398e-01,
          -3.8438e+00, -7.8438e+00],
         [-2.1484e-01,  1.6562e+00,  1.2695e-01,  ...,  5.5859e-01,
          -1.9219e+00, -1.0062e+01],
         ...,
         [-7.3438e-01,  1.1719e-01, -9.9609e-01,  ...,  9.6875e-01,
           1.6211e-01, -4.9062e+00],
         [ 9.7656e-03,  4.2383e-01,  1.4258e-01,  ...,  1.5312e+00,
          -6.5234e-01, -5.3438e+00],
         [ 1.2969e+00,  2.6953e-01,  1.4531e+00,  ...,  3.4570e-01,
          -2.7656e+00, -5.7500e+00]],

        [[ 1.5198e-02,  1.0864e-02, -6.3477e-03,  ...,  1.2146e-02,
           1.2451e-02,  2.3281e+00],
         [ 8.3984e-01,  2.2344e+00,  1.0391e+00,  ...,  3.3398e-01,
          -3.8438e+00, -7.8438e+00],
         [-2.1484e-01,  1.6562e+00,  1.2695e-01,  ...,  5.5859e-01,
          -1.9219e+00, -1.0062e+01],
         ...,
         [-7.3438e-01,  1.1719e-01, -9.9609e-01,  ...,  9.6875e-01,
           1.6211e-01, -4.9062e+00],
         [ 9.7656e-03,  4.2383e-01,  1.4258e-01,  ...,  1.5312e+00,
          -6.5234e-01, -5.3438e+00],
         [ 1.2969e+00,  2.6953e-01,  1.4531e+00,  ...,  3.4570e-01,
          -2.7656e+00, -5.7500e+00]],

        [[ 1.5198e-02,  1.0864e-02, -6.3477e-03,  ...,  1.2146e-02,
           1.2451e-02,  2.3281e+00],
         [ 8.3984e-01,  2.2344e+00,  1.0391e+00,  ...,  3.3398e-01,
          -3.8438e+00, -7.8438e+00],
         [-2.1484e-01,  1.6562e+00,  1.2695e-01,  ...,  5.5859e-01,
          -1.9219e+00, -1.0062e+01],
         ...,
         [-7.3438e-01,  1.1719e-01, -9.9609e-01,  ...,  9.6875e-01,
           1.6211e-01, -4.9062e+00],
         [ 9.7656e-03,  4.2383e-01,  1.4258e-01,  ...,  1.5312e+00,
          -6.5234e-01, -5.3438e+00],
         [ 1.2969e+00,  2.6953e-01,  1.4531e+00,  ...,  3.4570e-01,
          -2.7656e+00, -5.7500e+00]]], dtype=torch.bfloat16), tensor([[[-1.0437e-02,  4.6730e-04,  1.5869e-02,  ...,  4.7607e-02,
          -4.8584e-02, -2.1118e-02],
         [-4.8438e-01,  6.2256e-02,  3.7500e-01,  ..., -1.2329e-02,
           6.8054e-03,  4.7461e-01],
         [ 2.6367e-01, -3.1641e-01,  3.7109e-01,  ..., -4.0039e-02,
           4.7461e-01,  4.8242e-01],
         ...,
         [ 1.3184e-01,  4.9023e-01,  6.1328e-01,  ..., -6.9141e-01,
          -6.6797e-01,  6.2256e-02],
         [ 1.6504e-01,  2.4048e-02, -4.0820e-01,  ..., -2.3242e-01,
           8.3203e-01, -3.0664e-01],
         [-9.4238e-02,  1.7188e-01, -1.5820e-01,  ...,  2.5586e-01,
          -3.6914e-01,  1.0596e-01]],

        [[-1.0437e-02,  4.6730e-04,  1.5869e-02,  ...,  4.7607e-02,
          -4.8584e-02, -2.1118e-02],
         [-4.8438e-01,  6.2256e-02,  3.7500e-01,  ..., -1.2329e-02,
           6.8054e-03,  4.7461e-01],
         [ 2.6367e-01, -3.1641e-01,  3.7109e-01,  ..., -4.0039e-02,
           4.7461e-01,  4.8242e-01],
         ...,
         [ 1.3184e-01,  4.9023e-01,  6.1328e-01,  ..., -6.9141e-01,
          -6.6797e-01,  6.2256e-02],
         [ 1.6504e-01,  2.4048e-02, -4.0820e-01,  ..., -2.3242e-01,
           8.3203e-01, -3.0664e-01],
         [-9.4238e-02,  1.7188e-01, -1.5820e-01,  ...,  2.5586e-01,
          -3.6914e-01,  1.0596e-01]],

        [[-1.0437e-02,  4.6730e-04,  1.5869e-02,  ...,  4.7607e-02,
          -4.8584e-02, -2.1118e-02],
         [-4.8438e-01,  6.2256e-02,  3.7500e-01,  ..., -1.2329e-02,
           6.8054e-03,  4.7461e-01],
         [ 2.6367e-01, -3.1641e-01,  3.7109e-01,  ..., -4.0039e-02,
           4.7461e-01,  4.8242e-01],
         ...,
         [ 1.3184e-01,  4.9023e-01,  6.1328e-01,  ..., -6.9141e-01,
          -6.6797e-01,  6.2256e-02],
         [ 1.6504e-01,  2.4048e-02, -4.0820e-01,  ..., -2.3242e-01,
           8.3203e-01, -3.0664e-01],
         [-9.4238e-02,  1.7188e-01, -1.5820e-01,  ...,  2.5586e-01,
          -3.6914e-01,  1.0596e-01]],

        ...,

        [[-1.8799e-02,  1.3245e-02, -1.6602e-02,  ..., -1.1292e-02,
           1.3489e-02,  2.1820e-03],
         [ 2.6758e-01, -1.4062e-01, -9.5215e-02,  ..., -1.2500e-01,
           4.9219e-01,  2.7930e-01],
         [-3.2422e-01,  1.4453e-01, -4.1211e-01,  ...,  3.0078e-01,
           3.6328e-01, -1.3672e-01],
         ...,
         [ 3.3008e-01,  3.2031e-01,  9.5703e-02,  ..., -3.2617e-01,
           6.5234e-01, -8.2031e-01],
         [ 1.0400e-01, -5.0781e-01, -5.6250e-01,  ...,  1.3086e-01,
          -1.4941e-01, -3.8818e-02],
         [ 7.1777e-02, -4.2969e-01, -5.5859e-01,  ..., -3.1250e-02,
          -4.4531e-01,  2.2852e-01]],

        [[-1.8799e-02,  1.3245e-02, -1.6602e-02,  ..., -1.1292e-02,
           1.3489e-02,  2.1820e-03],
         [ 2.6758e-01, -1.4062e-01, -9.5215e-02,  ..., -1.2500e-01,
           4.9219e-01,  2.7930e-01],
         [-3.2422e-01,  1.4453e-01, -4.1211e-01,  ...,  3.0078e-01,
           3.6328e-01, -1.3672e-01],
         ...,
         [ 3.3008e-01,  3.2031e-01,  9.5703e-02,  ..., -3.2617e-01,
           6.5234e-01, -8.2031e-01],
         [ 1.0400e-01, -5.0781e-01, -5.6250e-01,  ...,  1.3086e-01,
          -1.4941e-01, -3.8818e-02],
         [ 7.1777e-02, -4.2969e-01, -5.5859e-01,  ..., -3.1250e-02,
          -4.4531e-01,  2.2852e-01]],

        [[-1.8799e-02,  1.3245e-02, -1.6602e-02,  ..., -1.1292e-02,
           1.3489e-02,  2.1820e-03],
         [ 2.6758e-01, -1.4062e-01, -9.5215e-02,  ..., -1.2500e-01,
           4.9219e-01,  2.7930e-01],
         [-3.2422e-01,  1.4453e-01, -4.1211e-01,  ...,  3.0078e-01,
           3.6328e-01, -1.3672e-01],
         ...,
         [ 3.3008e-01,  3.2031e-01,  9.5703e-02,  ..., -3.2617e-01,
           6.5234e-01, -8.2031e-01],
         [ 1.0400e-01, -5.0781e-01, -5.6250e-01,  ...,  1.3086e-01,
          -1.4941e-01, -3.8818e-02],
         [ 7.1777e-02, -4.2969e-01, -5.5859e-01,  ..., -3.1250e-02,
          -4.4531e-01,  2.2852e-01]]], dtype=torch.bfloat16)), (tensor([[[-2.4567e-03, -9.7046e-03,  2.9449e-03,  ..., -2.8711e-01,
          -1.3770e-01, -1.1230e-01],
         [-1.0078e+00,  2.8711e-01, -1.0469e+00,  ..., -1.1562e+00,
           3.9688e+00, -3.5156e-01],
         [ 3.0469e-01,  6.7578e-01, -3.7109e-01,  ...,  1.1230e-01,
           7.0703e-01,  2.5625e+00],
         ...,
         [ 1.1641e+00,  6.8750e-01, -3.0664e-01,  ..., -3.0078e-01,
          -1.2969e+00, -1.3594e+00],
         [ 4.4922e-02,  7.8125e-01,  4.9805e-02,  ..., -6.8359e-01,
           9.1016e-01,  2.6758e-01],
         [-7.7734e-01,  1.1406e+00, -8.6719e-01,  ..., -3.7344e+00,
           1.8594e+00,  1.8906e+00]],

        [[-2.4567e-03, -9.7046e-03,  2.9449e-03,  ..., -2.8711e-01,
          -1.3770e-01, -1.1230e-01],
         [-1.0078e+00,  2.8711e-01, -1.0469e+00,  ..., -1.1562e+00,
           3.9688e+00, -3.5156e-01],
         [ 3.0469e-01,  6.7578e-01, -3.7109e-01,  ...,  1.1230e-01,
           7.0703e-01,  2.5625e+00],
         ...,
         [ 1.1641e+00,  6.8750e-01, -3.0664e-01,  ..., -3.0078e-01,
          -1.2969e+00, -1.3594e+00],
         [ 4.4922e-02,  7.8125e-01,  4.9805e-02,  ..., -6.8359e-01,
           9.1016e-01,  2.6758e-01],
         [-7.7734e-01,  1.1406e+00, -8.6719e-01,  ..., -3.7344e+00,
           1.8594e+00,  1.8906e+00]],

        [[-2.4567e-03, -9.7046e-03,  2.9449e-03,  ..., -2.8711e-01,
          -1.3770e-01, -1.1230e-01],
         [-1.0078e+00,  2.8711e-01, -1.0469e+00,  ..., -1.1562e+00,
           3.9688e+00, -3.5156e-01],
         [ 3.0469e-01,  6.7578e-01, -3.7109e-01,  ...,  1.1230e-01,
           7.0703e-01,  2.5625e+00],
         ...,
         [ 1.1641e+00,  6.8750e-01, -3.0664e-01,  ..., -3.0078e-01,
          -1.2969e+00, -1.3594e+00],
         [ 4.4922e-02,  7.8125e-01,  4.9805e-02,  ..., -6.8359e-01,
           9.1016e-01,  2.6758e-01],
         [-7.7734e-01,  1.1406e+00, -8.6719e-01,  ..., -3.7344e+00,
           1.8594e+00,  1.8906e+00]],

        ...,

        [[ 3.6163e-03, -1.3794e-02,  1.1597e-02,  ...,  3.6914e-01,
           7.5684e-02,  1.3184e-01],
         [ 7.3047e-01,  1.1328e+00,  1.2109e+00,  ..., -3.8672e-01,
          -5.5469e-01, -7.9688e-01],
         [ 6.7578e-01, -7.4707e-02,  8.4766e-01,  ...,  3.9844e+00,
          -1.8125e+00,  1.0000e+00],
         ...,
         [-1.2969e+00, -1.0938e+00,  1.0000e+00,  ..., -2.0781e+00,
           2.3750e+00,  1.3672e+00],
         [ 8.5938e-02, -3.5156e-01, -2.2070e-01,  ..., -2.5625e+00,
           8.8672e-01,  1.5234e+00],
         [-1.0000e+00,  1.3594e+00,  2.1875e-01,  ..., -4.6875e-01,
           2.0801e-01, -9.8438e-01]],

        [[ 3.6163e-03, -1.3794e-02,  1.1597e-02,  ...,  3.6914e-01,
           7.5684e-02,  1.3184e-01],
         [ 7.3047e-01,  1.1328e+00,  1.2109e+00,  ..., -3.8672e-01,
          -5.5469e-01, -7.9688e-01],
         [ 6.7578e-01, -7.4707e-02,  8.4766e-01,  ...,  3.9844e+00,
          -1.8125e+00,  1.0000e+00],
         ...,
         [-1.2969e+00, -1.0938e+00,  1.0000e+00,  ..., -2.0781e+00,
           2.3750e+00,  1.3672e+00],
         [ 8.5938e-02, -3.5156e-01, -2.2070e-01,  ..., -2.5625e+00,
           8.8672e-01,  1.5234e+00],
         [-1.0000e+00,  1.3594e+00,  2.1875e-01,  ..., -4.6875e-01,
           2.0801e-01, -9.8438e-01]],

        [[ 3.6163e-03, -1.3794e-02,  1.1597e-02,  ...,  3.6914e-01,
           7.5684e-02,  1.3184e-01],
         [ 7.3047e-01,  1.1328e+00,  1.2109e+00,  ..., -3.8672e-01,
          -5.5469e-01, -7.9688e-01],
         [ 6.7578e-01, -7.4707e-02,  8.4766e-01,  ...,  3.9844e+00,
          -1.8125e+00,  1.0000e+00],
         ...,
         [-1.2969e+00, -1.0938e+00,  1.0000e+00,  ..., -2.0781e+00,
           2.3750e+00,  1.3672e+00],
         [ 8.5938e-02, -3.5156e-01, -2.2070e-01,  ..., -2.5625e+00,
           8.8672e-01,  1.5234e+00],
         [-1.0000e+00,  1.3594e+00,  2.1875e-01,  ..., -4.6875e-01,
           2.0801e-01, -9.8438e-01]]], dtype=torch.bfloat16), tensor([[[-1.9775e-02,  3.9551e-02, -2.7344e-02,  ..., -4.4434e-02,
          -7.5684e-02,  1.7944e-02],
         [ 2.5781e-01, -4.1797e-01,  1.6406e-01,  ...,  6.4453e-01,
          -3.4766e-01, -4.9414e-01],
         [-5.0781e-01,  4.7070e-01, -2.5781e-01,  ...,  7.9297e-01,
          -6.0547e-01,  4.2969e-02],
         ...,
         [-6.4844e-01,  9.4727e-02, -5.2979e-02,  ..., -6.2109e-01,
          -1.9043e-01, -6.0547e-01],
         [-8.9844e-01, -3.6914e-01, -2.8906e-01,  ..., -3.0078e-01,
          -1.9922e-01, -9.9609e-01],
         [-2.8125e-01,  1.6406e-01, -5.0391e-01,  ...,  1.8457e-01,
          -7.3242e-02,  1.6406e-01]],

        [[-1.9775e-02,  3.9551e-02, -2.7344e-02,  ..., -4.4434e-02,
          -7.5684e-02,  1.7944e-02],
         [ 2.5781e-01, -4.1797e-01,  1.6406e-01,  ...,  6.4453e-01,
          -3.4766e-01, -4.9414e-01],
         [-5.0781e-01,  4.7070e-01, -2.5781e-01,  ...,  7.9297e-01,
          -6.0547e-01,  4.2969e-02],
         ...,
         [-6.4844e-01,  9.4727e-02, -5.2979e-02,  ..., -6.2109e-01,
          -1.9043e-01, -6.0547e-01],
         [-8.9844e-01, -3.6914e-01, -2.8906e-01,  ..., -3.0078e-01,
          -1.9922e-01, -9.9609e-01],
         [-2.8125e-01,  1.6406e-01, -5.0391e-01,  ...,  1.8457e-01,
          -7.3242e-02,  1.6406e-01]],

        [[-1.9775e-02,  3.9551e-02, -2.7344e-02,  ..., -4.4434e-02,
          -7.5684e-02,  1.7944e-02],
         [ 2.5781e-01, -4.1797e-01,  1.6406e-01,  ...,  6.4453e-01,
          -3.4766e-01, -4.9414e-01],
         [-5.0781e-01,  4.7070e-01, -2.5781e-01,  ...,  7.9297e-01,
          -6.0547e-01,  4.2969e-02],
         ...,
         [-6.4844e-01,  9.4727e-02, -5.2979e-02,  ..., -6.2109e-01,
          -1.9043e-01, -6.0547e-01],
         [-8.9844e-01, -3.6914e-01, -2.8906e-01,  ..., -3.0078e-01,
          -1.9922e-01, -9.9609e-01],
         [-2.8125e-01,  1.6406e-01, -5.0391e-01,  ...,  1.8457e-01,
          -7.3242e-02,  1.6406e-01]],

        ...,

        [[ 4.2725e-03, -1.0498e-02, -1.6357e-02,  ..., -2.4261e-03,
          -2.6489e-02,  8.6975e-04],
         [-3.2031e-01, -3.2617e-01,  8.8379e-02,  ..., -8.1250e-01,
           1.6504e-01,  4.5508e-01],
         [-2.8711e-01, -3.9453e-01,  1.5137e-01,  ...,  5.1514e-02,
           6.3965e-02,  1.4746e-01],
         ...,
         [-3.2812e-01,  9.0332e-02,  2.4536e-02,  ...,  7.3828e-01,
           5.6641e-01,  1.2891e-01],
         [-7.0312e-01,  1.7383e-01,  4.1602e-01,  ..., -5.8594e-02,
           8.0469e-01,  2.6758e-01],
         [ 2.3315e-02, -4.7607e-03, -2.8125e-01,  ...,  1.9238e-01,
           6.2500e-01,  9.6680e-02]],

        [[ 4.2725e-03, -1.0498e-02, -1.6357e-02,  ..., -2.4261e-03,
          -2.6489e-02,  8.6975e-04],
         [-3.2031e-01, -3.2617e-01,  8.8379e-02,  ..., -8.1250e-01,
           1.6504e-01,  4.5508e-01],
         [-2.8711e-01, -3.9453e-01,  1.5137e-01,  ...,  5.1514e-02,
           6.3965e-02,  1.4746e-01],
         ...,
         [-3.2812e-01,  9.0332e-02,  2.4536e-02,  ...,  7.3828e-01,
           5.6641e-01,  1.2891e-01],
         [-7.0312e-01,  1.7383e-01,  4.1602e-01,  ..., -5.8594e-02,
           8.0469e-01,  2.6758e-01],
         [ 2.3315e-02, -4.7607e-03, -2.8125e-01,  ...,  1.9238e-01,
           6.2500e-01,  9.6680e-02]],

        [[ 4.2725e-03, -1.0498e-02, -1.6357e-02,  ..., -2.4261e-03,
          -2.6489e-02,  8.6975e-04],
         [-3.2031e-01, -3.2617e-01,  8.8379e-02,  ..., -8.1250e-01,
           1.6504e-01,  4.5508e-01],
         [-2.8711e-01, -3.9453e-01,  1.5137e-01,  ...,  5.1514e-02,
           6.3965e-02,  1.4746e-01],
         ...,
         [-3.2812e-01,  9.0332e-02,  2.4536e-02,  ...,  7.3828e-01,
           5.6641e-01,  1.2891e-01],
         [-7.0312e-01,  1.7383e-01,  4.1602e-01,  ..., -5.8594e-02,
           8.0469e-01,  2.6758e-01],
         [ 2.3315e-02, -4.7607e-03, -2.8125e-01,  ...,  1.9238e-01,
           6.2500e-01,  9.6680e-02]]], dtype=torch.bfloat16)), (tensor([[[ 2.0504e-04, -5.9509e-04, -8.3618e-03,  ...,  3.5706e-03,
           7.7637e-02, -1.4258e-01],
         [-1.5547e+00, -6.5234e-01, -7.3047e-01,  ...,  2.0312e+00,
          -6.8750e-01,  6.7578e-01],
         [ 6.4844e-01, -6.2500e-01, -6.1328e-01,  ..., -7.3047e-01,
          -1.5078e+00, -6.0547e-01],
         ...,
         [ 1.2031e+00,  7.3438e-01, -6.9531e-01,  ...,  9.8828e-01,
           2.0000e+00,  2.5312e+00],
         [-3.1250e-01,  2.0605e-01, -3.0469e-01,  ..., -6.2988e-02,
           4.4141e-01,  1.4453e+00],
         [-8.6719e-01,  3.8477e-01, -8.6719e-01,  ...,  1.1016e+00,
          -1.4922e+00,  8.3984e-01]],

        [[ 2.0504e-04, -5.9509e-04, -8.3618e-03,  ...,  3.5706e-03,
           7.7637e-02, -1.4258e-01],
         [-1.5547e+00, -6.5234e-01, -7.3047e-01,  ...,  2.0312e+00,
          -6.8750e-01,  6.7578e-01],
         [ 6.4844e-01, -6.2500e-01, -6.1328e-01,  ..., -7.3047e-01,
          -1.5078e+00, -6.0547e-01],
         ...,
         [ 1.2031e+00,  7.3438e-01, -6.9531e-01,  ...,  9.8828e-01,
           2.0000e+00,  2.5312e+00],
         [-3.1250e-01,  2.0605e-01, -3.0469e-01,  ..., -6.2988e-02,
           4.4141e-01,  1.4453e+00],
         [-8.6719e-01,  3.8477e-01, -8.6719e-01,  ...,  1.1016e+00,
          -1.4922e+00,  8.3984e-01]],

        [[ 2.0504e-04, -5.9509e-04, -8.3618e-03,  ...,  3.5706e-03,
           7.7637e-02, -1.4258e-01],
         [-1.5547e+00, -6.5234e-01, -7.3047e-01,  ...,  2.0312e+00,
          -6.8750e-01,  6.7578e-01],
         [ 6.4844e-01, -6.2500e-01, -6.1328e-01,  ..., -7.3047e-01,
          -1.5078e+00, -6.0547e-01],
         ...,
         [ 1.2031e+00,  7.3438e-01, -6.9531e-01,  ...,  9.8828e-01,
           2.0000e+00,  2.5312e+00],
         [-3.1250e-01,  2.0605e-01, -3.0469e-01,  ..., -6.2988e-02,
           4.4141e-01,  1.4453e+00],
         [-8.6719e-01,  3.8477e-01, -8.6719e-01,  ...,  1.1016e+00,
          -1.4922e+00,  8.3984e-01]],

        ...,

        [[-8.8882e-04, -6.9275e-03, -3.6926e-03,  ..., -2.1094e-01,
          -4.8584e-02, -6.6406e-02],
         [ 7.1875e-01,  4.1016e-01, -6.3281e-01,  ...,  2.8438e+00,
          -3.8906e+00, -7.0703e-01],
         [ 1.3086e-01, -2.4805e-01, -8.0859e-01,  ...,  2.0000e+00,
          -4.7812e+00, -3.4531e+00],
         ...,
         [ 5.3711e-02,  2.3438e-01, -3.9062e-02,  ..., -3.8750e+00,
          -1.5000e+00, -1.4453e-01],
         [-2.4219e-01, -2.3145e-01, -2.9102e-01,  ..., -3.5156e+00,
           2.3281e+00, -1.8203e+00],
         [-3.3691e-02, -2.7832e-02, -2.0898e-01,  ...,  4.0312e+00,
           1.5391e+00,  1.8359e-01]],

        [[-8.8882e-04, -6.9275e-03, -3.6926e-03,  ..., -2.1094e-01,
          -4.8584e-02, -6.6406e-02],
         [ 7.1875e-01,  4.1016e-01, -6.3281e-01,  ...,  2.8438e+00,
          -3.8906e+00, -7.0703e-01],
         [ 1.3086e-01, -2.4805e-01, -8.0859e-01,  ...,  2.0000e+00,
          -4.7812e+00, -3.4531e+00],
         ...,
         [ 5.3711e-02,  2.3438e-01, -3.9062e-02,  ..., -3.8750e+00,
          -1.5000e+00, -1.4453e-01],
         [-2.4219e-01, -2.3145e-01, -2.9102e-01,  ..., -3.5156e+00,
           2.3281e+00, -1.8203e+00],
         [-3.3691e-02, -2.7832e-02, -2.0898e-01,  ...,  4.0312e+00,
           1.5391e+00,  1.8359e-01]],

        [[-8.8882e-04, -6.9275e-03, -3.6926e-03,  ..., -2.1094e-01,
          -4.8584e-02, -6.6406e-02],
         [ 7.1875e-01,  4.1016e-01, -6.3281e-01,  ...,  2.8438e+00,
          -3.8906e+00, -7.0703e-01],
         [ 1.3086e-01, -2.4805e-01, -8.0859e-01,  ...,  2.0000e+00,
          -4.7812e+00, -3.4531e+00],
         ...,
         [ 5.3711e-02,  2.3438e-01, -3.9062e-02,  ..., -3.8750e+00,
          -1.5000e+00, -1.4453e-01],
         [-2.4219e-01, -2.3145e-01, -2.9102e-01,  ..., -3.5156e+00,
           2.3281e+00, -1.8203e+00],
         [-3.3691e-02, -2.7832e-02, -2.0898e-01,  ...,  4.0312e+00,
           1.5391e+00,  1.8359e-01]]], dtype=torch.bfloat16), tensor([[[-4.9210e-04, -2.6398e-03,  5.1270e-03,  ...,  1.7929e-03,
           4.4556e-03,  9.5825e-03],
         [-7.7637e-02, -2.2949e-01, -3.5742e-01,  ...,  1.1597e-02,
          -8.4766e-01, -5.3906e-01],
         [-3.6719e-01,  3.7500e-01,  4.3164e-01,  ..., -1.7188e-01,
          -4.4922e-01,  4.8438e-01],
         ...,
         [-4.8242e-01,  2.4023e-01, -1.5430e-01,  ...,  5.6250e-01,
           4.3945e-01, -5.1953e-01],
         [-3.6133e-01, -6.0059e-02,  1.7480e-01,  ...,  6.4062e-01,
          -2.8906e-01, -3.7500e-01],
         [-5.4688e-01,  2.2363e-01,  1.1572e-01,  ..., -4.1992e-01,
           4.5312e-01,  2.9688e-01]],

        [[-4.9210e-04, -2.6398e-03,  5.1270e-03,  ...,  1.7929e-03,
           4.4556e-03,  9.5825e-03],
         [-7.7637e-02, -2.2949e-01, -3.5742e-01,  ...,  1.1597e-02,
          -8.4766e-01, -5.3906e-01],
         [-3.6719e-01,  3.7500e-01,  4.3164e-01,  ..., -1.7188e-01,
          -4.4922e-01,  4.8438e-01],
         ...,
         [-4.8242e-01,  2.4023e-01, -1.5430e-01,  ...,  5.6250e-01,
           4.3945e-01, -5.1953e-01],
         [-3.6133e-01, -6.0059e-02,  1.7480e-01,  ...,  6.4062e-01,
          -2.8906e-01, -3.7500e-01],
         [-5.4688e-01,  2.2363e-01,  1.1572e-01,  ..., -4.1992e-01,
           4.5312e-01,  2.9688e-01]],

        [[-4.9210e-04, -2.6398e-03,  5.1270e-03,  ...,  1.7929e-03,
           4.4556e-03,  9.5825e-03],
         [-7.7637e-02, -2.2949e-01, -3.5742e-01,  ...,  1.1597e-02,
          -8.4766e-01, -5.3906e-01],
         [-3.6719e-01,  3.7500e-01,  4.3164e-01,  ..., -1.7188e-01,
          -4.4922e-01,  4.8438e-01],
         ...,
         [-4.8242e-01,  2.4023e-01, -1.5430e-01,  ...,  5.6250e-01,
           4.3945e-01, -5.1953e-01],
         [-3.6133e-01, -6.0059e-02,  1.7480e-01,  ...,  6.4062e-01,
          -2.8906e-01, -3.7500e-01],
         [-5.4688e-01,  2.2363e-01,  1.1572e-01,  ..., -4.1992e-01,
           4.5312e-01,  2.9688e-01]],

        ...,

        [[-1.7456e-02,  5.1514e-02, -1.7262e-04,  ...,  2.8968e-05,
           2.8687e-03,  1.8066e-02],
         [-1.3125e+00,  4.2969e-02,  1.1719e-01,  ...,  7.4219e-01,
          -8.2812e-01, -4.1602e-01],
         [-8.3984e-01, -2.1406e+00, -1.8828e+00,  ..., -1.0312e+00,
          -6.6797e-01,  1.4355e-01],
         ...,
         [ 6.3672e-01, -1.2344e+00, -4.1211e-01,  ...,  3.9368e-03,
           5.5176e-02,  5.3125e-01],
         [ 2.6172e-01, -4.1260e-02,  7.6660e-02,  ...,  3.4961e-01,
           6.6797e-01,  2.0630e-02],
         [-4.4531e-01, -5.5859e-01, -1.0938e+00,  ...,  7.1875e-01,
          -3.8672e-01, -2.6758e-01]],

        [[-1.7456e-02,  5.1514e-02, -1.7262e-04,  ...,  2.8968e-05,
           2.8687e-03,  1.8066e-02],
         [-1.3125e+00,  4.2969e-02,  1.1719e-01,  ...,  7.4219e-01,
          -8.2812e-01, -4.1602e-01],
         [-8.3984e-01, -2.1406e+00, -1.8828e+00,  ..., -1.0312e+00,
          -6.6797e-01,  1.4355e-01],
         ...,
         [ 6.3672e-01, -1.2344e+00, -4.1211e-01,  ...,  3.9368e-03,
           5.5176e-02,  5.3125e-01],
         [ 2.6172e-01, -4.1260e-02,  7.6660e-02,  ...,  3.4961e-01,
           6.6797e-01,  2.0630e-02],
         [-4.4531e-01, -5.5859e-01, -1.0938e+00,  ...,  7.1875e-01,
          -3.8672e-01, -2.6758e-01]],

        [[-1.7456e-02,  5.1514e-02, -1.7262e-04,  ...,  2.8968e-05,
           2.8687e-03,  1.8066e-02],
         [-1.3125e+00,  4.2969e-02,  1.1719e-01,  ...,  7.4219e-01,
          -8.2812e-01, -4.1602e-01],
         [-8.3984e-01, -2.1406e+00, -1.8828e+00,  ..., -1.0312e+00,
          -6.6797e-01,  1.4355e-01],
         ...,
         [ 6.3672e-01, -1.2344e+00, -4.1211e-01,  ...,  3.9368e-03,
           5.5176e-02,  5.3125e-01],
         [ 2.6172e-01, -4.1260e-02,  7.6660e-02,  ...,  3.4961e-01,
           6.6797e-01,  2.0630e-02],
         [-4.4531e-01, -5.5859e-01, -1.0938e+00,  ...,  7.1875e-01,
          -3.8672e-01, -2.6758e-01]]], dtype=torch.bfloat16)), (tensor([[[ 6.1646e-03, -5.6763e-03,  5.2490e-03,  ..., -9.8633e-02,
          -2.2656e-01, -2.7954e-02],
         [-1.6406e+00, -1.4844e-01,  2.5391e-01,  ..., -2.9688e-01,
          -2.1562e+00,  4.9072e-02],
         [-2.0625e+00,  9.0234e-01, -6.2109e-01,  ...,  1.6875e+00,
          -1.6250e+00, -1.6562e+00],
         ...,
         [ 6.8359e-02, -9.4531e-01,  2.7344e-02,  ..., -2.6094e+00,
           1.7656e+00, -1.4297e+00],
         [ 3.0078e-01,  1.1475e-01, -6.1719e-01,  ..., -1.2422e+00,
           5.9375e-01, -2.5781e+00],
         [ 1.4219e+00,  2.0156e+00,  3.4570e-01,  ..., -1.0312e+00,
          -3.3281e+00,  3.0938e+00]],

        [[ 6.1646e-03, -5.6763e-03,  5.2490e-03,  ..., -9.8633e-02,
          -2.2656e-01, -2.7954e-02],
         [-1.6406e+00, -1.4844e-01,  2.5391e-01,  ..., -2.9688e-01,
          -2.1562e+00,  4.9072e-02],
         [-2.0625e+00,  9.0234e-01, -6.2109e-01,  ...,  1.6875e+00,
          -1.6250e+00, -1.6562e+00],
         ...,
         [ 6.8359e-02, -9.4531e-01,  2.7344e-02,  ..., -2.6094e+00,
           1.7656e+00, -1.4297e+00],
         [ 3.0078e-01,  1.1475e-01, -6.1719e-01,  ..., -1.2422e+00,
           5.9375e-01, -2.5781e+00],
         [ 1.4219e+00,  2.0156e+00,  3.4570e-01,  ..., -1.0312e+00,
          -3.3281e+00,  3.0938e+00]],

        [[ 6.1646e-03, -5.6763e-03,  5.2490e-03,  ..., -9.8633e-02,
          -2.2656e-01, -2.7954e-02],
         [-1.6406e+00, -1.4844e-01,  2.5391e-01,  ..., -2.9688e-01,
          -2.1562e+00,  4.9072e-02],
         [-2.0625e+00,  9.0234e-01, -6.2109e-01,  ...,  1.6875e+00,
          -1.6250e+00, -1.6562e+00],
         ...,
         [ 6.8359e-02, -9.4531e-01,  2.7344e-02,  ..., -2.6094e+00,
           1.7656e+00, -1.4297e+00],
         [ 3.0078e-01,  1.1475e-01, -6.1719e-01,  ..., -1.2422e+00,
           5.9375e-01, -2.5781e+00],
         [ 1.4219e+00,  2.0156e+00,  3.4570e-01,  ..., -1.0312e+00,
          -3.3281e+00,  3.0938e+00]],

        ...,

        [[ 2.3193e-03,  1.8921e-02,  1.1902e-02,  ...,  3.8574e-02,
           1.0620e-02,  1.2207e-01],
         [ 9.7656e-01, -4.1016e-02,  1.2031e+00,  ...,  4.9805e-01,
           4.1211e-01, -8.9453e-01],
         [ 9.9609e-01,  1.0938e+00,  2.0781e+00,  ...,  2.4062e+00,
          -1.6357e-02, -6.0156e-01],
         ...,
         [-6.6016e-01, -1.5938e+00,  6.1719e-01,  ..., -9.5312e-01,
           1.6172e+00, -2.0781e+00],
         [-6.6016e-01, -2.3047e-01,  3.8281e-01,  ...,  1.3906e+00,
           2.0938e+00, -1.7188e+00],
         [-2.1562e+00,  5.9375e-01,  2.1875e+00,  ...,  1.6406e+00,
           6.1328e-01, -4.8047e-01]],

        [[ 2.3193e-03,  1.8921e-02,  1.1902e-02,  ...,  3.8574e-02,
           1.0620e-02,  1.2207e-01],
         [ 9.7656e-01, -4.1016e-02,  1.2031e+00,  ...,  4.9805e-01,
           4.1211e-01, -8.9453e-01],
         [ 9.9609e-01,  1.0938e+00,  2.0781e+00,  ...,  2.4062e+00,
          -1.6357e-02, -6.0156e-01],
         ...,
         [-6.6016e-01, -1.5938e+00,  6.1719e-01,  ..., -9.5312e-01,
           1.6172e+00, -2.0781e+00],
         [-6.6016e-01, -2.3047e-01,  3.8281e-01,  ...,  1.3906e+00,
           2.0938e+00, -1.7188e+00],
         [-2.1562e+00,  5.9375e-01,  2.1875e+00,  ...,  1.6406e+00,
           6.1328e-01, -4.8047e-01]],

        [[ 2.3193e-03,  1.8921e-02,  1.1902e-02,  ...,  3.8574e-02,
           1.0620e-02,  1.2207e-01],
         [ 9.7656e-01, -4.1016e-02,  1.2031e+00,  ...,  4.9805e-01,
           4.1211e-01, -8.9453e-01],
         [ 9.9609e-01,  1.0938e+00,  2.0781e+00,  ...,  2.4062e+00,
          -1.6357e-02, -6.0156e-01],
         ...,
         [-6.6016e-01, -1.5938e+00,  6.1719e-01,  ..., -9.5312e-01,
           1.6172e+00, -2.0781e+00],
         [-6.6016e-01, -2.3047e-01,  3.8281e-01,  ...,  1.3906e+00,
           2.0938e+00, -1.7188e+00],
         [-2.1562e+00,  5.9375e-01,  2.1875e+00,  ...,  1.6406e+00,
           6.1328e-01, -4.8047e-01]]], dtype=torch.bfloat16), tensor([[[-0.0067, -0.0065, -0.0085,  ..., -0.0104, -0.0156, -0.0349],
         [-0.1035,  0.9180,  0.8984,  ..., -0.1216,  0.4629,  0.6289],
         [-0.7266,  1.1094,  0.8711,  ...,  0.2344, -0.6406,  1.3047],
         ...,
         [ 1.2578,  0.3184, -0.1357,  ...,  0.3320,  0.8125, -1.1094],
         [-0.2871, -0.1533, -0.4902,  ...,  0.5195,  0.3477,  0.1377],
         [ 0.6992, -1.3984,  0.2617,  ..., -0.1904,  1.1797,  0.1328]],

        [[-0.0067, -0.0065, -0.0085,  ..., -0.0104, -0.0156, -0.0349],
         [-0.1035,  0.9180,  0.8984,  ..., -0.1216,  0.4629,  0.6289],
         [-0.7266,  1.1094,  0.8711,  ...,  0.2344, -0.6406,  1.3047],
         ...,
         [ 1.2578,  0.3184, -0.1357,  ...,  0.3320,  0.8125, -1.1094],
         [-0.2871, -0.1533, -0.4902,  ...,  0.5195,  0.3477,  0.1377],
         [ 0.6992, -1.3984,  0.2617,  ..., -0.1904,  1.1797,  0.1328]],

        [[-0.0067, -0.0065, -0.0085,  ..., -0.0104, -0.0156, -0.0349],
         [-0.1035,  0.9180,  0.8984,  ..., -0.1216,  0.4629,  0.6289],
         [-0.7266,  1.1094,  0.8711,  ...,  0.2344, -0.6406,  1.3047],
         ...,
         [ 1.2578,  0.3184, -0.1357,  ...,  0.3320,  0.8125, -1.1094],
         [-0.2871, -0.1533, -0.4902,  ...,  0.5195,  0.3477,  0.1377],
         [ 0.6992, -1.3984,  0.2617,  ..., -0.1904,  1.1797,  0.1328]],

        ...,

        [[ 0.0139, -0.0273,  0.0087,  ..., -0.0417,  0.0271,  0.0422],
         [ 0.2275,  0.4062, -0.0045,  ...,  0.7266, -0.3867, -0.6211],
         [-0.2422, -0.1992,  1.3438,  ...,  0.9805, -0.5234, -0.2324],
         ...,
         [ 0.6680, -1.5312, -0.2910,  ...,  0.2676,  0.0425,  0.1152],
         [ 0.5078, -0.8750, -0.6211,  ...,  0.1143,  0.4062,  0.1768],
         [-0.1030,  0.5312, -0.5469,  ..., -0.6172, -1.0312, -0.5117]],

        [[ 0.0139, -0.0273,  0.0087,  ..., -0.0417,  0.0271,  0.0422],
         [ 0.2275,  0.4062, -0.0045,  ...,  0.7266, -0.3867, -0.6211],
         [-0.2422, -0.1992,  1.3438,  ...,  0.9805, -0.5234, -0.2324],
         ...,
         [ 0.6680, -1.5312, -0.2910,  ...,  0.2676,  0.0425,  0.1152],
         [ 0.5078, -0.8750, -0.6211,  ...,  0.1143,  0.4062,  0.1768],
         [-0.1030,  0.5312, -0.5469,  ..., -0.6172, -1.0312, -0.5117]],

        [[ 0.0139, -0.0273,  0.0087,  ..., -0.0417,  0.0271,  0.0422],
         [ 0.2275,  0.4062, -0.0045,  ...,  0.7266, -0.3867, -0.6211],
         [-0.2422, -0.1992,  1.3438,  ...,  0.9805, -0.5234, -0.2324],
         ...,
         [ 0.6680, -1.5312, -0.2910,  ...,  0.2676,  0.0425,  0.1152],
         [ 0.5078, -0.8750, -0.6211,  ...,  0.1143,  0.4062,  0.1768],
         [-0.1030,  0.5312, -0.5469,  ..., -0.6172, -1.0312, -0.5117]]],
       dtype=torch.bfloat16)), (tensor([[[ 6.0425e-03, -9.7656e-03,  6.5002e-03,  ...,  1.8164e-01,
          -2.0938e+00, -4.0039e-02],
         [ 2.4023e-01,  2.0605e-01,  1.4297e+00,  ..., -9.3359e-01,
           6.4688e+00, -7.0312e-01],
         [-8.1543e-02,  7.2266e-01,  9.5312e-01,  ..., -9.1406e-01,
           7.5000e+00, -3.5938e-01],
         ...,
         [ 3.4570e-01, -3.7109e-01,  8.2422e-01,  ...,  6.2891e-01,
           4.9688e+00,  1.4688e+00],
         [-5.6250e-01, -2.4316e-01,  9.3750e-02,  ..., -8.0859e-01,
           4.9062e+00,  2.0938e+00],
         [-2.0156e+00, -1.1484e+00,  1.1094e+00,  ...,  1.5547e+00,
           6.3750e+00, -3.4531e+00]],

        [[ 6.0425e-03, -9.7656e-03,  6.5002e-03,  ...,  1.8164e-01,
          -2.0938e+00, -4.0039e-02],
         [ 2.4023e-01,  2.0605e-01,  1.4297e+00,  ..., -9.3359e-01,
           6.4688e+00, -7.0312e-01],
         [-8.1543e-02,  7.2266e-01,  9.5312e-01,  ..., -9.1406e-01,
           7.5000e+00, -3.5938e-01],
         ...,
         [ 3.4570e-01, -3.7109e-01,  8.2422e-01,  ...,  6.2891e-01,
           4.9688e+00,  1.4688e+00],
         [-5.6250e-01, -2.4316e-01,  9.3750e-02,  ..., -8.0859e-01,
           4.9062e+00,  2.0938e+00],
         [-2.0156e+00, -1.1484e+00,  1.1094e+00,  ...,  1.5547e+00,
           6.3750e+00, -3.4531e+00]],

        [[ 6.0425e-03, -9.7656e-03,  6.5002e-03,  ...,  1.8164e-01,
          -2.0938e+00, -4.0039e-02],
         [ 2.4023e-01,  2.0605e-01,  1.4297e+00,  ..., -9.3359e-01,
           6.4688e+00, -7.0312e-01],
         [-8.1543e-02,  7.2266e-01,  9.5312e-01,  ..., -9.1406e-01,
           7.5000e+00, -3.5938e-01],
         ...,
         [ 3.4570e-01, -3.7109e-01,  8.2422e-01,  ...,  6.2891e-01,
           4.9688e+00,  1.4688e+00],
         [-5.6250e-01, -2.4316e-01,  9.3750e-02,  ..., -8.0859e-01,
           4.9062e+00,  2.0938e+00],
         [-2.0156e+00, -1.1484e+00,  1.1094e+00,  ...,  1.5547e+00,
           6.3750e+00, -3.4531e+00]],

        ...,

        [[-4.9744e-03, -7.0190e-04, -9.7656e-03,  ...,  3.0859e-01,
          -4.3750e-01, -6.5430e-02],
         [ 1.8906e+00, -1.2500e+00,  9.7656e-01,  ..., -7.7344e-01,
           1.6250e+00, -5.1172e-01],
         [-1.9531e-01,  5.0781e-01, -6.7188e-01,  ..., -1.9531e+00,
           1.1719e+00,  3.0859e-01],
         ...,
         [-1.0781e+00, -4.9414e-01,  8.0078e-01,  ..., -1.6172e+00,
           2.1680e-01,  5.0537e-02],
         [-6.1328e-01, -1.0391e+00,  5.9766e-01,  ..., -2.0156e+00,
          -4.0039e-01, -4.5166e-02],
         [-3.2812e-01, -2.2188e+00,  2.3242e-01,  ..., -1.7344e+00,
           1.0938e+00, -1.6211e-01]],

        [[-4.9744e-03, -7.0190e-04, -9.7656e-03,  ...,  3.0859e-01,
          -4.3750e-01, -6.5430e-02],
         [ 1.8906e+00, -1.2500e+00,  9.7656e-01,  ..., -7.7344e-01,
           1.6250e+00, -5.1172e-01],
         [-1.9531e-01,  5.0781e-01, -6.7188e-01,  ..., -1.9531e+00,
           1.1719e+00,  3.0859e-01],
         ...,
         [-1.0781e+00, -4.9414e-01,  8.0078e-01,  ..., -1.6172e+00,
           2.1680e-01,  5.0537e-02],
         [-6.1328e-01, -1.0391e+00,  5.9766e-01,  ..., -2.0156e+00,
          -4.0039e-01, -4.5166e-02],
         [-3.2812e-01, -2.2188e+00,  2.3242e-01,  ..., -1.7344e+00,
           1.0938e+00, -1.6211e-01]],

        [[-4.9744e-03, -7.0190e-04, -9.7656e-03,  ...,  3.0859e-01,
          -4.3750e-01, -6.5430e-02],
         [ 1.8906e+00, -1.2500e+00,  9.7656e-01,  ..., -7.7344e-01,
           1.6250e+00, -5.1172e-01],
         [-1.9531e-01,  5.0781e-01, -6.7188e-01,  ..., -1.9531e+00,
           1.1719e+00,  3.0859e-01],
         ...,
         [-1.0781e+00, -4.9414e-01,  8.0078e-01,  ..., -1.6172e+00,
           2.1680e-01,  5.0537e-02],
         [-6.1328e-01, -1.0391e+00,  5.9766e-01,  ..., -2.0156e+00,
          -4.0039e-01, -4.5166e-02],
         [-3.2812e-01, -2.2188e+00,  2.3242e-01,  ..., -1.7344e+00,
           1.0938e+00, -1.6211e-01]]], dtype=torch.bfloat16), tensor([[[ 0.0245,  0.0625, -0.0684,  ..., -0.0576, -0.0496,  0.0322],
         [-0.6133,  0.3223, -0.2158,  ...,  0.2930,  0.2852, -0.3613],
         [ 0.2656,  0.1973,  0.1157,  ..., -0.3672, -0.1279,  0.3633],
         ...,
         [ 0.4941, -0.6875,  0.2812,  ...,  0.3164,  0.3203, -0.3379],
         [-0.2754, -0.9141, -0.3926,  ...,  0.4922,  0.0208, -0.1631],
         [ 0.2891,  0.6836, -0.7148,  ...,  0.1895, -0.1787,  0.4590]],

        [[ 0.0245,  0.0625, -0.0684,  ..., -0.0576, -0.0496,  0.0322],
         [-0.6133,  0.3223, -0.2158,  ...,  0.2930,  0.2852, -0.3613],
         [ 0.2656,  0.1973,  0.1157,  ..., -0.3672, -0.1279,  0.3633],
         ...,
         [ 0.4941, -0.6875,  0.2812,  ...,  0.3164,  0.3203, -0.3379],
         [-0.2754, -0.9141, -0.3926,  ...,  0.4922,  0.0208, -0.1631],
         [ 0.2891,  0.6836, -0.7148,  ...,  0.1895, -0.1787,  0.4590]],

        [[ 0.0245,  0.0625, -0.0684,  ..., -0.0576, -0.0496,  0.0322],
         [-0.6133,  0.3223, -0.2158,  ...,  0.2930,  0.2852, -0.3613],
         [ 0.2656,  0.1973,  0.1157,  ..., -0.3672, -0.1279,  0.3633],
         ...,
         [ 0.4941, -0.6875,  0.2812,  ...,  0.3164,  0.3203, -0.3379],
         [-0.2754, -0.9141, -0.3926,  ...,  0.4922,  0.0208, -0.1631],
         [ 0.2891,  0.6836, -0.7148,  ...,  0.1895, -0.1787,  0.4590]],

        ...,

        [[ 0.0066, -0.0017,  0.0037,  ...,  0.0084,  0.0132,  0.0036],
         [-0.4062, -1.0234,  0.8516,  ...,  0.2617,  1.2031,  0.4395],
         [-0.5898,  0.8008,  0.1118,  ...,  0.6875,  1.5156, -0.1387],
         ...,
         [ 0.5977,  1.6406,  0.4316,  ...,  0.0053, -1.3516, -0.8086],
         [ 0.8164, -0.5195,  0.6992,  ...,  1.5391,  0.2559,  0.5469],
         [-0.4102, -0.0165,  0.1055,  ..., -0.2617,  0.9688,  0.5547]],

        [[ 0.0066, -0.0017,  0.0037,  ...,  0.0084,  0.0132,  0.0036],
         [-0.4062, -1.0234,  0.8516,  ...,  0.2617,  1.2031,  0.4395],
         [-0.5898,  0.8008,  0.1118,  ...,  0.6875,  1.5156, -0.1387],
         ...,
         [ 0.5977,  1.6406,  0.4316,  ...,  0.0053, -1.3516, -0.8086],
         [ 0.8164, -0.5195,  0.6992,  ...,  1.5391,  0.2559,  0.5469],
         [-0.4102, -0.0165,  0.1055,  ..., -0.2617,  0.9688,  0.5547]],

        [[ 0.0066, -0.0017,  0.0037,  ...,  0.0084,  0.0132,  0.0036],
         [-0.4062, -1.0234,  0.8516,  ...,  0.2617,  1.2031,  0.4395],
         [-0.5898,  0.8008,  0.1118,  ...,  0.6875,  1.5156, -0.1387],
         ...,
         [ 0.5977,  1.6406,  0.4316,  ...,  0.0053, -1.3516, -0.8086],
         [ 0.8164, -0.5195,  0.6992,  ...,  1.5391,  0.2559,  0.5469],
         [-0.4102, -0.0165,  0.1055,  ..., -0.2617,  0.9688,  0.5547]]],
       dtype=torch.bfloat16)), (tensor([[[-3.7537e-03, -2.6245e-03,  1.2329e-02,  ...,  1.1768e-01,
          -4.6387e-02,  2.0508e-01],
         [-1.5781e+00, -1.7578e-02, -8.6719e-01,  ..., -1.2500e+00,
           7.1875e-01, -3.4961e-01],
         [-7.1875e-01, -1.1182e-01,  2.5195e-01,  ..., -1.3438e+00,
           3.3594e+00, -1.1172e+00],
         ...,
         [ 5.2344e-01,  1.4453e-01, -9.5703e-01,  ..., -1.8203e+00,
          -3.0664e-01,  2.5000e-01],
         [ 6.6406e-02, -5.5469e-01, -4.1406e-01,  ..., -6.2891e-01,
           3.4375e-01, -8.7402e-02],
         [ 1.0781e+00, -8.1641e-01, -5.1172e-01,  ...,  1.6094e+00,
           2.8125e+00, -2.4844e+00]],

        [[-3.7537e-03, -2.6245e-03,  1.2329e-02,  ...,  1.1768e-01,
          -4.6387e-02,  2.0508e-01],
         [-1.5781e+00, -1.7578e-02, -8.6719e-01,  ..., -1.2500e+00,
           7.1875e-01, -3.4961e-01],
         [-7.1875e-01, -1.1182e-01,  2.5195e-01,  ..., -1.3438e+00,
           3.3594e+00, -1.1172e+00],
         ...,
         [ 5.2344e-01,  1.4453e-01, -9.5703e-01,  ..., -1.8203e+00,
          -3.0664e-01,  2.5000e-01],
         [ 6.6406e-02, -5.5469e-01, -4.1406e-01,  ..., -6.2891e-01,
           3.4375e-01, -8.7402e-02],
         [ 1.0781e+00, -8.1641e-01, -5.1172e-01,  ...,  1.6094e+00,
           2.8125e+00, -2.4844e+00]],

        [[-3.7537e-03, -2.6245e-03,  1.2329e-02,  ...,  1.1768e-01,
          -4.6387e-02,  2.0508e-01],
         [-1.5781e+00, -1.7578e-02, -8.6719e-01,  ..., -1.2500e+00,
           7.1875e-01, -3.4961e-01],
         [-7.1875e-01, -1.1182e-01,  2.5195e-01,  ..., -1.3438e+00,
           3.3594e+00, -1.1172e+00],
         ...,
         [ 5.2344e-01,  1.4453e-01, -9.5703e-01,  ..., -1.8203e+00,
          -3.0664e-01,  2.5000e-01],
         [ 6.6406e-02, -5.5469e-01, -4.1406e-01,  ..., -6.2891e-01,
           3.4375e-01, -8.7402e-02],
         [ 1.0781e+00, -8.1641e-01, -5.1172e-01,  ...,  1.6094e+00,
           2.8125e+00, -2.4844e+00]],

        ...,

        [[-1.0315e-02,  6.5804e-05,  1.0681e-02,  ...,  5.1172e-01,
           1.9043e-02, -2.9907e-02],
         [ 9.0625e-01, -7.8125e-02, -1.0156e+00,  ..., -3.0000e+00,
           4.2578e-01, -9.8047e-01],
         [ 6.7188e-01,  3.4180e-01,  4.0625e-01,  ..., -1.3594e+00,
           1.4844e-01, -1.0312e+00],
         ...,
         [-7.0312e-01, -9.0625e-01, -5.6641e-01,  ..., -1.2109e+00,
          -3.1094e+00, -8.7500e-01],
         [ 2.1680e-01, -1.4648e-01,  4.7852e-02,  ..., -2.1362e-03,
          -9.1016e-01,  3.8867e-01],
         [ 8.3203e-01, -1.8594e+00, -1.6875e+00,  ...,  2.7148e-01,
          -4.1992e-01,  4.6680e-01]],

        [[-1.0315e-02,  6.5804e-05,  1.0681e-02,  ...,  5.1172e-01,
           1.9043e-02, -2.9907e-02],
         [ 9.0625e-01, -7.8125e-02, -1.0156e+00,  ..., -3.0000e+00,
           4.2578e-01, -9.8047e-01],
         [ 6.7188e-01,  3.4180e-01,  4.0625e-01,  ..., -1.3594e+00,
           1.4844e-01, -1.0312e+00],
         ...,
         [-7.0312e-01, -9.0625e-01, -5.6641e-01,  ..., -1.2109e+00,
          -3.1094e+00, -8.7500e-01],
         [ 2.1680e-01, -1.4648e-01,  4.7852e-02,  ..., -2.1362e-03,
          -9.1016e-01,  3.8867e-01],
         [ 8.3203e-01, -1.8594e+00, -1.6875e+00,  ...,  2.7148e-01,
          -4.1992e-01,  4.6680e-01]],

        [[-1.0315e-02,  6.5804e-05,  1.0681e-02,  ...,  5.1172e-01,
           1.9043e-02, -2.9907e-02],
         [ 9.0625e-01, -7.8125e-02, -1.0156e+00,  ..., -3.0000e+00,
           4.2578e-01, -9.8047e-01],
         [ 6.7188e-01,  3.4180e-01,  4.0625e-01,  ..., -1.3594e+00,
           1.4844e-01, -1.0312e+00],
         ...,
         [-7.0312e-01, -9.0625e-01, -5.6641e-01,  ..., -1.2109e+00,
          -3.1094e+00, -8.7500e-01],
         [ 2.1680e-01, -1.4648e-01,  4.7852e-02,  ..., -2.1362e-03,
          -9.1016e-01,  3.8867e-01],
         [ 8.3203e-01, -1.8594e+00, -1.6875e+00,  ...,  2.7148e-01,
          -4.1992e-01,  4.6680e-01]]], dtype=torch.bfloat16), tensor([[[-2.2339e-02, -2.2583e-02, -1.3672e-01,  ...,  2.0386e-02,
           5.4016e-03, -2.3193e-02],
         [-7.5000e-01, -3.5547e-01,  6.2500e-01,  ...,  3.8281e-01,
           3.4961e-01,  3.8477e-01],
         [ 2.2266e-01, -3.2617e-01,  4.4727e-01,  ..., -1.6504e-01,
           5.8203e-01,  2.3047e-01],
         ...,
         [-1.2109e-01, -1.7676e-01, -8.2422e-01,  ...,  4.3945e-01,
           2.3242e-01,  1.0312e+00],
         [-8.5938e-01,  1.5547e+00,  2.0801e-01,  ...,  9.2578e-01,
           4.7266e-01,  7.3438e-01],
         [ 9.5703e-01, -1.5918e-01, -5.7812e-01,  ..., -2.7930e-01,
           4.8242e-01,  4.6094e-01]],

        [[-2.2339e-02, -2.2583e-02, -1.3672e-01,  ...,  2.0386e-02,
           5.4016e-03, -2.3193e-02],
         [-7.5000e-01, -3.5547e-01,  6.2500e-01,  ...,  3.8281e-01,
           3.4961e-01,  3.8477e-01],
         [ 2.2266e-01, -3.2617e-01,  4.4727e-01,  ..., -1.6504e-01,
           5.8203e-01,  2.3047e-01],
         ...,
         [-1.2109e-01, -1.7676e-01, -8.2422e-01,  ...,  4.3945e-01,
           2.3242e-01,  1.0312e+00],
         [-8.5938e-01,  1.5547e+00,  2.0801e-01,  ...,  9.2578e-01,
           4.7266e-01,  7.3438e-01],
         [ 9.5703e-01, -1.5918e-01, -5.7812e-01,  ..., -2.7930e-01,
           4.8242e-01,  4.6094e-01]],

        [[-2.2339e-02, -2.2583e-02, -1.3672e-01,  ...,  2.0386e-02,
           5.4016e-03, -2.3193e-02],
         [-7.5000e-01, -3.5547e-01,  6.2500e-01,  ...,  3.8281e-01,
           3.4961e-01,  3.8477e-01],
         [ 2.2266e-01, -3.2617e-01,  4.4727e-01,  ..., -1.6504e-01,
           5.8203e-01,  2.3047e-01],
         ...,
         [-1.2109e-01, -1.7676e-01, -8.2422e-01,  ...,  4.3945e-01,
           2.3242e-01,  1.0312e+00],
         [-8.5938e-01,  1.5547e+00,  2.0801e-01,  ...,  9.2578e-01,
           4.7266e-01,  7.3438e-01],
         [ 9.5703e-01, -1.5918e-01, -5.7812e-01,  ..., -2.7930e-01,
           4.8242e-01,  4.6094e-01]],

        ...,

        [[-1.1963e-02,  1.4801e-03, -6.3171e-03,  ...,  3.9062e-03,
          -2.3315e-02, -1.6327e-03],
         [-4.7852e-01, -4.9219e-01,  1.7285e-01,  ..., -3.8086e-01,
           7.8516e-01,  5.4297e-01],
         [-1.1953e+00, -3.8281e-01, -4.9609e-01,  ..., -6.0938e-01,
           5.3516e-01,  4.7266e-01],
         ...,
         [-2.1289e-01,  2.1484e-01,  7.0312e-01,  ...,  5.2344e-01,
          -5.7031e-01, -6.7188e-01],
         [-1.8359e-01,  2.7100e-02,  2.2754e-01,  ...,  7.3047e-01,
           6.4941e-02,  1.7871e-01],
         [-1.3477e-01,  1.2329e-02,  1.0156e+00,  ...,  6.4062e-01,
           1.2817e-02,  1.2878e-02]],

        [[-1.1963e-02,  1.4801e-03, -6.3171e-03,  ...,  3.9062e-03,
          -2.3315e-02, -1.6327e-03],
         [-4.7852e-01, -4.9219e-01,  1.7285e-01,  ..., -3.8086e-01,
           7.8516e-01,  5.4297e-01],
         [-1.1953e+00, -3.8281e-01, -4.9609e-01,  ..., -6.0938e-01,
           5.3516e-01,  4.7266e-01],
         ...,
         [-2.1289e-01,  2.1484e-01,  7.0312e-01,  ...,  5.2344e-01,
          -5.7031e-01, -6.7188e-01],
         [-1.8359e-01,  2.7100e-02,  2.2754e-01,  ...,  7.3047e-01,
           6.4941e-02,  1.7871e-01],
         [-1.3477e-01,  1.2329e-02,  1.0156e+00,  ...,  6.4062e-01,
           1.2817e-02,  1.2878e-02]],

        [[-1.1963e-02,  1.4801e-03, -6.3171e-03,  ...,  3.9062e-03,
          -2.3315e-02, -1.6327e-03],
         [-4.7852e-01, -4.9219e-01,  1.7285e-01,  ..., -3.8086e-01,
           7.8516e-01,  5.4297e-01],
         [-1.1953e+00, -3.8281e-01, -4.9609e-01,  ..., -6.0938e-01,
           5.3516e-01,  4.7266e-01],
         ...,
         [-2.1289e-01,  2.1484e-01,  7.0312e-01,  ...,  5.2344e-01,
          -5.7031e-01, -6.7188e-01],
         [-1.8359e-01,  2.7100e-02,  2.2754e-01,  ...,  7.3047e-01,
           6.4941e-02,  1.7871e-01],
         [-1.3477e-01,  1.2329e-02,  1.0156e+00,  ...,  6.4062e-01,
           1.2817e-02,  1.2878e-02]]], dtype=torch.bfloat16)), (tensor([[[ 7.9956e-03,  4.7302e-03,  9.1553e-05,  ...,  4.0234e-01,
          -7.4707e-02, -5.3516e-01],
         [ 1.8750e-01, -2.0312e+00,  1.7109e+00,  ..., -2.7344e+00,
           9.1406e-01,  3.5938e-01],
         [-6.9922e-01,  2.0508e-01,  7.5781e-01,  ..., -2.6875e+00,
           4.7070e-01, -4.1504e-02],
         ...,
         [-9.1797e-02,  1.2812e+00,  9.9609e-01,  ..., -3.6406e+00,
          -3.3203e-02,  1.7891e+00],
         [ 7.5781e-01,  3.4570e-01,  9.4141e-01,  ..., -3.6875e+00,
          -2.3750e+00,  3.2656e+00],
         [ 2.7656e+00, -6.6797e-01,  1.9688e+00,  ..., -4.1250e+00,
           2.7188e+00,  1.1250e+00]],

        [[ 7.9956e-03,  4.7302e-03,  9.1553e-05,  ...,  4.0234e-01,
          -7.4707e-02, -5.3516e-01],
         [ 1.8750e-01, -2.0312e+00,  1.7109e+00,  ..., -2.7344e+00,
           9.1406e-01,  3.5938e-01],
         [-6.9922e-01,  2.0508e-01,  7.5781e-01,  ..., -2.6875e+00,
           4.7070e-01, -4.1504e-02],
         ...,
         [-9.1797e-02,  1.2812e+00,  9.9609e-01,  ..., -3.6406e+00,
          -3.3203e-02,  1.7891e+00],
         [ 7.5781e-01,  3.4570e-01,  9.4141e-01,  ..., -3.6875e+00,
          -2.3750e+00,  3.2656e+00],
         [ 2.7656e+00, -6.6797e-01,  1.9688e+00,  ..., -4.1250e+00,
           2.7188e+00,  1.1250e+00]],

        [[ 7.9956e-03,  4.7302e-03,  9.1553e-05,  ...,  4.0234e-01,
          -7.4707e-02, -5.3516e-01],
         [ 1.8750e-01, -2.0312e+00,  1.7109e+00,  ..., -2.7344e+00,
           9.1406e-01,  3.5938e-01],
         [-6.9922e-01,  2.0508e-01,  7.5781e-01,  ..., -2.6875e+00,
           4.7070e-01, -4.1504e-02],
         ...,
         [-9.1797e-02,  1.2812e+00,  9.9609e-01,  ..., -3.6406e+00,
          -3.3203e-02,  1.7891e+00],
         [ 7.5781e-01,  3.4570e-01,  9.4141e-01,  ..., -3.6875e+00,
          -2.3750e+00,  3.2656e+00],
         [ 2.7656e+00, -6.6797e-01,  1.9688e+00,  ..., -4.1250e+00,
           2.7188e+00,  1.1250e+00]],

        ...,

        [[ 3.9368e-03,  7.5378e-03, -1.5137e-02,  ...,  9.0820e-02,
          -4.9805e-01, -1.0469e+00],
         [ 1.6719e+00, -2.3633e-01,  4.3359e-01,  ..., -6.9531e-01,
           4.8438e+00,  3.9531e+00],
         [ 2.9297e-01,  6.9922e-01,  4.2383e-01,  ..., -1.3359e+00,
           3.0156e+00,  5.8125e+00],
         ...,
         [-1.7109e+00,  3.1250e-02, -6.0938e-01,  ..., -2.2070e-01,
          -5.8984e-01,  3.6562e+00],
         [-2.3242e-01,  1.3086e-01,  5.7422e-01,  ...,  1.1094e+00,
           2.0386e-02,  4.0625e+00],
         [ 7.1875e-01, -1.0625e+00,  1.0312e+00,  ..., -3.7344e+00,
           3.1719e+00,  5.3750e+00]],

        [[ 3.9368e-03,  7.5378e-03, -1.5137e-02,  ...,  9.0820e-02,
          -4.9805e-01, -1.0469e+00],
         [ 1.6719e+00, -2.3633e-01,  4.3359e-01,  ..., -6.9531e-01,
           4.8438e+00,  3.9531e+00],
         [ 2.9297e-01,  6.9922e-01,  4.2383e-01,  ..., -1.3359e+00,
           3.0156e+00,  5.8125e+00],
         ...,
         [-1.7109e+00,  3.1250e-02, -6.0938e-01,  ..., -2.2070e-01,
          -5.8984e-01,  3.6562e+00],
         [-2.3242e-01,  1.3086e-01,  5.7422e-01,  ...,  1.1094e+00,
           2.0386e-02,  4.0625e+00],
         [ 7.1875e-01, -1.0625e+00,  1.0312e+00,  ..., -3.7344e+00,
           3.1719e+00,  5.3750e+00]],

        [[ 3.9368e-03,  7.5378e-03, -1.5137e-02,  ...,  9.0820e-02,
          -4.9805e-01, -1.0469e+00],
         [ 1.6719e+00, -2.3633e-01,  4.3359e-01,  ..., -6.9531e-01,
           4.8438e+00,  3.9531e+00],
         [ 2.9297e-01,  6.9922e-01,  4.2383e-01,  ..., -1.3359e+00,
           3.0156e+00,  5.8125e+00],
         ...,
         [-1.7109e+00,  3.1250e-02, -6.0938e-01,  ..., -2.2070e-01,
          -5.8984e-01,  3.6562e+00],
         [-2.3242e-01,  1.3086e-01,  5.7422e-01,  ...,  1.1094e+00,
           2.0386e-02,  4.0625e+00],
         [ 7.1875e-01, -1.0625e+00,  1.0312e+00,  ..., -3.7344e+00,
           3.1719e+00,  5.3750e+00]]], dtype=torch.bfloat16), tensor([[[-7.3547e-03, -4.1504e-03, -1.7334e-02,  ...,  1.0132e-02,
          -1.4725e-03,  4.8828e-03],
         [-9.4531e-01,  1.7188e-01,  2.5781e+00,  ..., -4.9219e-01,
           1.5234e+00,  1.3594e+00],
         [-2.6562e-01, -1.1484e+00,  1.6406e+00,  ..., -1.5078e+00,
           9.1016e-01,  1.6406e+00],
         ...,
         [-5.7031e-01,  3.6865e-02, -3.2227e-01,  ...,  3.5742e-01,
          -2.7930e-01, -3.2422e-01],
         [-1.4258e-01, -4.2578e-01, -6.1719e-01,  ...,  1.3770e-01,
          -2.7344e-01,  4.9023e-01],
         [-1.2969e+00, -4.5508e-01,  6.1719e-01,  ...,  1.6602e-01,
          -1.2512e-02, -3.4766e-01]],

        [[-7.3547e-03, -4.1504e-03, -1.7334e-02,  ...,  1.0132e-02,
          -1.4725e-03,  4.8828e-03],
         [-9.4531e-01,  1.7188e-01,  2.5781e+00,  ..., -4.9219e-01,
           1.5234e+00,  1.3594e+00],
         [-2.6562e-01, -1.1484e+00,  1.6406e+00,  ..., -1.5078e+00,
           9.1016e-01,  1.6406e+00],
         ...,
         [-5.7031e-01,  3.6865e-02, -3.2227e-01,  ...,  3.5742e-01,
          -2.7930e-01, -3.2422e-01],
         [-1.4258e-01, -4.2578e-01, -6.1719e-01,  ...,  1.3770e-01,
          -2.7344e-01,  4.9023e-01],
         [-1.2969e+00, -4.5508e-01,  6.1719e-01,  ...,  1.6602e-01,
          -1.2512e-02, -3.4766e-01]],

        [[-7.3547e-03, -4.1504e-03, -1.7334e-02,  ...,  1.0132e-02,
          -1.4725e-03,  4.8828e-03],
         [-9.4531e-01,  1.7188e-01,  2.5781e+00,  ..., -4.9219e-01,
           1.5234e+00,  1.3594e+00],
         [-2.6562e-01, -1.1484e+00,  1.6406e+00,  ..., -1.5078e+00,
           9.1016e-01,  1.6406e+00],
         ...,
         [-5.7031e-01,  3.6865e-02, -3.2227e-01,  ...,  3.5742e-01,
          -2.7930e-01, -3.2422e-01],
         [-1.4258e-01, -4.2578e-01, -6.1719e-01,  ...,  1.3770e-01,
          -2.7344e-01,  4.9023e-01],
         [-1.2969e+00, -4.5508e-01,  6.1719e-01,  ...,  1.6602e-01,
          -1.2512e-02, -3.4766e-01]],

        ...,

        [[-7.8735e-03,  1.6357e-02,  8.4686e-04,  ..., -9.6436e-03,
          -1.5747e-02,  1.2512e-02],
         [-1.0840e-01, -1.2031e+00, -2.6953e-01,  ..., -1.9653e-02,
          -2.6953e-01,  1.5991e-02],
         [ 5.1172e-01,  1.1328e+00, -1.2402e-01,  ...,  1.3281e-01,
           4.2773e-01, -2.9297e-01],
         ...,
         [-9.0625e-01, -5.1172e-01,  2.8076e-02,  ...,  5.0000e-01,
          -1.0469e+00, -1.0559e-02],
         [ 7.3242e-02, -2.7539e-01,  1.1963e-01,  ..., -6.3281e-01,
           6.6406e-01, -1.3086e-01],
         [-1.1172e+00,  1.8281e+00,  1.7031e+00,  ..., -6.9141e-01,
           9.6094e-01, -5.0391e-01]],

        [[-7.8735e-03,  1.6357e-02,  8.4686e-04,  ..., -9.6436e-03,
          -1.5747e-02,  1.2512e-02],
         [-1.0840e-01, -1.2031e+00, -2.6953e-01,  ..., -1.9653e-02,
          -2.6953e-01,  1.5991e-02],
         [ 5.1172e-01,  1.1328e+00, -1.2402e-01,  ...,  1.3281e-01,
           4.2773e-01, -2.9297e-01],
         ...,
         [-9.0625e-01, -5.1172e-01,  2.8076e-02,  ...,  5.0000e-01,
          -1.0469e+00, -1.0559e-02],
         [ 7.3242e-02, -2.7539e-01,  1.1963e-01,  ..., -6.3281e-01,
           6.6406e-01, -1.3086e-01],
         [-1.1172e+00,  1.8281e+00,  1.7031e+00,  ..., -6.9141e-01,
           9.6094e-01, -5.0391e-01]],

        [[-7.8735e-03,  1.6357e-02,  8.4686e-04,  ..., -9.6436e-03,
          -1.5747e-02,  1.2512e-02],
         [-1.0840e-01, -1.2031e+00, -2.6953e-01,  ..., -1.9653e-02,
          -2.6953e-01,  1.5991e-02],
         [ 5.1172e-01,  1.1328e+00, -1.2402e-01,  ...,  1.3281e-01,
           4.2773e-01, -2.9297e-01],
         ...,
         [-9.0625e-01, -5.1172e-01,  2.8076e-02,  ...,  5.0000e-01,
          -1.0469e+00, -1.0559e-02],
         [ 7.3242e-02, -2.7539e-01,  1.1963e-01,  ..., -6.3281e-01,
           6.6406e-01, -1.3086e-01],
         [-1.1172e+00,  1.8281e+00,  1.7031e+00,  ..., -6.9141e-01,
           9.6094e-01, -5.0391e-01]]], dtype=torch.bfloat16)), (tensor([[[ 3.4332e-04,  1.2634e-02, -1.8433e-02,  ...,  3.8330e-02,
           1.4941e-01, -1.8828e+00],
         [ 9.5312e-01, -1.9219e+00, -1.5820e-01,  ..., -1.0391e+00,
          -6.2891e-01,  7.4688e+00],
         [-3.3594e-01, -6.1719e-01, -3.4180e-01,  ...,  6.3281e-01,
           2.0156e+00,  8.3750e+00],
         ...,
         [-1.1621e-01,  3.0469e-01, -4.6094e-01,  ..., -2.9688e+00,
          -3.5312e+00,  5.2812e+00],
         [ 3.3203e-01,  1.1621e-01,  2.7148e-01,  ..., -2.9531e+00,
          -1.5938e+00,  5.8438e+00],
         [ 1.4844e+00, -7.9297e-01, -1.1719e-02,  ..., -8.4961e-02,
           2.4062e+00,  7.5000e+00]],

        [[ 3.4332e-04,  1.2634e-02, -1.8433e-02,  ...,  3.8330e-02,
           1.4941e-01, -1.8828e+00],
         [ 9.5312e-01, -1.9219e+00, -1.5820e-01,  ..., -1.0391e+00,
          -6.2891e-01,  7.4688e+00],
         [-3.3594e-01, -6.1719e-01, -3.4180e-01,  ...,  6.3281e-01,
           2.0156e+00,  8.3750e+00],
         ...,
         [-1.1621e-01,  3.0469e-01, -4.6094e-01,  ..., -2.9688e+00,
          -3.5312e+00,  5.2812e+00],
         [ 3.3203e-01,  1.1621e-01,  2.7148e-01,  ..., -2.9531e+00,
          -1.5938e+00,  5.8438e+00],
         [ 1.4844e+00, -7.9297e-01, -1.1719e-02,  ..., -8.4961e-02,
           2.4062e+00,  7.5000e+00]],

        [[ 3.4332e-04,  1.2634e-02, -1.8433e-02,  ...,  3.8330e-02,
           1.4941e-01, -1.8828e+00],
         [ 9.5312e-01, -1.9219e+00, -1.5820e-01,  ..., -1.0391e+00,
          -6.2891e-01,  7.4688e+00],
         [-3.3594e-01, -6.1719e-01, -3.4180e-01,  ...,  6.3281e-01,
           2.0156e+00,  8.3750e+00],
         ...,
         [-1.1621e-01,  3.0469e-01, -4.6094e-01,  ..., -2.9688e+00,
          -3.5312e+00,  5.2812e+00],
         [ 3.3203e-01,  1.1621e-01,  2.7148e-01,  ..., -2.9531e+00,
          -1.5938e+00,  5.8438e+00],
         [ 1.4844e+00, -7.9297e-01, -1.1719e-02,  ..., -8.4961e-02,
           2.4062e+00,  7.5000e+00]],

        ...,

        [[ 4.5776e-03, -9.3994e-03, -8.3618e-03,  ..., -1.5156e+00,
          -4.8828e-01,  8.8379e-02],
         [-2.8438e+00,  2.7188e+00,  1.3594e+00,  ...,  8.3750e+00,
          -1.0234e+00, -5.2344e-01],
         [-7.5000e-01,  4.7266e-01,  3.7305e-01,  ...,  9.1875e+00,
          -2.5000e+00, -2.6406e+00],
         ...,
         [ 9.3359e-01, -7.5391e-01, -8.5547e-01,  ...,  7.8438e+00,
          -3.4375e-01,  8.0859e-01],
         [ 7.2266e-01, -9.7168e-02,  3.1836e-01,  ...,  7.0938e+00,
           1.1719e+00,  6.3672e-01],
         [ 1.2422e+00,  5.5078e-01,  5.7812e-01,  ...,  7.6250e+00,
           7.0312e-02, -2.3750e+00]],

        [[ 4.5776e-03, -9.3994e-03, -8.3618e-03,  ..., -1.5156e+00,
          -4.8828e-01,  8.8379e-02],
         [-2.8438e+00,  2.7188e+00,  1.3594e+00,  ...,  8.3750e+00,
          -1.0234e+00, -5.2344e-01],
         [-7.5000e-01,  4.7266e-01,  3.7305e-01,  ...,  9.1875e+00,
          -2.5000e+00, -2.6406e+00],
         ...,
         [ 9.3359e-01, -7.5391e-01, -8.5547e-01,  ...,  7.8438e+00,
          -3.4375e-01,  8.0859e-01],
         [ 7.2266e-01, -9.7168e-02,  3.1836e-01,  ...,  7.0938e+00,
           1.1719e+00,  6.3672e-01],
         [ 1.2422e+00,  5.5078e-01,  5.7812e-01,  ...,  7.6250e+00,
           7.0312e-02, -2.3750e+00]],

        [[ 4.5776e-03, -9.3994e-03, -8.3618e-03,  ..., -1.5156e+00,
          -4.8828e-01,  8.8379e-02],
         [-2.8438e+00,  2.7188e+00,  1.3594e+00,  ...,  8.3750e+00,
          -1.0234e+00, -5.2344e-01],
         [-7.5000e-01,  4.7266e-01,  3.7305e-01,  ...,  9.1875e+00,
          -2.5000e+00, -2.6406e+00],
         ...,
         [ 9.3359e-01, -7.5391e-01, -8.5547e-01,  ...,  7.8438e+00,
          -3.4375e-01,  8.0859e-01],
         [ 7.2266e-01, -9.7168e-02,  3.1836e-01,  ...,  7.0938e+00,
           1.1719e+00,  6.3672e-01],
         [ 1.2422e+00,  5.5078e-01,  5.7812e-01,  ...,  7.6250e+00,
           7.0312e-02, -2.3750e+00]]], dtype=torch.bfloat16), tensor([[[ 8.1543e-02,  4.4556e-03,  5.4199e-02,  ..., -3.0212e-03,
          -8.1787e-03, -3.4668e-02],
         [-3.1641e-01, -2.2461e-01, -2.6758e-01,  ...,  4.6289e-01,
          -2.1191e-01, -3.3789e-01],
         [-8.7402e-02,  3.1250e-01,  4.8828e-01,  ...,  7.9297e-01,
           9.8633e-02,  1.7090e-01],
         ...,
         [ 6.1719e-01,  9.8047e-01,  8.5938e-02,  ...,  3.6377e-02,
           4.3945e-01, -5.7031e-01],
         [-9.4238e-02,  4.3359e-01,  2.4707e-01,  ...,  8.7891e-01,
           7.1875e-01, -8.4375e-01],
         [-6.8359e-02,  3.2422e-01, -3.3398e-01,  ...,  3.8672e-01,
           5.4688e-01, -3.1055e-01]],

        [[ 8.1543e-02,  4.4556e-03,  5.4199e-02,  ..., -3.0212e-03,
          -8.1787e-03, -3.4668e-02],
         [-3.1641e-01, -2.2461e-01, -2.6758e-01,  ...,  4.6289e-01,
          -2.1191e-01, -3.3789e-01],
         [-8.7402e-02,  3.1250e-01,  4.8828e-01,  ...,  7.9297e-01,
           9.8633e-02,  1.7090e-01],
         ...,
         [ 6.1719e-01,  9.8047e-01,  8.5938e-02,  ...,  3.6377e-02,
           4.3945e-01, -5.7031e-01],
         [-9.4238e-02,  4.3359e-01,  2.4707e-01,  ...,  8.7891e-01,
           7.1875e-01, -8.4375e-01],
         [-6.8359e-02,  3.2422e-01, -3.3398e-01,  ...,  3.8672e-01,
           5.4688e-01, -3.1055e-01]],

        [[ 8.1543e-02,  4.4556e-03,  5.4199e-02,  ..., -3.0212e-03,
          -8.1787e-03, -3.4668e-02],
         [-3.1641e-01, -2.2461e-01, -2.6758e-01,  ...,  4.6289e-01,
          -2.1191e-01, -3.3789e-01],
         [-8.7402e-02,  3.1250e-01,  4.8828e-01,  ...,  7.9297e-01,
           9.8633e-02,  1.7090e-01],
         ...,
         [ 6.1719e-01,  9.8047e-01,  8.5938e-02,  ...,  3.6377e-02,
           4.3945e-01, -5.7031e-01],
         [-9.4238e-02,  4.3359e-01,  2.4707e-01,  ...,  8.7891e-01,
           7.1875e-01, -8.4375e-01],
         [-6.8359e-02,  3.2422e-01, -3.3398e-01,  ...,  3.8672e-01,
           5.4688e-01, -3.1055e-01]],

        ...,

        [[ 5.6763e-03, -2.8534e-03,  6.3705e-04,  ...,  3.8300e-03,
          -9.3994e-03,  3.9673e-03],
         [-2.5938e+00,  3.1406e+00, -4.9023e-01,  ..., -5.6250e-01,
          -1.3281e+00,  2.0142e-02],
         [-2.4844e+00,  1.4297e+00, -1.9922e+00,  ...,  5.3516e-01,
          -4.1797e-01, -2.4609e-01],
         ...,
         [ 5.4688e-01, -5.9766e-01,  8.5938e-01,  ...,  6.7969e-01,
          -7.6172e-01,  4.8438e-01],
         [ 3.1445e-01, -3.9795e-02,  5.6250e-01,  ..., -7.1289e-02,
          -6.4453e-02,  3.7598e-02],
         [-1.8262e-01, -1.9141e-01, -5.1562e-01,  ...,  2.3828e-01,
           2.0781e+00, -9.9219e-01]],

        [[ 5.6763e-03, -2.8534e-03,  6.3705e-04,  ...,  3.8300e-03,
          -9.3994e-03,  3.9673e-03],
         [-2.5938e+00,  3.1406e+00, -4.9023e-01,  ..., -5.6250e-01,
          -1.3281e+00,  2.0142e-02],
         [-2.4844e+00,  1.4297e+00, -1.9922e+00,  ...,  5.3516e-01,
          -4.1797e-01, -2.4609e-01],
         ...,
         [ 5.4688e-01, -5.9766e-01,  8.5938e-01,  ...,  6.7969e-01,
          -7.6172e-01,  4.8438e-01],
         [ 3.1445e-01, -3.9795e-02,  5.6250e-01,  ..., -7.1289e-02,
          -6.4453e-02,  3.7598e-02],
         [-1.8262e-01, -1.9141e-01, -5.1562e-01,  ...,  2.3828e-01,
           2.0781e+00, -9.9219e-01]],

        [[ 5.6763e-03, -2.8534e-03,  6.3705e-04,  ...,  3.8300e-03,
          -9.3994e-03,  3.9673e-03],
         [-2.5938e+00,  3.1406e+00, -4.9023e-01,  ..., -5.6250e-01,
          -1.3281e+00,  2.0142e-02],
         [-2.4844e+00,  1.4297e+00, -1.9922e+00,  ...,  5.3516e-01,
          -4.1797e-01, -2.4609e-01],
         ...,
         [ 5.4688e-01, -5.9766e-01,  8.5938e-01,  ...,  6.7969e-01,
          -7.6172e-01,  4.8438e-01],
         [ 3.1445e-01, -3.9795e-02,  5.6250e-01,  ..., -7.1289e-02,
          -6.4453e-02,  3.7598e-02],
         [-1.8262e-01, -1.9141e-01, -5.1562e-01,  ...,  2.3828e-01,
           2.0781e+00, -9.9219e-01]]], dtype=torch.bfloat16)), (tensor([[[-4.6921e-04,  1.2268e-02,  1.6113e-02,  ...,  2.0703e-01,
           7.5684e-02,  7.1716e-03],
         [-1.4531e+00, -7.7344e-01, -3.1055e-01,  ..., -5.7812e-01,
          -1.6968e-02,  5.7068e-03],
         [-9.0234e-01,  3.5938e-01, -5.3125e-01,  ..., -1.9922e+00,
           2.1250e+00, -2.1406e+00],
         ...,
         [ 2.5391e-01,  3.5156e-01, -2.4023e-01,  ...,  1.4609e+00,
          -3.7031e+00, -1.8594e+00],
         [-2.1484e-01,  5.2344e-01,  4.8828e-04,  ...,  6.0156e-01,
          -5.0781e-01, -2.0938e+00],
         [ 2.3125e+00,  9.6875e-01, -1.5859e+00,  ..., -1.2500e+00,
           8.1250e-01, -9.9609e-01]],

        [[-4.6921e-04,  1.2268e-02,  1.6113e-02,  ...,  2.0703e-01,
           7.5684e-02,  7.1716e-03],
         [-1.4531e+00, -7.7344e-01, -3.1055e-01,  ..., -5.7812e-01,
          -1.6968e-02,  5.7068e-03],
         [-9.0234e-01,  3.5938e-01, -5.3125e-01,  ..., -1.9922e+00,
           2.1250e+00, -2.1406e+00],
         ...,
         [ 2.5391e-01,  3.5156e-01, -2.4023e-01,  ...,  1.4609e+00,
          -3.7031e+00, -1.8594e+00],
         [-2.1484e-01,  5.2344e-01,  4.8828e-04,  ...,  6.0156e-01,
          -5.0781e-01, -2.0938e+00],
         [ 2.3125e+00,  9.6875e-01, -1.5859e+00,  ..., -1.2500e+00,
           8.1250e-01, -9.9609e-01]],

        [[-4.6921e-04,  1.2268e-02,  1.6113e-02,  ...,  2.0703e-01,
           7.5684e-02,  7.1716e-03],
         [-1.4531e+00, -7.7344e-01, -3.1055e-01,  ..., -5.7812e-01,
          -1.6968e-02,  5.7068e-03],
         [-9.0234e-01,  3.5938e-01, -5.3125e-01,  ..., -1.9922e+00,
           2.1250e+00, -2.1406e+00],
         ...,
         [ 2.5391e-01,  3.5156e-01, -2.4023e-01,  ...,  1.4609e+00,
          -3.7031e+00, -1.8594e+00],
         [-2.1484e-01,  5.2344e-01,  4.8828e-04,  ...,  6.0156e-01,
          -5.0781e-01, -2.0938e+00],
         [ 2.3125e+00,  9.6875e-01, -1.5859e+00,  ..., -1.2500e+00,
           8.1250e-01, -9.9609e-01]],

        ...,

        [[-1.4221e-02,  1.0681e-02,  2.6398e-03,  ...,  5.9082e-02,
          -9.4727e-02, -1.1328e-01],
         [ 9.3359e-01, -9.9219e-01,  3.9062e-01,  ..., -4.4727e-01,
          -3.0625e+00, -1.9727e-01],
         [-1.7969e-01, -1.2266e+00,  8.1641e-01,  ...,  1.0234e+00,
          -3.5938e-01,  8.9062e-01],
         ...,
         [-9.2578e-01,  2.4219e-01, -1.3281e+00,  ...,  1.7344e+00,
           3.0273e-01,  1.4766e+00],
         [-2.3633e-01,  3.8672e-01, -3.7695e-01,  ...,  7.6953e-01,
           4.2969e-01,  1.3281e+00],
         [ 4.1016e-01,  1.4375e+00,  2.0703e-01,  ...,  1.9922e+00,
           1.0938e+00, -4.3750e+00]],

        [[-1.4221e-02,  1.0681e-02,  2.6398e-03,  ...,  5.9082e-02,
          -9.4727e-02, -1.1328e-01],
         [ 9.3359e-01, -9.9219e-01,  3.9062e-01,  ..., -4.4727e-01,
          -3.0625e+00, -1.9727e-01],
         [-1.7969e-01, -1.2266e+00,  8.1641e-01,  ...,  1.0234e+00,
          -3.5938e-01,  8.9062e-01],
         ...,
         [-9.2578e-01,  2.4219e-01, -1.3281e+00,  ...,  1.7344e+00,
           3.0273e-01,  1.4766e+00],
         [-2.3633e-01,  3.8672e-01, -3.7695e-01,  ...,  7.6953e-01,
           4.2969e-01,  1.3281e+00],
         [ 4.1016e-01,  1.4375e+00,  2.0703e-01,  ...,  1.9922e+00,
           1.0938e+00, -4.3750e+00]],

        [[-1.4221e-02,  1.0681e-02,  2.6398e-03,  ...,  5.9082e-02,
          -9.4727e-02, -1.1328e-01],
         [ 9.3359e-01, -9.9219e-01,  3.9062e-01,  ..., -4.4727e-01,
          -3.0625e+00, -1.9727e-01],
         [-1.7969e-01, -1.2266e+00,  8.1641e-01,  ...,  1.0234e+00,
          -3.5938e-01,  8.9062e-01],
         ...,
         [-9.2578e-01,  2.4219e-01, -1.3281e+00,  ...,  1.7344e+00,
           3.0273e-01,  1.4766e+00],
         [-2.3633e-01,  3.8672e-01, -3.7695e-01,  ...,  7.6953e-01,
           4.2969e-01,  1.3281e+00],
         [ 4.1016e-01,  1.4375e+00,  2.0703e-01,  ...,  1.9922e+00,
           1.0938e+00, -4.3750e+00]]], dtype=torch.bfloat16), tensor([[[ 9.8877e-03,  2.3193e-03,  7.3547e-03,  ..., -1.6357e-02,
           8.9722e-03, -5.0781e-02],
         [ 2.2461e-01, -3.2422e-01, -6.4062e-01,  ...,  3.5938e-01,
          -6.1328e-01, -2.9053e-02],
         [ 1.3086e-01, -2.3633e-01,  6.6406e-01,  ...,  6.1719e-01,
          -1.0625e+00,  4.8438e-01],
         ...,
         [ 1.2109e+00,  1.9684e-03, -2.0801e-01,  ...,  5.1953e-01,
          -6.2891e-01,  1.4160e-01],
         [ 1.0234e+00,  2.7148e-01, -6.6797e-01,  ...,  1.7456e-02,
          -6.4062e-01,  1.2422e+00],
         [ 4.0430e-01,  5.1172e-01, -5.2734e-01,  ..., -1.9922e-01,
          -7.1484e-01,  1.4453e+00]],

        [[ 9.8877e-03,  2.3193e-03,  7.3547e-03,  ..., -1.6357e-02,
           8.9722e-03, -5.0781e-02],
         [ 2.2461e-01, -3.2422e-01, -6.4062e-01,  ...,  3.5938e-01,
          -6.1328e-01, -2.9053e-02],
         [ 1.3086e-01, -2.3633e-01,  6.6406e-01,  ...,  6.1719e-01,
          -1.0625e+00,  4.8438e-01],
         ...,
         [ 1.2109e+00,  1.9684e-03, -2.0801e-01,  ...,  5.1953e-01,
          -6.2891e-01,  1.4160e-01],
         [ 1.0234e+00,  2.7148e-01, -6.6797e-01,  ...,  1.7456e-02,
          -6.4062e-01,  1.2422e+00],
         [ 4.0430e-01,  5.1172e-01, -5.2734e-01,  ..., -1.9922e-01,
          -7.1484e-01,  1.4453e+00]],

        [[ 9.8877e-03,  2.3193e-03,  7.3547e-03,  ..., -1.6357e-02,
           8.9722e-03, -5.0781e-02],
         [ 2.2461e-01, -3.2422e-01, -6.4062e-01,  ...,  3.5938e-01,
          -6.1328e-01, -2.9053e-02],
         [ 1.3086e-01, -2.3633e-01,  6.6406e-01,  ...,  6.1719e-01,
          -1.0625e+00,  4.8438e-01],
         ...,
         [ 1.2109e+00,  1.9684e-03, -2.0801e-01,  ...,  5.1953e-01,
          -6.2891e-01,  1.4160e-01],
         [ 1.0234e+00,  2.7148e-01, -6.6797e-01,  ...,  1.7456e-02,
          -6.4062e-01,  1.2422e+00],
         [ 4.0430e-01,  5.1172e-01, -5.2734e-01,  ..., -1.9922e-01,
          -7.1484e-01,  1.4453e+00]],

        ...,

        [[ 3.3569e-03, -6.7139e-03, -8.9264e-04,  ...,  1.3046e-03,
           2.6703e-03,  5.0659e-03],
         [-2.7812e+00, -3.8867e-01, -9.0234e-01,  ..., -2.7969e+00,
          -6.2109e-01, -2.7500e+00],
         [-1.2578e+00,  3.4961e-01, -1.9531e-01,  ..., -3.0078e-01,
           8.3984e-01, -1.6484e+00],
         ...,
         [ 3.9307e-02,  7.9346e-03, -2.7539e-01,  ...,  3.5938e-01,
           3.8281e-01, -6.0156e-01],
         [-3.7598e-02,  4.1602e-01, -2.5195e-01,  ..., -1.2988e-01,
          -4.0283e-02, -7.0312e-01],
         [-7.5391e-01,  3.6377e-02,  1.1641e+00,  ...,  9.5703e-01,
           1.5781e+00,  2.9883e-01]],

        [[ 3.3569e-03, -6.7139e-03, -8.9264e-04,  ...,  1.3046e-03,
           2.6703e-03,  5.0659e-03],
         [-2.7812e+00, -3.8867e-01, -9.0234e-01,  ..., -2.7969e+00,
          -6.2109e-01, -2.7500e+00],
         [-1.2578e+00,  3.4961e-01, -1.9531e-01,  ..., -3.0078e-01,
           8.3984e-01, -1.6484e+00],
         ...,
         [ 3.9307e-02,  7.9346e-03, -2.7539e-01,  ...,  3.5938e-01,
           3.8281e-01, -6.0156e-01],
         [-3.7598e-02,  4.1602e-01, -2.5195e-01,  ..., -1.2988e-01,
          -4.0283e-02, -7.0312e-01],
         [-7.5391e-01,  3.6377e-02,  1.1641e+00,  ...,  9.5703e-01,
           1.5781e+00,  2.9883e-01]],

        [[ 3.3569e-03, -6.7139e-03, -8.9264e-04,  ...,  1.3046e-03,
           2.6703e-03,  5.0659e-03],
         [-2.7812e+00, -3.8867e-01, -9.0234e-01,  ..., -2.7969e+00,
          -6.2109e-01, -2.7500e+00],
         [-1.2578e+00,  3.4961e-01, -1.9531e-01,  ..., -3.0078e-01,
           8.3984e-01, -1.6484e+00],
         ...,
         [ 3.9307e-02,  7.9346e-03, -2.7539e-01,  ...,  3.5938e-01,
           3.8281e-01, -6.0156e-01],
         [-3.7598e-02,  4.1602e-01, -2.5195e-01,  ..., -1.2988e-01,
          -4.0283e-02, -7.0312e-01],
         [-7.5391e-01,  3.6377e-02,  1.1641e+00,  ...,  9.5703e-01,
           1.5781e+00,  2.9883e-01]]], dtype=torch.bfloat16)), (tensor([[[-1.6724e-02, -4.3030e-03,  4.7607e-03,  ...,  1.0352e-01,
           8.0566e-02,  3.9062e-01],
         [-4.0625e-01, -7.0312e-01, -8.1250e-01,  ..., -2.0781e+00,
          -1.4922e+00, -6.7812e+00],
         [-1.4688e+00, -1.0312e+00, -9.8828e-01,  ..., -9.4531e-01,
          -2.5469e+00, -6.0625e+00],
         ...,
         [-3.4375e-01,  7.0312e-01,  4.0234e-01,  ..., -1.2656e+00,
          -2.8125e+00, -3.6094e+00],
         [ 1.1719e-02,  1.0156e+00, -1.0938e-01,  ..., -1.5156e+00,
          -2.9375e+00, -5.5000e+00],
         [ 3.3594e-01,  1.7266e+00, -1.5234e+00,  ..., -1.0000e+00,
          -3.4375e+00, -5.5938e+00]],

        [[-1.6724e-02, -4.3030e-03,  4.7607e-03,  ...,  1.0352e-01,
           8.0566e-02,  3.9062e-01],
         [-4.0625e-01, -7.0312e-01, -8.1250e-01,  ..., -2.0781e+00,
          -1.4922e+00, -6.7812e+00],
         [-1.4688e+00, -1.0312e+00, -9.8828e-01,  ..., -9.4531e-01,
          -2.5469e+00, -6.0625e+00],
         ...,
         [-3.4375e-01,  7.0312e-01,  4.0234e-01,  ..., -1.2656e+00,
          -2.8125e+00, -3.6094e+00],
         [ 1.1719e-02,  1.0156e+00, -1.0938e-01,  ..., -1.5156e+00,
          -2.9375e+00, -5.5000e+00],
         [ 3.3594e-01,  1.7266e+00, -1.5234e+00,  ..., -1.0000e+00,
          -3.4375e+00, -5.5938e+00]],

        [[-1.6724e-02, -4.3030e-03,  4.7607e-03,  ...,  1.0352e-01,
           8.0566e-02,  3.9062e-01],
         [-4.0625e-01, -7.0312e-01, -8.1250e-01,  ..., -2.0781e+00,
          -1.4922e+00, -6.7812e+00],
         [-1.4688e+00, -1.0312e+00, -9.8828e-01,  ..., -9.4531e-01,
          -2.5469e+00, -6.0625e+00],
         ...,
         [-3.4375e-01,  7.0312e-01,  4.0234e-01,  ..., -1.2656e+00,
          -2.8125e+00, -3.6094e+00],
         [ 1.1719e-02,  1.0156e+00, -1.0938e-01,  ..., -1.5156e+00,
          -2.9375e+00, -5.5000e+00],
         [ 3.3594e-01,  1.7266e+00, -1.5234e+00,  ..., -1.0000e+00,
          -3.4375e+00, -5.5938e+00]],

        ...,

        [[-5.0354e-04,  4.9744e-03,  1.7471e-03,  ...,  4.1504e-02,
          -3.3203e-02,  4.9316e-02],
         [ 9.1016e-01, -8.8672e-01, -8.7891e-01,  ..., -1.8828e+00,
           1.9297e+00,  2.3047e-01],
         [-1.1406e+00,  1.6895e-01, -6.6797e-01,  ..., -1.1094e+00,
           4.7852e-01,  2.0312e+00],
         ...,
         [-9.3359e-01,  7.5391e-01,  9.0820e-02,  ...,  1.0156e+00,
           1.3047e+00,  7.8613e-02],
         [-3.0469e-01,  2.6367e-01, -7.8125e-01,  ..., -6.3672e-01,
           9.7266e-01, -8.7500e-01],
         [ 1.0625e+00,  2.0938e+00, -1.1484e+00,  ..., -9.0820e-02,
           1.6562e+00,  2.3594e+00]],

        [[-5.0354e-04,  4.9744e-03,  1.7471e-03,  ...,  4.1504e-02,
          -3.3203e-02,  4.9316e-02],
         [ 9.1016e-01, -8.8672e-01, -8.7891e-01,  ..., -1.8828e+00,
           1.9297e+00,  2.3047e-01],
         [-1.1406e+00,  1.6895e-01, -6.6797e-01,  ..., -1.1094e+00,
           4.7852e-01,  2.0312e+00],
         ...,
         [-9.3359e-01,  7.5391e-01,  9.0820e-02,  ...,  1.0156e+00,
           1.3047e+00,  7.8613e-02],
         [-3.0469e-01,  2.6367e-01, -7.8125e-01,  ..., -6.3672e-01,
           9.7266e-01, -8.7500e-01],
         [ 1.0625e+00,  2.0938e+00, -1.1484e+00,  ..., -9.0820e-02,
           1.6562e+00,  2.3594e+00]],

        [[-5.0354e-04,  4.9744e-03,  1.7471e-03,  ...,  4.1504e-02,
          -3.3203e-02,  4.9316e-02],
         [ 9.1016e-01, -8.8672e-01, -8.7891e-01,  ..., -1.8828e+00,
           1.9297e+00,  2.3047e-01],
         [-1.1406e+00,  1.6895e-01, -6.6797e-01,  ..., -1.1094e+00,
           4.7852e-01,  2.0312e+00],
         ...,
         [-9.3359e-01,  7.5391e-01,  9.0820e-02,  ...,  1.0156e+00,
           1.3047e+00,  7.8613e-02],
         [-3.0469e-01,  2.6367e-01, -7.8125e-01,  ..., -6.3672e-01,
           9.7266e-01, -8.7500e-01],
         [ 1.0625e+00,  2.0938e+00, -1.1484e+00,  ..., -9.0820e-02,
           1.6562e+00,  2.3594e+00]]], dtype=torch.bfloat16), tensor([[[ 1.1902e-03, -6.4392e-03,  8.5449e-03,  ...,  5.4626e-03,
           7.1716e-03,  7.1106e-03],
         [ 5.1562e-01,  1.0625e+00, -7.3438e-01,  ...,  1.2266e+00,
          -1.7422e+00,  6.5625e-01],
         [ 3.5156e-01, -6.4844e-01, -9.4531e-01,  ...,  4.4727e-01,
          -1.4531e+00,  6.0938e-01],
         ...,
         [ 2.8076e-02,  1.6406e+00, -2.3926e-01,  ..., -1.0703e+00,
          -1.5234e+00,  4.5508e-01],
         [ 3.6523e-01,  1.3184e-01,  4.9316e-02,  ...,  1.6211e-01,
          -3.7305e-01,  5.3516e-01],
         [ 1.1875e+00,  1.4648e-01, -1.1406e+00,  ...,  7.0801e-02,
           4.4189e-02, -7.7148e-02]],

        [[ 1.1902e-03, -6.4392e-03,  8.5449e-03,  ...,  5.4626e-03,
           7.1716e-03,  7.1106e-03],
         [ 5.1562e-01,  1.0625e+00, -7.3438e-01,  ...,  1.2266e+00,
          -1.7422e+00,  6.5625e-01],
         [ 3.5156e-01, -6.4844e-01, -9.4531e-01,  ...,  4.4727e-01,
          -1.4531e+00,  6.0938e-01],
         ...,
         [ 2.8076e-02,  1.6406e+00, -2.3926e-01,  ..., -1.0703e+00,
          -1.5234e+00,  4.5508e-01],
         [ 3.6523e-01,  1.3184e-01,  4.9316e-02,  ...,  1.6211e-01,
          -3.7305e-01,  5.3516e-01],
         [ 1.1875e+00,  1.4648e-01, -1.1406e+00,  ...,  7.0801e-02,
           4.4189e-02, -7.7148e-02]],

        [[ 1.1902e-03, -6.4392e-03,  8.5449e-03,  ...,  5.4626e-03,
           7.1716e-03,  7.1106e-03],
         [ 5.1562e-01,  1.0625e+00, -7.3438e-01,  ...,  1.2266e+00,
          -1.7422e+00,  6.5625e-01],
         [ 3.5156e-01, -6.4844e-01, -9.4531e-01,  ...,  4.4727e-01,
          -1.4531e+00,  6.0938e-01],
         ...,
         [ 2.8076e-02,  1.6406e+00, -2.3926e-01,  ..., -1.0703e+00,
          -1.5234e+00,  4.5508e-01],
         [ 3.6523e-01,  1.3184e-01,  4.9316e-02,  ...,  1.6211e-01,
          -3.7305e-01,  5.3516e-01],
         [ 1.1875e+00,  1.4648e-01, -1.1406e+00,  ...,  7.0801e-02,
           4.4189e-02, -7.7148e-02]],

        ...,

        [[-5.5847e-03, -3.2501e-03,  4.8523e-03,  ...,  1.0010e-02,
          -5.5542e-03,  7.4768e-03],
         [ 5.9375e-01, -1.4160e-01, -1.7383e-01,  ...,  2.6758e-01,
          -4.1992e-01,  2.2949e-01],
         [-2.5781e-01,  1.6562e+00,  1.6172e+00,  ...,  1.5781e+00,
           7.8516e-01,  3.4375e-01],
         ...,
         [-1.1719e+00,  9.8438e-01, -1.4375e+00,  ..., -8.8281e-01,
           1.0078e+00, -5.7031e-01],
         [-2.3730e-01,  2.6367e-01, -3.3398e-01,  ...,  2.0605e-01,
           8.3594e-01,  4.0234e-01],
         [-3.6719e-01, -1.4609e+00,  2.3242e-01,  ..., -1.0469e+00,
           3.7695e-01, -2.0703e-01]],

        [[-5.5847e-03, -3.2501e-03,  4.8523e-03,  ...,  1.0010e-02,
          -5.5542e-03,  7.4768e-03],
         [ 5.9375e-01, -1.4160e-01, -1.7383e-01,  ...,  2.6758e-01,
          -4.1992e-01,  2.2949e-01],
         [-2.5781e-01,  1.6562e+00,  1.6172e+00,  ...,  1.5781e+00,
           7.8516e-01,  3.4375e-01],
         ...,
         [-1.1719e+00,  9.8438e-01, -1.4375e+00,  ..., -8.8281e-01,
           1.0078e+00, -5.7031e-01],
         [-2.3730e-01,  2.6367e-01, -3.3398e-01,  ...,  2.0605e-01,
           8.3594e-01,  4.0234e-01],
         [-3.6719e-01, -1.4609e+00,  2.3242e-01,  ..., -1.0469e+00,
           3.7695e-01, -2.0703e-01]],

        [[-5.5847e-03, -3.2501e-03,  4.8523e-03,  ...,  1.0010e-02,
          -5.5542e-03,  7.4768e-03],
         [ 5.9375e-01, -1.4160e-01, -1.7383e-01,  ...,  2.6758e-01,
          -4.1992e-01,  2.2949e-01],
         [-2.5781e-01,  1.6562e+00,  1.6172e+00,  ...,  1.5781e+00,
           7.8516e-01,  3.4375e-01],
         ...,
         [-1.1719e+00,  9.8438e-01, -1.4375e+00,  ..., -8.8281e-01,
           1.0078e+00, -5.7031e-01],
         [-2.3730e-01,  2.6367e-01, -3.3398e-01,  ...,  2.0605e-01,
           8.3594e-01,  4.0234e-01],
         [-3.6719e-01, -1.4609e+00,  2.3242e-01,  ..., -1.0469e+00,
           3.7695e-01, -2.0703e-01]]], dtype=torch.bfloat16)), (tensor([[[ 2.4414e-03,  9.9487e-03,  9.2773e-03,  ..., -2.6367e-01,
           4.9316e-02,  1.9043e-02],
         [-8.3594e-01, -7.3828e-01,  3.5352e-01,  ...,  4.2812e+00,
          -3.0469e-01, -2.1875e+00],
         [ 6.0156e-01,  1.6406e-01,  8.0078e-02,  ...,  6.3125e+00,
          -2.9883e-01,  1.1658e-02],
         ...,
         [ 1.0547e+00, -8.3594e-01, -9.4922e-01,  ...,  2.2656e+00,
           1.3281e-01,  8.7891e-01],
         [-7.4219e-02, -5.0781e-01, -8.7500e-01,  ...,  3.4688e+00,
           3.3398e-01,  8.4473e-02],
         [-1.1172e+00, -8.5156e-01, -6.6406e-01,  ...,  3.9375e+00,
          -9.5703e-01, -3.9062e-01]],

        [[ 2.4414e-03,  9.9487e-03,  9.2773e-03,  ..., -2.6367e-01,
           4.9316e-02,  1.9043e-02],
         [-8.3594e-01, -7.3828e-01,  3.5352e-01,  ...,  4.2812e+00,
          -3.0469e-01, -2.1875e+00],
         [ 6.0156e-01,  1.6406e-01,  8.0078e-02,  ...,  6.3125e+00,
          -2.9883e-01,  1.1658e-02],
         ...,
         [ 1.0547e+00, -8.3594e-01, -9.4922e-01,  ...,  2.2656e+00,
           1.3281e-01,  8.7891e-01],
         [-7.4219e-02, -5.0781e-01, -8.7500e-01,  ...,  3.4688e+00,
           3.3398e-01,  8.4473e-02],
         [-1.1172e+00, -8.5156e-01, -6.6406e-01,  ...,  3.9375e+00,
          -9.5703e-01, -3.9062e-01]],

        [[ 2.4414e-03,  9.9487e-03,  9.2773e-03,  ..., -2.6367e-01,
           4.9316e-02,  1.9043e-02],
         [-8.3594e-01, -7.3828e-01,  3.5352e-01,  ...,  4.2812e+00,
          -3.0469e-01, -2.1875e+00],
         [ 6.0156e-01,  1.6406e-01,  8.0078e-02,  ...,  6.3125e+00,
          -2.9883e-01,  1.1658e-02],
         ...,
         [ 1.0547e+00, -8.3594e-01, -9.4922e-01,  ...,  2.2656e+00,
           1.3281e-01,  8.7891e-01],
         [-7.4219e-02, -5.0781e-01, -8.7500e-01,  ...,  3.4688e+00,
           3.3398e-01,  8.4473e-02],
         [-1.1172e+00, -8.5156e-01, -6.6406e-01,  ...,  3.9375e+00,
          -9.5703e-01, -3.9062e-01]],

        ...,

        [[-6.4850e-04,  4.4250e-03, -1.9684e-03,  ...,  9.6680e-02,
          -5.8984e-01, -5.2246e-02],
         [-2.4609e-01, -6.4844e-01, -1.1133e-01,  ...,  2.4062e+00,
           8.1250e+00,  3.2188e+00],
         [-8.2031e-02,  6.2500e-01, -2.9492e-01,  ..., -3.3750e+00,
           1.0375e+01,  4.1406e-01],
         ...,
         [-7.4609e-01, -3.5547e-01, -2.1875e-01,  ...,  3.2812e-01,
           8.5000e+00,  5.8125e+00],
         [-9.0234e-01, -2.9492e-01, -6.7969e-01,  ..., -1.2109e+00,
           7.6562e+00,  1.9141e+00],
         [ 5.4688e-01, -8.9844e-01, -3.2227e-01,  ...,  3.7305e-01,
           9.4375e+00, -5.3750e+00]],

        [[-6.4850e-04,  4.4250e-03, -1.9684e-03,  ...,  9.6680e-02,
          -5.8984e-01, -5.2246e-02],
         [-2.4609e-01, -6.4844e-01, -1.1133e-01,  ...,  2.4062e+00,
           8.1250e+00,  3.2188e+00],
         [-8.2031e-02,  6.2500e-01, -2.9492e-01,  ..., -3.3750e+00,
           1.0375e+01,  4.1406e-01],
         ...,
         [-7.4609e-01, -3.5547e-01, -2.1875e-01,  ...,  3.2812e-01,
           8.5000e+00,  5.8125e+00],
         [-9.0234e-01, -2.9492e-01, -6.7969e-01,  ..., -1.2109e+00,
           7.6562e+00,  1.9141e+00],
         [ 5.4688e-01, -8.9844e-01, -3.2227e-01,  ...,  3.7305e-01,
           9.4375e+00, -5.3750e+00]],

        [[-6.4850e-04,  4.4250e-03, -1.9684e-03,  ...,  9.6680e-02,
          -5.8984e-01, -5.2246e-02],
         [-2.4609e-01, -6.4844e-01, -1.1133e-01,  ...,  2.4062e+00,
           8.1250e+00,  3.2188e+00],
         [-8.2031e-02,  6.2500e-01, -2.9492e-01,  ..., -3.3750e+00,
           1.0375e+01,  4.1406e-01],
         ...,
         [-7.4609e-01, -3.5547e-01, -2.1875e-01,  ...,  3.2812e-01,
           8.5000e+00,  5.8125e+00],
         [-9.0234e-01, -2.9492e-01, -6.7969e-01,  ..., -1.2109e+00,
           7.6562e+00,  1.9141e+00],
         [ 5.4688e-01, -8.9844e-01, -3.2227e-01,  ...,  3.7305e-01,
           9.4375e+00, -5.3750e+00]]], dtype=torch.bfloat16), tensor([[[-4.2114e-03,  1.7700e-03,  8.1787e-03,  ...,  2.0599e-03,
          -3.2616e-04, -2.4414e-03],
         [ 2.2031e+00, -1.1377e-01, -2.0156e+00,  ...,  1.6719e+00,
          -4.4375e+00, -2.3594e+00],
         [ 5.7812e-01, -7.9297e-01, -7.2266e-01,  ...,  6.5430e-02,
           8.1543e-02, -1.7891e+00],
         ...,
         [ 9.1406e-01, -1.8672e+00, -6.3281e-01,  ..., -3.0469e-01,
          -6.7383e-02, -3.3203e-01],
         [ 1.0938e+00,  3.7891e-01, -1.0498e-01,  ..., -3.2422e-01,
           1.5137e-01, -4.5300e-05],
         [ 4.9219e-01,  2.0312e+00,  2.0469e+00,  ...,  1.5391e+00,
          -2.3730e-01, -2.8320e-01]],

        [[-4.2114e-03,  1.7700e-03,  8.1787e-03,  ...,  2.0599e-03,
          -3.2616e-04, -2.4414e-03],
         [ 2.2031e+00, -1.1377e-01, -2.0156e+00,  ...,  1.6719e+00,
          -4.4375e+00, -2.3594e+00],
         [ 5.7812e-01, -7.9297e-01, -7.2266e-01,  ...,  6.5430e-02,
           8.1543e-02, -1.7891e+00],
         ...,
         [ 9.1406e-01, -1.8672e+00, -6.3281e-01,  ..., -3.0469e-01,
          -6.7383e-02, -3.3203e-01],
         [ 1.0938e+00,  3.7891e-01, -1.0498e-01,  ..., -3.2422e-01,
           1.5137e-01, -4.5300e-05],
         [ 4.9219e-01,  2.0312e+00,  2.0469e+00,  ...,  1.5391e+00,
          -2.3730e-01, -2.8320e-01]],

        [[-4.2114e-03,  1.7700e-03,  8.1787e-03,  ...,  2.0599e-03,
          -3.2616e-04, -2.4414e-03],
         [ 2.2031e+00, -1.1377e-01, -2.0156e+00,  ...,  1.6719e+00,
          -4.4375e+00, -2.3594e+00],
         [ 5.7812e-01, -7.9297e-01, -7.2266e-01,  ...,  6.5430e-02,
           8.1543e-02, -1.7891e+00],
         ...,
         [ 9.1406e-01, -1.8672e+00, -6.3281e-01,  ..., -3.0469e-01,
          -6.7383e-02, -3.3203e-01],
         [ 1.0938e+00,  3.7891e-01, -1.0498e-01,  ..., -3.2422e-01,
           1.5137e-01, -4.5300e-05],
         [ 4.9219e-01,  2.0312e+00,  2.0469e+00,  ...,  1.5391e+00,
          -2.3730e-01, -2.8320e-01]],

        ...,

        [[ 1.8311e-02, -2.3499e-03,  1.9455e-04,  ..., -1.7334e-02,
           4.1389e-04,  4.0283e-03],
         [ 2.8125e-01,  3.6914e-01, -3.3398e-01,  ...,  1.3672e+00,
           4.8047e-01,  6.4062e-01],
         [-1.0391e+00, -1.2695e-01,  4.0039e-01,  ...,  5.3467e-02,
           1.9043e-01,  2.7148e-01],
         ...,
         [ 4.4727e-01,  2.8320e-01,  5.4688e-01,  ..., -5.1562e-01,
          -4.1211e-01,  5.3516e-01],
         [-1.3770e-01,  5.0391e-01,  1.8945e-01,  ...,  2.1289e-01,
           6.9885e-03, -1.1084e-01],
         [-7.6172e-01,  1.3438e+00, -1.7700e-02,  ..., -1.0469e+00,
           1.3828e+00,  2.2949e-02]],

        [[ 1.8311e-02, -2.3499e-03,  1.9455e-04,  ..., -1.7334e-02,
           4.1389e-04,  4.0283e-03],
         [ 2.8125e-01,  3.6914e-01, -3.3398e-01,  ...,  1.3672e+00,
           4.8047e-01,  6.4062e-01],
         [-1.0391e+00, -1.2695e-01,  4.0039e-01,  ...,  5.3467e-02,
           1.9043e-01,  2.7148e-01],
         ...,
         [ 4.4727e-01,  2.8320e-01,  5.4688e-01,  ..., -5.1562e-01,
          -4.1211e-01,  5.3516e-01],
         [-1.3770e-01,  5.0391e-01,  1.8945e-01,  ...,  2.1289e-01,
           6.9885e-03, -1.1084e-01],
         [-7.6172e-01,  1.3438e+00, -1.7700e-02,  ..., -1.0469e+00,
           1.3828e+00,  2.2949e-02]],

        [[ 1.8311e-02, -2.3499e-03,  1.9455e-04,  ..., -1.7334e-02,
           4.1389e-04,  4.0283e-03],
         [ 2.8125e-01,  3.6914e-01, -3.3398e-01,  ...,  1.3672e+00,
           4.8047e-01,  6.4062e-01],
         [-1.0391e+00, -1.2695e-01,  4.0039e-01,  ...,  5.3467e-02,
           1.9043e-01,  2.7148e-01],
         ...,
         [ 4.4727e-01,  2.8320e-01,  5.4688e-01,  ..., -5.1562e-01,
          -4.1211e-01,  5.3516e-01],
         [-1.3770e-01,  5.0391e-01,  1.8945e-01,  ...,  2.1289e-01,
           6.9885e-03, -1.1084e-01],
         [-7.6172e-01,  1.3438e+00, -1.7700e-02,  ..., -1.0469e+00,
           1.3828e+00,  2.2949e-02]]], dtype=torch.bfloat16)), (tensor([[[-3.4943e-03,  2.4414e-02,  5.6839e-04,  ...,  5.4932e-02,
          -1.2354e-01, -2.4023e-01],
         [-1.8359e-01, -8.9844e-02, -1.6953e+00,  ..., -1.0156e+00,
           3.7188e+00,  4.8047e-01],
         [-1.4062e+00, -1.9219e+00, -7.4609e-01,  ...,  1.8672e+00,
          -7.1716e-03,  3.0625e+00],
         ...,
         [ 1.8164e-01, -3.8086e-02, -7.7148e-02,  ..., -1.6328e+00,
           7.0703e-01,  1.2344e+00],
         [ 2.1094e-01,  3.1250e-01, -5.7031e-01,  ..., -6.0156e-01,
           7.6953e-01,  2.1875e+00],
         [ 7.4219e-01,  8.6719e-01, -5.8594e-01,  ...,  1.5527e-01,
           8.8281e-01,  8.6914e-02]],

        [[-3.4943e-03,  2.4414e-02,  5.6839e-04,  ...,  5.4932e-02,
          -1.2354e-01, -2.4023e-01],
         [-1.8359e-01, -8.9844e-02, -1.6953e+00,  ..., -1.0156e+00,
           3.7188e+00,  4.8047e-01],
         [-1.4062e+00, -1.9219e+00, -7.4609e-01,  ...,  1.8672e+00,
          -7.1716e-03,  3.0625e+00],
         ...,
         [ 1.8164e-01, -3.8086e-02, -7.7148e-02,  ..., -1.6328e+00,
           7.0703e-01,  1.2344e+00],
         [ 2.1094e-01,  3.1250e-01, -5.7031e-01,  ..., -6.0156e-01,
           7.6953e-01,  2.1875e+00],
         [ 7.4219e-01,  8.6719e-01, -5.8594e-01,  ...,  1.5527e-01,
           8.8281e-01,  8.6914e-02]],

        [[-3.4943e-03,  2.4414e-02,  5.6839e-04,  ...,  5.4932e-02,
          -1.2354e-01, -2.4023e-01],
         [-1.8359e-01, -8.9844e-02, -1.6953e+00,  ..., -1.0156e+00,
           3.7188e+00,  4.8047e-01],
         [-1.4062e+00, -1.9219e+00, -7.4609e-01,  ...,  1.8672e+00,
          -7.1716e-03,  3.0625e+00],
         ...,
         [ 1.8164e-01, -3.8086e-02, -7.7148e-02,  ..., -1.6328e+00,
           7.0703e-01,  1.2344e+00],
         [ 2.1094e-01,  3.1250e-01, -5.7031e-01,  ..., -6.0156e-01,
           7.6953e-01,  2.1875e+00],
         [ 7.4219e-01,  8.6719e-01, -5.8594e-01,  ...,  1.5527e-01,
           8.8281e-01,  8.6914e-02]],

        ...,

        [[ 1.0729e-04,  8.3008e-03, -1.4099e-02,  ...,  5.6763e-03,
          -5.2979e-02,  1.0864e-02],
         [-3.2031e-01, -1.0498e-01,  6.7188e-01,  ..., -9.3750e-01,
           4.4688e+00, -1.2344e+00],
         [-5.8203e-01, -7.8906e-01,  1.0156e+00,  ...,  1.8516e+00,
           2.5312e+00, -1.0547e+00],
         ...,
         [-2.4414e-01,  3.3008e-01,  3.3398e-01,  ...,  6.2012e-02,
           1.0859e+00, -2.2656e+00],
         [ 2.3438e-01,  3.1836e-01,  3.5547e-01,  ...,  3.4961e-01,
           2.1250e+00, -1.1621e-01],
         [ 8.1250e-01,  8.7891e-01,  1.1406e+00,  ..., -1.0703e+00,
           2.3242e-01, -5.5078e-01]],

        [[ 1.0729e-04,  8.3008e-03, -1.4099e-02,  ...,  5.6763e-03,
          -5.2979e-02,  1.0864e-02],
         [-3.2031e-01, -1.0498e-01,  6.7188e-01,  ..., -9.3750e-01,
           4.4688e+00, -1.2344e+00],
         [-5.8203e-01, -7.8906e-01,  1.0156e+00,  ...,  1.8516e+00,
           2.5312e+00, -1.0547e+00],
         ...,
         [-2.4414e-01,  3.3008e-01,  3.3398e-01,  ...,  6.2012e-02,
           1.0859e+00, -2.2656e+00],
         [ 2.3438e-01,  3.1836e-01,  3.5547e-01,  ...,  3.4961e-01,
           2.1250e+00, -1.1621e-01],
         [ 8.1250e-01,  8.7891e-01,  1.1406e+00,  ..., -1.0703e+00,
           2.3242e-01, -5.5078e-01]],

        [[ 1.0729e-04,  8.3008e-03, -1.4099e-02,  ...,  5.6763e-03,
          -5.2979e-02,  1.0864e-02],
         [-3.2031e-01, -1.0498e-01,  6.7188e-01,  ..., -9.3750e-01,
           4.4688e+00, -1.2344e+00],
         [-5.8203e-01, -7.8906e-01,  1.0156e+00,  ...,  1.8516e+00,
           2.5312e+00, -1.0547e+00],
         ...,
         [-2.4414e-01,  3.3008e-01,  3.3398e-01,  ...,  6.2012e-02,
           1.0859e+00, -2.2656e+00],
         [ 2.3438e-01,  3.1836e-01,  3.5547e-01,  ...,  3.4961e-01,
           2.1250e+00, -1.1621e-01],
         [ 8.1250e-01,  8.7891e-01,  1.1406e+00,  ..., -1.0703e+00,
           2.3242e-01, -5.5078e-01]]], dtype=torch.bfloat16), tensor([[[ 4.5776e-03,  1.1414e-02,  6.9275e-03,  ...,  3.7994e-03,
          -2.7161e-03, -2.9755e-03],
         [ 1.0469e+00, -4.7812e+00, -1.7734e+00,  ..., -2.7188e+00,
           1.1641e+00,  6.6016e-01],
         [ 6.2891e-01, -1.8359e+00, -9.1797e-01,  ...,  1.0010e-02,
           1.4453e+00,  1.7266e+00],
         ...,
         [-3.3984e-01, -1.1250e+00, -1.0078e+00,  ...,  1.2891e+00,
           1.0391e+00,  1.1562e+00],
         [ 1.3477e-01, -2.9053e-02, -8.8281e-01,  ...,  6.7188e-01,
           8.5547e-01,  5.8594e-01],
         [-6.9922e-01,  2.9102e-01,  1.0312e+00,  ..., -4.5508e-01,
           4.6094e-01,  1.3125e+00]],

        [[ 4.5776e-03,  1.1414e-02,  6.9275e-03,  ...,  3.7994e-03,
          -2.7161e-03, -2.9755e-03],
         [ 1.0469e+00, -4.7812e+00, -1.7734e+00,  ..., -2.7188e+00,
           1.1641e+00,  6.6016e-01],
         [ 6.2891e-01, -1.8359e+00, -9.1797e-01,  ...,  1.0010e-02,
           1.4453e+00,  1.7266e+00],
         ...,
         [-3.3984e-01, -1.1250e+00, -1.0078e+00,  ...,  1.2891e+00,
           1.0391e+00,  1.1562e+00],
         [ 1.3477e-01, -2.9053e-02, -8.8281e-01,  ...,  6.7188e-01,
           8.5547e-01,  5.8594e-01],
         [-6.9922e-01,  2.9102e-01,  1.0312e+00,  ..., -4.5508e-01,
           4.6094e-01,  1.3125e+00]],

        [[ 4.5776e-03,  1.1414e-02,  6.9275e-03,  ...,  3.7994e-03,
          -2.7161e-03, -2.9755e-03],
         [ 1.0469e+00, -4.7812e+00, -1.7734e+00,  ..., -2.7188e+00,
           1.1641e+00,  6.6016e-01],
         [ 6.2891e-01, -1.8359e+00, -9.1797e-01,  ...,  1.0010e-02,
           1.4453e+00,  1.7266e+00],
         ...,
         [-3.3984e-01, -1.1250e+00, -1.0078e+00,  ...,  1.2891e+00,
           1.0391e+00,  1.1562e+00],
         [ 1.3477e-01, -2.9053e-02, -8.8281e-01,  ...,  6.7188e-01,
           8.5547e-01,  5.8594e-01],
         [-6.9922e-01,  2.9102e-01,  1.0312e+00,  ..., -4.5508e-01,
           4.6094e-01,  1.3125e+00]],

        ...,

        [[-3.8757e-03,  1.6327e-03,  7.3242e-03,  ..., -1.7822e-02,
          -6.9336e-02,  4.8447e-04],
         [-1.7344e+00,  4.6094e-01, -3.1875e+00,  ..., -2.4609e-01,
           2.3125e+00, -3.9688e+00],
         [ 8.3984e-01,  1.3281e+00, -1.0938e+00,  ...,  2.6953e-01,
          -2.3145e-01, -3.3203e-01],
         ...,
         [ 6.2500e-01, -6.8359e-01,  5.4297e-01,  ..., -7.5391e-01,
          -7.6172e-01,  9.2578e-01],
         [ 7.4219e-01,  4.6289e-01, -2.3535e-01,  ...,  4.4922e-01,
           1.9453e+00,  4.7070e-01],
         [ 7.8516e-01,  3.6914e-01, -2.0781e+00,  ...,  1.9375e+00,
           1.0234e+00,  1.1797e+00]],

        [[-3.8757e-03,  1.6327e-03,  7.3242e-03,  ..., -1.7822e-02,
          -6.9336e-02,  4.8447e-04],
         [-1.7344e+00,  4.6094e-01, -3.1875e+00,  ..., -2.4609e-01,
           2.3125e+00, -3.9688e+00],
         [ 8.3984e-01,  1.3281e+00, -1.0938e+00,  ...,  2.6953e-01,
          -2.3145e-01, -3.3203e-01],
         ...,
         [ 6.2500e-01, -6.8359e-01,  5.4297e-01,  ..., -7.5391e-01,
          -7.6172e-01,  9.2578e-01],
         [ 7.4219e-01,  4.6289e-01, -2.3535e-01,  ...,  4.4922e-01,
           1.9453e+00,  4.7070e-01],
         [ 7.8516e-01,  3.6914e-01, -2.0781e+00,  ...,  1.9375e+00,
           1.0234e+00,  1.1797e+00]],

        [[-3.8757e-03,  1.6327e-03,  7.3242e-03,  ..., -1.7822e-02,
          -6.9336e-02,  4.8447e-04],
         [-1.7344e+00,  4.6094e-01, -3.1875e+00,  ..., -2.4609e-01,
           2.3125e+00, -3.9688e+00],
         [ 8.3984e-01,  1.3281e+00, -1.0938e+00,  ...,  2.6953e-01,
          -2.3145e-01, -3.3203e-01],
         ...,
         [ 6.2500e-01, -6.8359e-01,  5.4297e-01,  ..., -7.5391e-01,
          -7.6172e-01,  9.2578e-01],
         [ 7.4219e-01,  4.6289e-01, -2.3535e-01,  ...,  4.4922e-01,
           1.9453e+00,  4.7070e-01],
         [ 7.8516e-01,  3.6914e-01, -2.0781e+00,  ...,  1.9375e+00,
           1.0234e+00,  1.1797e+00]]], dtype=torch.bfloat16)), (tensor([[[-0.0067,  0.0189, -0.0097,  ...,  0.1387,  0.0334,  0.0076],
         [-0.2676, -0.6953,  1.0625,  ..., -3.4531, -0.6523,  1.7344],
         [-0.5508,  1.1797,  0.9883,  ..., -4.2188,  0.6836,  0.8125],
         ...,
         [ 0.1611,  0.0879,  1.6562,  ..., -4.1250, -1.5078, -1.5156],
         [ 0.4375, -0.8086,  1.0938,  ..., -4.0000, -1.8438, -0.5469],
         [ 0.2734, -1.9844,  1.8047,  ..., -3.6250, -0.5430,  0.0093]],

        [[-0.0067,  0.0189, -0.0097,  ...,  0.1387,  0.0334,  0.0076],
         [-0.2676, -0.6953,  1.0625,  ..., -3.4531, -0.6523,  1.7344],
         [-0.5508,  1.1797,  0.9883,  ..., -4.2188,  0.6836,  0.8125],
         ...,
         [ 0.1611,  0.0879,  1.6562,  ..., -4.1250, -1.5078, -1.5156],
         [ 0.4375, -0.8086,  1.0938,  ..., -4.0000, -1.8438, -0.5469],
         [ 0.2734, -1.9844,  1.8047,  ..., -3.6250, -0.5430,  0.0093]],

        [[-0.0067,  0.0189, -0.0097,  ...,  0.1387,  0.0334,  0.0076],
         [-0.2676, -0.6953,  1.0625,  ..., -3.4531, -0.6523,  1.7344],
         [-0.5508,  1.1797,  0.9883,  ..., -4.2188,  0.6836,  0.8125],
         ...,
         [ 0.1611,  0.0879,  1.6562,  ..., -4.1250, -1.5078, -1.5156],
         [ 0.4375, -0.8086,  1.0938,  ..., -4.0000, -1.8438, -0.5469],
         [ 0.2734, -1.9844,  1.8047,  ..., -3.6250, -0.5430,  0.0093]],

        ...,

        [[-0.0256,  0.0177,  0.0053,  ...,  0.0048, -0.0129, -0.0820],
         [ 1.0859, -0.6836, -0.5078,  ...,  0.0684, -0.6016,  1.6953],
         [-1.2891, -1.2500,  0.5430,  ...,  1.9766,  0.4082,  2.3125],
         ...,
         [-1.8438,  1.1328, -0.5898,  ..., -0.2715, -0.8984,  0.7969],
         [ 0.0156,  0.4453, -0.3672,  ...,  0.0508, -0.8281,  0.9844],
         [ 1.5625, -1.1094, -0.3203,  ..., -1.5312, -0.0349,  0.5547]],

        [[-0.0256,  0.0177,  0.0053,  ...,  0.0048, -0.0129, -0.0820],
         [ 1.0859, -0.6836, -0.5078,  ...,  0.0684, -0.6016,  1.6953],
         [-1.2891, -1.2500,  0.5430,  ...,  1.9766,  0.4082,  2.3125],
         ...,
         [-1.8438,  1.1328, -0.5898,  ..., -0.2715, -0.8984,  0.7969],
         [ 0.0156,  0.4453, -0.3672,  ...,  0.0508, -0.8281,  0.9844],
         [ 1.5625, -1.1094, -0.3203,  ..., -1.5312, -0.0349,  0.5547]],

        [[-0.0256,  0.0177,  0.0053,  ...,  0.0048, -0.0129, -0.0820],
         [ 1.0859, -0.6836, -0.5078,  ...,  0.0684, -0.6016,  1.6953],
         [-1.2891, -1.2500,  0.5430,  ...,  1.9766,  0.4082,  2.3125],
         ...,
         [-1.8438,  1.1328, -0.5898,  ..., -0.2715, -0.8984,  0.7969],
         [ 0.0156,  0.4453, -0.3672,  ...,  0.0508, -0.8281,  0.9844],
         [ 1.5625, -1.1094, -0.3203,  ..., -1.5312, -0.0349,  0.5547]]],
       dtype=torch.bfloat16), tensor([[[-5.5542e-03, -1.3123e-03, -4.8218e-03,  ...,  5.2185e-03,
          -5.2185e-03, -7.4768e-04],
         [-1.1562e+00,  5.5469e-01,  4.4922e-01,  ..., -1.0078e+00,
          -3.0859e-01,  2.7734e-01],
         [ 5.1172e-01, -2.8711e-01,  3.6865e-02,  ...,  5.1953e-01,
           8.7109e-01, -9.5312e-01],
         ...,
         [ 4.5508e-01,  4.0430e-01, -2.4292e-02,  ..., -5.8350e-02,
          -1.7031e+00, -7.5781e-01],
         [ 1.9141e-01,  3.3936e-02,  1.7969e-01,  ..., -1.0400e-01,
          -1.0781e+00, -7.8906e-01],
         [ 2.2461e-01,  9.8828e-01, -3.6133e-01,  ...,  4.0625e-01,
          -5.4297e-01,  2.0312e-01]],

        [[-5.5542e-03, -1.3123e-03, -4.8218e-03,  ...,  5.2185e-03,
          -5.2185e-03, -7.4768e-04],
         [-1.1562e+00,  5.5469e-01,  4.4922e-01,  ..., -1.0078e+00,
          -3.0859e-01,  2.7734e-01],
         [ 5.1172e-01, -2.8711e-01,  3.6865e-02,  ...,  5.1953e-01,
           8.7109e-01, -9.5312e-01],
         ...,
         [ 4.5508e-01,  4.0430e-01, -2.4292e-02,  ..., -5.8350e-02,
          -1.7031e+00, -7.5781e-01],
         [ 1.9141e-01,  3.3936e-02,  1.7969e-01,  ..., -1.0400e-01,
          -1.0781e+00, -7.8906e-01],
         [ 2.2461e-01,  9.8828e-01, -3.6133e-01,  ...,  4.0625e-01,
          -5.4297e-01,  2.0312e-01]],

        [[-5.5542e-03, -1.3123e-03, -4.8218e-03,  ...,  5.2185e-03,
          -5.2185e-03, -7.4768e-04],
         [-1.1562e+00,  5.5469e-01,  4.4922e-01,  ..., -1.0078e+00,
          -3.0859e-01,  2.7734e-01],
         [ 5.1172e-01, -2.8711e-01,  3.6865e-02,  ...,  5.1953e-01,
           8.7109e-01, -9.5312e-01],
         ...,
         [ 4.5508e-01,  4.0430e-01, -2.4292e-02,  ..., -5.8350e-02,
          -1.7031e+00, -7.5781e-01],
         [ 1.9141e-01,  3.3936e-02,  1.7969e-01,  ..., -1.0400e-01,
          -1.0781e+00, -7.8906e-01],
         [ 2.2461e-01,  9.8828e-01, -3.6133e-01,  ...,  4.0625e-01,
          -5.4297e-01,  2.0312e-01]],

        ...,

        [[-7.5378e-03, -4.8161e-05,  1.0132e-02,  ..., -6.2256e-03,
          -5.2185e-03, -4.3640e-03],
         [ 1.0156e+00, -2.6367e-01, -4.8242e-01,  ..., -1.2969e+00,
           5.2734e-01,  2.0996e-01],
         [-4.5898e-02,  1.1719e+00,  9.1797e-01,  ..., -1.0000e+00,
           8.6719e-01, -4.1211e-01],
         ...,
         [-2.7930e-01, -1.7734e+00, -4.6484e-01,  ..., -1.1484e+00,
          -7.7344e-01,  1.9434e-01],
         [-2.1484e-02, -4.3164e-01, -3.5547e-01,  ..., -4.2578e-01,
          -1.4844e+00, -4.3750e-01],
         [-1.2188e+00, -5.0781e-01,  4.6875e-02,  ..., -7.5195e-02,
           7.5781e-01,  2.1387e-01]],

        [[-7.5378e-03, -4.8161e-05,  1.0132e-02,  ..., -6.2256e-03,
          -5.2185e-03, -4.3640e-03],
         [ 1.0156e+00, -2.6367e-01, -4.8242e-01,  ..., -1.2969e+00,
           5.2734e-01,  2.0996e-01],
         [-4.5898e-02,  1.1719e+00,  9.1797e-01,  ..., -1.0000e+00,
           8.6719e-01, -4.1211e-01],
         ...,
         [-2.7930e-01, -1.7734e+00, -4.6484e-01,  ..., -1.1484e+00,
          -7.7344e-01,  1.9434e-01],
         [-2.1484e-02, -4.3164e-01, -3.5547e-01,  ..., -4.2578e-01,
          -1.4844e+00, -4.3750e-01],
         [-1.2188e+00, -5.0781e-01,  4.6875e-02,  ..., -7.5195e-02,
           7.5781e-01,  2.1387e-01]],

        [[-7.5378e-03, -4.8161e-05,  1.0132e-02,  ..., -6.2256e-03,
          -5.2185e-03, -4.3640e-03],
         [ 1.0156e+00, -2.6367e-01, -4.8242e-01,  ..., -1.2969e+00,
           5.2734e-01,  2.0996e-01],
         [-4.5898e-02,  1.1719e+00,  9.1797e-01,  ..., -1.0000e+00,
           8.6719e-01, -4.1211e-01],
         ...,
         [-2.7930e-01, -1.7734e+00, -4.6484e-01,  ..., -1.1484e+00,
          -7.7344e-01,  1.9434e-01],
         [-2.1484e-02, -4.3164e-01, -3.5547e-01,  ..., -4.2578e-01,
          -1.4844e+00, -4.3750e-01],
         [-1.2188e+00, -5.0781e-01,  4.6875e-02,  ..., -7.5195e-02,
           7.5781e-01,  2.1387e-01]]], dtype=torch.bfloat16)), (tensor([[[ 1.5198e-02,  4.3030e-03,  1.2131e-03,  ...,  3.6865e-02,
          -2.1387e-01, -1.9287e-02],
         [-1.3828e+00,  2.1875e-01, -4.2480e-02,  ...,  1.6562e+00,
           4.1406e-01,  3.2344e+00],
         [-9.0234e-01,  2.6001e-02, -2.6172e-01,  ...,  2.2656e+00,
           1.5625e+00,  2.1719e+00],
         ...,
         [-4.4336e-01, -5.3906e-01,  3.9453e-01,  ..., -3.8086e-02,
          -2.5469e+00,  2.5000e+00],
         [-8.5938e-02, -1.6895e-01,  7.3242e-03,  ..., -2.3242e-01,
          -1.8203e+00, -7.3828e-01],
         [ 4.6289e-01,  8.2031e-01, -6.4844e-01,  ..., -3.5469e+00,
          -3.0469e+00,  2.0000e+00]],

        [[ 1.5198e-02,  4.3030e-03,  1.2131e-03,  ...,  3.6865e-02,
          -2.1387e-01, -1.9287e-02],
         [-1.3828e+00,  2.1875e-01, -4.2480e-02,  ...,  1.6562e+00,
           4.1406e-01,  3.2344e+00],
         [-9.0234e-01,  2.6001e-02, -2.6172e-01,  ...,  2.2656e+00,
           1.5625e+00,  2.1719e+00],
         ...,
         [-4.4336e-01, -5.3906e-01,  3.9453e-01,  ..., -3.8086e-02,
          -2.5469e+00,  2.5000e+00],
         [-8.5938e-02, -1.6895e-01,  7.3242e-03,  ..., -2.3242e-01,
          -1.8203e+00, -7.3828e-01],
         [ 4.6289e-01,  8.2031e-01, -6.4844e-01,  ..., -3.5469e+00,
          -3.0469e+00,  2.0000e+00]],

        [[ 1.5198e-02,  4.3030e-03,  1.2131e-03,  ...,  3.6865e-02,
          -2.1387e-01, -1.9287e-02],
         [-1.3828e+00,  2.1875e-01, -4.2480e-02,  ...,  1.6562e+00,
           4.1406e-01,  3.2344e+00],
         [-9.0234e-01,  2.6001e-02, -2.6172e-01,  ...,  2.2656e+00,
           1.5625e+00,  2.1719e+00],
         ...,
         [-4.4336e-01, -5.3906e-01,  3.9453e-01,  ..., -3.8086e-02,
          -2.5469e+00,  2.5000e+00],
         [-8.5938e-02, -1.6895e-01,  7.3242e-03,  ..., -2.3242e-01,
          -1.8203e+00, -7.3828e-01],
         [ 4.6289e-01,  8.2031e-01, -6.4844e-01,  ..., -3.5469e+00,
          -3.0469e+00,  2.0000e+00]],

        ...,

        [[-1.1536e-02, -6.7139e-03,  3.7003e-04,  ...,  1.7944e-02,
           6.1279e-02,  1.0938e+00],
         [-4.4922e-02, -1.0078e+00,  9.3750e-02,  ..., -1.9609e+00,
           3.4219e+00, -8.7500e+00],
         [-1.6602e-01, -6.3672e-01,  5.1270e-03,  ...,  7.3828e-01,
          -2.9375e+00, -1.0000e+01],
         ...,
         [-3.1641e-01,  2.7148e-01, -2.7734e-01,  ...,  2.8906e+00,
           6.4844e-01, -6.7188e+00],
         [ 1.6992e-01,  3.1641e-01, -1.6211e-01,  ..., -5.4688e-01,
          -1.4609e+00, -7.9062e+00],
         [ 2.4902e-01,  7.0312e-01, -8.1250e-01,  ..., -5.4297e-01,
          -2.4531e+00, -7.5000e+00]],

        [[-1.1536e-02, -6.7139e-03,  3.7003e-04,  ...,  1.7944e-02,
           6.1279e-02,  1.0938e+00],
         [-4.4922e-02, -1.0078e+00,  9.3750e-02,  ..., -1.9609e+00,
           3.4219e+00, -8.7500e+00],
         [-1.6602e-01, -6.3672e-01,  5.1270e-03,  ...,  7.3828e-01,
          -2.9375e+00, -1.0000e+01],
         ...,
         [-3.1641e-01,  2.7148e-01, -2.7734e-01,  ...,  2.8906e+00,
           6.4844e-01, -6.7188e+00],
         [ 1.6992e-01,  3.1641e-01, -1.6211e-01,  ..., -5.4688e-01,
          -1.4609e+00, -7.9062e+00],
         [ 2.4902e-01,  7.0312e-01, -8.1250e-01,  ..., -5.4297e-01,
          -2.4531e+00, -7.5000e+00]],

        [[-1.1536e-02, -6.7139e-03,  3.7003e-04,  ...,  1.7944e-02,
           6.1279e-02,  1.0938e+00],
         [-4.4922e-02, -1.0078e+00,  9.3750e-02,  ..., -1.9609e+00,
           3.4219e+00, -8.7500e+00],
         [-1.6602e-01, -6.3672e-01,  5.1270e-03,  ...,  7.3828e-01,
          -2.9375e+00, -1.0000e+01],
         ...,
         [-3.1641e-01,  2.7148e-01, -2.7734e-01,  ...,  2.8906e+00,
           6.4844e-01, -6.7188e+00],
         [ 1.6992e-01,  3.1641e-01, -1.6211e-01,  ..., -5.4688e-01,
          -1.4609e+00, -7.9062e+00],
         [ 2.4902e-01,  7.0312e-01, -8.1250e-01,  ..., -5.4297e-01,
          -2.4531e+00, -7.5000e+00]]], dtype=torch.bfloat16), tensor([[[ 2.9755e-04,  3.8757e-03,  6.9427e-04,  ...,  5.5237e-03,
           1.3828e-04, -6.5918e-03],
         [ 1.3594e+00,  5.7422e-01, -1.0938e+00,  ..., -1.8828e+00,
           5.3516e-01,  5.5859e-01],
         [ 2.3315e-02, -2.5938e+00,  8.9844e-01,  ..., -9.7266e-01,
           3.0664e-01,  1.2266e+00],
         ...,
         [ 1.4062e+00, -5.8594e-01,  2.3438e-01,  ..., -6.3281e-01,
          -1.9043e-01, -1.8457e-01],
         [ 1.1875e+00, -9.2969e-01, -3.5156e-01,  ..., -2.8516e-01,
          -3.4766e-01,  5.4443e-02],
         [ 6.3281e-01,  3.4375e-01, -4.7266e-01,  ...,  1.0000e+00,
          -8.9355e-02,  3.7354e-02]],

        [[ 2.9755e-04,  3.8757e-03,  6.9427e-04,  ...,  5.5237e-03,
           1.3828e-04, -6.5918e-03],
         [ 1.3594e+00,  5.7422e-01, -1.0938e+00,  ..., -1.8828e+00,
           5.3516e-01,  5.5859e-01],
         [ 2.3315e-02, -2.5938e+00,  8.9844e-01,  ..., -9.7266e-01,
           3.0664e-01,  1.2266e+00],
         ...,
         [ 1.4062e+00, -5.8594e-01,  2.3438e-01,  ..., -6.3281e-01,
          -1.9043e-01, -1.8457e-01],
         [ 1.1875e+00, -9.2969e-01, -3.5156e-01,  ..., -2.8516e-01,
          -3.4766e-01,  5.4443e-02],
         [ 6.3281e-01,  3.4375e-01, -4.7266e-01,  ...,  1.0000e+00,
          -8.9355e-02,  3.7354e-02]],

        [[ 2.9755e-04,  3.8757e-03,  6.9427e-04,  ...,  5.5237e-03,
           1.3828e-04, -6.5918e-03],
         [ 1.3594e+00,  5.7422e-01, -1.0938e+00,  ..., -1.8828e+00,
           5.3516e-01,  5.5859e-01],
         [ 2.3315e-02, -2.5938e+00,  8.9844e-01,  ..., -9.7266e-01,
           3.0664e-01,  1.2266e+00],
         ...,
         [ 1.4062e+00, -5.8594e-01,  2.3438e-01,  ..., -6.3281e-01,
          -1.9043e-01, -1.8457e-01],
         [ 1.1875e+00, -9.2969e-01, -3.5156e-01,  ..., -2.8516e-01,
          -3.4766e-01,  5.4443e-02],
         [ 6.3281e-01,  3.4375e-01, -4.7266e-01,  ...,  1.0000e+00,
          -8.9355e-02,  3.7354e-02]],

        ...,

        [[ 2.6703e-03, -6.5918e-03, -9.8267e-03,  ..., -7.3242e-03,
          -2.3460e-04, -4.3945e-03],
         [-1.8906e+00,  7.5391e-01, -9.6191e-02,  ...,  2.6367e-01,
           1.5156e+00, -1.2031e+00],
         [ 1.4531e+00, -1.3359e+00,  9.0820e-02,  ..., -2.6406e+00,
           4.2188e+00,  3.1875e+00],
         ...,
         [ 5.2734e-01,  1.1562e+00, -9.4922e-01,  ...,  2.3071e-02,
          -2.3828e-01, -8.9844e-01],
         [-3.6719e-01, -1.5918e-01, -5.4297e-01,  ...,  2.0312e-01,
           1.0205e-01,  7.2266e-02],
         [ 3.8867e-01,  2.6758e-01,  2.2344e+00,  ...,  1.9531e+00,
          -3.7500e-01,  6.6895e-02]],

        [[ 2.6703e-03, -6.5918e-03, -9.8267e-03,  ..., -7.3242e-03,
          -2.3460e-04, -4.3945e-03],
         [-1.8906e+00,  7.5391e-01, -9.6191e-02,  ...,  2.6367e-01,
           1.5156e+00, -1.2031e+00],
         [ 1.4531e+00, -1.3359e+00,  9.0820e-02,  ..., -2.6406e+00,
           4.2188e+00,  3.1875e+00],
         ...,
         [ 5.2734e-01,  1.1562e+00, -9.4922e-01,  ...,  2.3071e-02,
          -2.3828e-01, -8.9844e-01],
         [-3.6719e-01, -1.5918e-01, -5.4297e-01,  ...,  2.0312e-01,
           1.0205e-01,  7.2266e-02],
         [ 3.8867e-01,  2.6758e-01,  2.2344e+00,  ...,  1.9531e+00,
          -3.7500e-01,  6.6895e-02]],

        [[ 2.6703e-03, -6.5918e-03, -9.8267e-03,  ..., -7.3242e-03,
          -2.3460e-04, -4.3945e-03],
         [-1.8906e+00,  7.5391e-01, -9.6191e-02,  ...,  2.6367e-01,
           1.5156e+00, -1.2031e+00],
         [ 1.4531e+00, -1.3359e+00,  9.0820e-02,  ..., -2.6406e+00,
           4.2188e+00,  3.1875e+00],
         ...,
         [ 5.2734e-01,  1.1562e+00, -9.4922e-01,  ...,  2.3071e-02,
          -2.3828e-01, -8.9844e-01],
         [-3.6719e-01, -1.5918e-01, -5.4297e-01,  ...,  2.0312e-01,
           1.0205e-01,  7.2266e-02],
         [ 3.8867e-01,  2.6758e-01,  2.2344e+00,  ...,  1.9531e+00,
          -3.7500e-01,  6.6895e-02]]], dtype=torch.bfloat16)), (tensor([[[-3.6926e-03,  3.5858e-03, -1.4191e-03,  ...,  2.9907e-02,
          -6.6406e-02,  5.3516e-01],
         [ 2.0938e+00,  1.7031e+00, -1.5156e+00,  ..., -2.1875e+00,
           1.8594e+00, -5.1875e+00],
         [ 7.6172e-01,  1.2109e+00, -1.4062e-01,  ..., -1.0938e+00,
          -1.8047e+00, -3.8281e+00],
         ...,
         [-8.5547e-01, -1.2812e+00, -2.6562e-01,  ..., -4.0820e-01,
           4.6562e+00, -5.0938e+00],
         [ 1.1230e-01, -3.1445e-01,  5.0000e-01,  ...,  8.0469e-01,
           5.0781e-01, -4.2812e+00],
         [ 5.5078e-01, -2.0312e-01, -6.4062e-01,  ..., -1.8984e+00,
          -5.2344e-01, -4.5938e+00]],

        [[-3.6926e-03,  3.5858e-03, -1.4191e-03,  ...,  2.9907e-02,
          -6.6406e-02,  5.3516e-01],
         [ 2.0938e+00,  1.7031e+00, -1.5156e+00,  ..., -2.1875e+00,
           1.8594e+00, -5.1875e+00],
         [ 7.6172e-01,  1.2109e+00, -1.4062e-01,  ..., -1.0938e+00,
          -1.8047e+00, -3.8281e+00],
         ...,
         [-8.5547e-01, -1.2812e+00, -2.6562e-01,  ..., -4.0820e-01,
           4.6562e+00, -5.0938e+00],
         [ 1.1230e-01, -3.1445e-01,  5.0000e-01,  ...,  8.0469e-01,
           5.0781e-01, -4.2812e+00],
         [ 5.5078e-01, -2.0312e-01, -6.4062e-01,  ..., -1.8984e+00,
          -5.2344e-01, -4.5938e+00]],

        [[-3.6926e-03,  3.5858e-03, -1.4191e-03,  ...,  2.9907e-02,
          -6.6406e-02,  5.3516e-01],
         [ 2.0938e+00,  1.7031e+00, -1.5156e+00,  ..., -2.1875e+00,
           1.8594e+00, -5.1875e+00],
         [ 7.6172e-01,  1.2109e+00, -1.4062e-01,  ..., -1.0938e+00,
          -1.8047e+00, -3.8281e+00],
         ...,
         [-8.5547e-01, -1.2812e+00, -2.6562e-01,  ..., -4.0820e-01,
           4.6562e+00, -5.0938e+00],
         [ 1.1230e-01, -3.1445e-01,  5.0000e-01,  ...,  8.0469e-01,
           5.0781e-01, -4.2812e+00],
         [ 5.5078e-01, -2.0312e-01, -6.4062e-01,  ..., -1.8984e+00,
          -5.2344e-01, -4.5938e+00]],

        ...,

        [[-1.0834e-03,  1.1597e-02,  6.6223e-03,  ...,  1.1426e-01,
           5.9082e-02,  7.5195e-02],
         [ 1.6719e+00, -5.2344e-01, -8.1641e-01,  ..., -6.9922e-01,
          -5.8594e-01, -1.0781e+00],
         [ 1.2344e+00, -1.4453e+00, -9.9219e-01,  ..., -6.4844e-01,
          -4.3945e-01, -6.7969e-01],
         ...,
         [-1.8945e-01,  1.4453e-01,  4.4141e-01,  ..., -1.1875e+00,
          -1.1953e+00, -1.0391e+00],
         [-4.3945e-01,  1.3203e+00, -2.2070e-01,  ..., -5.1172e-01,
          -1.0107e-01, -1.1523e-01],
         [-4.1992e-01,  3.9062e-01, -7.5000e-01,  ..., -9.4141e-01,
           4.8340e-02, -1.2422e+00]],

        [[-1.0834e-03,  1.1597e-02,  6.6223e-03,  ...,  1.1426e-01,
           5.9082e-02,  7.5195e-02],
         [ 1.6719e+00, -5.2344e-01, -8.1641e-01,  ..., -6.9922e-01,
          -5.8594e-01, -1.0781e+00],
         [ 1.2344e+00, -1.4453e+00, -9.9219e-01,  ..., -6.4844e-01,
          -4.3945e-01, -6.7969e-01],
         ...,
         [-1.8945e-01,  1.4453e-01,  4.4141e-01,  ..., -1.1875e+00,
          -1.1953e+00, -1.0391e+00],
         [-4.3945e-01,  1.3203e+00, -2.2070e-01,  ..., -5.1172e-01,
          -1.0107e-01, -1.1523e-01],
         [-4.1992e-01,  3.9062e-01, -7.5000e-01,  ..., -9.4141e-01,
           4.8340e-02, -1.2422e+00]],

        [[-1.0834e-03,  1.1597e-02,  6.6223e-03,  ...,  1.1426e-01,
           5.9082e-02,  7.5195e-02],
         [ 1.6719e+00, -5.2344e-01, -8.1641e-01,  ..., -6.9922e-01,
          -5.8594e-01, -1.0781e+00],
         [ 1.2344e+00, -1.4453e+00, -9.9219e-01,  ..., -6.4844e-01,
          -4.3945e-01, -6.7969e-01],
         ...,
         [-1.8945e-01,  1.4453e-01,  4.4141e-01,  ..., -1.1875e+00,
          -1.1953e+00, -1.0391e+00],
         [-4.3945e-01,  1.3203e+00, -2.2070e-01,  ..., -5.1172e-01,
          -1.0107e-01, -1.1523e-01],
         [-4.1992e-01,  3.9062e-01, -7.5000e-01,  ..., -9.4141e-01,
           4.8340e-02, -1.2422e+00]]], dtype=torch.bfloat16), tensor([[[-1.1215e-03, -1.3184e-02,  2.7313e-03,  ...,  1.4160e-02,
           2.8038e-04,  7.0496e-03],
         [ 6.5625e-01,  3.0078e-01,  3.0664e-01,  ..., -2.2754e-01,
          -8.6719e-01, -2.0703e-01],
         [ 3.8672e-01, -2.0020e-01,  3.7305e-01,  ...,  1.3477e-01,
           7.9297e-01,  3.1250e-01],
         ...,
         [-1.5938e+00,  2.8125e-01,  1.6357e-02,  ..., -7.7637e-02,
           2.5781e-01, -4.3359e-01],
         [-1.1719e+00, -1.6406e-01,  2.2583e-02,  ...,  4.1211e-01,
          -2.0117e-01, -1.8066e-01],
         [-7.1094e-01, -8.0566e-03,  7.1484e-01,  ...,  4.2578e-01,
          -1.9727e-01, -8.1641e-01]],

        [[-1.1215e-03, -1.3184e-02,  2.7313e-03,  ...,  1.4160e-02,
           2.8038e-04,  7.0496e-03],
         [ 6.5625e-01,  3.0078e-01,  3.0664e-01,  ..., -2.2754e-01,
          -8.6719e-01, -2.0703e-01],
         [ 3.8672e-01, -2.0020e-01,  3.7305e-01,  ...,  1.3477e-01,
           7.9297e-01,  3.1250e-01],
         ...,
         [-1.5938e+00,  2.8125e-01,  1.6357e-02,  ..., -7.7637e-02,
           2.5781e-01, -4.3359e-01],
         [-1.1719e+00, -1.6406e-01,  2.2583e-02,  ...,  4.1211e-01,
          -2.0117e-01, -1.8066e-01],
         [-7.1094e-01, -8.0566e-03,  7.1484e-01,  ...,  4.2578e-01,
          -1.9727e-01, -8.1641e-01]],

        [[-1.1215e-03, -1.3184e-02,  2.7313e-03,  ...,  1.4160e-02,
           2.8038e-04,  7.0496e-03],
         [ 6.5625e-01,  3.0078e-01,  3.0664e-01,  ..., -2.2754e-01,
          -8.6719e-01, -2.0703e-01],
         [ 3.8672e-01, -2.0020e-01,  3.7305e-01,  ...,  1.3477e-01,
           7.9297e-01,  3.1250e-01],
         ...,
         [-1.5938e+00,  2.8125e-01,  1.6357e-02,  ..., -7.7637e-02,
           2.5781e-01, -4.3359e-01],
         [-1.1719e+00, -1.6406e-01,  2.2583e-02,  ...,  4.1211e-01,
          -2.0117e-01, -1.8066e-01],
         [-7.1094e-01, -8.0566e-03,  7.1484e-01,  ...,  4.2578e-01,
          -1.9727e-01, -8.1641e-01]],

        ...,

        [[-9.1553e-03, -9.7656e-03, -6.1951e-03,  ...,  5.8899e-03,
           6.4697e-03, -7.2098e-04],
         [-5.0781e-01,  1.3281e+00, -1.4766e+00,  ..., -7.6172e-01,
          -1.5430e-01, -1.7656e+00],
         [-4.9414e-01,  2.2949e-01,  1.8203e+00,  ..., -3.0273e-01,
          -2.4805e-01, -6.5625e-01],
         ...,
         [-5.9814e-02,  2.3535e-01,  1.6250e+00,  ...,  1.0703e+00,
          -1.2656e+00,  2.3340e-01],
         [-2.2168e-01,  4.7461e-01,  8.6719e-01,  ...,  1.3672e-01,
           5.6885e-02, -7.2021e-03],
         [-9.2188e-01, -1.1328e+00,  1.0312e+00,  ...,  1.2969e+00,
           1.3203e+00, -7.2656e-01]],

        [[-9.1553e-03, -9.7656e-03, -6.1951e-03,  ...,  5.8899e-03,
           6.4697e-03, -7.2098e-04],
         [-5.0781e-01,  1.3281e+00, -1.4766e+00,  ..., -7.6172e-01,
          -1.5430e-01, -1.7656e+00],
         [-4.9414e-01,  2.2949e-01,  1.8203e+00,  ..., -3.0273e-01,
          -2.4805e-01, -6.5625e-01],
         ...,
         [-5.9814e-02,  2.3535e-01,  1.6250e+00,  ...,  1.0703e+00,
          -1.2656e+00,  2.3340e-01],
         [-2.2168e-01,  4.7461e-01,  8.6719e-01,  ...,  1.3672e-01,
           5.6885e-02, -7.2021e-03],
         [-9.2188e-01, -1.1328e+00,  1.0312e+00,  ...,  1.2969e+00,
           1.3203e+00, -7.2656e-01]],

        [[-9.1553e-03, -9.7656e-03, -6.1951e-03,  ...,  5.8899e-03,
           6.4697e-03, -7.2098e-04],
         [-5.0781e-01,  1.3281e+00, -1.4766e+00,  ..., -7.6172e-01,
          -1.5430e-01, -1.7656e+00],
         [-4.9414e-01,  2.2949e-01,  1.8203e+00,  ..., -3.0273e-01,
          -2.4805e-01, -6.5625e-01],
         ...,
         [-5.9814e-02,  2.3535e-01,  1.6250e+00,  ...,  1.0703e+00,
          -1.2656e+00,  2.3340e-01],
         [-2.2168e-01,  4.7461e-01,  8.6719e-01,  ...,  1.3672e-01,
           5.6885e-02, -7.2021e-03],
         [-9.2188e-01, -1.1328e+00,  1.0312e+00,  ...,  1.2969e+00,
           1.3203e+00, -7.2656e-01]]], dtype=torch.bfloat16)), (tensor([[[-2.0386e-02,  2.8801e-04,  1.3504e-03,  ..., -6.3281e-01,
          -2.7222e-02, -1.0645e-01],
         [-3.9062e-02,  2.2969e+00,  1.6562e+00,  ...,  1.0688e+01,
          -2.3750e+00,  5.7031e-01],
         [-1.7656e+00,  8.3594e-01,  4.0234e-01,  ...,  1.1438e+01,
          -7.1094e-01,  8.7109e-01],
         ...,
         [-7.3438e-01, -1.1875e+00,  7.0312e-01,  ...,  9.6875e+00,
           1.3047e+00,  1.0391e+00],
         [ 2.1484e-02, -1.3828e+00,  8.9844e-01,  ...,  8.6250e+00,
           6.4844e-01,  4.7656e-01],
         [ 1.7578e+00, -4.8047e-01,  3.2031e-01,  ...,  9.4375e+00,
           1.2305e-01, -1.7188e-01]],

        [[-2.0386e-02,  2.8801e-04,  1.3504e-03,  ..., -6.3281e-01,
          -2.7222e-02, -1.0645e-01],
         [-3.9062e-02,  2.2969e+00,  1.6562e+00,  ...,  1.0688e+01,
          -2.3750e+00,  5.7031e-01],
         [-1.7656e+00,  8.3594e-01,  4.0234e-01,  ...,  1.1438e+01,
          -7.1094e-01,  8.7109e-01],
         ...,
         [-7.3438e-01, -1.1875e+00,  7.0312e-01,  ...,  9.6875e+00,
           1.3047e+00,  1.0391e+00],
         [ 2.1484e-02, -1.3828e+00,  8.9844e-01,  ...,  8.6250e+00,
           6.4844e-01,  4.7656e-01],
         [ 1.7578e+00, -4.8047e-01,  3.2031e-01,  ...,  9.4375e+00,
           1.2305e-01, -1.7188e-01]],

        [[-2.0386e-02,  2.8801e-04,  1.3504e-03,  ..., -6.3281e-01,
          -2.7222e-02, -1.0645e-01],
         [-3.9062e-02,  2.2969e+00,  1.6562e+00,  ...,  1.0688e+01,
          -2.3750e+00,  5.7031e-01],
         [-1.7656e+00,  8.3594e-01,  4.0234e-01,  ...,  1.1438e+01,
          -7.1094e-01,  8.7109e-01],
         ...,
         [-7.3438e-01, -1.1875e+00,  7.0312e-01,  ...,  9.6875e+00,
           1.3047e+00,  1.0391e+00],
         [ 2.1484e-02, -1.3828e+00,  8.9844e-01,  ...,  8.6250e+00,
           6.4844e-01,  4.7656e-01],
         [ 1.7578e+00, -4.8047e-01,  3.2031e-01,  ...,  9.4375e+00,
           1.2305e-01, -1.7188e-01]],

        ...,

        [[ 4.9744e-03,  5.5237e-03,  7.2937e-03,  ..., -2.2949e-02,
           8.1787e-03,  1.5723e-01],
         [-1.3750e+00, -8.7891e-01, -8.5156e-01,  ...,  1.4453e+00,
          -2.7734e-01, -3.7812e+00],
         [-8.2812e-01, -5.9375e-01,  1.8652e-01,  ...,  3.1875e+00,
           2.5977e-01,  1.8555e-01],
         ...,
         [ 1.1328e+00,  5.0000e-01, -9.5703e-01,  ..., -1.5078e+00,
          -2.6250e+00, -4.0000e+00],
         [ 2.1094e-01, -8.7402e-02, -7.1484e-01,  ...,  5.3125e-01,
           4.9219e-01, -7.1094e-01],
         [ 1.1484e+00, -1.6016e+00, -7.5391e-01,  ...,  2.5000e+00,
          -1.0859e+00, -2.0938e+00]],

        [[ 4.9744e-03,  5.5237e-03,  7.2937e-03,  ..., -2.2949e-02,
           8.1787e-03,  1.5723e-01],
         [-1.3750e+00, -8.7891e-01, -8.5156e-01,  ...,  1.4453e+00,
          -2.7734e-01, -3.7812e+00],
         [-8.2812e-01, -5.9375e-01,  1.8652e-01,  ...,  3.1875e+00,
           2.5977e-01,  1.8555e-01],
         ...,
         [ 1.1328e+00,  5.0000e-01, -9.5703e-01,  ..., -1.5078e+00,
          -2.6250e+00, -4.0000e+00],
         [ 2.1094e-01, -8.7402e-02, -7.1484e-01,  ...,  5.3125e-01,
           4.9219e-01, -7.1094e-01],
         [ 1.1484e+00, -1.6016e+00, -7.5391e-01,  ...,  2.5000e+00,
          -1.0859e+00, -2.0938e+00]],

        [[ 4.9744e-03,  5.5237e-03,  7.2937e-03,  ..., -2.2949e-02,
           8.1787e-03,  1.5723e-01],
         [-1.3750e+00, -8.7891e-01, -8.5156e-01,  ...,  1.4453e+00,
          -2.7734e-01, -3.7812e+00],
         [-8.2812e-01, -5.9375e-01,  1.8652e-01,  ...,  3.1875e+00,
           2.5977e-01,  1.8555e-01],
         ...,
         [ 1.1328e+00,  5.0000e-01, -9.5703e-01,  ..., -1.5078e+00,
          -2.6250e+00, -4.0000e+00],
         [ 2.1094e-01, -8.7402e-02, -7.1484e-01,  ...,  5.3125e-01,
           4.9219e-01, -7.1094e-01],
         [ 1.1484e+00, -1.6016e+00, -7.5391e-01,  ...,  2.5000e+00,
          -1.0859e+00, -2.0938e+00]]], dtype=torch.bfloat16), tensor([[[-0.0063, -0.0132, -0.0074,  ...,  0.0058, -0.0055, -0.0167],
         [-1.8672, -0.9336, -1.5859,  ..., -1.7891, -0.7305, -0.5938],
         [-1.5156,  0.0366,  0.4688,  ..., -0.1069,  0.0542,  0.1040],
         ...,
         [ 0.3672,  0.3418, -1.2734,  ...,  0.5977, -0.7734, -0.1543],
         [ 0.6211, -0.0996, -0.0452,  ..., -0.3730, -0.6523,  0.3340],
         [-1.1875,  0.2227,  1.3828,  ...,  0.5156,  0.0330,  0.7070]],

        [[-0.0063, -0.0132, -0.0074,  ...,  0.0058, -0.0055, -0.0167],
         [-1.8672, -0.9336, -1.5859,  ..., -1.7891, -0.7305, -0.5938],
         [-1.5156,  0.0366,  0.4688,  ..., -0.1069,  0.0542,  0.1040],
         ...,
         [ 0.3672,  0.3418, -1.2734,  ...,  0.5977, -0.7734, -0.1543],
         [ 0.6211, -0.0996, -0.0452,  ..., -0.3730, -0.6523,  0.3340],
         [-1.1875,  0.2227,  1.3828,  ...,  0.5156,  0.0330,  0.7070]],

        [[-0.0063, -0.0132, -0.0074,  ...,  0.0058, -0.0055, -0.0167],
         [-1.8672, -0.9336, -1.5859,  ..., -1.7891, -0.7305, -0.5938],
         [-1.5156,  0.0366,  0.4688,  ..., -0.1069,  0.0542,  0.1040],
         ...,
         [ 0.3672,  0.3418, -1.2734,  ...,  0.5977, -0.7734, -0.1543],
         [ 0.6211, -0.0996, -0.0452,  ..., -0.3730, -0.6523,  0.3340],
         [-1.1875,  0.2227,  1.3828,  ...,  0.5156,  0.0330,  0.7070]],

        ...,

        [[ 0.0068,  0.0091,  0.0122,  ..., -0.0236, -0.0041,  0.0050],
         [ 0.5156,  0.7461,  1.4219,  ...,  1.0859,  1.1172,  0.5859],
         [-0.3027,  0.3945,  1.0859,  ...,  0.7852,  0.3652, -0.1992],
         ...,
         [-0.2930, -0.4766, -0.0162,  ...,  0.1172, -1.1328,  0.0571],
         [ 0.2930, -0.0217, -0.4297,  ...,  0.1543,  0.2285,  0.2598],
         [-0.9648, -0.8828,  0.5430,  ...,  0.3125, -0.2676,  0.3633]],

        [[ 0.0068,  0.0091,  0.0122,  ..., -0.0236, -0.0041,  0.0050],
         [ 0.5156,  0.7461,  1.4219,  ...,  1.0859,  1.1172,  0.5859],
         [-0.3027,  0.3945,  1.0859,  ...,  0.7852,  0.3652, -0.1992],
         ...,
         [-0.2930, -0.4766, -0.0162,  ...,  0.1172, -1.1328,  0.0571],
         [ 0.2930, -0.0217, -0.4297,  ...,  0.1543,  0.2285,  0.2598],
         [-0.9648, -0.8828,  0.5430,  ...,  0.3125, -0.2676,  0.3633]],

        [[ 0.0068,  0.0091,  0.0122,  ..., -0.0236, -0.0041,  0.0050],
         [ 0.5156,  0.7461,  1.4219,  ...,  1.0859,  1.1172,  0.5859],
         [-0.3027,  0.3945,  1.0859,  ...,  0.7852,  0.3652, -0.1992],
         ...,
         [-0.2930, -0.4766, -0.0162,  ...,  0.1172, -1.1328,  0.0571],
         [ 0.2930, -0.0217, -0.4297,  ...,  0.1543,  0.2285,  0.2598],
         [-0.9648, -0.8828,  0.5430,  ...,  0.3125, -0.2676,  0.3633]]],
       dtype=torch.bfloat16)), (tensor([[[ 2.4261e-03,  2.3804e-03,  2.8076e-03,  ..., -3.6377e-02,
          -2.4219e-01,  8.2812e-01],
         [-4.5312e-01, -1.1094e+00,  3.9844e-01,  ...,  6.1328e-01,
          -1.2354e-01, -7.6875e+00],
         [-8.2422e-01, -1.3516e+00,  9.1406e-01,  ..., -1.0859e+00,
           2.6875e+00, -9.0625e+00],
         ...,
         [-1.3574e-01,  3.9648e-01, -7.7734e-01,  ...,  3.4531e+00,
          -5.5625e+00, -7.4688e+00],
         [-2.7344e-01, -4.7852e-01, -3.4961e-01,  ...,  7.0312e-01,
          -2.4375e+00, -7.1562e+00],
         [ 1.2031e+00, -1.5234e-01,  6.7188e-01,  ..., -4.6875e-01,
           1.8203e+00, -8.0625e+00]],

        [[ 2.4261e-03,  2.3804e-03,  2.8076e-03,  ..., -3.6377e-02,
          -2.4219e-01,  8.2812e-01],
         [-4.5312e-01, -1.1094e+00,  3.9844e-01,  ...,  6.1328e-01,
          -1.2354e-01, -7.6875e+00],
         [-8.2422e-01, -1.3516e+00,  9.1406e-01,  ..., -1.0859e+00,
           2.6875e+00, -9.0625e+00],
         ...,
         [-1.3574e-01,  3.9648e-01, -7.7734e-01,  ...,  3.4531e+00,
          -5.5625e+00, -7.4688e+00],
         [-2.7344e-01, -4.7852e-01, -3.4961e-01,  ...,  7.0312e-01,
          -2.4375e+00, -7.1562e+00],
         [ 1.2031e+00, -1.5234e-01,  6.7188e-01,  ..., -4.6875e-01,
           1.8203e+00, -8.0625e+00]],

        [[ 2.4261e-03,  2.3804e-03,  2.8076e-03,  ..., -3.6377e-02,
          -2.4219e-01,  8.2812e-01],
         [-4.5312e-01, -1.1094e+00,  3.9844e-01,  ...,  6.1328e-01,
          -1.2354e-01, -7.6875e+00],
         [-8.2422e-01, -1.3516e+00,  9.1406e-01,  ..., -1.0859e+00,
           2.6875e+00, -9.0625e+00],
         ...,
         [-1.3574e-01,  3.9648e-01, -7.7734e-01,  ...,  3.4531e+00,
          -5.5625e+00, -7.4688e+00],
         [-2.7344e-01, -4.7852e-01, -3.4961e-01,  ...,  7.0312e-01,
          -2.4375e+00, -7.1562e+00],
         [ 1.2031e+00, -1.5234e-01,  6.7188e-01,  ..., -4.6875e-01,
           1.8203e+00, -8.0625e+00]],

        ...,

        [[ 1.5747e-02, -5.4016e-03,  8.7891e-03,  ..., -1.3184e-02,
          -5.9326e-02,  3.2812e-01],
         [-2.4688e+00,  2.1875e-01, -9.1797e-01,  ..., -9.3750e-01,
           1.6250e+00, -2.6875e+00],
         [-2.8125e-01,  2.2461e-01, -4.1406e-01,  ..., -2.9844e+00,
           1.8555e-01, -2.4531e+00],
         ...,
         [-1.4941e-01, -6.0547e-01,  1.0059e-01,  ..., -1.7109e+00,
          -3.5742e-01, -3.2969e+00],
         [ 5.1562e-01, -2.5781e-01, -1.9922e-01,  ..., -3.3750e+00,
          -1.2109e+00, -1.0156e+00],
         [ 9.1797e-01, -6.1719e-01, -1.4766e+00,  ..., -3.3281e+00,
           2.3438e+00, -1.8047e+00]],

        [[ 1.5747e-02, -5.4016e-03,  8.7891e-03,  ..., -1.3184e-02,
          -5.9326e-02,  3.2812e-01],
         [-2.4688e+00,  2.1875e-01, -9.1797e-01,  ..., -9.3750e-01,
           1.6250e+00, -2.6875e+00],
         [-2.8125e-01,  2.2461e-01, -4.1406e-01,  ..., -2.9844e+00,
           1.8555e-01, -2.4531e+00],
         ...,
         [-1.4941e-01, -6.0547e-01,  1.0059e-01,  ..., -1.7109e+00,
          -3.5742e-01, -3.2969e+00],
         [ 5.1562e-01, -2.5781e-01, -1.9922e-01,  ..., -3.3750e+00,
          -1.2109e+00, -1.0156e+00],
         [ 9.1797e-01, -6.1719e-01, -1.4766e+00,  ..., -3.3281e+00,
           2.3438e+00, -1.8047e+00]],

        [[ 1.5747e-02, -5.4016e-03,  8.7891e-03,  ..., -1.3184e-02,
          -5.9326e-02,  3.2812e-01],
         [-2.4688e+00,  2.1875e-01, -9.1797e-01,  ..., -9.3750e-01,
           1.6250e+00, -2.6875e+00],
         [-2.8125e-01,  2.2461e-01, -4.1406e-01,  ..., -2.9844e+00,
           1.8555e-01, -2.4531e+00],
         ...,
         [-1.4941e-01, -6.0547e-01,  1.0059e-01,  ..., -1.7109e+00,
          -3.5742e-01, -3.2969e+00],
         [ 5.1562e-01, -2.5781e-01, -1.9922e-01,  ..., -3.3750e+00,
          -1.2109e+00, -1.0156e+00],
         [ 9.1797e-01, -6.1719e-01, -1.4766e+00,  ..., -3.3281e+00,
           2.3438e+00, -1.8047e+00]]], dtype=torch.bfloat16), tensor([[[-5.2795e-03,  2.1648e-04, -1.0498e-02,  ..., -1.2054e-03,
           1.2665e-03,  1.0254e-02],
         [ 1.1016e+00, -6.9922e-01,  2.5391e-01,  ..., -3.6914e-01,
           2.3438e-01, -1.0791e-01],
         [ 1.4551e-01, -2.8198e-02,  4.4727e-01,  ..., -1.1719e+00,
           6.7188e-01, -6.7188e-01],
         ...,
         [ 1.4297e+00,  1.7734e+00,  2.4062e+00,  ...,  2.2188e+00,
          -4.1562e+00,  3.7344e+00],
         [ 1.1250e+00,  8.6719e-01,  1.9824e-01,  ...,  1.0498e-01,
          -5.9766e-01,  9.2578e-01],
         [-1.2793e-01,  2.3438e-01,  4.6680e-01,  ..., -9.1406e-01,
           3.9258e-01, -3.1055e-01]],

        [[-5.2795e-03,  2.1648e-04, -1.0498e-02,  ..., -1.2054e-03,
           1.2665e-03,  1.0254e-02],
         [ 1.1016e+00, -6.9922e-01,  2.5391e-01,  ..., -3.6914e-01,
           2.3438e-01, -1.0791e-01],
         [ 1.4551e-01, -2.8198e-02,  4.4727e-01,  ..., -1.1719e+00,
           6.7188e-01, -6.7188e-01],
         ...,
         [ 1.4297e+00,  1.7734e+00,  2.4062e+00,  ...,  2.2188e+00,
          -4.1562e+00,  3.7344e+00],
         [ 1.1250e+00,  8.6719e-01,  1.9824e-01,  ...,  1.0498e-01,
          -5.9766e-01,  9.2578e-01],
         [-1.2793e-01,  2.3438e-01,  4.6680e-01,  ..., -9.1406e-01,
           3.9258e-01, -3.1055e-01]],

        [[-5.2795e-03,  2.1648e-04, -1.0498e-02,  ..., -1.2054e-03,
           1.2665e-03,  1.0254e-02],
         [ 1.1016e+00, -6.9922e-01,  2.5391e-01,  ..., -3.6914e-01,
           2.3438e-01, -1.0791e-01],
         [ 1.4551e-01, -2.8198e-02,  4.4727e-01,  ..., -1.1719e+00,
           6.7188e-01, -6.7188e-01],
         ...,
         [ 1.4297e+00,  1.7734e+00,  2.4062e+00,  ...,  2.2188e+00,
          -4.1562e+00,  3.7344e+00],
         [ 1.1250e+00,  8.6719e-01,  1.9824e-01,  ...,  1.0498e-01,
          -5.9766e-01,  9.2578e-01],
         [-1.2793e-01,  2.3438e-01,  4.6680e-01,  ..., -9.1406e-01,
           3.9258e-01, -3.1055e-01]],

        ...,

        [[-1.2054e-03,  1.3611e-02, -3.5553e-03,  ..., -1.2817e-03,
          -1.3123e-03,  2.8839e-03],
         [-1.1377e-01, -1.1875e+00,  1.1719e+00,  ..., -1.6250e+00,
           4.6680e-01, -6.9531e-01],
         [ 1.9453e+00,  1.0840e-01,  1.4038e-02,  ...,  8.9844e-02,
          -5.5078e-01,  4.6484e-01],
         ...,
         [-3.6719e-01, -7.7344e-01,  5.8984e-01,  ..., -6.6895e-02,
           4.2773e-01, -7.6562e-01],
         [ 7.6172e-02, -5.0781e-01,  1.3965e-01,  ...,  7.7344e-01,
           6.1719e-01, -5.9766e-01],
         [-3.9062e-01,  1.7334e-02, -5.6250e-01,  ...,  1.3438e+00,
           4.4336e-01,  1.1768e-01]],

        [[-1.2054e-03,  1.3611e-02, -3.5553e-03,  ..., -1.2817e-03,
          -1.3123e-03,  2.8839e-03],
         [-1.1377e-01, -1.1875e+00,  1.1719e+00,  ..., -1.6250e+00,
           4.6680e-01, -6.9531e-01],
         [ 1.9453e+00,  1.0840e-01,  1.4038e-02,  ...,  8.9844e-02,
          -5.5078e-01,  4.6484e-01],
         ...,
         [-3.6719e-01, -7.7344e-01,  5.8984e-01,  ..., -6.6895e-02,
           4.2773e-01, -7.6562e-01],
         [ 7.6172e-02, -5.0781e-01,  1.3965e-01,  ...,  7.7344e-01,
           6.1719e-01, -5.9766e-01],
         [-3.9062e-01,  1.7334e-02, -5.6250e-01,  ...,  1.3438e+00,
           4.4336e-01,  1.1768e-01]],

        [[-1.2054e-03,  1.3611e-02, -3.5553e-03,  ..., -1.2817e-03,
          -1.3123e-03,  2.8839e-03],
         [-1.1377e-01, -1.1875e+00,  1.1719e+00,  ..., -1.6250e+00,
           4.6680e-01, -6.9531e-01],
         [ 1.9453e+00,  1.0840e-01,  1.4038e-02,  ...,  8.9844e-02,
          -5.5078e-01,  4.6484e-01],
         ...,
         [-3.6719e-01, -7.7344e-01,  5.8984e-01,  ..., -6.6895e-02,
           4.2773e-01, -7.6562e-01],
         [ 7.6172e-02, -5.0781e-01,  1.3965e-01,  ...,  7.7344e-01,
           6.1719e-01, -5.9766e-01],
         [-3.9062e-01,  1.7334e-02, -5.6250e-01,  ...,  1.3438e+00,
           4.4336e-01,  1.1768e-01]]], dtype=torch.bfloat16)), (tensor([[[ 1.3977e-02, -1.9302e-03,  5.9509e-03,  ..., -5.0537e-02,
          -9.7656e-02, -5.2734e-01],
         [-1.3203e+00, -7.7344e-01, -8.3984e-01,  ..., -8.1641e-01,
          -6.2109e-01,  8.3750e+00],
         [ 3.7500e-01,  5.3125e-01, -1.0938e+00,  ..., -1.4062e-01,
           9.2285e-02,  8.3750e+00],
         ...,
         [ 1.4766e+00, -5.0391e-01,  1.0156e+00,  ...,  2.3594e+00,
           8.2812e-01,  7.3438e+00],
         [ 7.5000e-01, -1.3984e+00, -9.0625e-01,  ...,  1.5781e+00,
           1.3047e+00,  7.1250e+00],
         [-7.2656e-01, -2.0469e+00, -1.5156e+00,  ...,  1.2969e+00,
           1.9688e+00,  7.9062e+00]],

        [[ 1.3977e-02, -1.9302e-03,  5.9509e-03,  ..., -5.0537e-02,
          -9.7656e-02, -5.2734e-01],
         [-1.3203e+00, -7.7344e-01, -8.3984e-01,  ..., -8.1641e-01,
          -6.2109e-01,  8.3750e+00],
         [ 3.7500e-01,  5.3125e-01, -1.0938e+00,  ..., -1.4062e-01,
           9.2285e-02,  8.3750e+00],
         ...,
         [ 1.4766e+00, -5.0391e-01,  1.0156e+00,  ...,  2.3594e+00,
           8.2812e-01,  7.3438e+00],
         [ 7.5000e-01, -1.3984e+00, -9.0625e-01,  ...,  1.5781e+00,
           1.3047e+00,  7.1250e+00],
         [-7.2656e-01, -2.0469e+00, -1.5156e+00,  ...,  1.2969e+00,
           1.9688e+00,  7.9062e+00]],

        [[ 1.3977e-02, -1.9302e-03,  5.9509e-03,  ..., -5.0537e-02,
          -9.7656e-02, -5.2734e-01],
         [-1.3203e+00, -7.7344e-01, -8.3984e-01,  ..., -8.1641e-01,
          -6.2109e-01,  8.3750e+00],
         [ 3.7500e-01,  5.3125e-01, -1.0938e+00,  ..., -1.4062e-01,
           9.2285e-02,  8.3750e+00],
         ...,
         [ 1.4766e+00, -5.0391e-01,  1.0156e+00,  ...,  2.3594e+00,
           8.2812e-01,  7.3438e+00],
         [ 7.5000e-01, -1.3984e+00, -9.0625e-01,  ...,  1.5781e+00,
           1.3047e+00,  7.1250e+00],
         [-7.2656e-01, -2.0469e+00, -1.5156e+00,  ...,  1.2969e+00,
           1.9688e+00,  7.9062e+00]],

        ...,

        [[ 3.8719e-04, -5.5237e-03,  4.5013e-04,  ...,  4.7852e-02,
           4.2114e-03,  6.7383e-02],
         [-2.9062e+00,  1.9297e+00,  1.4219e+00,  ...,  3.7305e-01,
          -9.2578e-01,  1.9766e+00],
         [-1.4062e+00,  7.8906e-01,  4.8047e-01,  ..., -1.0840e-01,
           1.9922e+00,  8.5938e-01],
         ...,
         [ 6.9531e-01,  6.0938e-01,  1.4297e+00,  ...,  8.3203e-01,
           7.2656e-01, -1.8848e-01],
         [ 2.4414e-01,  8.5938e-02,  7.0801e-02,  ...,  3.3203e-02,
           1.9219e+00, -1.5938e+00],
         [ 1.6953e+00,  1.1328e+00,  2.1875e+00,  ..., -2.9883e-01,
          -7.6953e-01, -5.1953e-01]],

        [[ 3.8719e-04, -5.5237e-03,  4.5013e-04,  ...,  4.7852e-02,
           4.2114e-03,  6.7383e-02],
         [-2.9062e+00,  1.9297e+00,  1.4219e+00,  ...,  3.7305e-01,
          -9.2578e-01,  1.9766e+00],
         [-1.4062e+00,  7.8906e-01,  4.8047e-01,  ..., -1.0840e-01,
           1.9922e+00,  8.5938e-01],
         ...,
         [ 6.9531e-01,  6.0938e-01,  1.4297e+00,  ...,  8.3203e-01,
           7.2656e-01, -1.8848e-01],
         [ 2.4414e-01,  8.5938e-02,  7.0801e-02,  ...,  3.3203e-02,
           1.9219e+00, -1.5938e+00],
         [ 1.6953e+00,  1.1328e+00,  2.1875e+00,  ..., -2.9883e-01,
          -7.6953e-01, -5.1953e-01]],

        [[ 3.8719e-04, -5.5237e-03,  4.5013e-04,  ...,  4.7852e-02,
           4.2114e-03,  6.7383e-02],
         [-2.9062e+00,  1.9297e+00,  1.4219e+00,  ...,  3.7305e-01,
          -9.2578e-01,  1.9766e+00],
         [-1.4062e+00,  7.8906e-01,  4.8047e-01,  ..., -1.0840e-01,
           1.9922e+00,  8.5938e-01],
         ...,
         [ 6.9531e-01,  6.0938e-01,  1.4297e+00,  ...,  8.3203e-01,
           7.2656e-01, -1.8848e-01],
         [ 2.4414e-01,  8.5938e-02,  7.0801e-02,  ...,  3.3203e-02,
           1.9219e+00, -1.5938e+00],
         [ 1.6953e+00,  1.1328e+00,  2.1875e+00,  ..., -2.9883e-01,
          -7.6953e-01, -5.1953e-01]]], dtype=torch.bfloat16), tensor([[[ 6.1035e-03,  4.6997e-03,  3.7384e-04,  ..., -1.2817e-03,
           8.5449e-03,  3.8757e-03],
         [-1.0234e+00, -6.3672e-01, -7.3047e-01,  ...,  4.5312e-01,
          -1.2598e-01,  1.3750e+00],
         [ 2.8516e-01, -2.5312e+00, -1.6309e-01,  ...,  1.1953e+00,
           8.6328e-01,  5.8594e-01],
         ...,
         [ 2.3125e+00,  4.6094e-01, -1.4688e+00,  ..., -6.4062e-01,
           2.8906e-01, -4.1602e-01],
         [ 8.6328e-01, -4.2773e-01, -1.3516e+00,  ..., -4.7852e-01,
           9.4922e-01, -4.8242e-01],
         [ 9.8438e-01, -1.0156e+00, -1.1406e+00,  ..., -1.1172e+00,
           1.4941e-01, -3.0859e-01]],

        [[ 6.1035e-03,  4.6997e-03,  3.7384e-04,  ..., -1.2817e-03,
           8.5449e-03,  3.8757e-03],
         [-1.0234e+00, -6.3672e-01, -7.3047e-01,  ...,  4.5312e-01,
          -1.2598e-01,  1.3750e+00],
         [ 2.8516e-01, -2.5312e+00, -1.6309e-01,  ...,  1.1953e+00,
           8.6328e-01,  5.8594e-01],
         ...,
         [ 2.3125e+00,  4.6094e-01, -1.4688e+00,  ..., -6.4062e-01,
           2.8906e-01, -4.1602e-01],
         [ 8.6328e-01, -4.2773e-01, -1.3516e+00,  ..., -4.7852e-01,
           9.4922e-01, -4.8242e-01],
         [ 9.8438e-01, -1.0156e+00, -1.1406e+00,  ..., -1.1172e+00,
           1.4941e-01, -3.0859e-01]],

        [[ 6.1035e-03,  4.6997e-03,  3.7384e-04,  ..., -1.2817e-03,
           8.5449e-03,  3.8757e-03],
         [-1.0234e+00, -6.3672e-01, -7.3047e-01,  ...,  4.5312e-01,
          -1.2598e-01,  1.3750e+00],
         [ 2.8516e-01, -2.5312e+00, -1.6309e-01,  ...,  1.1953e+00,
           8.6328e-01,  5.8594e-01],
         ...,
         [ 2.3125e+00,  4.6094e-01, -1.4688e+00,  ..., -6.4062e-01,
           2.8906e-01, -4.1602e-01],
         [ 8.6328e-01, -4.2773e-01, -1.3516e+00,  ..., -4.7852e-01,
           9.4922e-01, -4.8242e-01],
         [ 9.8438e-01, -1.0156e+00, -1.1406e+00,  ..., -1.1172e+00,
           1.4941e-01, -3.0859e-01]],

        ...,

        [[-3.0136e-04, -1.9379e-03, -3.3569e-03,  ..., -4.6692e-03,
          -7.5989e-03, -9.1934e-04],
         [-2.0117e-01, -1.7969e+00,  1.0791e-01,  ...,  2.6406e+00,
          -6.2500e-01, -4.7656e-01],
         [-1.2734e+00,  4.7070e-01, -9.9609e-01,  ..., -2.3438e+00,
           1.2344e+00,  9.4141e-01],
         ...,
         [-1.7773e-01, -8.2031e-01, -7.7148e-02,  ..., -4.6875e-01,
           1.4355e-01, -1.6309e-01],
         [-4.1797e-01,  1.2188e+00, -3.3984e-01,  ..., -1.3672e+00,
           1.1484e+00,  4.8633e-01],
         [-1.6406e-01,  7.0801e-02, -1.2451e-01,  ..., -2.0801e-01,
           8.0859e-01,  1.4038e-02]],

        [[-3.0136e-04, -1.9379e-03, -3.3569e-03,  ..., -4.6692e-03,
          -7.5989e-03, -9.1934e-04],
         [-2.0117e-01, -1.7969e+00,  1.0791e-01,  ...,  2.6406e+00,
          -6.2500e-01, -4.7656e-01],
         [-1.2734e+00,  4.7070e-01, -9.9609e-01,  ..., -2.3438e+00,
           1.2344e+00,  9.4141e-01],
         ...,
         [-1.7773e-01, -8.2031e-01, -7.7148e-02,  ..., -4.6875e-01,
           1.4355e-01, -1.6309e-01],
         [-4.1797e-01,  1.2188e+00, -3.3984e-01,  ..., -1.3672e+00,
           1.1484e+00,  4.8633e-01],
         [-1.6406e-01,  7.0801e-02, -1.2451e-01,  ..., -2.0801e-01,
           8.0859e-01,  1.4038e-02]],

        [[-3.0136e-04, -1.9379e-03, -3.3569e-03,  ..., -4.6692e-03,
          -7.5989e-03, -9.1934e-04],
         [-2.0117e-01, -1.7969e+00,  1.0791e-01,  ...,  2.6406e+00,
          -6.2500e-01, -4.7656e-01],
         [-1.2734e+00,  4.7070e-01, -9.9609e-01,  ..., -2.3438e+00,
           1.2344e+00,  9.4141e-01],
         ...,
         [-1.7773e-01, -8.2031e-01, -7.7148e-02,  ..., -4.6875e-01,
           1.4355e-01, -1.6309e-01],
         [-4.1797e-01,  1.2188e+00, -3.3984e-01,  ..., -1.3672e+00,
           1.1484e+00,  4.8633e-01],
         [-1.6406e-01,  7.0801e-02, -1.2451e-01,  ..., -2.0801e-01,
           8.0859e-01,  1.4038e-02]]], dtype=torch.bfloat16)), (tensor([[[-1.1215e-03,  1.2283e-03,  4.0588e-03,  ...,  6.5430e-02,
           7.5684e-03, -4.1748e-02],
         [-5.1953e-01, -8.5938e-01, -1.2812e+00,  ..., -8.3203e-01,
           1.4141e+00,  1.0859e+00],
         [-1.6719e+00, -1.0156e+00, -2.2500e+00,  ..., -1.5859e+00,
          -1.4297e+00,  2.0781e+00],
         ...,
         [-3.5938e-01,  2.8516e-01,  2.2070e-01,  ..., -2.7344e+00,
          -8.9844e-02,  2.7812e+00],
         [-2.6758e-01,  4.6680e-01,  2.0508e-02,  ..., -4.0625e+00,
           6.5234e-01,  2.2812e+00],
         [ 1.6953e+00, -2.8906e-01, -9.8828e-01,  ..., -2.5938e+00,
          -2.1484e-01,  1.8281e+00]],

        [[-1.1215e-03,  1.2283e-03,  4.0588e-03,  ...,  6.5430e-02,
           7.5684e-03, -4.1748e-02],
         [-5.1953e-01, -8.5938e-01, -1.2812e+00,  ..., -8.3203e-01,
           1.4141e+00,  1.0859e+00],
         [-1.6719e+00, -1.0156e+00, -2.2500e+00,  ..., -1.5859e+00,
          -1.4297e+00,  2.0781e+00],
         ...,
         [-3.5938e-01,  2.8516e-01,  2.2070e-01,  ..., -2.7344e+00,
          -8.9844e-02,  2.7812e+00],
         [-2.6758e-01,  4.6680e-01,  2.0508e-02,  ..., -4.0625e+00,
           6.5234e-01,  2.2812e+00],
         [ 1.6953e+00, -2.8906e-01, -9.8828e-01,  ..., -2.5938e+00,
          -2.1484e-01,  1.8281e+00]],

        [[-1.1215e-03,  1.2283e-03,  4.0588e-03,  ...,  6.5430e-02,
           7.5684e-03, -4.1748e-02],
         [-5.1953e-01, -8.5938e-01, -1.2812e+00,  ..., -8.3203e-01,
           1.4141e+00,  1.0859e+00],
         [-1.6719e+00, -1.0156e+00, -2.2500e+00,  ..., -1.5859e+00,
          -1.4297e+00,  2.0781e+00],
         ...,
         [-3.5938e-01,  2.8516e-01,  2.2070e-01,  ..., -2.7344e+00,
          -8.9844e-02,  2.7812e+00],
         [-2.6758e-01,  4.6680e-01,  2.0508e-02,  ..., -4.0625e+00,
           6.5234e-01,  2.2812e+00],
         [ 1.6953e+00, -2.8906e-01, -9.8828e-01,  ..., -2.5938e+00,
          -2.1484e-01,  1.8281e+00]],

        ...,

        [[-2.6245e-03,  7.7820e-03, -2.7924e-03,  ...,  9.4604e-03,
           4.3640e-03, -3.2227e-01],
         [-5.1172e-01, -1.6953e+00, -1.0156e+00,  ..., -1.5156e+00,
           2.1250e+00,  1.0250e+01],
         [-1.2812e+00, -1.2109e+00, -8.0078e-01,  ..., -3.9648e-01,
           2.2559e-01,  1.0125e+01],
         ...,
         [ 3.4375e-01,  9.1406e-01,  4.7656e-01,  ..., -2.1484e-02,
           2.7031e+00,  7.5625e+00],
         [ 6.3672e-01,  4.1406e-01, -7.1875e-01,  ..., -9.0625e-01,
           1.1250e+00,  8.0000e+00],
         [ 1.3516e+00, -1.4844e-01, -8.2812e-01,  ..., -1.6562e+00,
           4.6484e-01,  8.8125e+00]],

        [[-2.6245e-03,  7.7820e-03, -2.7924e-03,  ...,  9.4604e-03,
           4.3640e-03, -3.2227e-01],
         [-5.1172e-01, -1.6953e+00, -1.0156e+00,  ..., -1.5156e+00,
           2.1250e+00,  1.0250e+01],
         [-1.2812e+00, -1.2109e+00, -8.0078e-01,  ..., -3.9648e-01,
           2.2559e-01,  1.0125e+01],
         ...,
         [ 3.4375e-01,  9.1406e-01,  4.7656e-01,  ..., -2.1484e-02,
           2.7031e+00,  7.5625e+00],
         [ 6.3672e-01,  4.1406e-01, -7.1875e-01,  ..., -9.0625e-01,
           1.1250e+00,  8.0000e+00],
         [ 1.3516e+00, -1.4844e-01, -8.2812e-01,  ..., -1.6562e+00,
           4.6484e-01,  8.8125e+00]],

        [[-2.6245e-03,  7.7820e-03, -2.7924e-03,  ...,  9.4604e-03,
           4.3640e-03, -3.2227e-01],
         [-5.1172e-01, -1.6953e+00, -1.0156e+00,  ..., -1.5156e+00,
           2.1250e+00,  1.0250e+01],
         [-1.2812e+00, -1.2109e+00, -8.0078e-01,  ..., -3.9648e-01,
           2.2559e-01,  1.0125e+01],
         ...,
         [ 3.4375e-01,  9.1406e-01,  4.7656e-01,  ..., -2.1484e-02,
           2.7031e+00,  7.5625e+00],
         [ 6.3672e-01,  4.1406e-01, -7.1875e-01,  ..., -9.0625e-01,
           1.1250e+00,  8.0000e+00],
         [ 1.3516e+00, -1.4844e-01, -8.2812e-01,  ..., -1.6562e+00,
           4.6484e-01,  8.8125e+00]]], dtype=torch.bfloat16), tensor([[[ 5.8289e-03, -2.0905e-03, -2.3041e-03,  ...,  6.0654e-04,
           6.7139e-03, -1.5717e-03],
         [ 5.4688e-01,  5.8899e-03,  2.3145e-01,  ...,  1.2695e-01,
           4.5508e-01,  6.8750e-01],
         [ 9.8828e-01, -4.2969e-01,  9.2969e-01,  ..., -1.8828e+00,
           1.0547e+00, -1.2578e+00],
         ...,
         [ 9.4922e-01, -6.6797e-01,  1.3477e-01,  ..., -6.0156e-01,
           6.4941e-02,  2.5391e-01],
         [ 2.7539e-01,  6.7969e-01,  1.0791e-01,  ..., -8.9062e-01,
          -8.8281e-01,  1.0547e+00],
         [ 2.5000e-01,  4.5312e-01,  3.7109e-01,  ...,  6.5918e-02,
           7.6953e-01, -2.4023e-01]],

        [[ 5.8289e-03, -2.0905e-03, -2.3041e-03,  ...,  6.0654e-04,
           6.7139e-03, -1.5717e-03],
         [ 5.4688e-01,  5.8899e-03,  2.3145e-01,  ...,  1.2695e-01,
           4.5508e-01,  6.8750e-01],
         [ 9.8828e-01, -4.2969e-01,  9.2969e-01,  ..., -1.8828e+00,
           1.0547e+00, -1.2578e+00],
         ...,
         [ 9.4922e-01, -6.6797e-01,  1.3477e-01,  ..., -6.0156e-01,
           6.4941e-02,  2.5391e-01],
         [ 2.7539e-01,  6.7969e-01,  1.0791e-01,  ..., -8.9062e-01,
          -8.8281e-01,  1.0547e+00],
         [ 2.5000e-01,  4.5312e-01,  3.7109e-01,  ...,  6.5918e-02,
           7.6953e-01, -2.4023e-01]],

        [[ 5.8289e-03, -2.0905e-03, -2.3041e-03,  ...,  6.0654e-04,
           6.7139e-03, -1.5717e-03],
         [ 5.4688e-01,  5.8899e-03,  2.3145e-01,  ...,  1.2695e-01,
           4.5508e-01,  6.8750e-01],
         [ 9.8828e-01, -4.2969e-01,  9.2969e-01,  ..., -1.8828e+00,
           1.0547e+00, -1.2578e+00],
         ...,
         [ 9.4922e-01, -6.6797e-01,  1.3477e-01,  ..., -6.0156e-01,
           6.4941e-02,  2.5391e-01],
         [ 2.7539e-01,  6.7969e-01,  1.0791e-01,  ..., -8.9062e-01,
          -8.8281e-01,  1.0547e+00],
         [ 2.5000e-01,  4.5312e-01,  3.7109e-01,  ...,  6.5918e-02,
           7.6953e-01, -2.4023e-01]],

        ...,

        [[-3.1128e-03, -7.8583e-04,  5.3406e-03,  ..., -7.8125e-03,
           1.1963e-02, -1.4420e-03],
         [ 8.1250e-01, -4.7461e-01, -8.2812e-01,  ...,  4.3555e-01,
          -1.8125e+00, -8.7891e-01],
         [-1.0391e+00, -8.7500e-01, -4.8242e-01,  ..., -1.1963e-01,
          -1.3984e+00, -1.8359e-01],
         ...,
         [ 2.8320e-01,  3.9844e-01,  4.3555e-01,  ...,  9.3750e-01,
          -6.4453e-01, -2.7539e-01],
         [ 8.2812e-01, -1.3770e-01,  5.3125e-01,  ...,  5.3516e-01,
          -3.3789e-01,  4.6484e-01],
         [ 9.4141e-01, -6.2109e-01,  8.3984e-01,  ..., -5.2344e-01,
          -1.8262e-01, -1.5918e-01]],

        [[-3.1128e-03, -7.8583e-04,  5.3406e-03,  ..., -7.8125e-03,
           1.1963e-02, -1.4420e-03],
         [ 8.1250e-01, -4.7461e-01, -8.2812e-01,  ...,  4.3555e-01,
          -1.8125e+00, -8.7891e-01],
         [-1.0391e+00, -8.7500e-01, -4.8242e-01,  ..., -1.1963e-01,
          -1.3984e+00, -1.8359e-01],
         ...,
         [ 2.8320e-01,  3.9844e-01,  4.3555e-01,  ...,  9.3750e-01,
          -6.4453e-01, -2.7539e-01],
         [ 8.2812e-01, -1.3770e-01,  5.3125e-01,  ...,  5.3516e-01,
          -3.3789e-01,  4.6484e-01],
         [ 9.4141e-01, -6.2109e-01,  8.3984e-01,  ..., -5.2344e-01,
          -1.8262e-01, -1.5918e-01]],

        [[-3.1128e-03, -7.8583e-04,  5.3406e-03,  ..., -7.8125e-03,
           1.1963e-02, -1.4420e-03],
         [ 8.1250e-01, -4.7461e-01, -8.2812e-01,  ...,  4.3555e-01,
          -1.8125e+00, -8.7891e-01],
         [-1.0391e+00, -8.7500e-01, -4.8242e-01,  ..., -1.1963e-01,
          -1.3984e+00, -1.8359e-01],
         ...,
         [ 2.8320e-01,  3.9844e-01,  4.3555e-01,  ...,  9.3750e-01,
          -6.4453e-01, -2.7539e-01],
         [ 8.2812e-01, -1.3770e-01,  5.3125e-01,  ...,  5.3516e-01,
          -3.3789e-01,  4.6484e-01],
         [ 9.4141e-01, -6.2109e-01,  8.3984e-01,  ..., -5.2344e-01,
          -1.8262e-01, -1.5918e-01]]], dtype=torch.bfloat16)), (tensor([[[-2.0508e-02,  6.1340e-03, -8.1787e-03,  ..., -1.7285e-01,
          -1.0010e-01,  9.5703e-02],
         [ 2.2656e+00, -1.7500e+00,  1.6562e+00,  ...,  1.0625e+00,
           8.9453e-01, -1.7969e+00],
         [ 1.6172e+00, -1.5234e+00,  1.7031e+00,  ...,  7.7344e-01,
           8.0078e-02,  5.8984e-01],
         ...,
         [-6.4844e-01, -1.1328e-01, -3.1250e-01,  ...,  1.3750e+00,
           1.0703e+00, -1.3984e+00],
         [-9.0625e-01,  7.6562e-01,  9.9609e-01,  ...,  1.2422e+00,
           1.3906e+00, -6.6406e-01],
         [-1.0938e+00,  1.5156e+00,  6.9141e-01,  ...,  1.1562e+00,
           3.0781e+00, -8.7109e-01]],

        [[-2.0508e-02,  6.1340e-03, -8.1787e-03,  ..., -1.7285e-01,
          -1.0010e-01,  9.5703e-02],
         [ 2.2656e+00, -1.7500e+00,  1.6562e+00,  ...,  1.0625e+00,
           8.9453e-01, -1.7969e+00],
         [ 1.6172e+00, -1.5234e+00,  1.7031e+00,  ...,  7.7344e-01,
           8.0078e-02,  5.8984e-01],
         ...,
         [-6.4844e-01, -1.1328e-01, -3.1250e-01,  ...,  1.3750e+00,
           1.0703e+00, -1.3984e+00],
         [-9.0625e-01,  7.6562e-01,  9.9609e-01,  ...,  1.2422e+00,
           1.3906e+00, -6.6406e-01],
         [-1.0938e+00,  1.5156e+00,  6.9141e-01,  ...,  1.1562e+00,
           3.0781e+00, -8.7109e-01]],

        [[-2.0508e-02,  6.1340e-03, -8.1787e-03,  ..., -1.7285e-01,
          -1.0010e-01,  9.5703e-02],
         [ 2.2656e+00, -1.7500e+00,  1.6562e+00,  ...,  1.0625e+00,
           8.9453e-01, -1.7969e+00],
         [ 1.6172e+00, -1.5234e+00,  1.7031e+00,  ...,  7.7344e-01,
           8.0078e-02,  5.8984e-01],
         ...,
         [-6.4844e-01, -1.1328e-01, -3.1250e-01,  ...,  1.3750e+00,
           1.0703e+00, -1.3984e+00],
         [-9.0625e-01,  7.6562e-01,  9.9609e-01,  ...,  1.2422e+00,
           1.3906e+00, -6.6406e-01],
         [-1.0938e+00,  1.5156e+00,  6.9141e-01,  ...,  1.1562e+00,
           3.0781e+00, -8.7109e-01]],

        ...,

        [[-7.2021e-03, -4.3335e-03,  4.3335e-03,  ..., -1.3245e-02,
          -7.4609e-01,  1.4941e-01],
         [-1.0156e+00,  2.8125e+00,  1.6016e-01,  ...,  1.0391e+00,
           7.6562e+00, -8.5156e-01],
         [-2.3750e+00,  1.4844e+00,  1.1094e+00,  ...,  5.0781e-01,
           8.9375e+00, -2.2812e+00],
         ...,
         [-5.1953e-01, -7.2266e-01, -1.1797e+00,  ...,  5.1953e-01,
           7.0938e+00, -3.4570e-01],
         [-7.8125e-03, -1.4922e+00, -1.1328e+00,  ..., -5.7812e-01,
           7.4688e+00, -7.3047e-01],
         [ 1.5469e+00, -1.4062e-01, -4.6875e-01,  ...,  7.4219e-01,
           7.5938e+00, -6.0156e-01]],

        [[-7.2021e-03, -4.3335e-03,  4.3335e-03,  ..., -1.3245e-02,
          -7.4609e-01,  1.4941e-01],
         [-1.0156e+00,  2.8125e+00,  1.6016e-01,  ...,  1.0391e+00,
           7.6562e+00, -8.5156e-01],
         [-2.3750e+00,  1.4844e+00,  1.1094e+00,  ...,  5.0781e-01,
           8.9375e+00, -2.2812e+00],
         ...,
         [-5.1953e-01, -7.2266e-01, -1.1797e+00,  ...,  5.1953e-01,
           7.0938e+00, -3.4570e-01],
         [-7.8125e-03, -1.4922e+00, -1.1328e+00,  ..., -5.7812e-01,
           7.4688e+00, -7.3047e-01],
         [ 1.5469e+00, -1.4062e-01, -4.6875e-01,  ...,  7.4219e-01,
           7.5938e+00, -6.0156e-01]],

        [[-7.2021e-03, -4.3335e-03,  4.3335e-03,  ..., -1.3245e-02,
          -7.4609e-01,  1.4941e-01],
         [-1.0156e+00,  2.8125e+00,  1.6016e-01,  ...,  1.0391e+00,
           7.6562e+00, -8.5156e-01],
         [-2.3750e+00,  1.4844e+00,  1.1094e+00,  ...,  5.0781e-01,
           8.9375e+00, -2.2812e+00],
         ...,
         [-5.1953e-01, -7.2266e-01, -1.1797e+00,  ...,  5.1953e-01,
           7.0938e+00, -3.4570e-01],
         [-7.8125e-03, -1.4922e+00, -1.1328e+00,  ..., -5.7812e-01,
           7.4688e+00, -7.3047e-01],
         [ 1.5469e+00, -1.4062e-01, -4.6875e-01,  ...,  7.4219e-01,
           7.5938e+00, -6.0156e-01]]], dtype=torch.bfloat16), tensor([[[ 2.0752e-03, -2.1210e-03, -6.5002e-03,  ...,  5.2185e-03,
          -5.4626e-03, -2.9144e-03],
         [ 8.2397e-03, -2.9102e-01,  2.8516e-01,  ..., -9.3750e-01,
          -5.7812e-01,  6.2109e-01],
         [ 3.7500e-01, -6.1328e-01,  1.0547e-01,  ..., -1.8438e+00,
           4.2969e-01, -1.8281e+00],
         ...,
         [ 2.6367e-01, -1.3047e+00, -4.6680e-01,  ...,  6.3965e-02,
          -9.1406e-01,  1.2891e-01],
         [ 1.0156e-01, -1.0000e+00, -1.1094e+00,  ...,  8.4766e-01,
          -1.0938e+00, -1.3984e+00],
         [ 9.5312e-01, -6.4062e-01, -1.6406e-01,  ...,  7.2266e-01,
          -1.3359e+00, -6.6016e-01]],

        [[ 2.0752e-03, -2.1210e-03, -6.5002e-03,  ...,  5.2185e-03,
          -5.4626e-03, -2.9144e-03],
         [ 8.2397e-03, -2.9102e-01,  2.8516e-01,  ..., -9.3750e-01,
          -5.7812e-01,  6.2109e-01],
         [ 3.7500e-01, -6.1328e-01,  1.0547e-01,  ..., -1.8438e+00,
           4.2969e-01, -1.8281e+00],
         ...,
         [ 2.6367e-01, -1.3047e+00, -4.6680e-01,  ...,  6.3965e-02,
          -9.1406e-01,  1.2891e-01],
         [ 1.0156e-01, -1.0000e+00, -1.1094e+00,  ...,  8.4766e-01,
          -1.0938e+00, -1.3984e+00],
         [ 9.5312e-01, -6.4062e-01, -1.6406e-01,  ...,  7.2266e-01,
          -1.3359e+00, -6.6016e-01]],

        [[ 2.0752e-03, -2.1210e-03, -6.5002e-03,  ...,  5.2185e-03,
          -5.4626e-03, -2.9144e-03],
         [ 8.2397e-03, -2.9102e-01,  2.8516e-01,  ..., -9.3750e-01,
          -5.7812e-01,  6.2109e-01],
         [ 3.7500e-01, -6.1328e-01,  1.0547e-01,  ..., -1.8438e+00,
           4.2969e-01, -1.8281e+00],
         ...,
         [ 2.6367e-01, -1.3047e+00, -4.6680e-01,  ...,  6.3965e-02,
          -9.1406e-01,  1.2891e-01],
         [ 1.0156e-01, -1.0000e+00, -1.1094e+00,  ...,  8.4766e-01,
          -1.0938e+00, -1.3984e+00],
         [ 9.5312e-01, -6.4062e-01, -1.6406e-01,  ...,  7.2266e-01,
          -1.3359e+00, -6.6016e-01]],

        ...,

        [[ 4.7302e-03,  9.9182e-04, -4.4556e-03,  ...,  1.8921e-03,
          -6.1340e-03,  2.1820e-03],
         [ 5.3906e-01,  8.6426e-02, -3.4424e-02,  ...,  9.8438e-01,
           8.1250e-01,  7.6172e-01],
         [-3.0859e-01,  3.6719e-01,  1.8906e+00,  ...,  1.7871e-01,
           6.7871e-02, -2.1562e+00],
         ...,
         [-5.6250e-01, -1.3867e-01, -5.7373e-02,  ...,  1.0449e-01,
           1.4355e-01,  1.0547e-01],
         [ 1.5723e-01,  6.2109e-01, -7.5391e-01,  ..., -1.1250e+00,
          -1.0010e-01, -8.0859e-01],
         [-4.2578e-01, -2.8711e-01,  8.2422e-01,  ..., -1.1523e-01,
          -3.2617e-01,  4.1797e-01]],

        [[ 4.7302e-03,  9.9182e-04, -4.4556e-03,  ...,  1.8921e-03,
          -6.1340e-03,  2.1820e-03],
         [ 5.3906e-01,  8.6426e-02, -3.4424e-02,  ...,  9.8438e-01,
           8.1250e-01,  7.6172e-01],
         [-3.0859e-01,  3.6719e-01,  1.8906e+00,  ...,  1.7871e-01,
           6.7871e-02, -2.1562e+00],
         ...,
         [-5.6250e-01, -1.3867e-01, -5.7373e-02,  ...,  1.0449e-01,
           1.4355e-01,  1.0547e-01],
         [ 1.5723e-01,  6.2109e-01, -7.5391e-01,  ..., -1.1250e+00,
          -1.0010e-01, -8.0859e-01],
         [-4.2578e-01, -2.8711e-01,  8.2422e-01,  ..., -1.1523e-01,
          -3.2617e-01,  4.1797e-01]],

        [[ 4.7302e-03,  9.9182e-04, -4.4556e-03,  ...,  1.8921e-03,
          -6.1340e-03,  2.1820e-03],
         [ 5.3906e-01,  8.6426e-02, -3.4424e-02,  ...,  9.8438e-01,
           8.1250e-01,  7.6172e-01],
         [-3.0859e-01,  3.6719e-01,  1.8906e+00,  ...,  1.7871e-01,
           6.7871e-02, -2.1562e+00],
         ...,
         [-5.6250e-01, -1.3867e-01, -5.7373e-02,  ...,  1.0449e-01,
           1.4355e-01,  1.0547e-01],
         [ 1.5723e-01,  6.2109e-01, -7.5391e-01,  ..., -1.1250e+00,
          -1.0010e-01, -8.0859e-01],
         [-4.2578e-01, -2.8711e-01,  8.2422e-01,  ..., -1.1523e-01,
          -3.2617e-01,  4.1797e-01]]], dtype=torch.bfloat16)), (tensor([[[ 3.7994e-03, -2.5146e-02, -1.2360e-03,  ..., -2.3633e-01,
          -9.6680e-02,  1.0059e-01],
         [ 2.5000e-01, -3.6328e-01,  3.5156e-01,  ..., -8.7402e-02,
          -9.6484e-01, -1.3516e+00],
         [ 1.1797e+00, -1.0859e+00, -7.4609e-01,  ..., -5.4062e+00,
          -4.2500e+00, -3.3750e+00],
         ...,
         [ 1.1250e+00,  4.6289e-01,  4.1602e-01,  ..., -4.9062e+00,
          -5.9766e-01, -1.3359e+00],
         [-2.5977e-01, -3.4180e-01,  6.4062e-01,  ..., -4.8750e+00,
           1.2793e-01,  8.6719e-01],
         [-6.8359e-01,  4.2578e-01, -1.1406e+00,  ..., -3.5938e+00,
          -2.6719e+00, -1.1641e+00]],

        [[ 3.7994e-03, -2.5146e-02, -1.2360e-03,  ..., -2.3633e-01,
          -9.6680e-02,  1.0059e-01],
         [ 2.5000e-01, -3.6328e-01,  3.5156e-01,  ..., -8.7402e-02,
          -9.6484e-01, -1.3516e+00],
         [ 1.1797e+00, -1.0859e+00, -7.4609e-01,  ..., -5.4062e+00,
          -4.2500e+00, -3.3750e+00],
         ...,
         [ 1.1250e+00,  4.6289e-01,  4.1602e-01,  ..., -4.9062e+00,
          -5.9766e-01, -1.3359e+00],
         [-2.5977e-01, -3.4180e-01,  6.4062e-01,  ..., -4.8750e+00,
           1.2793e-01,  8.6719e-01],
         [-6.8359e-01,  4.2578e-01, -1.1406e+00,  ..., -3.5938e+00,
          -2.6719e+00, -1.1641e+00]],

        [[ 3.7994e-03, -2.5146e-02, -1.2360e-03,  ..., -2.3633e-01,
          -9.6680e-02,  1.0059e-01],
         [ 2.5000e-01, -3.6328e-01,  3.5156e-01,  ..., -8.7402e-02,
          -9.6484e-01, -1.3516e+00],
         [ 1.1797e+00, -1.0859e+00, -7.4609e-01,  ..., -5.4062e+00,
          -4.2500e+00, -3.3750e+00],
         ...,
         [ 1.1250e+00,  4.6289e-01,  4.1602e-01,  ..., -4.9062e+00,
          -5.9766e-01, -1.3359e+00],
         [-2.5977e-01, -3.4180e-01,  6.4062e-01,  ..., -4.8750e+00,
           1.2793e-01,  8.6719e-01],
         [-6.8359e-01,  4.2578e-01, -1.1406e+00,  ..., -3.5938e+00,
          -2.6719e+00, -1.1641e+00]],

        ...,

        [[-8.1177e-03, -2.3041e-03, -8.8501e-04,  ...,  4.3213e-02,
           1.1426e-01,  1.2305e-01],
         [ 1.2969e+00,  1.7734e+00,  1.0703e+00,  ...,  2.6406e+00,
           9.8828e-01,  1.5723e-01],
         [ 1.7188e-01,  6.2500e-01,  2.2344e+00,  ...,  1.0625e+00,
          -7.8906e-01,  2.1484e-01],
         ...,
         [ 6.8359e-01, -8.3984e-01,  2.1719e+00,  ...,  5.4297e-01,
           5.5078e-01, -2.9688e+00],
         [ 2.4023e-01, -6.1719e-01,  1.1328e+00,  ...,  2.5977e-01,
           7.6172e-02, -2.4023e-01],
         [-1.1719e+00,  9.5703e-02,  1.2422e+00,  ...,  2.5156e+00,
           1.2578e+00, -1.1797e+00]],

        [[-8.1177e-03, -2.3041e-03, -8.8501e-04,  ...,  4.3213e-02,
           1.1426e-01,  1.2305e-01],
         [ 1.2969e+00,  1.7734e+00,  1.0703e+00,  ...,  2.6406e+00,
           9.8828e-01,  1.5723e-01],
         [ 1.7188e-01,  6.2500e-01,  2.2344e+00,  ...,  1.0625e+00,
          -7.8906e-01,  2.1484e-01],
         ...,
         [ 6.8359e-01, -8.3984e-01,  2.1719e+00,  ...,  5.4297e-01,
           5.5078e-01, -2.9688e+00],
         [ 2.4023e-01, -6.1719e-01,  1.1328e+00,  ...,  2.5977e-01,
           7.6172e-02, -2.4023e-01],
         [-1.1719e+00,  9.5703e-02,  1.2422e+00,  ...,  2.5156e+00,
           1.2578e+00, -1.1797e+00]],

        [[-8.1177e-03, -2.3041e-03, -8.8501e-04,  ...,  4.3213e-02,
           1.1426e-01,  1.2305e-01],
         [ 1.2969e+00,  1.7734e+00,  1.0703e+00,  ...,  2.6406e+00,
           9.8828e-01,  1.5723e-01],
         [ 1.7188e-01,  6.2500e-01,  2.2344e+00,  ...,  1.0625e+00,
          -7.8906e-01,  2.1484e-01],
         ...,
         [ 6.8359e-01, -8.3984e-01,  2.1719e+00,  ...,  5.4297e-01,
           5.5078e-01, -2.9688e+00],
         [ 2.4023e-01, -6.1719e-01,  1.1328e+00,  ...,  2.5977e-01,
           7.6172e-02, -2.4023e-01],
         [-1.1719e+00,  9.5703e-02,  1.2422e+00,  ...,  2.5156e+00,
           1.2578e+00, -1.1797e+00]]], dtype=torch.bfloat16), tensor([[[ 7.2327e-03, -4.8218e-03, -5.1575e-03,  ...,  1.7090e-03,
           1.6479e-03,  5.5847e-03],
         [-1.0781e+00, -9.1406e-01,  2.9688e+00,  ..., -1.7031e+00,
          -5.6641e-01, -1.1016e+00],
         [-1.9219e+00, -4.1562e+00,  4.4688e+00,  ..., -7.5781e-01,
          -1.3359e+00, -9.6094e-01],
         ...,
         [-1.5469e+00, -5.3516e-01, -2.2461e-01,  ...,  1.4062e+00,
           9.1016e-01,  1.1406e+00],
         [-4.6094e-01, -2.2656e+00,  1.8750e+00,  ..., -7.8125e-01,
           5.0781e-01,  1.2891e+00],
         [-7.1875e-01, -2.1875e+00,  2.0938e+00,  ...,  1.0938e-01,
          -2.0996e-01,  9.2188e-01]],

        [[ 7.2327e-03, -4.8218e-03, -5.1575e-03,  ...,  1.7090e-03,
           1.6479e-03,  5.5847e-03],
         [-1.0781e+00, -9.1406e-01,  2.9688e+00,  ..., -1.7031e+00,
          -5.6641e-01, -1.1016e+00],
         [-1.9219e+00, -4.1562e+00,  4.4688e+00,  ..., -7.5781e-01,
          -1.3359e+00, -9.6094e-01],
         ...,
         [-1.5469e+00, -5.3516e-01, -2.2461e-01,  ...,  1.4062e+00,
           9.1016e-01,  1.1406e+00],
         [-4.6094e-01, -2.2656e+00,  1.8750e+00,  ..., -7.8125e-01,
           5.0781e-01,  1.2891e+00],
         [-7.1875e-01, -2.1875e+00,  2.0938e+00,  ...,  1.0938e-01,
          -2.0996e-01,  9.2188e-01]],

        [[ 7.2327e-03, -4.8218e-03, -5.1575e-03,  ...,  1.7090e-03,
           1.6479e-03,  5.5847e-03],
         [-1.0781e+00, -9.1406e-01,  2.9688e+00,  ..., -1.7031e+00,
          -5.6641e-01, -1.1016e+00],
         [-1.9219e+00, -4.1562e+00,  4.4688e+00,  ..., -7.5781e-01,
          -1.3359e+00, -9.6094e-01],
         ...,
         [-1.5469e+00, -5.3516e-01, -2.2461e-01,  ...,  1.4062e+00,
           9.1016e-01,  1.1406e+00],
         [-4.6094e-01, -2.2656e+00,  1.8750e+00,  ..., -7.8125e-01,
           5.0781e-01,  1.2891e+00],
         [-7.1875e-01, -2.1875e+00,  2.0938e+00,  ...,  1.0938e-01,
          -2.0996e-01,  9.2188e-01]],

        ...,

        [[ 5.9509e-04,  1.7090e-02, -3.9368e-03,  ...,  3.6011e-03,
           3.3569e-03, -2.1362e-03],
         [-1.0547e-01,  1.1963e-01,  1.7266e+00,  ..., -1.3359e+00,
          -6.9141e-01, -1.7422e+00],
         [-4.1809e-03, -8.7402e-02,  5.8594e-01,  ..., -8.2422e-01,
          -4.3701e-02, -3.7598e-02],
         ...,
         [-1.6953e+00, -2.0508e-01,  2.5977e-01,  ..., -1.7266e+00,
          -2.5156e+00, -1.0391e+00],
         [-9.4922e-01, -5.0000e-01, -5.4688e-01,  ..., -7.5391e-01,
          -1.5781e+00, -1.8359e+00],
         [-1.2109e+00,  8.8281e-01,  3.9453e-01,  ..., -2.0469e+00,
          -7.4219e-01,  1.3984e+00]],

        [[ 5.9509e-04,  1.7090e-02, -3.9368e-03,  ...,  3.6011e-03,
           3.3569e-03, -2.1362e-03],
         [-1.0547e-01,  1.1963e-01,  1.7266e+00,  ..., -1.3359e+00,
          -6.9141e-01, -1.7422e+00],
         [-4.1809e-03, -8.7402e-02,  5.8594e-01,  ..., -8.2422e-01,
          -4.3701e-02, -3.7598e-02],
         ...,
         [-1.6953e+00, -2.0508e-01,  2.5977e-01,  ..., -1.7266e+00,
          -2.5156e+00, -1.0391e+00],
         [-9.4922e-01, -5.0000e-01, -5.4688e-01,  ..., -7.5391e-01,
          -1.5781e+00, -1.8359e+00],
         [-1.2109e+00,  8.8281e-01,  3.9453e-01,  ..., -2.0469e+00,
          -7.4219e-01,  1.3984e+00]],

        [[ 5.9509e-04,  1.7090e-02, -3.9368e-03,  ...,  3.6011e-03,
           3.3569e-03, -2.1362e-03],
         [-1.0547e-01,  1.1963e-01,  1.7266e+00,  ..., -1.3359e+00,
          -6.9141e-01, -1.7422e+00],
         [-4.1809e-03, -8.7402e-02,  5.8594e-01,  ..., -8.2422e-01,
          -4.3701e-02, -3.7598e-02],
         ...,
         [-1.6953e+00, -2.0508e-01,  2.5977e-01,  ..., -1.7266e+00,
          -2.5156e+00, -1.0391e+00],
         [-9.4922e-01, -5.0000e-01, -5.4688e-01,  ..., -7.5391e-01,
          -1.5781e+00, -1.8359e+00],
         [-1.2109e+00,  8.8281e-01,  3.9453e-01,  ..., -2.0469e+00,
          -7.4219e-01,  1.3984e+00]]], dtype=torch.bfloat16)), (tensor([[[-1.2024e-02,  5.1270e-03,  5.8289e-03,  ..., -1.3594e+00,
           1.7090e-03,  2.0215e-01],
         [ 3.5547e-01,  1.9609e+00, -6.4844e-01,  ...,  9.5625e+00,
           1.6113e-01, -2.9062e+00],
         [-9.1016e-01, -1.1621e-01,  2.7734e-01,  ...,  1.0000e+01,
           1.7188e+00, -6.8359e-01],
         ...,
         [-6.9141e-01, -1.5918e-01, -6.6406e-01,  ...,  9.1875e+00,
           6.7383e-02, -1.1641e+00],
         [-7.6172e-02,  1.3672e-01, -4.2773e-01,  ...,  8.8750e+00,
           5.4297e-01, -1.1250e+00],
         [ 1.9062e+00,  1.2188e+00, -7.5391e-01,  ...,  9.0625e+00,
          -5.1953e-01, -2.5000e+00]],

        [[-1.2024e-02,  5.1270e-03,  5.8289e-03,  ..., -1.3594e+00,
           1.7090e-03,  2.0215e-01],
         [ 3.5547e-01,  1.9609e+00, -6.4844e-01,  ...,  9.5625e+00,
           1.6113e-01, -2.9062e+00],
         [-9.1016e-01, -1.1621e-01,  2.7734e-01,  ...,  1.0000e+01,
           1.7188e+00, -6.8359e-01],
         ...,
         [-6.9141e-01, -1.5918e-01, -6.6406e-01,  ...,  9.1875e+00,
           6.7383e-02, -1.1641e+00],
         [-7.6172e-02,  1.3672e-01, -4.2773e-01,  ...,  8.8750e+00,
           5.4297e-01, -1.1250e+00],
         [ 1.9062e+00,  1.2188e+00, -7.5391e-01,  ...,  9.0625e+00,
          -5.1953e-01, -2.5000e+00]],

        [[-1.2024e-02,  5.1270e-03,  5.8289e-03,  ..., -1.3594e+00,
           1.7090e-03,  2.0215e-01],
         [ 3.5547e-01,  1.9609e+00, -6.4844e-01,  ...,  9.5625e+00,
           1.6113e-01, -2.9062e+00],
         [-9.1016e-01, -1.1621e-01,  2.7734e-01,  ...,  1.0000e+01,
           1.7188e+00, -6.8359e-01],
         ...,
         [-6.9141e-01, -1.5918e-01, -6.6406e-01,  ...,  9.1875e+00,
           6.7383e-02, -1.1641e+00],
         [-7.6172e-02,  1.3672e-01, -4.2773e-01,  ...,  8.8750e+00,
           5.4297e-01, -1.1250e+00],
         [ 1.9062e+00,  1.2188e+00, -7.5391e-01,  ...,  9.0625e+00,
          -5.1953e-01, -2.5000e+00]],

        ...,

        [[ 9.3384e-03,  1.5198e-02, -1.2146e-02,  ...,  7.6660e-02,
          -3.2422e-01,  1.1523e-01],
         [-1.5078e+00,  2.6406e+00,  1.2656e+00,  ...,  2.8711e-01,
           2.0469e+00,  1.1523e-01],
         [-7.0312e-01,  9.1309e-02,  5.9375e-01,  ...,  4.6875e+00,
          -4.1562e+00,  1.5234e+00],
         ...,
         [ 2.0312e+00, -7.4219e-02,  8.0469e-01,  ...,  1.7266e+00,
          -3.5625e+00,  2.6406e+00],
         [ 4.2969e-01, -5.1562e-01,  7.5000e-01,  ...,  8.2812e-01,
          -3.5625e+00,  2.9062e+00],
         [-1.1953e+00,  2.3438e-01,  1.4844e+00,  ...,  6.1719e-01,
          -2.1875e+00,  1.4688e+00]],

        [[ 9.3384e-03,  1.5198e-02, -1.2146e-02,  ...,  7.6660e-02,
          -3.2422e-01,  1.1523e-01],
         [-1.5078e+00,  2.6406e+00,  1.2656e+00,  ...,  2.8711e-01,
           2.0469e+00,  1.1523e-01],
         [-7.0312e-01,  9.1309e-02,  5.9375e-01,  ...,  4.6875e+00,
          -4.1562e+00,  1.5234e+00],
         ...,
         [ 2.0312e+00, -7.4219e-02,  8.0469e-01,  ...,  1.7266e+00,
          -3.5625e+00,  2.6406e+00],
         [ 4.2969e-01, -5.1562e-01,  7.5000e-01,  ...,  8.2812e-01,
          -3.5625e+00,  2.9062e+00],
         [-1.1953e+00,  2.3438e-01,  1.4844e+00,  ...,  6.1719e-01,
          -2.1875e+00,  1.4688e+00]],

        [[ 9.3384e-03,  1.5198e-02, -1.2146e-02,  ...,  7.6660e-02,
          -3.2422e-01,  1.1523e-01],
         [-1.5078e+00,  2.6406e+00,  1.2656e+00,  ...,  2.8711e-01,
           2.0469e+00,  1.1523e-01],
         [-7.0312e-01,  9.1309e-02,  5.9375e-01,  ...,  4.6875e+00,
          -4.1562e+00,  1.5234e+00],
         ...,
         [ 2.0312e+00, -7.4219e-02,  8.0469e-01,  ...,  1.7266e+00,
          -3.5625e+00,  2.6406e+00],
         [ 4.2969e-01, -5.1562e-01,  7.5000e-01,  ...,  8.2812e-01,
          -3.5625e+00,  2.9062e+00],
         [-1.1953e+00,  2.3438e-01,  1.4844e+00,  ...,  6.1719e-01,
          -2.1875e+00,  1.4688e+00]]], dtype=torch.bfloat16), tensor([[[-9.0942e-03,  5.8289e-03, -4.0894e-03,  ..., -1.3184e-02,
          -4.6997e-03,  2.1172e-04],
         [ 1.4062e-01, -7.3047e-01, -4.3164e-01,  ..., -9.0625e-01,
           6.6797e-01, -9.2969e-01],
         [-9.3359e-01, -2.1719e+00,  3.2959e-02,  ...,  9.1406e-01,
           5.8984e-01,  5.2734e-01],
         ...,
         [-7.1094e-01, -7.1289e-02,  1.1094e+00,  ...,  5.1953e-01,
           5.9375e-01,  3.6719e-01],
         [-8.4766e-01,  1.0547e+00, -5.2734e-01,  ...,  1.2578e+00,
           9.4922e-01,  8.6914e-02],
         [-1.9043e-01,  1.2734e+00,  1.0781e+00,  ...,  1.0391e+00,
          -5.1172e-01, -2.2461e-01]],

        [[-9.0942e-03,  5.8289e-03, -4.0894e-03,  ..., -1.3184e-02,
          -4.6997e-03,  2.1172e-04],
         [ 1.4062e-01, -7.3047e-01, -4.3164e-01,  ..., -9.0625e-01,
           6.6797e-01, -9.2969e-01],
         [-9.3359e-01, -2.1719e+00,  3.2959e-02,  ...,  9.1406e-01,
           5.8984e-01,  5.2734e-01],
         ...,
         [-7.1094e-01, -7.1289e-02,  1.1094e+00,  ...,  5.1953e-01,
           5.9375e-01,  3.6719e-01],
         [-8.4766e-01,  1.0547e+00, -5.2734e-01,  ...,  1.2578e+00,
           9.4922e-01,  8.6914e-02],
         [-1.9043e-01,  1.2734e+00,  1.0781e+00,  ...,  1.0391e+00,
          -5.1172e-01, -2.2461e-01]],

        [[-9.0942e-03,  5.8289e-03, -4.0894e-03,  ..., -1.3184e-02,
          -4.6997e-03,  2.1172e-04],
         [ 1.4062e-01, -7.3047e-01, -4.3164e-01,  ..., -9.0625e-01,
           6.6797e-01, -9.2969e-01],
         [-9.3359e-01, -2.1719e+00,  3.2959e-02,  ...,  9.1406e-01,
           5.8984e-01,  5.2734e-01],
         ...,
         [-7.1094e-01, -7.1289e-02,  1.1094e+00,  ...,  5.1953e-01,
           5.9375e-01,  3.6719e-01],
         [-8.4766e-01,  1.0547e+00, -5.2734e-01,  ...,  1.2578e+00,
           9.4922e-01,  8.6914e-02],
         [-1.9043e-01,  1.2734e+00,  1.0781e+00,  ...,  1.0391e+00,
          -5.1172e-01, -2.2461e-01]],

        ...,

        [[ 3.6955e-05, -7.8735e-03,  3.1853e-04,  ...,  1.9531e-02,
           2.4986e-04, -6.5308e-03],
         [-1.0449e-01, -9.7656e-02,  5.8984e-01,  ...,  8.9844e-01,
          -1.6094e+00,  1.9434e-01],
         [-1.0781e+00,  3.4219e+00,  2.5000e+00,  ..., -6.2500e-01,
          -1.0859e+00,  1.2812e+00],
         ...,
         [ 3.5938e-01, -1.6309e-01,  1.0703e+00,  ...,  7.0801e-02,
          -1.9688e+00,  3.2031e-01],
         [-8.0078e-01,  1.4062e+00,  1.7656e+00,  ..., -6.3672e-01,
          -3.0312e+00, -3.1445e-01],
         [-3.6328e-01,  1.1094e+00,  2.4219e+00,  ...,  2.8711e-01,
          -1.6797e+00,  6.4062e-01]],

        [[ 3.6955e-05, -7.8735e-03,  3.1853e-04,  ...,  1.9531e-02,
           2.4986e-04, -6.5308e-03],
         [-1.0449e-01, -9.7656e-02,  5.8984e-01,  ...,  8.9844e-01,
          -1.6094e+00,  1.9434e-01],
         [-1.0781e+00,  3.4219e+00,  2.5000e+00,  ..., -6.2500e-01,
          -1.0859e+00,  1.2812e+00],
         ...,
         [ 3.5938e-01, -1.6309e-01,  1.0703e+00,  ...,  7.0801e-02,
          -1.9688e+00,  3.2031e-01],
         [-8.0078e-01,  1.4062e+00,  1.7656e+00,  ..., -6.3672e-01,
          -3.0312e+00, -3.1445e-01],
         [-3.6328e-01,  1.1094e+00,  2.4219e+00,  ...,  2.8711e-01,
          -1.6797e+00,  6.4062e-01]],

        [[ 3.6955e-05, -7.8735e-03,  3.1853e-04,  ...,  1.9531e-02,
           2.4986e-04, -6.5308e-03],
         [-1.0449e-01, -9.7656e-02,  5.8984e-01,  ...,  8.9844e-01,
          -1.6094e+00,  1.9434e-01],
         [-1.0781e+00,  3.4219e+00,  2.5000e+00,  ..., -6.2500e-01,
          -1.0859e+00,  1.2812e+00],
         ...,
         [ 3.5938e-01, -1.6309e-01,  1.0703e+00,  ...,  7.0801e-02,
          -1.9688e+00,  3.2031e-01],
         [-8.0078e-01,  1.4062e+00,  1.7656e+00,  ..., -6.3672e-01,
          -3.0312e+00, -3.1445e-01],
         [-3.6328e-01,  1.1094e+00,  2.4219e+00,  ...,  2.8711e-01,
          -1.6797e+00,  6.4062e-01]]], dtype=torch.bfloat16)), (tensor([[[ 2.8076e-03, -3.0212e-03,  3.6774e-03,  ...,  4.7266e-01,
           5.4443e-02, -6.8848e-02],
         [-4.5625e+00,  8.1250e-01,  1.7500e+00,  ..., -5.6641e-01,
          -1.9141e-01,  8.0078e-02],
         [-1.1719e+00,  2.5938e+00,  1.4844e+00,  ..., -6.6406e-01,
           4.4727e-01,  4.9609e-01],
         ...,
         [ 6.5625e-01, -1.5312e+00, -1.4531e+00,  ...,  9.7656e-01,
           4.0820e-01,  1.3203e+00],
         [ 1.6250e+00, -1.4922e+00,  2.1484e-01,  ...,  5.0000e-01,
          -6.2500e-02, -4.8047e-01],
         [ 1.1719e+00, -1.3984e+00,  2.0312e+00,  ...,  1.6406e+00,
          -2.4688e+00, -1.5391e+00]],

        [[ 2.8076e-03, -3.0212e-03,  3.6774e-03,  ...,  4.7266e-01,
           5.4443e-02, -6.8848e-02],
         [-4.5625e+00,  8.1250e-01,  1.7500e+00,  ..., -5.6641e-01,
          -1.9141e-01,  8.0078e-02],
         [-1.1719e+00,  2.5938e+00,  1.4844e+00,  ..., -6.6406e-01,
           4.4727e-01,  4.9609e-01],
         ...,
         [ 6.5625e-01, -1.5312e+00, -1.4531e+00,  ...,  9.7656e-01,
           4.0820e-01,  1.3203e+00],
         [ 1.6250e+00, -1.4922e+00,  2.1484e-01,  ...,  5.0000e-01,
          -6.2500e-02, -4.8047e-01],
         [ 1.1719e+00, -1.3984e+00,  2.0312e+00,  ...,  1.6406e+00,
          -2.4688e+00, -1.5391e+00]],

        [[ 2.8076e-03, -3.0212e-03,  3.6774e-03,  ...,  4.7266e-01,
           5.4443e-02, -6.8848e-02],
         [-4.5625e+00,  8.1250e-01,  1.7500e+00,  ..., -5.6641e-01,
          -1.9141e-01,  8.0078e-02],
         [-1.1719e+00,  2.5938e+00,  1.4844e+00,  ..., -6.6406e-01,
           4.4727e-01,  4.9609e-01],
         ...,
         [ 6.5625e-01, -1.5312e+00, -1.4531e+00,  ...,  9.7656e-01,
           4.0820e-01,  1.3203e+00],
         [ 1.6250e+00, -1.4922e+00,  2.1484e-01,  ...,  5.0000e-01,
          -6.2500e-02, -4.8047e-01],
         [ 1.1719e+00, -1.3984e+00,  2.0312e+00,  ...,  1.6406e+00,
          -2.4688e+00, -1.5391e+00]],

        ...,

        [[ 1.2390e-02,  3.0670e-03, -8.9111e-03,  ...,  2.3438e-01,
          -5.9326e-02,  6.6895e-02],
         [-1.9043e-01, -1.5625e-02, -2.5000e-01,  ..., -2.9688e-01,
           6.7188e-01, -3.9258e-01],
         [-3.3789e-01,  2.4316e-01,  9.7168e-02,  ..., -1.0703e+00,
           5.2734e-01, -1.5000e+00],
         ...,
         [ 4.2578e-01, -5.8984e-01, -1.1719e+00,  ...,  5.2734e-01,
          -7.6953e-01, -7.0703e-01],
         [ 9.4727e-02, -2.1387e-01, -2.8516e-01,  ...,  1.1094e+00,
          -8.3496e-02,  1.2793e-01],
         [-3.1445e-01, -8.5938e-01,  6.6406e-02,  ...,  1.1406e+00,
           5.6250e-01,  2.9102e-01]],

        [[ 1.2390e-02,  3.0670e-03, -8.9111e-03,  ...,  2.3438e-01,
          -5.9326e-02,  6.6895e-02],
         [-1.9043e-01, -1.5625e-02, -2.5000e-01,  ..., -2.9688e-01,
           6.7188e-01, -3.9258e-01],
         [-3.3789e-01,  2.4316e-01,  9.7168e-02,  ..., -1.0703e+00,
           5.2734e-01, -1.5000e+00],
         ...,
         [ 4.2578e-01, -5.8984e-01, -1.1719e+00,  ...,  5.2734e-01,
          -7.6953e-01, -7.0703e-01],
         [ 9.4727e-02, -2.1387e-01, -2.8516e-01,  ...,  1.1094e+00,
          -8.3496e-02,  1.2793e-01],
         [-3.1445e-01, -8.5938e-01,  6.6406e-02,  ...,  1.1406e+00,
           5.6250e-01,  2.9102e-01]],

        [[ 1.2390e-02,  3.0670e-03, -8.9111e-03,  ...,  2.3438e-01,
          -5.9326e-02,  6.6895e-02],
         [-1.9043e-01, -1.5625e-02, -2.5000e-01,  ..., -2.9688e-01,
           6.7188e-01, -3.9258e-01],
         [-3.3789e-01,  2.4316e-01,  9.7168e-02,  ..., -1.0703e+00,
           5.2734e-01, -1.5000e+00],
         ...,
         [ 4.2578e-01, -5.8984e-01, -1.1719e+00,  ...,  5.2734e-01,
          -7.6953e-01, -7.0703e-01],
         [ 9.4727e-02, -2.1387e-01, -2.8516e-01,  ...,  1.1094e+00,
          -8.3496e-02,  1.2793e-01],
         [-3.1445e-01, -8.5938e-01,  6.6406e-02,  ...,  1.1406e+00,
           5.6250e-01,  2.9102e-01]]], dtype=torch.bfloat16), tensor([[[-3.1891e-03,  2.6703e-03, -3.2196e-03,  ..., -9.7656e-03,
          -9.3994e-03,  2.1667e-03],
         [-3.2422e-01,  1.2891e-01,  5.2344e-01,  ...,  4.4922e-01,
           1.2598e-01, -9.6484e-01],
         [ 8.7500e-01, -7.2656e-01,  8.9355e-02,  ...,  5.7812e-01,
          -7.8516e-01, -1.2793e-01],
         ...,
         [-2.9883e-01, -1.0078e+00, -1.0254e-01,  ...,  1.4531e+00,
           1.0000e+00, -2.3340e-01],
         [ 1.0156e+00, -5.3125e-01, -1.0312e+00,  ...,  2.6172e-01,
          -6.4844e-01, -2.7148e-01],
         [-1.4453e+00, -1.5391e+00, -6.8750e-01,  ...,  2.1562e+00,
           2.5781e+00,  1.0391e+00]],

        [[-3.1891e-03,  2.6703e-03, -3.2196e-03,  ..., -9.7656e-03,
          -9.3994e-03,  2.1667e-03],
         [-3.2422e-01,  1.2891e-01,  5.2344e-01,  ...,  4.4922e-01,
           1.2598e-01, -9.6484e-01],
         [ 8.7500e-01, -7.2656e-01,  8.9355e-02,  ...,  5.7812e-01,
          -7.8516e-01, -1.2793e-01],
         ...,
         [-2.9883e-01, -1.0078e+00, -1.0254e-01,  ...,  1.4531e+00,
           1.0000e+00, -2.3340e-01],
         [ 1.0156e+00, -5.3125e-01, -1.0312e+00,  ...,  2.6172e-01,
          -6.4844e-01, -2.7148e-01],
         [-1.4453e+00, -1.5391e+00, -6.8750e-01,  ...,  2.1562e+00,
           2.5781e+00,  1.0391e+00]],

        [[-3.1891e-03,  2.6703e-03, -3.2196e-03,  ..., -9.7656e-03,
          -9.3994e-03,  2.1667e-03],
         [-3.2422e-01,  1.2891e-01,  5.2344e-01,  ...,  4.4922e-01,
           1.2598e-01, -9.6484e-01],
         [ 8.7500e-01, -7.2656e-01,  8.9355e-02,  ...,  5.7812e-01,
          -7.8516e-01, -1.2793e-01],
         ...,
         [-2.9883e-01, -1.0078e+00, -1.0254e-01,  ...,  1.4531e+00,
           1.0000e+00, -2.3340e-01],
         [ 1.0156e+00, -5.3125e-01, -1.0312e+00,  ...,  2.6172e-01,
          -6.4844e-01, -2.7148e-01],
         [-1.4453e+00, -1.5391e+00, -6.8750e-01,  ...,  2.1562e+00,
           2.5781e+00,  1.0391e+00]],

        ...,

        [[-4.9133e-03, -5.0049e-03, -8.3008e-03,  ..., -8.8501e-04,
          -2.7618e-03,  1.4587e-02],
         [-5.5469e-01, -1.0859e+00, -3.4570e-01,  ..., -1.1641e+00,
           5.1172e-01,  1.9629e-01],
         [ 2.1875e-01, -9.4531e-01,  2.5781e-01,  ..., -7.8516e-01,
           5.6641e-01,  1.0000e+00],
         ...,
         [-1.1865e-01,  7.1875e-01,  1.5234e+00,  ..., -1.6250e+00,
           8.0469e-01, -4.2383e-01],
         [ 1.2578e+00,  2.4414e-02,  4.9219e-01,  ..., -1.1641e+00,
          -1.0703e+00, -2.7930e-01],
         [ 1.9336e-01, -4.8242e-01,  1.4531e+00,  ..., -9.9609e-01,
           1.9766e+00, -1.4219e+00]],

        [[-4.9133e-03, -5.0049e-03, -8.3008e-03,  ..., -8.8501e-04,
          -2.7618e-03,  1.4587e-02],
         [-5.5469e-01, -1.0859e+00, -3.4570e-01,  ..., -1.1641e+00,
           5.1172e-01,  1.9629e-01],
         [ 2.1875e-01, -9.4531e-01,  2.5781e-01,  ..., -7.8516e-01,
           5.6641e-01,  1.0000e+00],
         ...,
         [-1.1865e-01,  7.1875e-01,  1.5234e+00,  ..., -1.6250e+00,
           8.0469e-01, -4.2383e-01],
         [ 1.2578e+00,  2.4414e-02,  4.9219e-01,  ..., -1.1641e+00,
          -1.0703e+00, -2.7930e-01],
         [ 1.9336e-01, -4.8242e-01,  1.4531e+00,  ..., -9.9609e-01,
           1.9766e+00, -1.4219e+00]],

        [[-4.9133e-03, -5.0049e-03, -8.3008e-03,  ..., -8.8501e-04,
          -2.7618e-03,  1.4587e-02],
         [-5.5469e-01, -1.0859e+00, -3.4570e-01,  ..., -1.1641e+00,
           5.1172e-01,  1.9629e-01],
         [ 2.1875e-01, -9.4531e-01,  2.5781e-01,  ..., -7.8516e-01,
           5.6641e-01,  1.0000e+00],
         ...,
         [-1.1865e-01,  7.1875e-01,  1.5234e+00,  ..., -1.6250e+00,
           8.0469e-01, -4.2383e-01],
         [ 1.2578e+00,  2.4414e-02,  4.9219e-01,  ..., -1.1641e+00,
          -1.0703e+00, -2.7930e-01],
         [ 1.9336e-01, -4.8242e-01,  1.4531e+00,  ..., -9.9609e-01,
           1.9766e+00, -1.4219e+00]]], dtype=torch.bfloat16)), (tensor([[[-1.6724e-02, -4.4861e-03, -2.8076e-03,  ...,  1.1523e-01,
           1.3574e-01, -1.5747e-02],
         [ 4.4141e-01,  3.1250e-01, -9.4141e-01,  ...,  7.8906e-01,
           4.4531e-01,  4.8438e-01],
         [-1.5469e+00,  1.9727e-01, -2.2188e+00,  ...,  6.8359e-01,
          -1.1426e-01,  2.2500e+00],
         ...,
         [-5.7812e-01, -7.6172e-02,  3.9062e-01,  ..., -4.9316e-02,
          -1.6484e+00, -2.3828e-01],
         [-2.4219e-01, -7.0312e-01, -2.3047e-01,  ...,  5.1172e-01,
          -4.8438e-01, -5.2246e-02],
         [ 1.7031e+00,  1.5156e+00, -1.1797e+00,  ..., -1.5703e+00,
          -4.8633e-01,  6.5625e-01]],

        [[-1.6724e-02, -4.4861e-03, -2.8076e-03,  ...,  1.1523e-01,
           1.3574e-01, -1.5747e-02],
         [ 4.4141e-01,  3.1250e-01, -9.4141e-01,  ...,  7.8906e-01,
           4.4531e-01,  4.8438e-01],
         [-1.5469e+00,  1.9727e-01, -2.2188e+00,  ...,  6.8359e-01,
          -1.1426e-01,  2.2500e+00],
         ...,
         [-5.7812e-01, -7.6172e-02,  3.9062e-01,  ..., -4.9316e-02,
          -1.6484e+00, -2.3828e-01],
         [-2.4219e-01, -7.0312e-01, -2.3047e-01,  ...,  5.1172e-01,
          -4.8438e-01, -5.2246e-02],
         [ 1.7031e+00,  1.5156e+00, -1.1797e+00,  ..., -1.5703e+00,
          -4.8633e-01,  6.5625e-01]],

        [[-1.6724e-02, -4.4861e-03, -2.8076e-03,  ...,  1.1523e-01,
           1.3574e-01, -1.5747e-02],
         [ 4.4141e-01,  3.1250e-01, -9.4141e-01,  ...,  7.8906e-01,
           4.4531e-01,  4.8438e-01],
         [-1.5469e+00,  1.9727e-01, -2.2188e+00,  ...,  6.8359e-01,
          -1.1426e-01,  2.2500e+00],
         ...,
         [-5.7812e-01, -7.6172e-02,  3.9062e-01,  ..., -4.9316e-02,
          -1.6484e+00, -2.3828e-01],
         [-2.4219e-01, -7.0312e-01, -2.3047e-01,  ...,  5.1172e-01,
          -4.8438e-01, -5.2246e-02],
         [ 1.7031e+00,  1.5156e+00, -1.1797e+00,  ..., -1.5703e+00,
          -4.8633e-01,  6.5625e-01]],

        ...,

        [[ 2.2125e-03,  1.2573e-02,  1.5335e-03,  ...,  9.7168e-02,
          -4.8523e-03, -1.6504e-01],
         [ 1.8203e+00,  8.9062e-01, -2.1484e-01,  ..., -2.0801e-01,
          -4.8242e-01,  1.1562e+00],
         [ 1.4375e+00,  2.5156e+00, -8.4375e-01,  ...,  2.4805e-01,
          -2.2949e-01,  2.6406e+00],
         ...,
         [ 8.1250e-01, -1.9531e-02, -2.6758e-01,  ..., -1.6328e+00,
          -2.3633e-01,  3.0273e-01],
         [ 5.3906e-01, -8.6328e-01, -1.3516e+00,  ...,  1.2031e+00,
           3.2422e-01,  1.1641e+00],
         [-1.4688e+00, -2.4375e+00, -4.5508e-01,  ..., -9.2188e-01,
          -1.4282e-02,  2.6094e+00]],

        [[ 2.2125e-03,  1.2573e-02,  1.5335e-03,  ...,  9.7168e-02,
          -4.8523e-03, -1.6504e-01],
         [ 1.8203e+00,  8.9062e-01, -2.1484e-01,  ..., -2.0801e-01,
          -4.8242e-01,  1.1562e+00],
         [ 1.4375e+00,  2.5156e+00, -8.4375e-01,  ...,  2.4805e-01,
          -2.2949e-01,  2.6406e+00],
         ...,
         [ 8.1250e-01, -1.9531e-02, -2.6758e-01,  ..., -1.6328e+00,
          -2.3633e-01,  3.0273e-01],
         [ 5.3906e-01, -8.6328e-01, -1.3516e+00,  ...,  1.2031e+00,
           3.2422e-01,  1.1641e+00],
         [-1.4688e+00, -2.4375e+00, -4.5508e-01,  ..., -9.2188e-01,
          -1.4282e-02,  2.6094e+00]],

        [[ 2.2125e-03,  1.2573e-02,  1.5335e-03,  ...,  9.7168e-02,
          -4.8523e-03, -1.6504e-01],
         [ 1.8203e+00,  8.9062e-01, -2.1484e-01,  ..., -2.0801e-01,
          -4.8242e-01,  1.1562e+00],
         [ 1.4375e+00,  2.5156e+00, -8.4375e-01,  ...,  2.4805e-01,
          -2.2949e-01,  2.6406e+00],
         ...,
         [ 8.1250e-01, -1.9531e-02, -2.6758e-01,  ..., -1.6328e+00,
          -2.3633e-01,  3.0273e-01],
         [ 5.3906e-01, -8.6328e-01, -1.3516e+00,  ...,  1.2031e+00,
           3.2422e-01,  1.1641e+00],
         [-1.4688e+00, -2.4375e+00, -4.5508e-01,  ..., -9.2188e-01,
          -1.4282e-02,  2.6094e+00]]], dtype=torch.bfloat16), tensor([[[ 0.0064,  0.0118,  0.0106,  ...,  0.0147,  0.0085, -0.0025],
         [-1.1484, -1.8438,  0.8281,  ...,  1.4219, -1.5938, -0.3262],
         [-0.4883, -0.5859, -0.4316,  ...,  0.0128, -0.0199,  1.2422],
         ...,
         [ 0.1182,  0.7305, -0.0317,  ..., -0.6523,  0.8633,  0.6562],
         [ 0.4688,  0.2949, -0.9844,  ...,  0.5078,  0.6172,  0.7578],
         [ 0.8516, -1.6094, -0.2246,  ..., -0.1001,  2.1562,  1.0547]],

        [[ 0.0064,  0.0118,  0.0106,  ...,  0.0147,  0.0085, -0.0025],
         [-1.1484, -1.8438,  0.8281,  ...,  1.4219, -1.5938, -0.3262],
         [-0.4883, -0.5859, -0.4316,  ...,  0.0128, -0.0199,  1.2422],
         ...,
         [ 0.1182,  0.7305, -0.0317,  ..., -0.6523,  0.8633,  0.6562],
         [ 0.4688,  0.2949, -0.9844,  ...,  0.5078,  0.6172,  0.7578],
         [ 0.8516, -1.6094, -0.2246,  ..., -0.1001,  2.1562,  1.0547]],

        [[ 0.0064,  0.0118,  0.0106,  ...,  0.0147,  0.0085, -0.0025],
         [-1.1484, -1.8438,  0.8281,  ...,  1.4219, -1.5938, -0.3262],
         [-0.4883, -0.5859, -0.4316,  ...,  0.0128, -0.0199,  1.2422],
         ...,
         [ 0.1182,  0.7305, -0.0317,  ..., -0.6523,  0.8633,  0.6562],
         [ 0.4688,  0.2949, -0.9844,  ...,  0.5078,  0.6172,  0.7578],
         [ 0.8516, -1.6094, -0.2246,  ..., -0.1001,  2.1562,  1.0547]],

        ...,

        [[-0.0048,  0.0029, -0.0053,  ..., -0.0062, -0.0024,  0.0045],
         [ 0.4160, -0.6328,  0.5195,  ...,  0.7461, -0.2754,  0.5273],
         [-1.4375,  1.0703,  0.4629,  ..., -0.9336, -1.2734, -0.4043],
         ...,
         [-0.8711, -0.2246,  0.1895,  ..., -0.2451, -0.4746, -0.7812],
         [-2.2500, -0.7812, -1.2812,  ..., -0.8164, -1.1484, -0.1445],
         [ 1.1094,  0.5508,  0.0986,  ...,  0.4297, -0.0275, -0.0444]],

        [[-0.0048,  0.0029, -0.0053,  ..., -0.0062, -0.0024,  0.0045],
         [ 0.4160, -0.6328,  0.5195,  ...,  0.7461, -0.2754,  0.5273],
         [-1.4375,  1.0703,  0.4629,  ..., -0.9336, -1.2734, -0.4043],
         ...,
         [-0.8711, -0.2246,  0.1895,  ..., -0.2451, -0.4746, -0.7812],
         [-2.2500, -0.7812, -1.2812,  ..., -0.8164, -1.1484, -0.1445],
         [ 1.1094,  0.5508,  0.0986,  ...,  0.4297, -0.0275, -0.0444]],

        [[-0.0048,  0.0029, -0.0053,  ..., -0.0062, -0.0024,  0.0045],
         [ 0.4160, -0.6328,  0.5195,  ...,  0.7461, -0.2754,  0.5273],
         [-1.4375,  1.0703,  0.4629,  ..., -0.9336, -1.2734, -0.4043],
         ...,
         [-0.8711, -0.2246,  0.1895,  ..., -0.2451, -0.4746, -0.7812],
         [-2.2500, -0.7812, -1.2812,  ..., -0.8164, -1.1484, -0.1445],
         [ 1.1094,  0.5508,  0.0986,  ...,  0.4297, -0.0275, -0.0444]]],
       dtype=torch.bfloat16)), (tensor([[[-7.9346e-04,  2.2430e-03,  1.5717e-03,  ..., -7.9956e-03,
          -1.6406e+00,  8.7402e-02],
         [ 3.9453e-01,  1.5625e-01, -1.2266e+00,  ..., -1.7285e-01,
           6.5938e+00, -1.5527e-01],
         [ 4.2188e-01, -3.2812e-01,  1.5430e-01,  ..., -2.4375e+00,
           9.6875e+00,  1.8281e+00],
         ...,
         [-2.2266e-01,  1.0205e-01, -5.9766e-01,  ...,  4.7812e+00,
           7.6875e+00, -4.1875e+00],
         [-1.8555e-01,  6.9336e-02, -3.0078e-01,  ...,  1.3594e+00,
           5.8125e+00, -2.3594e+00],
         [-7.2656e-01,  9.2969e-01, -1.8945e-01,  ..., -2.3281e+00,
           9.1250e+00, -1.1484e+00]],

        [[-7.9346e-04,  2.2430e-03,  1.5717e-03,  ..., -7.9956e-03,
          -1.6406e+00,  8.7402e-02],
         [ 3.9453e-01,  1.5625e-01, -1.2266e+00,  ..., -1.7285e-01,
           6.5938e+00, -1.5527e-01],
         [ 4.2188e-01, -3.2812e-01,  1.5430e-01,  ..., -2.4375e+00,
           9.6875e+00,  1.8281e+00],
         ...,
         [-2.2266e-01,  1.0205e-01, -5.9766e-01,  ...,  4.7812e+00,
           7.6875e+00, -4.1875e+00],
         [-1.8555e-01,  6.9336e-02, -3.0078e-01,  ...,  1.3594e+00,
           5.8125e+00, -2.3594e+00],
         [-7.2656e-01,  9.2969e-01, -1.8945e-01,  ..., -2.3281e+00,
           9.1250e+00, -1.1484e+00]],

        [[-7.9346e-04,  2.2430e-03,  1.5717e-03,  ..., -7.9956e-03,
          -1.6406e+00,  8.7402e-02],
         [ 3.9453e-01,  1.5625e-01, -1.2266e+00,  ..., -1.7285e-01,
           6.5938e+00, -1.5527e-01],
         [ 4.2188e-01, -3.2812e-01,  1.5430e-01,  ..., -2.4375e+00,
           9.6875e+00,  1.8281e+00],
         ...,
         [-2.2266e-01,  1.0205e-01, -5.9766e-01,  ...,  4.7812e+00,
           7.6875e+00, -4.1875e+00],
         [-1.8555e-01,  6.9336e-02, -3.0078e-01,  ...,  1.3594e+00,
           5.8125e+00, -2.3594e+00],
         [-7.2656e-01,  9.2969e-01, -1.8945e-01,  ..., -2.3281e+00,
           9.1250e+00, -1.1484e+00]],

        ...,

        [[-3.6774e-03, -1.3000e-02, -1.9989e-03,  ..., -5.5908e-02,
           2.0386e-02, -7.2754e-02],
         [ 1.8594e+00, -1.2344e+00, -7.7344e-01,  ...,  5.4297e-01,
          -1.0703e+00,  1.8359e+00],
         [ 1.0000e+00, -7.2656e-01, -4.6875e-01,  ...,  7.3438e-01,
           3.0273e-01, -2.5781e-01],
         ...,
         [ 1.0781e+00,  4.1016e-01,  1.3594e+00,  ..., -8.6719e-01,
           2.6719e+00,  8.8672e-01],
         [ 5.7422e-01,  7.3438e-01,  5.8594e-03,  ..., -1.0547e+00,
           1.6602e-01, -1.6406e+00],
         [-1.5156e+00, -5.6250e-01, -7.2656e-01,  ..., -1.2891e+00,
           1.2734e+00,  7.7344e-01]],

        [[-3.6774e-03, -1.3000e-02, -1.9989e-03,  ..., -5.5908e-02,
           2.0386e-02, -7.2754e-02],
         [ 1.8594e+00, -1.2344e+00, -7.7344e-01,  ...,  5.4297e-01,
          -1.0703e+00,  1.8359e+00],
         [ 1.0000e+00, -7.2656e-01, -4.6875e-01,  ...,  7.3438e-01,
           3.0273e-01, -2.5781e-01],
         ...,
         [ 1.0781e+00,  4.1016e-01,  1.3594e+00,  ..., -8.6719e-01,
           2.6719e+00,  8.8672e-01],
         [ 5.7422e-01,  7.3438e-01,  5.8594e-03,  ..., -1.0547e+00,
           1.6602e-01, -1.6406e+00],
         [-1.5156e+00, -5.6250e-01, -7.2656e-01,  ..., -1.2891e+00,
           1.2734e+00,  7.7344e-01]],

        [[-3.6774e-03, -1.3000e-02, -1.9989e-03,  ..., -5.5908e-02,
           2.0386e-02, -7.2754e-02],
         [ 1.8594e+00, -1.2344e+00, -7.7344e-01,  ...,  5.4297e-01,
          -1.0703e+00,  1.8359e+00],
         [ 1.0000e+00, -7.2656e-01, -4.6875e-01,  ...,  7.3438e-01,
           3.0273e-01, -2.5781e-01],
         ...,
         [ 1.0781e+00,  4.1016e-01,  1.3594e+00,  ..., -8.6719e-01,
           2.6719e+00,  8.8672e-01],
         [ 5.7422e-01,  7.3438e-01,  5.8594e-03,  ..., -1.0547e+00,
           1.6602e-01, -1.6406e+00],
         [-1.5156e+00, -5.6250e-01, -7.2656e-01,  ..., -1.2891e+00,
           1.2734e+00,  7.7344e-01]]], dtype=torch.bfloat16), tensor([[[ 4.6692e-03,  2.2095e-02, -3.2043e-03,  ..., -1.2268e-02,
           1.4160e-02,  3.9307e-02],
         [ 5.5078e-01, -2.9102e-01, -6.2500e-01,  ..., -8.1635e-04,
          -1.5078e+00, -1.0781e+00],
         [-2.6367e-02,  6.0156e-01,  2.3047e-01,  ..., -8.0469e-01,
           4.0820e-01, -1.7578e-01],
         ...,
         [-9.0234e-01, -2.0410e-01, -1.0859e+00,  ...,  6.3672e-01,
          -9.4238e-02,  1.5938e+00],
         [-8.0859e-01,  8.3008e-03, -9.4141e-01,  ..., -5.5078e-01,
           4.8438e-01,  1.5234e+00],
         [ 1.0781e+00, -1.1328e+00, -1.7422e+00,  ..., -1.2578e+00,
          -2.3750e+00,  5.4688e-01]],

        [[ 4.6692e-03,  2.2095e-02, -3.2043e-03,  ..., -1.2268e-02,
           1.4160e-02,  3.9307e-02],
         [ 5.5078e-01, -2.9102e-01, -6.2500e-01,  ..., -8.1635e-04,
          -1.5078e+00, -1.0781e+00],
         [-2.6367e-02,  6.0156e-01,  2.3047e-01,  ..., -8.0469e-01,
           4.0820e-01, -1.7578e-01],
         ...,
         [-9.0234e-01, -2.0410e-01, -1.0859e+00,  ...,  6.3672e-01,
          -9.4238e-02,  1.5938e+00],
         [-8.0859e-01,  8.3008e-03, -9.4141e-01,  ..., -5.5078e-01,
           4.8438e-01,  1.5234e+00],
         [ 1.0781e+00, -1.1328e+00, -1.7422e+00,  ..., -1.2578e+00,
          -2.3750e+00,  5.4688e-01]],

        [[ 4.6692e-03,  2.2095e-02, -3.2043e-03,  ..., -1.2268e-02,
           1.4160e-02,  3.9307e-02],
         [ 5.5078e-01, -2.9102e-01, -6.2500e-01,  ..., -8.1635e-04,
          -1.5078e+00, -1.0781e+00],
         [-2.6367e-02,  6.0156e-01,  2.3047e-01,  ..., -8.0469e-01,
           4.0820e-01, -1.7578e-01],
         ...,
         [-9.0234e-01, -2.0410e-01, -1.0859e+00,  ...,  6.3672e-01,
          -9.4238e-02,  1.5938e+00],
         [-8.0859e-01,  8.3008e-03, -9.4141e-01,  ..., -5.5078e-01,
           4.8438e-01,  1.5234e+00],
         [ 1.0781e+00, -1.1328e+00, -1.7422e+00,  ..., -1.2578e+00,
          -2.3750e+00,  5.4688e-01]],

        ...,

        [[ 1.7319e-03,  1.0872e-04, -7.7209e-03,  ..., -1.3489e-02,
          -8.0566e-03,  6.9580e-03],
         [-7.5000e-01,  3.8086e-01, -8.7500e-01,  ..., -4.7070e-01,
          -1.2578e+00,  4.1602e-01],
         [-2.5000e-01,  2.1289e-01,  5.5469e-01,  ...,  8.9844e-01,
          -1.0469e+00,  1.2266e+00],
         ...,
         [-1.3867e-01,  9.3750e-02,  3.3203e-01,  ..., -1.2598e-01,
          -2.4414e-01, -4.9072e-02],
         [ 3.9258e-01, -6.9336e-02,  3.7500e-01,  ...,  1.1250e+00,
           1.5078e+00, -1.6504e-01],
         [ 3.3789e-01,  9.3750e-01,  1.0781e+00,  ...,  4.0625e-01,
           1.9455e-03,  3.0469e-01]],

        [[ 1.7319e-03,  1.0872e-04, -7.7209e-03,  ..., -1.3489e-02,
          -8.0566e-03,  6.9580e-03],
         [-7.5000e-01,  3.8086e-01, -8.7500e-01,  ..., -4.7070e-01,
          -1.2578e+00,  4.1602e-01],
         [-2.5000e-01,  2.1289e-01,  5.5469e-01,  ...,  8.9844e-01,
          -1.0469e+00,  1.2266e+00],
         ...,
         [-1.3867e-01,  9.3750e-02,  3.3203e-01,  ..., -1.2598e-01,
          -2.4414e-01, -4.9072e-02],
         [ 3.9258e-01, -6.9336e-02,  3.7500e-01,  ...,  1.1250e+00,
           1.5078e+00, -1.6504e-01],
         [ 3.3789e-01,  9.3750e-01,  1.0781e+00,  ...,  4.0625e-01,
           1.9455e-03,  3.0469e-01]],

        [[ 1.7319e-03,  1.0872e-04, -7.7209e-03,  ..., -1.3489e-02,
          -8.0566e-03,  6.9580e-03],
         [-7.5000e-01,  3.8086e-01, -8.7500e-01,  ..., -4.7070e-01,
          -1.2578e+00,  4.1602e-01],
         [-2.5000e-01,  2.1289e-01,  5.5469e-01,  ...,  8.9844e-01,
          -1.0469e+00,  1.2266e+00],
         ...,
         [-1.3867e-01,  9.3750e-02,  3.3203e-01,  ..., -1.2598e-01,
          -2.4414e-01, -4.9072e-02],
         [ 3.9258e-01, -6.9336e-02,  3.7500e-01,  ...,  1.1250e+00,
           1.5078e+00, -1.6504e-01],
         [ 3.3789e-01,  9.3750e-01,  1.0781e+00,  ...,  4.0625e-01,
           1.9455e-03,  3.0469e-01]]], dtype=torch.bfloat16)), (tensor([[[ 1.2695e-02,  1.6098e-03,  1.1658e-02,  ...,  2.1680e-01,
           1.3672e-01, -1.2812e+00],
         [ 7.8516e-01,  4.4531e-01,  3.4375e-01,  ...,  1.6406e+00,
           8.2031e-01,  1.7891e+00],
         [ 2.5781e-01, -4.2969e-01, -2.9785e-02,  ...,  1.4766e+00,
          -2.0781e+00,  2.2656e+00],
         ...,
         [ 4.5117e-01,  2.4609e-01,  3.5547e-01,  ..., -3.2188e+00,
           8.6328e-01,  3.9531e+00],
         [-5.4297e-01, -1.9824e-01,  1.5625e-01,  ..., -1.3672e+00,
          -1.2109e+00,  2.7031e+00],
         [-2.0703e-01,  3.3008e-01,  2.9297e-01,  ...,  1.1953e+00,
          -1.2891e+00,  4.5625e+00]],

        [[ 1.2695e-02,  1.6098e-03,  1.1658e-02,  ...,  2.1680e-01,
           1.3672e-01, -1.2812e+00],
         [ 7.8516e-01,  4.4531e-01,  3.4375e-01,  ...,  1.6406e+00,
           8.2031e-01,  1.7891e+00],
         [ 2.5781e-01, -4.2969e-01, -2.9785e-02,  ...,  1.4766e+00,
          -2.0781e+00,  2.2656e+00],
         ...,
         [ 4.5117e-01,  2.4609e-01,  3.5547e-01,  ..., -3.2188e+00,
           8.6328e-01,  3.9531e+00],
         [-5.4297e-01, -1.9824e-01,  1.5625e-01,  ..., -1.3672e+00,
          -1.2109e+00,  2.7031e+00],
         [-2.0703e-01,  3.3008e-01,  2.9297e-01,  ...,  1.1953e+00,
          -1.2891e+00,  4.5625e+00]],

        [[ 1.2695e-02,  1.6098e-03,  1.1658e-02,  ...,  2.1680e-01,
           1.3672e-01, -1.2812e+00],
         [ 7.8516e-01,  4.4531e-01,  3.4375e-01,  ...,  1.6406e+00,
           8.2031e-01,  1.7891e+00],
         [ 2.5781e-01, -4.2969e-01, -2.9785e-02,  ...,  1.4766e+00,
          -2.0781e+00,  2.2656e+00],
         ...,
         [ 4.5117e-01,  2.4609e-01,  3.5547e-01,  ..., -3.2188e+00,
           8.6328e-01,  3.9531e+00],
         [-5.4297e-01, -1.9824e-01,  1.5625e-01,  ..., -1.3672e+00,
          -1.2109e+00,  2.7031e+00],
         [-2.0703e-01,  3.3008e-01,  2.9297e-01,  ...,  1.1953e+00,
          -1.2891e+00,  4.5625e+00]],

        ...,

        [[ 9.3384e-03, -2.3438e-02, -4.6692e-03,  ...,  5.3516e-01,
          -2.9297e-01, -8.1543e-02],
         [ 4.1211e-01,  9.6875e-01, -4.1992e-02,  ..., -1.7188e+00,
           9.1016e-01, -3.4062e+00],
         [ 2.2852e-01,  7.6904e-03,  3.0469e-01,  ..., -2.0703e-01,
           8.8281e-01, -1.5000e+00],
         ...,
         [-3.8574e-02, -5.1172e-01, -8.4961e-02,  ..., -1.6641e+00,
          -1.5000e+00, -2.7500e+00],
         [-4.8096e-02, -3.9453e-01, -1.1328e-01,  ...,  1.5547e+00,
          -8.3594e-01,  3.3008e-01],
         [-8.6719e-01, -1.0547e-01,  9.9219e-01,  ...,  1.0547e+00,
           8.8281e-01, -6.6250e+00]],

        [[ 9.3384e-03, -2.3438e-02, -4.6692e-03,  ...,  5.3516e-01,
          -2.9297e-01, -8.1543e-02],
         [ 4.1211e-01,  9.6875e-01, -4.1992e-02,  ..., -1.7188e+00,
           9.1016e-01, -3.4062e+00],
         [ 2.2852e-01,  7.6904e-03,  3.0469e-01,  ..., -2.0703e-01,
           8.8281e-01, -1.5000e+00],
         ...,
         [-3.8574e-02, -5.1172e-01, -8.4961e-02,  ..., -1.6641e+00,
          -1.5000e+00, -2.7500e+00],
         [-4.8096e-02, -3.9453e-01, -1.1328e-01,  ...,  1.5547e+00,
          -8.3594e-01,  3.3008e-01],
         [-8.6719e-01, -1.0547e-01,  9.9219e-01,  ...,  1.0547e+00,
           8.8281e-01, -6.6250e+00]],

        [[ 9.3384e-03, -2.3438e-02, -4.6692e-03,  ...,  5.3516e-01,
          -2.9297e-01, -8.1543e-02],
         [ 4.1211e-01,  9.6875e-01, -4.1992e-02,  ..., -1.7188e+00,
           9.1016e-01, -3.4062e+00],
         [ 2.2852e-01,  7.6904e-03,  3.0469e-01,  ..., -2.0703e-01,
           8.8281e-01, -1.5000e+00],
         ...,
         [-3.8574e-02, -5.1172e-01, -8.4961e-02,  ..., -1.6641e+00,
          -1.5000e+00, -2.7500e+00],
         [-4.8096e-02, -3.9453e-01, -1.1328e-01,  ...,  1.5547e+00,
          -8.3594e-01,  3.3008e-01],
         [-8.6719e-01, -1.0547e-01,  9.9219e-01,  ...,  1.0547e+00,
           8.8281e-01, -6.6250e+00]]], dtype=torch.bfloat16), tensor([[[-5.3406e-03,  1.2398e-04,  1.0864e-02,  ..., -7.3547e-03,
           1.1475e-02,  7.3547e-03],
         [ 2.5156e+00, -1.6172e+00, -5.4688e-01,  ...,  1.7773e-01,
           7.1094e-01,  2.1406e+00],
         [ 2.7930e-01, -1.3125e+00, -1.1172e+00,  ...,  2.2754e-01,
           3.7109e-01,  7.1094e-01],
         ...,
         [-9.7656e-02, -3.7500e-01,  6.1328e-01,  ...,  4.0820e-01,
          -1.4258e-01,  9.9609e-01],
         [-1.8359e-01,  1.9922e-01, -7.2266e-01,  ...,  4.2578e-01,
           8.9111e-03,  8.5156e-01],
         [ 1.8906e+00, -2.2852e-01,  3.4961e-01,  ...,  1.7344e+00,
          -1.5469e+00,  1.1797e+00]],

        [[-5.3406e-03,  1.2398e-04,  1.0864e-02,  ..., -7.3547e-03,
           1.1475e-02,  7.3547e-03],
         [ 2.5156e+00, -1.6172e+00, -5.4688e-01,  ...,  1.7773e-01,
           7.1094e-01,  2.1406e+00],
         [ 2.7930e-01, -1.3125e+00, -1.1172e+00,  ...,  2.2754e-01,
           3.7109e-01,  7.1094e-01],
         ...,
         [-9.7656e-02, -3.7500e-01,  6.1328e-01,  ...,  4.0820e-01,
          -1.4258e-01,  9.9609e-01],
         [-1.8359e-01,  1.9922e-01, -7.2266e-01,  ...,  4.2578e-01,
           8.9111e-03,  8.5156e-01],
         [ 1.8906e+00, -2.2852e-01,  3.4961e-01,  ...,  1.7344e+00,
          -1.5469e+00,  1.1797e+00]],

        [[-5.3406e-03,  1.2398e-04,  1.0864e-02,  ..., -7.3547e-03,
           1.1475e-02,  7.3547e-03],
         [ 2.5156e+00, -1.6172e+00, -5.4688e-01,  ...,  1.7773e-01,
           7.1094e-01,  2.1406e+00],
         [ 2.7930e-01, -1.3125e+00, -1.1172e+00,  ...,  2.2754e-01,
           3.7109e-01,  7.1094e-01],
         ...,
         [-9.7656e-02, -3.7500e-01,  6.1328e-01,  ...,  4.0820e-01,
          -1.4258e-01,  9.9609e-01],
         [-1.8359e-01,  1.9922e-01, -7.2266e-01,  ...,  4.2578e-01,
           8.9111e-03,  8.5156e-01],
         [ 1.8906e+00, -2.2852e-01,  3.4961e-01,  ...,  1.7344e+00,
          -1.5469e+00,  1.1797e+00]],

        ...,

        [[ 5.8289e-03,  1.7822e-02, -2.7222e-02,  ..., -1.2360e-03,
           1.0071e-02,  9.7656e-04],
         [-6.9531e-01, -1.4062e+00,  4.2773e-01,  ...,  6.6016e-01,
           1.1719e-01,  2.8438e+00],
         [-3.5352e-01, -1.9141e+00,  2.0996e-01,  ..., -3.6328e-01,
           1.7383e-01, -1.1562e+00],
         ...,
         [-3.3594e-01, -7.4609e-01, -1.2988e-01,  ...,  3.0078e-01,
          -9.7656e-02, -2.6172e-01],
         [-1.0781e+00, -3.1641e-01, -2.8125e-01,  ...,  3.0469e-01,
           1.0547e+00,  6.4941e-02],
         [ 1.7266e+00,  4.6143e-02,  2.1719e+00,  ...,  1.1670e-01,
           9.8828e-01, -1.0703e+00]],

        [[ 5.8289e-03,  1.7822e-02, -2.7222e-02,  ..., -1.2360e-03,
           1.0071e-02,  9.7656e-04],
         [-6.9531e-01, -1.4062e+00,  4.2773e-01,  ...,  6.6016e-01,
           1.1719e-01,  2.8438e+00],
         [-3.5352e-01, -1.9141e+00,  2.0996e-01,  ..., -3.6328e-01,
           1.7383e-01, -1.1562e+00],
         ...,
         [-3.3594e-01, -7.4609e-01, -1.2988e-01,  ...,  3.0078e-01,
          -9.7656e-02, -2.6172e-01],
         [-1.0781e+00, -3.1641e-01, -2.8125e-01,  ...,  3.0469e-01,
           1.0547e+00,  6.4941e-02],
         [ 1.7266e+00,  4.6143e-02,  2.1719e+00,  ...,  1.1670e-01,
           9.8828e-01, -1.0703e+00]],

        [[ 5.8289e-03,  1.7822e-02, -2.7222e-02,  ..., -1.2360e-03,
           1.0071e-02,  9.7656e-04],
         [-6.9531e-01, -1.4062e+00,  4.2773e-01,  ...,  6.6016e-01,
           1.1719e-01,  2.8438e+00],
         [-3.5352e-01, -1.9141e+00,  2.0996e-01,  ..., -3.6328e-01,
           1.7383e-01, -1.1562e+00],
         ...,
         [-3.3594e-01, -7.4609e-01, -1.2988e-01,  ...,  3.0078e-01,
          -9.7656e-02, -2.6172e-01],
         [-1.0781e+00, -3.1641e-01, -2.8125e-01,  ...,  3.0469e-01,
           1.0547e+00,  6.4941e-02],
         [ 1.7266e+00,  4.6143e-02,  2.1719e+00,  ...,  1.1670e-01,
           9.8828e-01, -1.0703e+00]]], dtype=torch.bfloat16)), (tensor([[[ 8.9722e-03,  1.1139e-03, -4.6997e-03,  ...,  2.6611e-02,
          -1.4531e+00,  2.0020e-01],
         [-3.1875e+00,  1.5625e-01, -1.3438e+00,  ..., -1.1797e+00,
           7.4688e+00, -3.3750e+00],
         [-7.5000e-01,  7.6172e-01, -5.3125e-01,  ...,  7.8516e-01,
           8.7500e+00, -2.4844e+00],
         ...,
         [ 4.1406e-01, -4.9609e-01, -2.8125e-01,  ..., -5.5859e-01,
           5.4375e+00, -3.8750e+00],
         [ 8.1250e-01, -3.7305e-01, -7.9688e-01,  ...,  1.6211e-01,
           6.7500e+00, -2.1719e+00],
         [ 6.2500e-01, -1.1797e+00, -1.3750e+00,  ..., -8.6719e-01,
           7.7500e+00, -1.3359e+00]],

        [[ 8.9722e-03,  1.1139e-03, -4.6997e-03,  ...,  2.6611e-02,
          -1.4531e+00,  2.0020e-01],
         [-3.1875e+00,  1.5625e-01, -1.3438e+00,  ..., -1.1797e+00,
           7.4688e+00, -3.3750e+00],
         [-7.5000e-01,  7.6172e-01, -5.3125e-01,  ...,  7.8516e-01,
           8.7500e+00, -2.4844e+00],
         ...,
         [ 4.1406e-01, -4.9609e-01, -2.8125e-01,  ..., -5.5859e-01,
           5.4375e+00, -3.8750e+00],
         [ 8.1250e-01, -3.7305e-01, -7.9688e-01,  ...,  1.6211e-01,
           6.7500e+00, -2.1719e+00],
         [ 6.2500e-01, -1.1797e+00, -1.3750e+00,  ..., -8.6719e-01,
           7.7500e+00, -1.3359e+00]],

        [[ 8.9722e-03,  1.1139e-03, -4.6997e-03,  ...,  2.6611e-02,
          -1.4531e+00,  2.0020e-01],
         [-3.1875e+00,  1.5625e-01, -1.3438e+00,  ..., -1.1797e+00,
           7.4688e+00, -3.3750e+00],
         [-7.5000e-01,  7.6172e-01, -5.3125e-01,  ...,  7.8516e-01,
           8.7500e+00, -2.4844e+00],
         ...,
         [ 4.1406e-01, -4.9609e-01, -2.8125e-01,  ..., -5.5859e-01,
           5.4375e+00, -3.8750e+00],
         [ 8.1250e-01, -3.7305e-01, -7.9688e-01,  ...,  1.6211e-01,
           6.7500e+00, -2.1719e+00],
         [ 6.2500e-01, -1.1797e+00, -1.3750e+00,  ..., -8.6719e-01,
           7.7500e+00, -1.3359e+00]],

        ...,

        [[-6.3782e-03, -8.7280e-03, -9.7046e-03,  ...,  8.0078e-02,
           8.1055e-02, -2.4707e-01],
         [ 3.1562e+00,  3.8672e-01,  2.2266e-01,  ..., -4.9805e-01,
          -1.1484e+00,  1.0645e-01],
         [ 9.3359e-01,  1.5391e+00,  1.1328e-01,  ...,  1.3203e+00,
           2.0938e+00,  7.9688e-01],
         ...,
         [-3.0469e-01,  6.2891e-01,  6.9922e-01,  ..., -2.6875e+00,
           2.8320e-01,  1.9922e+00],
         [-7.3047e-01, -3.1836e-01,  5.7422e-01,  ..., -1.9766e+00,
           2.2344e+00,  1.0078e+00],
         [-6.1719e-01, -1.3438e+00,  1.0703e+00,  ...,  4.0430e-01,
           1.1250e+00, -3.8086e-01]],

        [[-6.3782e-03, -8.7280e-03, -9.7046e-03,  ...,  8.0078e-02,
           8.1055e-02, -2.4707e-01],
         [ 3.1562e+00,  3.8672e-01,  2.2266e-01,  ..., -4.9805e-01,
          -1.1484e+00,  1.0645e-01],
         [ 9.3359e-01,  1.5391e+00,  1.1328e-01,  ...,  1.3203e+00,
           2.0938e+00,  7.9688e-01],
         ...,
         [-3.0469e-01,  6.2891e-01,  6.9922e-01,  ..., -2.6875e+00,
           2.8320e-01,  1.9922e+00],
         [-7.3047e-01, -3.1836e-01,  5.7422e-01,  ..., -1.9766e+00,
           2.2344e+00,  1.0078e+00],
         [-6.1719e-01, -1.3438e+00,  1.0703e+00,  ...,  4.0430e-01,
           1.1250e+00, -3.8086e-01]],

        [[-6.3782e-03, -8.7280e-03, -9.7046e-03,  ...,  8.0078e-02,
           8.1055e-02, -2.4707e-01],
         [ 3.1562e+00,  3.8672e-01,  2.2266e-01,  ..., -4.9805e-01,
          -1.1484e+00,  1.0645e-01],
         [ 9.3359e-01,  1.5391e+00,  1.1328e-01,  ...,  1.3203e+00,
           2.0938e+00,  7.9688e-01],
         ...,
         [-3.0469e-01,  6.2891e-01,  6.9922e-01,  ..., -2.6875e+00,
           2.8320e-01,  1.9922e+00],
         [-7.3047e-01, -3.1836e-01,  5.7422e-01,  ..., -1.9766e+00,
           2.2344e+00,  1.0078e+00],
         [-6.1719e-01, -1.3438e+00,  1.0703e+00,  ...,  4.0430e-01,
           1.1250e+00, -3.8086e-01]]], dtype=torch.bfloat16), tensor([[[ 7.6294e-03, -6.6528e-03, -5.7678e-03,  ...,  6.5613e-03,
          -5.0354e-03,  2.7313e-03],
         [ 9.2969e-01,  1.1641e+00, -3.2812e-01,  ...,  1.3281e+00,
           2.6367e-01,  2.0625e+00],
         [-4.9609e-01, -8.0078e-01,  4.6680e-01,  ...,  5.9766e-01,
          -2.7734e-01,  8.7500e-01],
         ...,
         [ 1.1562e+00, -5.1562e-01, -1.0234e+00,  ...,  4.6484e-01,
           1.3438e+00, -6.7578e-01],
         [ 5.3125e-01, -7.6953e-01,  1.4746e-01,  ...,  8.7891e-01,
           4.4531e-01, -3.4180e-01],
         [ 1.8281e+00, -4.3750e-01,  1.8828e+00,  ..., -1.3750e+00,
          -1.1406e+00,  1.0469e+00]],

        [[ 7.6294e-03, -6.6528e-03, -5.7678e-03,  ...,  6.5613e-03,
          -5.0354e-03,  2.7313e-03],
         [ 9.2969e-01,  1.1641e+00, -3.2812e-01,  ...,  1.3281e+00,
           2.6367e-01,  2.0625e+00],
         [-4.9609e-01, -8.0078e-01,  4.6680e-01,  ...,  5.9766e-01,
          -2.7734e-01,  8.7500e-01],
         ...,
         [ 1.1562e+00, -5.1562e-01, -1.0234e+00,  ...,  4.6484e-01,
           1.3438e+00, -6.7578e-01],
         [ 5.3125e-01, -7.6953e-01,  1.4746e-01,  ...,  8.7891e-01,
           4.4531e-01, -3.4180e-01],
         [ 1.8281e+00, -4.3750e-01,  1.8828e+00,  ..., -1.3750e+00,
          -1.1406e+00,  1.0469e+00]],

        [[ 7.6294e-03, -6.6528e-03, -5.7678e-03,  ...,  6.5613e-03,
          -5.0354e-03,  2.7313e-03],
         [ 9.2969e-01,  1.1641e+00, -3.2812e-01,  ...,  1.3281e+00,
           2.6367e-01,  2.0625e+00],
         [-4.9609e-01, -8.0078e-01,  4.6680e-01,  ...,  5.9766e-01,
          -2.7734e-01,  8.7500e-01],
         ...,
         [ 1.1562e+00, -5.1562e-01, -1.0234e+00,  ...,  4.6484e-01,
           1.3438e+00, -6.7578e-01],
         [ 5.3125e-01, -7.6953e-01,  1.4746e-01,  ...,  8.7891e-01,
           4.4531e-01, -3.4180e-01],
         [ 1.8281e+00, -4.3750e-01,  1.8828e+00,  ..., -1.3750e+00,
          -1.1406e+00,  1.0469e+00]],

        ...,

        [[-6.7234e-05,  2.8381e-03, -5.2795e-03,  ...,  2.2736e-03,
          -6.7444e-03,  6.7749e-03],
         [-7.3242e-02,  5.4688e-01,  8.6328e-01,  ...,  7.6953e-01,
           4.8633e-01,  3.1055e-01],
         [-6.3672e-01,  3.1445e-01,  2.5977e-01,  ..., -3.6865e-02,
           1.9165e-02, -8.3594e-01],
         ...,
         [-2.8320e-01,  8.5156e-01,  2.9688e-01,  ..., -9.0625e-01,
           6.4844e-01,  5.4688e-01],
         [-1.9238e-01, -7.3047e-01,  3.6133e-02,  ...,  6.4844e-01,
           1.8125e+00, -1.9336e-01],
         [ 1.1953e+00,  1.9141e-01,  1.0312e+00,  ..., -1.3580e-03,
           9.6094e-01,  1.2422e+00]],

        [[-6.7234e-05,  2.8381e-03, -5.2795e-03,  ...,  2.2736e-03,
          -6.7444e-03,  6.7749e-03],
         [-7.3242e-02,  5.4688e-01,  8.6328e-01,  ...,  7.6953e-01,
           4.8633e-01,  3.1055e-01],
         [-6.3672e-01,  3.1445e-01,  2.5977e-01,  ..., -3.6865e-02,
           1.9165e-02, -8.3594e-01],
         ...,
         [-2.8320e-01,  8.5156e-01,  2.9688e-01,  ..., -9.0625e-01,
           6.4844e-01,  5.4688e-01],
         [-1.9238e-01, -7.3047e-01,  3.6133e-02,  ...,  6.4844e-01,
           1.8125e+00, -1.9336e-01],
         [ 1.1953e+00,  1.9141e-01,  1.0312e+00,  ..., -1.3580e-03,
           9.6094e-01,  1.2422e+00]],

        [[-6.7234e-05,  2.8381e-03, -5.2795e-03,  ...,  2.2736e-03,
          -6.7444e-03,  6.7749e-03],
         [-7.3242e-02,  5.4688e-01,  8.6328e-01,  ...,  7.6953e-01,
           4.8633e-01,  3.1055e-01],
         [-6.3672e-01,  3.1445e-01,  2.5977e-01,  ..., -3.6865e-02,
           1.9165e-02, -8.3594e-01],
         ...,
         [-2.8320e-01,  8.5156e-01,  2.9688e-01,  ..., -9.0625e-01,
           6.4844e-01,  5.4688e-01],
         [-1.9238e-01, -7.3047e-01,  3.6133e-02,  ...,  6.4844e-01,
           1.8125e+00, -1.9336e-01],
         [ 1.1953e+00,  1.9141e-01,  1.0312e+00,  ..., -1.3580e-03,
           9.6094e-01,  1.2422e+00]]], dtype=torch.bfloat16)), (tensor([[[ 1.3245e-02, -6.1035e-04,  2.7618e-03,  ..., -1.0303e-01,
          -1.4941e-01,  1.9727e-01],
         [-1.1875e+00, -4.7070e-01,  2.4531e+00,  ...,  9.4141e-01,
           2.5195e-01,  1.8125e+00],
         [ 1.2451e-01,  2.1387e-01,  2.1973e-01,  ...,  1.1719e+00,
          -3.3398e-01, -2.8125e+00],
         ...,
         [-4.5508e-01,  7.6172e-01,  9.5703e-01,  ...,  5.8203e-01,
           1.1719e+00, -4.7266e-01],
         [-4.4922e-01,  3.0762e-02, -4.8828e-02,  ..., -1.3965e-01,
           7.4219e-02, -1.6406e+00],
         [ 5.5859e-01, -1.0625e+00,  5.0391e-01,  ...,  1.5156e+00,
           5.2344e-01,  6.2109e-01]],

        [[ 1.3245e-02, -6.1035e-04,  2.7618e-03,  ..., -1.0303e-01,
          -1.4941e-01,  1.9727e-01],
         [-1.1875e+00, -4.7070e-01,  2.4531e+00,  ...,  9.4141e-01,
           2.5195e-01,  1.8125e+00],
         [ 1.2451e-01,  2.1387e-01,  2.1973e-01,  ...,  1.1719e+00,
          -3.3398e-01, -2.8125e+00],
         ...,
         [-4.5508e-01,  7.6172e-01,  9.5703e-01,  ...,  5.8203e-01,
           1.1719e+00, -4.7266e-01],
         [-4.4922e-01,  3.0762e-02, -4.8828e-02,  ..., -1.3965e-01,
           7.4219e-02, -1.6406e+00],
         [ 5.5859e-01, -1.0625e+00,  5.0391e-01,  ...,  1.5156e+00,
           5.2344e-01,  6.2109e-01]],

        [[ 1.3245e-02, -6.1035e-04,  2.7618e-03,  ..., -1.0303e-01,
          -1.4941e-01,  1.9727e-01],
         [-1.1875e+00, -4.7070e-01,  2.4531e+00,  ...,  9.4141e-01,
           2.5195e-01,  1.8125e+00],
         [ 1.2451e-01,  2.1387e-01,  2.1973e-01,  ...,  1.1719e+00,
          -3.3398e-01, -2.8125e+00],
         ...,
         [-4.5508e-01,  7.6172e-01,  9.5703e-01,  ...,  5.8203e-01,
           1.1719e+00, -4.7266e-01],
         [-4.4922e-01,  3.0762e-02, -4.8828e-02,  ..., -1.3965e-01,
           7.4219e-02, -1.6406e+00],
         [ 5.5859e-01, -1.0625e+00,  5.0391e-01,  ...,  1.5156e+00,
           5.2344e-01,  6.2109e-01]],

        ...,

        [[ 7.3242e-03,  2.6398e-03,  3.4027e-03,  ..., -1.0352e-01,
           8.5938e-01, -3.6914e-01],
         [-1.3281e-01,  6.6016e-01,  1.0059e-01,  ...,  2.9531e+00,
          -3.9062e+00, -5.7031e-01],
         [-2.3828e-01,  6.2012e-02, -2.3535e-01,  ...,  1.9609e+00,
          -5.2500e+00,  1.8516e+00],
         ...,
         [ 6.1768e-02,  1.2891e-01,  1.5820e-01,  ..., -1.9609e+00,
           7.5684e-02,  1.1250e+00],
         [ 3.3203e-01,  5.9814e-02,  2.6172e-01,  ..., -7.5781e-01,
          -2.4844e+00,  1.6328e+00],
         [ 8.4961e-02, -4.3555e-01,  1.2109e-01,  ...,  1.8281e+00,
          -3.2812e+00,  4.8828e-01]],

        [[ 7.3242e-03,  2.6398e-03,  3.4027e-03,  ..., -1.0352e-01,
           8.5938e-01, -3.6914e-01],
         [-1.3281e-01,  6.6016e-01,  1.0059e-01,  ...,  2.9531e+00,
          -3.9062e+00, -5.7031e-01],
         [-2.3828e-01,  6.2012e-02, -2.3535e-01,  ...,  1.9609e+00,
          -5.2500e+00,  1.8516e+00],
         ...,
         [ 6.1768e-02,  1.2891e-01,  1.5820e-01,  ..., -1.9609e+00,
           7.5684e-02,  1.1250e+00],
         [ 3.3203e-01,  5.9814e-02,  2.6172e-01,  ..., -7.5781e-01,
          -2.4844e+00,  1.6328e+00],
         [ 8.4961e-02, -4.3555e-01,  1.2109e-01,  ...,  1.8281e+00,
          -3.2812e+00,  4.8828e-01]],

        [[ 7.3242e-03,  2.6398e-03,  3.4027e-03,  ..., -1.0352e-01,
           8.5938e-01, -3.6914e-01],
         [-1.3281e-01,  6.6016e-01,  1.0059e-01,  ...,  2.9531e+00,
          -3.9062e+00, -5.7031e-01],
         [-2.3828e-01,  6.2012e-02, -2.3535e-01,  ...,  1.9609e+00,
          -5.2500e+00,  1.8516e+00],
         ...,
         [ 6.1768e-02,  1.2891e-01,  1.5820e-01,  ..., -1.9609e+00,
           7.5684e-02,  1.1250e+00],
         [ 3.3203e-01,  5.9814e-02,  2.6172e-01,  ..., -7.5781e-01,
          -2.4844e+00,  1.6328e+00],
         [ 8.4961e-02, -4.3555e-01,  1.2109e-01,  ...,  1.8281e+00,
          -3.2812e+00,  4.8828e-01]]], dtype=torch.bfloat16), tensor([[[-7.5378e-03,  4.1504e-03,  5.4932e-03,  ..., -3.3569e-03,
           1.2024e-02,  3.7384e-03],
         [-1.2656e+00, -1.1250e+00, -8.9453e-01,  ...,  3.3008e-01,
          -6.0156e-01, -7.5391e-01],
         [-2.0156e+00, -9.3750e-01, -2.6245e-03,  ..., -3.0000e+00,
           9.2578e-01, -1.0234e+00],
         ...,
         [ 2.3047e-01, -1.2656e+00, -8.7109e-01,  ..., -3.1250e+00,
           1.0312e+00,  6.4453e-01],
         [ 4.9414e-01, -7.6562e-01, -1.8828e+00,  ..., -5.6562e+00,
           2.2344e+00, -1.2422e+00],
         [-5.0000e-01, -5.0391e-01, -1.3438e+00,  ..., -2.5938e+00,
          -1.1016e+00, -2.6367e-01]],

        [[-7.5378e-03,  4.1504e-03,  5.4932e-03,  ..., -3.3569e-03,
           1.2024e-02,  3.7384e-03],
         [-1.2656e+00, -1.1250e+00, -8.9453e-01,  ...,  3.3008e-01,
          -6.0156e-01, -7.5391e-01],
         [-2.0156e+00, -9.3750e-01, -2.6245e-03,  ..., -3.0000e+00,
           9.2578e-01, -1.0234e+00],
         ...,
         [ 2.3047e-01, -1.2656e+00, -8.7109e-01,  ..., -3.1250e+00,
           1.0312e+00,  6.4453e-01],
         [ 4.9414e-01, -7.6562e-01, -1.8828e+00,  ..., -5.6562e+00,
           2.2344e+00, -1.2422e+00],
         [-5.0000e-01, -5.0391e-01, -1.3438e+00,  ..., -2.5938e+00,
          -1.1016e+00, -2.6367e-01]],

        [[-7.5378e-03,  4.1504e-03,  5.4932e-03,  ..., -3.3569e-03,
           1.2024e-02,  3.7384e-03],
         [-1.2656e+00, -1.1250e+00, -8.9453e-01,  ...,  3.3008e-01,
          -6.0156e-01, -7.5391e-01],
         [-2.0156e+00, -9.3750e-01, -2.6245e-03,  ..., -3.0000e+00,
           9.2578e-01, -1.0234e+00],
         ...,
         [ 2.3047e-01, -1.2656e+00, -8.7109e-01,  ..., -3.1250e+00,
           1.0312e+00,  6.4453e-01],
         [ 4.9414e-01, -7.6562e-01, -1.8828e+00,  ..., -5.6562e+00,
           2.2344e+00, -1.2422e+00],
         [-5.0000e-01, -5.0391e-01, -1.3438e+00,  ..., -2.5938e+00,
          -1.1016e+00, -2.6367e-01]],

        ...,

        [[ 7.0190e-03, -1.6556e-03, -3.8605e-03,  ..., -3.7384e-03,
           2.1973e-03,  1.5564e-03],
         [-8.9844e-01, -1.7422e+00,  4.0312e+00,  ..., -1.5469e+00,
           9.5312e-01,  2.0625e+00],
         [ 1.4954e-03, -8.6719e-01,  2.4375e+00,  ..., -1.3672e+00,
          -6.4062e-01, -1.4453e-01],
         ...,
         [-3.0664e-01,  1.4062e+00, -1.9062e+00,  ..., -3.3447e-02,
           3.8867e-01,  1.0391e+00],
         [ 7.2754e-02,  7.9688e-01, -5.7031e-01,  ...,  2.6367e-01,
           2.5977e-01,  6.9531e-01],
         [-1.5078e+00, -7.9297e-01,  9.1406e-01,  ..., -1.8516e+00,
           8.2812e-01,  4.5703e-01]],

        [[ 7.0190e-03, -1.6556e-03, -3.8605e-03,  ..., -3.7384e-03,
           2.1973e-03,  1.5564e-03],
         [-8.9844e-01, -1.7422e+00,  4.0312e+00,  ..., -1.5469e+00,
           9.5312e-01,  2.0625e+00],
         [ 1.4954e-03, -8.6719e-01,  2.4375e+00,  ..., -1.3672e+00,
          -6.4062e-01, -1.4453e-01],
         ...,
         [-3.0664e-01,  1.4062e+00, -1.9062e+00,  ..., -3.3447e-02,
           3.8867e-01,  1.0391e+00],
         [ 7.2754e-02,  7.9688e-01, -5.7031e-01,  ...,  2.6367e-01,
           2.5977e-01,  6.9531e-01],
         [-1.5078e+00, -7.9297e-01,  9.1406e-01,  ..., -1.8516e+00,
           8.2812e-01,  4.5703e-01]],

        [[ 7.0190e-03, -1.6556e-03, -3.8605e-03,  ..., -3.7384e-03,
           2.1973e-03,  1.5564e-03],
         [-8.9844e-01, -1.7422e+00,  4.0312e+00,  ..., -1.5469e+00,
           9.5312e-01,  2.0625e+00],
         [ 1.4954e-03, -8.6719e-01,  2.4375e+00,  ..., -1.3672e+00,
          -6.4062e-01, -1.4453e-01],
         ...,
         [-3.0664e-01,  1.4062e+00, -1.9062e+00,  ..., -3.3447e-02,
           3.8867e-01,  1.0391e+00],
         [ 7.2754e-02,  7.9688e-01, -5.7031e-01,  ...,  2.6367e-01,
           2.5977e-01,  6.9531e-01],
         [-1.5078e+00, -7.9297e-01,  9.1406e-01,  ..., -1.8516e+00,
           8.2812e-01,  4.5703e-01]]], dtype=torch.bfloat16)), (tensor([[[ 8.5449e-03,  1.2741e-03, -8.4229e-03,  ..., -1.3379e-01,
          -6.2866e-03,  7.2327e-03],
         [-1.0781e+00, -6.3281e-01, -2.5391e-01,  ...,  5.6641e-01,
          -1.0625e+00,  2.5156e+00],
         [ 1.4258e-01,  2.4609e-01, -9.9219e-01,  ...,  1.7344e+00,
           2.8125e-01,  3.3906e+00],
         ...,
         [ 5.6641e-01,  5.5859e-01, -9.5703e-02,  ..., -2.7812e+00,
           1.7822e-02,  4.1992e-01],
         [ 7.6562e-01,  8.0078e-01, -7.8906e-01,  ...,  1.9844e+00,
          -1.1016e+00,  1.7109e+00],
         [-4.3750e-01, -1.1094e+00, -1.1250e+00,  ..., -8.3984e-02,
          -2.2656e-01,  3.4844e+00]],

        [[ 8.5449e-03,  1.2741e-03, -8.4229e-03,  ..., -1.3379e-01,
          -6.2866e-03,  7.2327e-03],
         [-1.0781e+00, -6.3281e-01, -2.5391e-01,  ...,  5.6641e-01,
          -1.0625e+00,  2.5156e+00],
         [ 1.4258e-01,  2.4609e-01, -9.9219e-01,  ...,  1.7344e+00,
           2.8125e-01,  3.3906e+00],
         ...,
         [ 5.6641e-01,  5.5859e-01, -9.5703e-02,  ..., -2.7812e+00,
           1.7822e-02,  4.1992e-01],
         [ 7.6562e-01,  8.0078e-01, -7.8906e-01,  ...,  1.9844e+00,
          -1.1016e+00,  1.7109e+00],
         [-4.3750e-01, -1.1094e+00, -1.1250e+00,  ..., -8.3984e-02,
          -2.2656e-01,  3.4844e+00]],

        [[ 8.5449e-03,  1.2741e-03, -8.4229e-03,  ..., -1.3379e-01,
          -6.2866e-03,  7.2327e-03],
         [-1.0781e+00, -6.3281e-01, -2.5391e-01,  ...,  5.6641e-01,
          -1.0625e+00,  2.5156e+00],
         [ 1.4258e-01,  2.4609e-01, -9.9219e-01,  ...,  1.7344e+00,
           2.8125e-01,  3.3906e+00],
         ...,
         [ 5.6641e-01,  5.5859e-01, -9.5703e-02,  ..., -2.7812e+00,
           1.7822e-02,  4.1992e-01],
         [ 7.6562e-01,  8.0078e-01, -7.8906e-01,  ...,  1.9844e+00,
          -1.1016e+00,  1.7109e+00],
         [-4.3750e-01, -1.1094e+00, -1.1250e+00,  ..., -8.3984e-02,
          -2.2656e-01,  3.4844e+00]],

        ...,

        [[ 4.4861e-03, -7.5073e-03, -7.8735e-03,  ..., -7.3242e-02,
          -5.5908e-02,  5.2490e-02],
         [ 1.3125e+00, -1.4844e-01,  1.0781e+00,  ..., -9.2969e-01,
           2.9883e-01,  5.6641e-01],
         [-1.1484e+00,  1.1875e+00,  1.0840e-01,  ...,  8.4766e-01,
           3.3789e-01,  1.4219e+00],
         ...,
         [-4.1602e-01, -2.0703e-01, -6.2891e-01,  ..., -3.0312e+00,
          -1.6250e+00, -3.7656e+00],
         [-6.7188e-01, -6.3672e-01,  2.2461e-02,  ..., -3.0625e+00,
          -2.2031e+00, -2.2500e+00],
         [ 9.5312e-01, -1.6953e+00,  1.5938e+00,  ..., -4.2188e-01,
           5.0781e-01,  3.1406e+00]],

        [[ 4.4861e-03, -7.5073e-03, -7.8735e-03,  ..., -7.3242e-02,
          -5.5908e-02,  5.2490e-02],
         [ 1.3125e+00, -1.4844e-01,  1.0781e+00,  ..., -9.2969e-01,
           2.9883e-01,  5.6641e-01],
         [-1.1484e+00,  1.1875e+00,  1.0840e-01,  ...,  8.4766e-01,
           3.3789e-01,  1.4219e+00],
         ...,
         [-4.1602e-01, -2.0703e-01, -6.2891e-01,  ..., -3.0312e+00,
          -1.6250e+00, -3.7656e+00],
         [-6.7188e-01, -6.3672e-01,  2.2461e-02,  ..., -3.0625e+00,
          -2.2031e+00, -2.2500e+00],
         [ 9.5312e-01, -1.6953e+00,  1.5938e+00,  ..., -4.2188e-01,
           5.0781e-01,  3.1406e+00]],

        [[ 4.4861e-03, -7.5073e-03, -7.8735e-03,  ..., -7.3242e-02,
          -5.5908e-02,  5.2490e-02],
         [ 1.3125e+00, -1.4844e-01,  1.0781e+00,  ..., -9.2969e-01,
           2.9883e-01,  5.6641e-01],
         [-1.1484e+00,  1.1875e+00,  1.0840e-01,  ...,  8.4766e-01,
           3.3789e-01,  1.4219e+00],
         ...,
         [-4.1602e-01, -2.0703e-01, -6.2891e-01,  ..., -3.0312e+00,
          -1.6250e+00, -3.7656e+00],
         [-6.7188e-01, -6.3672e-01,  2.2461e-02,  ..., -3.0625e+00,
          -2.2031e+00, -2.2500e+00],
         [ 9.5312e-01, -1.6953e+00,  1.5938e+00,  ..., -4.2188e-01,
           5.0781e-01,  3.1406e+00]]], dtype=torch.bfloat16), tensor([[[ 5.7068e-03,  9.1171e-04, -8.8501e-03,  ..., -4.3869e-04,
          -9.8877e-03, -9.2316e-04],
         [-5.3906e-01, -3.3008e-01,  1.1035e-01,  ...,  1.0791e-01,
          -5.3125e-01, -3.2422e-01],
         [-1.3245e-02, -2.9492e-01, -7.5000e-01,  ..., -4.2773e-01,
          -9.1016e-01,  1.9824e-01],
         ...,
         [ 9.8438e-01,  2.3535e-01, -9.6875e-01,  ...,  8.7891e-02,
          -1.1562e+00, -1.0547e+00],
         [-3.1738e-02,  4.5508e-01, -1.8555e-01,  ..., -3.3936e-02,
          -9.8047e-01,  7.0703e-01],
         [-5.8203e-01,  2.0469e+00, -3.2617e-01,  ..., -6.2891e-01,
           3.3594e-01,  2.2363e-01]],

        [[ 5.7068e-03,  9.1171e-04, -8.8501e-03,  ..., -4.3869e-04,
          -9.8877e-03, -9.2316e-04],
         [-5.3906e-01, -3.3008e-01,  1.1035e-01,  ...,  1.0791e-01,
          -5.3125e-01, -3.2422e-01],
         [-1.3245e-02, -2.9492e-01, -7.5000e-01,  ..., -4.2773e-01,
          -9.1016e-01,  1.9824e-01],
         ...,
         [ 9.8438e-01,  2.3535e-01, -9.6875e-01,  ...,  8.7891e-02,
          -1.1562e+00, -1.0547e+00],
         [-3.1738e-02,  4.5508e-01, -1.8555e-01,  ..., -3.3936e-02,
          -9.8047e-01,  7.0703e-01],
         [-5.8203e-01,  2.0469e+00, -3.2617e-01,  ..., -6.2891e-01,
           3.3594e-01,  2.2363e-01]],

        [[ 5.7068e-03,  9.1171e-04, -8.8501e-03,  ..., -4.3869e-04,
          -9.8877e-03, -9.2316e-04],
         [-5.3906e-01, -3.3008e-01,  1.1035e-01,  ...,  1.0791e-01,
          -5.3125e-01, -3.2422e-01],
         [-1.3245e-02, -2.9492e-01, -7.5000e-01,  ..., -4.2773e-01,
          -9.1016e-01,  1.9824e-01],
         ...,
         [ 9.8438e-01,  2.3535e-01, -9.6875e-01,  ...,  8.7891e-02,
          -1.1562e+00, -1.0547e+00],
         [-3.1738e-02,  4.5508e-01, -1.8555e-01,  ..., -3.3936e-02,
          -9.8047e-01,  7.0703e-01],
         [-5.8203e-01,  2.0469e+00, -3.2617e-01,  ..., -6.2891e-01,
           3.3594e-01,  2.2363e-01]],

        ...,

        [[ 3.2227e-02,  2.3682e-02,  8.5449e-02,  ...,  9.9487e-03,
          -4.2236e-02,  3.6377e-02],
         [-3.1836e-01, -1.2695e-01,  9.5312e-01,  ..., -5.3516e-01,
           2.8711e-01, -3.7305e-01],
         [ 2.5000e-01, -4.2725e-02,  1.6797e-01,  ...,  1.0254e-01,
          -1.0078e+00,  5.4199e-02],
         ...,
         [-1.2500e-01, -1.3125e+00,  1.9688e+00,  ...,  5.1953e-01,
           2.6562e-01, -6.4453e-01],
         [-5.0781e-01, -8.6719e-01,  4.6289e-01,  ..., -1.4551e-01,
           3.2031e-01, -1.0859e+00],
         [ 5.2734e-01, -6.6797e-01,  1.1523e-01,  ...,  1.0703e+00,
           1.4141e+00,  6.3672e-01]],

        [[ 3.2227e-02,  2.3682e-02,  8.5449e-02,  ...,  9.9487e-03,
          -4.2236e-02,  3.6377e-02],
         [-3.1836e-01, -1.2695e-01,  9.5312e-01,  ..., -5.3516e-01,
           2.8711e-01, -3.7305e-01],
         [ 2.5000e-01, -4.2725e-02,  1.6797e-01,  ...,  1.0254e-01,
          -1.0078e+00,  5.4199e-02],
         ...,
         [-1.2500e-01, -1.3125e+00,  1.9688e+00,  ...,  5.1953e-01,
           2.6562e-01, -6.4453e-01],
         [-5.0781e-01, -8.6719e-01,  4.6289e-01,  ..., -1.4551e-01,
           3.2031e-01, -1.0859e+00],
         [ 5.2734e-01, -6.6797e-01,  1.1523e-01,  ...,  1.0703e+00,
           1.4141e+00,  6.3672e-01]],

        [[ 3.2227e-02,  2.3682e-02,  8.5449e-02,  ...,  9.9487e-03,
          -4.2236e-02,  3.6377e-02],
         [-3.1836e-01, -1.2695e-01,  9.5312e-01,  ..., -5.3516e-01,
           2.8711e-01, -3.7305e-01],
         [ 2.5000e-01, -4.2725e-02,  1.6797e-01,  ...,  1.0254e-01,
          -1.0078e+00,  5.4199e-02],
         ...,
         [-1.2500e-01, -1.3125e+00,  1.9688e+00,  ...,  5.1953e-01,
           2.6562e-01, -6.4453e-01],
         [-5.0781e-01, -8.6719e-01,  4.6289e-01,  ..., -1.4551e-01,
           3.2031e-01, -1.0859e+00],
         [ 5.2734e-01, -6.6797e-01,  1.1523e-01,  ...,  1.0703e+00,
           1.4141e+00,  6.3672e-01]]], dtype=torch.bfloat16)), (tensor([[[ 1.4893e-02,  5.1270e-03,  1.1780e-02,  ...,  1.9375e+00,
          -9.0820e-02,  3.8086e-02],
         [ 2.3438e-02,  2.0508e-01,  1.1572e-01,  ..., -6.1250e+00,
          -3.9648e-01,  3.4766e-01],
         [ 6.9336e-02, -4.5508e-01, -1.8555e-02,  ..., -7.1875e+00,
           5.3516e-01,  2.1406e+00],
         ...,
         [-3.8867e-01, -1.0156e-01, -4.8828e-03,  ..., -4.1875e+00,
           1.5391e+00, -4.2500e+00],
         [-2.5977e-01,  1.8164e-01,  3.3594e-01,  ..., -5.7500e+00,
           1.2188e+00, -6.2891e-01],
         [-4.0430e-01,  2.1289e-01,  2.4805e-01,  ..., -6.2812e+00,
          -3.4570e-01,  1.2061e-01]],

        [[ 1.4893e-02,  5.1270e-03,  1.1780e-02,  ...,  1.9375e+00,
          -9.0820e-02,  3.8086e-02],
         [ 2.3438e-02,  2.0508e-01,  1.1572e-01,  ..., -6.1250e+00,
          -3.9648e-01,  3.4766e-01],
         [ 6.9336e-02, -4.5508e-01, -1.8555e-02,  ..., -7.1875e+00,
           5.3516e-01,  2.1406e+00],
         ...,
         [-3.8867e-01, -1.0156e-01, -4.8828e-03,  ..., -4.1875e+00,
           1.5391e+00, -4.2500e+00],
         [-2.5977e-01,  1.8164e-01,  3.3594e-01,  ..., -5.7500e+00,
           1.2188e+00, -6.2891e-01],
         [-4.0430e-01,  2.1289e-01,  2.4805e-01,  ..., -6.2812e+00,
          -3.4570e-01,  1.2061e-01]],

        [[ 1.4893e-02,  5.1270e-03,  1.1780e-02,  ...,  1.9375e+00,
          -9.0820e-02,  3.8086e-02],
         [ 2.3438e-02,  2.0508e-01,  1.1572e-01,  ..., -6.1250e+00,
          -3.9648e-01,  3.4766e-01],
         [ 6.9336e-02, -4.5508e-01, -1.8555e-02,  ..., -7.1875e+00,
           5.3516e-01,  2.1406e+00],
         ...,
         [-3.8867e-01, -1.0156e-01, -4.8828e-03,  ..., -4.1875e+00,
           1.5391e+00, -4.2500e+00],
         [-2.5977e-01,  1.8164e-01,  3.3594e-01,  ..., -5.7500e+00,
           1.2188e+00, -6.2891e-01],
         [-4.0430e-01,  2.1289e-01,  2.4805e-01,  ..., -6.2812e+00,
          -3.4570e-01,  1.2061e-01]],

        ...,

        [[ 2.5787e-03,  3.2501e-03,  1.0620e-02,  ...,  6.8359e-02,
           9.8047e-01, -1.3281e-01],
         [-3.9062e-03,  4.6484e-01, -1.1250e+00,  ...,  4.6631e-02,
          -1.1953e+00, -3.4766e-01],
         [-4.2188e-01, -9.4531e-01, -8.1641e-01,  ...,  1.5332e-01,
          -4.5938e+00,  2.9688e-01],
         ...,
         [-1.1641e+00,  8.3594e-01, -1.0254e-01,  ...,  2.6489e-02,
          -2.1094e+00, -1.4062e-01],
         [-3.0469e-01,  2.9297e-01, -2.6562e-01,  ...,  2.4219e-01,
          -2.7812e+00, -1.8457e-01],
         [ 1.5938e+00,  1.6094e+00, -1.6641e+00,  ..., -5.5469e-01,
          -2.8594e+00,  5.2344e-01]],

        [[ 2.5787e-03,  3.2501e-03,  1.0620e-02,  ...,  6.8359e-02,
           9.8047e-01, -1.3281e-01],
         [-3.9062e-03,  4.6484e-01, -1.1250e+00,  ...,  4.6631e-02,
          -1.1953e+00, -3.4766e-01],
         [-4.2188e-01, -9.4531e-01, -8.1641e-01,  ...,  1.5332e-01,
          -4.5938e+00,  2.9688e-01],
         ...,
         [-1.1641e+00,  8.3594e-01, -1.0254e-01,  ...,  2.6489e-02,
          -2.1094e+00, -1.4062e-01],
         [-3.0469e-01,  2.9297e-01, -2.6562e-01,  ...,  2.4219e-01,
          -2.7812e+00, -1.8457e-01],
         [ 1.5938e+00,  1.6094e+00, -1.6641e+00,  ..., -5.5469e-01,
          -2.8594e+00,  5.2344e-01]],

        [[ 2.5787e-03,  3.2501e-03,  1.0620e-02,  ...,  6.8359e-02,
           9.8047e-01, -1.3281e-01],
         [-3.9062e-03,  4.6484e-01, -1.1250e+00,  ...,  4.6631e-02,
          -1.1953e+00, -3.4766e-01],
         [-4.2188e-01, -9.4531e-01, -8.1641e-01,  ...,  1.5332e-01,
          -4.5938e+00,  2.9688e-01],
         ...,
         [-1.1641e+00,  8.3594e-01, -1.0254e-01,  ...,  2.6489e-02,
          -2.1094e+00, -1.4062e-01],
         [-3.0469e-01,  2.9297e-01, -2.6562e-01,  ...,  2.4219e-01,
          -2.7812e+00, -1.8457e-01],
         [ 1.5938e+00,  1.6094e+00, -1.6641e+00,  ..., -5.5469e-01,
          -2.8594e+00,  5.2344e-01]]], dtype=torch.bfloat16), tensor([[[-0.0053,  0.0099, -0.0089,  ...,  0.0029, -0.0044, -0.0042],
         [ 2.7031, -0.5781,  0.9375,  ...,  2.6094, -1.6562, -0.2637],
         [ 0.5664, -1.2891,  0.1299,  ...,  0.7930, -0.0786, -1.3047],
         ...,
         [-0.9766, -1.0312, -0.3301,  ...,  0.1187,  0.3691, -0.3945],
         [-1.4297,  0.5469, -0.8711,  ..., -0.2305, -0.8359, -0.7109],
         [ 1.0469,  0.2207, -0.1836,  ...,  0.0776,  0.9375,  0.6406]],

        [[-0.0053,  0.0099, -0.0089,  ...,  0.0029, -0.0044, -0.0042],
         [ 2.7031, -0.5781,  0.9375,  ...,  2.6094, -1.6562, -0.2637],
         [ 0.5664, -1.2891,  0.1299,  ...,  0.7930, -0.0786, -1.3047],
         ...,
         [-0.9766, -1.0312, -0.3301,  ...,  0.1187,  0.3691, -0.3945],
         [-1.4297,  0.5469, -0.8711,  ..., -0.2305, -0.8359, -0.7109],
         [ 1.0469,  0.2207, -0.1836,  ...,  0.0776,  0.9375,  0.6406]],

        [[-0.0053,  0.0099, -0.0089,  ...,  0.0029, -0.0044, -0.0042],
         [ 2.7031, -0.5781,  0.9375,  ...,  2.6094, -1.6562, -0.2637],
         [ 0.5664, -1.2891,  0.1299,  ...,  0.7930, -0.0786, -1.3047],
         ...,
         [-0.9766, -1.0312, -0.3301,  ...,  0.1187,  0.3691, -0.3945],
         [-1.4297,  0.5469, -0.8711,  ..., -0.2305, -0.8359, -0.7109],
         [ 1.0469,  0.2207, -0.1836,  ...,  0.0776,  0.9375,  0.6406]],

        ...,

        [[-0.0347,  0.0099,  0.0165,  ..., -0.0503, -0.0281,  0.0081],
         [ 0.5117,  0.2988,  0.5977,  ...,  0.2715,  0.8438, -0.1855],
         [ 1.2891,  0.1729,  0.1484,  ...,  0.2363,  0.6758, -0.1934],
         ...,
         [ 1.0312,  0.2168,  0.1875,  ...,  0.7891, -0.2988, -0.1855],
         [ 1.6641,  0.8945,  0.0201,  ..., -0.1406, -0.2793, -0.1436],
         [ 0.1865, -0.0791, -0.4160,  ..., -0.0747, -0.1514, -0.4941]],

        [[-0.0347,  0.0099,  0.0165,  ..., -0.0503, -0.0281,  0.0081],
         [ 0.5117,  0.2988,  0.5977,  ...,  0.2715,  0.8438, -0.1855],
         [ 1.2891,  0.1729,  0.1484,  ...,  0.2363,  0.6758, -0.1934],
         ...,
         [ 1.0312,  0.2168,  0.1875,  ...,  0.7891, -0.2988, -0.1855],
         [ 1.6641,  0.8945,  0.0201,  ..., -0.1406, -0.2793, -0.1436],
         [ 0.1865, -0.0791, -0.4160,  ..., -0.0747, -0.1514, -0.4941]],

        [[-0.0347,  0.0099,  0.0165,  ..., -0.0503, -0.0281,  0.0081],
         [ 0.5117,  0.2988,  0.5977,  ...,  0.2715,  0.8438, -0.1855],
         [ 1.2891,  0.1729,  0.1484,  ...,  0.2363,  0.6758, -0.1934],
         ...,
         [ 1.0312,  0.2168,  0.1875,  ...,  0.7891, -0.2988, -0.1855],
         [ 1.6641,  0.8945,  0.0201,  ..., -0.1406, -0.2793, -0.1436],
         [ 0.1865, -0.0791, -0.4160,  ..., -0.0747, -0.1514, -0.4941]]],
       dtype=torch.bfloat16)), (tensor([[[ 1.6113e-02,  5.1498e-04, -1.2268e-02,  ...,  2.3047e-01,
           6.2500e-02,  7.8906e-01],
         [ 1.0156e-01, -8.4766e-01,  2.4219e-01,  ...,  5.6641e-01,
          -2.0781e+00, -3.1562e+00],
         [ 1.3672e-01,  1.2793e-01,  9.2188e-01,  ..., -3.5938e-01,
          -1.2109e+00, -4.1875e+00],
         ...,
         [-3.0469e-01,  2.6367e-02, -4.1504e-02,  ..., -5.5859e-01,
          -4.3750e-01,  1.1572e-01],
         [-6.3672e-01,  7.0312e-01,  1.0938e+00,  ..., -6.6797e-01,
           6.6016e-01, -1.3984e+00],
         [-2.0508e-02,  5.9570e-02,  5.3906e-01,  ..., -8.5938e-01,
           1.6172e+00, -2.9375e+00]],

        [[ 1.6113e-02,  5.1498e-04, -1.2268e-02,  ...,  2.3047e-01,
           6.2500e-02,  7.8906e-01],
         [ 1.0156e-01, -8.4766e-01,  2.4219e-01,  ...,  5.6641e-01,
          -2.0781e+00, -3.1562e+00],
         [ 1.3672e-01,  1.2793e-01,  9.2188e-01,  ..., -3.5938e-01,
          -1.2109e+00, -4.1875e+00],
         ...,
         [-3.0469e-01,  2.6367e-02, -4.1504e-02,  ..., -5.5859e-01,
          -4.3750e-01,  1.1572e-01],
         [-6.3672e-01,  7.0312e-01,  1.0938e+00,  ..., -6.6797e-01,
           6.6016e-01, -1.3984e+00],
         [-2.0508e-02,  5.9570e-02,  5.3906e-01,  ..., -8.5938e-01,
           1.6172e+00, -2.9375e+00]],

        [[ 1.6113e-02,  5.1498e-04, -1.2268e-02,  ...,  2.3047e-01,
           6.2500e-02,  7.8906e-01],
         [ 1.0156e-01, -8.4766e-01,  2.4219e-01,  ...,  5.6641e-01,
          -2.0781e+00, -3.1562e+00],
         [ 1.3672e-01,  1.2793e-01,  9.2188e-01,  ..., -3.5938e-01,
          -1.2109e+00, -4.1875e+00],
         ...,
         [-3.0469e-01,  2.6367e-02, -4.1504e-02,  ..., -5.5859e-01,
          -4.3750e-01,  1.1572e-01],
         [-6.3672e-01,  7.0312e-01,  1.0938e+00,  ..., -6.6797e-01,
           6.6016e-01, -1.3984e+00],
         [-2.0508e-02,  5.9570e-02,  5.3906e-01,  ..., -8.5938e-01,
           1.6172e+00, -2.9375e+00]],

        ...,

        [[-6.3171e-03,  1.1536e-02, -4.2152e-04,  ..., -9.0820e-02,
          -2.0605e-01,  7.3730e-02],
         [ 5.0781e-02,  6.1523e-02,  7.6172e-02,  ...,  2.1562e+00,
           1.2988e-01, -1.7812e+00],
         [-2.1484e-01,  3.0273e-01,  4.9219e-01,  ...,  3.0938e+00,
          -1.2266e+00, -3.8594e+00],
         ...,
         [ 3.1250e-01,  2.9688e-01, -1.5820e-01,  ..., -1.8906e+00,
          -2.1719e+00,  3.1250e-01],
         [ 4.3359e-01, -3.0078e-01,  2.0508e-02,  ..., -3.2471e-02,
           3.0078e-01, -1.3184e-01],
         [-4.3555e-01, -4.3945e-02,  2.0508e-01,  ...,  2.4531e+00,
           3.5156e-01,  9.6875e-01]],

        [[-6.3171e-03,  1.1536e-02, -4.2152e-04,  ..., -9.0820e-02,
          -2.0605e-01,  7.3730e-02],
         [ 5.0781e-02,  6.1523e-02,  7.6172e-02,  ...,  2.1562e+00,
           1.2988e-01, -1.7812e+00],
         [-2.1484e-01,  3.0273e-01,  4.9219e-01,  ...,  3.0938e+00,
          -1.2266e+00, -3.8594e+00],
         ...,
         [ 3.1250e-01,  2.9688e-01, -1.5820e-01,  ..., -1.8906e+00,
          -2.1719e+00,  3.1250e-01],
         [ 4.3359e-01, -3.0078e-01,  2.0508e-02,  ..., -3.2471e-02,
           3.0078e-01, -1.3184e-01],
         [-4.3555e-01, -4.3945e-02,  2.0508e-01,  ...,  2.4531e+00,
           3.5156e-01,  9.6875e-01]],

        [[-6.3171e-03,  1.1536e-02, -4.2152e-04,  ..., -9.0820e-02,
          -2.0605e-01,  7.3730e-02],
         [ 5.0781e-02,  6.1523e-02,  7.6172e-02,  ...,  2.1562e+00,
           1.2988e-01, -1.7812e+00],
         [-2.1484e-01,  3.0273e-01,  4.9219e-01,  ...,  3.0938e+00,
          -1.2266e+00, -3.8594e+00],
         ...,
         [ 3.1250e-01,  2.9688e-01, -1.5820e-01,  ..., -1.8906e+00,
          -2.1719e+00,  3.1250e-01],
         [ 4.3359e-01, -3.0078e-01,  2.0508e-02,  ..., -3.2471e-02,
           3.0078e-01, -1.3184e-01],
         [-4.3555e-01, -4.3945e-02,  2.0508e-01,  ...,  2.4531e+00,
           3.5156e-01,  9.6875e-01]]], dtype=torch.bfloat16), tensor([[[ 6.7444e-03, -1.3672e-02,  2.2278e-03,  ...,  3.4180e-03,
          -8.3008e-03, -3.9062e-03],
         [ 9.1309e-02,  9.2773e-02, -8.5156e-01,  ..., -1.7734e+00,
          -1.3203e+00,  1.1094e+00],
         [ 8.8281e-01, -3.5352e-01, -1.4258e-01,  ..., -5.0391e-01,
           6.4941e-02,  2.2656e-01],
         ...,
         [ 1.3574e-01, -4.7070e-01, -5.9375e-01,  ..., -1.8262e-01,
           1.1406e+00,  1.4355e-01],
         [-9.4922e-01,  1.7969e-01,  4.5312e-01,  ...,  4.4141e-01,
          -6.1719e-01, -1.5137e-01],
         [-5.2734e-01, -1.3906e+00, -1.7285e-01,  ...,  1.8652e-01,
           1.2188e+00,  1.0791e-01]],

        [[ 6.7444e-03, -1.3672e-02,  2.2278e-03,  ...,  3.4180e-03,
          -8.3008e-03, -3.9062e-03],
         [ 9.1309e-02,  9.2773e-02, -8.5156e-01,  ..., -1.7734e+00,
          -1.3203e+00,  1.1094e+00],
         [ 8.8281e-01, -3.5352e-01, -1.4258e-01,  ..., -5.0391e-01,
           6.4941e-02,  2.2656e-01],
         ...,
         [ 1.3574e-01, -4.7070e-01, -5.9375e-01,  ..., -1.8262e-01,
           1.1406e+00,  1.4355e-01],
         [-9.4922e-01,  1.7969e-01,  4.5312e-01,  ...,  4.4141e-01,
          -6.1719e-01, -1.5137e-01],
         [-5.2734e-01, -1.3906e+00, -1.7285e-01,  ...,  1.8652e-01,
           1.2188e+00,  1.0791e-01]],

        [[ 6.7444e-03, -1.3672e-02,  2.2278e-03,  ...,  3.4180e-03,
          -8.3008e-03, -3.9062e-03],
         [ 9.1309e-02,  9.2773e-02, -8.5156e-01,  ..., -1.7734e+00,
          -1.3203e+00,  1.1094e+00],
         [ 8.8281e-01, -3.5352e-01, -1.4258e-01,  ..., -5.0391e-01,
           6.4941e-02,  2.2656e-01],
         ...,
         [ 1.3574e-01, -4.7070e-01, -5.9375e-01,  ..., -1.8262e-01,
           1.1406e+00,  1.4355e-01],
         [-9.4922e-01,  1.7969e-01,  4.5312e-01,  ...,  4.4141e-01,
          -6.1719e-01, -1.5137e-01],
         [-5.2734e-01, -1.3906e+00, -1.7285e-01,  ...,  1.8652e-01,
           1.2188e+00,  1.0791e-01]],

        ...,

        [[-2.0142e-02,  5.4626e-03,  2.3499e-03,  ..., -1.0147e-03,
          -1.6113e-02,  8.1177e-03],
         [ 1.0078e+00,  2.0801e-01, -1.6250e+00,  ...,  2.9297e-01,
           7.2656e-01,  2.6094e+00],
         [ 8.9453e-01,  3.0156e+00, -7.5000e-01,  ...,  6.4453e-01,
           1.6602e-01,  2.5625e+00],
         ...,
         [-3.8086e-02,  7.1094e-01, -3.0664e-01,  ..., -9.5703e-02,
          -2.1289e-01,  1.6113e-01],
         [-5.7031e-01,  8.2812e-01, -2.5781e-01,  ...,  5.5469e-01,
           6.4453e-01, -2.0605e-01],
         [-7.8906e-01, -9.3750e-01, -1.6406e+00,  ..., -6.4453e-01,
           1.4141e+00, -2.1250e+00]],

        [[-2.0142e-02,  5.4626e-03,  2.3499e-03,  ..., -1.0147e-03,
          -1.6113e-02,  8.1177e-03],
         [ 1.0078e+00,  2.0801e-01, -1.6250e+00,  ...,  2.9297e-01,
           7.2656e-01,  2.6094e+00],
         [ 8.9453e-01,  3.0156e+00, -7.5000e-01,  ...,  6.4453e-01,
           1.6602e-01,  2.5625e+00],
         ...,
         [-3.8086e-02,  7.1094e-01, -3.0664e-01,  ..., -9.5703e-02,
          -2.1289e-01,  1.6113e-01],
         [-5.7031e-01,  8.2812e-01, -2.5781e-01,  ...,  5.5469e-01,
           6.4453e-01, -2.0605e-01],
         [-7.8906e-01, -9.3750e-01, -1.6406e+00,  ..., -6.4453e-01,
           1.4141e+00, -2.1250e+00]],

        [[-2.0142e-02,  5.4626e-03,  2.3499e-03,  ..., -1.0147e-03,
          -1.6113e-02,  8.1177e-03],
         [ 1.0078e+00,  2.0801e-01, -1.6250e+00,  ...,  2.9297e-01,
           7.2656e-01,  2.6094e+00],
         [ 8.9453e-01,  3.0156e+00, -7.5000e-01,  ...,  6.4453e-01,
           1.6602e-01,  2.5625e+00],
         ...,
         [-3.8086e-02,  7.1094e-01, -3.0664e-01,  ..., -9.5703e-02,
          -2.1289e-01,  1.6113e-01],
         [-5.7031e-01,  8.2812e-01, -2.5781e-01,  ...,  5.5469e-01,
           6.4453e-01, -2.0605e-01],
         [-7.8906e-01, -9.3750e-01, -1.6406e+00,  ..., -6.4453e-01,
           1.4141e+00, -2.1250e+00]]], dtype=torch.bfloat16)), (tensor([[[ 2.2583e-03,  1.4648e-03, -5.1270e-03,  ...,  8.6426e-02,
          -2.5586e-01, -1.3062e-02],
         [-2.4531e+00, -2.8516e-01, -9.7266e-01,  ..., -4.1797e-01,
          -6.9141e-01, -7.8516e-01],
         [-2.7031e+00, -9.8828e-01, -1.5234e+00,  ...,  1.9141e+00,
          -1.3672e+00,  1.1797e+00],
         ...,
         [-4.6484e-01,  9.7656e-03, -1.3984e+00,  ..., -1.7266e+00,
          -6.2891e-01, -5.4688e-01],
         [ 4.6094e-01, -4.6094e-01, -9.1406e-01,  ..., -1.8438e+00,
          -1.3125e+00,  5.9375e-01],
         [ 1.9844e+00,  1.8594e+00, -6.4844e-01,  ..., -1.7734e+00,
          -1.3594e+00, -2.2656e+00]],

        [[ 2.2583e-03,  1.4648e-03, -5.1270e-03,  ...,  8.6426e-02,
          -2.5586e-01, -1.3062e-02],
         [-2.4531e+00, -2.8516e-01, -9.7266e-01,  ..., -4.1797e-01,
          -6.9141e-01, -7.8516e-01],
         [-2.7031e+00, -9.8828e-01, -1.5234e+00,  ...,  1.9141e+00,
          -1.3672e+00,  1.1797e+00],
         ...,
         [-4.6484e-01,  9.7656e-03, -1.3984e+00,  ..., -1.7266e+00,
          -6.2891e-01, -5.4688e-01],
         [ 4.6094e-01, -4.6094e-01, -9.1406e-01,  ..., -1.8438e+00,
          -1.3125e+00,  5.9375e-01],
         [ 1.9844e+00,  1.8594e+00, -6.4844e-01,  ..., -1.7734e+00,
          -1.3594e+00, -2.2656e+00]],

        [[ 2.2583e-03,  1.4648e-03, -5.1270e-03,  ...,  8.6426e-02,
          -2.5586e-01, -1.3062e-02],
         [-2.4531e+00, -2.8516e-01, -9.7266e-01,  ..., -4.1797e-01,
          -6.9141e-01, -7.8516e-01],
         [-2.7031e+00, -9.8828e-01, -1.5234e+00,  ...,  1.9141e+00,
          -1.3672e+00,  1.1797e+00],
         ...,
         [-4.6484e-01,  9.7656e-03, -1.3984e+00,  ..., -1.7266e+00,
          -6.2891e-01, -5.4688e-01],
         [ 4.6094e-01, -4.6094e-01, -9.1406e-01,  ..., -1.8438e+00,
          -1.3125e+00,  5.9375e-01],
         [ 1.9844e+00,  1.8594e+00, -6.4844e-01,  ..., -1.7734e+00,
          -1.3594e+00, -2.2656e+00]],

        ...,

        [[-3.2349e-03,  5.8594e-03,  7.5378e-03,  ..., -8.3008e-02,
           6.6406e-02, -8.7891e-02],
         [-3.1250e+00,  9.4141e-01, -2.5781e-01,  ...,  6.3281e-01,
          -1.2969e+00,  9.1797e-02],
         [ 1.6406e-01, -5.3906e-01, -2.2188e+00,  ..., -1.6602e-01,
          -2.3906e+00,  1.1328e-01],
         ...,
         [ 2.0156e+00,  1.0781e+00,  5.5469e-01,  ...,  3.1250e-01,
           6.3281e-01, -1.3906e+00],
         [ 1.5469e+00, -5.1172e-01,  1.2500e-01,  ...,  1.1182e-01,
           1.3672e-02, -1.5859e+00],
         [-2.1094e-01, -1.8125e+00, -1.2031e+00,  ..., -4.7461e-01,
          -8.0469e-01, -4.3359e-01]],

        [[-3.2349e-03,  5.8594e-03,  7.5378e-03,  ..., -8.3008e-02,
           6.6406e-02, -8.7891e-02],
         [-3.1250e+00,  9.4141e-01, -2.5781e-01,  ...,  6.3281e-01,
          -1.2969e+00,  9.1797e-02],
         [ 1.6406e-01, -5.3906e-01, -2.2188e+00,  ..., -1.6602e-01,
          -2.3906e+00,  1.1328e-01],
         ...,
         [ 2.0156e+00,  1.0781e+00,  5.5469e-01,  ...,  3.1250e-01,
           6.3281e-01, -1.3906e+00],
         [ 1.5469e+00, -5.1172e-01,  1.2500e-01,  ...,  1.1182e-01,
           1.3672e-02, -1.5859e+00],
         [-2.1094e-01, -1.8125e+00, -1.2031e+00,  ..., -4.7461e-01,
          -8.0469e-01, -4.3359e-01]],

        [[-3.2349e-03,  5.8594e-03,  7.5378e-03,  ..., -8.3008e-02,
           6.6406e-02, -8.7891e-02],
         [-3.1250e+00,  9.4141e-01, -2.5781e-01,  ...,  6.3281e-01,
          -1.2969e+00,  9.1797e-02],
         [ 1.6406e-01, -5.3906e-01, -2.2188e+00,  ..., -1.6602e-01,
          -2.3906e+00,  1.1328e-01],
         ...,
         [ 2.0156e+00,  1.0781e+00,  5.5469e-01,  ...,  3.1250e-01,
           6.3281e-01, -1.3906e+00],
         [ 1.5469e+00, -5.1172e-01,  1.2500e-01,  ...,  1.1182e-01,
           1.3672e-02, -1.5859e+00],
         [-2.1094e-01, -1.8125e+00, -1.2031e+00,  ..., -4.7461e-01,
          -8.0469e-01, -4.3359e-01]]], dtype=torch.bfloat16), tensor([[[ 5.5847e-03, -3.1738e-03,  1.5137e-02,  ...,  1.8692e-03,
           1.3733e-02,  1.8433e-02],
         [ 1.6992e-01, -1.4609e+00,  1.2012e-01,  ..., -3.7305e-01,
           1.3359e+00, -2.1191e-01],
         [ 3.2617e-01, -1.7090e-01,  1.1172e+00,  ...,  1.2656e+00,
          -4.8633e-01, -4.1602e-01],
         ...,
         [-1.1621e-01,  2.3926e-01,  7.4219e-01,  ..., -3.0640e-02,
          -1.5430e-01, -2.7344e-01],
         [ 5.2734e-01,  4.9023e-01, -2.8320e-02,  ...,  4.4727e-01,
           7.2656e-01, -4.6875e-01],
         [-3.6914e-01, -2.9492e-01, -7.0312e-01,  ..., -5.3711e-02,
          -6.3281e-01, -1.4062e+00]],

        [[ 5.5847e-03, -3.1738e-03,  1.5137e-02,  ...,  1.8692e-03,
           1.3733e-02,  1.8433e-02],
         [ 1.6992e-01, -1.4609e+00,  1.2012e-01,  ..., -3.7305e-01,
           1.3359e+00, -2.1191e-01],
         [ 3.2617e-01, -1.7090e-01,  1.1172e+00,  ...,  1.2656e+00,
          -4.8633e-01, -4.1602e-01],
         ...,
         [-1.1621e-01,  2.3926e-01,  7.4219e-01,  ..., -3.0640e-02,
          -1.5430e-01, -2.7344e-01],
         [ 5.2734e-01,  4.9023e-01, -2.8320e-02,  ...,  4.4727e-01,
           7.2656e-01, -4.6875e-01],
         [-3.6914e-01, -2.9492e-01, -7.0312e-01,  ..., -5.3711e-02,
          -6.3281e-01, -1.4062e+00]],

        [[ 5.5847e-03, -3.1738e-03,  1.5137e-02,  ...,  1.8692e-03,
           1.3733e-02,  1.8433e-02],
         [ 1.6992e-01, -1.4609e+00,  1.2012e-01,  ..., -3.7305e-01,
           1.3359e+00, -2.1191e-01],
         [ 3.2617e-01, -1.7090e-01,  1.1172e+00,  ...,  1.2656e+00,
          -4.8633e-01, -4.1602e-01],
         ...,
         [-1.1621e-01,  2.3926e-01,  7.4219e-01,  ..., -3.0640e-02,
          -1.5430e-01, -2.7344e-01],
         [ 5.2734e-01,  4.9023e-01, -2.8320e-02,  ...,  4.4727e-01,
           7.2656e-01, -4.6875e-01],
         [-3.6914e-01, -2.9492e-01, -7.0312e-01,  ..., -5.3711e-02,
          -6.3281e-01, -1.4062e+00]],

        ...,

        [[ 3.8910e-03, -1.4572e-03, -4.1809e-03,  ...,  1.4282e-02,
          -5.4321e-03,  6.4392e-03],
         [ 1.1250e+00, -6.4453e-01, -1.4551e-01,  ..., -6.5430e-02,
           6.2891e-01, -6.4453e-01],
         [-3.3398e-01, -1.8457e-01,  5.4688e-01,  ...,  8.7500e-01,
           2.1191e-01,  1.0781e+00],
         ...,
         [-2.2168e-01, -6.6797e-01,  1.1953e+00,  ...,  6.6406e-01,
           1.6250e+00, -1.9824e-01],
         [ 1.8652e-01,  5.7031e-01, -5.4297e-01,  ..., -6.9141e-01,
          -1.0547e+00, -7.7344e-01],
         [ 1.4609e+00, -6.9531e-01, -3.5889e-02,  ...,  6.1719e-01,
          -6.0547e-01,  2.4121e-01]],

        [[ 3.8910e-03, -1.4572e-03, -4.1809e-03,  ...,  1.4282e-02,
          -5.4321e-03,  6.4392e-03],
         [ 1.1250e+00, -6.4453e-01, -1.4551e-01,  ..., -6.5430e-02,
           6.2891e-01, -6.4453e-01],
         [-3.3398e-01, -1.8457e-01,  5.4688e-01,  ...,  8.7500e-01,
           2.1191e-01,  1.0781e+00],
         ...,
         [-2.2168e-01, -6.6797e-01,  1.1953e+00,  ...,  6.6406e-01,
           1.6250e+00, -1.9824e-01],
         [ 1.8652e-01,  5.7031e-01, -5.4297e-01,  ..., -6.9141e-01,
          -1.0547e+00, -7.7344e-01],
         [ 1.4609e+00, -6.9531e-01, -3.5889e-02,  ...,  6.1719e-01,
          -6.0547e-01,  2.4121e-01]],

        [[ 3.8910e-03, -1.4572e-03, -4.1809e-03,  ...,  1.4282e-02,
          -5.4321e-03,  6.4392e-03],
         [ 1.1250e+00, -6.4453e-01, -1.4551e-01,  ..., -6.5430e-02,
           6.2891e-01, -6.4453e-01],
         [-3.3398e-01, -1.8457e-01,  5.4688e-01,  ...,  8.7500e-01,
           2.1191e-01,  1.0781e+00],
         ...,
         [-2.2168e-01, -6.6797e-01,  1.1953e+00,  ...,  6.6406e-01,
           1.6250e+00, -1.9824e-01],
         [ 1.8652e-01,  5.7031e-01, -5.4297e-01,  ..., -6.9141e-01,
          -1.0547e+00, -7.7344e-01],
         [ 1.4609e+00, -6.9531e-01, -3.5889e-02,  ...,  6.1719e-01,
          -6.0547e-01,  2.4121e-01]]], dtype=torch.bfloat16)), (tensor([[[ 9.3994e-03, -1.4465e-02, -9.5825e-03,  ..., -1.4551e-01,
           5.5469e-01,  2.6562e+00],
         [ 9.4922e-01, -3.2812e-01,  2.4414e-02,  ..., -9.5312e-01,
          -3.2422e-01, -6.4062e+00],
         [-4.3750e-01, -2.2070e-01,  1.6016e-01,  ...,  1.9453e+00,
          -2.6978e-02, -9.3125e+00],
         ...,
         [ 2.5977e-01, -3.3203e-02,  1.1865e-01,  ...,  5.2344e-01,
          -1.6992e-01, -6.2500e+00],
         [-3.5938e-01,  4.0625e-01,  1.2793e-01,  ...,  3.2031e-01,
          -1.1953e+00, -4.8750e+00],
         [-2.5781e-01,  8.5547e-01,  1.1875e+00,  ...,  2.5625e+00,
          -3.0078e-01, -7.2812e+00]],

        [[ 9.3994e-03, -1.4465e-02, -9.5825e-03,  ..., -1.4551e-01,
           5.5469e-01,  2.6562e+00],
         [ 9.4922e-01, -3.2812e-01,  2.4414e-02,  ..., -9.5312e-01,
          -3.2422e-01, -6.4062e+00],
         [-4.3750e-01, -2.2070e-01,  1.6016e-01,  ...,  1.9453e+00,
          -2.6978e-02, -9.3125e+00],
         ...,
         [ 2.5977e-01, -3.3203e-02,  1.1865e-01,  ...,  5.2344e-01,
          -1.6992e-01, -6.2500e+00],
         [-3.5938e-01,  4.0625e-01,  1.2793e-01,  ...,  3.2031e-01,
          -1.1953e+00, -4.8750e+00],
         [-2.5781e-01,  8.5547e-01,  1.1875e+00,  ...,  2.5625e+00,
          -3.0078e-01, -7.2812e+00]],

        [[ 9.3994e-03, -1.4465e-02, -9.5825e-03,  ..., -1.4551e-01,
           5.5469e-01,  2.6562e+00],
         [ 9.4922e-01, -3.2812e-01,  2.4414e-02,  ..., -9.5312e-01,
          -3.2422e-01, -6.4062e+00],
         [-4.3750e-01, -2.2070e-01,  1.6016e-01,  ...,  1.9453e+00,
          -2.6978e-02, -9.3125e+00],
         ...,
         [ 2.5977e-01, -3.3203e-02,  1.1865e-01,  ...,  5.2344e-01,
          -1.6992e-01, -6.2500e+00],
         [-3.5938e-01,  4.0625e-01,  1.2793e-01,  ...,  3.2031e-01,
          -1.1953e+00, -4.8750e+00],
         [-2.5781e-01,  8.5547e-01,  1.1875e+00,  ...,  2.5625e+00,
          -3.0078e-01, -7.2812e+00]],

        ...,

        [[-2.3926e-02, -3.6240e-04, -4.7607e-03,  ...,  2.4048e-02,
           2.1875e+00,  3.8086e-01],
         [ 1.6250e+00, -7.9297e-01, -4.3359e-01,  ..., -3.8086e-01,
          -4.6250e+00,  2.2852e-01],
         [ 7.5195e-02, -5.9766e-01,  7.3828e-01,  ...,  9.8047e-01,
          -3.7500e+00,  1.2256e-01],
         ...,
         [-1.8555e-01, -4.6875e-01,  1.9922e-01,  ...,  2.1562e+00,
          -3.7344e+00,  9.8438e-01],
         [ 1.5625e-02,  2.0312e-01, -8.4961e-02,  ...,  1.0625e+00,
          -4.6875e+00,  8.7891e-01],
         [ 1.0156e+00,  7.4609e-01, -2.8711e-01,  ...,  1.9922e+00,
          -3.7188e+00,  1.3672e+00]],

        [[-2.3926e-02, -3.6240e-04, -4.7607e-03,  ...,  2.4048e-02,
           2.1875e+00,  3.8086e-01],
         [ 1.6250e+00, -7.9297e-01, -4.3359e-01,  ..., -3.8086e-01,
          -4.6250e+00,  2.2852e-01],
         [ 7.5195e-02, -5.9766e-01,  7.3828e-01,  ...,  9.8047e-01,
          -3.7500e+00,  1.2256e-01],
         ...,
         [-1.8555e-01, -4.6875e-01,  1.9922e-01,  ...,  2.1562e+00,
          -3.7344e+00,  9.8438e-01],
         [ 1.5625e-02,  2.0312e-01, -8.4961e-02,  ...,  1.0625e+00,
          -4.6875e+00,  8.7891e-01],
         [ 1.0156e+00,  7.4609e-01, -2.8711e-01,  ...,  1.9922e+00,
          -3.7188e+00,  1.3672e+00]],

        [[-2.3926e-02, -3.6240e-04, -4.7607e-03,  ...,  2.4048e-02,
           2.1875e+00,  3.8086e-01],
         [ 1.6250e+00, -7.9297e-01, -4.3359e-01,  ..., -3.8086e-01,
          -4.6250e+00,  2.2852e-01],
         [ 7.5195e-02, -5.9766e-01,  7.3828e-01,  ...,  9.8047e-01,
          -3.7500e+00,  1.2256e-01],
         ...,
         [-1.8555e-01, -4.6875e-01,  1.9922e-01,  ...,  2.1562e+00,
          -3.7344e+00,  9.8438e-01],
         [ 1.5625e-02,  2.0312e-01, -8.4961e-02,  ...,  1.0625e+00,
          -4.6875e+00,  8.7891e-01],
         [ 1.0156e+00,  7.4609e-01, -2.8711e-01,  ...,  1.9922e+00,
          -3.7188e+00,  1.3672e+00]]], dtype=torch.bfloat16), tensor([[[ 1.2268e-02, -1.1047e-02, -2.2949e-02,  ...,  1.1963e-02,
          -4.4861e-03,  2.0508e-02],
         [-3.2031e-01, -8.4766e-01,  5.0391e-01,  ..., -1.3184e-01,
           2.0801e-01, -1.2188e+00],
         [-5.3711e-02,  1.8281e+00,  2.7656e+00,  ..., -1.2256e-01,
           2.0000e+00,  1.0234e+00],
         ...,
         [-1.8750e+00,  1.3438e+00,  3.4180e-01,  ...,  1.5391e+00,
           1.1484e+00,  2.1719e+00],
         [-3.2617e-01,  3.3594e-01,  9.2578e-01,  ..., -6.3965e-02,
           1.5156e+00,  1.9922e+00],
         [-3.3984e-01, -5.0781e-01, -1.0391e+00,  ..., -1.4531e+00,
           2.8906e-01, -5.8984e-01]],

        [[ 1.2268e-02, -1.1047e-02, -2.2949e-02,  ...,  1.1963e-02,
          -4.4861e-03,  2.0508e-02],
         [-3.2031e-01, -8.4766e-01,  5.0391e-01,  ..., -1.3184e-01,
           2.0801e-01, -1.2188e+00],
         [-5.3711e-02,  1.8281e+00,  2.7656e+00,  ..., -1.2256e-01,
           2.0000e+00,  1.0234e+00],
         ...,
         [-1.8750e+00,  1.3438e+00,  3.4180e-01,  ...,  1.5391e+00,
           1.1484e+00,  2.1719e+00],
         [-3.2617e-01,  3.3594e-01,  9.2578e-01,  ..., -6.3965e-02,
           1.5156e+00,  1.9922e+00],
         [-3.3984e-01, -5.0781e-01, -1.0391e+00,  ..., -1.4531e+00,
           2.8906e-01, -5.8984e-01]],

        [[ 1.2268e-02, -1.1047e-02, -2.2949e-02,  ...,  1.1963e-02,
          -4.4861e-03,  2.0508e-02],
         [-3.2031e-01, -8.4766e-01,  5.0391e-01,  ..., -1.3184e-01,
           2.0801e-01, -1.2188e+00],
         [-5.3711e-02,  1.8281e+00,  2.7656e+00,  ..., -1.2256e-01,
           2.0000e+00,  1.0234e+00],
         ...,
         [-1.8750e+00,  1.3438e+00,  3.4180e-01,  ...,  1.5391e+00,
           1.1484e+00,  2.1719e+00],
         [-3.2617e-01,  3.3594e-01,  9.2578e-01,  ..., -6.3965e-02,
           1.5156e+00,  1.9922e+00],
         [-3.3984e-01, -5.0781e-01, -1.0391e+00,  ..., -1.4531e+00,
           2.8906e-01, -5.8984e-01]],

        ...,

        [[ 4.6082e-03, -1.6022e-04, -2.8687e-02,  ..., -4.7607e-03,
          -8.6060e-03, -6.3705e-04],
         [-1.5938e+00, -1.2305e-01,  2.7148e-01,  ..., -1.8555e-02,
          -4.1016e-01,  4.1406e-01],
         [-1.4062e+00,  6.1328e-01, -2.0605e-01,  ..., -3.7305e-01,
           5.9375e-01,  4.0625e-01],
         ...,
         [-2.7148e-01,  1.1406e+00,  7.8516e-01,  ..., -1.3086e-01,
          -5.7031e-01, -1.9297e+00],
         [ 5.4688e-01,  4.3701e-02,  7.5391e-01,  ..., -7.3438e-01,
           8.0078e-01,  1.9727e-01],
         [ 8.7500e-01,  2.2559e-01,  5.5469e-01,  ..., -9.8047e-01,
          -1.9238e-01, -1.0391e+00]],

        [[ 4.6082e-03, -1.6022e-04, -2.8687e-02,  ..., -4.7607e-03,
          -8.6060e-03, -6.3705e-04],
         [-1.5938e+00, -1.2305e-01,  2.7148e-01,  ..., -1.8555e-02,
          -4.1016e-01,  4.1406e-01],
         [-1.4062e+00,  6.1328e-01, -2.0605e-01,  ..., -3.7305e-01,
           5.9375e-01,  4.0625e-01],
         ...,
         [-2.7148e-01,  1.1406e+00,  7.8516e-01,  ..., -1.3086e-01,
          -5.7031e-01, -1.9297e+00],
         [ 5.4688e-01,  4.3701e-02,  7.5391e-01,  ..., -7.3438e-01,
           8.0078e-01,  1.9727e-01],
         [ 8.7500e-01,  2.2559e-01,  5.5469e-01,  ..., -9.8047e-01,
          -1.9238e-01, -1.0391e+00]],

        [[ 4.6082e-03, -1.6022e-04, -2.8687e-02,  ..., -4.7607e-03,
          -8.6060e-03, -6.3705e-04],
         [-1.5938e+00, -1.2305e-01,  2.7148e-01,  ..., -1.8555e-02,
          -4.1016e-01,  4.1406e-01],
         [-1.4062e+00,  6.1328e-01, -2.0605e-01,  ..., -3.7305e-01,
           5.9375e-01,  4.0625e-01],
         ...,
         [-2.7148e-01,  1.1406e+00,  7.8516e-01,  ..., -1.3086e-01,
          -5.7031e-01, -1.9297e+00],
         [ 5.4688e-01,  4.3701e-02,  7.5391e-01,  ..., -7.3438e-01,
           8.0078e-01,  1.9727e-01],
         [ 8.7500e-01,  2.2559e-01,  5.5469e-01,  ..., -9.8047e-01,
          -1.9238e-01, -1.0391e+00]]], dtype=torch.bfloat16)), (tensor([[[ 7.6904e-03, -4.5166e-03,  4.1809e-03,  ...,  7.6599e-03,
          -2.2754e-01,  7.9102e-02],
         [-2.2031e+00,  9.3750e-02, -2.4512e-01,  ...,  1.4609e+00,
           7.2266e-01, -8.1641e-01],
         [-1.9062e+00, -5.0781e-01,  7.1289e-02,  ...,  1.2031e+00,
           1.2344e+00, -8.0078e-01],
         ...,
         [-2.8516e-01, -2.5195e-01, -1.5312e+00,  ...,  3.8594e+00,
          -2.4707e-01,  1.6172e+00],
         [ 5.8594e-01, -6.5625e-01,  1.3281e-01,  ...,  4.0625e+00,
          -9.5312e-01,  5.3516e-01],
         [ 1.7578e+00,  3.3398e-01, -5.1953e-01,  ...,  1.0391e+00,
           2.8906e+00, -3.9368e-03]],

        [[ 7.6904e-03, -4.5166e-03,  4.1809e-03,  ...,  7.6599e-03,
          -2.2754e-01,  7.9102e-02],
         [-2.2031e+00,  9.3750e-02, -2.4512e-01,  ...,  1.4609e+00,
           7.2266e-01, -8.1641e-01],
         [-1.9062e+00, -5.0781e-01,  7.1289e-02,  ...,  1.2031e+00,
           1.2344e+00, -8.0078e-01],
         ...,
         [-2.8516e-01, -2.5195e-01, -1.5312e+00,  ...,  3.8594e+00,
          -2.4707e-01,  1.6172e+00],
         [ 5.8594e-01, -6.5625e-01,  1.3281e-01,  ...,  4.0625e+00,
          -9.5312e-01,  5.3516e-01],
         [ 1.7578e+00,  3.3398e-01, -5.1953e-01,  ...,  1.0391e+00,
           2.8906e+00, -3.9368e-03]],

        [[ 7.6904e-03, -4.5166e-03,  4.1809e-03,  ...,  7.6599e-03,
          -2.2754e-01,  7.9102e-02],
         [-2.2031e+00,  9.3750e-02, -2.4512e-01,  ...,  1.4609e+00,
           7.2266e-01, -8.1641e-01],
         [-1.9062e+00, -5.0781e-01,  7.1289e-02,  ...,  1.2031e+00,
           1.2344e+00, -8.0078e-01],
         ...,
         [-2.8516e-01, -2.5195e-01, -1.5312e+00,  ...,  3.8594e+00,
          -2.4707e-01,  1.6172e+00],
         [ 5.8594e-01, -6.5625e-01,  1.3281e-01,  ...,  4.0625e+00,
          -9.5312e-01,  5.3516e-01],
         [ 1.7578e+00,  3.3398e-01, -5.1953e-01,  ...,  1.0391e+00,
           2.8906e+00, -3.9368e-03]],

        ...,

        [[ 4.9744e-03, -6.3171e-03,  2.1076e-04,  ..., -7.6172e-02,
           2.6562e-01,  1.2354e-01],
         [ 1.2969e+00,  1.5625e+00,  1.3906e+00,  ...,  1.0703e+00,
           5.5078e-01,  1.7812e+00],
         [ 4.7266e-01, -8.3984e-01,  1.8750e+00,  ..., -9.1406e-01,
           1.4453e+00,  4.1992e-01],
         ...,
         [-1.3203e+00, -6.0156e-01, -3.9453e-01,  ..., -2.9102e-01,
           2.1250e+00, -1.2344e+00],
         [-2.4707e-01, -3.7109e-01,  5.2344e-01,  ..., -2.1719e+00,
           1.7578e+00, -2.6250e+00],
         [-2.5586e-01, -1.3086e-01,  6.1719e-01,  ..., -5.3125e-01,
           1.4062e+00, -7.2266e-01]],

        [[ 4.9744e-03, -6.3171e-03,  2.1076e-04,  ..., -7.6172e-02,
           2.6562e-01,  1.2354e-01],
         [ 1.2969e+00,  1.5625e+00,  1.3906e+00,  ...,  1.0703e+00,
           5.5078e-01,  1.7812e+00],
         [ 4.7266e-01, -8.3984e-01,  1.8750e+00,  ..., -9.1406e-01,
           1.4453e+00,  4.1992e-01],
         ...,
         [-1.3203e+00, -6.0156e-01, -3.9453e-01,  ..., -2.9102e-01,
           2.1250e+00, -1.2344e+00],
         [-2.4707e-01, -3.7109e-01,  5.2344e-01,  ..., -2.1719e+00,
           1.7578e+00, -2.6250e+00],
         [-2.5586e-01, -1.3086e-01,  6.1719e-01,  ..., -5.3125e-01,
           1.4062e+00, -7.2266e-01]],

        [[ 4.9744e-03, -6.3171e-03,  2.1076e-04,  ..., -7.6172e-02,
           2.6562e-01,  1.2354e-01],
         [ 1.2969e+00,  1.5625e+00,  1.3906e+00,  ...,  1.0703e+00,
           5.5078e-01,  1.7812e+00],
         [ 4.7266e-01, -8.3984e-01,  1.8750e+00,  ..., -9.1406e-01,
           1.4453e+00,  4.1992e-01],
         ...,
         [-1.3203e+00, -6.0156e-01, -3.9453e-01,  ..., -2.9102e-01,
           2.1250e+00, -1.2344e+00],
         [-2.4707e-01, -3.7109e-01,  5.2344e-01,  ..., -2.1719e+00,
           1.7578e+00, -2.6250e+00],
         [-2.5586e-01, -1.3086e-01,  6.1719e-01,  ..., -5.3125e-01,
           1.4062e+00, -7.2266e-01]]], dtype=torch.bfloat16), tensor([[[-4.2725e-03,  5.8594e-03,  1.6937e-03,  ...,  5.0049e-03,
          -1.9775e-02, -1.0620e-02],
         [ 3.3203e-01,  6.7969e-01,  8.1641e-01,  ...,  3.4375e-01,
          -1.2207e-01, -5.8984e-01],
         [-2.0215e-01,  7.8516e-01, -8.2812e-01,  ...,  1.2969e+00,
          -4.8047e-01, -6.5234e-01],
         ...,
         [-1.2500e+00,  1.3203e+00,  1.0449e-01,  ...,  2.0312e+00,
           3.2812e+00, -5.7422e-01],
         [-1.3828e+00,  5.7422e-01,  1.5547e+00,  ...,  1.3359e+00,
           7.7734e-01, -1.7969e+00],
         [-4.0820e-01,  1.7090e-01, -1.6250e+00,  ..., -1.4844e+00,
           2.6953e-01, -2.9102e-01]],

        [[-4.2725e-03,  5.8594e-03,  1.6937e-03,  ...,  5.0049e-03,
          -1.9775e-02, -1.0620e-02],
         [ 3.3203e-01,  6.7969e-01,  8.1641e-01,  ...,  3.4375e-01,
          -1.2207e-01, -5.8984e-01],
         [-2.0215e-01,  7.8516e-01, -8.2812e-01,  ...,  1.2969e+00,
          -4.8047e-01, -6.5234e-01],
         ...,
         [-1.2500e+00,  1.3203e+00,  1.0449e-01,  ...,  2.0312e+00,
           3.2812e+00, -5.7422e-01],
         [-1.3828e+00,  5.7422e-01,  1.5547e+00,  ...,  1.3359e+00,
           7.7734e-01, -1.7969e+00],
         [-4.0820e-01,  1.7090e-01, -1.6250e+00,  ..., -1.4844e+00,
           2.6953e-01, -2.9102e-01]],

        [[-4.2725e-03,  5.8594e-03,  1.6937e-03,  ...,  5.0049e-03,
          -1.9775e-02, -1.0620e-02],
         [ 3.3203e-01,  6.7969e-01,  8.1641e-01,  ...,  3.4375e-01,
          -1.2207e-01, -5.8984e-01],
         [-2.0215e-01,  7.8516e-01, -8.2812e-01,  ...,  1.2969e+00,
          -4.8047e-01, -6.5234e-01],
         ...,
         [-1.2500e+00,  1.3203e+00,  1.0449e-01,  ...,  2.0312e+00,
           3.2812e+00, -5.7422e-01],
         [-1.3828e+00,  5.7422e-01,  1.5547e+00,  ...,  1.3359e+00,
           7.7734e-01, -1.7969e+00],
         [-4.0820e-01,  1.7090e-01, -1.6250e+00,  ..., -1.4844e+00,
           2.6953e-01, -2.9102e-01]],

        ...,

        [[ 2.7588e-02, -9.9487e-03, -1.7700e-02,  ..., -6.4087e-03,
           1.4038e-03,  5.3711e-03],
         [ 2.7930e-01, -2.5977e-01,  7.3047e-01,  ...,  8.4961e-02,
           3.4570e-01, -1.5332e-01],
         [-8.1177e-03,  1.7578e+00, -4.1016e-01,  ...,  3.9062e-01,
           3.3398e-01,  4.3750e-01],
         ...,
         [-4.6680e-01, -3.0469e-01,  4.7656e-01,  ...,  3.2617e-01,
          -8.2422e-01,  3.4180e-01],
         [ 8.0859e-01, -1.5234e+00,  1.1406e+00,  ...,  3.2969e+00,
           1.2734e+00, -5.5469e-01],
         [ 3.6914e-01, -7.6172e-01,  1.7969e-01,  ..., -1.8750e-01,
          -4.5508e-01,  1.8457e-01]],

        [[ 2.7588e-02, -9.9487e-03, -1.7700e-02,  ..., -6.4087e-03,
           1.4038e-03,  5.3711e-03],
         [ 2.7930e-01, -2.5977e-01,  7.3047e-01,  ...,  8.4961e-02,
           3.4570e-01, -1.5332e-01],
         [-8.1177e-03,  1.7578e+00, -4.1016e-01,  ...,  3.9062e-01,
           3.3398e-01,  4.3750e-01],
         ...,
         [-4.6680e-01, -3.0469e-01,  4.7656e-01,  ...,  3.2617e-01,
          -8.2422e-01,  3.4180e-01],
         [ 8.0859e-01, -1.5234e+00,  1.1406e+00,  ...,  3.2969e+00,
           1.2734e+00, -5.5469e-01],
         [ 3.6914e-01, -7.6172e-01,  1.7969e-01,  ..., -1.8750e-01,
          -4.5508e-01,  1.8457e-01]],

        [[ 2.7588e-02, -9.9487e-03, -1.7700e-02,  ..., -6.4087e-03,
           1.4038e-03,  5.3711e-03],
         [ 2.7930e-01, -2.5977e-01,  7.3047e-01,  ...,  8.4961e-02,
           3.4570e-01, -1.5332e-01],
         [-8.1177e-03,  1.7578e+00, -4.1016e-01,  ...,  3.9062e-01,
           3.3398e-01,  4.3750e-01],
         ...,
         [-4.6680e-01, -3.0469e-01,  4.7656e-01,  ...,  3.2617e-01,
          -8.2422e-01,  3.4180e-01],
         [ 8.0859e-01, -1.5234e+00,  1.1406e+00,  ...,  3.2969e+00,
           1.2734e+00, -5.5469e-01],
         [ 3.6914e-01, -7.6172e-01,  1.7969e-01,  ..., -1.8750e-01,
          -4.5508e-01,  1.8457e-01]]], dtype=torch.bfloat16)), (tensor([[[-2.3346e-03,  9.5215e-03, -9.2163e-03,  ..., -1.6602e-01,
          -1.3672e-01,  2.4414e-01],
         [ 2.0312e+00, -6.0156e-01, -2.6953e-01,  ...,  1.7734e+00,
           7.8516e-01, -8.5547e-01],
         [ 1.6016e+00, -2.7734e-01, -2.1289e-01,  ...,  2.1406e+00,
           1.4551e-01, -5.0000e-01],
         ...,
         [ 9.3262e-02,  4.5312e-01,  2.9102e-01,  ...,  6.5918e-02,
          -1.6016e-01,  3.4961e-01],
         [-2.9688e-01,  3.9258e-01, -4.5117e-01,  ...,  4.1992e-02,
           1.5781e+00,  1.1016e+00],
         [-1.9688e+00,  1.1523e-01, -5.0781e-01,  ...,  5.5859e-01,
           3.7305e-01, -2.4414e-01]],

        [[-2.3346e-03,  9.5215e-03, -9.2163e-03,  ..., -1.6602e-01,
          -1.3672e-01,  2.4414e-01],
         [ 2.0312e+00, -6.0156e-01, -2.6953e-01,  ...,  1.7734e+00,
           7.8516e-01, -8.5547e-01],
         [ 1.6016e+00, -2.7734e-01, -2.1289e-01,  ...,  2.1406e+00,
           1.4551e-01, -5.0000e-01],
         ...,
         [ 9.3262e-02,  4.5312e-01,  2.9102e-01,  ...,  6.5918e-02,
          -1.6016e-01,  3.4961e-01],
         [-2.9688e-01,  3.9258e-01, -4.5117e-01,  ...,  4.1992e-02,
           1.5781e+00,  1.1016e+00],
         [-1.9688e+00,  1.1523e-01, -5.0781e-01,  ...,  5.5859e-01,
           3.7305e-01, -2.4414e-01]],

        [[-2.3346e-03,  9.5215e-03, -9.2163e-03,  ..., -1.6602e-01,
          -1.3672e-01,  2.4414e-01],
         [ 2.0312e+00, -6.0156e-01, -2.6953e-01,  ...,  1.7734e+00,
           7.8516e-01, -8.5547e-01],
         [ 1.6016e+00, -2.7734e-01, -2.1289e-01,  ...,  2.1406e+00,
           1.4551e-01, -5.0000e-01],
         ...,
         [ 9.3262e-02,  4.5312e-01,  2.9102e-01,  ...,  6.5918e-02,
          -1.6016e-01,  3.4961e-01],
         [-2.9688e-01,  3.9258e-01, -4.5117e-01,  ...,  4.1992e-02,
           1.5781e+00,  1.1016e+00],
         [-1.9688e+00,  1.1523e-01, -5.0781e-01,  ...,  5.5859e-01,
           3.7305e-01, -2.4414e-01]],

        ...,

        [[ 1.5137e-02, -7.3547e-03, -4.0894e-03,  ..., -2.2852e-01,
           8.4375e-01, -1.9141e-01],
         [-1.2422e+00, -1.0703e+00,  6.9922e-01,  ..., -3.3008e-01,
           1.2266e+00,  3.3125e+00],
         [-9.8438e-01, -7.5391e-01,  7.5000e-01,  ..., -1.3184e-01,
           2.2266e-01, -2.5938e+00],
         ...,
         [ 4.0625e-01,  1.1094e+00, -1.2012e-01,  ...,  3.0156e+00,
          -5.7422e-01,  7.6953e-01],
         [ 8.5547e-01,  8.7891e-01, -1.2695e-01,  ...,  4.1211e-01,
          -1.7188e-01,  4.9805e-01],
         [ 1.0078e+00,  1.0938e-01, -1.7344e+00,  ..., -6.4844e-01,
          -1.9629e-01, -1.8906e+00]],

        [[ 1.5137e-02, -7.3547e-03, -4.0894e-03,  ..., -2.2852e-01,
           8.4375e-01, -1.9141e-01],
         [-1.2422e+00, -1.0703e+00,  6.9922e-01,  ..., -3.3008e-01,
           1.2266e+00,  3.3125e+00],
         [-9.8438e-01, -7.5391e-01,  7.5000e-01,  ..., -1.3184e-01,
           2.2266e-01, -2.5938e+00],
         ...,
         [ 4.0625e-01,  1.1094e+00, -1.2012e-01,  ...,  3.0156e+00,
          -5.7422e-01,  7.6953e-01],
         [ 8.5547e-01,  8.7891e-01, -1.2695e-01,  ...,  4.1211e-01,
          -1.7188e-01,  4.9805e-01],
         [ 1.0078e+00,  1.0938e-01, -1.7344e+00,  ..., -6.4844e-01,
          -1.9629e-01, -1.8906e+00]],

        [[ 1.5137e-02, -7.3547e-03, -4.0894e-03,  ..., -2.2852e-01,
           8.4375e-01, -1.9141e-01],
         [-1.2422e+00, -1.0703e+00,  6.9922e-01,  ..., -3.3008e-01,
           1.2266e+00,  3.3125e+00],
         [-9.8438e-01, -7.5391e-01,  7.5000e-01,  ..., -1.3184e-01,
           2.2266e-01, -2.5938e+00],
         ...,
         [ 4.0625e-01,  1.1094e+00, -1.2012e-01,  ...,  3.0156e+00,
          -5.7422e-01,  7.6953e-01],
         [ 8.5547e-01,  8.7891e-01, -1.2695e-01,  ...,  4.1211e-01,
          -1.7188e-01,  4.9805e-01],
         [ 1.0078e+00,  1.0938e-01, -1.7344e+00,  ..., -6.4844e-01,
          -1.9629e-01, -1.8906e+00]]], dtype=torch.bfloat16), tensor([[[-7.1106e-03, -1.7944e-02,  7.7820e-03,  ...,  1.2756e-02,
          -1.6724e-02,  4.2725e-03],
         [-1.6797e-01, -6.9922e-01,  4.2969e-02,  ...,  7.4707e-02,
          -1.8066e-01, -3.6133e-02],
         [ 2.7539e-01,  1.1562e+00, -1.1875e+00,  ..., -6.7383e-02,
          -3.1055e-01,  2.6611e-02],
         ...,
         [ 1.6895e-01,  1.6211e-01, -1.1572e-01,  ...,  2.7734e-01,
          -5.3906e-01,  3.3789e-01],
         [-5.9766e-01,  5.2344e-01,  6.7969e-01,  ..., -1.1406e+00,
          -4.1602e-01, -5.8984e-01],
         [-2.8125e-01,  2.0938e+00,  1.3125e+00,  ...,  7.5000e-01,
           1.1797e+00,  2.0469e+00]],

        [[-7.1106e-03, -1.7944e-02,  7.7820e-03,  ...,  1.2756e-02,
          -1.6724e-02,  4.2725e-03],
         [-1.6797e-01, -6.9922e-01,  4.2969e-02,  ...,  7.4707e-02,
          -1.8066e-01, -3.6133e-02],
         [ 2.7539e-01,  1.1562e+00, -1.1875e+00,  ..., -6.7383e-02,
          -3.1055e-01,  2.6611e-02],
         ...,
         [ 1.6895e-01,  1.6211e-01, -1.1572e-01,  ...,  2.7734e-01,
          -5.3906e-01,  3.3789e-01],
         [-5.9766e-01,  5.2344e-01,  6.7969e-01,  ..., -1.1406e+00,
          -4.1602e-01, -5.8984e-01],
         [-2.8125e-01,  2.0938e+00,  1.3125e+00,  ...,  7.5000e-01,
           1.1797e+00,  2.0469e+00]],

        [[-7.1106e-03, -1.7944e-02,  7.7820e-03,  ...,  1.2756e-02,
          -1.6724e-02,  4.2725e-03],
         [-1.6797e-01, -6.9922e-01,  4.2969e-02,  ...,  7.4707e-02,
          -1.8066e-01, -3.6133e-02],
         [ 2.7539e-01,  1.1562e+00, -1.1875e+00,  ..., -6.7383e-02,
          -3.1055e-01,  2.6611e-02],
         ...,
         [ 1.6895e-01,  1.6211e-01, -1.1572e-01,  ...,  2.7734e-01,
          -5.3906e-01,  3.3789e-01],
         [-5.9766e-01,  5.2344e-01,  6.7969e-01,  ..., -1.1406e+00,
          -4.1602e-01, -5.8984e-01],
         [-2.8125e-01,  2.0938e+00,  1.3125e+00,  ...,  7.5000e-01,
           1.1797e+00,  2.0469e+00]],

        ...,

        [[-3.6469e-03, -8.3618e-03,  3.2349e-03,  ...,  2.0504e-04,
          -3.1128e-03,  1.9836e-03],
         [-2.4844e+00,  9.6484e-01, -8.0469e-01,  ...,  1.8047e+00,
           2.6562e+00, -2.4688e+00],
         [-8.5938e-01, -1.7344e+00, -9.7266e-01,  ...,  6.6797e-01,
           5.0781e-01, -2.2500e+00],
         ...,
         [ 1.1797e+00, -8.7891e-01, -5.4297e-01,  ...,  6.0156e-01,
          -9.7656e-01,  5.5469e-01],
         [-4.7656e-01, -2.6562e-01, -1.3672e+00,  ...,  1.3516e+00,
           1.1865e-01, -8.7109e-01],
         [-7.6562e-01, -3.3789e-01, -2.7344e+00,  ...,  1.2031e+00,
          -1.1797e+00,  7.7637e-02]],

        [[-3.6469e-03, -8.3618e-03,  3.2349e-03,  ...,  2.0504e-04,
          -3.1128e-03,  1.9836e-03],
         [-2.4844e+00,  9.6484e-01, -8.0469e-01,  ...,  1.8047e+00,
           2.6562e+00, -2.4688e+00],
         [-8.5938e-01, -1.7344e+00, -9.7266e-01,  ...,  6.6797e-01,
           5.0781e-01, -2.2500e+00],
         ...,
         [ 1.1797e+00, -8.7891e-01, -5.4297e-01,  ...,  6.0156e-01,
          -9.7656e-01,  5.5469e-01],
         [-4.7656e-01, -2.6562e-01, -1.3672e+00,  ...,  1.3516e+00,
           1.1865e-01, -8.7109e-01],
         [-7.6562e-01, -3.3789e-01, -2.7344e+00,  ...,  1.2031e+00,
          -1.1797e+00,  7.7637e-02]],

        [[-3.6469e-03, -8.3618e-03,  3.2349e-03,  ...,  2.0504e-04,
          -3.1128e-03,  1.9836e-03],
         [-2.4844e+00,  9.6484e-01, -8.0469e-01,  ...,  1.8047e+00,
           2.6562e+00, -2.4688e+00],
         [-8.5938e-01, -1.7344e+00, -9.7266e-01,  ...,  6.6797e-01,
           5.0781e-01, -2.2500e+00],
         ...,
         [ 1.1797e+00, -8.7891e-01, -5.4297e-01,  ...,  6.0156e-01,
          -9.7656e-01,  5.5469e-01],
         [-4.7656e-01, -2.6562e-01, -1.3672e+00,  ...,  1.3516e+00,
           1.1865e-01, -8.7109e-01],
         [-7.6562e-01, -3.3789e-01, -2.7344e+00,  ...,  1.2031e+00,
          -1.1797e+00,  7.7637e-02]]], dtype=torch.bfloat16)), (tensor([[[ 6.3324e-04,  3.0518e-03, -2.5635e-03,  ..., -1.9043e-01,
          -1.2695e-01,  1.2598e-01],
         [ 0.0000e+00,  9.1406e-01, -2.5000e+00,  ...,  1.4062e-01,
           6.0938e-01, -8.7891e-01],
         [-1.0078e+00,  1.2500e+00, -2.9492e-01,  ..., -2.1680e-01,
          -4.3555e-01,  2.7734e-01],
         ...,
         [-3.6719e-01, -8.8672e-01,  2.0801e-01,  ..., -2.3633e-01,
          -5.4688e-01, -6.9922e-01],
         [ 2.1973e-01, -9.0234e-01, -9.9609e-02,  ..., -1.8672e+00,
          -4.6289e-01,  7.0312e-01],
         [ 1.6328e+00, -1.4453e+00, -4.0820e-01,  ...,  1.0498e-01,
          -1.6797e+00, -1.1562e+00]],

        [[ 6.3324e-04,  3.0518e-03, -2.5635e-03,  ..., -1.9043e-01,
          -1.2695e-01,  1.2598e-01],
         [ 0.0000e+00,  9.1406e-01, -2.5000e+00,  ...,  1.4062e-01,
           6.0938e-01, -8.7891e-01],
         [-1.0078e+00,  1.2500e+00, -2.9492e-01,  ..., -2.1680e-01,
          -4.3555e-01,  2.7734e-01],
         ...,
         [-3.6719e-01, -8.8672e-01,  2.0801e-01,  ..., -2.3633e-01,
          -5.4688e-01, -6.9922e-01],
         [ 2.1973e-01, -9.0234e-01, -9.9609e-02,  ..., -1.8672e+00,
          -4.6289e-01,  7.0312e-01],
         [ 1.6328e+00, -1.4453e+00, -4.0820e-01,  ...,  1.0498e-01,
          -1.6797e+00, -1.1562e+00]],

        [[ 6.3324e-04,  3.0518e-03, -2.5635e-03,  ..., -1.9043e-01,
          -1.2695e-01,  1.2598e-01],
         [ 0.0000e+00,  9.1406e-01, -2.5000e+00,  ...,  1.4062e-01,
           6.0938e-01, -8.7891e-01],
         [-1.0078e+00,  1.2500e+00, -2.9492e-01,  ..., -2.1680e-01,
          -4.3555e-01,  2.7734e-01],
         ...,
         [-3.6719e-01, -8.8672e-01,  2.0801e-01,  ..., -2.3633e-01,
          -5.4688e-01, -6.9922e-01],
         [ 2.1973e-01, -9.0234e-01, -9.9609e-02,  ..., -1.8672e+00,
          -4.6289e-01,  7.0312e-01],
         [ 1.6328e+00, -1.4453e+00, -4.0820e-01,  ...,  1.0498e-01,
          -1.6797e+00, -1.1562e+00]],

        ...,

        [[-1.1978e-03, -1.1841e-02, -8.0566e-03,  ...,  1.3965e-01,
           5.9082e-02,  1.1865e-01],
         [-2.5781e+00, -1.0078e+00,  1.7578e+00,  ..., -2.5156e+00,
           2.7734e-01, -6.7188e-01],
         [-1.5938e+00, -1.5312e+00,  1.4609e+00,  ..., -3.6719e-01,
           1.8125e+00, -1.7188e+00],
         ...,
         [-1.1719e+00,  1.3594e+00,  1.4609e+00,  ...,  7.3047e-01,
          -2.5469e+00,  2.1562e+00],
         [ 8.1250e-01,  6.4844e-01,  1.2031e+00,  ...,  1.8281e+00,
          -3.8672e-01,  2.3594e+00],
         [ 2.1875e+00, -5.3125e-01,  1.3203e+00,  ..., -3.3984e-01,
           1.5332e-01, -8.5449e-03]],

        [[-1.1978e-03, -1.1841e-02, -8.0566e-03,  ...,  1.3965e-01,
           5.9082e-02,  1.1865e-01],
         [-2.5781e+00, -1.0078e+00,  1.7578e+00,  ..., -2.5156e+00,
           2.7734e-01, -6.7188e-01],
         [-1.5938e+00, -1.5312e+00,  1.4609e+00,  ..., -3.6719e-01,
           1.8125e+00, -1.7188e+00],
         ...,
         [-1.1719e+00,  1.3594e+00,  1.4609e+00,  ...,  7.3047e-01,
          -2.5469e+00,  2.1562e+00],
         [ 8.1250e-01,  6.4844e-01,  1.2031e+00,  ...,  1.8281e+00,
          -3.8672e-01,  2.3594e+00],
         [ 2.1875e+00, -5.3125e-01,  1.3203e+00,  ..., -3.3984e-01,
           1.5332e-01, -8.5449e-03]],

        [[-1.1978e-03, -1.1841e-02, -8.0566e-03,  ...,  1.3965e-01,
           5.9082e-02,  1.1865e-01],
         [-2.5781e+00, -1.0078e+00,  1.7578e+00,  ..., -2.5156e+00,
           2.7734e-01, -6.7188e-01],
         [-1.5938e+00, -1.5312e+00,  1.4609e+00,  ..., -3.6719e-01,
           1.8125e+00, -1.7188e+00],
         ...,
         [-1.1719e+00,  1.3594e+00,  1.4609e+00,  ...,  7.3047e-01,
          -2.5469e+00,  2.1562e+00],
         [ 8.1250e-01,  6.4844e-01,  1.2031e+00,  ...,  1.8281e+00,
          -3.8672e-01,  2.3594e+00],
         [ 2.1875e+00, -5.3125e-01,  1.3203e+00,  ..., -3.3984e-01,
           1.5332e-01, -8.5449e-03]]], dtype=torch.bfloat16), tensor([[[ 2.4658e-02, -2.5558e-04, -4.5166e-03,  ..., -9.0942e-03,
           1.7822e-02, -9.8877e-03],
         [-3.9648e-01,  2.2656e-01,  2.7344e-01,  ..., -8.3984e-01,
           4.1602e-01, -3.8086e-01],
         [-1.7383e-01,  4.6094e-01, -1.5859e+00,  ...,  3.0469e-01,
          -1.4375e+00, -8.2812e-01],
         ...,
         [ 8.8281e-01,  2.2344e+00,  2.2656e+00,  ...,  2.0117e-01,
           1.2578e+00,  9.7656e-01],
         [ 1.3047e+00,  1.3203e+00,  2.3730e-01,  ..., -4.3945e-01,
           7.7344e-01, -8.5449e-02],
         [-8.9453e-01, -9.7168e-02, -4.8242e-01,  ...,  3.5156e-01,
           1.3438e+00, -3.8086e-01]],

        [[ 2.4658e-02, -2.5558e-04, -4.5166e-03,  ..., -9.0942e-03,
           1.7822e-02, -9.8877e-03],
         [-3.9648e-01,  2.2656e-01,  2.7344e-01,  ..., -8.3984e-01,
           4.1602e-01, -3.8086e-01],
         [-1.7383e-01,  4.6094e-01, -1.5859e+00,  ...,  3.0469e-01,
          -1.4375e+00, -8.2812e-01],
         ...,
         [ 8.8281e-01,  2.2344e+00,  2.2656e+00,  ...,  2.0117e-01,
           1.2578e+00,  9.7656e-01],
         [ 1.3047e+00,  1.3203e+00,  2.3730e-01,  ..., -4.3945e-01,
           7.7344e-01, -8.5449e-02],
         [-8.9453e-01, -9.7168e-02, -4.8242e-01,  ...,  3.5156e-01,
           1.3438e+00, -3.8086e-01]],

        [[ 2.4658e-02, -2.5558e-04, -4.5166e-03,  ..., -9.0942e-03,
           1.7822e-02, -9.8877e-03],
         [-3.9648e-01,  2.2656e-01,  2.7344e-01,  ..., -8.3984e-01,
           4.1602e-01, -3.8086e-01],
         [-1.7383e-01,  4.6094e-01, -1.5859e+00,  ...,  3.0469e-01,
          -1.4375e+00, -8.2812e-01],
         ...,
         [ 8.8281e-01,  2.2344e+00,  2.2656e+00,  ...,  2.0117e-01,
           1.2578e+00,  9.7656e-01],
         [ 1.3047e+00,  1.3203e+00,  2.3730e-01,  ..., -4.3945e-01,
           7.7344e-01, -8.5449e-02],
         [-8.9453e-01, -9.7168e-02, -4.8242e-01,  ...,  3.5156e-01,
           1.3438e+00, -3.8086e-01]],

        ...,

        [[-1.2085e-02, -7.4463e-03, -1.9836e-03,  ..., -1.0315e-02,
          -3.5858e-03, -1.2756e-02],
         [-3.5742e-01, -1.9336e-01, -2.2656e-01,  ..., -8.3984e-01,
           5.3906e-01, -1.9141e+00],
         [-9.7656e-03, -3.3594e-01,  3.7109e-02,  ...,  8.7109e-01,
          -7.7734e-01, -7.7734e-01],
         ...,
         [ 4.0625e-01,  1.0781e+00, -8.5938e-01,  ..., -9.2578e-01,
           1.7422e+00,  4.0234e-01],
         [ 3.3398e-01, -1.2500e-01, -1.4531e+00,  ...,  3.7305e-01,
           8.7500e-01,  1.2256e-01],
         [-9.5703e-01, -7.2266e-01, -2.3730e-01,  ..., -5.7031e-01,
           3.2031e-01,  2.2070e-01]],

        [[-1.2085e-02, -7.4463e-03, -1.9836e-03,  ..., -1.0315e-02,
          -3.5858e-03, -1.2756e-02],
         [-3.5742e-01, -1.9336e-01, -2.2656e-01,  ..., -8.3984e-01,
           5.3906e-01, -1.9141e+00],
         [-9.7656e-03, -3.3594e-01,  3.7109e-02,  ...,  8.7109e-01,
          -7.7734e-01, -7.7734e-01],
         ...,
         [ 4.0625e-01,  1.0781e+00, -8.5938e-01,  ..., -9.2578e-01,
           1.7422e+00,  4.0234e-01],
         [ 3.3398e-01, -1.2500e-01, -1.4531e+00,  ...,  3.7305e-01,
           8.7500e-01,  1.2256e-01],
         [-9.5703e-01, -7.2266e-01, -2.3730e-01,  ..., -5.7031e-01,
           3.2031e-01,  2.2070e-01]],

        [[-1.2085e-02, -7.4463e-03, -1.9836e-03,  ..., -1.0315e-02,
          -3.5858e-03, -1.2756e-02],
         [-3.5742e-01, -1.9336e-01, -2.2656e-01,  ..., -8.3984e-01,
           5.3906e-01, -1.9141e+00],
         [-9.7656e-03, -3.3594e-01,  3.7109e-02,  ...,  8.7109e-01,
          -7.7734e-01, -7.7734e-01],
         ...,
         [ 4.0625e-01,  1.0781e+00, -8.5938e-01,  ..., -9.2578e-01,
           1.7422e+00,  4.0234e-01],
         [ 3.3398e-01, -1.2500e-01, -1.4531e+00,  ...,  3.7305e-01,
           8.7500e-01,  1.2256e-01],
         [-9.5703e-01, -7.2266e-01, -2.3730e-01,  ..., -5.7031e-01,
           3.2031e-01,  2.2070e-01]]], dtype=torch.bfloat16)), (tensor([[[-3.1433e-03,  4.5776e-03,  1.3306e-02,  ..., -1.6016e-01,
           2.3535e-01,  3.0273e-01],
         [-3.1719e+00, -1.2656e+00, -1.2812e+00,  ...,  2.9883e-01,
          -2.4805e-01, -5.2734e-01],
         [-1.9141e-01, -8.1250e-01, -4.2969e-01,  ..., -3.5889e-02,
          -1.2031e+00, -1.0234e+00],
         ...,
         [ 1.4609e+00,  1.1875e+00, -1.1172e+00,  ...,  9.9609e-01,
          -6.0156e-01, -8.4375e-01],
         [ 1.9062e+00,  4.4727e-01, -1.7188e+00,  ...,  9.2969e-01,
          -5.7812e-01, -2.2266e-01],
         [ 2.6562e-01, -1.5078e+00, -5.0391e-01,  ...,  4.1602e-01,
           1.3281e+00, -1.7344e+00]],

        [[-3.1433e-03,  4.5776e-03,  1.3306e-02,  ..., -1.6016e-01,
           2.3535e-01,  3.0273e-01],
         [-3.1719e+00, -1.2656e+00, -1.2812e+00,  ...,  2.9883e-01,
          -2.4805e-01, -5.2734e-01],
         [-1.9141e-01, -8.1250e-01, -4.2969e-01,  ..., -3.5889e-02,
          -1.2031e+00, -1.0234e+00],
         ...,
         [ 1.4609e+00,  1.1875e+00, -1.1172e+00,  ...,  9.9609e-01,
          -6.0156e-01, -8.4375e-01],
         [ 1.9062e+00,  4.4727e-01, -1.7188e+00,  ...,  9.2969e-01,
          -5.7812e-01, -2.2266e-01],
         [ 2.6562e-01, -1.5078e+00, -5.0391e-01,  ...,  4.1602e-01,
           1.3281e+00, -1.7344e+00]],

        [[-3.1433e-03,  4.5776e-03,  1.3306e-02,  ..., -1.6016e-01,
           2.3535e-01,  3.0273e-01],
         [-3.1719e+00, -1.2656e+00, -1.2812e+00,  ...,  2.9883e-01,
          -2.4805e-01, -5.2734e-01],
         [-1.9141e-01, -8.1250e-01, -4.2969e-01,  ..., -3.5889e-02,
          -1.2031e+00, -1.0234e+00],
         ...,
         [ 1.4609e+00,  1.1875e+00, -1.1172e+00,  ...,  9.9609e-01,
          -6.0156e-01, -8.4375e-01],
         [ 1.9062e+00,  4.4727e-01, -1.7188e+00,  ...,  9.2969e-01,
          -5.7812e-01, -2.2266e-01],
         [ 2.6562e-01, -1.5078e+00, -5.0391e-01,  ...,  4.1602e-01,
           1.3281e+00, -1.7344e+00]],

        ...,

        [[ 6.1798e-04,  9.0027e-04, -2.0996e-02,  ...,  1.0681e-03,
          -4.8828e-02, -1.9336e-01],
         [-3.7031e+00, -1.8047e+00, -6.6406e-02,  ...,  2.6562e-01,
           1.8594e+00,  1.1914e-01],
         [-7.6953e-01, -1.3828e+00,  2.4414e-02,  ...,  5.9375e-01,
          -1.5391e+00, -1.1484e+00],
         ...,
         [ 1.7500e+00,  5.2344e-01, -1.0391e+00,  ...,  1.0312e+00,
           8.5547e-01, -4.9805e-01],
         [ 9.5312e-01,  3.4375e-01,  1.3281e-01,  ...,  3.4961e-01,
          -2.2344e+00, -6.3672e-01],
         [ 8.1641e-01, -8.9844e-01,  4.1016e-01,  ...,  1.2969e+00,
           3.2031e-01, -1.2734e+00]],

        [[ 6.1798e-04,  9.0027e-04, -2.0996e-02,  ...,  1.0681e-03,
          -4.8828e-02, -1.9336e-01],
         [-3.7031e+00, -1.8047e+00, -6.6406e-02,  ...,  2.6562e-01,
           1.8594e+00,  1.1914e-01],
         [-7.6953e-01, -1.3828e+00,  2.4414e-02,  ...,  5.9375e-01,
          -1.5391e+00, -1.1484e+00],
         ...,
         [ 1.7500e+00,  5.2344e-01, -1.0391e+00,  ...,  1.0312e+00,
           8.5547e-01, -4.9805e-01],
         [ 9.5312e-01,  3.4375e-01,  1.3281e-01,  ...,  3.4961e-01,
          -2.2344e+00, -6.3672e-01],
         [ 8.1641e-01, -8.9844e-01,  4.1016e-01,  ...,  1.2969e+00,
           3.2031e-01, -1.2734e+00]],

        [[ 6.1798e-04,  9.0027e-04, -2.0996e-02,  ...,  1.0681e-03,
          -4.8828e-02, -1.9336e-01],
         [-3.7031e+00, -1.8047e+00, -6.6406e-02,  ...,  2.6562e-01,
           1.8594e+00,  1.1914e-01],
         [-7.6953e-01, -1.3828e+00,  2.4414e-02,  ...,  5.9375e-01,
          -1.5391e+00, -1.1484e+00],
         ...,
         [ 1.7500e+00,  5.2344e-01, -1.0391e+00,  ...,  1.0312e+00,
           8.5547e-01, -4.9805e-01],
         [ 9.5312e-01,  3.4375e-01,  1.3281e-01,  ...,  3.4961e-01,
          -2.2344e+00, -6.3672e-01],
         [ 8.1641e-01, -8.9844e-01,  4.1016e-01,  ...,  1.2969e+00,
           3.2031e-01, -1.2734e+00]]], dtype=torch.bfloat16), tensor([[[ 1.0925e-02, -2.8229e-03, -1.0071e-02,  ..., -1.0803e-02,
           2.3193e-02, -4.3030e-03],
         [-3.6328e-01,  2.4512e-01, -1.2031e+00,  ..., -1.0498e-02,
           2.5781e-01, -9.3359e-01],
         [-7.7637e-02,  4.9219e-01, -2.6367e-01,  ...,  1.1016e+00,
          -1.1562e+00, -8.4473e-02],
         ...,
         [-9.2578e-01,  8.5938e-01, -5.4297e-01,  ...,  4.4727e-01,
          -5.5469e-01, -4.5312e-01],
         [-2.0996e-02, -5.0391e-01, -8.0566e-02,  ...,  7.0312e-01,
          -1.3906e+00,  1.2812e+00],
         [ 2.0020e-01,  5.0781e-01, -1.0547e+00,  ...,  5.4688e-01,
          -9.6484e-01,  7.9688e-01]],

        [[ 1.0925e-02, -2.8229e-03, -1.0071e-02,  ..., -1.0803e-02,
           2.3193e-02, -4.3030e-03],
         [-3.6328e-01,  2.4512e-01, -1.2031e+00,  ..., -1.0498e-02,
           2.5781e-01, -9.3359e-01],
         [-7.7637e-02,  4.9219e-01, -2.6367e-01,  ...,  1.1016e+00,
          -1.1562e+00, -8.4473e-02],
         ...,
         [-9.2578e-01,  8.5938e-01, -5.4297e-01,  ...,  4.4727e-01,
          -5.5469e-01, -4.5312e-01],
         [-2.0996e-02, -5.0391e-01, -8.0566e-02,  ...,  7.0312e-01,
          -1.3906e+00,  1.2812e+00],
         [ 2.0020e-01,  5.0781e-01, -1.0547e+00,  ...,  5.4688e-01,
          -9.6484e-01,  7.9688e-01]],

        [[ 1.0925e-02, -2.8229e-03, -1.0071e-02,  ..., -1.0803e-02,
           2.3193e-02, -4.3030e-03],
         [-3.6328e-01,  2.4512e-01, -1.2031e+00,  ..., -1.0498e-02,
           2.5781e-01, -9.3359e-01],
         [-7.7637e-02,  4.9219e-01, -2.6367e-01,  ...,  1.1016e+00,
          -1.1562e+00, -8.4473e-02],
         ...,
         [-9.2578e-01,  8.5938e-01, -5.4297e-01,  ...,  4.4727e-01,
          -5.5469e-01, -4.5312e-01],
         [-2.0996e-02, -5.0391e-01, -8.0566e-02,  ...,  7.0312e-01,
          -1.3906e+00,  1.2812e+00],
         [ 2.0020e-01,  5.0781e-01, -1.0547e+00,  ...,  5.4688e-01,
          -9.6484e-01,  7.9688e-01]],

        ...,

        [[-6.8283e-04, -3.7079e-03,  7.7820e-04,  ...,  2.1118e-02,
           6.7444e-03,  5.3406e-03],
         [-2.0508e-01, -2.8516e-01,  4.9609e-01,  ...,  1.5938e+00,
           1.1719e+00,  5.5078e-01],
         [-2.8125e-01, -1.6992e-01, -5.9766e-01,  ...,  4.9414e-01,
           1.4160e-02,  4.6289e-01],
         ...,
         [ 1.8164e-01, -1.5312e+00, -2.0117e-01,  ...,  1.2188e+00,
          -2.3594e+00,  3.1055e-01],
         [ 7.1777e-02,  7.4609e-01,  1.0156e+00,  ..., -2.9297e-01,
          -8.0859e-01,  1.7031e+00],
         [-5.6250e-01,  8.3496e-02,  2.1973e-01,  ..., -5.0391e-01,
           2.4121e-01, -2.0142e-02]],

        [[-6.8283e-04, -3.7079e-03,  7.7820e-04,  ...,  2.1118e-02,
           6.7444e-03,  5.3406e-03],
         [-2.0508e-01, -2.8516e-01,  4.9609e-01,  ...,  1.5938e+00,
           1.1719e+00,  5.5078e-01],
         [-2.8125e-01, -1.6992e-01, -5.9766e-01,  ...,  4.9414e-01,
           1.4160e-02,  4.6289e-01],
         ...,
         [ 1.8164e-01, -1.5312e+00, -2.0117e-01,  ...,  1.2188e+00,
          -2.3594e+00,  3.1055e-01],
         [ 7.1777e-02,  7.4609e-01,  1.0156e+00,  ..., -2.9297e-01,
          -8.0859e-01,  1.7031e+00],
         [-5.6250e-01,  8.3496e-02,  2.1973e-01,  ..., -5.0391e-01,
           2.4121e-01, -2.0142e-02]],

        [[-6.8283e-04, -3.7079e-03,  7.7820e-04,  ...,  2.1118e-02,
           6.7444e-03,  5.3406e-03],
         [-2.0508e-01, -2.8516e-01,  4.9609e-01,  ...,  1.5938e+00,
           1.1719e+00,  5.5078e-01],
         [-2.8125e-01, -1.6992e-01, -5.9766e-01,  ...,  4.9414e-01,
           1.4160e-02,  4.6289e-01],
         ...,
         [ 1.8164e-01, -1.5312e+00, -2.0117e-01,  ...,  1.2188e+00,
          -2.3594e+00,  3.1055e-01],
         [ 7.1777e-02,  7.4609e-01,  1.0156e+00,  ..., -2.9297e-01,
          -8.0859e-01,  1.7031e+00],
         [-5.6250e-01,  8.3496e-02,  2.1973e-01,  ..., -5.0391e-01,
           2.4121e-01, -2.0142e-02]]], dtype=torch.bfloat16)), (tensor([[[ 1.0437e-02, -6.3782e-03, -7.9346e-03,  ..., -1.8750e+00,
           8.4473e-02,  7.5781e-01],
         [-4.1016e-01,  6.1719e-01,  5.3906e-01,  ...,  1.1016e+00,
           1.3281e-01, -5.6562e+00],
         [ 4.3945e-01, -1.9336e-01,  2.6562e-01,  ...,  4.3438e+00,
          -3.1250e+00, -2.8594e+00],
         ...,
         [-1.1816e-01, -6.2109e-01, -6.2891e-01,  ...,  3.6562e+00,
          -1.0469e+00, -5.0391e-01],
         [-3.8281e-01,  2.4609e-01, -2.2461e-02,  ...,  4.1562e+00,
           1.1914e-01, -2.3281e+00],
         [ 4.9805e-01, -4.1016e-01,  5.5078e-01,  ...,  3.8125e+00,
           3.5352e-01, -2.0781e+00]],

        [[ 1.0437e-02, -6.3782e-03, -7.9346e-03,  ..., -1.8750e+00,
           8.4473e-02,  7.5781e-01],
         [-4.1016e-01,  6.1719e-01,  5.3906e-01,  ...,  1.1016e+00,
           1.3281e-01, -5.6562e+00],
         [ 4.3945e-01, -1.9336e-01,  2.6562e-01,  ...,  4.3438e+00,
          -3.1250e+00, -2.8594e+00],
         ...,
         [-1.1816e-01, -6.2109e-01, -6.2891e-01,  ...,  3.6562e+00,
          -1.0469e+00, -5.0391e-01],
         [-3.8281e-01,  2.4609e-01, -2.2461e-02,  ...,  4.1562e+00,
           1.1914e-01, -2.3281e+00],
         [ 4.9805e-01, -4.1016e-01,  5.5078e-01,  ...,  3.8125e+00,
           3.5352e-01, -2.0781e+00]],

        [[ 1.0437e-02, -6.3782e-03, -7.9346e-03,  ..., -1.8750e+00,
           8.4473e-02,  7.5781e-01],
         [-4.1016e-01,  6.1719e-01,  5.3906e-01,  ...,  1.1016e+00,
           1.3281e-01, -5.6562e+00],
         [ 4.3945e-01, -1.9336e-01,  2.6562e-01,  ...,  4.3438e+00,
          -3.1250e+00, -2.8594e+00],
         ...,
         [-1.1816e-01, -6.2109e-01, -6.2891e-01,  ...,  3.6562e+00,
          -1.0469e+00, -5.0391e-01],
         [-3.8281e-01,  2.4609e-01, -2.2461e-02,  ...,  4.1562e+00,
           1.1914e-01, -2.3281e+00],
         [ 4.9805e-01, -4.1016e-01,  5.5078e-01,  ...,  3.8125e+00,
           3.5352e-01, -2.0781e+00]],

        ...,

        [[-6.0730e-03, -9.9487e-03, -2.1057e-03,  ...,  1.2158e-01,
           2.1484e-01,  6.1646e-03],
         [-1.2656e+00, -9.0625e-01,  1.5625e+00,  ...,  1.6641e+00,
           8.0469e-01, -7.3438e-01],
         [-1.6211e-01,  7.6172e-02,  1.0312e+00,  ...,  5.5625e+00,
           3.5312e+00,  3.0664e-01],
         ...,
         [ 1.1670e-01,  7.7344e-01,  5.9375e-01,  ...,  3.0312e+00,
           3.0000e+00,  2.7969e+00],
         [-1.6211e-01,  8.7891e-01,  1.2109e+00,  ...,  4.0625e+00,
           4.0312e+00,  7.2188e+00],
         [ 1.2734e+00,  9.9219e-01,  6.0156e-01,  ...,  4.4062e+00,
           2.1875e+00,  2.8594e+00]],

        [[-6.0730e-03, -9.9487e-03, -2.1057e-03,  ...,  1.2158e-01,
           2.1484e-01,  6.1646e-03],
         [-1.2656e+00, -9.0625e-01,  1.5625e+00,  ...,  1.6641e+00,
           8.0469e-01, -7.3438e-01],
         [-1.6211e-01,  7.6172e-02,  1.0312e+00,  ...,  5.5625e+00,
           3.5312e+00,  3.0664e-01],
         ...,
         [ 1.1670e-01,  7.7344e-01,  5.9375e-01,  ...,  3.0312e+00,
           3.0000e+00,  2.7969e+00],
         [-1.6211e-01,  8.7891e-01,  1.2109e+00,  ...,  4.0625e+00,
           4.0312e+00,  7.2188e+00],
         [ 1.2734e+00,  9.9219e-01,  6.0156e-01,  ...,  4.4062e+00,
           2.1875e+00,  2.8594e+00]],

        [[-6.0730e-03, -9.9487e-03, -2.1057e-03,  ...,  1.2158e-01,
           2.1484e-01,  6.1646e-03],
         [-1.2656e+00, -9.0625e-01,  1.5625e+00,  ...,  1.6641e+00,
           8.0469e-01, -7.3438e-01],
         [-1.6211e-01,  7.6172e-02,  1.0312e+00,  ...,  5.5625e+00,
           3.5312e+00,  3.0664e-01],
         ...,
         [ 1.1670e-01,  7.7344e-01,  5.9375e-01,  ...,  3.0312e+00,
           3.0000e+00,  2.7969e+00],
         [-1.6211e-01,  8.7891e-01,  1.2109e+00,  ...,  4.0625e+00,
           4.0312e+00,  7.2188e+00],
         [ 1.2734e+00,  9.9219e-01,  6.0156e-01,  ...,  4.4062e+00,
           2.1875e+00,  2.8594e+00]]], dtype=torch.bfloat16), tensor([[[-1.1658e-02,  3.4332e-03,  1.4099e-02,  ...,  2.7954e-02,
           8.4229e-03, -1.0010e-02],
         [-8.8672e-01,  2.0156e+00, -1.7500e+00,  ..., -9.3750e-01,
           2.1719e+00,  2.9531e+00],
         [ 1.2969e+00, -3.3750e+00,  9.1406e-01,  ...,  8.5156e-01,
           4.4727e-01,  2.2031e+00],
         ...,
         [ 7.4609e-01,  1.5469e+00, -8.4375e-01,  ..., -7.7344e-01,
           7.9297e-01, -1.8433e-02],
         [ 4.6484e-01, -8.1250e-01, -7.1289e-02,  ...,  4.9219e-01,
          -4.0039e-01,  2.4316e-01],
         [ 2.3438e+00, -8.5156e-01,  4.4727e-01,  ..., -3.8867e-01,
          -8.7891e-01, -5.0293e-02]],

        [[-1.1658e-02,  3.4332e-03,  1.4099e-02,  ...,  2.7954e-02,
           8.4229e-03, -1.0010e-02],
         [-8.8672e-01,  2.0156e+00, -1.7500e+00,  ..., -9.3750e-01,
           2.1719e+00,  2.9531e+00],
         [ 1.2969e+00, -3.3750e+00,  9.1406e-01,  ...,  8.5156e-01,
           4.4727e-01,  2.2031e+00],
         ...,
         [ 7.4609e-01,  1.5469e+00, -8.4375e-01,  ..., -7.7344e-01,
           7.9297e-01, -1.8433e-02],
         [ 4.6484e-01, -8.1250e-01, -7.1289e-02,  ...,  4.9219e-01,
          -4.0039e-01,  2.4316e-01],
         [ 2.3438e+00, -8.5156e-01,  4.4727e-01,  ..., -3.8867e-01,
          -8.7891e-01, -5.0293e-02]],

        [[-1.1658e-02,  3.4332e-03,  1.4099e-02,  ...,  2.7954e-02,
           8.4229e-03, -1.0010e-02],
         [-8.8672e-01,  2.0156e+00, -1.7500e+00,  ..., -9.3750e-01,
           2.1719e+00,  2.9531e+00],
         [ 1.2969e+00, -3.3750e+00,  9.1406e-01,  ...,  8.5156e-01,
           4.4727e-01,  2.2031e+00],
         ...,
         [ 7.4609e-01,  1.5469e+00, -8.4375e-01,  ..., -7.7344e-01,
           7.9297e-01, -1.8433e-02],
         [ 4.6484e-01, -8.1250e-01, -7.1289e-02,  ...,  4.9219e-01,
          -4.0039e-01,  2.4316e-01],
         [ 2.3438e+00, -8.5156e-01,  4.4727e-01,  ..., -3.8867e-01,
          -8.7891e-01, -5.0293e-02]],

        ...,

        [[ 2.0599e-03,  1.2146e-02, -2.2095e-02,  ...,  1.0620e-02,
          -2.1057e-03,  3.3264e-03],
         [ 2.0000e+00,  6.7578e-01,  7.5391e-01,  ...,  1.5469e+00,
           4.0527e-02, -1.3438e+00],
         [ 4.1797e-01,  1.6406e+00,  1.8672e+00,  ...,  1.5234e+00,
          -2.3281e+00, -2.5625e+00],
         ...,
         [ 4.7500e+00,  3.3594e+00,  2.2188e+00,  ...,  1.7656e+00,
          -4.9688e+00, -2.3281e+00],
         [ 6.2812e+00,  3.4844e+00,  3.2188e+00,  ...,  4.0000e+00,
          -5.7500e+00, -4.1250e+00],
         [ 4.0000e+00,  1.9141e+00,  1.7109e+00,  ...,  3.0000e+00,
          -4.2500e+00, -2.1094e+00]],

        [[ 2.0599e-03,  1.2146e-02, -2.2095e-02,  ...,  1.0620e-02,
          -2.1057e-03,  3.3264e-03],
         [ 2.0000e+00,  6.7578e-01,  7.5391e-01,  ...,  1.5469e+00,
           4.0527e-02, -1.3438e+00],
         [ 4.1797e-01,  1.6406e+00,  1.8672e+00,  ...,  1.5234e+00,
          -2.3281e+00, -2.5625e+00],
         ...,
         [ 4.7500e+00,  3.3594e+00,  2.2188e+00,  ...,  1.7656e+00,
          -4.9688e+00, -2.3281e+00],
         [ 6.2812e+00,  3.4844e+00,  3.2188e+00,  ...,  4.0000e+00,
          -5.7500e+00, -4.1250e+00],
         [ 4.0000e+00,  1.9141e+00,  1.7109e+00,  ...,  3.0000e+00,
          -4.2500e+00, -2.1094e+00]],

        [[ 2.0599e-03,  1.2146e-02, -2.2095e-02,  ...,  1.0620e-02,
          -2.1057e-03,  3.3264e-03],
         [ 2.0000e+00,  6.7578e-01,  7.5391e-01,  ...,  1.5469e+00,
           4.0527e-02, -1.3438e+00],
         [ 4.1797e-01,  1.6406e+00,  1.8672e+00,  ...,  1.5234e+00,
          -2.3281e+00, -2.5625e+00],
         ...,
         [ 4.7500e+00,  3.3594e+00,  2.2188e+00,  ...,  1.7656e+00,
          -4.9688e+00, -2.3281e+00],
         [ 6.2812e+00,  3.4844e+00,  3.2188e+00,  ...,  4.0000e+00,
          -5.7500e+00, -4.1250e+00],
         [ 4.0000e+00,  1.9141e+00,  1.7109e+00,  ...,  3.0000e+00,
          -4.2500e+00, -2.1094e+00]]], dtype=torch.bfloat16)), (tensor([[[ 8.5831e-04, -4.5166e-03,  2.7008e-03,  ...,  4.5166e-02,
           3.1055e-01, -8.4961e-02],
         [ 3.1250e+00,  2.2656e+00, -6.4062e-01,  ...,  6.3281e-01,
          -5.9375e-01,  1.2500e+00],
         [ 1.9375e+00,  1.9688e+00,  1.0625e+00,  ...,  5.1953e-01,
           5.8105e-02,  9.5312e-01],
         ...,
         [-2.2363e-01, -1.3984e+00, -1.7578e-01,  ...,  8.3984e-02,
           8.5547e-01,  1.1182e-01],
         [-9.4922e-01,  4.0430e-01, -1.0547e+00,  ...,  2.3340e-01,
          -5.6396e-02, -5.5859e-01],
         [-2.4375e+00,  7.4219e-02, -1.4844e+00,  ...,  9.2188e-01,
          -8.8281e-01, -3.1055e-01]],

        [[ 8.5831e-04, -4.5166e-03,  2.7008e-03,  ...,  4.5166e-02,
           3.1055e-01, -8.4961e-02],
         [ 3.1250e+00,  2.2656e+00, -6.4062e-01,  ...,  6.3281e-01,
          -5.9375e-01,  1.2500e+00],
         [ 1.9375e+00,  1.9688e+00,  1.0625e+00,  ...,  5.1953e-01,
           5.8105e-02,  9.5312e-01],
         ...,
         [-2.2363e-01, -1.3984e+00, -1.7578e-01,  ...,  8.3984e-02,
           8.5547e-01,  1.1182e-01],
         [-9.4922e-01,  4.0430e-01, -1.0547e+00,  ...,  2.3340e-01,
          -5.6396e-02, -5.5859e-01],
         [-2.4375e+00,  7.4219e-02, -1.4844e+00,  ...,  9.2188e-01,
          -8.8281e-01, -3.1055e-01]],

        [[ 8.5831e-04, -4.5166e-03,  2.7008e-03,  ...,  4.5166e-02,
           3.1055e-01, -8.4961e-02],
         [ 3.1250e+00,  2.2656e+00, -6.4062e-01,  ...,  6.3281e-01,
          -5.9375e-01,  1.2500e+00],
         [ 1.9375e+00,  1.9688e+00,  1.0625e+00,  ...,  5.1953e-01,
           5.8105e-02,  9.5312e-01],
         ...,
         [-2.2363e-01, -1.3984e+00, -1.7578e-01,  ...,  8.3984e-02,
           8.5547e-01,  1.1182e-01],
         [-9.4922e-01,  4.0430e-01, -1.0547e+00,  ...,  2.3340e-01,
          -5.6396e-02, -5.5859e-01],
         [-2.4375e+00,  7.4219e-02, -1.4844e+00,  ...,  9.2188e-01,
          -8.8281e-01, -3.1055e-01]],

        ...,

        [[ 3.2959e-03, -2.9755e-03, -6.4392e-03,  ...,  2.6562e-01,
           4.3945e-01,  1.4551e-01],
         [-3.2031e+00,  1.4766e+00, -1.5938e+00,  ..., -3.1641e-01,
          -1.6484e+00, -4.7852e-01],
         [ 2.7344e-02,  1.8262e-01, -1.2031e+00,  ..., -1.3359e+00,
          -2.0781e+00,  1.5938e+00],
         ...,
         [ 1.8281e+00, -2.1562e+00, -1.8203e+00,  ..., -9.1797e-01,
          -1.2500e+00,  4.3750e-01],
         [ 9.6094e-01,  2.9297e-01, -8.2422e-01,  ...,  7.1484e-01,
          -2.2812e+00,  1.5859e+00],
         [ 3.1250e-01,  1.4219e+00, -1.6250e+00,  ..., -1.0547e+00,
          -2.0156e+00,  2.5391e-01]],

        [[ 3.2959e-03, -2.9755e-03, -6.4392e-03,  ...,  2.6562e-01,
           4.3945e-01,  1.4551e-01],
         [-3.2031e+00,  1.4766e+00, -1.5938e+00,  ..., -3.1641e-01,
          -1.6484e+00, -4.7852e-01],
         [ 2.7344e-02,  1.8262e-01, -1.2031e+00,  ..., -1.3359e+00,
          -2.0781e+00,  1.5938e+00],
         ...,
         [ 1.8281e+00, -2.1562e+00, -1.8203e+00,  ..., -9.1797e-01,
          -1.2500e+00,  4.3750e-01],
         [ 9.6094e-01,  2.9297e-01, -8.2422e-01,  ...,  7.1484e-01,
          -2.2812e+00,  1.5859e+00],
         [ 3.1250e-01,  1.4219e+00, -1.6250e+00,  ..., -1.0547e+00,
          -2.0156e+00,  2.5391e-01]],

        [[ 3.2959e-03, -2.9755e-03, -6.4392e-03,  ...,  2.6562e-01,
           4.3945e-01,  1.4551e-01],
         [-3.2031e+00,  1.4766e+00, -1.5938e+00,  ..., -3.1641e-01,
          -1.6484e+00, -4.7852e-01],
         [ 2.7344e-02,  1.8262e-01, -1.2031e+00,  ..., -1.3359e+00,
          -2.0781e+00,  1.5938e+00],
         ...,
         [ 1.8281e+00, -2.1562e+00, -1.8203e+00,  ..., -9.1797e-01,
          -1.2500e+00,  4.3750e-01],
         [ 9.6094e-01,  2.9297e-01, -8.2422e-01,  ...,  7.1484e-01,
          -2.2812e+00,  1.5859e+00],
         [ 3.1250e-01,  1.4219e+00, -1.6250e+00,  ..., -1.0547e+00,
          -2.0156e+00,  2.5391e-01]]], dtype=torch.bfloat16), tensor([[[-2.0264e-02,  7.4005e-04, -1.5869e-02,  ..., -1.4893e-02,
           4.4250e-03, -2.7771e-03],
         [ 5.3125e-01,  7.6172e-01,  1.8203e+00,  ..., -3.2812e-01,
          -6.6016e-01, -5.2734e-01],
         [ 1.6328e+00,  1.3438e+00,  4.5898e-01,  ...,  7.7344e-01,
          -7.0703e-01,  9.9219e-01],
         ...,
         [ 3.2812e-01, -8.9062e-01,  1.7031e+00,  ...,  2.9102e-01,
           1.0078e+00,  5.1953e-01],
         [-2.7148e-01, -9.5703e-01, -7.4219e-01,  ..., -3.8281e-01,
           6.9531e-01, -4.5410e-02],
         [ 6.9141e-01,  8.4766e-01, -4.0039e-01,  ..., -1.0000e+00,
           7.8906e-01, -1.0781e+00]],

        [[-2.0264e-02,  7.4005e-04, -1.5869e-02,  ..., -1.4893e-02,
           4.4250e-03, -2.7771e-03],
         [ 5.3125e-01,  7.6172e-01,  1.8203e+00,  ..., -3.2812e-01,
          -6.6016e-01, -5.2734e-01],
         [ 1.6328e+00,  1.3438e+00,  4.5898e-01,  ...,  7.7344e-01,
          -7.0703e-01,  9.9219e-01],
         ...,
         [ 3.2812e-01, -8.9062e-01,  1.7031e+00,  ...,  2.9102e-01,
           1.0078e+00,  5.1953e-01],
         [-2.7148e-01, -9.5703e-01, -7.4219e-01,  ..., -3.8281e-01,
           6.9531e-01, -4.5410e-02],
         [ 6.9141e-01,  8.4766e-01, -4.0039e-01,  ..., -1.0000e+00,
           7.8906e-01, -1.0781e+00]],

        [[-2.0264e-02,  7.4005e-04, -1.5869e-02,  ..., -1.4893e-02,
           4.4250e-03, -2.7771e-03],
         [ 5.3125e-01,  7.6172e-01,  1.8203e+00,  ..., -3.2812e-01,
          -6.6016e-01, -5.2734e-01],
         [ 1.6328e+00,  1.3438e+00,  4.5898e-01,  ...,  7.7344e-01,
          -7.0703e-01,  9.9219e-01],
         ...,
         [ 3.2812e-01, -8.9062e-01,  1.7031e+00,  ...,  2.9102e-01,
           1.0078e+00,  5.1953e-01],
         [-2.7148e-01, -9.5703e-01, -7.4219e-01,  ..., -3.8281e-01,
           6.9531e-01, -4.5410e-02],
         [ 6.9141e-01,  8.4766e-01, -4.0039e-01,  ..., -1.0000e+00,
           7.8906e-01, -1.0781e+00]],

        ...,

        [[ 6.8665e-04, -4.1809e-03, -1.6968e-02,  ...,  1.6235e-02,
           5.7068e-03,  1.3062e-02],
         [ 2.1191e-01, -4.6289e-01,  7.0703e-01,  ..., -3.9648e-01,
          -3.3984e-01,  4.9805e-01],
         [-1.2793e-01, -1.2500e-01,  1.2734e+00,  ...,  2.5391e-01,
           2.8906e-01, -1.5469e+00],
         ...,
         [-1.2734e+00,  1.2734e+00,  1.4453e+00,  ...,  1.0938e+00,
          -9.8047e-01, -9.0625e-01],
         [ 4.0430e-01, -1.2891e-01,  1.4141e+00,  ..., -4.1562e+00,
           3.5938e-01, -1.0703e+00],
         [ 8.5938e-01,  7.6562e-01,  4.2969e-01,  ...,  5.6641e-02,
          -6.9141e-01, -9.3750e-01]],

        [[ 6.8665e-04, -4.1809e-03, -1.6968e-02,  ...,  1.6235e-02,
           5.7068e-03,  1.3062e-02],
         [ 2.1191e-01, -4.6289e-01,  7.0703e-01,  ..., -3.9648e-01,
          -3.3984e-01,  4.9805e-01],
         [-1.2793e-01, -1.2500e-01,  1.2734e+00,  ...,  2.5391e-01,
           2.8906e-01, -1.5469e+00],
         ...,
         [-1.2734e+00,  1.2734e+00,  1.4453e+00,  ...,  1.0938e+00,
          -9.8047e-01, -9.0625e-01],
         [ 4.0430e-01, -1.2891e-01,  1.4141e+00,  ..., -4.1562e+00,
           3.5938e-01, -1.0703e+00],
         [ 8.5938e-01,  7.6562e-01,  4.2969e-01,  ...,  5.6641e-02,
          -6.9141e-01, -9.3750e-01]],

        [[ 6.8665e-04, -4.1809e-03, -1.6968e-02,  ...,  1.6235e-02,
           5.7068e-03,  1.3062e-02],
         [ 2.1191e-01, -4.6289e-01,  7.0703e-01,  ..., -3.9648e-01,
          -3.3984e-01,  4.9805e-01],
         [-1.2793e-01, -1.2500e-01,  1.2734e+00,  ...,  2.5391e-01,
           2.8906e-01, -1.5469e+00],
         ...,
         [-1.2734e+00,  1.2734e+00,  1.4453e+00,  ...,  1.0938e+00,
          -9.8047e-01, -9.0625e-01],
         [ 4.0430e-01, -1.2891e-01,  1.4141e+00,  ..., -4.1562e+00,
           3.5938e-01, -1.0703e+00],
         [ 8.5938e-01,  7.6562e-01,  4.2969e-01,  ...,  5.6641e-02,
          -6.9141e-01, -9.3750e-01]]], dtype=torch.bfloat16)), (tensor([[[ 6.1340e-03, -1.6724e-02,  5.8365e-04,  ...,  3.6328e-01,
          -4.5898e-02, -1.3086e-01],
         [ 1.1641e+00,  6.7969e-01, -1.1875e+00,  ..., -9.2188e-01,
          -5.7373e-02,  9.6094e-01],
         [ 1.6875e+00,  1.3438e+00, -6.3281e-01,  ..., -1.5078e+00,
          -4.3164e-01,  1.9375e+00],
         ...,
         [ 1.7031e+00, -1.3438e+00, -4.6094e-01,  ..., -4.1016e-01,
          -1.0938e+00,  1.2344e+00],
         [ 5.5469e-01, -5.4297e-01, -1.8438e+00,  ..., -9.7656e-01,
          -8.5938e-02,  2.0938e+00],
         [-2.5000e+00,  1.5156e+00, -3.6133e-01,  ..., -8.9062e-01,
          -3.5742e-01,  1.1250e+00]],

        [[ 6.1340e-03, -1.6724e-02,  5.8365e-04,  ...,  3.6328e-01,
          -4.5898e-02, -1.3086e-01],
         [ 1.1641e+00,  6.7969e-01, -1.1875e+00,  ..., -9.2188e-01,
          -5.7373e-02,  9.6094e-01],
         [ 1.6875e+00,  1.3438e+00, -6.3281e-01,  ..., -1.5078e+00,
          -4.3164e-01,  1.9375e+00],
         ...,
         [ 1.7031e+00, -1.3438e+00, -4.6094e-01,  ..., -4.1016e-01,
          -1.0938e+00,  1.2344e+00],
         [ 5.5469e-01, -5.4297e-01, -1.8438e+00,  ..., -9.7656e-01,
          -8.5938e-02,  2.0938e+00],
         [-2.5000e+00,  1.5156e+00, -3.6133e-01,  ..., -8.9062e-01,
          -3.5742e-01,  1.1250e+00]],

        [[ 6.1340e-03, -1.6724e-02,  5.8365e-04,  ...,  3.6328e-01,
          -4.5898e-02, -1.3086e-01],
         [ 1.1641e+00,  6.7969e-01, -1.1875e+00,  ..., -9.2188e-01,
          -5.7373e-02,  9.6094e-01],
         [ 1.6875e+00,  1.3438e+00, -6.3281e-01,  ..., -1.5078e+00,
          -4.3164e-01,  1.9375e+00],
         ...,
         [ 1.7031e+00, -1.3438e+00, -4.6094e-01,  ..., -4.1016e-01,
          -1.0938e+00,  1.2344e+00],
         [ 5.5469e-01, -5.4297e-01, -1.8438e+00,  ..., -9.7656e-01,
          -8.5938e-02,  2.0938e+00],
         [-2.5000e+00,  1.5156e+00, -3.6133e-01,  ..., -8.9062e-01,
          -3.5742e-01,  1.1250e+00]],

        ...,

        [[ 1.5076e-02,  1.1780e-02,  4.5471e-03,  ..., -1.2012e-01,
          -1.2024e-02, -1.5391e+00],
         [ 3.3125e+00, -1.0156e-01, -9.9219e-01,  ..., -3.9062e-01,
           5.7812e-01,  6.6250e+00],
         [ 2.0781e+00,  1.8438e+00, -9.1406e-01,  ..., -1.8359e+00,
           4.6875e-02,  8.0625e+00],
         ...,
         [ 4.1406e-01, -7.4219e-01, -2.2461e-01,  ...,  7.3438e-01,
          -1.3984e+00,  7.0938e+00],
         [-1.5859e+00, -2.0781e+00,  7.0312e-02,  ...,  1.1875e+00,
          -3.4961e-01,  7.5938e+00],
         [-2.8281e+00, -1.8594e+00, -4.6094e-01,  ...,  9.6094e-01,
           1.1172e+00,  7.6562e+00]],

        [[ 1.5076e-02,  1.1780e-02,  4.5471e-03,  ..., -1.2012e-01,
          -1.2024e-02, -1.5391e+00],
         [ 3.3125e+00, -1.0156e-01, -9.9219e-01,  ..., -3.9062e-01,
           5.7812e-01,  6.6250e+00],
         [ 2.0781e+00,  1.8438e+00, -9.1406e-01,  ..., -1.8359e+00,
           4.6875e-02,  8.0625e+00],
         ...,
         [ 4.1406e-01, -7.4219e-01, -2.2461e-01,  ...,  7.3438e-01,
          -1.3984e+00,  7.0938e+00],
         [-1.5859e+00, -2.0781e+00,  7.0312e-02,  ...,  1.1875e+00,
          -3.4961e-01,  7.5938e+00],
         [-2.8281e+00, -1.8594e+00, -4.6094e-01,  ...,  9.6094e-01,
           1.1172e+00,  7.6562e+00]],

        [[ 1.5076e-02,  1.1780e-02,  4.5471e-03,  ..., -1.2012e-01,
          -1.2024e-02, -1.5391e+00],
         [ 3.3125e+00, -1.0156e-01, -9.9219e-01,  ..., -3.9062e-01,
           5.7812e-01,  6.6250e+00],
         [ 2.0781e+00,  1.8438e+00, -9.1406e-01,  ..., -1.8359e+00,
           4.6875e-02,  8.0625e+00],
         ...,
         [ 4.1406e-01, -7.4219e-01, -2.2461e-01,  ...,  7.3438e-01,
          -1.3984e+00,  7.0938e+00],
         [-1.5859e+00, -2.0781e+00,  7.0312e-02,  ...,  1.1875e+00,
          -3.4961e-01,  7.5938e+00],
         [-2.8281e+00, -1.8594e+00, -4.6094e-01,  ...,  9.6094e-01,
           1.1172e+00,  7.6562e+00]]], dtype=torch.bfloat16), tensor([[[ 4.6387e-03,  8.9722e-03, -1.2268e-02,  ...,  2.8687e-02,
           2.5024e-03,  4.5776e-03],
         [-2.1094e-01,  3.4180e-02,  6.0938e-01,  ...,  7.0312e-01,
           3.2031e-01,  3.0078e-01],
         [ 7.0312e-01,  9.6094e-01, -2.3730e-01,  ..., -2.8125e-01,
           9.0625e-01, -6.3281e-01],
         ...,
         [-8.1250e-01,  1.5391e+00, -8.0078e-02,  ...,  2.1719e+00,
           8.6328e-01,  2.9688e-01],
         [ 6.1328e-01, -1.0938e+00,  1.8281e+00,  ..., -1.4609e+00,
          -8.0078e-02, -3.3594e-01],
         [ 3.5938e-01, -1.4062e+00,  1.0391e+00,  ..., -1.8047e+00,
           1.4062e+00,  9.7656e-01]],

        [[ 4.6387e-03,  8.9722e-03, -1.2268e-02,  ...,  2.8687e-02,
           2.5024e-03,  4.5776e-03],
         [-2.1094e-01,  3.4180e-02,  6.0938e-01,  ...,  7.0312e-01,
           3.2031e-01,  3.0078e-01],
         [ 7.0312e-01,  9.6094e-01, -2.3730e-01,  ..., -2.8125e-01,
           9.0625e-01, -6.3281e-01],
         ...,
         [-8.1250e-01,  1.5391e+00, -8.0078e-02,  ...,  2.1719e+00,
           8.6328e-01,  2.9688e-01],
         [ 6.1328e-01, -1.0938e+00,  1.8281e+00,  ..., -1.4609e+00,
          -8.0078e-02, -3.3594e-01],
         [ 3.5938e-01, -1.4062e+00,  1.0391e+00,  ..., -1.8047e+00,
           1.4062e+00,  9.7656e-01]],

        [[ 4.6387e-03,  8.9722e-03, -1.2268e-02,  ...,  2.8687e-02,
           2.5024e-03,  4.5776e-03],
         [-2.1094e-01,  3.4180e-02,  6.0938e-01,  ...,  7.0312e-01,
           3.2031e-01,  3.0078e-01],
         [ 7.0312e-01,  9.6094e-01, -2.3730e-01,  ..., -2.8125e-01,
           9.0625e-01, -6.3281e-01],
         ...,
         [-8.1250e-01,  1.5391e+00, -8.0078e-02,  ...,  2.1719e+00,
           8.6328e-01,  2.9688e-01],
         [ 6.1328e-01, -1.0938e+00,  1.8281e+00,  ..., -1.4609e+00,
          -8.0078e-02, -3.3594e-01],
         [ 3.5938e-01, -1.4062e+00,  1.0391e+00,  ..., -1.8047e+00,
           1.4062e+00,  9.7656e-01]],

        ...,

        [[-7.5684e-03,  1.5869e-02, -2.8198e-02,  ...,  1.9531e-02,
           6.4697e-03, -8.6975e-04],
         [ 1.1914e-01,  4.0283e-02,  1.3438e+00,  ..., -8.8672e-01,
           1.7031e+00, -2.0996e-01],
         [ 5.4297e-01, -9.4922e-01,  1.1484e+00,  ..., -2.2656e+00,
          -1.2266e+00, -2.0781e+00],
         ...,
         [ 2.7656e+00,  1.3984e+00,  3.0859e-01,  ..., -9.2188e-01,
           4.2773e-01, -3.7656e+00],
         [ 1.3672e+00, -1.2500e-01, -5.1172e-01,  ..., -7.6953e-01,
          -1.7656e+00, -1.7422e+00],
         [ 3.9688e+00, -7.2266e-01, -3.8477e-01,  ..., -2.0156e+00,
           9.6484e-01, -2.2344e+00]],

        [[-7.5684e-03,  1.5869e-02, -2.8198e-02,  ...,  1.9531e-02,
           6.4697e-03, -8.6975e-04],
         [ 1.1914e-01,  4.0283e-02,  1.3438e+00,  ..., -8.8672e-01,
           1.7031e+00, -2.0996e-01],
         [ 5.4297e-01, -9.4922e-01,  1.1484e+00,  ..., -2.2656e+00,
          -1.2266e+00, -2.0781e+00],
         ...,
         [ 2.7656e+00,  1.3984e+00,  3.0859e-01,  ..., -9.2188e-01,
           4.2773e-01, -3.7656e+00],
         [ 1.3672e+00, -1.2500e-01, -5.1172e-01,  ..., -7.6953e-01,
          -1.7656e+00, -1.7422e+00],
         [ 3.9688e+00, -7.2266e-01, -3.8477e-01,  ..., -2.0156e+00,
           9.6484e-01, -2.2344e+00]],

        [[-7.5684e-03,  1.5869e-02, -2.8198e-02,  ...,  1.9531e-02,
           6.4697e-03, -8.6975e-04],
         [ 1.1914e-01,  4.0283e-02,  1.3438e+00,  ..., -8.8672e-01,
           1.7031e+00, -2.0996e-01],
         [ 5.4297e-01, -9.4922e-01,  1.1484e+00,  ..., -2.2656e+00,
          -1.2266e+00, -2.0781e+00],
         ...,
         [ 2.7656e+00,  1.3984e+00,  3.0859e-01,  ..., -9.2188e-01,
           4.2773e-01, -3.7656e+00],
         [ 1.3672e+00, -1.2500e-01, -5.1172e-01,  ..., -7.6953e-01,
          -1.7656e+00, -1.7422e+00],
         [ 3.9688e+00, -7.2266e-01, -3.8477e-01,  ..., -2.0156e+00,
           9.6484e-01, -2.2344e+00]]], dtype=torch.bfloat16)), (tensor([[[-1.7090e-02,  5.7983e-03,  1.4267e-03,  ..., -4.4189e-02,
          -2.8076e-03, -2.6001e-02],
         [ 2.1250e+00, -6.4844e-01, -5.2734e-01,  ...,  1.3359e+00,
          -8.1641e-01,  7.1484e-01],
         [-1.7383e-01,  3.6328e-01,  2.1406e+00,  ...,  1.6953e+00,
          -2.8750e+00, -4.3555e-01],
         ...,
         [-1.9219e+00,  6.0156e-01,  8.5938e-01,  ..., -3.5000e+00,
           6.7578e-01,  1.9629e-01],
         [-1.6406e-01,  6.8750e-01,  1.3359e+00,  ..., -2.1719e+00,
          -1.7344e+00,  1.4375e+00],
         [ 4.7852e-01,  2.8564e-02,  7.4219e-01,  ..., -1.1641e+00,
          -1.3594e+00, -5.1562e-01]],

        [[-1.7090e-02,  5.7983e-03,  1.4267e-03,  ..., -4.4189e-02,
          -2.8076e-03, -2.6001e-02],
         [ 2.1250e+00, -6.4844e-01, -5.2734e-01,  ...,  1.3359e+00,
          -8.1641e-01,  7.1484e-01],
         [-1.7383e-01,  3.6328e-01,  2.1406e+00,  ...,  1.6953e+00,
          -2.8750e+00, -4.3555e-01],
         ...,
         [-1.9219e+00,  6.0156e-01,  8.5938e-01,  ..., -3.5000e+00,
           6.7578e-01,  1.9629e-01],
         [-1.6406e-01,  6.8750e-01,  1.3359e+00,  ..., -2.1719e+00,
          -1.7344e+00,  1.4375e+00],
         [ 4.7852e-01,  2.8564e-02,  7.4219e-01,  ..., -1.1641e+00,
          -1.3594e+00, -5.1562e-01]],

        [[-1.7090e-02,  5.7983e-03,  1.4267e-03,  ..., -4.4189e-02,
          -2.8076e-03, -2.6001e-02],
         [ 2.1250e+00, -6.4844e-01, -5.2734e-01,  ...,  1.3359e+00,
          -8.1641e-01,  7.1484e-01],
         [-1.7383e-01,  3.6328e-01,  2.1406e+00,  ...,  1.6953e+00,
          -2.8750e+00, -4.3555e-01],
         ...,
         [-1.9219e+00,  6.0156e-01,  8.5938e-01,  ..., -3.5000e+00,
           6.7578e-01,  1.9629e-01],
         [-1.6406e-01,  6.8750e-01,  1.3359e+00,  ..., -2.1719e+00,
          -1.7344e+00,  1.4375e+00],
         [ 4.7852e-01,  2.8564e-02,  7.4219e-01,  ..., -1.1641e+00,
          -1.3594e+00, -5.1562e-01]],

        ...,

        [[-1.4221e-02,  2.1240e-02, -1.3855e-02,  ...,  1.4258e-01,
           8.3984e-02, -1.8906e+00],
         [ 2.9297e-02, -2.5391e-01, -9.0625e-01,  ...,  4.1016e-01,
          -1.7344e+00,  5.9375e+00],
         [ 4.1406e-01,  3.7842e-02,  1.7578e-01,  ..., -5.4297e-01,
          -2.0469e+00,  6.5312e+00],
         ...,
         [ 6.7578e-01,  1.6797e-01, -3.6523e-01,  ..., -9.3750e-01,
           1.4375e+00,  5.0938e+00],
         [-3.5547e-01, -2.9102e-01,  3.9453e-01,  ..., -2.9375e+00,
          -5.1562e-01,  5.3125e+00],
         [-1.1250e+00, -6.1719e-01, -5.2344e-01,  ..., -2.2656e+00,
          -1.3594e+00,  4.8750e+00]],

        [[-1.4221e-02,  2.1240e-02, -1.3855e-02,  ...,  1.4258e-01,
           8.3984e-02, -1.8906e+00],
         [ 2.9297e-02, -2.5391e-01, -9.0625e-01,  ...,  4.1016e-01,
          -1.7344e+00,  5.9375e+00],
         [ 4.1406e-01,  3.7842e-02,  1.7578e-01,  ..., -5.4297e-01,
          -2.0469e+00,  6.5312e+00],
         ...,
         [ 6.7578e-01,  1.6797e-01, -3.6523e-01,  ..., -9.3750e-01,
           1.4375e+00,  5.0938e+00],
         [-3.5547e-01, -2.9102e-01,  3.9453e-01,  ..., -2.9375e+00,
          -5.1562e-01,  5.3125e+00],
         [-1.1250e+00, -6.1719e-01, -5.2344e-01,  ..., -2.2656e+00,
          -1.3594e+00,  4.8750e+00]],

        [[-1.4221e-02,  2.1240e-02, -1.3855e-02,  ...,  1.4258e-01,
           8.3984e-02, -1.8906e+00],
         [ 2.9297e-02, -2.5391e-01, -9.0625e-01,  ...,  4.1016e-01,
          -1.7344e+00,  5.9375e+00],
         [ 4.1406e-01,  3.7842e-02,  1.7578e-01,  ..., -5.4297e-01,
          -2.0469e+00,  6.5312e+00],
         ...,
         [ 6.7578e-01,  1.6797e-01, -3.6523e-01,  ..., -9.3750e-01,
           1.4375e+00,  5.0938e+00],
         [-3.5547e-01, -2.9102e-01,  3.9453e-01,  ..., -2.9375e+00,
          -5.1562e-01,  5.3125e+00],
         [-1.1250e+00, -6.1719e-01, -5.2344e-01,  ..., -2.2656e+00,
          -1.3594e+00,  4.8750e+00]]], dtype=torch.bfloat16), tensor([[[-4.0039e-02, -1.9409e-02, -2.3804e-02,  ...,  3.9795e-02,
           6.1035e-02, -2.9785e-02],
         [-1.0938e+00,  4.3750e-01, -2.5781e-01,  ..., -3.9258e-01,
           7.9590e-02,  1.3828e+00],
         [-7.3828e-01, -2.5391e-01,  6.8359e-01,  ..., -4.2578e-01,
           6.7969e-01,  4.2578e-01],
         ...,
         [-1.3125e+00,  1.0312e+00, -2.1094e+00,  ...,  4.8242e-01,
           1.7422e+00,  1.3125e+00],
         [-6.2109e-01, -8.0469e-01,  2.7148e-01,  ..., -1.7812e+00,
           1.7383e-01, -7.6953e-01],
         [-2.8809e-02, -1.3594e+00, -6.3672e-01,  ..., -2.1387e-01,
           2.0469e+00, -5.4688e-01]],

        [[-4.0039e-02, -1.9409e-02, -2.3804e-02,  ...,  3.9795e-02,
           6.1035e-02, -2.9785e-02],
         [-1.0938e+00,  4.3750e-01, -2.5781e-01,  ..., -3.9258e-01,
           7.9590e-02,  1.3828e+00],
         [-7.3828e-01, -2.5391e-01,  6.8359e-01,  ..., -4.2578e-01,
           6.7969e-01,  4.2578e-01],
         ...,
         [-1.3125e+00,  1.0312e+00, -2.1094e+00,  ...,  4.8242e-01,
           1.7422e+00,  1.3125e+00],
         [-6.2109e-01, -8.0469e-01,  2.7148e-01,  ..., -1.7812e+00,
           1.7383e-01, -7.6953e-01],
         [-2.8809e-02, -1.3594e+00, -6.3672e-01,  ..., -2.1387e-01,
           2.0469e+00, -5.4688e-01]],

        [[-4.0039e-02, -1.9409e-02, -2.3804e-02,  ...,  3.9795e-02,
           6.1035e-02, -2.9785e-02],
         [-1.0938e+00,  4.3750e-01, -2.5781e-01,  ..., -3.9258e-01,
           7.9590e-02,  1.3828e+00],
         [-7.3828e-01, -2.5391e-01,  6.8359e-01,  ..., -4.2578e-01,
           6.7969e-01,  4.2578e-01],
         ...,
         [-1.3125e+00,  1.0312e+00, -2.1094e+00,  ...,  4.8242e-01,
           1.7422e+00,  1.3125e+00],
         [-6.2109e-01, -8.0469e-01,  2.7148e-01,  ..., -1.7812e+00,
           1.7383e-01, -7.6953e-01],
         [-2.8809e-02, -1.3594e+00, -6.3672e-01,  ..., -2.1387e-01,
           2.0469e+00, -5.4688e-01]],

        ...,

        [[ 6.0730e-03, -1.4343e-02, -1.2146e-02,  ...,  1.8692e-03,
          -4.9744e-03,  3.8757e-03],
         [-4.6875e-01,  2.4414e-01, -9.5312e-01,  ...,  5.5859e-01,
          -7.8125e-01,  4.4727e-01],
         [-8.5547e-01, -5.8984e-01,  4.5898e-01,  ..., -2.5625e+00,
          -5.1270e-02,  3.3008e-01],
         ...,
         [-6.4844e-01, -3.5938e-01,  9.7656e-01,  ..., -3.0859e-01,
          -1.8164e-01,  6.2500e-01],
         [ 9.0332e-02, -4.6875e-01, -7.1875e-01,  ..., -7.6953e-01,
          -2.5586e-01,  2.7930e-01],
         [-2.2812e+00, -8.9453e-01,  5.0391e-01,  ..., -1.5859e+00,
          -5.3125e-01, -9.1406e-01]],

        [[ 6.0730e-03, -1.4343e-02, -1.2146e-02,  ...,  1.8692e-03,
          -4.9744e-03,  3.8757e-03],
         [-4.6875e-01,  2.4414e-01, -9.5312e-01,  ...,  5.5859e-01,
          -7.8125e-01,  4.4727e-01],
         [-8.5547e-01, -5.8984e-01,  4.5898e-01,  ..., -2.5625e+00,
          -5.1270e-02,  3.3008e-01],
         ...,
         [-6.4844e-01, -3.5938e-01,  9.7656e-01,  ..., -3.0859e-01,
          -1.8164e-01,  6.2500e-01],
         [ 9.0332e-02, -4.6875e-01, -7.1875e-01,  ..., -7.6953e-01,
          -2.5586e-01,  2.7930e-01],
         [-2.2812e+00, -8.9453e-01,  5.0391e-01,  ..., -1.5859e+00,
          -5.3125e-01, -9.1406e-01]],

        [[ 6.0730e-03, -1.4343e-02, -1.2146e-02,  ...,  1.8692e-03,
          -4.9744e-03,  3.8757e-03],
         [-4.6875e-01,  2.4414e-01, -9.5312e-01,  ...,  5.5859e-01,
          -7.8125e-01,  4.4727e-01],
         [-8.5547e-01, -5.8984e-01,  4.5898e-01,  ..., -2.5625e+00,
          -5.1270e-02,  3.3008e-01],
         ...,
         [-6.4844e-01, -3.5938e-01,  9.7656e-01,  ..., -3.0859e-01,
          -1.8164e-01,  6.2500e-01],
         [ 9.0332e-02, -4.6875e-01, -7.1875e-01,  ..., -7.6953e-01,
          -2.5586e-01,  2.7930e-01],
         [-2.2812e+00, -8.9453e-01,  5.0391e-01,  ..., -1.5859e+00,
          -5.3125e-01, -9.1406e-01]]], dtype=torch.bfloat16)), (tensor([[[-0.0069,  0.0071, -0.0229,  ..., -0.2656, -0.3203,  0.5664],
         [ 1.9844, -0.9141,  1.1562,  ..., -2.0938, -0.4727,  0.7070],
         [ 1.4766, -0.3887,  0.7656,  ..., -2.5625, -0.1885, -0.8750],
         ...,
         [ 0.2070,  0.9141,  0.1455,  ..., -1.2422,  0.0520, -2.4375],
         [-0.1982,  0.2090,  0.4023,  ..., -2.4219,  1.5391, -3.0469],
         [-1.8281, -0.8906,  0.7656,  ..., -2.2500, -2.0156, -1.9375]],

        [[-0.0069,  0.0071, -0.0229,  ..., -0.2656, -0.3203,  0.5664],
         [ 1.9844, -0.9141,  1.1562,  ..., -2.0938, -0.4727,  0.7070],
         [ 1.4766, -0.3887,  0.7656,  ..., -2.5625, -0.1885, -0.8750],
         ...,
         [ 0.2070,  0.9141,  0.1455,  ..., -1.2422,  0.0520, -2.4375],
         [-0.1982,  0.2090,  0.4023,  ..., -2.4219,  1.5391, -3.0469],
         [-1.8281, -0.8906,  0.7656,  ..., -2.2500, -2.0156, -1.9375]],

        [[-0.0069,  0.0071, -0.0229,  ..., -0.2656, -0.3203,  0.5664],
         [ 1.9844, -0.9141,  1.1562,  ..., -2.0938, -0.4727,  0.7070],
         [ 1.4766, -0.3887,  0.7656,  ..., -2.5625, -0.1885, -0.8750],
         ...,
         [ 0.2070,  0.9141,  0.1455,  ..., -1.2422,  0.0520, -2.4375],
         [-0.1982,  0.2090,  0.4023,  ..., -2.4219,  1.5391, -3.0469],
         [-1.8281, -0.8906,  0.7656,  ..., -2.2500, -2.0156, -1.9375]],

        ...,

        [[-0.0102, -0.0106,  0.0152,  ..., -0.0306,  0.2354,  0.2734],
         [ 0.5859,  0.5820,  1.4844,  ...,  0.5117, -0.1738, -0.3418],
         [-0.3145,  2.4375,  0.7930,  ..., -0.0059,  0.1729, -0.4980],
         ...,
         [-1.7578, -1.1875, -0.5273,  ..., -0.6836, -0.0330, -0.4785],
         [ 0.4180, -0.4766,  0.5195,  ...,  1.0938, -1.3750,  1.7812],
         [ 1.3281, -0.9102,  1.0312,  ..., -0.0315,  0.2432,  1.0547]],

        [[-0.0102, -0.0106,  0.0152,  ..., -0.0306,  0.2354,  0.2734],
         [ 0.5859,  0.5820,  1.4844,  ...,  0.5117, -0.1738, -0.3418],
         [-0.3145,  2.4375,  0.7930,  ..., -0.0059,  0.1729, -0.4980],
         ...,
         [-1.7578, -1.1875, -0.5273,  ..., -0.6836, -0.0330, -0.4785],
         [ 0.4180, -0.4766,  0.5195,  ...,  1.0938, -1.3750,  1.7812],
         [ 1.3281, -0.9102,  1.0312,  ..., -0.0315,  0.2432,  1.0547]],

        [[-0.0102, -0.0106,  0.0152,  ..., -0.0306,  0.2354,  0.2734],
         [ 0.5859,  0.5820,  1.4844,  ...,  0.5117, -0.1738, -0.3418],
         [-0.3145,  2.4375,  0.7930,  ..., -0.0059,  0.1729, -0.4980],
         ...,
         [-1.7578, -1.1875, -0.5273,  ..., -0.6836, -0.0330, -0.4785],
         [ 0.4180, -0.4766,  0.5195,  ...,  1.0938, -1.3750,  1.7812],
         [ 1.3281, -0.9102,  1.0312,  ..., -0.0315,  0.2432,  1.0547]]],
       dtype=torch.bfloat16), tensor([[[ 0.0066,  0.0254,  0.0081,  ...,  0.0171,  0.0244, -0.0117],
         [ 0.2949, -0.6914,  0.5703,  ...,  0.0679,  0.4258, -0.1797],
         [-0.5859, -1.0625, -0.5781,  ...,  0.0036,  0.9961, -1.3984],
         ...,
         [-0.1953,  0.5703, -0.2031,  ..., -0.9609,  0.7617,  0.6719],
         [ 0.1289,  0.4941,  1.1328,  ..., -1.1484,  0.3223,  0.4355],
         [-0.1719,  0.3711,  0.7031,  ...,  0.4648,  0.0403, -0.8555]],

        [[ 0.0066,  0.0254,  0.0081,  ...,  0.0171,  0.0244, -0.0117],
         [ 0.2949, -0.6914,  0.5703,  ...,  0.0679,  0.4258, -0.1797],
         [-0.5859, -1.0625, -0.5781,  ...,  0.0036,  0.9961, -1.3984],
         ...,
         [-0.1953,  0.5703, -0.2031,  ..., -0.9609,  0.7617,  0.6719],
         [ 0.1289,  0.4941,  1.1328,  ..., -1.1484,  0.3223,  0.4355],
         [-0.1719,  0.3711,  0.7031,  ...,  0.4648,  0.0403, -0.8555]],

        [[ 0.0066,  0.0254,  0.0081,  ...,  0.0171,  0.0244, -0.0117],
         [ 0.2949, -0.6914,  0.5703,  ...,  0.0679,  0.4258, -0.1797],
         [-0.5859, -1.0625, -0.5781,  ...,  0.0036,  0.9961, -1.3984],
         ...,
         [-0.1953,  0.5703, -0.2031,  ..., -0.9609,  0.7617,  0.6719],
         [ 0.1289,  0.4941,  1.1328,  ..., -1.1484,  0.3223,  0.4355],
         [-0.1719,  0.3711,  0.7031,  ...,  0.4648,  0.0403, -0.8555]],

        ...,

        [[-0.0140, -0.0067,  0.0165,  ..., -0.0031, -0.0269,  0.0166],
         [ 0.7578,  0.6719, -0.7812,  ..., -0.4512, -1.5156, -0.4316],
         [ 1.1406,  1.6641, -0.1426,  ..., -0.7422,  0.7422, -2.5156],
         ...,
         [-1.2188, -1.5234,  0.0674,  ..., -0.2891, -0.5664,  0.3047],
         [-1.3750, -0.3125, -0.2129,  ...,  1.0547, -0.4668, -1.4531],
         [-1.1328, -0.2930, -0.3711,  ..., -1.0469, -0.6289,  0.2949]],

        [[-0.0140, -0.0067,  0.0165,  ..., -0.0031, -0.0269,  0.0166],
         [ 0.7578,  0.6719, -0.7812,  ..., -0.4512, -1.5156, -0.4316],
         [ 1.1406,  1.6641, -0.1426,  ..., -0.7422,  0.7422, -2.5156],
         ...,
         [-1.2188, -1.5234,  0.0674,  ..., -0.2891, -0.5664,  0.3047],
         [-1.3750, -0.3125, -0.2129,  ...,  1.0547, -0.4668, -1.4531],
         [-1.1328, -0.2930, -0.3711,  ..., -1.0469, -0.6289,  0.2949]],

        [[-0.0140, -0.0067,  0.0165,  ..., -0.0031, -0.0269,  0.0166],
         [ 0.7578,  0.6719, -0.7812,  ..., -0.4512, -1.5156, -0.4316],
         [ 1.1406,  1.6641, -0.1426,  ..., -0.7422,  0.7422, -2.5156],
         ...,
         [-1.2188, -1.5234,  0.0674,  ..., -0.2891, -0.5664,  0.3047],
         [-1.3750, -0.3125, -0.2129,  ...,  1.0547, -0.4668, -1.4531],
         [-1.1328, -0.2930, -0.3711,  ..., -1.0469, -0.6289,  0.2949]]],
       dtype=torch.bfloat16)), (tensor([[[ 4.0894e-03,  4.7607e-03, -2.6245e-03,  ..., -6.8848e-02,
           4.7852e-02,  1.8262e-01],
         [-3.7500e+00,  1.0625e+00, -1.1094e+00,  ..., -1.0303e-01,
           3.3203e-01, -1.0547e+00],
         [-1.4766e+00,  3.1562e+00, -1.0625e+00,  ..., -1.7578e-02,
          -9.4531e-01, -2.1250e+00],
         ...,
         [ 1.4375e+00, -7.1484e-01, -1.8555e-01,  ...,  5.4297e-01,
           1.7109e+00, -9.3359e-01],
         [ 1.2578e+00, -1.4531e+00, -1.1094e+00,  ...,  7.8516e-01,
          -1.2891e+00,  1.8359e+00],
         [ 1.6094e+00, -2.8438e+00, -1.7969e+00,  ...,  2.5391e-01,
           1.4062e-01, -5.1025e-02]],

        [[ 4.0894e-03,  4.7607e-03, -2.6245e-03,  ..., -6.8848e-02,
           4.7852e-02,  1.8262e-01],
         [-3.7500e+00,  1.0625e+00, -1.1094e+00,  ..., -1.0303e-01,
           3.3203e-01, -1.0547e+00],
         [-1.4766e+00,  3.1562e+00, -1.0625e+00,  ..., -1.7578e-02,
          -9.4531e-01, -2.1250e+00],
         ...,
         [ 1.4375e+00, -7.1484e-01, -1.8555e-01,  ...,  5.4297e-01,
           1.7109e+00, -9.3359e-01],
         [ 1.2578e+00, -1.4531e+00, -1.1094e+00,  ...,  7.8516e-01,
          -1.2891e+00,  1.8359e+00],
         [ 1.6094e+00, -2.8438e+00, -1.7969e+00,  ...,  2.5391e-01,
           1.4062e-01, -5.1025e-02]],

        [[ 4.0894e-03,  4.7607e-03, -2.6245e-03,  ..., -6.8848e-02,
           4.7852e-02,  1.8262e-01],
         [-3.7500e+00,  1.0625e+00, -1.1094e+00,  ..., -1.0303e-01,
           3.3203e-01, -1.0547e+00],
         [-1.4766e+00,  3.1562e+00, -1.0625e+00,  ..., -1.7578e-02,
          -9.4531e-01, -2.1250e+00],
         ...,
         [ 1.4375e+00, -7.1484e-01, -1.8555e-01,  ...,  5.4297e-01,
           1.7109e+00, -9.3359e-01],
         [ 1.2578e+00, -1.4531e+00, -1.1094e+00,  ...,  7.8516e-01,
          -1.2891e+00,  1.8359e+00],
         [ 1.6094e+00, -2.8438e+00, -1.7969e+00,  ...,  2.5391e-01,
           1.4062e-01, -5.1025e-02]],

        ...,

        [[ 8.1787e-03,  1.9165e-02, -1.0620e-02,  ...,  1.6699e-01,
          -1.3965e-01, -1.6016e-01],
         [-1.8906e+00,  1.1406e+00,  4.0039e-01,  ...,  1.2305e-01,
           4.7852e-01, -2.8906e-01],
         [ 5.7031e-01, -1.0469e+00, -5.1172e-01,  ..., -6.7578e-01,
          -2.7734e-01, -1.0781e+00],
         ...,
         [ 1.2656e+00, -6.4844e-01, -4.3359e-01,  ...,  5.9766e-01,
          -2.4219e-01, -4.4678e-02],
         [ 1.1875e+00,  1.7188e-01, -1.2695e-01,  ...,  9.5703e-01,
          -5.8594e-01,  9.5703e-01],
         [-5.6641e-01,  1.3047e+00,  5.9375e-01,  ...,  6.3281e-01,
           9.4531e-01, -9.1797e-02]],

        [[ 8.1787e-03,  1.9165e-02, -1.0620e-02,  ...,  1.6699e-01,
          -1.3965e-01, -1.6016e-01],
         [-1.8906e+00,  1.1406e+00,  4.0039e-01,  ...,  1.2305e-01,
           4.7852e-01, -2.8906e-01],
         [ 5.7031e-01, -1.0469e+00, -5.1172e-01,  ..., -6.7578e-01,
          -2.7734e-01, -1.0781e+00],
         ...,
         [ 1.2656e+00, -6.4844e-01, -4.3359e-01,  ...,  5.9766e-01,
          -2.4219e-01, -4.4678e-02],
         [ 1.1875e+00,  1.7188e-01, -1.2695e-01,  ...,  9.5703e-01,
          -5.8594e-01,  9.5703e-01],
         [-5.6641e-01,  1.3047e+00,  5.9375e-01,  ...,  6.3281e-01,
           9.4531e-01, -9.1797e-02]],

        [[ 8.1787e-03,  1.9165e-02, -1.0620e-02,  ...,  1.6699e-01,
          -1.3965e-01, -1.6016e-01],
         [-1.8906e+00,  1.1406e+00,  4.0039e-01,  ...,  1.2305e-01,
           4.7852e-01, -2.8906e-01],
         [ 5.7031e-01, -1.0469e+00, -5.1172e-01,  ..., -6.7578e-01,
          -2.7734e-01, -1.0781e+00],
         ...,
         [ 1.2656e+00, -6.4844e-01, -4.3359e-01,  ...,  5.9766e-01,
          -2.4219e-01, -4.4678e-02],
         [ 1.1875e+00,  1.7188e-01, -1.2695e-01,  ...,  9.5703e-01,
          -5.8594e-01,  9.5703e-01],
         [-5.6641e-01,  1.3047e+00,  5.9375e-01,  ...,  6.3281e-01,
           9.4531e-01, -9.1797e-02]]], dtype=torch.bfloat16), tensor([[[-0.0254, -0.0189,  0.0183,  ..., -0.0042,  0.1099, -0.0179],
         [-0.3301,  0.7344,  1.5312,  ..., -0.8516, -0.1631,  0.8555],
         [ 0.0043,  0.4824, -0.2041,  ..., -0.1660,  0.5938,  0.3242],
         ...,
         [-0.4434,  0.4688, -0.1040,  ..., -1.4453,  0.7422,  1.6016],
         [-0.4941, -0.0325,  0.3105,  ..., -1.4297, -1.3906,  1.2891],
         [-0.0786,  0.7031,  0.8906,  ..., -0.1040,  0.0566,  0.0618]],

        [[-0.0254, -0.0189,  0.0183,  ..., -0.0042,  0.1099, -0.0179],
         [-0.3301,  0.7344,  1.5312,  ..., -0.8516, -0.1631,  0.8555],
         [ 0.0043,  0.4824, -0.2041,  ..., -0.1660,  0.5938,  0.3242],
         ...,
         [-0.4434,  0.4688, -0.1040,  ..., -1.4453,  0.7422,  1.6016],
         [-0.4941, -0.0325,  0.3105,  ..., -1.4297, -1.3906,  1.2891],
         [-0.0786,  0.7031,  0.8906,  ..., -0.1040,  0.0566,  0.0618]],

        [[-0.0254, -0.0189,  0.0183,  ..., -0.0042,  0.1099, -0.0179],
         [-0.3301,  0.7344,  1.5312,  ..., -0.8516, -0.1631,  0.8555],
         [ 0.0043,  0.4824, -0.2041,  ..., -0.1660,  0.5938,  0.3242],
         ...,
         [-0.4434,  0.4688, -0.1040,  ..., -1.4453,  0.7422,  1.6016],
         [-0.4941, -0.0325,  0.3105,  ..., -1.4297, -1.3906,  1.2891],
         [-0.0786,  0.7031,  0.8906,  ..., -0.1040,  0.0566,  0.0618]],

        ...,

        [[-0.0245,  0.0131,  0.0088,  ...,  0.0496, -0.0391, -0.0024],
         [-0.2373, -0.2256,  0.3340,  ..., -0.0225, -0.1021,  0.9844],
         [-0.2637, -1.0156, -0.0444,  ..., -0.5508, -0.2002, -0.1064],
         ...,
         [ 0.2617,  0.0302,  1.6172,  ..., -2.2812,  0.2930,  0.2539],
         [ 0.1191, -0.2949, -0.1289,  ..., -1.3281,  0.3242,  0.9062],
         [ 0.2090,  0.1377, -0.5859,  ...,  0.2871,  0.3379, -0.2178]],

        [[-0.0245,  0.0131,  0.0088,  ...,  0.0496, -0.0391, -0.0024],
         [-0.2373, -0.2256,  0.3340,  ..., -0.0225, -0.1021,  0.9844],
         [-0.2637, -1.0156, -0.0444,  ..., -0.5508, -0.2002, -0.1064],
         ...,
         [ 0.2617,  0.0302,  1.6172,  ..., -2.2812,  0.2930,  0.2539],
         [ 0.1191, -0.2949, -0.1289,  ..., -1.3281,  0.3242,  0.9062],
         [ 0.2090,  0.1377, -0.5859,  ...,  0.2871,  0.3379, -0.2178]],

        [[-0.0245,  0.0131,  0.0088,  ...,  0.0496, -0.0391, -0.0024],
         [-0.2373, -0.2256,  0.3340,  ..., -0.0225, -0.1021,  0.9844],
         [-0.2637, -1.0156, -0.0444,  ..., -0.5508, -0.2002, -0.1064],
         ...,
         [ 0.2617,  0.0302,  1.6172,  ..., -2.2812,  0.2930,  0.2539],
         [ 0.1191, -0.2949, -0.1289,  ..., -1.3281,  0.3242,  0.9062],
         [ 0.2090,  0.1377, -0.5859,  ...,  0.2871,  0.3379, -0.2178]]],
       dtype=torch.bfloat16)), (tensor([[[ 2.4048e-02,  1.5381e-02, -2.6489e-02,  ...,  1.5442e-02,
           1.6094e+00, -2.5586e-01],
         [-1.9297e+00,  1.5312e+00,  3.8281e-01,  ..., -1.4844e+00,
          -4.0000e+00,  7.5000e-01],
         [-4.2188e-01,  1.6562e+00,  1.6953e+00,  ...,  1.9453e+00,
          -5.9062e+00,  1.8555e-01],
         ...,
         [ 1.1406e+00, -1.7422e+00, -1.0625e+00,  ..., -3.6562e+00,
          -5.2188e+00,  5.5469e-01],
         [ 9.0234e-01, -9.3750e-01, -5.2734e-01,  ..., -1.5078e+00,
          -6.4688e+00,  9.8438e-01],
         [-5.7812e-01, -8.3594e-01,  1.4062e+00,  ..., -1.6719e+00,
          -4.2812e+00,  3.0518e-02]],

        [[ 2.4048e-02,  1.5381e-02, -2.6489e-02,  ...,  1.5442e-02,
           1.6094e+00, -2.5586e-01],
         [-1.9297e+00,  1.5312e+00,  3.8281e-01,  ..., -1.4844e+00,
          -4.0000e+00,  7.5000e-01],
         [-4.2188e-01,  1.6562e+00,  1.6953e+00,  ...,  1.9453e+00,
          -5.9062e+00,  1.8555e-01],
         ...,
         [ 1.1406e+00, -1.7422e+00, -1.0625e+00,  ..., -3.6562e+00,
          -5.2188e+00,  5.5469e-01],
         [ 9.0234e-01, -9.3750e-01, -5.2734e-01,  ..., -1.5078e+00,
          -6.4688e+00,  9.8438e-01],
         [-5.7812e-01, -8.3594e-01,  1.4062e+00,  ..., -1.6719e+00,
          -4.2812e+00,  3.0518e-02]],

        [[ 2.4048e-02,  1.5381e-02, -2.6489e-02,  ...,  1.5442e-02,
           1.6094e+00, -2.5586e-01],
         [-1.9297e+00,  1.5312e+00,  3.8281e-01,  ..., -1.4844e+00,
          -4.0000e+00,  7.5000e-01],
         [-4.2188e-01,  1.6562e+00,  1.6953e+00,  ...,  1.9453e+00,
          -5.9062e+00,  1.8555e-01],
         ...,
         [ 1.1406e+00, -1.7422e+00, -1.0625e+00,  ..., -3.6562e+00,
          -5.2188e+00,  5.5469e-01],
         [ 9.0234e-01, -9.3750e-01, -5.2734e-01,  ..., -1.5078e+00,
          -6.4688e+00,  9.8438e-01],
         [-5.7812e-01, -8.3594e-01,  1.4062e+00,  ..., -1.6719e+00,
          -4.2812e+00,  3.0518e-02]],

        ...,

        [[-1.1475e-02,  4.6997e-03, -4.4861e-03,  ...,  4.9438e-03,
           5.2002e-02, -1.9844e+00],
         [-9.8047e-01, -8.5547e-01,  8.7500e-01,  ...,  9.5312e-01,
           1.0681e-02,  5.5000e+00],
         [-2.0312e+00, -2.3594e+00, -6.3672e-01,  ..., -8.6328e-01,
          -3.6719e-01,  7.7500e+00],
         ...,
         [-2.6953e-01, -1.3750e+00, -7.6953e-01,  ...,  6.3672e-01,
           4.2578e-01,  4.8438e+00],
         [ 5.2344e-01, -8.3496e-02,  8.9062e-01,  ...,  1.3047e+00,
           1.5703e+00,  6.6250e+00],
         [ 2.0469e+00,  1.0859e+00,  8.3594e-01,  ...,  2.9531e+00,
           3.6328e-01,  4.9375e+00]],

        [[-1.1475e-02,  4.6997e-03, -4.4861e-03,  ...,  4.9438e-03,
           5.2002e-02, -1.9844e+00],
         [-9.8047e-01, -8.5547e-01,  8.7500e-01,  ...,  9.5312e-01,
           1.0681e-02,  5.5000e+00],
         [-2.0312e+00, -2.3594e+00, -6.3672e-01,  ..., -8.6328e-01,
          -3.6719e-01,  7.7500e+00],
         ...,
         [-2.6953e-01, -1.3750e+00, -7.6953e-01,  ...,  6.3672e-01,
           4.2578e-01,  4.8438e+00],
         [ 5.2344e-01, -8.3496e-02,  8.9062e-01,  ...,  1.3047e+00,
           1.5703e+00,  6.6250e+00],
         [ 2.0469e+00,  1.0859e+00,  8.3594e-01,  ...,  2.9531e+00,
           3.6328e-01,  4.9375e+00]],

        [[-1.1475e-02,  4.6997e-03, -4.4861e-03,  ...,  4.9438e-03,
           5.2002e-02, -1.9844e+00],
         [-9.8047e-01, -8.5547e-01,  8.7500e-01,  ...,  9.5312e-01,
           1.0681e-02,  5.5000e+00],
         [-2.0312e+00, -2.3594e+00, -6.3672e-01,  ..., -8.6328e-01,
          -3.6719e-01,  7.7500e+00],
         ...,
         [-2.6953e-01, -1.3750e+00, -7.6953e-01,  ...,  6.3672e-01,
           4.2578e-01,  4.8438e+00],
         [ 5.2344e-01, -8.3496e-02,  8.9062e-01,  ...,  1.3047e+00,
           1.5703e+00,  6.6250e+00],
         [ 2.0469e+00,  1.0859e+00,  8.3594e-01,  ...,  2.9531e+00,
           3.6328e-01,  4.9375e+00]]], dtype=torch.bfloat16), tensor([[[ 9.2163e-03,  1.6724e-02,  7.9346e-03,  ..., -2.8442e-02,
          -1.2573e-02,  1.4038e-02],
         [ 3.3984e-01,  3.1836e-01,  3.7305e-01,  ...,  4.7852e-01,
          -5.5078e-01,  4.5703e-01],
         [-2.9883e-01, -6.1328e-01,  1.1094e+00,  ..., -9.4531e-01,
           6.4844e-01, -4.6680e-01],
         ...,
         [ 2.8906e-01,  7.3828e-01, -1.6562e+00,  ...,  3.6719e-01,
           1.8828e+00,  1.8945e-01],
         [ 1.2031e+00,  1.4531e+00, -1.9409e-02,  ...,  1.5625e+00,
           3.8750e+00, -3.4180e-01],
         [ 1.2969e+00,  8.7500e-01,  6.4062e-01,  ...,  1.0859e+00,
           1.3594e+00, -3.1641e-01]],

        [[ 9.2163e-03,  1.6724e-02,  7.9346e-03,  ..., -2.8442e-02,
          -1.2573e-02,  1.4038e-02],
         [ 3.3984e-01,  3.1836e-01,  3.7305e-01,  ...,  4.7852e-01,
          -5.5078e-01,  4.5703e-01],
         [-2.9883e-01, -6.1328e-01,  1.1094e+00,  ..., -9.4531e-01,
           6.4844e-01, -4.6680e-01],
         ...,
         [ 2.8906e-01,  7.3828e-01, -1.6562e+00,  ...,  3.6719e-01,
           1.8828e+00,  1.8945e-01],
         [ 1.2031e+00,  1.4531e+00, -1.9409e-02,  ...,  1.5625e+00,
           3.8750e+00, -3.4180e-01],
         [ 1.2969e+00,  8.7500e-01,  6.4062e-01,  ...,  1.0859e+00,
           1.3594e+00, -3.1641e-01]],

        [[ 9.2163e-03,  1.6724e-02,  7.9346e-03,  ..., -2.8442e-02,
          -1.2573e-02,  1.4038e-02],
         [ 3.3984e-01,  3.1836e-01,  3.7305e-01,  ...,  4.7852e-01,
          -5.5078e-01,  4.5703e-01],
         [-2.9883e-01, -6.1328e-01,  1.1094e+00,  ..., -9.4531e-01,
           6.4844e-01, -4.6680e-01],
         ...,
         [ 2.8906e-01,  7.3828e-01, -1.6562e+00,  ...,  3.6719e-01,
           1.8828e+00,  1.8945e-01],
         [ 1.2031e+00,  1.4531e+00, -1.9409e-02,  ...,  1.5625e+00,
           3.8750e+00, -3.4180e-01],
         [ 1.2969e+00,  8.7500e-01,  6.4062e-01,  ...,  1.0859e+00,
           1.3594e+00, -3.1641e-01]],

        ...,

        [[ 1.2817e-02,  1.9897e-02, -2.3682e-02,  ...,  7.8201e-05,
          -3.9795e-02,  1.3306e-02],
         [ 1.4922e+00, -5.1172e-01, -7.1094e-01,  ...,  1.6016e+00,
          -3.3594e-01,  2.4707e-01],
         [ 2.9297e-01, -3.4062e+00, -3.9648e-01,  ..., -7.1094e-01,
          -1.3828e+00, -2.6875e+00],
         ...,
         [-8.3594e-01, -4.2188e-01, -3.3789e-01,  ...,  4.0625e-01,
          -1.1641e+00,  1.8984e+00],
         [-1.2578e+00, -1.3281e+00,  4.6875e-01,  ..., -1.2656e+00,
           5.3516e-01, -1.3438e+00],
         [ 4.7656e-01,  2.2500e+00, -4.8633e-01,  ...,  3.6719e-01,
          -4.5117e-01,  2.6367e-01]],

        [[ 1.2817e-02,  1.9897e-02, -2.3682e-02,  ...,  7.8201e-05,
          -3.9795e-02,  1.3306e-02],
         [ 1.4922e+00, -5.1172e-01, -7.1094e-01,  ...,  1.6016e+00,
          -3.3594e-01,  2.4707e-01],
         [ 2.9297e-01, -3.4062e+00, -3.9648e-01,  ..., -7.1094e-01,
          -1.3828e+00, -2.6875e+00],
         ...,
         [-8.3594e-01, -4.2188e-01, -3.3789e-01,  ...,  4.0625e-01,
          -1.1641e+00,  1.8984e+00],
         [-1.2578e+00, -1.3281e+00,  4.6875e-01,  ..., -1.2656e+00,
           5.3516e-01, -1.3438e+00],
         [ 4.7656e-01,  2.2500e+00, -4.8633e-01,  ...,  3.6719e-01,
          -4.5117e-01,  2.6367e-01]],

        [[ 1.2817e-02,  1.9897e-02, -2.3682e-02,  ...,  7.8201e-05,
          -3.9795e-02,  1.3306e-02],
         [ 1.4922e+00, -5.1172e-01, -7.1094e-01,  ...,  1.6016e+00,
          -3.3594e-01,  2.4707e-01],
         [ 2.9297e-01, -3.4062e+00, -3.9648e-01,  ..., -7.1094e-01,
          -1.3828e+00, -2.6875e+00],
         ...,
         [-8.3594e-01, -4.2188e-01, -3.3789e-01,  ...,  4.0625e-01,
          -1.1641e+00,  1.8984e+00],
         [-1.2578e+00, -1.3281e+00,  4.6875e-01,  ..., -1.2656e+00,
           5.3516e-01, -1.3438e+00],
         [ 4.7656e-01,  2.2500e+00, -4.8633e-01,  ...,  3.6719e-01,
          -4.5117e-01,  2.6367e-01]]], dtype=torch.bfloat16)), (tensor([[[-0.2656, -0.6289, -0.9727,  ...,  0.3301,  0.3516, -1.3516],
         [ 0.3125,  0.0537,  0.1123,  ...,  0.2930,  0.2480,  0.2793],
         [-0.2490, -0.2793, -0.0850,  ...,  0.4082, -2.8750, -1.3750],
         ...,
         [-0.5508, -0.0742,  0.0664,  ...,  3.8750, -3.5156, -2.2500],
         [-0.4023, -1.2109,  0.1289,  ...,  1.0547, -1.4141, -0.2832],
         [-0.3340, -1.6406, -1.1797,  ...,  2.3906, -1.9688,  0.0256]],

        [[-0.2656, -0.6289, -0.9727,  ...,  0.3301,  0.3516, -1.3516],
         [ 0.3125,  0.0537,  0.1123,  ...,  0.2930,  0.2480,  0.2793],
         [-0.2490, -0.2793, -0.0850,  ...,  0.4082, -2.8750, -1.3750],
         ...,
         [-0.5508, -0.0742,  0.0664,  ...,  3.8750, -3.5156, -2.2500],
         [-0.4023, -1.2109,  0.1289,  ...,  1.0547, -1.4141, -0.2832],
         [-0.3340, -1.6406, -1.1797,  ...,  2.3906, -1.9688,  0.0256]],

        [[-0.2656, -0.6289, -0.9727,  ...,  0.3301,  0.3516, -1.3516],
         [ 0.3125,  0.0537,  0.1123,  ...,  0.2930,  0.2480,  0.2793],
         [-0.2490, -0.2793, -0.0850,  ...,  0.4082, -2.8750, -1.3750],
         ...,
         [-0.5508, -0.0742,  0.0664,  ...,  3.8750, -3.5156, -2.2500],
         [-0.4023, -1.2109,  0.1289,  ...,  1.0547, -1.4141, -0.2832],
         [-0.3340, -1.6406, -1.1797,  ...,  2.3906, -1.9688,  0.0256]],

        ...,

        [[-0.0134,  0.9180, -0.1533,  ...,  2.3906, -3.7188,  1.6250],
         [ 0.9141, -0.7188, -0.4258,  ..., -0.6445, -0.3828,  1.3047],
         [-0.3672,  0.5781,  0.1060,  ...,  0.2412,  0.2207, -0.4941],
         ...,
         [-0.8477,  0.4551,  0.4844,  ..., -3.3906,  0.4434, -2.7500],
         [-0.4531,  0.4902,  0.3691,  ..., -2.9375, -4.0312, -0.2344],
         [ 0.0352, -0.5000, -0.5938,  ..., -2.5312,  1.0547, -2.2500]],

        [[-0.0134,  0.9180, -0.1533,  ...,  2.3906, -3.7188,  1.6250],
         [ 0.9141, -0.7188, -0.4258,  ..., -0.6445, -0.3828,  1.3047],
         [-0.3672,  0.5781,  0.1060,  ...,  0.2412,  0.2207, -0.4941],
         ...,
         [-0.8477,  0.4551,  0.4844,  ..., -3.3906,  0.4434, -2.7500],
         [-0.4531,  0.4902,  0.3691,  ..., -2.9375, -4.0312, -0.2344],
         [ 0.0352, -0.5000, -0.5938,  ..., -2.5312,  1.0547, -2.2500]],

        [[-0.0134,  0.9180, -0.1533,  ...,  2.3906, -3.7188,  1.6250],
         [ 0.9141, -0.7188, -0.4258,  ..., -0.6445, -0.3828,  1.3047],
         [-0.3672,  0.5781,  0.1060,  ...,  0.2412,  0.2207, -0.4941],
         ...,
         [-0.8477,  0.4551,  0.4844,  ..., -3.3906,  0.4434, -2.7500],
         [-0.4531,  0.4902,  0.3691,  ..., -2.9375, -4.0312, -0.2344],
         [ 0.0352, -0.5000, -0.5938,  ..., -2.5312,  1.0547, -2.2500]]],
       dtype=torch.bfloat16), tensor([[[ 0.4805, -0.1099, -0.3379,  ..., -0.0300,  0.0986,  0.2188],
         [-0.0181, -0.2637, -0.2461,  ..., -0.2080, -0.1777, -0.1348],
         [-0.4180, -0.0654,  0.0275,  ...,  0.0275, -0.0159,  0.0723],
         ...,
         [ 0.2324,  0.0518,  0.4141,  ..., -0.6719,  0.6680, -0.7969],
         [ 0.2227,  0.3457, -0.1729,  ..., -0.1299,  0.3965, -0.1816],
         [ 0.2949,  0.3594,  0.3672,  ...,  0.3105,  0.5742, -0.4805]],

        [[ 0.4805, -0.1099, -0.3379,  ..., -0.0300,  0.0986,  0.2188],
         [-0.0181, -0.2637, -0.2461,  ..., -0.2080, -0.1777, -0.1348],
         [-0.4180, -0.0654,  0.0275,  ...,  0.0275, -0.0159,  0.0723],
         ...,
         [ 0.2324,  0.0518,  0.4141,  ..., -0.6719,  0.6680, -0.7969],
         [ 0.2227,  0.3457, -0.1729,  ..., -0.1299,  0.3965, -0.1816],
         [ 0.2949,  0.3594,  0.3672,  ...,  0.3105,  0.5742, -0.4805]],

        [[ 0.4805, -0.1099, -0.3379,  ..., -0.0300,  0.0986,  0.2188],
         [-0.0181, -0.2637, -0.2461,  ..., -0.2080, -0.1777, -0.1348],
         [-0.4180, -0.0654,  0.0275,  ...,  0.0275, -0.0159,  0.0723],
         ...,
         [ 0.2324,  0.0518,  0.4141,  ..., -0.6719,  0.6680, -0.7969],
         [ 0.2227,  0.3457, -0.1729,  ..., -0.1299,  0.3965, -0.1816],
         [ 0.2949,  0.3594,  0.3672,  ...,  0.3105,  0.5742, -0.4805]],

        ...,

        [[ 0.3242, -0.4062, -0.1709,  ...,  0.2695, -0.5391, -0.3926],
         [ 0.2695, -0.2197,  0.2227,  ..., -0.2734, -0.5898,  0.0435],
         [ 0.8477, -0.6094, -0.0427,  ..., -0.1914, -0.5508, -0.7031],
         ...,
         [-0.9922,  1.2812, -0.1426,  ...,  0.1475,  0.4668,  0.4688],
         [-1.4141,  1.1797, -0.3125,  ...,  0.4570,  0.0796,  0.2695],
         [-0.1904, -0.0030, -0.6992,  ...,  0.3398, -0.4238,  0.1924]],

        [[ 0.3242, -0.4062, -0.1709,  ...,  0.2695, -0.5391, -0.3926],
         [ 0.2695, -0.2197,  0.2227,  ..., -0.2734, -0.5898,  0.0435],
         [ 0.8477, -0.6094, -0.0427,  ..., -0.1914, -0.5508, -0.7031],
         ...,
         [-0.9922,  1.2812, -0.1426,  ...,  0.1475,  0.4668,  0.4688],
         [-1.4141,  1.1797, -0.3125,  ...,  0.4570,  0.0796,  0.2695],
         [-0.1904, -0.0030, -0.6992,  ...,  0.3398, -0.4238,  0.1924]],

        [[ 0.3242, -0.4062, -0.1709,  ...,  0.2695, -0.5391, -0.3926],
         [ 0.2695, -0.2197,  0.2227,  ..., -0.2734, -0.5898,  0.0435],
         [ 0.8477, -0.6094, -0.0427,  ..., -0.1914, -0.5508, -0.7031],
         ...,
         [-0.9922,  1.2812, -0.1426,  ...,  0.1475,  0.4668,  0.4688],
         [-1.4141,  1.1797, -0.3125,  ...,  0.4570,  0.0796,  0.2695],
         [-0.1904, -0.0030, -0.6992,  ...,  0.3398, -0.4238,  0.1924]]],
       dtype=torch.bfloat16)))}, logits=tensor([[[ -8.7500, -10.6250, -11.9375,  ..., -10.8125, -12.0000,  -9.5000],
         [ -9.1250, -10.5000, -12.1250,  ...,  -8.1250,  -9.4375,  -7.1562],
         [-16.1250, -22.2500, -24.0000,  ..., -19.1250, -19.5000, -18.6250],
         ...,
         [-13.8750, -16.2500, -20.3750,  ..., -12.1250, -17.5000, -11.3750],
         [-13.4375, -14.7500, -19.0000,  ..., -15.5000, -17.7500, -14.8125],
         [-13.3750, -15.8125, -17.7500,  ..., -11.9375, -16.0000,  -9.5000]]],
       dtype=torch.bfloat16), past_key_values=((tensor([[[ 0.5547, -1.4062, -0.3340,  ...,  0.1514, -0.9297, -1.0312],
         [-0.4785, -1.4062, -0.4316,  ...,  0.0583, -1.4062,  0.1992],
         [-2.2500, -2.3281,  0.4707,  ..., -1.1562, -0.1973,  0.9883],
         ...,
         [-0.1641,  1.3672, -0.2100,  ...,  0.0522, -1.4062,  0.1992],
         [-0.1484, -0.2168, -0.0801,  ..., -0.9414, -0.6211, -1.1953],
         [ 1.5781, -1.5859,  0.1396,  ..., -0.1191, -0.2041, -0.2812]],

        [[ 0.5547, -1.4062, -0.3340,  ...,  0.1514, -0.9297, -1.0312],
         [-0.4785, -1.4062, -0.4316,  ...,  0.0583, -1.4062,  0.1992],
         [-2.2500, -2.3281,  0.4707,  ..., -1.1562, -0.1973,  0.9883],
         ...,
         [-0.1641,  1.3672, -0.2100,  ...,  0.0522, -1.4062,  0.1992],
         [-0.1484, -0.2168, -0.0801,  ..., -0.9414, -0.6211, -1.1953],
         [ 1.5781, -1.5859,  0.1396,  ..., -0.1191, -0.2041, -0.2812]],

        [[ 0.5547, -1.4062, -0.3340,  ...,  0.1514, -0.9297, -1.0312],
         [-0.4785, -1.4062, -0.4316,  ...,  0.0583, -1.4062,  0.1992],
         [-2.2500, -2.3281,  0.4707,  ..., -1.1562, -0.1973,  0.9883],
         ...,
         [-0.1641,  1.3672, -0.2100,  ...,  0.0522, -1.4062,  0.1992],
         [-0.1484, -0.2168, -0.0801,  ..., -0.9414, -0.6211, -1.1953],
         [ 1.5781, -1.5859,  0.1396,  ..., -0.1191, -0.2041, -0.2812]],

        ...,

        [[ 0.4551, -0.1377, -0.2383,  ...,  0.9648, -1.7422,  0.1562],
         [ 0.2422, -1.1875, -0.4629,  ..., -0.1816, -0.8867, -0.7070],
         [-1.3672, -1.3750, -1.0781,  ...,  0.2119,  1.6172,  0.6758],
         ...,
         [-1.4375,  1.2188, -0.1250,  ..., -0.1787, -0.8945, -0.7070],
         [-0.1367,  0.2949, -0.0059,  ...,  0.7617, -1.8203,  1.6328],
         [ 1.4922,  0.6953, -1.4297,  ...,  0.9961,  1.2109,  1.5938]],

        [[ 0.4551, -0.1377, -0.2383,  ...,  0.9648, -1.7422,  0.1562],
         [ 0.2422, -1.1875, -0.4629,  ..., -0.1816, -0.8867, -0.7070],
         [-1.3672, -1.3750, -1.0781,  ...,  0.2119,  1.6172,  0.6758],
         ...,
         [-1.4375,  1.2188, -0.1250,  ..., -0.1787, -0.8945, -0.7070],
         [-0.1367,  0.2949, -0.0059,  ...,  0.7617, -1.8203,  1.6328],
         [ 1.4922,  0.6953, -1.4297,  ...,  0.9961,  1.2109,  1.5938]],

        [[ 0.4551, -0.1377, -0.2383,  ...,  0.9648, -1.7422,  0.1562],
         [ 0.2422, -1.1875, -0.4629,  ..., -0.1816, -0.8867, -0.7070],
         [-1.3672, -1.3750, -1.0781,  ...,  0.2119,  1.6172,  0.6758],
         ...,
         [-1.4375,  1.2188, -0.1250,  ..., -0.1787, -0.8945, -0.7070],
         [-0.1367,  0.2949, -0.0059,  ...,  0.7617, -1.8203,  1.6328],
         [ 1.4922,  0.6953, -1.4297,  ...,  0.9961,  1.2109,  1.5938]]],
       dtype=torch.bfloat16), tensor([[[-0.0036,  0.0417,  0.0364,  ..., -0.0087, -0.0391, -0.0474],
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [-0.0591, -0.0118, -0.1279,  ..., -0.1592,  0.1338,  0.0255],
         ...,
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [ 0.0087,  0.0205,  0.0557,  ...,  0.0085,  0.0039, -0.0454],
         [-0.0659,  0.0427,  0.1006,  ...,  0.1055, -0.0527, -0.0339]],

        [[-0.0036,  0.0417,  0.0364,  ..., -0.0087, -0.0391, -0.0474],
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [-0.0591, -0.0118, -0.1279,  ..., -0.1592,  0.1338,  0.0255],
         ...,
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [ 0.0087,  0.0205,  0.0557,  ...,  0.0085,  0.0039, -0.0454],
         [-0.0659,  0.0427,  0.1006,  ...,  0.1055, -0.0527, -0.0339]],

        [[-0.0036,  0.0417,  0.0364,  ..., -0.0087, -0.0391, -0.0474],
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [-0.0591, -0.0118, -0.1279,  ..., -0.1592,  0.1338,  0.0255],
         ...,
         [-0.0889,  0.0193, -0.0654,  ..., -0.0253,  0.0302, -0.0610],
         [ 0.0087,  0.0205,  0.0557,  ...,  0.0085,  0.0039, -0.0454],
         [-0.0659,  0.0427,  0.1006,  ...,  0.1055, -0.0527, -0.0339]],

        ...,

        [[-0.0593, -0.0640, -0.0276,  ..., -0.0116, -0.0459, -0.0016],
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0708, -0.0918,  0.2285,  ..., -0.0635,  0.1396,  0.0603],
         ...,
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0923,  0.0505,  0.0068,  ...,  0.0239, -0.0119,  0.0031],
         [-0.2773, -0.3125, -0.3086,  ...,  0.0464,  0.1826, -0.2871]],

        [[-0.0593, -0.0640, -0.0276,  ..., -0.0116, -0.0459, -0.0016],
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0708, -0.0918,  0.2285,  ..., -0.0635,  0.1396,  0.0603],
         ...,
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0923,  0.0505,  0.0068,  ...,  0.0239, -0.0119,  0.0031],
         [-0.2773, -0.3125, -0.3086,  ...,  0.0464,  0.1826, -0.2871]],

        [[-0.0593, -0.0640, -0.0276,  ..., -0.0116, -0.0459, -0.0016],
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0708, -0.0918,  0.2285,  ..., -0.0635,  0.1396,  0.0603],
         ...,
         [-0.1680,  0.1338,  0.0145,  ...,  0.0097, -0.0281,  0.0104],
         [-0.0923,  0.0505,  0.0068,  ...,  0.0239, -0.0119,  0.0031],
         [-0.2773, -0.3125, -0.3086,  ...,  0.0464,  0.1826, -0.2871]]],
       dtype=torch.bfloat16)), (tensor([[[ 2.4531,  1.6719, -4.0938,  ..., -1.4219,  0.8477,  1.2422],
         [ 7.1250,  2.4375, -2.5938,  ..., -1.1406,  0.6445,  1.5938],
         [ 4.6875,  4.4375, -2.6875,  ..., -0.1709,  0.5898,  3.0156],
         ...,
         [-3.0156, -2.6562, -2.6562,  ..., -2.2656,  0.7383,  3.3906],
         [-2.3125, -2.8281, -2.1562,  ...,  0.1621,  2.0000, -1.0938],
         [-5.6250, -4.4062, -2.9062,  ..., -1.9922,  1.0938,  1.3516]],

        [[ 2.4531,  1.6719, -4.0938,  ..., -1.4219,  0.8477,  1.2422],
         [ 7.1250,  2.4375, -2.5938,  ..., -1.1406,  0.6445,  1.5938],
         [ 4.6875,  4.4375, -2.6875,  ..., -0.1709,  0.5898,  3.0156],
         ...,
         [-3.0156, -2.6562, -2.6562,  ..., -2.2656,  0.7383,  3.3906],
         [-2.3125, -2.8281, -2.1562,  ...,  0.1621,  2.0000, -1.0938],
         [-5.6250, -4.4062, -2.9062,  ..., -1.9922,  1.0938,  1.3516]],

        [[ 2.4531,  1.6719, -4.0938,  ..., -1.4219,  0.8477,  1.2422],
         [ 7.1250,  2.4375, -2.5938,  ..., -1.1406,  0.6445,  1.5938],
         [ 4.6875,  4.4375, -2.6875,  ..., -0.1709,  0.5898,  3.0156],
         ...,
         [-3.0156, -2.6562, -2.6562,  ..., -2.2656,  0.7383,  3.3906],
         [-2.3125, -2.8281, -2.1562,  ...,  0.1621,  2.0000, -1.0938],
         [-5.6250, -4.4062, -2.9062,  ..., -1.9922,  1.0938,  1.3516]],

        ...,

        [[ 0.1846,  0.7070, -2.5938,  ..., -1.6562,  3.1719, -1.0938],
         [ 3.0312,  1.2500, -1.8672,  ..., -2.9219,  2.3594, -0.6680],
         [ 5.0625,  0.7070, -0.4258,  ..., -4.0938,  3.5781, -3.2344],
         ...,
         [-0.9375, -2.8438, -1.4531,  ..., -2.8594,  4.0312, -3.2188],
         [ 0.1758, -0.5859, -0.0127,  ..., -4.4688,  2.4375, -1.9375],
         [-5.9688, -1.0625, -4.2500,  ..., -2.6406,  3.4844, -0.4336]],

        [[ 0.1846,  0.7070, -2.5938,  ..., -1.6562,  3.1719, -1.0938],
         [ 3.0312,  1.2500, -1.8672,  ..., -2.9219,  2.3594, -0.6680],
         [ 5.0625,  0.7070, -0.4258,  ..., -4.0938,  3.5781, -3.2344],
         ...,
         [-0.9375, -2.8438, -1.4531,  ..., -2.8594,  4.0312, -3.2188],
         [ 0.1758, -0.5859, -0.0127,  ..., -4.4688,  2.4375, -1.9375],
         [-5.9688, -1.0625, -4.2500,  ..., -2.6406,  3.4844, -0.4336]],

        [[ 0.1846,  0.7070, -2.5938,  ..., -1.6562,  3.1719, -1.0938],
         [ 3.0312,  1.2500, -1.8672,  ..., -2.9219,  2.3594, -0.6680],
         [ 5.0625,  0.7070, -0.4258,  ..., -4.0938,  3.5781, -3.2344],
         ...,
         [-0.9375, -2.8438, -1.4531,  ..., -2.8594,  4.0312, -3.2188],
         [ 0.1758, -0.5859, -0.0127,  ..., -4.4688,  2.4375, -1.9375],
         [-5.9688, -1.0625, -4.2500,  ..., -2.6406,  3.4844, -0.4336]]],
       dtype=torch.bfloat16), tensor([[[ 0.0366,  0.0337,  0.0234,  ...,  0.1299,  0.2031, -0.1338],
         [-0.0269, -0.0304,  0.0752,  ..., -0.1118, -0.0603,  0.0413],
         [-0.0747,  0.0601,  0.0187,  ..., -0.0109, -0.2598, -0.1670],
         ...,
         [ 0.1973, -0.0928,  0.0315,  ..., -0.1162,  0.1777,  0.1699],
         [-0.0259, -0.0040,  0.0400,  ...,  0.0557, -0.0082, -0.0427],
         [-0.1729,  0.1797,  0.0096,  ...,  0.0938, -0.3145, -0.1206]],

        [[ 0.0366,  0.0337,  0.0234,  ...,  0.1299,  0.2031, -0.1338],
         [-0.0269, -0.0304,  0.0752,  ..., -0.1118, -0.0603,  0.0413],
         [-0.0747,  0.0601,  0.0187,  ..., -0.0109, -0.2598, -0.1670],
         ...,
         [ 0.1973, -0.0928,  0.0315,  ..., -0.1162,  0.1777,  0.1699],
         [-0.0259, -0.0040,  0.0400,  ...,  0.0557, -0.0082, -0.0427],
         [-0.1729,  0.1797,  0.0096,  ...,  0.0938, -0.3145, -0.1206]],

        [[ 0.0366,  0.0337,  0.0234,  ...,  0.1299,  0.2031, -0.1338],
         [-0.0269, -0.0304,  0.0752,  ..., -0.1118, -0.0603,  0.0413],
         [-0.0747,  0.0601,  0.0187,  ..., -0.0109, -0.2598, -0.1670],
         ...,
         [ 0.1973, -0.0928,  0.0315,  ..., -0.1162,  0.1777,  0.1699],
         [-0.0259, -0.0040,  0.0400,  ...,  0.0557, -0.0082, -0.0427],
         [-0.1729,  0.1797,  0.0096,  ...,  0.0938, -0.3145, -0.1206]],

        ...,

        [[ 0.0688,  0.0198,  0.0096,  ..., -0.0469, -0.0825, -0.0283],
         [-0.0437,  0.0231, -0.0981,  ...,  0.0354, -0.0835, -0.0356],
         [-0.1641,  0.0330, -0.0334,  ...,  0.0674, -0.1543,  0.1328],
         ...,
         [ 0.0874, -0.0222, -0.1260,  ...,  0.0718, -0.0569, -0.0508],
         [-0.0073, -0.0215,  0.0016,  ...,  0.0200,  0.0398,  0.0527],
         [ 0.0408,  0.1011,  0.2637,  ..., -0.0903,  0.0087,  0.0859]],

        [[ 0.0688,  0.0198,  0.0096,  ..., -0.0469, -0.0825, -0.0283],
         [-0.0437,  0.0231, -0.0981,  ...,  0.0354, -0.0835, -0.0356],
         [-0.1641,  0.0330, -0.0334,  ...,  0.0674, -0.1543,  0.1328],
         ...,
         [ 0.0874, -0.0222, -0.1260,  ...,  0.0718, -0.0569, -0.0508],
         [-0.0073, -0.0215,  0.0016,  ...,  0.0200,  0.0398,  0.0527],
         [ 0.0408,  0.1011,  0.2637,  ..., -0.0903,  0.0087,  0.0859]],

        [[ 0.0688,  0.0198,  0.0096,  ..., -0.0469, -0.0825, -0.0283],
         [-0.0437,  0.0231, -0.0981,  ...,  0.0354, -0.0835, -0.0356],
         [-0.1641,  0.0330, -0.0334,  ...,  0.0674, -0.1543,  0.1328],
         ...,
         [ 0.0874, -0.0222, -0.1260,  ...,  0.0718, -0.0569, -0.0508],
         [-0.0073, -0.0215,  0.0016,  ...,  0.0200,  0.0398,  0.0527],
         [ 0.0408,  0.1011,  0.2637,  ..., -0.0903,  0.0087,  0.0859]]],
       dtype=torch.bfloat16)), (tensor([[[-2.1820e-03, -5.0964e-03,  8.5831e-04,  ..., -9.5312e-01,
           6.8970e-03,  4.1406e-01],
         [ 1.4219e+00,  1.4141e+00,  2.2188e+00,  ...,  6.9062e+00,
           1.8164e-01, -1.7188e+00],
         [ 1.3672e+00,  7.8613e-02,  1.4531e+00,  ...,  6.3438e+00,
          -1.0547e+00, -2.1562e+00],
         ...,
         [ 8.2422e-01, -7.6953e-01,  1.9297e+00,  ...,  7.3125e+00,
          -2.0625e+00, -2.8750e+00],
         [ 7.0312e-02,  1.0352e-01,  2.8125e-01,  ...,  6.2500e+00,
          -7.6953e-01, -1.5312e+00],
         [-1.4062e+00, -1.1562e+00,  6.0547e-01,  ...,  6.2188e+00,
           1.2031e+00, -1.1016e+00]],

        [[-2.1820e-03, -5.0964e-03,  8.5831e-04,  ..., -9.5312e-01,
           6.8970e-03,  4.1406e-01],
         [ 1.4219e+00,  1.4141e+00,  2.2188e+00,  ...,  6.9062e+00,
           1.8164e-01, -1.7188e+00],
         [ 1.3672e+00,  7.8613e-02,  1.4531e+00,  ...,  6.3438e+00,
          -1.0547e+00, -2.1562e+00],
         ...,
         [ 8.2422e-01, -7.6953e-01,  1.9297e+00,  ...,  7.3125e+00,
          -2.0625e+00, -2.8750e+00],
         [ 7.0312e-02,  1.0352e-01,  2.8125e-01,  ...,  6.2500e+00,
          -7.6953e-01, -1.5312e+00],
         [-1.4062e+00, -1.1562e+00,  6.0547e-01,  ...,  6.2188e+00,
           1.2031e+00, -1.1016e+00]],

        [[-2.1820e-03, -5.0964e-03,  8.5831e-04,  ..., -9.5312e-01,
           6.8970e-03,  4.1406e-01],
         [ 1.4219e+00,  1.4141e+00,  2.2188e+00,  ...,  6.9062e+00,
           1.8164e-01, -1.7188e+00],
         [ 1.3672e+00,  7.8613e-02,  1.4531e+00,  ...,  6.3438e+00,
          -1.0547e+00, -2.1562e+00],
         ...,
         [ 8.2422e-01, -7.6953e-01,  1.9297e+00,  ...,  7.3125e+00,
          -2.0625e+00, -2.8750e+00],
         [ 7.0312e-02,  1.0352e-01,  2.8125e-01,  ...,  6.2500e+00,
          -7.6953e-01, -1.5312e+00],
         [-1.4062e+00, -1.1562e+00,  6.0547e-01,  ...,  6.2188e+00,
           1.2031e+00, -1.1016e+00]],

        ...,

        [[ 1.0315e-02,  7.4005e-04, -5.4626e-03,  ..., -4.5654e-02,
           9.8145e-02,  1.5015e-02],
         [ 4.0000e+00,  1.5156e+00,  1.7188e+00,  ...,  8.3008e-02,
          -2.4844e+00,  1.7188e+00],
         [ 2.1875e+00, -4.4434e-02,  1.3672e+00,  ...,  1.4609e+00,
          -1.6250e+00,  1.2578e+00],
         ...,
         [-2.4062e+00, -4.2383e-01,  2.0156e+00,  ...,  1.1406e+00,
          -1.5859e+00,  8.0469e-01],
         [-1.1875e+00, -2.2363e-01,  1.2422e+00,  ...,  1.6016e+00,
           4.3164e-01,  2.5000e+00],
         [-2.5781e+00, -9.8438e-01,  7.4219e-01,  ...,  6.5234e-01,
          -1.0391e+00,  1.8359e-01]],

        [[ 1.0315e-02,  7.4005e-04, -5.4626e-03,  ..., -4.5654e-02,
           9.8145e-02,  1.5015e-02],
         [ 4.0000e+00,  1.5156e+00,  1.7188e+00,  ...,  8.3008e-02,
          -2.4844e+00,  1.7188e+00],
         [ 2.1875e+00, -4.4434e-02,  1.3672e+00,  ...,  1.4609e+00,
          -1.6250e+00,  1.2578e+00],
         ...,
         [-2.4062e+00, -4.2383e-01,  2.0156e+00,  ...,  1.1406e+00,
          -1.5859e+00,  8.0469e-01],
         [-1.1875e+00, -2.2363e-01,  1.2422e+00,  ...,  1.6016e+00,
           4.3164e-01,  2.5000e+00],
         [-2.5781e+00, -9.8438e-01,  7.4219e-01,  ...,  6.5234e-01,
          -1.0391e+00,  1.8359e-01]],

        [[ 1.0315e-02,  7.4005e-04, -5.4626e-03,  ..., -4.5654e-02,
           9.8145e-02,  1.5015e-02],
         [ 4.0000e+00,  1.5156e+00,  1.7188e+00,  ...,  8.3008e-02,
          -2.4844e+00,  1.7188e+00],
         [ 2.1875e+00, -4.4434e-02,  1.3672e+00,  ...,  1.4609e+00,
          -1.6250e+00,  1.2578e+00],
         ...,
         [-2.4062e+00, -4.2383e-01,  2.0156e+00,  ...,  1.1406e+00,
          -1.5859e+00,  8.0469e-01],
         [-1.1875e+00, -2.2363e-01,  1.2422e+00,  ...,  1.6016e+00,
           4.3164e-01,  2.5000e+00],
         [-2.5781e+00, -9.8438e-01,  7.4219e-01,  ...,  6.5234e-01,
          -1.0391e+00,  1.8359e-01]]], dtype=torch.bfloat16), tensor([[[-1.8997e-03, -1.0864e-02,  2.7313e-03,  ..., -5.3711e-03,
           7.7820e-04,  3.0975e-03],
         [ 1.2793e-01,  3.1445e-01,  5.7812e-01,  ...,  5.1953e-01,
          -1.4771e-02, -3.0078e-01],
         [-2.1973e-01,  5.4443e-02, -1.6699e-01,  ..., -5.3516e-01,
           2.7930e-01,  1.0205e-01],
         ...,
         [-1.8262e-01,  3.0273e-02,  2.2461e-01,  ..., -1.0193e-02,
           2.6953e-01,  5.9891e-04],
         [ 3.1250e-01, -3.7354e-02,  4.6631e-02,  ...,  4.3945e-01,
          -9.2285e-02, -1.6895e-01],
         [ 3.7500e-01,  1.1406e+00, -1.2969e+00,  ...,  1.0156e+00,
          -5.2344e-01,  5.2734e-01]],

        [[-1.8997e-03, -1.0864e-02,  2.7313e-03,  ..., -5.3711e-03,
           7.7820e-04,  3.0975e-03],
         [ 1.2793e-01,  3.1445e-01,  5.7812e-01,  ...,  5.1953e-01,
          -1.4771e-02, -3.0078e-01],
         [-2.1973e-01,  5.4443e-02, -1.6699e-01,  ..., -5.3516e-01,
           2.7930e-01,  1.0205e-01],
         ...,
         [-1.8262e-01,  3.0273e-02,  2.2461e-01,  ..., -1.0193e-02,
           2.6953e-01,  5.9891e-04],
         [ 3.1250e-01, -3.7354e-02,  4.6631e-02,  ...,  4.3945e-01,
          -9.2285e-02, -1.6895e-01],
         [ 3.7500e-01,  1.1406e+00, -1.2969e+00,  ...,  1.0156e+00,
          -5.2344e-01,  5.2734e-01]],

        [[-1.8997e-03, -1.0864e-02,  2.7313e-03,  ..., -5.3711e-03,
           7.7820e-04,  3.0975e-03],
         [ 1.2793e-01,  3.1445e-01,  5.7812e-01,  ...,  5.1953e-01,
          -1.4771e-02, -3.0078e-01],
         [-2.1973e-01,  5.4443e-02, -1.6699e-01,  ..., -5.3516e-01,
           2.7930e-01,  1.0205e-01],
         ...,
         [-1.8262e-01,  3.0273e-02,  2.2461e-01,  ..., -1.0193e-02,
           2.6953e-01,  5.9891e-04],
         [ 3.1250e-01, -3.7354e-02,  4.6631e-02,  ...,  4.3945e-01,
          -9.2285e-02, -1.6895e-01],
         [ 3.7500e-01,  1.1406e+00, -1.2969e+00,  ...,  1.0156e+00,
          -5.2344e-01,  5.2734e-01]],

        ...,

        [[ 1.1658e-02, -7.6599e-03, -4.5967e-04,  ..., -3.1128e-03,
           3.2349e-03,  9.1934e-04],
         [ 2.6953e-01,  2.1582e-01, -2.3633e-01,  ...,  2.2095e-02,
           2.5195e-01,  5.7373e-02],
         [-3.2031e-01,  1.0315e-02,  5.2979e-02,  ...,  4.0820e-01,
           1.6895e-01, -2.3047e-01],
         ...,
         [-2.9785e-02,  1.4258e-01, -1.9824e-01,  ...,  2.0312e-01,
           1.8848e-01,  2.8711e-01],
         [-2.9688e-01,  1.5234e-01, -1.6797e-01,  ...,  1.9824e-01,
          -2.8711e-01, -4.9072e-02],
         [-1.3184e-01, -1.4160e-01,  2.1973e-02,  ..., -2.5781e-01,
          -2.5977e-01, -4.6680e-01]],

        [[ 1.1658e-02, -7.6599e-03, -4.5967e-04,  ..., -3.1128e-03,
           3.2349e-03,  9.1934e-04],
         [ 2.6953e-01,  2.1582e-01, -2.3633e-01,  ...,  2.2095e-02,
           2.5195e-01,  5.7373e-02],
         [-3.2031e-01,  1.0315e-02,  5.2979e-02,  ...,  4.0820e-01,
           1.6895e-01, -2.3047e-01],
         ...,
         [-2.9785e-02,  1.4258e-01, -1.9824e-01,  ...,  2.0312e-01,
           1.8848e-01,  2.8711e-01],
         [-2.9688e-01,  1.5234e-01, -1.6797e-01,  ...,  1.9824e-01,
          -2.8711e-01, -4.9072e-02],
         [-1.3184e-01, -1.4160e-01,  2.1973e-02,  ..., -2.5781e-01,
          -2.5977e-01, -4.6680e-01]],

        [[ 1.1658e-02, -7.6599e-03, -4.5967e-04,  ..., -3.1128e-03,
           3.2349e-03,  9.1934e-04],
         [ 2.6953e-01,  2.1582e-01, -2.3633e-01,  ...,  2.2095e-02,
           2.5195e-01,  5.7373e-02],
         [-3.2031e-01,  1.0315e-02,  5.2979e-02,  ...,  4.0820e-01,
           1.6895e-01, -2.3047e-01],
         ...,
         [-2.9785e-02,  1.4258e-01, -1.9824e-01,  ...,  2.0312e-01,
           1.8848e-01,  2.8711e-01],
         [-2.9688e-01,  1.5234e-01, -1.6797e-01,  ...,  1.9824e-01,
          -2.8711e-01, -4.9072e-02],
         [-1.3184e-01, -1.4160e-01,  2.1973e-02,  ..., -2.5781e-01,
          -2.5977e-01, -4.6680e-01]]], dtype=torch.bfloat16)), (tensor([[[ 7.3242e-03,  4.6387e-03,  8.1787e-03,  ...,  6.4844e-01,
           5.0000e-01, -3.8477e-01],
         [-1.9688e+00,  1.5000e+00, -7.1484e-01,  ..., -3.2031e+00,
          -2.9375e+00,  2.9883e-01],
         [-5.7031e-01,  5.5078e-01, -1.0781e+00,  ..., -3.3281e+00,
          -3.3906e+00,  1.0469e+00],
         ...,
         [ 5.1172e-01, -8.0469e-01,  1.9434e-01,  ..., -4.8438e+00,
          -3.0312e+00,  8.8867e-02],
         [ 1.2500e-01, -3.7305e-01, -3.4570e-01,  ..., -4.9375e+00,
          -3.5000e+00,  5.1953e-01],
         [ 1.6309e-01,  1.2891e-01,  3.4961e-01,  ..., -5.2500e+00,
          -2.3750e+00,  1.0078e+00]],

        [[ 7.3242e-03,  4.6387e-03,  8.1787e-03,  ...,  6.4844e-01,
           5.0000e-01, -3.8477e-01],
         [-1.9688e+00,  1.5000e+00, -7.1484e-01,  ..., -3.2031e+00,
          -2.9375e+00,  2.9883e-01],
         [-5.7031e-01,  5.5078e-01, -1.0781e+00,  ..., -3.3281e+00,
          -3.3906e+00,  1.0469e+00],
         ...,
         [ 5.1172e-01, -8.0469e-01,  1.9434e-01,  ..., -4.8438e+00,
          -3.0312e+00,  8.8867e-02],
         [ 1.2500e-01, -3.7305e-01, -3.4570e-01,  ..., -4.9375e+00,
          -3.5000e+00,  5.1953e-01],
         [ 1.6309e-01,  1.2891e-01,  3.4961e-01,  ..., -5.2500e+00,
          -2.3750e+00,  1.0078e+00]],

        [[ 7.3242e-03,  4.6387e-03,  8.1787e-03,  ...,  6.4844e-01,
           5.0000e-01, -3.8477e-01],
         [-1.9688e+00,  1.5000e+00, -7.1484e-01,  ..., -3.2031e+00,
          -2.9375e+00,  2.9883e-01],
         [-5.7031e-01,  5.5078e-01, -1.0781e+00,  ..., -3.3281e+00,
          -3.3906e+00,  1.0469e+00],
         ...,
         [ 5.1172e-01, -8.0469e-01,  1.9434e-01,  ..., -4.8438e+00,
          -3.0312e+00,  8.8867e-02],
         [ 1.2500e-01, -3.7305e-01, -3.4570e-01,  ..., -4.9375e+00,
          -3.5000e+00,  5.1953e-01],
         [ 1.6309e-01,  1.2891e-01,  3.4961e-01,  ..., -5.2500e+00,
          -2.3750e+00,  1.0078e+00]],

        ...,

        [[ 2.6398e-03, -5.1575e-03,  3.1586e-03,  ...,  8.0078e-02,
          -5.4321e-03, -2.4219e-01],
         [-5.1875e+00, -8.9062e-01,  7.8516e-01,  ..., -6.6406e-01,
          -1.3281e+00, -6.8750e-01],
         [-3.2188e+00, -5.5664e-02,  1.1719e-02,  ...,  5.1172e-01,
           3.5352e-01, -9.7266e-01],
         ...,
         [ 2.1562e+00,  1.8438e+00,  1.5703e+00,  ..., -6.1328e-01,
           5.1953e-01, -7.0312e-01],
         [ 3.2500e+00, -3.3789e-01,  1.6328e+00,  ...,  7.4219e-01,
           8.8672e-01, -7.8125e-01],
         [ 2.7812e+00, -2.0312e+00, -7.5781e-01,  ..., -1.0391e+00,
           2.1875e-01, -1.5039e-01]],

        [[ 2.6398e-03, -5.1575e-03,  3.1586e-03,  ...,  8.0078e-02,
          -5.4321e-03, -2.4219e-01],
         [-5.1875e+00, -8.9062e-01,  7.8516e-01,  ..., -6.6406e-01,
          -1.3281e+00, -6.8750e-01],
         [-3.2188e+00, -5.5664e-02,  1.1719e-02,  ...,  5.1172e-01,
           3.5352e-01, -9.7266e-01],
         ...,
         [ 2.1562e+00,  1.8438e+00,  1.5703e+00,  ..., -6.1328e-01,
           5.1953e-01, -7.0312e-01],
         [ 3.2500e+00, -3.3789e-01,  1.6328e+00,  ...,  7.4219e-01,
           8.8672e-01, -7.8125e-01],
         [ 2.7812e+00, -2.0312e+00, -7.5781e-01,  ..., -1.0391e+00,
           2.1875e-01, -1.5039e-01]],

        [[ 2.6398e-03, -5.1575e-03,  3.1586e-03,  ...,  8.0078e-02,
          -5.4321e-03, -2.4219e-01],
         [-5.1875e+00, -8.9062e-01,  7.8516e-01,  ..., -6.6406e-01,
          -1.3281e+00, -6.8750e-01],
         [-3.2188e+00, -5.5664e-02,  1.1719e-02,  ...,  5.1172e-01,
           3.5352e-01, -9.7266e-01],
         ...,
         [ 2.1562e+00,  1.8438e+00,  1.5703e+00,  ..., -6.1328e-01,
           5.1953e-01, -7.0312e-01],
         [ 3.2500e+00, -3.3789e-01,  1.6328e+00,  ...,  7.4219e-01,
           8.8672e-01, -7.8125e-01],
         [ 2.7812e+00, -2.0312e+00, -7.5781e-01,  ..., -1.0391e+00,
           2.1875e-01, -1.5039e-01]]], dtype=torch.bfloat16), tensor([[[ 5.3406e-04,  6.9580e-03, -4.0283e-03,  ...,  1.3428e-03,
          -3.2349e-03, -1.6174e-03],
         [ 1.1963e-01, -6.2109e-01,  2.1680e-01,  ..., -2.0874e-02,
           1.0193e-02,  2.7930e-01],
         [-3.1055e-01,  4.1016e-01,  8.4375e-01,  ...,  1.8457e-01,
          -7.3730e-02, -3.7305e-01],
         ...,
         [ 1.2891e-01,  4.1992e-01,  4.2188e-01,  ...,  1.4453e-01,
           3.4961e-01, -2.1118e-02],
         [ 7.2266e-01, -7.5000e-01,  1.8750e-01,  ..., -2.8516e-01,
           5.1172e-01,  2.8198e-02],
         [-1.3184e-01,  1.5332e-01,  1.0693e-01,  ...,  4.2578e-01,
          -6.1646e-03, -3.0859e-01]],

        [[ 5.3406e-04,  6.9580e-03, -4.0283e-03,  ...,  1.3428e-03,
          -3.2349e-03, -1.6174e-03],
         [ 1.1963e-01, -6.2109e-01,  2.1680e-01,  ..., -2.0874e-02,
           1.0193e-02,  2.7930e-01],
         [-3.1055e-01,  4.1016e-01,  8.4375e-01,  ...,  1.8457e-01,
          -7.3730e-02, -3.7305e-01],
         ...,
         [ 1.2891e-01,  4.1992e-01,  4.2188e-01,  ...,  1.4453e-01,
           3.4961e-01, -2.1118e-02],
         [ 7.2266e-01, -7.5000e-01,  1.8750e-01,  ..., -2.8516e-01,
           5.1172e-01,  2.8198e-02],
         [-1.3184e-01,  1.5332e-01,  1.0693e-01,  ...,  4.2578e-01,
          -6.1646e-03, -3.0859e-01]],

        [[ 5.3406e-04,  6.9580e-03, -4.0283e-03,  ...,  1.3428e-03,
          -3.2349e-03, -1.6174e-03],
         [ 1.1963e-01, -6.2109e-01,  2.1680e-01,  ..., -2.0874e-02,
           1.0193e-02,  2.7930e-01],
         [-3.1055e-01,  4.1016e-01,  8.4375e-01,  ...,  1.8457e-01,
          -7.3730e-02, -3.7305e-01],
         ...,
         [ 1.2891e-01,  4.1992e-01,  4.2188e-01,  ...,  1.4453e-01,
           3.4961e-01, -2.1118e-02],
         [ 7.2266e-01, -7.5000e-01,  1.8750e-01,  ..., -2.8516e-01,
           5.1172e-01,  2.8198e-02],
         [-1.3184e-01,  1.5332e-01,  1.0693e-01,  ...,  4.2578e-01,
          -6.1646e-03, -3.0859e-01]],

        ...,

        [[-3.9368e-03, -2.0752e-03,  3.1433e-03,  ...,  3.3569e-04,
           3.0212e-03, -7.6675e-04],
         [-3.5547e-01,  2.2339e-02, -1.4648e-01,  ..., -1.3770e-01,
          -2.1973e-01,  1.2024e-02],
         [ 8.7891e-02, -1.6699e-01,  6.2500e-02,  ...,  4.0039e-01,
          -2.6953e-01,  1.7383e-01],
         ...,
         [ 2.7222e-02,  2.2070e-01, -2.7930e-01,  ...,  1.5527e-01,
           8.2031e-02, -3.4790e-03],
         [ 1.3867e-01,  2.8198e-02,  8.3008e-02,  ..., -6.5308e-03,
           2.6367e-01, -1.5137e-01],
         [-1.4062e-01, -2.9297e-01, -4.2578e-01,  ...,  3.3691e-02,
          -9.7656e-02,  6.5918e-02]],

        [[-3.9368e-03, -2.0752e-03,  3.1433e-03,  ...,  3.3569e-04,
           3.0212e-03, -7.6675e-04],
         [-3.5547e-01,  2.2339e-02, -1.4648e-01,  ..., -1.3770e-01,
          -2.1973e-01,  1.2024e-02],
         [ 8.7891e-02, -1.6699e-01,  6.2500e-02,  ...,  4.0039e-01,
          -2.6953e-01,  1.7383e-01],
         ...,
         [ 2.7222e-02,  2.2070e-01, -2.7930e-01,  ...,  1.5527e-01,
           8.2031e-02, -3.4790e-03],
         [ 1.3867e-01,  2.8198e-02,  8.3008e-02,  ..., -6.5308e-03,
           2.6367e-01, -1.5137e-01],
         [-1.4062e-01, -2.9297e-01, -4.2578e-01,  ...,  3.3691e-02,
          -9.7656e-02,  6.5918e-02]],

        [[-3.9368e-03, -2.0752e-03,  3.1433e-03,  ...,  3.3569e-04,
           3.0212e-03, -7.6675e-04],
         [-3.5547e-01,  2.2339e-02, -1.4648e-01,  ..., -1.3770e-01,
          -2.1973e-01,  1.2024e-02],
         [ 8.7891e-02, -1.6699e-01,  6.2500e-02,  ...,  4.0039e-01,
          -2.6953e-01,  1.7383e-01],
         ...,
         [ 2.7222e-02,  2.2070e-01, -2.7930e-01,  ...,  1.5527e-01,
           8.2031e-02, -3.4790e-03],
         [ 1.3867e-01,  2.8198e-02,  8.3008e-02,  ..., -6.5308e-03,
           2.6367e-01, -1.5137e-01],
         [-1.4062e-01, -2.9297e-01, -4.2578e-01,  ...,  3.3691e-02,
          -9.7656e-02,  6.5918e-02]]], dtype=torch.bfloat16)), (tensor([[[ 0.0146,  0.0317, -0.0143,  ..., -0.2393, -0.6523, -0.1162],
         [ 1.6094, -1.1406, -0.0879,  ...,  0.4395,  2.6562, -0.4863],
         [ 2.5781, -0.8555, -0.3262,  ...,  0.2451,  3.0469, -0.1484],
         ...,
         [-0.6953,  0.5469, -1.6641,  ...,  2.4375,  3.4062,  1.3047],
         [-0.4199,  0.4336,  0.5078,  ...,  1.5391,  0.7188,  0.7188],
         [-0.5078,  2.0938, -0.6367,  ...,  1.0078,  2.1406, -2.1406]],

        [[ 0.0146,  0.0317, -0.0143,  ..., -0.2393, -0.6523, -0.1162],
         [ 1.6094, -1.1406, -0.0879,  ...,  0.4395,  2.6562, -0.4863],
         [ 2.5781, -0.8555, -0.3262,  ...,  0.2451,  3.0469, -0.1484],
         ...,
         [-0.6953,  0.5469, -1.6641,  ...,  2.4375,  3.4062,  1.3047],
         [-0.4199,  0.4336,  0.5078,  ...,  1.5391,  0.7188,  0.7188],
         [-0.5078,  2.0938, -0.6367,  ...,  1.0078,  2.1406, -2.1406]],

        [[ 0.0146,  0.0317, -0.0143,  ..., -0.2393, -0.6523, -0.1162],
         [ 1.6094, -1.1406, -0.0879,  ...,  0.4395,  2.6562, -0.4863],
         [ 2.5781, -0.8555, -0.3262,  ...,  0.2451,  3.0469, -0.1484],
         ...,
         [-0.6953,  0.5469, -1.6641,  ...,  2.4375,  3.4062,  1.3047],
         [-0.4199,  0.4336,  0.5078,  ...,  1.5391,  0.7188,  0.7188],
         [-0.5078,  2.0938, -0.6367,  ...,  1.0078,  2.1406, -2.1406]],

        ...,

        [[-0.0132, -0.0184,  0.0106,  ...,  0.0322, -0.2773,  0.1357],
         [ 2.0625, -1.0625, -1.7812,  ..., -0.1167, -0.0369, -0.5156],
         [ 0.3672, -0.5586, -0.3867,  ...,  0.9570, -0.5742, -0.1230],
         ...,
         [-1.0391,  0.9609, -1.2891,  ...,  1.7422,  1.3750, -1.1016],
         [-0.5234,  0.6836, -0.7266,  ...,  0.2891, -1.3828,  1.3281],
         [-0.7422,  0.3984, -0.5547,  ...,  1.7734,  0.4082, -4.3750]],

        [[-0.0132, -0.0184,  0.0106,  ...,  0.0322, -0.2773,  0.1357],
         [ 2.0625, -1.0625, -1.7812,  ..., -0.1167, -0.0369, -0.5156],
         [ 0.3672, -0.5586, -0.3867,  ...,  0.9570, -0.5742, -0.1230],
         ...,
         [-1.0391,  0.9609, -1.2891,  ...,  1.7422,  1.3750, -1.1016],
         [-0.5234,  0.6836, -0.7266,  ...,  0.2891, -1.3828,  1.3281],
         [-0.7422,  0.3984, -0.5547,  ...,  1.7734,  0.4082, -4.3750]],

        [[-0.0132, -0.0184,  0.0106,  ...,  0.0322, -0.2773,  0.1357],
         [ 2.0625, -1.0625, -1.7812,  ..., -0.1167, -0.0369, -0.5156],
         [ 0.3672, -0.5586, -0.3867,  ...,  0.9570, -0.5742, -0.1230],
         ...,
         [-1.0391,  0.9609, -1.2891,  ...,  1.7422,  1.3750, -1.1016],
         [-0.5234,  0.6836, -0.7266,  ...,  0.2891, -1.3828,  1.3281],
         [-0.7422,  0.3984, -0.5547,  ...,  1.7734,  0.4082, -4.3750]]],
       dtype=torch.bfloat16), tensor([[[-9.8267e-03,  3.4027e-03, -1.1963e-02,  ...,  4.8523e-03,
          -4.0894e-03, -9.4604e-03],
         [-1.2988e-01,  6.7578e-01, -2.5586e-01,  ..., -2.1191e-01,
          -2.3828e-01,  6.7188e-01],
         [ 3.7109e-01, -2.2363e-01, -2.2559e-01,  ...,  5.8594e-01,
           6.9531e-01, -2.0117e-01],
         ...,
         [ 5.6641e-02,  1.3867e-01,  7.0312e-01,  ...,  5.5469e-01,
           2.1191e-01,  1.2793e-01],
         [ 8.5938e-01, -5.3906e-01,  6.1328e-01,  ..., -4.2383e-01,
          -2.2168e-01,  5.3516e-01],
         [-1.4160e-01,  1.6357e-02, -2.3633e-01,  ...,  4.9316e-02,
           2.5391e-01, -2.0410e-01]],

        [[-9.8267e-03,  3.4027e-03, -1.1963e-02,  ...,  4.8523e-03,
          -4.0894e-03, -9.4604e-03],
         [-1.2988e-01,  6.7578e-01, -2.5586e-01,  ..., -2.1191e-01,
          -2.3828e-01,  6.7188e-01],
         [ 3.7109e-01, -2.2363e-01, -2.2559e-01,  ...,  5.8594e-01,
           6.9531e-01, -2.0117e-01],
         ...,
         [ 5.6641e-02,  1.3867e-01,  7.0312e-01,  ...,  5.5469e-01,
           2.1191e-01,  1.2793e-01],
         [ 8.5938e-01, -5.3906e-01,  6.1328e-01,  ..., -4.2383e-01,
          -2.2168e-01,  5.3516e-01],
         [-1.4160e-01,  1.6357e-02, -2.3633e-01,  ...,  4.9316e-02,
           2.5391e-01, -2.0410e-01]],

        [[-9.8267e-03,  3.4027e-03, -1.1963e-02,  ...,  4.8523e-03,
          -4.0894e-03, -9.4604e-03],
         [-1.2988e-01,  6.7578e-01, -2.5586e-01,  ..., -2.1191e-01,
          -2.3828e-01,  6.7188e-01],
         [ 3.7109e-01, -2.2363e-01, -2.2559e-01,  ...,  5.8594e-01,
           6.9531e-01, -2.0117e-01],
         ...,
         [ 5.6641e-02,  1.3867e-01,  7.0312e-01,  ...,  5.5469e-01,
           2.1191e-01,  1.2793e-01],
         [ 8.5938e-01, -5.3906e-01,  6.1328e-01,  ..., -4.2383e-01,
          -2.2168e-01,  5.3516e-01],
         [-1.4160e-01,  1.6357e-02, -2.3633e-01,  ...,  4.9316e-02,
           2.5391e-01, -2.0410e-01]],

        ...,

        [[-4.5967e-04, -4.8523e-03, -3.3447e-02,  ...,  6.7139e-03,
           7.4768e-03,  5.1880e-03],
         [-1.9629e-01, -2.9541e-02,  9.2163e-03,  ..., -6.4844e-01,
          -5.3516e-01,  3.8574e-02],
         [ 4.1602e-01, -6.7578e-01,  1.9531e-01,  ..., -4.4531e-01,
           3.5156e-02,  4.7070e-01],
         ...,
         [ 4.1602e-01,  5.5859e-01,  8.3923e-04,  ..., -2.9297e-01,
           7.3730e-02,  2.3340e-01],
         [ 2.1973e-01,  5.0391e-01,  4.7070e-01,  ..., -2.5391e-01,
          -7.0312e-02,  3.0859e-01],
         [-1.0547e-01, -6.4453e-01, -1.1670e-01,  ..., -1.3281e-01,
           3.1738e-02, -3.0469e-01]],

        [[-4.5967e-04, -4.8523e-03, -3.3447e-02,  ...,  6.7139e-03,
           7.4768e-03,  5.1880e-03],
         [-1.9629e-01, -2.9541e-02,  9.2163e-03,  ..., -6.4844e-01,
          -5.3516e-01,  3.8574e-02],
         [ 4.1602e-01, -6.7578e-01,  1.9531e-01,  ..., -4.4531e-01,
           3.5156e-02,  4.7070e-01],
         ...,
         [ 4.1602e-01,  5.5859e-01,  8.3923e-04,  ..., -2.9297e-01,
           7.3730e-02,  2.3340e-01],
         [ 2.1973e-01,  5.0391e-01,  4.7070e-01,  ..., -2.5391e-01,
          -7.0312e-02,  3.0859e-01],
         [-1.0547e-01, -6.4453e-01, -1.1670e-01,  ..., -1.3281e-01,
           3.1738e-02, -3.0469e-01]],

        [[-4.5967e-04, -4.8523e-03, -3.3447e-02,  ...,  6.7139e-03,
           7.4768e-03,  5.1880e-03],
         [-1.9629e-01, -2.9541e-02,  9.2163e-03,  ..., -6.4844e-01,
          -5.3516e-01,  3.8574e-02],
         [ 4.1602e-01, -6.7578e-01,  1.9531e-01,  ..., -4.4531e-01,
           3.5156e-02,  4.7070e-01],
         ...,
         [ 4.1602e-01,  5.5859e-01,  8.3923e-04,  ..., -2.9297e-01,
           7.3730e-02,  2.3340e-01],
         [ 2.1973e-01,  5.0391e-01,  4.7070e-01,  ..., -2.5391e-01,
          -7.0312e-02,  3.0859e-01],
         [-1.0547e-01, -6.4453e-01, -1.1670e-01,  ..., -1.3281e-01,
           3.1738e-02, -3.0469e-01]]], dtype=torch.bfloat16)), (tensor([[[-5.8289e-03, -5.2185e-03, -3.2043e-03,  ...,  1.9141e-01,
          -9.6191e-02, -6.4062e-01],
         [-3.3750e+00, -1.9297e+00,  2.3828e-01,  ..., -1.8359e-01,
           7.5391e-01,  2.9375e+00],
         [ 5.9375e-01, -2.5312e+00, -1.2266e+00,  ..., -3.8574e-02,
           8.2812e-01,  1.7812e+00],
         ...,
         [ 4.4375e+00,  2.1875e+00,  6.8359e-01,  ..., -1.2500e+00,
          -4.7070e-01,  2.8438e+00],
         [ 2.8906e+00,  1.8906e+00, -2.6367e-01,  ..., -2.2188e+00,
          -2.5977e-01,  3.6562e+00],
         [-2.9688e-01,  1.1016e+00, -2.8906e-01,  ...,  7.5781e-01,
          -2.6953e-01,  2.0938e+00]],

        [[-5.8289e-03, -5.2185e-03, -3.2043e-03,  ...,  1.9141e-01,
          -9.6191e-02, -6.4062e-01],
         [-3.3750e+00, -1.9297e+00,  2.3828e-01,  ..., -1.8359e-01,
           7.5391e-01,  2.9375e+00],
         [ 5.9375e-01, -2.5312e+00, -1.2266e+00,  ..., -3.8574e-02,
           8.2812e-01,  1.7812e+00],
         ...,
         [ 4.4375e+00,  2.1875e+00,  6.8359e-01,  ..., -1.2500e+00,
          -4.7070e-01,  2.8438e+00],
         [ 2.8906e+00,  1.8906e+00, -2.6367e-01,  ..., -2.2188e+00,
          -2.5977e-01,  3.6562e+00],
         [-2.9688e-01,  1.1016e+00, -2.8906e-01,  ...,  7.5781e-01,
          -2.6953e-01,  2.0938e+00]],

        [[-5.8289e-03, -5.2185e-03, -3.2043e-03,  ...,  1.9141e-01,
          -9.6191e-02, -6.4062e-01],
         [-3.3750e+00, -1.9297e+00,  2.3828e-01,  ..., -1.8359e-01,
           7.5391e-01,  2.9375e+00],
         [ 5.9375e-01, -2.5312e+00, -1.2266e+00,  ..., -3.8574e-02,
           8.2812e-01,  1.7812e+00],
         ...,
         [ 4.4375e+00,  2.1875e+00,  6.8359e-01,  ..., -1.2500e+00,
          -4.7070e-01,  2.8438e+00],
         [ 2.8906e+00,  1.8906e+00, -2.6367e-01,  ..., -2.2188e+00,
          -2.5977e-01,  3.6562e+00],
         [-2.9688e-01,  1.1016e+00, -2.8906e-01,  ...,  7.5781e-01,
          -2.6953e-01,  2.0938e+00]],

        ...,

        [[-2.0599e-03, -3.8338e-04,  5.0049e-03,  ...,  1.0437e-02,
           1.6235e-02,  2.6758e-01],
         [-6.3672e-01,  5.5469e-01, -4.7656e-01,  ..., -2.0625e+00,
           9.9609e-01, -1.2969e+00],
         [ 1.2266e+00,  1.9531e+00,  3.3984e-01,  ..., -7.0703e-01,
          -2.5781e+00,  2.0605e-01],
         ...,
         [ 2.7656e+00, -1.6562e+00, -4.3555e-01,  ...,  5.7031e-01,
           2.2266e-01, -4.2969e-01],
         [ 4.8340e-02, -6.9531e-01, -6.7969e-01,  ...,  1.0234e+00,
          -2.1719e+00, -1.3828e+00],
         [-1.5938e+00, -1.5469e+00,  3.7109e-01,  ...,  1.5469e+00,
          -1.0234e+00, -6.2891e-01]],

        [[-2.0599e-03, -3.8338e-04,  5.0049e-03,  ...,  1.0437e-02,
           1.6235e-02,  2.6758e-01],
         [-6.3672e-01,  5.5469e-01, -4.7656e-01,  ..., -2.0625e+00,
           9.9609e-01, -1.2969e+00],
         [ 1.2266e+00,  1.9531e+00,  3.3984e-01,  ..., -7.0703e-01,
          -2.5781e+00,  2.0605e-01],
         ...,
         [ 2.7656e+00, -1.6562e+00, -4.3555e-01,  ...,  5.7031e-01,
           2.2266e-01, -4.2969e-01],
         [ 4.8340e-02, -6.9531e-01, -6.7969e-01,  ...,  1.0234e+00,
          -2.1719e+00, -1.3828e+00],
         [-1.5938e+00, -1.5469e+00,  3.7109e-01,  ...,  1.5469e+00,
          -1.0234e+00, -6.2891e-01]],

        [[-2.0599e-03, -3.8338e-04,  5.0049e-03,  ...,  1.0437e-02,
           1.6235e-02,  2.6758e-01],
         [-6.3672e-01,  5.5469e-01, -4.7656e-01,  ..., -2.0625e+00,
           9.9609e-01, -1.2969e+00],
         [ 1.2266e+00,  1.9531e+00,  3.3984e-01,  ..., -7.0703e-01,
          -2.5781e+00,  2.0605e-01],
         ...,
         [ 2.7656e+00, -1.6562e+00, -4.3555e-01,  ...,  5.7031e-01,
           2.2266e-01, -4.2969e-01],
         [ 4.8340e-02, -6.9531e-01, -6.7969e-01,  ...,  1.0234e+00,
          -2.1719e+00, -1.3828e+00],
         [-1.5938e+00, -1.5469e+00,  3.7109e-01,  ...,  1.5469e+00,
          -1.0234e+00, -6.2891e-01]]], dtype=torch.bfloat16), tensor([[[-6.6833e-03,  9.1553e-03, -9.5215e-03,  ...,  9.5825e-03,
          -1.5747e-02,  8.4229e-03],
         [ 1.8164e-01,  5.7861e-02,  8.6719e-01,  ..., -2.4414e-01,
           3.2812e-01,  5.7422e-01],
         [ 1.6113e-01, -7.6953e-01,  4.9414e-01,  ..., -5.0391e-01,
          -2.1191e-01, -4.2578e-01],
         ...,
         [ 2.1191e-01, -1.1182e-01,  1.4160e-02,  ..., -2.5781e-01,
           6.4062e-01,  5.7422e-01],
         [ 1.1621e-01, -3.2031e-01, -2.4048e-02,  ..., -3.6133e-01,
           7.3730e-02, -2.4219e-01],
         [-2.8320e-01,  2.3047e-01,  2.9883e-01,  ..., -1.0254e-01,
           1.3594e+00, -2.2559e-01]],

        [[-6.6833e-03,  9.1553e-03, -9.5215e-03,  ...,  9.5825e-03,
          -1.5747e-02,  8.4229e-03],
         [ 1.8164e-01,  5.7861e-02,  8.6719e-01,  ..., -2.4414e-01,
           3.2812e-01,  5.7422e-01],
         [ 1.6113e-01, -7.6953e-01,  4.9414e-01,  ..., -5.0391e-01,
          -2.1191e-01, -4.2578e-01],
         ...,
         [ 2.1191e-01, -1.1182e-01,  1.4160e-02,  ..., -2.5781e-01,
           6.4062e-01,  5.7422e-01],
         [ 1.1621e-01, -3.2031e-01, -2.4048e-02,  ..., -3.6133e-01,
           7.3730e-02, -2.4219e-01],
         [-2.8320e-01,  2.3047e-01,  2.9883e-01,  ..., -1.0254e-01,
           1.3594e+00, -2.2559e-01]],

        [[-6.6833e-03,  9.1553e-03, -9.5215e-03,  ...,  9.5825e-03,
          -1.5747e-02,  8.4229e-03],
         [ 1.8164e-01,  5.7861e-02,  8.6719e-01,  ..., -2.4414e-01,
           3.2812e-01,  5.7422e-01],
         [ 1.6113e-01, -7.6953e-01,  4.9414e-01,  ..., -5.0391e-01,
          -2.1191e-01, -4.2578e-01],
         ...,
         [ 2.1191e-01, -1.1182e-01,  1.4160e-02,  ..., -2.5781e-01,
           6.4062e-01,  5.7422e-01],
         [ 1.1621e-01, -3.2031e-01, -2.4048e-02,  ..., -3.6133e-01,
           7.3730e-02, -2.4219e-01],
         [-2.8320e-01,  2.3047e-01,  2.9883e-01,  ..., -1.0254e-01,
           1.3594e+00, -2.2559e-01]],

        ...,

        [[ 7.4463e-03,  4.1504e-03, -7.1716e-03,  ...,  4.1199e-03,
           1.5625e-02, -1.1841e-02],
         [ 2.1289e-01,  2.9102e-01, -2.2363e-01,  ...,  1.0010e-01,
           2.1973e-01,  3.2812e-01],
         [-1.3672e-01,  5.5859e-01,  4.2534e-04,  ..., -3.8672e-01,
          -6.1328e-01, -5.4688e-01],
         ...,
         [ 1.4355e-01,  3.9844e-01, -5.7031e-01,  ..., -7.0703e-01,
          -1.1230e-01,  2.7930e-01],
         [-2.4512e-01, -9.1797e-02, -1.2402e-01,  ..., -5.3516e-01,
          -8.7402e-02,  2.3926e-01],
         [ 3.9844e-01,  9.9219e-01,  7.5195e-02,  ...,  3.6914e-01,
          -2.3535e-01, -8.0078e-02]],

        [[ 7.4463e-03,  4.1504e-03, -7.1716e-03,  ...,  4.1199e-03,
           1.5625e-02, -1.1841e-02],
         [ 2.1289e-01,  2.9102e-01, -2.2363e-01,  ...,  1.0010e-01,
           2.1973e-01,  3.2812e-01],
         [-1.3672e-01,  5.5859e-01,  4.2534e-04,  ..., -3.8672e-01,
          -6.1328e-01, -5.4688e-01],
         ...,
         [ 1.4355e-01,  3.9844e-01, -5.7031e-01,  ..., -7.0703e-01,
          -1.1230e-01,  2.7930e-01],
         [-2.4512e-01, -9.1797e-02, -1.2402e-01,  ..., -5.3516e-01,
          -8.7402e-02,  2.3926e-01],
         [ 3.9844e-01,  9.9219e-01,  7.5195e-02,  ...,  3.6914e-01,
          -2.3535e-01, -8.0078e-02]],

        [[ 7.4463e-03,  4.1504e-03, -7.1716e-03,  ...,  4.1199e-03,
           1.5625e-02, -1.1841e-02],
         [ 2.1289e-01,  2.9102e-01, -2.2363e-01,  ...,  1.0010e-01,
           2.1973e-01,  3.2812e-01],
         [-1.3672e-01,  5.5859e-01,  4.2534e-04,  ..., -3.8672e-01,
          -6.1328e-01, -5.4688e-01],
         ...,
         [ 1.4355e-01,  3.9844e-01, -5.7031e-01,  ..., -7.0703e-01,
          -1.1230e-01,  2.7930e-01],
         [-2.4512e-01, -9.1797e-02, -1.2402e-01,  ..., -5.3516e-01,
          -8.7402e-02,  2.3926e-01],
         [ 3.9844e-01,  9.9219e-01,  7.5195e-02,  ...,  3.6914e-01,
          -2.3535e-01, -8.0078e-02]]], dtype=torch.bfloat16)), (tensor([[[ 0.0128,  0.0156, -0.0175,  ...,  0.0996, -0.0767, -0.1533],
         [-1.1406, -0.2266, -0.6562,  ...,  1.5859,  0.4961,  0.7188],
         [-1.5938,  0.2578,  0.3047,  ...,  0.1641,  1.9609, -1.6484],
         ...,
         [ 0.5195, -0.9219, -0.9688,  ...,  0.1719,  1.0391, -2.8281],
         [ 0.2344, -0.1104,  0.5625,  ..., -2.2656,  2.6875,  1.8594],
         [ 2.0938,  1.2656,  0.6406,  ..., -0.5234, -0.3047, -0.0488]],

        [[ 0.0128,  0.0156, -0.0175,  ...,  0.0996, -0.0767, -0.1533],
         [-1.1406, -0.2266, -0.6562,  ...,  1.5859,  0.4961,  0.7188],
         [-1.5938,  0.2578,  0.3047,  ...,  0.1641,  1.9609, -1.6484],
         ...,
         [ 0.5195, -0.9219, -0.9688,  ...,  0.1719,  1.0391, -2.8281],
         [ 0.2344, -0.1104,  0.5625,  ..., -2.2656,  2.6875,  1.8594],
         [ 2.0938,  1.2656,  0.6406,  ..., -0.5234, -0.3047, -0.0488]],

        [[ 0.0128,  0.0156, -0.0175,  ...,  0.0996, -0.0767, -0.1533],
         [-1.1406, -0.2266, -0.6562,  ...,  1.5859,  0.4961,  0.7188],
         [-1.5938,  0.2578,  0.3047,  ...,  0.1641,  1.9609, -1.6484],
         ...,
         [ 0.5195, -0.9219, -0.9688,  ...,  0.1719,  1.0391, -2.8281],
         [ 0.2344, -0.1104,  0.5625,  ..., -2.2656,  2.6875,  1.8594],
         [ 2.0938,  1.2656,  0.6406,  ..., -0.5234, -0.3047, -0.0488]],

        ...,

        [[-0.0070, -0.0035, -0.0055,  ..., -0.0588, -0.0309,  0.0552],
         [ 1.9453,  1.6719,  0.0781,  ...,  1.7734, -0.5156,  1.0938],
         [ 2.8125,  1.0000,  0.5859,  ...,  1.2578,  0.5547,  1.7344],
         ...,
         [-0.3086, -1.2266, -0.3398,  ...,  0.7422,  0.0845, -0.8867],
         [-0.1670, -1.1250, -0.3242,  ...,  0.3008,  1.4531, -0.9492],
         [-1.3906, -0.5273, -1.0625,  ..., -0.7695, -0.3340, -2.1250]],

        [[-0.0070, -0.0035, -0.0055,  ..., -0.0588, -0.0309,  0.0552],
         [ 1.9453,  1.6719,  0.0781,  ...,  1.7734, -0.5156,  1.0938],
         [ 2.8125,  1.0000,  0.5859,  ...,  1.2578,  0.5547,  1.7344],
         ...,
         [-0.3086, -1.2266, -0.3398,  ...,  0.7422,  0.0845, -0.8867],
         [-0.1670, -1.1250, -0.3242,  ...,  0.3008,  1.4531, -0.9492],
         [-1.3906, -0.5273, -1.0625,  ..., -0.7695, -0.3340, -2.1250]],

        [[-0.0070, -0.0035, -0.0055,  ..., -0.0588, -0.0309,  0.0552],
         [ 1.9453,  1.6719,  0.0781,  ...,  1.7734, -0.5156,  1.0938],
         [ 2.8125,  1.0000,  0.5859,  ...,  1.2578,  0.5547,  1.7344],
         ...,
         [-0.3086, -1.2266, -0.3398,  ...,  0.7422,  0.0845, -0.8867],
         [-0.1670, -1.1250, -0.3242,  ...,  0.3008,  1.4531, -0.9492],
         [-1.3906, -0.5273, -1.0625,  ..., -0.7695, -0.3340, -2.1250]]],
       dtype=torch.bfloat16), tensor([[[-0.0125,  0.0048, -0.0069,  ...,  0.0050, -0.0082, -0.0070],
         [ 0.1729, -0.1914,  0.0452,  ...,  0.5156,  0.0698,  0.3691],
         [-1.0156,  0.7773,  0.3066,  ..., -0.1240,  0.8320,  0.8203],
         ...,
         [-0.4062,  0.7812, -0.6992,  ...,  0.2148, -0.0635, -0.7734],
         [-0.4746, -0.3242,  0.3184,  ...,  0.2715, -0.0131, -0.3809],
         [ 0.0299,  0.4707, -0.2832,  ...,  0.1592, -0.0796, -0.3047]],

        [[-0.0125,  0.0048, -0.0069,  ...,  0.0050, -0.0082, -0.0070],
         [ 0.1729, -0.1914,  0.0452,  ...,  0.5156,  0.0698,  0.3691],
         [-1.0156,  0.7773,  0.3066,  ..., -0.1240,  0.8320,  0.8203],
         ...,
         [-0.4062,  0.7812, -0.6992,  ...,  0.2148, -0.0635, -0.7734],
         [-0.4746, -0.3242,  0.3184,  ...,  0.2715, -0.0131, -0.3809],
         [ 0.0299,  0.4707, -0.2832,  ...,  0.1592, -0.0796, -0.3047]],

        [[-0.0125,  0.0048, -0.0069,  ...,  0.0050, -0.0082, -0.0070],
         [ 0.1729, -0.1914,  0.0452,  ...,  0.5156,  0.0698,  0.3691],
         [-1.0156,  0.7773,  0.3066,  ..., -0.1240,  0.8320,  0.8203],
         ...,
         [-0.4062,  0.7812, -0.6992,  ...,  0.2148, -0.0635, -0.7734],
         [-0.4746, -0.3242,  0.3184,  ...,  0.2715, -0.0131, -0.3809],
         [ 0.0299,  0.4707, -0.2832,  ...,  0.1592, -0.0796, -0.3047]],

        ...,

        [[-0.0022, -0.0028, -0.0028,  ..., -0.0060,  0.0087, -0.0074],
         [ 0.0129,  0.1001,  0.2002,  ..., -0.2734,  0.4023,  0.2168],
         [ 0.3848,  0.5156,  0.6914,  ..., -0.5273,  0.2832,  0.9570],
         ...,
         [ 0.2891,  0.0508, -0.5742,  ...,  0.3242, -0.5742, -0.2012],
         [-0.1445,  0.5586, -0.2139,  ..., -0.3965, -0.1055,  0.3711],
         [ 0.3828,  0.5156,  0.1934,  ...,  0.0635,  0.2578,  0.0045]],

        [[-0.0022, -0.0028, -0.0028,  ..., -0.0060,  0.0087, -0.0074],
         [ 0.0129,  0.1001,  0.2002,  ..., -0.2734,  0.4023,  0.2168],
         [ 0.3848,  0.5156,  0.6914,  ..., -0.5273,  0.2832,  0.9570],
         ...,
         [ 0.2891,  0.0508, -0.5742,  ...,  0.3242, -0.5742, -0.2012],
         [-0.1445,  0.5586, -0.2139,  ..., -0.3965, -0.1055,  0.3711],
         [ 0.3828,  0.5156,  0.1934,  ...,  0.0635,  0.2578,  0.0045]],

        [[-0.0022, -0.0028, -0.0028,  ..., -0.0060,  0.0087, -0.0074],
         [ 0.0129,  0.1001,  0.2002,  ..., -0.2734,  0.4023,  0.2168],
         [ 0.3848,  0.5156,  0.6914,  ..., -0.5273,  0.2832,  0.9570],
         ...,
         [ 0.2891,  0.0508, -0.5742,  ...,  0.3242, -0.5742, -0.2012],
         [-0.1445,  0.5586, -0.2139,  ..., -0.3965, -0.1055,  0.3711],
         [ 0.3828,  0.5156,  0.1934,  ...,  0.0635,  0.2578,  0.0045]]],
       dtype=torch.bfloat16)), (tensor([[[-1.5015e-02, -4.2725e-03, -8.5831e-04,  ..., -5.9375e-01,
          -1.5723e-01,  2.8076e-02],
         [-1.6641e+00,  2.8125e-01, -1.2188e+00,  ...,  2.2031e+00,
           6.6797e-01,  4.2188e-01],
         [-1.6406e-01,  1.0391e+00, -1.2695e-01,  ..., -1.5442e-02,
           2.4531e+00,  1.2109e+00],
         ...,
         [ 1.2656e+00, -9.4531e-01, -1.6562e+00,  ...,  1.2891e+00,
           2.1562e+00, -1.0234e+00],
         [ 5.5664e-02, -4.8828e-01, -8.7500e-01,  ...,  9.2578e-01,
           1.2500e+00, -6.9336e-02],
         [ 5.7812e-01, -1.1406e+00, -1.0000e+00,  ...,  2.0156e+00,
           3.9258e-01,  6.5625e-01]],

        [[-1.5015e-02, -4.2725e-03, -8.5831e-04,  ..., -5.9375e-01,
          -1.5723e-01,  2.8076e-02],
         [-1.6641e+00,  2.8125e-01, -1.2188e+00,  ...,  2.2031e+00,
           6.6797e-01,  4.2188e-01],
         [-1.6406e-01,  1.0391e+00, -1.2695e-01,  ..., -1.5442e-02,
           2.4531e+00,  1.2109e+00],
         ...,
         [ 1.2656e+00, -9.4531e-01, -1.6562e+00,  ...,  1.2891e+00,
           2.1562e+00, -1.0234e+00],
         [ 5.5664e-02, -4.8828e-01, -8.7500e-01,  ...,  9.2578e-01,
           1.2500e+00, -6.9336e-02],
         [ 5.7812e-01, -1.1406e+00, -1.0000e+00,  ...,  2.0156e+00,
           3.9258e-01,  6.5625e-01]],

        [[-1.5015e-02, -4.2725e-03, -8.5831e-04,  ..., -5.9375e-01,
          -1.5723e-01,  2.8076e-02],
         [-1.6641e+00,  2.8125e-01, -1.2188e+00,  ...,  2.2031e+00,
           6.6797e-01,  4.2188e-01],
         [-1.6406e-01,  1.0391e+00, -1.2695e-01,  ..., -1.5442e-02,
           2.4531e+00,  1.2109e+00],
         ...,
         [ 1.2656e+00, -9.4531e-01, -1.6562e+00,  ...,  1.2891e+00,
           2.1562e+00, -1.0234e+00],
         [ 5.5664e-02, -4.8828e-01, -8.7500e-01,  ...,  9.2578e-01,
           1.2500e+00, -6.9336e-02],
         [ 5.7812e-01, -1.1406e+00, -1.0000e+00,  ...,  2.0156e+00,
           3.9258e-01,  6.5625e-01]],

        ...,

        [[ 3.1433e-03,  8.4839e-03,  4.5967e-04,  ...,  1.1914e-01,
           5.0781e-02, -9.8267e-03],
         [ 1.4688e+00, -8.5547e-01, -1.2734e+00,  ..., -4.9805e-01,
          -4.1504e-02,  1.1016e+00],
         [-1.2422e+00, -9.2969e-01, -9.9219e-01,  ...,  5.4297e-01,
          -1.6328e+00,  5.5078e-01],
         ...,
         [-2.5156e+00,  3.3398e-01, -1.7422e+00,  ...,  3.5156e-01,
           3.1641e-01,  1.6797e+00],
         [ 2.0508e-01,  9.4531e-01, -4.4531e-01,  ...,  1.0859e+00,
          -6.9531e-01,  2.2031e+00],
         [-1.0547e-01,  7.7344e-01, -8.3203e-01,  ..., -4.9023e-01,
          -5.7373e-02,  1.0234e+00]],

        [[ 3.1433e-03,  8.4839e-03,  4.5967e-04,  ...,  1.1914e-01,
           5.0781e-02, -9.8267e-03],
         [ 1.4688e+00, -8.5547e-01, -1.2734e+00,  ..., -4.9805e-01,
          -4.1504e-02,  1.1016e+00],
         [-1.2422e+00, -9.2969e-01, -9.9219e-01,  ...,  5.4297e-01,
          -1.6328e+00,  5.5078e-01],
         ...,
         [-2.5156e+00,  3.3398e-01, -1.7422e+00,  ...,  3.5156e-01,
           3.1641e-01,  1.6797e+00],
         [ 2.0508e-01,  9.4531e-01, -4.4531e-01,  ...,  1.0859e+00,
          -6.9531e-01,  2.2031e+00],
         [-1.0547e-01,  7.7344e-01, -8.3203e-01,  ..., -4.9023e-01,
          -5.7373e-02,  1.0234e+00]],

        [[ 3.1433e-03,  8.4839e-03,  4.5967e-04,  ...,  1.1914e-01,
           5.0781e-02, -9.8267e-03],
         [ 1.4688e+00, -8.5547e-01, -1.2734e+00,  ..., -4.9805e-01,
          -4.1504e-02,  1.1016e+00],
         [-1.2422e+00, -9.2969e-01, -9.9219e-01,  ...,  5.4297e-01,
          -1.6328e+00,  5.5078e-01],
         ...,
         [-2.5156e+00,  3.3398e-01, -1.7422e+00,  ...,  3.5156e-01,
           3.1641e-01,  1.6797e+00],
         [ 2.0508e-01,  9.4531e-01, -4.4531e-01,  ...,  1.0859e+00,
          -6.9531e-01,  2.2031e+00],
         [-1.0547e-01,  7.7344e-01, -8.3203e-01,  ..., -4.9023e-01,
          -5.7373e-02,  1.0234e+00]]], dtype=torch.bfloat16), tensor([[[-2.3651e-03, -5.9204e-03,  4.3030e-03,  ...,  5.5313e-04,
           9.9182e-04, -5.1880e-03],
         [ 4.3750e-01,  5.5859e-01,  6.6016e-01,  ...,  1.4160e-01,
          -8.3594e-01, -5.1514e-02],
         [ 5.1953e-01, -4.2725e-02,  1.4062e-01,  ..., -4.8828e-01,
           2.5586e-01,  2.6367e-01],
         ...,
         [-3.1641e-01,  2.7734e-01, -3.7354e-02,  ...,  1.3281e-01,
          -4.4336e-01,  1.0156e+00],
         [ 1.6895e-01,  3.0469e-01, -4.8633e-01,  ..., -3.6914e-01,
           8.3008e-02,  7.0312e-01],
         [-8.0566e-02,  7.8516e-01, -3.4668e-02,  ..., -4.6143e-02,
          -8.0469e-01,  1.6504e-01]],

        [[-2.3651e-03, -5.9204e-03,  4.3030e-03,  ...,  5.5313e-04,
           9.9182e-04, -5.1880e-03],
         [ 4.3750e-01,  5.5859e-01,  6.6016e-01,  ...,  1.4160e-01,
          -8.3594e-01, -5.1514e-02],
         [ 5.1953e-01, -4.2725e-02,  1.4062e-01,  ..., -4.8828e-01,
           2.5586e-01,  2.6367e-01],
         ...,
         [-3.1641e-01,  2.7734e-01, -3.7354e-02,  ...,  1.3281e-01,
          -4.4336e-01,  1.0156e+00],
         [ 1.6895e-01,  3.0469e-01, -4.8633e-01,  ..., -3.6914e-01,
           8.3008e-02,  7.0312e-01],
         [-8.0566e-02,  7.8516e-01, -3.4668e-02,  ..., -4.6143e-02,
          -8.0469e-01,  1.6504e-01]],

        [[-2.3651e-03, -5.9204e-03,  4.3030e-03,  ...,  5.5313e-04,
           9.9182e-04, -5.1880e-03],
         [ 4.3750e-01,  5.5859e-01,  6.6016e-01,  ...,  1.4160e-01,
          -8.3594e-01, -5.1514e-02],
         [ 5.1953e-01, -4.2725e-02,  1.4062e-01,  ..., -4.8828e-01,
           2.5586e-01,  2.6367e-01],
         ...,
         [-3.1641e-01,  2.7734e-01, -3.7354e-02,  ...,  1.3281e-01,
          -4.4336e-01,  1.0156e+00],
         [ 1.6895e-01,  3.0469e-01, -4.8633e-01,  ..., -3.6914e-01,
           8.3008e-02,  7.0312e-01],
         [-8.0566e-02,  7.8516e-01, -3.4668e-02,  ..., -4.6143e-02,
          -8.0469e-01,  1.6504e-01]],

        ...,

        [[-4.8065e-04,  9.7275e-04,  5.5176e-02,  ..., -3.0518e-02,
          -1.1108e-02,  2.5635e-02],
         [-2.1680e-01, -1.3750e+00, -3.3984e-01,  ...,  4.8438e-01,
          -1.5039e-01,  7.4219e-01],
         [-1.3965e-01, -8.9453e-01, -3.2617e-01,  ...,  3.6914e-01,
          -7.4609e-01,  1.6699e-01],
         ...,
         [-8.6328e-01, -1.7383e-01,  1.3965e-01,  ...,  1.3125e+00,
          -2.1289e-01,  1.3672e+00],
         [-9.9219e-01,  2.3633e-01,  2.4609e-01,  ...,  6.0156e-01,
          -4.3750e-01,  4.5117e-01],
         [ 1.0234e+00, -7.7344e-01, -5.0781e-01,  ..., -1.6504e-01,
           1.6309e-01, -5.8203e-01]],

        [[-4.8065e-04,  9.7275e-04,  5.5176e-02,  ..., -3.0518e-02,
          -1.1108e-02,  2.5635e-02],
         [-2.1680e-01, -1.3750e+00, -3.3984e-01,  ...,  4.8438e-01,
          -1.5039e-01,  7.4219e-01],
         [-1.3965e-01, -8.9453e-01, -3.2617e-01,  ...,  3.6914e-01,
          -7.4609e-01,  1.6699e-01],
         ...,
         [-8.6328e-01, -1.7383e-01,  1.3965e-01,  ...,  1.3125e+00,
          -2.1289e-01,  1.3672e+00],
         [-9.9219e-01,  2.3633e-01,  2.4609e-01,  ...,  6.0156e-01,
          -4.3750e-01,  4.5117e-01],
         [ 1.0234e+00, -7.7344e-01, -5.0781e-01,  ..., -1.6504e-01,
           1.6309e-01, -5.8203e-01]],

        [[-4.8065e-04,  9.7275e-04,  5.5176e-02,  ..., -3.0518e-02,
          -1.1108e-02,  2.5635e-02],
         [-2.1680e-01, -1.3750e+00, -3.3984e-01,  ...,  4.8438e-01,
          -1.5039e-01,  7.4219e-01],
         [-1.3965e-01, -8.9453e-01, -3.2617e-01,  ...,  3.6914e-01,
          -7.4609e-01,  1.6699e-01],
         ...,
         [-8.6328e-01, -1.7383e-01,  1.3965e-01,  ...,  1.3125e+00,
          -2.1289e-01,  1.3672e+00],
         [-9.9219e-01,  2.3633e-01,  2.4609e-01,  ...,  6.0156e-01,
          -4.3750e-01,  4.5117e-01],
         [ 1.0234e+00, -7.7344e-01, -5.0781e-01,  ..., -1.6504e-01,
           1.6309e-01, -5.8203e-01]]], dtype=torch.bfloat16)), (tensor([[[ 1.1780e-02, -3.5858e-03, -3.2196e-03,  ..., -3.6328e-01,
          -1.3733e-02, -4.6875e-02],
         [-1.5234e-01,  1.4062e-01, -9.8438e-01,  ..., -8.7500e-01,
           1.2969e+00,  1.2969e+00],
         [ 8.9062e-01, -1.4062e+00, -6.4062e-01,  ...,  1.8828e+00,
           2.2656e+00,  1.3438e+00],
         ...,
         [-3.5400e-02,  8.9844e-01, -4.5312e-01,  ..., -8.9844e-02,
          -9.1016e-01,  2.4780e-02],
         [-6.1328e-01,  4.3359e-01, -4.3750e-01,  ...,  2.2344e+00,
          -2.4375e+00, -1.1562e+00],
         [-1.3047e+00,  1.6484e+00, -1.5312e+00,  ...,  1.7266e+00,
           2.4375e+00,  2.0625e+00]],

        [[ 1.1780e-02, -3.5858e-03, -3.2196e-03,  ..., -3.6328e-01,
          -1.3733e-02, -4.6875e-02],
         [-1.5234e-01,  1.4062e-01, -9.8438e-01,  ..., -8.7500e-01,
           1.2969e+00,  1.2969e+00],
         [ 8.9062e-01, -1.4062e+00, -6.4062e-01,  ...,  1.8828e+00,
           2.2656e+00,  1.3438e+00],
         ...,
         [-3.5400e-02,  8.9844e-01, -4.5312e-01,  ..., -8.9844e-02,
          -9.1016e-01,  2.4780e-02],
         [-6.1328e-01,  4.3359e-01, -4.3750e-01,  ...,  2.2344e+00,
          -2.4375e+00, -1.1562e+00],
         [-1.3047e+00,  1.6484e+00, -1.5312e+00,  ...,  1.7266e+00,
           2.4375e+00,  2.0625e+00]],

        [[ 1.1780e-02, -3.5858e-03, -3.2196e-03,  ..., -3.6328e-01,
          -1.3733e-02, -4.6875e-02],
         [-1.5234e-01,  1.4062e-01, -9.8438e-01,  ..., -8.7500e-01,
           1.2969e+00,  1.2969e+00],
         [ 8.9062e-01, -1.4062e+00, -6.4062e-01,  ...,  1.8828e+00,
           2.2656e+00,  1.3438e+00],
         ...,
         [-3.5400e-02,  8.9844e-01, -4.5312e-01,  ..., -8.9844e-02,
          -9.1016e-01,  2.4780e-02],
         [-6.1328e-01,  4.3359e-01, -4.3750e-01,  ...,  2.2344e+00,
          -2.4375e+00, -1.1562e+00],
         [-1.3047e+00,  1.6484e+00, -1.5312e+00,  ...,  1.7266e+00,
           2.4375e+00,  2.0625e+00]],

        ...,

        [[-6.5613e-04,  1.4099e-02, -2.3956e-03,  ..., -2.5977e-01,
          -1.6406e-01, -4.7607e-03],
         [-1.0156e+00,  3.3203e-01,  2.2656e+00,  ...,  1.3516e+00,
          -9.7266e-01, -3.9844e-01],
         [-2.8750e+00,  5.5469e-01,  2.0625e+00,  ...,  1.0391e+00,
          -9.7168e-02, -1.7109e+00],
         ...,
         [-1.2695e-01,  1.6406e+00,  1.6016e+00,  ..., -1.0703e+00,
           7.0312e-01, -4.1797e-01],
         [ 2.7344e-01,  2.4414e-02,  3.6719e-01,  ..., -1.3906e+00,
          -7.1094e-01,  5.7812e-01],
         [ 1.9219e+00, -8.5938e-01,  1.3750e+00,  ..., -3.5352e-01,
           4.5898e-01, -1.3477e-01]],

        [[-6.5613e-04,  1.4099e-02, -2.3956e-03,  ..., -2.5977e-01,
          -1.6406e-01, -4.7607e-03],
         [-1.0156e+00,  3.3203e-01,  2.2656e+00,  ...,  1.3516e+00,
          -9.7266e-01, -3.9844e-01],
         [-2.8750e+00,  5.5469e-01,  2.0625e+00,  ...,  1.0391e+00,
          -9.7168e-02, -1.7109e+00],
         ...,
         [-1.2695e-01,  1.6406e+00,  1.6016e+00,  ..., -1.0703e+00,
           7.0312e-01, -4.1797e-01],
         [ 2.7344e-01,  2.4414e-02,  3.6719e-01,  ..., -1.3906e+00,
          -7.1094e-01,  5.7812e-01],
         [ 1.9219e+00, -8.5938e-01,  1.3750e+00,  ..., -3.5352e-01,
           4.5898e-01, -1.3477e-01]],

        [[-6.5613e-04,  1.4099e-02, -2.3956e-03,  ..., -2.5977e-01,
          -1.6406e-01, -4.7607e-03],
         [-1.0156e+00,  3.3203e-01,  2.2656e+00,  ...,  1.3516e+00,
          -9.7266e-01, -3.9844e-01],
         [-2.8750e+00,  5.5469e-01,  2.0625e+00,  ...,  1.0391e+00,
          -9.7168e-02, -1.7109e+00],
         ...,
         [-1.2695e-01,  1.6406e+00,  1.6016e+00,  ..., -1.0703e+00,
           7.0312e-01, -4.1797e-01],
         [ 2.7344e-01,  2.4414e-02,  3.6719e-01,  ..., -1.3906e+00,
          -7.1094e-01,  5.7812e-01],
         [ 1.9219e+00, -8.5938e-01,  1.3750e+00,  ..., -3.5352e-01,
           4.5898e-01, -1.3477e-01]]], dtype=torch.bfloat16), tensor([[[ 3.7842e-02, -7.5073e-03, -7.4768e-04,  ...,  8.5449e-02,
           3.0762e-02,  1.2695e-02],
         [-6.9141e-01,  4.3359e-01, -2.8320e-02,  ...,  3.4570e-01,
          -2.2949e-01, -1.6699e-01],
         [-4.7070e-01,  7.6172e-02, -4.6631e-02,  ...,  1.8750e-01,
          -2.4219e-01, -5.5859e-01],
         ...,
         [-1.9775e-02,  4.9609e-01,  4.1992e-01,  ..., -4.6289e-01,
          -6.9531e-01, -3.0469e-01],
         [-6.0156e-01,  3.9844e-01,  6.2891e-01,  ..., -2.0605e-01,
          -7.5000e-01,  3.0664e-01],
         [ 3.6133e-01,  5.5469e-01, -7.0801e-02,  ...,  3.1641e-01,
          -4.0820e-01, -9.3359e-01]],

        [[ 3.7842e-02, -7.5073e-03, -7.4768e-04,  ...,  8.5449e-02,
           3.0762e-02,  1.2695e-02],
         [-6.9141e-01,  4.3359e-01, -2.8320e-02,  ...,  3.4570e-01,
          -2.2949e-01, -1.6699e-01],
         [-4.7070e-01,  7.6172e-02, -4.6631e-02,  ...,  1.8750e-01,
          -2.4219e-01, -5.5859e-01],
         ...,
         [-1.9775e-02,  4.9609e-01,  4.1992e-01,  ..., -4.6289e-01,
          -6.9531e-01, -3.0469e-01],
         [-6.0156e-01,  3.9844e-01,  6.2891e-01,  ..., -2.0605e-01,
          -7.5000e-01,  3.0664e-01],
         [ 3.6133e-01,  5.5469e-01, -7.0801e-02,  ...,  3.1641e-01,
          -4.0820e-01, -9.3359e-01]],

        [[ 3.7842e-02, -7.5073e-03, -7.4768e-04,  ...,  8.5449e-02,
           3.0762e-02,  1.2695e-02],
         [-6.9141e-01,  4.3359e-01, -2.8320e-02,  ...,  3.4570e-01,
          -2.2949e-01, -1.6699e-01],
         [-4.7070e-01,  7.6172e-02, -4.6631e-02,  ...,  1.8750e-01,
          -2.4219e-01, -5.5859e-01],
         ...,
         [-1.9775e-02,  4.9609e-01,  4.1992e-01,  ..., -4.6289e-01,
          -6.9531e-01, -3.0469e-01],
         [-6.0156e-01,  3.9844e-01,  6.2891e-01,  ..., -2.0605e-01,
          -7.5000e-01,  3.0664e-01],
         [ 3.6133e-01,  5.5469e-01, -7.0801e-02,  ...,  3.1641e-01,
          -4.0820e-01, -9.3359e-01]],

        ...,

        [[ 1.3580e-03, -1.8082e-03, -3.6812e-04,  ...,  1.6861e-03,
          -4.3945e-03,  9.2163e-03],
         [ 6.5625e-01, -1.4954e-03, -7.1484e-01,  ...,  2.5000e-01,
           8.4375e-01,  1.0781e+00],
         [-1.8555e-01, -9.1309e-02, -5.0000e-01,  ..., -4.0430e-01,
           1.7188e-01,  5.8203e-01],
         ...,
         [ 2.4512e-01,  2.2168e-01,  6.9922e-01,  ...,  2.1973e-02,
           2.6172e-01, -1.5039e-01],
         [-4.6631e-02, -2.8687e-02,  1.0078e+00,  ...,  5.0781e-01,
          -8.2422e-01, -2.4805e-01],
         [-1.1172e+00, -4.8242e-01,  5.8203e-01,  ..., -6.2988e-02,
           5.7031e-01, -1.7676e-01]],

        [[ 1.3580e-03, -1.8082e-03, -3.6812e-04,  ...,  1.6861e-03,
          -4.3945e-03,  9.2163e-03],
         [ 6.5625e-01, -1.4954e-03, -7.1484e-01,  ...,  2.5000e-01,
           8.4375e-01,  1.0781e+00],
         [-1.8555e-01, -9.1309e-02, -5.0000e-01,  ..., -4.0430e-01,
           1.7188e-01,  5.8203e-01],
         ...,
         [ 2.4512e-01,  2.2168e-01,  6.9922e-01,  ...,  2.1973e-02,
           2.6172e-01, -1.5039e-01],
         [-4.6631e-02, -2.8687e-02,  1.0078e+00,  ...,  5.0781e-01,
          -8.2422e-01, -2.4805e-01],
         [-1.1172e+00, -4.8242e-01,  5.8203e-01,  ..., -6.2988e-02,
           5.7031e-01, -1.7676e-01]],

        [[ 1.3580e-03, -1.8082e-03, -3.6812e-04,  ...,  1.6861e-03,
          -4.3945e-03,  9.2163e-03],
         [ 6.5625e-01, -1.4954e-03, -7.1484e-01,  ...,  2.5000e-01,
           8.4375e-01,  1.0781e+00],
         [-1.8555e-01, -9.1309e-02, -5.0000e-01,  ..., -4.0430e-01,
           1.7188e-01,  5.8203e-01],
         ...,
         [ 2.4512e-01,  2.2168e-01,  6.9922e-01,  ...,  2.1973e-02,
           2.6172e-01, -1.5039e-01],
         [-4.6631e-02, -2.8687e-02,  1.0078e+00,  ...,  5.0781e-01,
          -8.2422e-01, -2.4805e-01],
         [-1.1172e+00, -4.8242e-01,  5.8203e-01,  ..., -6.2988e-02,
           5.7031e-01, -1.7676e-01]]], dtype=torch.bfloat16)), (tensor([[[-1.1719e-02,  1.8799e-02,  3.8605e-03,  ..., -1.9629e-01,
          -1.7090e-01, -1.7578e-01],
         [ 3.2344e+00, -3.0781e+00,  3.3203e-01,  ...,  5.3516e-01,
          -2.3125e+00,  9.1016e-01],
         [ 3.6914e-01,  1.4160e-01,  1.7031e+00,  ...,  9.1016e-01,
           3.7500e-01, -1.6602e-01],
         ...,
         [-2.1406e+00,  1.2793e-01,  1.1641e+00,  ..., -9.4531e-01,
          -1.1328e+00,  2.2656e+00],
         [ 2.6953e-01,  3.5352e-01, -1.2188e+00,  ...,  1.6504e-01,
          -9.4531e-01,  1.9297e+00],
         [ 0.0000e+00,  1.1484e+00, -3.9062e-03,  ...,  8.5156e-01,
           2.5156e+00, -1.6406e+00]],

        [[-1.1719e-02,  1.8799e-02,  3.8605e-03,  ..., -1.9629e-01,
          -1.7090e-01, -1.7578e-01],
         [ 3.2344e+00, -3.0781e+00,  3.3203e-01,  ...,  5.3516e-01,
          -2.3125e+00,  9.1016e-01],
         [ 3.6914e-01,  1.4160e-01,  1.7031e+00,  ...,  9.1016e-01,
           3.7500e-01, -1.6602e-01],
         ...,
         [-2.1406e+00,  1.2793e-01,  1.1641e+00,  ..., -9.4531e-01,
          -1.1328e+00,  2.2656e+00],
         [ 2.6953e-01,  3.5352e-01, -1.2188e+00,  ...,  1.6504e-01,
          -9.4531e-01,  1.9297e+00],
         [ 0.0000e+00,  1.1484e+00, -3.9062e-03,  ...,  8.5156e-01,
           2.5156e+00, -1.6406e+00]],

        [[-1.1719e-02,  1.8799e-02,  3.8605e-03,  ..., -1.9629e-01,
          -1.7090e-01, -1.7578e-01],
         [ 3.2344e+00, -3.0781e+00,  3.3203e-01,  ...,  5.3516e-01,
          -2.3125e+00,  9.1016e-01],
         [ 3.6914e-01,  1.4160e-01,  1.7031e+00,  ...,  9.1016e-01,
           3.7500e-01, -1.6602e-01],
         ...,
         [-2.1406e+00,  1.2793e-01,  1.1641e+00,  ..., -9.4531e-01,
          -1.1328e+00,  2.2656e+00],
         [ 2.6953e-01,  3.5352e-01, -1.2188e+00,  ...,  1.6504e-01,
          -9.4531e-01,  1.9297e+00],
         [ 0.0000e+00,  1.1484e+00, -3.9062e-03,  ...,  8.5156e-01,
           2.5156e+00, -1.6406e+00]],

        ...,

        [[ 6.1798e-04,  1.2024e-02, -1.5076e-02,  ...,  1.9141e-01,
           4.7656e-01,  4.5703e-01],
         [-3.6406e+00, -1.2109e-01,  1.8203e+00,  ...,  1.6797e+00,
          -2.1719e+00,  1.7109e+00],
         [-1.5938e+00,  1.7383e-01,  1.0625e+00,  ...,  1.1250e+00,
          -4.1875e+00,  7.7637e-02],
         ...,
         [ 1.2188e+00,  1.4297e+00, -6.4844e-01,  ..., -4.9609e-01,
          -3.7031e+00,  1.6016e-01],
         [ 7.7344e-01,  7.6172e-01, -4.4531e-01,  ..., -7.7148e-02,
          -3.6094e+00, -1.8906e+00],
         [ 1.7344e+00, -2.1094e+00,  1.2578e+00,  ...,  2.2031e+00,
          -1.5312e+00,  9.9609e-01]],

        [[ 6.1798e-04,  1.2024e-02, -1.5076e-02,  ...,  1.9141e-01,
           4.7656e-01,  4.5703e-01],
         [-3.6406e+00, -1.2109e-01,  1.8203e+00,  ...,  1.6797e+00,
          -2.1719e+00,  1.7109e+00],
         [-1.5938e+00,  1.7383e-01,  1.0625e+00,  ...,  1.1250e+00,
          -4.1875e+00,  7.7637e-02],
         ...,
         [ 1.2188e+00,  1.4297e+00, -6.4844e-01,  ..., -4.9609e-01,
          -3.7031e+00,  1.6016e-01],
         [ 7.7344e-01,  7.6172e-01, -4.4531e-01,  ..., -7.7148e-02,
          -3.6094e+00, -1.8906e+00],
         [ 1.7344e+00, -2.1094e+00,  1.2578e+00,  ...,  2.2031e+00,
          -1.5312e+00,  9.9609e-01]],

        [[ 6.1798e-04,  1.2024e-02, -1.5076e-02,  ...,  1.9141e-01,
           4.7656e-01,  4.5703e-01],
         [-3.6406e+00, -1.2109e-01,  1.8203e+00,  ...,  1.6797e+00,
          -2.1719e+00,  1.7109e+00],
         [-1.5938e+00,  1.7383e-01,  1.0625e+00,  ...,  1.1250e+00,
          -4.1875e+00,  7.7637e-02],
         ...,
         [ 1.2188e+00,  1.4297e+00, -6.4844e-01,  ..., -4.9609e-01,
          -3.7031e+00,  1.6016e-01],
         [ 7.7344e-01,  7.6172e-01, -4.4531e-01,  ..., -7.7148e-02,
          -3.6094e+00, -1.8906e+00],
         [ 1.7344e+00, -2.1094e+00,  1.2578e+00,  ...,  2.2031e+00,
          -1.5312e+00,  9.9609e-01]]], dtype=torch.bfloat16), tensor([[[-3.8300e-03,  1.1749e-03, -5.0964e-03,  ...,  1.0193e-02,
           1.3489e-02, -1.4221e-02],
         [ 2.2559e-01, -1.7212e-02, -2.6562e-01,  ...,  2.5586e-01,
           1.8457e-01,  7.3828e-01],
         [-4.7656e-01,  3.0664e-01,  2.2559e-01,  ..., -1.1084e-01,
          -3.3594e-01, -2.4023e-01],
         ...,
         [ 8.0859e-01,  3.3203e-01, -1.1953e+00,  ..., -3.5547e-01,
           8.0078e-02, -1.6992e-01],
         [ 2.3242e-01,  1.1377e-01, -5.1953e-01,  ..., -7.8613e-02,
          -1.3086e-01, -4.4922e-01],
         [ 9.6680e-02, -5.1562e-01,  6.2109e-01,  ..., -1.8359e-01,
           3.6133e-01,  4.0820e-01]],

        [[-3.8300e-03,  1.1749e-03, -5.0964e-03,  ...,  1.0193e-02,
           1.3489e-02, -1.4221e-02],
         [ 2.2559e-01, -1.7212e-02, -2.6562e-01,  ...,  2.5586e-01,
           1.8457e-01,  7.3828e-01],
         [-4.7656e-01,  3.0664e-01,  2.2559e-01,  ..., -1.1084e-01,
          -3.3594e-01, -2.4023e-01],
         ...,
         [ 8.0859e-01,  3.3203e-01, -1.1953e+00,  ..., -3.5547e-01,
           8.0078e-02, -1.6992e-01],
         [ 2.3242e-01,  1.1377e-01, -5.1953e-01,  ..., -7.8613e-02,
          -1.3086e-01, -4.4922e-01],
         [ 9.6680e-02, -5.1562e-01,  6.2109e-01,  ..., -1.8359e-01,
           3.6133e-01,  4.0820e-01]],

        [[-3.8300e-03,  1.1749e-03, -5.0964e-03,  ...,  1.0193e-02,
           1.3489e-02, -1.4221e-02],
         [ 2.2559e-01, -1.7212e-02, -2.6562e-01,  ...,  2.5586e-01,
           1.8457e-01,  7.3828e-01],
         [-4.7656e-01,  3.0664e-01,  2.2559e-01,  ..., -1.1084e-01,
          -3.3594e-01, -2.4023e-01],
         ...,
         [ 8.0859e-01,  3.3203e-01, -1.1953e+00,  ..., -3.5547e-01,
           8.0078e-02, -1.6992e-01],
         [ 2.3242e-01,  1.1377e-01, -5.1953e-01,  ..., -7.8613e-02,
          -1.3086e-01, -4.4922e-01],
         [ 9.6680e-02, -5.1562e-01,  6.2109e-01,  ..., -1.8359e-01,
           3.6133e-01,  4.0820e-01]],

        ...,

        [[-8.8501e-03, -3.8452e-03, -2.5940e-03,  ...,  2.5635e-03,
           3.6469e-03, -3.3875e-03],
         [ 1.6504e-01, -4.8242e-01, -6.4453e-02,  ..., -2.0508e-01,
          -2.4902e-01, -2.7344e-01],
         [-7.1484e-01,  8.3203e-01,  2.2656e-01,  ..., -7.3242e-02,
          -9.1797e-01,  1.1016e+00],
         ...,
         [-3.3594e-01,  2.6367e-01, -4.7852e-01,  ..., -1.7578e-01,
          -3.4766e-01, -3.2422e-01],
         [-1.2891e-01,  4.1992e-01, -3.9258e-01,  ...,  3.5742e-01,
          -5.8594e-02, -1.7090e-01],
         [ 9.7656e-01, -3.3203e-02, -8.0078e-01,  ..., -7.3730e-02,
           1.7383e-01,  6.6406e-01]],

        [[-8.8501e-03, -3.8452e-03, -2.5940e-03,  ...,  2.5635e-03,
           3.6469e-03, -3.3875e-03],
         [ 1.6504e-01, -4.8242e-01, -6.4453e-02,  ..., -2.0508e-01,
          -2.4902e-01, -2.7344e-01],
         [-7.1484e-01,  8.3203e-01,  2.2656e-01,  ..., -7.3242e-02,
          -9.1797e-01,  1.1016e+00],
         ...,
         [-3.3594e-01,  2.6367e-01, -4.7852e-01,  ..., -1.7578e-01,
          -3.4766e-01, -3.2422e-01],
         [-1.2891e-01,  4.1992e-01, -3.9258e-01,  ...,  3.5742e-01,
          -5.8594e-02, -1.7090e-01],
         [ 9.7656e-01, -3.3203e-02, -8.0078e-01,  ..., -7.3730e-02,
           1.7383e-01,  6.6406e-01]],

        [[-8.8501e-03, -3.8452e-03, -2.5940e-03,  ...,  2.5635e-03,
           3.6469e-03, -3.3875e-03],
         [ 1.6504e-01, -4.8242e-01, -6.4453e-02,  ..., -2.0508e-01,
          -2.4902e-01, -2.7344e-01],
         [-7.1484e-01,  8.3203e-01,  2.2656e-01,  ..., -7.3242e-02,
          -9.1797e-01,  1.1016e+00],
         ...,
         [-3.3594e-01,  2.6367e-01, -4.7852e-01,  ..., -1.7578e-01,
          -3.4766e-01, -3.2422e-01],
         [-1.2891e-01,  4.1992e-01, -3.9258e-01,  ...,  3.5742e-01,
          -5.8594e-02, -1.7090e-01],
         [ 9.7656e-01, -3.3203e-02, -8.0078e-01,  ..., -7.3730e-02,
           1.7383e-01,  6.6406e-01]]], dtype=torch.bfloat16)), (tensor([[[ 3.0060e-03,  8.0872e-04, -1.0437e-02,  ..., -2.7344e-02,
          -3.0518e-03, -2.2461e-01],
         [-1.8594e+00,  2.4688e+00,  2.5000e+00,  ...,  1.6328e+00,
          -5.8350e-02,  1.6328e+00],
         [ 1.2500e+00, -1.9336e-01,  1.2656e+00,  ...,  1.1562e+00,
          -5.8203e-01,  1.1562e+00],
         ...,
         [ 1.4219e+00,  2.2266e-01,  3.5938e-01,  ...,  2.0781e+00,
          -4.6875e-01, -3.9258e-01],
         [ 2.9883e-01, -2.2461e-02,  1.1963e-01,  ...,  2.1406e+00,
           4.1504e-02, -2.5781e+00],
         [-1.2812e+00,  1.1797e+00,  1.4844e+00,  ...,  8.4766e-01,
          -5.6250e-01, -8.7109e-01]],

        [[ 3.0060e-03,  8.0872e-04, -1.0437e-02,  ..., -2.7344e-02,
          -3.0518e-03, -2.2461e-01],
         [-1.8594e+00,  2.4688e+00,  2.5000e+00,  ...,  1.6328e+00,
          -5.8350e-02,  1.6328e+00],
         [ 1.2500e+00, -1.9336e-01,  1.2656e+00,  ...,  1.1562e+00,
          -5.8203e-01,  1.1562e+00],
         ...,
         [ 1.4219e+00,  2.2266e-01,  3.5938e-01,  ...,  2.0781e+00,
          -4.6875e-01, -3.9258e-01],
         [ 2.9883e-01, -2.2461e-02,  1.1963e-01,  ...,  2.1406e+00,
           4.1504e-02, -2.5781e+00],
         [-1.2812e+00,  1.1797e+00,  1.4844e+00,  ...,  8.4766e-01,
          -5.6250e-01, -8.7109e-01]],

        [[ 3.0060e-03,  8.0872e-04, -1.0437e-02,  ..., -2.7344e-02,
          -3.0518e-03, -2.2461e-01],
         [-1.8594e+00,  2.4688e+00,  2.5000e+00,  ...,  1.6328e+00,
          -5.8350e-02,  1.6328e+00],
         [ 1.2500e+00, -1.9336e-01,  1.2656e+00,  ...,  1.1562e+00,
          -5.8203e-01,  1.1562e+00],
         ...,
         [ 1.4219e+00,  2.2266e-01,  3.5938e-01,  ...,  2.0781e+00,
          -4.6875e-01, -3.9258e-01],
         [ 2.9883e-01, -2.2461e-02,  1.1963e-01,  ...,  2.1406e+00,
           4.1504e-02, -2.5781e+00],
         [-1.2812e+00,  1.1797e+00,  1.4844e+00,  ...,  8.4766e-01,
          -5.6250e-01, -8.7109e-01]],

        ...,

        [[-8.2779e-04,  5.0964e-03, -1.3123e-03,  ...,  2.4609e-01,
           7.2656e-01,  5.5078e-01],
         [-1.8203e+00,  7.3047e-01, -2.6875e+00,  ...,  4.5312e-01,
          -3.4219e+00,  3.5156e-01],
         [ 1.9824e-01,  3.9453e-01, -7.2266e-01,  ...,  5.5859e-01,
          -6.7500e+00,  4.4141e-01],
         ...,
         [ 5.8984e-01,  9.4922e-01, -6.9141e-01,  ...,  1.9922e+00,
          -4.2812e+00,  2.5625e+00],
         [ 4.2188e-01,  6.3281e-01,  2.0508e-02,  ...,  9.2773e-02,
          -6.5312e+00,  3.1562e+00],
         [ 5.8594e-01,  5.3516e-01, -6.2500e-01,  ..., -8.4375e-01,
          -2.8594e+00,  2.3906e+00]],

        [[-8.2779e-04,  5.0964e-03, -1.3123e-03,  ...,  2.4609e-01,
           7.2656e-01,  5.5078e-01],
         [-1.8203e+00,  7.3047e-01, -2.6875e+00,  ...,  4.5312e-01,
          -3.4219e+00,  3.5156e-01],
         [ 1.9824e-01,  3.9453e-01, -7.2266e-01,  ...,  5.5859e-01,
          -6.7500e+00,  4.4141e-01],
         ...,
         [ 5.8984e-01,  9.4922e-01, -6.9141e-01,  ...,  1.9922e+00,
          -4.2812e+00,  2.5625e+00],
         [ 4.2188e-01,  6.3281e-01,  2.0508e-02,  ...,  9.2773e-02,
          -6.5312e+00,  3.1562e+00],
         [ 5.8594e-01,  5.3516e-01, -6.2500e-01,  ..., -8.4375e-01,
          -2.8594e+00,  2.3906e+00]],

        [[-8.2779e-04,  5.0964e-03, -1.3123e-03,  ...,  2.4609e-01,
           7.2656e-01,  5.5078e-01],
         [-1.8203e+00,  7.3047e-01, -2.6875e+00,  ...,  4.5312e-01,
          -3.4219e+00,  3.5156e-01],
         [ 1.9824e-01,  3.9453e-01, -7.2266e-01,  ...,  5.5859e-01,
          -6.7500e+00,  4.4141e-01],
         ...,
         [ 5.8984e-01,  9.4922e-01, -6.9141e-01,  ...,  1.9922e+00,
          -4.2812e+00,  2.5625e+00],
         [ 4.2188e-01,  6.3281e-01,  2.0508e-02,  ...,  9.2773e-02,
          -6.5312e+00,  3.1562e+00],
         [ 5.8594e-01,  5.3516e-01, -6.2500e-01,  ..., -8.4375e-01,
          -2.8594e+00,  2.3906e+00]]], dtype=torch.bfloat16), tensor([[[-1.3046e-03,  9.9487e-03, -1.2756e-02,  ..., -1.3000e-02,
           3.1281e-03, -3.1891e-03],
         [ 3.8672e-01, -2.0215e-01, -6.1035e-02,  ...,  8.5449e-02,
           4.7656e-01,  5.9570e-02],
         [ 1.0391e+00, -6.9336e-02, -2.5391e-01,  ..., -3.4375e-01,
          -4.6387e-02, -2.8906e-01],
         ...,
         [-6.8359e-01,  3.2617e-01,  2.4512e-01,  ..., -1.0742e-01,
          -2.1484e-01, -4.5117e-01],
         [-1.0791e-01,  4.7656e-01, -4.2773e-01,  ...,  8.7891e-01,
           1.3379e-01,  7.1777e-02],
         [ 7.0801e-02,  3.0029e-02, -8.9355e-02,  ..., -6.2500e-02,
          -4.1211e-01, -4.1016e-01]],

        [[-1.3046e-03,  9.9487e-03, -1.2756e-02,  ..., -1.3000e-02,
           3.1281e-03, -3.1891e-03],
         [ 3.8672e-01, -2.0215e-01, -6.1035e-02,  ...,  8.5449e-02,
           4.7656e-01,  5.9570e-02],
         [ 1.0391e+00, -6.9336e-02, -2.5391e-01,  ..., -3.4375e-01,
          -4.6387e-02, -2.8906e-01],
         ...,
         [-6.8359e-01,  3.2617e-01,  2.4512e-01,  ..., -1.0742e-01,
          -2.1484e-01, -4.5117e-01],
         [-1.0791e-01,  4.7656e-01, -4.2773e-01,  ...,  8.7891e-01,
           1.3379e-01,  7.1777e-02],
         [ 7.0801e-02,  3.0029e-02, -8.9355e-02,  ..., -6.2500e-02,
          -4.1211e-01, -4.1016e-01]],

        [[-1.3046e-03,  9.9487e-03, -1.2756e-02,  ..., -1.3000e-02,
           3.1281e-03, -3.1891e-03],
         [ 3.8672e-01, -2.0215e-01, -6.1035e-02,  ...,  8.5449e-02,
           4.7656e-01,  5.9570e-02],
         [ 1.0391e+00, -6.9336e-02, -2.5391e-01,  ..., -3.4375e-01,
          -4.6387e-02, -2.8906e-01],
         ...,
         [-6.8359e-01,  3.2617e-01,  2.4512e-01,  ..., -1.0742e-01,
          -2.1484e-01, -4.5117e-01],
         [-1.0791e-01,  4.7656e-01, -4.2773e-01,  ...,  8.7891e-01,
           1.3379e-01,  7.1777e-02],
         [ 7.0801e-02,  3.0029e-02, -8.9355e-02,  ..., -6.2500e-02,
          -4.1211e-01, -4.1016e-01]],

        ...,

        [[ 2.9785e-02, -4.5166e-02,  2.2949e-02,  ...,  2.6001e-02,
           3.2715e-02,  8.3618e-03],
         [-6.6016e-01, -8.2520e-02, -2.0410e-01,  ...,  9.5825e-03,
          -3.3008e-01, -7.3730e-02],
         [-8.3203e-01,  3.6719e-01, -1.0391e+00,  ..., -5.0781e-01,
           1.6211e-01, -2.3242e-01],
         ...,
         [-5.4688e-01,  3.5742e-01, -8.1250e-01,  ...,  1.0234e+00,
          -6.2109e-01,  5.7031e-01],
         [-2.0215e-01, -7.9102e-02, -1.5781e+00,  ...,  3.5938e-01,
          -2.9883e-01, -2.0142e-02],
         [-6.3281e-01, -3.1836e-01, -4.6484e-01,  ..., -2.8320e-01,
          -1.5430e-01, -1.0742e-01]],

        [[ 2.9785e-02, -4.5166e-02,  2.2949e-02,  ...,  2.6001e-02,
           3.2715e-02,  8.3618e-03],
         [-6.6016e-01, -8.2520e-02, -2.0410e-01,  ...,  9.5825e-03,
          -3.3008e-01, -7.3730e-02],
         [-8.3203e-01,  3.6719e-01, -1.0391e+00,  ..., -5.0781e-01,
           1.6211e-01, -2.3242e-01],
         ...,
         [-5.4688e-01,  3.5742e-01, -8.1250e-01,  ...,  1.0234e+00,
          -6.2109e-01,  5.7031e-01],
         [-2.0215e-01, -7.9102e-02, -1.5781e+00,  ...,  3.5938e-01,
          -2.9883e-01, -2.0142e-02],
         [-6.3281e-01, -3.1836e-01, -4.6484e-01,  ..., -2.8320e-01,
          -1.5430e-01, -1.0742e-01]],

        [[ 2.9785e-02, -4.5166e-02,  2.2949e-02,  ...,  2.6001e-02,
           3.2715e-02,  8.3618e-03],
         [-6.6016e-01, -8.2520e-02, -2.0410e-01,  ...,  9.5825e-03,
          -3.3008e-01, -7.3730e-02],
         [-8.3203e-01,  3.6719e-01, -1.0391e+00,  ..., -5.0781e-01,
           1.6211e-01, -2.3242e-01],
         ...,
         [-5.4688e-01,  3.5742e-01, -8.1250e-01,  ...,  1.0234e+00,
          -6.2109e-01,  5.7031e-01],
         [-2.0215e-01, -7.9102e-02, -1.5781e+00,  ...,  3.5938e-01,
          -2.9883e-01, -2.0142e-02],
         [-6.3281e-01, -3.1836e-01, -4.6484e-01,  ..., -2.8320e-01,
          -1.5430e-01, -1.0742e-01]]], dtype=torch.bfloat16)), (tensor([[[-1.6602e-02,  7.8735e-03, -1.5869e-02,  ..., -4.2969e-01,
           3.4570e-01,  9.5215e-02],
         [-2.8281e+00, -1.3750e+00,  1.6641e+00,  ...,  1.5625e+00,
           3.0273e-02,  7.7637e-02],
         [ 5.7812e-01, -7.7344e-01,  1.5781e+00,  ...,  2.2969e+00,
           1.6641e+00, -5.8984e-01],
         ...,
         [ 1.8906e+00,  5.3906e-01,  5.9766e-01,  ...,  1.4453e+00,
          -1.3672e+00,  5.4297e-01],
         [ 2.9492e-01,  1.1875e+00, -7.7148e-02,  ...,  1.1797e+00,
          -1.1172e+00,  9.2188e-01],
         [-5.3906e-01,  2.0625e+00,  2.3750e+00,  ...,  7.0312e-01,
           5.9082e-02, -1.9297e+00]],

        [[-1.6602e-02,  7.8735e-03, -1.5869e-02,  ..., -4.2969e-01,
           3.4570e-01,  9.5215e-02],
         [-2.8281e+00, -1.3750e+00,  1.6641e+00,  ...,  1.5625e+00,
           3.0273e-02,  7.7637e-02],
         [ 5.7812e-01, -7.7344e-01,  1.5781e+00,  ...,  2.2969e+00,
           1.6641e+00, -5.8984e-01],
         ...,
         [ 1.8906e+00,  5.3906e-01,  5.9766e-01,  ...,  1.4453e+00,
          -1.3672e+00,  5.4297e-01],
         [ 2.9492e-01,  1.1875e+00, -7.7148e-02,  ...,  1.1797e+00,
          -1.1172e+00,  9.2188e-01],
         [-5.3906e-01,  2.0625e+00,  2.3750e+00,  ...,  7.0312e-01,
           5.9082e-02, -1.9297e+00]],

        [[-1.6602e-02,  7.8735e-03, -1.5869e-02,  ..., -4.2969e-01,
           3.4570e-01,  9.5215e-02],
         [-2.8281e+00, -1.3750e+00,  1.6641e+00,  ...,  1.5625e+00,
           3.0273e-02,  7.7637e-02],
         [ 5.7812e-01, -7.7344e-01,  1.5781e+00,  ...,  2.2969e+00,
           1.6641e+00, -5.8984e-01],
         ...,
         [ 1.8906e+00,  5.3906e-01,  5.9766e-01,  ...,  1.4453e+00,
          -1.3672e+00,  5.4297e-01],
         [ 2.9492e-01,  1.1875e+00, -7.7148e-02,  ...,  1.1797e+00,
          -1.1172e+00,  9.2188e-01],
         [-5.3906e-01,  2.0625e+00,  2.3750e+00,  ...,  7.0312e-01,
           5.9082e-02, -1.9297e+00]],

        ...,

        [[-2.3315e-02, -1.9043e-02, -2.8687e-03,  ...,  7.2754e-02,
          -2.9219e+00, -4.1602e-01],
         [ 3.3203e-01, -5.1953e-01, -1.9922e-01,  ...,  5.7422e-01,
           8.7500e+00,  5.1875e+00],
         [-2.8516e-01, -3.2422e-01, -1.7578e-01,  ...,  8.8281e-01,
           8.1875e+00,  5.2500e+00],
         ...,
         [-3.2617e-01,  5.9375e-01,  4.4727e-01,  ...,  4.8828e-01,
           8.1875e+00, -9.4922e-01],
         [ 3.9648e-01, -1.2988e-01, -4.2480e-02,  ...,  1.1562e+00,
           5.9375e+00, -2.1562e+00],
         [ 3.2227e-01,  1.4551e-01, -6.5625e-01,  ..., -1.8047e+00,
           7.4062e+00,  2.4844e+00]],

        [[-2.3315e-02, -1.9043e-02, -2.8687e-03,  ...,  7.2754e-02,
          -2.9219e+00, -4.1602e-01],
         [ 3.3203e-01, -5.1953e-01, -1.9922e-01,  ...,  5.7422e-01,
           8.7500e+00,  5.1875e+00],
         [-2.8516e-01, -3.2422e-01, -1.7578e-01,  ...,  8.8281e-01,
           8.1875e+00,  5.2500e+00],
         ...,
         [-3.2617e-01,  5.9375e-01,  4.4727e-01,  ...,  4.8828e-01,
           8.1875e+00, -9.4922e-01],
         [ 3.9648e-01, -1.2988e-01, -4.2480e-02,  ...,  1.1562e+00,
           5.9375e+00, -2.1562e+00],
         [ 3.2227e-01,  1.4551e-01, -6.5625e-01,  ..., -1.8047e+00,
           7.4062e+00,  2.4844e+00]],

        [[-2.3315e-02, -1.9043e-02, -2.8687e-03,  ...,  7.2754e-02,
          -2.9219e+00, -4.1602e-01],
         [ 3.3203e-01, -5.1953e-01, -1.9922e-01,  ...,  5.7422e-01,
           8.7500e+00,  5.1875e+00],
         [-2.8516e-01, -3.2422e-01, -1.7578e-01,  ...,  8.8281e-01,
           8.1875e+00,  5.2500e+00],
         ...,
         [-3.2617e-01,  5.9375e-01,  4.4727e-01,  ...,  4.8828e-01,
           8.1875e+00, -9.4922e-01],
         [ 3.9648e-01, -1.2988e-01, -4.2480e-02,  ...,  1.1562e+00,
           5.9375e+00, -2.1562e+00],
         [ 3.2227e-01,  1.4551e-01, -6.5625e-01,  ..., -1.8047e+00,
           7.4062e+00,  2.4844e+00]]], dtype=torch.bfloat16), tensor([[[-2.7832e-02, -1.5869e-02, -4.3701e-02,  ...,  5.1270e-03,
           9.5215e-03,  1.0620e-02],
         [ 5.9375e-01, -4.9414e-01,  4.2969e-01,  ...,  1.3086e-01,
          -1.9434e-01,  3.2812e-01],
         [ 2.7734e-01, -1.6504e-01,  7.1484e-01,  ...,  1.1797e+00,
           2.8906e-01,  2.6245e-02],
         ...,
         [ 2.7148e-01, -1.4531e+00, -3.6133e-01,  ...,  2.7148e-01,
           5.0000e-01,  3.4424e-02],
         [ 2.4512e-01, -1.9297e+00, -9.1016e-01,  ...,  1.0781e+00,
           3.9844e-01,  8.2031e-01],
         [ 5.4297e-01,  4.6387e-02,  3.7842e-02,  ..., -1.7090e-01,
           7.4609e-01,  1.2109e-01]],

        [[-2.7832e-02, -1.5869e-02, -4.3701e-02,  ...,  5.1270e-03,
           9.5215e-03,  1.0620e-02],
         [ 5.9375e-01, -4.9414e-01,  4.2969e-01,  ...,  1.3086e-01,
          -1.9434e-01,  3.2812e-01],
         [ 2.7734e-01, -1.6504e-01,  7.1484e-01,  ...,  1.1797e+00,
           2.8906e-01,  2.6245e-02],
         ...,
         [ 2.7148e-01, -1.4531e+00, -3.6133e-01,  ...,  2.7148e-01,
           5.0000e-01,  3.4424e-02],
         [ 2.4512e-01, -1.9297e+00, -9.1016e-01,  ...,  1.0781e+00,
           3.9844e-01,  8.2031e-01],
         [ 5.4297e-01,  4.6387e-02,  3.7842e-02,  ..., -1.7090e-01,
           7.4609e-01,  1.2109e-01]],

        [[-2.7832e-02, -1.5869e-02, -4.3701e-02,  ...,  5.1270e-03,
           9.5215e-03,  1.0620e-02],
         [ 5.9375e-01, -4.9414e-01,  4.2969e-01,  ...,  1.3086e-01,
          -1.9434e-01,  3.2812e-01],
         [ 2.7734e-01, -1.6504e-01,  7.1484e-01,  ...,  1.1797e+00,
           2.8906e-01,  2.6245e-02],
         ...,
         [ 2.7148e-01, -1.4531e+00, -3.6133e-01,  ...,  2.7148e-01,
           5.0000e-01,  3.4424e-02],
         [ 2.4512e-01, -1.9297e+00, -9.1016e-01,  ...,  1.0781e+00,
           3.9844e-01,  8.2031e-01],
         [ 5.4297e-01,  4.6387e-02,  3.7842e-02,  ..., -1.7090e-01,
           7.4609e-01,  1.2109e-01]],

        ...,

        [[-2.3438e-02,  2.6978e-02,  1.3062e-02,  ...,  7.7438e-04,
           1.4954e-02,  1.2283e-03],
         [ 6.6406e-01, -8.2031e-01, -3.6523e-01,  ..., -5.2795e-03,
          -1.8359e-01,  2.6367e-01],
         [-1.0498e-01, -6.3281e-01, -3.3789e-01,  ..., -1.1250e+00,
          -7.9688e-01,  8.9844e-02],
         ...,
         [ 2.0625e+00, -4.9609e-01,  2.4805e-01,  ..., -3.1445e-01,
          -1.8516e+00, -1.2188e+00],
         [ 6.5234e-01,  1.9238e-01,  3.2422e-01,  ..., -2.5757e-02,
          -1.3359e+00, -8.3984e-01],
         [-2.1289e-01, -6.4062e-01, -8.8672e-01,  ..., -2.0703e-01,
           4.8340e-02,  7.3828e-01]],

        [[-2.3438e-02,  2.6978e-02,  1.3062e-02,  ...,  7.7438e-04,
           1.4954e-02,  1.2283e-03],
         [ 6.6406e-01, -8.2031e-01, -3.6523e-01,  ..., -5.2795e-03,
          -1.8359e-01,  2.6367e-01],
         [-1.0498e-01, -6.3281e-01, -3.3789e-01,  ..., -1.1250e+00,
          -7.9688e-01,  8.9844e-02],
         ...,
         [ 2.0625e+00, -4.9609e-01,  2.4805e-01,  ..., -3.1445e-01,
          -1.8516e+00, -1.2188e+00],
         [ 6.5234e-01,  1.9238e-01,  3.2422e-01,  ..., -2.5757e-02,
          -1.3359e+00, -8.3984e-01],
         [-2.1289e-01, -6.4062e-01, -8.8672e-01,  ..., -2.0703e-01,
           4.8340e-02,  7.3828e-01]],

        [[-2.3438e-02,  2.6978e-02,  1.3062e-02,  ...,  7.7438e-04,
           1.4954e-02,  1.2283e-03],
         [ 6.6406e-01, -8.2031e-01, -3.6523e-01,  ..., -5.2795e-03,
          -1.8359e-01,  2.6367e-01],
         [-1.0498e-01, -6.3281e-01, -3.3789e-01,  ..., -1.1250e+00,
          -7.9688e-01,  8.9844e-02],
         ...,
         [ 2.0625e+00, -4.9609e-01,  2.4805e-01,  ..., -3.1445e-01,
          -1.8516e+00, -1.2188e+00],
         [ 6.5234e-01,  1.9238e-01,  3.2422e-01,  ..., -2.5757e-02,
          -1.3359e+00, -8.3984e-01],
         [-2.1289e-01, -6.4062e-01, -8.8672e-01,  ..., -2.0703e-01,
           4.8340e-02,  7.3828e-01]]], dtype=torch.bfloat16)), (tensor([[[ 7.9727e-04,  1.0925e-02, -3.7842e-03,  ..., -1.3184e-01,
          -2.9297e-01, -9.7168e-02],
         [ 2.0000e+00,  7.6172e-01, -2.5938e+00,  ...,  9.4238e-02,
           6.7188e-01, -9.1406e-01],
         [ 6.0156e-01, -1.2344e+00, -1.1250e+00,  ...,  8.0078e-02,
           1.8672e+00, -4.7461e-01],
         ...,
         [-8.2422e-01,  7.7344e-01, -9.0820e-02,  ...,  1.4688e+00,
          -6.8750e-01, -7.1094e-01],
         [ 4.2969e-01,  3.3203e-01,  4.6875e-01,  ...,  3.3594e-01,
          -3.2617e-01, -1.0469e+00],
         [ 2.6562e-01,  1.7344e+00, -1.6250e+00,  ..., -2.1562e+00,
          -2.4219e+00, -2.5625e+00]],

        [[ 7.9727e-04,  1.0925e-02, -3.7842e-03,  ..., -1.3184e-01,
          -2.9297e-01, -9.7168e-02],
         [ 2.0000e+00,  7.6172e-01, -2.5938e+00,  ...,  9.4238e-02,
           6.7188e-01, -9.1406e-01],
         [ 6.0156e-01, -1.2344e+00, -1.1250e+00,  ...,  8.0078e-02,
           1.8672e+00, -4.7461e-01],
         ...,
         [-8.2422e-01,  7.7344e-01, -9.0820e-02,  ...,  1.4688e+00,
          -6.8750e-01, -7.1094e-01],
         [ 4.2969e-01,  3.3203e-01,  4.6875e-01,  ...,  3.3594e-01,
          -3.2617e-01, -1.0469e+00],
         [ 2.6562e-01,  1.7344e+00, -1.6250e+00,  ..., -2.1562e+00,
          -2.4219e+00, -2.5625e+00]],

        [[ 7.9727e-04,  1.0925e-02, -3.7842e-03,  ..., -1.3184e-01,
          -2.9297e-01, -9.7168e-02],
         [ 2.0000e+00,  7.6172e-01, -2.5938e+00,  ...,  9.4238e-02,
           6.7188e-01, -9.1406e-01],
         [ 6.0156e-01, -1.2344e+00, -1.1250e+00,  ...,  8.0078e-02,
           1.8672e+00, -4.7461e-01],
         ...,
         [-8.2422e-01,  7.7344e-01, -9.0820e-02,  ...,  1.4688e+00,
          -6.8750e-01, -7.1094e-01],
         [ 4.2969e-01,  3.3203e-01,  4.6875e-01,  ...,  3.3594e-01,
          -3.2617e-01, -1.0469e+00],
         [ 2.6562e-01,  1.7344e+00, -1.6250e+00,  ..., -2.1562e+00,
          -2.4219e+00, -2.5625e+00]],

        ...,

        [[ 3.6469e-03,  1.0742e-02,  3.5706e-03,  ..., -2.5156e+00,
          -1.3672e-02,  1.0498e-01],
         [ 1.4219e+00,  1.2500e-01,  5.7812e-01,  ...,  8.2500e+00,
          -4.3945e-01, -1.3047e+00],
         [-5.0781e-01,  6.7969e-01, -1.4766e+00,  ...,  8.4375e+00,
           1.2422e+00, -1.2578e+00],
         ...,
         [-2.9492e-01, -2.5977e-01,  3.3203e-01,  ...,  7.0938e+00,
          -1.1875e+00,  8.1250e-01],
         [ 3.6719e-01, -1.1562e+00,  1.0400e-01,  ...,  5.2812e+00,
          -8.9355e-02,  8.0469e-01],
         [ 1.1172e+00, -4.6875e-01,  1.1719e-02,  ...,  7.2812e+00,
          -2.0781e+00, -9.6484e-01]],

        [[ 3.6469e-03,  1.0742e-02,  3.5706e-03,  ..., -2.5156e+00,
          -1.3672e-02,  1.0498e-01],
         [ 1.4219e+00,  1.2500e-01,  5.7812e-01,  ...,  8.2500e+00,
          -4.3945e-01, -1.3047e+00],
         [-5.0781e-01,  6.7969e-01, -1.4766e+00,  ...,  8.4375e+00,
           1.2422e+00, -1.2578e+00],
         ...,
         [-2.9492e-01, -2.5977e-01,  3.3203e-01,  ...,  7.0938e+00,
          -1.1875e+00,  8.1250e-01],
         [ 3.6719e-01, -1.1562e+00,  1.0400e-01,  ...,  5.2812e+00,
          -8.9355e-02,  8.0469e-01],
         [ 1.1172e+00, -4.6875e-01,  1.1719e-02,  ...,  7.2812e+00,
          -2.0781e+00, -9.6484e-01]],

        [[ 3.6469e-03,  1.0742e-02,  3.5706e-03,  ..., -2.5156e+00,
          -1.3672e-02,  1.0498e-01],
         [ 1.4219e+00,  1.2500e-01,  5.7812e-01,  ...,  8.2500e+00,
          -4.3945e-01, -1.3047e+00],
         [-5.0781e-01,  6.7969e-01, -1.4766e+00,  ...,  8.4375e+00,
           1.2422e+00, -1.2578e+00],
         ...,
         [-2.9492e-01, -2.5977e-01,  3.3203e-01,  ...,  7.0938e+00,
          -1.1875e+00,  8.1250e-01],
         [ 3.6719e-01, -1.1562e+00,  1.0400e-01,  ...,  5.2812e+00,
          -8.9355e-02,  8.0469e-01],
         [ 1.1172e+00, -4.6875e-01,  1.1719e-02,  ...,  7.2812e+00,
          -2.0781e+00, -9.6484e-01]]], dtype=torch.bfloat16), tensor([[[-0.0193,  0.0225, -0.0356,  ..., -0.0254, -0.0118, -0.0156],
         [ 0.4336,  0.2480, -0.4023,  ...,  0.2812, -0.6602, -0.6875],
         [-0.0229, -0.0767, -0.1055,  ..., -0.5859, -0.3262,  0.0864],
         ...,
         [-0.3730, -0.6250,  0.5508,  ..., -0.0092, -1.2188, -0.0742],
         [ 0.1465, -0.7773, -0.0559,  ...,  0.0938, -0.7266,  0.3613],
         [ 0.7227,  1.0469, -0.1816,  ...,  0.7930,  0.0243, -0.2773]],

        [[-0.0193,  0.0225, -0.0356,  ..., -0.0254, -0.0118, -0.0156],
         [ 0.4336,  0.2480, -0.4023,  ...,  0.2812, -0.6602, -0.6875],
         [-0.0229, -0.0767, -0.1055,  ..., -0.5859, -0.3262,  0.0864],
         ...,
         [-0.3730, -0.6250,  0.5508,  ..., -0.0092, -1.2188, -0.0742],
         [ 0.1465, -0.7773, -0.0559,  ...,  0.0938, -0.7266,  0.3613],
         [ 0.7227,  1.0469, -0.1816,  ...,  0.7930,  0.0243, -0.2773]],

        [[-0.0193,  0.0225, -0.0356,  ..., -0.0254, -0.0118, -0.0156],
         [ 0.4336,  0.2480, -0.4023,  ...,  0.2812, -0.6602, -0.6875],
         [-0.0229, -0.0767, -0.1055,  ..., -0.5859, -0.3262,  0.0864],
         ...,
         [-0.3730, -0.6250,  0.5508,  ..., -0.0092, -1.2188, -0.0742],
         [ 0.1465, -0.7773, -0.0559,  ...,  0.0938, -0.7266,  0.3613],
         [ 0.7227,  1.0469, -0.1816,  ...,  0.7930,  0.0243, -0.2773]],

        ...,

        [[ 0.0197,  0.0050, -0.0060,  ...,  0.0374, -0.0171,  0.0016],
         [-0.0884, -0.1924,  0.2832,  ...,  0.3457, -0.0591,  0.0559],
         [ 0.3125, -0.8359,  0.3438,  ...,  0.0593,  0.3125, -0.3301],
         ...,
         [-0.7734,  0.5352,  0.1973,  ..., -0.1650, -0.4180,  0.0596],
         [-0.0215,  0.0084,  0.1279,  ..., -0.2109, -0.2266,  0.3047],
         [-0.1123,  0.0981, -0.0510,  ..., -0.6797, -1.0938,  0.0583]],

        [[ 0.0197,  0.0050, -0.0060,  ...,  0.0374, -0.0171,  0.0016],
         [-0.0884, -0.1924,  0.2832,  ...,  0.3457, -0.0591,  0.0559],
         [ 0.3125, -0.8359,  0.3438,  ...,  0.0593,  0.3125, -0.3301],
         ...,
         [-0.7734,  0.5352,  0.1973,  ..., -0.1650, -0.4180,  0.0596],
         [-0.0215,  0.0084,  0.1279,  ..., -0.2109, -0.2266,  0.3047],
         [-0.1123,  0.0981, -0.0510,  ..., -0.6797, -1.0938,  0.0583]],

        [[ 0.0197,  0.0050, -0.0060,  ...,  0.0374, -0.0171,  0.0016],
         [-0.0884, -0.1924,  0.2832,  ...,  0.3457, -0.0591,  0.0559],
         [ 0.3125, -0.8359,  0.3438,  ...,  0.0593,  0.3125, -0.3301],
         ...,
         [-0.7734,  0.5352,  0.1973,  ..., -0.1650, -0.4180,  0.0596],
         [-0.0215,  0.0084,  0.1279,  ..., -0.2109, -0.2266,  0.3047],
         [-0.1123,  0.0981, -0.0510,  ..., -0.6797, -1.0938,  0.0583]]],
       dtype=torch.bfloat16)), (tensor([[[-1.7700e-02, -1.2329e-02,  1.7456e-02,  ...,  7.7637e-02,
           5.0537e-02, -1.0107e-01],
         [ 2.0000e+00,  1.4551e-01,  2.0469e+00,  ..., -4.6289e-01,
          -6.8359e-01, -7.5391e-01],
         [ 1.6562e+00,  1.2734e+00, -2.5977e-01,  ...,  2.4512e-01,
          -2.8594e+00,  1.6562e+00],
         ...,
         [-6.8359e-01, -1.5156e+00, -2.2852e-01,  ..., -1.6484e+00,
           5.1562e+00, -8.9453e-01],
         [ 2.0117e-01, -6.7578e-01, -4.4922e-02,  ...,  7.9102e-02,
           1.9434e-01, -2.6719e+00],
         [-1.0938e+00, -1.5469e+00,  1.4844e-01,  ..., -5.3906e-01,
          -9.2163e-03,  1.0703e+00]],

        [[-1.7700e-02, -1.2329e-02,  1.7456e-02,  ...,  7.7637e-02,
           5.0537e-02, -1.0107e-01],
         [ 2.0000e+00,  1.4551e-01,  2.0469e+00,  ..., -4.6289e-01,
          -6.8359e-01, -7.5391e-01],
         [ 1.6562e+00,  1.2734e+00, -2.5977e-01,  ...,  2.4512e-01,
          -2.8594e+00,  1.6562e+00],
         ...,
         [-6.8359e-01, -1.5156e+00, -2.2852e-01,  ..., -1.6484e+00,
           5.1562e+00, -8.9453e-01],
         [ 2.0117e-01, -6.7578e-01, -4.4922e-02,  ...,  7.9102e-02,
           1.9434e-01, -2.6719e+00],
         [-1.0938e+00, -1.5469e+00,  1.4844e-01,  ..., -5.3906e-01,
          -9.2163e-03,  1.0703e+00]],

        [[-1.7700e-02, -1.2329e-02,  1.7456e-02,  ...,  7.7637e-02,
           5.0537e-02, -1.0107e-01],
         [ 2.0000e+00,  1.4551e-01,  2.0469e+00,  ..., -4.6289e-01,
          -6.8359e-01, -7.5391e-01],
         [ 1.6562e+00,  1.2734e+00, -2.5977e-01,  ...,  2.4512e-01,
          -2.8594e+00,  1.6562e+00],
         ...,
         [-6.8359e-01, -1.5156e+00, -2.2852e-01,  ..., -1.6484e+00,
           5.1562e+00, -8.9453e-01],
         [ 2.0117e-01, -6.7578e-01, -4.4922e-02,  ...,  7.9102e-02,
           1.9434e-01, -2.6719e+00],
         [-1.0938e+00, -1.5469e+00,  1.4844e-01,  ..., -5.3906e-01,
          -9.2163e-03,  1.0703e+00]],

        ...,

        [[-3.5706e-03,  6.9275e-03,  1.4221e-02,  ..., -1.7285e-01,
           1.7773e-01, -1.6968e-02],
         [ 2.4062e+00,  1.9062e+00,  2.1777e-01,  ..., -1.5156e+00,
          -1.3750e+00, -9.2578e-01],
         [ 9.4922e-01,  1.1484e+00,  3.0273e-01,  ..., -3.0938e+00,
          -2.1250e+00, -2.0938e+00],
         ...,
         [-2.1406e+00, -1.0547e+00, -2.0625e+00,  ...,  5.8984e-01,
           5.7422e-01,  1.9766e+00],
         [-3.4570e-01, -3.5156e-01, -7.7734e-01,  ...,  9.7656e-02,
           7.9688e-01,  2.7031e+00],
         [-6.9531e-01,  2.8516e-01, -1.9336e-01,  ..., -5.1953e-01,
          -1.6484e+00,  1.3047e+00]],

        [[-3.5706e-03,  6.9275e-03,  1.4221e-02,  ..., -1.7285e-01,
           1.7773e-01, -1.6968e-02],
         [ 2.4062e+00,  1.9062e+00,  2.1777e-01,  ..., -1.5156e+00,
          -1.3750e+00, -9.2578e-01],
         [ 9.4922e-01,  1.1484e+00,  3.0273e-01,  ..., -3.0938e+00,
          -2.1250e+00, -2.0938e+00],
         ...,
         [-2.1406e+00, -1.0547e+00, -2.0625e+00,  ...,  5.8984e-01,
           5.7422e-01,  1.9766e+00],
         [-3.4570e-01, -3.5156e-01, -7.7734e-01,  ...,  9.7656e-02,
           7.9688e-01,  2.7031e+00],
         [-6.9531e-01,  2.8516e-01, -1.9336e-01,  ..., -5.1953e-01,
          -1.6484e+00,  1.3047e+00]],

        [[-3.5706e-03,  6.9275e-03,  1.4221e-02,  ..., -1.7285e-01,
           1.7773e-01, -1.6968e-02],
         [ 2.4062e+00,  1.9062e+00,  2.1777e-01,  ..., -1.5156e+00,
          -1.3750e+00, -9.2578e-01],
         [ 9.4922e-01,  1.1484e+00,  3.0273e-01,  ..., -3.0938e+00,
          -2.1250e+00, -2.0938e+00],
         ...,
         [-2.1406e+00, -1.0547e+00, -2.0625e+00,  ...,  5.8984e-01,
           5.7422e-01,  1.9766e+00],
         [-3.4570e-01, -3.5156e-01, -7.7734e-01,  ...,  9.7656e-02,
           7.9688e-01,  2.7031e+00],
         [-6.9531e-01,  2.8516e-01, -1.9336e-01,  ..., -5.1953e-01,
          -1.6484e+00,  1.3047e+00]]], dtype=torch.bfloat16), tensor([[[ 0.0537, -0.0067, -0.0131,  ...,  0.0227, -0.0175,  0.0188],
         [-0.4023, -0.3379, -0.4727,  ..., -0.3730, -0.2930,  0.3906],
         [-0.2852, -0.7734,  0.0374,  ..., -0.4922, -0.1074,  0.3262],
         ...,
         [ 0.7656, -0.0801, -0.0025,  ...,  0.2656,  0.7109, -0.1611],
         [ 0.2539, -0.4785, -0.9531,  ..., -0.4551,  0.8828, -0.2119],
         [-0.3438, -0.1406, -0.0776,  ...,  0.0947,  0.0938,  0.3672]],

        [[ 0.0537, -0.0067, -0.0131,  ...,  0.0227, -0.0175,  0.0188],
         [-0.4023, -0.3379, -0.4727,  ..., -0.3730, -0.2930,  0.3906],
         [-0.2852, -0.7734,  0.0374,  ..., -0.4922, -0.1074,  0.3262],
         ...,
         [ 0.7656, -0.0801, -0.0025,  ...,  0.2656,  0.7109, -0.1611],
         [ 0.2539, -0.4785, -0.9531,  ..., -0.4551,  0.8828, -0.2119],
         [-0.3438, -0.1406, -0.0776,  ...,  0.0947,  0.0938,  0.3672]],

        [[ 0.0537, -0.0067, -0.0131,  ...,  0.0227, -0.0175,  0.0188],
         [-0.4023, -0.3379, -0.4727,  ..., -0.3730, -0.2930,  0.3906],
         [-0.2852, -0.7734,  0.0374,  ..., -0.4922, -0.1074,  0.3262],
         ...,
         [ 0.7656, -0.0801, -0.0025,  ...,  0.2656,  0.7109, -0.1611],
         [ 0.2539, -0.4785, -0.9531,  ..., -0.4551,  0.8828, -0.2119],
         [-0.3438, -0.1406, -0.0776,  ...,  0.0947,  0.0938,  0.3672]],

        ...,

        [[-0.0300, -0.0415, -0.0488,  ..., -0.0352,  0.0204,  0.0442],
         [-0.2051,  0.1680,  0.2500,  ..., -0.3320, -0.2949,  0.1670],
         [ 0.0203, -0.3398, -0.0371,  ..., -0.2354, -0.0206,  0.0171],
         ...,
         [ 0.3496, -0.0040,  0.5586,  ...,  0.6992,  0.4531,  0.1025],
         [ 0.0669, -0.3203,  0.1123,  ...,  0.5859, -1.0703, -0.2539],
         [ 0.2539,  0.2891,  0.6719,  ...,  0.6758,  0.2256,  0.3047]],

        [[-0.0300, -0.0415, -0.0488,  ..., -0.0352,  0.0204,  0.0442],
         [-0.2051,  0.1680,  0.2500,  ..., -0.3320, -0.2949,  0.1670],
         [ 0.0203, -0.3398, -0.0371,  ..., -0.2354, -0.0206,  0.0171],
         ...,
         [ 0.3496, -0.0040,  0.5586,  ...,  0.6992,  0.4531,  0.1025],
         [ 0.0669, -0.3203,  0.1123,  ...,  0.5859, -1.0703, -0.2539],
         [ 0.2539,  0.2891,  0.6719,  ...,  0.6758,  0.2256,  0.3047]],

        [[-0.0300, -0.0415, -0.0488,  ..., -0.0352,  0.0204,  0.0442],
         [-0.2051,  0.1680,  0.2500,  ..., -0.3320, -0.2949,  0.1670],
         [ 0.0203, -0.3398, -0.0371,  ..., -0.2354, -0.0206,  0.0171],
         ...,
         [ 0.3496, -0.0040,  0.5586,  ...,  0.6992,  0.4531,  0.1025],
         [ 0.0669, -0.3203,  0.1123,  ...,  0.5859, -1.0703, -0.2539],
         [ 0.2539,  0.2891,  0.6719,  ...,  0.6758,  0.2256,  0.3047]]],
       dtype=torch.bfloat16)), (tensor([[[-7.9956e-03, -3.5553e-03, -3.8147e-03,  ..., -1.2500e+00,
           1.5747e-02,  6.2012e-02],
         [ 7.0312e-01,  3.6719e-01,  1.5820e-01,  ...,  2.6172e-01,
          -5.8594e-01, -1.5625e-01],
         [-6.3281e-01,  7.7734e-01, -8.4229e-03,  ...,  1.1719e+00,
          -1.7109e+00, -2.7500e+00],
         ...,
         [-8.8867e-02, -3.4570e-01, -1.0859e+00,  ..., -1.3984e+00,
          -2.6719e+00,  7.9297e-01],
         [ 5.1172e-01, -3.7109e-01, -7.5781e-01,  ...,  4.8438e-01,
          -1.8438e+00, -8.3203e-01],
         [ 9.4141e-01, -8.5938e-01, -6.0938e-01,  ...,  2.9297e-01,
           1.5938e+00, -1.7578e+00]],

        [[-7.9956e-03, -3.5553e-03, -3.8147e-03,  ..., -1.2500e+00,
           1.5747e-02,  6.2012e-02],
         [ 7.0312e-01,  3.6719e-01,  1.5820e-01,  ...,  2.6172e-01,
          -5.8594e-01, -1.5625e-01],
         [-6.3281e-01,  7.7734e-01, -8.4229e-03,  ...,  1.1719e+00,
          -1.7109e+00, -2.7500e+00],
         ...,
         [-8.8867e-02, -3.4570e-01, -1.0859e+00,  ..., -1.3984e+00,
          -2.6719e+00,  7.9297e-01],
         [ 5.1172e-01, -3.7109e-01, -7.5781e-01,  ...,  4.8438e-01,
          -1.8438e+00, -8.3203e-01],
         [ 9.4141e-01, -8.5938e-01, -6.0938e-01,  ...,  2.9297e-01,
           1.5938e+00, -1.7578e+00]],

        [[-7.9956e-03, -3.5553e-03, -3.8147e-03,  ..., -1.2500e+00,
           1.5747e-02,  6.2012e-02],
         [ 7.0312e-01,  3.6719e-01,  1.5820e-01,  ...,  2.6172e-01,
          -5.8594e-01, -1.5625e-01],
         [-6.3281e-01,  7.7734e-01, -8.4229e-03,  ...,  1.1719e+00,
          -1.7109e+00, -2.7500e+00],
         ...,
         [-8.8867e-02, -3.4570e-01, -1.0859e+00,  ..., -1.3984e+00,
          -2.6719e+00,  7.9297e-01],
         [ 5.1172e-01, -3.7109e-01, -7.5781e-01,  ...,  4.8438e-01,
          -1.8438e+00, -8.3203e-01],
         [ 9.4141e-01, -8.5938e-01, -6.0938e-01,  ...,  2.9297e-01,
           1.5938e+00, -1.7578e+00]],

        ...,

        [[-1.8692e-03,  1.0300e-03, -7.0496e-03,  ...,  1.4375e+00,
          -2.7969e+00, -3.5156e-01],
         [ 1.7188e-01, -2.3281e+00, -9.9219e-01,  ..., -1.9219e+00,
           6.9375e+00,  1.6641e+00],
         [ 1.5527e-01, -3.7109e-01, -6.2500e-01,  ..., -2.3906e+00,
           7.8438e+00,  4.6562e+00],
         ...,
         [-1.1641e+00,  8.4375e-01, -1.6211e-01,  ..., -1.2891e+00,
           6.0938e+00, -4.3438e+00],
         [ 6.0547e-02, -1.7676e-01, -4.6875e-01,  ..., -9.0234e-01,
           4.4688e+00, -1.2344e+00],
         [-1.2266e+00,  1.5625e-02, -9.3750e-01,  ..., -3.4180e-01,
           5.8438e+00,  2.3281e+00]],

        [[-1.8692e-03,  1.0300e-03, -7.0496e-03,  ...,  1.4375e+00,
          -2.7969e+00, -3.5156e-01],
         [ 1.7188e-01, -2.3281e+00, -9.9219e-01,  ..., -1.9219e+00,
           6.9375e+00,  1.6641e+00],
         [ 1.5527e-01, -3.7109e-01, -6.2500e-01,  ..., -2.3906e+00,
           7.8438e+00,  4.6562e+00],
         ...,
         [-1.1641e+00,  8.4375e-01, -1.6211e-01,  ..., -1.2891e+00,
           6.0938e+00, -4.3438e+00],
         [ 6.0547e-02, -1.7676e-01, -4.6875e-01,  ..., -9.0234e-01,
           4.4688e+00, -1.2344e+00],
         [-1.2266e+00,  1.5625e-02, -9.3750e-01,  ..., -3.4180e-01,
           5.8438e+00,  2.3281e+00]],

        [[-1.8692e-03,  1.0300e-03, -7.0496e-03,  ...,  1.4375e+00,
          -2.7969e+00, -3.5156e-01],
         [ 1.7188e-01, -2.3281e+00, -9.9219e-01,  ..., -1.9219e+00,
           6.9375e+00,  1.6641e+00],
         [ 1.5527e-01, -3.7109e-01, -6.2500e-01,  ..., -2.3906e+00,
           7.8438e+00,  4.6562e+00],
         ...,
         [-1.1641e+00,  8.4375e-01, -1.6211e-01,  ..., -1.2891e+00,
           6.0938e+00, -4.3438e+00],
         [ 6.0547e-02, -1.7676e-01, -4.6875e-01,  ..., -9.0234e-01,
           4.4688e+00, -1.2344e+00],
         [-1.2266e+00,  1.5625e-02, -9.3750e-01,  ..., -3.4180e-01,
           5.8438e+00,  2.3281e+00]]], dtype=torch.bfloat16), tensor([[[ 0.0623,  0.0913, -0.0081,  ...,  0.0413, -0.0123,  0.0043],
         [-0.5742,  0.0535, -0.3262,  ...,  0.3320,  0.4824,  0.2266],
         [-0.7031, -0.4531,  0.2227,  ..., -0.0579,  0.4180,  0.1084],
         ...,
         [ 0.4590, -0.3926,  0.3984,  ...,  0.2559, -0.1855,  0.7461],
         [-0.0698, -0.4395,  0.8242,  ..., -0.0913,  0.0625,  0.7383],
         [ 0.2676, -0.3828, -1.4062,  ...,  0.6211,  0.8164, -0.4180]],

        [[ 0.0623,  0.0913, -0.0081,  ...,  0.0413, -0.0123,  0.0043],
         [-0.5742,  0.0535, -0.3262,  ...,  0.3320,  0.4824,  0.2266],
         [-0.7031, -0.4531,  0.2227,  ..., -0.0579,  0.4180,  0.1084],
         ...,
         [ 0.4590, -0.3926,  0.3984,  ...,  0.2559, -0.1855,  0.7461],
         [-0.0698, -0.4395,  0.8242,  ..., -0.0913,  0.0625,  0.7383],
         [ 0.2676, -0.3828, -1.4062,  ...,  0.6211,  0.8164, -0.4180]],

        [[ 0.0623,  0.0913, -0.0081,  ...,  0.0413, -0.0123,  0.0043],
         [-0.5742,  0.0535, -0.3262,  ...,  0.3320,  0.4824,  0.2266],
         [-0.7031, -0.4531,  0.2227,  ..., -0.0579,  0.4180,  0.1084],
         ...,
         [ 0.4590, -0.3926,  0.3984,  ...,  0.2559, -0.1855,  0.7461],
         [-0.0698, -0.4395,  0.8242,  ..., -0.0913,  0.0625,  0.7383],
         [ 0.2676, -0.3828, -1.4062,  ...,  0.6211,  0.8164, -0.4180]],

        ...,

        [[-0.0090,  0.0481,  0.0194,  ...,  0.0170,  0.0413,  0.0540],
         [-0.3809, -0.5859,  0.0077,  ...,  0.2715,  0.6562, -0.1631],
         [-0.8516, -0.4805, -0.6523,  ..., -0.4141,  0.1099, -0.4941],
         ...,
         [ 0.7930, -0.0747, -0.0684,  ..., -0.5898,  0.0684, -0.4922],
         [ 1.3281,  0.2334, -0.2139,  ..., -0.4824,  0.0820, -0.4980],
         [-0.2598, -0.6133, -0.6797,  ...,  0.3457, -0.5742,  0.2832]],

        [[-0.0090,  0.0481,  0.0194,  ...,  0.0170,  0.0413,  0.0540],
         [-0.3809, -0.5859,  0.0077,  ...,  0.2715,  0.6562, -0.1631],
         [-0.8516, -0.4805, -0.6523,  ..., -0.4141,  0.1099, -0.4941],
         ...,
         [ 0.7930, -0.0747, -0.0684,  ..., -0.5898,  0.0684, -0.4922],
         [ 1.3281,  0.2334, -0.2139,  ..., -0.4824,  0.0820, -0.4980],
         [-0.2598, -0.6133, -0.6797,  ...,  0.3457, -0.5742,  0.2832]],

        [[-0.0090,  0.0481,  0.0194,  ...,  0.0170,  0.0413,  0.0540],
         [-0.3809, -0.5859,  0.0077,  ...,  0.2715,  0.6562, -0.1631],
         [-0.8516, -0.4805, -0.6523,  ..., -0.4141,  0.1099, -0.4941],
         ...,
         [ 0.7930, -0.0747, -0.0684,  ..., -0.5898,  0.0684, -0.4922],
         [ 1.3281,  0.2334, -0.2139,  ..., -0.4824,  0.0820, -0.4980],
         [-0.2598, -0.6133, -0.6797,  ...,  0.3457, -0.5742,  0.2832]]],
       dtype=torch.bfloat16)), (tensor([[[-7.8125e-03,  3.9673e-03, -4.2419e-03,  ..., -6.7188e-01,
          -1.7090e-01,  9.5215e-02],
         [ 1.6406e+00,  9.9609e-01,  1.7422e+00,  ..., -1.0781e+00,
           1.0469e+00,  1.4141e+00],
         [ 9.2578e-01,  9.8047e-01,  1.8828e+00,  ..., -2.4062e+00,
          -1.1328e+00,  2.3750e+00],
         ...,
         [-5.9375e-01,  8.7891e-02,  1.4219e+00,  ...,  4.4189e-02,
           5.5469e-01,  8.5938e-01],
         [-2.6562e-01, -4.3555e-01,  4.1406e-01,  ..., -3.9453e-01,
          -1.3359e+00,  1.3359e+00],
         [-1.7500e+00, -2.9492e-01,  1.0000e+00,  ..., -1.2158e-01,
           1.3516e+00,  7.1484e-01]],

        [[-7.8125e-03,  3.9673e-03, -4.2419e-03,  ..., -6.7188e-01,
          -1.7090e-01,  9.5215e-02],
         [ 1.6406e+00,  9.9609e-01,  1.7422e+00,  ..., -1.0781e+00,
           1.0469e+00,  1.4141e+00],
         [ 9.2578e-01,  9.8047e-01,  1.8828e+00,  ..., -2.4062e+00,
          -1.1328e+00,  2.3750e+00],
         ...,
         [-5.9375e-01,  8.7891e-02,  1.4219e+00,  ...,  4.4189e-02,
           5.5469e-01,  8.5938e-01],
         [-2.6562e-01, -4.3555e-01,  4.1406e-01,  ..., -3.9453e-01,
          -1.3359e+00,  1.3359e+00],
         [-1.7500e+00, -2.9492e-01,  1.0000e+00,  ..., -1.2158e-01,
           1.3516e+00,  7.1484e-01]],

        [[-7.8125e-03,  3.9673e-03, -4.2419e-03,  ..., -6.7188e-01,
          -1.7090e-01,  9.5215e-02],
         [ 1.6406e+00,  9.9609e-01,  1.7422e+00,  ..., -1.0781e+00,
           1.0469e+00,  1.4141e+00],
         [ 9.2578e-01,  9.8047e-01,  1.8828e+00,  ..., -2.4062e+00,
          -1.1328e+00,  2.3750e+00],
         ...,
         [-5.9375e-01,  8.7891e-02,  1.4219e+00,  ...,  4.4189e-02,
           5.5469e-01,  8.5938e-01],
         [-2.6562e-01, -4.3555e-01,  4.1406e-01,  ..., -3.9453e-01,
          -1.3359e+00,  1.3359e+00],
         [-1.7500e+00, -2.9492e-01,  1.0000e+00,  ..., -1.2158e-01,
           1.3516e+00,  7.1484e-01]],

        ...,

        [[ 1.5198e-02,  1.0864e-02, -6.3477e-03,  ...,  1.2146e-02,
           1.2451e-02,  2.3281e+00],
         [ 8.3984e-01,  2.2344e+00,  1.0391e+00,  ...,  3.3398e-01,
          -3.8438e+00, -7.8438e+00],
         [-2.1484e-01,  1.6562e+00,  1.2695e-01,  ...,  5.5859e-01,
          -1.9219e+00, -1.0062e+01],
         ...,
         [-7.3438e-01,  1.1719e-01, -9.9609e-01,  ...,  9.6875e-01,
           1.6211e-01, -4.9062e+00],
         [ 9.7656e-03,  4.2383e-01,  1.4258e-01,  ...,  1.5312e+00,
          -6.5234e-01, -5.3438e+00],
         [ 1.2969e+00,  2.6953e-01,  1.4531e+00,  ...,  3.4570e-01,
          -2.7656e+00, -5.7500e+00]],

        [[ 1.5198e-02,  1.0864e-02, -6.3477e-03,  ...,  1.2146e-02,
           1.2451e-02,  2.3281e+00],
         [ 8.3984e-01,  2.2344e+00,  1.0391e+00,  ...,  3.3398e-01,
          -3.8438e+00, -7.8438e+00],
         [-2.1484e-01,  1.6562e+00,  1.2695e-01,  ...,  5.5859e-01,
          -1.9219e+00, -1.0062e+01],
         ...,
         [-7.3438e-01,  1.1719e-01, -9.9609e-01,  ...,  9.6875e-01,
           1.6211e-01, -4.9062e+00],
         [ 9.7656e-03,  4.2383e-01,  1.4258e-01,  ...,  1.5312e+00,
          -6.5234e-01, -5.3438e+00],
         [ 1.2969e+00,  2.6953e-01,  1.4531e+00,  ...,  3.4570e-01,
          -2.7656e+00, -5.7500e+00]],

        [[ 1.5198e-02,  1.0864e-02, -6.3477e-03,  ...,  1.2146e-02,
           1.2451e-02,  2.3281e+00],
         [ 8.3984e-01,  2.2344e+00,  1.0391e+00,  ...,  3.3398e-01,
          -3.8438e+00, -7.8438e+00],
         [-2.1484e-01,  1.6562e+00,  1.2695e-01,  ...,  5.5859e-01,
          -1.9219e+00, -1.0062e+01],
         ...,
         [-7.3438e-01,  1.1719e-01, -9.9609e-01,  ...,  9.6875e-01,
           1.6211e-01, -4.9062e+00],
         [ 9.7656e-03,  4.2383e-01,  1.4258e-01,  ...,  1.5312e+00,
          -6.5234e-01, -5.3438e+00],
         [ 1.2969e+00,  2.6953e-01,  1.4531e+00,  ...,  3.4570e-01,
          -2.7656e+00, -5.7500e+00]]], dtype=torch.bfloat16), tensor([[[-1.0437e-02,  4.6730e-04,  1.5869e-02,  ...,  4.7607e-02,
          -4.8584e-02, -2.1118e-02],
         [-4.8438e-01,  6.2256e-02,  3.7500e-01,  ..., -1.2329e-02,
           6.8054e-03,  4.7461e-01],
         [ 2.6367e-01, -3.1641e-01,  3.7109e-01,  ..., -4.0039e-02,
           4.7461e-01,  4.8242e-01],
         ...,
         [ 1.3184e-01,  4.9023e-01,  6.1328e-01,  ..., -6.9141e-01,
          -6.6797e-01,  6.2256e-02],
         [ 1.6504e-01,  2.4048e-02, -4.0820e-01,  ..., -2.3242e-01,
           8.3203e-01, -3.0664e-01],
         [-9.4238e-02,  1.7188e-01, -1.5820e-01,  ...,  2.5586e-01,
          -3.6914e-01,  1.0596e-01]],

        [[-1.0437e-02,  4.6730e-04,  1.5869e-02,  ...,  4.7607e-02,
          -4.8584e-02, -2.1118e-02],
         [-4.8438e-01,  6.2256e-02,  3.7500e-01,  ..., -1.2329e-02,
           6.8054e-03,  4.7461e-01],
         [ 2.6367e-01, -3.1641e-01,  3.7109e-01,  ..., -4.0039e-02,
           4.7461e-01,  4.8242e-01],
         ...,
         [ 1.3184e-01,  4.9023e-01,  6.1328e-01,  ..., -6.9141e-01,
          -6.6797e-01,  6.2256e-02],
         [ 1.6504e-01,  2.4048e-02, -4.0820e-01,  ..., -2.3242e-01,
           8.3203e-01, -3.0664e-01],
         [-9.4238e-02,  1.7188e-01, -1.5820e-01,  ...,  2.5586e-01,
          -3.6914e-01,  1.0596e-01]],

        [[-1.0437e-02,  4.6730e-04,  1.5869e-02,  ...,  4.7607e-02,
          -4.8584e-02, -2.1118e-02],
         [-4.8438e-01,  6.2256e-02,  3.7500e-01,  ..., -1.2329e-02,
           6.8054e-03,  4.7461e-01],
         [ 2.6367e-01, -3.1641e-01,  3.7109e-01,  ..., -4.0039e-02,
           4.7461e-01,  4.8242e-01],
         ...,
         [ 1.3184e-01,  4.9023e-01,  6.1328e-01,  ..., -6.9141e-01,
          -6.6797e-01,  6.2256e-02],
         [ 1.6504e-01,  2.4048e-02, -4.0820e-01,  ..., -2.3242e-01,
           8.3203e-01, -3.0664e-01],
         [-9.4238e-02,  1.7188e-01, -1.5820e-01,  ...,  2.5586e-01,
          -3.6914e-01,  1.0596e-01]],

        ...,

        [[-1.8799e-02,  1.3245e-02, -1.6602e-02,  ..., -1.1292e-02,
           1.3489e-02,  2.1820e-03],
         [ 2.6758e-01, -1.4062e-01, -9.5215e-02,  ..., -1.2500e-01,
           4.9219e-01,  2.7930e-01],
         [-3.2422e-01,  1.4453e-01, -4.1211e-01,  ...,  3.0078e-01,
           3.6328e-01, -1.3672e-01],
         ...,
         [ 3.3008e-01,  3.2031e-01,  9.5703e-02,  ..., -3.2617e-01,
           6.5234e-01, -8.2031e-01],
         [ 1.0400e-01, -5.0781e-01, -5.6250e-01,  ...,  1.3086e-01,
          -1.4941e-01, -3.8818e-02],
         [ 7.1777e-02, -4.2969e-01, -5.5859e-01,  ..., -3.1250e-02,
          -4.4531e-01,  2.2852e-01]],

        [[-1.8799e-02,  1.3245e-02, -1.6602e-02,  ..., -1.1292e-02,
           1.3489e-02,  2.1820e-03],
         [ 2.6758e-01, -1.4062e-01, -9.5215e-02,  ..., -1.2500e-01,
           4.9219e-01,  2.7930e-01],
         [-3.2422e-01,  1.4453e-01, -4.1211e-01,  ...,  3.0078e-01,
           3.6328e-01, -1.3672e-01],
         ...,
         [ 3.3008e-01,  3.2031e-01,  9.5703e-02,  ..., -3.2617e-01,
           6.5234e-01, -8.2031e-01],
         [ 1.0400e-01, -5.0781e-01, -5.6250e-01,  ...,  1.3086e-01,
          -1.4941e-01, -3.8818e-02],
         [ 7.1777e-02, -4.2969e-01, -5.5859e-01,  ..., -3.1250e-02,
          -4.4531e-01,  2.2852e-01]],

        [[-1.8799e-02,  1.3245e-02, -1.6602e-02,  ..., -1.1292e-02,
           1.3489e-02,  2.1820e-03],
         [ 2.6758e-01, -1.4062e-01, -9.5215e-02,  ..., -1.2500e-01,
           4.9219e-01,  2.7930e-01],
         [-3.2422e-01,  1.4453e-01, -4.1211e-01,  ...,  3.0078e-01,
           3.6328e-01, -1.3672e-01],
         ...,
         [ 3.3008e-01,  3.2031e-01,  9.5703e-02,  ..., -3.2617e-01,
           6.5234e-01, -8.2031e-01],
         [ 1.0400e-01, -5.0781e-01, -5.6250e-01,  ...,  1.3086e-01,
          -1.4941e-01, -3.8818e-02],
         [ 7.1777e-02, -4.2969e-01, -5.5859e-01,  ..., -3.1250e-02,
          -4.4531e-01,  2.2852e-01]]], dtype=torch.bfloat16)), (tensor([[[-2.4567e-03, -9.7046e-03,  2.9449e-03,  ..., -2.8711e-01,
          -1.3770e-01, -1.1230e-01],
         [-1.0078e+00,  2.8711e-01, -1.0469e+00,  ..., -1.1562e+00,
           3.9688e+00, -3.5156e-01],
         [ 3.0469e-01,  6.7578e-01, -3.7109e-01,  ...,  1.1230e-01,
           7.0703e-01,  2.5625e+00],
         ...,
         [ 1.1641e+00,  6.8750e-01, -3.0664e-01,  ..., -3.0078e-01,
          -1.2969e+00, -1.3594e+00],
         [ 4.4922e-02,  7.8125e-01,  4.9805e-02,  ..., -6.8359e-01,
           9.1016e-01,  2.6758e-01],
         [-7.7734e-01,  1.1406e+00, -8.6719e-01,  ..., -3.7344e+00,
           1.8594e+00,  1.8906e+00]],

        [[-2.4567e-03, -9.7046e-03,  2.9449e-03,  ..., -2.8711e-01,
          -1.3770e-01, -1.1230e-01],
         [-1.0078e+00,  2.8711e-01, -1.0469e+00,  ..., -1.1562e+00,
           3.9688e+00, -3.5156e-01],
         [ 3.0469e-01,  6.7578e-01, -3.7109e-01,  ...,  1.1230e-01,
           7.0703e-01,  2.5625e+00],
         ...,
         [ 1.1641e+00,  6.8750e-01, -3.0664e-01,  ..., -3.0078e-01,
          -1.2969e+00, -1.3594e+00],
         [ 4.4922e-02,  7.8125e-01,  4.9805e-02,  ..., -6.8359e-01,
           9.1016e-01,  2.6758e-01],
         [-7.7734e-01,  1.1406e+00, -8.6719e-01,  ..., -3.7344e+00,
           1.8594e+00,  1.8906e+00]],

        [[-2.4567e-03, -9.7046e-03,  2.9449e-03,  ..., -2.8711e-01,
          -1.3770e-01, -1.1230e-01],
         [-1.0078e+00,  2.8711e-01, -1.0469e+00,  ..., -1.1562e+00,
           3.9688e+00, -3.5156e-01],
         [ 3.0469e-01,  6.7578e-01, -3.7109e-01,  ...,  1.1230e-01,
           7.0703e-01,  2.5625e+00],
         ...,
         [ 1.1641e+00,  6.8750e-01, -3.0664e-01,  ..., -3.0078e-01,
          -1.2969e+00, -1.3594e+00],
         [ 4.4922e-02,  7.8125e-01,  4.9805e-02,  ..., -6.8359e-01,
           9.1016e-01,  2.6758e-01],
         [-7.7734e-01,  1.1406e+00, -8.6719e-01,  ..., -3.7344e+00,
           1.8594e+00,  1.8906e+00]],

        ...,

        [[ 3.6163e-03, -1.3794e-02,  1.1597e-02,  ...,  3.6914e-01,
           7.5684e-02,  1.3184e-01],
         [ 7.3047e-01,  1.1328e+00,  1.2109e+00,  ..., -3.8672e-01,
          -5.5469e-01, -7.9688e-01],
         [ 6.7578e-01, -7.4707e-02,  8.4766e-01,  ...,  3.9844e+00,
          -1.8125e+00,  1.0000e+00],
         ...,
         [-1.2969e+00, -1.0938e+00,  1.0000e+00,  ..., -2.0781e+00,
           2.3750e+00,  1.3672e+00],
         [ 8.5938e-02, -3.5156e-01, -2.2070e-01,  ..., -2.5625e+00,
           8.8672e-01,  1.5234e+00],
         [-1.0000e+00,  1.3594e+00,  2.1875e-01,  ..., -4.6875e-01,
           2.0801e-01, -9.8438e-01]],

        [[ 3.6163e-03, -1.3794e-02,  1.1597e-02,  ...,  3.6914e-01,
           7.5684e-02,  1.3184e-01],
         [ 7.3047e-01,  1.1328e+00,  1.2109e+00,  ..., -3.8672e-01,
          -5.5469e-01, -7.9688e-01],
         [ 6.7578e-01, -7.4707e-02,  8.4766e-01,  ...,  3.9844e+00,
          -1.8125e+00,  1.0000e+00],
         ...,
         [-1.2969e+00, -1.0938e+00,  1.0000e+00,  ..., -2.0781e+00,
           2.3750e+00,  1.3672e+00],
         [ 8.5938e-02, -3.5156e-01, -2.2070e-01,  ..., -2.5625e+00,
           8.8672e-01,  1.5234e+00],
         [-1.0000e+00,  1.3594e+00,  2.1875e-01,  ..., -4.6875e-01,
           2.0801e-01, -9.8438e-01]],

        [[ 3.6163e-03, -1.3794e-02,  1.1597e-02,  ...,  3.6914e-01,
           7.5684e-02,  1.3184e-01],
         [ 7.3047e-01,  1.1328e+00,  1.2109e+00,  ..., -3.8672e-01,
          -5.5469e-01, -7.9688e-01],
         [ 6.7578e-01, -7.4707e-02,  8.4766e-01,  ...,  3.9844e+00,
          -1.8125e+00,  1.0000e+00],
         ...,
         [-1.2969e+00, -1.0938e+00,  1.0000e+00,  ..., -2.0781e+00,
           2.3750e+00,  1.3672e+00],
         [ 8.5938e-02, -3.5156e-01, -2.2070e-01,  ..., -2.5625e+00,
           8.8672e-01,  1.5234e+00],
         [-1.0000e+00,  1.3594e+00,  2.1875e-01,  ..., -4.6875e-01,
           2.0801e-01, -9.8438e-01]]], dtype=torch.bfloat16), tensor([[[-1.9775e-02,  3.9551e-02, -2.7344e-02,  ..., -4.4434e-02,
          -7.5684e-02,  1.7944e-02],
         [ 2.5781e-01, -4.1797e-01,  1.6406e-01,  ...,  6.4453e-01,
          -3.4766e-01, -4.9414e-01],
         [-5.0781e-01,  4.7070e-01, -2.5781e-01,  ...,  7.9297e-01,
          -6.0547e-01,  4.2969e-02],
         ...,
         [-6.4844e-01,  9.4727e-02, -5.2979e-02,  ..., -6.2109e-01,
          -1.9043e-01, -6.0547e-01],
         [-8.9844e-01, -3.6914e-01, -2.8906e-01,  ..., -3.0078e-01,
          -1.9922e-01, -9.9609e-01],
         [-2.8125e-01,  1.6406e-01, -5.0391e-01,  ...,  1.8457e-01,
          -7.3242e-02,  1.6406e-01]],

        [[-1.9775e-02,  3.9551e-02, -2.7344e-02,  ..., -4.4434e-02,
          -7.5684e-02,  1.7944e-02],
         [ 2.5781e-01, -4.1797e-01,  1.6406e-01,  ...,  6.4453e-01,
          -3.4766e-01, -4.9414e-01],
         [-5.0781e-01,  4.7070e-01, -2.5781e-01,  ...,  7.9297e-01,
          -6.0547e-01,  4.2969e-02],
         ...,
         [-6.4844e-01,  9.4727e-02, -5.2979e-02,  ..., -6.2109e-01,
          -1.9043e-01, -6.0547e-01],
         [-8.9844e-01, -3.6914e-01, -2.8906e-01,  ..., -3.0078e-01,
          -1.9922e-01, -9.9609e-01],
         [-2.8125e-01,  1.6406e-01, -5.0391e-01,  ...,  1.8457e-01,
          -7.3242e-02,  1.6406e-01]],

        [[-1.9775e-02,  3.9551e-02, -2.7344e-02,  ..., -4.4434e-02,
          -7.5684e-02,  1.7944e-02],
         [ 2.5781e-01, -4.1797e-01,  1.6406e-01,  ...,  6.4453e-01,
          -3.4766e-01, -4.9414e-01],
         [-5.0781e-01,  4.7070e-01, -2.5781e-01,  ...,  7.9297e-01,
          -6.0547e-01,  4.2969e-02],
         ...,
         [-6.4844e-01,  9.4727e-02, -5.2979e-02,  ..., -6.2109e-01,
          -1.9043e-01, -6.0547e-01],
         [-8.9844e-01, -3.6914e-01, -2.8906e-01,  ..., -3.0078e-01,
          -1.9922e-01, -9.9609e-01],
         [-2.8125e-01,  1.6406e-01, -5.0391e-01,  ...,  1.8457e-01,
          -7.3242e-02,  1.6406e-01]],

        ...,

        [[ 4.2725e-03, -1.0498e-02, -1.6357e-02,  ..., -2.4261e-03,
          -2.6489e-02,  8.6975e-04],
         [-3.2031e-01, -3.2617e-01,  8.8379e-02,  ..., -8.1250e-01,
           1.6504e-01,  4.5508e-01],
         [-2.8711e-01, -3.9453e-01,  1.5137e-01,  ...,  5.1514e-02,
           6.3965e-02,  1.4746e-01],
         ...,
         [-3.2812e-01,  9.0332e-02,  2.4536e-02,  ...,  7.3828e-01,
           5.6641e-01,  1.2891e-01],
         [-7.0312e-01,  1.7383e-01,  4.1602e-01,  ..., -5.8594e-02,
           8.0469e-01,  2.6758e-01],
         [ 2.3315e-02, -4.7607e-03, -2.8125e-01,  ...,  1.9238e-01,
           6.2500e-01,  9.6680e-02]],

        [[ 4.2725e-03, -1.0498e-02, -1.6357e-02,  ..., -2.4261e-03,
          -2.6489e-02,  8.6975e-04],
         [-3.2031e-01, -3.2617e-01,  8.8379e-02,  ..., -8.1250e-01,
           1.6504e-01,  4.5508e-01],
         [-2.8711e-01, -3.9453e-01,  1.5137e-01,  ...,  5.1514e-02,
           6.3965e-02,  1.4746e-01],
         ...,
         [-3.2812e-01,  9.0332e-02,  2.4536e-02,  ...,  7.3828e-01,
           5.6641e-01,  1.2891e-01],
         [-7.0312e-01,  1.7383e-01,  4.1602e-01,  ..., -5.8594e-02,
           8.0469e-01,  2.6758e-01],
         [ 2.3315e-02, -4.7607e-03, -2.8125e-01,  ...,  1.9238e-01,
           6.2500e-01,  9.6680e-02]],

        [[ 4.2725e-03, -1.0498e-02, -1.6357e-02,  ..., -2.4261e-03,
          -2.6489e-02,  8.6975e-04],
         [-3.2031e-01, -3.2617e-01,  8.8379e-02,  ..., -8.1250e-01,
           1.6504e-01,  4.5508e-01],
         [-2.8711e-01, -3.9453e-01,  1.5137e-01,  ...,  5.1514e-02,
           6.3965e-02,  1.4746e-01],
         ...,
         [-3.2812e-01,  9.0332e-02,  2.4536e-02,  ...,  7.3828e-01,
           5.6641e-01,  1.2891e-01],
         [-7.0312e-01,  1.7383e-01,  4.1602e-01,  ..., -5.8594e-02,
           8.0469e-01,  2.6758e-01],
         [ 2.3315e-02, -4.7607e-03, -2.8125e-01,  ...,  1.9238e-01,
           6.2500e-01,  9.6680e-02]]], dtype=torch.bfloat16)), (tensor([[[ 2.0504e-04, -5.9509e-04, -8.3618e-03,  ...,  3.5706e-03,
           7.7637e-02, -1.4258e-01],
         [-1.5547e+00, -6.5234e-01, -7.3047e-01,  ...,  2.0312e+00,
          -6.8750e-01,  6.7578e-01],
         [ 6.4844e-01, -6.2500e-01, -6.1328e-01,  ..., -7.3047e-01,
          -1.5078e+00, -6.0547e-01],
         ...,
         [ 1.2031e+00,  7.3438e-01, -6.9531e-01,  ...,  9.8828e-01,
           2.0000e+00,  2.5312e+00],
         [-3.1250e-01,  2.0605e-01, -3.0469e-01,  ..., -6.2988e-02,
           4.4141e-01,  1.4453e+00],
         [-8.6719e-01,  3.8477e-01, -8.6719e-01,  ...,  1.1016e+00,
          -1.4922e+00,  8.3984e-01]],

        [[ 2.0504e-04, -5.9509e-04, -8.3618e-03,  ...,  3.5706e-03,
           7.7637e-02, -1.4258e-01],
         [-1.5547e+00, -6.5234e-01, -7.3047e-01,  ...,  2.0312e+00,
          -6.8750e-01,  6.7578e-01],
         [ 6.4844e-01, -6.2500e-01, -6.1328e-01,  ..., -7.3047e-01,
          -1.5078e+00, -6.0547e-01],
         ...,
         [ 1.2031e+00,  7.3438e-01, -6.9531e-01,  ...,  9.8828e-01,
           2.0000e+00,  2.5312e+00],
         [-3.1250e-01,  2.0605e-01, -3.0469e-01,  ..., -6.2988e-02,
           4.4141e-01,  1.4453e+00],
         [-8.6719e-01,  3.8477e-01, -8.6719e-01,  ...,  1.1016e+00,
          -1.4922e+00,  8.3984e-01]],

        [[ 2.0504e-04, -5.9509e-04, -8.3618e-03,  ...,  3.5706e-03,
           7.7637e-02, -1.4258e-01],
         [-1.5547e+00, -6.5234e-01, -7.3047e-01,  ...,  2.0312e+00,
          -6.8750e-01,  6.7578e-01],
         [ 6.4844e-01, -6.2500e-01, -6.1328e-01,  ..., -7.3047e-01,
          -1.5078e+00, -6.0547e-01],
         ...,
         [ 1.2031e+00,  7.3438e-01, -6.9531e-01,  ...,  9.8828e-01,
           2.0000e+00,  2.5312e+00],
         [-3.1250e-01,  2.0605e-01, -3.0469e-01,  ..., -6.2988e-02,
           4.4141e-01,  1.4453e+00],
         [-8.6719e-01,  3.8477e-01, -8.6719e-01,  ...,  1.1016e+00,
          -1.4922e+00,  8.3984e-01]],

        ...,

        [[-8.8882e-04, -6.9275e-03, -3.6926e-03,  ..., -2.1094e-01,
          -4.8584e-02, -6.6406e-02],
         [ 7.1875e-01,  4.1016e-01, -6.3281e-01,  ...,  2.8438e+00,
          -3.8906e+00, -7.0703e-01],
         [ 1.3086e-01, -2.4805e-01, -8.0859e-01,  ...,  2.0000e+00,
          -4.7812e+00, -3.4531e+00],
         ...,
         [ 5.3711e-02,  2.3438e-01, -3.9062e-02,  ..., -3.8750e+00,
          -1.5000e+00, -1.4453e-01],
         [-2.4219e-01, -2.3145e-01, -2.9102e-01,  ..., -3.5156e+00,
           2.3281e+00, -1.8203e+00],
         [-3.3691e-02, -2.7832e-02, -2.0898e-01,  ...,  4.0312e+00,
           1.5391e+00,  1.8359e-01]],

        [[-8.8882e-04, -6.9275e-03, -3.6926e-03,  ..., -2.1094e-01,
          -4.8584e-02, -6.6406e-02],
         [ 7.1875e-01,  4.1016e-01, -6.3281e-01,  ...,  2.8438e+00,
          -3.8906e+00, -7.0703e-01],
         [ 1.3086e-01, -2.4805e-01, -8.0859e-01,  ...,  2.0000e+00,
          -4.7812e+00, -3.4531e+00],
         ...,
         [ 5.3711e-02,  2.3438e-01, -3.9062e-02,  ..., -3.8750e+00,
          -1.5000e+00, -1.4453e-01],
         [-2.4219e-01, -2.3145e-01, -2.9102e-01,  ..., -3.5156e+00,
           2.3281e+00, -1.8203e+00],
         [-3.3691e-02, -2.7832e-02, -2.0898e-01,  ...,  4.0312e+00,
           1.5391e+00,  1.8359e-01]],

        [[-8.8882e-04, -6.9275e-03, -3.6926e-03,  ..., -2.1094e-01,
          -4.8584e-02, -6.6406e-02],
         [ 7.1875e-01,  4.1016e-01, -6.3281e-01,  ...,  2.8438e+00,
          -3.8906e+00, -7.0703e-01],
         [ 1.3086e-01, -2.4805e-01, -8.0859e-01,  ...,  2.0000e+00,
          -4.7812e+00, -3.4531e+00],
         ...,
         [ 5.3711e-02,  2.3438e-01, -3.9062e-02,  ..., -3.8750e+00,
          -1.5000e+00, -1.4453e-01],
         [-2.4219e-01, -2.3145e-01, -2.9102e-01,  ..., -3.5156e+00,
           2.3281e+00, -1.8203e+00],
         [-3.3691e-02, -2.7832e-02, -2.0898e-01,  ...,  4.0312e+00,
           1.5391e+00,  1.8359e-01]]], dtype=torch.bfloat16), tensor([[[-4.9210e-04, -2.6398e-03,  5.1270e-03,  ...,  1.7929e-03,
           4.4556e-03,  9.5825e-03],
         [-7.7637e-02, -2.2949e-01, -3.5742e-01,  ...,  1.1597e-02,
          -8.4766e-01, -5.3906e-01],
         [-3.6719e-01,  3.7500e-01,  4.3164e-01,  ..., -1.7188e-01,
          -4.4922e-01,  4.8438e-01],
         ...,
         [-4.8242e-01,  2.4023e-01, -1.5430e-01,  ...,  5.6250e-01,
           4.3945e-01, -5.1953e-01],
         [-3.6133e-01, -6.0059e-02,  1.7480e-01,  ...,  6.4062e-01,
          -2.8906e-01, -3.7500e-01],
         [-5.4688e-01,  2.2363e-01,  1.1572e-01,  ..., -4.1992e-01,
           4.5312e-01,  2.9688e-01]],

        [[-4.9210e-04, -2.6398e-03,  5.1270e-03,  ...,  1.7929e-03,
           4.4556e-03,  9.5825e-03],
         [-7.7637e-02, -2.2949e-01, -3.5742e-01,  ...,  1.1597e-02,
          -8.4766e-01, -5.3906e-01],
         [-3.6719e-01,  3.7500e-01,  4.3164e-01,  ..., -1.7188e-01,
          -4.4922e-01,  4.8438e-01],
         ...,
         [-4.8242e-01,  2.4023e-01, -1.5430e-01,  ...,  5.6250e-01,
           4.3945e-01, -5.1953e-01],
         [-3.6133e-01, -6.0059e-02,  1.7480e-01,  ...,  6.4062e-01,
          -2.8906e-01, -3.7500e-01],
         [-5.4688e-01,  2.2363e-01,  1.1572e-01,  ..., -4.1992e-01,
           4.5312e-01,  2.9688e-01]],

        [[-4.9210e-04, -2.6398e-03,  5.1270e-03,  ...,  1.7929e-03,
           4.4556e-03,  9.5825e-03],
         [-7.7637e-02, -2.2949e-01, -3.5742e-01,  ...,  1.1597e-02,
          -8.4766e-01, -5.3906e-01],
         [-3.6719e-01,  3.7500e-01,  4.3164e-01,  ..., -1.7188e-01,
          -4.4922e-01,  4.8438e-01],
         ...,
         [-4.8242e-01,  2.4023e-01, -1.5430e-01,  ...,  5.6250e-01,
           4.3945e-01, -5.1953e-01],
         [-3.6133e-01, -6.0059e-02,  1.7480e-01,  ...,  6.4062e-01,
          -2.8906e-01, -3.7500e-01],
         [-5.4688e-01,  2.2363e-01,  1.1572e-01,  ..., -4.1992e-01,
           4.5312e-01,  2.9688e-01]],

        ...,

        [[-1.7456e-02,  5.1514e-02, -1.7262e-04,  ...,  2.8968e-05,
           2.8687e-03,  1.8066e-02],
         [-1.3125e+00,  4.2969e-02,  1.1719e-01,  ...,  7.4219e-01,
          -8.2812e-01, -4.1602e-01],
         [-8.3984e-01, -2.1406e+00, -1.8828e+00,  ..., -1.0312e+00,
          -6.6797e-01,  1.4355e-01],
         ...,
         [ 6.3672e-01, -1.2344e+00, -4.1211e-01,  ...,  3.9368e-03,
           5.5176e-02,  5.3125e-01],
         [ 2.6172e-01, -4.1260e-02,  7.6660e-02,  ...,  3.4961e-01,
           6.6797e-01,  2.0630e-02],
         [-4.4531e-01, -5.5859e-01, -1.0938e+00,  ...,  7.1875e-01,
          -3.8672e-01, -2.6758e-01]],

        [[-1.7456e-02,  5.1514e-02, -1.7262e-04,  ...,  2.8968e-05,
           2.8687e-03,  1.8066e-02],
         [-1.3125e+00,  4.2969e-02,  1.1719e-01,  ...,  7.4219e-01,
          -8.2812e-01, -4.1602e-01],
         [-8.3984e-01, -2.1406e+00, -1.8828e+00,  ..., -1.0312e+00,
          -6.6797e-01,  1.4355e-01],
         ...,
         [ 6.3672e-01, -1.2344e+00, -4.1211e-01,  ...,  3.9368e-03,
           5.5176e-02,  5.3125e-01],
         [ 2.6172e-01, -4.1260e-02,  7.6660e-02,  ...,  3.4961e-01,
           6.6797e-01,  2.0630e-02],
         [-4.4531e-01, -5.5859e-01, -1.0938e+00,  ...,  7.1875e-01,
          -3.8672e-01, -2.6758e-01]],

        [[-1.7456e-02,  5.1514e-02, -1.7262e-04,  ...,  2.8968e-05,
           2.8687e-03,  1.8066e-02],
         [-1.3125e+00,  4.2969e-02,  1.1719e-01,  ...,  7.4219e-01,
          -8.2812e-01, -4.1602e-01],
         [-8.3984e-01, -2.1406e+00, -1.8828e+00,  ..., -1.0312e+00,
          -6.6797e-01,  1.4355e-01],
         ...,
         [ 6.3672e-01, -1.2344e+00, -4.1211e-01,  ...,  3.9368e-03,
           5.5176e-02,  5.3125e-01],
         [ 2.6172e-01, -4.1260e-02,  7.6660e-02,  ...,  3.4961e-01,
           6.6797e-01,  2.0630e-02],
         [-4.4531e-01, -5.5859e-01, -1.0938e+00,  ...,  7.1875e-01,
          -3.8672e-01, -2.6758e-01]]], dtype=torch.bfloat16)), (tensor([[[ 6.1646e-03, -5.6763e-03,  5.2490e-03,  ..., -9.8633e-02,
          -2.2656e-01, -2.7954e-02],
         [-1.6406e+00, -1.4844e-01,  2.5391e-01,  ..., -2.9688e-01,
          -2.1562e+00,  4.9072e-02],
         [-2.0625e+00,  9.0234e-01, -6.2109e-01,  ...,  1.6875e+00,
          -1.6250e+00, -1.6562e+00],
         ...,
         [ 6.8359e-02, -9.4531e-01,  2.7344e-02,  ..., -2.6094e+00,
           1.7656e+00, -1.4297e+00],
         [ 3.0078e-01,  1.1475e-01, -6.1719e-01,  ..., -1.2422e+00,
           5.9375e-01, -2.5781e+00],
         [ 1.4219e+00,  2.0156e+00,  3.4570e-01,  ..., -1.0312e+00,
          -3.3281e+00,  3.0938e+00]],

        [[ 6.1646e-03, -5.6763e-03,  5.2490e-03,  ..., -9.8633e-02,
          -2.2656e-01, -2.7954e-02],
         [-1.6406e+00, -1.4844e-01,  2.5391e-01,  ..., -2.9688e-01,
          -2.1562e+00,  4.9072e-02],
         [-2.0625e+00,  9.0234e-01, -6.2109e-01,  ...,  1.6875e+00,
          -1.6250e+00, -1.6562e+00],
         ...,
         [ 6.8359e-02, -9.4531e-01,  2.7344e-02,  ..., -2.6094e+00,
           1.7656e+00, -1.4297e+00],
         [ 3.0078e-01,  1.1475e-01, -6.1719e-01,  ..., -1.2422e+00,
           5.9375e-01, -2.5781e+00],
         [ 1.4219e+00,  2.0156e+00,  3.4570e-01,  ..., -1.0312e+00,
          -3.3281e+00,  3.0938e+00]],

        [[ 6.1646e-03, -5.6763e-03,  5.2490e-03,  ..., -9.8633e-02,
          -2.2656e-01, -2.7954e-02],
         [-1.6406e+00, -1.4844e-01,  2.5391e-01,  ..., -2.9688e-01,
          -2.1562e+00,  4.9072e-02],
         [-2.0625e+00,  9.0234e-01, -6.2109e-01,  ...,  1.6875e+00,
          -1.6250e+00, -1.6562e+00],
         ...,
         [ 6.8359e-02, -9.4531e-01,  2.7344e-02,  ..., -2.6094e+00,
           1.7656e+00, -1.4297e+00],
         [ 3.0078e-01,  1.1475e-01, -6.1719e-01,  ..., -1.2422e+00,
           5.9375e-01, -2.5781e+00],
         [ 1.4219e+00,  2.0156e+00,  3.4570e-01,  ..., -1.0312e+00,
          -3.3281e+00,  3.0938e+00]],

        ...,

        [[ 2.3193e-03,  1.8921e-02,  1.1902e-02,  ...,  3.8574e-02,
           1.0620e-02,  1.2207e-01],
         [ 9.7656e-01, -4.1016e-02,  1.2031e+00,  ...,  4.9805e-01,
           4.1211e-01, -8.9453e-01],
         [ 9.9609e-01,  1.0938e+00,  2.0781e+00,  ...,  2.4062e+00,
          -1.6357e-02, -6.0156e-01],
         ...,
         [-6.6016e-01, -1.5938e+00,  6.1719e-01,  ..., -9.5312e-01,
           1.6172e+00, -2.0781e+00],
         [-6.6016e-01, -2.3047e-01,  3.8281e-01,  ...,  1.3906e+00,
           2.0938e+00, -1.7188e+00],
         [-2.1562e+00,  5.9375e-01,  2.1875e+00,  ...,  1.6406e+00,
           6.1328e-01, -4.8047e-01]],

        [[ 2.3193e-03,  1.8921e-02,  1.1902e-02,  ...,  3.8574e-02,
           1.0620e-02,  1.2207e-01],
         [ 9.7656e-01, -4.1016e-02,  1.2031e+00,  ...,  4.9805e-01,
           4.1211e-01, -8.9453e-01],
         [ 9.9609e-01,  1.0938e+00,  2.0781e+00,  ...,  2.4062e+00,
          -1.6357e-02, -6.0156e-01],
         ...,
         [-6.6016e-01, -1.5938e+00,  6.1719e-01,  ..., -9.5312e-01,
           1.6172e+00, -2.0781e+00],
         [-6.6016e-01, -2.3047e-01,  3.8281e-01,  ...,  1.3906e+00,
           2.0938e+00, -1.7188e+00],
         [-2.1562e+00,  5.9375e-01,  2.1875e+00,  ...,  1.6406e+00,
           6.1328e-01, -4.8047e-01]],

        [[ 2.3193e-03,  1.8921e-02,  1.1902e-02,  ...,  3.8574e-02,
           1.0620e-02,  1.2207e-01],
         [ 9.7656e-01, -4.1016e-02,  1.2031e+00,  ...,  4.9805e-01,
           4.1211e-01, -8.9453e-01],
         [ 9.9609e-01,  1.0938e+00,  2.0781e+00,  ...,  2.4062e+00,
          -1.6357e-02, -6.0156e-01],
         ...,
         [-6.6016e-01, -1.5938e+00,  6.1719e-01,  ..., -9.5312e-01,
           1.6172e+00, -2.0781e+00],
         [-6.6016e-01, -2.3047e-01,  3.8281e-01,  ...,  1.3906e+00,
           2.0938e+00, -1.7188e+00],
         [-2.1562e+00,  5.9375e-01,  2.1875e+00,  ...,  1.6406e+00,
           6.1328e-01, -4.8047e-01]]], dtype=torch.bfloat16), tensor([[[-0.0067, -0.0065, -0.0085,  ..., -0.0104, -0.0156, -0.0349],
         [-0.1035,  0.9180,  0.8984,  ..., -0.1216,  0.4629,  0.6289],
         [-0.7266,  1.1094,  0.8711,  ...,  0.2344, -0.6406,  1.3047],
         ...,
         [ 1.2578,  0.3184, -0.1357,  ...,  0.3320,  0.8125, -1.1094],
         [-0.2871, -0.1533, -0.4902,  ...,  0.5195,  0.3477,  0.1377],
         [ 0.6992, -1.3984,  0.2617,  ..., -0.1904,  1.1797,  0.1328]],

        [[-0.0067, -0.0065, -0.0085,  ..., -0.0104, -0.0156, -0.0349],
         [-0.1035,  0.9180,  0.8984,  ..., -0.1216,  0.4629,  0.6289],
         [-0.7266,  1.1094,  0.8711,  ...,  0.2344, -0.6406,  1.3047],
         ...,
         [ 1.2578,  0.3184, -0.1357,  ...,  0.3320,  0.8125, -1.1094],
         [-0.2871, -0.1533, -0.4902,  ...,  0.5195,  0.3477,  0.1377],
         [ 0.6992, -1.3984,  0.2617,  ..., -0.1904,  1.1797,  0.1328]],

        [[-0.0067, -0.0065, -0.0085,  ..., -0.0104, -0.0156, -0.0349],
         [-0.1035,  0.9180,  0.8984,  ..., -0.1216,  0.4629,  0.6289],
         [-0.7266,  1.1094,  0.8711,  ...,  0.2344, -0.6406,  1.3047],
         ...,
         [ 1.2578,  0.3184, -0.1357,  ...,  0.3320,  0.8125, -1.1094],
         [-0.2871, -0.1533, -0.4902,  ...,  0.5195,  0.3477,  0.1377],
         [ 0.6992, -1.3984,  0.2617,  ..., -0.1904,  1.1797,  0.1328]],

        ...,

        [[ 0.0139, -0.0273,  0.0087,  ..., -0.0417,  0.0271,  0.0422],
         [ 0.2275,  0.4062, -0.0045,  ...,  0.7266, -0.3867, -0.6211],
         [-0.2422, -0.1992,  1.3438,  ...,  0.9805, -0.5234, -0.2324],
         ...,
         [ 0.6680, -1.5312, -0.2910,  ...,  0.2676,  0.0425,  0.1152],
         [ 0.5078, -0.8750, -0.6211,  ...,  0.1143,  0.4062,  0.1768],
         [-0.1030,  0.5312, -0.5469,  ..., -0.6172, -1.0312, -0.5117]],

        [[ 0.0139, -0.0273,  0.0087,  ..., -0.0417,  0.0271,  0.0422],
         [ 0.2275,  0.4062, -0.0045,  ...,  0.7266, -0.3867, -0.6211],
         [-0.2422, -0.1992,  1.3438,  ...,  0.9805, -0.5234, -0.2324],
         ...,
         [ 0.6680, -1.5312, -0.2910,  ...,  0.2676,  0.0425,  0.1152],
         [ 0.5078, -0.8750, -0.6211,  ...,  0.1143,  0.4062,  0.1768],
         [-0.1030,  0.5312, -0.5469,  ..., -0.6172, -1.0312, -0.5117]],

        [[ 0.0139, -0.0273,  0.0087,  ..., -0.0417,  0.0271,  0.0422],
         [ 0.2275,  0.4062, -0.0045,  ...,  0.7266, -0.3867, -0.6211],
         [-0.2422, -0.1992,  1.3438,  ...,  0.9805, -0.5234, -0.2324],
         ...,
         [ 0.6680, -1.5312, -0.2910,  ...,  0.2676,  0.0425,  0.1152],
         [ 0.5078, -0.8750, -0.6211,  ...,  0.1143,  0.4062,  0.1768],
         [-0.1030,  0.5312, -0.5469,  ..., -0.6172, -1.0312, -0.5117]]],
       dtype=torch.bfloat16)), (tensor([[[ 6.0425e-03, -9.7656e-03,  6.5002e-03,  ...,  1.8164e-01,
          -2.0938e+00, -4.0039e-02],
         [ 2.4023e-01,  2.0605e-01,  1.4297e+00,  ..., -9.3359e-01,
           6.4688e+00, -7.0312e-01],
         [-8.1543e-02,  7.2266e-01,  9.5312e-01,  ..., -9.1406e-01,
           7.5000e+00, -3.5938e-01],
         ...,
         [ 3.4570e-01, -3.7109e-01,  8.2422e-01,  ...,  6.2891e-01,
           4.9688e+00,  1.4688e+00],
         [-5.6250e-01, -2.4316e-01,  9.3750e-02,  ..., -8.0859e-01,
           4.9062e+00,  2.0938e+00],
         [-2.0156e+00, -1.1484e+00,  1.1094e+00,  ...,  1.5547e+00,
           6.3750e+00, -3.4531e+00]],

        [[ 6.0425e-03, -9.7656e-03,  6.5002e-03,  ...,  1.8164e-01,
          -2.0938e+00, -4.0039e-02],
         [ 2.4023e-01,  2.0605e-01,  1.4297e+00,  ..., -9.3359e-01,
           6.4688e+00, -7.0312e-01],
         [-8.1543e-02,  7.2266e-01,  9.5312e-01,  ..., -9.1406e-01,
           7.5000e+00, -3.5938e-01],
         ...,
         [ 3.4570e-01, -3.7109e-01,  8.2422e-01,  ...,  6.2891e-01,
           4.9688e+00,  1.4688e+00],
         [-5.6250e-01, -2.4316e-01,  9.3750e-02,  ..., -8.0859e-01,
           4.9062e+00,  2.0938e+00],
         [-2.0156e+00, -1.1484e+00,  1.1094e+00,  ...,  1.5547e+00,
           6.3750e+00, -3.4531e+00]],

        [[ 6.0425e-03, -9.7656e-03,  6.5002e-03,  ...,  1.8164e-01,
          -2.0938e+00, -4.0039e-02],
         [ 2.4023e-01,  2.0605e-01,  1.4297e+00,  ..., -9.3359e-01,
           6.4688e+00, -7.0312e-01],
         [-8.1543e-02,  7.2266e-01,  9.5312e-01,  ..., -9.1406e-01,
           7.5000e+00, -3.5938e-01],
         ...,
         [ 3.4570e-01, -3.7109e-01,  8.2422e-01,  ...,  6.2891e-01,
           4.9688e+00,  1.4688e+00],
         [-5.6250e-01, -2.4316e-01,  9.3750e-02,  ..., -8.0859e-01,
           4.9062e+00,  2.0938e+00],
         [-2.0156e+00, -1.1484e+00,  1.1094e+00,  ...,  1.5547e+00,
           6.3750e+00, -3.4531e+00]],

        ...,

        [[-4.9744e-03, -7.0190e-04, -9.7656e-03,  ...,  3.0859e-01,
          -4.3750e-01, -6.5430e-02],
         [ 1.8906e+00, -1.2500e+00,  9.7656e-01,  ..., -7.7344e-01,
           1.6250e+00, -5.1172e-01],
         [-1.9531e-01,  5.0781e-01, -6.7188e-01,  ..., -1.9531e+00,
           1.1719e+00,  3.0859e-01],
         ...,
         [-1.0781e+00, -4.9414e-01,  8.0078e-01,  ..., -1.6172e+00,
           2.1680e-01,  5.0537e-02],
         [-6.1328e-01, -1.0391e+00,  5.9766e-01,  ..., -2.0156e+00,
          -4.0039e-01, -4.5166e-02],
         [-3.2812e-01, -2.2188e+00,  2.3242e-01,  ..., -1.7344e+00,
           1.0938e+00, -1.6211e-01]],

        [[-4.9744e-03, -7.0190e-04, -9.7656e-03,  ...,  3.0859e-01,
          -4.3750e-01, -6.5430e-02],
         [ 1.8906e+00, -1.2500e+00,  9.7656e-01,  ..., -7.7344e-01,
           1.6250e+00, -5.1172e-01],
         [-1.9531e-01,  5.0781e-01, -6.7188e-01,  ..., -1.9531e+00,
           1.1719e+00,  3.0859e-01],
         ...,
         [-1.0781e+00, -4.9414e-01,  8.0078e-01,  ..., -1.6172e+00,
           2.1680e-01,  5.0537e-02],
         [-6.1328e-01, -1.0391e+00,  5.9766e-01,  ..., -2.0156e+00,
          -4.0039e-01, -4.5166e-02],
         [-3.2812e-01, -2.2188e+00,  2.3242e-01,  ..., -1.7344e+00,
           1.0938e+00, -1.6211e-01]],

        [[-4.9744e-03, -7.0190e-04, -9.7656e-03,  ...,  3.0859e-01,
          -4.3750e-01, -6.5430e-02],
         [ 1.8906e+00, -1.2500e+00,  9.7656e-01,  ..., -7.7344e-01,
           1.6250e+00, -5.1172e-01],
         [-1.9531e-01,  5.0781e-01, -6.7188e-01,  ..., -1.9531e+00,
           1.1719e+00,  3.0859e-01],
         ...,
         [-1.0781e+00, -4.9414e-01,  8.0078e-01,  ..., -1.6172e+00,
           2.1680e-01,  5.0537e-02],
         [-6.1328e-01, -1.0391e+00,  5.9766e-01,  ..., -2.0156e+00,
          -4.0039e-01, -4.5166e-02],
         [-3.2812e-01, -2.2188e+00,  2.3242e-01,  ..., -1.7344e+00,
           1.0938e+00, -1.6211e-01]]], dtype=torch.bfloat16), tensor([[[ 0.0245,  0.0625, -0.0684,  ..., -0.0576, -0.0496,  0.0322],
         [-0.6133,  0.3223, -0.2158,  ...,  0.2930,  0.2852, -0.3613],
         [ 0.2656,  0.1973,  0.1157,  ..., -0.3672, -0.1279,  0.3633],
         ...,
         [ 0.4941, -0.6875,  0.2812,  ...,  0.3164,  0.3203, -0.3379],
         [-0.2754, -0.9141, -0.3926,  ...,  0.4922,  0.0208, -0.1631],
         [ 0.2891,  0.6836, -0.7148,  ...,  0.1895, -0.1787,  0.4590]],

        [[ 0.0245,  0.0625, -0.0684,  ..., -0.0576, -0.0496,  0.0322],
         [-0.6133,  0.3223, -0.2158,  ...,  0.2930,  0.2852, -0.3613],
         [ 0.2656,  0.1973,  0.1157,  ..., -0.3672, -0.1279,  0.3633],
         ...,
         [ 0.4941, -0.6875,  0.2812,  ...,  0.3164,  0.3203, -0.3379],
         [-0.2754, -0.9141, -0.3926,  ...,  0.4922,  0.0208, -0.1631],
         [ 0.2891,  0.6836, -0.7148,  ...,  0.1895, -0.1787,  0.4590]],

        [[ 0.0245,  0.0625, -0.0684,  ..., -0.0576, -0.0496,  0.0322],
         [-0.6133,  0.3223, -0.2158,  ...,  0.2930,  0.2852, -0.3613],
         [ 0.2656,  0.1973,  0.1157,  ..., -0.3672, -0.1279,  0.3633],
         ...,
         [ 0.4941, -0.6875,  0.2812,  ...,  0.3164,  0.3203, -0.3379],
         [-0.2754, -0.9141, -0.3926,  ...,  0.4922,  0.0208, -0.1631],
         [ 0.2891,  0.6836, -0.7148,  ...,  0.1895, -0.1787,  0.4590]],

        ...,

        [[ 0.0066, -0.0017,  0.0037,  ...,  0.0084,  0.0132,  0.0036],
         [-0.4062, -1.0234,  0.8516,  ...,  0.2617,  1.2031,  0.4395],
         [-0.5898,  0.8008,  0.1118,  ...,  0.6875,  1.5156, -0.1387],
         ...,
         [ 0.5977,  1.6406,  0.4316,  ...,  0.0053, -1.3516, -0.8086],
         [ 0.8164, -0.5195,  0.6992,  ...,  1.5391,  0.2559,  0.5469],
         [-0.4102, -0.0165,  0.1055,  ..., -0.2617,  0.9688,  0.5547]],

        [[ 0.0066, -0.0017,  0.0037,  ...,  0.0084,  0.0132,  0.0036],
         [-0.4062, -1.0234,  0.8516,  ...,  0.2617,  1.2031,  0.4395],
         [-0.5898,  0.8008,  0.1118,  ...,  0.6875,  1.5156, -0.1387],
         ...,
         [ 0.5977,  1.6406,  0.4316,  ...,  0.0053, -1.3516, -0.8086],
         [ 0.8164, -0.5195,  0.6992,  ...,  1.5391,  0.2559,  0.5469],
         [-0.4102, -0.0165,  0.1055,  ..., -0.2617,  0.9688,  0.5547]],

        [[ 0.0066, -0.0017,  0.0037,  ...,  0.0084,  0.0132,  0.0036],
         [-0.4062, -1.0234,  0.8516,  ...,  0.2617,  1.2031,  0.4395],
         [-0.5898,  0.8008,  0.1118,  ...,  0.6875,  1.5156, -0.1387],
         ...,
         [ 0.5977,  1.6406,  0.4316,  ...,  0.0053, -1.3516, -0.8086],
         [ 0.8164, -0.5195,  0.6992,  ...,  1.5391,  0.2559,  0.5469],
         [-0.4102, -0.0165,  0.1055,  ..., -0.2617,  0.9688,  0.5547]]],
       dtype=torch.bfloat16)), (tensor([[[-3.7537e-03, -2.6245e-03,  1.2329e-02,  ...,  1.1768e-01,
          -4.6387e-02,  2.0508e-01],
         [-1.5781e+00, -1.7578e-02, -8.6719e-01,  ..., -1.2500e+00,
           7.1875e-01, -3.4961e-01],
         [-7.1875e-01, -1.1182e-01,  2.5195e-01,  ..., -1.3438e+00,
           3.3594e+00, -1.1172e+00],
         ...,
         [ 5.2344e-01,  1.4453e-01, -9.5703e-01,  ..., -1.8203e+00,
          -3.0664e-01,  2.5000e-01],
         [ 6.6406e-02, -5.5469e-01, -4.1406e-01,  ..., -6.2891e-01,
           3.4375e-01, -8.7402e-02],
         [ 1.0781e+00, -8.1641e-01, -5.1172e-01,  ...,  1.6094e+00,
           2.8125e+00, -2.4844e+00]],

        [[-3.7537e-03, -2.6245e-03,  1.2329e-02,  ...,  1.1768e-01,
          -4.6387e-02,  2.0508e-01],
         [-1.5781e+00, -1.7578e-02, -8.6719e-01,  ..., -1.2500e+00,
           7.1875e-01, -3.4961e-01],
         [-7.1875e-01, -1.1182e-01,  2.5195e-01,  ..., -1.3438e+00,
           3.3594e+00, -1.1172e+00],
         ...,
         [ 5.2344e-01,  1.4453e-01, -9.5703e-01,  ..., -1.8203e+00,
          -3.0664e-01,  2.5000e-01],
         [ 6.6406e-02, -5.5469e-01, -4.1406e-01,  ..., -6.2891e-01,
           3.4375e-01, -8.7402e-02],
         [ 1.0781e+00, -8.1641e-01, -5.1172e-01,  ...,  1.6094e+00,
           2.8125e+00, -2.4844e+00]],

        [[-3.7537e-03, -2.6245e-03,  1.2329e-02,  ...,  1.1768e-01,
          -4.6387e-02,  2.0508e-01],
         [-1.5781e+00, -1.7578e-02, -8.6719e-01,  ..., -1.2500e+00,
           7.1875e-01, -3.4961e-01],
         [-7.1875e-01, -1.1182e-01,  2.5195e-01,  ..., -1.3438e+00,
           3.3594e+00, -1.1172e+00],
         ...,
         [ 5.2344e-01,  1.4453e-01, -9.5703e-01,  ..., -1.8203e+00,
          -3.0664e-01,  2.5000e-01],
         [ 6.6406e-02, -5.5469e-01, -4.1406e-01,  ..., -6.2891e-01,
           3.4375e-01, -8.7402e-02],
         [ 1.0781e+00, -8.1641e-01, -5.1172e-01,  ...,  1.6094e+00,
           2.8125e+00, -2.4844e+00]],

        ...,

        [[-1.0315e-02,  6.5804e-05,  1.0681e-02,  ...,  5.1172e-01,
           1.9043e-02, -2.9907e-02],
         [ 9.0625e-01, -7.8125e-02, -1.0156e+00,  ..., -3.0000e+00,
           4.2578e-01, -9.8047e-01],
         [ 6.7188e-01,  3.4180e-01,  4.0625e-01,  ..., -1.3594e+00,
           1.4844e-01, -1.0312e+00],
         ...,
         [-7.0312e-01, -9.0625e-01, -5.6641e-01,  ..., -1.2109e+00,
          -3.1094e+00, -8.7500e-01],
         [ 2.1680e-01, -1.4648e-01,  4.7852e-02,  ..., -2.1362e-03,
          -9.1016e-01,  3.8867e-01],
         [ 8.3203e-01, -1.8594e+00, -1.6875e+00,  ...,  2.7148e-01,
          -4.1992e-01,  4.6680e-01]],

        [[-1.0315e-02,  6.5804e-05,  1.0681e-02,  ...,  5.1172e-01,
           1.9043e-02, -2.9907e-02],
         [ 9.0625e-01, -7.8125e-02, -1.0156e+00,  ..., -3.0000e+00,
           4.2578e-01, -9.8047e-01],
         [ 6.7188e-01,  3.4180e-01,  4.0625e-01,  ..., -1.3594e+00,
           1.4844e-01, -1.0312e+00],
         ...,
         [-7.0312e-01, -9.0625e-01, -5.6641e-01,  ..., -1.2109e+00,
          -3.1094e+00, -8.7500e-01],
         [ 2.1680e-01, -1.4648e-01,  4.7852e-02,  ..., -2.1362e-03,
          -9.1016e-01,  3.8867e-01],
         [ 8.3203e-01, -1.8594e+00, -1.6875e+00,  ...,  2.7148e-01,
          -4.1992e-01,  4.6680e-01]],

        [[-1.0315e-02,  6.5804e-05,  1.0681e-02,  ...,  5.1172e-01,
           1.9043e-02, -2.9907e-02],
         [ 9.0625e-01, -7.8125e-02, -1.0156e+00,  ..., -3.0000e+00,
           4.2578e-01, -9.8047e-01],
         [ 6.7188e-01,  3.4180e-01,  4.0625e-01,  ..., -1.3594e+00,
           1.4844e-01, -1.0312e+00],
         ...,
         [-7.0312e-01, -9.0625e-01, -5.6641e-01,  ..., -1.2109e+00,
          -3.1094e+00, -8.7500e-01],
         [ 2.1680e-01, -1.4648e-01,  4.7852e-02,  ..., -2.1362e-03,
          -9.1016e-01,  3.8867e-01],
         [ 8.3203e-01, -1.8594e+00, -1.6875e+00,  ...,  2.7148e-01,
          -4.1992e-01,  4.6680e-01]]], dtype=torch.bfloat16), tensor([[[-2.2339e-02, -2.2583e-02, -1.3672e-01,  ...,  2.0386e-02,
           5.4016e-03, -2.3193e-02],
         [-7.5000e-01, -3.5547e-01,  6.2500e-01,  ...,  3.8281e-01,
           3.4961e-01,  3.8477e-01],
         [ 2.2266e-01, -3.2617e-01,  4.4727e-01,  ..., -1.6504e-01,
           5.8203e-01,  2.3047e-01],
         ...,
         [-1.2109e-01, -1.7676e-01, -8.2422e-01,  ...,  4.3945e-01,
           2.3242e-01,  1.0312e+00],
         [-8.5938e-01,  1.5547e+00,  2.0801e-01,  ...,  9.2578e-01,
           4.7266e-01,  7.3438e-01],
         [ 9.5703e-01, -1.5918e-01, -5.7812e-01,  ..., -2.7930e-01,
           4.8242e-01,  4.6094e-01]],

        [[-2.2339e-02, -2.2583e-02, -1.3672e-01,  ...,  2.0386e-02,
           5.4016e-03, -2.3193e-02],
         [-7.5000e-01, -3.5547e-01,  6.2500e-01,  ...,  3.8281e-01,
           3.4961e-01,  3.8477e-01],
         [ 2.2266e-01, -3.2617e-01,  4.4727e-01,  ..., -1.6504e-01,
           5.8203e-01,  2.3047e-01],
         ...,
         [-1.2109e-01, -1.7676e-01, -8.2422e-01,  ...,  4.3945e-01,
           2.3242e-01,  1.0312e+00],
         [-8.5938e-01,  1.5547e+00,  2.0801e-01,  ...,  9.2578e-01,
           4.7266e-01,  7.3438e-01],
         [ 9.5703e-01, -1.5918e-01, -5.7812e-01,  ..., -2.7930e-01,
           4.8242e-01,  4.6094e-01]],

        [[-2.2339e-02, -2.2583e-02, -1.3672e-01,  ...,  2.0386e-02,
           5.4016e-03, -2.3193e-02],
         [-7.5000e-01, -3.5547e-01,  6.2500e-01,  ...,  3.8281e-01,
           3.4961e-01,  3.8477e-01],
         [ 2.2266e-01, -3.2617e-01,  4.4727e-01,  ..., -1.6504e-01,
           5.8203e-01,  2.3047e-01],
         ...,
         [-1.2109e-01, -1.7676e-01, -8.2422e-01,  ...,  4.3945e-01,
           2.3242e-01,  1.0312e+00],
         [-8.5938e-01,  1.5547e+00,  2.0801e-01,  ...,  9.2578e-01,
           4.7266e-01,  7.3438e-01],
         [ 9.5703e-01, -1.5918e-01, -5.7812e-01,  ..., -2.7930e-01,
           4.8242e-01,  4.6094e-01]],

        ...,

        [[-1.1963e-02,  1.4801e-03, -6.3171e-03,  ...,  3.9062e-03,
          -2.3315e-02, -1.6327e-03],
         [-4.7852e-01, -4.9219e-01,  1.7285e-01,  ..., -3.8086e-01,
           7.8516e-01,  5.4297e-01],
         [-1.1953e+00, -3.8281e-01, -4.9609e-01,  ..., -6.0938e-01,
           5.3516e-01,  4.7266e-01],
         ...,
         [-2.1289e-01,  2.1484e-01,  7.0312e-01,  ...,  5.2344e-01,
          -5.7031e-01, -6.7188e-01],
         [-1.8359e-01,  2.7100e-02,  2.2754e-01,  ...,  7.3047e-01,
           6.4941e-02,  1.7871e-01],
         [-1.3477e-01,  1.2329e-02,  1.0156e+00,  ...,  6.4062e-01,
           1.2817e-02,  1.2878e-02]],

        [[-1.1963e-02,  1.4801e-03, -6.3171e-03,  ...,  3.9062e-03,
          -2.3315e-02, -1.6327e-03],
         [-4.7852e-01, -4.9219e-01,  1.7285e-01,  ..., -3.8086e-01,
           7.8516e-01,  5.4297e-01],
         [-1.1953e+00, -3.8281e-01, -4.9609e-01,  ..., -6.0938e-01,
           5.3516e-01,  4.7266e-01],
         ...,
         [-2.1289e-01,  2.1484e-01,  7.0312e-01,  ...,  5.2344e-01,
          -5.7031e-01, -6.7188e-01],
         [-1.8359e-01,  2.7100e-02,  2.2754e-01,  ...,  7.3047e-01,
           6.4941e-02,  1.7871e-01],
         [-1.3477e-01,  1.2329e-02,  1.0156e+00,  ...,  6.4062e-01,
           1.2817e-02,  1.2878e-02]],

        [[-1.1963e-02,  1.4801e-03, -6.3171e-03,  ...,  3.9062e-03,
          -2.3315e-02, -1.6327e-03],
         [-4.7852e-01, -4.9219e-01,  1.7285e-01,  ..., -3.8086e-01,
           7.8516e-01,  5.4297e-01],
         [-1.1953e+00, -3.8281e-01, -4.9609e-01,  ..., -6.0938e-01,
           5.3516e-01,  4.7266e-01],
         ...,
         [-2.1289e-01,  2.1484e-01,  7.0312e-01,  ...,  5.2344e-01,
          -5.7031e-01, -6.7188e-01],
         [-1.8359e-01,  2.7100e-02,  2.2754e-01,  ...,  7.3047e-01,
           6.4941e-02,  1.7871e-01],
         [-1.3477e-01,  1.2329e-02,  1.0156e+00,  ...,  6.4062e-01,
           1.2817e-02,  1.2878e-02]]], dtype=torch.bfloat16)), (tensor([[[ 7.9956e-03,  4.7302e-03,  9.1553e-05,  ...,  4.0234e-01,
          -7.4707e-02, -5.3516e-01],
         [ 1.8750e-01, -2.0312e+00,  1.7109e+00,  ..., -2.7344e+00,
           9.1406e-01,  3.5938e-01],
         [-6.9922e-01,  2.0508e-01,  7.5781e-01,  ..., -2.6875e+00,
           4.7070e-01, -4.1504e-02],
         ...,
         [-9.1797e-02,  1.2812e+00,  9.9609e-01,  ..., -3.6406e+00,
          -3.3203e-02,  1.7891e+00],
         [ 7.5781e-01,  3.4570e-01,  9.4141e-01,  ..., -3.6875e+00,
          -2.3750e+00,  3.2656e+00],
         [ 2.7656e+00, -6.6797e-01,  1.9688e+00,  ..., -4.1250e+00,
           2.7188e+00,  1.1250e+00]],

        [[ 7.9956e-03,  4.7302e-03,  9.1553e-05,  ...,  4.0234e-01,
          -7.4707e-02, -5.3516e-01],
         [ 1.8750e-01, -2.0312e+00,  1.7109e+00,  ..., -2.7344e+00,
           9.1406e-01,  3.5938e-01],
         [-6.9922e-01,  2.0508e-01,  7.5781e-01,  ..., -2.6875e+00,
           4.7070e-01, -4.1504e-02],
         ...,
         [-9.1797e-02,  1.2812e+00,  9.9609e-01,  ..., -3.6406e+00,
          -3.3203e-02,  1.7891e+00],
         [ 7.5781e-01,  3.4570e-01,  9.4141e-01,  ..., -3.6875e+00,
          -2.3750e+00,  3.2656e+00],
         [ 2.7656e+00, -6.6797e-01,  1.9688e+00,  ..., -4.1250e+00,
           2.7188e+00,  1.1250e+00]],

        [[ 7.9956e-03,  4.7302e-03,  9.1553e-05,  ...,  4.0234e-01,
          -7.4707e-02, -5.3516e-01],
         [ 1.8750e-01, -2.0312e+00,  1.7109e+00,  ..., -2.7344e+00,
           9.1406e-01,  3.5938e-01],
         [-6.9922e-01,  2.0508e-01,  7.5781e-01,  ..., -2.6875e+00,
           4.7070e-01, -4.1504e-02],
         ...,
         [-9.1797e-02,  1.2812e+00,  9.9609e-01,  ..., -3.6406e+00,
          -3.3203e-02,  1.7891e+00],
         [ 7.5781e-01,  3.4570e-01,  9.4141e-01,  ..., -3.6875e+00,
          -2.3750e+00,  3.2656e+00],
         [ 2.7656e+00, -6.6797e-01,  1.9688e+00,  ..., -4.1250e+00,
           2.7188e+00,  1.1250e+00]],

        ...,

        [[ 3.9368e-03,  7.5378e-03, -1.5137e-02,  ...,  9.0820e-02,
          -4.9805e-01, -1.0469e+00],
         [ 1.6719e+00, -2.3633e-01,  4.3359e-01,  ..., -6.9531e-01,
           4.8438e+00,  3.9531e+00],
         [ 2.9297e-01,  6.9922e-01,  4.2383e-01,  ..., -1.3359e+00,
           3.0156e+00,  5.8125e+00],
         ...,
         [-1.7109e+00,  3.1250e-02, -6.0938e-01,  ..., -2.2070e-01,
          -5.8984e-01,  3.6562e+00],
         [-2.3242e-01,  1.3086e-01,  5.7422e-01,  ...,  1.1094e+00,
           2.0386e-02,  4.0625e+00],
         [ 7.1875e-01, -1.0625e+00,  1.0312e+00,  ..., -3.7344e+00,
           3.1719e+00,  5.3750e+00]],

        [[ 3.9368e-03,  7.5378e-03, -1.5137e-02,  ...,  9.0820e-02,
          -4.9805e-01, -1.0469e+00],
         [ 1.6719e+00, -2.3633e-01,  4.3359e-01,  ..., -6.9531e-01,
           4.8438e+00,  3.9531e+00],
         [ 2.9297e-01,  6.9922e-01,  4.2383e-01,  ..., -1.3359e+00,
           3.0156e+00,  5.8125e+00],
         ...,
         [-1.7109e+00,  3.1250e-02, -6.0938e-01,  ..., -2.2070e-01,
          -5.8984e-01,  3.6562e+00],
         [-2.3242e-01,  1.3086e-01,  5.7422e-01,  ...,  1.1094e+00,
           2.0386e-02,  4.0625e+00],
         [ 7.1875e-01, -1.0625e+00,  1.0312e+00,  ..., -3.7344e+00,
           3.1719e+00,  5.3750e+00]],

        [[ 3.9368e-03,  7.5378e-03, -1.5137e-02,  ...,  9.0820e-02,
          -4.9805e-01, -1.0469e+00],
         [ 1.6719e+00, -2.3633e-01,  4.3359e-01,  ..., -6.9531e-01,
           4.8438e+00,  3.9531e+00],
         [ 2.9297e-01,  6.9922e-01,  4.2383e-01,  ..., -1.3359e+00,
           3.0156e+00,  5.8125e+00],
         ...,
         [-1.7109e+00,  3.1250e-02, -6.0938e-01,  ..., -2.2070e-01,
          -5.8984e-01,  3.6562e+00],
         [-2.3242e-01,  1.3086e-01,  5.7422e-01,  ...,  1.1094e+00,
           2.0386e-02,  4.0625e+00],
         [ 7.1875e-01, -1.0625e+00,  1.0312e+00,  ..., -3.7344e+00,
           3.1719e+00,  5.3750e+00]]], dtype=torch.bfloat16), tensor([[[-7.3547e-03, -4.1504e-03, -1.7334e-02,  ...,  1.0132e-02,
          -1.4725e-03,  4.8828e-03],
         [-9.4531e-01,  1.7188e-01,  2.5781e+00,  ..., -4.9219e-01,
           1.5234e+00,  1.3594e+00],
         [-2.6562e-01, -1.1484e+00,  1.6406e+00,  ..., -1.5078e+00,
           9.1016e-01,  1.6406e+00],
         ...,
         [-5.7031e-01,  3.6865e-02, -3.2227e-01,  ...,  3.5742e-01,
          -2.7930e-01, -3.2422e-01],
         [-1.4258e-01, -4.2578e-01, -6.1719e-01,  ...,  1.3770e-01,
          -2.7344e-01,  4.9023e-01],
         [-1.2969e+00, -4.5508e-01,  6.1719e-01,  ...,  1.6602e-01,
          -1.2512e-02, -3.4766e-01]],

        [[-7.3547e-03, -4.1504e-03, -1.7334e-02,  ...,  1.0132e-02,
          -1.4725e-03,  4.8828e-03],
         [-9.4531e-01,  1.7188e-01,  2.5781e+00,  ..., -4.9219e-01,
           1.5234e+00,  1.3594e+00],
         [-2.6562e-01, -1.1484e+00,  1.6406e+00,  ..., -1.5078e+00,
           9.1016e-01,  1.6406e+00],
         ...,
         [-5.7031e-01,  3.6865e-02, -3.2227e-01,  ...,  3.5742e-01,
          -2.7930e-01, -3.2422e-01],
         [-1.4258e-01, -4.2578e-01, -6.1719e-01,  ...,  1.3770e-01,
          -2.7344e-01,  4.9023e-01],
         [-1.2969e+00, -4.5508e-01,  6.1719e-01,  ...,  1.6602e-01,
          -1.2512e-02, -3.4766e-01]],

        [[-7.3547e-03, -4.1504e-03, -1.7334e-02,  ...,  1.0132e-02,
          -1.4725e-03,  4.8828e-03],
         [-9.4531e-01,  1.7188e-01,  2.5781e+00,  ..., -4.9219e-01,
           1.5234e+00,  1.3594e+00],
         [-2.6562e-01, -1.1484e+00,  1.6406e+00,  ..., -1.5078e+00,
           9.1016e-01,  1.6406e+00],
         ...,
         [-5.7031e-01,  3.6865e-02, -3.2227e-01,  ...,  3.5742e-01,
          -2.7930e-01, -3.2422e-01],
         [-1.4258e-01, -4.2578e-01, -6.1719e-01,  ...,  1.3770e-01,
          -2.7344e-01,  4.9023e-01],
         [-1.2969e+00, -4.5508e-01,  6.1719e-01,  ...,  1.6602e-01,
          -1.2512e-02, -3.4766e-01]],

        ...,

        [[-7.8735e-03,  1.6357e-02,  8.4686e-04,  ..., -9.6436e-03,
          -1.5747e-02,  1.2512e-02],
         [-1.0840e-01, -1.2031e+00, -2.6953e-01,  ..., -1.9653e-02,
          -2.6953e-01,  1.5991e-02],
         [ 5.1172e-01,  1.1328e+00, -1.2402e-01,  ...,  1.3281e-01,
           4.2773e-01, -2.9297e-01],
         ...,
         [-9.0625e-01, -5.1172e-01,  2.8076e-02,  ...,  5.0000e-01,
          -1.0469e+00, -1.0559e-02],
         [ 7.3242e-02, -2.7539e-01,  1.1963e-01,  ..., -6.3281e-01,
           6.6406e-01, -1.3086e-01],
         [-1.1172e+00,  1.8281e+00,  1.7031e+00,  ..., -6.9141e-01,
           9.6094e-01, -5.0391e-01]],

        [[-7.8735e-03,  1.6357e-02,  8.4686e-04,  ..., -9.6436e-03,
          -1.5747e-02,  1.2512e-02],
         [-1.0840e-01, -1.2031e+00, -2.6953e-01,  ..., -1.9653e-02,
          -2.6953e-01,  1.5991e-02],
         [ 5.1172e-01,  1.1328e+00, -1.2402e-01,  ...,  1.3281e-01,
           4.2773e-01, -2.9297e-01],
         ...,
         [-9.0625e-01, -5.1172e-01,  2.8076e-02,  ...,  5.0000e-01,
          -1.0469e+00, -1.0559e-02],
         [ 7.3242e-02, -2.7539e-01,  1.1963e-01,  ..., -6.3281e-01,
           6.6406e-01, -1.3086e-01],
         [-1.1172e+00,  1.8281e+00,  1.7031e+00,  ..., -6.9141e-01,
           9.6094e-01, -5.0391e-01]],

        [[-7.8735e-03,  1.6357e-02,  8.4686e-04,  ..., -9.6436e-03,
          -1.5747e-02,  1.2512e-02],
         [-1.0840e-01, -1.2031e+00, -2.6953e-01,  ..., -1.9653e-02,
          -2.6953e-01,  1.5991e-02],
         [ 5.1172e-01,  1.1328e+00, -1.2402e-01,  ...,  1.3281e-01,
           4.2773e-01, -2.9297e-01],
         ...,
         [-9.0625e-01, -5.1172e-01,  2.8076e-02,  ...,  5.0000e-01,
          -1.0469e+00, -1.0559e-02],
         [ 7.3242e-02, -2.7539e-01,  1.1963e-01,  ..., -6.3281e-01,
           6.6406e-01, -1.3086e-01],
         [-1.1172e+00,  1.8281e+00,  1.7031e+00,  ..., -6.9141e-01,
           9.6094e-01, -5.0391e-01]]], dtype=torch.bfloat16)), (tensor([[[ 3.4332e-04,  1.2634e-02, -1.8433e-02,  ...,  3.8330e-02,
           1.4941e-01, -1.8828e+00],
         [ 9.5312e-01, -1.9219e+00, -1.5820e-01,  ..., -1.0391e+00,
          -6.2891e-01,  7.4688e+00],
         [-3.3594e-01, -6.1719e-01, -3.4180e-01,  ...,  6.3281e-01,
           2.0156e+00,  8.3750e+00],
         ...,
         [-1.1621e-01,  3.0469e-01, -4.6094e-01,  ..., -2.9688e+00,
          -3.5312e+00,  5.2812e+00],
         [ 3.3203e-01,  1.1621e-01,  2.7148e-01,  ..., -2.9531e+00,
          -1.5938e+00,  5.8438e+00],
         [ 1.4844e+00, -7.9297e-01, -1.1719e-02,  ..., -8.4961e-02,
           2.4062e+00,  7.5000e+00]],

        [[ 3.4332e-04,  1.2634e-02, -1.8433e-02,  ...,  3.8330e-02,
           1.4941e-01, -1.8828e+00],
         [ 9.5312e-01, -1.9219e+00, -1.5820e-01,  ..., -1.0391e+00,
          -6.2891e-01,  7.4688e+00],
         [-3.3594e-01, -6.1719e-01, -3.4180e-01,  ...,  6.3281e-01,
           2.0156e+00,  8.3750e+00],
         ...,
         [-1.1621e-01,  3.0469e-01, -4.6094e-01,  ..., -2.9688e+00,
          -3.5312e+00,  5.2812e+00],
         [ 3.3203e-01,  1.1621e-01,  2.7148e-01,  ..., -2.9531e+00,
          -1.5938e+00,  5.8438e+00],
         [ 1.4844e+00, -7.9297e-01, -1.1719e-02,  ..., -8.4961e-02,
           2.4062e+00,  7.5000e+00]],

        [[ 3.4332e-04,  1.2634e-02, -1.8433e-02,  ...,  3.8330e-02,
           1.4941e-01, -1.8828e+00],
         [ 9.5312e-01, -1.9219e+00, -1.5820e-01,  ..., -1.0391e+00,
          -6.2891e-01,  7.4688e+00],
         [-3.3594e-01, -6.1719e-01, -3.4180e-01,  ...,  6.3281e-01,
           2.0156e+00,  8.3750e+00],
         ...,
         [-1.1621e-01,  3.0469e-01, -4.6094e-01,  ..., -2.9688e+00,
          -3.5312e+00,  5.2812e+00],
         [ 3.3203e-01,  1.1621e-01,  2.7148e-01,  ..., -2.9531e+00,
          -1.5938e+00,  5.8438e+00],
         [ 1.4844e+00, -7.9297e-01, -1.1719e-02,  ..., -8.4961e-02,
           2.4062e+00,  7.5000e+00]],

        ...,

        [[ 4.5776e-03, -9.3994e-03, -8.3618e-03,  ..., -1.5156e+00,
          -4.8828e-01,  8.8379e-02],
         [-2.8438e+00,  2.7188e+00,  1.3594e+00,  ...,  8.3750e+00,
          -1.0234e+00, -5.2344e-01],
         [-7.5000e-01,  4.7266e-01,  3.7305e-01,  ...,  9.1875e+00,
          -2.5000e+00, -2.6406e+00],
         ...,
         [ 9.3359e-01, -7.5391e-01, -8.5547e-01,  ...,  7.8438e+00,
          -3.4375e-01,  8.0859e-01],
         [ 7.2266e-01, -9.7168e-02,  3.1836e-01,  ...,  7.0938e+00,
           1.1719e+00,  6.3672e-01],
         [ 1.2422e+00,  5.5078e-01,  5.7812e-01,  ...,  7.6250e+00,
           7.0312e-02, -2.3750e+00]],

        [[ 4.5776e-03, -9.3994e-03, -8.3618e-03,  ..., -1.5156e+00,
          -4.8828e-01,  8.8379e-02],
         [-2.8438e+00,  2.7188e+00,  1.3594e+00,  ...,  8.3750e+00,
          -1.0234e+00, -5.2344e-01],
         [-7.5000e-01,  4.7266e-01,  3.7305e-01,  ...,  9.1875e+00,
          -2.5000e+00, -2.6406e+00],
         ...,
         [ 9.3359e-01, -7.5391e-01, -8.5547e-01,  ...,  7.8438e+00,
          -3.4375e-01,  8.0859e-01],
         [ 7.2266e-01, -9.7168e-02,  3.1836e-01,  ...,  7.0938e+00,
           1.1719e+00,  6.3672e-01],
         [ 1.2422e+00,  5.5078e-01,  5.7812e-01,  ...,  7.6250e+00,
           7.0312e-02, -2.3750e+00]],

        [[ 4.5776e-03, -9.3994e-03, -8.3618e-03,  ..., -1.5156e+00,
          -4.8828e-01,  8.8379e-02],
         [-2.8438e+00,  2.7188e+00,  1.3594e+00,  ...,  8.3750e+00,
          -1.0234e+00, -5.2344e-01],
         [-7.5000e-01,  4.7266e-01,  3.7305e-01,  ...,  9.1875e+00,
          -2.5000e+00, -2.6406e+00],
         ...,
         [ 9.3359e-01, -7.5391e-01, -8.5547e-01,  ...,  7.8438e+00,
          -3.4375e-01,  8.0859e-01],
         [ 7.2266e-01, -9.7168e-02,  3.1836e-01,  ...,  7.0938e+00,
           1.1719e+00,  6.3672e-01],
         [ 1.2422e+00,  5.5078e-01,  5.7812e-01,  ...,  7.6250e+00,
           7.0312e-02, -2.3750e+00]]], dtype=torch.bfloat16), tensor([[[ 8.1543e-02,  4.4556e-03,  5.4199e-02,  ..., -3.0212e-03,
          -8.1787e-03, -3.4668e-02],
         [-3.1641e-01, -2.2461e-01, -2.6758e-01,  ...,  4.6289e-01,
          -2.1191e-01, -3.3789e-01],
         [-8.7402e-02,  3.1250e-01,  4.8828e-01,  ...,  7.9297e-01,
           9.8633e-02,  1.7090e-01],
         ...,
         [ 6.1719e-01,  9.8047e-01,  8.5938e-02,  ...,  3.6377e-02,
           4.3945e-01, -5.7031e-01],
         [-9.4238e-02,  4.3359e-01,  2.4707e-01,  ...,  8.7891e-01,
           7.1875e-01, -8.4375e-01],
         [-6.8359e-02,  3.2422e-01, -3.3398e-01,  ...,  3.8672e-01,
           5.4688e-01, -3.1055e-01]],

        [[ 8.1543e-02,  4.4556e-03,  5.4199e-02,  ..., -3.0212e-03,
          -8.1787e-03, -3.4668e-02],
         [-3.1641e-01, -2.2461e-01, -2.6758e-01,  ...,  4.6289e-01,
          -2.1191e-01, -3.3789e-01],
         [-8.7402e-02,  3.1250e-01,  4.8828e-01,  ...,  7.9297e-01,
           9.8633e-02,  1.7090e-01],
         ...,
         [ 6.1719e-01,  9.8047e-01,  8.5938e-02,  ...,  3.6377e-02,
           4.3945e-01, -5.7031e-01],
         [-9.4238e-02,  4.3359e-01,  2.4707e-01,  ...,  8.7891e-01,
           7.1875e-01, -8.4375e-01],
         [-6.8359e-02,  3.2422e-01, -3.3398e-01,  ...,  3.8672e-01,
           5.4688e-01, -3.1055e-01]],

        [[ 8.1543e-02,  4.4556e-03,  5.4199e-02,  ..., -3.0212e-03,
          -8.1787e-03, -3.4668e-02],
         [-3.1641e-01, -2.2461e-01, -2.6758e-01,  ...,  4.6289e-01,
          -2.1191e-01, -3.3789e-01],
         [-8.7402e-02,  3.1250e-01,  4.8828e-01,  ...,  7.9297e-01,
           9.8633e-02,  1.7090e-01],
         ...,
         [ 6.1719e-01,  9.8047e-01,  8.5938e-02,  ...,  3.6377e-02,
           4.3945e-01, -5.7031e-01],
         [-9.4238e-02,  4.3359e-01,  2.4707e-01,  ...,  8.7891e-01,
           7.1875e-01, -8.4375e-01],
         [-6.8359e-02,  3.2422e-01, -3.3398e-01,  ...,  3.8672e-01,
           5.4688e-01, -3.1055e-01]],

        ...,

        [[ 5.6763e-03, -2.8534e-03,  6.3705e-04,  ...,  3.8300e-03,
          -9.3994e-03,  3.9673e-03],
         [-2.5938e+00,  3.1406e+00, -4.9023e-01,  ..., -5.6250e-01,
          -1.3281e+00,  2.0142e-02],
         [-2.4844e+00,  1.4297e+00, -1.9922e+00,  ...,  5.3516e-01,
          -4.1797e-01, -2.4609e-01],
         ...,
         [ 5.4688e-01, -5.9766e-01,  8.5938e-01,  ...,  6.7969e-01,
          -7.6172e-01,  4.8438e-01],
         [ 3.1445e-01, -3.9795e-02,  5.6250e-01,  ..., -7.1289e-02,
          -6.4453e-02,  3.7598e-02],
         [-1.8262e-01, -1.9141e-01, -5.1562e-01,  ...,  2.3828e-01,
           2.0781e+00, -9.9219e-01]],

        [[ 5.6763e-03, -2.8534e-03,  6.3705e-04,  ...,  3.8300e-03,
          -9.3994e-03,  3.9673e-03],
         [-2.5938e+00,  3.1406e+00, -4.9023e-01,  ..., -5.6250e-01,
          -1.3281e+00,  2.0142e-02],
         [-2.4844e+00,  1.4297e+00, -1.9922e+00,  ...,  5.3516e-01,
          -4.1797e-01, -2.4609e-01],
         ...,
         [ 5.4688e-01, -5.9766e-01,  8.5938e-01,  ...,  6.7969e-01,
          -7.6172e-01,  4.8438e-01],
         [ 3.1445e-01, -3.9795e-02,  5.6250e-01,  ..., -7.1289e-02,
          -6.4453e-02,  3.7598e-02],
         [-1.8262e-01, -1.9141e-01, -5.1562e-01,  ...,  2.3828e-01,
           2.0781e+00, -9.9219e-01]],

        [[ 5.6763e-03, -2.8534e-03,  6.3705e-04,  ...,  3.8300e-03,
          -9.3994e-03,  3.9673e-03],
         [-2.5938e+00,  3.1406e+00, -4.9023e-01,  ..., -5.6250e-01,
          -1.3281e+00,  2.0142e-02],
         [-2.4844e+00,  1.4297e+00, -1.9922e+00,  ...,  5.3516e-01,
          -4.1797e-01, -2.4609e-01],
         ...,
         [ 5.4688e-01, -5.9766e-01,  8.5938e-01,  ...,  6.7969e-01,
          -7.6172e-01,  4.8438e-01],
         [ 3.1445e-01, -3.9795e-02,  5.6250e-01,  ..., -7.1289e-02,
          -6.4453e-02,  3.7598e-02],
         [-1.8262e-01, -1.9141e-01, -5.1562e-01,  ...,  2.3828e-01,
           2.0781e+00, -9.9219e-01]]], dtype=torch.bfloat16)), (tensor([[[-4.6921e-04,  1.2268e-02,  1.6113e-02,  ...,  2.0703e-01,
           7.5684e-02,  7.1716e-03],
         [-1.4531e+00, -7.7344e-01, -3.1055e-01,  ..., -5.7812e-01,
          -1.6968e-02,  5.7068e-03],
         [-9.0234e-01,  3.5938e-01, -5.3125e-01,  ..., -1.9922e+00,
           2.1250e+00, -2.1406e+00],
         ...,
         [ 2.5391e-01,  3.5156e-01, -2.4023e-01,  ...,  1.4609e+00,
          -3.7031e+00, -1.8594e+00],
         [-2.1484e-01,  5.2344e-01,  4.8828e-04,  ...,  6.0156e-01,
          -5.0781e-01, -2.0938e+00],
         [ 2.3125e+00,  9.6875e-01, -1.5859e+00,  ..., -1.2500e+00,
           8.1250e-01, -9.9609e-01]],

        [[-4.6921e-04,  1.2268e-02,  1.6113e-02,  ...,  2.0703e-01,
           7.5684e-02,  7.1716e-03],
         [-1.4531e+00, -7.7344e-01, -3.1055e-01,  ..., -5.7812e-01,
          -1.6968e-02,  5.7068e-03],
         [-9.0234e-01,  3.5938e-01, -5.3125e-01,  ..., -1.9922e+00,
           2.1250e+00, -2.1406e+00],
         ...,
         [ 2.5391e-01,  3.5156e-01, -2.4023e-01,  ...,  1.4609e+00,
          -3.7031e+00, -1.8594e+00],
         [-2.1484e-01,  5.2344e-01,  4.8828e-04,  ...,  6.0156e-01,
          -5.0781e-01, -2.0938e+00],
         [ 2.3125e+00,  9.6875e-01, -1.5859e+00,  ..., -1.2500e+00,
           8.1250e-01, -9.9609e-01]],

        [[-4.6921e-04,  1.2268e-02,  1.6113e-02,  ...,  2.0703e-01,
           7.5684e-02,  7.1716e-03],
         [-1.4531e+00, -7.7344e-01, -3.1055e-01,  ..., -5.7812e-01,
          -1.6968e-02,  5.7068e-03],
         [-9.0234e-01,  3.5938e-01, -5.3125e-01,  ..., -1.9922e+00,
           2.1250e+00, -2.1406e+00],
         ...,
         [ 2.5391e-01,  3.5156e-01, -2.4023e-01,  ...,  1.4609e+00,
          -3.7031e+00, -1.8594e+00],
         [-2.1484e-01,  5.2344e-01,  4.8828e-04,  ...,  6.0156e-01,
          -5.0781e-01, -2.0938e+00],
         [ 2.3125e+00,  9.6875e-01, -1.5859e+00,  ..., -1.2500e+00,
           8.1250e-01, -9.9609e-01]],

        ...,

        [[-1.4221e-02,  1.0681e-02,  2.6398e-03,  ...,  5.9082e-02,
          -9.4727e-02, -1.1328e-01],
         [ 9.3359e-01, -9.9219e-01,  3.9062e-01,  ..., -4.4727e-01,
          -3.0625e+00, -1.9727e-01],
         [-1.7969e-01, -1.2266e+00,  8.1641e-01,  ...,  1.0234e+00,
          -3.5938e-01,  8.9062e-01],
         ...,
         [-9.2578e-01,  2.4219e-01, -1.3281e+00,  ...,  1.7344e+00,
           3.0273e-01,  1.4766e+00],
         [-2.3633e-01,  3.8672e-01, -3.7695e-01,  ...,  7.6953e-01,
           4.2969e-01,  1.3281e+00],
         [ 4.1016e-01,  1.4375e+00,  2.0703e-01,  ...,  1.9922e+00,
           1.0938e+00, -4.3750e+00]],

        [[-1.4221e-02,  1.0681e-02,  2.6398e-03,  ...,  5.9082e-02,
          -9.4727e-02, -1.1328e-01],
         [ 9.3359e-01, -9.9219e-01,  3.9062e-01,  ..., -4.4727e-01,
          -3.0625e+00, -1.9727e-01],
         [-1.7969e-01, -1.2266e+00,  8.1641e-01,  ...,  1.0234e+00,
          -3.5938e-01,  8.9062e-01],
         ...,
         [-9.2578e-01,  2.4219e-01, -1.3281e+00,  ...,  1.7344e+00,
           3.0273e-01,  1.4766e+00],
         [-2.3633e-01,  3.8672e-01, -3.7695e-01,  ...,  7.6953e-01,
           4.2969e-01,  1.3281e+00],
         [ 4.1016e-01,  1.4375e+00,  2.0703e-01,  ...,  1.9922e+00,
           1.0938e+00, -4.3750e+00]],

        [[-1.4221e-02,  1.0681e-02,  2.6398e-03,  ...,  5.9082e-02,
          -9.4727e-02, -1.1328e-01],
         [ 9.3359e-01, -9.9219e-01,  3.9062e-01,  ..., -4.4727e-01,
          -3.0625e+00, -1.9727e-01],
         [-1.7969e-01, -1.2266e+00,  8.1641e-01,  ...,  1.0234e+00,
          -3.5938e-01,  8.9062e-01],
         ...,
         [-9.2578e-01,  2.4219e-01, -1.3281e+00,  ...,  1.7344e+00,
           3.0273e-01,  1.4766e+00],
         [-2.3633e-01,  3.8672e-01, -3.7695e-01,  ...,  7.6953e-01,
           4.2969e-01,  1.3281e+00],
         [ 4.1016e-01,  1.4375e+00,  2.0703e-01,  ...,  1.9922e+00,
           1.0938e+00, -4.3750e+00]]], dtype=torch.bfloat16), tensor([[[ 9.8877e-03,  2.3193e-03,  7.3547e-03,  ..., -1.6357e-02,
           8.9722e-03, -5.0781e-02],
         [ 2.2461e-01, -3.2422e-01, -6.4062e-01,  ...,  3.5938e-01,
          -6.1328e-01, -2.9053e-02],
         [ 1.3086e-01, -2.3633e-01,  6.6406e-01,  ...,  6.1719e-01,
          -1.0625e+00,  4.8438e-01],
         ...,
         [ 1.2109e+00,  1.9684e-03, -2.0801e-01,  ...,  5.1953e-01,
          -6.2891e-01,  1.4160e-01],
         [ 1.0234e+00,  2.7148e-01, -6.6797e-01,  ...,  1.7456e-02,
          -6.4062e-01,  1.2422e+00],
         [ 4.0430e-01,  5.1172e-01, -5.2734e-01,  ..., -1.9922e-01,
          -7.1484e-01,  1.4453e+00]],

        [[ 9.8877e-03,  2.3193e-03,  7.3547e-03,  ..., -1.6357e-02,
           8.9722e-03, -5.0781e-02],
         [ 2.2461e-01, -3.2422e-01, -6.4062e-01,  ...,  3.5938e-01,
          -6.1328e-01, -2.9053e-02],
         [ 1.3086e-01, -2.3633e-01,  6.6406e-01,  ...,  6.1719e-01,
          -1.0625e+00,  4.8438e-01],
         ...,
         [ 1.2109e+00,  1.9684e-03, -2.0801e-01,  ...,  5.1953e-01,
          -6.2891e-01,  1.4160e-01],
         [ 1.0234e+00,  2.7148e-01, -6.6797e-01,  ...,  1.7456e-02,
          -6.4062e-01,  1.2422e+00],
         [ 4.0430e-01,  5.1172e-01, -5.2734e-01,  ..., -1.9922e-01,
          -7.1484e-01,  1.4453e+00]],

        [[ 9.8877e-03,  2.3193e-03,  7.3547e-03,  ..., -1.6357e-02,
           8.9722e-03, -5.0781e-02],
         [ 2.2461e-01, -3.2422e-01, -6.4062e-01,  ...,  3.5938e-01,
          -6.1328e-01, -2.9053e-02],
         [ 1.3086e-01, -2.3633e-01,  6.6406e-01,  ...,  6.1719e-01,
          -1.0625e+00,  4.8438e-01],
         ...,
         [ 1.2109e+00,  1.9684e-03, -2.0801e-01,  ...,  5.1953e-01,
          -6.2891e-01,  1.4160e-01],
         [ 1.0234e+00,  2.7148e-01, -6.6797e-01,  ...,  1.7456e-02,
          -6.4062e-01,  1.2422e+00],
         [ 4.0430e-01,  5.1172e-01, -5.2734e-01,  ..., -1.9922e-01,
          -7.1484e-01,  1.4453e+00]],

        ...,

        [[ 3.3569e-03, -6.7139e-03, -8.9264e-04,  ...,  1.3046e-03,
           2.6703e-03,  5.0659e-03],
         [-2.7812e+00, -3.8867e-01, -9.0234e-01,  ..., -2.7969e+00,
          -6.2109e-01, -2.7500e+00],
         [-1.2578e+00,  3.4961e-01, -1.9531e-01,  ..., -3.0078e-01,
           8.3984e-01, -1.6484e+00],
         ...,
         [ 3.9307e-02,  7.9346e-03, -2.7539e-01,  ...,  3.5938e-01,
           3.8281e-01, -6.0156e-01],
         [-3.7598e-02,  4.1602e-01, -2.5195e-01,  ..., -1.2988e-01,
          -4.0283e-02, -7.0312e-01],
         [-7.5391e-01,  3.6377e-02,  1.1641e+00,  ...,  9.5703e-01,
           1.5781e+00,  2.9883e-01]],

        [[ 3.3569e-03, -6.7139e-03, -8.9264e-04,  ...,  1.3046e-03,
           2.6703e-03,  5.0659e-03],
         [-2.7812e+00, -3.8867e-01, -9.0234e-01,  ..., -2.7969e+00,
          -6.2109e-01, -2.7500e+00],
         [-1.2578e+00,  3.4961e-01, -1.9531e-01,  ..., -3.0078e-01,
           8.3984e-01, -1.6484e+00],
         ...,
         [ 3.9307e-02,  7.9346e-03, -2.7539e-01,  ...,  3.5938e-01,
           3.8281e-01, -6.0156e-01],
         [-3.7598e-02,  4.1602e-01, -2.5195e-01,  ..., -1.2988e-01,
          -4.0283e-02, -7.0312e-01],
         [-7.5391e-01,  3.6377e-02,  1.1641e+00,  ...,  9.5703e-01,
           1.5781e+00,  2.9883e-01]],

        [[ 3.3569e-03, -6.7139e-03, -8.9264e-04,  ...,  1.3046e-03,
           2.6703e-03,  5.0659e-03],
         [-2.7812e+00, -3.8867e-01, -9.0234e-01,  ..., -2.7969e+00,
          -6.2109e-01, -2.7500e+00],
         [-1.2578e+00,  3.4961e-01, -1.9531e-01,  ..., -3.0078e-01,
           8.3984e-01, -1.6484e+00],
         ...,
         [ 3.9307e-02,  7.9346e-03, -2.7539e-01,  ...,  3.5938e-01,
           3.8281e-01, -6.0156e-01],
         [-3.7598e-02,  4.1602e-01, -2.5195e-01,  ..., -1.2988e-01,
          -4.0283e-02, -7.0312e-01],
         [-7.5391e-01,  3.6377e-02,  1.1641e+00,  ...,  9.5703e-01,
           1.5781e+00,  2.9883e-01]]], dtype=torch.bfloat16)), (tensor([[[-1.6724e-02, -4.3030e-03,  4.7607e-03,  ...,  1.0352e-01,
           8.0566e-02,  3.9062e-01],
         [-4.0625e-01, -7.0312e-01, -8.1250e-01,  ..., -2.0781e+00,
          -1.4922e+00, -6.7812e+00],
         [-1.4688e+00, -1.0312e+00, -9.8828e-01,  ..., -9.4531e-01,
          -2.5469e+00, -6.0625e+00],
         ...,
         [-3.4375e-01,  7.0312e-01,  4.0234e-01,  ..., -1.2656e+00,
          -2.8125e+00, -3.6094e+00],
         [ 1.1719e-02,  1.0156e+00, -1.0938e-01,  ..., -1.5156e+00,
          -2.9375e+00, -5.5000e+00],
         [ 3.3594e-01,  1.7266e+00, -1.5234e+00,  ..., -1.0000e+00,
          -3.4375e+00, -5.5938e+00]],

        [[-1.6724e-02, -4.3030e-03,  4.7607e-03,  ...,  1.0352e-01,
           8.0566e-02,  3.9062e-01],
         [-4.0625e-01, -7.0312e-01, -8.1250e-01,  ..., -2.0781e+00,
          -1.4922e+00, -6.7812e+00],
         [-1.4688e+00, -1.0312e+00, -9.8828e-01,  ..., -9.4531e-01,
          -2.5469e+00, -6.0625e+00],
         ...,
         [-3.4375e-01,  7.0312e-01,  4.0234e-01,  ..., -1.2656e+00,
          -2.8125e+00, -3.6094e+00],
         [ 1.1719e-02,  1.0156e+00, -1.0938e-01,  ..., -1.5156e+00,
          -2.9375e+00, -5.5000e+00],
         [ 3.3594e-01,  1.7266e+00, -1.5234e+00,  ..., -1.0000e+00,
          -3.4375e+00, -5.5938e+00]],

        [[-1.6724e-02, -4.3030e-03,  4.7607e-03,  ...,  1.0352e-01,
           8.0566e-02,  3.9062e-01],
         [-4.0625e-01, -7.0312e-01, -8.1250e-01,  ..., -2.0781e+00,
          -1.4922e+00, -6.7812e+00],
         [-1.4688e+00, -1.0312e+00, -9.8828e-01,  ..., -9.4531e-01,
          -2.5469e+00, -6.0625e+00],
         ...,
         [-3.4375e-01,  7.0312e-01,  4.0234e-01,  ..., -1.2656e+00,
          -2.8125e+00, -3.6094e+00],
         [ 1.1719e-02,  1.0156e+00, -1.0938e-01,  ..., -1.5156e+00,
          -2.9375e+00, -5.5000e+00],
         [ 3.3594e-01,  1.7266e+00, -1.5234e+00,  ..., -1.0000e+00,
          -3.4375e+00, -5.5938e+00]],

        ...,

        [[-5.0354e-04,  4.9744e-03,  1.7471e-03,  ...,  4.1504e-02,
          -3.3203e-02,  4.9316e-02],
         [ 9.1016e-01, -8.8672e-01, -8.7891e-01,  ..., -1.8828e+00,
           1.9297e+00,  2.3047e-01],
         [-1.1406e+00,  1.6895e-01, -6.6797e-01,  ..., -1.1094e+00,
           4.7852e-01,  2.0312e+00],
         ...,
         [-9.3359e-01,  7.5391e-01,  9.0820e-02,  ...,  1.0156e+00,
           1.3047e+00,  7.8613e-02],
         [-3.0469e-01,  2.6367e-01, -7.8125e-01,  ..., -6.3672e-01,
           9.7266e-01, -8.7500e-01],
         [ 1.0625e+00,  2.0938e+00, -1.1484e+00,  ..., -9.0820e-02,
           1.6562e+00,  2.3594e+00]],

        [[-5.0354e-04,  4.9744e-03,  1.7471e-03,  ...,  4.1504e-02,
          -3.3203e-02,  4.9316e-02],
         [ 9.1016e-01, -8.8672e-01, -8.7891e-01,  ..., -1.8828e+00,
           1.9297e+00,  2.3047e-01],
         [-1.1406e+00,  1.6895e-01, -6.6797e-01,  ..., -1.1094e+00,
           4.7852e-01,  2.0312e+00],
         ...,
         [-9.3359e-01,  7.5391e-01,  9.0820e-02,  ...,  1.0156e+00,
           1.3047e+00,  7.8613e-02],
         [-3.0469e-01,  2.6367e-01, -7.8125e-01,  ..., -6.3672e-01,
           9.7266e-01, -8.7500e-01],
         [ 1.0625e+00,  2.0938e+00, -1.1484e+00,  ..., -9.0820e-02,
           1.6562e+00,  2.3594e+00]],

        [[-5.0354e-04,  4.9744e-03,  1.7471e-03,  ...,  4.1504e-02,
          -3.3203e-02,  4.9316e-02],
         [ 9.1016e-01, -8.8672e-01, -8.7891e-01,  ..., -1.8828e+00,
           1.9297e+00,  2.3047e-01],
         [-1.1406e+00,  1.6895e-01, -6.6797e-01,  ..., -1.1094e+00,
           4.7852e-01,  2.0312e+00],
         ...,
         [-9.3359e-01,  7.5391e-01,  9.0820e-02,  ...,  1.0156e+00,
           1.3047e+00,  7.8613e-02],
         [-3.0469e-01,  2.6367e-01, -7.8125e-01,  ..., -6.3672e-01,
           9.7266e-01, -8.7500e-01],
         [ 1.0625e+00,  2.0938e+00, -1.1484e+00,  ..., -9.0820e-02,
           1.6562e+00,  2.3594e+00]]], dtype=torch.bfloat16), tensor([[[ 1.1902e-03, -6.4392e-03,  8.5449e-03,  ...,  5.4626e-03,
           7.1716e-03,  7.1106e-03],
         [ 5.1562e-01,  1.0625e+00, -7.3438e-01,  ...,  1.2266e+00,
          -1.7422e+00,  6.5625e-01],
         [ 3.5156e-01, -6.4844e-01, -9.4531e-01,  ...,  4.4727e-01,
          -1.4531e+00,  6.0938e-01],
         ...,
         [ 2.8076e-02,  1.6406e+00, -2.3926e-01,  ..., -1.0703e+00,
          -1.5234e+00,  4.5508e-01],
         [ 3.6523e-01,  1.3184e-01,  4.9316e-02,  ...,  1.6211e-01,
          -3.7305e-01,  5.3516e-01],
         [ 1.1875e+00,  1.4648e-01, -1.1406e+00,  ...,  7.0801e-02,
           4.4189e-02, -7.7148e-02]],

        [[ 1.1902e-03, -6.4392e-03,  8.5449e-03,  ...,  5.4626e-03,
           7.1716e-03,  7.1106e-03],
         [ 5.1562e-01,  1.0625e+00, -7.3438e-01,  ...,  1.2266e+00,
          -1.7422e+00,  6.5625e-01],
         [ 3.5156e-01, -6.4844e-01, -9.4531e-01,  ...,  4.4727e-01,
          -1.4531e+00,  6.0938e-01],
         ...,
         [ 2.8076e-02,  1.6406e+00, -2.3926e-01,  ..., -1.0703e+00,
          -1.5234e+00,  4.5508e-01],
         [ 3.6523e-01,  1.3184e-01,  4.9316e-02,  ...,  1.6211e-01,
          -3.7305e-01,  5.3516e-01],
         [ 1.1875e+00,  1.4648e-01, -1.1406e+00,  ...,  7.0801e-02,
           4.4189e-02, -7.7148e-02]],

        [[ 1.1902e-03, -6.4392e-03,  8.5449e-03,  ...,  5.4626e-03,
           7.1716e-03,  7.1106e-03],
         [ 5.1562e-01,  1.0625e+00, -7.3438e-01,  ...,  1.2266e+00,
          -1.7422e+00,  6.5625e-01],
         [ 3.5156e-01, -6.4844e-01, -9.4531e-01,  ...,  4.4727e-01,
          -1.4531e+00,  6.0938e-01],
         ...,
         [ 2.8076e-02,  1.6406e+00, -2.3926e-01,  ..., -1.0703e+00,
          -1.5234e+00,  4.5508e-01],
         [ 3.6523e-01,  1.3184e-01,  4.9316e-02,  ...,  1.6211e-01,
          -3.7305e-01,  5.3516e-01],
         [ 1.1875e+00,  1.4648e-01, -1.1406e+00,  ...,  7.0801e-02,
           4.4189e-02, -7.7148e-02]],

        ...,

        [[-5.5847e-03, -3.2501e-03,  4.8523e-03,  ...,  1.0010e-02,
          -5.5542e-03,  7.4768e-03],
         [ 5.9375e-01, -1.4160e-01, -1.7383e-01,  ...,  2.6758e-01,
          -4.1992e-01,  2.2949e-01],
         [-2.5781e-01,  1.6562e+00,  1.6172e+00,  ...,  1.5781e+00,
           7.8516e-01,  3.4375e-01],
         ...,
         [-1.1719e+00,  9.8438e-01, -1.4375e+00,  ..., -8.8281e-01,
           1.0078e+00, -5.7031e-01],
         [-2.3730e-01,  2.6367e-01, -3.3398e-01,  ...,  2.0605e-01,
           8.3594e-01,  4.0234e-01],
         [-3.6719e-01, -1.4609e+00,  2.3242e-01,  ..., -1.0469e+00,
           3.7695e-01, -2.0703e-01]],

        [[-5.5847e-03, -3.2501e-03,  4.8523e-03,  ...,  1.0010e-02,
          -5.5542e-03,  7.4768e-03],
         [ 5.9375e-01, -1.4160e-01, -1.7383e-01,  ...,  2.6758e-01,
          -4.1992e-01,  2.2949e-01],
         [-2.5781e-01,  1.6562e+00,  1.6172e+00,  ...,  1.5781e+00,
           7.8516e-01,  3.4375e-01],
         ...,
         [-1.1719e+00,  9.8438e-01, -1.4375e+00,  ..., -8.8281e-01,
           1.0078e+00, -5.7031e-01],
         [-2.3730e-01,  2.6367e-01, -3.3398e-01,  ...,  2.0605e-01,
           8.3594e-01,  4.0234e-01],
         [-3.6719e-01, -1.4609e+00,  2.3242e-01,  ..., -1.0469e+00,
           3.7695e-01, -2.0703e-01]],

        [[-5.5847e-03, -3.2501e-03,  4.8523e-03,  ...,  1.0010e-02,
          -5.5542e-03,  7.4768e-03],
         [ 5.9375e-01, -1.4160e-01, -1.7383e-01,  ...,  2.6758e-01,
          -4.1992e-01,  2.2949e-01],
         [-2.5781e-01,  1.6562e+00,  1.6172e+00,  ...,  1.5781e+00,
           7.8516e-01,  3.4375e-01],
         ...,
         [-1.1719e+00,  9.8438e-01, -1.4375e+00,  ..., -8.8281e-01,
           1.0078e+00, -5.7031e-01],
         [-2.3730e-01,  2.6367e-01, -3.3398e-01,  ...,  2.0605e-01,
           8.3594e-01,  4.0234e-01],
         [-3.6719e-01, -1.4609e+00,  2.3242e-01,  ..., -1.0469e+00,
           3.7695e-01, -2.0703e-01]]], dtype=torch.bfloat16)), (tensor([[[ 2.4414e-03,  9.9487e-03,  9.2773e-03,  ..., -2.6367e-01,
           4.9316e-02,  1.9043e-02],
         [-8.3594e-01, -7.3828e-01,  3.5352e-01,  ...,  4.2812e+00,
          -3.0469e-01, -2.1875e+00],
         [ 6.0156e-01,  1.6406e-01,  8.0078e-02,  ...,  6.3125e+00,
          -2.9883e-01,  1.1658e-02],
         ...,
         [ 1.0547e+00, -8.3594e-01, -9.4922e-01,  ...,  2.2656e+00,
           1.3281e-01,  8.7891e-01],
         [-7.4219e-02, -5.0781e-01, -8.7500e-01,  ...,  3.4688e+00,
           3.3398e-01,  8.4473e-02],
         [-1.1172e+00, -8.5156e-01, -6.6406e-01,  ...,  3.9375e+00,
          -9.5703e-01, -3.9062e-01]],

        [[ 2.4414e-03,  9.9487e-03,  9.2773e-03,  ..., -2.6367e-01,
           4.9316e-02,  1.9043e-02],
         [-8.3594e-01, -7.3828e-01,  3.5352e-01,  ...,  4.2812e+00,
          -3.0469e-01, -2.1875e+00],
         [ 6.0156e-01,  1.6406e-01,  8.0078e-02,  ...,  6.3125e+00,
          -2.9883e-01,  1.1658e-02],
         ...,
         [ 1.0547e+00, -8.3594e-01, -9.4922e-01,  ...,  2.2656e+00,
           1.3281e-01,  8.7891e-01],
         [-7.4219e-02, -5.0781e-01, -8.7500e-01,  ...,  3.4688e+00,
           3.3398e-01,  8.4473e-02],
         [-1.1172e+00, -8.5156e-01, -6.6406e-01,  ...,  3.9375e+00,
          -9.5703e-01, -3.9062e-01]],

        [[ 2.4414e-03,  9.9487e-03,  9.2773e-03,  ..., -2.6367e-01,
           4.9316e-02,  1.9043e-02],
         [-8.3594e-01, -7.3828e-01,  3.5352e-01,  ...,  4.2812e+00,
          -3.0469e-01, -2.1875e+00],
         [ 6.0156e-01,  1.6406e-01,  8.0078e-02,  ...,  6.3125e+00,
          -2.9883e-01,  1.1658e-02],
         ...,
         [ 1.0547e+00, -8.3594e-01, -9.4922e-01,  ...,  2.2656e+00,
           1.3281e-01,  8.7891e-01],
         [-7.4219e-02, -5.0781e-01, -8.7500e-01,  ...,  3.4688e+00,
           3.3398e-01,  8.4473e-02],
         [-1.1172e+00, -8.5156e-01, -6.6406e-01,  ...,  3.9375e+00,
          -9.5703e-01, -3.9062e-01]],

        ...,

        [[-6.4850e-04,  4.4250e-03, -1.9684e-03,  ...,  9.6680e-02,
          -5.8984e-01, -5.2246e-02],
         [-2.4609e-01, -6.4844e-01, -1.1133e-01,  ...,  2.4062e+00,
           8.1250e+00,  3.2188e+00],
         [-8.2031e-02,  6.2500e-01, -2.9492e-01,  ..., -3.3750e+00,
           1.0375e+01,  4.1406e-01],
         ...,
         [-7.4609e-01, -3.5547e-01, -2.1875e-01,  ...,  3.2812e-01,
           8.5000e+00,  5.8125e+00],
         [-9.0234e-01, -2.9492e-01, -6.7969e-01,  ..., -1.2109e+00,
           7.6562e+00,  1.9141e+00],
         [ 5.4688e-01, -8.9844e-01, -3.2227e-01,  ...,  3.7305e-01,
           9.4375e+00, -5.3750e+00]],

        [[-6.4850e-04,  4.4250e-03, -1.9684e-03,  ...,  9.6680e-02,
          -5.8984e-01, -5.2246e-02],
         [-2.4609e-01, -6.4844e-01, -1.1133e-01,  ...,  2.4062e+00,
           8.1250e+00,  3.2188e+00],
         [-8.2031e-02,  6.2500e-01, -2.9492e-01,  ..., -3.3750e+00,
           1.0375e+01,  4.1406e-01],
         ...,
         [-7.4609e-01, -3.5547e-01, -2.1875e-01,  ...,  3.2812e-01,
           8.5000e+00,  5.8125e+00],
         [-9.0234e-01, -2.9492e-01, -6.7969e-01,  ..., -1.2109e+00,
           7.6562e+00,  1.9141e+00],
         [ 5.4688e-01, -8.9844e-01, -3.2227e-01,  ...,  3.7305e-01,
           9.4375e+00, -5.3750e+00]],

        [[-6.4850e-04,  4.4250e-03, -1.9684e-03,  ...,  9.6680e-02,
          -5.8984e-01, -5.2246e-02],
         [-2.4609e-01, -6.4844e-01, -1.1133e-01,  ...,  2.4062e+00,
           8.1250e+00,  3.2188e+00],
         [-8.2031e-02,  6.2500e-01, -2.9492e-01,  ..., -3.3750e+00,
           1.0375e+01,  4.1406e-01],
         ...,
         [-7.4609e-01, -3.5547e-01, -2.1875e-01,  ...,  3.2812e-01,
           8.5000e+00,  5.8125e+00],
         [-9.0234e-01, -2.9492e-01, -6.7969e-01,  ..., -1.2109e+00,
           7.6562e+00,  1.9141e+00],
         [ 5.4688e-01, -8.9844e-01, -3.2227e-01,  ...,  3.7305e-01,
           9.4375e+00, -5.3750e+00]]], dtype=torch.bfloat16), tensor([[[-4.2114e-03,  1.7700e-03,  8.1787e-03,  ...,  2.0599e-03,
          -3.2616e-04, -2.4414e-03],
         [ 2.2031e+00, -1.1377e-01, -2.0156e+00,  ...,  1.6719e+00,
          -4.4375e+00, -2.3594e+00],
         [ 5.7812e-01, -7.9297e-01, -7.2266e-01,  ...,  6.5430e-02,
           8.1543e-02, -1.7891e+00],
         ...,
         [ 9.1406e-01, -1.8672e+00, -6.3281e-01,  ..., -3.0469e-01,
          -6.7383e-02, -3.3203e-01],
         [ 1.0938e+00,  3.7891e-01, -1.0498e-01,  ..., -3.2422e-01,
           1.5137e-01, -4.5300e-05],
         [ 4.9219e-01,  2.0312e+00,  2.0469e+00,  ...,  1.5391e+00,
          -2.3730e-01, -2.8320e-01]],

        [[-4.2114e-03,  1.7700e-03,  8.1787e-03,  ...,  2.0599e-03,
          -3.2616e-04, -2.4414e-03],
         [ 2.2031e+00, -1.1377e-01, -2.0156e+00,  ...,  1.6719e+00,
          -4.4375e+00, -2.3594e+00],
         [ 5.7812e-01, -7.9297e-01, -7.2266e-01,  ...,  6.5430e-02,
           8.1543e-02, -1.7891e+00],
         ...,
         [ 9.1406e-01, -1.8672e+00, -6.3281e-01,  ..., -3.0469e-01,
          -6.7383e-02, -3.3203e-01],
         [ 1.0938e+00,  3.7891e-01, -1.0498e-01,  ..., -3.2422e-01,
           1.5137e-01, -4.5300e-05],
         [ 4.9219e-01,  2.0312e+00,  2.0469e+00,  ...,  1.5391e+00,
          -2.3730e-01, -2.8320e-01]],

        [[-4.2114e-03,  1.7700e-03,  8.1787e-03,  ...,  2.0599e-03,
          -3.2616e-04, -2.4414e-03],
         [ 2.2031e+00, -1.1377e-01, -2.0156e+00,  ...,  1.6719e+00,
          -4.4375e+00, -2.3594e+00],
         [ 5.7812e-01, -7.9297e-01, -7.2266e-01,  ...,  6.5430e-02,
           8.1543e-02, -1.7891e+00],
         ...,
         [ 9.1406e-01, -1.8672e+00, -6.3281e-01,  ..., -3.0469e-01,
          -6.7383e-02, -3.3203e-01],
         [ 1.0938e+00,  3.7891e-01, -1.0498e-01,  ..., -3.2422e-01,
           1.5137e-01, -4.5300e-05],
         [ 4.9219e-01,  2.0312e+00,  2.0469e+00,  ...,  1.5391e+00,
          -2.3730e-01, -2.8320e-01]],

        ...,

        [[ 1.8311e-02, -2.3499e-03,  1.9455e-04,  ..., -1.7334e-02,
           4.1389e-04,  4.0283e-03],
         [ 2.8125e-01,  3.6914e-01, -3.3398e-01,  ...,  1.3672e+00,
           4.8047e-01,  6.4062e-01],
         [-1.0391e+00, -1.2695e-01,  4.0039e-01,  ...,  5.3467e-02,
           1.9043e-01,  2.7148e-01],
         ...,
         [ 4.4727e-01,  2.8320e-01,  5.4688e-01,  ..., -5.1562e-01,
          -4.1211e-01,  5.3516e-01],
         [-1.3770e-01,  5.0391e-01,  1.8945e-01,  ...,  2.1289e-01,
           6.9885e-03, -1.1084e-01],
         [-7.6172e-01,  1.3438e+00, -1.7700e-02,  ..., -1.0469e+00,
           1.3828e+00,  2.2949e-02]],

        [[ 1.8311e-02, -2.3499e-03,  1.9455e-04,  ..., -1.7334e-02,
           4.1389e-04,  4.0283e-03],
         [ 2.8125e-01,  3.6914e-01, -3.3398e-01,  ...,  1.3672e+00,
           4.8047e-01,  6.4062e-01],
         [-1.0391e+00, -1.2695e-01,  4.0039e-01,  ...,  5.3467e-02,
           1.9043e-01,  2.7148e-01],
         ...,
         [ 4.4727e-01,  2.8320e-01,  5.4688e-01,  ..., -5.1562e-01,
          -4.1211e-01,  5.3516e-01],
         [-1.3770e-01,  5.0391e-01,  1.8945e-01,  ...,  2.1289e-01,
           6.9885e-03, -1.1084e-01],
         [-7.6172e-01,  1.3438e+00, -1.7700e-02,  ..., -1.0469e+00,
           1.3828e+00,  2.2949e-02]],

        [[ 1.8311e-02, -2.3499e-03,  1.9455e-04,  ..., -1.7334e-02,
           4.1389e-04,  4.0283e-03],
         [ 2.8125e-01,  3.6914e-01, -3.3398e-01,  ...,  1.3672e+00,
           4.8047e-01,  6.4062e-01],
         [-1.0391e+00, -1.2695e-01,  4.0039e-01,  ...,  5.3467e-02,
           1.9043e-01,  2.7148e-01],
         ...,
         [ 4.4727e-01,  2.8320e-01,  5.4688e-01,  ..., -5.1562e-01,
          -4.1211e-01,  5.3516e-01],
         [-1.3770e-01,  5.0391e-01,  1.8945e-01,  ...,  2.1289e-01,
           6.9885e-03, -1.1084e-01],
         [-7.6172e-01,  1.3438e+00, -1.7700e-02,  ..., -1.0469e+00,
           1.3828e+00,  2.2949e-02]]], dtype=torch.bfloat16)), (tensor([[[-3.4943e-03,  2.4414e-02,  5.6839e-04,  ...,  5.4932e-02,
          -1.2354e-01, -2.4023e-01],
         [-1.8359e-01, -8.9844e-02, -1.6953e+00,  ..., -1.0156e+00,
           3.7188e+00,  4.8047e-01],
         [-1.4062e+00, -1.9219e+00, -7.4609e-01,  ...,  1.8672e+00,
          -7.1716e-03,  3.0625e+00],
         ...,
         [ 1.8164e-01, -3.8086e-02, -7.7148e-02,  ..., -1.6328e+00,
           7.0703e-01,  1.2344e+00],
         [ 2.1094e-01,  3.1250e-01, -5.7031e-01,  ..., -6.0156e-01,
           7.6953e-01,  2.1875e+00],
         [ 7.4219e-01,  8.6719e-01, -5.8594e-01,  ...,  1.5527e-01,
           8.8281e-01,  8.6914e-02]],

        [[-3.4943e-03,  2.4414e-02,  5.6839e-04,  ...,  5.4932e-02,
          -1.2354e-01, -2.4023e-01],
         [-1.8359e-01, -8.9844e-02, -1.6953e+00,  ..., -1.0156e+00,
           3.7188e+00,  4.8047e-01],
         [-1.4062e+00, -1.9219e+00, -7.4609e-01,  ...,  1.8672e+00,
          -7.1716e-03,  3.0625e+00],
         ...,
         [ 1.8164e-01, -3.8086e-02, -7.7148e-02,  ..., -1.6328e+00,
           7.0703e-01,  1.2344e+00],
         [ 2.1094e-01,  3.1250e-01, -5.7031e-01,  ..., -6.0156e-01,
           7.6953e-01,  2.1875e+00],
         [ 7.4219e-01,  8.6719e-01, -5.8594e-01,  ...,  1.5527e-01,
           8.8281e-01,  8.6914e-02]],

        [[-3.4943e-03,  2.4414e-02,  5.6839e-04,  ...,  5.4932e-02,
          -1.2354e-01, -2.4023e-01],
         [-1.8359e-01, -8.9844e-02, -1.6953e+00,  ..., -1.0156e+00,
           3.7188e+00,  4.8047e-01],
         [-1.4062e+00, -1.9219e+00, -7.4609e-01,  ...,  1.8672e+00,
          -7.1716e-03,  3.0625e+00],
         ...,
         [ 1.8164e-01, -3.8086e-02, -7.7148e-02,  ..., -1.6328e+00,
           7.0703e-01,  1.2344e+00],
         [ 2.1094e-01,  3.1250e-01, -5.7031e-01,  ..., -6.0156e-01,
           7.6953e-01,  2.1875e+00],
         [ 7.4219e-01,  8.6719e-01, -5.8594e-01,  ...,  1.5527e-01,
           8.8281e-01,  8.6914e-02]],

        ...,

        [[ 1.0729e-04,  8.3008e-03, -1.4099e-02,  ...,  5.6763e-03,
          -5.2979e-02,  1.0864e-02],
         [-3.2031e-01, -1.0498e-01,  6.7188e-01,  ..., -9.3750e-01,
           4.4688e+00, -1.2344e+00],
         [-5.8203e-01, -7.8906e-01,  1.0156e+00,  ...,  1.8516e+00,
           2.5312e+00, -1.0547e+00],
         ...,
         [-2.4414e-01,  3.3008e-01,  3.3398e-01,  ...,  6.2012e-02,
           1.0859e+00, -2.2656e+00],
         [ 2.3438e-01,  3.1836e-01,  3.5547e-01,  ...,  3.4961e-01,
           2.1250e+00, -1.1621e-01],
         [ 8.1250e-01,  8.7891e-01,  1.1406e+00,  ..., -1.0703e+00,
           2.3242e-01, -5.5078e-01]],

        [[ 1.0729e-04,  8.3008e-03, -1.4099e-02,  ...,  5.6763e-03,
          -5.2979e-02,  1.0864e-02],
         [-3.2031e-01, -1.0498e-01,  6.7188e-01,  ..., -9.3750e-01,
           4.4688e+00, -1.2344e+00],
         [-5.8203e-01, -7.8906e-01,  1.0156e+00,  ...,  1.8516e+00,
           2.5312e+00, -1.0547e+00],
         ...,
         [-2.4414e-01,  3.3008e-01,  3.3398e-01,  ...,  6.2012e-02,
           1.0859e+00, -2.2656e+00],
         [ 2.3438e-01,  3.1836e-01,  3.5547e-01,  ...,  3.4961e-01,
           2.1250e+00, -1.1621e-01],
         [ 8.1250e-01,  8.7891e-01,  1.1406e+00,  ..., -1.0703e+00,
           2.3242e-01, -5.5078e-01]],

        [[ 1.0729e-04,  8.3008e-03, -1.4099e-02,  ...,  5.6763e-03,
          -5.2979e-02,  1.0864e-02],
         [-3.2031e-01, -1.0498e-01,  6.7188e-01,  ..., -9.3750e-01,
           4.4688e+00, -1.2344e+00],
         [-5.8203e-01, -7.8906e-01,  1.0156e+00,  ...,  1.8516e+00,
           2.5312e+00, -1.0547e+00],
         ...,
         [-2.4414e-01,  3.3008e-01,  3.3398e-01,  ...,  6.2012e-02,
           1.0859e+00, -2.2656e+00],
         [ 2.3438e-01,  3.1836e-01,  3.5547e-01,  ...,  3.4961e-01,
           2.1250e+00, -1.1621e-01],
         [ 8.1250e-01,  8.7891e-01,  1.1406e+00,  ..., -1.0703e+00,
           2.3242e-01, -5.5078e-01]]], dtype=torch.bfloat16), tensor([[[ 4.5776e-03,  1.1414e-02,  6.9275e-03,  ...,  3.7994e-03,
          -2.7161e-03, -2.9755e-03],
         [ 1.0469e+00, -4.7812e+00, -1.7734e+00,  ..., -2.7188e+00,
           1.1641e+00,  6.6016e-01],
         [ 6.2891e-01, -1.8359e+00, -9.1797e-01,  ...,  1.0010e-02,
           1.4453e+00,  1.7266e+00],
         ...,
         [-3.3984e-01, -1.1250e+00, -1.0078e+00,  ...,  1.2891e+00,
           1.0391e+00,  1.1562e+00],
         [ 1.3477e-01, -2.9053e-02, -8.8281e-01,  ...,  6.7188e-01,
           8.5547e-01,  5.8594e-01],
         [-6.9922e-01,  2.9102e-01,  1.0312e+00,  ..., -4.5508e-01,
           4.6094e-01,  1.3125e+00]],

        [[ 4.5776e-03,  1.1414e-02,  6.9275e-03,  ...,  3.7994e-03,
          -2.7161e-03, -2.9755e-03],
         [ 1.0469e+00, -4.7812e+00, -1.7734e+00,  ..., -2.7188e+00,
           1.1641e+00,  6.6016e-01],
         [ 6.2891e-01, -1.8359e+00, -9.1797e-01,  ...,  1.0010e-02,
           1.4453e+00,  1.7266e+00],
         ...,
         [-3.3984e-01, -1.1250e+00, -1.0078e+00,  ...,  1.2891e+00,
           1.0391e+00,  1.1562e+00],
         [ 1.3477e-01, -2.9053e-02, -8.8281e-01,  ...,  6.7188e-01,
           8.5547e-01,  5.8594e-01],
         [-6.9922e-01,  2.9102e-01,  1.0312e+00,  ..., -4.5508e-01,
           4.6094e-01,  1.3125e+00]],

        [[ 4.5776e-03,  1.1414e-02,  6.9275e-03,  ...,  3.7994e-03,
          -2.7161e-03, -2.9755e-03],
         [ 1.0469e+00, -4.7812e+00, -1.7734e+00,  ..., -2.7188e+00,
           1.1641e+00,  6.6016e-01],
         [ 6.2891e-01, -1.8359e+00, -9.1797e-01,  ...,  1.0010e-02,
           1.4453e+00,  1.7266e+00],
         ...,
         [-3.3984e-01, -1.1250e+00, -1.0078e+00,  ...,  1.2891e+00,
           1.0391e+00,  1.1562e+00],
         [ 1.3477e-01, -2.9053e-02, -8.8281e-01,  ...,  6.7188e-01,
           8.5547e-01,  5.8594e-01],
         [-6.9922e-01,  2.9102e-01,  1.0312e+00,  ..., -4.5508e-01,
           4.6094e-01,  1.3125e+00]],

        ...,

        [[-3.8757e-03,  1.6327e-03,  7.3242e-03,  ..., -1.7822e-02,
          -6.9336e-02,  4.8447e-04],
         [-1.7344e+00,  4.6094e-01, -3.1875e+00,  ..., -2.4609e-01,
           2.3125e+00, -3.9688e+00],
         [ 8.3984e-01,  1.3281e+00, -1.0938e+00,  ...,  2.6953e-01,
          -2.3145e-01, -3.3203e-01],
         ...,
         [ 6.2500e-01, -6.8359e-01,  5.4297e-01,  ..., -7.5391e-01,
          -7.6172e-01,  9.2578e-01],
         [ 7.4219e-01,  4.6289e-01, -2.3535e-01,  ...,  4.4922e-01,
           1.9453e+00,  4.7070e-01],
         [ 7.8516e-01,  3.6914e-01, -2.0781e+00,  ...,  1.9375e+00,
           1.0234e+00,  1.1797e+00]],

        [[-3.8757e-03,  1.6327e-03,  7.3242e-03,  ..., -1.7822e-02,
          -6.9336e-02,  4.8447e-04],
         [-1.7344e+00,  4.6094e-01, -3.1875e+00,  ..., -2.4609e-01,
           2.3125e+00, -3.9688e+00],
         [ 8.3984e-01,  1.3281e+00, -1.0938e+00,  ...,  2.6953e-01,
          -2.3145e-01, -3.3203e-01],
         ...,
         [ 6.2500e-01, -6.8359e-01,  5.4297e-01,  ..., -7.5391e-01,
          -7.6172e-01,  9.2578e-01],
         [ 7.4219e-01,  4.6289e-01, -2.3535e-01,  ...,  4.4922e-01,
           1.9453e+00,  4.7070e-01],
         [ 7.8516e-01,  3.6914e-01, -2.0781e+00,  ...,  1.9375e+00,
           1.0234e+00,  1.1797e+00]],

        [[-3.8757e-03,  1.6327e-03,  7.3242e-03,  ..., -1.7822e-02,
          -6.9336e-02,  4.8447e-04],
         [-1.7344e+00,  4.6094e-01, -3.1875e+00,  ..., -2.4609e-01,
           2.3125e+00, -3.9688e+00],
         [ 8.3984e-01,  1.3281e+00, -1.0938e+00,  ...,  2.6953e-01,
          -2.3145e-01, -3.3203e-01],
         ...,
         [ 6.2500e-01, -6.8359e-01,  5.4297e-01,  ..., -7.5391e-01,
          -7.6172e-01,  9.2578e-01],
         [ 7.4219e-01,  4.6289e-01, -2.3535e-01,  ...,  4.4922e-01,
           1.9453e+00,  4.7070e-01],
         [ 7.8516e-01,  3.6914e-01, -2.0781e+00,  ...,  1.9375e+00,
           1.0234e+00,  1.1797e+00]]], dtype=torch.bfloat16)), (tensor([[[-0.0067,  0.0189, -0.0097,  ...,  0.1387,  0.0334,  0.0076],
         [-0.2676, -0.6953,  1.0625,  ..., -3.4531, -0.6523,  1.7344],
         [-0.5508,  1.1797,  0.9883,  ..., -4.2188,  0.6836,  0.8125],
         ...,
         [ 0.1611,  0.0879,  1.6562,  ..., -4.1250, -1.5078, -1.5156],
         [ 0.4375, -0.8086,  1.0938,  ..., -4.0000, -1.8438, -0.5469],
         [ 0.2734, -1.9844,  1.8047,  ..., -3.6250, -0.5430,  0.0093]],

        [[-0.0067,  0.0189, -0.0097,  ...,  0.1387,  0.0334,  0.0076],
         [-0.2676, -0.6953,  1.0625,  ..., -3.4531, -0.6523,  1.7344],
         [-0.5508,  1.1797,  0.9883,  ..., -4.2188,  0.6836,  0.8125],
         ...,
         [ 0.1611,  0.0879,  1.6562,  ..., -4.1250, -1.5078, -1.5156],
         [ 0.4375, -0.8086,  1.0938,  ..., -4.0000, -1.8438, -0.5469],
         [ 0.2734, -1.9844,  1.8047,  ..., -3.6250, -0.5430,  0.0093]],

        [[-0.0067,  0.0189, -0.0097,  ...,  0.1387,  0.0334,  0.0076],
         [-0.2676, -0.6953,  1.0625,  ..., -3.4531, -0.6523,  1.7344],
         [-0.5508,  1.1797,  0.9883,  ..., -4.2188,  0.6836,  0.8125],
         ...,
         [ 0.1611,  0.0879,  1.6562,  ..., -4.1250, -1.5078, -1.5156],
         [ 0.4375, -0.8086,  1.0938,  ..., -4.0000, -1.8438, -0.5469],
         [ 0.2734, -1.9844,  1.8047,  ..., -3.6250, -0.5430,  0.0093]],

        ...,

        [[-0.0256,  0.0177,  0.0053,  ...,  0.0048, -0.0129, -0.0820],
         [ 1.0859, -0.6836, -0.5078,  ...,  0.0684, -0.6016,  1.6953],
         [-1.2891, -1.2500,  0.5430,  ...,  1.9766,  0.4082,  2.3125],
         ...,
         [-1.8438,  1.1328, -0.5898,  ..., -0.2715, -0.8984,  0.7969],
         [ 0.0156,  0.4453, -0.3672,  ...,  0.0508, -0.8281,  0.9844],
         [ 1.5625, -1.1094, -0.3203,  ..., -1.5312, -0.0349,  0.5547]],

        [[-0.0256,  0.0177,  0.0053,  ...,  0.0048, -0.0129, -0.0820],
         [ 1.0859, -0.6836, -0.5078,  ...,  0.0684, -0.6016,  1.6953],
         [-1.2891, -1.2500,  0.5430,  ...,  1.9766,  0.4082,  2.3125],
         ...,
         [-1.8438,  1.1328, -0.5898,  ..., -0.2715, -0.8984,  0.7969],
         [ 0.0156,  0.4453, -0.3672,  ...,  0.0508, -0.8281,  0.9844],
         [ 1.5625, -1.1094, -0.3203,  ..., -1.5312, -0.0349,  0.5547]],

        [[-0.0256,  0.0177,  0.0053,  ...,  0.0048, -0.0129, -0.0820],
         [ 1.0859, -0.6836, -0.5078,  ...,  0.0684, -0.6016,  1.6953],
         [-1.2891, -1.2500,  0.5430,  ...,  1.9766,  0.4082,  2.3125],
         ...,
         [-1.8438,  1.1328, -0.5898,  ..., -0.2715, -0.8984,  0.7969],
         [ 0.0156,  0.4453, -0.3672,  ...,  0.0508, -0.8281,  0.9844],
         [ 1.5625, -1.1094, -0.3203,  ..., -1.5312, -0.0349,  0.5547]]],
       dtype=torch.bfloat16), tensor([[[-5.5542e-03, -1.3123e-03, -4.8218e-03,  ...,  5.2185e-03,
          -5.2185e-03, -7.4768e-04],
         [-1.1562e+00,  5.5469e-01,  4.4922e-01,  ..., -1.0078e+00,
          -3.0859e-01,  2.7734e-01],
         [ 5.1172e-01, -2.8711e-01,  3.6865e-02,  ...,  5.1953e-01,
           8.7109e-01, -9.5312e-01],
         ...,
         [ 4.5508e-01,  4.0430e-01, -2.4292e-02,  ..., -5.8350e-02,
          -1.7031e+00, -7.5781e-01],
         [ 1.9141e-01,  3.3936e-02,  1.7969e-01,  ..., -1.0400e-01,
          -1.0781e+00, -7.8906e-01],
         [ 2.2461e-01,  9.8828e-01, -3.6133e-01,  ...,  4.0625e-01,
          -5.4297e-01,  2.0312e-01]],

        [[-5.5542e-03, -1.3123e-03, -4.8218e-03,  ...,  5.2185e-03,
          -5.2185e-03, -7.4768e-04],
         [-1.1562e+00,  5.5469e-01,  4.4922e-01,  ..., -1.0078e+00,
          -3.0859e-01,  2.7734e-01],
         [ 5.1172e-01, -2.8711e-01,  3.6865e-02,  ...,  5.1953e-01,
           8.7109e-01, -9.5312e-01],
         ...,
         [ 4.5508e-01,  4.0430e-01, -2.4292e-02,  ..., -5.8350e-02,
          -1.7031e+00, -7.5781e-01],
         [ 1.9141e-01,  3.3936e-02,  1.7969e-01,  ..., -1.0400e-01,
          -1.0781e+00, -7.8906e-01],
         [ 2.2461e-01,  9.8828e-01, -3.6133e-01,  ...,  4.0625e-01,
          -5.4297e-01,  2.0312e-01]],

        [[-5.5542e-03, -1.3123e-03, -4.8218e-03,  ...,  5.2185e-03,
          -5.2185e-03, -7.4768e-04],
         [-1.1562e+00,  5.5469e-01,  4.4922e-01,  ..., -1.0078e+00,
          -3.0859e-01,  2.7734e-01],
         [ 5.1172e-01, -2.8711e-01,  3.6865e-02,  ...,  5.1953e-01,
           8.7109e-01, -9.5312e-01],
         ...,
         [ 4.5508e-01,  4.0430e-01, -2.4292e-02,  ..., -5.8350e-02,
          -1.7031e+00, -7.5781e-01],
         [ 1.9141e-01,  3.3936e-02,  1.7969e-01,  ..., -1.0400e-01,
          -1.0781e+00, -7.8906e-01],
         [ 2.2461e-01,  9.8828e-01, -3.6133e-01,  ...,  4.0625e-01,
          -5.4297e-01,  2.0312e-01]],

        ...,

        [[-7.5378e-03, -4.8161e-05,  1.0132e-02,  ..., -6.2256e-03,
          -5.2185e-03, -4.3640e-03],
         [ 1.0156e+00, -2.6367e-01, -4.8242e-01,  ..., -1.2969e+00,
           5.2734e-01,  2.0996e-01],
         [-4.5898e-02,  1.1719e+00,  9.1797e-01,  ..., -1.0000e+00,
           8.6719e-01, -4.1211e-01],
         ...,
         [-2.7930e-01, -1.7734e+00, -4.6484e-01,  ..., -1.1484e+00,
          -7.7344e-01,  1.9434e-01],
         [-2.1484e-02, -4.3164e-01, -3.5547e-01,  ..., -4.2578e-01,
          -1.4844e+00, -4.3750e-01],
         [-1.2188e+00, -5.0781e-01,  4.6875e-02,  ..., -7.5195e-02,
           7.5781e-01,  2.1387e-01]],

        [[-7.5378e-03, -4.8161e-05,  1.0132e-02,  ..., -6.2256e-03,
          -5.2185e-03, -4.3640e-03],
         [ 1.0156e+00, -2.6367e-01, -4.8242e-01,  ..., -1.2969e+00,
           5.2734e-01,  2.0996e-01],
         [-4.5898e-02,  1.1719e+00,  9.1797e-01,  ..., -1.0000e+00,
           8.6719e-01, -4.1211e-01],
         ...,
         [-2.7930e-01, -1.7734e+00, -4.6484e-01,  ..., -1.1484e+00,
          -7.7344e-01,  1.9434e-01],
         [-2.1484e-02, -4.3164e-01, -3.5547e-01,  ..., -4.2578e-01,
          -1.4844e+00, -4.3750e-01],
         [-1.2188e+00, -5.0781e-01,  4.6875e-02,  ..., -7.5195e-02,
           7.5781e-01,  2.1387e-01]],

        [[-7.5378e-03, -4.8161e-05,  1.0132e-02,  ..., -6.2256e-03,
          -5.2185e-03, -4.3640e-03],
         [ 1.0156e+00, -2.6367e-01, -4.8242e-01,  ..., -1.2969e+00,
           5.2734e-01,  2.0996e-01],
         [-4.5898e-02,  1.1719e+00,  9.1797e-01,  ..., -1.0000e+00,
           8.6719e-01, -4.1211e-01],
         ...,
         [-2.7930e-01, -1.7734e+00, -4.6484e-01,  ..., -1.1484e+00,
          -7.7344e-01,  1.9434e-01],
         [-2.1484e-02, -4.3164e-01, -3.5547e-01,  ..., -4.2578e-01,
          -1.4844e+00, -4.3750e-01],
         [-1.2188e+00, -5.0781e-01,  4.6875e-02,  ..., -7.5195e-02,
           7.5781e-01,  2.1387e-01]]], dtype=torch.bfloat16)), (tensor([[[ 1.5198e-02,  4.3030e-03,  1.2131e-03,  ...,  3.6865e-02,
          -2.1387e-01, -1.9287e-02],
         [-1.3828e+00,  2.1875e-01, -4.2480e-02,  ...,  1.6562e+00,
           4.1406e-01,  3.2344e+00],
         [-9.0234e-01,  2.6001e-02, -2.6172e-01,  ...,  2.2656e+00,
           1.5625e+00,  2.1719e+00],
         ...,
         [-4.4336e-01, -5.3906e-01,  3.9453e-01,  ..., -3.8086e-02,
          -2.5469e+00,  2.5000e+00],
         [-8.5938e-02, -1.6895e-01,  7.3242e-03,  ..., -2.3242e-01,
          -1.8203e+00, -7.3828e-01],
         [ 4.6289e-01,  8.2031e-01, -6.4844e-01,  ..., -3.5469e+00,
          -3.0469e+00,  2.0000e+00]],

        [[ 1.5198e-02,  4.3030e-03,  1.2131e-03,  ...,  3.6865e-02,
          -2.1387e-01, -1.9287e-02],
         [-1.3828e+00,  2.1875e-01, -4.2480e-02,  ...,  1.6562e+00,
           4.1406e-01,  3.2344e+00],
         [-9.0234e-01,  2.6001e-02, -2.6172e-01,  ...,  2.2656e+00,
           1.5625e+00,  2.1719e+00],
         ...,
         [-4.4336e-01, -5.3906e-01,  3.9453e-01,  ..., -3.8086e-02,
          -2.5469e+00,  2.5000e+00],
         [-8.5938e-02, -1.6895e-01,  7.3242e-03,  ..., -2.3242e-01,
          -1.8203e+00, -7.3828e-01],
         [ 4.6289e-01,  8.2031e-01, -6.4844e-01,  ..., -3.5469e+00,
          -3.0469e+00,  2.0000e+00]],

        [[ 1.5198e-02,  4.3030e-03,  1.2131e-03,  ...,  3.6865e-02,
          -2.1387e-01, -1.9287e-02],
         [-1.3828e+00,  2.1875e-01, -4.2480e-02,  ...,  1.6562e+00,
           4.1406e-01,  3.2344e+00],
         [-9.0234e-01,  2.6001e-02, -2.6172e-01,  ...,  2.2656e+00,
           1.5625e+00,  2.1719e+00],
         ...,
         [-4.4336e-01, -5.3906e-01,  3.9453e-01,  ..., -3.8086e-02,
          -2.5469e+00,  2.5000e+00],
         [-8.5938e-02, -1.6895e-01,  7.3242e-03,  ..., -2.3242e-01,
          -1.8203e+00, -7.3828e-01],
         [ 4.6289e-01,  8.2031e-01, -6.4844e-01,  ..., -3.5469e+00,
          -3.0469e+00,  2.0000e+00]],

        ...,

        [[-1.1536e-02, -6.7139e-03,  3.7003e-04,  ...,  1.7944e-02,
           6.1279e-02,  1.0938e+00],
         [-4.4922e-02, -1.0078e+00,  9.3750e-02,  ..., -1.9609e+00,
           3.4219e+00, -8.7500e+00],
         [-1.6602e-01, -6.3672e-01,  5.1270e-03,  ...,  7.3828e-01,
          -2.9375e+00, -1.0000e+01],
         ...,
         [-3.1641e-01,  2.7148e-01, -2.7734e-01,  ...,  2.8906e+00,
           6.4844e-01, -6.7188e+00],
         [ 1.6992e-01,  3.1641e-01, -1.6211e-01,  ..., -5.4688e-01,
          -1.4609e+00, -7.9062e+00],
         [ 2.4902e-01,  7.0312e-01, -8.1250e-01,  ..., -5.4297e-01,
          -2.4531e+00, -7.5000e+00]],

        [[-1.1536e-02, -6.7139e-03,  3.7003e-04,  ...,  1.7944e-02,
           6.1279e-02,  1.0938e+00],
         [-4.4922e-02, -1.0078e+00,  9.3750e-02,  ..., -1.9609e+00,
           3.4219e+00, -8.7500e+00],
         [-1.6602e-01, -6.3672e-01,  5.1270e-03,  ...,  7.3828e-01,
          -2.9375e+00, -1.0000e+01],
         ...,
         [-3.1641e-01,  2.7148e-01, -2.7734e-01,  ...,  2.8906e+00,
           6.4844e-01, -6.7188e+00],
         [ 1.6992e-01,  3.1641e-01, -1.6211e-01,  ..., -5.4688e-01,
          -1.4609e+00, -7.9062e+00],
         [ 2.4902e-01,  7.0312e-01, -8.1250e-01,  ..., -5.4297e-01,
          -2.4531e+00, -7.5000e+00]],

        [[-1.1536e-02, -6.7139e-03,  3.7003e-04,  ...,  1.7944e-02,
           6.1279e-02,  1.0938e+00],
         [-4.4922e-02, -1.0078e+00,  9.3750e-02,  ..., -1.9609e+00,
           3.4219e+00, -8.7500e+00],
         [-1.6602e-01, -6.3672e-01,  5.1270e-03,  ...,  7.3828e-01,
          -2.9375e+00, -1.0000e+01],
         ...,
         [-3.1641e-01,  2.7148e-01, -2.7734e-01,  ...,  2.8906e+00,
           6.4844e-01, -6.7188e+00],
         [ 1.6992e-01,  3.1641e-01, -1.6211e-01,  ..., -5.4688e-01,
          -1.4609e+00, -7.9062e+00],
         [ 2.4902e-01,  7.0312e-01, -8.1250e-01,  ..., -5.4297e-01,
          -2.4531e+00, -7.5000e+00]]], dtype=torch.bfloat16), tensor([[[ 2.9755e-04,  3.8757e-03,  6.9427e-04,  ...,  5.5237e-03,
           1.3828e-04, -6.5918e-03],
         [ 1.3594e+00,  5.7422e-01, -1.0938e+00,  ..., -1.8828e+00,
           5.3516e-01,  5.5859e-01],
         [ 2.3315e-02, -2.5938e+00,  8.9844e-01,  ..., -9.7266e-01,
           3.0664e-01,  1.2266e+00],
         ...,
         [ 1.4062e+00, -5.8594e-01,  2.3438e-01,  ..., -6.3281e-01,
          -1.9043e-01, -1.8457e-01],
         [ 1.1875e+00, -9.2969e-01, -3.5156e-01,  ..., -2.8516e-01,
          -3.4766e-01,  5.4443e-02],
         [ 6.3281e-01,  3.4375e-01, -4.7266e-01,  ...,  1.0000e+00,
          -8.9355e-02,  3.7354e-02]],

        [[ 2.9755e-04,  3.8757e-03,  6.9427e-04,  ...,  5.5237e-03,
           1.3828e-04, -6.5918e-03],
         [ 1.3594e+00,  5.7422e-01, -1.0938e+00,  ..., -1.8828e+00,
           5.3516e-01,  5.5859e-01],
         [ 2.3315e-02, -2.5938e+00,  8.9844e-01,  ..., -9.7266e-01,
           3.0664e-01,  1.2266e+00],
         ...,
         [ 1.4062e+00, -5.8594e-01,  2.3438e-01,  ..., -6.3281e-01,
          -1.9043e-01, -1.8457e-01],
         [ 1.1875e+00, -9.2969e-01, -3.5156e-01,  ..., -2.8516e-01,
          -3.4766e-01,  5.4443e-02],
         [ 6.3281e-01,  3.4375e-01, -4.7266e-01,  ...,  1.0000e+00,
          -8.9355e-02,  3.7354e-02]],

        [[ 2.9755e-04,  3.8757e-03,  6.9427e-04,  ...,  5.5237e-03,
           1.3828e-04, -6.5918e-03],
         [ 1.3594e+00,  5.7422e-01, -1.0938e+00,  ..., -1.8828e+00,
           5.3516e-01,  5.5859e-01],
         [ 2.3315e-02, -2.5938e+00,  8.9844e-01,  ..., -9.7266e-01,
           3.0664e-01,  1.2266e+00],
         ...,
         [ 1.4062e+00, -5.8594e-01,  2.3438e-01,  ..., -6.3281e-01,
          -1.9043e-01, -1.8457e-01],
         [ 1.1875e+00, -9.2969e-01, -3.5156e-01,  ..., -2.8516e-01,
          -3.4766e-01,  5.4443e-02],
         [ 6.3281e-01,  3.4375e-01, -4.7266e-01,  ...,  1.0000e+00,
          -8.9355e-02,  3.7354e-02]],

        ...,

        [[ 2.6703e-03, -6.5918e-03, -9.8267e-03,  ..., -7.3242e-03,
          -2.3460e-04, -4.3945e-03],
         [-1.8906e+00,  7.5391e-01, -9.6191e-02,  ...,  2.6367e-01,
           1.5156e+00, -1.2031e+00],
         [ 1.4531e+00, -1.3359e+00,  9.0820e-02,  ..., -2.6406e+00,
           4.2188e+00,  3.1875e+00],
         ...,
         [ 5.2734e-01,  1.1562e+00, -9.4922e-01,  ...,  2.3071e-02,
          -2.3828e-01, -8.9844e-01],
         [-3.6719e-01, -1.5918e-01, -5.4297e-01,  ...,  2.0312e-01,
           1.0205e-01,  7.2266e-02],
         [ 3.8867e-01,  2.6758e-01,  2.2344e+00,  ...,  1.9531e+00,
          -3.7500e-01,  6.6895e-02]],

        [[ 2.6703e-03, -6.5918e-03, -9.8267e-03,  ..., -7.3242e-03,
          -2.3460e-04, -4.3945e-03],
         [-1.8906e+00,  7.5391e-01, -9.6191e-02,  ...,  2.6367e-01,
           1.5156e+00, -1.2031e+00],
         [ 1.4531e+00, -1.3359e+00,  9.0820e-02,  ..., -2.6406e+00,
           4.2188e+00,  3.1875e+00],
         ...,
         [ 5.2734e-01,  1.1562e+00, -9.4922e-01,  ...,  2.3071e-02,
          -2.3828e-01, -8.9844e-01],
         [-3.6719e-01, -1.5918e-01, -5.4297e-01,  ...,  2.0312e-01,
           1.0205e-01,  7.2266e-02],
         [ 3.8867e-01,  2.6758e-01,  2.2344e+00,  ...,  1.9531e+00,
          -3.7500e-01,  6.6895e-02]],

        [[ 2.6703e-03, -6.5918e-03, -9.8267e-03,  ..., -7.3242e-03,
          -2.3460e-04, -4.3945e-03],
         [-1.8906e+00,  7.5391e-01, -9.6191e-02,  ...,  2.6367e-01,
           1.5156e+00, -1.2031e+00],
         [ 1.4531e+00, -1.3359e+00,  9.0820e-02,  ..., -2.6406e+00,
           4.2188e+00,  3.1875e+00],
         ...,
         [ 5.2734e-01,  1.1562e+00, -9.4922e-01,  ...,  2.3071e-02,
          -2.3828e-01, -8.9844e-01],
         [-3.6719e-01, -1.5918e-01, -5.4297e-01,  ...,  2.0312e-01,
           1.0205e-01,  7.2266e-02],
         [ 3.8867e-01,  2.6758e-01,  2.2344e+00,  ...,  1.9531e+00,
          -3.7500e-01,  6.6895e-02]]], dtype=torch.bfloat16)), (tensor([[[-3.6926e-03,  3.5858e-03, -1.4191e-03,  ...,  2.9907e-02,
          -6.6406e-02,  5.3516e-01],
         [ 2.0938e+00,  1.7031e+00, -1.5156e+00,  ..., -2.1875e+00,
           1.8594e+00, -5.1875e+00],
         [ 7.6172e-01,  1.2109e+00, -1.4062e-01,  ..., -1.0938e+00,
          -1.8047e+00, -3.8281e+00],
         ...,
         [-8.5547e-01, -1.2812e+00, -2.6562e-01,  ..., -4.0820e-01,
           4.6562e+00, -5.0938e+00],
         [ 1.1230e-01, -3.1445e-01,  5.0000e-01,  ...,  8.0469e-01,
           5.0781e-01, -4.2812e+00],
         [ 5.5078e-01, -2.0312e-01, -6.4062e-01,  ..., -1.8984e+00,
          -5.2344e-01, -4.5938e+00]],

        [[-3.6926e-03,  3.5858e-03, -1.4191e-03,  ...,  2.9907e-02,
          -6.6406e-02,  5.3516e-01],
         [ 2.0938e+00,  1.7031e+00, -1.5156e+00,  ..., -2.1875e+00,
           1.8594e+00, -5.1875e+00],
         [ 7.6172e-01,  1.2109e+00, -1.4062e-01,  ..., -1.0938e+00,
          -1.8047e+00, -3.8281e+00],
         ...,
         [-8.5547e-01, -1.2812e+00, -2.6562e-01,  ..., -4.0820e-01,
           4.6562e+00, -5.0938e+00],
         [ 1.1230e-01, -3.1445e-01,  5.0000e-01,  ...,  8.0469e-01,
           5.0781e-01, -4.2812e+00],
         [ 5.5078e-01, -2.0312e-01, -6.4062e-01,  ..., -1.8984e+00,
          -5.2344e-01, -4.5938e+00]],

        [[-3.6926e-03,  3.5858e-03, -1.4191e-03,  ...,  2.9907e-02,
          -6.6406e-02,  5.3516e-01],
         [ 2.0938e+00,  1.7031e+00, -1.5156e+00,  ..., -2.1875e+00,
           1.8594e+00, -5.1875e+00],
         [ 7.6172e-01,  1.2109e+00, -1.4062e-01,  ..., -1.0938e+00,
          -1.8047e+00, -3.8281e+00],
         ...,
         [-8.5547e-01, -1.2812e+00, -2.6562e-01,  ..., -4.0820e-01,
           4.6562e+00, -5.0938e+00],
         [ 1.1230e-01, -3.1445e-01,  5.0000e-01,  ...,  8.0469e-01,
           5.0781e-01, -4.2812e+00],
         [ 5.5078e-01, -2.0312e-01, -6.4062e-01,  ..., -1.8984e+00,
          -5.2344e-01, -4.5938e+00]],

        ...,

        [[-1.0834e-03,  1.1597e-02,  6.6223e-03,  ...,  1.1426e-01,
           5.9082e-02,  7.5195e-02],
         [ 1.6719e+00, -5.2344e-01, -8.1641e-01,  ..., -6.9922e-01,
          -5.8594e-01, -1.0781e+00],
         [ 1.2344e+00, -1.4453e+00, -9.9219e-01,  ..., -6.4844e-01,
          -4.3945e-01, -6.7969e-01],
         ...,
         [-1.8945e-01,  1.4453e-01,  4.4141e-01,  ..., -1.1875e+00,
          -1.1953e+00, -1.0391e+00],
         [-4.3945e-01,  1.3203e+00, -2.2070e-01,  ..., -5.1172e-01,
          -1.0107e-01, -1.1523e-01],
         [-4.1992e-01,  3.9062e-01, -7.5000e-01,  ..., -9.4141e-01,
           4.8340e-02, -1.2422e+00]],

        [[-1.0834e-03,  1.1597e-02,  6.6223e-03,  ...,  1.1426e-01,
           5.9082e-02,  7.5195e-02],
         [ 1.6719e+00, -5.2344e-01, -8.1641e-01,  ..., -6.9922e-01,
          -5.8594e-01, -1.0781e+00],
         [ 1.2344e+00, -1.4453e+00, -9.9219e-01,  ..., -6.4844e-01,
          -4.3945e-01, -6.7969e-01],
         ...,
         [-1.8945e-01,  1.4453e-01,  4.4141e-01,  ..., -1.1875e+00,
          -1.1953e+00, -1.0391e+00],
         [-4.3945e-01,  1.3203e+00, -2.2070e-01,  ..., -5.1172e-01,
          -1.0107e-01, -1.1523e-01],
         [-4.1992e-01,  3.9062e-01, -7.5000e-01,  ..., -9.4141e-01,
           4.8340e-02, -1.2422e+00]],

        [[-1.0834e-03,  1.1597e-02,  6.6223e-03,  ...,  1.1426e-01,
           5.9082e-02,  7.5195e-02],
         [ 1.6719e+00, -5.2344e-01, -8.1641e-01,  ..., -6.9922e-01,
          -5.8594e-01, -1.0781e+00],
         [ 1.2344e+00, -1.4453e+00, -9.9219e-01,  ..., -6.4844e-01,
          -4.3945e-01, -6.7969e-01],
         ...,
         [-1.8945e-01,  1.4453e-01,  4.4141e-01,  ..., -1.1875e+00,
          -1.1953e+00, -1.0391e+00],
         [-4.3945e-01,  1.3203e+00, -2.2070e-01,  ..., -5.1172e-01,
          -1.0107e-01, -1.1523e-01],
         [-4.1992e-01,  3.9062e-01, -7.5000e-01,  ..., -9.4141e-01,
           4.8340e-02, -1.2422e+00]]], dtype=torch.bfloat16), tensor([[[-1.1215e-03, -1.3184e-02,  2.7313e-03,  ...,  1.4160e-02,
           2.8038e-04,  7.0496e-03],
         [ 6.5625e-01,  3.0078e-01,  3.0664e-01,  ..., -2.2754e-01,
          -8.6719e-01, -2.0703e-01],
         [ 3.8672e-01, -2.0020e-01,  3.7305e-01,  ...,  1.3477e-01,
           7.9297e-01,  3.1250e-01],
         ...,
         [-1.5938e+00,  2.8125e-01,  1.6357e-02,  ..., -7.7637e-02,
           2.5781e-01, -4.3359e-01],
         [-1.1719e+00, -1.6406e-01,  2.2583e-02,  ...,  4.1211e-01,
          -2.0117e-01, -1.8066e-01],
         [-7.1094e-01, -8.0566e-03,  7.1484e-01,  ...,  4.2578e-01,
          -1.9727e-01, -8.1641e-01]],

        [[-1.1215e-03, -1.3184e-02,  2.7313e-03,  ...,  1.4160e-02,
           2.8038e-04,  7.0496e-03],
         [ 6.5625e-01,  3.0078e-01,  3.0664e-01,  ..., -2.2754e-01,
          -8.6719e-01, -2.0703e-01],
         [ 3.8672e-01, -2.0020e-01,  3.7305e-01,  ...,  1.3477e-01,
           7.9297e-01,  3.1250e-01],
         ...,
         [-1.5938e+00,  2.8125e-01,  1.6357e-02,  ..., -7.7637e-02,
           2.5781e-01, -4.3359e-01],
         [-1.1719e+00, -1.6406e-01,  2.2583e-02,  ...,  4.1211e-01,
          -2.0117e-01, -1.8066e-01],
         [-7.1094e-01, -8.0566e-03,  7.1484e-01,  ...,  4.2578e-01,
          -1.9727e-01, -8.1641e-01]],

        [[-1.1215e-03, -1.3184e-02,  2.7313e-03,  ...,  1.4160e-02,
           2.8038e-04,  7.0496e-03],
         [ 6.5625e-01,  3.0078e-01,  3.0664e-01,  ..., -2.2754e-01,
          -8.6719e-01, -2.0703e-01],
         [ 3.8672e-01, -2.0020e-01,  3.7305e-01,  ...,  1.3477e-01,
           7.9297e-01,  3.1250e-01],
         ...,
         [-1.5938e+00,  2.8125e-01,  1.6357e-02,  ..., -7.7637e-02,
           2.5781e-01, -4.3359e-01],
         [-1.1719e+00, -1.6406e-01,  2.2583e-02,  ...,  4.1211e-01,
          -2.0117e-01, -1.8066e-01],
         [-7.1094e-01, -8.0566e-03,  7.1484e-01,  ...,  4.2578e-01,
          -1.9727e-01, -8.1641e-01]],

        ...,

        [[-9.1553e-03, -9.7656e-03, -6.1951e-03,  ...,  5.8899e-03,
           6.4697e-03, -7.2098e-04],
         [-5.0781e-01,  1.3281e+00, -1.4766e+00,  ..., -7.6172e-01,
          -1.5430e-01, -1.7656e+00],
         [-4.9414e-01,  2.2949e-01,  1.8203e+00,  ..., -3.0273e-01,
          -2.4805e-01, -6.5625e-01],
         ...,
         [-5.9814e-02,  2.3535e-01,  1.6250e+00,  ...,  1.0703e+00,
          -1.2656e+00,  2.3340e-01],
         [-2.2168e-01,  4.7461e-01,  8.6719e-01,  ...,  1.3672e-01,
           5.6885e-02, -7.2021e-03],
         [-9.2188e-01, -1.1328e+00,  1.0312e+00,  ...,  1.2969e+00,
           1.3203e+00, -7.2656e-01]],

        [[-9.1553e-03, -9.7656e-03, -6.1951e-03,  ...,  5.8899e-03,
           6.4697e-03, -7.2098e-04],
         [-5.0781e-01,  1.3281e+00, -1.4766e+00,  ..., -7.6172e-01,
          -1.5430e-01, -1.7656e+00],
         [-4.9414e-01,  2.2949e-01,  1.8203e+00,  ..., -3.0273e-01,
          -2.4805e-01, -6.5625e-01],
         ...,
         [-5.9814e-02,  2.3535e-01,  1.6250e+00,  ...,  1.0703e+00,
          -1.2656e+00,  2.3340e-01],
         [-2.2168e-01,  4.7461e-01,  8.6719e-01,  ...,  1.3672e-01,
           5.6885e-02, -7.2021e-03],
         [-9.2188e-01, -1.1328e+00,  1.0312e+00,  ...,  1.2969e+00,
           1.3203e+00, -7.2656e-01]],

        [[-9.1553e-03, -9.7656e-03, -6.1951e-03,  ...,  5.8899e-03,
           6.4697e-03, -7.2098e-04],
         [-5.0781e-01,  1.3281e+00, -1.4766e+00,  ..., -7.6172e-01,
          -1.5430e-01, -1.7656e+00],
         [-4.9414e-01,  2.2949e-01,  1.8203e+00,  ..., -3.0273e-01,
          -2.4805e-01, -6.5625e-01],
         ...,
         [-5.9814e-02,  2.3535e-01,  1.6250e+00,  ...,  1.0703e+00,
          -1.2656e+00,  2.3340e-01],
         [-2.2168e-01,  4.7461e-01,  8.6719e-01,  ...,  1.3672e-01,
           5.6885e-02, -7.2021e-03],
         [-9.2188e-01, -1.1328e+00,  1.0312e+00,  ...,  1.2969e+00,
           1.3203e+00, -7.2656e-01]]], dtype=torch.bfloat16)), (tensor([[[-2.0386e-02,  2.8801e-04,  1.3504e-03,  ..., -6.3281e-01,
          -2.7222e-02, -1.0645e-01],
         [-3.9062e-02,  2.2969e+00,  1.6562e+00,  ...,  1.0688e+01,
          -2.3750e+00,  5.7031e-01],
         [-1.7656e+00,  8.3594e-01,  4.0234e-01,  ...,  1.1438e+01,
          -7.1094e-01,  8.7109e-01],
         ...,
         [-7.3438e-01, -1.1875e+00,  7.0312e-01,  ...,  9.6875e+00,
           1.3047e+00,  1.0391e+00],
         [ 2.1484e-02, -1.3828e+00,  8.9844e-01,  ...,  8.6250e+00,
           6.4844e-01,  4.7656e-01],
         [ 1.7578e+00, -4.8047e-01,  3.2031e-01,  ...,  9.4375e+00,
           1.2305e-01, -1.7188e-01]],

        [[-2.0386e-02,  2.8801e-04,  1.3504e-03,  ..., -6.3281e-01,
          -2.7222e-02, -1.0645e-01],
         [-3.9062e-02,  2.2969e+00,  1.6562e+00,  ...,  1.0688e+01,
          -2.3750e+00,  5.7031e-01],
         [-1.7656e+00,  8.3594e-01,  4.0234e-01,  ...,  1.1438e+01,
          -7.1094e-01,  8.7109e-01],
         ...,
         [-7.3438e-01, -1.1875e+00,  7.0312e-01,  ...,  9.6875e+00,
           1.3047e+00,  1.0391e+00],
         [ 2.1484e-02, -1.3828e+00,  8.9844e-01,  ...,  8.6250e+00,
           6.4844e-01,  4.7656e-01],
         [ 1.7578e+00, -4.8047e-01,  3.2031e-01,  ...,  9.4375e+00,
           1.2305e-01, -1.7188e-01]],

        [[-2.0386e-02,  2.8801e-04,  1.3504e-03,  ..., -6.3281e-01,
          -2.7222e-02, -1.0645e-01],
         [-3.9062e-02,  2.2969e+00,  1.6562e+00,  ...,  1.0688e+01,
          -2.3750e+00,  5.7031e-01],
         [-1.7656e+00,  8.3594e-01,  4.0234e-01,  ...,  1.1438e+01,
          -7.1094e-01,  8.7109e-01],
         ...,
         [-7.3438e-01, -1.1875e+00,  7.0312e-01,  ...,  9.6875e+00,
           1.3047e+00,  1.0391e+00],
         [ 2.1484e-02, -1.3828e+00,  8.9844e-01,  ...,  8.6250e+00,
           6.4844e-01,  4.7656e-01],
         [ 1.7578e+00, -4.8047e-01,  3.2031e-01,  ...,  9.4375e+00,
           1.2305e-01, -1.7188e-01]],

        ...,

        [[ 4.9744e-03,  5.5237e-03,  7.2937e-03,  ..., -2.2949e-02,
           8.1787e-03,  1.5723e-01],
         [-1.3750e+00, -8.7891e-01, -8.5156e-01,  ...,  1.4453e+00,
          -2.7734e-01, -3.7812e+00],
         [-8.2812e-01, -5.9375e-01,  1.8652e-01,  ...,  3.1875e+00,
           2.5977e-01,  1.8555e-01],
         ...,
         [ 1.1328e+00,  5.0000e-01, -9.5703e-01,  ..., -1.5078e+00,
          -2.6250e+00, -4.0000e+00],
         [ 2.1094e-01, -8.7402e-02, -7.1484e-01,  ...,  5.3125e-01,
           4.9219e-01, -7.1094e-01],
         [ 1.1484e+00, -1.6016e+00, -7.5391e-01,  ...,  2.5000e+00,
          -1.0859e+00, -2.0938e+00]],

        [[ 4.9744e-03,  5.5237e-03,  7.2937e-03,  ..., -2.2949e-02,
           8.1787e-03,  1.5723e-01],
         [-1.3750e+00, -8.7891e-01, -8.5156e-01,  ...,  1.4453e+00,
          -2.7734e-01, -3.7812e+00],
         [-8.2812e-01, -5.9375e-01,  1.8652e-01,  ...,  3.1875e+00,
           2.5977e-01,  1.8555e-01],
         ...,
         [ 1.1328e+00,  5.0000e-01, -9.5703e-01,  ..., -1.5078e+00,
          -2.6250e+00, -4.0000e+00],
         [ 2.1094e-01, -8.7402e-02, -7.1484e-01,  ...,  5.3125e-01,
           4.9219e-01, -7.1094e-01],
         [ 1.1484e+00, -1.6016e+00, -7.5391e-01,  ...,  2.5000e+00,
          -1.0859e+00, -2.0938e+00]],

        [[ 4.9744e-03,  5.5237e-03,  7.2937e-03,  ..., -2.2949e-02,
           8.1787e-03,  1.5723e-01],
         [-1.3750e+00, -8.7891e-01, -8.5156e-01,  ...,  1.4453e+00,
          -2.7734e-01, -3.7812e+00],
         [-8.2812e-01, -5.9375e-01,  1.8652e-01,  ...,  3.1875e+00,
           2.5977e-01,  1.8555e-01],
         ...,
         [ 1.1328e+00,  5.0000e-01, -9.5703e-01,  ..., -1.5078e+00,
          -2.6250e+00, -4.0000e+00],
         [ 2.1094e-01, -8.7402e-02, -7.1484e-01,  ...,  5.3125e-01,
           4.9219e-01, -7.1094e-01],
         [ 1.1484e+00, -1.6016e+00, -7.5391e-01,  ...,  2.5000e+00,
          -1.0859e+00, -2.0938e+00]]], dtype=torch.bfloat16), tensor([[[-0.0063, -0.0132, -0.0074,  ...,  0.0058, -0.0055, -0.0167],
         [-1.8672, -0.9336, -1.5859,  ..., -1.7891, -0.7305, -0.5938],
         [-1.5156,  0.0366,  0.4688,  ..., -0.1069,  0.0542,  0.1040],
         ...,
         [ 0.3672,  0.3418, -1.2734,  ...,  0.5977, -0.7734, -0.1543],
         [ 0.6211, -0.0996, -0.0452,  ..., -0.3730, -0.6523,  0.3340],
         [-1.1875,  0.2227,  1.3828,  ...,  0.5156,  0.0330,  0.7070]],

        [[-0.0063, -0.0132, -0.0074,  ...,  0.0058, -0.0055, -0.0167],
         [-1.8672, -0.9336, -1.5859,  ..., -1.7891, -0.7305, -0.5938],
         [-1.5156,  0.0366,  0.4688,  ..., -0.1069,  0.0542,  0.1040],
         ...,
         [ 0.3672,  0.3418, -1.2734,  ...,  0.5977, -0.7734, -0.1543],
         [ 0.6211, -0.0996, -0.0452,  ..., -0.3730, -0.6523,  0.3340],
         [-1.1875,  0.2227,  1.3828,  ...,  0.5156,  0.0330,  0.7070]],

        [[-0.0063, -0.0132, -0.0074,  ...,  0.0058, -0.0055, -0.0167],
         [-1.8672, -0.9336, -1.5859,  ..., -1.7891, -0.7305, -0.5938],
         [-1.5156,  0.0366,  0.4688,  ..., -0.1069,  0.0542,  0.1040],
         ...,
         [ 0.3672,  0.3418, -1.2734,  ...,  0.5977, -0.7734, -0.1543],
         [ 0.6211, -0.0996, -0.0452,  ..., -0.3730, -0.6523,  0.3340],
         [-1.1875,  0.2227,  1.3828,  ...,  0.5156,  0.0330,  0.7070]],

        ...,

        [[ 0.0068,  0.0091,  0.0122,  ..., -0.0236, -0.0041,  0.0050],
         [ 0.5156,  0.7461,  1.4219,  ...,  1.0859,  1.1172,  0.5859],
         [-0.3027,  0.3945,  1.0859,  ...,  0.7852,  0.3652, -0.1992],
         ...,
         [-0.2930, -0.4766, -0.0162,  ...,  0.1172, -1.1328,  0.0571],
         [ 0.2930, -0.0217, -0.4297,  ...,  0.1543,  0.2285,  0.2598],
         [-0.9648, -0.8828,  0.5430,  ...,  0.3125, -0.2676,  0.3633]],

        [[ 0.0068,  0.0091,  0.0122,  ..., -0.0236, -0.0041,  0.0050],
         [ 0.5156,  0.7461,  1.4219,  ...,  1.0859,  1.1172,  0.5859],
         [-0.3027,  0.3945,  1.0859,  ...,  0.7852,  0.3652, -0.1992],
         ...,
         [-0.2930, -0.4766, -0.0162,  ...,  0.1172, -1.1328,  0.0571],
         [ 0.2930, -0.0217, -0.4297,  ...,  0.1543,  0.2285,  0.2598],
         [-0.9648, -0.8828,  0.5430,  ...,  0.3125, -0.2676,  0.3633]],

        [[ 0.0068,  0.0091,  0.0122,  ..., -0.0236, -0.0041,  0.0050],
         [ 0.5156,  0.7461,  1.4219,  ...,  1.0859,  1.1172,  0.5859],
         [-0.3027,  0.3945,  1.0859,  ...,  0.7852,  0.3652, -0.1992],
         ...,
         [-0.2930, -0.4766, -0.0162,  ...,  0.1172, -1.1328,  0.0571],
         [ 0.2930, -0.0217, -0.4297,  ...,  0.1543,  0.2285,  0.2598],
         [-0.9648, -0.8828,  0.5430,  ...,  0.3125, -0.2676,  0.3633]]],
       dtype=torch.bfloat16)), (tensor([[[ 2.4261e-03,  2.3804e-03,  2.8076e-03,  ..., -3.6377e-02,
          -2.4219e-01,  8.2812e-01],
         [-4.5312e-01, -1.1094e+00,  3.9844e-01,  ...,  6.1328e-01,
          -1.2354e-01, -7.6875e+00],
         [-8.2422e-01, -1.3516e+00,  9.1406e-01,  ..., -1.0859e+00,
           2.6875e+00, -9.0625e+00],
         ...,
         [-1.3574e-01,  3.9648e-01, -7.7734e-01,  ...,  3.4531e+00,
          -5.5625e+00, -7.4688e+00],
         [-2.7344e-01, -4.7852e-01, -3.4961e-01,  ...,  7.0312e-01,
          -2.4375e+00, -7.1562e+00],
         [ 1.2031e+00, -1.5234e-01,  6.7188e-01,  ..., -4.6875e-01,
           1.8203e+00, -8.0625e+00]],

        [[ 2.4261e-03,  2.3804e-03,  2.8076e-03,  ..., -3.6377e-02,
          -2.4219e-01,  8.2812e-01],
         [-4.5312e-01, -1.1094e+00,  3.9844e-01,  ...,  6.1328e-01,
          -1.2354e-01, -7.6875e+00],
         [-8.2422e-01, -1.3516e+00,  9.1406e-01,  ..., -1.0859e+00,
           2.6875e+00, -9.0625e+00],
         ...,
         [-1.3574e-01,  3.9648e-01, -7.7734e-01,  ...,  3.4531e+00,
          -5.5625e+00, -7.4688e+00],
         [-2.7344e-01, -4.7852e-01, -3.4961e-01,  ...,  7.0312e-01,
          -2.4375e+00, -7.1562e+00],
         [ 1.2031e+00, -1.5234e-01,  6.7188e-01,  ..., -4.6875e-01,
           1.8203e+00, -8.0625e+00]],

        [[ 2.4261e-03,  2.3804e-03,  2.8076e-03,  ..., -3.6377e-02,
          -2.4219e-01,  8.2812e-01],
         [-4.5312e-01, -1.1094e+00,  3.9844e-01,  ...,  6.1328e-01,
          -1.2354e-01, -7.6875e+00],
         [-8.2422e-01, -1.3516e+00,  9.1406e-01,  ..., -1.0859e+00,
           2.6875e+00, -9.0625e+00],
         ...,
         [-1.3574e-01,  3.9648e-01, -7.7734e-01,  ...,  3.4531e+00,
          -5.5625e+00, -7.4688e+00],
         [-2.7344e-01, -4.7852e-01, -3.4961e-01,  ...,  7.0312e-01,
          -2.4375e+00, -7.1562e+00],
         [ 1.2031e+00, -1.5234e-01,  6.7188e-01,  ..., -4.6875e-01,
           1.8203e+00, -8.0625e+00]],

        ...,

        [[ 1.5747e-02, -5.4016e-03,  8.7891e-03,  ..., -1.3184e-02,
          -5.9326e-02,  3.2812e-01],
         [-2.4688e+00,  2.1875e-01, -9.1797e-01,  ..., -9.3750e-01,
           1.6250e+00, -2.6875e+00],
         [-2.8125e-01,  2.2461e-01, -4.1406e-01,  ..., -2.9844e+00,
           1.8555e-01, -2.4531e+00],
         ...,
         [-1.4941e-01, -6.0547e-01,  1.0059e-01,  ..., -1.7109e+00,
          -3.5742e-01, -3.2969e+00],
         [ 5.1562e-01, -2.5781e-01, -1.9922e-01,  ..., -3.3750e+00,
          -1.2109e+00, -1.0156e+00],
         [ 9.1797e-01, -6.1719e-01, -1.4766e+00,  ..., -3.3281e+00,
           2.3438e+00, -1.8047e+00]],

        [[ 1.5747e-02, -5.4016e-03,  8.7891e-03,  ..., -1.3184e-02,
          -5.9326e-02,  3.2812e-01],
         [-2.4688e+00,  2.1875e-01, -9.1797e-01,  ..., -9.3750e-01,
           1.6250e+00, -2.6875e+00],
         [-2.8125e-01,  2.2461e-01, -4.1406e-01,  ..., -2.9844e+00,
           1.8555e-01, -2.4531e+00],
         ...,
         [-1.4941e-01, -6.0547e-01,  1.0059e-01,  ..., -1.7109e+00,
          -3.5742e-01, -3.2969e+00],
         [ 5.1562e-01, -2.5781e-01, -1.9922e-01,  ..., -3.3750e+00,
          -1.2109e+00, -1.0156e+00],
         [ 9.1797e-01, -6.1719e-01, -1.4766e+00,  ..., -3.3281e+00,
           2.3438e+00, -1.8047e+00]],

        [[ 1.5747e-02, -5.4016e-03,  8.7891e-03,  ..., -1.3184e-02,
          -5.9326e-02,  3.2812e-01],
         [-2.4688e+00,  2.1875e-01, -9.1797e-01,  ..., -9.3750e-01,
           1.6250e+00, -2.6875e+00],
         [-2.8125e-01,  2.2461e-01, -4.1406e-01,  ..., -2.9844e+00,
           1.8555e-01, -2.4531e+00],
         ...,
         [-1.4941e-01, -6.0547e-01,  1.0059e-01,  ..., -1.7109e+00,
          -3.5742e-01, -3.2969e+00],
         [ 5.1562e-01, -2.5781e-01, -1.9922e-01,  ..., -3.3750e+00,
          -1.2109e+00, -1.0156e+00],
         [ 9.1797e-01, -6.1719e-01, -1.4766e+00,  ..., -3.3281e+00,
           2.3438e+00, -1.8047e+00]]], dtype=torch.bfloat16), tensor([[[-5.2795e-03,  2.1648e-04, -1.0498e-02,  ..., -1.2054e-03,
           1.2665e-03,  1.0254e-02],
         [ 1.1016e+00, -6.9922e-01,  2.5391e-01,  ..., -3.6914e-01,
           2.3438e-01, -1.0791e-01],
         [ 1.4551e-01, -2.8198e-02,  4.4727e-01,  ..., -1.1719e+00,
           6.7188e-01, -6.7188e-01],
         ...,
         [ 1.4297e+00,  1.7734e+00,  2.4062e+00,  ...,  2.2188e+00,
          -4.1562e+00,  3.7344e+00],
         [ 1.1250e+00,  8.6719e-01,  1.9824e-01,  ...,  1.0498e-01,
          -5.9766e-01,  9.2578e-01],
         [-1.2793e-01,  2.3438e-01,  4.6680e-01,  ..., -9.1406e-01,
           3.9258e-01, -3.1055e-01]],

        [[-5.2795e-03,  2.1648e-04, -1.0498e-02,  ..., -1.2054e-03,
           1.2665e-03,  1.0254e-02],
         [ 1.1016e+00, -6.9922e-01,  2.5391e-01,  ..., -3.6914e-01,
           2.3438e-01, -1.0791e-01],
         [ 1.4551e-01, -2.8198e-02,  4.4727e-01,  ..., -1.1719e+00,
           6.7188e-01, -6.7188e-01],
         ...,
         [ 1.4297e+00,  1.7734e+00,  2.4062e+00,  ...,  2.2188e+00,
          -4.1562e+00,  3.7344e+00],
         [ 1.1250e+00,  8.6719e-01,  1.9824e-01,  ...,  1.0498e-01,
          -5.9766e-01,  9.2578e-01],
         [-1.2793e-01,  2.3438e-01,  4.6680e-01,  ..., -9.1406e-01,
           3.9258e-01, -3.1055e-01]],

        [[-5.2795e-03,  2.1648e-04, -1.0498e-02,  ..., -1.2054e-03,
           1.2665e-03,  1.0254e-02],
         [ 1.1016e+00, -6.9922e-01,  2.5391e-01,  ..., -3.6914e-01,
           2.3438e-01, -1.0791e-01],
         [ 1.4551e-01, -2.8198e-02,  4.4727e-01,  ..., -1.1719e+00,
           6.7188e-01, -6.7188e-01],
         ...,
         [ 1.4297e+00,  1.7734e+00,  2.4062e+00,  ...,  2.2188e+00,
          -4.1562e+00,  3.7344e+00],
         [ 1.1250e+00,  8.6719e-01,  1.9824e-01,  ...,  1.0498e-01,
          -5.9766e-01,  9.2578e-01],
         [-1.2793e-01,  2.3438e-01,  4.6680e-01,  ..., -9.1406e-01,
           3.9258e-01, -3.1055e-01]],

        ...,

        [[-1.2054e-03,  1.3611e-02, -3.5553e-03,  ..., -1.2817e-03,
          -1.3123e-03,  2.8839e-03],
         [-1.1377e-01, -1.1875e+00,  1.1719e+00,  ..., -1.6250e+00,
           4.6680e-01, -6.9531e-01],
         [ 1.9453e+00,  1.0840e-01,  1.4038e-02,  ...,  8.9844e-02,
          -5.5078e-01,  4.6484e-01],
         ...,
         [-3.6719e-01, -7.7344e-01,  5.8984e-01,  ..., -6.6895e-02,
           4.2773e-01, -7.6562e-01],
         [ 7.6172e-02, -5.0781e-01,  1.3965e-01,  ...,  7.7344e-01,
           6.1719e-01, -5.9766e-01],
         [-3.9062e-01,  1.7334e-02, -5.6250e-01,  ...,  1.3438e+00,
           4.4336e-01,  1.1768e-01]],

        [[-1.2054e-03,  1.3611e-02, -3.5553e-03,  ..., -1.2817e-03,
          -1.3123e-03,  2.8839e-03],
         [-1.1377e-01, -1.1875e+00,  1.1719e+00,  ..., -1.6250e+00,
           4.6680e-01, -6.9531e-01],
         [ 1.9453e+00,  1.0840e-01,  1.4038e-02,  ...,  8.9844e-02,
          -5.5078e-01,  4.6484e-01],
         ...,
         [-3.6719e-01, -7.7344e-01,  5.8984e-01,  ..., -6.6895e-02,
           4.2773e-01, -7.6562e-01],
         [ 7.6172e-02, -5.0781e-01,  1.3965e-01,  ...,  7.7344e-01,
           6.1719e-01, -5.9766e-01],
         [-3.9062e-01,  1.7334e-02, -5.6250e-01,  ...,  1.3438e+00,
           4.4336e-01,  1.1768e-01]],

        [[-1.2054e-03,  1.3611e-02, -3.5553e-03,  ..., -1.2817e-03,
          -1.3123e-03,  2.8839e-03],
         [-1.1377e-01, -1.1875e+00,  1.1719e+00,  ..., -1.6250e+00,
           4.6680e-01, -6.9531e-01],
         [ 1.9453e+00,  1.0840e-01,  1.4038e-02,  ...,  8.9844e-02,
          -5.5078e-01,  4.6484e-01],
         ...,
         [-3.6719e-01, -7.7344e-01,  5.8984e-01,  ..., -6.6895e-02,
           4.2773e-01, -7.6562e-01],
         [ 7.6172e-02, -5.0781e-01,  1.3965e-01,  ...,  7.7344e-01,
           6.1719e-01, -5.9766e-01],
         [-3.9062e-01,  1.7334e-02, -5.6250e-01,  ...,  1.3438e+00,
           4.4336e-01,  1.1768e-01]]], dtype=torch.bfloat16)), (tensor([[[ 1.3977e-02, -1.9302e-03,  5.9509e-03,  ..., -5.0537e-02,
          -9.7656e-02, -5.2734e-01],
         [-1.3203e+00, -7.7344e-01, -8.3984e-01,  ..., -8.1641e-01,
          -6.2109e-01,  8.3750e+00],
         [ 3.7500e-01,  5.3125e-01, -1.0938e+00,  ..., -1.4062e-01,
           9.2285e-02,  8.3750e+00],
         ...,
         [ 1.4766e+00, -5.0391e-01,  1.0156e+00,  ...,  2.3594e+00,
           8.2812e-01,  7.3438e+00],
         [ 7.5000e-01, -1.3984e+00, -9.0625e-01,  ...,  1.5781e+00,
           1.3047e+00,  7.1250e+00],
         [-7.2656e-01, -2.0469e+00, -1.5156e+00,  ...,  1.2969e+00,
           1.9688e+00,  7.9062e+00]],

        [[ 1.3977e-02, -1.9302e-03,  5.9509e-03,  ..., -5.0537e-02,
          -9.7656e-02, -5.2734e-01],
         [-1.3203e+00, -7.7344e-01, -8.3984e-01,  ..., -8.1641e-01,
          -6.2109e-01,  8.3750e+00],
         [ 3.7500e-01,  5.3125e-01, -1.0938e+00,  ..., -1.4062e-01,
           9.2285e-02,  8.3750e+00],
         ...,
         [ 1.4766e+00, -5.0391e-01,  1.0156e+00,  ...,  2.3594e+00,
           8.2812e-01,  7.3438e+00],
         [ 7.5000e-01, -1.3984e+00, -9.0625e-01,  ...,  1.5781e+00,
           1.3047e+00,  7.1250e+00],
         [-7.2656e-01, -2.0469e+00, -1.5156e+00,  ...,  1.2969e+00,
           1.9688e+00,  7.9062e+00]],

        [[ 1.3977e-02, -1.9302e-03,  5.9509e-03,  ..., -5.0537e-02,
          -9.7656e-02, -5.2734e-01],
         [-1.3203e+00, -7.7344e-01, -8.3984e-01,  ..., -8.1641e-01,
          -6.2109e-01,  8.3750e+00],
         [ 3.7500e-01,  5.3125e-01, -1.0938e+00,  ..., -1.4062e-01,
           9.2285e-02,  8.3750e+00],
         ...,
         [ 1.4766e+00, -5.0391e-01,  1.0156e+00,  ...,  2.3594e+00,
           8.2812e-01,  7.3438e+00],
         [ 7.5000e-01, -1.3984e+00, -9.0625e-01,  ...,  1.5781e+00,
           1.3047e+00,  7.1250e+00],
         [-7.2656e-01, -2.0469e+00, -1.5156e+00,  ...,  1.2969e+00,
           1.9688e+00,  7.9062e+00]],

        ...,

        [[ 3.8719e-04, -5.5237e-03,  4.5013e-04,  ...,  4.7852e-02,
           4.2114e-03,  6.7383e-02],
         [-2.9062e+00,  1.9297e+00,  1.4219e+00,  ...,  3.7305e-01,
          -9.2578e-01,  1.9766e+00],
         [-1.4062e+00,  7.8906e-01,  4.8047e-01,  ..., -1.0840e-01,
           1.9922e+00,  8.5938e-01],
         ...,
         [ 6.9531e-01,  6.0938e-01,  1.4297e+00,  ...,  8.3203e-01,
           7.2656e-01, -1.8848e-01],
         [ 2.4414e-01,  8.5938e-02,  7.0801e-02,  ...,  3.3203e-02,
           1.9219e+00, -1.5938e+00],
         [ 1.6953e+00,  1.1328e+00,  2.1875e+00,  ..., -2.9883e-01,
          -7.6953e-01, -5.1953e-01]],

        [[ 3.8719e-04, -5.5237e-03,  4.5013e-04,  ...,  4.7852e-02,
           4.2114e-03,  6.7383e-02],
         [-2.9062e+00,  1.9297e+00,  1.4219e+00,  ...,  3.7305e-01,
          -9.2578e-01,  1.9766e+00],
         [-1.4062e+00,  7.8906e-01,  4.8047e-01,  ..., -1.0840e-01,
           1.9922e+00,  8.5938e-01],
         ...,
         [ 6.9531e-01,  6.0938e-01,  1.4297e+00,  ...,  8.3203e-01,
           7.2656e-01, -1.8848e-01],
         [ 2.4414e-01,  8.5938e-02,  7.0801e-02,  ...,  3.3203e-02,
           1.9219e+00, -1.5938e+00],
         [ 1.6953e+00,  1.1328e+00,  2.1875e+00,  ..., -2.9883e-01,
          -7.6953e-01, -5.1953e-01]],

        [[ 3.8719e-04, -5.5237e-03,  4.5013e-04,  ...,  4.7852e-02,
           4.2114e-03,  6.7383e-02],
         [-2.9062e+00,  1.9297e+00,  1.4219e+00,  ...,  3.7305e-01,
          -9.2578e-01,  1.9766e+00],
         [-1.4062e+00,  7.8906e-01,  4.8047e-01,  ..., -1.0840e-01,
           1.9922e+00,  8.5938e-01],
         ...,
         [ 6.9531e-01,  6.0938e-01,  1.4297e+00,  ...,  8.3203e-01,
           7.2656e-01, -1.8848e-01],
         [ 2.4414e-01,  8.5938e-02,  7.0801e-02,  ...,  3.3203e-02,
           1.9219e+00, -1.5938e+00],
         [ 1.6953e+00,  1.1328e+00,  2.1875e+00,  ..., -2.9883e-01,
          -7.6953e-01, -5.1953e-01]]], dtype=torch.bfloat16), tensor([[[ 6.1035e-03,  4.6997e-03,  3.7384e-04,  ..., -1.2817e-03,
           8.5449e-03,  3.8757e-03],
         [-1.0234e+00, -6.3672e-01, -7.3047e-01,  ...,  4.5312e-01,
          -1.2598e-01,  1.3750e+00],
         [ 2.8516e-01, -2.5312e+00, -1.6309e-01,  ...,  1.1953e+00,
           8.6328e-01,  5.8594e-01],
         ...,
         [ 2.3125e+00,  4.6094e-01, -1.4688e+00,  ..., -6.4062e-01,
           2.8906e-01, -4.1602e-01],
         [ 8.6328e-01, -4.2773e-01, -1.3516e+00,  ..., -4.7852e-01,
           9.4922e-01, -4.8242e-01],
         [ 9.8438e-01, -1.0156e+00, -1.1406e+00,  ..., -1.1172e+00,
           1.4941e-01, -3.0859e-01]],

        [[ 6.1035e-03,  4.6997e-03,  3.7384e-04,  ..., -1.2817e-03,
           8.5449e-03,  3.8757e-03],
         [-1.0234e+00, -6.3672e-01, -7.3047e-01,  ...,  4.5312e-01,
          -1.2598e-01,  1.3750e+00],
         [ 2.8516e-01, -2.5312e+00, -1.6309e-01,  ...,  1.1953e+00,
           8.6328e-01,  5.8594e-01],
         ...,
         [ 2.3125e+00,  4.6094e-01, -1.4688e+00,  ..., -6.4062e-01,
           2.8906e-01, -4.1602e-01],
         [ 8.6328e-01, -4.2773e-01, -1.3516e+00,  ..., -4.7852e-01,
           9.4922e-01, -4.8242e-01],
         [ 9.8438e-01, -1.0156e+00, -1.1406e+00,  ..., -1.1172e+00,
           1.4941e-01, -3.0859e-01]],

        [[ 6.1035e-03,  4.6997e-03,  3.7384e-04,  ..., -1.2817e-03,
           8.5449e-03,  3.8757e-03],
         [-1.0234e+00, -6.3672e-01, -7.3047e-01,  ...,  4.5312e-01,
          -1.2598e-01,  1.3750e+00],
         [ 2.8516e-01, -2.5312e+00, -1.6309e-01,  ...,  1.1953e+00,
           8.6328e-01,  5.8594e-01],
         ...,
         [ 2.3125e+00,  4.6094e-01, -1.4688e+00,  ..., -6.4062e-01,
           2.8906e-01, -4.1602e-01],
         [ 8.6328e-01, -4.2773e-01, -1.3516e+00,  ..., -4.7852e-01,
           9.4922e-01, -4.8242e-01],
         [ 9.8438e-01, -1.0156e+00, -1.1406e+00,  ..., -1.1172e+00,
           1.4941e-01, -3.0859e-01]],

        ...,

        [[-3.0136e-04, -1.9379e-03, -3.3569e-03,  ..., -4.6692e-03,
          -7.5989e-03, -9.1934e-04],
         [-2.0117e-01, -1.7969e+00,  1.0791e-01,  ...,  2.6406e+00,
          -6.2500e-01, -4.7656e-01],
         [-1.2734e+00,  4.7070e-01, -9.9609e-01,  ..., -2.3438e+00,
           1.2344e+00,  9.4141e-01],
         ...,
         [-1.7773e-01, -8.2031e-01, -7.7148e-02,  ..., -4.6875e-01,
           1.4355e-01, -1.6309e-01],
         [-4.1797e-01,  1.2188e+00, -3.3984e-01,  ..., -1.3672e+00,
           1.1484e+00,  4.8633e-01],
         [-1.6406e-01,  7.0801e-02, -1.2451e-01,  ..., -2.0801e-01,
           8.0859e-01,  1.4038e-02]],

        [[-3.0136e-04, -1.9379e-03, -3.3569e-03,  ..., -4.6692e-03,
          -7.5989e-03, -9.1934e-04],
         [-2.0117e-01, -1.7969e+00,  1.0791e-01,  ...,  2.6406e+00,
          -6.2500e-01, -4.7656e-01],
         [-1.2734e+00,  4.7070e-01, -9.9609e-01,  ..., -2.3438e+00,
           1.2344e+00,  9.4141e-01],
         ...,
         [-1.7773e-01, -8.2031e-01, -7.7148e-02,  ..., -4.6875e-01,
           1.4355e-01, -1.6309e-01],
         [-4.1797e-01,  1.2188e+00, -3.3984e-01,  ..., -1.3672e+00,
           1.1484e+00,  4.8633e-01],
         [-1.6406e-01,  7.0801e-02, -1.2451e-01,  ..., -2.0801e-01,
           8.0859e-01,  1.4038e-02]],

        [[-3.0136e-04, -1.9379e-03, -3.3569e-03,  ..., -4.6692e-03,
          -7.5989e-03, -9.1934e-04],
         [-2.0117e-01, -1.7969e+00,  1.0791e-01,  ...,  2.6406e+00,
          -6.2500e-01, -4.7656e-01],
         [-1.2734e+00,  4.7070e-01, -9.9609e-01,  ..., -2.3438e+00,
           1.2344e+00,  9.4141e-01],
         ...,
         [-1.7773e-01, -8.2031e-01, -7.7148e-02,  ..., -4.6875e-01,
           1.4355e-01, -1.6309e-01],
         [-4.1797e-01,  1.2188e+00, -3.3984e-01,  ..., -1.3672e+00,
           1.1484e+00,  4.8633e-01],
         [-1.6406e-01,  7.0801e-02, -1.2451e-01,  ..., -2.0801e-01,
           8.0859e-01,  1.4038e-02]]], dtype=torch.bfloat16)), (tensor([[[-1.1215e-03,  1.2283e-03,  4.0588e-03,  ...,  6.5430e-02,
           7.5684e-03, -4.1748e-02],
         [-5.1953e-01, -8.5938e-01, -1.2812e+00,  ..., -8.3203e-01,
           1.4141e+00,  1.0859e+00],
         [-1.6719e+00, -1.0156e+00, -2.2500e+00,  ..., -1.5859e+00,
          -1.4297e+00,  2.0781e+00],
         ...,
         [-3.5938e-01,  2.8516e-01,  2.2070e-01,  ..., -2.7344e+00,
          -8.9844e-02,  2.7812e+00],
         [-2.6758e-01,  4.6680e-01,  2.0508e-02,  ..., -4.0625e+00,
           6.5234e-01,  2.2812e+00],
         [ 1.6953e+00, -2.8906e-01, -9.8828e-01,  ..., -2.5938e+00,
          -2.1484e-01,  1.8281e+00]],

        [[-1.1215e-03,  1.2283e-03,  4.0588e-03,  ...,  6.5430e-02,
           7.5684e-03, -4.1748e-02],
         [-5.1953e-01, -8.5938e-01, -1.2812e+00,  ..., -8.3203e-01,
           1.4141e+00,  1.0859e+00],
         [-1.6719e+00, -1.0156e+00, -2.2500e+00,  ..., -1.5859e+00,
          -1.4297e+00,  2.0781e+00],
         ...,
         [-3.5938e-01,  2.8516e-01,  2.2070e-01,  ..., -2.7344e+00,
          -8.9844e-02,  2.7812e+00],
         [-2.6758e-01,  4.6680e-01,  2.0508e-02,  ..., -4.0625e+00,
           6.5234e-01,  2.2812e+00],
         [ 1.6953e+00, -2.8906e-01, -9.8828e-01,  ..., -2.5938e+00,
          -2.1484e-01,  1.8281e+00]],

        [[-1.1215e-03,  1.2283e-03,  4.0588e-03,  ...,  6.5430e-02,
           7.5684e-03, -4.1748e-02],
         [-5.1953e-01, -8.5938e-01, -1.2812e+00,  ..., -8.3203e-01,
           1.4141e+00,  1.0859e+00],
         [-1.6719e+00, -1.0156e+00, -2.2500e+00,  ..., -1.5859e+00,
          -1.4297e+00,  2.0781e+00],
         ...,
         [-3.5938e-01,  2.8516e-01,  2.2070e-01,  ..., -2.7344e+00,
          -8.9844e-02,  2.7812e+00],
         [-2.6758e-01,  4.6680e-01,  2.0508e-02,  ..., -4.0625e+00,
           6.5234e-01,  2.2812e+00],
         [ 1.6953e+00, -2.8906e-01, -9.8828e-01,  ..., -2.5938e+00,
          -2.1484e-01,  1.8281e+00]],

        ...,

        [[-2.6245e-03,  7.7820e-03, -2.7924e-03,  ...,  9.4604e-03,
           4.3640e-03, -3.2227e-01],
         [-5.1172e-01, -1.6953e+00, -1.0156e+00,  ..., -1.5156e+00,
           2.1250e+00,  1.0250e+01],
         [-1.2812e+00, -1.2109e+00, -8.0078e-01,  ..., -3.9648e-01,
           2.2559e-01,  1.0125e+01],
         ...,
         [ 3.4375e-01,  9.1406e-01,  4.7656e-01,  ..., -2.1484e-02,
           2.7031e+00,  7.5625e+00],
         [ 6.3672e-01,  4.1406e-01, -7.1875e-01,  ..., -9.0625e-01,
           1.1250e+00,  8.0000e+00],
         [ 1.3516e+00, -1.4844e-01, -8.2812e-01,  ..., -1.6562e+00,
           4.6484e-01,  8.8125e+00]],

        [[-2.6245e-03,  7.7820e-03, -2.7924e-03,  ...,  9.4604e-03,
           4.3640e-03, -3.2227e-01],
         [-5.1172e-01, -1.6953e+00, -1.0156e+00,  ..., -1.5156e+00,
           2.1250e+00,  1.0250e+01],
         [-1.2812e+00, -1.2109e+00, -8.0078e-01,  ..., -3.9648e-01,
           2.2559e-01,  1.0125e+01],
         ...,
         [ 3.4375e-01,  9.1406e-01,  4.7656e-01,  ..., -2.1484e-02,
           2.7031e+00,  7.5625e+00],
         [ 6.3672e-01,  4.1406e-01, -7.1875e-01,  ..., -9.0625e-01,
           1.1250e+00,  8.0000e+00],
         [ 1.3516e+00, -1.4844e-01, -8.2812e-01,  ..., -1.6562e+00,
           4.6484e-01,  8.8125e+00]],

        [[-2.6245e-03,  7.7820e-03, -2.7924e-03,  ...,  9.4604e-03,
           4.3640e-03, -3.2227e-01],
         [-5.1172e-01, -1.6953e+00, -1.0156e+00,  ..., -1.5156e+00,
           2.1250e+00,  1.0250e+01],
         [-1.2812e+00, -1.2109e+00, -8.0078e-01,  ..., -3.9648e-01,
           2.2559e-01,  1.0125e+01],
         ...,
         [ 3.4375e-01,  9.1406e-01,  4.7656e-01,  ..., -2.1484e-02,
           2.7031e+00,  7.5625e+00],
         [ 6.3672e-01,  4.1406e-01, -7.1875e-01,  ..., -9.0625e-01,
           1.1250e+00,  8.0000e+00],
         [ 1.3516e+00, -1.4844e-01, -8.2812e-01,  ..., -1.6562e+00,
           4.6484e-01,  8.8125e+00]]], dtype=torch.bfloat16), tensor([[[ 5.8289e-03, -2.0905e-03, -2.3041e-03,  ...,  6.0654e-04,
           6.7139e-03, -1.5717e-03],
         [ 5.4688e-01,  5.8899e-03,  2.3145e-01,  ...,  1.2695e-01,
           4.5508e-01,  6.8750e-01],
         [ 9.8828e-01, -4.2969e-01,  9.2969e-01,  ..., -1.8828e+00,
           1.0547e+00, -1.2578e+00],
         ...,
         [ 9.4922e-01, -6.6797e-01,  1.3477e-01,  ..., -6.0156e-01,
           6.4941e-02,  2.5391e-01],
         [ 2.7539e-01,  6.7969e-01,  1.0791e-01,  ..., -8.9062e-01,
          -8.8281e-01,  1.0547e+00],
         [ 2.5000e-01,  4.5312e-01,  3.7109e-01,  ...,  6.5918e-02,
           7.6953e-01, -2.4023e-01]],

        [[ 5.8289e-03, -2.0905e-03, -2.3041e-03,  ...,  6.0654e-04,
           6.7139e-03, -1.5717e-03],
         [ 5.4688e-01,  5.8899e-03,  2.3145e-01,  ...,  1.2695e-01,
           4.5508e-01,  6.8750e-01],
         [ 9.8828e-01, -4.2969e-01,  9.2969e-01,  ..., -1.8828e+00,
           1.0547e+00, -1.2578e+00],
         ...,
         [ 9.4922e-01, -6.6797e-01,  1.3477e-01,  ..., -6.0156e-01,
           6.4941e-02,  2.5391e-01],
         [ 2.7539e-01,  6.7969e-01,  1.0791e-01,  ..., -8.9062e-01,
          -8.8281e-01,  1.0547e+00],
         [ 2.5000e-01,  4.5312e-01,  3.7109e-01,  ...,  6.5918e-02,
           7.6953e-01, -2.4023e-01]],

        [[ 5.8289e-03, -2.0905e-03, -2.3041e-03,  ...,  6.0654e-04,
           6.7139e-03, -1.5717e-03],
         [ 5.4688e-01,  5.8899e-03,  2.3145e-01,  ...,  1.2695e-01,
           4.5508e-01,  6.8750e-01],
         [ 9.8828e-01, -4.2969e-01,  9.2969e-01,  ..., -1.8828e+00,
           1.0547e+00, -1.2578e+00],
         ...,
         [ 9.4922e-01, -6.6797e-01,  1.3477e-01,  ..., -6.0156e-01,
           6.4941e-02,  2.5391e-01],
         [ 2.7539e-01,  6.7969e-01,  1.0791e-01,  ..., -8.9062e-01,
          -8.8281e-01,  1.0547e+00],
         [ 2.5000e-01,  4.5312e-01,  3.7109e-01,  ...,  6.5918e-02,
           7.6953e-01, -2.4023e-01]],

        ...,

        [[-3.1128e-03, -7.8583e-04,  5.3406e-03,  ..., -7.8125e-03,
           1.1963e-02, -1.4420e-03],
         [ 8.1250e-01, -4.7461e-01, -8.2812e-01,  ...,  4.3555e-01,
          -1.8125e+00, -8.7891e-01],
         [-1.0391e+00, -8.7500e-01, -4.8242e-01,  ..., -1.1963e-01,
          -1.3984e+00, -1.8359e-01],
         ...,
         [ 2.8320e-01,  3.9844e-01,  4.3555e-01,  ...,  9.3750e-01,
          -6.4453e-01, -2.7539e-01],
         [ 8.2812e-01, -1.3770e-01,  5.3125e-01,  ...,  5.3516e-01,
          -3.3789e-01,  4.6484e-01],
         [ 9.4141e-01, -6.2109e-01,  8.3984e-01,  ..., -5.2344e-01,
          -1.8262e-01, -1.5918e-01]],

        [[-3.1128e-03, -7.8583e-04,  5.3406e-03,  ..., -7.8125e-03,
           1.1963e-02, -1.4420e-03],
         [ 8.1250e-01, -4.7461e-01, -8.2812e-01,  ...,  4.3555e-01,
          -1.8125e+00, -8.7891e-01],
         [-1.0391e+00, -8.7500e-01, -4.8242e-01,  ..., -1.1963e-01,
          -1.3984e+00, -1.8359e-01],
         ...,
         [ 2.8320e-01,  3.9844e-01,  4.3555e-01,  ...,  9.3750e-01,
          -6.4453e-01, -2.7539e-01],
         [ 8.2812e-01, -1.3770e-01,  5.3125e-01,  ...,  5.3516e-01,
          -3.3789e-01,  4.6484e-01],
         [ 9.4141e-01, -6.2109e-01,  8.3984e-01,  ..., -5.2344e-01,
          -1.8262e-01, -1.5918e-01]],

        [[-3.1128e-03, -7.8583e-04,  5.3406e-03,  ..., -7.8125e-03,
           1.1963e-02, -1.4420e-03],
         [ 8.1250e-01, -4.7461e-01, -8.2812e-01,  ...,  4.3555e-01,
          -1.8125e+00, -8.7891e-01],
         [-1.0391e+00, -8.7500e-01, -4.8242e-01,  ..., -1.1963e-01,
          -1.3984e+00, -1.8359e-01],
         ...,
         [ 2.8320e-01,  3.9844e-01,  4.3555e-01,  ...,  9.3750e-01,
          -6.4453e-01, -2.7539e-01],
         [ 8.2812e-01, -1.3770e-01,  5.3125e-01,  ...,  5.3516e-01,
          -3.3789e-01,  4.6484e-01],
         [ 9.4141e-01, -6.2109e-01,  8.3984e-01,  ..., -5.2344e-01,
          -1.8262e-01, -1.5918e-01]]], dtype=torch.bfloat16)), (tensor([[[-2.0508e-02,  6.1340e-03, -8.1787e-03,  ..., -1.7285e-01,
          -1.0010e-01,  9.5703e-02],
         [ 2.2656e+00, -1.7500e+00,  1.6562e+00,  ...,  1.0625e+00,
           8.9453e-01, -1.7969e+00],
         [ 1.6172e+00, -1.5234e+00,  1.7031e+00,  ...,  7.7344e-01,
           8.0078e-02,  5.8984e-01],
         ...,
         [-6.4844e-01, -1.1328e-01, -3.1250e-01,  ...,  1.3750e+00,
           1.0703e+00, -1.3984e+00],
         [-9.0625e-01,  7.6562e-01,  9.9609e-01,  ...,  1.2422e+00,
           1.3906e+00, -6.6406e-01],
         [-1.0938e+00,  1.5156e+00,  6.9141e-01,  ...,  1.1562e+00,
           3.0781e+00, -8.7109e-01]],

        [[-2.0508e-02,  6.1340e-03, -8.1787e-03,  ..., -1.7285e-01,
          -1.0010e-01,  9.5703e-02],
         [ 2.2656e+00, -1.7500e+00,  1.6562e+00,  ...,  1.0625e+00,
           8.9453e-01, -1.7969e+00],
         [ 1.6172e+00, -1.5234e+00,  1.7031e+00,  ...,  7.7344e-01,
           8.0078e-02,  5.8984e-01],
         ...,
         [-6.4844e-01, -1.1328e-01, -3.1250e-01,  ...,  1.3750e+00,
           1.0703e+00, -1.3984e+00],
         [-9.0625e-01,  7.6562e-01,  9.9609e-01,  ...,  1.2422e+00,
           1.3906e+00, -6.6406e-01],
         [-1.0938e+00,  1.5156e+00,  6.9141e-01,  ...,  1.1562e+00,
           3.0781e+00, -8.7109e-01]],

        [[-2.0508e-02,  6.1340e-03, -8.1787e-03,  ..., -1.7285e-01,
          -1.0010e-01,  9.5703e-02],
         [ 2.2656e+00, -1.7500e+00,  1.6562e+00,  ...,  1.0625e+00,
           8.9453e-01, -1.7969e+00],
         [ 1.6172e+00, -1.5234e+00,  1.7031e+00,  ...,  7.7344e-01,
           8.0078e-02,  5.8984e-01],
         ...,
         [-6.4844e-01, -1.1328e-01, -3.1250e-01,  ...,  1.3750e+00,
           1.0703e+00, -1.3984e+00],
         [-9.0625e-01,  7.6562e-01,  9.9609e-01,  ...,  1.2422e+00,
           1.3906e+00, -6.6406e-01],
         [-1.0938e+00,  1.5156e+00,  6.9141e-01,  ...,  1.1562e+00,
           3.0781e+00, -8.7109e-01]],

        ...,

        [[-7.2021e-03, -4.3335e-03,  4.3335e-03,  ..., -1.3245e-02,
          -7.4609e-01,  1.4941e-01],
         [-1.0156e+00,  2.8125e+00,  1.6016e-01,  ...,  1.0391e+00,
           7.6562e+00, -8.5156e-01],
         [-2.3750e+00,  1.4844e+00,  1.1094e+00,  ...,  5.0781e-01,
           8.9375e+00, -2.2812e+00],
         ...,
         [-5.1953e-01, -7.2266e-01, -1.1797e+00,  ...,  5.1953e-01,
           7.0938e+00, -3.4570e-01],
         [-7.8125e-03, -1.4922e+00, -1.1328e+00,  ..., -5.7812e-01,
           7.4688e+00, -7.3047e-01],
         [ 1.5469e+00, -1.4062e-01, -4.6875e-01,  ...,  7.4219e-01,
           7.5938e+00, -6.0156e-01]],

        [[-7.2021e-03, -4.3335e-03,  4.3335e-03,  ..., -1.3245e-02,
          -7.4609e-01,  1.4941e-01],
         [-1.0156e+00,  2.8125e+00,  1.6016e-01,  ...,  1.0391e+00,
           7.6562e+00, -8.5156e-01],
         [-2.3750e+00,  1.4844e+00,  1.1094e+00,  ...,  5.0781e-01,
           8.9375e+00, -2.2812e+00],
         ...,
         [-5.1953e-01, -7.2266e-01, -1.1797e+00,  ...,  5.1953e-01,
           7.0938e+00, -3.4570e-01],
         [-7.8125e-03, -1.4922e+00, -1.1328e+00,  ..., -5.7812e-01,
           7.4688e+00, -7.3047e-01],
         [ 1.5469e+00, -1.4062e-01, -4.6875e-01,  ...,  7.4219e-01,
           7.5938e+00, -6.0156e-01]],

        [[-7.2021e-03, -4.3335e-03,  4.3335e-03,  ..., -1.3245e-02,
          -7.4609e-01,  1.4941e-01],
         [-1.0156e+00,  2.8125e+00,  1.6016e-01,  ...,  1.0391e+00,
           7.6562e+00, -8.5156e-01],
         [-2.3750e+00,  1.4844e+00,  1.1094e+00,  ...,  5.0781e-01,
           8.9375e+00, -2.2812e+00],
         ...,
         [-5.1953e-01, -7.2266e-01, -1.1797e+00,  ...,  5.1953e-01,
           7.0938e+00, -3.4570e-01],
         [-7.8125e-03, -1.4922e+00, -1.1328e+00,  ..., -5.7812e-01,
           7.4688e+00, -7.3047e-01],
         [ 1.5469e+00, -1.4062e-01, -4.6875e-01,  ...,  7.4219e-01,
           7.5938e+00, -6.0156e-01]]], dtype=torch.bfloat16), tensor([[[ 2.0752e-03, -2.1210e-03, -6.5002e-03,  ...,  5.2185e-03,
          -5.4626e-03, -2.9144e-03],
         [ 8.2397e-03, -2.9102e-01,  2.8516e-01,  ..., -9.3750e-01,
          -5.7812e-01,  6.2109e-01],
         [ 3.7500e-01, -6.1328e-01,  1.0547e-01,  ..., -1.8438e+00,
           4.2969e-01, -1.8281e+00],
         ...,
         [ 2.6367e-01, -1.3047e+00, -4.6680e-01,  ...,  6.3965e-02,
          -9.1406e-01,  1.2891e-01],
         [ 1.0156e-01, -1.0000e+00, -1.1094e+00,  ...,  8.4766e-01,
          -1.0938e+00, -1.3984e+00],
         [ 9.5312e-01, -6.4062e-01, -1.6406e-01,  ...,  7.2266e-01,
          -1.3359e+00, -6.6016e-01]],

        [[ 2.0752e-03, -2.1210e-03, -6.5002e-03,  ...,  5.2185e-03,
          -5.4626e-03, -2.9144e-03],
         [ 8.2397e-03, -2.9102e-01,  2.8516e-01,  ..., -9.3750e-01,
          -5.7812e-01,  6.2109e-01],
         [ 3.7500e-01, -6.1328e-01,  1.0547e-01,  ..., -1.8438e+00,
           4.2969e-01, -1.8281e+00],
         ...,
         [ 2.6367e-01, -1.3047e+00, -4.6680e-01,  ...,  6.3965e-02,
          -9.1406e-01,  1.2891e-01],
         [ 1.0156e-01, -1.0000e+00, -1.1094e+00,  ...,  8.4766e-01,
          -1.0938e+00, -1.3984e+00],
         [ 9.5312e-01, -6.4062e-01, -1.6406e-01,  ...,  7.2266e-01,
          -1.3359e+00, -6.6016e-01]],

        [[ 2.0752e-03, -2.1210e-03, -6.5002e-03,  ...,  5.2185e-03,
          -5.4626e-03, -2.9144e-03],
         [ 8.2397e-03, -2.9102e-01,  2.8516e-01,  ..., -9.3750e-01,
          -5.7812e-01,  6.2109e-01],
         [ 3.7500e-01, -6.1328e-01,  1.0547e-01,  ..., -1.8438e+00,
           4.2969e-01, -1.8281e+00],
         ...,
         [ 2.6367e-01, -1.3047e+00, -4.6680e-01,  ...,  6.3965e-02,
          -9.1406e-01,  1.2891e-01],
         [ 1.0156e-01, -1.0000e+00, -1.1094e+00,  ...,  8.4766e-01,
          -1.0938e+00, -1.3984e+00],
         [ 9.5312e-01, -6.4062e-01, -1.6406e-01,  ...,  7.2266e-01,
          -1.3359e+00, -6.6016e-01]],

        ...,

        [[ 4.7302e-03,  9.9182e-04, -4.4556e-03,  ...,  1.8921e-03,
          -6.1340e-03,  2.1820e-03],
         [ 5.3906e-01,  8.6426e-02, -3.4424e-02,  ...,  9.8438e-01,
           8.1250e-01,  7.6172e-01],
         [-3.0859e-01,  3.6719e-01,  1.8906e+00,  ...,  1.7871e-01,
           6.7871e-02, -2.1562e+00],
         ...,
         [-5.6250e-01, -1.3867e-01, -5.7373e-02,  ...,  1.0449e-01,
           1.4355e-01,  1.0547e-01],
         [ 1.5723e-01,  6.2109e-01, -7.5391e-01,  ..., -1.1250e+00,
          -1.0010e-01, -8.0859e-01],
         [-4.2578e-01, -2.8711e-01,  8.2422e-01,  ..., -1.1523e-01,
          -3.2617e-01,  4.1797e-01]],

        [[ 4.7302e-03,  9.9182e-04, -4.4556e-03,  ...,  1.8921e-03,
          -6.1340e-03,  2.1820e-03],
         [ 5.3906e-01,  8.6426e-02, -3.4424e-02,  ...,  9.8438e-01,
           8.1250e-01,  7.6172e-01],
         [-3.0859e-01,  3.6719e-01,  1.8906e+00,  ...,  1.7871e-01,
           6.7871e-02, -2.1562e+00],
         ...,
         [-5.6250e-01, -1.3867e-01, -5.7373e-02,  ...,  1.0449e-01,
           1.4355e-01,  1.0547e-01],
         [ 1.5723e-01,  6.2109e-01, -7.5391e-01,  ..., -1.1250e+00,
          -1.0010e-01, -8.0859e-01],
         [-4.2578e-01, -2.8711e-01,  8.2422e-01,  ..., -1.1523e-01,
          -3.2617e-01,  4.1797e-01]],

        [[ 4.7302e-03,  9.9182e-04, -4.4556e-03,  ...,  1.8921e-03,
          -6.1340e-03,  2.1820e-03],
         [ 5.3906e-01,  8.6426e-02, -3.4424e-02,  ...,  9.8438e-01,
           8.1250e-01,  7.6172e-01],
         [-3.0859e-01,  3.6719e-01,  1.8906e+00,  ...,  1.7871e-01,
           6.7871e-02, -2.1562e+00],
         ...,
         [-5.6250e-01, -1.3867e-01, -5.7373e-02,  ...,  1.0449e-01,
           1.4355e-01,  1.0547e-01],
         [ 1.5723e-01,  6.2109e-01, -7.5391e-01,  ..., -1.1250e+00,
          -1.0010e-01, -8.0859e-01],
         [-4.2578e-01, -2.8711e-01,  8.2422e-01,  ..., -1.1523e-01,
          -3.2617e-01,  4.1797e-01]]], dtype=torch.bfloat16)), (tensor([[[ 3.7994e-03, -2.5146e-02, -1.2360e-03,  ..., -2.3633e-01,
          -9.6680e-02,  1.0059e-01],
         [ 2.5000e-01, -3.6328e-01,  3.5156e-01,  ..., -8.7402e-02,
          -9.6484e-01, -1.3516e+00],
         [ 1.1797e+00, -1.0859e+00, -7.4609e-01,  ..., -5.4062e+00,
          -4.2500e+00, -3.3750e+00],
         ...,
         [ 1.1250e+00,  4.6289e-01,  4.1602e-01,  ..., -4.9062e+00,
          -5.9766e-01, -1.3359e+00],
         [-2.5977e-01, -3.4180e-01,  6.4062e-01,  ..., -4.8750e+00,
           1.2793e-01,  8.6719e-01],
         [-6.8359e-01,  4.2578e-01, -1.1406e+00,  ..., -3.5938e+00,
          -2.6719e+00, -1.1641e+00]],

        [[ 3.7994e-03, -2.5146e-02, -1.2360e-03,  ..., -2.3633e-01,
          -9.6680e-02,  1.0059e-01],
         [ 2.5000e-01, -3.6328e-01,  3.5156e-01,  ..., -8.7402e-02,
          -9.6484e-01, -1.3516e+00],
         [ 1.1797e+00, -1.0859e+00, -7.4609e-01,  ..., -5.4062e+00,
          -4.2500e+00, -3.3750e+00],
         ...,
         [ 1.1250e+00,  4.6289e-01,  4.1602e-01,  ..., -4.9062e+00,
          -5.9766e-01, -1.3359e+00],
         [-2.5977e-01, -3.4180e-01,  6.4062e-01,  ..., -4.8750e+00,
           1.2793e-01,  8.6719e-01],
         [-6.8359e-01,  4.2578e-01, -1.1406e+00,  ..., -3.5938e+00,
          -2.6719e+00, -1.1641e+00]],

        [[ 3.7994e-03, -2.5146e-02, -1.2360e-03,  ..., -2.3633e-01,
          -9.6680e-02,  1.0059e-01],
         [ 2.5000e-01, -3.6328e-01,  3.5156e-01,  ..., -8.7402e-02,
          -9.6484e-01, -1.3516e+00],
         [ 1.1797e+00, -1.0859e+00, -7.4609e-01,  ..., -5.4062e+00,
          -4.2500e+00, -3.3750e+00],
         ...,
         [ 1.1250e+00,  4.6289e-01,  4.1602e-01,  ..., -4.9062e+00,
          -5.9766e-01, -1.3359e+00],
         [-2.5977e-01, -3.4180e-01,  6.4062e-01,  ..., -4.8750e+00,
           1.2793e-01,  8.6719e-01],
         [-6.8359e-01,  4.2578e-01, -1.1406e+00,  ..., -3.5938e+00,
          -2.6719e+00, -1.1641e+00]],

        ...,

        [[-8.1177e-03, -2.3041e-03, -8.8501e-04,  ...,  4.3213e-02,
           1.1426e-01,  1.2305e-01],
         [ 1.2969e+00,  1.7734e+00,  1.0703e+00,  ...,  2.6406e+00,
           9.8828e-01,  1.5723e-01],
         [ 1.7188e-01,  6.2500e-01,  2.2344e+00,  ...,  1.0625e+00,
          -7.8906e-01,  2.1484e-01],
         ...,
         [ 6.8359e-01, -8.3984e-01,  2.1719e+00,  ...,  5.4297e-01,
           5.5078e-01, -2.9688e+00],
         [ 2.4023e-01, -6.1719e-01,  1.1328e+00,  ...,  2.5977e-01,
           7.6172e-02, -2.4023e-01],
         [-1.1719e+00,  9.5703e-02,  1.2422e+00,  ...,  2.5156e+00,
           1.2578e+00, -1.1797e+00]],

        [[-8.1177e-03, -2.3041e-03, -8.8501e-04,  ...,  4.3213e-02,
           1.1426e-01,  1.2305e-01],
         [ 1.2969e+00,  1.7734e+00,  1.0703e+00,  ...,  2.6406e+00,
           9.8828e-01,  1.5723e-01],
         [ 1.7188e-01,  6.2500e-01,  2.2344e+00,  ...,  1.0625e+00,
          -7.8906e-01,  2.1484e-01],
         ...,
         [ 6.8359e-01, -8.3984e-01,  2.1719e+00,  ...,  5.4297e-01,
           5.5078e-01, -2.9688e+00],
         [ 2.4023e-01, -6.1719e-01,  1.1328e+00,  ...,  2.5977e-01,
           7.6172e-02, -2.4023e-01],
         [-1.1719e+00,  9.5703e-02,  1.2422e+00,  ...,  2.5156e+00,
           1.2578e+00, -1.1797e+00]],

        [[-8.1177e-03, -2.3041e-03, -8.8501e-04,  ...,  4.3213e-02,
           1.1426e-01,  1.2305e-01],
         [ 1.2969e+00,  1.7734e+00,  1.0703e+00,  ...,  2.6406e+00,
           9.8828e-01,  1.5723e-01],
         [ 1.7188e-01,  6.2500e-01,  2.2344e+00,  ...,  1.0625e+00,
          -7.8906e-01,  2.1484e-01],
         ...,
         [ 6.8359e-01, -8.3984e-01,  2.1719e+00,  ...,  5.4297e-01,
           5.5078e-01, -2.9688e+00],
         [ 2.4023e-01, -6.1719e-01,  1.1328e+00,  ...,  2.5977e-01,
           7.6172e-02, -2.4023e-01],
         [-1.1719e+00,  9.5703e-02,  1.2422e+00,  ...,  2.5156e+00,
           1.2578e+00, -1.1797e+00]]], dtype=torch.bfloat16), tensor([[[ 7.2327e-03, -4.8218e-03, -5.1575e-03,  ...,  1.7090e-03,
           1.6479e-03,  5.5847e-03],
         [-1.0781e+00, -9.1406e-01,  2.9688e+00,  ..., -1.7031e+00,
          -5.6641e-01, -1.1016e+00],
         [-1.9219e+00, -4.1562e+00,  4.4688e+00,  ..., -7.5781e-01,
          -1.3359e+00, -9.6094e-01],
         ...,
         [-1.5469e+00, -5.3516e-01, -2.2461e-01,  ...,  1.4062e+00,
           9.1016e-01,  1.1406e+00],
         [-4.6094e-01, -2.2656e+00,  1.8750e+00,  ..., -7.8125e-01,
           5.0781e-01,  1.2891e+00],
         [-7.1875e-01, -2.1875e+00,  2.0938e+00,  ...,  1.0938e-01,
          -2.0996e-01,  9.2188e-01]],

        [[ 7.2327e-03, -4.8218e-03, -5.1575e-03,  ...,  1.7090e-03,
           1.6479e-03,  5.5847e-03],
         [-1.0781e+00, -9.1406e-01,  2.9688e+00,  ..., -1.7031e+00,
          -5.6641e-01, -1.1016e+00],
         [-1.9219e+00, -4.1562e+00,  4.4688e+00,  ..., -7.5781e-01,
          -1.3359e+00, -9.6094e-01],
         ...,
         [-1.5469e+00, -5.3516e-01, -2.2461e-01,  ...,  1.4062e+00,
           9.1016e-01,  1.1406e+00],
         [-4.6094e-01, -2.2656e+00,  1.8750e+00,  ..., -7.8125e-01,
           5.0781e-01,  1.2891e+00],
         [-7.1875e-01, -2.1875e+00,  2.0938e+00,  ...,  1.0938e-01,
          -2.0996e-01,  9.2188e-01]],

        [[ 7.2327e-03, -4.8218e-03, -5.1575e-03,  ...,  1.7090e-03,
           1.6479e-03,  5.5847e-03],
         [-1.0781e+00, -9.1406e-01,  2.9688e+00,  ..., -1.7031e+00,
          -5.6641e-01, -1.1016e+00],
         [-1.9219e+00, -4.1562e+00,  4.4688e+00,  ..., -7.5781e-01,
          -1.3359e+00, -9.6094e-01],
         ...,
         [-1.5469e+00, -5.3516e-01, -2.2461e-01,  ...,  1.4062e+00,
           9.1016e-01,  1.1406e+00],
         [-4.6094e-01, -2.2656e+00,  1.8750e+00,  ..., -7.8125e-01,
           5.0781e-01,  1.2891e+00],
         [-7.1875e-01, -2.1875e+00,  2.0938e+00,  ...,  1.0938e-01,
          -2.0996e-01,  9.2188e-01]],

        ...,

        [[ 5.9509e-04,  1.7090e-02, -3.9368e-03,  ...,  3.6011e-03,
           3.3569e-03, -2.1362e-03],
         [-1.0547e-01,  1.1963e-01,  1.7266e+00,  ..., -1.3359e+00,
          -6.9141e-01, -1.7422e+00],
         [-4.1809e-03, -8.7402e-02,  5.8594e-01,  ..., -8.2422e-01,
          -4.3701e-02, -3.7598e-02],
         ...,
         [-1.6953e+00, -2.0508e-01,  2.5977e-01,  ..., -1.7266e+00,
          -2.5156e+00, -1.0391e+00],
         [-9.4922e-01, -5.0000e-01, -5.4688e-01,  ..., -7.5391e-01,
          -1.5781e+00, -1.8359e+00],
         [-1.2109e+00,  8.8281e-01,  3.9453e-01,  ..., -2.0469e+00,
          -7.4219e-01,  1.3984e+00]],

        [[ 5.9509e-04,  1.7090e-02, -3.9368e-03,  ...,  3.6011e-03,
           3.3569e-03, -2.1362e-03],
         [-1.0547e-01,  1.1963e-01,  1.7266e+00,  ..., -1.3359e+00,
          -6.9141e-01, -1.7422e+00],
         [-4.1809e-03, -8.7402e-02,  5.8594e-01,  ..., -8.2422e-01,
          -4.3701e-02, -3.7598e-02],
         ...,
         [-1.6953e+00, -2.0508e-01,  2.5977e-01,  ..., -1.7266e+00,
          -2.5156e+00, -1.0391e+00],
         [-9.4922e-01, -5.0000e-01, -5.4688e-01,  ..., -7.5391e-01,
          -1.5781e+00, -1.8359e+00],
         [-1.2109e+00,  8.8281e-01,  3.9453e-01,  ..., -2.0469e+00,
          -7.4219e-01,  1.3984e+00]],

        [[ 5.9509e-04,  1.7090e-02, -3.9368e-03,  ...,  3.6011e-03,
           3.3569e-03, -2.1362e-03],
         [-1.0547e-01,  1.1963e-01,  1.7266e+00,  ..., -1.3359e+00,
          -6.9141e-01, -1.7422e+00],
         [-4.1809e-03, -8.7402e-02,  5.8594e-01,  ..., -8.2422e-01,
          -4.3701e-02, -3.7598e-02],
         ...,
         [-1.6953e+00, -2.0508e-01,  2.5977e-01,  ..., -1.7266e+00,
          -2.5156e+00, -1.0391e+00],
         [-9.4922e-01, -5.0000e-01, -5.4688e-01,  ..., -7.5391e-01,
          -1.5781e+00, -1.8359e+00],
         [-1.2109e+00,  8.8281e-01,  3.9453e-01,  ..., -2.0469e+00,
          -7.4219e-01,  1.3984e+00]]], dtype=torch.bfloat16)), (tensor([[[-1.2024e-02,  5.1270e-03,  5.8289e-03,  ..., -1.3594e+00,
           1.7090e-03,  2.0215e-01],
         [ 3.5547e-01,  1.9609e+00, -6.4844e-01,  ...,  9.5625e+00,
           1.6113e-01, -2.9062e+00],
         [-9.1016e-01, -1.1621e-01,  2.7734e-01,  ...,  1.0000e+01,
           1.7188e+00, -6.8359e-01],
         ...,
         [-6.9141e-01, -1.5918e-01, -6.6406e-01,  ...,  9.1875e+00,
           6.7383e-02, -1.1641e+00],
         [-7.6172e-02,  1.3672e-01, -4.2773e-01,  ...,  8.8750e+00,
           5.4297e-01, -1.1250e+00],
         [ 1.9062e+00,  1.2188e+00, -7.5391e-01,  ...,  9.0625e+00,
          -5.1953e-01, -2.5000e+00]],

        [[-1.2024e-02,  5.1270e-03,  5.8289e-03,  ..., -1.3594e+00,
           1.7090e-03,  2.0215e-01],
         [ 3.5547e-01,  1.9609e+00, -6.4844e-01,  ...,  9.5625e+00,
           1.6113e-01, -2.9062e+00],
         [-9.1016e-01, -1.1621e-01,  2.7734e-01,  ...,  1.0000e+01,
           1.7188e+00, -6.8359e-01],
         ...,
         [-6.9141e-01, -1.5918e-01, -6.6406e-01,  ...,  9.1875e+00,
           6.7383e-02, -1.1641e+00],
         [-7.6172e-02,  1.3672e-01, -4.2773e-01,  ...,  8.8750e+00,
           5.4297e-01, -1.1250e+00],
         [ 1.9062e+00,  1.2188e+00, -7.5391e-01,  ...,  9.0625e+00,
          -5.1953e-01, -2.5000e+00]],

        [[-1.2024e-02,  5.1270e-03,  5.8289e-03,  ..., -1.3594e+00,
           1.7090e-03,  2.0215e-01],
         [ 3.5547e-01,  1.9609e+00, -6.4844e-01,  ...,  9.5625e+00,
           1.6113e-01, -2.9062e+00],
         [-9.1016e-01, -1.1621e-01,  2.7734e-01,  ...,  1.0000e+01,
           1.7188e+00, -6.8359e-01],
         ...,
         [-6.9141e-01, -1.5918e-01, -6.6406e-01,  ...,  9.1875e+00,
           6.7383e-02, -1.1641e+00],
         [-7.6172e-02,  1.3672e-01, -4.2773e-01,  ...,  8.8750e+00,
           5.4297e-01, -1.1250e+00],
         [ 1.9062e+00,  1.2188e+00, -7.5391e-01,  ...,  9.0625e+00,
          -5.1953e-01, -2.5000e+00]],

        ...,

        [[ 9.3384e-03,  1.5198e-02, -1.2146e-02,  ...,  7.6660e-02,
          -3.2422e-01,  1.1523e-01],
         [-1.5078e+00,  2.6406e+00,  1.2656e+00,  ...,  2.8711e-01,
           2.0469e+00,  1.1523e-01],
         [-7.0312e-01,  9.1309e-02,  5.9375e-01,  ...,  4.6875e+00,
          -4.1562e+00,  1.5234e+00],
         ...,
         [ 2.0312e+00, -7.4219e-02,  8.0469e-01,  ...,  1.7266e+00,
          -3.5625e+00,  2.6406e+00],
         [ 4.2969e-01, -5.1562e-01,  7.5000e-01,  ...,  8.2812e-01,
          -3.5625e+00,  2.9062e+00],
         [-1.1953e+00,  2.3438e-01,  1.4844e+00,  ...,  6.1719e-01,
          -2.1875e+00,  1.4688e+00]],

        [[ 9.3384e-03,  1.5198e-02, -1.2146e-02,  ...,  7.6660e-02,
          -3.2422e-01,  1.1523e-01],
         [-1.5078e+00,  2.6406e+00,  1.2656e+00,  ...,  2.8711e-01,
           2.0469e+00,  1.1523e-01],
         [-7.0312e-01,  9.1309e-02,  5.9375e-01,  ...,  4.6875e+00,
          -4.1562e+00,  1.5234e+00],
         ...,
         [ 2.0312e+00, -7.4219e-02,  8.0469e-01,  ...,  1.7266e+00,
          -3.5625e+00,  2.6406e+00],
         [ 4.2969e-01, -5.1562e-01,  7.5000e-01,  ...,  8.2812e-01,
          -3.5625e+00,  2.9062e+00],
         [-1.1953e+00,  2.3438e-01,  1.4844e+00,  ...,  6.1719e-01,
          -2.1875e+00,  1.4688e+00]],

        [[ 9.3384e-03,  1.5198e-02, -1.2146e-02,  ...,  7.6660e-02,
          -3.2422e-01,  1.1523e-01],
         [-1.5078e+00,  2.6406e+00,  1.2656e+00,  ...,  2.8711e-01,
           2.0469e+00,  1.1523e-01],
         [-7.0312e-01,  9.1309e-02,  5.9375e-01,  ...,  4.6875e+00,
          -4.1562e+00,  1.5234e+00],
         ...,
         [ 2.0312e+00, -7.4219e-02,  8.0469e-01,  ...,  1.7266e+00,
          -3.5625e+00,  2.6406e+00],
         [ 4.2969e-01, -5.1562e-01,  7.5000e-01,  ...,  8.2812e-01,
          -3.5625e+00,  2.9062e+00],
         [-1.1953e+00,  2.3438e-01,  1.4844e+00,  ...,  6.1719e-01,
          -2.1875e+00,  1.4688e+00]]], dtype=torch.bfloat16), tensor([[[-9.0942e-03,  5.8289e-03, -4.0894e-03,  ..., -1.3184e-02,
          -4.6997e-03,  2.1172e-04],
         [ 1.4062e-01, -7.3047e-01, -4.3164e-01,  ..., -9.0625e-01,
           6.6797e-01, -9.2969e-01],
         [-9.3359e-01, -2.1719e+00,  3.2959e-02,  ...,  9.1406e-01,
           5.8984e-01,  5.2734e-01],
         ...,
         [-7.1094e-01, -7.1289e-02,  1.1094e+00,  ...,  5.1953e-01,
           5.9375e-01,  3.6719e-01],
         [-8.4766e-01,  1.0547e+00, -5.2734e-01,  ...,  1.2578e+00,
           9.4922e-01,  8.6914e-02],
         [-1.9043e-01,  1.2734e+00,  1.0781e+00,  ...,  1.0391e+00,
          -5.1172e-01, -2.2461e-01]],

        [[-9.0942e-03,  5.8289e-03, -4.0894e-03,  ..., -1.3184e-02,
          -4.6997e-03,  2.1172e-04],
         [ 1.4062e-01, -7.3047e-01, -4.3164e-01,  ..., -9.0625e-01,
           6.6797e-01, -9.2969e-01],
         [-9.3359e-01, -2.1719e+00,  3.2959e-02,  ...,  9.1406e-01,
           5.8984e-01,  5.2734e-01],
         ...,
         [-7.1094e-01, -7.1289e-02,  1.1094e+00,  ...,  5.1953e-01,
           5.9375e-01,  3.6719e-01],
         [-8.4766e-01,  1.0547e+00, -5.2734e-01,  ...,  1.2578e+00,
           9.4922e-01,  8.6914e-02],
         [-1.9043e-01,  1.2734e+00,  1.0781e+00,  ...,  1.0391e+00,
          -5.1172e-01, -2.2461e-01]],

        [[-9.0942e-03,  5.8289e-03, -4.0894e-03,  ..., -1.3184e-02,
          -4.6997e-03,  2.1172e-04],
         [ 1.4062e-01, -7.3047e-01, -4.3164e-01,  ..., -9.0625e-01,
           6.6797e-01, -9.2969e-01],
         [-9.3359e-01, -2.1719e+00,  3.2959e-02,  ...,  9.1406e-01,
           5.8984e-01,  5.2734e-01],
         ...,
         [-7.1094e-01, -7.1289e-02,  1.1094e+00,  ...,  5.1953e-01,
           5.9375e-01,  3.6719e-01],
         [-8.4766e-01,  1.0547e+00, -5.2734e-01,  ...,  1.2578e+00,
           9.4922e-01,  8.6914e-02],
         [-1.9043e-01,  1.2734e+00,  1.0781e+00,  ...,  1.0391e+00,
          -5.1172e-01, -2.2461e-01]],

        ...,

        [[ 3.6955e-05, -7.8735e-03,  3.1853e-04,  ...,  1.9531e-02,
           2.4986e-04, -6.5308e-03],
         [-1.0449e-01, -9.7656e-02,  5.8984e-01,  ...,  8.9844e-01,
          -1.6094e+00,  1.9434e-01],
         [-1.0781e+00,  3.4219e+00,  2.5000e+00,  ..., -6.2500e-01,
          -1.0859e+00,  1.2812e+00],
         ...,
         [ 3.5938e-01, -1.6309e-01,  1.0703e+00,  ...,  7.0801e-02,
          -1.9688e+00,  3.2031e-01],
         [-8.0078e-01,  1.4062e+00,  1.7656e+00,  ..., -6.3672e-01,
          -3.0312e+00, -3.1445e-01],
         [-3.6328e-01,  1.1094e+00,  2.4219e+00,  ...,  2.8711e-01,
          -1.6797e+00,  6.4062e-01]],

        [[ 3.6955e-05, -7.8735e-03,  3.1853e-04,  ...,  1.9531e-02,
           2.4986e-04, -6.5308e-03],
         [-1.0449e-01, -9.7656e-02,  5.8984e-01,  ...,  8.9844e-01,
          -1.6094e+00,  1.9434e-01],
         [-1.0781e+00,  3.4219e+00,  2.5000e+00,  ..., -6.2500e-01,
          -1.0859e+00,  1.2812e+00],
         ...,
         [ 3.5938e-01, -1.6309e-01,  1.0703e+00,  ...,  7.0801e-02,
          -1.9688e+00,  3.2031e-01],
         [-8.0078e-01,  1.4062e+00,  1.7656e+00,  ..., -6.3672e-01,
          -3.0312e+00, -3.1445e-01],
         [-3.6328e-01,  1.1094e+00,  2.4219e+00,  ...,  2.8711e-01,
          -1.6797e+00,  6.4062e-01]],

        [[ 3.6955e-05, -7.8735e-03,  3.1853e-04,  ...,  1.9531e-02,
           2.4986e-04, -6.5308e-03],
         [-1.0449e-01, -9.7656e-02,  5.8984e-01,  ...,  8.9844e-01,
          -1.6094e+00,  1.9434e-01],
         [-1.0781e+00,  3.4219e+00,  2.5000e+00,  ..., -6.2500e-01,
          -1.0859e+00,  1.2812e+00],
         ...,
         [ 3.5938e-01, -1.6309e-01,  1.0703e+00,  ...,  7.0801e-02,
          -1.9688e+00,  3.2031e-01],
         [-8.0078e-01,  1.4062e+00,  1.7656e+00,  ..., -6.3672e-01,
          -3.0312e+00, -3.1445e-01],
         [-3.6328e-01,  1.1094e+00,  2.4219e+00,  ...,  2.8711e-01,
          -1.6797e+00,  6.4062e-01]]], dtype=torch.bfloat16)), (tensor([[[ 2.8076e-03, -3.0212e-03,  3.6774e-03,  ...,  4.7266e-01,
           5.4443e-02, -6.8848e-02],
         [-4.5625e+00,  8.1250e-01,  1.7500e+00,  ..., -5.6641e-01,
          -1.9141e-01,  8.0078e-02],
         [-1.1719e+00,  2.5938e+00,  1.4844e+00,  ..., -6.6406e-01,
           4.4727e-01,  4.9609e-01],
         ...,
         [ 6.5625e-01, -1.5312e+00, -1.4531e+00,  ...,  9.7656e-01,
           4.0820e-01,  1.3203e+00],
         [ 1.6250e+00, -1.4922e+00,  2.1484e-01,  ...,  5.0000e-01,
          -6.2500e-02, -4.8047e-01],
         [ 1.1719e+00, -1.3984e+00,  2.0312e+00,  ...,  1.6406e+00,
          -2.4688e+00, -1.5391e+00]],

        [[ 2.8076e-03, -3.0212e-03,  3.6774e-03,  ...,  4.7266e-01,
           5.4443e-02, -6.8848e-02],
         [-4.5625e+00,  8.1250e-01,  1.7500e+00,  ..., -5.6641e-01,
          -1.9141e-01,  8.0078e-02],
         [-1.1719e+00,  2.5938e+00,  1.4844e+00,  ..., -6.6406e-01,
           4.4727e-01,  4.9609e-01],
         ...,
         [ 6.5625e-01, -1.5312e+00, -1.4531e+00,  ...,  9.7656e-01,
           4.0820e-01,  1.3203e+00],
         [ 1.6250e+00, -1.4922e+00,  2.1484e-01,  ...,  5.0000e-01,
          -6.2500e-02, -4.8047e-01],
         [ 1.1719e+00, -1.3984e+00,  2.0312e+00,  ...,  1.6406e+00,
          -2.4688e+00, -1.5391e+00]],

        [[ 2.8076e-03, -3.0212e-03,  3.6774e-03,  ...,  4.7266e-01,
           5.4443e-02, -6.8848e-02],
         [-4.5625e+00,  8.1250e-01,  1.7500e+00,  ..., -5.6641e-01,
          -1.9141e-01,  8.0078e-02],
         [-1.1719e+00,  2.5938e+00,  1.4844e+00,  ..., -6.6406e-01,
           4.4727e-01,  4.9609e-01],
         ...,
         [ 6.5625e-01, -1.5312e+00, -1.4531e+00,  ...,  9.7656e-01,
           4.0820e-01,  1.3203e+00],
         [ 1.6250e+00, -1.4922e+00,  2.1484e-01,  ...,  5.0000e-01,
          -6.2500e-02, -4.8047e-01],
         [ 1.1719e+00, -1.3984e+00,  2.0312e+00,  ...,  1.6406e+00,
          -2.4688e+00, -1.5391e+00]],

        ...,

        [[ 1.2390e-02,  3.0670e-03, -8.9111e-03,  ...,  2.3438e-01,
          -5.9326e-02,  6.6895e-02],
         [-1.9043e-01, -1.5625e-02, -2.5000e-01,  ..., -2.9688e-01,
           6.7188e-01, -3.9258e-01],
         [-3.3789e-01,  2.4316e-01,  9.7168e-02,  ..., -1.0703e+00,
           5.2734e-01, -1.5000e+00],
         ...,
         [ 4.2578e-01, -5.8984e-01, -1.1719e+00,  ...,  5.2734e-01,
          -7.6953e-01, -7.0703e-01],
         [ 9.4727e-02, -2.1387e-01, -2.8516e-01,  ...,  1.1094e+00,
          -8.3496e-02,  1.2793e-01],
         [-3.1445e-01, -8.5938e-01,  6.6406e-02,  ...,  1.1406e+00,
           5.6250e-01,  2.9102e-01]],

        [[ 1.2390e-02,  3.0670e-03, -8.9111e-03,  ...,  2.3438e-01,
          -5.9326e-02,  6.6895e-02],
         [-1.9043e-01, -1.5625e-02, -2.5000e-01,  ..., -2.9688e-01,
           6.7188e-01, -3.9258e-01],
         [-3.3789e-01,  2.4316e-01,  9.7168e-02,  ..., -1.0703e+00,
           5.2734e-01, -1.5000e+00],
         ...,
         [ 4.2578e-01, -5.8984e-01, -1.1719e+00,  ...,  5.2734e-01,
          -7.6953e-01, -7.0703e-01],
         [ 9.4727e-02, -2.1387e-01, -2.8516e-01,  ...,  1.1094e+00,
          -8.3496e-02,  1.2793e-01],
         [-3.1445e-01, -8.5938e-01,  6.6406e-02,  ...,  1.1406e+00,
           5.6250e-01,  2.9102e-01]],

        [[ 1.2390e-02,  3.0670e-03, -8.9111e-03,  ...,  2.3438e-01,
          -5.9326e-02,  6.6895e-02],
         [-1.9043e-01, -1.5625e-02, -2.5000e-01,  ..., -2.9688e-01,
           6.7188e-01, -3.9258e-01],
         [-3.3789e-01,  2.4316e-01,  9.7168e-02,  ..., -1.0703e+00,
           5.2734e-01, -1.5000e+00],
         ...,
         [ 4.2578e-01, -5.8984e-01, -1.1719e+00,  ...,  5.2734e-01,
          -7.6953e-01, -7.0703e-01],
         [ 9.4727e-02, -2.1387e-01, -2.8516e-01,  ...,  1.1094e+00,
          -8.3496e-02,  1.2793e-01],
         [-3.1445e-01, -8.5938e-01,  6.6406e-02,  ...,  1.1406e+00,
           5.6250e-01,  2.9102e-01]]], dtype=torch.bfloat16), tensor([[[-3.1891e-03,  2.6703e-03, -3.2196e-03,  ..., -9.7656e-03,
          -9.3994e-03,  2.1667e-03],
         [-3.2422e-01,  1.2891e-01,  5.2344e-01,  ...,  4.4922e-01,
           1.2598e-01, -9.6484e-01],
         [ 8.7500e-01, -7.2656e-01,  8.9355e-02,  ...,  5.7812e-01,
          -7.8516e-01, -1.2793e-01],
         ...,
         [-2.9883e-01, -1.0078e+00, -1.0254e-01,  ...,  1.4531e+00,
           1.0000e+00, -2.3340e-01],
         [ 1.0156e+00, -5.3125e-01, -1.0312e+00,  ...,  2.6172e-01,
          -6.4844e-01, -2.7148e-01],
         [-1.4453e+00, -1.5391e+00, -6.8750e-01,  ...,  2.1562e+00,
           2.5781e+00,  1.0391e+00]],

        [[-3.1891e-03,  2.6703e-03, -3.2196e-03,  ..., -9.7656e-03,
          -9.3994e-03,  2.1667e-03],
         [-3.2422e-01,  1.2891e-01,  5.2344e-01,  ...,  4.4922e-01,
           1.2598e-01, -9.6484e-01],
         [ 8.7500e-01, -7.2656e-01,  8.9355e-02,  ...,  5.7812e-01,
          -7.8516e-01, -1.2793e-01],
         ...,
         [-2.9883e-01, -1.0078e+00, -1.0254e-01,  ...,  1.4531e+00,
           1.0000e+00, -2.3340e-01],
         [ 1.0156e+00, -5.3125e-01, -1.0312e+00,  ...,  2.6172e-01,
          -6.4844e-01, -2.7148e-01],
         [-1.4453e+00, -1.5391e+00, -6.8750e-01,  ...,  2.1562e+00,
           2.5781e+00,  1.0391e+00]],

        [[-3.1891e-03,  2.6703e-03, -3.2196e-03,  ..., -9.7656e-03,
          -9.3994e-03,  2.1667e-03],
         [-3.2422e-01,  1.2891e-01,  5.2344e-01,  ...,  4.4922e-01,
           1.2598e-01, -9.6484e-01],
         [ 8.7500e-01, -7.2656e-01,  8.9355e-02,  ...,  5.7812e-01,
          -7.8516e-01, -1.2793e-01],
         ...,
         [-2.9883e-01, -1.0078e+00, -1.0254e-01,  ...,  1.4531e+00,
           1.0000e+00, -2.3340e-01],
         [ 1.0156e+00, -5.3125e-01, -1.0312e+00,  ...,  2.6172e-01,
          -6.4844e-01, -2.7148e-01],
         [-1.4453e+00, -1.5391e+00, -6.8750e-01,  ...,  2.1562e+00,
           2.5781e+00,  1.0391e+00]],

        ...,

        [[-4.9133e-03, -5.0049e-03, -8.3008e-03,  ..., -8.8501e-04,
          -2.7618e-03,  1.4587e-02],
         [-5.5469e-01, -1.0859e+00, -3.4570e-01,  ..., -1.1641e+00,
           5.1172e-01,  1.9629e-01],
         [ 2.1875e-01, -9.4531e-01,  2.5781e-01,  ..., -7.8516e-01,
           5.6641e-01,  1.0000e+00],
         ...,
         [-1.1865e-01,  7.1875e-01,  1.5234e+00,  ..., -1.6250e+00,
           8.0469e-01, -4.2383e-01],
         [ 1.2578e+00,  2.4414e-02,  4.9219e-01,  ..., -1.1641e+00,
          -1.0703e+00, -2.7930e-01],
         [ 1.9336e-01, -4.8242e-01,  1.4531e+00,  ..., -9.9609e-01,
           1.9766e+00, -1.4219e+00]],

        [[-4.9133e-03, -5.0049e-03, -8.3008e-03,  ..., -8.8501e-04,
          -2.7618e-03,  1.4587e-02],
         [-5.5469e-01, -1.0859e+00, -3.4570e-01,  ..., -1.1641e+00,
           5.1172e-01,  1.9629e-01],
         [ 2.1875e-01, -9.4531e-01,  2.5781e-01,  ..., -7.8516e-01,
           5.6641e-01,  1.0000e+00],
         ...,
         [-1.1865e-01,  7.1875e-01,  1.5234e+00,  ..., -1.6250e+00,
           8.0469e-01, -4.2383e-01],
         [ 1.2578e+00,  2.4414e-02,  4.9219e-01,  ..., -1.1641e+00,
          -1.0703e+00, -2.7930e-01],
         [ 1.9336e-01, -4.8242e-01,  1.4531e+00,  ..., -9.9609e-01,
           1.9766e+00, -1.4219e+00]],

        [[-4.9133e-03, -5.0049e-03, -8.3008e-03,  ..., -8.8501e-04,
          -2.7618e-03,  1.4587e-02],
         [-5.5469e-01, -1.0859e+00, -3.4570e-01,  ..., -1.1641e+00,
           5.1172e-01,  1.9629e-01],
         [ 2.1875e-01, -9.4531e-01,  2.5781e-01,  ..., -7.8516e-01,
           5.6641e-01,  1.0000e+00],
         ...,
         [-1.1865e-01,  7.1875e-01,  1.5234e+00,  ..., -1.6250e+00,
           8.0469e-01, -4.2383e-01],
         [ 1.2578e+00,  2.4414e-02,  4.9219e-01,  ..., -1.1641e+00,
          -1.0703e+00, -2.7930e-01],
         [ 1.9336e-01, -4.8242e-01,  1.4531e+00,  ..., -9.9609e-01,
           1.9766e+00, -1.4219e+00]]], dtype=torch.bfloat16)), (tensor([[[-1.6724e-02, -4.4861e-03, -2.8076e-03,  ...,  1.1523e-01,
           1.3574e-01, -1.5747e-02],
         [ 4.4141e-01,  3.1250e-01, -9.4141e-01,  ...,  7.8906e-01,
           4.4531e-01,  4.8438e-01],
         [-1.5469e+00,  1.9727e-01, -2.2188e+00,  ...,  6.8359e-01,
          -1.1426e-01,  2.2500e+00],
         ...,
         [-5.7812e-01, -7.6172e-02,  3.9062e-01,  ..., -4.9316e-02,
          -1.6484e+00, -2.3828e-01],
         [-2.4219e-01, -7.0312e-01, -2.3047e-01,  ...,  5.1172e-01,
          -4.8438e-01, -5.2246e-02],
         [ 1.7031e+00,  1.5156e+00, -1.1797e+00,  ..., -1.5703e+00,
          -4.8633e-01,  6.5625e-01]],

        [[-1.6724e-02, -4.4861e-03, -2.8076e-03,  ...,  1.1523e-01,
           1.3574e-01, -1.5747e-02],
         [ 4.4141e-01,  3.1250e-01, -9.4141e-01,  ...,  7.8906e-01,
           4.4531e-01,  4.8438e-01],
         [-1.5469e+00,  1.9727e-01, -2.2188e+00,  ...,  6.8359e-01,
          -1.1426e-01,  2.2500e+00],
         ...,
         [-5.7812e-01, -7.6172e-02,  3.9062e-01,  ..., -4.9316e-02,
          -1.6484e+00, -2.3828e-01],
         [-2.4219e-01, -7.0312e-01, -2.3047e-01,  ...,  5.1172e-01,
          -4.8438e-01, -5.2246e-02],
         [ 1.7031e+00,  1.5156e+00, -1.1797e+00,  ..., -1.5703e+00,
          -4.8633e-01,  6.5625e-01]],

        [[-1.6724e-02, -4.4861e-03, -2.8076e-03,  ...,  1.1523e-01,
           1.3574e-01, -1.5747e-02],
         [ 4.4141e-01,  3.1250e-01, -9.4141e-01,  ...,  7.8906e-01,
           4.4531e-01,  4.8438e-01],
         [-1.5469e+00,  1.9727e-01, -2.2188e+00,  ...,  6.8359e-01,
          -1.1426e-01,  2.2500e+00],
         ...,
         [-5.7812e-01, -7.6172e-02,  3.9062e-01,  ..., -4.9316e-02,
          -1.6484e+00, -2.3828e-01],
         [-2.4219e-01, -7.0312e-01, -2.3047e-01,  ...,  5.1172e-01,
          -4.8438e-01, -5.2246e-02],
         [ 1.7031e+00,  1.5156e+00, -1.1797e+00,  ..., -1.5703e+00,
          -4.8633e-01,  6.5625e-01]],

        ...,

        [[ 2.2125e-03,  1.2573e-02,  1.5335e-03,  ...,  9.7168e-02,
          -4.8523e-03, -1.6504e-01],
         [ 1.8203e+00,  8.9062e-01, -2.1484e-01,  ..., -2.0801e-01,
          -4.8242e-01,  1.1562e+00],
         [ 1.4375e+00,  2.5156e+00, -8.4375e-01,  ...,  2.4805e-01,
          -2.2949e-01,  2.6406e+00],
         ...,
         [ 8.1250e-01, -1.9531e-02, -2.6758e-01,  ..., -1.6328e+00,
          -2.3633e-01,  3.0273e-01],
         [ 5.3906e-01, -8.6328e-01, -1.3516e+00,  ...,  1.2031e+00,
           3.2422e-01,  1.1641e+00],
         [-1.4688e+00, -2.4375e+00, -4.5508e-01,  ..., -9.2188e-01,
          -1.4282e-02,  2.6094e+00]],

        [[ 2.2125e-03,  1.2573e-02,  1.5335e-03,  ...,  9.7168e-02,
          -4.8523e-03, -1.6504e-01],
         [ 1.8203e+00,  8.9062e-01, -2.1484e-01,  ..., -2.0801e-01,
          -4.8242e-01,  1.1562e+00],
         [ 1.4375e+00,  2.5156e+00, -8.4375e-01,  ...,  2.4805e-01,
          -2.2949e-01,  2.6406e+00],
         ...,
         [ 8.1250e-01, -1.9531e-02, -2.6758e-01,  ..., -1.6328e+00,
          -2.3633e-01,  3.0273e-01],
         [ 5.3906e-01, -8.6328e-01, -1.3516e+00,  ...,  1.2031e+00,
           3.2422e-01,  1.1641e+00],
         [-1.4688e+00, -2.4375e+00, -4.5508e-01,  ..., -9.2188e-01,
          -1.4282e-02,  2.6094e+00]],

        [[ 2.2125e-03,  1.2573e-02,  1.5335e-03,  ...,  9.7168e-02,
          -4.8523e-03, -1.6504e-01],
         [ 1.8203e+00,  8.9062e-01, -2.1484e-01,  ..., -2.0801e-01,
          -4.8242e-01,  1.1562e+00],
         [ 1.4375e+00,  2.5156e+00, -8.4375e-01,  ...,  2.4805e-01,
          -2.2949e-01,  2.6406e+00],
         ...,
         [ 8.1250e-01, -1.9531e-02, -2.6758e-01,  ..., -1.6328e+00,
          -2.3633e-01,  3.0273e-01],
         [ 5.3906e-01, -8.6328e-01, -1.3516e+00,  ...,  1.2031e+00,
           3.2422e-01,  1.1641e+00],
         [-1.4688e+00, -2.4375e+00, -4.5508e-01,  ..., -9.2188e-01,
          -1.4282e-02,  2.6094e+00]]], dtype=torch.bfloat16), tensor([[[ 0.0064,  0.0118,  0.0106,  ...,  0.0147,  0.0085, -0.0025],
         [-1.1484, -1.8438,  0.8281,  ...,  1.4219, -1.5938, -0.3262],
         [-0.4883, -0.5859, -0.4316,  ...,  0.0128, -0.0199,  1.2422],
         ...,
         [ 0.1182,  0.7305, -0.0317,  ..., -0.6523,  0.8633,  0.6562],
         [ 0.4688,  0.2949, -0.9844,  ...,  0.5078,  0.6172,  0.7578],
         [ 0.8516, -1.6094, -0.2246,  ..., -0.1001,  2.1562,  1.0547]],

        [[ 0.0064,  0.0118,  0.0106,  ...,  0.0147,  0.0085, -0.0025],
         [-1.1484, -1.8438,  0.8281,  ...,  1.4219, -1.5938, -0.3262],
         [-0.4883, -0.5859, -0.4316,  ...,  0.0128, -0.0199,  1.2422],
         ...,
         [ 0.1182,  0.7305, -0.0317,  ..., -0.6523,  0.8633,  0.6562],
         [ 0.4688,  0.2949, -0.9844,  ...,  0.5078,  0.6172,  0.7578],
         [ 0.8516, -1.6094, -0.2246,  ..., -0.1001,  2.1562,  1.0547]],

        [[ 0.0064,  0.0118,  0.0106,  ...,  0.0147,  0.0085, -0.0025],
         [-1.1484, -1.8438,  0.8281,  ...,  1.4219, -1.5938, -0.3262],
         [-0.4883, -0.5859, -0.4316,  ...,  0.0128, -0.0199,  1.2422],
         ...,
         [ 0.1182,  0.7305, -0.0317,  ..., -0.6523,  0.8633,  0.6562],
         [ 0.4688,  0.2949, -0.9844,  ...,  0.5078,  0.6172,  0.7578],
         [ 0.8516, -1.6094, -0.2246,  ..., -0.1001,  2.1562,  1.0547]],

        ...,

        [[-0.0048,  0.0029, -0.0053,  ..., -0.0062, -0.0024,  0.0045],
         [ 0.4160, -0.6328,  0.5195,  ...,  0.7461, -0.2754,  0.5273],
         [-1.4375,  1.0703,  0.4629,  ..., -0.9336, -1.2734, -0.4043],
         ...,
         [-0.8711, -0.2246,  0.1895,  ..., -0.2451, -0.4746, -0.7812],
         [-2.2500, -0.7812, -1.2812,  ..., -0.8164, -1.1484, -0.1445],
         [ 1.1094,  0.5508,  0.0986,  ...,  0.4297, -0.0275, -0.0444]],

        [[-0.0048,  0.0029, -0.0053,  ..., -0.0062, -0.0024,  0.0045],
         [ 0.4160, -0.6328,  0.5195,  ...,  0.7461, -0.2754,  0.5273],
         [-1.4375,  1.0703,  0.4629,  ..., -0.9336, -1.2734, -0.4043],
         ...,
         [-0.8711, -0.2246,  0.1895,  ..., -0.2451, -0.4746, -0.7812],
         [-2.2500, -0.7812, -1.2812,  ..., -0.8164, -1.1484, -0.1445],
         [ 1.1094,  0.5508,  0.0986,  ...,  0.4297, -0.0275, -0.0444]],

        [[-0.0048,  0.0029, -0.0053,  ..., -0.0062, -0.0024,  0.0045],
         [ 0.4160, -0.6328,  0.5195,  ...,  0.7461, -0.2754,  0.5273],
         [-1.4375,  1.0703,  0.4629,  ..., -0.9336, -1.2734, -0.4043],
         ...,
         [-0.8711, -0.2246,  0.1895,  ..., -0.2451, -0.4746, -0.7812],
         [-2.2500, -0.7812, -1.2812,  ..., -0.8164, -1.1484, -0.1445],
         [ 1.1094,  0.5508,  0.0986,  ...,  0.4297, -0.0275, -0.0444]]],
       dtype=torch.bfloat16)), (tensor([[[-7.9346e-04,  2.2430e-03,  1.5717e-03,  ..., -7.9956e-03,
          -1.6406e+00,  8.7402e-02],
         [ 3.9453e-01,  1.5625e-01, -1.2266e+00,  ..., -1.7285e-01,
           6.5938e+00, -1.5527e-01],
         [ 4.2188e-01, -3.2812e-01,  1.5430e-01,  ..., -2.4375e+00,
           9.6875e+00,  1.8281e+00],
         ...,
         [-2.2266e-01,  1.0205e-01, -5.9766e-01,  ...,  4.7812e+00,
           7.6875e+00, -4.1875e+00],
         [-1.8555e-01,  6.9336e-02, -3.0078e-01,  ...,  1.3594e+00,
           5.8125e+00, -2.3594e+00],
         [-7.2656e-01,  9.2969e-01, -1.8945e-01,  ..., -2.3281e+00,
           9.1250e+00, -1.1484e+00]],

        [[-7.9346e-04,  2.2430e-03,  1.5717e-03,  ..., -7.9956e-03,
          -1.6406e+00,  8.7402e-02],
         [ 3.9453e-01,  1.5625e-01, -1.2266e+00,  ..., -1.7285e-01,
           6.5938e+00, -1.5527e-01],
         [ 4.2188e-01, -3.2812e-01,  1.5430e-01,  ..., -2.4375e+00,
           9.6875e+00,  1.8281e+00],
         ...,
         [-2.2266e-01,  1.0205e-01, -5.9766e-01,  ...,  4.7812e+00,
           7.6875e+00, -4.1875e+00],
         [-1.8555e-01,  6.9336e-02, -3.0078e-01,  ...,  1.3594e+00,
           5.8125e+00, -2.3594e+00],
         [-7.2656e-01,  9.2969e-01, -1.8945e-01,  ..., -2.3281e+00,
           9.1250e+00, -1.1484e+00]],

        [[-7.9346e-04,  2.2430e-03,  1.5717e-03,  ..., -7.9956e-03,
          -1.6406e+00,  8.7402e-02],
         [ 3.9453e-01,  1.5625e-01, -1.2266e+00,  ..., -1.7285e-01,
           6.5938e+00, -1.5527e-01],
         [ 4.2188e-01, -3.2812e-01,  1.5430e-01,  ..., -2.4375e+00,
           9.6875e+00,  1.8281e+00],
         ...,
         [-2.2266e-01,  1.0205e-01, -5.9766e-01,  ...,  4.7812e+00,
           7.6875e+00, -4.1875e+00],
         [-1.8555e-01,  6.9336e-02, -3.0078e-01,  ...,  1.3594e+00,
           5.8125e+00, -2.3594e+00],
         [-7.2656e-01,  9.2969e-01, -1.8945e-01,  ..., -2.3281e+00,
           9.1250e+00, -1.1484e+00]],

        ...,

        [[-3.6774e-03, -1.3000e-02, -1.9989e-03,  ..., -5.5908e-02,
           2.0386e-02, -7.2754e-02],
         [ 1.8594e+00, -1.2344e+00, -7.7344e-01,  ...,  5.4297e-01,
          -1.0703e+00,  1.8359e+00],
         [ 1.0000e+00, -7.2656e-01, -4.6875e-01,  ...,  7.3438e-01,
           3.0273e-01, -2.5781e-01],
         ...,
         [ 1.0781e+00,  4.1016e-01,  1.3594e+00,  ..., -8.6719e-01,
           2.6719e+00,  8.8672e-01],
         [ 5.7422e-01,  7.3438e-01,  5.8594e-03,  ..., -1.0547e+00,
           1.6602e-01, -1.6406e+00],
         [-1.5156e+00, -5.6250e-01, -7.2656e-01,  ..., -1.2891e+00,
           1.2734e+00,  7.7344e-01]],

        [[-3.6774e-03, -1.3000e-02, -1.9989e-03,  ..., -5.5908e-02,
           2.0386e-02, -7.2754e-02],
         [ 1.8594e+00, -1.2344e+00, -7.7344e-01,  ...,  5.4297e-01,
          -1.0703e+00,  1.8359e+00],
         [ 1.0000e+00, -7.2656e-01, -4.6875e-01,  ...,  7.3438e-01,
           3.0273e-01, -2.5781e-01],
         ...,
         [ 1.0781e+00,  4.1016e-01,  1.3594e+00,  ..., -8.6719e-01,
           2.6719e+00,  8.8672e-01],
         [ 5.7422e-01,  7.3438e-01,  5.8594e-03,  ..., -1.0547e+00,
           1.6602e-01, -1.6406e+00],
         [-1.5156e+00, -5.6250e-01, -7.2656e-01,  ..., -1.2891e+00,
           1.2734e+00,  7.7344e-01]],

        [[-3.6774e-03, -1.3000e-02, -1.9989e-03,  ..., -5.5908e-02,
           2.0386e-02, -7.2754e-02],
         [ 1.8594e+00, -1.2344e+00, -7.7344e-01,  ...,  5.4297e-01,
          -1.0703e+00,  1.8359e+00],
         [ 1.0000e+00, -7.2656e-01, -4.6875e-01,  ...,  7.3438e-01,
           3.0273e-01, -2.5781e-01],
         ...,
         [ 1.0781e+00,  4.1016e-01,  1.3594e+00,  ..., -8.6719e-01,
           2.6719e+00,  8.8672e-01],
         [ 5.7422e-01,  7.3438e-01,  5.8594e-03,  ..., -1.0547e+00,
           1.6602e-01, -1.6406e+00],
         [-1.5156e+00, -5.6250e-01, -7.2656e-01,  ..., -1.2891e+00,
           1.2734e+00,  7.7344e-01]]], dtype=torch.bfloat16), tensor([[[ 4.6692e-03,  2.2095e-02, -3.2043e-03,  ..., -1.2268e-02,
           1.4160e-02,  3.9307e-02],
         [ 5.5078e-01, -2.9102e-01, -6.2500e-01,  ..., -8.1635e-04,
          -1.5078e+00, -1.0781e+00],
         [-2.6367e-02,  6.0156e-01,  2.3047e-01,  ..., -8.0469e-01,
           4.0820e-01, -1.7578e-01],
         ...,
         [-9.0234e-01, -2.0410e-01, -1.0859e+00,  ...,  6.3672e-01,
          -9.4238e-02,  1.5938e+00],
         [-8.0859e-01,  8.3008e-03, -9.4141e-01,  ..., -5.5078e-01,
           4.8438e-01,  1.5234e+00],
         [ 1.0781e+00, -1.1328e+00, -1.7422e+00,  ..., -1.2578e+00,
          -2.3750e+00,  5.4688e-01]],

        [[ 4.6692e-03,  2.2095e-02, -3.2043e-03,  ..., -1.2268e-02,
           1.4160e-02,  3.9307e-02],
         [ 5.5078e-01, -2.9102e-01, -6.2500e-01,  ..., -8.1635e-04,
          -1.5078e+00, -1.0781e+00],
         [-2.6367e-02,  6.0156e-01,  2.3047e-01,  ..., -8.0469e-01,
           4.0820e-01, -1.7578e-01],
         ...,
         [-9.0234e-01, -2.0410e-01, -1.0859e+00,  ...,  6.3672e-01,
          -9.4238e-02,  1.5938e+00],
         [-8.0859e-01,  8.3008e-03, -9.4141e-01,  ..., -5.5078e-01,
           4.8438e-01,  1.5234e+00],
         [ 1.0781e+00, -1.1328e+00, -1.7422e+00,  ..., -1.2578e+00,
          -2.3750e+00,  5.4688e-01]],

        [[ 4.6692e-03,  2.2095e-02, -3.2043e-03,  ..., -1.2268e-02,
           1.4160e-02,  3.9307e-02],
         [ 5.5078e-01, -2.9102e-01, -6.2500e-01,  ..., -8.1635e-04,
          -1.5078e+00, -1.0781e+00],
         [-2.6367e-02,  6.0156e-01,  2.3047e-01,  ..., -8.0469e-01,
           4.0820e-01, -1.7578e-01],
         ...,
         [-9.0234e-01, -2.0410e-01, -1.0859e+00,  ...,  6.3672e-01,
          -9.4238e-02,  1.5938e+00],
         [-8.0859e-01,  8.3008e-03, -9.4141e-01,  ..., -5.5078e-01,
           4.8438e-01,  1.5234e+00],
         [ 1.0781e+00, -1.1328e+00, -1.7422e+00,  ..., -1.2578e+00,
          -2.3750e+00,  5.4688e-01]],

        ...,

        [[ 1.7319e-03,  1.0872e-04, -7.7209e-03,  ..., -1.3489e-02,
          -8.0566e-03,  6.9580e-03],
         [-7.5000e-01,  3.8086e-01, -8.7500e-01,  ..., -4.7070e-01,
          -1.2578e+00,  4.1602e-01],
         [-2.5000e-01,  2.1289e-01,  5.5469e-01,  ...,  8.9844e-01,
          -1.0469e+00,  1.2266e+00],
         ...,
         [-1.3867e-01,  9.3750e-02,  3.3203e-01,  ..., -1.2598e-01,
          -2.4414e-01, -4.9072e-02],
         [ 3.9258e-01, -6.9336e-02,  3.7500e-01,  ...,  1.1250e+00,
           1.5078e+00, -1.6504e-01],
         [ 3.3789e-01,  9.3750e-01,  1.0781e+00,  ...,  4.0625e-01,
           1.9455e-03,  3.0469e-01]],

        [[ 1.7319e-03,  1.0872e-04, -7.7209e-03,  ..., -1.3489e-02,
          -8.0566e-03,  6.9580e-03],
         [-7.5000e-01,  3.8086e-01, -8.7500e-01,  ..., -4.7070e-01,
          -1.2578e+00,  4.1602e-01],
         [-2.5000e-01,  2.1289e-01,  5.5469e-01,  ...,  8.9844e-01,
          -1.0469e+00,  1.2266e+00],
         ...,
         [-1.3867e-01,  9.3750e-02,  3.3203e-01,  ..., -1.2598e-01,
          -2.4414e-01, -4.9072e-02],
         [ 3.9258e-01, -6.9336e-02,  3.7500e-01,  ...,  1.1250e+00,
           1.5078e+00, -1.6504e-01],
         [ 3.3789e-01,  9.3750e-01,  1.0781e+00,  ...,  4.0625e-01,
           1.9455e-03,  3.0469e-01]],

        [[ 1.7319e-03,  1.0872e-04, -7.7209e-03,  ..., -1.3489e-02,
          -8.0566e-03,  6.9580e-03],
         [-7.5000e-01,  3.8086e-01, -8.7500e-01,  ..., -4.7070e-01,
          -1.2578e+00,  4.1602e-01],
         [-2.5000e-01,  2.1289e-01,  5.5469e-01,  ...,  8.9844e-01,
          -1.0469e+00,  1.2266e+00],
         ...,
         [-1.3867e-01,  9.3750e-02,  3.3203e-01,  ..., -1.2598e-01,
          -2.4414e-01, -4.9072e-02],
         [ 3.9258e-01, -6.9336e-02,  3.7500e-01,  ...,  1.1250e+00,
           1.5078e+00, -1.6504e-01],
         [ 3.3789e-01,  9.3750e-01,  1.0781e+00,  ...,  4.0625e-01,
           1.9455e-03,  3.0469e-01]]], dtype=torch.bfloat16)), (tensor([[[ 1.2695e-02,  1.6098e-03,  1.1658e-02,  ...,  2.1680e-01,
           1.3672e-01, -1.2812e+00],
         [ 7.8516e-01,  4.4531e-01,  3.4375e-01,  ...,  1.6406e+00,
           8.2031e-01,  1.7891e+00],
         [ 2.5781e-01, -4.2969e-01, -2.9785e-02,  ...,  1.4766e+00,
          -2.0781e+00,  2.2656e+00],
         ...,
         [ 4.5117e-01,  2.4609e-01,  3.5547e-01,  ..., -3.2188e+00,
           8.6328e-01,  3.9531e+00],
         [-5.4297e-01, -1.9824e-01,  1.5625e-01,  ..., -1.3672e+00,
          -1.2109e+00,  2.7031e+00],
         [-2.0703e-01,  3.3008e-01,  2.9297e-01,  ...,  1.1953e+00,
          -1.2891e+00,  4.5625e+00]],

        [[ 1.2695e-02,  1.6098e-03,  1.1658e-02,  ...,  2.1680e-01,
           1.3672e-01, -1.2812e+00],
         [ 7.8516e-01,  4.4531e-01,  3.4375e-01,  ...,  1.6406e+00,
           8.2031e-01,  1.7891e+00],
         [ 2.5781e-01, -4.2969e-01, -2.9785e-02,  ...,  1.4766e+00,
          -2.0781e+00,  2.2656e+00],
         ...,
         [ 4.5117e-01,  2.4609e-01,  3.5547e-01,  ..., -3.2188e+00,
           8.6328e-01,  3.9531e+00],
         [-5.4297e-01, -1.9824e-01,  1.5625e-01,  ..., -1.3672e+00,
          -1.2109e+00,  2.7031e+00],
         [-2.0703e-01,  3.3008e-01,  2.9297e-01,  ...,  1.1953e+00,
          -1.2891e+00,  4.5625e+00]],

        [[ 1.2695e-02,  1.6098e-03,  1.1658e-02,  ...,  2.1680e-01,
           1.3672e-01, -1.2812e+00],
         [ 7.8516e-01,  4.4531e-01,  3.4375e-01,  ...,  1.6406e+00,
           8.2031e-01,  1.7891e+00],
         [ 2.5781e-01, -4.2969e-01, -2.9785e-02,  ...,  1.4766e+00,
          -2.0781e+00,  2.2656e+00],
         ...,
         [ 4.5117e-01,  2.4609e-01,  3.5547e-01,  ..., -3.2188e+00,
           8.6328e-01,  3.9531e+00],
         [-5.4297e-01, -1.9824e-01,  1.5625e-01,  ..., -1.3672e+00,
          -1.2109e+00,  2.7031e+00],
         [-2.0703e-01,  3.3008e-01,  2.9297e-01,  ...,  1.1953e+00,
          -1.2891e+00,  4.5625e+00]],

        ...,

        [[ 9.3384e-03, -2.3438e-02, -4.6692e-03,  ...,  5.3516e-01,
          -2.9297e-01, -8.1543e-02],
         [ 4.1211e-01,  9.6875e-01, -4.1992e-02,  ..., -1.7188e+00,
           9.1016e-01, -3.4062e+00],
         [ 2.2852e-01,  7.6904e-03,  3.0469e-01,  ..., -2.0703e-01,
           8.8281e-01, -1.5000e+00],
         ...,
         [-3.8574e-02, -5.1172e-01, -8.4961e-02,  ..., -1.6641e+00,
          -1.5000e+00, -2.7500e+00],
         [-4.8096e-02, -3.9453e-01, -1.1328e-01,  ...,  1.5547e+00,
          -8.3594e-01,  3.3008e-01],
         [-8.6719e-01, -1.0547e-01,  9.9219e-01,  ...,  1.0547e+00,
           8.8281e-01, -6.6250e+00]],

        [[ 9.3384e-03, -2.3438e-02, -4.6692e-03,  ...,  5.3516e-01,
          -2.9297e-01, -8.1543e-02],
         [ 4.1211e-01,  9.6875e-01, -4.1992e-02,  ..., -1.7188e+00,
           9.1016e-01, -3.4062e+00],
         [ 2.2852e-01,  7.6904e-03,  3.0469e-01,  ..., -2.0703e-01,
           8.8281e-01, -1.5000e+00],
         ...,
         [-3.8574e-02, -5.1172e-01, -8.4961e-02,  ..., -1.6641e+00,
          -1.5000e+00, -2.7500e+00],
         [-4.8096e-02, -3.9453e-01, -1.1328e-01,  ...,  1.5547e+00,
          -8.3594e-01,  3.3008e-01],
         [-8.6719e-01, -1.0547e-01,  9.9219e-01,  ...,  1.0547e+00,
           8.8281e-01, -6.6250e+00]],

        [[ 9.3384e-03, -2.3438e-02, -4.6692e-03,  ...,  5.3516e-01,
          -2.9297e-01, -8.1543e-02],
         [ 4.1211e-01,  9.6875e-01, -4.1992e-02,  ..., -1.7188e+00,
           9.1016e-01, -3.4062e+00],
         [ 2.2852e-01,  7.6904e-03,  3.0469e-01,  ..., -2.0703e-01,
           8.8281e-01, -1.5000e+00],
         ...,
         [-3.8574e-02, -5.1172e-01, -8.4961e-02,  ..., -1.6641e+00,
          -1.5000e+00, -2.7500e+00],
         [-4.8096e-02, -3.9453e-01, -1.1328e-01,  ...,  1.5547e+00,
          -8.3594e-01,  3.3008e-01],
         [-8.6719e-01, -1.0547e-01,  9.9219e-01,  ...,  1.0547e+00,
           8.8281e-01, -6.6250e+00]]], dtype=torch.bfloat16), tensor([[[-5.3406e-03,  1.2398e-04,  1.0864e-02,  ..., -7.3547e-03,
           1.1475e-02,  7.3547e-03],
         [ 2.5156e+00, -1.6172e+00, -5.4688e-01,  ...,  1.7773e-01,
           7.1094e-01,  2.1406e+00],
         [ 2.7930e-01, -1.3125e+00, -1.1172e+00,  ...,  2.2754e-01,
           3.7109e-01,  7.1094e-01],
         ...,
         [-9.7656e-02, -3.7500e-01,  6.1328e-01,  ...,  4.0820e-01,
          -1.4258e-01,  9.9609e-01],
         [-1.8359e-01,  1.9922e-01, -7.2266e-01,  ...,  4.2578e-01,
           8.9111e-03,  8.5156e-01],
         [ 1.8906e+00, -2.2852e-01,  3.4961e-01,  ...,  1.7344e+00,
          -1.5469e+00,  1.1797e+00]],

        [[-5.3406e-03,  1.2398e-04,  1.0864e-02,  ..., -7.3547e-03,
           1.1475e-02,  7.3547e-03],
         [ 2.5156e+00, -1.6172e+00, -5.4688e-01,  ...,  1.7773e-01,
           7.1094e-01,  2.1406e+00],
         [ 2.7930e-01, -1.3125e+00, -1.1172e+00,  ...,  2.2754e-01,
           3.7109e-01,  7.1094e-01],
         ...,
         [-9.7656e-02, -3.7500e-01,  6.1328e-01,  ...,  4.0820e-01,
          -1.4258e-01,  9.9609e-01],
         [-1.8359e-01,  1.9922e-01, -7.2266e-01,  ...,  4.2578e-01,
           8.9111e-03,  8.5156e-01],
         [ 1.8906e+00, -2.2852e-01,  3.4961e-01,  ...,  1.7344e+00,
          -1.5469e+00,  1.1797e+00]],

        [[-5.3406e-03,  1.2398e-04,  1.0864e-02,  ..., -7.3547e-03,
           1.1475e-02,  7.3547e-03],
         [ 2.5156e+00, -1.6172e+00, -5.4688e-01,  ...,  1.7773e-01,
           7.1094e-01,  2.1406e+00],
         [ 2.7930e-01, -1.3125e+00, -1.1172e+00,  ...,  2.2754e-01,
           3.7109e-01,  7.1094e-01],
         ...,
         [-9.7656e-02, -3.7500e-01,  6.1328e-01,  ...,  4.0820e-01,
          -1.4258e-01,  9.9609e-01],
         [-1.8359e-01,  1.9922e-01, -7.2266e-01,  ...,  4.2578e-01,
           8.9111e-03,  8.5156e-01],
         [ 1.8906e+00, -2.2852e-01,  3.4961e-01,  ...,  1.7344e+00,
          -1.5469e+00,  1.1797e+00]],

        ...,

        [[ 5.8289e-03,  1.7822e-02, -2.7222e-02,  ..., -1.2360e-03,
           1.0071e-02,  9.7656e-04],
         [-6.9531e-01, -1.4062e+00,  4.2773e-01,  ...,  6.6016e-01,
           1.1719e-01,  2.8438e+00],
         [-3.5352e-01, -1.9141e+00,  2.0996e-01,  ..., -3.6328e-01,
           1.7383e-01, -1.1562e+00],
         ...,
         [-3.3594e-01, -7.4609e-01, -1.2988e-01,  ...,  3.0078e-01,
          -9.7656e-02, -2.6172e-01],
         [-1.0781e+00, -3.1641e-01, -2.8125e-01,  ...,  3.0469e-01,
           1.0547e+00,  6.4941e-02],
         [ 1.7266e+00,  4.6143e-02,  2.1719e+00,  ...,  1.1670e-01,
           9.8828e-01, -1.0703e+00]],

        [[ 5.8289e-03,  1.7822e-02, -2.7222e-02,  ..., -1.2360e-03,
           1.0071e-02,  9.7656e-04],
         [-6.9531e-01, -1.4062e+00,  4.2773e-01,  ...,  6.6016e-01,
           1.1719e-01,  2.8438e+00],
         [-3.5352e-01, -1.9141e+00,  2.0996e-01,  ..., -3.6328e-01,
           1.7383e-01, -1.1562e+00],
         ...,
         [-3.3594e-01, -7.4609e-01, -1.2988e-01,  ...,  3.0078e-01,
          -9.7656e-02, -2.6172e-01],
         [-1.0781e+00, -3.1641e-01, -2.8125e-01,  ...,  3.0469e-01,
           1.0547e+00,  6.4941e-02],
         [ 1.7266e+00,  4.6143e-02,  2.1719e+00,  ...,  1.1670e-01,
           9.8828e-01, -1.0703e+00]],

        [[ 5.8289e-03,  1.7822e-02, -2.7222e-02,  ..., -1.2360e-03,
           1.0071e-02,  9.7656e-04],
         [-6.9531e-01, -1.4062e+00,  4.2773e-01,  ...,  6.6016e-01,
           1.1719e-01,  2.8438e+00],
         [-3.5352e-01, -1.9141e+00,  2.0996e-01,  ..., -3.6328e-01,
           1.7383e-01, -1.1562e+00],
         ...,
         [-3.3594e-01, -7.4609e-01, -1.2988e-01,  ...,  3.0078e-01,
          -9.7656e-02, -2.6172e-01],
         [-1.0781e+00, -3.1641e-01, -2.8125e-01,  ...,  3.0469e-01,
           1.0547e+00,  6.4941e-02],
         [ 1.7266e+00,  4.6143e-02,  2.1719e+00,  ...,  1.1670e-01,
           9.8828e-01, -1.0703e+00]]], dtype=torch.bfloat16)), (tensor([[[ 8.9722e-03,  1.1139e-03, -4.6997e-03,  ...,  2.6611e-02,
          -1.4531e+00,  2.0020e-01],
         [-3.1875e+00,  1.5625e-01, -1.3438e+00,  ..., -1.1797e+00,
           7.4688e+00, -3.3750e+00],
         [-7.5000e-01,  7.6172e-01, -5.3125e-01,  ...,  7.8516e-01,
           8.7500e+00, -2.4844e+00],
         ...,
         [ 4.1406e-01, -4.9609e-01, -2.8125e-01,  ..., -5.5859e-01,
           5.4375e+00, -3.8750e+00],
         [ 8.1250e-01, -3.7305e-01, -7.9688e-01,  ...,  1.6211e-01,
           6.7500e+00, -2.1719e+00],
         [ 6.2500e-01, -1.1797e+00, -1.3750e+00,  ..., -8.6719e-01,
           7.7500e+00, -1.3359e+00]],

        [[ 8.9722e-03,  1.1139e-03, -4.6997e-03,  ...,  2.6611e-02,
          -1.4531e+00,  2.0020e-01],
         [-3.1875e+00,  1.5625e-01, -1.3438e+00,  ..., -1.1797e+00,
           7.4688e+00, -3.3750e+00],
         [-7.5000e-01,  7.6172e-01, -5.3125e-01,  ...,  7.8516e-01,
           8.7500e+00, -2.4844e+00],
         ...,
         [ 4.1406e-01, -4.9609e-01, -2.8125e-01,  ..., -5.5859e-01,
           5.4375e+00, -3.8750e+00],
         [ 8.1250e-01, -3.7305e-01, -7.9688e-01,  ...,  1.6211e-01,
           6.7500e+00, -2.1719e+00],
         [ 6.2500e-01, -1.1797e+00, -1.3750e+00,  ..., -8.6719e-01,
           7.7500e+00, -1.3359e+00]],

        [[ 8.9722e-03,  1.1139e-03, -4.6997e-03,  ...,  2.6611e-02,
          -1.4531e+00,  2.0020e-01],
         [-3.1875e+00,  1.5625e-01, -1.3438e+00,  ..., -1.1797e+00,
           7.4688e+00, -3.3750e+00],
         [-7.5000e-01,  7.6172e-01, -5.3125e-01,  ...,  7.8516e-01,
           8.7500e+00, -2.4844e+00],
         ...,
         [ 4.1406e-01, -4.9609e-01, -2.8125e-01,  ..., -5.5859e-01,
           5.4375e+00, -3.8750e+00],
         [ 8.1250e-01, -3.7305e-01, -7.9688e-01,  ...,  1.6211e-01,
           6.7500e+00, -2.1719e+00],
         [ 6.2500e-01, -1.1797e+00, -1.3750e+00,  ..., -8.6719e-01,
           7.7500e+00, -1.3359e+00]],

        ...,

        [[-6.3782e-03, -8.7280e-03, -9.7046e-03,  ...,  8.0078e-02,
           8.1055e-02, -2.4707e-01],
         [ 3.1562e+00,  3.8672e-01,  2.2266e-01,  ..., -4.9805e-01,
          -1.1484e+00,  1.0645e-01],
         [ 9.3359e-01,  1.5391e+00,  1.1328e-01,  ...,  1.3203e+00,
           2.0938e+00,  7.9688e-01],
         ...,
         [-3.0469e-01,  6.2891e-01,  6.9922e-01,  ..., -2.6875e+00,
           2.8320e-01,  1.9922e+00],
         [-7.3047e-01, -3.1836e-01,  5.7422e-01,  ..., -1.9766e+00,
           2.2344e+00,  1.0078e+00],
         [-6.1719e-01, -1.3438e+00,  1.0703e+00,  ...,  4.0430e-01,
           1.1250e+00, -3.8086e-01]],

        [[-6.3782e-03, -8.7280e-03, -9.7046e-03,  ...,  8.0078e-02,
           8.1055e-02, -2.4707e-01],
         [ 3.1562e+00,  3.8672e-01,  2.2266e-01,  ..., -4.9805e-01,
          -1.1484e+00,  1.0645e-01],
         [ 9.3359e-01,  1.5391e+00,  1.1328e-01,  ...,  1.3203e+00,
           2.0938e+00,  7.9688e-01],
         ...,
         [-3.0469e-01,  6.2891e-01,  6.9922e-01,  ..., -2.6875e+00,
           2.8320e-01,  1.9922e+00],
         [-7.3047e-01, -3.1836e-01,  5.7422e-01,  ..., -1.9766e+00,
           2.2344e+00,  1.0078e+00],
         [-6.1719e-01, -1.3438e+00,  1.0703e+00,  ...,  4.0430e-01,
           1.1250e+00, -3.8086e-01]],

        [[-6.3782e-03, -8.7280e-03, -9.7046e-03,  ...,  8.0078e-02,
           8.1055e-02, -2.4707e-01],
         [ 3.1562e+00,  3.8672e-01,  2.2266e-01,  ..., -4.9805e-01,
          -1.1484e+00,  1.0645e-01],
         [ 9.3359e-01,  1.5391e+00,  1.1328e-01,  ...,  1.3203e+00,
           2.0938e+00,  7.9688e-01],
         ...,
         [-3.0469e-01,  6.2891e-01,  6.9922e-01,  ..., -2.6875e+00,
           2.8320e-01,  1.9922e+00],
         [-7.3047e-01, -3.1836e-01,  5.7422e-01,  ..., -1.9766e+00,
           2.2344e+00,  1.0078e+00],
         [-6.1719e-01, -1.3438e+00,  1.0703e+00,  ...,  4.0430e-01,
           1.1250e+00, -3.8086e-01]]], dtype=torch.bfloat16), tensor([[[ 7.6294e-03, -6.6528e-03, -5.7678e-03,  ...,  6.5613e-03,
          -5.0354e-03,  2.7313e-03],
         [ 9.2969e-01,  1.1641e+00, -3.2812e-01,  ...,  1.3281e+00,
           2.6367e-01,  2.0625e+00],
         [-4.9609e-01, -8.0078e-01,  4.6680e-01,  ...,  5.9766e-01,
          -2.7734e-01,  8.7500e-01],
         ...,
         [ 1.1562e+00, -5.1562e-01, -1.0234e+00,  ...,  4.6484e-01,
           1.3438e+00, -6.7578e-01],
         [ 5.3125e-01, -7.6953e-01,  1.4746e-01,  ...,  8.7891e-01,
           4.4531e-01, -3.4180e-01],
         [ 1.8281e+00, -4.3750e-01,  1.8828e+00,  ..., -1.3750e+00,
          -1.1406e+00,  1.0469e+00]],

        [[ 7.6294e-03, -6.6528e-03, -5.7678e-03,  ...,  6.5613e-03,
          -5.0354e-03,  2.7313e-03],
         [ 9.2969e-01,  1.1641e+00, -3.2812e-01,  ...,  1.3281e+00,
           2.6367e-01,  2.0625e+00],
         [-4.9609e-01, -8.0078e-01,  4.6680e-01,  ...,  5.9766e-01,
          -2.7734e-01,  8.7500e-01],
         ...,
         [ 1.1562e+00, -5.1562e-01, -1.0234e+00,  ...,  4.6484e-01,
           1.3438e+00, -6.7578e-01],
         [ 5.3125e-01, -7.6953e-01,  1.4746e-01,  ...,  8.7891e-01,
           4.4531e-01, -3.4180e-01],
         [ 1.8281e+00, -4.3750e-01,  1.8828e+00,  ..., -1.3750e+00,
          -1.1406e+00,  1.0469e+00]],

        [[ 7.6294e-03, -6.6528e-03, -5.7678e-03,  ...,  6.5613e-03,
          -5.0354e-03,  2.7313e-03],
         [ 9.2969e-01,  1.1641e+00, -3.2812e-01,  ...,  1.3281e+00,
           2.6367e-01,  2.0625e+00],
         [-4.9609e-01, -8.0078e-01,  4.6680e-01,  ...,  5.9766e-01,
          -2.7734e-01,  8.7500e-01],
         ...,
         [ 1.1562e+00, -5.1562e-01, -1.0234e+00,  ...,  4.6484e-01,
           1.3438e+00, -6.7578e-01],
         [ 5.3125e-01, -7.6953e-01,  1.4746e-01,  ...,  8.7891e-01,
           4.4531e-01, -3.4180e-01],
         [ 1.8281e+00, -4.3750e-01,  1.8828e+00,  ..., -1.3750e+00,
          -1.1406e+00,  1.0469e+00]],

        ...,

        [[-6.7234e-05,  2.8381e-03, -5.2795e-03,  ...,  2.2736e-03,
          -6.7444e-03,  6.7749e-03],
         [-7.3242e-02,  5.4688e-01,  8.6328e-01,  ...,  7.6953e-01,
           4.8633e-01,  3.1055e-01],
         [-6.3672e-01,  3.1445e-01,  2.5977e-01,  ..., -3.6865e-02,
           1.9165e-02, -8.3594e-01],
         ...,
         [-2.8320e-01,  8.5156e-01,  2.9688e-01,  ..., -9.0625e-01,
           6.4844e-01,  5.4688e-01],
         [-1.9238e-01, -7.3047e-01,  3.6133e-02,  ...,  6.4844e-01,
           1.8125e+00, -1.9336e-01],
         [ 1.1953e+00,  1.9141e-01,  1.0312e+00,  ..., -1.3580e-03,
           9.6094e-01,  1.2422e+00]],

        [[-6.7234e-05,  2.8381e-03, -5.2795e-03,  ...,  2.2736e-03,
          -6.7444e-03,  6.7749e-03],
         [-7.3242e-02,  5.4688e-01,  8.6328e-01,  ...,  7.6953e-01,
           4.8633e-01,  3.1055e-01],
         [-6.3672e-01,  3.1445e-01,  2.5977e-01,  ..., -3.6865e-02,
           1.9165e-02, -8.3594e-01],
         ...,
         [-2.8320e-01,  8.5156e-01,  2.9688e-01,  ..., -9.0625e-01,
           6.4844e-01,  5.4688e-01],
         [-1.9238e-01, -7.3047e-01,  3.6133e-02,  ...,  6.4844e-01,
           1.8125e+00, -1.9336e-01],
         [ 1.1953e+00,  1.9141e-01,  1.0312e+00,  ..., -1.3580e-03,
           9.6094e-01,  1.2422e+00]],

        [[-6.7234e-05,  2.8381e-03, -5.2795e-03,  ...,  2.2736e-03,
          -6.7444e-03,  6.7749e-03],
         [-7.3242e-02,  5.4688e-01,  8.6328e-01,  ...,  7.6953e-01,
           4.8633e-01,  3.1055e-01],
         [-6.3672e-01,  3.1445e-01,  2.5977e-01,  ..., -3.6865e-02,
           1.9165e-02, -8.3594e-01],
         ...,
         [-2.8320e-01,  8.5156e-01,  2.9688e-01,  ..., -9.0625e-01,
           6.4844e-01,  5.4688e-01],
         [-1.9238e-01, -7.3047e-01,  3.6133e-02,  ...,  6.4844e-01,
           1.8125e+00, -1.9336e-01],
         [ 1.1953e+00,  1.9141e-01,  1.0312e+00,  ..., -1.3580e-03,
           9.6094e-01,  1.2422e+00]]], dtype=torch.bfloat16)), (tensor([[[ 1.3245e-02, -6.1035e-04,  2.7618e-03,  ..., -1.0303e-01,
          -1.4941e-01,  1.9727e-01],
         [-1.1875e+00, -4.7070e-01,  2.4531e+00,  ...,  9.4141e-01,
           2.5195e-01,  1.8125e+00],
         [ 1.2451e-01,  2.1387e-01,  2.1973e-01,  ...,  1.1719e+00,
          -3.3398e-01, -2.8125e+00],
         ...,
         [-4.5508e-01,  7.6172e-01,  9.5703e-01,  ...,  5.8203e-01,
           1.1719e+00, -4.7266e-01],
         [-4.4922e-01,  3.0762e-02, -4.8828e-02,  ..., -1.3965e-01,
           7.4219e-02, -1.6406e+00],
         [ 5.5859e-01, -1.0625e+00,  5.0391e-01,  ...,  1.5156e+00,
           5.2344e-01,  6.2109e-01]],

        [[ 1.3245e-02, -6.1035e-04,  2.7618e-03,  ..., -1.0303e-01,
          -1.4941e-01,  1.9727e-01],
         [-1.1875e+00, -4.7070e-01,  2.4531e+00,  ...,  9.4141e-01,
           2.5195e-01,  1.8125e+00],
         [ 1.2451e-01,  2.1387e-01,  2.1973e-01,  ...,  1.1719e+00,
          -3.3398e-01, -2.8125e+00],
         ...,
         [-4.5508e-01,  7.6172e-01,  9.5703e-01,  ...,  5.8203e-01,
           1.1719e+00, -4.7266e-01],
         [-4.4922e-01,  3.0762e-02, -4.8828e-02,  ..., -1.3965e-01,
           7.4219e-02, -1.6406e+00],
         [ 5.5859e-01, -1.0625e+00,  5.0391e-01,  ...,  1.5156e+00,
           5.2344e-01,  6.2109e-01]],

        [[ 1.3245e-02, -6.1035e-04,  2.7618e-03,  ..., -1.0303e-01,
          -1.4941e-01,  1.9727e-01],
         [-1.1875e+00, -4.7070e-01,  2.4531e+00,  ...,  9.4141e-01,
           2.5195e-01,  1.8125e+00],
         [ 1.2451e-01,  2.1387e-01,  2.1973e-01,  ...,  1.1719e+00,
          -3.3398e-01, -2.8125e+00],
         ...,
         [-4.5508e-01,  7.6172e-01,  9.5703e-01,  ...,  5.8203e-01,
           1.1719e+00, -4.7266e-01],
         [-4.4922e-01,  3.0762e-02, -4.8828e-02,  ..., -1.3965e-01,
           7.4219e-02, -1.6406e+00],
         [ 5.5859e-01, -1.0625e+00,  5.0391e-01,  ...,  1.5156e+00,
           5.2344e-01,  6.2109e-01]],

        ...,

        [[ 7.3242e-03,  2.6398e-03,  3.4027e-03,  ..., -1.0352e-01,
           8.5938e-01, -3.6914e-01],
         [-1.3281e-01,  6.6016e-01,  1.0059e-01,  ...,  2.9531e+00,
          -3.9062e+00, -5.7031e-01],
         [-2.3828e-01,  6.2012e-02, -2.3535e-01,  ...,  1.9609e+00,
          -5.2500e+00,  1.8516e+00],
         ...,
         [ 6.1768e-02,  1.2891e-01,  1.5820e-01,  ..., -1.9609e+00,
           7.5684e-02,  1.1250e+00],
         [ 3.3203e-01,  5.9814e-02,  2.6172e-01,  ..., -7.5781e-01,
          -2.4844e+00,  1.6328e+00],
         [ 8.4961e-02, -4.3555e-01,  1.2109e-01,  ...,  1.8281e+00,
          -3.2812e+00,  4.8828e-01]],

        [[ 7.3242e-03,  2.6398e-03,  3.4027e-03,  ..., -1.0352e-01,
           8.5938e-01, -3.6914e-01],
         [-1.3281e-01,  6.6016e-01,  1.0059e-01,  ...,  2.9531e+00,
          -3.9062e+00, -5.7031e-01],
         [-2.3828e-01,  6.2012e-02, -2.3535e-01,  ...,  1.9609e+00,
          -5.2500e+00,  1.8516e+00],
         ...,
         [ 6.1768e-02,  1.2891e-01,  1.5820e-01,  ..., -1.9609e+00,
           7.5684e-02,  1.1250e+00],
         [ 3.3203e-01,  5.9814e-02,  2.6172e-01,  ..., -7.5781e-01,
          -2.4844e+00,  1.6328e+00],
         [ 8.4961e-02, -4.3555e-01,  1.2109e-01,  ...,  1.8281e+00,
          -3.2812e+00,  4.8828e-01]],

        [[ 7.3242e-03,  2.6398e-03,  3.4027e-03,  ..., -1.0352e-01,
           8.5938e-01, -3.6914e-01],
         [-1.3281e-01,  6.6016e-01,  1.0059e-01,  ...,  2.9531e+00,
          -3.9062e+00, -5.7031e-01],
         [-2.3828e-01,  6.2012e-02, -2.3535e-01,  ...,  1.9609e+00,
          -5.2500e+00,  1.8516e+00],
         ...,
         [ 6.1768e-02,  1.2891e-01,  1.5820e-01,  ..., -1.9609e+00,
           7.5684e-02,  1.1250e+00],
         [ 3.3203e-01,  5.9814e-02,  2.6172e-01,  ..., -7.5781e-01,
          -2.4844e+00,  1.6328e+00],
         [ 8.4961e-02, -4.3555e-01,  1.2109e-01,  ...,  1.8281e+00,
          -3.2812e+00,  4.8828e-01]]], dtype=torch.bfloat16), tensor([[[-7.5378e-03,  4.1504e-03,  5.4932e-03,  ..., -3.3569e-03,
           1.2024e-02,  3.7384e-03],
         [-1.2656e+00, -1.1250e+00, -8.9453e-01,  ...,  3.3008e-01,
          -6.0156e-01, -7.5391e-01],
         [-2.0156e+00, -9.3750e-01, -2.6245e-03,  ..., -3.0000e+00,
           9.2578e-01, -1.0234e+00],
         ...,
         [ 2.3047e-01, -1.2656e+00, -8.7109e-01,  ..., -3.1250e+00,
           1.0312e+00,  6.4453e-01],
         [ 4.9414e-01, -7.6562e-01, -1.8828e+00,  ..., -5.6562e+00,
           2.2344e+00, -1.2422e+00],
         [-5.0000e-01, -5.0391e-01, -1.3438e+00,  ..., -2.5938e+00,
          -1.1016e+00, -2.6367e-01]],

        [[-7.5378e-03,  4.1504e-03,  5.4932e-03,  ..., -3.3569e-03,
           1.2024e-02,  3.7384e-03],
         [-1.2656e+00, -1.1250e+00, -8.9453e-01,  ...,  3.3008e-01,
          -6.0156e-01, -7.5391e-01],
         [-2.0156e+00, -9.3750e-01, -2.6245e-03,  ..., -3.0000e+00,
           9.2578e-01, -1.0234e+00],
         ...,
         [ 2.3047e-01, -1.2656e+00, -8.7109e-01,  ..., -3.1250e+00,
           1.0312e+00,  6.4453e-01],
         [ 4.9414e-01, -7.6562e-01, -1.8828e+00,  ..., -5.6562e+00,
           2.2344e+00, -1.2422e+00],
         [-5.0000e-01, -5.0391e-01, -1.3438e+00,  ..., -2.5938e+00,
          -1.1016e+00, -2.6367e-01]],

        [[-7.5378e-03,  4.1504e-03,  5.4932e-03,  ..., -3.3569e-03,
           1.2024e-02,  3.7384e-03],
         [-1.2656e+00, -1.1250e+00, -8.9453e-01,  ...,  3.3008e-01,
          -6.0156e-01, -7.5391e-01],
         [-2.0156e+00, -9.3750e-01, -2.6245e-03,  ..., -3.0000e+00,
           9.2578e-01, -1.0234e+00],
         ...,
         [ 2.3047e-01, -1.2656e+00, -8.7109e-01,  ..., -3.1250e+00,
           1.0312e+00,  6.4453e-01],
         [ 4.9414e-01, -7.6562e-01, -1.8828e+00,  ..., -5.6562e+00,
           2.2344e+00, -1.2422e+00],
         [-5.0000e-01, -5.0391e-01, -1.3438e+00,  ..., -2.5938e+00,
          -1.1016e+00, -2.6367e-01]],

        ...,

        [[ 7.0190e-03, -1.6556e-03, -3.8605e-03,  ..., -3.7384e-03,
           2.1973e-03,  1.5564e-03],
         [-8.9844e-01, -1.7422e+00,  4.0312e+00,  ..., -1.5469e+00,
           9.5312e-01,  2.0625e+00],
         [ 1.4954e-03, -8.6719e-01,  2.4375e+00,  ..., -1.3672e+00,
          -6.4062e-01, -1.4453e-01],
         ...,
         [-3.0664e-01,  1.4062e+00, -1.9062e+00,  ..., -3.3447e-02,
           3.8867e-01,  1.0391e+00],
         [ 7.2754e-02,  7.9688e-01, -5.7031e-01,  ...,  2.6367e-01,
           2.5977e-01,  6.9531e-01],
         [-1.5078e+00, -7.9297e-01,  9.1406e-01,  ..., -1.8516e+00,
           8.2812e-01,  4.5703e-01]],

        [[ 7.0190e-03, -1.6556e-03, -3.8605e-03,  ..., -3.7384e-03,
           2.1973e-03,  1.5564e-03],
         [-8.9844e-01, -1.7422e+00,  4.0312e+00,  ..., -1.5469e+00,
           9.5312e-01,  2.0625e+00],
         [ 1.4954e-03, -8.6719e-01,  2.4375e+00,  ..., -1.3672e+00,
          -6.4062e-01, -1.4453e-01],
         ...,
         [-3.0664e-01,  1.4062e+00, -1.9062e+00,  ..., -3.3447e-02,
           3.8867e-01,  1.0391e+00],
         [ 7.2754e-02,  7.9688e-01, -5.7031e-01,  ...,  2.6367e-01,
           2.5977e-01,  6.9531e-01],
         [-1.5078e+00, -7.9297e-01,  9.1406e-01,  ..., -1.8516e+00,
           8.2812e-01,  4.5703e-01]],

        [[ 7.0190e-03, -1.6556e-03, -3.8605e-03,  ..., -3.7384e-03,
           2.1973e-03,  1.5564e-03],
         [-8.9844e-01, -1.7422e+00,  4.0312e+00,  ..., -1.5469e+00,
           9.5312e-01,  2.0625e+00],
         [ 1.4954e-03, -8.6719e-01,  2.4375e+00,  ..., -1.3672e+00,
          -6.4062e-01, -1.4453e-01],
         ...,
         [-3.0664e-01,  1.4062e+00, -1.9062e+00,  ..., -3.3447e-02,
           3.8867e-01,  1.0391e+00],
         [ 7.2754e-02,  7.9688e-01, -5.7031e-01,  ...,  2.6367e-01,
           2.5977e-01,  6.9531e-01],
         [-1.5078e+00, -7.9297e-01,  9.1406e-01,  ..., -1.8516e+00,
           8.2812e-01,  4.5703e-01]]], dtype=torch.bfloat16)), (tensor([[[ 8.5449e-03,  1.2741e-03, -8.4229e-03,  ..., -1.3379e-01,
          -6.2866e-03,  7.2327e-03],
         [-1.0781e+00, -6.3281e-01, -2.5391e-01,  ...,  5.6641e-01,
          -1.0625e+00,  2.5156e+00],
         [ 1.4258e-01,  2.4609e-01, -9.9219e-01,  ...,  1.7344e+00,
           2.8125e-01,  3.3906e+00],
         ...,
         [ 5.6641e-01,  5.5859e-01, -9.5703e-02,  ..., -2.7812e+00,
           1.7822e-02,  4.1992e-01],
         [ 7.6562e-01,  8.0078e-01, -7.8906e-01,  ...,  1.9844e+00,
          -1.1016e+00,  1.7109e+00],
         [-4.3750e-01, -1.1094e+00, -1.1250e+00,  ..., -8.3984e-02,
          -2.2656e-01,  3.4844e+00]],

        [[ 8.5449e-03,  1.2741e-03, -8.4229e-03,  ..., -1.3379e-01,
          -6.2866e-03,  7.2327e-03],
         [-1.0781e+00, -6.3281e-01, -2.5391e-01,  ...,  5.6641e-01,
          -1.0625e+00,  2.5156e+00],
         [ 1.4258e-01,  2.4609e-01, -9.9219e-01,  ...,  1.7344e+00,
           2.8125e-01,  3.3906e+00],
         ...,
         [ 5.6641e-01,  5.5859e-01, -9.5703e-02,  ..., -2.7812e+00,
           1.7822e-02,  4.1992e-01],
         [ 7.6562e-01,  8.0078e-01, -7.8906e-01,  ...,  1.9844e+00,
          -1.1016e+00,  1.7109e+00],
         [-4.3750e-01, -1.1094e+00, -1.1250e+00,  ..., -8.3984e-02,
          -2.2656e-01,  3.4844e+00]],

        [[ 8.5449e-03,  1.2741e-03, -8.4229e-03,  ..., -1.3379e-01,
          -6.2866e-03,  7.2327e-03],
         [-1.0781e+00, -6.3281e-01, -2.5391e-01,  ...,  5.6641e-01,
          -1.0625e+00,  2.5156e+00],
         [ 1.4258e-01,  2.4609e-01, -9.9219e-01,  ...,  1.7344e+00,
           2.8125e-01,  3.3906e+00],
         ...,
         [ 5.6641e-01,  5.5859e-01, -9.5703e-02,  ..., -2.7812e+00,
           1.7822e-02,  4.1992e-01],
         [ 7.6562e-01,  8.0078e-01, -7.8906e-01,  ...,  1.9844e+00,
          -1.1016e+00,  1.7109e+00],
         [-4.3750e-01, -1.1094e+00, -1.1250e+00,  ..., -8.3984e-02,
          -2.2656e-01,  3.4844e+00]],

        ...,

        [[ 4.4861e-03, -7.5073e-03, -7.8735e-03,  ..., -7.3242e-02,
          -5.5908e-02,  5.2490e-02],
         [ 1.3125e+00, -1.4844e-01,  1.0781e+00,  ..., -9.2969e-01,
           2.9883e-01,  5.6641e-01],
         [-1.1484e+00,  1.1875e+00,  1.0840e-01,  ...,  8.4766e-01,
           3.3789e-01,  1.4219e+00],
         ...,
         [-4.1602e-01, -2.0703e-01, -6.2891e-01,  ..., -3.0312e+00,
          -1.6250e+00, -3.7656e+00],
         [-6.7188e-01, -6.3672e-01,  2.2461e-02,  ..., -3.0625e+00,
          -2.2031e+00, -2.2500e+00],
         [ 9.5312e-01, -1.6953e+00,  1.5938e+00,  ..., -4.2188e-01,
           5.0781e-01,  3.1406e+00]],

        [[ 4.4861e-03, -7.5073e-03, -7.8735e-03,  ..., -7.3242e-02,
          -5.5908e-02,  5.2490e-02],
         [ 1.3125e+00, -1.4844e-01,  1.0781e+00,  ..., -9.2969e-01,
           2.9883e-01,  5.6641e-01],
         [-1.1484e+00,  1.1875e+00,  1.0840e-01,  ...,  8.4766e-01,
           3.3789e-01,  1.4219e+00],
         ...,
         [-4.1602e-01, -2.0703e-01, -6.2891e-01,  ..., -3.0312e+00,
          -1.6250e+00, -3.7656e+00],
         [-6.7188e-01, -6.3672e-01,  2.2461e-02,  ..., -3.0625e+00,
          -2.2031e+00, -2.2500e+00],
         [ 9.5312e-01, -1.6953e+00,  1.5938e+00,  ..., -4.2188e-01,
           5.0781e-01,  3.1406e+00]],

        [[ 4.4861e-03, -7.5073e-03, -7.8735e-03,  ..., -7.3242e-02,
          -5.5908e-02,  5.2490e-02],
         [ 1.3125e+00, -1.4844e-01,  1.0781e+00,  ..., -9.2969e-01,
           2.9883e-01,  5.6641e-01],
         [-1.1484e+00,  1.1875e+00,  1.0840e-01,  ...,  8.4766e-01,
           3.3789e-01,  1.4219e+00],
         ...,
         [-4.1602e-01, -2.0703e-01, -6.2891e-01,  ..., -3.0312e+00,
          -1.6250e+00, -3.7656e+00],
         [-6.7188e-01, -6.3672e-01,  2.2461e-02,  ..., -3.0625e+00,
          -2.2031e+00, -2.2500e+00],
         [ 9.5312e-01, -1.6953e+00,  1.5938e+00,  ..., -4.2188e-01,
           5.0781e-01,  3.1406e+00]]], dtype=torch.bfloat16), tensor([[[ 5.7068e-03,  9.1171e-04, -8.8501e-03,  ..., -4.3869e-04,
          -9.8877e-03, -9.2316e-04],
         [-5.3906e-01, -3.3008e-01,  1.1035e-01,  ...,  1.0791e-01,
          -5.3125e-01, -3.2422e-01],
         [-1.3245e-02, -2.9492e-01, -7.5000e-01,  ..., -4.2773e-01,
          -9.1016e-01,  1.9824e-01],
         ...,
         [ 9.8438e-01,  2.3535e-01, -9.6875e-01,  ...,  8.7891e-02,
          -1.1562e+00, -1.0547e+00],
         [-3.1738e-02,  4.5508e-01, -1.8555e-01,  ..., -3.3936e-02,
          -9.8047e-01,  7.0703e-01],
         [-5.8203e-01,  2.0469e+00, -3.2617e-01,  ..., -6.2891e-01,
           3.3594e-01,  2.2363e-01]],

        [[ 5.7068e-03,  9.1171e-04, -8.8501e-03,  ..., -4.3869e-04,
          -9.8877e-03, -9.2316e-04],
         [-5.3906e-01, -3.3008e-01,  1.1035e-01,  ...,  1.0791e-01,
          -5.3125e-01, -3.2422e-01],
         [-1.3245e-02, -2.9492e-01, -7.5000e-01,  ..., -4.2773e-01,
          -9.1016e-01,  1.9824e-01],
         ...,
         [ 9.8438e-01,  2.3535e-01, -9.6875e-01,  ...,  8.7891e-02,
          -1.1562e+00, -1.0547e+00],
         [-3.1738e-02,  4.5508e-01, -1.8555e-01,  ..., -3.3936e-02,
          -9.8047e-01,  7.0703e-01],
         [-5.8203e-01,  2.0469e+00, -3.2617e-01,  ..., -6.2891e-01,
           3.3594e-01,  2.2363e-01]],

        [[ 5.7068e-03,  9.1171e-04, -8.8501e-03,  ..., -4.3869e-04,
          -9.8877e-03, -9.2316e-04],
         [-5.3906e-01, -3.3008e-01,  1.1035e-01,  ...,  1.0791e-01,
          -5.3125e-01, -3.2422e-01],
         [-1.3245e-02, -2.9492e-01, -7.5000e-01,  ..., -4.2773e-01,
          -9.1016e-01,  1.9824e-01],
         ...,
         [ 9.8438e-01,  2.3535e-01, -9.6875e-01,  ...,  8.7891e-02,
          -1.1562e+00, -1.0547e+00],
         [-3.1738e-02,  4.5508e-01, -1.8555e-01,  ..., -3.3936e-02,
          -9.8047e-01,  7.0703e-01],
         [-5.8203e-01,  2.0469e+00, -3.2617e-01,  ..., -6.2891e-01,
           3.3594e-01,  2.2363e-01]],

        ...,

        [[ 3.2227e-02,  2.3682e-02,  8.5449e-02,  ...,  9.9487e-03,
          -4.2236e-02,  3.6377e-02],
         [-3.1836e-01, -1.2695e-01,  9.5312e-01,  ..., -5.3516e-01,
           2.8711e-01, -3.7305e-01],
         [ 2.5000e-01, -4.2725e-02,  1.6797e-01,  ...,  1.0254e-01,
          -1.0078e+00,  5.4199e-02],
         ...,
         [-1.2500e-01, -1.3125e+00,  1.9688e+00,  ...,  5.1953e-01,
           2.6562e-01, -6.4453e-01],
         [-5.0781e-01, -8.6719e-01,  4.6289e-01,  ..., -1.4551e-01,
           3.2031e-01, -1.0859e+00],
         [ 5.2734e-01, -6.6797e-01,  1.1523e-01,  ...,  1.0703e+00,
           1.4141e+00,  6.3672e-01]],

        [[ 3.2227e-02,  2.3682e-02,  8.5449e-02,  ...,  9.9487e-03,
          -4.2236e-02,  3.6377e-02],
         [-3.1836e-01, -1.2695e-01,  9.5312e-01,  ..., -5.3516e-01,
           2.8711e-01, -3.7305e-01],
         [ 2.5000e-01, -4.2725e-02,  1.6797e-01,  ...,  1.0254e-01,
          -1.0078e+00,  5.4199e-02],
         ...,
         [-1.2500e-01, -1.3125e+00,  1.9688e+00,  ...,  5.1953e-01,
           2.6562e-01, -6.4453e-01],
         [-5.0781e-01, -8.6719e-01,  4.6289e-01,  ..., -1.4551e-01,
           3.2031e-01, -1.0859e+00],
         [ 5.2734e-01, -6.6797e-01,  1.1523e-01,  ...,  1.0703e+00,
           1.4141e+00,  6.3672e-01]],

        [[ 3.2227e-02,  2.3682e-02,  8.5449e-02,  ...,  9.9487e-03,
          -4.2236e-02,  3.6377e-02],
         [-3.1836e-01, -1.2695e-01,  9.5312e-01,  ..., -5.3516e-01,
           2.8711e-01, -3.7305e-01],
         [ 2.5000e-01, -4.2725e-02,  1.6797e-01,  ...,  1.0254e-01,
          -1.0078e+00,  5.4199e-02],
         ...,
         [-1.2500e-01, -1.3125e+00,  1.9688e+00,  ...,  5.1953e-01,
           2.6562e-01, -6.4453e-01],
         [-5.0781e-01, -8.6719e-01,  4.6289e-01,  ..., -1.4551e-01,
           3.2031e-01, -1.0859e+00],
         [ 5.2734e-01, -6.6797e-01,  1.1523e-01,  ...,  1.0703e+00,
           1.4141e+00,  6.3672e-01]]], dtype=torch.bfloat16)), (tensor([[[ 1.4893e-02,  5.1270e-03,  1.1780e-02,  ...,  1.9375e+00,
          -9.0820e-02,  3.8086e-02],
         [ 2.3438e-02,  2.0508e-01,  1.1572e-01,  ..., -6.1250e+00,
          -3.9648e-01,  3.4766e-01],
         [ 6.9336e-02, -4.5508e-01, -1.8555e-02,  ..., -7.1875e+00,
           5.3516e-01,  2.1406e+00],
         ...,
         [-3.8867e-01, -1.0156e-01, -4.8828e-03,  ..., -4.1875e+00,
           1.5391e+00, -4.2500e+00],
         [-2.5977e-01,  1.8164e-01,  3.3594e-01,  ..., -5.7500e+00,
           1.2188e+00, -6.2891e-01],
         [-4.0430e-01,  2.1289e-01,  2.4805e-01,  ..., -6.2812e+00,
          -3.4570e-01,  1.2061e-01]],

        [[ 1.4893e-02,  5.1270e-03,  1.1780e-02,  ...,  1.9375e+00,
          -9.0820e-02,  3.8086e-02],
         [ 2.3438e-02,  2.0508e-01,  1.1572e-01,  ..., -6.1250e+00,
          -3.9648e-01,  3.4766e-01],
         [ 6.9336e-02, -4.5508e-01, -1.8555e-02,  ..., -7.1875e+00,
           5.3516e-01,  2.1406e+00],
         ...,
         [-3.8867e-01, -1.0156e-01, -4.8828e-03,  ..., -4.1875e+00,
           1.5391e+00, -4.2500e+00],
         [-2.5977e-01,  1.8164e-01,  3.3594e-01,  ..., -5.7500e+00,
           1.2188e+00, -6.2891e-01],
         [-4.0430e-01,  2.1289e-01,  2.4805e-01,  ..., -6.2812e+00,
          -3.4570e-01,  1.2061e-01]],

        [[ 1.4893e-02,  5.1270e-03,  1.1780e-02,  ...,  1.9375e+00,
          -9.0820e-02,  3.8086e-02],
         [ 2.3438e-02,  2.0508e-01,  1.1572e-01,  ..., -6.1250e+00,
          -3.9648e-01,  3.4766e-01],
         [ 6.9336e-02, -4.5508e-01, -1.8555e-02,  ..., -7.1875e+00,
           5.3516e-01,  2.1406e+00],
         ...,
         [-3.8867e-01, -1.0156e-01, -4.8828e-03,  ..., -4.1875e+00,
           1.5391e+00, -4.2500e+00],
         [-2.5977e-01,  1.8164e-01,  3.3594e-01,  ..., -5.7500e+00,
           1.2188e+00, -6.2891e-01],
         [-4.0430e-01,  2.1289e-01,  2.4805e-01,  ..., -6.2812e+00,
          -3.4570e-01,  1.2061e-01]],

        ...,

        [[ 2.5787e-03,  3.2501e-03,  1.0620e-02,  ...,  6.8359e-02,
           9.8047e-01, -1.3281e-01],
         [-3.9062e-03,  4.6484e-01, -1.1250e+00,  ...,  4.6631e-02,
          -1.1953e+00, -3.4766e-01],
         [-4.2188e-01, -9.4531e-01, -8.1641e-01,  ...,  1.5332e-01,
          -4.5938e+00,  2.9688e-01],
         ...,
         [-1.1641e+00,  8.3594e-01, -1.0254e-01,  ...,  2.6489e-02,
          -2.1094e+00, -1.4062e-01],
         [-3.0469e-01,  2.9297e-01, -2.6562e-01,  ...,  2.4219e-01,
          -2.7812e+00, -1.8457e-01],
         [ 1.5938e+00,  1.6094e+00, -1.6641e+00,  ..., -5.5469e-01,
          -2.8594e+00,  5.2344e-01]],

        [[ 2.5787e-03,  3.2501e-03,  1.0620e-02,  ...,  6.8359e-02,
           9.8047e-01, -1.3281e-01],
         [-3.9062e-03,  4.6484e-01, -1.1250e+00,  ...,  4.6631e-02,
          -1.1953e+00, -3.4766e-01],
         [-4.2188e-01, -9.4531e-01, -8.1641e-01,  ...,  1.5332e-01,
          -4.5938e+00,  2.9688e-01],
         ...,
         [-1.1641e+00,  8.3594e-01, -1.0254e-01,  ...,  2.6489e-02,
          -2.1094e+00, -1.4062e-01],
         [-3.0469e-01,  2.9297e-01, -2.6562e-01,  ...,  2.4219e-01,
          -2.7812e+00, -1.8457e-01],
         [ 1.5938e+00,  1.6094e+00, -1.6641e+00,  ..., -5.5469e-01,
          -2.8594e+00,  5.2344e-01]],

        [[ 2.5787e-03,  3.2501e-03,  1.0620e-02,  ...,  6.8359e-02,
           9.8047e-01, -1.3281e-01],
         [-3.9062e-03,  4.6484e-01, -1.1250e+00,  ...,  4.6631e-02,
          -1.1953e+00, -3.4766e-01],
         [-4.2188e-01, -9.4531e-01, -8.1641e-01,  ...,  1.5332e-01,
          -4.5938e+00,  2.9688e-01],
         ...,
         [-1.1641e+00,  8.3594e-01, -1.0254e-01,  ...,  2.6489e-02,
          -2.1094e+00, -1.4062e-01],
         [-3.0469e-01,  2.9297e-01, -2.6562e-01,  ...,  2.4219e-01,
          -2.7812e+00, -1.8457e-01],
         [ 1.5938e+00,  1.6094e+00, -1.6641e+00,  ..., -5.5469e-01,
          -2.8594e+00,  5.2344e-01]]], dtype=torch.bfloat16), tensor([[[-0.0053,  0.0099, -0.0089,  ...,  0.0029, -0.0044, -0.0042],
         [ 2.7031, -0.5781,  0.9375,  ...,  2.6094, -1.6562, -0.2637],
         [ 0.5664, -1.2891,  0.1299,  ...,  0.7930, -0.0786, -1.3047],
         ...,
         [-0.9766, -1.0312, -0.3301,  ...,  0.1187,  0.3691, -0.3945],
         [-1.4297,  0.5469, -0.8711,  ..., -0.2305, -0.8359, -0.7109],
         [ 1.0469,  0.2207, -0.1836,  ...,  0.0776,  0.9375,  0.6406]],

        [[-0.0053,  0.0099, -0.0089,  ...,  0.0029, -0.0044, -0.0042],
         [ 2.7031, -0.5781,  0.9375,  ...,  2.6094, -1.6562, -0.2637],
         [ 0.5664, -1.2891,  0.1299,  ...,  0.7930, -0.0786, -1.3047],
         ...,
         [-0.9766, -1.0312, -0.3301,  ...,  0.1187,  0.3691, -0.3945],
         [-1.4297,  0.5469, -0.8711,  ..., -0.2305, -0.8359, -0.7109],
         [ 1.0469,  0.2207, -0.1836,  ...,  0.0776,  0.9375,  0.6406]],

        [[-0.0053,  0.0099, -0.0089,  ...,  0.0029, -0.0044, -0.0042],
         [ 2.7031, -0.5781,  0.9375,  ...,  2.6094, -1.6562, -0.2637],
         [ 0.5664, -1.2891,  0.1299,  ...,  0.7930, -0.0786, -1.3047],
         ...,
         [-0.9766, -1.0312, -0.3301,  ...,  0.1187,  0.3691, -0.3945],
         [-1.4297,  0.5469, -0.8711,  ..., -0.2305, -0.8359, -0.7109],
         [ 1.0469,  0.2207, -0.1836,  ...,  0.0776,  0.9375,  0.6406]],

        ...,

        [[-0.0347,  0.0099,  0.0165,  ..., -0.0503, -0.0281,  0.0081],
         [ 0.5117,  0.2988,  0.5977,  ...,  0.2715,  0.8438, -0.1855],
         [ 1.2891,  0.1729,  0.1484,  ...,  0.2363,  0.6758, -0.1934],
         ...,
         [ 1.0312,  0.2168,  0.1875,  ...,  0.7891, -0.2988, -0.1855],
         [ 1.6641,  0.8945,  0.0201,  ..., -0.1406, -0.2793, -0.1436],
         [ 0.1865, -0.0791, -0.4160,  ..., -0.0747, -0.1514, -0.4941]],

        [[-0.0347,  0.0099,  0.0165,  ..., -0.0503, -0.0281,  0.0081],
         [ 0.5117,  0.2988,  0.5977,  ...,  0.2715,  0.8438, -0.1855],
         [ 1.2891,  0.1729,  0.1484,  ...,  0.2363,  0.6758, -0.1934],
         ...,
         [ 1.0312,  0.2168,  0.1875,  ...,  0.7891, -0.2988, -0.1855],
         [ 1.6641,  0.8945,  0.0201,  ..., -0.1406, -0.2793, -0.1436],
         [ 0.1865, -0.0791, -0.4160,  ..., -0.0747, -0.1514, -0.4941]],

        [[-0.0347,  0.0099,  0.0165,  ..., -0.0503, -0.0281,  0.0081],
         [ 0.5117,  0.2988,  0.5977,  ...,  0.2715,  0.8438, -0.1855],
         [ 1.2891,  0.1729,  0.1484,  ...,  0.2363,  0.6758, -0.1934],
         ...,
         [ 1.0312,  0.2168,  0.1875,  ...,  0.7891, -0.2988, -0.1855],
         [ 1.6641,  0.8945,  0.0201,  ..., -0.1406, -0.2793, -0.1436],
         [ 0.1865, -0.0791, -0.4160,  ..., -0.0747, -0.1514, -0.4941]]],
       dtype=torch.bfloat16)), (tensor([[[ 1.6113e-02,  5.1498e-04, -1.2268e-02,  ...,  2.3047e-01,
           6.2500e-02,  7.8906e-01],
         [ 1.0156e-01, -8.4766e-01,  2.4219e-01,  ...,  5.6641e-01,
          -2.0781e+00, -3.1562e+00],
         [ 1.3672e-01,  1.2793e-01,  9.2188e-01,  ..., -3.5938e-01,
          -1.2109e+00, -4.1875e+00],
         ...,
         [-3.0469e-01,  2.6367e-02, -4.1504e-02,  ..., -5.5859e-01,
          -4.3750e-01,  1.1572e-01],
         [-6.3672e-01,  7.0312e-01,  1.0938e+00,  ..., -6.6797e-01,
           6.6016e-01, -1.3984e+00],
         [-2.0508e-02,  5.9570e-02,  5.3906e-01,  ..., -8.5938e-01,
           1.6172e+00, -2.9375e+00]],

        [[ 1.6113e-02,  5.1498e-04, -1.2268e-02,  ...,  2.3047e-01,
           6.2500e-02,  7.8906e-01],
         [ 1.0156e-01, -8.4766e-01,  2.4219e-01,  ...,  5.6641e-01,
          -2.0781e+00, -3.1562e+00],
         [ 1.3672e-01,  1.2793e-01,  9.2188e-01,  ..., -3.5938e-01,
          -1.2109e+00, -4.1875e+00],
         ...,
         [-3.0469e-01,  2.6367e-02, -4.1504e-02,  ..., -5.5859e-01,
          -4.3750e-01,  1.1572e-01],
         [-6.3672e-01,  7.0312e-01,  1.0938e+00,  ..., -6.6797e-01,
           6.6016e-01, -1.3984e+00],
         [-2.0508e-02,  5.9570e-02,  5.3906e-01,  ..., -8.5938e-01,
           1.6172e+00, -2.9375e+00]],

        [[ 1.6113e-02,  5.1498e-04, -1.2268e-02,  ...,  2.3047e-01,
           6.2500e-02,  7.8906e-01],
         [ 1.0156e-01, -8.4766e-01,  2.4219e-01,  ...,  5.6641e-01,
          -2.0781e+00, -3.1562e+00],
         [ 1.3672e-01,  1.2793e-01,  9.2188e-01,  ..., -3.5938e-01,
          -1.2109e+00, -4.1875e+00],
         ...,
         [-3.0469e-01,  2.6367e-02, -4.1504e-02,  ..., -5.5859e-01,
          -4.3750e-01,  1.1572e-01],
         [-6.3672e-01,  7.0312e-01,  1.0938e+00,  ..., -6.6797e-01,
           6.6016e-01, -1.3984e+00],
         [-2.0508e-02,  5.9570e-02,  5.3906e-01,  ..., -8.5938e-01,
           1.6172e+00, -2.9375e+00]],

        ...,

        [[-6.3171e-03,  1.1536e-02, -4.2152e-04,  ..., -9.0820e-02,
          -2.0605e-01,  7.3730e-02],
         [ 5.0781e-02,  6.1523e-02,  7.6172e-02,  ...,  2.1562e+00,
           1.2988e-01, -1.7812e+00],
         [-2.1484e-01,  3.0273e-01,  4.9219e-01,  ...,  3.0938e+00,
          -1.2266e+00, -3.8594e+00],
         ...,
         [ 3.1250e-01,  2.9688e-01, -1.5820e-01,  ..., -1.8906e+00,
          -2.1719e+00,  3.1250e-01],
         [ 4.3359e-01, -3.0078e-01,  2.0508e-02,  ..., -3.2471e-02,
           3.0078e-01, -1.3184e-01],
         [-4.3555e-01, -4.3945e-02,  2.0508e-01,  ...,  2.4531e+00,
           3.5156e-01,  9.6875e-01]],

        [[-6.3171e-03,  1.1536e-02, -4.2152e-04,  ..., -9.0820e-02,
          -2.0605e-01,  7.3730e-02],
         [ 5.0781e-02,  6.1523e-02,  7.6172e-02,  ...,  2.1562e+00,
           1.2988e-01, -1.7812e+00],
         [-2.1484e-01,  3.0273e-01,  4.9219e-01,  ...,  3.0938e+00,
          -1.2266e+00, -3.8594e+00],
         ...,
         [ 3.1250e-01,  2.9688e-01, -1.5820e-01,  ..., -1.8906e+00,
          -2.1719e+00,  3.1250e-01],
         [ 4.3359e-01, -3.0078e-01,  2.0508e-02,  ..., -3.2471e-02,
           3.0078e-01, -1.3184e-01],
         [-4.3555e-01, -4.3945e-02,  2.0508e-01,  ...,  2.4531e+00,
           3.5156e-01,  9.6875e-01]],

        [[-6.3171e-03,  1.1536e-02, -4.2152e-04,  ..., -9.0820e-02,
          -2.0605e-01,  7.3730e-02],
         [ 5.0781e-02,  6.1523e-02,  7.6172e-02,  ...,  2.1562e+00,
           1.2988e-01, -1.7812e+00],
         [-2.1484e-01,  3.0273e-01,  4.9219e-01,  ...,  3.0938e+00,
          -1.2266e+00, -3.8594e+00],
         ...,
         [ 3.1250e-01,  2.9688e-01, -1.5820e-01,  ..., -1.8906e+00,
          -2.1719e+00,  3.1250e-01],
         [ 4.3359e-01, -3.0078e-01,  2.0508e-02,  ..., -3.2471e-02,
           3.0078e-01, -1.3184e-01],
         [-4.3555e-01, -4.3945e-02,  2.0508e-01,  ...,  2.4531e+00,
           3.5156e-01,  9.6875e-01]]], dtype=torch.bfloat16), tensor([[[ 6.7444e-03, -1.3672e-02,  2.2278e-03,  ...,  3.4180e-03,
          -8.3008e-03, -3.9062e-03],
         [ 9.1309e-02,  9.2773e-02, -8.5156e-01,  ..., -1.7734e+00,
          -1.3203e+00,  1.1094e+00],
         [ 8.8281e-01, -3.5352e-01, -1.4258e-01,  ..., -5.0391e-01,
           6.4941e-02,  2.2656e-01],
         ...,
         [ 1.3574e-01, -4.7070e-01, -5.9375e-01,  ..., -1.8262e-01,
           1.1406e+00,  1.4355e-01],
         [-9.4922e-01,  1.7969e-01,  4.5312e-01,  ...,  4.4141e-01,
          -6.1719e-01, -1.5137e-01],
         [-5.2734e-01, -1.3906e+00, -1.7285e-01,  ...,  1.8652e-01,
           1.2188e+00,  1.0791e-01]],

        [[ 6.7444e-03, -1.3672e-02,  2.2278e-03,  ...,  3.4180e-03,
          -8.3008e-03, -3.9062e-03],
         [ 9.1309e-02,  9.2773e-02, -8.5156e-01,  ..., -1.7734e+00,
          -1.3203e+00,  1.1094e+00],
         [ 8.8281e-01, -3.5352e-01, -1.4258e-01,  ..., -5.0391e-01,
           6.4941e-02,  2.2656e-01],
         ...,
         [ 1.3574e-01, -4.7070e-01, -5.9375e-01,  ..., -1.8262e-01,
           1.1406e+00,  1.4355e-01],
         [-9.4922e-01,  1.7969e-01,  4.5312e-01,  ...,  4.4141e-01,
          -6.1719e-01, -1.5137e-01],
         [-5.2734e-01, -1.3906e+00, -1.7285e-01,  ...,  1.8652e-01,
           1.2188e+00,  1.0791e-01]],

        [[ 6.7444e-03, -1.3672e-02,  2.2278e-03,  ...,  3.4180e-03,
          -8.3008e-03, -3.9062e-03],
         [ 9.1309e-02,  9.2773e-02, -8.5156e-01,  ..., -1.7734e+00,
          -1.3203e+00,  1.1094e+00],
         [ 8.8281e-01, -3.5352e-01, -1.4258e-01,  ..., -5.0391e-01,
           6.4941e-02,  2.2656e-01],
         ...,
         [ 1.3574e-01, -4.7070e-01, -5.9375e-01,  ..., -1.8262e-01,
           1.1406e+00,  1.4355e-01],
         [-9.4922e-01,  1.7969e-01,  4.5312e-01,  ...,  4.4141e-01,
          -6.1719e-01, -1.5137e-01],
         [-5.2734e-01, -1.3906e+00, -1.7285e-01,  ...,  1.8652e-01,
           1.2188e+00,  1.0791e-01]],

        ...,

        [[-2.0142e-02,  5.4626e-03,  2.3499e-03,  ..., -1.0147e-03,
          -1.6113e-02,  8.1177e-03],
         [ 1.0078e+00,  2.0801e-01, -1.6250e+00,  ...,  2.9297e-01,
           7.2656e-01,  2.6094e+00],
         [ 8.9453e-01,  3.0156e+00, -7.5000e-01,  ...,  6.4453e-01,
           1.6602e-01,  2.5625e+00],
         ...,
         [-3.8086e-02,  7.1094e-01, -3.0664e-01,  ..., -9.5703e-02,
          -2.1289e-01,  1.6113e-01],
         [-5.7031e-01,  8.2812e-01, -2.5781e-01,  ...,  5.5469e-01,
           6.4453e-01, -2.0605e-01],
         [-7.8906e-01, -9.3750e-01, -1.6406e+00,  ..., -6.4453e-01,
           1.4141e+00, -2.1250e+00]],

        [[-2.0142e-02,  5.4626e-03,  2.3499e-03,  ..., -1.0147e-03,
          -1.6113e-02,  8.1177e-03],
         [ 1.0078e+00,  2.0801e-01, -1.6250e+00,  ...,  2.9297e-01,
           7.2656e-01,  2.6094e+00],
         [ 8.9453e-01,  3.0156e+00, -7.5000e-01,  ...,  6.4453e-01,
           1.6602e-01,  2.5625e+00],
         ...,
         [-3.8086e-02,  7.1094e-01, -3.0664e-01,  ..., -9.5703e-02,
          -2.1289e-01,  1.6113e-01],
         [-5.7031e-01,  8.2812e-01, -2.5781e-01,  ...,  5.5469e-01,
           6.4453e-01, -2.0605e-01],
         [-7.8906e-01, -9.3750e-01, -1.6406e+00,  ..., -6.4453e-01,
           1.4141e+00, -2.1250e+00]],

        [[-2.0142e-02,  5.4626e-03,  2.3499e-03,  ..., -1.0147e-03,
          -1.6113e-02,  8.1177e-03],
         [ 1.0078e+00,  2.0801e-01, -1.6250e+00,  ...,  2.9297e-01,
           7.2656e-01,  2.6094e+00],
         [ 8.9453e-01,  3.0156e+00, -7.5000e-01,  ...,  6.4453e-01,
           1.6602e-01,  2.5625e+00],
         ...,
         [-3.8086e-02,  7.1094e-01, -3.0664e-01,  ..., -9.5703e-02,
          -2.1289e-01,  1.6113e-01],
         [-5.7031e-01,  8.2812e-01, -2.5781e-01,  ...,  5.5469e-01,
           6.4453e-01, -2.0605e-01],
         [-7.8906e-01, -9.3750e-01, -1.6406e+00,  ..., -6.4453e-01,
           1.4141e+00, -2.1250e+00]]], dtype=torch.bfloat16)), (tensor([[[ 2.2583e-03,  1.4648e-03, -5.1270e-03,  ...,  8.6426e-02,
          -2.5586e-01, -1.3062e-02],
         [-2.4531e+00, -2.8516e-01, -9.7266e-01,  ..., -4.1797e-01,
          -6.9141e-01, -7.8516e-01],
         [-2.7031e+00, -9.8828e-01, -1.5234e+00,  ...,  1.9141e+00,
          -1.3672e+00,  1.1797e+00],
         ...,
         [-4.6484e-01,  9.7656e-03, -1.3984e+00,  ..., -1.7266e+00,
          -6.2891e-01, -5.4688e-01],
         [ 4.6094e-01, -4.6094e-01, -9.1406e-01,  ..., -1.8438e+00,
          -1.3125e+00,  5.9375e-01],
         [ 1.9844e+00,  1.8594e+00, -6.4844e-01,  ..., -1.7734e+00,
          -1.3594e+00, -2.2656e+00]],

        [[ 2.2583e-03,  1.4648e-03, -5.1270e-03,  ...,  8.6426e-02,
          -2.5586e-01, -1.3062e-02],
         [-2.4531e+00, -2.8516e-01, -9.7266e-01,  ..., -4.1797e-01,
          -6.9141e-01, -7.8516e-01],
         [-2.7031e+00, -9.8828e-01, -1.5234e+00,  ...,  1.9141e+00,
          -1.3672e+00,  1.1797e+00],
         ...,
         [-4.6484e-01,  9.7656e-03, -1.3984e+00,  ..., -1.7266e+00,
          -6.2891e-01, -5.4688e-01],
         [ 4.6094e-01, -4.6094e-01, -9.1406e-01,  ..., -1.8438e+00,
          -1.3125e+00,  5.9375e-01],
         [ 1.9844e+00,  1.8594e+00, -6.4844e-01,  ..., -1.7734e+00,
          -1.3594e+00, -2.2656e+00]],

        [[ 2.2583e-03,  1.4648e-03, -5.1270e-03,  ...,  8.6426e-02,
          -2.5586e-01, -1.3062e-02],
         [-2.4531e+00, -2.8516e-01, -9.7266e-01,  ..., -4.1797e-01,
          -6.9141e-01, -7.8516e-01],
         [-2.7031e+00, -9.8828e-01, -1.5234e+00,  ...,  1.9141e+00,
          -1.3672e+00,  1.1797e+00],
         ...,
         [-4.6484e-01,  9.7656e-03, -1.3984e+00,  ..., -1.7266e+00,
          -6.2891e-01, -5.4688e-01],
         [ 4.6094e-01, -4.6094e-01, -9.1406e-01,  ..., -1.8438e+00,
          -1.3125e+00,  5.9375e-01],
         [ 1.9844e+00,  1.8594e+00, -6.4844e-01,  ..., -1.7734e+00,
          -1.3594e+00, -2.2656e+00]],

        ...,

        [[-3.2349e-03,  5.8594e-03,  7.5378e-03,  ..., -8.3008e-02,
           6.6406e-02, -8.7891e-02],
         [-3.1250e+00,  9.4141e-01, -2.5781e-01,  ...,  6.3281e-01,
          -1.2969e+00,  9.1797e-02],
         [ 1.6406e-01, -5.3906e-01, -2.2188e+00,  ..., -1.6602e-01,
          -2.3906e+00,  1.1328e-01],
         ...,
         [ 2.0156e+00,  1.0781e+00,  5.5469e-01,  ...,  3.1250e-01,
           6.3281e-01, -1.3906e+00],
         [ 1.5469e+00, -5.1172e-01,  1.2500e-01,  ...,  1.1182e-01,
           1.3672e-02, -1.5859e+00],
         [-2.1094e-01, -1.8125e+00, -1.2031e+00,  ..., -4.7461e-01,
          -8.0469e-01, -4.3359e-01]],

        [[-3.2349e-03,  5.8594e-03,  7.5378e-03,  ..., -8.3008e-02,
           6.6406e-02, -8.7891e-02],
         [-3.1250e+00,  9.4141e-01, -2.5781e-01,  ...,  6.3281e-01,
          -1.2969e+00,  9.1797e-02],
         [ 1.6406e-01, -5.3906e-01, -2.2188e+00,  ..., -1.6602e-01,
          -2.3906e+00,  1.1328e-01],
         ...,
         [ 2.0156e+00,  1.0781e+00,  5.5469e-01,  ...,  3.1250e-01,
           6.3281e-01, -1.3906e+00],
         [ 1.5469e+00, -5.1172e-01,  1.2500e-01,  ...,  1.1182e-01,
           1.3672e-02, -1.5859e+00],
         [-2.1094e-01, -1.8125e+00, -1.2031e+00,  ..., -4.7461e-01,
          -8.0469e-01, -4.3359e-01]],

        [[-3.2349e-03,  5.8594e-03,  7.5378e-03,  ..., -8.3008e-02,
           6.6406e-02, -8.7891e-02],
         [-3.1250e+00,  9.4141e-01, -2.5781e-01,  ...,  6.3281e-01,
          -1.2969e+00,  9.1797e-02],
         [ 1.6406e-01, -5.3906e-01, -2.2188e+00,  ..., -1.6602e-01,
          -2.3906e+00,  1.1328e-01],
         ...,
         [ 2.0156e+00,  1.0781e+00,  5.5469e-01,  ...,  3.1250e-01,
           6.3281e-01, -1.3906e+00],
         [ 1.5469e+00, -5.1172e-01,  1.2500e-01,  ...,  1.1182e-01,
           1.3672e-02, -1.5859e+00],
         [-2.1094e-01, -1.8125e+00, -1.2031e+00,  ..., -4.7461e-01,
          -8.0469e-01, -4.3359e-01]]], dtype=torch.bfloat16), tensor([[[ 5.5847e-03, -3.1738e-03,  1.5137e-02,  ...,  1.8692e-03,
           1.3733e-02,  1.8433e-02],
         [ 1.6992e-01, -1.4609e+00,  1.2012e-01,  ..., -3.7305e-01,
           1.3359e+00, -2.1191e-01],
         [ 3.2617e-01, -1.7090e-01,  1.1172e+00,  ...,  1.2656e+00,
          -4.8633e-01, -4.1602e-01],
         ...,
         [-1.1621e-01,  2.3926e-01,  7.4219e-01,  ..., -3.0640e-02,
          -1.5430e-01, -2.7344e-01],
         [ 5.2734e-01,  4.9023e-01, -2.8320e-02,  ...,  4.4727e-01,
           7.2656e-01, -4.6875e-01],
         [-3.6914e-01, -2.9492e-01, -7.0312e-01,  ..., -5.3711e-02,
          -6.3281e-01, -1.4062e+00]],

        [[ 5.5847e-03, -3.1738e-03,  1.5137e-02,  ...,  1.8692e-03,
           1.3733e-02,  1.8433e-02],
         [ 1.6992e-01, -1.4609e+00,  1.2012e-01,  ..., -3.7305e-01,
           1.3359e+00, -2.1191e-01],
         [ 3.2617e-01, -1.7090e-01,  1.1172e+00,  ...,  1.2656e+00,
          -4.8633e-01, -4.1602e-01],
         ...,
         [-1.1621e-01,  2.3926e-01,  7.4219e-01,  ..., -3.0640e-02,
          -1.5430e-01, -2.7344e-01],
         [ 5.2734e-01,  4.9023e-01, -2.8320e-02,  ...,  4.4727e-01,
           7.2656e-01, -4.6875e-01],
         [-3.6914e-01, -2.9492e-01, -7.0312e-01,  ..., -5.3711e-02,
          -6.3281e-01, -1.4062e+00]],

        [[ 5.5847e-03, -3.1738e-03,  1.5137e-02,  ...,  1.8692e-03,
           1.3733e-02,  1.8433e-02],
         [ 1.6992e-01, -1.4609e+00,  1.2012e-01,  ..., -3.7305e-01,
           1.3359e+00, -2.1191e-01],
         [ 3.2617e-01, -1.7090e-01,  1.1172e+00,  ...,  1.2656e+00,
          -4.8633e-01, -4.1602e-01],
         ...,
         [-1.1621e-01,  2.3926e-01,  7.4219e-01,  ..., -3.0640e-02,
          -1.5430e-01, -2.7344e-01],
         [ 5.2734e-01,  4.9023e-01, -2.8320e-02,  ...,  4.4727e-01,
           7.2656e-01, -4.6875e-01],
         [-3.6914e-01, -2.9492e-01, -7.0312e-01,  ..., -5.3711e-02,
          -6.3281e-01, -1.4062e+00]],

        ...,

        [[ 3.8910e-03, -1.4572e-03, -4.1809e-03,  ...,  1.4282e-02,
          -5.4321e-03,  6.4392e-03],
         [ 1.1250e+00, -6.4453e-01, -1.4551e-01,  ..., -6.5430e-02,
           6.2891e-01, -6.4453e-01],
         [-3.3398e-01, -1.8457e-01,  5.4688e-01,  ...,  8.7500e-01,
           2.1191e-01,  1.0781e+00],
         ...,
         [-2.2168e-01, -6.6797e-01,  1.1953e+00,  ...,  6.6406e-01,
           1.6250e+00, -1.9824e-01],
         [ 1.8652e-01,  5.7031e-01, -5.4297e-01,  ..., -6.9141e-01,
          -1.0547e+00, -7.7344e-01],
         [ 1.4609e+00, -6.9531e-01, -3.5889e-02,  ...,  6.1719e-01,
          -6.0547e-01,  2.4121e-01]],

        [[ 3.8910e-03, -1.4572e-03, -4.1809e-03,  ...,  1.4282e-02,
          -5.4321e-03,  6.4392e-03],
         [ 1.1250e+00, -6.4453e-01, -1.4551e-01,  ..., -6.5430e-02,
           6.2891e-01, -6.4453e-01],
         [-3.3398e-01, -1.8457e-01,  5.4688e-01,  ...,  8.7500e-01,
           2.1191e-01,  1.0781e+00],
         ...,
         [-2.2168e-01, -6.6797e-01,  1.1953e+00,  ...,  6.6406e-01,
           1.6250e+00, -1.9824e-01],
         [ 1.8652e-01,  5.7031e-01, -5.4297e-01,  ..., -6.9141e-01,
          -1.0547e+00, -7.7344e-01],
         [ 1.4609e+00, -6.9531e-01, -3.5889e-02,  ...,  6.1719e-01,
          -6.0547e-01,  2.4121e-01]],

        [[ 3.8910e-03, -1.4572e-03, -4.1809e-03,  ...,  1.4282e-02,
          -5.4321e-03,  6.4392e-03],
         [ 1.1250e+00, -6.4453e-01, -1.4551e-01,  ..., -6.5430e-02,
           6.2891e-01, -6.4453e-01],
         [-3.3398e-01, -1.8457e-01,  5.4688e-01,  ...,  8.7500e-01,
           2.1191e-01,  1.0781e+00],
         ...,
         [-2.2168e-01, -6.6797e-01,  1.1953e+00,  ...,  6.6406e-01,
           1.6250e+00, -1.9824e-01],
         [ 1.8652e-01,  5.7031e-01, -5.4297e-01,  ..., -6.9141e-01,
          -1.0547e+00, -7.7344e-01],
         [ 1.4609e+00, -6.9531e-01, -3.5889e-02,  ...,  6.1719e-01,
          -6.0547e-01,  2.4121e-01]]], dtype=torch.bfloat16)), (tensor([[[ 9.3994e-03, -1.4465e-02, -9.5825e-03,  ..., -1.4551e-01,
           5.5469e-01,  2.6562e+00],
         [ 9.4922e-01, -3.2812e-01,  2.4414e-02,  ..., -9.5312e-01,
          -3.2422e-01, -6.4062e+00],
         [-4.3750e-01, -2.2070e-01,  1.6016e-01,  ...,  1.9453e+00,
          -2.6978e-02, -9.3125e+00],
         ...,
         [ 2.5977e-01, -3.3203e-02,  1.1865e-01,  ...,  5.2344e-01,
          -1.6992e-01, -6.2500e+00],
         [-3.5938e-01,  4.0625e-01,  1.2793e-01,  ...,  3.2031e-01,
          -1.1953e+00, -4.8750e+00],
         [-2.5781e-01,  8.5547e-01,  1.1875e+00,  ...,  2.5625e+00,
          -3.0078e-01, -7.2812e+00]],

        [[ 9.3994e-03, -1.4465e-02, -9.5825e-03,  ..., -1.4551e-01,
           5.5469e-01,  2.6562e+00],
         [ 9.4922e-01, -3.2812e-01,  2.4414e-02,  ..., -9.5312e-01,
          -3.2422e-01, -6.4062e+00],
         [-4.3750e-01, -2.2070e-01,  1.6016e-01,  ...,  1.9453e+00,
          -2.6978e-02, -9.3125e+00],
         ...,
         [ 2.5977e-01, -3.3203e-02,  1.1865e-01,  ...,  5.2344e-01,
          -1.6992e-01, -6.2500e+00],
         [-3.5938e-01,  4.0625e-01,  1.2793e-01,  ...,  3.2031e-01,
          -1.1953e+00, -4.8750e+00],
         [-2.5781e-01,  8.5547e-01,  1.1875e+00,  ...,  2.5625e+00,
          -3.0078e-01, -7.2812e+00]],

        [[ 9.3994e-03, -1.4465e-02, -9.5825e-03,  ..., -1.4551e-01,
           5.5469e-01,  2.6562e+00],
         [ 9.4922e-01, -3.2812e-01,  2.4414e-02,  ..., -9.5312e-01,
          -3.2422e-01, -6.4062e+00],
         [-4.3750e-01, -2.2070e-01,  1.6016e-01,  ...,  1.9453e+00,
          -2.6978e-02, -9.3125e+00],
         ...,
         [ 2.5977e-01, -3.3203e-02,  1.1865e-01,  ...,  5.2344e-01,
          -1.6992e-01, -6.2500e+00],
         [-3.5938e-01,  4.0625e-01,  1.2793e-01,  ...,  3.2031e-01,
          -1.1953e+00, -4.8750e+00],
         [-2.5781e-01,  8.5547e-01,  1.1875e+00,  ...,  2.5625e+00,
          -3.0078e-01, -7.2812e+00]],

        ...,

        [[-2.3926e-02, -3.6240e-04, -4.7607e-03,  ...,  2.4048e-02,
           2.1875e+00,  3.8086e-01],
         [ 1.6250e+00, -7.9297e-01, -4.3359e-01,  ..., -3.8086e-01,
          -4.6250e+00,  2.2852e-01],
         [ 7.5195e-02, -5.9766e-01,  7.3828e-01,  ...,  9.8047e-01,
          -3.7500e+00,  1.2256e-01],
         ...,
         [-1.8555e-01, -4.6875e-01,  1.9922e-01,  ...,  2.1562e+00,
          -3.7344e+00,  9.8438e-01],
         [ 1.5625e-02,  2.0312e-01, -8.4961e-02,  ...,  1.0625e+00,
          -4.6875e+00,  8.7891e-01],
         [ 1.0156e+00,  7.4609e-01, -2.8711e-01,  ...,  1.9922e+00,
          -3.7188e+00,  1.3672e+00]],

        [[-2.3926e-02, -3.6240e-04, -4.7607e-03,  ...,  2.4048e-02,
           2.1875e+00,  3.8086e-01],
         [ 1.6250e+00, -7.9297e-01, -4.3359e-01,  ..., -3.8086e-01,
          -4.6250e+00,  2.2852e-01],
         [ 7.5195e-02, -5.9766e-01,  7.3828e-01,  ...,  9.8047e-01,
          -3.7500e+00,  1.2256e-01],
         ...,
         [-1.8555e-01, -4.6875e-01,  1.9922e-01,  ...,  2.1562e+00,
          -3.7344e+00,  9.8438e-01],
         [ 1.5625e-02,  2.0312e-01, -8.4961e-02,  ...,  1.0625e+00,
          -4.6875e+00,  8.7891e-01],
         [ 1.0156e+00,  7.4609e-01, -2.8711e-01,  ...,  1.9922e+00,
          -3.7188e+00,  1.3672e+00]],

        [[-2.3926e-02, -3.6240e-04, -4.7607e-03,  ...,  2.4048e-02,
           2.1875e+00,  3.8086e-01],
         [ 1.6250e+00, -7.9297e-01, -4.3359e-01,  ..., -3.8086e-01,
          -4.6250e+00,  2.2852e-01],
         [ 7.5195e-02, -5.9766e-01,  7.3828e-01,  ...,  9.8047e-01,
          -3.7500e+00,  1.2256e-01],
         ...,
         [-1.8555e-01, -4.6875e-01,  1.9922e-01,  ...,  2.1562e+00,
          -3.7344e+00,  9.8438e-01],
         [ 1.5625e-02,  2.0312e-01, -8.4961e-02,  ...,  1.0625e+00,
          -4.6875e+00,  8.7891e-01],
         [ 1.0156e+00,  7.4609e-01, -2.8711e-01,  ...,  1.9922e+00,
          -3.7188e+00,  1.3672e+00]]], dtype=torch.bfloat16), tensor([[[ 1.2268e-02, -1.1047e-02, -2.2949e-02,  ...,  1.1963e-02,
          -4.4861e-03,  2.0508e-02],
         [-3.2031e-01, -8.4766e-01,  5.0391e-01,  ..., -1.3184e-01,
           2.0801e-01, -1.2188e+00],
         [-5.3711e-02,  1.8281e+00,  2.7656e+00,  ..., -1.2256e-01,
           2.0000e+00,  1.0234e+00],
         ...,
         [-1.8750e+00,  1.3438e+00,  3.4180e-01,  ...,  1.5391e+00,
           1.1484e+00,  2.1719e+00],
         [-3.2617e-01,  3.3594e-01,  9.2578e-01,  ..., -6.3965e-02,
           1.5156e+00,  1.9922e+00],
         [-3.3984e-01, -5.0781e-01, -1.0391e+00,  ..., -1.4531e+00,
           2.8906e-01, -5.8984e-01]],

        [[ 1.2268e-02, -1.1047e-02, -2.2949e-02,  ...,  1.1963e-02,
          -4.4861e-03,  2.0508e-02],
         [-3.2031e-01, -8.4766e-01,  5.0391e-01,  ..., -1.3184e-01,
           2.0801e-01, -1.2188e+00],
         [-5.3711e-02,  1.8281e+00,  2.7656e+00,  ..., -1.2256e-01,
           2.0000e+00,  1.0234e+00],
         ...,
         [-1.8750e+00,  1.3438e+00,  3.4180e-01,  ...,  1.5391e+00,
           1.1484e+00,  2.1719e+00],
         [-3.2617e-01,  3.3594e-01,  9.2578e-01,  ..., -6.3965e-02,
           1.5156e+00,  1.9922e+00],
         [-3.3984e-01, -5.0781e-01, -1.0391e+00,  ..., -1.4531e+00,
           2.8906e-01, -5.8984e-01]],

        [[ 1.2268e-02, -1.1047e-02, -2.2949e-02,  ...,  1.1963e-02,
          -4.4861e-03,  2.0508e-02],
         [-3.2031e-01, -8.4766e-01,  5.0391e-01,  ..., -1.3184e-01,
           2.0801e-01, -1.2188e+00],
         [-5.3711e-02,  1.8281e+00,  2.7656e+00,  ..., -1.2256e-01,
           2.0000e+00,  1.0234e+00],
         ...,
         [-1.8750e+00,  1.3438e+00,  3.4180e-01,  ...,  1.5391e+00,
           1.1484e+00,  2.1719e+00],
         [-3.2617e-01,  3.3594e-01,  9.2578e-01,  ..., -6.3965e-02,
           1.5156e+00,  1.9922e+00],
         [-3.3984e-01, -5.0781e-01, -1.0391e+00,  ..., -1.4531e+00,
           2.8906e-01, -5.8984e-01]],

        ...,

        [[ 4.6082e-03, -1.6022e-04, -2.8687e-02,  ..., -4.7607e-03,
          -8.6060e-03, -6.3705e-04],
         [-1.5938e+00, -1.2305e-01,  2.7148e-01,  ..., -1.8555e-02,
          -4.1016e-01,  4.1406e-01],
         [-1.4062e+00,  6.1328e-01, -2.0605e-01,  ..., -3.7305e-01,
           5.9375e-01,  4.0625e-01],
         ...,
         [-2.7148e-01,  1.1406e+00,  7.8516e-01,  ..., -1.3086e-01,
          -5.7031e-01, -1.9297e+00],
         [ 5.4688e-01,  4.3701e-02,  7.5391e-01,  ..., -7.3438e-01,
           8.0078e-01,  1.9727e-01],
         [ 8.7500e-01,  2.2559e-01,  5.5469e-01,  ..., -9.8047e-01,
          -1.9238e-01, -1.0391e+00]],

        [[ 4.6082e-03, -1.6022e-04, -2.8687e-02,  ..., -4.7607e-03,
          -8.6060e-03, -6.3705e-04],
         [-1.5938e+00, -1.2305e-01,  2.7148e-01,  ..., -1.8555e-02,
          -4.1016e-01,  4.1406e-01],
         [-1.4062e+00,  6.1328e-01, -2.0605e-01,  ..., -3.7305e-01,
           5.9375e-01,  4.0625e-01],
         ...,
         [-2.7148e-01,  1.1406e+00,  7.8516e-01,  ..., -1.3086e-01,
          -5.7031e-01, -1.9297e+00],
         [ 5.4688e-01,  4.3701e-02,  7.5391e-01,  ..., -7.3438e-01,
           8.0078e-01,  1.9727e-01],
         [ 8.7500e-01,  2.2559e-01,  5.5469e-01,  ..., -9.8047e-01,
          -1.9238e-01, -1.0391e+00]],

        [[ 4.6082e-03, -1.6022e-04, -2.8687e-02,  ..., -4.7607e-03,
          -8.6060e-03, -6.3705e-04],
         [-1.5938e+00, -1.2305e-01,  2.7148e-01,  ..., -1.8555e-02,
          -4.1016e-01,  4.1406e-01],
         [-1.4062e+00,  6.1328e-01, -2.0605e-01,  ..., -3.7305e-01,
           5.9375e-01,  4.0625e-01],
         ...,
         [-2.7148e-01,  1.1406e+00,  7.8516e-01,  ..., -1.3086e-01,
          -5.7031e-01, -1.9297e+00],
         [ 5.4688e-01,  4.3701e-02,  7.5391e-01,  ..., -7.3438e-01,
           8.0078e-01,  1.9727e-01],
         [ 8.7500e-01,  2.2559e-01,  5.5469e-01,  ..., -9.8047e-01,
          -1.9238e-01, -1.0391e+00]]], dtype=torch.bfloat16)), (tensor([[[ 7.6904e-03, -4.5166e-03,  4.1809e-03,  ...,  7.6599e-03,
          -2.2754e-01,  7.9102e-02],
         [-2.2031e+00,  9.3750e-02, -2.4512e-01,  ...,  1.4609e+00,
           7.2266e-01, -8.1641e-01],
         [-1.9062e+00, -5.0781e-01,  7.1289e-02,  ...,  1.2031e+00,
           1.2344e+00, -8.0078e-01],
         ...,
         [-2.8516e-01, -2.5195e-01, -1.5312e+00,  ...,  3.8594e+00,
          -2.4707e-01,  1.6172e+00],
         [ 5.8594e-01, -6.5625e-01,  1.3281e-01,  ...,  4.0625e+00,
          -9.5312e-01,  5.3516e-01],
         [ 1.7578e+00,  3.3398e-01, -5.1953e-01,  ...,  1.0391e+00,
           2.8906e+00, -3.9368e-03]],

        [[ 7.6904e-03, -4.5166e-03,  4.1809e-03,  ...,  7.6599e-03,
          -2.2754e-01,  7.9102e-02],
         [-2.2031e+00,  9.3750e-02, -2.4512e-01,  ...,  1.4609e+00,
           7.2266e-01, -8.1641e-01],
         [-1.9062e+00, -5.0781e-01,  7.1289e-02,  ...,  1.2031e+00,
           1.2344e+00, -8.0078e-01],
         ...,
         [-2.8516e-01, -2.5195e-01, -1.5312e+00,  ...,  3.8594e+00,
          -2.4707e-01,  1.6172e+00],
         [ 5.8594e-01, -6.5625e-01,  1.3281e-01,  ...,  4.0625e+00,
          -9.5312e-01,  5.3516e-01],
         [ 1.7578e+00,  3.3398e-01, -5.1953e-01,  ...,  1.0391e+00,
           2.8906e+00, -3.9368e-03]],

        [[ 7.6904e-03, -4.5166e-03,  4.1809e-03,  ...,  7.6599e-03,
          -2.2754e-01,  7.9102e-02],
         [-2.2031e+00,  9.3750e-02, -2.4512e-01,  ...,  1.4609e+00,
           7.2266e-01, -8.1641e-01],
         [-1.9062e+00, -5.0781e-01,  7.1289e-02,  ...,  1.2031e+00,
           1.2344e+00, -8.0078e-01],
         ...,
         [-2.8516e-01, -2.5195e-01, -1.5312e+00,  ...,  3.8594e+00,
          -2.4707e-01,  1.6172e+00],
         [ 5.8594e-01, -6.5625e-01,  1.3281e-01,  ...,  4.0625e+00,
          -9.5312e-01,  5.3516e-01],
         [ 1.7578e+00,  3.3398e-01, -5.1953e-01,  ...,  1.0391e+00,
           2.8906e+00, -3.9368e-03]],

        ...,

        [[ 4.9744e-03, -6.3171e-03,  2.1076e-04,  ..., -7.6172e-02,
           2.6562e-01,  1.2354e-01],
         [ 1.2969e+00,  1.5625e+00,  1.3906e+00,  ...,  1.0703e+00,
           5.5078e-01,  1.7812e+00],
         [ 4.7266e-01, -8.3984e-01,  1.8750e+00,  ..., -9.1406e-01,
           1.4453e+00,  4.1992e-01],
         ...,
         [-1.3203e+00, -6.0156e-01, -3.9453e-01,  ..., -2.9102e-01,
           2.1250e+00, -1.2344e+00],
         [-2.4707e-01, -3.7109e-01,  5.2344e-01,  ..., -2.1719e+00,
           1.7578e+00, -2.6250e+00],
         [-2.5586e-01, -1.3086e-01,  6.1719e-01,  ..., -5.3125e-01,
           1.4062e+00, -7.2266e-01]],

        [[ 4.9744e-03, -6.3171e-03,  2.1076e-04,  ..., -7.6172e-02,
           2.6562e-01,  1.2354e-01],
         [ 1.2969e+00,  1.5625e+00,  1.3906e+00,  ...,  1.0703e+00,
           5.5078e-01,  1.7812e+00],
         [ 4.7266e-01, -8.3984e-01,  1.8750e+00,  ..., -9.1406e-01,
           1.4453e+00,  4.1992e-01],
         ...,
         [-1.3203e+00, -6.0156e-01, -3.9453e-01,  ..., -2.9102e-01,
           2.1250e+00, -1.2344e+00],
         [-2.4707e-01, -3.7109e-01,  5.2344e-01,  ..., -2.1719e+00,
           1.7578e+00, -2.6250e+00],
         [-2.5586e-01, -1.3086e-01,  6.1719e-01,  ..., -5.3125e-01,
           1.4062e+00, -7.2266e-01]],

        [[ 4.9744e-03, -6.3171e-03,  2.1076e-04,  ..., -7.6172e-02,
           2.6562e-01,  1.2354e-01],
         [ 1.2969e+00,  1.5625e+00,  1.3906e+00,  ...,  1.0703e+00,
           5.5078e-01,  1.7812e+00],
         [ 4.7266e-01, -8.3984e-01,  1.8750e+00,  ..., -9.1406e-01,
           1.4453e+00,  4.1992e-01],
         ...,
         [-1.3203e+00, -6.0156e-01, -3.9453e-01,  ..., -2.9102e-01,
           2.1250e+00, -1.2344e+00],
         [-2.4707e-01, -3.7109e-01,  5.2344e-01,  ..., -2.1719e+00,
           1.7578e+00, -2.6250e+00],
         [-2.5586e-01, -1.3086e-01,  6.1719e-01,  ..., -5.3125e-01,
           1.4062e+00, -7.2266e-01]]], dtype=torch.bfloat16), tensor([[[-4.2725e-03,  5.8594e-03,  1.6937e-03,  ...,  5.0049e-03,
          -1.9775e-02, -1.0620e-02],
         [ 3.3203e-01,  6.7969e-01,  8.1641e-01,  ...,  3.4375e-01,
          -1.2207e-01, -5.8984e-01],
         [-2.0215e-01,  7.8516e-01, -8.2812e-01,  ...,  1.2969e+00,
          -4.8047e-01, -6.5234e-01],
         ...,
         [-1.2500e+00,  1.3203e+00,  1.0449e-01,  ...,  2.0312e+00,
           3.2812e+00, -5.7422e-01],
         [-1.3828e+00,  5.7422e-01,  1.5547e+00,  ...,  1.3359e+00,
           7.7734e-01, -1.7969e+00],
         [-4.0820e-01,  1.7090e-01, -1.6250e+00,  ..., -1.4844e+00,
           2.6953e-01, -2.9102e-01]],

        [[-4.2725e-03,  5.8594e-03,  1.6937e-03,  ...,  5.0049e-03,
          -1.9775e-02, -1.0620e-02],
         [ 3.3203e-01,  6.7969e-01,  8.1641e-01,  ...,  3.4375e-01,
          -1.2207e-01, -5.8984e-01],
         [-2.0215e-01,  7.8516e-01, -8.2812e-01,  ...,  1.2969e+00,
          -4.8047e-01, -6.5234e-01],
         ...,
         [-1.2500e+00,  1.3203e+00,  1.0449e-01,  ...,  2.0312e+00,
           3.2812e+00, -5.7422e-01],
         [-1.3828e+00,  5.7422e-01,  1.5547e+00,  ...,  1.3359e+00,
           7.7734e-01, -1.7969e+00],
         [-4.0820e-01,  1.7090e-01, -1.6250e+00,  ..., -1.4844e+00,
           2.6953e-01, -2.9102e-01]],

        [[-4.2725e-03,  5.8594e-03,  1.6937e-03,  ...,  5.0049e-03,
          -1.9775e-02, -1.0620e-02],
         [ 3.3203e-01,  6.7969e-01,  8.1641e-01,  ...,  3.4375e-01,
          -1.2207e-01, -5.8984e-01],
         [-2.0215e-01,  7.8516e-01, -8.2812e-01,  ...,  1.2969e+00,
          -4.8047e-01, -6.5234e-01],
         ...,
         [-1.2500e+00,  1.3203e+00,  1.0449e-01,  ...,  2.0312e+00,
           3.2812e+00, -5.7422e-01],
         [-1.3828e+00,  5.7422e-01,  1.5547e+00,  ...,  1.3359e+00,
           7.7734e-01, -1.7969e+00],
         [-4.0820e-01,  1.7090e-01, -1.6250e+00,  ..., -1.4844e+00,
           2.6953e-01, -2.9102e-01]],

        ...,

        [[ 2.7588e-02, -9.9487e-03, -1.7700e-02,  ..., -6.4087e-03,
           1.4038e-03,  5.3711e-03],
         [ 2.7930e-01, -2.5977e-01,  7.3047e-01,  ...,  8.4961e-02,
           3.4570e-01, -1.5332e-01],
         [-8.1177e-03,  1.7578e+00, -4.1016e-01,  ...,  3.9062e-01,
           3.3398e-01,  4.3750e-01],
         ...,
         [-4.6680e-01, -3.0469e-01,  4.7656e-01,  ...,  3.2617e-01,
          -8.2422e-01,  3.4180e-01],
         [ 8.0859e-01, -1.5234e+00,  1.1406e+00,  ...,  3.2969e+00,
           1.2734e+00, -5.5469e-01],
         [ 3.6914e-01, -7.6172e-01,  1.7969e-01,  ..., -1.8750e-01,
          -4.5508e-01,  1.8457e-01]],

        [[ 2.7588e-02, -9.9487e-03, -1.7700e-02,  ..., -6.4087e-03,
           1.4038e-03,  5.3711e-03],
         [ 2.7930e-01, -2.5977e-01,  7.3047e-01,  ...,  8.4961e-02,
           3.4570e-01, -1.5332e-01],
         [-8.1177e-03,  1.7578e+00, -4.1016e-01,  ...,  3.9062e-01,
           3.3398e-01,  4.3750e-01],
         ...,
         [-4.6680e-01, -3.0469e-01,  4.7656e-01,  ...,  3.2617e-01,
          -8.2422e-01,  3.4180e-01],
         [ 8.0859e-01, -1.5234e+00,  1.1406e+00,  ...,  3.2969e+00,
           1.2734e+00, -5.5469e-01],
         [ 3.6914e-01, -7.6172e-01,  1.7969e-01,  ..., -1.8750e-01,
          -4.5508e-01,  1.8457e-01]],

        [[ 2.7588e-02, -9.9487e-03, -1.7700e-02,  ..., -6.4087e-03,
           1.4038e-03,  5.3711e-03],
         [ 2.7930e-01, -2.5977e-01,  7.3047e-01,  ...,  8.4961e-02,
           3.4570e-01, -1.5332e-01],
         [-8.1177e-03,  1.7578e+00, -4.1016e-01,  ...,  3.9062e-01,
           3.3398e-01,  4.3750e-01],
         ...,
         [-4.6680e-01, -3.0469e-01,  4.7656e-01,  ...,  3.2617e-01,
          -8.2422e-01,  3.4180e-01],
         [ 8.0859e-01, -1.5234e+00,  1.1406e+00,  ...,  3.2969e+00,
           1.2734e+00, -5.5469e-01],
         [ 3.6914e-01, -7.6172e-01,  1.7969e-01,  ..., -1.8750e-01,
          -4.5508e-01,  1.8457e-01]]], dtype=torch.bfloat16)), (tensor([[[-2.3346e-03,  9.5215e-03, -9.2163e-03,  ..., -1.6602e-01,
          -1.3672e-01,  2.4414e-01],
         [ 2.0312e+00, -6.0156e-01, -2.6953e-01,  ...,  1.7734e+00,
           7.8516e-01, -8.5547e-01],
         [ 1.6016e+00, -2.7734e-01, -2.1289e-01,  ...,  2.1406e+00,
           1.4551e-01, -5.0000e-01],
         ...,
         [ 9.3262e-02,  4.5312e-01,  2.9102e-01,  ...,  6.5918e-02,
          -1.6016e-01,  3.4961e-01],
         [-2.9688e-01,  3.9258e-01, -4.5117e-01,  ...,  4.1992e-02,
           1.5781e+00,  1.1016e+00],
         [-1.9688e+00,  1.1523e-01, -5.0781e-01,  ...,  5.5859e-01,
           3.7305e-01, -2.4414e-01]],

        [[-2.3346e-03,  9.5215e-03, -9.2163e-03,  ..., -1.6602e-01,
          -1.3672e-01,  2.4414e-01],
         [ 2.0312e+00, -6.0156e-01, -2.6953e-01,  ...,  1.7734e+00,
           7.8516e-01, -8.5547e-01],
         [ 1.6016e+00, -2.7734e-01, -2.1289e-01,  ...,  2.1406e+00,
           1.4551e-01, -5.0000e-01],
         ...,
         [ 9.3262e-02,  4.5312e-01,  2.9102e-01,  ...,  6.5918e-02,
          -1.6016e-01,  3.4961e-01],
         [-2.9688e-01,  3.9258e-01, -4.5117e-01,  ...,  4.1992e-02,
           1.5781e+00,  1.1016e+00],
         [-1.9688e+00,  1.1523e-01, -5.0781e-01,  ...,  5.5859e-01,
           3.7305e-01, -2.4414e-01]],

        [[-2.3346e-03,  9.5215e-03, -9.2163e-03,  ..., -1.6602e-01,
          -1.3672e-01,  2.4414e-01],
         [ 2.0312e+00, -6.0156e-01, -2.6953e-01,  ...,  1.7734e+00,
           7.8516e-01, -8.5547e-01],
         [ 1.6016e+00, -2.7734e-01, -2.1289e-01,  ...,  2.1406e+00,
           1.4551e-01, -5.0000e-01],
         ...,
         [ 9.3262e-02,  4.5312e-01,  2.9102e-01,  ...,  6.5918e-02,
          -1.6016e-01,  3.4961e-01],
         [-2.9688e-01,  3.9258e-01, -4.5117e-01,  ...,  4.1992e-02,
           1.5781e+00,  1.1016e+00],
         [-1.9688e+00,  1.1523e-01, -5.0781e-01,  ...,  5.5859e-01,
           3.7305e-01, -2.4414e-01]],

        ...,

        [[ 1.5137e-02, -7.3547e-03, -4.0894e-03,  ..., -2.2852e-01,
           8.4375e-01, -1.9141e-01],
         [-1.2422e+00, -1.0703e+00,  6.9922e-01,  ..., -3.3008e-01,
           1.2266e+00,  3.3125e+00],
         [-9.8438e-01, -7.5391e-01,  7.5000e-01,  ..., -1.3184e-01,
           2.2266e-01, -2.5938e+00],
         ...,
         [ 4.0625e-01,  1.1094e+00, -1.2012e-01,  ...,  3.0156e+00,
          -5.7422e-01,  7.6953e-01],
         [ 8.5547e-01,  8.7891e-01, -1.2695e-01,  ...,  4.1211e-01,
          -1.7188e-01,  4.9805e-01],
         [ 1.0078e+00,  1.0938e-01, -1.7344e+00,  ..., -6.4844e-01,
          -1.9629e-01, -1.8906e+00]],

        [[ 1.5137e-02, -7.3547e-03, -4.0894e-03,  ..., -2.2852e-01,
           8.4375e-01, -1.9141e-01],
         [-1.2422e+00, -1.0703e+00,  6.9922e-01,  ..., -3.3008e-01,
           1.2266e+00,  3.3125e+00],
         [-9.8438e-01, -7.5391e-01,  7.5000e-01,  ..., -1.3184e-01,
           2.2266e-01, -2.5938e+00],
         ...,
         [ 4.0625e-01,  1.1094e+00, -1.2012e-01,  ...,  3.0156e+00,
          -5.7422e-01,  7.6953e-01],
         [ 8.5547e-01,  8.7891e-01, -1.2695e-01,  ...,  4.1211e-01,
          -1.7188e-01,  4.9805e-01],
         [ 1.0078e+00,  1.0938e-01, -1.7344e+00,  ..., -6.4844e-01,
          -1.9629e-01, -1.8906e+00]],

        [[ 1.5137e-02, -7.3547e-03, -4.0894e-03,  ..., -2.2852e-01,
           8.4375e-01, -1.9141e-01],
         [-1.2422e+00, -1.0703e+00,  6.9922e-01,  ..., -3.3008e-01,
           1.2266e+00,  3.3125e+00],
         [-9.8438e-01, -7.5391e-01,  7.5000e-01,  ..., -1.3184e-01,
           2.2266e-01, -2.5938e+00],
         ...,
         [ 4.0625e-01,  1.1094e+00, -1.2012e-01,  ...,  3.0156e+00,
          -5.7422e-01,  7.6953e-01],
         [ 8.5547e-01,  8.7891e-01, -1.2695e-01,  ...,  4.1211e-01,
          -1.7188e-01,  4.9805e-01],
         [ 1.0078e+00,  1.0938e-01, -1.7344e+00,  ..., -6.4844e-01,
          -1.9629e-01, -1.8906e+00]]], dtype=torch.bfloat16), tensor([[[-7.1106e-03, -1.7944e-02,  7.7820e-03,  ...,  1.2756e-02,
          -1.6724e-02,  4.2725e-03],
         [-1.6797e-01, -6.9922e-01,  4.2969e-02,  ...,  7.4707e-02,
          -1.8066e-01, -3.6133e-02],
         [ 2.7539e-01,  1.1562e+00, -1.1875e+00,  ..., -6.7383e-02,
          -3.1055e-01,  2.6611e-02],
         ...,
         [ 1.6895e-01,  1.6211e-01, -1.1572e-01,  ...,  2.7734e-01,
          -5.3906e-01,  3.3789e-01],
         [-5.9766e-01,  5.2344e-01,  6.7969e-01,  ..., -1.1406e+00,
          -4.1602e-01, -5.8984e-01],
         [-2.8125e-01,  2.0938e+00,  1.3125e+00,  ...,  7.5000e-01,
           1.1797e+00,  2.0469e+00]],

        [[-7.1106e-03, -1.7944e-02,  7.7820e-03,  ...,  1.2756e-02,
          -1.6724e-02,  4.2725e-03],
         [-1.6797e-01, -6.9922e-01,  4.2969e-02,  ...,  7.4707e-02,
          -1.8066e-01, -3.6133e-02],
         [ 2.7539e-01,  1.1562e+00, -1.1875e+00,  ..., -6.7383e-02,
          -3.1055e-01,  2.6611e-02],
         ...,
         [ 1.6895e-01,  1.6211e-01, -1.1572e-01,  ...,  2.7734e-01,
          -5.3906e-01,  3.3789e-01],
         [-5.9766e-01,  5.2344e-01,  6.7969e-01,  ..., -1.1406e+00,
          -4.1602e-01, -5.8984e-01],
         [-2.8125e-01,  2.0938e+00,  1.3125e+00,  ...,  7.5000e-01,
           1.1797e+00,  2.0469e+00]],

        [[-7.1106e-03, -1.7944e-02,  7.7820e-03,  ...,  1.2756e-02,
          -1.6724e-02,  4.2725e-03],
         [-1.6797e-01, -6.9922e-01,  4.2969e-02,  ...,  7.4707e-02,
          -1.8066e-01, -3.6133e-02],
         [ 2.7539e-01,  1.1562e+00, -1.1875e+00,  ..., -6.7383e-02,
          -3.1055e-01,  2.6611e-02],
         ...,
         [ 1.6895e-01,  1.6211e-01, -1.1572e-01,  ...,  2.7734e-01,
          -5.3906e-01,  3.3789e-01],
         [-5.9766e-01,  5.2344e-01,  6.7969e-01,  ..., -1.1406e+00,
          -4.1602e-01, -5.8984e-01],
         [-2.8125e-01,  2.0938e+00,  1.3125e+00,  ...,  7.5000e-01,
           1.1797e+00,  2.0469e+00]],

        ...,

        [[-3.6469e-03, -8.3618e-03,  3.2349e-03,  ...,  2.0504e-04,
          -3.1128e-03,  1.9836e-03],
         [-2.4844e+00,  9.6484e-01, -8.0469e-01,  ...,  1.8047e+00,
           2.6562e+00, -2.4688e+00],
         [-8.5938e-01, -1.7344e+00, -9.7266e-01,  ...,  6.6797e-01,
           5.0781e-01, -2.2500e+00],
         ...,
         [ 1.1797e+00, -8.7891e-01, -5.4297e-01,  ...,  6.0156e-01,
          -9.7656e-01,  5.5469e-01],
         [-4.7656e-01, -2.6562e-01, -1.3672e+00,  ...,  1.3516e+00,
           1.1865e-01, -8.7109e-01],
         [-7.6562e-01, -3.3789e-01, -2.7344e+00,  ...,  1.2031e+00,
          -1.1797e+00,  7.7637e-02]],

        [[-3.6469e-03, -8.3618e-03,  3.2349e-03,  ...,  2.0504e-04,
          -3.1128e-03,  1.9836e-03],
         [-2.4844e+00,  9.6484e-01, -8.0469e-01,  ...,  1.8047e+00,
           2.6562e+00, -2.4688e+00],
         [-8.5938e-01, -1.7344e+00, -9.7266e-01,  ...,  6.6797e-01,
           5.0781e-01, -2.2500e+00],
         ...,
         [ 1.1797e+00, -8.7891e-01, -5.4297e-01,  ...,  6.0156e-01,
          -9.7656e-01,  5.5469e-01],
         [-4.7656e-01, -2.6562e-01, -1.3672e+00,  ...,  1.3516e+00,
           1.1865e-01, -8.7109e-01],
         [-7.6562e-01, -3.3789e-01, -2.7344e+00,  ...,  1.2031e+00,
          -1.1797e+00,  7.7637e-02]],

        [[-3.6469e-03, -8.3618e-03,  3.2349e-03,  ...,  2.0504e-04,
          -3.1128e-03,  1.9836e-03],
         [-2.4844e+00,  9.6484e-01, -8.0469e-01,  ...,  1.8047e+00,
           2.6562e+00, -2.4688e+00],
         [-8.5938e-01, -1.7344e+00, -9.7266e-01,  ...,  6.6797e-01,
           5.0781e-01, -2.2500e+00],
         ...,
         [ 1.1797e+00, -8.7891e-01, -5.4297e-01,  ...,  6.0156e-01,
          -9.7656e-01,  5.5469e-01],
         [-4.7656e-01, -2.6562e-01, -1.3672e+00,  ...,  1.3516e+00,
           1.1865e-01, -8.7109e-01],
         [-7.6562e-01, -3.3789e-01, -2.7344e+00,  ...,  1.2031e+00,
          -1.1797e+00,  7.7637e-02]]], dtype=torch.bfloat16)), (tensor([[[ 6.3324e-04,  3.0518e-03, -2.5635e-03,  ..., -1.9043e-01,
          -1.2695e-01,  1.2598e-01],
         [ 0.0000e+00,  9.1406e-01, -2.5000e+00,  ...,  1.4062e-01,
           6.0938e-01, -8.7891e-01],
         [-1.0078e+00,  1.2500e+00, -2.9492e-01,  ..., -2.1680e-01,
          -4.3555e-01,  2.7734e-01],
         ...,
         [-3.6719e-01, -8.8672e-01,  2.0801e-01,  ..., -2.3633e-01,
          -5.4688e-01, -6.9922e-01],
         [ 2.1973e-01, -9.0234e-01, -9.9609e-02,  ..., -1.8672e+00,
          -4.6289e-01,  7.0312e-01],
         [ 1.6328e+00, -1.4453e+00, -4.0820e-01,  ...,  1.0498e-01,
          -1.6797e+00, -1.1562e+00]],

        [[ 6.3324e-04,  3.0518e-03, -2.5635e-03,  ..., -1.9043e-01,
          -1.2695e-01,  1.2598e-01],
         [ 0.0000e+00,  9.1406e-01, -2.5000e+00,  ...,  1.4062e-01,
           6.0938e-01, -8.7891e-01],
         [-1.0078e+00,  1.2500e+00, -2.9492e-01,  ..., -2.1680e-01,
          -4.3555e-01,  2.7734e-01],
         ...,
         [-3.6719e-01, -8.8672e-01,  2.0801e-01,  ..., -2.3633e-01,
          -5.4688e-01, -6.9922e-01],
         [ 2.1973e-01, -9.0234e-01, -9.9609e-02,  ..., -1.8672e+00,
          -4.6289e-01,  7.0312e-01],
         [ 1.6328e+00, -1.4453e+00, -4.0820e-01,  ...,  1.0498e-01,
          -1.6797e+00, -1.1562e+00]],

        [[ 6.3324e-04,  3.0518e-03, -2.5635e-03,  ..., -1.9043e-01,
          -1.2695e-01,  1.2598e-01],
         [ 0.0000e+00,  9.1406e-01, -2.5000e+00,  ...,  1.4062e-01,
           6.0938e-01, -8.7891e-01],
         [-1.0078e+00,  1.2500e+00, -2.9492e-01,  ..., -2.1680e-01,
          -4.3555e-01,  2.7734e-01],
         ...,
         [-3.6719e-01, -8.8672e-01,  2.0801e-01,  ..., -2.3633e-01,
          -5.4688e-01, -6.9922e-01],
         [ 2.1973e-01, -9.0234e-01, -9.9609e-02,  ..., -1.8672e+00,
          -4.6289e-01,  7.0312e-01],
         [ 1.6328e+00, -1.4453e+00, -4.0820e-01,  ...,  1.0498e-01,
          -1.6797e+00, -1.1562e+00]],

        ...,

        [[-1.1978e-03, -1.1841e-02, -8.0566e-03,  ...,  1.3965e-01,
           5.9082e-02,  1.1865e-01],
         [-2.5781e+00, -1.0078e+00,  1.7578e+00,  ..., -2.5156e+00,
           2.7734e-01, -6.7188e-01],
         [-1.5938e+00, -1.5312e+00,  1.4609e+00,  ..., -3.6719e-01,
           1.8125e+00, -1.7188e+00],
         ...,
         [-1.1719e+00,  1.3594e+00,  1.4609e+00,  ...,  7.3047e-01,
          -2.5469e+00,  2.1562e+00],
         [ 8.1250e-01,  6.4844e-01,  1.2031e+00,  ...,  1.8281e+00,
          -3.8672e-01,  2.3594e+00],
         [ 2.1875e+00, -5.3125e-01,  1.3203e+00,  ..., -3.3984e-01,
           1.5332e-01, -8.5449e-03]],

        [[-1.1978e-03, -1.1841e-02, -8.0566e-03,  ...,  1.3965e-01,
           5.9082e-02,  1.1865e-01],
         [-2.5781e+00, -1.0078e+00,  1.7578e+00,  ..., -2.5156e+00,
           2.7734e-01, -6.7188e-01],
         [-1.5938e+00, -1.5312e+00,  1.4609e+00,  ..., -3.6719e-01,
           1.8125e+00, -1.7188e+00],
         ...,
         [-1.1719e+00,  1.3594e+00,  1.4609e+00,  ...,  7.3047e-01,
          -2.5469e+00,  2.1562e+00],
         [ 8.1250e-01,  6.4844e-01,  1.2031e+00,  ...,  1.8281e+00,
          -3.8672e-01,  2.3594e+00],
         [ 2.1875e+00, -5.3125e-01,  1.3203e+00,  ..., -3.3984e-01,
           1.5332e-01, -8.5449e-03]],

        [[-1.1978e-03, -1.1841e-02, -8.0566e-03,  ...,  1.3965e-01,
           5.9082e-02,  1.1865e-01],
         [-2.5781e+00, -1.0078e+00,  1.7578e+00,  ..., -2.5156e+00,
           2.7734e-01, -6.7188e-01],
         [-1.5938e+00, -1.5312e+00,  1.4609e+00,  ..., -3.6719e-01,
           1.8125e+00, -1.7188e+00],
         ...,
         [-1.1719e+00,  1.3594e+00,  1.4609e+00,  ...,  7.3047e-01,
          -2.5469e+00,  2.1562e+00],
         [ 8.1250e-01,  6.4844e-01,  1.2031e+00,  ...,  1.8281e+00,
          -3.8672e-01,  2.3594e+00],
         [ 2.1875e+00, -5.3125e-01,  1.3203e+00,  ..., -3.3984e-01,
           1.5332e-01, -8.5449e-03]]], dtype=torch.bfloat16), tensor([[[ 2.4658e-02, -2.5558e-04, -4.5166e-03,  ..., -9.0942e-03,
           1.7822e-02, -9.8877e-03],
         [-3.9648e-01,  2.2656e-01,  2.7344e-01,  ..., -8.3984e-01,
           4.1602e-01, -3.8086e-01],
         [-1.7383e-01,  4.6094e-01, -1.5859e+00,  ...,  3.0469e-01,
          -1.4375e+00, -8.2812e-01],
         ...,
         [ 8.8281e-01,  2.2344e+00,  2.2656e+00,  ...,  2.0117e-01,
           1.2578e+00,  9.7656e-01],
         [ 1.3047e+00,  1.3203e+00,  2.3730e-01,  ..., -4.3945e-01,
           7.7344e-01, -8.5449e-02],
         [-8.9453e-01, -9.7168e-02, -4.8242e-01,  ...,  3.5156e-01,
           1.3438e+00, -3.8086e-01]],

        [[ 2.4658e-02, -2.5558e-04, -4.5166e-03,  ..., -9.0942e-03,
           1.7822e-02, -9.8877e-03],
         [-3.9648e-01,  2.2656e-01,  2.7344e-01,  ..., -8.3984e-01,
           4.1602e-01, -3.8086e-01],
         [-1.7383e-01,  4.6094e-01, -1.5859e+00,  ...,  3.0469e-01,
          -1.4375e+00, -8.2812e-01],
         ...,
         [ 8.8281e-01,  2.2344e+00,  2.2656e+00,  ...,  2.0117e-01,
           1.2578e+00,  9.7656e-01],
         [ 1.3047e+00,  1.3203e+00,  2.3730e-01,  ..., -4.3945e-01,
           7.7344e-01, -8.5449e-02],
         [-8.9453e-01, -9.7168e-02, -4.8242e-01,  ...,  3.5156e-01,
           1.3438e+00, -3.8086e-01]],

        [[ 2.4658e-02, -2.5558e-04, -4.5166e-03,  ..., -9.0942e-03,
           1.7822e-02, -9.8877e-03],
         [-3.9648e-01,  2.2656e-01,  2.7344e-01,  ..., -8.3984e-01,
           4.1602e-01, -3.8086e-01],
         [-1.7383e-01,  4.6094e-01, -1.5859e+00,  ...,  3.0469e-01,
          -1.4375e+00, -8.2812e-01],
         ...,
         [ 8.8281e-01,  2.2344e+00,  2.2656e+00,  ...,  2.0117e-01,
           1.2578e+00,  9.7656e-01],
         [ 1.3047e+00,  1.3203e+00,  2.3730e-01,  ..., -4.3945e-01,
           7.7344e-01, -8.5449e-02],
         [-8.9453e-01, -9.7168e-02, -4.8242e-01,  ...,  3.5156e-01,
           1.3438e+00, -3.8086e-01]],

        ...,

        [[-1.2085e-02, -7.4463e-03, -1.9836e-03,  ..., -1.0315e-02,
          -3.5858e-03, -1.2756e-02],
         [-3.5742e-01, -1.9336e-01, -2.2656e-01,  ..., -8.3984e-01,
           5.3906e-01, -1.9141e+00],
         [-9.7656e-03, -3.3594e-01,  3.7109e-02,  ...,  8.7109e-01,
          -7.7734e-01, -7.7734e-01],
         ...,
         [ 4.0625e-01,  1.0781e+00, -8.5938e-01,  ..., -9.2578e-01,
           1.7422e+00,  4.0234e-01],
         [ 3.3398e-01, -1.2500e-01, -1.4531e+00,  ...,  3.7305e-01,
           8.7500e-01,  1.2256e-01],
         [-9.5703e-01, -7.2266e-01, -2.3730e-01,  ..., -5.7031e-01,
           3.2031e-01,  2.2070e-01]],

        [[-1.2085e-02, -7.4463e-03, -1.9836e-03,  ..., -1.0315e-02,
          -3.5858e-03, -1.2756e-02],
         [-3.5742e-01, -1.9336e-01, -2.2656e-01,  ..., -8.3984e-01,
           5.3906e-01, -1.9141e+00],
         [-9.7656e-03, -3.3594e-01,  3.7109e-02,  ...,  8.7109e-01,
          -7.7734e-01, -7.7734e-01],
         ...,
         [ 4.0625e-01,  1.0781e+00, -8.5938e-01,  ..., -9.2578e-01,
           1.7422e+00,  4.0234e-01],
         [ 3.3398e-01, -1.2500e-01, -1.4531e+00,  ...,  3.7305e-01,
           8.7500e-01,  1.2256e-01],
         [-9.5703e-01, -7.2266e-01, -2.3730e-01,  ..., -5.7031e-01,
           3.2031e-01,  2.2070e-01]],

        [[-1.2085e-02, -7.4463e-03, -1.9836e-03,  ..., -1.0315e-02,
          -3.5858e-03, -1.2756e-02],
         [-3.5742e-01, -1.9336e-01, -2.2656e-01,  ..., -8.3984e-01,
           5.3906e-01, -1.9141e+00],
         [-9.7656e-03, -3.3594e-01,  3.7109e-02,  ...,  8.7109e-01,
          -7.7734e-01, -7.7734e-01],
         ...,
         [ 4.0625e-01,  1.0781e+00, -8.5938e-01,  ..., -9.2578e-01,
           1.7422e+00,  4.0234e-01],
         [ 3.3398e-01, -1.2500e-01, -1.4531e+00,  ...,  3.7305e-01,
           8.7500e-01,  1.2256e-01],
         [-9.5703e-01, -7.2266e-01, -2.3730e-01,  ..., -5.7031e-01,
           3.2031e-01,  2.2070e-01]]], dtype=torch.bfloat16)), (tensor([[[-3.1433e-03,  4.5776e-03,  1.3306e-02,  ..., -1.6016e-01,
           2.3535e-01,  3.0273e-01],
         [-3.1719e+00, -1.2656e+00, -1.2812e+00,  ...,  2.9883e-01,
          -2.4805e-01, -5.2734e-01],
         [-1.9141e-01, -8.1250e-01, -4.2969e-01,  ..., -3.5889e-02,
          -1.2031e+00, -1.0234e+00],
         ...,
         [ 1.4609e+00,  1.1875e+00, -1.1172e+00,  ...,  9.9609e-01,
          -6.0156e-01, -8.4375e-01],
         [ 1.9062e+00,  4.4727e-01, -1.7188e+00,  ...,  9.2969e-01,
          -5.7812e-01, -2.2266e-01],
         [ 2.6562e-01, -1.5078e+00, -5.0391e-01,  ...,  4.1602e-01,
           1.3281e+00, -1.7344e+00]],

        [[-3.1433e-03,  4.5776e-03,  1.3306e-02,  ..., -1.6016e-01,
           2.3535e-01,  3.0273e-01],
         [-3.1719e+00, -1.2656e+00, -1.2812e+00,  ...,  2.9883e-01,
          -2.4805e-01, -5.2734e-01],
         [-1.9141e-01, -8.1250e-01, -4.2969e-01,  ..., -3.5889e-02,
          -1.2031e+00, -1.0234e+00],
         ...,
         [ 1.4609e+00,  1.1875e+00, -1.1172e+00,  ...,  9.9609e-01,
          -6.0156e-01, -8.4375e-01],
         [ 1.9062e+00,  4.4727e-01, -1.7188e+00,  ...,  9.2969e-01,
          -5.7812e-01, -2.2266e-01],
         [ 2.6562e-01, -1.5078e+00, -5.0391e-01,  ...,  4.1602e-01,
           1.3281e+00, -1.7344e+00]],

        [[-3.1433e-03,  4.5776e-03,  1.3306e-02,  ..., -1.6016e-01,
           2.3535e-01,  3.0273e-01],
         [-3.1719e+00, -1.2656e+00, -1.2812e+00,  ...,  2.9883e-01,
          -2.4805e-01, -5.2734e-01],
         [-1.9141e-01, -8.1250e-01, -4.2969e-01,  ..., -3.5889e-02,
          -1.2031e+00, -1.0234e+00],
         ...,
         [ 1.4609e+00,  1.1875e+00, -1.1172e+00,  ...,  9.9609e-01,
          -6.0156e-01, -8.4375e-01],
         [ 1.9062e+00,  4.4727e-01, -1.7188e+00,  ...,  9.2969e-01,
          -5.7812e-01, -2.2266e-01],
         [ 2.6562e-01, -1.5078e+00, -5.0391e-01,  ...,  4.1602e-01,
           1.3281e+00, -1.7344e+00]],

        ...,

        [[ 6.1798e-04,  9.0027e-04, -2.0996e-02,  ...,  1.0681e-03,
          -4.8828e-02, -1.9336e-01],
         [-3.7031e+00, -1.8047e+00, -6.6406e-02,  ...,  2.6562e-01,
           1.8594e+00,  1.1914e-01],
         [-7.6953e-01, -1.3828e+00,  2.4414e-02,  ...,  5.9375e-01,
          -1.5391e+00, -1.1484e+00],
         ...,
         [ 1.7500e+00,  5.2344e-01, -1.0391e+00,  ...,  1.0312e+00,
           8.5547e-01, -4.9805e-01],
         [ 9.5312e-01,  3.4375e-01,  1.3281e-01,  ...,  3.4961e-01,
          -2.2344e+00, -6.3672e-01],
         [ 8.1641e-01, -8.9844e-01,  4.1016e-01,  ...,  1.2969e+00,
           3.2031e-01, -1.2734e+00]],

        [[ 6.1798e-04,  9.0027e-04, -2.0996e-02,  ...,  1.0681e-03,
          -4.8828e-02, -1.9336e-01],
         [-3.7031e+00, -1.8047e+00, -6.6406e-02,  ...,  2.6562e-01,
           1.8594e+00,  1.1914e-01],
         [-7.6953e-01, -1.3828e+00,  2.4414e-02,  ...,  5.9375e-01,
          -1.5391e+00, -1.1484e+00],
         ...,
         [ 1.7500e+00,  5.2344e-01, -1.0391e+00,  ...,  1.0312e+00,
           8.5547e-01, -4.9805e-01],
         [ 9.5312e-01,  3.4375e-01,  1.3281e-01,  ...,  3.4961e-01,
          -2.2344e+00, -6.3672e-01],
         [ 8.1641e-01, -8.9844e-01,  4.1016e-01,  ...,  1.2969e+00,
           3.2031e-01, -1.2734e+00]],

        [[ 6.1798e-04,  9.0027e-04, -2.0996e-02,  ...,  1.0681e-03,
          -4.8828e-02, -1.9336e-01],
         [-3.7031e+00, -1.8047e+00, -6.6406e-02,  ...,  2.6562e-01,
           1.8594e+00,  1.1914e-01],
         [-7.6953e-01, -1.3828e+00,  2.4414e-02,  ...,  5.9375e-01,
          -1.5391e+00, -1.1484e+00],
         ...,
         [ 1.7500e+00,  5.2344e-01, -1.0391e+00,  ...,  1.0312e+00,
           8.5547e-01, -4.9805e-01],
         [ 9.5312e-01,  3.4375e-01,  1.3281e-01,  ...,  3.4961e-01,
          -2.2344e+00, -6.3672e-01],
         [ 8.1641e-01, -8.9844e-01,  4.1016e-01,  ...,  1.2969e+00,
           3.2031e-01, -1.2734e+00]]], dtype=torch.bfloat16), tensor([[[ 1.0925e-02, -2.8229e-03, -1.0071e-02,  ..., -1.0803e-02,
           2.3193e-02, -4.3030e-03],
         [-3.6328e-01,  2.4512e-01, -1.2031e+00,  ..., -1.0498e-02,
           2.5781e-01, -9.3359e-01],
         [-7.7637e-02,  4.9219e-01, -2.6367e-01,  ...,  1.1016e+00,
          -1.1562e+00, -8.4473e-02],
         ...,
         [-9.2578e-01,  8.5938e-01, -5.4297e-01,  ...,  4.4727e-01,
          -5.5469e-01, -4.5312e-01],
         [-2.0996e-02, -5.0391e-01, -8.0566e-02,  ...,  7.0312e-01,
          -1.3906e+00,  1.2812e+00],
         [ 2.0020e-01,  5.0781e-01, -1.0547e+00,  ...,  5.4688e-01,
          -9.6484e-01,  7.9688e-01]],

        [[ 1.0925e-02, -2.8229e-03, -1.0071e-02,  ..., -1.0803e-02,
           2.3193e-02, -4.3030e-03],
         [-3.6328e-01,  2.4512e-01, -1.2031e+00,  ..., -1.0498e-02,
           2.5781e-01, -9.3359e-01],
         [-7.7637e-02,  4.9219e-01, -2.6367e-01,  ...,  1.1016e+00,
          -1.1562e+00, -8.4473e-02],
         ...,
         [-9.2578e-01,  8.5938e-01, -5.4297e-01,  ...,  4.4727e-01,
          -5.5469e-01, -4.5312e-01],
         [-2.0996e-02, -5.0391e-01, -8.0566e-02,  ...,  7.0312e-01,
          -1.3906e+00,  1.2812e+00],
         [ 2.0020e-01,  5.0781e-01, -1.0547e+00,  ...,  5.4688e-01,
          -9.6484e-01,  7.9688e-01]],

        [[ 1.0925e-02, -2.8229e-03, -1.0071e-02,  ..., -1.0803e-02,
           2.3193e-02, -4.3030e-03],
         [-3.6328e-01,  2.4512e-01, -1.2031e+00,  ..., -1.0498e-02,
           2.5781e-01, -9.3359e-01],
         [-7.7637e-02,  4.9219e-01, -2.6367e-01,  ...,  1.1016e+00,
          -1.1562e+00, -8.4473e-02],
         ...,
         [-9.2578e-01,  8.5938e-01, -5.4297e-01,  ...,  4.4727e-01,
          -5.5469e-01, -4.5312e-01],
         [-2.0996e-02, -5.0391e-01, -8.0566e-02,  ...,  7.0312e-01,
          -1.3906e+00,  1.2812e+00],
         [ 2.0020e-01,  5.0781e-01, -1.0547e+00,  ...,  5.4688e-01,
          -9.6484e-01,  7.9688e-01]],

        ...,

        [[-6.8283e-04, -3.7079e-03,  7.7820e-04,  ...,  2.1118e-02,
           6.7444e-03,  5.3406e-03],
         [-2.0508e-01, -2.8516e-01,  4.9609e-01,  ...,  1.5938e+00,
           1.1719e+00,  5.5078e-01],
         [-2.8125e-01, -1.6992e-01, -5.9766e-01,  ...,  4.9414e-01,
           1.4160e-02,  4.6289e-01],
         ...,
         [ 1.8164e-01, -1.5312e+00, -2.0117e-01,  ...,  1.2188e+00,
          -2.3594e+00,  3.1055e-01],
         [ 7.1777e-02,  7.4609e-01,  1.0156e+00,  ..., -2.9297e-01,
          -8.0859e-01,  1.7031e+00],
         [-5.6250e-01,  8.3496e-02,  2.1973e-01,  ..., -5.0391e-01,
           2.4121e-01, -2.0142e-02]],

        [[-6.8283e-04, -3.7079e-03,  7.7820e-04,  ...,  2.1118e-02,
           6.7444e-03,  5.3406e-03],
         [-2.0508e-01, -2.8516e-01,  4.9609e-01,  ...,  1.5938e+00,
           1.1719e+00,  5.5078e-01],
         [-2.8125e-01, -1.6992e-01, -5.9766e-01,  ...,  4.9414e-01,
           1.4160e-02,  4.6289e-01],
         ...,
         [ 1.8164e-01, -1.5312e+00, -2.0117e-01,  ...,  1.2188e+00,
          -2.3594e+00,  3.1055e-01],
         [ 7.1777e-02,  7.4609e-01,  1.0156e+00,  ..., -2.9297e-01,
          -8.0859e-01,  1.7031e+00],
         [-5.6250e-01,  8.3496e-02,  2.1973e-01,  ..., -5.0391e-01,
           2.4121e-01, -2.0142e-02]],

        [[-6.8283e-04, -3.7079e-03,  7.7820e-04,  ...,  2.1118e-02,
           6.7444e-03,  5.3406e-03],
         [-2.0508e-01, -2.8516e-01,  4.9609e-01,  ...,  1.5938e+00,
           1.1719e+00,  5.5078e-01],
         [-2.8125e-01, -1.6992e-01, -5.9766e-01,  ...,  4.9414e-01,
           1.4160e-02,  4.6289e-01],
         ...,
         [ 1.8164e-01, -1.5312e+00, -2.0117e-01,  ...,  1.2188e+00,
          -2.3594e+00,  3.1055e-01],
         [ 7.1777e-02,  7.4609e-01,  1.0156e+00,  ..., -2.9297e-01,
          -8.0859e-01,  1.7031e+00],
         [-5.6250e-01,  8.3496e-02,  2.1973e-01,  ..., -5.0391e-01,
           2.4121e-01, -2.0142e-02]]], dtype=torch.bfloat16)), (tensor([[[ 1.0437e-02, -6.3782e-03, -7.9346e-03,  ..., -1.8750e+00,
           8.4473e-02,  7.5781e-01],
         [-4.1016e-01,  6.1719e-01,  5.3906e-01,  ...,  1.1016e+00,
           1.3281e-01, -5.6562e+00],
         [ 4.3945e-01, -1.9336e-01,  2.6562e-01,  ...,  4.3438e+00,
          -3.1250e+00, -2.8594e+00],
         ...,
         [-1.1816e-01, -6.2109e-01, -6.2891e-01,  ...,  3.6562e+00,
          -1.0469e+00, -5.0391e-01],
         [-3.8281e-01,  2.4609e-01, -2.2461e-02,  ...,  4.1562e+00,
           1.1914e-01, -2.3281e+00],
         [ 4.9805e-01, -4.1016e-01,  5.5078e-01,  ...,  3.8125e+00,
           3.5352e-01, -2.0781e+00]],

        [[ 1.0437e-02, -6.3782e-03, -7.9346e-03,  ..., -1.8750e+00,
           8.4473e-02,  7.5781e-01],
         [-4.1016e-01,  6.1719e-01,  5.3906e-01,  ...,  1.1016e+00,
           1.3281e-01, -5.6562e+00],
         [ 4.3945e-01, -1.9336e-01,  2.6562e-01,  ...,  4.3438e+00,
          -3.1250e+00, -2.8594e+00],
         ...,
         [-1.1816e-01, -6.2109e-01, -6.2891e-01,  ...,  3.6562e+00,
          -1.0469e+00, -5.0391e-01],
         [-3.8281e-01,  2.4609e-01, -2.2461e-02,  ...,  4.1562e+00,
           1.1914e-01, -2.3281e+00],
         [ 4.9805e-01, -4.1016e-01,  5.5078e-01,  ...,  3.8125e+00,
           3.5352e-01, -2.0781e+00]],

        [[ 1.0437e-02, -6.3782e-03, -7.9346e-03,  ..., -1.8750e+00,
           8.4473e-02,  7.5781e-01],
         [-4.1016e-01,  6.1719e-01,  5.3906e-01,  ...,  1.1016e+00,
           1.3281e-01, -5.6562e+00],
         [ 4.3945e-01, -1.9336e-01,  2.6562e-01,  ...,  4.3438e+00,
          -3.1250e+00, -2.8594e+00],
         ...,
         [-1.1816e-01, -6.2109e-01, -6.2891e-01,  ...,  3.6562e+00,
          -1.0469e+00, -5.0391e-01],
         [-3.8281e-01,  2.4609e-01, -2.2461e-02,  ...,  4.1562e+00,
           1.1914e-01, -2.3281e+00],
         [ 4.9805e-01, -4.1016e-01,  5.5078e-01,  ...,  3.8125e+00,
           3.5352e-01, -2.0781e+00]],

        ...,

        [[-6.0730e-03, -9.9487e-03, -2.1057e-03,  ...,  1.2158e-01,
           2.1484e-01,  6.1646e-03],
         [-1.2656e+00, -9.0625e-01,  1.5625e+00,  ...,  1.6641e+00,
           8.0469e-01, -7.3438e-01],
         [-1.6211e-01,  7.6172e-02,  1.0312e+00,  ...,  5.5625e+00,
           3.5312e+00,  3.0664e-01],
         ...,
         [ 1.1670e-01,  7.7344e-01,  5.9375e-01,  ...,  3.0312e+00,
           3.0000e+00,  2.7969e+00],
         [-1.6211e-01,  8.7891e-01,  1.2109e+00,  ...,  4.0625e+00,
           4.0312e+00,  7.2188e+00],
         [ 1.2734e+00,  9.9219e-01,  6.0156e-01,  ...,  4.4062e+00,
           2.1875e+00,  2.8594e+00]],

        [[-6.0730e-03, -9.9487e-03, -2.1057e-03,  ...,  1.2158e-01,
           2.1484e-01,  6.1646e-03],
         [-1.2656e+00, -9.0625e-01,  1.5625e+00,  ...,  1.6641e+00,
           8.0469e-01, -7.3438e-01],
         [-1.6211e-01,  7.6172e-02,  1.0312e+00,  ...,  5.5625e+00,
           3.5312e+00,  3.0664e-01],
         ...,
         [ 1.1670e-01,  7.7344e-01,  5.9375e-01,  ...,  3.0312e+00,
           3.0000e+00,  2.7969e+00],
         [-1.6211e-01,  8.7891e-01,  1.2109e+00,  ...,  4.0625e+00,
           4.0312e+00,  7.2188e+00],
         [ 1.2734e+00,  9.9219e-01,  6.0156e-01,  ...,  4.4062e+00,
           2.1875e+00,  2.8594e+00]],

        [[-6.0730e-03, -9.9487e-03, -2.1057e-03,  ...,  1.2158e-01,
           2.1484e-01,  6.1646e-03],
         [-1.2656e+00, -9.0625e-01,  1.5625e+00,  ...,  1.6641e+00,
           8.0469e-01, -7.3438e-01],
         [-1.6211e-01,  7.6172e-02,  1.0312e+00,  ...,  5.5625e+00,
           3.5312e+00,  3.0664e-01],
         ...,
         [ 1.1670e-01,  7.7344e-01,  5.9375e-01,  ...,  3.0312e+00,
           3.0000e+00,  2.7969e+00],
         [-1.6211e-01,  8.7891e-01,  1.2109e+00,  ...,  4.0625e+00,
           4.0312e+00,  7.2188e+00],
         [ 1.2734e+00,  9.9219e-01,  6.0156e-01,  ...,  4.4062e+00,
           2.1875e+00,  2.8594e+00]]], dtype=torch.bfloat16), tensor([[[-1.1658e-02,  3.4332e-03,  1.4099e-02,  ...,  2.7954e-02,
           8.4229e-03, -1.0010e-02],
         [-8.8672e-01,  2.0156e+00, -1.7500e+00,  ..., -9.3750e-01,
           2.1719e+00,  2.9531e+00],
         [ 1.2969e+00, -3.3750e+00,  9.1406e-01,  ...,  8.5156e-01,
           4.4727e-01,  2.2031e+00],
         ...,
         [ 7.4609e-01,  1.5469e+00, -8.4375e-01,  ..., -7.7344e-01,
           7.9297e-01, -1.8433e-02],
         [ 4.6484e-01, -8.1250e-01, -7.1289e-02,  ...,  4.9219e-01,
          -4.0039e-01,  2.4316e-01],
         [ 2.3438e+00, -8.5156e-01,  4.4727e-01,  ..., -3.8867e-01,
          -8.7891e-01, -5.0293e-02]],

        [[-1.1658e-02,  3.4332e-03,  1.4099e-02,  ...,  2.7954e-02,
           8.4229e-03, -1.0010e-02],
         [-8.8672e-01,  2.0156e+00, -1.7500e+00,  ..., -9.3750e-01,
           2.1719e+00,  2.9531e+00],
         [ 1.2969e+00, -3.3750e+00,  9.1406e-01,  ...,  8.5156e-01,
           4.4727e-01,  2.2031e+00],
         ...,
         [ 7.4609e-01,  1.5469e+00, -8.4375e-01,  ..., -7.7344e-01,
           7.9297e-01, -1.8433e-02],
         [ 4.6484e-01, -8.1250e-01, -7.1289e-02,  ...,  4.9219e-01,
          -4.0039e-01,  2.4316e-01],
         [ 2.3438e+00, -8.5156e-01,  4.4727e-01,  ..., -3.8867e-01,
          -8.7891e-01, -5.0293e-02]],

        [[-1.1658e-02,  3.4332e-03,  1.4099e-02,  ...,  2.7954e-02,
           8.4229e-03, -1.0010e-02],
         [-8.8672e-01,  2.0156e+00, -1.7500e+00,  ..., -9.3750e-01,
           2.1719e+00,  2.9531e+00],
         [ 1.2969e+00, -3.3750e+00,  9.1406e-01,  ...,  8.5156e-01,
           4.4727e-01,  2.2031e+00],
         ...,
         [ 7.4609e-01,  1.5469e+00, -8.4375e-01,  ..., -7.7344e-01,
           7.9297e-01, -1.8433e-02],
         [ 4.6484e-01, -8.1250e-01, -7.1289e-02,  ...,  4.9219e-01,
          -4.0039e-01,  2.4316e-01],
         [ 2.3438e+00, -8.5156e-01,  4.4727e-01,  ..., -3.8867e-01,
          -8.7891e-01, -5.0293e-02]],

        ...,

        [[ 2.0599e-03,  1.2146e-02, -2.2095e-02,  ...,  1.0620e-02,
          -2.1057e-03,  3.3264e-03],
         [ 2.0000e+00,  6.7578e-01,  7.5391e-01,  ...,  1.5469e+00,
           4.0527e-02, -1.3438e+00],
         [ 4.1797e-01,  1.6406e+00,  1.8672e+00,  ...,  1.5234e+00,
          -2.3281e+00, -2.5625e+00],
         ...,
         [ 4.7500e+00,  3.3594e+00,  2.2188e+00,  ...,  1.7656e+00,
          -4.9688e+00, -2.3281e+00],
         [ 6.2812e+00,  3.4844e+00,  3.2188e+00,  ...,  4.0000e+00,
          -5.7500e+00, -4.1250e+00],
         [ 4.0000e+00,  1.9141e+00,  1.7109e+00,  ...,  3.0000e+00,
          -4.2500e+00, -2.1094e+00]],

        [[ 2.0599e-03,  1.2146e-02, -2.2095e-02,  ...,  1.0620e-02,
          -2.1057e-03,  3.3264e-03],
         [ 2.0000e+00,  6.7578e-01,  7.5391e-01,  ...,  1.5469e+00,
           4.0527e-02, -1.3438e+00],
         [ 4.1797e-01,  1.6406e+00,  1.8672e+00,  ...,  1.5234e+00,
          -2.3281e+00, -2.5625e+00],
         ...,
         [ 4.7500e+00,  3.3594e+00,  2.2188e+00,  ...,  1.7656e+00,
          -4.9688e+00, -2.3281e+00],
         [ 6.2812e+00,  3.4844e+00,  3.2188e+00,  ...,  4.0000e+00,
          -5.7500e+00, -4.1250e+00],
         [ 4.0000e+00,  1.9141e+00,  1.7109e+00,  ...,  3.0000e+00,
          -4.2500e+00, -2.1094e+00]],

        [[ 2.0599e-03,  1.2146e-02, -2.2095e-02,  ...,  1.0620e-02,
          -2.1057e-03,  3.3264e-03],
         [ 2.0000e+00,  6.7578e-01,  7.5391e-01,  ...,  1.5469e+00,
           4.0527e-02, -1.3438e+00],
         [ 4.1797e-01,  1.6406e+00,  1.8672e+00,  ...,  1.5234e+00,
          -2.3281e+00, -2.5625e+00],
         ...,
         [ 4.7500e+00,  3.3594e+00,  2.2188e+00,  ...,  1.7656e+00,
          -4.9688e+00, -2.3281e+00],
         [ 6.2812e+00,  3.4844e+00,  3.2188e+00,  ...,  4.0000e+00,
          -5.7500e+00, -4.1250e+00],
         [ 4.0000e+00,  1.9141e+00,  1.7109e+00,  ...,  3.0000e+00,
          -4.2500e+00, -2.1094e+00]]], dtype=torch.bfloat16)), (tensor([[[ 8.5831e-04, -4.5166e-03,  2.7008e-03,  ...,  4.5166e-02,
           3.1055e-01, -8.4961e-02],
         [ 3.1250e+00,  2.2656e+00, -6.4062e-01,  ...,  6.3281e-01,
          -5.9375e-01,  1.2500e+00],
         [ 1.9375e+00,  1.9688e+00,  1.0625e+00,  ...,  5.1953e-01,
           5.8105e-02,  9.5312e-01],
         ...,
         [-2.2363e-01, -1.3984e+00, -1.7578e-01,  ...,  8.3984e-02,
           8.5547e-01,  1.1182e-01],
         [-9.4922e-01,  4.0430e-01, -1.0547e+00,  ...,  2.3340e-01,
          -5.6396e-02, -5.5859e-01],
         [-2.4375e+00,  7.4219e-02, -1.4844e+00,  ...,  9.2188e-01,
          -8.8281e-01, -3.1055e-01]],

        [[ 8.5831e-04, -4.5166e-03,  2.7008e-03,  ...,  4.5166e-02,
           3.1055e-01, -8.4961e-02],
         [ 3.1250e+00,  2.2656e+00, -6.4062e-01,  ...,  6.3281e-01,
          -5.9375e-01,  1.2500e+00],
         [ 1.9375e+00,  1.9688e+00,  1.0625e+00,  ...,  5.1953e-01,
           5.8105e-02,  9.5312e-01],
         ...,
         [-2.2363e-01, -1.3984e+00, -1.7578e-01,  ...,  8.3984e-02,
           8.5547e-01,  1.1182e-01],
         [-9.4922e-01,  4.0430e-01, -1.0547e+00,  ...,  2.3340e-01,
          -5.6396e-02, -5.5859e-01],
         [-2.4375e+00,  7.4219e-02, -1.4844e+00,  ...,  9.2188e-01,
          -8.8281e-01, -3.1055e-01]],

        [[ 8.5831e-04, -4.5166e-03,  2.7008e-03,  ...,  4.5166e-02,
           3.1055e-01, -8.4961e-02],
         [ 3.1250e+00,  2.2656e+00, -6.4062e-01,  ...,  6.3281e-01,
          -5.9375e-01,  1.2500e+00],
         [ 1.9375e+00,  1.9688e+00,  1.0625e+00,  ...,  5.1953e-01,
           5.8105e-02,  9.5312e-01],
         ...,
         [-2.2363e-01, -1.3984e+00, -1.7578e-01,  ...,  8.3984e-02,
           8.5547e-01,  1.1182e-01],
         [-9.4922e-01,  4.0430e-01, -1.0547e+00,  ...,  2.3340e-01,
          -5.6396e-02, -5.5859e-01],
         [-2.4375e+00,  7.4219e-02, -1.4844e+00,  ...,  9.2188e-01,
          -8.8281e-01, -3.1055e-01]],

        ...,

        [[ 3.2959e-03, -2.9755e-03, -6.4392e-03,  ...,  2.6562e-01,
           4.3945e-01,  1.4551e-01],
         [-3.2031e+00,  1.4766e+00, -1.5938e+00,  ..., -3.1641e-01,
          -1.6484e+00, -4.7852e-01],
         [ 2.7344e-02,  1.8262e-01, -1.2031e+00,  ..., -1.3359e+00,
          -2.0781e+00,  1.5938e+00],
         ...,
         [ 1.8281e+00, -2.1562e+00, -1.8203e+00,  ..., -9.1797e-01,
          -1.2500e+00,  4.3750e-01],
         [ 9.6094e-01,  2.9297e-01, -8.2422e-01,  ...,  7.1484e-01,
          -2.2812e+00,  1.5859e+00],
         [ 3.1250e-01,  1.4219e+00, -1.6250e+00,  ..., -1.0547e+00,
          -2.0156e+00,  2.5391e-01]],

        [[ 3.2959e-03, -2.9755e-03, -6.4392e-03,  ...,  2.6562e-01,
           4.3945e-01,  1.4551e-01],
         [-3.2031e+00,  1.4766e+00, -1.5938e+00,  ..., -3.1641e-01,
          -1.6484e+00, -4.7852e-01],
         [ 2.7344e-02,  1.8262e-01, -1.2031e+00,  ..., -1.3359e+00,
          -2.0781e+00,  1.5938e+00],
         ...,
         [ 1.8281e+00, -2.1562e+00, -1.8203e+00,  ..., -9.1797e-01,
          -1.2500e+00,  4.3750e-01],
         [ 9.6094e-01,  2.9297e-01, -8.2422e-01,  ...,  7.1484e-01,
          -2.2812e+00,  1.5859e+00],
         [ 3.1250e-01,  1.4219e+00, -1.6250e+00,  ..., -1.0547e+00,
          -2.0156e+00,  2.5391e-01]],

        [[ 3.2959e-03, -2.9755e-03, -6.4392e-03,  ...,  2.6562e-01,
           4.3945e-01,  1.4551e-01],
         [-3.2031e+00,  1.4766e+00, -1.5938e+00,  ..., -3.1641e-01,
          -1.6484e+00, -4.7852e-01],
         [ 2.7344e-02,  1.8262e-01, -1.2031e+00,  ..., -1.3359e+00,
          -2.0781e+00,  1.5938e+00],
         ...,
         [ 1.8281e+00, -2.1562e+00, -1.8203e+00,  ..., -9.1797e-01,
          -1.2500e+00,  4.3750e-01],
         [ 9.6094e-01,  2.9297e-01, -8.2422e-01,  ...,  7.1484e-01,
          -2.2812e+00,  1.5859e+00],
         [ 3.1250e-01,  1.4219e+00, -1.6250e+00,  ..., -1.0547e+00,
          -2.0156e+00,  2.5391e-01]]], dtype=torch.bfloat16), tensor([[[-2.0264e-02,  7.4005e-04, -1.5869e-02,  ..., -1.4893e-02,
           4.4250e-03, -2.7771e-03],
         [ 5.3125e-01,  7.6172e-01,  1.8203e+00,  ..., -3.2812e-01,
          -6.6016e-01, -5.2734e-01],
         [ 1.6328e+00,  1.3438e+00,  4.5898e-01,  ...,  7.7344e-01,
          -7.0703e-01,  9.9219e-01],
         ...,
         [ 3.2812e-01, -8.9062e-01,  1.7031e+00,  ...,  2.9102e-01,
           1.0078e+00,  5.1953e-01],
         [-2.7148e-01, -9.5703e-01, -7.4219e-01,  ..., -3.8281e-01,
           6.9531e-01, -4.5410e-02],
         [ 6.9141e-01,  8.4766e-01, -4.0039e-01,  ..., -1.0000e+00,
           7.8906e-01, -1.0781e+00]],

        [[-2.0264e-02,  7.4005e-04, -1.5869e-02,  ..., -1.4893e-02,
           4.4250e-03, -2.7771e-03],
         [ 5.3125e-01,  7.6172e-01,  1.8203e+00,  ..., -3.2812e-01,
          -6.6016e-01, -5.2734e-01],
         [ 1.6328e+00,  1.3438e+00,  4.5898e-01,  ...,  7.7344e-01,
          -7.0703e-01,  9.9219e-01],
         ...,
         [ 3.2812e-01, -8.9062e-01,  1.7031e+00,  ...,  2.9102e-01,
           1.0078e+00,  5.1953e-01],
         [-2.7148e-01, -9.5703e-01, -7.4219e-01,  ..., -3.8281e-01,
           6.9531e-01, -4.5410e-02],
         [ 6.9141e-01,  8.4766e-01, -4.0039e-01,  ..., -1.0000e+00,
           7.8906e-01, -1.0781e+00]],

        [[-2.0264e-02,  7.4005e-04, -1.5869e-02,  ..., -1.4893e-02,
           4.4250e-03, -2.7771e-03],
         [ 5.3125e-01,  7.6172e-01,  1.8203e+00,  ..., -3.2812e-01,
          -6.6016e-01, -5.2734e-01],
         [ 1.6328e+00,  1.3438e+00,  4.5898e-01,  ...,  7.7344e-01,
          -7.0703e-01,  9.9219e-01],
         ...,
         [ 3.2812e-01, -8.9062e-01,  1.7031e+00,  ...,  2.9102e-01,
           1.0078e+00,  5.1953e-01],
         [-2.7148e-01, -9.5703e-01, -7.4219e-01,  ..., -3.8281e-01,
           6.9531e-01, -4.5410e-02],
         [ 6.9141e-01,  8.4766e-01, -4.0039e-01,  ..., -1.0000e+00,
           7.8906e-01, -1.0781e+00]],

        ...,

        [[ 6.8665e-04, -4.1809e-03, -1.6968e-02,  ...,  1.6235e-02,
           5.7068e-03,  1.3062e-02],
         [ 2.1191e-01, -4.6289e-01,  7.0703e-01,  ..., -3.9648e-01,
          -3.3984e-01,  4.9805e-01],
         [-1.2793e-01, -1.2500e-01,  1.2734e+00,  ...,  2.5391e-01,
           2.8906e-01, -1.5469e+00],
         ...,
         [-1.2734e+00,  1.2734e+00,  1.4453e+00,  ...,  1.0938e+00,
          -9.8047e-01, -9.0625e-01],
         [ 4.0430e-01, -1.2891e-01,  1.4141e+00,  ..., -4.1562e+00,
           3.5938e-01, -1.0703e+00],
         [ 8.5938e-01,  7.6562e-01,  4.2969e-01,  ...,  5.6641e-02,
          -6.9141e-01, -9.3750e-01]],

        [[ 6.8665e-04, -4.1809e-03, -1.6968e-02,  ...,  1.6235e-02,
           5.7068e-03,  1.3062e-02],
         [ 2.1191e-01, -4.6289e-01,  7.0703e-01,  ..., -3.9648e-01,
          -3.3984e-01,  4.9805e-01],
         [-1.2793e-01, -1.2500e-01,  1.2734e+00,  ...,  2.5391e-01,
           2.8906e-01, -1.5469e+00],
         ...,
         [-1.2734e+00,  1.2734e+00,  1.4453e+00,  ...,  1.0938e+00,
          -9.8047e-01, -9.0625e-01],
         [ 4.0430e-01, -1.2891e-01,  1.4141e+00,  ..., -4.1562e+00,
           3.5938e-01, -1.0703e+00],
         [ 8.5938e-01,  7.6562e-01,  4.2969e-01,  ...,  5.6641e-02,
          -6.9141e-01, -9.3750e-01]],

        [[ 6.8665e-04, -4.1809e-03, -1.6968e-02,  ...,  1.6235e-02,
           5.7068e-03,  1.3062e-02],
         [ 2.1191e-01, -4.6289e-01,  7.0703e-01,  ..., -3.9648e-01,
          -3.3984e-01,  4.9805e-01],
         [-1.2793e-01, -1.2500e-01,  1.2734e+00,  ...,  2.5391e-01,
           2.8906e-01, -1.5469e+00],
         ...,
         [-1.2734e+00,  1.2734e+00,  1.4453e+00,  ...,  1.0938e+00,
          -9.8047e-01, -9.0625e-01],
         [ 4.0430e-01, -1.2891e-01,  1.4141e+00,  ..., -4.1562e+00,
           3.5938e-01, -1.0703e+00],
         [ 8.5938e-01,  7.6562e-01,  4.2969e-01,  ...,  5.6641e-02,
          -6.9141e-01, -9.3750e-01]]], dtype=torch.bfloat16)), (tensor([[[ 6.1340e-03, -1.6724e-02,  5.8365e-04,  ...,  3.6328e-01,
          -4.5898e-02, -1.3086e-01],
         [ 1.1641e+00,  6.7969e-01, -1.1875e+00,  ..., -9.2188e-01,
          -5.7373e-02,  9.6094e-01],
         [ 1.6875e+00,  1.3438e+00, -6.3281e-01,  ..., -1.5078e+00,
          -4.3164e-01,  1.9375e+00],
         ...,
         [ 1.7031e+00, -1.3438e+00, -4.6094e-01,  ..., -4.1016e-01,
          -1.0938e+00,  1.2344e+00],
         [ 5.5469e-01, -5.4297e-01, -1.8438e+00,  ..., -9.7656e-01,
          -8.5938e-02,  2.0938e+00],
         [-2.5000e+00,  1.5156e+00, -3.6133e-01,  ..., -8.9062e-01,
          -3.5742e-01,  1.1250e+00]],

        [[ 6.1340e-03, -1.6724e-02,  5.8365e-04,  ...,  3.6328e-01,
          -4.5898e-02, -1.3086e-01],
         [ 1.1641e+00,  6.7969e-01, -1.1875e+00,  ..., -9.2188e-01,
          -5.7373e-02,  9.6094e-01],
         [ 1.6875e+00,  1.3438e+00, -6.3281e-01,  ..., -1.5078e+00,
          -4.3164e-01,  1.9375e+00],
         ...,
         [ 1.7031e+00, -1.3438e+00, -4.6094e-01,  ..., -4.1016e-01,
          -1.0938e+00,  1.2344e+00],
         [ 5.5469e-01, -5.4297e-01, -1.8438e+00,  ..., -9.7656e-01,
          -8.5938e-02,  2.0938e+00],
         [-2.5000e+00,  1.5156e+00, -3.6133e-01,  ..., -8.9062e-01,
          -3.5742e-01,  1.1250e+00]],

        [[ 6.1340e-03, -1.6724e-02,  5.8365e-04,  ...,  3.6328e-01,
          -4.5898e-02, -1.3086e-01],
         [ 1.1641e+00,  6.7969e-01, -1.1875e+00,  ..., -9.2188e-01,
          -5.7373e-02,  9.6094e-01],
         [ 1.6875e+00,  1.3438e+00, -6.3281e-01,  ..., -1.5078e+00,
          -4.3164e-01,  1.9375e+00],
         ...,
         [ 1.7031e+00, -1.3438e+00, -4.6094e-01,  ..., -4.1016e-01,
          -1.0938e+00,  1.2344e+00],
         [ 5.5469e-01, -5.4297e-01, -1.8438e+00,  ..., -9.7656e-01,
          -8.5938e-02,  2.0938e+00],
         [-2.5000e+00,  1.5156e+00, -3.6133e-01,  ..., -8.9062e-01,
          -3.5742e-01,  1.1250e+00]],

        ...,

        [[ 1.5076e-02,  1.1780e-02,  4.5471e-03,  ..., -1.2012e-01,
          -1.2024e-02, -1.5391e+00],
         [ 3.3125e+00, -1.0156e-01, -9.9219e-01,  ..., -3.9062e-01,
           5.7812e-01,  6.6250e+00],
         [ 2.0781e+00,  1.8438e+00, -9.1406e-01,  ..., -1.8359e+00,
           4.6875e-02,  8.0625e+00],
         ...,
         [ 4.1406e-01, -7.4219e-01, -2.2461e-01,  ...,  7.3438e-01,
          -1.3984e+00,  7.0938e+00],
         [-1.5859e+00, -2.0781e+00,  7.0312e-02,  ...,  1.1875e+00,
          -3.4961e-01,  7.5938e+00],
         [-2.8281e+00, -1.8594e+00, -4.6094e-01,  ...,  9.6094e-01,
           1.1172e+00,  7.6562e+00]],

        [[ 1.5076e-02,  1.1780e-02,  4.5471e-03,  ..., -1.2012e-01,
          -1.2024e-02, -1.5391e+00],
         [ 3.3125e+00, -1.0156e-01, -9.9219e-01,  ..., -3.9062e-01,
           5.7812e-01,  6.6250e+00],
         [ 2.0781e+00,  1.8438e+00, -9.1406e-01,  ..., -1.8359e+00,
           4.6875e-02,  8.0625e+00],
         ...,
         [ 4.1406e-01, -7.4219e-01, -2.2461e-01,  ...,  7.3438e-01,
          -1.3984e+00,  7.0938e+00],
         [-1.5859e+00, -2.0781e+00,  7.0312e-02,  ...,  1.1875e+00,
          -3.4961e-01,  7.5938e+00],
         [-2.8281e+00, -1.8594e+00, -4.6094e-01,  ...,  9.6094e-01,
           1.1172e+00,  7.6562e+00]],

        [[ 1.5076e-02,  1.1780e-02,  4.5471e-03,  ..., -1.2012e-01,
          -1.2024e-02, -1.5391e+00],
         [ 3.3125e+00, -1.0156e-01, -9.9219e-01,  ..., -3.9062e-01,
           5.7812e-01,  6.6250e+00],
         [ 2.0781e+00,  1.8438e+00, -9.1406e-01,  ..., -1.8359e+00,
           4.6875e-02,  8.0625e+00],
         ...,
         [ 4.1406e-01, -7.4219e-01, -2.2461e-01,  ...,  7.3438e-01,
          -1.3984e+00,  7.0938e+00],
         [-1.5859e+00, -2.0781e+00,  7.0312e-02,  ...,  1.1875e+00,
          -3.4961e-01,  7.5938e+00],
         [-2.8281e+00, -1.8594e+00, -4.6094e-01,  ...,  9.6094e-01,
           1.1172e+00,  7.6562e+00]]], dtype=torch.bfloat16), tensor([[[ 4.6387e-03,  8.9722e-03, -1.2268e-02,  ...,  2.8687e-02,
           2.5024e-03,  4.5776e-03],
         [-2.1094e-01,  3.4180e-02,  6.0938e-01,  ...,  7.0312e-01,
           3.2031e-01,  3.0078e-01],
         [ 7.0312e-01,  9.6094e-01, -2.3730e-01,  ..., -2.8125e-01,
           9.0625e-01, -6.3281e-01],
         ...,
         [-8.1250e-01,  1.5391e+00, -8.0078e-02,  ...,  2.1719e+00,
           8.6328e-01,  2.9688e-01],
         [ 6.1328e-01, -1.0938e+00,  1.8281e+00,  ..., -1.4609e+00,
          -8.0078e-02, -3.3594e-01],
         [ 3.5938e-01, -1.4062e+00,  1.0391e+00,  ..., -1.8047e+00,
           1.4062e+00,  9.7656e-01]],

        [[ 4.6387e-03,  8.9722e-03, -1.2268e-02,  ...,  2.8687e-02,
           2.5024e-03,  4.5776e-03],
         [-2.1094e-01,  3.4180e-02,  6.0938e-01,  ...,  7.0312e-01,
           3.2031e-01,  3.0078e-01],
         [ 7.0312e-01,  9.6094e-01, -2.3730e-01,  ..., -2.8125e-01,
           9.0625e-01, -6.3281e-01],
         ...,
         [-8.1250e-01,  1.5391e+00, -8.0078e-02,  ...,  2.1719e+00,
           8.6328e-01,  2.9688e-01],
         [ 6.1328e-01, -1.0938e+00,  1.8281e+00,  ..., -1.4609e+00,
          -8.0078e-02, -3.3594e-01],
         [ 3.5938e-01, -1.4062e+00,  1.0391e+00,  ..., -1.8047e+00,
           1.4062e+00,  9.7656e-01]],

        [[ 4.6387e-03,  8.9722e-03, -1.2268e-02,  ...,  2.8687e-02,
           2.5024e-03,  4.5776e-03],
         [-2.1094e-01,  3.4180e-02,  6.0938e-01,  ...,  7.0312e-01,
           3.2031e-01,  3.0078e-01],
         [ 7.0312e-01,  9.6094e-01, -2.3730e-01,  ..., -2.8125e-01,
           9.0625e-01, -6.3281e-01],
         ...,
         [-8.1250e-01,  1.5391e+00, -8.0078e-02,  ...,  2.1719e+00,
           8.6328e-01,  2.9688e-01],
         [ 6.1328e-01, -1.0938e+00,  1.8281e+00,  ..., -1.4609e+00,
          -8.0078e-02, -3.3594e-01],
         [ 3.5938e-01, -1.4062e+00,  1.0391e+00,  ..., -1.8047e+00,
           1.4062e+00,  9.7656e-01]],

        ...,

        [[-7.5684e-03,  1.5869e-02, -2.8198e-02,  ...,  1.9531e-02,
           6.4697e-03, -8.6975e-04],
         [ 1.1914e-01,  4.0283e-02,  1.3438e+00,  ..., -8.8672e-01,
           1.7031e+00, -2.0996e-01],
         [ 5.4297e-01, -9.4922e-01,  1.1484e+00,  ..., -2.2656e+00,
          -1.2266e+00, -2.0781e+00],
         ...,
         [ 2.7656e+00,  1.3984e+00,  3.0859e-01,  ..., -9.2188e-01,
           4.2773e-01, -3.7656e+00],
         [ 1.3672e+00, -1.2500e-01, -5.1172e-01,  ..., -7.6953e-01,
          -1.7656e+00, -1.7422e+00],
         [ 3.9688e+00, -7.2266e-01, -3.8477e-01,  ..., -2.0156e+00,
           9.6484e-01, -2.2344e+00]],

        [[-7.5684e-03,  1.5869e-02, -2.8198e-02,  ...,  1.9531e-02,
           6.4697e-03, -8.6975e-04],
         [ 1.1914e-01,  4.0283e-02,  1.3438e+00,  ..., -8.8672e-01,
           1.7031e+00, -2.0996e-01],
         [ 5.4297e-01, -9.4922e-01,  1.1484e+00,  ..., -2.2656e+00,
          -1.2266e+00, -2.0781e+00],
         ...,
         [ 2.7656e+00,  1.3984e+00,  3.0859e-01,  ..., -9.2188e-01,
           4.2773e-01, -3.7656e+00],
         [ 1.3672e+00, -1.2500e-01, -5.1172e-01,  ..., -7.6953e-01,
          -1.7656e+00, -1.7422e+00],
         [ 3.9688e+00, -7.2266e-01, -3.8477e-01,  ..., -2.0156e+00,
           9.6484e-01, -2.2344e+00]],

        [[-7.5684e-03,  1.5869e-02, -2.8198e-02,  ...,  1.9531e-02,
           6.4697e-03, -8.6975e-04],
         [ 1.1914e-01,  4.0283e-02,  1.3438e+00,  ..., -8.8672e-01,
           1.7031e+00, -2.0996e-01],
         [ 5.4297e-01, -9.4922e-01,  1.1484e+00,  ..., -2.2656e+00,
          -1.2266e+00, -2.0781e+00],
         ...,
         [ 2.7656e+00,  1.3984e+00,  3.0859e-01,  ..., -9.2188e-01,
           4.2773e-01, -3.7656e+00],
         [ 1.3672e+00, -1.2500e-01, -5.1172e-01,  ..., -7.6953e-01,
          -1.7656e+00, -1.7422e+00],
         [ 3.9688e+00, -7.2266e-01, -3.8477e-01,  ..., -2.0156e+00,
           9.6484e-01, -2.2344e+00]]], dtype=torch.bfloat16)), (tensor([[[-1.7090e-02,  5.7983e-03,  1.4267e-03,  ..., -4.4189e-02,
          -2.8076e-03, -2.6001e-02],
         [ 2.1250e+00, -6.4844e-01, -5.2734e-01,  ...,  1.3359e+00,
          -8.1641e-01,  7.1484e-01],
         [-1.7383e-01,  3.6328e-01,  2.1406e+00,  ...,  1.6953e+00,
          -2.8750e+00, -4.3555e-01],
         ...,
         [-1.9219e+00,  6.0156e-01,  8.5938e-01,  ..., -3.5000e+00,
           6.7578e-01,  1.9629e-01],
         [-1.6406e-01,  6.8750e-01,  1.3359e+00,  ..., -2.1719e+00,
          -1.7344e+00,  1.4375e+00],
         [ 4.7852e-01,  2.8564e-02,  7.4219e-01,  ..., -1.1641e+00,
          -1.3594e+00, -5.1562e-01]],

        [[-1.7090e-02,  5.7983e-03,  1.4267e-03,  ..., -4.4189e-02,
          -2.8076e-03, -2.6001e-02],
         [ 2.1250e+00, -6.4844e-01, -5.2734e-01,  ...,  1.3359e+00,
          -8.1641e-01,  7.1484e-01],
         [-1.7383e-01,  3.6328e-01,  2.1406e+00,  ...,  1.6953e+00,
          -2.8750e+00, -4.3555e-01],
         ...,
         [-1.9219e+00,  6.0156e-01,  8.5938e-01,  ..., -3.5000e+00,
           6.7578e-01,  1.9629e-01],
         [-1.6406e-01,  6.8750e-01,  1.3359e+00,  ..., -2.1719e+00,
          -1.7344e+00,  1.4375e+00],
         [ 4.7852e-01,  2.8564e-02,  7.4219e-01,  ..., -1.1641e+00,
          -1.3594e+00, -5.1562e-01]],

        [[-1.7090e-02,  5.7983e-03,  1.4267e-03,  ..., -4.4189e-02,
          -2.8076e-03, -2.6001e-02],
         [ 2.1250e+00, -6.4844e-01, -5.2734e-01,  ...,  1.3359e+00,
          -8.1641e-01,  7.1484e-01],
         [-1.7383e-01,  3.6328e-01,  2.1406e+00,  ...,  1.6953e+00,
          -2.8750e+00, -4.3555e-01],
         ...,
         [-1.9219e+00,  6.0156e-01,  8.5938e-01,  ..., -3.5000e+00,
           6.7578e-01,  1.9629e-01],
         [-1.6406e-01,  6.8750e-01,  1.3359e+00,  ..., -2.1719e+00,
          -1.7344e+00,  1.4375e+00],
         [ 4.7852e-01,  2.8564e-02,  7.4219e-01,  ..., -1.1641e+00,
          -1.3594e+00, -5.1562e-01]],

        ...,

        [[-1.4221e-02,  2.1240e-02, -1.3855e-02,  ...,  1.4258e-01,
           8.3984e-02, -1.8906e+00],
         [ 2.9297e-02, -2.5391e-01, -9.0625e-01,  ...,  4.1016e-01,
          -1.7344e+00,  5.9375e+00],
         [ 4.1406e-01,  3.7842e-02,  1.7578e-01,  ..., -5.4297e-01,
          -2.0469e+00,  6.5312e+00],
         ...,
         [ 6.7578e-01,  1.6797e-01, -3.6523e-01,  ..., -9.3750e-01,
           1.4375e+00,  5.0938e+00],
         [-3.5547e-01, -2.9102e-01,  3.9453e-01,  ..., -2.9375e+00,
          -5.1562e-01,  5.3125e+00],
         [-1.1250e+00, -6.1719e-01, -5.2344e-01,  ..., -2.2656e+00,
          -1.3594e+00,  4.8750e+00]],

        [[-1.4221e-02,  2.1240e-02, -1.3855e-02,  ...,  1.4258e-01,
           8.3984e-02, -1.8906e+00],
         [ 2.9297e-02, -2.5391e-01, -9.0625e-01,  ...,  4.1016e-01,
          -1.7344e+00,  5.9375e+00],
         [ 4.1406e-01,  3.7842e-02,  1.7578e-01,  ..., -5.4297e-01,
          -2.0469e+00,  6.5312e+00],
         ...,
         [ 6.7578e-01,  1.6797e-01, -3.6523e-01,  ..., -9.3750e-01,
           1.4375e+00,  5.0938e+00],
         [-3.5547e-01, -2.9102e-01,  3.9453e-01,  ..., -2.9375e+00,
          -5.1562e-01,  5.3125e+00],
         [-1.1250e+00, -6.1719e-01, -5.2344e-01,  ..., -2.2656e+00,
          -1.3594e+00,  4.8750e+00]],

        [[-1.4221e-02,  2.1240e-02, -1.3855e-02,  ...,  1.4258e-01,
           8.3984e-02, -1.8906e+00],
         [ 2.9297e-02, -2.5391e-01, -9.0625e-01,  ...,  4.1016e-01,
          -1.7344e+00,  5.9375e+00],
         [ 4.1406e-01,  3.7842e-02,  1.7578e-01,  ..., -5.4297e-01,
          -2.0469e+00,  6.5312e+00],
         ...,
         [ 6.7578e-01,  1.6797e-01, -3.6523e-01,  ..., -9.3750e-01,
           1.4375e+00,  5.0938e+00],
         [-3.5547e-01, -2.9102e-01,  3.9453e-01,  ..., -2.9375e+00,
          -5.1562e-01,  5.3125e+00],
         [-1.1250e+00, -6.1719e-01, -5.2344e-01,  ..., -2.2656e+00,
          -1.3594e+00,  4.8750e+00]]], dtype=torch.bfloat16), tensor([[[-4.0039e-02, -1.9409e-02, -2.3804e-02,  ...,  3.9795e-02,
           6.1035e-02, -2.9785e-02],
         [-1.0938e+00,  4.3750e-01, -2.5781e-01,  ..., -3.9258e-01,
           7.9590e-02,  1.3828e+00],
         [-7.3828e-01, -2.5391e-01,  6.8359e-01,  ..., -4.2578e-01,
           6.7969e-01,  4.2578e-01],
         ...,
         [-1.3125e+00,  1.0312e+00, -2.1094e+00,  ...,  4.8242e-01,
           1.7422e+00,  1.3125e+00],
         [-6.2109e-01, -8.0469e-01,  2.7148e-01,  ..., -1.7812e+00,
           1.7383e-01, -7.6953e-01],
         [-2.8809e-02, -1.3594e+00, -6.3672e-01,  ..., -2.1387e-01,
           2.0469e+00, -5.4688e-01]],

        [[-4.0039e-02, -1.9409e-02, -2.3804e-02,  ...,  3.9795e-02,
           6.1035e-02, -2.9785e-02],
         [-1.0938e+00,  4.3750e-01, -2.5781e-01,  ..., -3.9258e-01,
           7.9590e-02,  1.3828e+00],
         [-7.3828e-01, -2.5391e-01,  6.8359e-01,  ..., -4.2578e-01,
           6.7969e-01,  4.2578e-01],
         ...,
         [-1.3125e+00,  1.0312e+00, -2.1094e+00,  ...,  4.8242e-01,
           1.7422e+00,  1.3125e+00],
         [-6.2109e-01, -8.0469e-01,  2.7148e-01,  ..., -1.7812e+00,
           1.7383e-01, -7.6953e-01],
         [-2.8809e-02, -1.3594e+00, -6.3672e-01,  ..., -2.1387e-01,
           2.0469e+00, -5.4688e-01]],

        [[-4.0039e-02, -1.9409e-02, -2.3804e-02,  ...,  3.9795e-02,
           6.1035e-02, -2.9785e-02],
         [-1.0938e+00,  4.3750e-01, -2.5781e-01,  ..., -3.9258e-01,
           7.9590e-02,  1.3828e+00],
         [-7.3828e-01, -2.5391e-01,  6.8359e-01,  ..., -4.2578e-01,
           6.7969e-01,  4.2578e-01],
         ...,
         [-1.3125e+00,  1.0312e+00, -2.1094e+00,  ...,  4.8242e-01,
           1.7422e+00,  1.3125e+00],
         [-6.2109e-01, -8.0469e-01,  2.7148e-01,  ..., -1.7812e+00,
           1.7383e-01, -7.6953e-01],
         [-2.8809e-02, -1.3594e+00, -6.3672e-01,  ..., -2.1387e-01,
           2.0469e+00, -5.4688e-01]],

        ...,

        [[ 6.0730e-03, -1.4343e-02, -1.2146e-02,  ...,  1.8692e-03,
          -4.9744e-03,  3.8757e-03],
         [-4.6875e-01,  2.4414e-01, -9.5312e-01,  ...,  5.5859e-01,
          -7.8125e-01,  4.4727e-01],
         [-8.5547e-01, -5.8984e-01,  4.5898e-01,  ..., -2.5625e+00,
          -5.1270e-02,  3.3008e-01],
         ...,
         [-6.4844e-01, -3.5938e-01,  9.7656e-01,  ..., -3.0859e-01,
          -1.8164e-01,  6.2500e-01],
         [ 9.0332e-02, -4.6875e-01, -7.1875e-01,  ..., -7.6953e-01,
          -2.5586e-01,  2.7930e-01],
         [-2.2812e+00, -8.9453e-01,  5.0391e-01,  ..., -1.5859e+00,
          -5.3125e-01, -9.1406e-01]],

        [[ 6.0730e-03, -1.4343e-02, -1.2146e-02,  ...,  1.8692e-03,
          -4.9744e-03,  3.8757e-03],
         [-4.6875e-01,  2.4414e-01, -9.5312e-01,  ...,  5.5859e-01,
          -7.8125e-01,  4.4727e-01],
         [-8.5547e-01, -5.8984e-01,  4.5898e-01,  ..., -2.5625e+00,
          -5.1270e-02,  3.3008e-01],
         ...,
         [-6.4844e-01, -3.5938e-01,  9.7656e-01,  ..., -3.0859e-01,
          -1.8164e-01,  6.2500e-01],
         [ 9.0332e-02, -4.6875e-01, -7.1875e-01,  ..., -7.6953e-01,
          -2.5586e-01,  2.7930e-01],
         [-2.2812e+00, -8.9453e-01,  5.0391e-01,  ..., -1.5859e+00,
          -5.3125e-01, -9.1406e-01]],

        [[ 6.0730e-03, -1.4343e-02, -1.2146e-02,  ...,  1.8692e-03,
          -4.9744e-03,  3.8757e-03],
         [-4.6875e-01,  2.4414e-01, -9.5312e-01,  ...,  5.5859e-01,
          -7.8125e-01,  4.4727e-01],
         [-8.5547e-01, -5.8984e-01,  4.5898e-01,  ..., -2.5625e+00,
          -5.1270e-02,  3.3008e-01],
         ...,
         [-6.4844e-01, -3.5938e-01,  9.7656e-01,  ..., -3.0859e-01,
          -1.8164e-01,  6.2500e-01],
         [ 9.0332e-02, -4.6875e-01, -7.1875e-01,  ..., -7.6953e-01,
          -2.5586e-01,  2.7930e-01],
         [-2.2812e+00, -8.9453e-01,  5.0391e-01,  ..., -1.5859e+00,
          -5.3125e-01, -9.1406e-01]]], dtype=torch.bfloat16)), (tensor([[[-0.0069,  0.0071, -0.0229,  ..., -0.2656, -0.3203,  0.5664],
         [ 1.9844, -0.9141,  1.1562,  ..., -2.0938, -0.4727,  0.7070],
         [ 1.4766, -0.3887,  0.7656,  ..., -2.5625, -0.1885, -0.8750],
         ...,
         [ 0.2070,  0.9141,  0.1455,  ..., -1.2422,  0.0520, -2.4375],
         [-0.1982,  0.2090,  0.4023,  ..., -2.4219,  1.5391, -3.0469],
         [-1.8281, -0.8906,  0.7656,  ..., -2.2500, -2.0156, -1.9375]],

        [[-0.0069,  0.0071, -0.0229,  ..., -0.2656, -0.3203,  0.5664],
         [ 1.9844, -0.9141,  1.1562,  ..., -2.0938, -0.4727,  0.7070],
         [ 1.4766, -0.3887,  0.7656,  ..., -2.5625, -0.1885, -0.8750],
         ...,
         [ 0.2070,  0.9141,  0.1455,  ..., -1.2422,  0.0520, -2.4375],
         [-0.1982,  0.2090,  0.4023,  ..., -2.4219,  1.5391, -3.0469],
         [-1.8281, -0.8906,  0.7656,  ..., -2.2500, -2.0156, -1.9375]],

        [[-0.0069,  0.0071, -0.0229,  ..., -0.2656, -0.3203,  0.5664],
         [ 1.9844, -0.9141,  1.1562,  ..., -2.0938, -0.4727,  0.7070],
         [ 1.4766, -0.3887,  0.7656,  ..., -2.5625, -0.1885, -0.8750],
         ...,
         [ 0.2070,  0.9141,  0.1455,  ..., -1.2422,  0.0520, -2.4375],
         [-0.1982,  0.2090,  0.4023,  ..., -2.4219,  1.5391, -3.0469],
         [-1.8281, -0.8906,  0.7656,  ..., -2.2500, -2.0156, -1.9375]],

        ...,

        [[-0.0102, -0.0106,  0.0152,  ..., -0.0306,  0.2354,  0.2734],
         [ 0.5859,  0.5820,  1.4844,  ...,  0.5117, -0.1738, -0.3418],
         [-0.3145,  2.4375,  0.7930,  ..., -0.0059,  0.1729, -0.4980],
         ...,
         [-1.7578, -1.1875, -0.5273,  ..., -0.6836, -0.0330, -0.4785],
         [ 0.4180, -0.4766,  0.5195,  ...,  1.0938, -1.3750,  1.7812],
         [ 1.3281, -0.9102,  1.0312,  ..., -0.0315,  0.2432,  1.0547]],

        [[-0.0102, -0.0106,  0.0152,  ..., -0.0306,  0.2354,  0.2734],
         [ 0.5859,  0.5820,  1.4844,  ...,  0.5117, -0.1738, -0.3418],
         [-0.3145,  2.4375,  0.7930,  ..., -0.0059,  0.1729, -0.4980],
         ...,
         [-1.7578, -1.1875, -0.5273,  ..., -0.6836, -0.0330, -0.4785],
         [ 0.4180, -0.4766,  0.5195,  ...,  1.0938, -1.3750,  1.7812],
         [ 1.3281, -0.9102,  1.0312,  ..., -0.0315,  0.2432,  1.0547]],

        [[-0.0102, -0.0106,  0.0152,  ..., -0.0306,  0.2354,  0.2734],
         [ 0.5859,  0.5820,  1.4844,  ...,  0.5117, -0.1738, -0.3418],
         [-0.3145,  2.4375,  0.7930,  ..., -0.0059,  0.1729, -0.4980],
         ...,
         [-1.7578, -1.1875, -0.5273,  ..., -0.6836, -0.0330, -0.4785],
         [ 0.4180, -0.4766,  0.5195,  ...,  1.0938, -1.3750,  1.7812],
         [ 1.3281, -0.9102,  1.0312,  ..., -0.0315,  0.2432,  1.0547]]],
       dtype=torch.bfloat16), tensor([[[ 0.0066,  0.0254,  0.0081,  ...,  0.0171,  0.0244, -0.0117],
         [ 0.2949, -0.6914,  0.5703,  ...,  0.0679,  0.4258, -0.1797],
         [-0.5859, -1.0625, -0.5781,  ...,  0.0036,  0.9961, -1.3984],
         ...,
         [-0.1953,  0.5703, -0.2031,  ..., -0.9609,  0.7617,  0.6719],
         [ 0.1289,  0.4941,  1.1328,  ..., -1.1484,  0.3223,  0.4355],
         [-0.1719,  0.3711,  0.7031,  ...,  0.4648,  0.0403, -0.8555]],

        [[ 0.0066,  0.0254,  0.0081,  ...,  0.0171,  0.0244, -0.0117],
         [ 0.2949, -0.6914,  0.5703,  ...,  0.0679,  0.4258, -0.1797],
         [-0.5859, -1.0625, -0.5781,  ...,  0.0036,  0.9961, -1.3984],
         ...,
         [-0.1953,  0.5703, -0.2031,  ..., -0.9609,  0.7617,  0.6719],
         [ 0.1289,  0.4941,  1.1328,  ..., -1.1484,  0.3223,  0.4355],
         [-0.1719,  0.3711,  0.7031,  ...,  0.4648,  0.0403, -0.8555]],

        [[ 0.0066,  0.0254,  0.0081,  ...,  0.0171,  0.0244, -0.0117],
         [ 0.2949, -0.6914,  0.5703,  ...,  0.0679,  0.4258, -0.1797],
         [-0.5859, -1.0625, -0.5781,  ...,  0.0036,  0.9961, -1.3984],
         ...,
         [-0.1953,  0.5703, -0.2031,  ..., -0.9609,  0.7617,  0.6719],
         [ 0.1289,  0.4941,  1.1328,  ..., -1.1484,  0.3223,  0.4355],
         [-0.1719,  0.3711,  0.7031,  ...,  0.4648,  0.0403, -0.8555]],

        ...,

        [[-0.0140, -0.0067,  0.0165,  ..., -0.0031, -0.0269,  0.0166],
         [ 0.7578,  0.6719, -0.7812,  ..., -0.4512, -1.5156, -0.4316],
         [ 1.1406,  1.6641, -0.1426,  ..., -0.7422,  0.7422, -2.5156],
         ...,
         [-1.2188, -1.5234,  0.0674,  ..., -0.2891, -0.5664,  0.3047],
         [-1.3750, -0.3125, -0.2129,  ...,  1.0547, -0.4668, -1.4531],
         [-1.1328, -0.2930, -0.3711,  ..., -1.0469, -0.6289,  0.2949]],

        [[-0.0140, -0.0067,  0.0165,  ..., -0.0031, -0.0269,  0.0166],
         [ 0.7578,  0.6719, -0.7812,  ..., -0.4512, -1.5156, -0.4316],
         [ 1.1406,  1.6641, -0.1426,  ..., -0.7422,  0.7422, -2.5156],
         ...,
         [-1.2188, -1.5234,  0.0674,  ..., -0.2891, -0.5664,  0.3047],
         [-1.3750, -0.3125, -0.2129,  ...,  1.0547, -0.4668, -1.4531],
         [-1.1328, -0.2930, -0.3711,  ..., -1.0469, -0.6289,  0.2949]],

        [[-0.0140, -0.0067,  0.0165,  ..., -0.0031, -0.0269,  0.0166],
         [ 0.7578,  0.6719, -0.7812,  ..., -0.4512, -1.5156, -0.4316],
         [ 1.1406,  1.6641, -0.1426,  ..., -0.7422,  0.7422, -2.5156],
         ...,
         [-1.2188, -1.5234,  0.0674,  ..., -0.2891, -0.5664,  0.3047],
         [-1.3750, -0.3125, -0.2129,  ...,  1.0547, -0.4668, -1.4531],
         [-1.1328, -0.2930, -0.3711,  ..., -1.0469, -0.6289,  0.2949]]],
       dtype=torch.bfloat16)), (tensor([[[ 4.0894e-03,  4.7607e-03, -2.6245e-03,  ..., -6.8848e-02,
           4.7852e-02,  1.8262e-01],
         [-3.7500e+00,  1.0625e+00, -1.1094e+00,  ..., -1.0303e-01,
           3.3203e-01, -1.0547e+00],
         [-1.4766e+00,  3.1562e+00, -1.0625e+00,  ..., -1.7578e-02,
          -9.4531e-01, -2.1250e+00],
         ...,
         [ 1.4375e+00, -7.1484e-01, -1.8555e-01,  ...,  5.4297e-01,
           1.7109e+00, -9.3359e-01],
         [ 1.2578e+00, -1.4531e+00, -1.1094e+00,  ...,  7.8516e-01,
          -1.2891e+00,  1.8359e+00],
         [ 1.6094e+00, -2.8438e+00, -1.7969e+00,  ...,  2.5391e-01,
           1.4062e-01, -5.1025e-02]],

        [[ 4.0894e-03,  4.7607e-03, -2.6245e-03,  ..., -6.8848e-02,
           4.7852e-02,  1.8262e-01],
         [-3.7500e+00,  1.0625e+00, -1.1094e+00,  ..., -1.0303e-01,
           3.3203e-01, -1.0547e+00],
         [-1.4766e+00,  3.1562e+00, -1.0625e+00,  ..., -1.7578e-02,
          -9.4531e-01, -2.1250e+00],
         ...,
         [ 1.4375e+00, -7.1484e-01, -1.8555e-01,  ...,  5.4297e-01,
           1.7109e+00, -9.3359e-01],
         [ 1.2578e+00, -1.4531e+00, -1.1094e+00,  ...,  7.8516e-01,
          -1.2891e+00,  1.8359e+00],
         [ 1.6094e+00, -2.8438e+00, -1.7969e+00,  ...,  2.5391e-01,
           1.4062e-01, -5.1025e-02]],

        [[ 4.0894e-03,  4.7607e-03, -2.6245e-03,  ..., -6.8848e-02,
           4.7852e-02,  1.8262e-01],
         [-3.7500e+00,  1.0625e+00, -1.1094e+00,  ..., -1.0303e-01,
           3.3203e-01, -1.0547e+00],
         [-1.4766e+00,  3.1562e+00, -1.0625e+00,  ..., -1.7578e-02,
          -9.4531e-01, -2.1250e+00],
         ...,
         [ 1.4375e+00, -7.1484e-01, -1.8555e-01,  ...,  5.4297e-01,
           1.7109e+00, -9.3359e-01],
         [ 1.2578e+00, -1.4531e+00, -1.1094e+00,  ...,  7.8516e-01,
          -1.2891e+00,  1.8359e+00],
         [ 1.6094e+00, -2.8438e+00, -1.7969e+00,  ...,  2.5391e-01,
           1.4062e-01, -5.1025e-02]],

        ...,

        [[ 8.1787e-03,  1.9165e-02, -1.0620e-02,  ...,  1.6699e-01,
          -1.3965e-01, -1.6016e-01],
         [-1.8906e+00,  1.1406e+00,  4.0039e-01,  ...,  1.2305e-01,
           4.7852e-01, -2.8906e-01],
         [ 5.7031e-01, -1.0469e+00, -5.1172e-01,  ..., -6.7578e-01,
          -2.7734e-01, -1.0781e+00],
         ...,
         [ 1.2656e+00, -6.4844e-01, -4.3359e-01,  ...,  5.9766e-01,
          -2.4219e-01, -4.4678e-02],
         [ 1.1875e+00,  1.7188e-01, -1.2695e-01,  ...,  9.5703e-01,
          -5.8594e-01,  9.5703e-01],
         [-5.6641e-01,  1.3047e+00,  5.9375e-01,  ...,  6.3281e-01,
           9.4531e-01, -9.1797e-02]],

        [[ 8.1787e-03,  1.9165e-02, -1.0620e-02,  ...,  1.6699e-01,
          -1.3965e-01, -1.6016e-01],
         [-1.8906e+00,  1.1406e+00,  4.0039e-01,  ...,  1.2305e-01,
           4.7852e-01, -2.8906e-01],
         [ 5.7031e-01, -1.0469e+00, -5.1172e-01,  ..., -6.7578e-01,
          -2.7734e-01, -1.0781e+00],
         ...,
         [ 1.2656e+00, -6.4844e-01, -4.3359e-01,  ...,  5.9766e-01,
          -2.4219e-01, -4.4678e-02],
         [ 1.1875e+00,  1.7188e-01, -1.2695e-01,  ...,  9.5703e-01,
          -5.8594e-01,  9.5703e-01],
         [-5.6641e-01,  1.3047e+00,  5.9375e-01,  ...,  6.3281e-01,
           9.4531e-01, -9.1797e-02]],

        [[ 8.1787e-03,  1.9165e-02, -1.0620e-02,  ...,  1.6699e-01,
          -1.3965e-01, -1.6016e-01],
         [-1.8906e+00,  1.1406e+00,  4.0039e-01,  ...,  1.2305e-01,
           4.7852e-01, -2.8906e-01],
         [ 5.7031e-01, -1.0469e+00, -5.1172e-01,  ..., -6.7578e-01,
          -2.7734e-01, -1.0781e+00],
         ...,
         [ 1.2656e+00, -6.4844e-01, -4.3359e-01,  ...,  5.9766e-01,
          -2.4219e-01, -4.4678e-02],
         [ 1.1875e+00,  1.7188e-01, -1.2695e-01,  ...,  9.5703e-01,
          -5.8594e-01,  9.5703e-01],
         [-5.6641e-01,  1.3047e+00,  5.9375e-01,  ...,  6.3281e-01,
           9.4531e-01, -9.1797e-02]]], dtype=torch.bfloat16), tensor([[[-0.0254, -0.0189,  0.0183,  ..., -0.0042,  0.1099, -0.0179],
         [-0.3301,  0.7344,  1.5312,  ..., -0.8516, -0.1631,  0.8555],
         [ 0.0043,  0.4824, -0.2041,  ..., -0.1660,  0.5938,  0.3242],
         ...,
         [-0.4434,  0.4688, -0.1040,  ..., -1.4453,  0.7422,  1.6016],
         [-0.4941, -0.0325,  0.3105,  ..., -1.4297, -1.3906,  1.2891],
         [-0.0786,  0.7031,  0.8906,  ..., -0.1040,  0.0566,  0.0618]],

        [[-0.0254, -0.0189,  0.0183,  ..., -0.0042,  0.1099, -0.0179],
         [-0.3301,  0.7344,  1.5312,  ..., -0.8516, -0.1631,  0.8555],
         [ 0.0043,  0.4824, -0.2041,  ..., -0.1660,  0.5938,  0.3242],
         ...,
         [-0.4434,  0.4688, -0.1040,  ..., -1.4453,  0.7422,  1.6016],
         [-0.4941, -0.0325,  0.3105,  ..., -1.4297, -1.3906,  1.2891],
         [-0.0786,  0.7031,  0.8906,  ..., -0.1040,  0.0566,  0.0618]],

        [[-0.0254, -0.0189,  0.0183,  ..., -0.0042,  0.1099, -0.0179],
         [-0.3301,  0.7344,  1.5312,  ..., -0.8516, -0.1631,  0.8555],
         [ 0.0043,  0.4824, -0.2041,  ..., -0.1660,  0.5938,  0.3242],
         ...,
         [-0.4434,  0.4688, -0.1040,  ..., -1.4453,  0.7422,  1.6016],
         [-0.4941, -0.0325,  0.3105,  ..., -1.4297, -1.3906,  1.2891],
         [-0.0786,  0.7031,  0.8906,  ..., -0.1040,  0.0566,  0.0618]],

        ...,

        [[-0.0245,  0.0131,  0.0088,  ...,  0.0496, -0.0391, -0.0024],
         [-0.2373, -0.2256,  0.3340,  ..., -0.0225, -0.1021,  0.9844],
         [-0.2637, -1.0156, -0.0444,  ..., -0.5508, -0.2002, -0.1064],
         ...,
         [ 0.2617,  0.0302,  1.6172,  ..., -2.2812,  0.2930,  0.2539],
         [ 0.1191, -0.2949, -0.1289,  ..., -1.3281,  0.3242,  0.9062],
         [ 0.2090,  0.1377, -0.5859,  ...,  0.2871,  0.3379, -0.2178]],

        [[-0.0245,  0.0131,  0.0088,  ...,  0.0496, -0.0391, -0.0024],
         [-0.2373, -0.2256,  0.3340,  ..., -0.0225, -0.1021,  0.9844],
         [-0.2637, -1.0156, -0.0444,  ..., -0.5508, -0.2002, -0.1064],
         ...,
         [ 0.2617,  0.0302,  1.6172,  ..., -2.2812,  0.2930,  0.2539],
         [ 0.1191, -0.2949, -0.1289,  ..., -1.3281,  0.3242,  0.9062],
         [ 0.2090,  0.1377, -0.5859,  ...,  0.2871,  0.3379, -0.2178]],

        [[-0.0245,  0.0131,  0.0088,  ...,  0.0496, -0.0391, -0.0024],
         [-0.2373, -0.2256,  0.3340,  ..., -0.0225, -0.1021,  0.9844],
         [-0.2637, -1.0156, -0.0444,  ..., -0.5508, -0.2002, -0.1064],
         ...,
         [ 0.2617,  0.0302,  1.6172,  ..., -2.2812,  0.2930,  0.2539],
         [ 0.1191, -0.2949, -0.1289,  ..., -1.3281,  0.3242,  0.9062],
         [ 0.2090,  0.1377, -0.5859,  ...,  0.2871,  0.3379, -0.2178]]],
       dtype=torch.bfloat16)), (tensor([[[ 2.4048e-02,  1.5381e-02, -2.6489e-02,  ...,  1.5442e-02,
           1.6094e+00, -2.5586e-01],
         [-1.9297e+00,  1.5312e+00,  3.8281e-01,  ..., -1.4844e+00,
          -4.0000e+00,  7.5000e-01],
         [-4.2188e-01,  1.6562e+00,  1.6953e+00,  ...,  1.9453e+00,
          -5.9062e+00,  1.8555e-01],
         ...,
         [ 1.1406e+00, -1.7422e+00, -1.0625e+00,  ..., -3.6562e+00,
          -5.2188e+00,  5.5469e-01],
         [ 9.0234e-01, -9.3750e-01, -5.2734e-01,  ..., -1.5078e+00,
          -6.4688e+00,  9.8438e-01],
         [-5.7812e-01, -8.3594e-01,  1.4062e+00,  ..., -1.6719e+00,
          -4.2812e+00,  3.0518e-02]],

        [[ 2.4048e-02,  1.5381e-02, -2.6489e-02,  ...,  1.5442e-02,
           1.6094e+00, -2.5586e-01],
         [-1.9297e+00,  1.5312e+00,  3.8281e-01,  ..., -1.4844e+00,
          -4.0000e+00,  7.5000e-01],
         [-4.2188e-01,  1.6562e+00,  1.6953e+00,  ...,  1.9453e+00,
          -5.9062e+00,  1.8555e-01],
         ...,
         [ 1.1406e+00, -1.7422e+00, -1.0625e+00,  ..., -3.6562e+00,
          -5.2188e+00,  5.5469e-01],
         [ 9.0234e-01, -9.3750e-01, -5.2734e-01,  ..., -1.5078e+00,
          -6.4688e+00,  9.8438e-01],
         [-5.7812e-01, -8.3594e-01,  1.4062e+00,  ..., -1.6719e+00,
          -4.2812e+00,  3.0518e-02]],

        [[ 2.4048e-02,  1.5381e-02, -2.6489e-02,  ...,  1.5442e-02,
           1.6094e+00, -2.5586e-01],
         [-1.9297e+00,  1.5312e+00,  3.8281e-01,  ..., -1.4844e+00,
          -4.0000e+00,  7.5000e-01],
         [-4.2188e-01,  1.6562e+00,  1.6953e+00,  ...,  1.9453e+00,
          -5.9062e+00,  1.8555e-01],
         ...,
         [ 1.1406e+00, -1.7422e+00, -1.0625e+00,  ..., -3.6562e+00,
          -5.2188e+00,  5.5469e-01],
         [ 9.0234e-01, -9.3750e-01, -5.2734e-01,  ..., -1.5078e+00,
          -6.4688e+00,  9.8438e-01],
         [-5.7812e-01, -8.3594e-01,  1.4062e+00,  ..., -1.6719e+00,
          -4.2812e+00,  3.0518e-02]],

        ...,

        [[-1.1475e-02,  4.6997e-03, -4.4861e-03,  ...,  4.9438e-03,
           5.2002e-02, -1.9844e+00],
         [-9.8047e-01, -8.5547e-01,  8.7500e-01,  ...,  9.5312e-01,
           1.0681e-02,  5.5000e+00],
         [-2.0312e+00, -2.3594e+00, -6.3672e-01,  ..., -8.6328e-01,
          -3.6719e-01,  7.7500e+00],
         ...,
         [-2.6953e-01, -1.3750e+00, -7.6953e-01,  ...,  6.3672e-01,
           4.2578e-01,  4.8438e+00],
         [ 5.2344e-01, -8.3496e-02,  8.9062e-01,  ...,  1.3047e+00,
           1.5703e+00,  6.6250e+00],
         [ 2.0469e+00,  1.0859e+00,  8.3594e-01,  ...,  2.9531e+00,
           3.6328e-01,  4.9375e+00]],

        [[-1.1475e-02,  4.6997e-03, -4.4861e-03,  ...,  4.9438e-03,
           5.2002e-02, -1.9844e+00],
         [-9.8047e-01, -8.5547e-01,  8.7500e-01,  ...,  9.5312e-01,
           1.0681e-02,  5.5000e+00],
         [-2.0312e+00, -2.3594e+00, -6.3672e-01,  ..., -8.6328e-01,
          -3.6719e-01,  7.7500e+00],
         ...,
         [-2.6953e-01, -1.3750e+00, -7.6953e-01,  ...,  6.3672e-01,
           4.2578e-01,  4.8438e+00],
         [ 5.2344e-01, -8.3496e-02,  8.9062e-01,  ...,  1.3047e+00,
           1.5703e+00,  6.6250e+00],
         [ 2.0469e+00,  1.0859e+00,  8.3594e-01,  ...,  2.9531e+00,
           3.6328e-01,  4.9375e+00]],

        [[-1.1475e-02,  4.6997e-03, -4.4861e-03,  ...,  4.9438e-03,
           5.2002e-02, -1.9844e+00],
         [-9.8047e-01, -8.5547e-01,  8.7500e-01,  ...,  9.5312e-01,
           1.0681e-02,  5.5000e+00],
         [-2.0312e+00, -2.3594e+00, -6.3672e-01,  ..., -8.6328e-01,
          -3.6719e-01,  7.7500e+00],
         ...,
         [-2.6953e-01, -1.3750e+00, -7.6953e-01,  ...,  6.3672e-01,
           4.2578e-01,  4.8438e+00],
         [ 5.2344e-01, -8.3496e-02,  8.9062e-01,  ...,  1.3047e+00,
           1.5703e+00,  6.6250e+00],
         [ 2.0469e+00,  1.0859e+00,  8.3594e-01,  ...,  2.9531e+00,
           3.6328e-01,  4.9375e+00]]], dtype=torch.bfloat16), tensor([[[ 9.2163e-03,  1.6724e-02,  7.9346e-03,  ..., -2.8442e-02,
          -1.2573e-02,  1.4038e-02],
         [ 3.3984e-01,  3.1836e-01,  3.7305e-01,  ...,  4.7852e-01,
          -5.5078e-01,  4.5703e-01],
         [-2.9883e-01, -6.1328e-01,  1.1094e+00,  ..., -9.4531e-01,
           6.4844e-01, -4.6680e-01],
         ...,
         [ 2.8906e-01,  7.3828e-01, -1.6562e+00,  ...,  3.6719e-01,
           1.8828e+00,  1.8945e-01],
         [ 1.2031e+00,  1.4531e+00, -1.9409e-02,  ...,  1.5625e+00,
           3.8750e+00, -3.4180e-01],
         [ 1.2969e+00,  8.7500e-01,  6.4062e-01,  ...,  1.0859e+00,
           1.3594e+00, -3.1641e-01]],

        [[ 9.2163e-03,  1.6724e-02,  7.9346e-03,  ..., -2.8442e-02,
          -1.2573e-02,  1.4038e-02],
         [ 3.3984e-01,  3.1836e-01,  3.7305e-01,  ...,  4.7852e-01,
          -5.5078e-01,  4.5703e-01],
         [-2.9883e-01, -6.1328e-01,  1.1094e+00,  ..., -9.4531e-01,
           6.4844e-01, -4.6680e-01],
         ...,
         [ 2.8906e-01,  7.3828e-01, -1.6562e+00,  ...,  3.6719e-01,
           1.8828e+00,  1.8945e-01],
         [ 1.2031e+00,  1.4531e+00, -1.9409e-02,  ...,  1.5625e+00,
           3.8750e+00, -3.4180e-01],
         [ 1.2969e+00,  8.7500e-01,  6.4062e-01,  ...,  1.0859e+00,
           1.3594e+00, -3.1641e-01]],

        [[ 9.2163e-03,  1.6724e-02,  7.9346e-03,  ..., -2.8442e-02,
          -1.2573e-02,  1.4038e-02],
         [ 3.3984e-01,  3.1836e-01,  3.7305e-01,  ...,  4.7852e-01,
          -5.5078e-01,  4.5703e-01],
         [-2.9883e-01, -6.1328e-01,  1.1094e+00,  ..., -9.4531e-01,
           6.4844e-01, -4.6680e-01],
         ...,
         [ 2.8906e-01,  7.3828e-01, -1.6562e+00,  ...,  3.6719e-01,
           1.8828e+00,  1.8945e-01],
         [ 1.2031e+00,  1.4531e+00, -1.9409e-02,  ...,  1.5625e+00,
           3.8750e+00, -3.4180e-01],
         [ 1.2969e+00,  8.7500e-01,  6.4062e-01,  ...,  1.0859e+00,
           1.3594e+00, -3.1641e-01]],

        ...,

        [[ 1.2817e-02,  1.9897e-02, -2.3682e-02,  ...,  7.8201e-05,
          -3.9795e-02,  1.3306e-02],
         [ 1.4922e+00, -5.1172e-01, -7.1094e-01,  ...,  1.6016e+00,
          -3.3594e-01,  2.4707e-01],
         [ 2.9297e-01, -3.4062e+00, -3.9648e-01,  ..., -7.1094e-01,
          -1.3828e+00, -2.6875e+00],
         ...,
         [-8.3594e-01, -4.2188e-01, -3.3789e-01,  ...,  4.0625e-01,
          -1.1641e+00,  1.8984e+00],
         [-1.2578e+00, -1.3281e+00,  4.6875e-01,  ..., -1.2656e+00,
           5.3516e-01, -1.3438e+00],
         [ 4.7656e-01,  2.2500e+00, -4.8633e-01,  ...,  3.6719e-01,
          -4.5117e-01,  2.6367e-01]],

        [[ 1.2817e-02,  1.9897e-02, -2.3682e-02,  ...,  7.8201e-05,
          -3.9795e-02,  1.3306e-02],
         [ 1.4922e+00, -5.1172e-01, -7.1094e-01,  ...,  1.6016e+00,
          -3.3594e-01,  2.4707e-01],
         [ 2.9297e-01, -3.4062e+00, -3.9648e-01,  ..., -7.1094e-01,
          -1.3828e+00, -2.6875e+00],
         ...,
         [-8.3594e-01, -4.2188e-01, -3.3789e-01,  ...,  4.0625e-01,
          -1.1641e+00,  1.8984e+00],
         [-1.2578e+00, -1.3281e+00,  4.6875e-01,  ..., -1.2656e+00,
           5.3516e-01, -1.3438e+00],
         [ 4.7656e-01,  2.2500e+00, -4.8633e-01,  ...,  3.6719e-01,
          -4.5117e-01,  2.6367e-01]],

        [[ 1.2817e-02,  1.9897e-02, -2.3682e-02,  ...,  7.8201e-05,
          -3.9795e-02,  1.3306e-02],
         [ 1.4922e+00, -5.1172e-01, -7.1094e-01,  ...,  1.6016e+00,
          -3.3594e-01,  2.4707e-01],
         [ 2.9297e-01, -3.4062e+00, -3.9648e-01,  ..., -7.1094e-01,
          -1.3828e+00, -2.6875e+00],
         ...,
         [-8.3594e-01, -4.2188e-01, -3.3789e-01,  ...,  4.0625e-01,
          -1.1641e+00,  1.8984e+00],
         [-1.2578e+00, -1.3281e+00,  4.6875e-01,  ..., -1.2656e+00,
           5.3516e-01, -1.3438e+00],
         [ 4.7656e-01,  2.2500e+00, -4.8633e-01,  ...,  3.6719e-01,
          -4.5117e-01,  2.6367e-01]]], dtype=torch.bfloat16)), (tensor([[[-0.2656, -0.6289, -0.9727,  ...,  0.3301,  0.3516, -1.3516],
         [ 0.3125,  0.0537,  0.1123,  ...,  0.2930,  0.2480,  0.2793],
         [-0.2490, -0.2793, -0.0850,  ...,  0.4082, -2.8750, -1.3750],
         ...,
         [-0.5508, -0.0742,  0.0664,  ...,  3.8750, -3.5156, -2.2500],
         [-0.4023, -1.2109,  0.1289,  ...,  1.0547, -1.4141, -0.2832],
         [-0.3340, -1.6406, -1.1797,  ...,  2.3906, -1.9688,  0.0256]],

        [[-0.2656, -0.6289, -0.9727,  ...,  0.3301,  0.3516, -1.3516],
         [ 0.3125,  0.0537,  0.1123,  ...,  0.2930,  0.2480,  0.2793],
         [-0.2490, -0.2793, -0.0850,  ...,  0.4082, -2.8750, -1.3750],
         ...,
         [-0.5508, -0.0742,  0.0664,  ...,  3.8750, -3.5156, -2.2500],
         [-0.4023, -1.2109,  0.1289,  ...,  1.0547, -1.4141, -0.2832],
         [-0.3340, -1.6406, -1.1797,  ...,  2.3906, -1.9688,  0.0256]],

        [[-0.2656, -0.6289, -0.9727,  ...,  0.3301,  0.3516, -1.3516],
         [ 0.3125,  0.0537,  0.1123,  ...,  0.2930,  0.2480,  0.2793],
         [-0.2490, -0.2793, -0.0850,  ...,  0.4082, -2.8750, -1.3750],
         ...,
         [-0.5508, -0.0742,  0.0664,  ...,  3.8750, -3.5156, -2.2500],
         [-0.4023, -1.2109,  0.1289,  ...,  1.0547, -1.4141, -0.2832],
         [-0.3340, -1.6406, -1.1797,  ...,  2.3906, -1.9688,  0.0256]],

        ...,

        [[-0.0134,  0.9180, -0.1533,  ...,  2.3906, -3.7188,  1.6250],
         [ 0.9141, -0.7188, -0.4258,  ..., -0.6445, -0.3828,  1.3047],
         [-0.3672,  0.5781,  0.1060,  ...,  0.2412,  0.2207, -0.4941],
         ...,
         [-0.8477,  0.4551,  0.4844,  ..., -3.3906,  0.4434, -2.7500],
         [-0.4531,  0.4902,  0.3691,  ..., -2.9375, -4.0312, -0.2344],
         [ 0.0352, -0.5000, -0.5938,  ..., -2.5312,  1.0547, -2.2500]],

        [[-0.0134,  0.9180, -0.1533,  ...,  2.3906, -3.7188,  1.6250],
         [ 0.9141, -0.7188, -0.4258,  ..., -0.6445, -0.3828,  1.3047],
         [-0.3672,  0.5781,  0.1060,  ...,  0.2412,  0.2207, -0.4941],
         ...,
         [-0.8477,  0.4551,  0.4844,  ..., -3.3906,  0.4434, -2.7500],
         [-0.4531,  0.4902,  0.3691,  ..., -2.9375, -4.0312, -0.2344],
         [ 0.0352, -0.5000, -0.5938,  ..., -2.5312,  1.0547, -2.2500]],

        [[-0.0134,  0.9180, -0.1533,  ...,  2.3906, -3.7188,  1.6250],
         [ 0.9141, -0.7188, -0.4258,  ..., -0.6445, -0.3828,  1.3047],
         [-0.3672,  0.5781,  0.1060,  ...,  0.2412,  0.2207, -0.4941],
         ...,
         [-0.8477,  0.4551,  0.4844,  ..., -3.3906,  0.4434, -2.7500],
         [-0.4531,  0.4902,  0.3691,  ..., -2.9375, -4.0312, -0.2344],
         [ 0.0352, -0.5000, -0.5938,  ..., -2.5312,  1.0547, -2.2500]]],
       dtype=torch.bfloat16), tensor([[[ 0.4805, -0.1099, -0.3379,  ..., -0.0300,  0.0986,  0.2188],
         [-0.0181, -0.2637, -0.2461,  ..., -0.2080, -0.1777, -0.1348],
         [-0.4180, -0.0654,  0.0275,  ...,  0.0275, -0.0159,  0.0723],
         ...,
         [ 0.2324,  0.0518,  0.4141,  ..., -0.6719,  0.6680, -0.7969],
         [ 0.2227,  0.3457, -0.1729,  ..., -0.1299,  0.3965, -0.1816],
         [ 0.2949,  0.3594,  0.3672,  ...,  0.3105,  0.5742, -0.4805]],

        [[ 0.4805, -0.1099, -0.3379,  ..., -0.0300,  0.0986,  0.2188],
         [-0.0181, -0.2637, -0.2461,  ..., -0.2080, -0.1777, -0.1348],
         [-0.4180, -0.0654,  0.0275,  ...,  0.0275, -0.0159,  0.0723],
         ...,
         [ 0.2324,  0.0518,  0.4141,  ..., -0.6719,  0.6680, -0.7969],
         [ 0.2227,  0.3457, -0.1729,  ..., -0.1299,  0.3965, -0.1816],
         [ 0.2949,  0.3594,  0.3672,  ...,  0.3105,  0.5742, -0.4805]],

        [[ 0.4805, -0.1099, -0.3379,  ..., -0.0300,  0.0986,  0.2188],
         [-0.0181, -0.2637, -0.2461,  ..., -0.2080, -0.1777, -0.1348],
         [-0.4180, -0.0654,  0.0275,  ...,  0.0275, -0.0159,  0.0723],
         ...,
         [ 0.2324,  0.0518,  0.4141,  ..., -0.6719,  0.6680, -0.7969],
         [ 0.2227,  0.3457, -0.1729,  ..., -0.1299,  0.3965, -0.1816],
         [ 0.2949,  0.3594,  0.3672,  ...,  0.3105,  0.5742, -0.4805]],

        ...,

        [[ 0.3242, -0.4062, -0.1709,  ...,  0.2695, -0.5391, -0.3926],
         [ 0.2695, -0.2197,  0.2227,  ..., -0.2734, -0.5898,  0.0435],
         [ 0.8477, -0.6094, -0.0427,  ..., -0.1914, -0.5508, -0.7031],
         ...,
         [-0.9922,  1.2812, -0.1426,  ...,  0.1475,  0.4668,  0.4688],
         [-1.4141,  1.1797, -0.3125,  ...,  0.4570,  0.0796,  0.2695],
         [-0.1904, -0.0030, -0.6992,  ...,  0.3398, -0.4238,  0.1924]],

        [[ 0.3242, -0.4062, -0.1709,  ...,  0.2695, -0.5391, -0.3926],
         [ 0.2695, -0.2197,  0.2227,  ..., -0.2734, -0.5898,  0.0435],
         [ 0.8477, -0.6094, -0.0427,  ..., -0.1914, -0.5508, -0.7031],
         ...,
         [-0.9922,  1.2812, -0.1426,  ...,  0.1475,  0.4668,  0.4688],
         [-1.4141,  1.1797, -0.3125,  ...,  0.4570,  0.0796,  0.2695],
         [-0.1904, -0.0030, -0.6992,  ...,  0.3398, -0.4238,  0.1924]],

        [[ 0.3242, -0.4062, -0.1709,  ...,  0.2695, -0.5391, -0.3926],
         [ 0.2695, -0.2197,  0.2227,  ..., -0.2734, -0.5898,  0.0435],
         [ 0.8477, -0.6094, -0.0427,  ..., -0.1914, -0.5508, -0.7031],
         ...,
         [-0.9922,  1.2812, -0.1426,  ...,  0.1475,  0.4668,  0.4688],
         [-1.4141,  1.1797, -0.3125,  ...,  0.4570,  0.0796,  0.2695],
         [-0.1904, -0.0030, -0.6992,  ...,  0.3398, -0.4238,  0.1924]]],
       dtype=torch.bfloat16))), hidden_states=None, attentions=None, cross_attentions=None)