986 lines
33 KiB
Python
986 lines
33 KiB
Python
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
|
|
"""ResNet-RS models for Keras.
|
|
|
|
Reference:
|
|
- [Revisiting ResNets: Improved Training and Scaling Strategies](
|
|
https://arxiv.org/pdf/2103.07579.pdf)
|
|
"""
|
|
import sys
|
|
from typing import Callable
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Union
|
|
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
from keras import backend
|
|
from keras import layers
|
|
from keras.applications import imagenet_utils
|
|
from keras.engine import training
|
|
from keras.utils import data_utils
|
|
from keras.utils import layer_utils
|
|
|
|
# isort: off
|
|
from tensorflow.python.util.tf_export import keras_export
|
|
|
|
BASE_WEIGHTS_URL = (
|
|
"https://storage.googleapis.com/tensorflow/keras-applications/resnet_rs/"
|
|
)
|
|
|
|
WEIGHT_HASHES = {
|
|
"resnet-rs-101-i160.h5": "544b3434d00efc199d66e9058c7f3379",
|
|
"resnet-rs-101-i160_notop.h5": "82d5b90c5ce9d710da639d6216d0f979",
|
|
"resnet-rs-101-i192.h5": "eb285be29ab42cf4835ff20a5e3b5d23",
|
|
"resnet-rs-101-i192_notop.h5": "f9a0f6b85faa9c3db2b6e233c4eebb5b",
|
|
"resnet-rs-152-i192.h5": "8d72a301ed8a6f11a47c4ced4396e338",
|
|
"resnet-rs-152-i192_notop.h5": "5fbf7ac2155cb4d5a6180ee9e3aa8704",
|
|
"resnet-rs-152-i224.h5": "31a46a92ab21b84193d0d71dd8c3d03b",
|
|
"resnet-rs-152-i224_notop.h5": "dc8b2cba2005552eafa3167f00dc2133",
|
|
"resnet-rs-152-i256.h5": "ba6271b99bdeb4e7a9b15c05964ef4ad",
|
|
"resnet-rs-152-i256_notop.h5": "fa79794252dbe47c89130f65349d654a",
|
|
"resnet-rs-200-i256.h5": "a76930b741884e09ce90fa7450747d5f",
|
|
"resnet-rs-200-i256_notop.h5": "bbdb3994718dfc0d1cd45d7eff3f3d9c",
|
|
"resnet-rs-270-i256.h5": "20d575825ba26176b03cb51012a367a8",
|
|
"resnet-rs-270-i256_notop.h5": "2c42ecb22e35f3e23d2f70babce0a2aa",
|
|
"resnet-rs-350-i256.h5": "f4a039dc3c421321b7fc240494574a68",
|
|
"resnet-rs-350-i256_notop.h5": "6e44b55025bbdff8f51692a023143d66",
|
|
"resnet-rs-350-i320.h5": "7ccb858cc738305e8ceb3c0140bee393",
|
|
"resnet-rs-350-i320_notop.h5": "ab0c1f9079d2f85a9facbd2c88aa6079",
|
|
"resnet-rs-420-i320.h5": "ae0eb9bed39e64fc8d7e0db4018dc7e8",
|
|
"resnet-rs-420-i320_notop.h5": "fe6217c32be8305b1889657172b98884",
|
|
"resnet-rs-50-i160.h5": "69d9d925319f00a8bdd4af23c04e4102",
|
|
"resnet-rs-50-i160_notop.h5": "90daa68cd26c95aa6c5d25451e095529",
|
|
}
|
|
|
|
DEPTH_TO_WEIGHT_VARIANTS = {
|
|
50: [160],
|
|
101: [160, 192],
|
|
152: [192, 224, 256],
|
|
200: [256],
|
|
270: [256],
|
|
350: [256, 320],
|
|
420: [320],
|
|
}
|
|
BLOCK_ARGS = {
|
|
50: [
|
|
{"input_filters": 64, "num_repeats": 3},
|
|
{"input_filters": 128, "num_repeats": 4},
|
|
{"input_filters": 256, "num_repeats": 6},
|
|
{"input_filters": 512, "num_repeats": 3},
|
|
],
|
|
101: [
|
|
{"input_filters": 64, "num_repeats": 3},
|
|
{"input_filters": 128, "num_repeats": 4},
|
|
{"input_filters": 256, "num_repeats": 23},
|
|
{"input_filters": 512, "num_repeats": 3},
|
|
],
|
|
152: [
|
|
{"input_filters": 64, "num_repeats": 3},
|
|
{"input_filters": 128, "num_repeats": 8},
|
|
{"input_filters": 256, "num_repeats": 36},
|
|
{"input_filters": 512, "num_repeats": 3},
|
|
],
|
|
200: [
|
|
{"input_filters": 64, "num_repeats": 3},
|
|
{"input_filters": 128, "num_repeats": 24},
|
|
{"input_filters": 256, "num_repeats": 36},
|
|
{"input_filters": 512, "num_repeats": 3},
|
|
],
|
|
270: [
|
|
{"input_filters": 64, "num_repeats": 4},
|
|
{"input_filters": 128, "num_repeats": 29},
|
|
{"input_filters": 256, "num_repeats": 53},
|
|
{"input_filters": 512, "num_repeats": 4},
|
|
],
|
|
350: [
|
|
{"input_filters": 64, "num_repeats": 4},
|
|
{"input_filters": 128, "num_repeats": 36},
|
|
{"input_filters": 256, "num_repeats": 72},
|
|
{"input_filters": 512, "num_repeats": 4},
|
|
],
|
|
420: [
|
|
{"input_filters": 64, "num_repeats": 4},
|
|
{"input_filters": 128, "num_repeats": 44},
|
|
{"input_filters": 256, "num_repeats": 87},
|
|
{"input_filters": 512, "num_repeats": 4},
|
|
],
|
|
}
|
|
CONV_KERNEL_INITIALIZER = {
|
|
"class_name": "VarianceScaling",
|
|
"config": {
|
|
"scale": 2.0,
|
|
"mode": "fan_out",
|
|
"distribution": "truncated_normal",
|
|
},
|
|
}
|
|
|
|
BASE_DOCSTRING = """Instantiates the {name} architecture.
|
|
|
|
Reference:
|
|
[Revisiting ResNets: Improved Training and Scaling Strategies](
|
|
https://arxiv.org/pdf/2103.07579.pdf)
|
|
|
|
For image classification use cases, see
|
|
[this page for detailed examples](
|
|
https://keras.io/api/applications/#usage-examples-for-image-classification-models).
|
|
|
|
For transfer learning use cases, make sure to read the
|
|
[guide to transfer learning & fine-tuning](
|
|
https://keras.io/guides/transfer_learning/).
|
|
|
|
Note: each Keras Application expects a specific kind of input preprocessing.
|
|
For ResNetRs, by default input preprocessing is included as a part of the
|
|
model (as a `Rescaling` layer), and thus
|
|
`tf.keras.applications.resnet_rs.preprocess_input` is actually a
|
|
pass-through function. In this use case, ResNetRS models expect their inputs
|
|
to be float tensors of pixels with values in the [0-255] range.
|
|
At the same time, preprocessing as a part of the model (i.e. `Rescaling`
|
|
layer) can be disabled by setting `include_preprocessing` argument to False.
|
|
With preprocessing disabled ResNetRS models expect their inputs to be float
|
|
tensors of pixels with values in the [-1, 1] range.
|
|
|
|
Args:
|
|
depth: Depth of ResNet network.
|
|
input_shape: optional shape tuple. It should have exactly 3 inputs
|
|
channels, and width and height should be no smaller than 32.
|
|
E.g. (200, 200, 3) would be one valid value.
|
|
bn_momentum: Momentum parameter for Batch Normalization layers.
|
|
bn_epsilon: Epsilon parameter for Batch Normalization layers.
|
|
activation: activation function.
|
|
se_ratio: Squeeze and Excitation layer ratio.
|
|
dropout_rate: dropout rate before final classifier layer.
|
|
drop_connect_rate: dropout rate at skip connections.
|
|
include_top: whether to include the fully-connected layer at the top of
|
|
the network.
|
|
block_args: list of dicts, parameters to construct block modules.
|
|
model_name: name of the model.
|
|
pooling: optional pooling mode for feature extraction when `include_top`
|
|
is `False`.
|
|
- `None` means that the output of the model will be
|
|
the 4D tensor output of the
|
|
last convolutional layer.
|
|
- `avg` means that global average pooling
|
|
will be applied to the output of the
|
|
last convolutional layer, and thus
|
|
the output of the model will be a 2D tensor.
|
|
- `max` means that global max pooling will
|
|
be applied.
|
|
weights: one of `None` (random initialization), `'imagenet'`
|
|
(pre-training on ImageNet), or the path to the weights file to be
|
|
loaded. Note: one model can have multiple imagenet variants
|
|
depending on input shape it was trained with. For input_shape
|
|
224x224 pass `imagenet-i224` as argument. By default, highest input
|
|
shape weights are downloaded.
|
|
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to
|
|
use as image input for the model.
|
|
classes: optional number of classes to classify images into, only to be
|
|
specified if `include_top` is True, and if no `weights` argument is
|
|
specified.
|
|
classifier_activation: A `str` or callable. The activation function to
|
|
use on the "top" layer. Ignored unless `include_top=True`. Set
|
|
`classifier_activation=None` to return the logits of the "top"
|
|
layer.
|
|
include_preprocessing: Boolean, whether to include the preprocessing
|
|
layer (`Rescaling`) at the bottom of the network. Defaults to
|
|
`True`. Note: Input image is normalized by ImageNet mean and
|
|
standard deviation.
|
|
|
|
Returns:
|
|
A `keras.Model` instance.
|
|
"""
|
|
|
|
|
|
def Conv2DFixedPadding(filters, kernel_size, strides, name=None):
|
|
"""Conv2D block with fixed padding."""
|
|
if name is None:
|
|
counter = backend.get_uid("conv_")
|
|
name = f"conv_{counter}"
|
|
|
|
def apply(inputs):
|
|
if strides > 1:
|
|
inputs = fixed_padding(inputs, kernel_size)
|
|
return layers.Conv2D(
|
|
filters=filters,
|
|
kernel_size=kernel_size,
|
|
strides=strides,
|
|
padding="same" if strides == 1 else "valid",
|
|
use_bias=False,
|
|
kernel_initializer=CONV_KERNEL_INITIALIZER,
|
|
name=name,
|
|
)(inputs)
|
|
|
|
return apply
|
|
|
|
|
|
def STEM(
|
|
bn_momentum: float = 0.0,
|
|
bn_epsilon: float = 1e-5,
|
|
activation: str = "relu",
|
|
name=None,
|
|
):
|
|
"""ResNet-D type STEM block."""
|
|
if name is None:
|
|
counter = backend.get_uid("stem_")
|
|
name = f"stem_{counter}"
|
|
|
|
def apply(inputs):
|
|
bn_axis = 3 if backend.image_data_format() == "channels_last" else 1
|
|
|
|
# First stem block
|
|
x = Conv2DFixedPadding(
|
|
filters=32, kernel_size=3, strides=2, name=name + "_stem_conv_1"
|
|
)(inputs)
|
|
x = layers.BatchNormalization(
|
|
axis=bn_axis,
|
|
momentum=bn_momentum,
|
|
epsilon=bn_epsilon,
|
|
name=name + "_stem_batch_norm_1",
|
|
)(x)
|
|
x = layers.Activation(activation, name=name + "_stem_act_1")(x)
|
|
|
|
# Second stem block
|
|
x = Conv2DFixedPadding(
|
|
filters=32, kernel_size=3, strides=1, name=name + "_stem_conv_2"
|
|
)(x)
|
|
x = layers.BatchNormalization(
|
|
axis=bn_axis,
|
|
momentum=bn_momentum,
|
|
epsilon=bn_epsilon,
|
|
name=name + "_stem_batch_norm_2",
|
|
)(x)
|
|
x = layers.Activation(activation, name=name + "_stem_act_2")(x)
|
|
|
|
# Final Stem block:
|
|
x = Conv2DFixedPadding(
|
|
filters=64, kernel_size=3, strides=1, name=name + "_stem_conv_3"
|
|
)(x)
|
|
x = layers.BatchNormalization(
|
|
axis=bn_axis,
|
|
momentum=bn_momentum,
|
|
epsilon=bn_epsilon,
|
|
name=name + "_stem_batch_norm_3",
|
|
)(x)
|
|
x = layers.Activation(activation, name=name + "_stem_act_3")(x)
|
|
|
|
# Replace stem max pool:
|
|
x = Conv2DFixedPadding(
|
|
filters=64, kernel_size=3, strides=2, name=name + "_stem_conv_4"
|
|
)(x)
|
|
x = layers.BatchNormalization(
|
|
axis=bn_axis,
|
|
momentum=bn_momentum,
|
|
epsilon=bn_epsilon,
|
|
name=name + "_stem_batch_norm_4",
|
|
)(x)
|
|
x = layers.Activation(activation, name=name + "_stem_act_4")(x)
|
|
return x
|
|
|
|
return apply
|
|
|
|
|
|
def SE(
|
|
in_filters: int, se_ratio: float = 0.25, expand_ratio: int = 1, name=None
|
|
):
|
|
"""Squeeze and Excitation block."""
|
|
bn_axis = 3 if backend.image_data_format() == "channels_last" else 1
|
|
if name is None:
|
|
counter = backend.get_uid("se_")
|
|
name = f"se_{counter}"
|
|
|
|
def apply(inputs):
|
|
x = layers.GlobalAveragePooling2D(name=name + "_se_squeeze")(inputs)
|
|
if bn_axis == 1:
|
|
se_shape = (x.shape[-1], 1, 1)
|
|
else:
|
|
se_shape = (1, 1, x.shape[-1])
|
|
x = layers.Reshape(se_shape, name=name + "_se_reshape")(x)
|
|
|
|
num_reduced_filters = max(1, int(in_filters * 4 * se_ratio))
|
|
|
|
x = layers.Conv2D(
|
|
filters=num_reduced_filters,
|
|
kernel_size=[1, 1],
|
|
strides=[1, 1],
|
|
kernel_initializer=CONV_KERNEL_INITIALIZER,
|
|
padding="same",
|
|
use_bias=True,
|
|
activation="relu",
|
|
name=name + "_se_reduce",
|
|
)(x)
|
|
|
|
x = layers.Conv2D(
|
|
filters=4
|
|
* in_filters
|
|
* expand_ratio, # Expand ratio is 1 by default
|
|
kernel_size=[1, 1],
|
|
strides=[1, 1],
|
|
kernel_initializer=CONV_KERNEL_INITIALIZER,
|
|
padding="same",
|
|
use_bias=True,
|
|
activation="sigmoid",
|
|
name=name + "_se_expand",
|
|
)(x)
|
|
|
|
return layers.multiply([inputs, x], name=name + "_se_excite")
|
|
|
|
return apply
|
|
|
|
|
|
def BottleneckBlock(
|
|
filters: int,
|
|
strides: int,
|
|
use_projection: bool,
|
|
bn_momentum: float = 0.0,
|
|
bn_epsilon: float = 1e-5,
|
|
activation: str = "relu",
|
|
se_ratio: float = 0.25,
|
|
survival_probability: float = 0.8,
|
|
name=None,
|
|
):
|
|
"""Bottleneck block variant for residual networks with BN."""
|
|
if name is None:
|
|
counter = backend.get_uid("block_0_")
|
|
name = f"block_0_{counter}"
|
|
|
|
def apply(inputs):
|
|
bn_axis = 3 if backend.image_data_format() == "channels_last" else 1
|
|
|
|
shortcut = inputs
|
|
|
|
if use_projection:
|
|
filters_out = filters * 4
|
|
if strides == 2:
|
|
shortcut = layers.AveragePooling2D(
|
|
pool_size=(2, 2),
|
|
strides=(2, 2),
|
|
padding="same",
|
|
name=name + "_projection_pooling",
|
|
)(inputs)
|
|
shortcut = Conv2DFixedPadding(
|
|
filters=filters_out,
|
|
kernel_size=1,
|
|
strides=1,
|
|
name=name + "_projection_conv",
|
|
)(shortcut)
|
|
else:
|
|
shortcut = Conv2DFixedPadding(
|
|
filters=filters_out,
|
|
kernel_size=1,
|
|
strides=strides,
|
|
name=name + "_projection_conv",
|
|
)(inputs)
|
|
|
|
shortcut = layers.BatchNormalization(
|
|
axis=bn_axis,
|
|
momentum=bn_momentum,
|
|
epsilon=bn_epsilon,
|
|
name=name + "_projection_batch_norm",
|
|
)(shortcut)
|
|
|
|
# First conv layer:
|
|
x = Conv2DFixedPadding(
|
|
filters=filters, kernel_size=1, strides=1, name=name + "_conv_1"
|
|
)(inputs)
|
|
x = layers.BatchNormalization(
|
|
axis=bn_axis,
|
|
momentum=bn_momentum,
|
|
epsilon=bn_epsilon,
|
|
name=name + "batch_norm_1",
|
|
)(x)
|
|
x = layers.Activation(activation, name=name + "_act_1")(x)
|
|
|
|
# Second conv layer:
|
|
x = Conv2DFixedPadding(
|
|
filters=filters,
|
|
kernel_size=3,
|
|
strides=strides,
|
|
name=name + "_conv_2",
|
|
)(x)
|
|
x = layers.BatchNormalization(
|
|
axis=bn_axis,
|
|
momentum=bn_momentum,
|
|
epsilon=bn_epsilon,
|
|
name=name + "_batch_norm_2",
|
|
)(x)
|
|
x = layers.Activation(activation, name=name + "_act_2")(x)
|
|
|
|
# Third conv layer:
|
|
x = Conv2DFixedPadding(
|
|
filters=filters * 4, kernel_size=1, strides=1, name=name + "_conv_3"
|
|
)(x)
|
|
x = layers.BatchNormalization(
|
|
axis=bn_axis,
|
|
momentum=bn_momentum,
|
|
epsilon=bn_epsilon,
|
|
name=name + "_batch_norm_3",
|
|
)(x)
|
|
|
|
if 0 < se_ratio < 1:
|
|
x = SE(filters, se_ratio=se_ratio, name=name + "_se")(x)
|
|
|
|
# Drop connect
|
|
if survival_probability:
|
|
x = layers.Dropout(
|
|
survival_probability,
|
|
noise_shape=(None, 1, 1, 1),
|
|
name=name + "_drop",
|
|
)(x)
|
|
|
|
x = layers.Add()([x, shortcut])
|
|
|
|
return layers.Activation(activation, name=name + "_output_act")(x)
|
|
|
|
return apply
|
|
|
|
|
|
def BlockGroup(
|
|
filters,
|
|
strides,
|
|
num_repeats,
|
|
se_ratio: float = 0.25,
|
|
bn_epsilon: float = 1e-5,
|
|
bn_momentum: float = 0.0,
|
|
activation: str = "relu",
|
|
survival_probability: float = 0.8,
|
|
name=None,
|
|
):
|
|
"""Create one group of blocks for the ResNet model."""
|
|
if name is None:
|
|
counter = backend.get_uid("block_group_")
|
|
name = f"block_group_{counter}"
|
|
|
|
def apply(inputs):
|
|
# Only the first block per block_group uses projection shortcut and
|
|
# strides.
|
|
x = BottleneckBlock(
|
|
filters=filters,
|
|
strides=strides,
|
|
use_projection=True,
|
|
se_ratio=se_ratio,
|
|
bn_epsilon=bn_epsilon,
|
|
bn_momentum=bn_momentum,
|
|
activation=activation,
|
|
survival_probability=survival_probability,
|
|
name=name + "_block_0_",
|
|
)(inputs)
|
|
|
|
for i in range(1, num_repeats):
|
|
x = BottleneckBlock(
|
|
filters=filters,
|
|
strides=1,
|
|
use_projection=False,
|
|
se_ratio=se_ratio,
|
|
activation=activation,
|
|
bn_epsilon=bn_epsilon,
|
|
bn_momentum=bn_momentum,
|
|
survival_probability=survival_probability,
|
|
name=name + f"_block_{i}_",
|
|
)(x)
|
|
return x
|
|
|
|
return apply
|
|
|
|
|
|
def get_survival_probability(init_rate, block_num, total_blocks):
|
|
"""Get survival probability based on block number and initial rate."""
|
|
return init_rate * float(block_num) / total_blocks
|
|
|
|
|
|
def allow_bigger_recursion(target_limit: int):
|
|
"""Increase default recursion limit to create larger models."""
|
|
current_limit = sys.getrecursionlimit()
|
|
if current_limit < target_limit:
|
|
sys.setrecursionlimit(target_limit)
|
|
|
|
|
|
def fixed_padding(inputs, kernel_size):
|
|
"""Pad the input along the spatial dimensions independently of input
|
|
size."""
|
|
pad_total = kernel_size - 1
|
|
pad_beg = pad_total // 2
|
|
pad_end = pad_total - pad_beg
|
|
|
|
# Use ZeroPadding as to avoid TFOpLambda layer
|
|
padded_inputs = layers.ZeroPadding2D(
|
|
padding=((pad_beg, pad_end), (pad_beg, pad_end))
|
|
)(inputs)
|
|
|
|
return padded_inputs
|
|
|
|
|
|
def ResNetRS(
|
|
depth: int,
|
|
input_shape=None,
|
|
bn_momentum=0.0,
|
|
bn_epsilon=1e-5,
|
|
activation: str = "relu",
|
|
se_ratio=0.25,
|
|
dropout_rate=0.25,
|
|
drop_connect_rate=0.2,
|
|
include_top=True,
|
|
block_args: List[Dict[str, int]] = None,
|
|
model_name="resnet-rs",
|
|
pooling=None,
|
|
weights="imagenet",
|
|
input_tensor=None,
|
|
classes=1000,
|
|
classifier_activation: Union[str, Callable] = "softmax",
|
|
include_preprocessing=True,
|
|
):
|
|
"""Build Resnet-RS model, given provided parameters.
|
|
|
|
Args:
|
|
depth: Depth of ResNet network.
|
|
input_shape: optional shape tuple. It should have exactly 3 inputs
|
|
channels, and width and height should be no smaller than 32. E.g.
|
|
(200, 200, 3) would be one valid value.
|
|
bn_momentum: Momentum parameter for Batch Normalization layers.
|
|
bn_epsilon: Epsilon parameter for Batch Normalization layers.
|
|
activation: activation function.
|
|
se_ratio: Squeeze and Excitation layer ratio.
|
|
dropout_rate: dropout rate before final classifier layer.
|
|
drop_connect_rate: dropout rate at skip connections.
|
|
include_top: whether to include the fully-connected layer at the top of
|
|
the network.
|
|
block_args: list of dicts, parameters to construct block modules.
|
|
model_name: name of the model.
|
|
pooling: optional pooling mode for feature extraction when `include_top`
|
|
is `False`.
|
|
- `None` means that the output of the model will be the 4D tensor
|
|
output of the last convolutional layer.
|
|
- `avg` means that global average pooling will be applied to the
|
|
output of the last convolutional layer, and thus the output of the
|
|
model will be a 2D tensor.
|
|
- `max` means that global max pooling will be applied.
|
|
weights: one of `None` (random initialization), `'imagenet'`
|
|
(pre-training on ImageNet), or the path to the weights file to be
|
|
loaded. Note- one model can have multiple imagenet variants depending
|
|
on input shape it was trained with. For input_shape 224x224 pass
|
|
`imagenet-i224` as argument. By default, highest input shape weights
|
|
are downloaded.
|
|
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to
|
|
use as image input for the model.
|
|
classes: optional number of classes to classify images into, only to be
|
|
specified if `include_top` is True, and if no `weights` argument is
|
|
specified.
|
|
classifier_activation: A `str` or callable. The activation function to
|
|
use on the "top" layer. Ignored unless `include_top=True`. Set
|
|
`classifier_activation=None` to return the logits of the "top" layer.
|
|
include_preprocessing: Boolean, whether to include the preprocessing
|
|
layer (`Rescaling`) at the bottom of the network. Defaults to `True`.
|
|
Note- Input image is normalized by ImageNet mean and standard
|
|
deviation.
|
|
|
|
Returns:
|
|
A `tf.keras.Model` instance.
|
|
|
|
Raises:
|
|
ValueError: in case of invalid argument for `weights`, or invalid input
|
|
shape.
|
|
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
|
using a pretrained top layer.
|
|
"""
|
|
# Validate parameters
|
|
available_weight_variants = DEPTH_TO_WEIGHT_VARIANTS[depth]
|
|
if weights == "imagenet":
|
|
max_input_shape = max(available_weight_variants)
|
|
# `imagenet` argument without explicit weights input size.
|
|
# Picking weights trained with biggest available shape
|
|
weights = f"{weights}-i{max_input_shape}"
|
|
|
|
weights_allow_list = [f"imagenet-i{x}" for x in available_weight_variants]
|
|
if not (
|
|
weights in {*weights_allow_list, None} or tf.io.gfile.exists(weights)
|
|
):
|
|
raise ValueError(
|
|
"The `weights` argument should be either "
|
|
"`None` (random initialization), `'imagenet'` "
|
|
"(pre-training on ImageNet, with highest available input shape),"
|
|
" or the path to the weights file to be loaded. "
|
|
f"For ResNetRS{depth} the following weight variants are "
|
|
f"available {weights_allow_list} (default=highest)."
|
|
f" Received weights={weights}"
|
|
)
|
|
|
|
if weights in weights_allow_list and include_top and classes != 1000:
|
|
raise ValueError(
|
|
"If using `weights` as `'imagenet'` or any "
|
|
f"of {weights_allow_list} "
|
|
"with `include_top` as true, `classes` should be 1000. "
|
|
f"Received classes={classes}"
|
|
)
|
|
|
|
input_shape = imagenet_utils.obtain_input_shape(
|
|
input_shape,
|
|
default_size=224,
|
|
min_size=32,
|
|
data_format=backend.image_data_format(),
|
|
require_flatten=include_top,
|
|
weights=weights,
|
|
)
|
|
# Define input tensor
|
|
if input_tensor is None:
|
|
img_input = layers.Input(shape=input_shape)
|
|
else:
|
|
if not backend.is_keras_tensor(input_tensor):
|
|
img_input = layers.Input(tensor=input_tensor, shape=input_shape)
|
|
else:
|
|
img_input = input_tensor
|
|
|
|
bn_axis = 3 if backend.image_data_format() == "channels_last" else 1
|
|
|
|
x = img_input
|
|
|
|
if include_preprocessing:
|
|
num_channels = input_shape[bn_axis - 1]
|
|
x = layers.Rescaling(scale=1.0 / 255)(x)
|
|
if num_channels == 3:
|
|
x = layers.Normalization(
|
|
mean=[0.485, 0.456, 0.406],
|
|
variance=[0.229**2, 0.224**2, 0.225**2],
|
|
axis=bn_axis,
|
|
)(x)
|
|
|
|
# Build stem
|
|
x = STEM(
|
|
bn_momentum=bn_momentum, bn_epsilon=bn_epsilon, activation=activation
|
|
)(x)
|
|
|
|
# Build blocks
|
|
if block_args is None:
|
|
block_args = BLOCK_ARGS[depth]
|
|
|
|
for i, args in enumerate(block_args):
|
|
survival_probability = get_survival_probability(
|
|
init_rate=drop_connect_rate,
|
|
block_num=i + 2,
|
|
total_blocks=len(block_args) + 1,
|
|
)
|
|
|
|
x = BlockGroup(
|
|
filters=args["input_filters"],
|
|
activation=activation,
|
|
strides=(1 if i == 0 else 2),
|
|
num_repeats=args["num_repeats"],
|
|
se_ratio=se_ratio,
|
|
bn_momentum=bn_momentum,
|
|
bn_epsilon=bn_epsilon,
|
|
survival_probability=survival_probability,
|
|
name=f"BlockGroup{i + 2}_",
|
|
)(x)
|
|
|
|
# Build head:
|
|
if include_top:
|
|
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
|
|
if dropout_rate > 0:
|
|
x = layers.Dropout(dropout_rate, name="top_dropout")(x)
|
|
|
|
imagenet_utils.validate_activation(classifier_activation, weights)
|
|
x = layers.Dense(
|
|
classes, activation=classifier_activation, name="predictions"
|
|
)(x)
|
|
else:
|
|
if pooling == "avg":
|
|
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
|
|
elif pooling == "max":
|
|
x = layers.GlobalMaxPooling2D(name="max_pool")(x)
|
|
|
|
# Ensure that the model takes into account
|
|
# any potential predecessors of `input_tensor`.
|
|
if input_tensor is not None:
|
|
inputs = layer_utils.get_source_inputs(input_tensor)
|
|
else:
|
|
inputs = img_input
|
|
|
|
# Create model.
|
|
model = training.Model(inputs, x, name=model_name)
|
|
|
|
# Download weights
|
|
if weights in weights_allow_list:
|
|
weights_input_shape = weights.split("-")[-1] # e. g. "i160"
|
|
weights_name = f"{model_name}-{weights_input_shape}"
|
|
if not include_top:
|
|
weights_name += "_notop"
|
|
|
|
filename = f"{weights_name}.h5"
|
|
download_url = BASE_WEIGHTS_URL + filename
|
|
weights_path = data_utils.get_file(
|
|
fname=filename,
|
|
origin=download_url,
|
|
cache_subdir="models",
|
|
file_hash=WEIGHT_HASHES[filename],
|
|
)
|
|
model.load_weights(weights_path)
|
|
|
|
elif weights is not None:
|
|
model.load_weights(weights)
|
|
|
|
return model
|
|
|
|
|
|
@keras_export(
|
|
"keras.applications.resnet_rs.ResNetRS50", "keras.applications.ResNetRS50"
|
|
)
|
|
def ResNetRS50(
|
|
include_top=True,
|
|
weights="imagenet",
|
|
classes=1000,
|
|
input_shape=None,
|
|
input_tensor=None,
|
|
pooling=None,
|
|
classifier_activation="softmax",
|
|
include_preprocessing=True,
|
|
):
|
|
"""Build ResNet-RS50 model."""
|
|
return ResNetRS(
|
|
depth=50,
|
|
include_top=include_top,
|
|
drop_connect_rate=0.0,
|
|
dropout_rate=0.25,
|
|
weights=weights,
|
|
classes=classes,
|
|
input_shape=input_shape,
|
|
input_tensor=input_tensor,
|
|
pooling=pooling,
|
|
classifier_activation=classifier_activation,
|
|
model_name="resnet-rs-50",
|
|
include_preprocessing=include_preprocessing,
|
|
)
|
|
|
|
|
|
@keras_export(
|
|
"keras.applications.resnet_rs.ResNetRS101", "keras.applications.ResNetRS101"
|
|
)
|
|
def ResNetRS101(
|
|
include_top=True,
|
|
weights="imagenet",
|
|
classes=1000,
|
|
input_shape=None,
|
|
input_tensor=None,
|
|
pooling=None,
|
|
classifier_activation="softmax",
|
|
include_preprocessing=True,
|
|
):
|
|
"""Build ResNet-RS101 model."""
|
|
return ResNetRS(
|
|
depth=101,
|
|
include_top=include_top,
|
|
drop_connect_rate=0.0,
|
|
dropout_rate=0.25,
|
|
weights=weights,
|
|
classes=classes,
|
|
input_shape=input_shape,
|
|
input_tensor=input_tensor,
|
|
pooling=pooling,
|
|
classifier_activation=classifier_activation,
|
|
model_name="resnet-rs-101",
|
|
include_preprocessing=include_preprocessing,
|
|
)
|
|
|
|
|
|
@keras_export(
|
|
"keras.applications.resnet_rs.ResNetRS152", "keras.applications.ResNetRS152"
|
|
)
|
|
def ResNetRS152(
|
|
include_top=True,
|
|
weights="imagenet",
|
|
classes=1000,
|
|
input_shape=None,
|
|
input_tensor=None,
|
|
pooling=None,
|
|
classifier_activation="softmax",
|
|
include_preprocessing=True,
|
|
):
|
|
"""Build ResNet-RS152 model."""
|
|
return ResNetRS(
|
|
depth=152,
|
|
include_top=include_top,
|
|
drop_connect_rate=0.0,
|
|
dropout_rate=0.25,
|
|
weights=weights,
|
|
classes=classes,
|
|
input_shape=input_shape,
|
|
input_tensor=input_tensor,
|
|
pooling=pooling,
|
|
classifier_activation=classifier_activation,
|
|
model_name="resnet-rs-152",
|
|
include_preprocessing=include_preprocessing,
|
|
)
|
|
|
|
|
|
@keras_export(
|
|
"keras.applications.resnet_rs.ResNetRS200", "keras.applications.ResNetRS200"
|
|
)
|
|
def ResNetRS200(
|
|
include_top=True,
|
|
weights="imagenet",
|
|
classes=1000,
|
|
input_shape=None,
|
|
input_tensor=None,
|
|
pooling=None,
|
|
classifier_activation="softmax",
|
|
include_preprocessing=True,
|
|
):
|
|
"""Build ResNet-RS200 model."""
|
|
return ResNetRS(
|
|
depth=200,
|
|
include_top=include_top,
|
|
drop_connect_rate=0.1,
|
|
dropout_rate=0.25,
|
|
weights=weights,
|
|
classes=classes,
|
|
input_shape=input_shape,
|
|
input_tensor=input_tensor,
|
|
pooling=pooling,
|
|
classifier_activation=classifier_activation,
|
|
model_name="resnet-rs-200",
|
|
include_preprocessing=include_preprocessing,
|
|
)
|
|
|
|
|
|
@keras_export(
|
|
"keras.applications.resnet_rs.ResNetRS270", "keras.applications.ResNetRS270"
|
|
)
|
|
def ResNetRS270(
|
|
include_top=True,
|
|
weights="imagenet",
|
|
classes=1000,
|
|
input_shape=None,
|
|
input_tensor=None,
|
|
pooling=None,
|
|
classifier_activation="softmax",
|
|
include_preprocessing=True,
|
|
):
|
|
"""Build ResNet-RS-270 model."""
|
|
allow_bigger_recursion(1300)
|
|
return ResNetRS(
|
|
depth=270,
|
|
include_top=include_top,
|
|
drop_connect_rate=0.1,
|
|
dropout_rate=0.25,
|
|
weights=weights,
|
|
classes=classes,
|
|
input_shape=input_shape,
|
|
input_tensor=input_tensor,
|
|
pooling=pooling,
|
|
classifier_activation=classifier_activation,
|
|
model_name="resnet-rs-270",
|
|
include_preprocessing=include_preprocessing,
|
|
)
|
|
|
|
|
|
@keras_export(
|
|
"keras.applications.resnet_rs.ResNetRS350", "keras.applications.ResNetRS350"
|
|
)
|
|
def ResNetRS350(
|
|
include_top=True,
|
|
weights="imagenet",
|
|
classes=1000,
|
|
input_shape=None,
|
|
input_tensor=None,
|
|
pooling=None,
|
|
classifier_activation="softmax",
|
|
include_preprocessing=True,
|
|
):
|
|
"""Build ResNet-RS350 model."""
|
|
allow_bigger_recursion(1500)
|
|
return ResNetRS(
|
|
depth=350,
|
|
include_top=include_top,
|
|
drop_connect_rate=0.1,
|
|
dropout_rate=0.4,
|
|
weights=weights,
|
|
classes=classes,
|
|
input_shape=input_shape,
|
|
input_tensor=input_tensor,
|
|
pooling=pooling,
|
|
classifier_activation=classifier_activation,
|
|
model_name="resnet-rs-350",
|
|
include_preprocessing=include_preprocessing,
|
|
)
|
|
|
|
|
|
@keras_export(
|
|
"keras.applications.resnet_rs.ResNetRS420", "keras.applications.ResNetRS420"
|
|
)
|
|
def ResNetRS420(
|
|
include_top=True,
|
|
weights="imagenet",
|
|
classes=1000,
|
|
input_shape=None,
|
|
input_tensor=None,
|
|
pooling=None,
|
|
classifier_activation="softmax",
|
|
include_preprocessing=True,
|
|
):
|
|
"""Build ResNet-RS420 model."""
|
|
allow_bigger_recursion(1800)
|
|
return ResNetRS(
|
|
depth=420,
|
|
include_top=include_top,
|
|
dropout_rate=0.4,
|
|
drop_connect_rate=0.1,
|
|
weights=weights,
|
|
classes=classes,
|
|
input_shape=input_shape,
|
|
input_tensor=input_tensor,
|
|
pooling=pooling,
|
|
classifier_activation=classifier_activation,
|
|
model_name="resnet-rs-420",
|
|
include_preprocessing=include_preprocessing,
|
|
)
|
|
|
|
|
|
@keras_export("keras.applications.resnet_rs.preprocess_input")
|
|
def preprocess_input(x, data_format=None):
|
|
"""A placeholder method for backward compatibility.
|
|
|
|
The preprocessing logic has been included in the ResnetRS model
|
|
implementation. Users are no longer required to call this method to
|
|
normalize
|
|
the input data. This method does nothing and only kept as a placeholder to
|
|
align the API surface between old and new version of model.
|
|
|
|
Args:
|
|
x: A floating point `numpy.array` or a `tf.Tensor`.
|
|
data_format: Optional data format of the image tensor/array. Defaults to
|
|
None, in which case the global setting
|
|
`tf.keras.backend.image_data_format()` is used (unless you changed it,
|
|
it defaults to "channels_last").{mode}
|
|
|
|
Returns:
|
|
Unchanged `numpy.array` or `tf.Tensor`.
|
|
"""
|
|
return x
|
|
|
|
|
|
@keras_export("keras.applications.resnet_rs.decode_predictions")
|
|
def decode_predictions(preds, top=5):
|
|
return imagenet_utils.decode_predictions(preds, top=top)
|
|
|
|
|
|
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
|
|
|
|
ResNetRS50.__doc__ = BASE_DOCSTRING.format(name="ResNetRS50")
|
|
ResNetRS101.__doc__ = BASE_DOCSTRING.format(name="ResNetRS101")
|
|
ResNetRS152.__doc__ = BASE_DOCSTRING.format(name="ResNetRS152")
|
|
ResNetRS200.__doc__ = BASE_DOCSTRING.format(name="ResNetRS200")
|
|
ResNetRS270.__doc__ = BASE_DOCSTRING.format(name="ResNetRS270")
|
|
ResNetRS350.__doc__ = BASE_DOCSTRING.format(name="ResNetRS350")
|
|
ResNetRS420.__doc__ = BASE_DOCSTRING.format(name="ResNetRS420")
|