gpt test2
This commit is contained in:
parent
7ce408fb11
commit
5acc5cd96a
6
gpt2.py
6
gpt2.py
@ -67,11 +67,11 @@ class GPT2ClassificationHeadCustom(nn.Module):
|
|||||||
def forward(self, x, **kwargs):
|
def forward(self, x, **kwargs):
|
||||||
if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None:
|
if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None:
|
||||||
# Get hidden states from second from the end
|
# Get hidden states from second from the end
|
||||||
logger.info('Hidden states found!')
|
print('Hidden states found!')
|
||||||
logger.info(len(kwargs['hidden_states']))
|
print(len(kwargs['hidden_states']))
|
||||||
hidden = kwargs['hidden_states'][-2]
|
hidden = kwargs['hidden_states'][-2]
|
||||||
else:
|
else:
|
||||||
logger.info('no hidden states :(')
|
print('no hidden states :(')
|
||||||
hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)
|
hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
x = self.dense_1_input(x)
|
x = self.dense_1_input(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user