roberta custom head test
This commit is contained in:
parent
3037991562
commit
256441a729
22
roberta.py
22
roberta.py
@ -7,6 +7,28 @@ from transformers import RobertaForSequenceClassification, RobertaModel
|
|||||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||||
|
|
||||||
|
|
||||||
|
class LeakyHeadCustom(nn.Module):
|
||||||
|
"""Incorporates Leaky ReLU"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
self.dense_1 = nn.Linear(hidden_size, 2 * hidden_size)
|
||||||
|
self.dense_2 = nn.Linear(2 * hidden_size, hidden_size)
|
||||||
|
self.out_proj = nn.Linear(hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
def forward(self, features, **kwargs):
|
||||||
|
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
||||||
|
|
||||||
|
x = self.dense_1(x)
|
||||||
|
x = torch.nn.LeakyReLU(x)
|
||||||
|
|
||||||
|
x = self.dense_2(x)
|
||||||
|
x = torch.nn.LeakyReLU(x)
|
||||||
|
|
||||||
|
x = self.out_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
# Simple version #
|
# Simple version #
|
||||||
|
|
||||||
class RobertaClassificationHeadCustomSimple(nn.Module):
|
class RobertaClassificationHeadCustomSimple(nn.Module):
|
||||||
|
@ -49,10 +49,11 @@ from transformers.utils import check_min_version, send_example_telemetry
|
|||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
from roberta import RobertaForSequenceClassificationCustomSimple, RobertaForSequenceClassificationCustom, RobertaForSequenceClassificationCustomAlternative
|
from roberta import LeakyHeadCustom, RobertaForSequenceClassificationCustomSimple, RobertaForSequenceClassificationCustom, RobertaForSequenceClassificationCustomAlternative
|
||||||
from gpt2 import GPT2ForSequenceClassificationCustomSimple, GPT2ForSequenceClassificationCustom
|
from gpt2 import GPT2ForSequenceClassificationCustomSimple, GPT2ForSequenceClassificationCustom
|
||||||
|
|
||||||
MODEL_NAME_TO_CLASS = {
|
MODEL_NAME_TO_CLASS = {
|
||||||
|
'roberta_leaky': LeakyHeadCustom,
|
||||||
'roberta_simple': RobertaForSequenceClassificationCustomSimple,
|
'roberta_simple': RobertaForSequenceClassificationCustomSimple,
|
||||||
'roberta_hidden': RobertaForSequenceClassificationCustom,
|
'roberta_hidden': RobertaForSequenceClassificationCustom,
|
||||||
'roberta_hidden_v2': RobertaForSequenceClassificationCustomAlternative,
|
'roberta_hidden_v2': RobertaForSequenceClassificationCustomAlternative,
|
||||||
|
Loading…
Reference in New Issue
Block a user