139 lines
4.6 KiB
Python
139 lines
4.6 KiB
Python
## @package model_helper_api
|
|
# Module caffe2.python.model_helper_api
|
|
|
|
|
|
|
|
|
|
|
|
import sys
|
|
import copy
|
|
import inspect
|
|
from past.builtins import basestring
|
|
from caffe2.python.model_helper import ModelHelper
|
|
|
|
# flake8: noqa
|
|
from caffe2.python.helpers.algebra import *
|
|
from caffe2.python.helpers.arg_scope import *
|
|
from caffe2.python.helpers.array_helpers import *
|
|
from caffe2.python.helpers.control_ops import *
|
|
from caffe2.python.helpers.conv import *
|
|
from caffe2.python.helpers.db_input import *
|
|
from caffe2.python.helpers.dropout import *
|
|
from caffe2.python.helpers.elementwise_linear import *
|
|
from caffe2.python.helpers.fc import *
|
|
from caffe2.python.helpers.nonlinearity import *
|
|
from caffe2.python.helpers.normalization import *
|
|
from caffe2.python.helpers.pooling import *
|
|
from caffe2.python.helpers.quantization import *
|
|
from caffe2.python.helpers.tools import *
|
|
from caffe2.python.helpers.train import *
|
|
|
|
|
|
class HelperWrapper(object):
|
|
_registry = {
|
|
'arg_scope': arg_scope,
|
|
'fc': fc,
|
|
'packed_fc': packed_fc,
|
|
'fc_decomp': fc_decomp,
|
|
'fc_sparse': fc_sparse,
|
|
'fc_prune': fc_prune,
|
|
'dropout': dropout,
|
|
'max_pool': max_pool,
|
|
'average_pool': average_pool,
|
|
'max_pool_with_index' : max_pool_with_index,
|
|
'lrn': lrn,
|
|
'softmax': softmax,
|
|
'instance_norm': instance_norm,
|
|
'spatial_bn': spatial_bn,
|
|
'spatial_gn': spatial_gn,
|
|
'moments_with_running_stats': moments_with_running_stats,
|
|
'relu': relu,
|
|
'prelu': prelu,
|
|
'tanh': tanh,
|
|
'concat': concat,
|
|
'depth_concat': depth_concat,
|
|
'sum': sum,
|
|
'reduce_sum': reduce_sum,
|
|
'sub': sub,
|
|
'arg_min': arg_min,
|
|
'transpose': transpose,
|
|
'iter': iter,
|
|
'accuracy': accuracy,
|
|
'conv': conv,
|
|
'conv_nd': conv_nd,
|
|
'conv_transpose': conv_transpose,
|
|
'group_conv': group_conv,
|
|
'group_conv_deprecated': group_conv_deprecated,
|
|
'image_input': image_input,
|
|
'video_input': video_input,
|
|
'add_weight_decay': add_weight_decay,
|
|
'elementwise_linear': elementwise_linear,
|
|
'layer_norm': layer_norm,
|
|
'mat_mul' : mat_mul,
|
|
'batch_mat_mul' : batch_mat_mul,
|
|
'cond' : cond,
|
|
'loop' : loop,
|
|
'db_input' : db_input,
|
|
'fused_8bit_rowwise_quantized_to_float' : fused_8bit_rowwise_quantized_to_float,
|
|
'sparse_lengths_sum_4bit_rowwise_sparse': sparse_lengths_sum_4bit_rowwise_sparse,
|
|
}
|
|
|
|
def __init__(self, wrapped):
|
|
self.wrapped = wrapped
|
|
|
|
def __getattr__(self, helper_name):
|
|
if helper_name not in self._registry:
|
|
raise AttributeError(
|
|
"Helper function {} not "
|
|
"registered.".format(helper_name)
|
|
)
|
|
|
|
def scope_wrapper(*args, **kwargs):
|
|
new_kwargs = {}
|
|
if helper_name != 'arg_scope':
|
|
if len(args) > 0 and isinstance(args[0], ModelHelper):
|
|
model = args[0]
|
|
elif 'model' in kwargs:
|
|
model = kwargs['model']
|
|
else:
|
|
raise RuntimeError(
|
|
"The first input of helper function should be model. " \
|
|
"Or you can provide it in kwargs as model=<your_model>.")
|
|
new_kwargs = copy.deepcopy(model.arg_scope)
|
|
func = self._registry[helper_name]
|
|
var_names, _, varkw, _= inspect.getargspec(func)
|
|
if varkw is None:
|
|
# this helper function does not take in random **kwargs
|
|
new_kwargs = {
|
|
var_name: new_kwargs[var_name]
|
|
for var_name in var_names if var_name in new_kwargs
|
|
}
|
|
|
|
cur_scope = get_current_scope()
|
|
new_kwargs.update(cur_scope.get(helper_name, {}))
|
|
new_kwargs.update(kwargs)
|
|
return func(*args, **new_kwargs)
|
|
|
|
scope_wrapper.__name__ = helper_name
|
|
return scope_wrapper
|
|
|
|
def Register(self, helper):
|
|
name = helper.__name__
|
|
if name in self._registry:
|
|
raise AttributeError(
|
|
"Helper {} already exists. Please change your "
|
|
"helper name.".format(name)
|
|
)
|
|
self._registry[name] = helper
|
|
|
|
def has_helper(self, helper_or_helper_name):
|
|
helper_name = (
|
|
helper_or_helper_name
|
|
if isinstance(helper_or_helper_name, basestring) else
|
|
helper_or_helper_name.__name__
|
|
)
|
|
return helper_name in self._registry
|
|
|
|
|
|
sys.modules[__name__] = HelperWrapper(sys.modules[__name__])
|