This commit is contained in:
s444501 2023-02-12 23:05:40 +01:00
parent 5c6144dae0
commit 0c3af3bede

View File

@ -67,8 +67,11 @@ class GPT2ClassificationHeadCustom(nn.Module):
def forward(self, x, **kwargs):
if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None:
# Get hidden states from second from the end
logger.info('Hidden states found!')
logger.info(len(kwargs['hidden_states']))
hidden = kwargs['hidden_states'][-2]
else:
logger.info('no hidden states :(')
hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)
x = self.dense_1_input(x)