diff --git a/run_translation_freezing.py b/run_translation_freezing.py index bdeb26d..86f7266 100644 --- a/run_translation_freezing.py +++ b/run_translation_freezing.py @@ -261,7 +261,7 @@ def freeze_model_weights(model: torch.nn.Module) -> None: count = 0 for param in model.parameters(): count += 1 - if count < 20: + if count <= 20: logger.info(f'Freezing layer {count}') param.requires_grad = False else: