773 lines
24 KiB
Python
773 lines
24 KiB
Python
|
## @package workspace
|
||
|
# Module caffe2.python.workspace
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
import collections
|
||
|
import contextlib
|
||
|
from google.protobuf.message import Message
|
||
|
from multiprocessing import Process
|
||
|
import os
|
||
|
from collections import defaultdict
|
||
|
import logging
|
||
|
import numpy as np
|
||
|
from past.builtins import basestring
|
||
|
import shutil
|
||
|
import socket
|
||
|
import tempfile
|
||
|
|
||
|
from caffe2.proto import caffe2_pb2
|
||
|
from caffe2.python import scope, utils
|
||
|
from caffe2.python.lazy import TriggerLazyImport
|
||
|
|
||
|
import caffe2.python._import_c_extension as C
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
Blobs = C.blobs
|
||
|
ResetBlob = C.reset_blob
|
||
|
CreateBlob = C.create_blob
|
||
|
CurrentWorkspace = C.current_workspace
|
||
|
DeserializeBlob = C.deserialize_blob
|
||
|
GlobalInit = C.global_init
|
||
|
HasBlob = C.has_blob
|
||
|
RegisteredOperators = C.registered_operators
|
||
|
SerializeBlob = C.serialize_blob
|
||
|
SwitchWorkspace = C.switch_workspace
|
||
|
RootFolder = C.root_folder
|
||
|
Workspaces = C.workspaces
|
||
|
BenchmarkNet = C.benchmark_net
|
||
|
BenchmarkNetOnce = C.benchmark_net_once
|
||
|
GetStats = C.get_stats
|
||
|
CreateOfflineTensor = C.create_offline_tensor
|
||
|
|
||
|
operator_tracebacks = defaultdict(dict)
|
||
|
|
||
|
is_asan = C.is_asan
|
||
|
has_cuda_support = C.has_cuda_support
|
||
|
has_hip_support = C.has_hip_support
|
||
|
has_gpu_support = C.has_gpu_support
|
||
|
if has_cuda_support:
|
||
|
GpuDeviceType = caffe2_pb2.CUDA
|
||
|
NumCudaDevices = C.num_cuda_devices
|
||
|
# This is a duplicate of NumCudaDevices. Remove
|
||
|
# NumCudaDevices once replaced everywhere in the code
|
||
|
NumGpuDevices = C.num_cuda_devices
|
||
|
GetCUDAVersion = C.get_cuda_version
|
||
|
GetCuDNNVersion = C.get_cudnn_version
|
||
|
|
||
|
def GetGpuPeerAccessPattern():
|
||
|
return np.asarray(C.get_cuda_peer_access_pattern())
|
||
|
|
||
|
GetDeviceProperties = C.get_device_properties
|
||
|
GetGPUMemoryInfo = C.get_gpu_memory_info
|
||
|
else:
|
||
|
NumCudaDevices = lambda: 0 # noqa
|
||
|
GetCUDAVersion = lambda: 0 # noqa
|
||
|
GetCuDNNVersion = lambda: 0 # noqa
|
||
|
|
||
|
if has_hip_support:
|
||
|
GpuDeviceType = caffe2_pb2.HIP
|
||
|
NumGpuDevices = C.num_hip_devices
|
||
|
GetHIPVersion = C.get_hip_version
|
||
|
|
||
|
def GetGpuPeerAccessPattern():
|
||
|
return np.asarray(C.get_hip_peer_access_pattern())
|
||
|
GetDeviceProperties = C.get_device_properties
|
||
|
GetGPUMemoryInfo = C.get_gpu_memory_info
|
||
|
|
||
|
if not has_gpu_support:
|
||
|
# setting cuda as the default GpuDeviceType as some tests
|
||
|
# like core, scope tests use GpuDeviceType even without gpu support
|
||
|
GpuDeviceType = caffe2_pb2.CUDA
|
||
|
NumGpuDevices = lambda: 0 # noqa
|
||
|
GetDeviceProperties = lambda x: None # noqa
|
||
|
GetGpuPeerAccessPattern = lambda: np.array([]) # noqa
|
||
|
GetGPUMemoryInfo = lambda: None # noqa
|
||
|
|
||
|
IsNUMAEnabled = C.is_numa_enabled
|
||
|
GetNumNUMANodes = C.get_num_numa_nodes
|
||
|
GetBlobNUMANode = C.get_blob_numa_node
|
||
|
GetBlobSizeBytes = C.get_blob_size_bytes
|
||
|
|
||
|
|
||
|
def FillRandomNetworkInputs(net, input_dims, input_types):
|
||
|
C.fill_random_network_inputs(net.Proto().SerializeToString(), input_dims, input_types)
|
||
|
|
||
|
|
||
|
def _GetFreeFlaskPort():
|
||
|
"""Get a free flask port."""
|
||
|
# We will prefer to use 5000. If not, we will then pick a random port.
|
||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
|
result = sock.connect_ex(('127.0.0.1', 5000))
|
||
|
if result == 0:
|
||
|
return 5000
|
||
|
else:
|
||
|
s = socket.socket()
|
||
|
s.bind(('', 0))
|
||
|
port = s.getsockname()[1]
|
||
|
s.close()
|
||
|
# Race condition: between the interval we close the socket and actually
|
||
|
# start a mint process, another process might have occupied the port. We
|
||
|
# don't do much here as this is mostly for convenience in research
|
||
|
# rather than 24x7 service.
|
||
|
return port
|
||
|
|
||
|
def StartMint(root_folder=None, port=None):
|
||
|
"""Start a mint instance.
|
||
|
|
||
|
TODO(Yangqing): this does not work well under ipython yet. According to
|
||
|
https://github.com/ipython/ipython/issues/5862
|
||
|
writing up some fix is a todo item.
|
||
|
"""
|
||
|
from caffe2.python.mint import app
|
||
|
if root_folder is None:
|
||
|
# Get the root folder from the current workspace
|
||
|
root_folder = C.root_folder()
|
||
|
if port is None:
|
||
|
port = _GetFreeFlaskPort()
|
||
|
process = Process(
|
||
|
target=app.main,
|
||
|
args=(
|
||
|
['-p', str(port), '-r', root_folder],
|
||
|
)
|
||
|
)
|
||
|
process.start()
|
||
|
print('Mint running at http://{}:{}'.format(socket.getfqdn(), port))
|
||
|
return process
|
||
|
|
||
|
|
||
|
def StringifyProto(obj):
|
||
|
"""Stringify a protocol buffer object.
|
||
|
|
||
|
Inputs:
|
||
|
obj: a protocol buffer object, or a Pycaffe2 object that has a Proto()
|
||
|
function.
|
||
|
Outputs:
|
||
|
string: the output protobuf string.
|
||
|
Raises:
|
||
|
AttributeError: if the passed in object does not have the right attribute.
|
||
|
"""
|
||
|
if isinstance(obj, basestring):
|
||
|
return obj
|
||
|
else:
|
||
|
if isinstance(obj, Message):
|
||
|
# First, see if this object is a protocol buffer, which we can
|
||
|
# simply serialize with the SerializeToString() call.
|
||
|
return obj.SerializeToString()
|
||
|
elif hasattr(obj, 'Proto'):
|
||
|
return obj.Proto().SerializeToString()
|
||
|
else:
|
||
|
raise ValueError("Unexpected argument to StringifyProto of type " +
|
||
|
type(obj).__name__)
|
||
|
|
||
|
|
||
|
def ResetWorkspace(root_folder=None):
|
||
|
if root_folder is None:
|
||
|
# Reset the workspace, but keep the current root folder setting.
|
||
|
return C.reset_workspace(C.root_folder())
|
||
|
else:
|
||
|
if not os.path.exists(root_folder):
|
||
|
os.makedirs(root_folder)
|
||
|
return C.reset_workspace(root_folder)
|
||
|
|
||
|
|
||
|
def CreateNet(net, overwrite=False, input_blobs=None):
|
||
|
TriggerLazyImport()
|
||
|
if input_blobs is None:
|
||
|
input_blobs = []
|
||
|
for input_blob in input_blobs:
|
||
|
C.create_blob(input_blob)
|
||
|
return CallWithExceptionIntercept(
|
||
|
C.create_net,
|
||
|
C.Workspace.current._last_failed_op_net_position,
|
||
|
GetNetName(net),
|
||
|
StringifyProto(net), overwrite,
|
||
|
)
|
||
|
|
||
|
|
||
|
def Predictor(init_net, predict_net):
|
||
|
return C.Predictor(StringifyProto(init_net), StringifyProto(predict_net))
|
||
|
|
||
|
|
||
|
def GetOperatorCost(operator, blobs):
|
||
|
return C.get_operator_cost(StringifyProto(operator), blobs)
|
||
|
|
||
|
|
||
|
def RunOperatorOnce(operator):
|
||
|
return C.run_operator_once(StringifyProto(operator))
|
||
|
|
||
|
|
||
|
def RunOperatorMultiple(operator, num_runs):
|
||
|
return C.run_operator_multiple(StringifyProto(operator), num_runs)
|
||
|
|
||
|
|
||
|
def RunOperatorsOnce(operators):
|
||
|
for op in operators:
|
||
|
success = RunOperatorOnce(op)
|
||
|
if not success:
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
|
||
|
def ClearGlobalNetObserver():
|
||
|
return C.clear_global_net_observer()
|
||
|
|
||
|
|
||
|
def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs):
|
||
|
try:
|
||
|
return func(*args, **kwargs)
|
||
|
except Exception:
|
||
|
op_id = op_id_fetcher()
|
||
|
net_tracebacks = operator_tracebacks.get(net_name, None)
|
||
|
logger.warning(
|
||
|
'Original python traceback for operator `{}` in network '
|
||
|
'`{}` in exception above (most recent call last):'.format(
|
||
|
op_id, net_name))
|
||
|
if net_tracebacks and op_id in net_tracebacks:
|
||
|
tb = net_tracebacks[op_id]
|
||
|
for line in reversed(tb):
|
||
|
logger.warning(' File "{}", line {}, in {}'.format(
|
||
|
line[0], line[1], line[2]))
|
||
|
raise
|
||
|
|
||
|
|
||
|
def RunNetOnce(net):
|
||
|
return CallWithExceptionIntercept(
|
||
|
C.run_net_once,
|
||
|
C.Workspace.current._last_failed_op_net_position,
|
||
|
GetNetName(net),
|
||
|
StringifyProto(net),
|
||
|
)
|
||
|
|
||
|
|
||
|
def RunNet(name, num_iter=1, allow_fail=False):
|
||
|
"""Runs a given net.
|
||
|
|
||
|
Inputs:
|
||
|
name: the name of the net, or a reference to the net.
|
||
|
num_iter: number of iterations to run
|
||
|
allow_fail: if True, does not assert on net exec failure but returns False
|
||
|
Returns:
|
||
|
True or an exception.
|
||
|
"""
|
||
|
return CallWithExceptionIntercept(
|
||
|
C.run_net,
|
||
|
C.Workspace.current._last_failed_op_net_position,
|
||
|
GetNetName(name),
|
||
|
StringifyNetName(name), num_iter, allow_fail,
|
||
|
)
|
||
|
|
||
|
|
||
|
def RunPlan(plan_or_step):
|
||
|
# TODO(jiayq): refactor core.py/workspace.py to avoid circular deps
|
||
|
import caffe2.python.core as core
|
||
|
if isinstance(plan_or_step, core.ExecutionStep):
|
||
|
plan_or_step = core.Plan(plan_or_step)
|
||
|
return C.run_plan(StringifyProto(plan_or_step))
|
||
|
|
||
|
|
||
|
def RunPlanInBackground(plan_or_step):
|
||
|
# TODO(jiayq): refactor core.py/workspace.py to avoid circular deps
|
||
|
import caffe2.python.core as core
|
||
|
if isinstance(plan_or_step, core.ExecutionStep):
|
||
|
plan_or_step = core.Plan(plan_or_step)
|
||
|
return C.run_plan_in_background(StringifyProto(plan_or_step))
|
||
|
|
||
|
|
||
|
def InferShapesAndTypes(nets, blob_dimensions=None, nets_proto=False,
|
||
|
blob_types=None):
|
||
|
"""Infers the shapes and types for the specified nets.
|
||
|
|
||
|
Inputs:
|
||
|
nets: the list of nets
|
||
|
blob_dimensions (optional): a dictionary of blobs and their dimensions.
|
||
|
If not specified, the workspace blobs are used.
|
||
|
nets_proto (optional): a boolean flag indicating whether the protobuffer
|
||
|
representation is passed to the routine.
|
||
|
Returns:
|
||
|
A tuple of (shapes, types) dictionaries keyed by blob name.
|
||
|
"""
|
||
|
if nets_proto:
|
||
|
net_protos = [StringifyProto(n) for n in nets]
|
||
|
else:
|
||
|
net_protos = [StringifyProto(n.Proto()) for n in nets]
|
||
|
if blob_dimensions is None:
|
||
|
assert blob_types is None
|
||
|
blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos)
|
||
|
elif blob_types is None:
|
||
|
blobdesc_prototxt = C.infer_shapes_and_types_from_map(
|
||
|
net_protos, blob_dimensions
|
||
|
)
|
||
|
else:
|
||
|
blobdesc_prototxt = C.infer_shapes_and_types_from_map(
|
||
|
net_protos, blob_dimensions, blob_types
|
||
|
)
|
||
|
blobdesc_proto = caffe2_pb2.TensorShapes()
|
||
|
blobdesc_proto.ParseFromString(blobdesc_prototxt)
|
||
|
shapes = {}
|
||
|
types = {}
|
||
|
for ts in blobdesc_proto.shapes:
|
||
|
if not ts.unknown_shape:
|
||
|
shapes[ts.name] = list(ts.dims)
|
||
|
types[ts.name] = ts.data_type
|
||
|
|
||
|
return (shapes, types)
|
||
|
|
||
|
|
||
|
def _StringifyName(name, expected_type):
|
||
|
if isinstance(name, basestring):
|
||
|
return name
|
||
|
assert type(name).__name__ == expected_type, \
|
||
|
"Expected a string or %s" % expected_type
|
||
|
return str(name)
|
||
|
|
||
|
|
||
|
def StringifyBlobName(name):
|
||
|
return _StringifyName(name, "BlobReference")
|
||
|
|
||
|
|
||
|
def StringifyNetName(name):
|
||
|
return _StringifyName(name, "Net")
|
||
|
|
||
|
|
||
|
def GetNetName(net):
|
||
|
if isinstance(net, basestring):
|
||
|
return net
|
||
|
if type(net).__name__ == "Net" or type(net).__name__ == "NetWithShapeInference":
|
||
|
return net.Name()
|
||
|
if isinstance(net, caffe2_pb2.NetDef):
|
||
|
return net.name
|
||
|
raise Exception("Not a Net object: {}".format(str(net)))
|
||
|
|
||
|
|
||
|
def FeedBlob(name, arr, device_option=None):
|
||
|
"""Feeds a blob into the workspace.
|
||
|
|
||
|
Inputs:
|
||
|
name: the name of the blob.
|
||
|
arr: either a TensorProto object or a numpy array object to be fed into
|
||
|
the workspace.
|
||
|
device_option (optional): the device option to feed the data with.
|
||
|
Returns:
|
||
|
True or False, stating whether the feed is successful.
|
||
|
"""
|
||
|
ws = C.Workspace.current
|
||
|
return _Workspace_feed_blob(ws, name, arr, device_option)
|
||
|
|
||
|
|
||
|
def FetchBlobs(names):
|
||
|
"""Fetches a list of blobs from the workspace.
|
||
|
|
||
|
Inputs:
|
||
|
names: list of names of blobs - strings or BlobReferences
|
||
|
Returns:
|
||
|
list of fetched blobs
|
||
|
"""
|
||
|
return [FetchBlob(name) for name in names]
|
||
|
|
||
|
|
||
|
def FetchBlob(name):
|
||
|
"""Fetches a blob from the workspace.
|
||
|
|
||
|
Inputs:
|
||
|
name: the name of the blob - a string or a BlobReference
|
||
|
Returns:
|
||
|
Fetched blob (numpy array or string) if successful
|
||
|
"""
|
||
|
result = C.fetch_blob(StringifyBlobName(name))
|
||
|
if isinstance(result, tuple):
|
||
|
raise TypeError(
|
||
|
"Use FetchInt8Blob to fetch Int8 Blob {}".format(
|
||
|
StringifyBlobName(name)
|
||
|
)
|
||
|
)
|
||
|
return result
|
||
|
|
||
|
|
||
|
def FetchTorch(name):
|
||
|
ws = C.Workspace.current
|
||
|
return ws.blobs[name].to_torch()
|
||
|
|
||
|
|
||
|
Int8Tensor = collections.namedtuple(
|
||
|
'Int8Tensor', ['data', 'scale', 'zero_point']
|
||
|
)
|
||
|
|
||
|
|
||
|
def FetchInt8Blob(name):
|
||
|
"""Fetches an Int8 blob from the workspace. It shared backend implementation
|
||
|
with FetchBlob but it is recommended when fetching Int8 Blobs
|
||
|
|
||
|
Inputs:
|
||
|
name: the name of the Int8 blob - a string or a BlobReference
|
||
|
Returns:
|
||
|
data: int8 numpy array, data
|
||
|
scale: float, fake quantization scale
|
||
|
zero_point: int, fake quantization offset
|
||
|
"""
|
||
|
result = C.fetch_blob(StringifyBlobName(name))
|
||
|
assert isinstance(result, tuple), \
|
||
|
'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
|
||
|
StringifyBlobName(name))
|
||
|
return Int8Tensor(*result)
|
||
|
|
||
|
|
||
|
def FetchInt8BlobRealVal(name):
|
||
|
"""Fetches an Int8 blob from the workspace and return its real value representation.
|
||
|
|
||
|
Inputs:
|
||
|
name: the name of the Int8 blob - a string or a BlobReference
|
||
|
Returns:
|
||
|
real value representation of int8 numpy array
|
||
|
"""
|
||
|
result = C.fetch_blob(StringifyBlobName(name))
|
||
|
assert isinstance(result, tuple), \
|
||
|
'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
|
||
|
StringifyBlobName(name))
|
||
|
int8_blob = Int8Tensor(*result)
|
||
|
return (int8_blob.data.astype(np.int32) - int(int8_blob.zero_point)).astype(
|
||
|
np.float32) * int8_blob.scale
|
||
|
|
||
|
|
||
|
def _Workspace_fetch_int8_blob(ws, name):
|
||
|
"""Fetches an Int8 blob from the workspace. It shared backend implementation
|
||
|
with FetchBlob but it is recommended when fetching Int8 Blobs
|
||
|
|
||
|
Inputs:
|
||
|
name: the name of the Int8 blob - a string or a BlobReference
|
||
|
Returns:
|
||
|
data: int8 numpy array, data
|
||
|
scale: float, fake quantization scale
|
||
|
zero_point: int, fake quantization offset
|
||
|
"""
|
||
|
result = ws.fetch_blob(name)
|
||
|
assert isinstance(result, tuple), \
|
||
|
'You are not fetching an Int8Blob {}. Please use fetch_blob'.format(
|
||
|
StringifyBlobName(name))
|
||
|
return Int8Tensor(*result)
|
||
|
|
||
|
|
||
|
C.Workspace.fetch_int8_blob = _Workspace_fetch_int8_blob
|
||
|
|
||
|
|
||
|
def ApplyTransform(transform_key, net):
|
||
|
"""Apply a Transform to a NetDef protobuf object, and returns the new
|
||
|
transformed NetDef.
|
||
|
|
||
|
Inputs:
|
||
|
transform_key: the name of the transform, as it is stored in the registry
|
||
|
net: a NetDef protobuf object
|
||
|
Returns:
|
||
|
Transformed NetDef protobuf object.
|
||
|
"""
|
||
|
transformed_net = caffe2_pb2.NetDef()
|
||
|
transformed_str = C.apply_transform(
|
||
|
str(transform_key).encode('utf-8'),
|
||
|
net.SerializeToString(),
|
||
|
)
|
||
|
transformed_net.ParseFromString(transformed_str)
|
||
|
return transformed_net
|
||
|
|
||
|
|
||
|
def ApplyTransformIfFaster(transform_key, net, init_net, **kwargs):
|
||
|
"""Apply a Transform to a NetDef protobuf object, and returns the new
|
||
|
transformed NetDef, only if it runs faster than the original.
|
||
|
|
||
|
The runs are performed on the current active workspace (gWorkspace).
|
||
|
You should initialize that workspace before making a call to this function.
|
||
|
|
||
|
Inputs:
|
||
|
transform_key: the name of the transform, as it is stored in the registry
|
||
|
net: a NetDef protobuf object
|
||
|
init_net: The net to initialize the workspace.
|
||
|
warmup_runs (optional):
|
||
|
Determines how many times the net is run before testing.
|
||
|
Will be 5 by default.
|
||
|
main_runs (optional):
|
||
|
Determines how many times the net is run during testing.
|
||
|
Will be 10 by default.
|
||
|
improvement_threshold (optional):
|
||
|
Determines the factor which the new net needs to be faster
|
||
|
in order to replace the old. Will be 1.01 by default.
|
||
|
|
||
|
Returns:
|
||
|
Either a Transformed NetDef protobuf object, or the original netdef.
|
||
|
"""
|
||
|
|
||
|
warmup_runs = kwargs['warmup_runs'] if 'warmup_runs' in kwargs else 5
|
||
|
main_runs = kwargs['main_runs'] if 'main_runs' in kwargs else 10
|
||
|
improvement_threshold = kwargs['improvement_threshold'] \
|
||
|
if 'improvement_threshold' in kwargs else 1.01
|
||
|
|
||
|
transformed_net = caffe2_pb2.NetDef()
|
||
|
transformed_str = C.apply_transform_if_faster(
|
||
|
str(transform_key).encode('utf-8'),
|
||
|
net.SerializeToString(),
|
||
|
init_net.SerializeToString(),
|
||
|
warmup_runs,
|
||
|
main_runs,
|
||
|
float(improvement_threshold),
|
||
|
)
|
||
|
transformed_net.ParseFromString(transformed_str)
|
||
|
return transformed_net
|
||
|
|
||
|
|
||
|
def GetNameScope():
|
||
|
"""Return the current namescope string. To be used to fetch blobs"""
|
||
|
return scope.CurrentNameScope()
|
||
|
|
||
|
|
||
|
class _BlobDict(object):
|
||
|
"""Provides python dict compatible way to do fetching and feeding"""
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
return FetchBlob(key)
|
||
|
|
||
|
def __setitem__(self, key, value):
|
||
|
return FeedBlob(key, value)
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(C.blobs())
|
||
|
|
||
|
def __iter__(self):
|
||
|
return C.blobs().__iter__()
|
||
|
|
||
|
def __contains__(self, item):
|
||
|
return C.has_blob(item)
|
||
|
|
||
|
|
||
|
blobs = _BlobDict()
|
||
|
|
||
|
|
||
|
################################################################################
|
||
|
# Utilities for immediate mode
|
||
|
#
|
||
|
# Caffe2's immediate mode implements the following behavior: between the two
|
||
|
# function calls StartImmediate() and StopImmediate(), for any operator that is
|
||
|
# called through CreateOperator(), we will also run that operator in a workspace
|
||
|
# that is specific to the immediate mode. The user is explicitly expected to
|
||
|
# make sure that these ops have proper inputs and outputs, i.e. one should not
|
||
|
# run an op where an external input is not created or fed.
|
||
|
#
|
||
|
# Users can use FeedImmediate() and FetchImmediate() to interact with blobs
|
||
|
# in the immediate workspace.
|
||
|
#
|
||
|
# Once StopImmediate() is called, all contents in the immediate workspace is
|
||
|
# freed up so one can continue using normal runs.
|
||
|
#
|
||
|
# The immediate mode is solely for debugging purposes and support will be very
|
||
|
# sparse.
|
||
|
################################################################################
|
||
|
|
||
|
_immediate_mode = False
|
||
|
_immediate_workspace_name = "_CAFFE2_IMMEDIATE"
|
||
|
_immediate_root_folder = ''
|
||
|
|
||
|
|
||
|
def IsImmediate():
|
||
|
return _immediate_mode
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def WorkspaceGuard(workspace_name):
|
||
|
current = CurrentWorkspace()
|
||
|
SwitchWorkspace(workspace_name, True)
|
||
|
yield
|
||
|
SwitchWorkspace(current)
|
||
|
|
||
|
|
||
|
def StartImmediate(i_know=False):
|
||
|
global _immediate_mode
|
||
|
global _immediate_root_folder
|
||
|
if IsImmediate():
|
||
|
# already in immediate mode. We will kill the previous one
|
||
|
# and start from fresh.
|
||
|
StopImmediate()
|
||
|
_immediate_mode = True
|
||
|
with WorkspaceGuard(_immediate_workspace_name):
|
||
|
_immediate_root_folder = tempfile.mkdtemp()
|
||
|
ResetWorkspace(_immediate_root_folder)
|
||
|
if i_know:
|
||
|
# if the user doesn't want to see the warning message, sure...
|
||
|
return
|
||
|
print("""
|
||
|
Enabling immediate mode in caffe2 python is an EXTREMELY EXPERIMENTAL
|
||
|
feature and may very easily go wrong. This is because Caffe2 uses a
|
||
|
declarative way of defining operators and models, which is essentially
|
||
|
not meant to run things in an interactive way. Read the following carefully
|
||
|
to make sure that you understand the caveats.
|
||
|
|
||
|
(1) You need to make sure that the sequences of operators you create are
|
||
|
actually runnable sequentially. For example, if you create an op that takes
|
||
|
an input X, somewhere earlier you should have already created X.
|
||
|
|
||
|
(2) Caffe2 immediate uses one single workspace, so if the set of operators
|
||
|
you run are intended to be under different workspaces, they will not run.
|
||
|
To create boundaries between such use cases, you can call FinishImmediate()
|
||
|
and StartImmediate() manually to flush out everything no longer needed.
|
||
|
|
||
|
(3) Underlying objects held by the immediate mode may interfere with your
|
||
|
normal run. For example, if there is a leveldb that you opened in immediate
|
||
|
mode and did not close, your main run will fail because leveldb does not
|
||
|
support double opening. Immediate mode may also occupy a lot of memory esp.
|
||
|
on GPUs. Call FinishImmediate() as soon as possible when you no longer
|
||
|
need it.
|
||
|
|
||
|
(4) Immediate is designed to be slow. Every immediate call implicitly
|
||
|
creates a temp operator object, runs it, and destroys the operator. This
|
||
|
slow-speed run is by design to discourage abuse. For most use cases other
|
||
|
than debugging, do NOT turn on immediate mode.
|
||
|
|
||
|
(5) If there is anything FATAL happening in the underlying C++ code, the
|
||
|
immediate mode will immediately (pun intended) cause the runtime to crash.
|
||
|
|
||
|
Thus you should use immediate mode with extra care. If you still would
|
||
|
like to, have fun [https://xkcd.com/149/].
|
||
|
""")
|
||
|
|
||
|
|
||
|
def StopImmediate():
|
||
|
"""Stops an immediate mode run."""
|
||
|
# Phew, that was a dangerous ride.
|
||
|
global _immediate_mode
|
||
|
global _immediate_root_folder
|
||
|
if not IsImmediate():
|
||
|
return
|
||
|
with WorkspaceGuard(_immediate_workspace_name):
|
||
|
ResetWorkspace()
|
||
|
shutil.rmtree(_immediate_root_folder)
|
||
|
_immediate_root_folder = ''
|
||
|
_immediate_mode = False
|
||
|
|
||
|
|
||
|
def ImmediateBlobs():
|
||
|
with WorkspaceGuard(_immediate_workspace_name):
|
||
|
return Blobs()
|
||
|
|
||
|
|
||
|
def RunOperatorImmediate(op):
|
||
|
with WorkspaceGuard(_immediate_workspace_name):
|
||
|
RunOperatorOnce(op)
|
||
|
|
||
|
|
||
|
def FetchImmediate(*args, **kwargs):
|
||
|
with WorkspaceGuard(_immediate_workspace_name):
|
||
|
return FetchBlob(*args, **kwargs)
|
||
|
|
||
|
|
||
|
def FeedImmediate(*args, **kwargs):
|
||
|
with WorkspaceGuard(_immediate_workspace_name):
|
||
|
return FeedBlob(*args, **kwargs)
|
||
|
|
||
|
|
||
|
# C.Workspace methods.
|
||
|
|
||
|
def _Workspace_create_net_with_exception_intercept(ws, net, overwrite=False):
|
||
|
return CallWithExceptionIntercept(
|
||
|
ws._create_net,
|
||
|
ws._last_failed_op_net_position,
|
||
|
GetNetName(net),
|
||
|
StringifyProto(net), overwrite,
|
||
|
)
|
||
|
|
||
|
|
||
|
def _Workspace_run(ws, obj):
|
||
|
if hasattr(obj, 'Proto'):
|
||
|
obj = obj.Proto()
|
||
|
if isinstance(obj, caffe2_pb2.PlanDef):
|
||
|
return ws._run_plan(obj.SerializeToString())
|
||
|
if isinstance(obj, caffe2_pb2.NetDef):
|
||
|
return CallWithExceptionIntercept(
|
||
|
ws._run_net,
|
||
|
ws._last_failed_op_net_position,
|
||
|
GetNetName(obj),
|
||
|
obj.SerializeToString(),
|
||
|
)
|
||
|
# return ws._run_net(obj.SerializeToString())
|
||
|
if isinstance(obj, caffe2_pb2.OperatorDef):
|
||
|
return ws._run_operator(obj.SerializeToString())
|
||
|
raise ValueError(
|
||
|
"Don't know how to do Workspace.run() on {}".format(type(obj)))
|
||
|
|
||
|
|
||
|
def _Workspace_feed_blob(ws, name, arr, device_option=None):
|
||
|
if type(arr) is caffe2_pb2.TensorProto:
|
||
|
arr = utils.Caffe2TensorToNumpyArray(arr)
|
||
|
if type(arr) is np.ndarray and arr.dtype.kind in 'SU':
|
||
|
# Plain NumPy strings are weird, let's use objects instead
|
||
|
arr = arr.astype(np.object)
|
||
|
|
||
|
if device_option is None:
|
||
|
device_option = scope.CurrentDeviceScope()
|
||
|
|
||
|
if device_option and device_option.device_type == caffe2_pb2.CUDA:
|
||
|
if arr.dtype == np.dtype('float64'):
|
||
|
logger.warning(
|
||
|
"CUDA operators do not support 64-bit doubles, " +
|
||
|
"please use arr.astype(np.float32) or np.int32 for ints." +
|
||
|
" Blob: {}".format(name) +
|
||
|
" type: {}".format(str(arr.dtype))
|
||
|
)
|
||
|
|
||
|
name = StringifyBlobName(name)
|
||
|
if device_option is not None:
|
||
|
return ws.create_blob(name).feed(arr, device_option)
|
||
|
else:
|
||
|
return ws.create_blob(name).feed(arr)
|
||
|
|
||
|
|
||
|
def _Workspace_remove_blob(ws, blob):
|
||
|
ws._remove_blob(str(blob))
|
||
|
|
||
|
|
||
|
Workspace = C.Workspace
|
||
|
Workspace.create_net = _Workspace_create_net_with_exception_intercept
|
||
|
Workspace.run = _Workspace_run
|
||
|
Workspace.feed_blob = _Workspace_feed_blob
|
||
|
Workspace.remove_blob = _Workspace_remove_blob
|
||
|
|
||
|
# C.Blob methods.
|
||
|
|
||
|
|
||
|
def _Blob_feed(blob, arg, device_option=None):
|
||
|
# conservative type check to avoid unnecessary import
|
||
|
if type(arg).__name__ == 'Tensor' and type(arg).__module__ == 'torch':
|
||
|
import torch
|
||
|
if isinstance(arg, torch.Tensor):
|
||
|
assert device_option is None, \
|
||
|
"device_option doesn't make sense with PyTorch tensors"
|
||
|
handle = torch._C._tensor_impl_raw_handle(arg)
|
||
|
blob._wrap_tensor_impl(handle)
|
||
|
return True # _feed() returns True for some reason
|
||
|
if device_option is not None:
|
||
|
device_option = StringifyProto(device_option)
|
||
|
return blob._feed(arg, device_option)
|
||
|
|
||
|
|
||
|
C.Blob.feed = _Blob_feed
|
||
|
|
||
|
|
||
|
def _Tensor_to_torch(tensor):
|
||
|
"""
|
||
|
PyTorch tensor interop (TensorCPU methods)
|
||
|
|
||
|
Can be accessed as:
|
||
|
workspace.Workspace.current.blobs['foo'].tensor().to_torch()
|
||
|
"""
|
||
|
# avoiding circular dependency
|
||
|
import torch
|
||
|
handle = tensor._tensor_impl_raw_handle()
|
||
|
return torch._C._wrap_tensor_impl(handle)
|
||
|
|
||
|
C.TensorCPU.to_torch = _Tensor_to_torch
|
||
|
|
||
|
|
||
|
def _Blob_to_torch(blob):
|
||
|
if not blob.is_tensor():
|
||
|
raise RuntimeError("Blob has to be a tensor")
|
||
|
return blob.as_tensor().to_torch()
|
||
|
|
||
|
C.Blob.to_torch = _Blob_to_torch
|