added freeze weights for GPT-2
This commit is contained in:
parent
a0a10fe18e
commit
1ea86c3ee2
23
run_glue.py
23
run_glue.py
@ -202,6 +202,24 @@ class ModelArguments:
|
||||
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
||||
)
|
||||
|
||||
freeze_weights: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Freeze encoder weights"},
|
||||
)
|
||||
|
||||
|
||||
def freeze_model_weights(model: torch.nn.Module) -> None:
|
||||
count = 0
|
||||
print(len(model.parameters())
|
||||
for param in model.parameters():
|
||||
count += 1
|
||||
if count <= 10:
|
||||
logger.info(f'Freezing layer {count}')
|
||||
param.requires_grad = False
|
||||
else:
|
||||
logger.info(f'Ignoring layer {count}')
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
@ -376,6 +394,11 @@ def main():
|
||||
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
if model_args.freeze_weights:
|
||||
logger.info("Freezing encoder weights")
|
||||
freeze_model_weights(model.decoder)
|
||||
|
||||
|
||||
# Preprocessing the raw_datasets
|
||||
if data_args.task_name is not None:
|
||||
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
|
||||
|
Loading…
Reference in New Issue
Block a user