# Copyright 2020 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Demonstrates training models and saving the result as a SavedModel. By default, uses a pure JAX implementation of MNIST. There are flags to choose a Flax CNN version of MNIST, or to skip the training and just test a previously saved SavedModel. It is possible to save a batch-polymorphic version of the model, or a model prepared for specific batch sizes. Try --help to see all flags. This file is used both as an executable, and as a library in two other examples. See discussion in README.md. """ import logging import os from absl import app from absl import flags from jax.experimental.jax2tf.examples import mnist_lib # type: ignore from jax.experimental.jax2tf.examples import saved_model_lib # type: ignore import numpy as np import tensorflow as tf # type: ignore import tensorflow_datasets as tfds # type: ignore flags.DEFINE_enum("model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"], "Which model to use.") flags.DEFINE_boolean("model_classifier_layer", True, ("The model should include the classifier layer, or just " "the last layer of logits. Set this to False when you " "want to reuse the classifier-less model in a larger " "model. See keras_reuse_main.py and README.md.")) flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models", "Path under which to save the SavedModel.") flags.DEFINE_integer("model_version", 1, ("The version number for the SavedModel. Needed for " "serving, larger versions will take precedence"), lower_bound=1) flags.DEFINE_integer("serving_batch_size", 1, "For what batch size to prepare the serving signature. " "Use -1 for converting and saving with batch polymorphism.") flags.register_validator( "serving_batch_size", lambda serving_batch_size: serving_batch_size > 0 or serving_batch_size == -1, message="--serving_batch_size must be either -1 or a positive integer.") flags.DEFINE_integer("num_epochs", 3, "For how many epochs to train.", lower_bound=1) flags.DEFINE_boolean( "generate_model", True, "Train and save a new model. Otherwise, use an existing SavedModel.") flags.DEFINE_boolean( "compile_model", True, "Enable TensorFlow jit_compiler for the SavedModel. This is " "necessary if you want to use the model for TensorFlow serving.") flags.DEFINE_boolean("show_model", True, "Show details of saved SavedModel.") flags.DEFINE_boolean( "show_images", False, "Plot some sample images with labels and inference results.") flags.DEFINE_boolean( "test_savedmodel", True, "Test TensorFlow inference using the SavedModel w.r.t. the JAX model.") FLAGS = flags.FLAGS def train_and_save(): logging.info("Loading the MNIST TensorFlow dataset") train_ds = mnist_lib.load_mnist( tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size) test_ds = mnist_lib.load_mnist( tfds.Split.TEST, batch_size=mnist_lib.test_batch_size) if FLAGS.show_images: mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None) the_model_class = pick_model_class() model_dir = savedmodel_dir(with_version=True) if FLAGS.generate_model: model_descr = model_description() logging.info("Generating model for %s", model_descr) (predict_fn, predict_params) = the_model_class.train( train_ds, test_ds, FLAGS.num_epochs, with_classifier=FLAGS.model_classifier_layer) if FLAGS.serving_batch_size == -1: # Batch-polymorphic SavedModel input_signatures = [ tf.TensorSpec((None,) + mnist_lib.input_shape, tf.float32), ] polymorphic_shapes = "(batch, ...)" else: input_signatures = [ # The first one will be the serving signature tf.TensorSpec((FLAGS.serving_batch_size,) + mnist_lib.input_shape, tf.float32), tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape, tf.float32), tf.TensorSpec((mnist_lib.test_batch_size,) + mnist_lib.input_shape, tf.float32), ] polymorphic_shapes = None logging.info("Saving model for %s", model_descr) saved_model_lib.convert_and_save_model( predict_fn, predict_params, model_dir, with_gradient=True, input_signatures=input_signatures, polymorphic_shapes=polymorphic_shapes, compile_model=FLAGS.compile_model) if FLAGS.test_savedmodel: tf_accelerator, tolerances = tf_accelerator_and_tolerances() with tf.device(tf_accelerator): logging.info("Testing savedmodel") pure_restored_model = tf.saved_model.load(model_dir) if FLAGS.show_images and FLAGS.model_classifier_layer: mnist_lib.plot_images( test_ds, 1, 5, f"Inference results for {model_descr}", inference_fn=pure_restored_model) test_input = np.ones( (mnist_lib.test_batch_size,) + mnist_lib.input_shape, dtype=np.float32) np.testing.assert_allclose( pure_restored_model(tf.convert_to_tensor(test_input)), predict_fn(predict_params, test_input), **tolerances) if FLAGS.show_model: def print_model(model_dir: str): cmd = f"saved_model_cli show --all --dir {model_dir}" print(cmd) os.system(cmd) print_model(model_dir) def pick_model_class(): """Picks one of PureJaxMNIST or FlaxMNIST.""" if FLAGS.model == "mnist_pure_jax": return mnist_lib.PureJaxMNIST elif FLAGS.model == "mnist_flax": return mnist_lib.FlaxMNIST else: raise ValueError(f"Unrecognized model: {FLAGS.model}") def model_description() -> str: """A short description of the picked model.""" res = pick_model_class().name if not FLAGS.model_classifier_layer: res += " (features_only)" return res def savedmodel_dir(with_version: bool = True) -> str: """The directory where we save the SavedModel.""" model_dir = os.path.join( FLAGS.model_path, FLAGS.model + ('' if FLAGS.model_classifier_layer else '_features') ) if with_version: model_dir = os.path.join(model_dir, str(FLAGS.model_version)) return model_dir def tf_accelerator_and_tolerances(): """Picks the TF accelerator to use and the tolerances for numerical checks.""" tf_accelerator = (tf.config.list_logical_devices("TPU") + tf.config.list_logical_devices("GPU") + tf.config.list_logical_devices("CPU"))[0] logging.info("Using tf_accelerator = %s", tf_accelerator) if tf_accelerator.device_type == "TPU": tolerances = dict(atol=1e-6, rtol=1e-6) elif tf_accelerator.device_type == "GPU": tolerances = dict(atol=1e-6, rtol=1e-4) elif tf_accelerator.device_type == "CPU": tolerances = dict(atol=1e-5, rtol=1e-5) logging.info("Using tolerances %s", tolerances) return tf_accelerator, tolerances if __name__ == "__main__": app.run(lambda _: train_and_save())