From 256441a7292dcc0937806c0178f415491ad10ddf Mon Sep 17 00:00:00 2001 From: s444501 Date: Sun, 12 Feb 2023 19:03:42 +0100 Subject: [PATCH] roberta custom head test --- roberta.py | 22 ++++++++++++++++++++++ run_glue.py | 3 ++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/roberta.py b/roberta.py index d58e15a..8b1e325 100644 --- a/roberta.py +++ b/roberta.py @@ -7,6 +7,28 @@ from transformers import RobertaForSequenceClassification, RobertaModel from transformers.modeling_outputs import SequenceClassifierOutput +class LeakyHeadCustom(nn.Module): + """Incorporates Leaky ReLU""" + + def __init__(self, config): + super().__init__() + hidden_size = config.hidden_size + self.dense_1 = nn.Linear(hidden_size, 2 * hidden_size) + self.dense_2 = nn.Linear(2 * hidden_size, hidden_size) + self.out_proj = nn.Linear(hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + + x = self.dense_1(x) + x = torch.nn.LeakyReLU(x) + + x = self.dense_2(x) + x = torch.nn.LeakyReLU(x) + + x = self.out_proj(x) + return x + # Simple version # class RobertaClassificationHeadCustomSimple(nn.Module): diff --git a/run_glue.py b/run_glue.py index 9cc5cff..69bee14 100644 --- a/run_glue.py +++ b/run_glue.py @@ -49,10 +49,11 @@ from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version -from roberta import RobertaForSequenceClassificationCustomSimple, RobertaForSequenceClassificationCustom, RobertaForSequenceClassificationCustomAlternative +from roberta import LeakyHeadCustom, RobertaForSequenceClassificationCustomSimple, RobertaForSequenceClassificationCustom, RobertaForSequenceClassificationCustomAlternative from gpt2 import GPT2ForSequenceClassificationCustomSimple, GPT2ForSequenceClassificationCustom MODEL_NAME_TO_CLASS = { + 'roberta_leaky': LeakyHeadCustom, 'roberta_simple': RobertaForSequenceClassificationCustomSimple, 'roberta_hidden': RobertaForSequenceClassificationCustom, 'roberta_hidden_v2': RobertaForSequenceClassificationCustomAlternative,