This commit is contained in:
s444501 2023-02-13 00:02:31 +01:00
parent b4227f7732
commit e4cdb95c3d

View File

@ -68,7 +68,7 @@ class GPT2ClassificationHeadCustom(nn.Module):
def forward(self, x, **kwargs):
if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None:
# Get last 5 hidden states from the end
hidden = torch.cat(kwargs['hidden_states'][-5:], dim=2)
hidden = torch.cat(kwargs['hidden_states'][-5:])
else:
hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)