gpt test2

This commit is contained in:
s444501 2023-02-12 23:27:40 +01:00
parent 7ce408fb11
commit 5acc5cd96a

View File

@ -67,11 +67,11 @@ class GPT2ClassificationHeadCustom(nn.Module):
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None: if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None:
# Get hidden states from second from the end # Get hidden states from second from the end
logger.info('Hidden states found!') print('Hidden states found!')
logger.info(len(kwargs['hidden_states'])) print(len(kwargs['hidden_states']))
hidden = kwargs['hidden_states'][-2] hidden = kwargs['hidden_states'][-2]
else: else:
logger.info('no hidden states :(') print('no hidden states :(')
hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device) hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)
x = self.dense_1_input(x) x = self.dense_1_input(x)