update run_glue
This commit is contained in:
parent
a0a10fe18e
commit
64a97b5fbb
67
run_glue.py
67
run_glue.py
@ -20,6 +20,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -47,6 +48,17 @@ from transformers.utils import check_min_version, send_example_telemetry
|
|||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
|
from roberta import RobertaForSequenceClassificationCustomSimple, RobertaForSequenceClassificationCustom, RobertaForSequenceClassificationCustomAlternative
|
||||||
|
from gpt2 import GPT2ForSequenceClassificationCustomSimple, GPT2ForSequenceClassificationCustom
|
||||||
|
|
||||||
|
MODEL_NAME_TO_CLASS = {
|
||||||
|
'roberta_simple': RobertaForSequenceClassificationCustomSimple,
|
||||||
|
'roberta_hidden': RobertaForSequenceClassificationCustom,
|
||||||
|
'roberta_hidden_v2': RobertaForSequenceClassificationCustomAlternative,
|
||||||
|
'gpt2_simple': GPT2ForSequenceClassificationCustomSimple,
|
||||||
|
'gpt2_hidden': GPT2ForSequenceClassificationCustom,
|
||||||
|
}
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.23.0")
|
check_min_version("4.23.0")
|
||||||
|
|
||||||
@ -201,6 +213,13 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
||||||
)
|
)
|
||||||
|
custom_model: str = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Use custom implementation from available list",
|
||||||
|
"choices": list(MODEL_NAME_TO_CLASS.keys()),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -366,7 +385,25 @@ def main():
|
|||||||
revision=model_args.model_revision,
|
revision=model_args.model_revision,
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(
|
|
||||||
|
custom_model = model_args.custom_model
|
||||||
|
if custom_model is not None:
|
||||||
|
# Check model and implementation is the same
|
||||||
|
if 'roberta' in custom_model and 'roberta' not in model_args.model_name_or_path:
|
||||||
|
raise RuntimeError('Model and custom implementation should be the same type: RoBERTa')
|
||||||
|
elif 'gpt2' in custom_model and 'gpt2' not in model_args.model_name_or_path:
|
||||||
|
raise RuntimeError('Model and custom implementation should be the same type: GPT-2')
|
||||||
|
|
||||||
|
# Set custom configuration in model configuration
|
||||||
|
config.use_hidden_states = 'hidden' in custom_model
|
||||||
|
logger.info(f'Using hidden states in model: {config.use_hidden_states}')
|
||||||
|
|
||||||
|
# Get class to initialize model
|
||||||
|
model_cls = MODEL_NAME_TO_CLASS[custom_model]
|
||||||
|
else:
|
||||||
|
model_cls = AutoModelForSequenceClassification
|
||||||
|
logger.info(f'Using implementation from class: {model_cls.__name__}')
|
||||||
|
model = model_cls.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
@ -376,6 +413,11 @@ def main():
|
|||||||
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if 'gpt2' in tokenizer.name_or_path and tokenizer.pad_token is None:
|
||||||
|
logger.info(f'Set PAD token to EOS: {tokenizer.eos_token}')
|
||||||
|
tokenizer._pad_token = tokenizer.eos_token
|
||||||
|
model.config.pad_token_id = model.config.eos_token_id
|
||||||
|
|
||||||
# Preprocessing the raw_datasets
|
# Preprocessing the raw_datasets
|
||||||
if data_args.task_name is not None:
|
if data_args.task_name is not None:
|
||||||
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
|
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
|
||||||
@ -436,11 +478,6 @@ def main():
|
|||||||
args = (
|
args = (
|
||||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||||
)
|
)
|
||||||
if 'gpt2' in tokenizer.name_or_path and tokenizer.pad_token is None:
|
|
||||||
logger.info(f'Set PAD token to EOS: {tokenizer.eos_token}')
|
|
||||||
tokenizer._pad_token = tokenizer.eos_token
|
|
||||||
model.config.pad_token_id = model.config.eos_token_id
|
|
||||||
|
|
||||||
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
|
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
|
||||||
|
|
||||||
# Map labels to IDs (not necessary for GLUE tasks)
|
# Map labels to IDs (not necessary for GLUE tasks)
|
||||||
@ -469,7 +506,16 @@ def main():
|
|||||||
eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
||||||
if data_args.max_eval_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||||
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
label_to_indexes = defaultdict(list)
|
||||||
|
for index, eval_sample in enumerate(eval_dataset):
|
||||||
|
label_to_indexes[eval_sample['label']].append(index)
|
||||||
|
max_samples_per_label = int(max_eval_samples / len(label_to_indexes))
|
||||||
|
eval_sample_indexes = []
|
||||||
|
for label, indexes in label_to_indexes.items():
|
||||||
|
eval_sample_indexes.extend(indexes[:max_samples_per_label])
|
||||||
|
logger.info(f"Set {max_samples_per_label} samples for {label}-class")
|
||||||
|
eval_sample_indexes.sort()
|
||||||
|
eval_dataset = eval_dataset.select(eval_sample_indexes)
|
||||||
|
|
||||||
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
|
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
|
||||||
if "test" not in raw_datasets and "test_matched" not in raw_datasets:
|
if "test" not in raw_datasets and "test_matched" not in raw_datasets:
|
||||||
@ -526,13 +572,14 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
ignore_keys_for_eval = ['hidden_states', 'attentions', 'past_key_values']
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
if training_args.resume_from_checkpoint is not None:
|
if training_args.resume_from_checkpoint is not None:
|
||||||
checkpoint = training_args.resume_from_checkpoint
|
checkpoint = training_args.resume_from_checkpoint
|
||||||
elif last_checkpoint is not None:
|
elif last_checkpoint is not None:
|
||||||
checkpoint = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint, ignore_keys_for_eval=ignore_keys_for_eval)
|
||||||
metrics = train_result.metrics
|
metrics = train_result.metrics
|
||||||
max_train_samples = (
|
max_train_samples = (
|
||||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||||
@ -562,7 +609,7 @@ def main():
|
|||||||
combined = {}
|
combined = {}
|
||||||
|
|
||||||
for eval_dataset, task in zip(eval_datasets, tasks):
|
for eval_dataset, task in zip(eval_datasets, tasks):
|
||||||
metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
metrics = trainer.evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys_for_eval)
|
||||||
|
|
||||||
max_eval_samples = (
|
max_eval_samples = (
|
||||||
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
@ -590,7 +637,7 @@ def main():
|
|||||||
for predict_dataset, task in zip(predict_datasets, tasks):
|
for predict_dataset, task in zip(predict_datasets, tasks):
|
||||||
# Removing the `label` columns because it contains -1 and Trainer won't like that.
|
# Removing the `label` columns because it contains -1 and Trainer won't like that.
|
||||||
predict_dataset = predict_dataset.remove_columns("label")
|
predict_dataset = predict_dataset.remove_columns("label")
|
||||||
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
|
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict", ignore_keys=ignore_keys_for_eval).predictions
|
||||||
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
||||||
|
|
||||||
output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
|
output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
|
||||||
|
Loading…
Reference in New Issue
Block a user