This commit is contained in:
Andrzej Preibisz 2023-02-12 18:16:10 +01:00
parent ea52abf75a
commit 3037991562

View File

@ -230,10 +230,9 @@ class ModelArguments:
def freeze_model_weights(model: torch.nn.Module) -> None: def freeze_model_weights(model: torch.nn.Module) -> None:
count = 0 count = 0
print(len(model.parameters()))
for param in model.parameters(): for param in model.parameters():
count += 1 count += 1
if count <= 10: if count <= 40:
logger.info(f'Freezing layer {count}') logger.info(f'Freezing layer {count}')
param.requires_grad = False param.requires_grad = False
else: else:
@ -434,7 +433,7 @@ def main():
if model_args.freeze_weights: if model_args.freeze_weights:
logger.info("Freezing encoder weights") logger.info("Freezing encoder weights")
freeze_model_weights(model.decoder) freeze_model_weights(model)
if 'gpt2' in tokenizer.name_or_path and tokenizer.pad_token is None: if 'gpt2' in tokenizer.name_or_path and tokenizer.pad_token is None: