339 lines
12 KiB
Python
339 lines
12 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Base class for testing reader datasets."""
|
|
|
|
import os
|
|
|
|
from tensorflow.core.example import example_pb2
|
|
from tensorflow.core.example import feature_pb2
|
|
from tensorflow.python.data.experimental.ops import readers
|
|
from tensorflow.python.data.kernel_tests import test_base
|
|
from tensorflow.python.data.ops import readers as core_readers
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.lib.io import python_io
|
|
from tensorflow.python.ops import parsing_ops
|
|
from tensorflow.python.util import compat
|
|
|
|
|
|
class FeaturesTestBase(test_base.DatasetTestBase):
|
|
"""Base class for testing TFRecord-based features."""
|
|
|
|
def setUp(self):
|
|
super(FeaturesTestBase, self).setUp()
|
|
self._num_files = 2
|
|
self._num_records = 7
|
|
self._filenames = self._createFiles()
|
|
|
|
def make_batch_feature(self,
|
|
filenames,
|
|
num_epochs,
|
|
batch_size,
|
|
label_key=None,
|
|
reader_num_threads=1,
|
|
parser_num_threads=1,
|
|
shuffle=False,
|
|
shuffle_seed=None,
|
|
drop_final_batch=False):
|
|
self.filenames = filenames
|
|
self.num_epochs = num_epochs
|
|
self.batch_size = batch_size
|
|
|
|
return readers.make_batched_features_dataset(
|
|
file_pattern=self.filenames,
|
|
batch_size=self.batch_size,
|
|
features={
|
|
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
|
|
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
|
|
"keywords": parsing_ops.VarLenFeature(dtypes.string),
|
|
"label": parsing_ops.FixedLenFeature([], dtypes.string),
|
|
},
|
|
label_key=label_key,
|
|
reader=core_readers.TFRecordDataset,
|
|
num_epochs=self.num_epochs,
|
|
shuffle=shuffle,
|
|
shuffle_seed=shuffle_seed,
|
|
reader_num_threads=reader_num_threads,
|
|
parser_num_threads=parser_num_threads,
|
|
drop_final_batch=drop_final_batch)
|
|
|
|
def _record(self, f, r, l):
|
|
example = example_pb2.Example(
|
|
features=feature_pb2.Features(
|
|
feature={
|
|
"file":
|
|
feature_pb2.Feature(
|
|
int64_list=feature_pb2.Int64List(value=[f])),
|
|
"record":
|
|
feature_pb2.Feature(
|
|
int64_list=feature_pb2.Int64List(value=[r])),
|
|
"keywords":
|
|
feature_pb2.Feature(
|
|
bytes_list=feature_pb2.BytesList(
|
|
value=self._get_keywords(f, r))),
|
|
"label":
|
|
feature_pb2.Feature(
|
|
bytes_list=feature_pb2.BytesList(
|
|
value=[compat.as_bytes(l)]))
|
|
}))
|
|
return example.SerializeToString()
|
|
|
|
def _get_keywords(self, f, r):
|
|
num_keywords = 1 + (f + r) % 2
|
|
keywords = []
|
|
for index in range(num_keywords):
|
|
keywords.append(compat.as_bytes("keyword%d" % index))
|
|
return keywords
|
|
|
|
def _sum_keywords(self, num_files):
|
|
sum_keywords = 0
|
|
for i in range(num_files):
|
|
for j in range(self._num_records):
|
|
sum_keywords += 1 + (i + j) % 2
|
|
return sum_keywords
|
|
|
|
def _createFiles(self):
|
|
filenames = []
|
|
for i in range(self._num_files):
|
|
fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
|
|
filenames.append(fn)
|
|
writer = python_io.TFRecordWriter(fn)
|
|
for j in range(self._num_records):
|
|
writer.write(self._record(i, j, "fake-label"))
|
|
writer.close()
|
|
return filenames
|
|
|
|
def _run_actual_batch(self, outputs, label_key_provided=False):
|
|
if label_key_provided:
|
|
# outputs would be a tuple of (feature dict, label)
|
|
features, label = self.evaluate(outputs())
|
|
else:
|
|
features = self.evaluate(outputs())
|
|
label = features["label"]
|
|
file_out = features["file"]
|
|
keywords_indices = features["keywords"].indices
|
|
keywords_values = features["keywords"].values
|
|
keywords_dense_shape = features["keywords"].dense_shape
|
|
record = features["record"]
|
|
return ([
|
|
file_out, keywords_indices, keywords_values, keywords_dense_shape,
|
|
record, label
|
|
])
|
|
|
|
def _next_actual_batch(self, label_key_provided=False):
|
|
return self._run_actual_batch(self.outputs, label_key_provided)
|
|
|
|
def _interleave(self, iterators, cycle_length):
|
|
pending_iterators = iterators
|
|
open_iterators = []
|
|
num_open = 0
|
|
for i in range(cycle_length):
|
|
if pending_iterators:
|
|
open_iterators.append(pending_iterators.pop(0))
|
|
num_open += 1
|
|
|
|
while num_open:
|
|
for i in range(min(cycle_length, len(open_iterators))):
|
|
if open_iterators[i] is None:
|
|
continue
|
|
try:
|
|
yield next(open_iterators[i])
|
|
except StopIteration:
|
|
if pending_iterators:
|
|
open_iterators[i] = pending_iterators.pop(0)
|
|
else:
|
|
open_iterators[i] = None
|
|
num_open -= 1
|
|
|
|
def _next_expected_batch(self,
|
|
file_indices,
|
|
batch_size,
|
|
num_epochs,
|
|
cycle_length=1):
|
|
|
|
def _next_record(file_indices):
|
|
for j in file_indices:
|
|
for i in range(self._num_records):
|
|
yield j, i, compat.as_bytes("fake-label")
|
|
|
|
def _next_record_interleaved(file_indices, cycle_length):
|
|
return self._interleave([_next_record([i]) for i in file_indices],
|
|
cycle_length)
|
|
|
|
file_batch = []
|
|
keywords_batch_indices = []
|
|
keywords_batch_values = []
|
|
keywords_batch_max_len = 0
|
|
record_batch = []
|
|
batch_index = 0
|
|
label_batch = []
|
|
for _ in range(num_epochs):
|
|
if cycle_length == 1:
|
|
next_records = _next_record(file_indices)
|
|
else:
|
|
next_records = _next_record_interleaved(file_indices, cycle_length)
|
|
for record in next_records:
|
|
f = record[0]
|
|
r = record[1]
|
|
label_batch.append(record[2])
|
|
file_batch.append(f)
|
|
record_batch.append(r)
|
|
keywords = self._get_keywords(f, r)
|
|
keywords_batch_values.extend(keywords)
|
|
keywords_batch_indices.extend(
|
|
[[batch_index, i] for i in range(len(keywords))])
|
|
batch_index += 1
|
|
keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
|
|
if len(file_batch) == batch_size:
|
|
yield [
|
|
file_batch, keywords_batch_indices, keywords_batch_values,
|
|
[batch_size, keywords_batch_max_len], record_batch, label_batch
|
|
]
|
|
file_batch = []
|
|
keywords_batch_indices = []
|
|
keywords_batch_values = []
|
|
keywords_batch_max_len = 0
|
|
record_batch = []
|
|
batch_index = 0
|
|
label_batch = []
|
|
if file_batch:
|
|
yield [
|
|
file_batch, keywords_batch_indices, keywords_batch_values,
|
|
[len(file_batch), keywords_batch_max_len], record_batch, label_batch
|
|
]
|
|
|
|
def _verify_records(self,
|
|
batch_size,
|
|
file_index=None,
|
|
num_epochs=1,
|
|
label_key_provided=False,
|
|
interleave_cycle_length=1):
|
|
if file_index is not None:
|
|
file_indices = [file_index]
|
|
else:
|
|
file_indices = range(self._num_files)
|
|
|
|
for expected_batch in self._next_expected_batch(
|
|
file_indices,
|
|
batch_size,
|
|
num_epochs,
|
|
cycle_length=interleave_cycle_length):
|
|
actual_batch = self._next_actual_batch(
|
|
label_key_provided=label_key_provided)
|
|
for i in range(len(expected_batch)):
|
|
self.assertAllEqual(expected_batch[i], actual_batch[i])
|
|
|
|
|
|
class TFRecordTestBase(test_base.DatasetTestBase):
|
|
"""Base class for TFRecord-based tests."""
|
|
|
|
def setUp(self):
|
|
super(TFRecordTestBase, self).setUp()
|
|
self._num_files = 2
|
|
self._num_records = 7
|
|
self._filenames = self._createFiles()
|
|
|
|
def _interleave(self, iterators, cycle_length):
|
|
pending_iterators = iterators
|
|
open_iterators = []
|
|
num_open = 0
|
|
for i in range(cycle_length):
|
|
if pending_iterators:
|
|
open_iterators.append(pending_iterators.pop(0))
|
|
num_open += 1
|
|
|
|
while num_open:
|
|
for i in range(min(cycle_length, len(open_iterators))):
|
|
if open_iterators[i] is None:
|
|
continue
|
|
try:
|
|
yield next(open_iterators[i])
|
|
except StopIteration:
|
|
if pending_iterators:
|
|
open_iterators[i] = pending_iterators.pop(0)
|
|
else:
|
|
open_iterators[i] = None
|
|
num_open -= 1
|
|
|
|
def _next_expected_batch(self, file_indices, batch_size, num_epochs,
|
|
cycle_length, drop_final_batch, use_parser_fn):
|
|
|
|
def _next_record(file_indices):
|
|
for j in file_indices:
|
|
for i in range(self._num_records):
|
|
yield j, i
|
|
|
|
def _next_record_interleaved(file_indices, cycle_length):
|
|
return self._interleave([_next_record([i]) for i in file_indices],
|
|
cycle_length)
|
|
|
|
record_batch = []
|
|
batch_index = 0
|
|
for _ in range(num_epochs):
|
|
if cycle_length == 1:
|
|
next_records = _next_record(file_indices)
|
|
else:
|
|
next_records = _next_record_interleaved(file_indices, cycle_length)
|
|
for f, r in next_records:
|
|
record = self._record(f, r)
|
|
if use_parser_fn:
|
|
record = record[1:]
|
|
record_batch.append(record)
|
|
batch_index += 1
|
|
if len(record_batch) == batch_size:
|
|
yield record_batch
|
|
record_batch = []
|
|
batch_index = 0
|
|
if record_batch and not drop_final_batch:
|
|
yield record_batch
|
|
|
|
def _verify_records(self, outputs, batch_size, file_index, num_epochs,
|
|
interleave_cycle_length, drop_final_batch, use_parser_fn):
|
|
if file_index is not None:
|
|
if isinstance(file_index, list):
|
|
file_indices = file_index
|
|
else:
|
|
file_indices = [file_index]
|
|
else:
|
|
file_indices = range(self._num_files)
|
|
|
|
for expected_batch in self._next_expected_batch(
|
|
file_indices, batch_size, num_epochs, interleave_cycle_length,
|
|
drop_final_batch, use_parser_fn):
|
|
actual_batch = self.evaluate(outputs())
|
|
self.assertAllEqual(expected_batch, actual_batch)
|
|
|
|
def _record(self, f, r):
|
|
return compat.as_bytes("Record %d of file %d" % (r, f))
|
|
|
|
def _createFiles(self):
|
|
filenames = []
|
|
for i in range(self._num_files):
|
|
fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
|
|
filenames.append(fn)
|
|
writer = python_io.TFRecordWriter(fn)
|
|
for j in range(self._num_records):
|
|
writer.write(self._record(i, j))
|
|
writer.close()
|
|
return filenames
|
|
|
|
def _writeFile(self, name, data):
|
|
filename = os.path.join(self.get_temp_dir(), name)
|
|
writer = python_io.TFRecordWriter(filename)
|
|
for d in data:
|
|
writer.write(compat.as_bytes(str(d)))
|
|
writer.close()
|
|
return filename
|
|
|