121 lines
4.3 KiB
Python
121 lines
4.3 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Keras initializers for TF 1."""
|
|
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.ops import init_ops
|
|
from tensorflow.python.util.tf_export import keras_export
|
|
|
|
|
|
_v1_zeros_initializer = init_ops.Zeros
|
|
_v1_ones_initializer = init_ops.Ones
|
|
_v1_constant_initializer = init_ops.Constant
|
|
_v1_variance_scaling_initializer = init_ops.VarianceScaling
|
|
_v1_orthogonal_initializer = init_ops.Orthogonal
|
|
_v1_identity = init_ops.Identity
|
|
_v1_glorot_uniform_initializer = init_ops.GlorotUniform
|
|
_v1_glorot_normal_initializer = init_ops.GlorotNormal
|
|
|
|
keras_export(v1=['keras.initializers.Zeros', 'keras.initializers.zeros'])(
|
|
_v1_zeros_initializer)
|
|
keras_export(v1=['keras.initializers.Ones', 'keras.initializers.ones'])(
|
|
_v1_ones_initializer)
|
|
keras_export(v1=['keras.initializers.Constant', 'keras.initializers.constant'])(
|
|
_v1_constant_initializer)
|
|
keras_export(v1=['keras.initializers.VarianceScaling'])(
|
|
_v1_variance_scaling_initializer)
|
|
keras_export(v1=['keras.initializers.Orthogonal',
|
|
'keras.initializers.orthogonal'])(_v1_orthogonal_initializer)
|
|
keras_export(v1=['keras.initializers.Identity',
|
|
'keras.initializers.identity'])(_v1_identity)
|
|
keras_export(v1=['keras.initializers.glorot_uniform'])(
|
|
_v1_glorot_uniform_initializer)
|
|
keras_export(v1=['keras.initializers.glorot_normal'])(
|
|
_v1_glorot_normal_initializer)
|
|
|
|
|
|
@keras_export(v1=['keras.initializers.RandomNormal',
|
|
'keras.initializers.random_normal',
|
|
'keras.initializers.normal'])
|
|
class RandomNormal(init_ops.RandomNormal):
|
|
|
|
def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
|
|
super(RandomNormal, self).__init__(
|
|
mean=mean, stddev=stddev, seed=seed, dtype=dtype)
|
|
|
|
|
|
@keras_export(v1=['keras.initializers.RandomUniform',
|
|
'keras.initializers.random_uniform',
|
|
'keras.initializers.uniform'])
|
|
class RandomUniform(init_ops.RandomUniform):
|
|
|
|
def __init__(self, minval=-0.05, maxval=0.05, seed=None,
|
|
dtype=dtypes.float32):
|
|
super(RandomUniform, self).__init__(
|
|
minval=minval, maxval=maxval, seed=seed, dtype=dtype)
|
|
|
|
|
|
@keras_export(v1=['keras.initializers.TruncatedNormal',
|
|
'keras.initializers.truncated_normal'])
|
|
class TruncatedNormal(init_ops.TruncatedNormal):
|
|
|
|
def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
|
|
super(TruncatedNormal, self).__init__(
|
|
mean=mean, stddev=stddev, seed=seed, dtype=dtype)
|
|
|
|
|
|
@keras_export(v1=['keras.initializers.lecun_normal'])
|
|
class LecunNormal(init_ops.VarianceScaling):
|
|
|
|
def __init__(self, seed=None):
|
|
super(LecunNormal, self).__init__(
|
|
scale=1., mode='fan_in', distribution='truncated_normal', seed=seed)
|
|
|
|
def get_config(self):
|
|
return {'seed': self.seed}
|
|
|
|
|
|
@keras_export(v1=['keras.initializers.lecun_uniform'])
|
|
class LecunUniform(init_ops.VarianceScaling):
|
|
|
|
def __init__(self, seed=None):
|
|
super(LecunUniform, self).__init__(
|
|
scale=1., mode='fan_in', distribution='uniform', seed=seed)
|
|
|
|
def get_config(self):
|
|
return {'seed': self.seed}
|
|
|
|
|
|
@keras_export(v1=['keras.initializers.he_normal'])
|
|
class HeNormal(init_ops.VarianceScaling):
|
|
|
|
def __init__(self, seed=None):
|
|
super(HeNormal, self).__init__(
|
|
scale=2., mode='fan_in', distribution='truncated_normal', seed=seed)
|
|
|
|
def get_config(self):
|
|
return {'seed': self.seed}
|
|
|
|
|
|
@keras_export(v1=['keras.initializers.he_uniform'])
|
|
class HeUniform(init_ops.VarianceScaling):
|
|
|
|
def __init__(self, seed=None):
|
|
super(HeUniform, self).__init__(
|
|
scale=2., mode='fan_in', distribution='uniform', seed=seed)
|
|
|
|
def get_config(self):
|
|
return {'seed': self.seed}
|