2023-06-19 00:49:18 +02:00

168 lines
5.5 KiB

# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras Utilities for DTensor related API."""
import inspect
import tensorflow.compat.v2 as tf
from keras.dtensor import dtensor_api as dtensor
# All the variable names in the default keras layers. We will use those to map
# against the args in the __init__ method to find corresponding layout args.
# See allow_layout() for more details.
def allow_initializer_layout(init_method):
"""A decorator for injecting layout information to layer.__init__.
Layout will be a new param for any of the weights for all the keras layers.
Adding the param to all the __init__ method will be a big/duplicated work.
This decorator is design to reduce and code duplication and make it easy to
add/remove the dtensor feature if needed.
Sample usage:
class Dense(tf.keras.layer.Layer):
def __init__(self, units,
d = Dense(units=8, kernel_layout=layout1, bias_layout=layout2)
d.kernel_layout == layout1
d.bias_layout == layout2
By adding this annotation, it will:
1. Filter out the kwargs based on some keywords, eg if the
'kernel_initialzer' appears in method signature, then it will try to pop
the 'kernel_layout' if it presents. Same for "bias" and
"recurrent_kernel", etc. This will make sure the layout related param is
not passed to `BaseLayer.__init__`, which will raise error about unexpect
keyword args.
2. Set the self.kernel/bias_layout attribute after the `__init__` method is
called. Keras framework will use those fields to create weights down the
init_method: the `__init__` method of the Keras layer to annotate.
the annotated __init__ method.
def _wrap_function(layer_instance, *args, **kwargs):
signature = inspect.signature(init_method)
layout_args = {}
# Check args like 'kernel_initializer' and pop the 'kernel_layout' if it
# presents.
for variable_name in KERAS_VARIABLE_NAMES:
if variable_name + "_initializer" in signature.parameters:
layout = kwargs.pop(variable_name + "_layout", None)
if layout:
layout_args[variable_name + "_layout"] = layout
init_method(layer_instance, *args, **kwargs)
# Inject the layout parameter after the invocation of __init__()
for layout_param_name, layout in layout_args.items():
setattr(layer_instance, layout_param_name, layout)
# return decorated
return tf.__internal__.decorator.make_decorator(
target=init_method, decorator_func=_wrap_function
def inject_mesh(init_method):
"""Inject DTensor mesh information to an object.
This is useful for keras object like `Metric` and `Optimizer` which need
DTensor mesh to create the weights, but doesn't want to change the current
public API interface.
This is for temporary usage and eventually the mesh/layout information will
be public arguments in the `__init__` method.
Sample usage:
class Accuracy(tf.keras.metrics.Metric):
def __init__(self, name='accuracy', dtype=None):
acc = Accuracy(mesh=mesh)
assert acc._mesh == mesh
init_method: the `__init__` method of the Keras class to annotate.
the annotated __init__ method.
def _wrap_function(instance, *args, **kwargs):
mesh = kwargs.pop("mesh", None)
# Note that the injection of _mesh need to happen before the invocation
# of __init__, since the class might need the mesh to create weights in
# the __init__.
if mesh is not None:
instance._mesh = mesh
init_method(instance, *args, **kwargs)
return tf.__internal__.decorator.make_decorator(
target=init_method, decorator_func=_wrap_function
def call_with_layout(fn, layout, *args, **kwargs):
"""Invoke the function with inputs and relayout the result.
fn: the function to invoke.
layout: if not None, the output of the fn will be relayout with this.
*args: positional arguments to be called with fn.
**kwargs: keyword arguments to be called with fn.
The output of fn, with potential relayout with the layout specified.
if layout:
with dtensor.run_on(layout.mesh):
result = fn(*args, **kwargs)
return dtensor.relayout(result, layout)
return fn(*args, **kwargs)