Projekt_AI-Automatyczny_saper/venv/Lib/site-packages/caffe2/python/dataio_test.py
2021-06-01 17:38:31 +02:00

446 lines
17 KiB
Python

from caffe2.python.dataio import (
CompositeReader,
CompositeReaderBuilder,
ReaderBuilder,
ReaderWithDelay,
ReaderWithLimit,
ReaderWithTimeLimit,
)
from caffe2.python.dataset import Dataset
from caffe2.python.db_file_reader import DBFileReader
from caffe2.python.pipeline import pipe
from caffe2.python.schema import Struct, NewRecord, FeedRecord
from caffe2.python.session import LocalSession
from caffe2.python.task import TaskGroup, final_output, WorkspaceType
from caffe2.python.test_util import TestCase
from caffe2.python.cached_reader import CachedReader
from caffe2.python import core, workspace, schema
from caffe2.python.net_builder import ops
import numpy as np
import numpy.testing as npt
import os
import shutil
import unittest
import tempfile
def make_source_dataset(ws, size=100, offset=0, name=None):
name = name or "src"
src_init = core.Net("{}_init".format(name))
with core.NameScope(name):
src_values = Struct(('label', np.array(range(offset, offset + size))))
src_blobs = NewRecord(src_init, src_values)
src_ds = Dataset(src_blobs, name=name)
FeedRecord(src_blobs, src_values, ws)
ws.run(src_init)
return src_ds
def make_destination_dataset(ws, schema, name=None):
name = name or 'dst'
dst_init = core.Net('{}_init'.format(name))
with core.NameScope(name):
dst_ds = Dataset(schema, name=name)
dst_ds.init_empty(dst_init)
ws.run(dst_init)
return dst_ds
class TestReaderBuilder(ReaderBuilder):
def __init__(self, name, size, offset):
self._schema = schema.Struct(
('label', schema.Scalar()),
)
self._name = name
self._size = size
self._offset = offset
self._src_ds = None
def schema(self):
return self._schema
def setup(self, ws):
self._src_ds = make_source_dataset(ws, offset=self._offset, size=self._size,
name=self._name)
return {}
def new_reader(self, **kwargs):
return self._src_ds
class TestCompositeReader(TestCase):
@unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins')
def test_composite_reader(self):
ws = workspace.C.Workspace()
session = LocalSession(ws)
num_srcs = 3
names = ["src_{}".format(i) for i in range(num_srcs)]
size = 100
offsets = [i * size for i in range(num_srcs)]
src_dses = [make_source_dataset(ws, offset=offset, size=size, name=name)
for (name, offset) in zip(names, offsets)]
data = [ws.fetch_blob(str(src.field_blobs[0])) for src in src_dses]
# Sanity check we didn't overwrite anything
for d, offset in zip(data, offsets):
npt.assert_array_equal(d, range(offset, offset + size))
# Make an identically-sized empty destination dataset
dst_ds_schema = schema.Struct(
*[
(name, src_ds.content().clone_schema())
for name, src_ds in zip(names, src_dses)
]
)
dst_ds = make_destination_dataset(ws, dst_ds_schema)
with TaskGroup() as tg:
reader = CompositeReader(names,
[src_ds.reader() for src_ds in src_dses])
pipe(reader, dst_ds.writer(), num_runtime_threads=3)
session.run(tg)
for i in range(num_srcs):
written_data = sorted(
ws.fetch_blob(str(dst_ds.content()[names[i]].label())))
npt.assert_array_equal(data[i], written_data, "i: {}".format(i))
@unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins')
def test_composite_reader_builder(self):
ws = workspace.C.Workspace()
session = LocalSession(ws)
num_srcs = 3
names = ["src_{}".format(i) for i in range(num_srcs)]
size = 100
offsets = [i * size for i in range(num_srcs)]
src_ds_builders = [
TestReaderBuilder(offset=offset, size=size, name=name)
for (name, offset) in zip(names, offsets)
]
# Make an identically-sized empty destination dataset
dst_ds_schema = schema.Struct(
*[
(name, src_ds_builder.schema())
for name, src_ds_builder in zip(names, src_ds_builders)
]
)
dst_ds = make_destination_dataset(ws, dst_ds_schema)
with TaskGroup() as tg:
reader_builder = CompositeReaderBuilder(
names, src_ds_builders)
reader_builder.setup(ws=ws)
pipe(reader_builder.new_reader(), dst_ds.writer(),
num_runtime_threads=3)
session.run(tg)
for name, offset in zip(names, offsets):
written_data = sorted(
ws.fetch_blob(str(dst_ds.content()[name].label())))
npt.assert_array_equal(range(offset, offset + size), written_data,
"name: {}".format(name))
class TestReaderWithLimit(TestCase):
def test_runtime_threads(self):
ws = workspace.C.Workspace()
session = LocalSession(ws)
src_ds = make_source_dataset(ws)
totals = [None] * 3
def proc(rec):
# executed once
with ops.task_init():
counter1 = ops.CreateCounter([], ['global_counter'])
counter2 = ops.CreateCounter([], ['global_counter2'])
counter3 = ops.CreateCounter([], ['global_counter3'])
# executed once per thread
with ops.task_instance_init():
task_counter = ops.CreateCounter([], ['task_counter'])
# executed on each iteration
ops.CountUp(counter1)
ops.CountUp(task_counter)
# executed once per thread
with ops.task_instance_exit():
with ops.loop(ops.RetrieveCount(task_counter)):
ops.CountUp(counter2)
ops.CountUp(counter3)
# executed once
with ops.task_exit():
totals[0] = final_output(ops.RetrieveCount(counter1))
totals[1] = final_output(ops.RetrieveCount(counter2))
totals[2] = final_output(ops.RetrieveCount(counter3))
return rec
# Read full data set from original reader
with TaskGroup() as tg:
pipe(src_ds.reader(), num_runtime_threads=8, processor=proc)
session.run(tg)
self.assertEqual(totals[0].fetch(), 100)
self.assertEqual(totals[1].fetch(), 100)
self.assertEqual(totals[2].fetch(), 8)
# Read with a count-limited reader
with TaskGroup() as tg:
q1 = pipe(src_ds.reader(), num_runtime_threads=2)
q2 = pipe(
ReaderWithLimit(q1.reader(), num_iter=25),
num_runtime_threads=3)
pipe(q2, processor=proc, num_runtime_threads=6)
session.run(tg)
self.assertEqual(totals[0].fetch(), 25)
self.assertEqual(totals[1].fetch(), 25)
self.assertEqual(totals[2].fetch(), 6)
def _test_limit_reader_init_shared(self, size):
ws = workspace.C.Workspace()
session = LocalSession(ws)
# Make source dataset
src_ds = make_source_dataset(ws, size=size)
# Make an identically-sized empty destination Dataset
dst_ds = make_destination_dataset(ws, src_ds.content().clone_schema())
return ws, session, src_ds, dst_ds
def _test_limit_reader_shared(self, reader_class, size, expected_read_len,
expected_read_len_threshold,
expected_finish, num_threads, read_delay,
**limiter_args):
ws, session, src_ds, dst_ds = \
self._test_limit_reader_init_shared(size)
# Read without limiter
# WorkspaceType.GLOBAL is required because we are fetching
# reader.data_finished() after the TaskGroup finishes.
with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg:
if read_delay > 0:
reader = reader_class(ReaderWithDelay(src_ds.reader(),
read_delay),
**limiter_args)
else:
reader = reader_class(src_ds.reader(), **limiter_args)
pipe(reader, dst_ds.writer(), num_runtime_threads=num_threads)
session.run(tg)
read_len = len(sorted(ws.blobs[str(dst_ds.content().label())].fetch()))
# Do a fuzzy match (expected_read_len +/- expected_read_len_threshold)
# to eliminate flakiness for time-limited tests
self.assertGreaterEqual(
read_len,
expected_read_len - expected_read_len_threshold)
self.assertLessEqual(
read_len,
expected_read_len + expected_read_len_threshold)
self.assertEqual(
sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
list(range(read_len))
)
self.assertEqual(ws.blobs[str(reader.data_finished())].fetch(),
expected_finish)
def test_count_limit_reader_without_limit(self):
# No iter count specified, should read all records.
self._test_limit_reader_shared(ReaderWithLimit,
size=100,
expected_read_len=100,
expected_read_len_threshold=0,
expected_finish=True,
num_threads=8,
read_delay=0,
num_iter=None)
def test_count_limit_reader_with_zero_limit(self):
# Zero iter count specified, should read 0 records.
self._test_limit_reader_shared(ReaderWithLimit,
size=100,
expected_read_len=0,
expected_read_len_threshold=0,
expected_finish=False,
num_threads=8,
read_delay=0,
num_iter=0)
def test_count_limit_reader_with_low_limit(self):
# Read with limit smaller than size of dataset
self._test_limit_reader_shared(ReaderWithLimit,
size=100,
expected_read_len=10,
expected_read_len_threshold=0,
expected_finish=False,
num_threads=8,
read_delay=0,
num_iter=10)
def test_count_limit_reader_with_high_limit(self):
# Read with limit larger than size of dataset
self._test_limit_reader_shared(ReaderWithLimit,
size=100,
expected_read_len=100,
expected_read_len_threshold=0,
expected_finish=True,
num_threads=8,
read_delay=0,
num_iter=110)
def test_time_limit_reader_without_limit(self):
# No duration specified, should read all records.
self._test_limit_reader_shared(ReaderWithTimeLimit,
size=100,
expected_read_len=100,
expected_read_len_threshold=0,
expected_finish=True,
num_threads=8,
read_delay=0.1,
duration=0)
def test_time_limit_reader_with_short_limit(self):
# Read with insufficient time limit
size = 50
num_threads = 4
sleep_duration = 0.25
duration = 1
expected_read_len = int(round(num_threads * duration / sleep_duration))
# Because the time limit check happens before the delay + read op,
# subtract a little bit of time to ensure we don't get in an extra read
duration = duration - 0.25 * sleep_duration
# NOTE: `expected_read_len_threshold` was added because this test case
# has significant execution variation under stress. Under stress, we may
# read strictly less than the expected # of samples; anywhere from
# [0,N] where N = expected_read_len.
# Hence we set expected_read_len to N/2, plus or minus N/2.
self._test_limit_reader_shared(ReaderWithTimeLimit,
size=size,
expected_read_len=expected_read_len / 2,
expected_read_len_threshold=expected_read_len / 2,
expected_finish=False,
num_threads=num_threads,
read_delay=sleep_duration,
duration=duration)
def test_time_limit_reader_with_long_limit(self):
# Read with ample time limit
# NOTE: we don't use `expected_read_len_threshold` because the duration,
# read_delay, and # threads should be more than sufficient
self._test_limit_reader_shared(ReaderWithTimeLimit,
size=50,
expected_read_len=50,
expected_read_len_threshold=0,
expected_finish=True,
num_threads=4,
read_delay=0.2,
duration=10)
class TestDBFileReader(TestCase):
def setUp(self):
self.temp_paths = []
def tearDown(self):
# In case any test method fails, clean up temp paths.
for path in self.temp_paths:
self._delete_path(path)
@staticmethod
def _delete_path(path):
if os.path.isfile(path):
os.remove(path) # Remove file.
elif os.path.isdir(path):
shutil.rmtree(path) # Remove dir recursively.
def _make_temp_path(self):
# Make a temp path as db_path.
with tempfile.NamedTemporaryFile() as f:
temp_path = f.name
self.temp_paths.append(temp_path)
return temp_path
@staticmethod
def _build_source_reader(ws, size):
src_ds = make_source_dataset(ws, size)
return src_ds.reader()
@staticmethod
def _read_all_data(ws, reader, session):
dst_ds = make_destination_dataset(ws, reader.schema().clone_schema())
with TaskGroup() as tg:
pipe(reader, dst_ds.writer(), num_runtime_threads=8)
session.run(tg)
return ws.blobs[str(dst_ds.content().label())].fetch()
@unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB")
def test_cached_reader(self):
ws = workspace.C.Workspace()
session = LocalSession(ws)
db_path = self._make_temp_path()
# Read data for the first time.
cached_reader1 = CachedReader(
self._build_source_reader(ws, 100), db_path, loop_over=False,
)
build_cache_step = cached_reader1.build_cache_step()
session.run(build_cache_step)
data = self._read_all_data(ws, cached_reader1, session)
self.assertEqual(sorted(data), list(range(100)))
# Read data from cache.
cached_reader2 = CachedReader(
self._build_source_reader(ws, 200), db_path,
)
build_cache_step = cached_reader2.build_cache_step()
session.run(build_cache_step)
data = self._read_all_data(ws, cached_reader2, session)
self.assertEqual(sorted(data), list(range(100)))
self._delete_path(db_path)
# We removed cache so we expect to receive data from original reader.
cached_reader3 = CachedReader(
self._build_source_reader(ws, 300), db_path,
)
build_cache_step = cached_reader3.build_cache_step()
session.run(build_cache_step)
data = self._read_all_data(ws, cached_reader3, session)
self.assertEqual(sorted(data), list(range(300)))
self._delete_path(db_path)
@unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB")
def test_db_file_reader(self):
ws = workspace.C.Workspace()
session = LocalSession(ws)
db_path = self._make_temp_path()
# Build a cache DB file.
cached_reader = CachedReader(
self._build_source_reader(ws, 100),
db_path=db_path,
db_type='LevelDB',
)
build_cache_step = cached_reader.build_cache_step()
session.run(build_cache_step)
# Read data from cache DB file.
db_file_reader = DBFileReader(
db_path=db_path,
db_type='LevelDB',
)
data = self._read_all_data(ws, db_file_reader, session)
self.assertEqual(sorted(data), list(range(100)))
self._delete_path(db_path)