Projekt_AI-Automatyczny_saper/venv/Lib/site-packages/caffe2/python/brew.py
2021-06-01 17:38:31 +02:00

139 lines
4.6 KiB
Python

## @package model_helper_api
# Module caffe2.python.model_helper_api
import sys
import copy
import inspect
from past.builtins import basestring
from caffe2.python.model_helper import ModelHelper
# flake8: noqa
from caffe2.python.helpers.algebra import *
from caffe2.python.helpers.arg_scope import *
from caffe2.python.helpers.array_helpers import *
from caffe2.python.helpers.control_ops import *
from caffe2.python.helpers.conv import *
from caffe2.python.helpers.db_input import *
from caffe2.python.helpers.dropout import *
from caffe2.python.helpers.elementwise_linear import *
from caffe2.python.helpers.fc import *
from caffe2.python.helpers.nonlinearity import *
from caffe2.python.helpers.normalization import *
from caffe2.python.helpers.pooling import *
from caffe2.python.helpers.quantization import *
from caffe2.python.helpers.tools import *
from caffe2.python.helpers.train import *
class HelperWrapper(object):
_registry = {
'arg_scope': arg_scope,
'fc': fc,
'packed_fc': packed_fc,
'fc_decomp': fc_decomp,
'fc_sparse': fc_sparse,
'fc_prune': fc_prune,
'dropout': dropout,
'max_pool': max_pool,
'average_pool': average_pool,
'max_pool_with_index' : max_pool_with_index,
'lrn': lrn,
'softmax': softmax,
'instance_norm': instance_norm,
'spatial_bn': spatial_bn,
'spatial_gn': spatial_gn,
'moments_with_running_stats': moments_with_running_stats,
'relu': relu,
'prelu': prelu,
'tanh': tanh,
'concat': concat,
'depth_concat': depth_concat,
'sum': sum,
'reduce_sum': reduce_sum,
'sub': sub,
'arg_min': arg_min,
'transpose': transpose,
'iter': iter,
'accuracy': accuracy,
'conv': conv,
'conv_nd': conv_nd,
'conv_transpose': conv_transpose,
'group_conv': group_conv,
'group_conv_deprecated': group_conv_deprecated,
'image_input': image_input,
'video_input': video_input,
'add_weight_decay': add_weight_decay,
'elementwise_linear': elementwise_linear,
'layer_norm': layer_norm,
'mat_mul' : mat_mul,
'batch_mat_mul' : batch_mat_mul,
'cond' : cond,
'loop' : loop,
'db_input' : db_input,
'fused_8bit_rowwise_quantized_to_float' : fused_8bit_rowwise_quantized_to_float,
'sparse_lengths_sum_4bit_rowwise_sparse': sparse_lengths_sum_4bit_rowwise_sparse,
}
def __init__(self, wrapped):
self.wrapped = wrapped
def __getattr__(self, helper_name):
if helper_name not in self._registry:
raise AttributeError(
"Helper function {} not "
"registered.".format(helper_name)
)
def scope_wrapper(*args, **kwargs):
new_kwargs = {}
if helper_name != 'arg_scope':
if len(args) > 0 and isinstance(args[0], ModelHelper):
model = args[0]
elif 'model' in kwargs:
model = kwargs['model']
else:
raise RuntimeError(
"The first input of helper function should be model. " \
"Or you can provide it in kwargs as model=<your_model>.")
new_kwargs = copy.deepcopy(model.arg_scope)
func = self._registry[helper_name]
var_names, _, varkw, _= inspect.getargspec(func)
if varkw is None:
# this helper function does not take in random **kwargs
new_kwargs = {
var_name: new_kwargs[var_name]
for var_name in var_names if var_name in new_kwargs
}
cur_scope = get_current_scope()
new_kwargs.update(cur_scope.get(helper_name, {}))
new_kwargs.update(kwargs)
return func(*args, **new_kwargs)
scope_wrapper.__name__ = helper_name
return scope_wrapper
def Register(self, helper):
name = helper.__name__
if name in self._registry:
raise AttributeError(
"Helper {} already exists. Please change your "
"helper name.".format(name)
)
self._registry[name] = helper
def has_helper(self, helper_or_helper_name):
helper_name = (
helper_or_helper_name
if isinstance(helper_or_helper_name, basestring) else
helper_or_helper_name.__name__
)
return helper_name in self._registry
sys.modules[__name__] = HelperWrapper(sys.modules[__name__])