This commit is contained in:
Andrzej Preibisz 2023-02-12 17:18:31 +01:00
commit 74117e2a3f

View File

@ -20,6 +20,7 @@ import logging
import os
import random
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Optional
@ -47,6 +48,17 @@ from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from roberta import RobertaForSequenceClassificationCustomSimple, RobertaForSequenceClassificationCustom, RobertaForSequenceClassificationCustomAlternative
from gpt2 import GPT2ForSequenceClassificationCustomSimple, GPT2ForSequenceClassificationCustom
MODEL_NAME_TO_CLASS = {
'roberta_simple': RobertaForSequenceClassificationCustomSimple,
'roberta_hidden': RobertaForSequenceClassificationCustom,
'roberta_hidden_v2': RobertaForSequenceClassificationCustomAlternative,
'gpt2_simple': GPT2ForSequenceClassificationCustomSimple,
'gpt2_hidden': GPT2ForSequenceClassificationCustom,
}
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.23.0")
@ -207,6 +219,13 @@ class ModelArguments:
metadata={"help": "Freeze encoder weights"},
)
custom_model: str = field(
default=None,
metadata={
"help": "Use custom implementation from available list",
"choices": list(MODEL_NAME_TO_CLASS.keys()),
},
)
def freeze_model_weights(model: torch.nn.Module) -> None:
count = 0
@ -384,7 +403,25 @@ def main():
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForSequenceClassification.from_pretrained(
custom_model = model_args.custom_model
if custom_model is not None:
# Check model and implementation is the same
if 'roberta' in custom_model and 'roberta' not in model_args.model_name_or_path:
raise RuntimeError('Model and custom implementation should be the same type: RoBERTa')
elif 'gpt2' in custom_model and 'gpt2' not in model_args.model_name_or_path:
raise RuntimeError('Model and custom implementation should be the same type: GPT-2')
# Set custom configuration in model configuration
config.use_hidden_states = 'hidden' in custom_model
logger.info(f'Using hidden states in model: {config.use_hidden_states}')
# Get class to initialize model
model_cls = MODEL_NAME_TO_CLASS[custom_model]
else:
model_cls = AutoModelForSequenceClassification
logger.info(f'Using implementation from class: {model_cls.__name__}')
model = model_cls.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
@ -399,6 +436,11 @@ def main():
freeze_model_weights(model.decoder)
if 'gpt2' in tokenizer.name_or_path and tokenizer.pad_token is None:
logger.info(f'Set PAD token to EOS: {tokenizer.eos_token}')
tokenizer._pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# Preprocessing the raw_datasets
if data_args.task_name is not None:
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
@ -459,11 +501,6 @@ def main():
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
if 'gpt2' in tokenizer.name_or_path and tokenizer.pad_token is None:
logger.info(f'Set PAD token to EOS: {tokenizer.eos_token}')
tokenizer._pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
# Map labels to IDs (not necessary for GLUE tasks)
@ -492,7 +529,16 @@ def main():
eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
label_to_indexes = defaultdict(list)
for index, eval_sample in enumerate(eval_dataset):
label_to_indexes[eval_sample['label']].append(index)
max_samples_per_label = int(max_eval_samples / len(label_to_indexes))
eval_sample_indexes = []
for label, indexes in label_to_indexes.items():
eval_sample_indexes.extend(indexes[:max_samples_per_label])
logger.info(f"Set {max_samples_per_label} samples for {label}-class")
eval_sample_indexes.sort()
eval_dataset = eval_dataset.select(eval_sample_indexes)
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
if "test" not in raw_datasets and "test_matched" not in raw_datasets:
@ -549,13 +595,14 @@ def main():
)
# Training
ignore_keys_for_eval = ['hidden_states', 'attentions', 'past_key_values']
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
train_result = trainer.train(resume_from_checkpoint=checkpoint, ignore_keys_for_eval=ignore_keys_for_eval)
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
@ -585,7 +632,7 @@ def main():
combined = {}
for eval_dataset, task in zip(eval_datasets, tasks):
metrics = trainer.evaluate(eval_dataset=eval_dataset)
metrics = trainer.evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys_for_eval)
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
@ -613,7 +660,7 @@ def main():
for predict_dataset, task in zip(predict_datasets, tasks):
# Removing the `label` columns because it contains -1 and Trainer won't like that.
predict_dataset = predict_dataset.remove_columns("label")
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict", ignore_keys=ignore_keys_for_eval).predictions
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")