168 lines
5.5 KiB
Python
168 lines
5.5 KiB
Python
|
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Keras Utilities for DTensor related API."""
|
||
|
|
||
|
import inspect
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras.dtensor import dtensor_api as dtensor
|
||
|
|
||
|
# All the variable names in the default keras layers. We will use those to map
|
||
|
# against the args in the __init__ method to find corresponding layout args.
|
||
|
# See allow_layout() for more details.
|
||
|
KERAS_VARIABLE_NAMES = [
|
||
|
"alpha",
|
||
|
"beta",
|
||
|
"bias",
|
||
|
"depthwise",
|
||
|
"embeddings",
|
||
|
"gamma",
|
||
|
"kernel",
|
||
|
"moving_mean",
|
||
|
"moving_variance",
|
||
|
"pointwise",
|
||
|
"recurrent",
|
||
|
]
|
||
|
|
||
|
|
||
|
def allow_initializer_layout(init_method):
|
||
|
"""A decorator for injecting layout information to layer.__init__.
|
||
|
|
||
|
Layout will be a new param for any of the weights for all the keras layers.
|
||
|
Adding the param to all the __init__ method will be a big/duplicated work.
|
||
|
|
||
|
This decorator is design to reduce and code duplication and make it easy to
|
||
|
add/remove the dtensor feature if needed.
|
||
|
|
||
|
Sample usage:
|
||
|
```python
|
||
|
class Dense(tf.keras.layer.Layer):
|
||
|
|
||
|
@allow_initializer_layout
|
||
|
def __init__(self, units,
|
||
|
kernel_initializer='zeros',
|
||
|
bias_initializer='zeros',
|
||
|
**kwargs):
|
||
|
super().__init__(**kwargs)
|
||
|
|
||
|
d = Dense(units=8, kernel_layout=layout1, bias_layout=layout2)
|
||
|
d.kernel_layout == layout1
|
||
|
d.bias_layout == layout2
|
||
|
```
|
||
|
|
||
|
By adding this annotation, it will:
|
||
|
|
||
|
1. Filter out the kwargs based on some keywords, eg if the
|
||
|
'kernel_initialzer' appears in method signature, then it will try to pop
|
||
|
the 'kernel_layout' if it presents. Same for "bias" and
|
||
|
"recurrent_kernel", etc. This will make sure the layout related param is
|
||
|
not passed to `BaseLayer.__init__`, which will raise error about unexpect
|
||
|
keyword args.
|
||
|
2. Set the self.kernel/bias_layout attribute after the `__init__` method is
|
||
|
called. Keras framework will use those fields to create weights down the
|
||
|
stream.
|
||
|
|
||
|
Args:
|
||
|
init_method: the `__init__` method of the Keras layer to annotate.
|
||
|
|
||
|
Returns:
|
||
|
the annotated __init__ method.
|
||
|
"""
|
||
|
|
||
|
def _wrap_function(layer_instance, *args, **kwargs):
|
||
|
signature = inspect.signature(init_method)
|
||
|
layout_args = {}
|
||
|
# Check args like 'kernel_initializer' and pop the 'kernel_layout' if it
|
||
|
# presents.
|
||
|
for variable_name in KERAS_VARIABLE_NAMES:
|
||
|
if variable_name + "_initializer" in signature.parameters:
|
||
|
layout = kwargs.pop(variable_name + "_layout", None)
|
||
|
if layout:
|
||
|
layout_args[variable_name + "_layout"] = layout
|
||
|
|
||
|
init_method(layer_instance, *args, **kwargs)
|
||
|
|
||
|
# Inject the layout parameter after the invocation of __init__()
|
||
|
for layout_param_name, layout in layout_args.items():
|
||
|
setattr(layer_instance, layout_param_name, layout)
|
||
|
|
||
|
# return decorated
|
||
|
return tf.__internal__.decorator.make_decorator(
|
||
|
target=init_method, decorator_func=_wrap_function
|
||
|
)
|
||
|
|
||
|
|
||
|
def inject_mesh(init_method):
|
||
|
"""Inject DTensor mesh information to an object.
|
||
|
|
||
|
This is useful for keras object like `Metric` and `Optimizer` which need
|
||
|
DTensor mesh to create the weights, but doesn't want to change the current
|
||
|
public API interface.
|
||
|
|
||
|
This is for temporary usage and eventually the mesh/layout information will
|
||
|
be public arguments in the `__init__` method.
|
||
|
|
||
|
Sample usage:
|
||
|
```python
|
||
|
class Accuracy(tf.keras.metrics.Metric):
|
||
|
|
||
|
@inject_mesh
|
||
|
def __init__(self, name='accuracy', dtype=None):
|
||
|
super().__init__(**kwargs)
|
||
|
|
||
|
acc = Accuracy(mesh=mesh)
|
||
|
assert acc._mesh == mesh
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
init_method: the `__init__` method of the Keras class to annotate.
|
||
|
|
||
|
Returns:
|
||
|
the annotated __init__ method.
|
||
|
"""
|
||
|
|
||
|
def _wrap_function(instance, *args, **kwargs):
|
||
|
mesh = kwargs.pop("mesh", None)
|
||
|
# Note that the injection of _mesh need to happen before the invocation
|
||
|
# of __init__, since the class might need the mesh to create weights in
|
||
|
# the __init__.
|
||
|
if mesh is not None:
|
||
|
instance._mesh = mesh
|
||
|
init_method(instance, *args, **kwargs)
|
||
|
|
||
|
return tf.__internal__.decorator.make_decorator(
|
||
|
target=init_method, decorator_func=_wrap_function
|
||
|
)
|
||
|
|
||
|
|
||
|
def call_with_layout(fn, layout, *args, **kwargs):
|
||
|
"""Invoke the function with inputs and relayout the result.
|
||
|
|
||
|
Args:
|
||
|
fn: the function to invoke.
|
||
|
layout: if not None, the output of the fn will be relayout with this.
|
||
|
*args: positional arguments to be called with fn.
|
||
|
**kwargs: keyword arguments to be called with fn.
|
||
|
|
||
|
Returns:
|
||
|
The output of fn, with potential relayout with the layout specified.
|
||
|
"""
|
||
|
if layout:
|
||
|
with dtensor.run_on(layout.mesh):
|
||
|
result = fn(*args, **kwargs)
|
||
|
return dtensor.relayout(result, layout)
|
||
|
return fn(*args, **kwargs)
|