446 lines
17 KiB
Python
446 lines
17 KiB
Python
|
|
|
|
|
|
|
|
|
|
from caffe2.python.dataio import (
|
|
CompositeReader,
|
|
CompositeReaderBuilder,
|
|
ReaderBuilder,
|
|
ReaderWithDelay,
|
|
ReaderWithLimit,
|
|
ReaderWithTimeLimit,
|
|
)
|
|
from caffe2.python.dataset import Dataset
|
|
from caffe2.python.db_file_reader import DBFileReader
|
|
from caffe2.python.pipeline import pipe
|
|
from caffe2.python.schema import Struct, NewRecord, FeedRecord
|
|
from caffe2.python.session import LocalSession
|
|
from caffe2.python.task import TaskGroup, final_output, WorkspaceType
|
|
from caffe2.python.test_util import TestCase
|
|
from caffe2.python.cached_reader import CachedReader
|
|
from caffe2.python import core, workspace, schema
|
|
from caffe2.python.net_builder import ops
|
|
|
|
import numpy as np
|
|
import numpy.testing as npt
|
|
import os
|
|
import shutil
|
|
import unittest
|
|
import tempfile
|
|
|
|
|
|
def make_source_dataset(ws, size=100, offset=0, name=None):
|
|
name = name or "src"
|
|
src_init = core.Net("{}_init".format(name))
|
|
with core.NameScope(name):
|
|
src_values = Struct(('label', np.array(range(offset, offset + size))))
|
|
src_blobs = NewRecord(src_init, src_values)
|
|
src_ds = Dataset(src_blobs, name=name)
|
|
FeedRecord(src_blobs, src_values, ws)
|
|
ws.run(src_init)
|
|
return src_ds
|
|
|
|
|
|
def make_destination_dataset(ws, schema, name=None):
|
|
name = name or 'dst'
|
|
dst_init = core.Net('{}_init'.format(name))
|
|
with core.NameScope(name):
|
|
dst_ds = Dataset(schema, name=name)
|
|
dst_ds.init_empty(dst_init)
|
|
ws.run(dst_init)
|
|
return dst_ds
|
|
|
|
|
|
class TestReaderBuilder(ReaderBuilder):
|
|
def __init__(self, name, size, offset):
|
|
self._schema = schema.Struct(
|
|
('label', schema.Scalar()),
|
|
)
|
|
self._name = name
|
|
self._size = size
|
|
self._offset = offset
|
|
self._src_ds = None
|
|
|
|
def schema(self):
|
|
return self._schema
|
|
|
|
def setup(self, ws):
|
|
self._src_ds = make_source_dataset(ws, offset=self._offset, size=self._size,
|
|
name=self._name)
|
|
return {}
|
|
|
|
def new_reader(self, **kwargs):
|
|
return self._src_ds
|
|
|
|
|
|
class TestCompositeReader(TestCase):
|
|
@unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins')
|
|
def test_composite_reader(self):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
num_srcs = 3
|
|
names = ["src_{}".format(i) for i in range(num_srcs)]
|
|
size = 100
|
|
offsets = [i * size for i in range(num_srcs)]
|
|
src_dses = [make_source_dataset(ws, offset=offset, size=size, name=name)
|
|
for (name, offset) in zip(names, offsets)]
|
|
|
|
data = [ws.fetch_blob(str(src.field_blobs[0])) for src in src_dses]
|
|
# Sanity check we didn't overwrite anything
|
|
for d, offset in zip(data, offsets):
|
|
npt.assert_array_equal(d, range(offset, offset + size))
|
|
|
|
# Make an identically-sized empty destination dataset
|
|
dst_ds_schema = schema.Struct(
|
|
*[
|
|
(name, src_ds.content().clone_schema())
|
|
for name, src_ds in zip(names, src_dses)
|
|
]
|
|
)
|
|
dst_ds = make_destination_dataset(ws, dst_ds_schema)
|
|
|
|
with TaskGroup() as tg:
|
|
reader = CompositeReader(names,
|
|
[src_ds.reader() for src_ds in src_dses])
|
|
pipe(reader, dst_ds.writer(), num_runtime_threads=3)
|
|
session.run(tg)
|
|
|
|
for i in range(num_srcs):
|
|
written_data = sorted(
|
|
ws.fetch_blob(str(dst_ds.content()[names[i]].label())))
|
|
npt.assert_array_equal(data[i], written_data, "i: {}".format(i))
|
|
|
|
@unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins')
|
|
def test_composite_reader_builder(self):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
num_srcs = 3
|
|
names = ["src_{}".format(i) for i in range(num_srcs)]
|
|
size = 100
|
|
offsets = [i * size for i in range(num_srcs)]
|
|
src_ds_builders = [
|
|
TestReaderBuilder(offset=offset, size=size, name=name)
|
|
for (name, offset) in zip(names, offsets)
|
|
]
|
|
|
|
# Make an identically-sized empty destination dataset
|
|
dst_ds_schema = schema.Struct(
|
|
*[
|
|
(name, src_ds_builder.schema())
|
|
for name, src_ds_builder in zip(names, src_ds_builders)
|
|
]
|
|
)
|
|
dst_ds = make_destination_dataset(ws, dst_ds_schema)
|
|
|
|
with TaskGroup() as tg:
|
|
reader_builder = CompositeReaderBuilder(
|
|
names, src_ds_builders)
|
|
reader_builder.setup(ws=ws)
|
|
pipe(reader_builder.new_reader(), dst_ds.writer(),
|
|
num_runtime_threads=3)
|
|
session.run(tg)
|
|
|
|
for name, offset in zip(names, offsets):
|
|
written_data = sorted(
|
|
ws.fetch_blob(str(dst_ds.content()[name].label())))
|
|
npt.assert_array_equal(range(offset, offset + size), written_data,
|
|
"name: {}".format(name))
|
|
|
|
|
|
class TestReaderWithLimit(TestCase):
|
|
def test_runtime_threads(self):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
src_ds = make_source_dataset(ws)
|
|
totals = [None] * 3
|
|
|
|
def proc(rec):
|
|
# executed once
|
|
with ops.task_init():
|
|
counter1 = ops.CreateCounter([], ['global_counter'])
|
|
counter2 = ops.CreateCounter([], ['global_counter2'])
|
|
counter3 = ops.CreateCounter([], ['global_counter3'])
|
|
# executed once per thread
|
|
with ops.task_instance_init():
|
|
task_counter = ops.CreateCounter([], ['task_counter'])
|
|
# executed on each iteration
|
|
ops.CountUp(counter1)
|
|
ops.CountUp(task_counter)
|
|
# executed once per thread
|
|
with ops.task_instance_exit():
|
|
with ops.loop(ops.RetrieveCount(task_counter)):
|
|
ops.CountUp(counter2)
|
|
ops.CountUp(counter3)
|
|
# executed once
|
|
with ops.task_exit():
|
|
totals[0] = final_output(ops.RetrieveCount(counter1))
|
|
totals[1] = final_output(ops.RetrieveCount(counter2))
|
|
totals[2] = final_output(ops.RetrieveCount(counter3))
|
|
return rec
|
|
|
|
# Read full data set from original reader
|
|
with TaskGroup() as tg:
|
|
pipe(src_ds.reader(), num_runtime_threads=8, processor=proc)
|
|
session.run(tg)
|
|
self.assertEqual(totals[0].fetch(), 100)
|
|
self.assertEqual(totals[1].fetch(), 100)
|
|
self.assertEqual(totals[2].fetch(), 8)
|
|
|
|
# Read with a count-limited reader
|
|
with TaskGroup() as tg:
|
|
q1 = pipe(src_ds.reader(), num_runtime_threads=2)
|
|
q2 = pipe(
|
|
ReaderWithLimit(q1.reader(), num_iter=25),
|
|
num_runtime_threads=3)
|
|
pipe(q2, processor=proc, num_runtime_threads=6)
|
|
session.run(tg)
|
|
self.assertEqual(totals[0].fetch(), 25)
|
|
self.assertEqual(totals[1].fetch(), 25)
|
|
self.assertEqual(totals[2].fetch(), 6)
|
|
|
|
def _test_limit_reader_init_shared(self, size):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
|
|
# Make source dataset
|
|
src_ds = make_source_dataset(ws, size=size)
|
|
|
|
# Make an identically-sized empty destination Dataset
|
|
dst_ds = make_destination_dataset(ws, src_ds.content().clone_schema())
|
|
|
|
return ws, session, src_ds, dst_ds
|
|
|
|
def _test_limit_reader_shared(self, reader_class, size, expected_read_len,
|
|
expected_read_len_threshold,
|
|
expected_finish, num_threads, read_delay,
|
|
**limiter_args):
|
|
ws, session, src_ds, dst_ds = \
|
|
self._test_limit_reader_init_shared(size)
|
|
|
|
# Read without limiter
|
|
# WorkspaceType.GLOBAL is required because we are fetching
|
|
# reader.data_finished() after the TaskGroup finishes.
|
|
with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg:
|
|
if read_delay > 0:
|
|
reader = reader_class(ReaderWithDelay(src_ds.reader(),
|
|
read_delay),
|
|
**limiter_args)
|
|
else:
|
|
reader = reader_class(src_ds.reader(), **limiter_args)
|
|
pipe(reader, dst_ds.writer(), num_runtime_threads=num_threads)
|
|
session.run(tg)
|
|
read_len = len(sorted(ws.blobs[str(dst_ds.content().label())].fetch()))
|
|
|
|
# Do a fuzzy match (expected_read_len +/- expected_read_len_threshold)
|
|
# to eliminate flakiness for time-limited tests
|
|
self.assertGreaterEqual(
|
|
read_len,
|
|
expected_read_len - expected_read_len_threshold)
|
|
self.assertLessEqual(
|
|
read_len,
|
|
expected_read_len + expected_read_len_threshold)
|
|
self.assertEqual(
|
|
sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
|
|
list(range(read_len))
|
|
)
|
|
self.assertEqual(ws.blobs[str(reader.data_finished())].fetch(),
|
|
expected_finish)
|
|
|
|
def test_count_limit_reader_without_limit(self):
|
|
# No iter count specified, should read all records.
|
|
self._test_limit_reader_shared(ReaderWithLimit,
|
|
size=100,
|
|
expected_read_len=100,
|
|
expected_read_len_threshold=0,
|
|
expected_finish=True,
|
|
num_threads=8,
|
|
read_delay=0,
|
|
num_iter=None)
|
|
|
|
def test_count_limit_reader_with_zero_limit(self):
|
|
# Zero iter count specified, should read 0 records.
|
|
self._test_limit_reader_shared(ReaderWithLimit,
|
|
size=100,
|
|
expected_read_len=0,
|
|
expected_read_len_threshold=0,
|
|
expected_finish=False,
|
|
num_threads=8,
|
|
read_delay=0,
|
|
num_iter=0)
|
|
|
|
def test_count_limit_reader_with_low_limit(self):
|
|
# Read with limit smaller than size of dataset
|
|
self._test_limit_reader_shared(ReaderWithLimit,
|
|
size=100,
|
|
expected_read_len=10,
|
|
expected_read_len_threshold=0,
|
|
expected_finish=False,
|
|
num_threads=8,
|
|
read_delay=0,
|
|
num_iter=10)
|
|
|
|
def test_count_limit_reader_with_high_limit(self):
|
|
# Read with limit larger than size of dataset
|
|
self._test_limit_reader_shared(ReaderWithLimit,
|
|
size=100,
|
|
expected_read_len=100,
|
|
expected_read_len_threshold=0,
|
|
expected_finish=True,
|
|
num_threads=8,
|
|
read_delay=0,
|
|
num_iter=110)
|
|
|
|
def test_time_limit_reader_without_limit(self):
|
|
# No duration specified, should read all records.
|
|
self._test_limit_reader_shared(ReaderWithTimeLimit,
|
|
size=100,
|
|
expected_read_len=100,
|
|
expected_read_len_threshold=0,
|
|
expected_finish=True,
|
|
num_threads=8,
|
|
read_delay=0.1,
|
|
duration=0)
|
|
|
|
def test_time_limit_reader_with_short_limit(self):
|
|
# Read with insufficient time limit
|
|
size = 50
|
|
num_threads = 4
|
|
sleep_duration = 0.25
|
|
duration = 1
|
|
expected_read_len = int(round(num_threads * duration / sleep_duration))
|
|
# Because the time limit check happens before the delay + read op,
|
|
# subtract a little bit of time to ensure we don't get in an extra read
|
|
duration = duration - 0.25 * sleep_duration
|
|
|
|
# NOTE: `expected_read_len_threshold` was added because this test case
|
|
# has significant execution variation under stress. Under stress, we may
|
|
# read strictly less than the expected # of samples; anywhere from
|
|
# [0,N] where N = expected_read_len.
|
|
# Hence we set expected_read_len to N/2, plus or minus N/2.
|
|
self._test_limit_reader_shared(ReaderWithTimeLimit,
|
|
size=size,
|
|
expected_read_len=expected_read_len / 2,
|
|
expected_read_len_threshold=expected_read_len / 2,
|
|
expected_finish=False,
|
|
num_threads=num_threads,
|
|
read_delay=sleep_duration,
|
|
duration=duration)
|
|
|
|
def test_time_limit_reader_with_long_limit(self):
|
|
# Read with ample time limit
|
|
# NOTE: we don't use `expected_read_len_threshold` because the duration,
|
|
# read_delay, and # threads should be more than sufficient
|
|
self._test_limit_reader_shared(ReaderWithTimeLimit,
|
|
size=50,
|
|
expected_read_len=50,
|
|
expected_read_len_threshold=0,
|
|
expected_finish=True,
|
|
num_threads=4,
|
|
read_delay=0.2,
|
|
duration=10)
|
|
|
|
|
|
class TestDBFileReader(TestCase):
|
|
def setUp(self):
|
|
self.temp_paths = []
|
|
|
|
def tearDown(self):
|
|
# In case any test method fails, clean up temp paths.
|
|
for path in self.temp_paths:
|
|
self._delete_path(path)
|
|
|
|
@staticmethod
|
|
def _delete_path(path):
|
|
if os.path.isfile(path):
|
|
os.remove(path) # Remove file.
|
|
elif os.path.isdir(path):
|
|
shutil.rmtree(path) # Remove dir recursively.
|
|
|
|
def _make_temp_path(self):
|
|
# Make a temp path as db_path.
|
|
with tempfile.NamedTemporaryFile() as f:
|
|
temp_path = f.name
|
|
self.temp_paths.append(temp_path)
|
|
return temp_path
|
|
|
|
@staticmethod
|
|
def _build_source_reader(ws, size):
|
|
src_ds = make_source_dataset(ws, size)
|
|
return src_ds.reader()
|
|
|
|
@staticmethod
|
|
def _read_all_data(ws, reader, session):
|
|
dst_ds = make_destination_dataset(ws, reader.schema().clone_schema())
|
|
|
|
with TaskGroup() as tg:
|
|
pipe(reader, dst_ds.writer(), num_runtime_threads=8)
|
|
session.run(tg)
|
|
|
|
return ws.blobs[str(dst_ds.content().label())].fetch()
|
|
|
|
@unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB")
|
|
def test_cached_reader(self):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
db_path = self._make_temp_path()
|
|
|
|
# Read data for the first time.
|
|
cached_reader1 = CachedReader(
|
|
self._build_source_reader(ws, 100), db_path, loop_over=False,
|
|
)
|
|
build_cache_step = cached_reader1.build_cache_step()
|
|
session.run(build_cache_step)
|
|
|
|
data = self._read_all_data(ws, cached_reader1, session)
|
|
self.assertEqual(sorted(data), list(range(100)))
|
|
|
|
# Read data from cache.
|
|
cached_reader2 = CachedReader(
|
|
self._build_source_reader(ws, 200), db_path,
|
|
)
|
|
build_cache_step = cached_reader2.build_cache_step()
|
|
session.run(build_cache_step)
|
|
|
|
data = self._read_all_data(ws, cached_reader2, session)
|
|
self.assertEqual(sorted(data), list(range(100)))
|
|
|
|
self._delete_path(db_path)
|
|
|
|
# We removed cache so we expect to receive data from original reader.
|
|
cached_reader3 = CachedReader(
|
|
self._build_source_reader(ws, 300), db_path,
|
|
)
|
|
build_cache_step = cached_reader3.build_cache_step()
|
|
session.run(build_cache_step)
|
|
|
|
data = self._read_all_data(ws, cached_reader3, session)
|
|
self.assertEqual(sorted(data), list(range(300)))
|
|
|
|
self._delete_path(db_path)
|
|
|
|
@unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB")
|
|
def test_db_file_reader(self):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
db_path = self._make_temp_path()
|
|
|
|
# Build a cache DB file.
|
|
cached_reader = CachedReader(
|
|
self._build_source_reader(ws, 100),
|
|
db_path=db_path,
|
|
db_type='LevelDB',
|
|
)
|
|
build_cache_step = cached_reader.build_cache_step()
|
|
session.run(build_cache_step)
|
|
|
|
# Read data from cache DB file.
|
|
db_file_reader = DBFileReader(
|
|
db_path=db_path,
|
|
db_type='LevelDB',
|
|
)
|
|
data = self._read_all_data(ws, db_file_reader, session)
|
|
self.assertEqual(sorted(data), list(range(100)))
|
|
|
|
self._delete_path(db_path)
|