# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Implementation of LoadDataset in Python.""" import multiprocessing import os from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import structured_function from tensorflow.python.eager import context from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.platform import gfile from tensorflow.python.util import lazy_loader # TODO(b/238903802): Use TypeSpec serialization methods directly. nested_structure_coder = lazy_loader.LazyLoader( "nested_structure_coder", globals(), "tensorflow.python.saved_model.nested_structure_coder") def _load(path, element_spec, compression, reader_func): return _LoadDataset(path, element_spec, compression, reader_func) class _LoadDataset(dataset_ops.DatasetSource): """A dataset that loads previously saved dataset.""" def __init__(self, path, element_spec=None, compression=None, reader_func=None): if reader_func is None: reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=multiprocessing.cpu_count(), num_parallel_calls=dataset_ops.AUTOTUNE) self._path = path if element_spec is None: if not context.executing_eagerly(): raise ValueError( "In graph mode the `element_spec` argument must be provided.") with gfile.GFile( os.path.join(path, dataset_ops.DATASET_SPEC_FILENAME), "rb") as f: encoded_spec = f.read() struct_pb = nested_structure_coder.struct_pb2.StructuredValue() struct_pb.ParseFromString(encoded_spec) spec = nested_structure_coder.decode_proto(struct_pb) self._element_spec = spec else: self._element_spec = element_spec self._compression = compression self._reader_func = structured_function.StructuredFunctionWrapper( reader_func, "load()", # Dataset of datasets of input elements input_structure=dataset_ops.DatasetSpec( dataset_ops.DatasetSpec(self._element_spec))) variant_tensor = ged_ops.load_dataset( path, reader_func_other_args=self._reader_func.function.captured_inputs, compression=compression, reader_func=self._reader_func.function, **self._flat_structure) super().__init__(variant_tensor) def _functions(self): return [self._reader_func] @property def element_spec(self): return self._element_spec