504 lines
20 KiB
Python
504 lines
20 KiB
Python
|
import torch
|
||
|
import torch._six
|
||
|
from typing import Optional, List, DefaultDict
|
||
|
import warnings
|
||
|
from collections import defaultdict
|
||
|
import sys
|
||
|
import traceback
|
||
|
|
||
|
|
||
|
|
||
|
def _type(self, dtype=None, non_blocking=False, **kwargs):
|
||
|
"""Returns the type if `dtype` is not provided, else casts this object to
|
||
|
the specified type.
|
||
|
|
||
|
If this is already of the correct type, no copy is performed and the
|
||
|
original object is returned.
|
||
|
|
||
|
Args:
|
||
|
dtype (type or string): The desired type
|
||
|
non_blocking (bool): If ``True``, and the source is in pinned memory
|
||
|
and destination is on the GPU or vice versa, the copy is performed
|
||
|
asynchronously with respect to the host. Otherwise, the argument
|
||
|
has no effect.
|
||
|
**kwargs: For compatibility, may contain the key ``async`` in place of
|
||
|
the ``non_blocking`` argument. The ``async`` arg is deprecated.
|
||
|
"""
|
||
|
non_blocking = _get_async_or_non_blocking('type', non_blocking, kwargs)
|
||
|
if dtype is None:
|
||
|
return self.__module__ + '.' + self.__class__.__name__
|
||
|
|
||
|
if isinstance(dtype, str):
|
||
|
dtype = _import_dotted_name(dtype)
|
||
|
if dtype == type(self):
|
||
|
return self
|
||
|
if self.is_sparse:
|
||
|
if not dtype.is_sparse:
|
||
|
raise RuntimeError("Cannot cast sparse tensor to dense tensor")
|
||
|
new_module_name = dtype.__module__.replace('.sparse', '')
|
||
|
new_values_type_name = new_module_name + '.' + dtype.__name__
|
||
|
new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
|
||
|
new_indices_type_name = new_module_name + '.LongTensor'
|
||
|
new_indices = torch.Tensor._indices(self).type(new_indices_type_name, non_blocking)
|
||
|
return dtype(new_indices, new_values, self.size())
|
||
|
if dtype.is_sparse:
|
||
|
raise RuntimeError("Cannot cast dense tensor to sparse tensor")
|
||
|
return dtype(self.size()).copy_(self, non_blocking)
|
||
|
|
||
|
|
||
|
def _cuda(self, device=None, non_blocking=False, **kwargs):
|
||
|
"""Returns a copy of this object in CUDA memory.
|
||
|
|
||
|
If this object is already in CUDA memory and on the correct device, then
|
||
|
no copy is performed and the original object is returned.
|
||
|
|
||
|
Args:
|
||
|
device (int): The destination GPU id. Defaults to the current device.
|
||
|
non_blocking (bool): If ``True`` and the source is in pinned memory,
|
||
|
the copy will be asynchronous with respect to the host. Otherwise,
|
||
|
the argument has no effect.
|
||
|
**kwargs: For compatibility, may contain the key ``async`` in place of
|
||
|
the ``non_blocking`` argument.
|
||
|
"""
|
||
|
non_blocking = _get_async_or_non_blocking('cuda', non_blocking, kwargs)
|
||
|
if self.is_cuda:
|
||
|
if device is None:
|
||
|
device = torch.cuda.current_device()
|
||
|
if self.get_device() == device:
|
||
|
return self
|
||
|
else:
|
||
|
if device is None:
|
||
|
device = -1
|
||
|
with torch.cuda.device(device):
|
||
|
if self.is_sparse:
|
||
|
new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
|
||
|
indices = torch.Tensor._indices(self).cuda(device, non_blocking)
|
||
|
values = torch.Tensor._values(self).cuda(device, non_blocking)
|
||
|
return new_type(indices, values, self.size())
|
||
|
else:
|
||
|
new_type = getattr(torch.cuda, self.__class__.__name__)
|
||
|
return new_type(self.size()).copy_(self, non_blocking)
|
||
|
|
||
|
|
||
|
def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
|
||
|
if not kwargs:
|
||
|
return non_blocking
|
||
|
if len(kwargs) != 1 or 'async' not in kwargs:
|
||
|
message = "{}() got an unexpected keyword argument '{}'"
|
||
|
argument = list(kwargs.keys()).pop()
|
||
|
raise TypeError(message.format(function_name, argument))
|
||
|
warnings.warn("'async' is deprecated; use 'non_blocking'")
|
||
|
return kwargs['async']
|
||
|
|
||
|
|
||
|
# Note [Don't serialize hooks]
|
||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
# Since time immemorial, we have serialized the backward hooks associated with
|
||
|
# variables. This kind of half-worked--Python can pickle global functions
|
||
|
# (but not closures!)--but there were problems.
|
||
|
#
|
||
|
# - It's fragile. If you serialize a backward hook into a saved
|
||
|
# model, and then you rename the function associated with the hook,
|
||
|
# now your saved model is broken and you can't load it anymore.
|
||
|
#
|
||
|
# - It's not actually used. The standard recommendation is to
|
||
|
# serialize the *state_dict* of a model, not the model itself
|
||
|
# (since this is more stable to code changes affecting the model
|
||
|
# serialization), and the state dict saves "data" only, thus
|
||
|
# stripping the the backward hooks. In some cases, hooks are
|
||
|
# essential to the well-functioning of a model (e.g., DDP),
|
||
|
# but DDP already manages readding the hooks!
|
||
|
#
|
||
|
# - We didn't serialize them in many cases. Prior to #10220, we
|
||
|
# were dropping backward hooks in ForkingPickler. We "fixed" this
|
||
|
# to be convenient with other serialization sites, but lack of
|
||
|
# serializing backward hooks wasn't actually the root cause of
|
||
|
# the bug.
|
||
|
#
|
||
|
# With these cases in mind, we have decided that a better strategy
|
||
|
# is to just NOT serialize hooks at all.
|
||
|
#
|
||
|
# Since this is a BC-breaking change, we should warn when we previously
|
||
|
# serialized a hook, but no longer do so. This will be done by adding a special
|
||
|
# sentinel property to hooks will be used to suppress this warning. If a hook
|
||
|
# has the property _torch_serialize_ignore, we will not emit a warning if we
|
||
|
# attempt to serialize a Tensor with this hook attached to it.
|
||
|
#
|
||
|
# By the way, when _backward_hooks is skipped, we must give an EMPTY
|
||
|
# OrderedDict(), if you pass a None you'll run afoul #12219.
|
||
|
|
||
|
|
||
|
def _rebuild_tensor(storage, storage_offset, size, stride):
|
||
|
# first construct a tensor with the correct dtype/device
|
||
|
t = torch.tensor([], dtype=storage.dtype, device=storage.device)
|
||
|
return t.set_(storage, storage_offset, size, stride)
|
||
|
|
||
|
|
||
|
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
|
||
|
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
|
||
|
tensor.requires_grad = requires_grad
|
||
|
# NB: This line exists only for backwards compatibility; the
|
||
|
# general expectation is that backward_hooks is an empty
|
||
|
# OrderedDict. See Note [Don't serialize hooks]
|
||
|
tensor._backward_hooks = backward_hooks
|
||
|
return tensor
|
||
|
|
||
|
|
||
|
_sparse_tensors_to_validate: List["torch.Tensor"] = []
|
||
|
|
||
|
# In _legacy_load() in serialization.py we unpickle storages after the sparse
|
||
|
# tensors have been already unpickled. Those storages contain data necessary for
|
||
|
# validating sparse tensors: indices and values. That's why sparse tensors are
|
||
|
# first unpickled without any validation, and then this function is called just
|
||
|
# before _legacy_load() returns, so that all the sparse tensors can be validated
|
||
|
# in bulk.
|
||
|
#
|
||
|
# The same procedure must be followed by _load() in serialization.py because due
|
||
|
# to Pickler semantics, we have to use the same (non-validating) function for
|
||
|
# unpickling sparse tensors, regardless of the caller.
|
||
|
def _validate_loaded_sparse_tensors():
|
||
|
try:
|
||
|
for t in _sparse_tensors_to_validate:
|
||
|
torch._validate_sparse_coo_tensor_args(t._indices(), t._values(),
|
||
|
t.size())
|
||
|
finally:
|
||
|
_sparse_tensors_to_validate.clear()
|
||
|
|
||
|
def _rebuild_sparse_tensor(layout, data):
|
||
|
if layout == torch.sparse_coo:
|
||
|
indices, values, size = data
|
||
|
result = torch._sparse_coo_tensor_unsafe(indices, values, size)
|
||
|
_sparse_tensors_to_validate.append(result)
|
||
|
return result
|
||
|
|
||
|
raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout))
|
||
|
|
||
|
|
||
|
def _rebuild_xla_tensor(data, dtype, device, requires_grad):
|
||
|
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
|
||
|
tensor.requires_grad = requires_grad
|
||
|
return tensor
|
||
|
|
||
|
|
||
|
def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks):
|
||
|
qscheme = quantizer_params[0]
|
||
|
if qscheme == torch.per_tensor_affine:
|
||
|
_, scale, zero_point = quantizer_params
|
||
|
tensor = torch._empty_affine_quantized(size, scale=scale, zero_point=zero_point, dtype=storage.dtype)
|
||
|
elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
|
||
|
_, scales, zero_points, axis = quantizer_params
|
||
|
if type(scales) is list and type(zero_points) is list:
|
||
|
if qscheme == torch.per_channel_affine:
|
||
|
scales = torch.tensor(scales, dtype=torch.double)
|
||
|
zero_points = torch.tensor(zero_points, dtype=torch.long)
|
||
|
else:
|
||
|
scales = torch.tensor(scales, dtype=torch.float)
|
||
|
zero_points = torch.tensor(zero_points, dtype=torch.float)
|
||
|
tensor = torch._empty_per_channel_affine_quantized(
|
||
|
size, scales=scales, zero_points=zero_points, axis=axis, dtype=storage.dtype)
|
||
|
else:
|
||
|
raise RuntimeError("Can't deserialize quantized tensor with qscheme {}".format(qscheme))
|
||
|
tensor.set_(storage, storage_offset, size, stride)
|
||
|
tensor.requires_grad = requires_grad
|
||
|
# NB: This line exists only for backwards compatibility; the
|
||
|
# general expectation is that backward_hooks is an empty
|
||
|
# OrderedDict. See Note [Don't serialize hooks]
|
||
|
tensor._backward_hooks = backward_hooks
|
||
|
return tensor
|
||
|
|
||
|
def _rebuild_parameter(data, requires_grad, backward_hooks):
|
||
|
param = torch.nn.Parameter(data, requires_grad)
|
||
|
# NB: This line exists only for backwards compatibility; the
|
||
|
# general expectation is that backward_hooks is an empty
|
||
|
# OrderedDict. See Note [Don't serialize hooks]
|
||
|
param._backward_hooks = backward_hooks
|
||
|
|
||
|
return param
|
||
|
|
||
|
|
||
|
def _import_dotted_name(name):
|
||
|
components = name.split('.')
|
||
|
obj = __import__(components[0])
|
||
|
for component in components[1:]:
|
||
|
obj = getattr(obj, component)
|
||
|
return obj
|
||
|
|
||
|
|
||
|
# Taken from python 3.5 docs
|
||
|
def _accumulate(iterable, fn=lambda x, y: x + y):
|
||
|
'Return running totals'
|
||
|
# _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
|
||
|
# _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
|
||
|
it = iter(iterable)
|
||
|
try:
|
||
|
total = next(it)
|
||
|
except StopIteration:
|
||
|
return
|
||
|
yield total
|
||
|
for element in it:
|
||
|
total = fn(total, element)
|
||
|
yield total
|
||
|
|
||
|
|
||
|
def _flatten_dense_tensors(tensors):
|
||
|
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
|
||
|
same dense type.
|
||
|
|
||
|
Since inputs are dense, the resulting tensor will be a concatenated 1D
|
||
|
buffer. Element-wise operation on this buffer will be equivalent to
|
||
|
operating individually.
|
||
|
|
||
|
Args:
|
||
|
tensors (Iterable[Tensor]): dense tensors to flatten.
|
||
|
|
||
|
Returns:
|
||
|
A contiguous 1D buffer containing input tensors.
|
||
|
"""
|
||
|
if len(tensors) == 1:
|
||
|
return tensors[0].contiguous().view(-1)
|
||
|
flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
|
||
|
return flat
|
||
|
|
||
|
|
||
|
def _flatten_sparse_tensors(tensors):
|
||
|
"""Flatten sparse tensors into two contiguous 1D buffers, one of indices and
|
||
|
one of values. Assume tensors are of same sparse type.
|
||
|
|
||
|
Args:
|
||
|
tensors (Iterable[Tensor]): sparse tensors to flatten.
|
||
|
|
||
|
Returns:
|
||
|
A tuple of two contiguous 1D buffers, one containing input tensors'
|
||
|
indices and the other containing the values.
|
||
|
"""
|
||
|
flat_indices = _flatten_dense_tensors([torch.Tensor._indices(t) for t in tensors])
|
||
|
flat_values = _flatten_dense_tensors([torch.Tensor._values(t) for t in tensors])
|
||
|
return flat_indices, flat_values
|
||
|
|
||
|
|
||
|
def _unflatten_dense_tensors(flat, tensors):
|
||
|
"""View a flat buffer using the sizes of tensors. Assume that tensors are of
|
||
|
same dense type, and that flat is given by _flatten_dense_tensors.
|
||
|
|
||
|
Args:
|
||
|
flat (Tensor): flattened dense tensors to unflatten.
|
||
|
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
|
||
|
unflatten flat.
|
||
|
|
||
|
Returns:
|
||
|
Unflattened dense tensors with sizes same as tensors and values from
|
||
|
flat.
|
||
|
"""
|
||
|
outputs = []
|
||
|
offset = 0
|
||
|
for tensor in tensors:
|
||
|
numel = tensor.numel()
|
||
|
outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
|
||
|
offset += numel
|
||
|
return tuple(outputs)
|
||
|
|
||
|
|
||
|
def _unflatten_sparse_tensors(flat, tensors):
|
||
|
"""View flat buffer (containing indices and values) using the sizes of
|
||
|
tensors. Assume that tensors are of same sparse type, and that flat is given
|
||
|
by _flatten_sparse_tensors.
|
||
|
|
||
|
Args:
|
||
|
flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
|
||
|
tensors to unflatten.
|
||
|
tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
|
||
|
unflatten flat.
|
||
|
|
||
|
Returns:
|
||
|
Unflattened sparse tensors with sizes same as tensors and values from
|
||
|
flat.
|
||
|
"""
|
||
|
flat_indices, flat_values = flat
|
||
|
indices = _unflatten_dense_tensors(flat_indices, [torch.Tensor._indices(t) for t in tensors])
|
||
|
values = _unflatten_dense_tensors(flat_values, [torch.Tensor._values(t) for t in tensors])
|
||
|
outputs = []
|
||
|
for t, i, v in zip(tensors, indices, values):
|
||
|
outputs.append(t.new(i, v, t.size()))
|
||
|
return tuple(outputs)
|
||
|
|
||
|
|
||
|
def _reorder_tensors_as(tensors, ordered_tensors):
|
||
|
"""Assume that tensors are of same order as ordered_tensors within their
|
||
|
types, e.g., from _take_tensors. Reorder them to be of same order as
|
||
|
ordered_tensors.
|
||
|
|
||
|
Args:
|
||
|
tensors (Iterable[Tensor]): tensors to be reordered. They should be of
|
||
|
the same order as ordered_tensors within their own types.
|
||
|
ordered_tensors (Iterable[Tensor]): tensors whose order will be the
|
||
|
reference.
|
||
|
|
||
|
Returns:
|
||
|
Ordered tuple of tensors with contents from tensors and order of
|
||
|
ordered_tensors.
|
||
|
"""
|
||
|
type_dict = defaultdict(list)
|
||
|
for tensor in tensors:
|
||
|
type_dict[tensor.type()].append(tensor)
|
||
|
type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
|
||
|
return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
|
||
|
|
||
|
|
||
|
def _take_tensors(tensors, size_limit):
|
||
|
"""Group tensors into chunks. This generator yields a chunk at each time,
|
||
|
each containing tensors of same type up to certain byte limit in total size.
|
||
|
|
||
|
Args:
|
||
|
tensors (Sequence): A sequence of tensors to be separated into chunks.
|
||
|
size_limit (int): The limit of each chunk in bytes.
|
||
|
|
||
|
Yields:
|
||
|
Blocks of tensors of same type and within size_limit. The yielded
|
||
|
tensors are only ordered as the original sequence within its types.
|
||
|
"""
|
||
|
buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
|
||
|
for tensor in tensors:
|
||
|
t = tensor.type()
|
||
|
if tensor.is_sparse:
|
||
|
indices = torch.Tensor._indices(tensor)
|
||
|
values = torch.Tensor._values(tensor)
|
||
|
size = indices.numel() * indices.element_size() + values.numel() * values.element_size()
|
||
|
else:
|
||
|
size = tensor.numel() * tensor.element_size()
|
||
|
buf_and_size = buf_dict[t]
|
||
|
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
|
||
|
yield buf_and_size[0]
|
||
|
buf_and_size = buf_dict[t] = [[], 0]
|
||
|
buf_and_size[0].append(tensor)
|
||
|
buf_and_size[1] += size
|
||
|
for buf, _ in buf_dict.values():
|
||
|
if len(buf) > 0:
|
||
|
yield buf
|
||
|
|
||
|
|
||
|
# annotation decorator to get annotations in a way that is compatible
|
||
|
# with both Python 2 and 3
|
||
|
def annotate(ret, **kwargs):
|
||
|
def dec(fun):
|
||
|
fun.__annotations__ = dict(kwargs)
|
||
|
fun.__annotations__['return'] = ret
|
||
|
return fun
|
||
|
return dec
|
||
|
|
||
|
|
||
|
# NOTE [ Python Traceback Reference Cycle Problem ]
|
||
|
#
|
||
|
# When using sys.exc_info(), it is important to **not** store the exc_info[2],
|
||
|
# which is the traceback, because otherwise you will run into the traceback
|
||
|
# reference cycle problem, i.e., the traceback holding reference to the frame,
|
||
|
# and the frame (which holds reference to all the object in its temporary scope)
|
||
|
# holding reference the traceback.
|
||
|
|
||
|
class KeyErrorMessage(str):
|
||
|
r"""str subclass that returns itself in repr"""
|
||
|
def __repr__(self):
|
||
|
return self
|
||
|
|
||
|
|
||
|
class ExceptionWrapper(object):
|
||
|
r"""Wraps an exception plus traceback to communicate across threads"""
|
||
|
def __init__(self, exc_info=None, where="in background"):
|
||
|
# It is important that we don't store exc_info, see
|
||
|
# NOTE [ Python Traceback Reference Cycle Problem ]
|
||
|
if exc_info is None:
|
||
|
exc_info = sys.exc_info()
|
||
|
self.exc_type = exc_info[0]
|
||
|
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
||
|
self.where = where
|
||
|
|
||
|
def reraise(self):
|
||
|
r"""Reraises the wrapped exception in the current thread"""
|
||
|
# Format a message such as: "Caught ValueError in DataLoader worker
|
||
|
# process 2. Original Traceback:", followed by the traceback.
|
||
|
msg = "Caught {} {}.\nOriginal {}".format(
|
||
|
self.exc_type.__name__, self.where, self.exc_msg)
|
||
|
if self.exc_type == KeyError:
|
||
|
# KeyError calls repr() on its argument (usually a dict key). This
|
||
|
# makes stack traces unreadable. It will not be changed in Python
|
||
|
# (https://bugs.python.org/issue2651), so we work around it.
|
||
|
msg = KeyErrorMessage(msg)
|
||
|
elif getattr(self.exc_type, "message", None):
|
||
|
# Some exceptions have first argument as non-str but explicitly
|
||
|
# have message field
|
||
|
raise self.exc_type(message=msg)
|
||
|
raise self.exc_type(msg)
|
||
|
|
||
|
|
||
|
def _get_available_device_type():
|
||
|
if torch.cuda.is_available():
|
||
|
return "cuda"
|
||
|
# add more available device types here
|
||
|
return None
|
||
|
|
||
|
|
||
|
def _get_device_attr(get_member):
|
||
|
device_type = _get_available_device_type()
|
||
|
if device_type and device_type.lower() == "cuda":
|
||
|
return get_member(torch.cuda)
|
||
|
# add more available device types here
|
||
|
return None
|
||
|
|
||
|
|
||
|
def _get_current_device_index():
|
||
|
# current device index
|
||
|
return _get_device_attr(lambda m: m.current_device())
|
||
|
|
||
|
|
||
|
def _get_all_device_indices():
|
||
|
# all device index
|
||
|
return _get_device_attr(lambda m: list(range(m.device_count())))
|
||
|
|
||
|
|
||
|
def _get_devices_properties(device_ids):
|
||
|
# all device properties
|
||
|
return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
|
||
|
|
||
|
|
||
|
def _get_device_index(device, optional=False, allow_cpu=False) -> int:
|
||
|
r"""Gets the device index from :attr:`device`, which can be a torch.device
|
||
|
object, a Python integer, or ``None``.
|
||
|
|
||
|
If :attr:`device` is a torch.device object, returns the device index if it
|
||
|
has index. Note that for a device without a specified index,
|
||
|
i.e., ``torch.device('xxx')``, this will return the current default
|
||
|
device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
|
||
|
CPU devices will be accepted and ``-1`` will be returned in this case.
|
||
|
|
||
|
If :attr:`device` is a Python integer, it is returned as is.
|
||
|
|
||
|
If :attr:`device` is ``None``, this will return the current default
|
||
|
device of the supported runtime platform if :attr:`optional` is ``True``.
|
||
|
i.e., the current default CUDA device will be returned if CUDA runtime is supported.
|
||
|
"""
|
||
|
if isinstance(device, str):
|
||
|
device = torch.device(device)
|
||
|
device_idx: Optional[int]
|
||
|
device_idx = None
|
||
|
if isinstance(device, torch.device):
|
||
|
if not allow_cpu and device.type == 'cpu':
|
||
|
raise ValueError('Expected a non cpu device, but got: {}'.format(device))
|
||
|
device_idx = -1 if device.type == 'cpu' else device.index
|
||
|
if isinstance(device, int):
|
||
|
device_idx = device
|
||
|
if device_idx is None:
|
||
|
if optional:
|
||
|
device_idx = _get_current_device_index()
|
||
|
else:
|
||
|
raise ValueError('Expected a torch.device with a specified index '
|
||
|
'or an integer, but got:{}'.format(device))
|
||
|
return device_idx
|
||
|
|
||
|
|
||
|
def _handle_complex(tensor):
|
||
|
"""
|
||
|
Returns a real view of a tensor if complex dtype else just the tensor
|
||
|
need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
|
||
|
"""
|
||
|
return torch.view_as_real(tensor) if not isinstance(tensor,
|
||
|
torch.nn.UninitializedParameter) and tensor.is_complex() else tensor
|