618 lines
21 KiB
Python
618 lines
21 KiB
Python
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Utilities to manipulate TensorProtos."""
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from tensorboard.compat.proto import tensor_pb2
|
||
|
from tensorboard.compat.tensorflow_stub import dtypes, compat, tensor_shape
|
||
|
|
||
|
|
||
|
def ExtractBitsFromFloat16(x):
|
||
|
return np.asarray(x, dtype=np.float16).view(np.uint16).item()
|
||
|
|
||
|
|
||
|
def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.half_val.extend(
|
||
|
[ExtractBitsFromFloat16(x) for x in proto_values]
|
||
|
)
|
||
|
|
||
|
|
||
|
def ExtractBitsFromBFloat16(x):
|
||
|
return (
|
||
|
np.asarray(x, dtype=dtypes.bfloat16.as_numpy_dtype)
|
||
|
.view(np.uint16)
|
||
|
.item()
|
||
|
)
|
||
|
|
||
|
|
||
|
def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.half_val.extend(
|
||
|
[ExtractBitsFromBFloat16(x) for x in proto_values]
|
||
|
)
|
||
|
|
||
|
|
||
|
def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.float_val.extend([x.item() for x in proto_values])
|
||
|
|
||
|
|
||
|
def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.double_val.extend([x.item() for x in proto_values])
|
||
|
|
||
|
|
||
|
def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.int_val.extend([x.item() for x in proto_values])
|
||
|
|
||
|
|
||
|
def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.int64_val.extend([x.item() for x in proto_values])
|
||
|
|
||
|
|
||
|
def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.int_val.extend([x[0].item() for x in proto_values])
|
||
|
|
||
|
|
||
|
def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.uint32_val.extend([x.item() for x in proto_values])
|
||
|
|
||
|
|
||
|
def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.uint64_val.extend([x.item() for x in proto_values])
|
||
|
|
||
|
|
||
|
def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.scomplex_val.extend(
|
||
|
[v.item() for x in proto_values for v in [x.real, x.imag]]
|
||
|
)
|
||
|
|
||
|
|
||
|
def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.dcomplex_val.extend(
|
||
|
[v.item() for x in proto_values for v in [x.real, x.imag]]
|
||
|
)
|
||
|
|
||
|
|
||
|
def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values])
|
||
|
|
||
|
|
||
|
def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values):
|
||
|
tensor_proto.bool_val.extend([x.item() for x in proto_values])
|
||
|
|
||
|
|
||
|
_NP_TO_APPEND_FN = {
|
||
|
np.float16: SlowAppendFloat16ArrayToTensorProto,
|
||
|
np.float32: SlowAppendFloat32ArrayToTensorProto,
|
||
|
np.float64: SlowAppendFloat64ArrayToTensorProto,
|
||
|
np.int32: SlowAppendIntArrayToTensorProto,
|
||
|
np.int64: SlowAppendInt64ArrayToTensorProto,
|
||
|
np.uint8: SlowAppendIntArrayToTensorProto,
|
||
|
np.uint16: SlowAppendIntArrayToTensorProto,
|
||
|
np.uint32: SlowAppendUInt32ArrayToTensorProto,
|
||
|
np.uint64: SlowAppendUInt64ArrayToTensorProto,
|
||
|
np.int8: SlowAppendIntArrayToTensorProto,
|
||
|
np.int16: SlowAppendIntArrayToTensorProto,
|
||
|
np.complex64: SlowAppendComplex64ArrayToTensorProto,
|
||
|
np.complex128: SlowAppendComplex128ArrayToTensorProto,
|
||
|
np.object_: SlowAppendObjectArrayToTensorProto,
|
||
|
np.bool_: SlowAppendBoolArrayToTensorProto,
|
||
|
dtypes.qint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
|
||
|
dtypes.quint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
|
||
|
dtypes.qint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
|
||
|
dtypes.quint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
|
||
|
dtypes.qint32.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
|
||
|
# NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
|
||
|
}
|
||
|
|
||
|
BACKUP_DICT = {
|
||
|
dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto
|
||
|
}
|
||
|
|
||
|
|
||
|
def GetFromNumpyDTypeDict(dtype_dict, dtype):
|
||
|
# NOTE: dtype_dict.get(dtype) always returns None.
|
||
|
for key, val in dtype_dict.items():
|
||
|
if key == dtype:
|
||
|
return val
|
||
|
for key, val in BACKUP_DICT.items():
|
||
|
if key == dtype:
|
||
|
return val
|
||
|
return None
|
||
|
|
||
|
|
||
|
def GetNumpyAppendFn(dtype):
|
||
|
# numpy dtype for strings are variable length. We can not compare
|
||
|
# dtype with a single constant (np.string does not exist) to decide
|
||
|
# dtype is a "string" type. We need to compare the dtype.type to be
|
||
|
# sure it's a string type.
|
||
|
if dtype.type == np.string_ or dtype.type == np.unicode_:
|
||
|
return SlowAppendObjectArrayToTensorProto
|
||
|
return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype)
|
||
|
|
||
|
|
||
|
def _GetDenseDimensions(list_of_lists):
|
||
|
"""Returns the inferred dense dimensions of a list of lists."""
|
||
|
if not isinstance(list_of_lists, (list, tuple)):
|
||
|
return []
|
||
|
elif not list_of_lists:
|
||
|
return [0]
|
||
|
else:
|
||
|
return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0])
|
||
|
|
||
|
|
||
|
def _FlattenToStrings(nested_strings):
|
||
|
if isinstance(nested_strings, (list, tuple)):
|
||
|
for inner in nested_strings:
|
||
|
for flattened_string in _FlattenToStrings(inner):
|
||
|
yield flattened_string
|
||
|
else:
|
||
|
yield nested_strings
|
||
|
|
||
|
|
||
|
_TENSOR_CONTENT_TYPES = frozenset(
|
||
|
[
|
||
|
dtypes.float32,
|
||
|
dtypes.float64,
|
||
|
dtypes.int32,
|
||
|
dtypes.uint8,
|
||
|
dtypes.int16,
|
||
|
dtypes.int8,
|
||
|
dtypes.int64,
|
||
|
dtypes.qint8,
|
||
|
dtypes.quint8,
|
||
|
dtypes.qint16,
|
||
|
dtypes.quint16,
|
||
|
dtypes.qint32,
|
||
|
dtypes.uint32,
|
||
|
dtypes.uint64,
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
class _Message:
|
||
|
def __init__(self, message):
|
||
|
self._message = message
|
||
|
|
||
|
def __repr__(self):
|
||
|
return self._message
|
||
|
|
||
|
|
||
|
def _FirstNotNone(l):
|
||
|
for x in l:
|
||
|
if x is not None:
|
||
|
return x
|
||
|
return None
|
||
|
|
||
|
|
||
|
def _NotNone(v):
|
||
|
if v is None:
|
||
|
return _Message("None")
|
||
|
else:
|
||
|
return v
|
||
|
|
||
|
|
||
|
def _FilterTuple(v):
|
||
|
if not isinstance(v, (list, tuple)):
|
||
|
return v
|
||
|
if isinstance(v, tuple):
|
||
|
if not any(isinstance(x, (list, tuple)) for x in v):
|
||
|
return None
|
||
|
if isinstance(v, list):
|
||
|
if not any(isinstance(x, (list, tuple)) for x in v):
|
||
|
return _FirstNotNone(
|
||
|
[None if isinstance(x, (list, tuple)) else x for x in v]
|
||
|
)
|
||
|
return _FirstNotNone([_FilterTuple(x) for x in v])
|
||
|
|
||
|
|
||
|
def _FilterInt(v):
|
||
|
if isinstance(v, (list, tuple)):
|
||
|
return _FirstNotNone([_FilterInt(x) for x in v])
|
||
|
return (
|
||
|
None
|
||
|
if isinstance(v, (compat.integral_types, tensor_shape.Dimension))
|
||
|
else _NotNone(v)
|
||
|
)
|
||
|
|
||
|
|
||
|
def _FilterFloat(v):
|
||
|
if isinstance(v, (list, tuple)):
|
||
|
return _FirstNotNone([_FilterFloat(x) for x in v])
|
||
|
return None if isinstance(v, compat.real_types) else _NotNone(v)
|
||
|
|
||
|
|
||
|
def _FilterComplex(v):
|
||
|
if isinstance(v, (list, tuple)):
|
||
|
return _FirstNotNone([_FilterComplex(x) for x in v])
|
||
|
return None if isinstance(v, compat.complex_types) else _NotNone(v)
|
||
|
|
||
|
|
||
|
def _FilterStr(v):
|
||
|
if isinstance(v, (list, tuple)):
|
||
|
return _FirstNotNone([_FilterStr(x) for x in v])
|
||
|
if isinstance(v, compat.bytes_or_text_types):
|
||
|
return None
|
||
|
else:
|
||
|
return _NotNone(v)
|
||
|
|
||
|
|
||
|
def _FilterBool(v):
|
||
|
if isinstance(v, (list, tuple)):
|
||
|
return _FirstNotNone([_FilterBool(x) for x in v])
|
||
|
return None if isinstance(v, bool) else _NotNone(v)
|
||
|
|
||
|
|
||
|
_TF_TO_IS_OK = {
|
||
|
dtypes.bool: [_FilterBool],
|
||
|
dtypes.complex128: [_FilterComplex],
|
||
|
dtypes.complex64: [_FilterComplex],
|
||
|
dtypes.float16: [_FilterFloat],
|
||
|
dtypes.float32: [_FilterFloat],
|
||
|
dtypes.float64: [_FilterFloat],
|
||
|
dtypes.int16: [_FilterInt],
|
||
|
dtypes.int32: [_FilterInt],
|
||
|
dtypes.int64: [_FilterInt],
|
||
|
dtypes.int8: [_FilterInt],
|
||
|
dtypes.qint16: [_FilterInt, _FilterTuple],
|
||
|
dtypes.qint32: [_FilterInt, _FilterTuple],
|
||
|
dtypes.qint8: [_FilterInt, _FilterTuple],
|
||
|
dtypes.quint16: [_FilterInt, _FilterTuple],
|
||
|
dtypes.quint8: [_FilterInt, _FilterTuple],
|
||
|
dtypes.string: [_FilterStr],
|
||
|
dtypes.uint16: [_FilterInt],
|
||
|
dtypes.uint8: [_FilterInt],
|
||
|
}
|
||
|
|
||
|
|
||
|
def _Assertconvertible(values, dtype):
|
||
|
# If dtype is None or not recognized, assume it's convertible.
|
||
|
if dtype is None or dtype not in _TF_TO_IS_OK:
|
||
|
return
|
||
|
fn_list = _TF_TO_IS_OK.get(dtype)
|
||
|
mismatch = _FirstNotNone([fn(values) for fn in fn_list])
|
||
|
if mismatch is not None:
|
||
|
raise TypeError(
|
||
|
"Expected %s, got %s of type '%s' instead."
|
||
|
% (dtype.name, repr(mismatch), type(mismatch).__name__)
|
||
|
)
|
||
|
|
||
|
|
||
|
def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False):
|
||
|
"""Create a TensorProto.
|
||
|
|
||
|
Args:
|
||
|
values: Values to put in the TensorProto.
|
||
|
dtype: Optional tensor_pb2 DataType value.
|
||
|
shape: List of integers representing the dimensions of tensor.
|
||
|
verify_shape: Boolean that enables verification of a shape of values.
|
||
|
|
||
|
Returns:
|
||
|
A `TensorProto`. Depending on the type, it may contain data in the
|
||
|
"tensor_content" attribute, which is not directly useful to Python programs.
|
||
|
To access the values you should convert the proto back to a numpy ndarray
|
||
|
with `tensor_util.MakeNdarray(proto)`.
|
||
|
|
||
|
If `values` is a `TensorProto`, it is immediately returned; `dtype` and
|
||
|
`shape` are ignored.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if unsupported types are provided.
|
||
|
ValueError: if arguments have inappropriate values or if verify_shape is
|
||
|
True and shape of values is not equals to a shape from the argument.
|
||
|
|
||
|
make_tensor_proto accepts "values" of a python scalar, a python list, a
|
||
|
numpy ndarray, or a numpy scalar.
|
||
|
|
||
|
If "values" is a python scalar or a python list, make_tensor_proto
|
||
|
first convert it to numpy ndarray. If dtype is None, the
|
||
|
conversion tries its best to infer the right numpy data
|
||
|
type. Otherwise, the resulting numpy array has a convertible data
|
||
|
type with the given dtype.
|
||
|
|
||
|
In either case above, the numpy ndarray (either the caller provided
|
||
|
or the auto converted) must have the convertible type with dtype.
|
||
|
|
||
|
make_tensor_proto then converts the numpy array to a tensor proto.
|
||
|
|
||
|
If "shape" is None, the resulting tensor proto represents the numpy
|
||
|
array precisely.
|
||
|
|
||
|
Otherwise, "shape" specifies the tensor's shape and the numpy array
|
||
|
can not have more elements than what "shape" specifies.
|
||
|
"""
|
||
|
if isinstance(values, tensor_pb2.TensorProto):
|
||
|
return values
|
||
|
|
||
|
if dtype:
|
||
|
dtype = dtypes.as_dtype(dtype)
|
||
|
|
||
|
is_quantized = dtype in [
|
||
|
dtypes.qint8,
|
||
|
dtypes.quint8,
|
||
|
dtypes.qint16,
|
||
|
dtypes.quint16,
|
||
|
dtypes.qint32,
|
||
|
]
|
||
|
|
||
|
# We first convert value to a numpy array or scalar.
|
||
|
if isinstance(values, (np.ndarray, np.generic)):
|
||
|
if dtype:
|
||
|
nparray = values.astype(dtype.as_numpy_dtype)
|
||
|
else:
|
||
|
nparray = values
|
||
|
elif callable(getattr(values, "__array__", None)) or isinstance(
|
||
|
getattr(values, "__array_interface__", None), dict
|
||
|
):
|
||
|
# If a class has the __array__ method, or __array_interface__ dict, then it
|
||
|
# is possible to convert to numpy array.
|
||
|
nparray = np.asarray(values, dtype=dtype)
|
||
|
|
||
|
# This is the preferred way to create an array from the object, so replace
|
||
|
# the `values` with the array so that _FlattenToStrings is not run.
|
||
|
values = nparray
|
||
|
else:
|
||
|
if values is None:
|
||
|
raise ValueError("None values not supported.")
|
||
|
# if dtype is provided, forces numpy array to be the type
|
||
|
# provided if possible.
|
||
|
if dtype and dtype.is_numpy_compatible:
|
||
|
np_dt = dtype.as_numpy_dtype
|
||
|
else:
|
||
|
np_dt = None
|
||
|
# If shape is None, numpy.prod returns None when dtype is not set, but raises
|
||
|
# exception when dtype is set to np.int64
|
||
|
if shape is not None and np.prod(shape, dtype=np.int64) == 0:
|
||
|
nparray = np.empty(shape, dtype=np_dt)
|
||
|
else:
|
||
|
_Assertconvertible(values, dtype)
|
||
|
nparray = np.array(values, dtype=np_dt)
|
||
|
# check to them.
|
||
|
# We need to pass in quantized values as tuples, so don't apply the shape
|
||
|
if (
|
||
|
list(nparray.shape) != _GetDenseDimensions(values)
|
||
|
and not is_quantized
|
||
|
):
|
||
|
raise ValueError(
|
||
|
"""Argument must be a dense tensor: %s"""
|
||
|
""" - got shape %s, but wanted %s."""
|
||
|
% (values, list(nparray.shape), _GetDenseDimensions(values))
|
||
|
)
|
||
|
|
||
|
# python/numpy default float type is float64. We prefer float32 instead.
|
||
|
if (nparray.dtype == np.float64) and dtype is None:
|
||
|
nparray = nparray.astype(np.float32)
|
||
|
# python/numpy default int type is int64. We prefer int32 instead.
|
||
|
elif (nparray.dtype == np.int64) and dtype is None:
|
||
|
downcasted_array = nparray.astype(np.int32)
|
||
|
# Do not down cast if it leads to precision loss.
|
||
|
if np.array_equal(downcasted_array, nparray):
|
||
|
nparray = downcasted_array
|
||
|
|
||
|
# if dtype is provided, it must be convertible with what numpy
|
||
|
# conversion says.
|
||
|
numpy_dtype = dtypes.as_dtype(nparray.dtype)
|
||
|
if numpy_dtype is None:
|
||
|
raise TypeError("Unrecognized data type: %s" % nparray.dtype)
|
||
|
|
||
|
# If dtype was specified and is a quantized type, we convert
|
||
|
# numpy_dtype back into the quantized version.
|
||
|
if is_quantized:
|
||
|
numpy_dtype = dtype
|
||
|
|
||
|
if dtype is not None and (
|
||
|
not hasattr(dtype, "base_dtype")
|
||
|
or dtype.base_dtype != numpy_dtype.base_dtype
|
||
|
):
|
||
|
raise TypeError(
|
||
|
"Inconvertible types: %s vs. %s. Value is %s"
|
||
|
% (dtype, nparray.dtype, values)
|
||
|
)
|
||
|
|
||
|
# If shape is not given, get the shape from the numpy array.
|
||
|
if shape is None:
|
||
|
shape = nparray.shape
|
||
|
is_same_size = True
|
||
|
shape_size = nparray.size
|
||
|
else:
|
||
|
shape = [int(dim) for dim in shape]
|
||
|
shape_size = np.prod(shape, dtype=np.int64)
|
||
|
is_same_size = shape_size == nparray.size
|
||
|
|
||
|
if verify_shape:
|
||
|
if not nparray.shape == tuple(shape):
|
||
|
raise TypeError(
|
||
|
"Expected Tensor's shape: %s, got %s."
|
||
|
% (tuple(shape), nparray.shape)
|
||
|
)
|
||
|
|
||
|
if nparray.size > shape_size:
|
||
|
raise ValueError(
|
||
|
"Too many elements provided. Needed at most %d, but received %d"
|
||
|
% (shape_size, nparray.size)
|
||
|
)
|
||
|
|
||
|
tensor_proto = tensor_pb2.TensorProto(
|
||
|
dtype=numpy_dtype.as_datatype_enum,
|
||
|
tensor_shape=tensor_shape.as_shape(shape).as_proto(),
|
||
|
)
|
||
|
|
||
|
if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
|
||
|
if nparray.size * nparray.itemsize >= (1 << 31):
|
||
|
raise ValueError(
|
||
|
"Cannot create a tensor proto whose content is larger than 2GB."
|
||
|
)
|
||
|
tensor_proto.tensor_content = nparray.tobytes()
|
||
|
return tensor_proto
|
||
|
|
||
|
# If we were not given values as a numpy array, compute the proto_values
|
||
|
# from the given values directly, to avoid numpy trimming nulls from the
|
||
|
# strings. Since values could be a list of strings, or a multi-dimensional
|
||
|
# list of lists that might or might not correspond to the given shape,
|
||
|
# we flatten it conservatively.
|
||
|
if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray):
|
||
|
proto_values = _FlattenToStrings(values)
|
||
|
|
||
|
# At this point, values may be a list of objects that we could not
|
||
|
# identify a common type for (hence it was inferred as
|
||
|
# np.object/dtypes.string). If we are unable to convert it to a
|
||
|
# string, we raise a more helpful error message.
|
||
|
#
|
||
|
# Ideally, we'd be able to convert the elements of the list to a
|
||
|
# common type, but this type inference requires some thinking and
|
||
|
# so we defer it for now.
|
||
|
try:
|
||
|
str_values = [compat.as_bytes(x) for x in proto_values]
|
||
|
except TypeError:
|
||
|
raise TypeError(
|
||
|
"Failed to convert object of type %s to Tensor. "
|
||
|
"Contents: %s. Consider casting elements to a "
|
||
|
"supported type." % (type(values), values)
|
||
|
)
|
||
|
tensor_proto.string_val.extend(str_values)
|
||
|
return tensor_proto
|
||
|
|
||
|
# TensorFlow expects C order (a.k.a., eigen row major).
|
||
|
proto_values = nparray.ravel()
|
||
|
|
||
|
append_fn = GetNumpyAppendFn(proto_values.dtype)
|
||
|
if append_fn is None:
|
||
|
raise TypeError(
|
||
|
"Element type not supported in TensorProto: %s" % numpy_dtype.name
|
||
|
)
|
||
|
append_fn(tensor_proto, proto_values)
|
||
|
|
||
|
return tensor_proto
|
||
|
|
||
|
|
||
|
def make_ndarray(tensor):
|
||
|
"""Create a numpy ndarray from a tensor.
|
||
|
|
||
|
Create a numpy ndarray with the same shape and data as the tensor.
|
||
|
|
||
|
Args:
|
||
|
tensor: A TensorProto.
|
||
|
|
||
|
Returns:
|
||
|
A numpy array with the tensor contents.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if tensor has unsupported type.
|
||
|
"""
|
||
|
shape = [d.size for d in tensor.tensor_shape.dim]
|
||
|
num_elements = np.prod(shape, dtype=np.int64)
|
||
|
tensor_dtype = dtypes.as_dtype(tensor.dtype)
|
||
|
dtype = tensor_dtype.as_numpy_dtype
|
||
|
|
||
|
if tensor.tensor_content:
|
||
|
return (
|
||
|
np.frombuffer(tensor.tensor_content, dtype=dtype)
|
||
|
.copy()
|
||
|
.reshape(shape)
|
||
|
)
|
||
|
elif tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16:
|
||
|
# the half_val field of the TensorProto stores the binary representation
|
||
|
# of the fp16: we need to reinterpret this as a proper float16
|
||
|
if len(tensor.half_val) == 1:
|
||
|
tmp = np.array(tensor.half_val[0], dtype=np.uint16)
|
||
|
tmp.dtype = tensor_dtype.as_numpy_dtype
|
||
|
return np.repeat(tmp, num_elements).reshape(shape)
|
||
|
else:
|
||
|
tmp = np.fromiter(tensor.half_val, dtype=np.uint16)
|
||
|
tmp.dtype = tensor_dtype.as_numpy_dtype
|
||
|
return tmp.reshape(shape)
|
||
|
elif tensor_dtype == dtypes.float32:
|
||
|
if len(tensor.float_val) == 1:
|
||
|
return np.repeat(
|
||
|
np.array(tensor.float_val[0], dtype=dtype), num_elements
|
||
|
).reshape(shape)
|
||
|
else:
|
||
|
return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape)
|
||
|
elif tensor_dtype == dtypes.float64:
|
||
|
if len(tensor.double_val) == 1:
|
||
|
return np.repeat(
|
||
|
np.array(tensor.double_val[0], dtype=dtype), num_elements
|
||
|
).reshape(shape)
|
||
|
else:
|
||
|
return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape)
|
||
|
elif tensor_dtype in [
|
||
|
dtypes.int32,
|
||
|
dtypes.uint8,
|
||
|
dtypes.uint16,
|
||
|
dtypes.int16,
|
||
|
dtypes.int8,
|
||
|
dtypes.qint32,
|
||
|
dtypes.quint8,
|
||
|
dtypes.qint8,
|
||
|
dtypes.qint16,
|
||
|
dtypes.quint16,
|
||
|
]:
|
||
|
if len(tensor.int_val) == 1:
|
||
|
return np.repeat(
|
||
|
np.array(tensor.int_val[0], dtype=dtype), num_elements
|
||
|
).reshape(shape)
|
||
|
else:
|
||
|
return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape)
|
||
|
elif tensor_dtype == dtypes.int64:
|
||
|
if len(tensor.int64_val) == 1:
|
||
|
return np.repeat(
|
||
|
np.array(tensor.int64_val[0], dtype=dtype), num_elements
|
||
|
).reshape(shape)
|
||
|
else:
|
||
|
return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape)
|
||
|
elif tensor_dtype == dtypes.string:
|
||
|
if len(tensor.string_val) == 1:
|
||
|
return np.repeat(
|
||
|
np.array(tensor.string_val[0], dtype=dtype), num_elements
|
||
|
).reshape(shape)
|
||
|
else:
|
||
|
return np.array(list(tensor.string_val), dtype=dtype).reshape(shape)
|
||
|
elif tensor_dtype == dtypes.complex64:
|
||
|
it = iter(tensor.scomplex_val)
|
||
|
if len(tensor.scomplex_val) == 2:
|
||
|
return np.repeat(
|
||
|
np.array(
|
||
|
complex(tensor.scomplex_val[0], tensor.scomplex_val[1]),
|
||
|
dtype=dtype,
|
||
|
),
|
||
|
num_elements,
|
||
|
).reshape(shape)
|
||
|
else:
|
||
|
return np.array(
|
||
|
[complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype
|
||
|
).reshape(shape)
|
||
|
elif tensor_dtype == dtypes.complex128:
|
||
|
it = iter(tensor.dcomplex_val)
|
||
|
if len(tensor.dcomplex_val) == 2:
|
||
|
return np.repeat(
|
||
|
np.array(
|
||
|
complex(tensor.dcomplex_val[0], tensor.dcomplex_val[1]),
|
||
|
dtype=dtype,
|
||
|
),
|
||
|
num_elements,
|
||
|
).reshape(shape)
|
||
|
else:
|
||
|
return np.array(
|
||
|
[complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype
|
||
|
).reshape(shape)
|
||
|
elif tensor_dtype == dtypes.bool:
|
||
|
if len(tensor.bool_val) == 1:
|
||
|
return np.repeat(
|
||
|
np.array(tensor.bool_val[0], dtype=dtype), num_elements
|
||
|
).reshape(shape)
|
||
|
else:
|
||
|
return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape)
|
||
|
else:
|
||
|
raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
|