Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/jax2tf/tests/savedmodel_test.py
2023-06-19 00:49:18 +02:00

330 lines
12 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import os
import jax
from jax import lax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf # type: ignore[import]
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
from jax._src import test_util as jtu
from jax import config
config.parse_flags_with_absl()
class SavedModelTest(tf_test_util.JaxToTfTestCase):
def test_eval(self):
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
model = tf.Module()
model.f = tf.function(jax2tf.convert(f_jax),
autograph=False,
input_signature=[tf.TensorSpec([], tf.float32)]
)
x = np.array(0.7, dtype=jnp.float32)
self.assertAllClose(model.f(x), f_jax(x))
restored_model = tf_test_util.SaveAndLoadModel(model)
self.assertAllClose(restored_model.f(x), f_jax(x))
def test_gradient(self):
"""Save and restore the custom gradient."""
@jax.custom_jvp
def f_jax(x):
return x * x
@f_jax.defjvp
def f_jax_jvp(primals, tangents):
# 3 * x * x_t
x, = primals
x_dot, = tangents
primal_out = f_jax(x)
tangent_out = x * x_dot * 3.
return primal_out, tangent_out
model = tf.Module()
model.f = tf.function(jax2tf.convert(f_jax, with_gradient=True),
autograph=False,
input_signature=[tf.TensorSpec([], tf.float32)])
x = np.array(0.7, dtype=jnp.float32)
self.assertAllClose(model.f(x), f_jax(x))
restored_model = tf_test_util.SaveAndLoadModel(model)
xv = tf.Variable(x)
self.assertAllClose(restored_model.f(x), f_jax(x))
with tf.GradientTape() as tape:
y = restored_model.f(xv)
self.assertAllClose(tape.gradient(y, xv).numpy(),
jax.grad(f_jax)(x))
def test_gradient_nested(self):
"""Save and restore the custom gradient, when combined with other TF code."""
@jax.custom_jvp
def f_jax(x):
return x * x
@f_jax.defjvp
def f_jax_jvp(primals, tangents):
# 3 * x * x_t
x, = primals
x_dot, = tangents
primal_out = f_jax(x)
tangent_out = x * x_dot * 3.
return primal_out, tangent_out
model = tf.Module()
# After conversion, we wrap with some pure TF code
model.f = tf.function(lambda x: tf.math.sin(jax2tf.convert(f_jax, with_gradient=True)(x)),
autograph=False,
input_signature=[tf.TensorSpec([], tf.float32)])
f_jax_equiv = lambda x: jnp.sin(f_jax(x))
x = np.array(0.7, dtype=jnp.float32)
self.assertAllClose(model.f(x), f_jax_equiv(x))
restored_model = tf_test_util.SaveAndLoadModel(model)
xv = tf.Variable(x)
self.assertAllClose(restored_model.f(x), f_jax_equiv(x))
with tf.GradientTape() as tape:
y = restored_model.f(xv)
self.assertAllClose(tape.gradient(y, xv).numpy(),
jax.grad(f_jax_equiv)(x))
def test_gradient_disabled(self):
f_jax = lambda x: x * x
model = tf.Module()
model.f = tf.function(jax2tf.convert(f_jax, with_gradient=False),
autograph=False,
input_signature=[tf.TensorSpec([], tf.float32)])
x = np.array(0.7, dtype=jnp.float32)
self.assertAllClose(model.f(x), f_jax(x))
restored_model = tf_test_util.SaveAndLoadModel(model)
xv = tf.Variable(0.7, dtype=jnp.float32)
self.assertAllClose(restored_model.f(x), f_jax(x))
with self.assertRaisesRegex(LookupError,
"Gradient explicitly disabled.*The jax2tf-converted function does not support gradients"):
with tf.GradientTape():
_ = restored_model.f(xv)
def test_save_without_gradients(self):
f_jax = lambda x: x * x
x = np.array(0.7, dtype=jnp.float32)
model = tf.Module()
model.f = tf.function(jax2tf.convert(f_jax, with_gradient=True),
autograph=False,
input_signature=[tf.TensorSpec(x.shape, x.dtype)])
self.assertAllClose(model.f(x), f_jax(x))
restored_model = tf_test_util.SaveAndLoadModel(model,
save_gradients=False)
self.assertAllClose(restored_model.f(x), f_jax(x))
xv = tf.Variable(x)
with tf.GradientTape():
_ = restored_model.f(xv)
# TODO: clean this up b/191117111: it should fail with a clear error
# The following results in a confusing error:
# TypeError: An op outside of the function building code is being passed
# a "Graph" tensor. It is possible to have Graph tensors
# leak out of the function building context by including a
# tf.init_scope in your function building code.
# For example, the following function will fail:
# @tf.function
# def has_init_scope():
# my_constant = tf.constant(1.)
# with tf.init_scope():
# added = my_constant * 2
# The graph tensor has name: args_0:0
# g = tape.gradient(res, xv)
#self.assertAllClose(g.numpy(), jax.grad(f_jax)(x))
def test_save_without_embedding_params(self):
def model_jax(params, inputs):
return params[0] + params[1] * inputs
params = (np.array(1.0, dtype=jnp.float32),
np.array(2.0, dtype=jnp.float32))
params_vars = tf.nest.map_structure(tf.Variable, params)
prediction_tf = lambda x: jax2tf.convert(model_jax)(params_vars, x)
model = tf.Module()
model._variables = tf.nest.flatten(params_vars)
model.f = tf.function(prediction_tf, jit_compile=True, autograph=False)
x = np.array(0.7, dtype=jnp.float32)
self.assertAllClose(model.f(x), model_jax(params, x))
restored_model = tf_test_util.SaveAndLoadModel(model,
save_gradients=False)
self.assertAllClose(restored_model.f(x), model_jax(params, x))
def test_save_grad_integers(self):
# https://github.com/google/jax/issues/7123
# In the end this is a test that does not involve JAX at all
batch_size = 5
state = np.array([1], dtype=np.int32) # Works if float32
params = np.ones((3, 3), dtype=np.float32)
# params: f32[3, 3], state: i32[1]
# returns f32[5, 2] constant, and the state
def tf_predict(params, state):
# state = tf.cast(state, tf.float32)
# Setup a custom-gradient, like jax2tf would
@tf.custom_gradient
def converted_fun_with_custom_gradient(params, state):
res_out = tf.zeros((batch_size, 2), dtype=tf.float32)
state_out = state # tf.zeros((4, 4), np.int32),
return ((res_out, state_out), converted_grad_fn)
def converted_grad_fn(res_out_ct, state_out_ct, variables=None):
# The gradients for params and the state
return tf.zeros(params.shape, dtype=params.dtype), state_out_ct
res, state_out = converted_fun_with_custom_gradient(params, state)
# state_out = tf.cast(state_out, tf.int32)
return res, state_out
# Compute the gradient before saving. This works!
params_v = tf.Variable(params)
with tf.GradientTape() as tape:
preds = tf_predict(params_v, state)[0]
loss = tf.reduce_mean(preds)
g = tape.gradient(loss, params_v)
self.assertAllClose(g.numpy(), np.zeros(params.shape, dtype=params.dtype))
# TF -> SavedModel
model = tf.Module()
model.fn = tf.function(tf_predict, autograph=False)
model.fn.get_concrete_function(
tf.TensorSpec(params.shape, params.dtype),
tf.TensorSpec(state.shape, state.dtype))
save_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(model)))
options = tf.saved_model.SaveOptions(experimental_custom_gradients=True)
_ = tf.saved_model.save(model, save_dir, options=options)
restored_module = tf.saved_model.load(save_dir)
# It seems that saving and reloading is important
restored_fn = restored_module.fn
# Compute the gradients after saving and restoring. Fails!
with tf.GradientTape() as tape:
preds = restored_fn(params_v, state)[0]
loss = tf.reduce_mean(preds)
g = tape.gradient(loss, params_v)
self.assertAllClose(g.numpy(), np.zeros(params.shape, dtype=params.dtype))
def _compare_with_saved_model(self, f_jax, *args):
# Certain ops are converted to ensure an XLA context, e.g.,
# tf.gather, so that the index-out-of-bounds behavior matches that of
# JAX. We check that this information is preserved through a savedmodel
f_tf = jax2tf.convert(f_jax)
res = f_tf(*args)
restored_f, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=args)
res_restored = restored_f(*args)
self.assertAllClose(res, res_restored)
def test_pytree(self):
def f_jax(params, x):
# params is a dict
return x @ params["w"] + params["b"]
x = np.ones((2, 3), dtype=np.float32)
params = dict(w=np.ones((3, 4), dtype=np.float32),
b=np.ones((2, 4), dtype=np.float32))
res_jax = f_jax(params, x)
f_tf = jax2tf.convert(f_jax)
res_tf = f_tf(params, x)
self.assertAllClose(res_jax, res_tf.numpy())
restored_f, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=(params, x),
save_gradients=True)
self.assertAllClose(restored_f(params, x).numpy(), res_tf.numpy())
# Gradients for the converted function
params_v = tf.nest.map_structure(tf.Variable, params)
with tf.GradientTape() as tape:
res = f_tf(params_v, x)
loss = tf.reduce_sum(res)
g_tf = tape.gradient(loss, params_v)
params_v = tf.nest.map_structure(tf.Variable, params)
with tf.GradientTape() as tape:
res = restored_f(params_v, x)
loss = tf.reduce_sum(res)
g_restored_f = tape.gradient(loss, params_v)
self.assertAllClose(g_tf["w"].numpy(), g_restored_f["w"].numpy())
self.assertAllClose(g_tf["b"].numpy(), g_restored_f["b"].numpy())
def test_xla_context_preserved_slice(self):
arr = np.arange(10, dtype=np.float32)
def f_jax(arr):
return lax.dynamic_slice(arr, [100], [1]) # out of bounds, should return the last element
self._compare_with_saved_model(f_jax, arr)
def test_xla_context_preserved_gather(self):
def f_jax(arr):
return arr[100] # out of bounds, should return the last element
arr = np.arange(10, dtype=np.float32)
self._compare_with_saved_model(f_jax, arr)
# Test does not work on GPU/TPU; would need something like TPU inference
# converter to separate the model on what needs to run on CPU or accelerator.
@jtu.skip_on_devices("gpu", "tpu")
def test_tf_mix_jax_with_uncompilableble(self):
"""Show how to combine TF-uncompilableble code with compiled JAX-converted code."""
def tf_fn(x_str, compute_tf_fn=lambda x: x):
# Some TF preprocessing code that cannot be compiled with XLA because it
# uses strings.
numbers_f32 = tf.strings.to_number(x_str, out_type=tf.float32)
numbers_f16 = tf.cast(numbers_f32, tf.float16)
return compute_tf_fn(numbers_f16)
x_str = np.array(["3.14", "2.78"])
# Test that we get an error if we try to TF-compile `tf_fn`
with self.assertRaisesRegex(
Exception,
"Detected unsupported operations when trying to compile graph"):
tf.function(tf_fn, jit_compile=True, autograph=False)(x_str)
# Plug in the TF-compiled JAX-converted `compute_jax_fn`.
composed_fn = lambda x_str: tf_fn(
x_str,
compute_tf_fn=tf.function(jax2tf.convert(jnp.sin),
autograph=False,
jit_compile=True))
res_tf = composed_fn(x_str)
self.assertAllClose(res_tf.numpy(),
jnp.sin(np.array([3.14, 2.78], dtype=np.float16)))
# Save and restore SavedModel
restored_f, _ = tf_test_util.SaveAndLoadFunction(composed_fn,
input_args=[x_str])
res_tf_restored = restored_f(x_str)
self.assertAllClose(res_tf_restored.numpy(), res_tf.numpy())
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())