Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/jax2tf/tests/model_harness.py
2023-06-19 00:49:18 +02:00

418 lines
14 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All the models to convert."""
import dataclasses
import functools
from typing import Any, Callable, Dict, Optional, Sequence, Union
import re
import numpy as np
import jraph
from jax.experimental.jax2tf.tests.flax_models import actor_critic
from jax.experimental.jax2tf.tests.flax_models import bilstm_classifier
from jax.experimental.jax2tf.tests.flax_models import cnn
from jax.experimental.jax2tf.tests.flax_models import gnn
from jax.experimental.jax2tf.tests.flax_models import resnet
from jax.experimental.jax2tf.tests.flax_models import seq2seq_lstm
from jax.experimental.jax2tf.tests.flax_models import transformer_lm1b as lm1b
from jax.experimental.jax2tf.tests.flax_models import transformer_nlp_seq as nlp_seq
from jax.experimental.jax2tf.tests.flax_models import transformer_wmt as wmt
from jax.experimental.jax2tf.tests.flax_models import vae
import jax
from jax import random
import tensorflow as tf
@dataclasses.dataclass
class ModelHarness:
name: str
apply: Callable[..., Any]
variables: Dict[str, Any]
inputs: Sequence[np.ndarray]
rtol: float = 1e-4
polymorphic_shapes: Optional[Sequence[Union[str, None]]] = None
tensor_spec: Optional[Sequence[tf.TensorSpec]] = None
def __post_init__(self):
# When providing polymorphic shapes, tensor_spec should be provided as well.
assert bool(self.polymorphic_shapes) == bool(self.tensor_spec)
@property
def tf_input_signature(self):
def _to_tensorspec(x):
return tf.TensorSpec(x.shape, tf.dtypes.as_dtype(x.dtype))
if self.tensor_spec:
return self.tensor_spec
else:
return jax.tree_util.tree_map(_to_tensorspec, self.inputs)
def apply_with_vars(self, *args, **kwargs):
return self.apply(self.variables, *args, **kwargs)
##### All harnesses in this file.
ALL_HARNESSES: Dict[str, Callable[[str], ModelHarness]] = {}
def _make_harness(harness_fn, name, poly_shapes=None, tensor_specs=None):
"""Partially apply harness in order to create variables lazily.
Note: quotes and commas are stripped from `name` to ensure they can be passed
through the command-line.
"""
if poly_shapes:
name += "_" + re.sub(r"(?:'|\"|,)", "", str(poly_shapes))
if tensor_specs:
tensor_specs = [tf.TensorSpec(spec, dtype) for spec, dtype in tensor_specs]
partial_fn = functools.partial(
harness_fn,
name=name,
polymorphic_shapes=poly_shapes,
tensor_spec=tensor_specs)
if name in ALL_HARNESSES:
raise ValueError(f"Harness {name} exists already")
ALL_HARNESSES[name] = partial_fn
######################## Model Harness Definitions #############################
def _actor_critic_harness(name, **kwargs):
model = actor_critic.ActorCritic(num_outputs=8)
x = np.zeros((1, 84, 84, 4), np.float32)
variables = model.init(random.PRNGKey(0), x)
return ModelHarness(name, model.apply, variables, [x], **kwargs)
def _bilstm_harness(name, **kwargs):
model = bilstm_classifier.TextClassifier(
# TODO(marcvanzee): This fails when
# `embedding_size != hidden_size`. I suppose some arrays are
# concatenated with incompatible shapes, which could mean
# something is going wrong in the translation.
embedding_size=3,
hidden_size=1,
vocab_size=13,
output_size=1,
dropout_rate=0.,
word_dropout_rate=0.)
x = np.array([[2, 4, 3], [2, 6, 3]], np.int32)
lengths = np.array([2, 3], np.int32)
variables = model.init(random.PRNGKey(0), x, lengths, deterministic=True)
apply = functools.partial(model.apply, deterministic=True)
return ModelHarness(name, apply, variables, [x, lengths], **kwargs)
def _cnn_harness(name, **kwargs):
model = cnn.CNN()
x = np.zeros((1, 28, 28, 1), np.float32)
variables = model.init(random.PRNGKey(0), x)
return ModelHarness(name, model.apply, variables, [x], **kwargs)
def _get_gnn_graphs():
n_node = np.arange(3, 11)
n_edge = np.arange(4, 12)
total_n_node = np.sum(n_node)
total_n_edge = np.sum(n_edge)
n_graph = n_node.shape[0]
feature_dim = 10
graphs = jraph.GraphsTuple(
n_node=n_node,
n_edge=n_edge,
senders=np.zeros(total_n_edge, dtype=np.int32),
receivers=np.ones(total_n_edge, dtype=np.int32),
nodes=np.ones((total_n_node, feature_dim)),
edges=np.zeros((total_n_edge, feature_dim)),
globals=np.zeros((n_graph, feature_dim)),
)
return graphs
def _gnn_harness(name, **kwargs):
# Setting taken from flax/examples/ogbg_molpcba/models_test.py.
rngs = {
'params': random.PRNGKey(0),
'dropout': random.PRNGKey(1),
}
graphs = _get_gnn_graphs()
model = gnn.GraphNet(
latent_size=5,
num_mlp_layers=2,
message_passing_steps=2,
output_globals_size=15,
use_edge_model=True)
variables = model.init(rngs, graphs)
return ModelHarness(name, model.apply, variables, [graphs], rtol=2e-4,
**kwargs)
def _gnn_conv_harness(name, **kwargs):
# Setting taken from flax/examples/ogbg_molpcba/models_test.py.
rngs = {
'params': random.PRNGKey(0),
'dropout': random.PRNGKey(1),
}
graphs = _get_gnn_graphs()
model = gnn.GraphConvNet(
latent_size=5,
num_mlp_layers=2,
message_passing_steps=2,
output_globals_size=5)
variables = model.init(rngs, graphs)
return ModelHarness(name, model.apply, variables, [graphs], **kwargs)
def _resnet50_harness(name, **kwargs):
model = resnet.ResNet50(num_classes=2, dtype=np.float32)
x = np.zeros((8, 16, 16, 3), np.float32)
variables = model.init(random.PRNGKey(0), x)
apply = functools.partial(model.apply, train=False, mutable=False)
return ModelHarness(name, apply, variables, [x], **kwargs)
def _seq2seq_lstm_harness(name, **kwargs):
model = seq2seq_lstm.Seq2seq(teacher_force=True, hidden_size=2, vocab_size=4)
encoder_inputs = np.zeros((1, 2, 4), np.float32) # [batch, inp_len, vocab]
decoder_inputs = np.zeros((1, 3, 4), np.float32) # [batch, outp_len, vocab]
rngs = {
'params': random.PRNGKey(0),
'lstm': random.PRNGKey(1),
}
xs = [encoder_inputs, decoder_inputs]
variables = model.init(rngs, *xs)
apply = functools.partial(model.apply, rngs={'lstm': random.PRNGKey(2)})
return ModelHarness(name, apply, variables, xs, **kwargs)
def _min_transformer_kwargs():
return dict(
vocab_size=8,
output_vocab_size=8,
emb_dim = 4,
num_heads= 1,
num_layers = 1,
qkv_dim= 2,
mlp_dim = 2,
max_len = 2,
dropout_rate = 0.,
attention_dropout_rate = 0.)
def _full_transformer_kwargs():
kwargs = dict(
decode = True,
deterministic = True,
logits_via_embedding=False,
share_embeddings=False)
return {**kwargs, **_min_transformer_kwargs()}
def _transformer_lm1b_harness(name, **kwargs):
config = lm1b.TransformerConfig(**_full_transformer_kwargs())
model = lm1b.TransformerLM(config=config)
x = np.zeros((2, 1), np.float32)
rng1, rng2 = random.split(random.PRNGKey(0))
variables = model.init(rng1, x)
def apply(*args):
# Don't return the new state (containing the cache).
output, _ = model.apply(*args, rngs={'cache': rng2}, mutable=['cache'])
return output
return ModelHarness(name, apply, variables, [x], **kwargs)
def _transformer_nlp_seq_harness(name, **kwargs):
config = nlp_seq.TransformerConfig(**_min_transformer_kwargs())
model = nlp_seq.Transformer(config=config)
x = np.zeros((2, 1), np.float32)
variables = model.init(random.PRNGKey(0), x, train=False)
apply = functools.partial(model.apply, train=False)
return ModelHarness(name, apply, variables, [x], **kwargs)
def _transformer_wmt_harness(name, **kwargs):
config = wmt.TransformerConfig(**_full_transformer_kwargs())
model = wmt.Transformer(config=config)
x = np.zeros((2, 1), np.float32)
variables = model.init(random.PRNGKey(0), x, x)
def apply(*args):
# Don't return the new state (containing the cache).
output, _ = model.apply(*args, mutable=['cache'])
return output
return ModelHarness(name, apply, variables, [x, x], **kwargs)
def _vae_harness(name, **kwargs):
model = vae.VAE(latents=3)
x = np.zeros((1, 8, 8, 3), np.float32)
rng1, rng2 = random.split(random.PRNGKey(0))
variables = model.init(rng1, x, rng2)
generate = lambda v, x: model.apply(v, x, method=model.generate)
return ModelHarness(name, generate, variables, [x], **kwargs)
####################### Model Harness Construction #############################
# actor_critic input spec: [((1, 84, 84, 4), np.float32)].
for poly_shapes, tensor_specs in [
(None, None), # No polymorphism.
# batch polymorphism.
(["(b, ...)"], [((None, 84, 84, 4), tf.float32)]),
# Dependent shapes for spatial dims.
# TODO(marcvanzee): Figure out the right multiple for these dimensions.
(["(_, 4*b, 4*b, _)"], [((1, None, None, 4), tf.float32)]),
]:
_make_harness(
harness_fn=_actor_critic_harness,
name="flax/actor_critic",
poly_shapes=poly_shapes,
tensor_specs=tensor_specs)
# bilstm input specs: [((2, 3), np.int32), ((2,), np.int32)] = [inputs, lengths]
for poly_shapes, tensor_specs in [ # type: ignore
(None, None),
# batch polymorphism
(["(b, _)", "(_,)"], [((None, 3), tf.int32), ((2,), tf.int32)]),
# dynamic input lengths
(["(_, _)", "(b,)"], [((2, 3), tf.int32), ((None,), tf.int32)]),
]:
_make_harness(
harness_fn=_bilstm_harness,
name="flax/bilstm",
poly_shapes=poly_shapes,
tensor_specs=tensor_specs)
# cnn input spec: [((1, 28, 28, 1), np.float32)].
for poly_shapes, tensor_specs in [
(None, None), # No polymorphism.
# batch polymorphism.
(["(b, ...)"], [((None, 28, 28, 1), tf.float32)]),
# Dependent shapes for spatial dims.
# TODO(marcvanzee): Figure out the right multiple for these dimensions.
(["(_, b, b, _)"], [((1, None, None, 1), tf.float32)]),
]:
_make_harness(
harness_fn=_cnn_harness,
name="flax/cnn",
poly_shapes=poly_shapes,
tensor_specs=tensor_specs)
# We do not support polymorphism for the GNN examples since they use GraphTuples
# as input rather than regular arrays.
_make_harness(harness_fn=_gnn_harness, name="flax/gnn")
_make_harness(harness_fn=_gnn_conv_harness, name="flax/gnn_conv")
# resnet50 input spec: [((8, 16, 16, 3), np.float32)]
for poly_shapes, tensor_specs in [
(None, None), # No polymorphism.
# batch polymorphism.
(["(b, ...)"], [((None, 16, 16, 3), tf.float32)]),
# Dependent shapes for spatial dims.
# TODO(marcvanzee): Figure out the right multiple for these dimensions.
(["(_, 4*b, 4*b, _)"], [((8, None, None, 3), tf.float32)]),
]:
_make_harness(
harness_fn=_resnet50_harness,
name="flax/resnet50",
poly_shapes=poly_shapes,
tensor_specs=tensor_specs)
# seq2seq input specs (we use the same input and output lengths for now):
# [
# ((1, 2, 4), np.float32), # encoder inp: [batch, max_input_len, vocab_size]
# ((1, 3, 4), np.float32), # decoder_inp: [batch, max_output_len, vocab_size]
# ]
for poly_shapes, tensor_specs in [ # type: ignore
(None, None),
# batch polymorphism
(
["(b, _, _)", "(b, _, _)"],
[((None, 2, 4), tf.float32), ((None, 3, 4), tf.float32)],
),
# dynamic input lengths
(
["(_, b, _)", "(_, _, _)"],
[((1, None, 4), tf.float32), ((1, 3, 4), tf.float32)],
),
# dynamic output lengths
(
["(_, _, _)", "(_, b, _)"],
[((1, 2, 4), tf.float32), ((1, None, 4), tf.float32)],
),
]:
_make_harness(
harness_fn=_seq2seq_lstm_harness,
name="flax/seq2seq_lstm",
poly_shapes=poly_shapes,
tensor_specs=tensor_specs)
# lm1b/nlp_seq input spec: [((2, 1), np.float32)] [batch, seq_len]
for poly_shapes, tensor_specs in [ # type: ignore
(None, None),
# batch polymorphism.
(["(b, _)"], [((None, 1), tf.float32)]),
]:
for name, harness_fn in [
("flax/lm1b", _transformer_lm1b_harness),
("flax/nlp_seq", _transformer_nlp_seq_harness)
]:
_make_harness(
harness_fn=harness_fn,
name=name,
poly_shapes=poly_shapes,
tensor_specs=tensor_specs)
# wmt input spec (both inputs have the same shape):
# [
# ((1, 2), np.float32), # inputs: [batch, max_target_len]
# ((1, 2), np.float32), # targets: [batch, max_target_len]
# ]
for poly_shapes, tensor_specs in [ # type: ignore
(None, None),
# batch polymorphism.
(["(b, _)"] * 2, [((None, 1), tf.float32)] * 2),
# dynamic lengths.
(["(_, b)"] * 2, [((1, None), tf.float32)] * 2),
]:
_make_harness(
harness_fn=_transformer_wmt_harness,
name="flax/wmt",
poly_shapes=poly_shapes,
tensor_specs=tensor_specs)
# vae input spec: [((1, 8, 8, 3), np.float32)].
for poly_shapes, tensor_specs in [
(None, None), # No polymorphism.
# batch polymorphism.
(["(b, ...)"], [((None, 8, 8, 3), tf.float32)]),
# Dependent shapes for spatial dims.
# TODO(marcvanzee): Figure out the right multiple for these dimensions.
(["(_, b, b, _)"], [((1, None, None, 3), tf.float32)]),
]:
_make_harness(
harness_fn=_vae_harness,
name="flax/vae",
poly_shapes=poly_shapes,
tensor_specs=tensor_specs)