gpt test2

This commit is contained in:
s444501 2023-02-12 23:27:40 +01:00
parent 7ce408fb11
commit 5acc5cd96a

View File

@ -67,11 +67,11 @@ class GPT2ClassificationHeadCustom(nn.Module):
def forward(self, x, **kwargs):
if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None:
# Get hidden states from second from the end
logger.info('Hidden states found!')
logger.info(len(kwargs['hidden_states']))
print('Hidden states found!')
print(len(kwargs['hidden_states']))
hidden = kwargs['hidden_states'][-2]
else:
logger.info('no hidden states :(')
print('no hidden states :(')
hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)
x = self.dense_1_input(x)