268 lines
8.0 KiB
Python
268 lines
8.0 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Ops for GPU collective operations implemented using NVIDIA nccl."""
|
|
import threading
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.framework import device
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import gen_nccl_ops
|
|
|
|
|
|
_module_lock = threading.Lock()
|
|
_shared_name_counter = 0
|
|
|
|
|
|
def all_sum(tensors):
|
|
"""Returns a list of tensors with the all-reduce sum across `tensors`.
|
|
|
|
The computation is done with an all-reduce operation, so if only some of the
|
|
returned tensors are evaluated then the computation will hang.
|
|
|
|
Args:
|
|
tensors: The input tensors across which to sum; must be assigned
|
|
to GPU devices.
|
|
|
|
Returns:
|
|
List of tensors, each with the sum of the input tensors, where tensor i has
|
|
the same device as `tensors[i]`.
|
|
"""
|
|
return _apply_all_reduce('sum', tensors)
|
|
|
|
|
|
@ops.RegisterGradient('NcclAllReduce')
|
|
def _all_sum_grad(op, grad):
|
|
"""The gradients for `all_sum`.
|
|
|
|
Args:
|
|
op: The `all_sum` `Operation` that we are differentiating.
|
|
grad: Gradient with respect to the output of the `all_sum` op.
|
|
|
|
Returns:
|
|
The gradient with respect to the output of `all_sum`.
|
|
|
|
Raises:
|
|
LookupError: If `reduction` is not `sum`.
|
|
"""
|
|
if op.get_attr('reduction') != b'sum':
|
|
raise LookupError('No gradient defined for NcclAllReduce except for '
|
|
'reduction="sum".')
|
|
|
|
_check_device(grad, expected=op.device)
|
|
num_devices = op.get_attr('num_devices')
|
|
shared_name = op.get_attr('shared_name') + b'_grad'
|
|
|
|
with ops.device(op.device):
|
|
return gen_nccl_ops.nccl_all_reduce(
|
|
input=grad,
|
|
reduction='sum',
|
|
num_devices=num_devices,
|
|
shared_name=shared_name)
|
|
|
|
|
|
def all_prod(tensors):
|
|
"""Returns a list of tensors with the all-reduce product across `tensors`.
|
|
|
|
The computation is done with an all-reduce operation, so if only some of the
|
|
returned tensors are evaluated then the computation will hang.
|
|
|
|
Args:
|
|
tensors: The input tensors across which to multiply; must be assigned
|
|
to GPU devices.
|
|
|
|
Returns:
|
|
List of tensors, each with the product of the input tensors, where tensor i
|
|
has the same device as `tensors[i]`.
|
|
"""
|
|
return _apply_all_reduce('prod', tensors)
|
|
|
|
|
|
def all_min(tensors):
|
|
"""Returns a list of tensors with the all-reduce min across `tensors`.
|
|
|
|
The computation is done with an all-reduce operation, so if only some of the
|
|
returned tensors are evaluated then the computation will hang.
|
|
|
|
Args:
|
|
tensors: The input tensors across which to reduce; must be assigned
|
|
to GPU devices.
|
|
|
|
Returns:
|
|
List of tensors, each with the minimum of the input tensors, where tensor i
|
|
has the same device as `tensors[i]`.
|
|
"""
|
|
return _apply_all_reduce('min', tensors)
|
|
|
|
|
|
def all_max(tensors):
|
|
"""Returns a list of tensors with the all-reduce max across `tensors`.
|
|
|
|
The computation is done with an all-reduce operation, so if only some of the
|
|
returned tensors are evaluated then the computation will hang.
|
|
|
|
Args:
|
|
tensors: The input tensors across which to reduce; must be assigned
|
|
to GPU devices.
|
|
|
|
Returns:
|
|
List of tensors, each with the maximum of the input tensors, where tensor i
|
|
has the same device as `tensors[i]`.
|
|
"""
|
|
return _apply_all_reduce('max', tensors)
|
|
|
|
|
|
def reduce_sum(tensors):
|
|
"""Returns a tensor with the reduce sum across `tensors`.
|
|
|
|
The computation is done with a reduce operation, so only one tensor is
|
|
returned.
|
|
|
|
Args:
|
|
tensors: The input tensors across which to sum; must be assigned
|
|
to GPU devices.
|
|
|
|
Returns:
|
|
A tensor containing the sum of the input tensors.
|
|
|
|
Raises:
|
|
LookupError: If context is not currently using a GPU device.
|
|
"""
|
|
return _apply_reduce('sum', tensors)
|
|
|
|
|
|
@ops.RegisterGradient('NcclReduce')
|
|
def _reduce_sum_grad(op, grad):
|
|
"""The gradients for input `Operation` of `reduce_sum`.
|
|
|
|
Args:
|
|
op: The `sum send` `Operation` that we are differentiating.
|
|
grad: Gradient with respect to the output of the `reduce_sum` op.
|
|
|
|
Returns:
|
|
The gradient with respect to the input of `reduce_sum` op.
|
|
|
|
Raises:
|
|
LookupError: If the reduction attribute of op is not `sum`.
|
|
"""
|
|
if op.get_attr('reduction') != b'sum':
|
|
raise LookupError('No gradient defined for NcclAllReduce except for '
|
|
'reduction="sum".')
|
|
_check_device(grad, expected=op.device)
|
|
|
|
with ops.device(op.device):
|
|
result = gen_nccl_ops.nccl_broadcast(input=grad, shape=grad.shape)
|
|
|
|
return [result] * len(op.inputs)
|
|
|
|
|
|
def broadcast(tensor):
|
|
"""Returns a tensor that can be efficiently transferred to other devices.
|
|
|
|
Args:
|
|
tensor: The tensor to send; must be assigned to a GPU device.
|
|
|
|
Returns:
|
|
A tensor with the value of `src_tensor`, which can be used as input to
|
|
ops on other GPU devices.
|
|
"""
|
|
_check_device(tensor)
|
|
|
|
with ops.device(tensor.device):
|
|
return gen_nccl_ops.nccl_broadcast(input=tensor, shape=tensor.shape)
|
|
|
|
|
|
@ops.RegisterGradient('NcclBroadcast')
|
|
def _broadcast_grad(op, accumulated_grad):
|
|
"""The gradients for input `Operation` of `broadcast`.
|
|
|
|
Args:
|
|
op: The `broadcast send` `Operation` that we are differentiating.
|
|
accumulated_grad: Accumulated gradients with respect to the output of the
|
|
`broadcast` op.
|
|
|
|
Returns:
|
|
Gradients with respect to the input of `broadcast`.
|
|
"""
|
|
# Grab inputs of accumulated_grad and replace accumulation with reduce_sum.
|
|
grads = [t for t in accumulated_grad.op.inputs]
|
|
for t in grads:
|
|
_check_device(t)
|
|
|
|
with ops.device(op.device):
|
|
return gen_nccl_ops.nccl_reduce(input=grads, reduction='sum')
|
|
|
|
|
|
def _apply_all_reduce(reduction, tensors):
|
|
"""Helper function for all_* functions."""
|
|
if not tensors:
|
|
raise ValueError('Must pass >0 tensors to all reduce operations')
|
|
|
|
shared_name = _get_shared_name()
|
|
|
|
def _all_reduce():
|
|
"""Call nccl allreduce."""
|
|
res = []
|
|
for t in tensors:
|
|
_check_device(t)
|
|
with ops.device(t.device):
|
|
res.append(
|
|
gen_nccl_ops.nccl_all_reduce(
|
|
input=t,
|
|
reduction=reduction,
|
|
num_devices=len(tensors),
|
|
shared_name=shared_name))
|
|
return res
|
|
|
|
if context.executing_eagerly():
|
|
# Nccl ops will block unless they are executed concurrently such as in a
|
|
# graph or a defun.
|
|
return def_function.function(_all_reduce)()
|
|
else:
|
|
return _all_reduce()
|
|
|
|
|
|
def _apply_reduce(reduction, tensors):
|
|
"""Helper function for reduce_* functions."""
|
|
if not tensors:
|
|
raise ValueError('Must pass >0 tensors to reduce operations')
|
|
|
|
for t in tensors:
|
|
_check_device(t)
|
|
result = gen_nccl_ops.nccl_reduce(input=tensors, reduction=reduction)
|
|
try:
|
|
next(t for t in tensors if t.device == result.device)
|
|
except StopIteration:
|
|
raise ValueError('One input tensor must be assigned to current device')
|
|
return result
|
|
|
|
|
|
def _get_shared_name():
|
|
global _shared_name_counter
|
|
|
|
with _module_lock:
|
|
val = _shared_name_counter
|
|
_shared_name_counter += 1
|
|
return 'c%s' % val
|
|
|
|
|
|
def _check_device(tensor, expected=None):
|
|
if not device.canonical_name(tensor.device):
|
|
raise ValueError(f'Device assignment for tensor={tensor} required for nccl '
|
|
'collective ops')
|
|
if expected and expected != tensor.device:
|
|
raise ValueError(f'Expected device {expected}, got {tensor.device} for '
|
|
f'tensor={tensor}.')
|