241 lines
7.4 KiB
Python
241 lines
7.4 KiB
Python
## @package cnn
|
|
# Module caffe2.python.cnn
|
|
|
|
|
|
|
|
|
|
|
|
from caffe2.python import brew, workspace
|
|
from caffe2.python.model_helper import ModelHelper
|
|
from caffe2.proto import caffe2_pb2
|
|
import logging
|
|
|
|
|
|
class CNNModelHelper(ModelHelper):
|
|
"""A helper model so we can write CNN models more easily, without having to
|
|
manually define parameter initializations and operators separately.
|
|
"""
|
|
|
|
def __init__(self, order="NCHW", name=None,
|
|
use_cudnn=True, cudnn_exhaustive_search=False,
|
|
ws_nbytes_limit=None, init_params=True,
|
|
skip_sparse_optim=False,
|
|
param_model=None):
|
|
logging.warning(
|
|
"[====DEPRECATE WARNING====]: you are creating an "
|
|
"object from CNNModelHelper class which will be deprecated soon. "
|
|
"Please use ModelHelper object with brew module. For more "
|
|
"information, please refer to caffe2.ai and python/brew.py, "
|
|
"python/brew_test.py for more information."
|
|
)
|
|
|
|
cnn_arg_scope = {
|
|
'order': order,
|
|
'use_cudnn': use_cudnn,
|
|
'cudnn_exhaustive_search': cudnn_exhaustive_search,
|
|
}
|
|
if ws_nbytes_limit:
|
|
cnn_arg_scope['ws_nbytes_limit'] = ws_nbytes_limit
|
|
super(CNNModelHelper, self).__init__(
|
|
skip_sparse_optim=skip_sparse_optim,
|
|
name="CNN" if name is None else name,
|
|
init_params=init_params,
|
|
param_model=param_model,
|
|
arg_scope=cnn_arg_scope,
|
|
)
|
|
|
|
self.order = order
|
|
self.use_cudnn = use_cudnn
|
|
self.cudnn_exhaustive_search = cudnn_exhaustive_search
|
|
self.ws_nbytes_limit = ws_nbytes_limit
|
|
if self.order != "NHWC" and self.order != "NCHW":
|
|
raise ValueError(
|
|
"Cannot understand the CNN storage order %s." % self.order
|
|
)
|
|
|
|
def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, **kwargs):
|
|
return brew.image_input(
|
|
self,
|
|
blob_in,
|
|
blob_out,
|
|
order=self.order,
|
|
use_gpu_transform=use_gpu_transform,
|
|
**kwargs
|
|
)
|
|
|
|
def VideoInput(self, blob_in, blob_out, **kwargs):
|
|
return brew.video_input(
|
|
self,
|
|
blob_in,
|
|
blob_out,
|
|
**kwargs
|
|
)
|
|
|
|
def PadImage(self, blob_in, blob_out, **kwargs):
|
|
# TODO(wyiming): remove this dummy helper later
|
|
self.net.PadImage(blob_in, blob_out, **kwargs)
|
|
|
|
def ConvNd(self, *args, **kwargs):
|
|
return brew.conv_nd(
|
|
self,
|
|
*args,
|
|
use_cudnn=self.use_cudnn,
|
|
order=self.order,
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
**kwargs
|
|
)
|
|
|
|
def Conv(self, *args, **kwargs):
|
|
return brew.conv(
|
|
self,
|
|
*args,
|
|
use_cudnn=self.use_cudnn,
|
|
order=self.order,
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
**kwargs
|
|
)
|
|
|
|
def ConvTranspose(self, *args, **kwargs):
|
|
return brew.conv_transpose(
|
|
self,
|
|
*args,
|
|
use_cudnn=self.use_cudnn,
|
|
order=self.order,
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
**kwargs
|
|
)
|
|
|
|
def GroupConv(self, *args, **kwargs):
|
|
return brew.group_conv(
|
|
self,
|
|
*args,
|
|
use_cudnn=self.use_cudnn,
|
|
order=self.order,
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
**kwargs
|
|
)
|
|
|
|
def GroupConv_Deprecated(self, *args, **kwargs):
|
|
return brew.group_conv_deprecated(
|
|
self,
|
|
*args,
|
|
use_cudnn=self.use_cudnn,
|
|
order=self.order,
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
**kwargs
|
|
)
|
|
|
|
def FC(self, *args, **kwargs):
|
|
return brew.fc(self, *args, **kwargs)
|
|
|
|
def PackedFC(self, *args, **kwargs):
|
|
return brew.packed_fc(self, *args, **kwargs)
|
|
|
|
def FC_Prune(self, *args, **kwargs):
|
|
return brew.fc_prune(self, *args, **kwargs)
|
|
|
|
def FC_Decomp(self, *args, **kwargs):
|
|
return brew.fc_decomp(self, *args, **kwargs)
|
|
|
|
def FC_Sparse(self, *args, **kwargs):
|
|
return brew.fc_sparse(self, *args, **kwargs)
|
|
|
|
def Dropout(self, *args, **kwargs):
|
|
return brew.dropout(
|
|
self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
|
|
)
|
|
|
|
def LRN(self, *args, **kwargs):
|
|
return brew.lrn(
|
|
self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
|
|
)
|
|
|
|
def Softmax(self, *args, **kwargs):
|
|
return brew.softmax(self, *args, use_cudnn=self.use_cudnn, **kwargs)
|
|
|
|
def SpatialBN(self, *args, **kwargs):
|
|
return brew.spatial_bn(self, *args, order=self.order, **kwargs)
|
|
|
|
def SpatialGN(self, *args, **kwargs):
|
|
return brew.spatial_gn(self, *args, order=self.order, **kwargs)
|
|
|
|
def InstanceNorm(self, *args, **kwargs):
|
|
return brew.instance_norm(self, *args, order=self.order, **kwargs)
|
|
|
|
def Relu(self, *args, **kwargs):
|
|
return brew.relu(
|
|
self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
|
|
)
|
|
|
|
def PRelu(self, *args, **kwargs):
|
|
return brew.prelu(self, *args, **kwargs)
|
|
|
|
def Concat(self, *args, **kwargs):
|
|
return brew.concat(self, *args, order=self.order, **kwargs)
|
|
|
|
def DepthConcat(self, *args, **kwargs):
|
|
"""The old depth concat function - we should move to use concat."""
|
|
print("DepthConcat is deprecated. use Concat instead.")
|
|
return self.Concat(*args, **kwargs)
|
|
|
|
def Sum(self, *args, **kwargs):
|
|
return brew.sum(self, *args, **kwargs)
|
|
|
|
def Transpose(self, *args, **kwargs):
|
|
return brew.transpose(self, *args, use_cudnn=self.use_cudnn, **kwargs)
|
|
|
|
def Iter(self, *args, **kwargs):
|
|
return brew.iter(self, *args, **kwargs)
|
|
|
|
def Accuracy(self, *args, **kwargs):
|
|
return brew.accuracy(self, *args, **kwargs)
|
|
|
|
def MaxPool(self, *args, **kwargs):
|
|
return brew.max_pool(
|
|
self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
|
|
)
|
|
|
|
def MaxPoolWithIndex(self, *args, **kwargs):
|
|
return brew.max_pool_with_index(self, *args, order=self.order, **kwargs)
|
|
|
|
def AveragePool(self, *args, **kwargs):
|
|
return brew.average_pool(
|
|
self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
|
|
)
|
|
|
|
@property
|
|
def XavierInit(self):
|
|
return ('XavierFill', {})
|
|
|
|
def ConstantInit(self, value):
|
|
return ('ConstantFill', dict(value=value))
|
|
|
|
@property
|
|
def MSRAInit(self):
|
|
return ('MSRAFill', {})
|
|
|
|
@property
|
|
def ZeroInit(self):
|
|
return ('ConstantFill', {})
|
|
|
|
def AddWeightDecay(self, weight_decay):
|
|
return brew.add_weight_decay(self, weight_decay)
|
|
|
|
@property
|
|
def CPU(self):
|
|
device_option = caffe2_pb2.DeviceOption()
|
|
device_option.device_type = caffe2_pb2.CPU
|
|
return device_option
|
|
|
|
@property
|
|
def GPU(self, gpu_id=0):
|
|
device_option = caffe2_pb2.DeviceOption()
|
|
device_option.device_type = workspace.GpuDeviceType
|
|
device_option.device_id = gpu_id
|
|
return device_option
|