fix?
This commit is contained in:
parent
9987da1656
commit
b49de54ba0
4
gpt2.py
4
gpt2.py
@ -58,7 +58,7 @@ class GPT2ClassificationHeadCustom(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
hidden_size = config.n_embd
|
||||
self.dense_1_input = nn.Linear(hidden_size, 2 * hidden_size)
|
||||
self.dense_1_input = nn.Linear(5 * hidden_size, 2 * hidden_size)
|
||||
self.dense_1_hidden = nn.Linear(hidden_size, 2 * hidden_size)
|
||||
self.dense_2 = nn.Linear(4 * hidden_size, 2 * hidden_size)
|
||||
self.dense_3 = nn.Linear(2 * hidden_size, hidden_size)
|
||||
@ -68,7 +68,7 @@ class GPT2ClassificationHeadCustom(nn.Module):
|
||||
def forward(self, x, **kwargs):
|
||||
if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None:
|
||||
# Get last 5 hidden states from the end
|
||||
hidden = kwargs['hidden_states'][-5:]
|
||||
hidden = torch.cat(kwargs['hidden_states'][-5:], dim=2)
|
||||
else:
|
||||
hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user