137 lines
4.4 KiB
Python
137 lines
4.4 KiB
Python
## @package queue_util
|
|
# Module caffe2.python.queue_util
|
|
|
|
|
|
|
|
|
|
|
|
from caffe2.python import core, dataio
|
|
from caffe2.python.task import TaskGroup
|
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class _QueueReader(dataio.Reader):
|
|
def __init__(self, wrapper, num_dequeue_records=1):
|
|
assert wrapper.schema is not None, (
|
|
'Queue needs a schema in order to be read from.')
|
|
dataio.Reader.__init__(self, wrapper.schema())
|
|
self._wrapper = wrapper
|
|
self._num_dequeue_records = num_dequeue_records
|
|
|
|
def setup_ex(self, init_net, exit_net):
|
|
exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
|
|
|
|
def read_ex(self, local_init_net, local_finish_net):
|
|
self._wrapper._new_reader(local_init_net)
|
|
dequeue_net = core.Net('dequeue')
|
|
fields, status_blob = dequeue(
|
|
dequeue_net,
|
|
self._wrapper.queue(),
|
|
len(self.schema().field_names()),
|
|
field_names=self.schema().field_names(),
|
|
num_records=self._num_dequeue_records)
|
|
return [dequeue_net], status_blob, fields
|
|
|
|
def read(self, net):
|
|
net, _, fields = self.read_ex(net, None)
|
|
return net, fields
|
|
|
|
|
|
class _QueueWriter(dataio.Writer):
|
|
def __init__(self, wrapper):
|
|
self._wrapper = wrapper
|
|
|
|
def setup_ex(self, init_net, exit_net):
|
|
exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
|
|
|
|
def write_ex(self, fields, local_init_net, local_finish_net, status):
|
|
self._wrapper._new_writer(self.schema(), local_init_net)
|
|
enqueue_net = core.Net('enqueue')
|
|
enqueue(enqueue_net, self._wrapper.queue(), fields, status)
|
|
return [enqueue_net]
|
|
|
|
|
|
class QueueWrapper(dataio.Pipe):
|
|
def __init__(self, handler, schema=None, num_dequeue_records=1):
|
|
dataio.Pipe.__init__(self, schema, TaskGroup.LOCAL_SETUP)
|
|
self._queue = handler
|
|
self._num_dequeue_records = num_dequeue_records
|
|
|
|
def reader(self):
|
|
return _QueueReader(
|
|
self, num_dequeue_records=self._num_dequeue_records)
|
|
|
|
def writer(self):
|
|
return _QueueWriter(self)
|
|
|
|
def queue(self):
|
|
return self._queue
|
|
|
|
|
|
class Queue(QueueWrapper):
|
|
def __init__(self, capacity, schema=None, name='queue',
|
|
num_dequeue_records=1):
|
|
# find a unique blob name for the queue
|
|
net = core.Net(name)
|
|
queue_blob = net.AddExternalInput(net.NextName('handler'))
|
|
QueueWrapper.__init__(
|
|
self, queue_blob, schema, num_dequeue_records=num_dequeue_records)
|
|
self.capacity = capacity
|
|
self._setup_done = False
|
|
|
|
def setup(self, global_init_net):
|
|
assert self._schema, 'This queue does not have a schema.'
|
|
self._setup_done = True
|
|
global_init_net.CreateBlobsQueue(
|
|
[],
|
|
[self._queue],
|
|
capacity=self.capacity,
|
|
num_blobs=len(self._schema.field_names()),
|
|
field_names=self._schema.field_names())
|
|
|
|
|
|
def enqueue(net, queue, data_blobs, status=None):
|
|
if status is None:
|
|
status = net.NextName('status')
|
|
# Enqueueing moved the data into the queue;
|
|
# duplication will result in data corruption
|
|
queue_blobs = []
|
|
for blob in data_blobs:
|
|
if blob not in queue_blobs:
|
|
queue_blobs.append(blob)
|
|
else:
|
|
logger.warning("Need to copy blob {} to enqueue".format(blob))
|
|
queue_blobs.append(net.Copy(blob))
|
|
results = net.SafeEnqueueBlobs([queue] + queue_blobs, queue_blobs + [status])
|
|
return results[-1]
|
|
|
|
|
|
def dequeue(net, queue, num_blobs, status=None, field_names=None,
|
|
num_records=1):
|
|
if field_names is not None:
|
|
assert len(field_names) == num_blobs
|
|
data_names = [net.NextName(name) for name in field_names]
|
|
else:
|
|
data_names = [net.NextName('data', i) for i in range(num_blobs)]
|
|
if status is None:
|
|
status = net.NextName('status')
|
|
results = net.SafeDequeueBlobs(
|
|
queue, data_names + [status], num_records=num_records)
|
|
results = list(results)
|
|
status_blob = results.pop(-1)
|
|
return results, status_blob
|
|
|
|
|
|
def close_queue(step, *queues):
|
|
close_net = core.Net("close_queue_net")
|
|
for queue in queues:
|
|
close_net.CloseBlobsQueue([queue], 0)
|
|
close_step = core.execution_step("%s_step" % str(close_net), close_net)
|
|
return core.execution_step(
|
|
"%s_wraper_step" % str(close_net),
|
|
[step, close_step])
|