Merge branch 'master' of https://git.wmi.amu.edu.pl/s444465/projekt-glebokie
This commit is contained in:
commit
74117e2a3f
67
run_glue.py
67
run_glue.py
@ -20,6 +20,7 @@ import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
@ -47,6 +48,17 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
from roberta import RobertaForSequenceClassificationCustomSimple, RobertaForSequenceClassificationCustom, RobertaForSequenceClassificationCustomAlternative
|
||||
from gpt2 import GPT2ForSequenceClassificationCustomSimple, GPT2ForSequenceClassificationCustom
|
||||
|
||||
MODEL_NAME_TO_CLASS = {
|
||||
'roberta_simple': RobertaForSequenceClassificationCustomSimple,
|
||||
'roberta_hidden': RobertaForSequenceClassificationCustom,
|
||||
'roberta_hidden_v2': RobertaForSequenceClassificationCustomAlternative,
|
||||
'gpt2_simple': GPT2ForSequenceClassificationCustomSimple,
|
||||
'gpt2_hidden': GPT2ForSequenceClassificationCustom,
|
||||
}
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.23.0")
|
||||
|
||||
@ -207,6 +219,13 @@ class ModelArguments:
|
||||
metadata={"help": "Freeze encoder weights"},
|
||||
)
|
||||
|
||||
custom_model: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Use custom implementation from available list",
|
||||
"choices": list(MODEL_NAME_TO_CLASS.keys()),
|
||||
},
|
||||
)
|
||||
|
||||
def freeze_model_weights(model: torch.nn.Module) -> None:
|
||||
count = 0
|
||||
@ -384,7 +403,25 @@ def main():
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
|
||||
custom_model = model_args.custom_model
|
||||
if custom_model is not None:
|
||||
# Check model and implementation is the same
|
||||
if 'roberta' in custom_model and 'roberta' not in model_args.model_name_or_path:
|
||||
raise RuntimeError('Model and custom implementation should be the same type: RoBERTa')
|
||||
elif 'gpt2' in custom_model and 'gpt2' not in model_args.model_name_or_path:
|
||||
raise RuntimeError('Model and custom implementation should be the same type: GPT-2')
|
||||
|
||||
# Set custom configuration in model configuration
|
||||
config.use_hidden_states = 'hidden' in custom_model
|
||||
logger.info(f'Using hidden states in model: {config.use_hidden_states}')
|
||||
|
||||
# Get class to initialize model
|
||||
model_cls = MODEL_NAME_TO_CLASS[custom_model]
|
||||
else:
|
||||
model_cls = AutoModelForSequenceClassification
|
||||
logger.info(f'Using implementation from class: {model_cls.__name__}')
|
||||
model = model_cls.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
@ -399,6 +436,11 @@ def main():
|
||||
freeze_model_weights(model.decoder)
|
||||
|
||||
|
||||
if 'gpt2' in tokenizer.name_or_path and tokenizer.pad_token is None:
|
||||
logger.info(f'Set PAD token to EOS: {tokenizer.eos_token}')
|
||||
tokenizer._pad_token = tokenizer.eos_token
|
||||
model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
# Preprocessing the raw_datasets
|
||||
if data_args.task_name is not None:
|
||||
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
|
||||
@ -459,11 +501,6 @@ def main():
|
||||
args = (
|
||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||
)
|
||||
if 'gpt2' in tokenizer.name_or_path and tokenizer.pad_token is None:
|
||||
logger.info(f'Set PAD token to EOS: {tokenizer.eos_token}')
|
||||
tokenizer._pad_token = tokenizer.eos_token
|
||||
model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
|
||||
|
||||
# Map labels to IDs (not necessary for GLUE tasks)
|
||||
@ -492,7 +529,16 @@ def main():
|
||||
eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
||||
label_to_indexes = defaultdict(list)
|
||||
for index, eval_sample in enumerate(eval_dataset):
|
||||
label_to_indexes[eval_sample['label']].append(index)
|
||||
max_samples_per_label = int(max_eval_samples / len(label_to_indexes))
|
||||
eval_sample_indexes = []
|
||||
for label, indexes in label_to_indexes.items():
|
||||
eval_sample_indexes.extend(indexes[:max_samples_per_label])
|
||||
logger.info(f"Set {max_samples_per_label} samples for {label}-class")
|
||||
eval_sample_indexes.sort()
|
||||
eval_dataset = eval_dataset.select(eval_sample_indexes)
|
||||
|
||||
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
|
||||
if "test" not in raw_datasets and "test_matched" not in raw_datasets:
|
||||
@ -549,13 +595,14 @@ def main():
|
||||
)
|
||||
|
||||
# Training
|
||||
ignore_keys_for_eval = ['hidden_states', 'attentions', 'past_key_values']
|
||||
if training_args.do_train:
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint, ignore_keys_for_eval=ignore_keys_for_eval)
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
@ -585,7 +632,7 @@ def main():
|
||||
combined = {}
|
||||
|
||||
for eval_dataset, task in zip(eval_datasets, tasks):
|
||||
metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
||||
metrics = trainer.evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys_for_eval)
|
||||
|
||||
max_eval_samples = (
|
||||
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
@ -613,7 +660,7 @@ def main():
|
||||
for predict_dataset, task in zip(predict_datasets, tasks):
|
||||
# Removing the `label` columns because it contains -1 and Trainer won't like that.
|
||||
predict_dataset = predict_dataset.remove_columns("label")
|
||||
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
|
||||
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict", ignore_keys=ignore_keys_for_eval).predictions
|
||||
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
||||
|
||||
output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
|
||||
|
Loading…
Reference in New Issue
Block a user