# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for saving/loading function for keras Model.""" import collections import keras # Declaring namedtuple() ModelFn = collections.namedtuple( "ModelFn", ["model", "input_shape", "target_shape"] ) def basic_sequential(): """Basic sequential model.""" model = keras.Sequential( [ keras.layers.Dense(3, activation="relu", input_shape=(3,)), keras.layers.Dense(2, activation="softmax"), ] ) return ModelFn(model, (None, 3), (None, 2)) def basic_sequential_deferred(): """Sequential model with deferred input shape.""" model = keras.Sequential( [ keras.layers.Dense(3, activation="relu"), keras.layers.Dense(2, activation="softmax"), ] ) return ModelFn(model, (None, 3), (None, 2)) def stacked_rnn(): """Stacked RNN model.""" inputs = keras.Input((None, 3)) layer = keras.layers.RNN([keras.layers.LSTMCell(2) for _ in range(3)]) x = layer(inputs) outputs = keras.layers.Dense(2)(x) model = keras.Model(inputs, outputs) return ModelFn(model, (None, 4, 3), (None, 2)) def lstm(): """LSTM model.""" inputs = keras.Input((None, 3)) x = keras.layers.LSTM(4, return_sequences=True)(inputs) x = keras.layers.LSTM(3, return_sequences=True)(x) x = keras.layers.LSTM(2, return_sequences=False)(x) outputs = keras.layers.Dense(2)(x) model = keras.Model(inputs, outputs) return ModelFn(model, (None, 4, 3), (None, 2)) def multi_input_multi_output(): """Multi-input Multi-output model.""" body_input = keras.Input(shape=(None,), name="body") tags_input = keras.Input(shape=(2,), name="tags") x = keras.layers.Embedding(10, 4)(body_input) body_features = keras.layers.LSTM(5)(x) x = keras.layers.concatenate([body_features, tags_input]) pred_1 = keras.layers.Dense(2, activation="sigmoid", name="priority")(x) pred_2 = keras.layers.Dense(3, activation="softmax", name="department")(x) model = keras.Model( inputs=[body_input, tags_input], outputs=[pred_1, pred_2] ) return ModelFn(model, [(None, 1), (None, 2)], [(None, 2), (None, 3)]) def nested_sequential_in_functional(): """A sequential model nested in a functional model.""" inner_model = keras.Sequential( [ keras.layers.Dense(3, activation="relu", input_shape=(3,)), keras.layers.Dense(2, activation="relu"), ] ) inputs = keras.Input(shape=(3,)) x = inner_model(inputs) outputs = keras.layers.Dense(2, activation="softmax")(x) model = keras.Model(inputs, outputs) return ModelFn(model, (None, 3), (None, 2)) def seq_to_seq(): """Sequence to sequence model.""" num_encoder_tokens = 3 num_decoder_tokens = 3 latent_dim = 2 encoder_inputs = keras.Input(shape=(None, num_encoder_tokens)) encoder = keras.layers.LSTM(latent_dim, return_state=True) _, state_h, state_c = encoder(encoder_inputs) encoder_states = [state_h, state_c] decoder_inputs = keras.Input(shape=(None, num_decoder_tokens)) decoder_lstm = keras.layers.LSTM( latent_dim, return_sequences=True, return_state=True ) decoder_outputs, _, _ = decoder_lstm( decoder_inputs, initial_state=encoder_states ) decoder_dense = keras.layers.Dense(num_decoder_tokens, activation="softmax") decoder_outputs = decoder_dense(decoder_outputs) model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs) return ModelFn( model, [(None, 2, num_encoder_tokens), (None, 2, num_decoder_tokens)], (None, 2, num_decoder_tokens), ) def shared_layer_functional(): """Shared layer in a functional model.""" main_input = keras.Input(shape=(10,), dtype="int32", name="main_input") x = keras.layers.Embedding(output_dim=5, input_dim=4, input_length=10)( main_input ) lstm_out = keras.layers.LSTM(3)(x) auxiliary_output = keras.layers.Dense( 1, activation="sigmoid", name="aux_output" )(lstm_out) auxiliary_input = keras.Input(shape=(5,), name="aux_input") x = keras.layers.concatenate([lstm_out, auxiliary_input]) x = keras.layers.Dense(2, activation="relu")(x) main_output = keras.layers.Dense( 1, activation="sigmoid", name="main_output" )(x) model = keras.Model( inputs=[main_input, auxiliary_input], outputs=[main_output, auxiliary_output], ) return ModelFn(model, [(None, 10), (None, 5)], [(None, 1), (None, 1)]) def shared_sequential(): """Shared sequential model in a functional model.""" inner_model = keras.Sequential( [ keras.layers.Conv2D(2, 3, activation="relu"), keras.layers.Conv2D(2, 3, activation="relu"), ] ) inputs_1 = keras.Input((5, 5, 3)) inputs_2 = keras.Input((5, 5, 3)) x1 = inner_model(inputs_1) x2 = inner_model(inputs_2) x = keras.layers.concatenate([x1, x2]) outputs = keras.layers.GlobalAveragePooling2D()(x) model = keras.Model([inputs_1, inputs_2], outputs) return ModelFn(model, [(None, 5, 5, 3), (None, 5, 5, 3)], (None, 4)) class MySubclassModel(keras.Model): """A subclass model.""" def __init__(self, input_dim=3): super().__init__(name="my_subclass_model") self._config = {"input_dim": input_dim} self.dense1 = keras.layers.Dense(8, activation="relu") self.dense2 = keras.layers.Dense(2, activation="softmax") self.bn = keras.layers.BatchNormalization() self.dp = keras.layers.Dropout(0.5) def call(self, inputs, **kwargs): x = self.dense1(inputs) x = self.dp(x) x = self.bn(x) return self.dense2(x) def get_config(self): return self._config @classmethod def from_config(cls, config): return cls(**config) def nested_subclassed_model(): """A subclass model nested in another subclass model.""" class NestedSubclassModel(keras.Model): """A nested subclass model.""" def __init__(self): super().__init__() self.dense1 = keras.layers.Dense(4, activation="relu") self.dense2 = keras.layers.Dense(2, activation="relu") self.bn = keras.layers.BatchNormalization() self.inner_subclass_model = MySubclassModel() def call(self, inputs): x = self.dense1(inputs) x = self.bn(x) x = self.inner_subclass_model(x) return self.dense2(x) return ModelFn(NestedSubclassModel(), (None, 3), (None, 2)) def nested_subclassed_in_functional_model(): """A subclass model nested in a functional model.""" inner_subclass_model = MySubclassModel() inputs = keras.Input(shape=(3,)) x = inner_subclass_model(inputs) x = keras.layers.BatchNormalization()(x) outputs = keras.layers.Dense(2, activation="softmax")(x) model = keras.Model(inputs, outputs) return ModelFn(model, (None, 3), (None, 2)) def nested_functional_in_subclassed_model(): """A functional model nested in a subclass model.""" def get_functional_model(): inputs = keras.Input(shape=(4,)) x = keras.layers.Dense(4, activation="relu")(inputs) x = keras.layers.BatchNormalization()(x) outputs = keras.layers.Dense(2)(x) return keras.Model(inputs, outputs) class NestedFunctionalInSubclassModel(keras.Model): """A functional nested in subclass model.""" def __init__(self): super().__init__(name="nested_functional_in_subclassed_model") self.dense1 = keras.layers.Dense(4, activation="relu") self.dense2 = keras.layers.Dense(2, activation="relu") self.inner_functional_model = get_functional_model() def call(self, inputs): x = self.dense1(inputs) x = self.inner_functional_model(x) return self.dense2(x) return ModelFn(NestedFunctionalInSubclassModel(), (None, 3), (None, 2)) def shared_layer_subclassed_model(): """Shared layer in a subclass model.""" class SharedLayerSubclassModel(keras.Model): """A subclass model with shared layers.""" def __init__(self): super().__init__(name="shared_layer_subclass_model") self.dense = keras.layers.Dense(3, activation="relu") self.dp = keras.layers.Dropout(0.5) self.bn = keras.layers.BatchNormalization() def call(self, inputs): x = self.dense(inputs) x = self.dp(x) x = self.bn(x) return self.dense(x) return ModelFn(SharedLayerSubclassModel(), (None, 3), (None, 3)) def functional_with_keyword_args(): """A functional model with keyword args.""" inputs = keras.Input(shape=(3,)) x = keras.layers.Dense(4)(inputs) x = keras.layers.BatchNormalization()(x) outputs = keras.layers.Dense(2)(x) model = keras.Model(inputs, outputs, name="m", trainable=False) return ModelFn(model, (None, 3), (None, 2)) ALL_MODELS = [ ("basic_sequential", basic_sequential), ("basic_sequential_deferred", basic_sequential_deferred), ("stacked_rnn", stacked_rnn), ("lstm", lstm), ("multi_input_multi_output", multi_input_multi_output), ("nested_sequential_in_functional", nested_sequential_in_functional), ("seq_to_seq", seq_to_seq), ("shared_layer_functional", shared_layer_functional), ("shared_sequential", shared_sequential), ("nested_subclassed_model", nested_subclassed_model), ( "nested_subclassed_in_functional_model", nested_subclassed_in_functional_model, ), ( "nested_functional_in_subclassed_model", nested_functional_in_subclassed_model, ), ("shared_layer_subclassed_model", shared_layer_subclassed_model), ("functional_with_keyword_args", functional_with_keyword_args), ] def get_models(exclude_models=None): """Get all models excluding the specified ones.""" models = [model for model in ALL_MODELS if model[0] not in exclude_models] return models