1317 lines
44 KiB
Python
1317 lines
44 KiB
Python
|
## @package schema
|
||
|
# Module caffe2.python.schema
|
||
|
"""
|
||
|
Defines a minimal set of data types that allow to represent datasets with
|
||
|
arbitrary nested structure, including objects of variable length, such as
|
||
|
maps and lists.
|
||
|
|
||
|
This defines a columnar storage format for such datasets on top of caffe2
|
||
|
tensors. In terms of capacity of representation, it can represent most of
|
||
|
the data types supported by Parquet, ORC, DWRF file formats.
|
||
|
|
||
|
See comments in operator_test/dataset_ops_test.py for an example and
|
||
|
walkthrough on how to use schema to store and iterate through a structured
|
||
|
in-memory dataset.
|
||
|
"""
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
import logging
|
||
|
import numpy as np
|
||
|
from caffe2.python import core
|
||
|
from caffe2.python import workspace
|
||
|
from caffe2.python.core import BlobReference
|
||
|
from collections import OrderedDict, namedtuple
|
||
|
from past.builtins import basestring
|
||
|
from future.utils import viewitems, viewkeys, viewvalues
|
||
|
from itertools import islice
|
||
|
from six import StringIO
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
FIELD_SEPARATOR = ':'
|
||
|
|
||
|
|
||
|
def _join_field_name(prefix, suffix):
|
||
|
if prefix and suffix:
|
||
|
return '{}{}{}'.format(prefix, FIELD_SEPARATOR, suffix)
|
||
|
elif prefix:
|
||
|
return prefix
|
||
|
elif suffix:
|
||
|
return suffix
|
||
|
else:
|
||
|
return ''
|
||
|
|
||
|
|
||
|
def _normalize_field(field_or_type_or_blob, keep_blobs=True):
|
||
|
"""Clones/normalizes a field before adding it to a container."""
|
||
|
if isinstance(field_or_type_or_blob, Field):
|
||
|
return field_or_type_or_blob.clone(keep_blobs=keep_blobs)
|
||
|
elif type(field_or_type_or_blob) in (type, np.dtype):
|
||
|
return Scalar(dtype=field_or_type_or_blob)
|
||
|
else:
|
||
|
return Scalar(blob=field_or_type_or_blob)
|
||
|
|
||
|
|
||
|
FeatureSpec = namedtuple(
|
||
|
'FeatureSpec',
|
||
|
[
|
||
|
'feature_type',
|
||
|
'feature_names',
|
||
|
'feature_ids',
|
||
|
'feature_is_request_only',
|
||
|
'desired_hash_size',
|
||
|
'feature_to_index',
|
||
|
]
|
||
|
)
|
||
|
|
||
|
FeatureSpec.__new__.__defaults__ = (None, None, None, None, None, None)
|
||
|
|
||
|
|
||
|
class Metadata(
|
||
|
namedtuple(
|
||
|
'Metadata', ['categorical_limit', 'expected_value', 'feature_specs']
|
||
|
)
|
||
|
):
|
||
|
"""Represents additional information associated with a scalar in schema.
|
||
|
|
||
|
`categorical_limit` - for fields of integral type that are guaranteed to be
|
||
|
non-negative it specifies the maximum possible value plus one. It's often
|
||
|
used as a size of an embedding table.
|
||
|
|
||
|
`expected_value` - anticipated average value of elements in the field.
|
||
|
Usually makes sense for length fields of lists.
|
||
|
|
||
|
`feature_specs` - information about the features that contained in this
|
||
|
field. For example if field have more than 1 feature it can have list of
|
||
|
feature names contained in this field."""
|
||
|
__slots__ = ()
|
||
|
|
||
|
|
||
|
Metadata.__new__.__defaults__ = (None, None, None)
|
||
|
|
||
|
|
||
|
class Field(object):
|
||
|
"""Represents an abstract field type in a dataset.
|
||
|
"""
|
||
|
|
||
|
__slots__ = ("_parent", "_field_offsets")
|
||
|
|
||
|
def __init__(self, children):
|
||
|
"""Derived classes must call this after their initialization."""
|
||
|
self._parent = (None, 0)
|
||
|
offset = 0
|
||
|
self._field_offsets = []
|
||
|
for child in children:
|
||
|
self._field_offsets.append(offset)
|
||
|
offset += len(child.field_names())
|
||
|
self._field_offsets.append(offset)
|
||
|
|
||
|
def clone_schema(self):
|
||
|
return self.clone(keep_blobs=False)
|
||
|
|
||
|
def field_names(self):
|
||
|
"""Return the children field names for this field."""
|
||
|
raise NotImplementedError('Field is an abstract class.')
|
||
|
|
||
|
def field_types(self):
|
||
|
"""Return the numpy.dtype for each of the children fields."""
|
||
|
raise NotImplementedError('Field is an abstract class.')
|
||
|
|
||
|
def field_metadata(self):
|
||
|
"""Return the Metadata for each of the children fields."""
|
||
|
raise NotImplementedError('Field is an abstract class.')
|
||
|
|
||
|
def field_blobs(self):
|
||
|
"""Return the list of blobs with contents for this Field.
|
||
|
Values can either be all numpy.ndarray or BlobReference.
|
||
|
If any of the fields doesn't have a blob, throws.
|
||
|
"""
|
||
|
raise NotImplementedError('Field is an abstract class.')
|
||
|
|
||
|
def all_scalars(self):
|
||
|
"""Return the list of all Scalar instances in the Field.
|
||
|
The order is the same as for field_names() or field_blobs()"""
|
||
|
raise NotImplementedError('Field is an abstract class.')
|
||
|
|
||
|
def has_blobs(self):
|
||
|
"""Return True if every scalar of this field has blobs."""
|
||
|
raise NotImplementedError('Field is an abstract class.')
|
||
|
|
||
|
def clone(self, keep_blobs=True):
|
||
|
"""Clone this Field along with its children."""
|
||
|
raise NotImplementedError('Field is an abstract class.')
|
||
|
|
||
|
def _set_parent(self, parent, relative_id):
|
||
|
self._parent = (parent, relative_id)
|
||
|
|
||
|
def slice(self):
|
||
|
"""
|
||
|
Returns a slice representing the range of field ids that belong to
|
||
|
this field. This slice can be used to index a list of fields.
|
||
|
|
||
|
E.g.:
|
||
|
|
||
|
>>> s = Struct(
|
||
|
>>> ('a', Scalar()),
|
||
|
>>> ('b', Struct(
|
||
|
>>> ('b1', Scalar()),
|
||
|
>>> ('b2', Scalar()),
|
||
|
>>> )),
|
||
|
>>> ('c', Scalar()),
|
||
|
>>> )
|
||
|
>>> field_data = ['da', 'db1', 'db2', 'dc']
|
||
|
>>> field_data[s.b.split()]
|
||
|
['db1', 'db2']
|
||
|
"""
|
||
|
base_id = self._child_base_id()
|
||
|
return slice(base_id, base_id + len(self.field_names()))
|
||
|
|
||
|
def _child_base_id(self, child_index=None):
|
||
|
"""Get the base id of the given child"""
|
||
|
p, i = self._parent
|
||
|
pos = 0 if child_index is None else self._field_offsets[child_index]
|
||
|
if p:
|
||
|
pos += p._child_base_id(i)
|
||
|
return pos
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
"""Equivalance of two schemas"""
|
||
|
return (
|
||
|
(self.field_names() == other.field_names()) and
|
||
|
(self.field_types() == other.field_types()) and
|
||
|
(self.field_metadata() == other.field_metadata())
|
||
|
)
|
||
|
|
||
|
def _pprint_impl(self, indent, str_buffer):
|
||
|
raise NotImplementedError('Field is an abstract class.')
|
||
|
|
||
|
def __repr__(self):
|
||
|
str_buffer = StringIO()
|
||
|
self._pprint_impl(0, str_buffer)
|
||
|
contents = str_buffer.getvalue()
|
||
|
str_buffer.close()
|
||
|
return contents
|
||
|
|
||
|
|
||
|
class List(Field):
|
||
|
"""Represents a variable-length list.
|
||
|
|
||
|
Values of a list can also be complex fields such as Lists and Structs.
|
||
|
In addition to the fields exposed by its `values` field, a List exposes an
|
||
|
additional `lengths` field, which will contain the size of each list under
|
||
|
the parent domain.
|
||
|
"""
|
||
|
|
||
|
__slots__ = ("lengths", "_items")
|
||
|
|
||
|
def __init__(self, values, lengths_blob=None):
|
||
|
if isinstance(lengths_blob, Field):
|
||
|
assert isinstance(lengths_blob, Scalar)
|
||
|
self.lengths = _normalize_field(lengths_blob)
|
||
|
else:
|
||
|
self.lengths = Scalar(np.int32, lengths_blob)
|
||
|
self._items = _normalize_field(values)
|
||
|
self.lengths._set_parent(self, 0)
|
||
|
self._items._set_parent(self, 1)
|
||
|
super(List, self).__init__([self.lengths, self._items])
|
||
|
|
||
|
def field_names(self):
|
||
|
value_fields = self._items.field_names()
|
||
|
return (
|
||
|
['lengths'] + [_join_field_name('values', v) for v in value_fields]
|
||
|
)
|
||
|
|
||
|
def field_types(self):
|
||
|
return self.lengths.field_types() + self._items.field_types()
|
||
|
|
||
|
def field_metadata(self):
|
||
|
return self.lengths.field_metadata() + self._items.field_metadata()
|
||
|
|
||
|
def field_blobs(self):
|
||
|
return self.lengths.field_blobs() + self._items.field_blobs()
|
||
|
|
||
|
def all_scalars(self):
|
||
|
return self.lengths.all_scalars() + self._items.all_scalars()
|
||
|
|
||
|
def has_blobs(self):
|
||
|
return self.lengths.has_blobs() and self._items.has_blobs()
|
||
|
|
||
|
def clone(self, keep_blobs=True):
|
||
|
return type(self)(
|
||
|
_normalize_field(self._items, keep_blobs=keep_blobs),
|
||
|
_normalize_field(self.lengths, keep_blobs=keep_blobs)
|
||
|
)
|
||
|
|
||
|
def _pprint_impl(self, indent, str_buffer):
|
||
|
str_buffer.write(' ' * indent + "List(\n")
|
||
|
str_buffer.write(' ' * (indent + 1) + "lengths=\n")
|
||
|
self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
|
||
|
str_buffer.write(' ' * (indent + 1) + "_items=\n")
|
||
|
self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
|
||
|
str_buffer.write(' ' * indent + ")\n")
|
||
|
|
||
|
def __getattr__(self, item):
|
||
|
"""If the value of this list is a struct,
|
||
|
allow to introspect directly into its fields."""
|
||
|
if item.startswith('__'):
|
||
|
raise AttributeError(item)
|
||
|
if isinstance(self._items, Struct):
|
||
|
return getattr(self._items, item)
|
||
|
elif item == 'value' or item == 'items':
|
||
|
return self._items
|
||
|
else:
|
||
|
raise AttributeError('Field not found in list: %s.' % item)
|
||
|
|
||
|
def __getitem__(self, item):
|
||
|
names = item.split(FIELD_SEPARATOR, 1)
|
||
|
|
||
|
if len(names) == 1:
|
||
|
if item == 'lengths':
|
||
|
return self.lengths
|
||
|
elif item == 'values':
|
||
|
return self._items
|
||
|
else:
|
||
|
if names[0] == 'values':
|
||
|
return self._items[names[1]]
|
||
|
raise KeyError('Field not found in list: %s.' % item)
|
||
|
|
||
|
|
||
|
class ListWithEvicted(List):
|
||
|
"""
|
||
|
This class is similar with List, but containing extra field evicted_values for
|
||
|
LRU Hashing.
|
||
|
"""
|
||
|
|
||
|
__slots__ = ("_evicted_values",)
|
||
|
|
||
|
def __init__(self, values, lengths_blob=None, evicted_values=None):
|
||
|
if isinstance(evicted_values, Field):
|
||
|
assert isinstance(evicted_values, Scalar)
|
||
|
self._evicted_values = _normalize_field(evicted_values)
|
||
|
else:
|
||
|
self._evicted_values = Scalar(np.int64, evicted_values)
|
||
|
super(ListWithEvicted, self).__init__(values, lengths_blob=lengths_blob)
|
||
|
|
||
|
def field_names(self):
|
||
|
value_fields = self._items.field_names()
|
||
|
return (
|
||
|
['lengths'] + [_join_field_name('values', v) for v in value_fields] + ["_evicted_values"]
|
||
|
)
|
||
|
|
||
|
def field_types(self):
|
||
|
return self.lengths.field_types() + self._items.field_types() + self._evicted_values.field_types()
|
||
|
|
||
|
def field_metadata(self):
|
||
|
return self.lengths.field_metadata() + self._items.field_metadata() + self._evicted_values.field_metadata()
|
||
|
|
||
|
def field_blobs(self):
|
||
|
return self.lengths.field_blobs() + self._items.field_blobs() + self._evicted_values.field_blobs()
|
||
|
|
||
|
def all_scalars(self):
|
||
|
return self.lengths.all_scalars() + self._items.all_scalars() + self._evicted_values.all_scalars()
|
||
|
|
||
|
def has_blobs(self):
|
||
|
return self.lengths.has_blobs() and self._items.has_blobs() + self._evicted_values.has_blobs()
|
||
|
|
||
|
def clone(self, keep_blobs=True):
|
||
|
return type(self)(
|
||
|
_normalize_field(self._items, keep_blobs=keep_blobs),
|
||
|
_normalize_field(self.lengths, keep_blobs=keep_blobs),
|
||
|
_normalize_field(self._evicted_values, keep_blobs=keep_blobs)
|
||
|
)
|
||
|
|
||
|
def _pprint_impl(self, indent, str_buffer):
|
||
|
str_buffer.write(' ' * indent + "ListWithEvicted(\n")
|
||
|
str_buffer.write(' ' * (indent + 1) + "lengths=\n")
|
||
|
self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
|
||
|
str_buffer.write(' ' * (indent + 1) + "_items=\n")
|
||
|
self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
|
||
|
str_buffer.write(' ' * (indent + 1) + "_evicted_values=\n")
|
||
|
self._evicted_values._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
|
||
|
str_buffer.write(' ' * indent + ")\n")
|
||
|
|
||
|
|
||
|
def __getattr__(self, item):
|
||
|
"""If the value of this list is a struct,
|
||
|
allow to introspect directly into its fields."""
|
||
|
if item.startswith('__'):
|
||
|
raise AttributeError(item)
|
||
|
if item == "_evicted_values":
|
||
|
return self._evicted_values
|
||
|
if isinstance(self._items, Struct):
|
||
|
return getattr(self._items, item)
|
||
|
elif item == 'value' or item == 'items':
|
||
|
return self._items
|
||
|
else:
|
||
|
raise AttributeError('Field not found in list: %s.' % item)
|
||
|
|
||
|
def __getitem__(self, item):
|
||
|
names = item.split(FIELD_SEPARATOR, 1)
|
||
|
|
||
|
if len(names) == 1:
|
||
|
if item == 'lengths':
|
||
|
return self.lengths
|
||
|
elif item == 'values':
|
||
|
return self._items
|
||
|
elif item == '_evicted_values':
|
||
|
return self._evicted_values
|
||
|
else:
|
||
|
if names[0] == 'values':
|
||
|
return self._items[names[1]]
|
||
|
raise KeyError('Field not found in list: %s.' % item)
|
||
|
|
||
|
|
||
|
class Struct(Field):
|
||
|
"""Represents a named list of fields sharing the same domain.
|
||
|
"""
|
||
|
|
||
|
__slots__ = ("fields", "_frozen")
|
||
|
|
||
|
def __init__(self, *fields):
|
||
|
""" fields is a list of tuples in format of (name, field). The name is
|
||
|
a string of nested name, e.g., `a`, `a:b`, `a:b:c`. For example
|
||
|
|
||
|
Struct(
|
||
|
('a', Scalar()),
|
||
|
('b:c', Scalar()),
|
||
|
('b:d:e', Scalar()),
|
||
|
('b', Struct(
|
||
|
('f', Scalar()),
|
||
|
)),
|
||
|
)
|
||
|
|
||
|
is equal to
|
||
|
|
||
|
Struct(
|
||
|
('a', Scalar()),
|
||
|
('b', Struct(
|
||
|
('c', Scalar()),
|
||
|
('d', Struct(('e', Scalar()))),
|
||
|
('f', Scalar()),
|
||
|
)),
|
||
|
)
|
||
|
"""
|
||
|
for field in fields:
|
||
|
assert len(field) == 2
|
||
|
assert field[0], 'Field names cannot be empty'
|
||
|
assert field[0] != 'lengths', (
|
||
|
'Struct cannot contain a field named `lengths`.'
|
||
|
)
|
||
|
fields = [(name, _normalize_field(field)) for name, field in fields]
|
||
|
self.fields = OrderedDict()
|
||
|
for name, field in fields:
|
||
|
if FIELD_SEPARATOR in name:
|
||
|
name, field = self._struct_from_nested_name(name, field)
|
||
|
if name not in self.fields:
|
||
|
self.fields[name] = field
|
||
|
continue
|
||
|
if (
|
||
|
not isinstance(field, Struct) or
|
||
|
not isinstance(self.fields[name], Struct)
|
||
|
):
|
||
|
raise ValueError('Duplicate field name: %s' % name)
|
||
|
self.fields[name] = self.fields[name] + field
|
||
|
for id, (_, field) in enumerate(viewitems(self.fields)):
|
||
|
field._set_parent(self, id)
|
||
|
super(Struct, self).__init__(viewvalues(self.fields))
|
||
|
self._frozen = True
|
||
|
|
||
|
def _struct_from_nested_name(self, nested_name, field):
|
||
|
def create_internal(nested_name, field):
|
||
|
names = nested_name.split(FIELD_SEPARATOR, 1)
|
||
|
if len(names) == 1:
|
||
|
added_field = field
|
||
|
else:
|
||
|
added_field = create_internal(names[1], field)
|
||
|
return Struct((names[0], added_field))
|
||
|
|
||
|
names = nested_name.split(FIELD_SEPARATOR, 1)
|
||
|
assert len(names) >= 2
|
||
|
return names[0], create_internal(names[1], field)
|
||
|
|
||
|
def get_children(self):
|
||
|
return list(viewitems(self.fields))
|
||
|
|
||
|
def field_names(self):
|
||
|
names = []
|
||
|
for name, field in viewitems(self.fields):
|
||
|
names += [_join_field_name(name, f) for f in field.field_names()]
|
||
|
return names
|
||
|
|
||
|
def field_types(self):
|
||
|
types = []
|
||
|
for _, field in viewitems(self.fields):
|
||
|
types += field.field_types()
|
||
|
return types
|
||
|
|
||
|
def field_metadata(self):
|
||
|
metadata = []
|
||
|
for _, field in viewitems(self.fields):
|
||
|
metadata += field.field_metadata()
|
||
|
return metadata
|
||
|
|
||
|
def field_blobs(self):
|
||
|
blobs = []
|
||
|
for _, field in viewitems(self.fields):
|
||
|
blobs += field.field_blobs()
|
||
|
return blobs
|
||
|
|
||
|
def all_scalars(self):
|
||
|
scalars = []
|
||
|
for _, field in viewitems(self.fields):
|
||
|
scalars += field.all_scalars()
|
||
|
return scalars
|
||
|
|
||
|
def has_blobs(self):
|
||
|
return all(field.has_blobs() for field in viewvalues(self.fields))
|
||
|
|
||
|
def clone(self, keep_blobs=True):
|
||
|
normalized_fields = [
|
||
|
(k, _normalize_field(v, keep_blobs=keep_blobs))
|
||
|
for k, v in viewitems(self.fields)
|
||
|
]
|
||
|
return type(self)(*normalized_fields)
|
||
|
|
||
|
def _get_field_by_nested_name(self, nested_name):
|
||
|
names = nested_name.split(FIELD_SEPARATOR, 1)
|
||
|
field = self.fields.get(names[0], None)
|
||
|
|
||
|
if field is None:
|
||
|
return None
|
||
|
|
||
|
if len(names) == 1:
|
||
|
return field
|
||
|
|
||
|
try:
|
||
|
return field[names[1]]
|
||
|
except (KeyError, TypeError):
|
||
|
return None
|
||
|
|
||
|
def _pprint_impl(self, indent, str_buffer):
|
||
|
str_buffer.write(' ' * indent + "Struct( \n")
|
||
|
for name, field in viewitems(self.fields):
|
||
|
str_buffer.write(' ' * (indent + 1) + "{}=".format(name) + "\n")
|
||
|
field._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
|
||
|
str_buffer.write(' ' * indent + ") \n")
|
||
|
|
||
|
def __contains__(self, item):
|
||
|
field = self._get_field_by_nested_name(item)
|
||
|
return field is not None
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.fields)
|
||
|
|
||
|
def __getitem__(self, item):
|
||
|
"""
|
||
|
item can be a tuple or list of ints or strings, or a single
|
||
|
int or string. String item is a nested field name, e.g., "a", "a:b",
|
||
|
"a:b:c". Int item is the index of a field at the first level of the
|
||
|
Struct.
|
||
|
"""
|
||
|
if isinstance(item, list) or isinstance(item, tuple):
|
||
|
keys = list(viewkeys(self.fields))
|
||
|
return Struct(
|
||
|
* [
|
||
|
(
|
||
|
keys[k]
|
||
|
if isinstance(k, int) else k, self[k]
|
||
|
) for k in item
|
||
|
]
|
||
|
)
|
||
|
elif isinstance(item, int):
|
||
|
return next(islice(viewvalues(self.fields), item, None))
|
||
|
else:
|
||
|
field = self._get_field_by_nested_name(item)
|
||
|
if field is None:
|
||
|
raise KeyError('field "%s" not found' % (item))
|
||
|
return field
|
||
|
|
||
|
def get(self, item, default_value):
|
||
|
"""
|
||
|
similar to python's dictionary get method, return field of item if found
|
||
|
(i.e. self.item is valid) or otherwise return default_value
|
||
|
|
||
|
it's a syntax suger of python's builtin getattr method
|
||
|
"""
|
||
|
return getattr(self, item, default_value)
|
||
|
|
||
|
def __getattr__(self, item):
|
||
|
if item.startswith('__'):
|
||
|
raise AttributeError(item)
|
||
|
try:
|
||
|
return super(Struct, self).__getattribute__("fields")[item]
|
||
|
except KeyError:
|
||
|
raise AttributeError(item)
|
||
|
|
||
|
def __setattr__(self, key, value):
|
||
|
# Disable setting attributes after initialization to prevent false
|
||
|
# impression of being able to overwrite a field.
|
||
|
# Allowing setting internal states mainly so that _parent can be set
|
||
|
# post initialization.
|
||
|
if getattr(self, '_frozen', None) and not key.startswith('_'):
|
||
|
raise TypeError('Struct.__setattr__() is disabled after __init__()')
|
||
|
super(Struct, self).__setattr__(key, value)
|
||
|
|
||
|
def __add__(self, other):
|
||
|
"""
|
||
|
Allows to merge fields of two schema.Struct using '+' operator.
|
||
|
If two Struct have common field names, the merge is conducted
|
||
|
recursively. Here are examples:
|
||
|
|
||
|
Example 1
|
||
|
s1 = Struct(('a', Scalar()))
|
||
|
s2 = Struct(('b', Scalar()))
|
||
|
s1 + s2 == Struct(
|
||
|
('a', Scalar()),
|
||
|
('b', Scalar()),
|
||
|
)
|
||
|
|
||
|
Example 2
|
||
|
s1 = Struct(
|
||
|
('a', Scalar()),
|
||
|
('b', Struct(('c', Scalar()))),
|
||
|
)
|
||
|
s2 = Struct(('b', Struct(('d', Scalar()))))
|
||
|
s1 + s2 == Struct(
|
||
|
('a', Scalar()),
|
||
|
('b', Struct(
|
||
|
('c', Scalar()),
|
||
|
('d', Scalar()),
|
||
|
)),
|
||
|
)
|
||
|
"""
|
||
|
if not isinstance(other, Struct):
|
||
|
return NotImplemented
|
||
|
|
||
|
children = OrderedDict(self.get_children())
|
||
|
for name, right_field in other.get_children():
|
||
|
if name not in children:
|
||
|
children[name] = right_field
|
||
|
continue
|
||
|
left_field = children[name]
|
||
|
if not (isinstance(left_field, Struct) and isinstance(right_field, Struct)):
|
||
|
raise TypeError(
|
||
|
"Type of left_field, " + str(type(left_field)) +
|
||
|
", and type of right_field, " +
|
||
|
str(type(right_field)) +
|
||
|
", must both the Struct to allow merging of the field, " + name)
|
||
|
children[name] = left_field + right_field
|
||
|
|
||
|
return Struct(*(viewitems(children)))
|
||
|
|
||
|
def __sub__(self, other):
|
||
|
"""
|
||
|
Allows to remove common fields of two schema.Struct from self by
|
||
|
using '-' operator. If two Struct have common field names, the
|
||
|
removal is conducted recursively. If a child struct has no fields
|
||
|
inside, it will be removed from its parent. Here are examples:
|
||
|
|
||
|
Example 1
|
||
|
s1 = Struct(
|
||
|
('a', Scalar()),
|
||
|
('b', Scalar()),
|
||
|
)
|
||
|
s2 = Struct(('a', Scalar()))
|
||
|
s1 - s2 == Struct(('b', Scalar()))
|
||
|
|
||
|
Example 2
|
||
|
s1 = Struct(
|
||
|
('b', Struct(
|
||
|
('c', Scalar()),
|
||
|
('d', Scalar()),
|
||
|
))
|
||
|
)
|
||
|
s2 = Struct(
|
||
|
('b', Struct(('c', Scalar()))),
|
||
|
)
|
||
|
s1 - s2 == Struct(
|
||
|
('b', Struct(
|
||
|
('d', Scalar()),
|
||
|
)),
|
||
|
)
|
||
|
|
||
|
Example 3
|
||
|
s1 = Struct(
|
||
|
('a', Scalar()),
|
||
|
('b', Struct(
|
||
|
('d', Scalar()),
|
||
|
))
|
||
|
)
|
||
|
s2 = Struct(
|
||
|
('b', Struct(
|
||
|
('c', Scalar())
|
||
|
('d', Scalar())
|
||
|
)),
|
||
|
)
|
||
|
s1 - s2 == Struct(
|
||
|
('a', Scalar()),
|
||
|
)
|
||
|
"""
|
||
|
if not isinstance(other, Struct):
|
||
|
return NotImplemented
|
||
|
|
||
|
children = OrderedDict(self.get_children())
|
||
|
for name, right_field in other.get_children():
|
||
|
if name in children:
|
||
|
left_field = children[name]
|
||
|
if type(left_field) == type(right_field):
|
||
|
if isinstance(left_field, Struct):
|
||
|
child = left_field - right_field
|
||
|
if child.get_children():
|
||
|
children[name] = child
|
||
|
continue
|
||
|
children.pop(name)
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
"Type of left_field, " + str(type(left_field)) +
|
||
|
", is not the same as that of right_field, " +
|
||
|
str(type(right_field)) +
|
||
|
", yet they have the same field name, " + name)
|
||
|
return Struct(*(children.items()))
|
||
|
|
||
|
|
||
|
class Scalar(Field):
|
||
|
"""Represents a typed scalar or tensor of fixed shape.
|
||
|
|
||
|
A Scalar is a leaf in a schema tree, translating to exactly one tensor in
|
||
|
the dataset's underlying storage.
|
||
|
|
||
|
Usually, the tensor storing the actual values of this field is a 1D tensor,
|
||
|
representing a series of values in its domain. It is possible however to
|
||
|
have higher rank values stored as a Scalar, as long as all entries have
|
||
|
the same shape.
|
||
|
|
||
|
E.g.:
|
||
|
|
||
|
Scalar(np.float64)
|
||
|
|
||
|
Scalar field of type float64. Caffe2 will expect readers and
|
||
|
datasets to expose it as a 1D tensor of doubles (vector), where
|
||
|
the size of the vector is determined by this fields' domain.
|
||
|
|
||
|
Scalar((np.int32, 5))
|
||
|
|
||
|
Tensor field of type int32. Caffe2 will expect readers and
|
||
|
datasets to implement it as a 2D tensor (matrix) of shape (L, 5),
|
||
|
where L is determined by this fields' domain.
|
||
|
|
||
|
Scalar((str, (10, 20)))
|
||
|
|
||
|
Tensor field of type str. Caffe2 will expect readers and
|
||
|
datasets to implement it as a 3D tensor of shape (L, 10, 20),
|
||
|
where L is determined by this fields' domain.
|
||
|
|
||
|
If the field type is unknown at construction time, call Scalar(), that will
|
||
|
default to np.void as its dtype.
|
||
|
|
||
|
It is an error to pass a structured dtype to Scalar, since it would contain
|
||
|
more than one field. Instead, use from_dtype, which will construct
|
||
|
a nested `Struct` field reflecting the given dtype's structure.
|
||
|
|
||
|
A Scalar can also contain a blob, which represents the value of this
|
||
|
Scalar. A blob can be either a numpy.ndarray, in which case it contain the
|
||
|
actual contents of the Scalar, or a BlobReference, which represents a
|
||
|
blob living in a caffe2 Workspace. If blob of different types are passed,
|
||
|
a conversion to numpy.ndarray is attempted.
|
||
|
"""
|
||
|
|
||
|
__slots__ = ("_metadata", "dtype", "_original_dtype", "_blob")
|
||
|
|
||
|
def __init__(self, dtype=None, blob=None, metadata=None):
|
||
|
self._metadata = None
|
||
|
self.set(dtype, blob, metadata, unsafe=True)
|
||
|
super(Scalar, self).__init__([])
|
||
|
|
||
|
def field_names(self):
|
||
|
return ['']
|
||
|
|
||
|
def field_type(self):
|
||
|
return self.dtype
|
||
|
|
||
|
def field_types(self):
|
||
|
return [self.dtype]
|
||
|
|
||
|
def field_metadata(self):
|
||
|
return [self._metadata]
|
||
|
|
||
|
def has_blobs(self):
|
||
|
return self._blob is not None
|
||
|
|
||
|
def field_blobs(self):
|
||
|
assert self._blob is not None, 'Value is not set for this field.'
|
||
|
return [self._blob]
|
||
|
|
||
|
def all_scalars(self):
|
||
|
return [self]
|
||
|
|
||
|
def clone(self, keep_blobs=True):
|
||
|
return Scalar(
|
||
|
dtype=self._original_dtype,
|
||
|
blob=self._blob if keep_blobs else None,
|
||
|
metadata=self._metadata
|
||
|
)
|
||
|
|
||
|
def get(self):
|
||
|
"""Gets the current blob of this Scalar field."""
|
||
|
assert self._blob is not None, 'Value is not set for this field.'
|
||
|
return self._blob
|
||
|
|
||
|
def __call__(self):
|
||
|
"""Shortcut for self.get()"""
|
||
|
return self.get()
|
||
|
|
||
|
@property
|
||
|
def metadata(self):
|
||
|
return self._metadata
|
||
|
|
||
|
def set_metadata(self, value):
|
||
|
assert isinstance(value, Metadata), \
|
||
|
'metadata must be Metadata, got {}'.format(type(value))
|
||
|
self._metadata = value
|
||
|
self._validate_metadata()
|
||
|
|
||
|
def _validate_metadata(self):
|
||
|
if self._metadata is None:
|
||
|
return
|
||
|
if (self._metadata.categorical_limit is not None and
|
||
|
self.dtype is not None):
|
||
|
assert np.issubdtype(self.dtype, np.integer), \
|
||
|
"`categorical_limit` can be specified only in integral " + \
|
||
|
"fields but got {}".format(self.dtype)
|
||
|
|
||
|
def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False):
|
||
|
"""Sets only the blob field still validating the existing dtype"""
|
||
|
if self.dtype.base != np.void and throw_on_type_mismatch:
|
||
|
assert isinstance(blob, np.ndarray), "Got {!r}".format(blob)
|
||
|
assert blob.dtype.base == self.dtype.base, (
|
||
|
"Expected {}, got {}".format(self.dtype.base, blob.dtype.base))
|
||
|
self.set(dtype=self._original_dtype, blob=blob, unsafe=unsafe)
|
||
|
|
||
|
def set(self, dtype=None, blob=None, metadata=None, unsafe=False):
|
||
|
"""Set the type and/or blob of this scalar. See __init__ for details.
|
||
|
|
||
|
Args:
|
||
|
dtype: can be any numpy type. If not provided and `blob` is
|
||
|
provided, it will be inferred. If no argument is provided,
|
||
|
this Scalar will be of type np.void.
|
||
|
blob: if provided, can be either a BlobReference or a
|
||
|
numpy.ndarray. If a value of different type is passed,
|
||
|
a conversion to numpy.ndarray is attempted. Strings aren't
|
||
|
accepted, since they can be ambiguous. If you want to pass
|
||
|
a string, to either BlobReference(blob) or np.array(blob).
|
||
|
metadata: optional instance of Metadata, if provided overrides
|
||
|
the metadata information of the scalar
|
||
|
"""
|
||
|
if not unsafe:
|
||
|
logger.warning(
|
||
|
"Scalar should be considered immutable. Only call Scalar.set() "
|
||
|
"on newly created Scalar with unsafe=True. This will become an "
|
||
|
"error soon."
|
||
|
)
|
||
|
if blob is not None and isinstance(blob, basestring):
|
||
|
raise ValueError(
|
||
|
'Passing str blob to Scalar.set() is ambiguous. '
|
||
|
'Do either set(blob=np.array(blob)) or '
|
||
|
'set(blob=BlobReference(blob))'
|
||
|
)
|
||
|
|
||
|
self._original_dtype = dtype
|
||
|
# Numpy will collapse a shape of 1 into an unindexed data array (shape = ()),
|
||
|
# which betrays the docstring of this class (which expects shape = (1,)).
|
||
|
# >>> import numpy as np
|
||
|
# >> np.dtype((np.int32, 1))
|
||
|
# dtype('int32')
|
||
|
# >>> np.dtype((np.int32, 5))
|
||
|
# dtype(('<i4', (5,)))
|
||
|
if dtype is not None and isinstance(dtype, tuple) and dtype[1] == 1:
|
||
|
dtype = (dtype[0], (1,))
|
||
|
if dtype is not None:
|
||
|
if isinstance(dtype, tuple) and dtype[0] == np.void:
|
||
|
raise TypeError(
|
||
|
"Cannot set the Scalar with type {} for blob {}."
|
||
|
"If this blob is the output of some operation, "
|
||
|
"please verify the input of that operation has "
|
||
|
"proper type.".format(dtype, blob)
|
||
|
)
|
||
|
dtype = np.dtype(dtype)
|
||
|
# If blob is not None and it is not a BlobReference, we assume that
|
||
|
# it is actual tensor data, so we will try to cast it to a numpy array.
|
||
|
if blob is not None and not isinstance(blob, BlobReference):
|
||
|
preserve_shape = isinstance(blob, np.ndarray)
|
||
|
if dtype is not None and dtype != np.void:
|
||
|
blob = np.array(blob, dtype=dtype.base)
|
||
|
# if array is empty we may need to reshape a little
|
||
|
if blob.size == 0 and not preserve_shape:
|
||
|
blob = blob.reshape((0, ) + dtype.shape)
|
||
|
else:
|
||
|
assert isinstance(blob, np.ndarray), (
|
||
|
'Invalid blob type: %s' % str(type(blob)))
|
||
|
|
||
|
# reshape scalars into 1D arrays
|
||
|
# TODO(azzolini): figure out better way of representing this
|
||
|
if len(blob.shape) == 0 and not preserve_shape:
|
||
|
blob = blob.reshape((1, ))
|
||
|
|
||
|
# infer inner shape from the blob given
|
||
|
# TODO(dzhulgakov): tweak this to make it work with PackedStruct
|
||
|
if (len(blob.shape) > 1 and dtype is not None and
|
||
|
dtype.base != np.void):
|
||
|
dtype = np.dtype((dtype.base, blob.shape[1:]))
|
||
|
# if we were still unable to infer the dtype
|
||
|
if dtype is None:
|
||
|
dtype = np.dtype(np.void)
|
||
|
assert not dtype.fields, (
|
||
|
'Cannot create Scalar with a structured dtype. ' +
|
||
|
'Use from_dtype instead.'
|
||
|
)
|
||
|
self.dtype = dtype
|
||
|
self._blob = blob
|
||
|
if metadata is not None:
|
||
|
self.set_metadata(metadata)
|
||
|
self._validate_metadata()
|
||
|
|
||
|
def set_type(self, dtype):
|
||
|
self._original_dtype = dtype
|
||
|
if dtype is not None:
|
||
|
self.dtype = np.dtype(dtype)
|
||
|
else:
|
||
|
self.dtype = np.dtype(np.void)
|
||
|
self._validate_metadata()
|
||
|
|
||
|
def _pprint_impl(self, indent, str_buffer):
|
||
|
str_buffer.write(' ' * (indent) +
|
||
|
'Scalar({!r}, {!r}, {!r})'.format(
|
||
|
self.dtype, self._blob, self._metadata) + "\n")
|
||
|
|
||
|
def id(self):
|
||
|
"""
|
||
|
Return the zero-indexed position of this scalar field in its schema.
|
||
|
Used in order to index into the field_blob list returned by readers or
|
||
|
accepted by writers.
|
||
|
"""
|
||
|
return self._child_base_id()
|
||
|
|
||
|
|
||
|
def Map(
|
||
|
keys,
|
||
|
values,
|
||
|
keys_name='keys',
|
||
|
values_name='values',
|
||
|
lengths_blob=None
|
||
|
):
|
||
|
"""A map is a List of Struct containing keys and values fields.
|
||
|
Optionally, you can provide custom name for the key and value fields.
|
||
|
"""
|
||
|
return List(
|
||
|
Struct((keys_name, keys), (values_name, values)),
|
||
|
lengths_blob=lengths_blob
|
||
|
)
|
||
|
|
||
|
def MapWithEvicted(
|
||
|
keys,
|
||
|
values,
|
||
|
keys_name='keys',
|
||
|
values_name='values',
|
||
|
lengths_blob=None,
|
||
|
evicted_values=None
|
||
|
):
|
||
|
"""A map with extra field evicted_values
|
||
|
"""
|
||
|
return ListWithEvicted(
|
||
|
Struct((keys_name, keys), (values_name, values)),
|
||
|
lengths_blob=lengths_blob,
|
||
|
evicted_values=evicted_values
|
||
|
)
|
||
|
|
||
|
|
||
|
def NamedTuple(name_prefix, *fields):
|
||
|
return Struct(* [('%s_%d' % (name_prefix, i), field)
|
||
|
for i, field in enumerate(fields)])
|
||
|
|
||
|
|
||
|
def Tuple(*fields):
|
||
|
"""
|
||
|
Creates a Struct with default, sequential, field names of given types.
|
||
|
"""
|
||
|
return NamedTuple('field', *fields)
|
||
|
|
||
|
|
||
|
def RawTuple(num_fields, name_prefix='field'):
|
||
|
"""
|
||
|
Creates a tuple of `num_field` untyped scalars.
|
||
|
"""
|
||
|
assert isinstance(num_fields, int)
|
||
|
assert num_fields >= 0
|
||
|
return NamedTuple(name_prefix, *([np.void] * num_fields))
|
||
|
|
||
|
|
||
|
def from_dtype(dtype, _outer_shape=()):
|
||
|
"""Constructs a Caffe2 schema from the given numpy's dtype.
|
||
|
|
||
|
Numpy supports scalar, array-like and structured datatypes, as long as
|
||
|
all the shapes are fixed. This function breaks down the given dtype into
|
||
|
a Caffe2 schema containing `Struct` and `Scalar` types.
|
||
|
|
||
|
Fields containing byte offsets are not currently supported.
|
||
|
"""
|
||
|
if not isinstance(dtype, np.dtype):
|
||
|
# wrap into a ndtype
|
||
|
shape = _outer_shape
|
||
|
dtype = np.dtype((dtype, _outer_shape))
|
||
|
else:
|
||
|
# concatenate shapes if necessary
|
||
|
shape = _outer_shape + dtype.shape
|
||
|
if shape != dtype.shape:
|
||
|
dtype = np.dtype((dtype.base, shape))
|
||
|
|
||
|
if not dtype.fields:
|
||
|
return Scalar(dtype)
|
||
|
|
||
|
struct_fields = []
|
||
|
for name, (fdtype, offset) in dtype.fields:
|
||
|
assert offset == 0, ('Fields with byte offsets are not supported.')
|
||
|
struct_fields += (name, from_dtype(fdtype, _outer_shape=shape))
|
||
|
return Struct(*struct_fields)
|
||
|
|
||
|
|
||
|
class _SchemaNode(object):
|
||
|
"""This is a private class used to represent a Schema Node"""
|
||
|
|
||
|
__slots__ = ("name", "children", "type_str", "field")
|
||
|
|
||
|
def __init__(self, name, type_str=''):
|
||
|
self.name = name
|
||
|
self.children = []
|
||
|
self.type_str = type_str
|
||
|
self.field = None
|
||
|
|
||
|
def add_child(self, name, type_str=''):
|
||
|
for child in self.children:
|
||
|
if child.name == name and child.type_str == type_str:
|
||
|
return child
|
||
|
child = _SchemaNode(name, type_str)
|
||
|
self.children.append(child)
|
||
|
return child
|
||
|
|
||
|
def get_field(self):
|
||
|
|
||
|
list_names = ['lengths', 'values']
|
||
|
map_names = ['lengths', 'keys', 'values']
|
||
|
|
||
|
if len(self.children) == 0 or self.field is not None:
|
||
|
if self.field is None:
|
||
|
return Struct()
|
||
|
else:
|
||
|
return self.field
|
||
|
|
||
|
child_names = []
|
||
|
for child in self.children:
|
||
|
child_names.append(child.name)
|
||
|
|
||
|
if (set(child_names) == set(list_names)):
|
||
|
for child in self.children:
|
||
|
if child.name == 'values':
|
||
|
values_field = child.get_field()
|
||
|
else:
|
||
|
lengths_field = child.get_field()
|
||
|
self.field = List(
|
||
|
values_field,
|
||
|
lengths_blob=lengths_field
|
||
|
)
|
||
|
self.type_str = "List"
|
||
|
return self.field
|
||
|
elif (set(child_names) == set(map_names)):
|
||
|
for child in self.children:
|
||
|
if child.name == 'keys':
|
||
|
key_field = child.get_field()
|
||
|
elif child.name == 'values':
|
||
|
values_field = child.get_field()
|
||
|
else:
|
||
|
lengths_field = child.get_field()
|
||
|
self.field = Map(
|
||
|
key_field,
|
||
|
values_field,
|
||
|
lengths_blob=lengths_field
|
||
|
)
|
||
|
self.type_str = "Map"
|
||
|
return self.field
|
||
|
|
||
|
else:
|
||
|
struct_fields = []
|
||
|
for child in self.children:
|
||
|
struct_fields.append((child.name, child.get_field()))
|
||
|
|
||
|
self.field = Struct(*struct_fields)
|
||
|
self.type_str = "Struct"
|
||
|
return self.field
|
||
|
|
||
|
def print_recursively(self):
|
||
|
for child in self.children:
|
||
|
child.print_recursively()
|
||
|
logger.info("Printing node: Name and type")
|
||
|
logger.info(self.name)
|
||
|
logger.info(self.type_str)
|
||
|
|
||
|
|
||
|
def from_column_list(
|
||
|
col_names, col_types=None,
|
||
|
col_blobs=None, col_metadata=None
|
||
|
):
|
||
|
"""
|
||
|
Given a list of names, types, and optionally values, construct a Schema.
|
||
|
"""
|
||
|
if col_types is None:
|
||
|
col_types = [None] * len(col_names)
|
||
|
if col_metadata is None:
|
||
|
col_metadata = [None] * len(col_names)
|
||
|
if col_blobs is None:
|
||
|
col_blobs = [None] * len(col_names)
|
||
|
assert len(col_names) == len(col_types), (
|
||
|
'col_names and col_types must have the same length.'
|
||
|
)
|
||
|
assert len(col_names) == len(col_metadata), (
|
||
|
'col_names and col_metadata must have the same length.'
|
||
|
)
|
||
|
assert len(col_names) == len(col_blobs), (
|
||
|
'col_names and col_blobs must have the same length.'
|
||
|
)
|
||
|
root = _SchemaNode('root', 'Struct')
|
||
|
for col_name, col_type, col_blob, col_metadata in zip(
|
||
|
col_names, col_types, col_blobs, col_metadata
|
||
|
):
|
||
|
columns = col_name.split(FIELD_SEPARATOR)
|
||
|
current = root
|
||
|
for i in range(len(columns)):
|
||
|
name = columns[i]
|
||
|
type_str = ''
|
||
|
field = None
|
||
|
if i == len(columns) - 1:
|
||
|
type_str = col_type
|
||
|
field = Scalar(
|
||
|
dtype=col_type,
|
||
|
blob=col_blob,
|
||
|
metadata=col_metadata
|
||
|
)
|
||
|
next = current.add_child(name, type_str)
|
||
|
if field is not None:
|
||
|
next.field = field
|
||
|
current = next
|
||
|
|
||
|
return root.get_field()
|
||
|
|
||
|
|
||
|
def from_blob_list(schema, values, throw_on_type_mismatch=False):
|
||
|
"""
|
||
|
Create a schema that clones the given schema, but containing the given
|
||
|
list of values.
|
||
|
"""
|
||
|
assert isinstance(schema, Field), 'Argument `schema` must be a Field.'
|
||
|
if isinstance(values, BlobReference):
|
||
|
values = [values]
|
||
|
record = schema.clone_schema()
|
||
|
scalars = record.all_scalars()
|
||
|
assert len(scalars) == len(values), (
|
||
|
'Values must have %d elements, got %d.' % (len(scalars), len(values))
|
||
|
)
|
||
|
for scalar, value in zip(scalars, values):
|
||
|
scalar.set_value(value, throw_on_type_mismatch, unsafe=True)
|
||
|
return record
|
||
|
|
||
|
|
||
|
def as_record(value):
|
||
|
if isinstance(value, Field):
|
||
|
return value
|
||
|
elif isinstance(value, list) or isinstance(value, tuple):
|
||
|
is_field_list = all(
|
||
|
f is tuple and len(f) == 2 and isinstance(f[0], basestring)
|
||
|
for f in value
|
||
|
)
|
||
|
if is_field_list:
|
||
|
return Struct(* [(k, as_record(v)) for k, v in value])
|
||
|
else:
|
||
|
return Tuple(* [as_record(f) for f in value])
|
||
|
elif isinstance(value, dict):
|
||
|
return Struct(* [(k, as_record(v)) for k, v in viewitems(value)])
|
||
|
else:
|
||
|
return _normalize_field(value)
|
||
|
|
||
|
|
||
|
def FetchRecord(blob_record, ws=None, throw_on_type_mismatch=False):
|
||
|
"""
|
||
|
Given a record containing BlobReferences, return a new record with same
|
||
|
schema, containing numpy arrays, fetched from the current active workspace.
|
||
|
"""
|
||
|
|
||
|
def fetch(v):
|
||
|
if ws is None:
|
||
|
return workspace.FetchBlob(str(v))
|
||
|
else:
|
||
|
return ws.blobs[str(v)].fetch()
|
||
|
|
||
|
assert isinstance(blob_record, Field)
|
||
|
field_blobs = blob_record.field_blobs()
|
||
|
assert all(isinstance(v, BlobReference) for v in field_blobs)
|
||
|
field_arrays = [fetch(value) for value in field_blobs]
|
||
|
return from_blob_list(blob_record, field_arrays, throw_on_type_mismatch)
|
||
|
|
||
|
|
||
|
def FeedRecord(blob_record, arrays, ws=None):
|
||
|
"""
|
||
|
Given a Record containing blob_references and arrays, which is either
|
||
|
a list of numpy arrays or a Record containing numpy arrays, feeds the
|
||
|
record to the current workspace.
|
||
|
"""
|
||
|
|
||
|
def feed(b, v):
|
||
|
if ws is None:
|
||
|
workspace.FeedBlob(str(b), v)
|
||
|
else:
|
||
|
ws.create_blob(str(b))
|
||
|
ws.blobs[str(b)].feed(v)
|
||
|
assert isinstance(blob_record, Field)
|
||
|
field_blobs = blob_record.field_blobs()
|
||
|
assert all(isinstance(v, BlobReference) for v in field_blobs)
|
||
|
if isinstance(arrays, Field):
|
||
|
# TODO: check schema
|
||
|
arrays = arrays.field_blobs()
|
||
|
assert len(arrays) == len(field_blobs), (
|
||
|
'Values must contain exactly %d ndarrays.' % len(field_blobs)
|
||
|
)
|
||
|
for blob, array in zip(field_blobs, arrays):
|
||
|
feed(blob, array)
|
||
|
|
||
|
|
||
|
def NewRecord(net, schema):
|
||
|
"""
|
||
|
Given a record of np.arrays, create a BlobReference for each one of them,
|
||
|
returning a record containing BlobReferences. The name of each returned blob
|
||
|
is NextScopedBlob(field_name), which guarantees unique name in the current
|
||
|
net. Use NameScope explicitly to avoid name conflictions between different
|
||
|
nets.
|
||
|
"""
|
||
|
if isinstance(schema, Scalar):
|
||
|
result = schema.clone()
|
||
|
result.set_value(
|
||
|
blob=net.NextScopedBlob('unnamed_scalar'),
|
||
|
unsafe=True,
|
||
|
)
|
||
|
return result
|
||
|
|
||
|
assert isinstance(schema, Field), 'Record must be a schema.Field instance.'
|
||
|
blob_refs = [
|
||
|
net.NextScopedBlob(prefix=name)
|
||
|
for name in schema.field_names()
|
||
|
]
|
||
|
return from_blob_list(schema, blob_refs)
|
||
|
|
||
|
|
||
|
def ConstRecord(net, array_record):
|
||
|
"""
|
||
|
Given a record of arrays, returns a record of blobs,
|
||
|
initialized with net.Const.
|
||
|
"""
|
||
|
blob_record = NewRecord(net, array_record)
|
||
|
for blob, array in zip(
|
||
|
blob_record.field_blobs(), array_record.field_blobs()
|
||
|
):
|
||
|
net.Const(array, blob)
|
||
|
return blob_record
|
||
|
|
||
|
|
||
|
def InitEmptyRecord(net, schema_or_record, enforce_types=False):
|
||
|
if not schema_or_record.has_blobs():
|
||
|
record = NewRecord(net, schema_or_record)
|
||
|
else:
|
||
|
record = schema_or_record
|
||
|
|
||
|
for blob_type, blob in zip(record.field_types(), record.field_blobs()):
|
||
|
try:
|
||
|
data_type = data_type_for_dtype(blob_type)
|
||
|
shape = [0] + list(blob_type.shape)
|
||
|
net.ConstantFill([], blob, shape=shape, dtype=data_type)
|
||
|
except TypeError:
|
||
|
logger.warning("Blob {} has type error".format(blob))
|
||
|
# If data_type_for_dtype doesn't know how to resolve given numpy
|
||
|
# type to core.DataType, that function can throw type error (for
|
||
|
# example that would happen for cases of unknown types such as
|
||
|
# np.void). This is not a problem for cases when the record if going
|
||
|
# to be overwritten by some operator later, though it might be an
|
||
|
# issue for type/shape inference.
|
||
|
if enforce_types:
|
||
|
raise
|
||
|
# If we don't enforce types for all items we'll create a blob with
|
||
|
# the default ConstantFill (FLOAT, no shape)
|
||
|
net.ConstantFill([], blob, shape=[0])
|
||
|
|
||
|
return record
|
||
|
|
||
|
|
||
|
_DATA_TYPE_FOR_DTYPE = [
|
||
|
(np.str, core.DataType.STRING),
|
||
|
(np.float16, core.DataType.FLOAT16),
|
||
|
(np.float32, core.DataType.FLOAT),
|
||
|
(np.float64, core.DataType.DOUBLE),
|
||
|
(np.bool, core.DataType.BOOL),
|
||
|
(np.int8, core.DataType.INT8),
|
||
|
(np.int16, core.DataType.INT16),
|
||
|
(np.int32, core.DataType.INT32),
|
||
|
(np.int64, core.DataType.INT64),
|
||
|
(np.uint8, core.DataType.UINT8),
|
||
|
(np.uint16, core.DataType.UINT16),
|
||
|
]
|
||
|
|
||
|
|
||
|
def is_schema_subset(schema, original_schema):
|
||
|
# TODO add more checks
|
||
|
return set(schema.field_names()).issubset(
|
||
|
set(original_schema.field_names()))
|
||
|
|
||
|
def equal_schemas(schema,
|
||
|
original_schema,
|
||
|
check_field_names=True,
|
||
|
check_field_types=True,
|
||
|
check_field_metas=False):
|
||
|
assert isinstance(schema, Field)
|
||
|
assert isinstance(original_schema, Field)
|
||
|
|
||
|
if check_field_names and (
|
||
|
schema.field_names() != original_schema.field_names()):
|
||
|
return False
|
||
|
if check_field_types and (
|
||
|
schema.field_types() != original_schema.field_types()):
|
||
|
return False
|
||
|
if check_field_metas and (
|
||
|
schema.field_metadata() != original_schema.field_metadata()):
|
||
|
return False
|
||
|
|
||
|
return True
|
||
|
|
||
|
|
||
|
def schema_check(schema, previous=None):
|
||
|
record = as_record(schema)
|
||
|
if previous is not None:
|
||
|
assert equal_schemas(schema, previous)
|
||
|
return record
|
||
|
|
||
|
|
||
|
def data_type_for_dtype(dtype):
|
||
|
for np_type, dt in _DATA_TYPE_FOR_DTYPE:
|
||
|
if dtype.base == np_type:
|
||
|
return dt
|
||
|
raise TypeError('Unknown dtype: ' + str(dtype.base))
|
||
|
|
||
|
|
||
|
def dtype_for_core_type(core_type):
|
||
|
for np_type, dt in _DATA_TYPE_FOR_DTYPE:
|
||
|
if dt == core_type:
|
||
|
return np_type
|
||
|
raise TypeError('Unknown core type: ' + str(core_type))
|
||
|
|
||
|
|
||
|
def attach_metadata_to_scalars(field, metadata):
|
||
|
for f in field.all_scalars():
|
||
|
f.set_metadata(metadata)
|