added freeze weights for GPT-2

This commit is contained in:
Andrzej Preibisz 2023-02-12 17:15:32 +01:00
parent a0a10fe18e
commit 1ea86c3ee2

View File

@ -201,6 +201,24 @@ class ModelArguments:
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
freeze_weights: bool = field(
default=False,
metadata={"help": "Freeze encoder weights"},
)
def freeze_model_weights(model: torch.nn.Module) -> None:
count = 0
print(len(model.parameters())
for param in model.parameters():
count += 1
if count <= 10:
logger.info(f'Freezing layer {count}')
param.requires_grad = False
else:
logger.info(f'Ignoring layer {count}')
def main():
@ -375,6 +393,11 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
if model_args.freeze_weights:
logger.info("Freezing encoder weights")
freeze_model_weights(model.decoder)
# Preprocessing the raw_datasets
if data_args.task_name is not None: