roberta custom head test

This commit is contained in:
s444501 2023-02-12 19:03:42 +01:00
parent 3037991562
commit 256441a729
2 changed files with 24 additions and 1 deletions

View File

@ -7,6 +7,28 @@ from transformers import RobertaForSequenceClassification, RobertaModel
from transformers.modeling_outputs import SequenceClassifierOutput from transformers.modeling_outputs import SequenceClassifierOutput
class LeakyHeadCustom(nn.Module):
"""Incorporates Leaky ReLU"""
def __init__(self, config):
super().__init__()
hidden_size = config.hidden_size
self.dense_1 = nn.Linear(hidden_size, 2 * hidden_size)
self.dense_2 = nn.Linear(2 * hidden_size, hidden_size)
self.out_proj = nn.Linear(hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dense_1(x)
x = torch.nn.LeakyReLU(x)
x = self.dense_2(x)
x = torch.nn.LeakyReLU(x)
x = self.out_proj(x)
return x
# Simple version # # Simple version #
class RobertaClassificationHeadCustomSimple(nn.Module): class RobertaClassificationHeadCustomSimple(nn.Module):

View File

@ -49,10 +49,11 @@ from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from roberta import RobertaForSequenceClassificationCustomSimple, RobertaForSequenceClassificationCustom, RobertaForSequenceClassificationCustomAlternative from roberta import LeakyHeadCustom, RobertaForSequenceClassificationCustomSimple, RobertaForSequenceClassificationCustom, RobertaForSequenceClassificationCustomAlternative
from gpt2 import GPT2ForSequenceClassificationCustomSimple, GPT2ForSequenceClassificationCustom from gpt2 import GPT2ForSequenceClassificationCustomSimple, GPT2ForSequenceClassificationCustom
MODEL_NAME_TO_CLASS = { MODEL_NAME_TO_CLASS = {
'roberta_leaky': LeakyHeadCustom,
'roberta_simple': RobertaForSequenceClassificationCustomSimple, 'roberta_simple': RobertaForSequenceClassificationCustomSimple,
'roberta_hidden': RobertaForSequenceClassificationCustom, 'roberta_hidden': RobertaForSequenceClassificationCustom,
'roberta_hidden_v2': RobertaForSequenceClassificationCustomAlternative, 'roberta_hidden_v2': RobertaForSequenceClassificationCustomAlternative,