diff --git a/custom_gpt.py b/custom_gpt.py new file mode 100644 index 0000000..287622d --- /dev/null +++ b/custom_gpt.py @@ -0,0 +1,186 @@ +import logging +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss +from transformers import GPT2Model, GPT2ForSequenceClassification +from transformers.modeling_outputs import SequenceClassifierOutputWithPast + + +logger = logging.getLogger(__name__) + + +class GPT2ClassificationHeadCustomFIX(nn.Module): + def __init__(self, config): + super().__init__() + hidden_size = config.n_embd + self.dense_1_input = nn.Linear(hidden_size, 2 * hidden_size) + self.dense_1_hidden = nn.Linear(hidden_size, 2 * hidden_size) + + self.dense_2 = nn.Linear(2 * hidden_size, 2 * hidden_size) + self.dense_2_hidden = nn.Linear(hidden_size, 2 * hidden_size) + + self.dense_3 = nn.Linear(2 * hidden_size, 2 * hidden_size) + self.dense_3_hidden = nn.Linear(hidden_size, 2 * hidden_size) + + self.dense_4 = nn.Linear(2 * hidden_size, hidden_size) + + self.dropout = nn.Dropout(config.resid_pdrop) + self.out_proj = nn.Linear(hidden_size, config.num_labels, bias=False) + + def forward(self, x, **kwargs): + if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None: + hidden_a = torch.cat(kwargs['hidden_states'][-2:]) + else: + hidden_a = torch.zeros(x.size(), dtype=x.dtype, device=x.device) + + x = self.dense_1_input(x) + x = torch.relu(x) + x = self.dropout(x) + + hidden = self.dense_1_hidden(hidden_a) + hidden = torch.relu(hidden) + hidden = self.dropout(hidden) + + x = torch.cat((x, hidden)) + + x = self.dense_2(x) + x = torch.relu(x) + x = self.dropout(x) + + hidden = self.dense_2_hidden(hidden_a) + hidden = torch.relu(hidden) + hidden = self.dropout(hidden) + + x = torch.cat((x, hidden)) + + x = self.dense_3(x) + x = torch.relu(x) + x = self.dropout(x) + + hidden = self.dense_3_hidden(hidden_a) + hidden = torch.relu(hidden) + hidden = self.dropout(hidden) + + x = torch.cat((x, hidden)) + + x = self.dense_4(x) + x = torch.relu(x) + x = self.dropout(x) + + x = self.out_proj(x) + return x + + +class GPT2ForSequenceClassificationCustomFIX(GPT2ForSequenceClassification): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = GPT2ClassificationHeadCustomFIX(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states or self.config.use_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states, hidden_states=transformer_outputs.hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + diff --git a/custom_roberta.py b/custom_roberta.py new file mode 100644 index 0000000..89d36fb --- /dev/null +++ b/custom_roberta.py @@ -0,0 +1,145 @@ +from typing import Optional, Union, Tuple + +import torch +from torch import nn +from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss +from transformers import RobertaForSequenceClassification, RobertaModel +from transformers.modeling_outputs import SequenceClassifierOutput + +class RobertaClassificationHeadCustomFIX(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + hidden_size = config.hidden_size + + self.dense_1_input = nn.Linear(hidden_size, 8 * hidden_size) + self.dense_1_hidden = nn.Linear(hidden_size, 8 * hidden_size) + self.dense_2 = nn.Linear(16 * hidden_size, 8 * hidden_size) + self.dense_3 = nn.Linear(8 * hidden_size, 4 * hidden_size) + self.dense_4 = nn.Linear(4 * hidden_size, hidden_size) + + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.relu = nn.LeakyReLU() + self.out_proj = nn.Linear(hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + if 'hidden_states' in kwargs and kwargs['hidden_states'] is not None: + # take token (equiv. to [CLS]) from hidden states from second from the end + hidden = kwargs['hidden_states'][-2][:, 0, :] + else: + hidden = torch.zeros(x.size(), dtype=x.dtype, device=x.device) + + x = self.dense_1_input(x) + x = self.relu(x) + x = self.dropout(x) + + hidden = self.dense_1_hidden(hidden) + hidden = self.relu(hidden) + hidden = self.dropout(hidden) + + x = torch.cat((x, hidden), dim=1) + x = self.dense_2(x) + x = self.relu(x) + x = self.dropout(x) + + x = self.dense_3(x) + x = self.relu(x) + x = self.dropout(x) + + x = self.dense_4(x) + x = self.relu(x) + x = self.dropout(x) + + x = self.out_proj(x) + return x + + +class RobertaForSequenceClassificationCustomFIX(RobertaForSequenceClassification): + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.classifier = RobertaClassificationHeadCustomFIX(config) + + self.init_weights() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states or self.config.use_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + if return_dict: + logits = self.classifier(sequence_output, hidden_states=outputs.hidden_states) + else: + raise NotImplemented('Not implemented for using non-dictionary object') + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/run_glue.py b/run_glue.py index 4e0b039..4af29f3 100644 --- a/run_glue.py +++ b/run_glue.py @@ -45,12 +45,12 @@ from transformers import ( from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version -from gpt2 import GPT2ForSequenceClassificationCustom -from roberta import RobertaForSequenceClassificationCustomAlternative +from custom_gpt import GPT2ForSequenceClassificationCustomFIX +from custom_roberta import RobertaForSequenceClassificationCustomFIX MODEL_NAME_TO_CLASS = { - 'roberta_custom': RobertaForSequenceClassificationCustomAlternative, - 'gpt2_custom': GPT2ForSequenceClassificationCustom + 'roberta_custom': RobertaForSequenceClassificationCustomFIX, + 'gpt2_custom': GPT2ForSequenceClassificationCustomFIX } # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.23.0")