## @package model_helper_api # Module caffe2.python.model_helper_api import sys import copy import inspect from past.builtins import basestring from caffe2.python.model_helper import ModelHelper # flake8: noqa from caffe2.python.helpers.algebra import * from caffe2.python.helpers.arg_scope import * from caffe2.python.helpers.array_helpers import * from caffe2.python.helpers.control_ops import * from caffe2.python.helpers.conv import * from caffe2.python.helpers.db_input import * from caffe2.python.helpers.dropout import * from caffe2.python.helpers.elementwise_linear import * from caffe2.python.helpers.fc import * from caffe2.python.helpers.nonlinearity import * from caffe2.python.helpers.normalization import * from caffe2.python.helpers.pooling import * from caffe2.python.helpers.quantization import * from caffe2.python.helpers.tools import * from caffe2.python.helpers.train import * class HelperWrapper(object): _registry = { 'arg_scope': arg_scope, 'fc': fc, 'packed_fc': packed_fc, 'fc_decomp': fc_decomp, 'fc_sparse': fc_sparse, 'fc_prune': fc_prune, 'dropout': dropout, 'max_pool': max_pool, 'average_pool': average_pool, 'max_pool_with_index' : max_pool_with_index, 'lrn': lrn, 'softmax': softmax, 'instance_norm': instance_norm, 'spatial_bn': spatial_bn, 'spatial_gn': spatial_gn, 'moments_with_running_stats': moments_with_running_stats, 'relu': relu, 'prelu': prelu, 'tanh': tanh, 'concat': concat, 'depth_concat': depth_concat, 'sum': sum, 'reduce_sum': reduce_sum, 'sub': sub, 'arg_min': arg_min, 'transpose': transpose, 'iter': iter, 'accuracy': accuracy, 'conv': conv, 'conv_nd': conv_nd, 'conv_transpose': conv_transpose, 'group_conv': group_conv, 'group_conv_deprecated': group_conv_deprecated, 'image_input': image_input, 'video_input': video_input, 'add_weight_decay': add_weight_decay, 'elementwise_linear': elementwise_linear, 'layer_norm': layer_norm, 'mat_mul' : mat_mul, 'batch_mat_mul' : batch_mat_mul, 'cond' : cond, 'loop' : loop, 'db_input' : db_input, 'fused_8bit_rowwise_quantized_to_float' : fused_8bit_rowwise_quantized_to_float, 'sparse_lengths_sum_4bit_rowwise_sparse': sparse_lengths_sum_4bit_rowwise_sparse, } def __init__(self, wrapped): self.wrapped = wrapped def __getattr__(self, helper_name): if helper_name not in self._registry: raise AttributeError( "Helper function {} not " "registered.".format(helper_name) ) def scope_wrapper(*args, **kwargs): new_kwargs = {} if helper_name != 'arg_scope': if len(args) > 0 and isinstance(args[0], ModelHelper): model = args[0] elif 'model' in kwargs: model = kwargs['model'] else: raise RuntimeError( "The first input of helper function should be model. " \ "Or you can provide it in kwargs as model=.") new_kwargs = copy.deepcopy(model.arg_scope) func = self._registry[helper_name] var_names, _, varkw, _= inspect.getargspec(func) if varkw is None: # this helper function does not take in random **kwargs new_kwargs = { var_name: new_kwargs[var_name] for var_name in var_names if var_name in new_kwargs } cur_scope = get_current_scope() new_kwargs.update(cur_scope.get(helper_name, {})) new_kwargs.update(kwargs) return func(*args, **new_kwargs) scope_wrapper.__name__ = helper_name return scope_wrapper def Register(self, helper): name = helper.__name__ if name in self._registry: raise AttributeError( "Helper {} already exists. Please change your " "helper name.".format(name) ) self._registry[name] = helper def has_helper(self, helper_or_helper_name): helper_name = ( helper_or_helper_name if isinstance(helper_or_helper_name, basestring) else helper_or_helper_name.__name__ ) return helper_name in self._registry sys.modules[__name__] = HelperWrapper(sys.modules[__name__])