This commit is contained in:
s444501 2023-02-13 00:01:30 +01:00
parent b49de54ba0
commit b4227f7732

View File

@ -58,7 +58,7 @@ class GPT2ClassificationHeadCustom(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
hidden_size = config.n_embd hidden_size = config.n_embd
self.dense_1_input = nn.Linear(5 * hidden_size, 2 * hidden_size) self.dense_1_input = nn.Linear(hidden_size, 2 * hidden_size)
self.dense_1_hidden = nn.Linear(hidden_size, 2 * hidden_size) self.dense_1_hidden = nn.Linear(hidden_size, 2 * hidden_size)
self.dense_2 = nn.Linear(4 * hidden_size, 2 * hidden_size) self.dense_2 = nn.Linear(4 * hidden_size, 2 * hidden_size)
self.dense_3 = nn.Linear(2 * hidden_size, hidden_size) self.dense_3 = nn.Linear(2 * hidden_size, hidden_size)