185 lines
7.4 KiB
Python
185 lines
7.4 KiB
Python
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utilities for slicing in to a `LinearOperator`."""
|
|
|
|
import collections
|
|
import functools
|
|
import numpy as np
|
|
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.util import nest
|
|
|
|
|
|
__all__ = ['batch_slice']
|
|
|
|
|
|
def _prefer_static_where(condition, x, y):
|
|
args = [condition, x, y]
|
|
constant_args = [tensor_util.constant_value(a) for a in args]
|
|
# Do this statically.
|
|
if all(arg is not None for arg in constant_args):
|
|
condition_, x_, y_ = constant_args
|
|
return np.where(condition_, x_, y_)
|
|
return array_ops.where(condition, x, y)
|
|
|
|
|
|
def _broadcast_parameter_with_batch_shape(
|
|
param, param_ndims_to_matrix_ndims, batch_shape):
|
|
"""Broadcasts `param` with the given batch shape, recursively."""
|
|
if hasattr(param, 'batch_shape_tensor'):
|
|
# Recursively broadcast every parameter inside the operator.
|
|
override_dict = {}
|
|
for name, ndims in param._experimental_parameter_ndims_to_matrix_ndims.items(): # pylint:disable=protected-access,line-too-long
|
|
sub_param = getattr(param, name)
|
|
override_dict[name] = nest.map_structure_up_to(
|
|
sub_param, functools.partial(
|
|
_broadcast_parameter_with_batch_shape,
|
|
batch_shape=batch_shape), sub_param, ndims)
|
|
parameters = dict(param.parameters, **override_dict)
|
|
return type(param)(**parameters)
|
|
|
|
base_shape = array_ops.concat(
|
|
[batch_shape, array_ops.ones(
|
|
[param_ndims_to_matrix_ndims], dtype=dtypes.int32)], axis=0)
|
|
return array_ops.broadcast_to(
|
|
param,
|
|
array_ops.broadcast_dynamic_shape(base_shape, array_ops.shape(param)))
|
|
|
|
|
|
def _sanitize_slices(slices, intended_shape, deficient_shape):
|
|
"""Restricts slices to avoid overflowing size-1 (broadcast) dimensions.
|
|
|
|
Args:
|
|
slices: iterable of slices received by `__getitem__`.
|
|
intended_shape: int `Tensor` shape for which the slices were intended.
|
|
deficient_shape: int `Tensor` shape to which the slices will be applied.
|
|
Must have the same rank as `intended_shape`.
|
|
Returns:
|
|
sanitized_slices: Python `list` of slice objects.
|
|
"""
|
|
sanitized_slices = []
|
|
idx = 0
|
|
for slc in slices:
|
|
if slc is Ellipsis: # Switch over to negative indexing.
|
|
if idx < 0:
|
|
raise ValueError('Found multiple `...` in slices {}'.format(slices))
|
|
num_remaining_non_newaxis_slices = sum(
|
|
s is not array_ops.newaxis for s in slices[
|
|
slices.index(Ellipsis) + 1:])
|
|
idx = -num_remaining_non_newaxis_slices
|
|
elif slc is array_ops.newaxis:
|
|
pass
|
|
else:
|
|
is_broadcast = intended_shape[idx] > deficient_shape[idx]
|
|
if isinstance(slc, slice):
|
|
# Slices are denoted by start:stop:step.
|
|
start, stop, step = slc.start, slc.stop, slc.step
|
|
if start is not None:
|
|
start = _prefer_static_where(is_broadcast, 0, start)
|
|
if stop is not None:
|
|
stop = _prefer_static_where(is_broadcast, 1, stop)
|
|
if step is not None:
|
|
step = _prefer_static_where(is_broadcast, 1, step)
|
|
slc = slice(start, stop, step)
|
|
else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
|
|
slc = _prefer_static_where(is_broadcast, 0, slc)
|
|
idx += 1
|
|
sanitized_slices.append(slc)
|
|
return sanitized_slices
|
|
|
|
|
|
def _slice_single_param(
|
|
param, param_ndims_to_matrix_ndims, slices, batch_shape):
|
|
"""Slices into the batch shape of a single parameter.
|
|
|
|
Args:
|
|
param: The original parameter to slice; either a `Tensor` or an object
|
|
with batch shape (LinearOperator).
|
|
param_ndims_to_matrix_ndims: `int` number of right-most dimensions used for
|
|
inferring matrix shape of the `LinearOperator`. For non-Tensor
|
|
parameters, this is the number of this param's batch dimensions used by
|
|
the matrix shape of the parent object.
|
|
slices: iterable of slices received by `__getitem__`.
|
|
batch_shape: The parameterized object's batch shape `Tensor`.
|
|
|
|
Returns:
|
|
new_param: Instance of the same type as `param`, batch-sliced according to
|
|
`slices`.
|
|
"""
|
|
# Broadcast the parammeter to have full batch rank.
|
|
param = _broadcast_parameter_with_batch_shape(
|
|
param, param_ndims_to_matrix_ndims, array_ops.ones_like(batch_shape))
|
|
|
|
if hasattr(param, 'batch_shape_tensor'):
|
|
param_batch_shape = param.batch_shape_tensor()
|
|
else:
|
|
param_batch_shape = array_ops.shape(param)
|
|
# Truncate by param_ndims_to_matrix_ndims
|
|
param_batch_rank = array_ops.size(param_batch_shape)
|
|
param_batch_shape = param_batch_shape[
|
|
:(param_batch_rank - param_ndims_to_matrix_ndims)]
|
|
|
|
# At this point the param should have full batch rank, *unless* it's an
|
|
# atomic object like `tfb.Identity()` incapable of having any batch rank.
|
|
if (tensor_util.constant_value(array_ops.size(batch_shape)) != 0 and
|
|
tensor_util.constant_value(array_ops.size(param_batch_shape)) == 0):
|
|
return param
|
|
param_slices = _sanitize_slices(
|
|
slices, intended_shape=batch_shape, deficient_shape=param_batch_shape)
|
|
|
|
# Extend `param_slices` (which represents slicing into the
|
|
# parameter's batch shape) with the parameter's event ndims. For example, if
|
|
# `params_ndims == 1`, then `[i, ..., j]` would become `[i, ..., j, :]`.
|
|
if param_ndims_to_matrix_ndims > 0:
|
|
if Ellipsis not in [
|
|
slc for slc in slices if not tensor_util.is_tensor(slc)]:
|
|
param_slices.append(Ellipsis)
|
|
param_slices += [slice(None)] * param_ndims_to_matrix_ndims
|
|
return param.__getitem__(tuple(param_slices))
|
|
|
|
|
|
def batch_slice(linop, params_overrides, slices):
|
|
"""Slices `linop` along its batch dimensions.
|
|
|
|
Args:
|
|
linop: A `LinearOperator` instance.
|
|
params_overrides: A `dict` of parameter overrides.
|
|
slices: A `slice` or `int` or `int` `Tensor` or `tf.newaxis` or `tuple`
|
|
thereof. (e.g. the argument of a `__getitem__` method).
|
|
|
|
Returns:
|
|
new_linop: A batch-sliced `LinearOperator`.
|
|
"""
|
|
if not isinstance(slices, collections.abc.Sequence):
|
|
slices = (slices,)
|
|
if len(slices) == 1 and slices[0] is Ellipsis:
|
|
override_dict = {}
|
|
else:
|
|
batch_shape = linop.batch_shape_tensor()
|
|
override_dict = {}
|
|
for param_name, param_ndims_to_matrix_ndims in linop._experimental_parameter_ndims_to_matrix_ndims.items(): # pylint:disable=protected-access,line-too-long
|
|
param = getattr(linop, param_name)
|
|
# These represent optional `Tensor` parameters.
|
|
if param is not None:
|
|
override_dict[param_name] = nest.map_structure_up_to(
|
|
param, functools.partial(
|
|
_slice_single_param, slices=slices, batch_shape=batch_shape),
|
|
param, param_ndims_to_matrix_ndims)
|
|
override_dict.update(params_overrides)
|
|
parameters = dict(linop.parameters, **override_dict)
|
|
return type(linop)(**parameters)
|