55 lines
1.7 KiB
Python
55 lines
1.7 KiB
Python
from typing import Optional, Union, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss
|
|
from transformers import RobertaForSequenceClassification, RobertaModel
|
|
from transformers.modeling_outputs import SequenceClassifierOutput
|
|
|
|
|
|
# Simple version #
|
|
|
|
class RobertaClassificationHeadCustomSimple(nn.Module):
|
|
"""Head for sentence-level classification tasks."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
hidden_size = config.hidden_size
|
|
self.dense_1 = nn.Linear(hidden_size, 4 * hidden_size)
|
|
self.dense_2 = nn.Linear(4 * hidden_size, hidden_size)
|
|
classifier_dropout = (
|
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
|
)
|
|
self.dropout = nn.Dropout(classifier_dropout)
|
|
self.out_proj = nn.Linear(hidden_size, config.num_labels)
|
|
self.activation = nn.GELU()
|
|
|
|
def forward(self, features, **kwargs):
|
|
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
|
|
|
x = self.dense_1(x)
|
|
x = self.activation(x)
|
|
x = self.dropout(x)
|
|
|
|
x = self.dense_2(x)
|
|
x = self.activation(x)
|
|
x = self.dropout(x)
|
|
|
|
x = self.out_proj(x)
|
|
return x
|
|
|
|
|
|
class RobertaForSequenceClassificationCustomSimple(RobertaForSequenceClassification):
|
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
self.config = config
|
|
|
|
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
|
self.classifier = RobertaClassificationHeadCustomSimple(config)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|