From e4cdb95c3da924f00edcfd785438e76ebf09c7d1 Mon Sep 17 00:00:00 2001 From: s444501 Date: Mon, 13 Feb 2023 00:02:31 +0100 Subject: [PATCH] fix? --- gpt2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt2.py b/gpt2.py index b72bd2c..462376c 100644 --- a/gpt2.py +++ b/gpt2.py @@ -68,7 +68,7 @@ class GPT2ClassificationHeadCustom(nn.Module): def forward(self, x, **kwargs): if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None: # Get last 5 hidden states from the end - hidden = torch.cat(kwargs['hidden_states'][-5:], dim=2) + hidden = torch.cat(kwargs['hidden_states'][-5:]) else: hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)