diff --git a/gpt2.py b/gpt2.py index b72bd2c..462376c 100644 --- a/gpt2.py +++ b/gpt2.py @@ -68,7 +68,7 @@ class GPT2ClassificationHeadCustom(nn.Module): def forward(self, x, **kwargs): if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None: # Get last 5 hidden states from the end - hidden = torch.cat(kwargs['hidden_states'][-5:], dim=2) + hidden = torch.cat(kwargs['hidden_states'][-5:]) else: hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)