diff --git a/gpt2.py b/gpt2.py index 462376c..e2d9307 100644 --- a/gpt2.py +++ b/gpt2.py @@ -59,7 +59,7 @@ class GPT2ClassificationHeadCustom(nn.Module): super().__init__() hidden_size = config.n_embd self.dense_1_input = nn.Linear(hidden_size, 2 * hidden_size) - self.dense_1_hidden = nn.Linear(hidden_size, 2 * hidden_size) + self.dense_1_hidden = nn.Linear(5 * hidden_size, 2 * hidden_size) self.dense_2 = nn.Linear(4 * hidden_size, 2 * hidden_size) self.dense_3 = nn.Linear(2 * hidden_size, hidden_size) self.dropout = nn.Dropout(config.resid_pdrop) @@ -68,7 +68,7 @@ class GPT2ClassificationHeadCustom(nn.Module): def forward(self, x, **kwargs): if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None: # Get last 5 hidden states from the end - hidden = torch.cat(kwargs['hidden_states'][-5:]) + hidden = torch.cat(kwargs['hidden_states'][-5:], dim=2) else: hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)