# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Keras initializers for TF 1.""" from tensorflow.python.framework import dtypes from tensorflow.python.ops import init_ops from tensorflow.python.util.tf_export import keras_export _v1_zeros_initializer = init_ops.Zeros _v1_ones_initializer = init_ops.Ones _v1_constant_initializer = init_ops.Constant _v1_variance_scaling_initializer = init_ops.VarianceScaling _v1_orthogonal_initializer = init_ops.Orthogonal _v1_identity = init_ops.Identity _v1_glorot_uniform_initializer = init_ops.GlorotUniform _v1_glorot_normal_initializer = init_ops.GlorotNormal keras_export(v1=['keras.initializers.Zeros', 'keras.initializers.zeros'])( _v1_zeros_initializer) keras_export(v1=['keras.initializers.Ones', 'keras.initializers.ones'])( _v1_ones_initializer) keras_export(v1=['keras.initializers.Constant', 'keras.initializers.constant'])( _v1_constant_initializer) keras_export(v1=['keras.initializers.VarianceScaling'])( _v1_variance_scaling_initializer) keras_export(v1=['keras.initializers.Orthogonal', 'keras.initializers.orthogonal'])(_v1_orthogonal_initializer) keras_export(v1=['keras.initializers.Identity', 'keras.initializers.identity'])(_v1_identity) keras_export(v1=['keras.initializers.glorot_uniform'])( _v1_glorot_uniform_initializer) keras_export(v1=['keras.initializers.glorot_normal'])( _v1_glorot_normal_initializer) @keras_export(v1=['keras.initializers.RandomNormal', 'keras.initializers.random_normal', 'keras.initializers.normal']) class RandomNormal(init_ops.RandomNormal): def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): super(RandomNormal, self).__init__( mean=mean, stddev=stddev, seed=seed, dtype=dtype) @keras_export(v1=['keras.initializers.RandomUniform', 'keras.initializers.random_uniform', 'keras.initializers.uniform']) class RandomUniform(init_ops.RandomUniform): def __init__(self, minval=-0.05, maxval=0.05, seed=None, dtype=dtypes.float32): super(RandomUniform, self).__init__( minval=minval, maxval=maxval, seed=seed, dtype=dtype) @keras_export(v1=['keras.initializers.TruncatedNormal', 'keras.initializers.truncated_normal']) class TruncatedNormal(init_ops.TruncatedNormal): def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): super(TruncatedNormal, self).__init__( mean=mean, stddev=stddev, seed=seed, dtype=dtype) @keras_export(v1=['keras.initializers.lecun_normal']) class LecunNormal(init_ops.VarianceScaling): def __init__(self, seed=None): super(LecunNormal, self).__init__( scale=1., mode='fan_in', distribution='truncated_normal', seed=seed) def get_config(self): return {'seed': self.seed} @keras_export(v1=['keras.initializers.lecun_uniform']) class LecunUniform(init_ops.VarianceScaling): def __init__(self, seed=None): super(LecunUniform, self).__init__( scale=1., mode='fan_in', distribution='uniform', seed=seed) def get_config(self): return {'seed': self.seed} @keras_export(v1=['keras.initializers.he_normal']) class HeNormal(init_ops.VarianceScaling): def __init__(self, seed=None): super(HeNormal, self).__init__( scale=2., mode='fan_in', distribution='truncated_normal', seed=seed) def get_config(self): return {'seed': self.seed} @keras_export(v1=['keras.initializers.he_uniform']) class HeUniform(init_ops.VarianceScaling): def __init__(self, seed=None): super(HeUniform, self).__init__( scale=2., mode='fan_in', distribution='uniform', seed=seed) def get_config(self): return {'seed': self.seed}