added freeze weights for GPT-2
This commit is contained in:
parent
a0a10fe18e
commit
1ea86c3ee2
23
run_glue.py
23
run_glue.py
@ -201,6 +201,24 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
freeze_weights: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Freeze encoder weights"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_model_weights(model: torch.nn.Module) -> None:
|
||||||
|
count = 0
|
||||||
|
print(len(model.parameters())
|
||||||
|
for param in model.parameters():
|
||||||
|
count += 1
|
||||||
|
if count <= 10:
|
||||||
|
logger.info(f'Freezing layer {count}')
|
||||||
|
param.requires_grad = False
|
||||||
|
else:
|
||||||
|
logger.info(f'Ignoring layer {count}')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -375,6 +393,11 @@ def main():
|
|||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_args.freeze_weights:
|
||||||
|
logger.info("Freezing encoder weights")
|
||||||
|
freeze_model_weights(model.decoder)
|
||||||
|
|
||||||
|
|
||||||
# Preprocessing the raw_datasets
|
# Preprocessing the raw_datasets
|
||||||
if data_args.task_name is not None:
|
if data_args.task_name is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user