134 lines
4.3 KiB
Python
134 lines
4.3 KiB
Python
## @package cached_reader
|
|
# Module caffe2.python.cached_reader
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
from caffe2.python import core
|
|
from caffe2.python.db_file_reader import DBFileReader
|
|
from caffe2.python.pipeline import pipe
|
|
from caffe2.python.task import Cluster, TaskGroup
|
|
|
|
|
|
class CachedReader(DBFileReader):
|
|
|
|
default_name_suffix = 'cached_reader'
|
|
|
|
"""Reader with persistent in-file cache.
|
|
|
|
Example usage:
|
|
cached_reader = CachedReader(
|
|
reader,
|
|
db_path='/tmp/cache.db',
|
|
db_type='LevelDB',
|
|
)
|
|
build_cache_step = cached_reader.build_cache_step()
|
|
with LocalSession() as session:
|
|
session.run(build_cache_step)
|
|
|
|
Every time new CachedReader is created, it's expected that
|
|
db_path exists before calling .setup_ex(...) and .read(...).
|
|
|
|
If db_path doesn't exist, it's expected build_cache_step to be called
|
|
first to build a cache at db_path.
|
|
|
|
build_cache_step will check existence of provided db_path and in case
|
|
it's missing will initialize it by reading data from original reader.
|
|
All consequent attempts to read will ignore original reader
|
|
(i.e. no additional data will be read from it).
|
|
|
|
Args:
|
|
original_reader: Reader.
|
|
If provided, it's the original reader used to build the cache file.
|
|
db_path: str.
|
|
|
|
Optional Args:
|
|
db_type: str. DB type of file. A db_type is registed by
|
|
`REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`.
|
|
Default to 'LevelDB'.
|
|
name: str or None. Name of CachedReader.
|
|
Optional name to prepend to blobs that will store the data.
|
|
Default to '<db_name>_<default_name_suffix>'.
|
|
batch_size: int.
|
|
How many examples are read for each time the read_net is run.
|
|
Defaults to 100.
|
|
loop_over: bool.
|
|
If True given, will go through examples in random order endlessly.
|
|
Defaults to False.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
original_reader,
|
|
db_path,
|
|
db_type='LevelDB',
|
|
name=None,
|
|
batch_size=100,
|
|
loop_over=False,
|
|
):
|
|
assert original_reader is not None, "original_reader can't be None"
|
|
self.original_reader = original_reader
|
|
|
|
super(CachedReader, self).__init__(
|
|
db_path,
|
|
db_type,
|
|
name,
|
|
batch_size,
|
|
loop_over,
|
|
)
|
|
|
|
def _init_reader_schema(self, *args, **kwargs):
|
|
"""Prepare the reader schema.
|
|
|
|
Since an original reader is given,
|
|
use it's schema as ground truth.
|
|
|
|
Returns:
|
|
schema: schema.Struct. Used in Reader.__init__(...).
|
|
"""
|
|
return self.original_reader._schema
|
|
|
|
def build_cache_step(self, overwrite=False):
|
|
"""Build a step for generating cache DB file.
|
|
|
|
If self.db_path exists and not overwritting, build an empty step.
|
|
Overwise, build a step as follows.
|
|
Pipe original reader to the _DatasetWriter,
|
|
so that dataset field blobs are populated.
|
|
Then save these blobs into a file.
|
|
|
|
Args:
|
|
overwrite: bool. If true, ignore the existing file
|
|
and build a new one overwritting the existing one anyway.
|
|
|
|
Returns:
|
|
build_cache_step: ExecutionStep.
|
|
The step to be run for building a cache DB file.
|
|
"""
|
|
if os.path.exists(self.db_path) and not overwrite:
|
|
# cache already exists, no need to rebuild it
|
|
return core.execution_step('build_step', [])
|
|
|
|
init_net = core.Net('init')
|
|
self._init_field_blobs_as_empty(init_net)
|
|
with Cluster(), core.NameScope(self.name), TaskGroup() as copy_tg:
|
|
pipe(self.original_reader, self.ds.writer(), num_threads=16)
|
|
copy_step = copy_tg.to_task().get_step()
|
|
save_net = core.Net('save')
|
|
self._save_field_blobs_to_db_file(save_net)
|
|
|
|
return core.execution_step('build_cache', [init_net, copy_step, save_net])
|
|
|
|
def _save_field_blobs_to_db_file(self, net):
|
|
"""Save dataset field blobs to a DB file at db_path"""
|
|
net.Save(
|
|
self.ds.get_blobs(),
|
|
[],
|
|
db=self.db_path,
|
|
db_type=self.db_type,
|
|
blob_name_overrides=self.ds.field_names(),
|
|
absolute_path=True,
|
|
)
|