Merge branch 'master' of https://git.wmi.amu.edu.pl/s444465/projekt-glebokie
This commit is contained in:
commit
5b438a884c
22
roberta.py
22
roberta.py
@ -7,6 +7,28 @@ from transformers import RobertaForSequenceClassification, RobertaModel
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
|
||||
|
||||
class LeakyHeadCustom(nn.Module):
|
||||
"""Incorporates Leaky ReLU"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.dense_1 = nn.Linear(hidden_size, 2 * hidden_size)
|
||||
self.dense_2 = nn.Linear(2 * hidden_size, hidden_size)
|
||||
self.out_proj = nn.Linear(hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
||||
|
||||
x = self.dense_1(x)
|
||||
x = torch.nn.LeakyReLU(x)
|
||||
|
||||
x = self.dense_2(x)
|
||||
x = torch.nn.LeakyReLU(x)
|
||||
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
# Simple version #
|
||||
|
||||
class RobertaClassificationHeadCustomSimple(nn.Module):
|
||||
|
@ -49,10 +49,11 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
from roberta import RobertaForSequenceClassificationCustomSimple, RobertaForSequenceClassificationCustom, RobertaForSequenceClassificationCustomAlternative
|
||||
from roberta import LeakyHeadCustom, RobertaForSequenceClassificationCustomSimple, RobertaForSequenceClassificationCustom, RobertaForSequenceClassificationCustomAlternative
|
||||
from gpt2 import GPT2ForSequenceClassificationCustomSimple, GPT2ForSequenceClassificationCustom
|
||||
|
||||
MODEL_NAME_TO_CLASS = {
|
||||
'roberta_leaky': LeakyHeadCustom,
|
||||
'roberta_simple': RobertaForSequenceClassificationCustomSimple,
|
||||
'roberta_hidden': RobertaForSequenceClassificationCustom,
|
||||
'roberta_hidden_v2': RobertaForSequenceClassificationCustomAlternative,
|
||||
|
Loading…
Reference in New Issue
Block a user