fix for real

This commit is contained in:
s444501 2023-02-13 00:08:30 +01:00
parent e4cdb95c3d
commit 24409ffb1b

View File

@ -59,7 +59,7 @@ class GPT2ClassificationHeadCustom(nn.Module):
super().__init__() super().__init__()
hidden_size = config.n_embd hidden_size = config.n_embd
self.dense_1_input = nn.Linear(hidden_size, 2 * hidden_size) self.dense_1_input = nn.Linear(hidden_size, 2 * hidden_size)
self.dense_1_hidden = nn.Linear(hidden_size, 2 * hidden_size) self.dense_1_hidden = nn.Linear(5 * hidden_size, 2 * hidden_size)
self.dense_2 = nn.Linear(4 * hidden_size, 2 * hidden_size) self.dense_2 = nn.Linear(4 * hidden_size, 2 * hidden_size)
self.dense_3 = nn.Linear(2 * hidden_size, hidden_size) self.dense_3 = nn.Linear(2 * hidden_size, hidden_size)
self.dropout = nn.Dropout(config.resid_pdrop) self.dropout = nn.Dropout(config.resid_pdrop)
@ -68,7 +68,7 @@ class GPT2ClassificationHeadCustom(nn.Module):
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None: if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None:
# Get last 5 hidden states from the end # Get last 5 hidden states from the end
hidden = torch.cat(kwargs['hidden_states'][-5:]) hidden = torch.cat(kwargs['hidden_states'][-5:], dim=2)
else: else:
hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device) hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)