added freeze weights for GPT-2

This commit is contained in:
Andrzej Preibisz 2023-02-12 17:15:32 +01:00
parent a0a10fe18e
commit 1ea86c3ee2

View File

@ -202,6 +202,24 @@ class ModelArguments:
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
) )
freeze_weights: bool = field(
default=False,
metadata={"help": "Freeze encoder weights"},
)
def freeze_model_weights(model: torch.nn.Module) -> None:
count = 0
print(len(model.parameters())
for param in model.parameters():
count += 1
if count <= 10:
logger.info(f'Freezing layer {count}')
param.requires_grad = False
else:
logger.info(f'Ignoring layer {count}')
def main(): def main():
# See all possible arguments in src/transformers/training_args.py # See all possible arguments in src/transformers/training_args.py
@ -376,6 +394,11 @@ def main():
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
) )
if model_args.freeze_weights:
logger.info("Freezing encoder weights")
freeze_model_weights(model.decoder)
# Preprocessing the raw_datasets # Preprocessing the raw_datasets
if data_args.task_name is not None: if data_args.task_name is not None:
sentence1_key, sentence2_key = task_to_keys[data_args.task_name] sentence1_key, sentence2_key = task_to_keys[data_args.task_name]