From 0c3af3bedeef949caa83c6423e300105c90ad8b0 Mon Sep 17 00:00:00 2001 From: s444501 Date: Sun, 12 Feb 2023 23:05:40 +0100 Subject: [PATCH] gpt test --- gpt2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gpt2.py b/gpt2.py index 5dedd0f..b6cfef8 100644 --- a/gpt2.py +++ b/gpt2.py @@ -67,8 +67,11 @@ class GPT2ClassificationHeadCustom(nn.Module): def forward(self, x, **kwargs): if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None: # Get hidden states from second from the end + logger.info('Hidden states found!') + logger.info(len(kwargs['hidden_states'])) hidden = kwargs['hidden_states'][-2] else: + logger.info('no hidden states :(') hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device) x = self.dense_1_input(x)