gpt 5 hidden and new dense layer
This commit is contained in:
parent
5acc5cd96a
commit
858ef8fc58
14
gpt2.py
14
gpt2.py
@ -60,18 +60,16 @@ class GPT2ClassificationHeadCustom(nn.Module):
|
|||||||
hidden_size = config.n_embd
|
hidden_size = config.n_embd
|
||||||
self.dense_1_input = nn.Linear(hidden_size, 2 * hidden_size)
|
self.dense_1_input = nn.Linear(hidden_size, 2 * hidden_size)
|
||||||
self.dense_1_hidden = nn.Linear(hidden_size, 2 * hidden_size)
|
self.dense_1_hidden = nn.Linear(hidden_size, 2 * hidden_size)
|
||||||
self.dense_2 = nn.Linear(4 * hidden_size, hidden_size)
|
self.dense_2 = nn.Linear(4 * hidden_size, 2 * hidden_size)
|
||||||
|
self.dense_3 = nn.Linear(2 * hidden_size, hidden_size)
|
||||||
self.dropout = nn.Dropout(config.resid_pdrop)
|
self.dropout = nn.Dropout(config.resid_pdrop)
|
||||||
self.out_proj = nn.Linear(hidden_size, config.num_labels, bias=False)
|
self.out_proj = nn.Linear(hidden_size, config.num_labels, bias=False)
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
def forward(self, x, **kwargs):
|
||||||
if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None:
|
if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None:
|
||||||
# Get hidden states from second from the end
|
# Get last 5 hidden states from the end
|
||||||
print('Hidden states found!')
|
hidden = kwargs['hidden_states'][-5:]
|
||||||
print(len(kwargs['hidden_states']))
|
|
||||||
hidden = kwargs['hidden_states'][-2]
|
|
||||||
else:
|
else:
|
||||||
print('no hidden states :(')
|
|
||||||
hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)
|
hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
x = self.dense_1_input(x)
|
x = self.dense_1_input(x)
|
||||||
@ -87,9 +85,7 @@ class GPT2ClassificationHeadCustom(nn.Module):
|
|||||||
x = torch.relu(x)
|
x = torch.relu(x)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
|
|
||||||
x = torch.relu(x)
|
x = self.dense_3(x)
|
||||||
x = self.dropout(x)
|
|
||||||
|
|
||||||
x = torch.relu(x)
|
x = torch.relu(x)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user