fix?
This commit is contained in:
parent
b49de54ba0
commit
b4227f7732
2
gpt2.py
2
gpt2.py
@ -58,7 +58,7 @@ class GPT2ClassificationHeadCustom(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.n_embd
|
hidden_size = config.n_embd
|
||||||
self.dense_1_input = nn.Linear(5 * hidden_size, 2 * hidden_size)
|
self.dense_1_input = nn.Linear(hidden_size, 2 * hidden_size)
|
||||||
self.dense_1_hidden = nn.Linear(hidden_size, 2 * hidden_size)
|
self.dense_1_hidden = nn.Linear(hidden_size, 2 * hidden_size)
|
||||||
self.dense_2 = nn.Linear(4 * hidden_size, 2 * hidden_size)
|
self.dense_2 = nn.Linear(4 * hidden_size, 2 * hidden_size)
|
||||||
self.dense_3 = nn.Linear(2 * hidden_size, hidden_size)
|
self.dense_3 = nn.Linear(2 * hidden_size, hidden_size)
|
||||||
|
Loading…
Reference in New Issue
Block a user