Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/tools/selective_registration_header_lib.py

249 lines
8.2 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Computes a header file to be used with SELECTIVE_REGISTRATION.
See the executable wrapper, print_selective_registration_header.py, for more
information.
"""
import json
import os
import sys
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import _pywrap_kernel_registry
# Usually, we use each graph node to induce registration of an op and
# corresponding kernel; nodes without a corresponding kernel (perhaps due to
# attr types) generate a warning but are otherwise ignored. Ops in this set are
# registered even if there's no corresponding kernel.
OPS_WITHOUT_KERNEL_ALLOWLIST = frozenset([
# AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see
# core/common_runtime/accumulate_n_optimizer.cc.
'AccumulateNV2'
])
FLEX_PREFIX = b'Flex'
FLEX_PREFIX_LENGTH = len(FLEX_PREFIX)
def _get_ops_from_ops_list(input_file):
"""Gets the ops and kernels needed from the ops list file."""
ops = set()
ops_list_str = gfile.GFile(input_file, 'r').read()
if not ops_list_str:
raise Exception('Input file should not be empty')
ops_list = json.loads(ops_list_str)
for op, kernel in ops_list:
op_and_kernel = (op, kernel if kernel else None)
ops.add(op_and_kernel)
return ops
def _get_ops_from_graphdef(graph_def):
"""Gets the ops and kernels needed from the tensorflow model."""
ops = set()
ops.update(_get_ops_from_nodedefs(graph_def.node))
for function in graph_def.library.function:
ops.update(_get_ops_from_nodedefs(function.node_def))
return ops
def get_ops_from_nodedef(node_def):
"""Gets the op and kernel needed from the given NodeDef.
Args:
node_def: TF NodeDef to get op/kernel information.
Returns:
A tuple of (op_name, kernel_name). If the op is not in the allowlist of ops
without kernel and there is no kernel found, then return None.
"""
if not node_def.device:
node_def.device = '/cpu:0'
kernel_class = _pywrap_kernel_registry.TryFindKernelClass(
node_def.SerializeToString())
op = str(node_def.op)
if kernel_class or op in OPS_WITHOUT_KERNEL_ALLOWLIST:
return (op, str(kernel_class.decode('utf-8')) if kernel_class else None)
else:
tf_logging.warning('Warning: no kernel found for op %s', op)
return None
def _get_ops_from_nodedefs(node_defs):
"""Gets the ops and kernels needed from the list of NodeDef.
If a NodeDef's op is not in the allowlist of ops without kernel and there is
no kernel found for this NodeDef, then skip that NodeDef and proceed to the
next one.
Args:
node_defs: list of NodeDef's to get op/kernel information.
Returns:
A set of (op_name, kernel_name) tuples.
"""
ops = set()
for node_def in node_defs:
op_and_kernel = get_ops_from_nodedef(node_def)
if op_and_kernel:
ops.add(op_and_kernel)
return ops
def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
"""Gets the ops and kernels needed from the model files."""
ops = set()
for proto_file in proto_files:
tf_logging.info('Loading proto file %s', proto_file)
# Load ops list file.
if proto_fileformat == 'ops_list':
ops = ops.union(_get_ops_from_ops_list(proto_file))
continue
# Load GraphDef.
file_data = gfile.GFile(proto_file, 'rb').read()
if proto_fileformat == 'rawproto':
graph_def = graph_pb2.GraphDef.FromString(file_data)
else:
assert proto_fileformat == 'textproto'
graph_def = text_format.Parse(file_data, graph_pb2.GraphDef())
ops = ops.union(_get_ops_from_graphdef(graph_def))
# Add default ops.
if default_ops_str and default_ops_str != 'all':
for s in default_ops_str.split(','):
op, kernel = s.split(':')
op_and_kernel = (op, kernel)
if op_and_kernel not in ops:
ops.add(op_and_kernel)
return sorted(ops)
def get_header_from_ops_and_kernels(ops_and_kernels,
include_all_ops_and_kernels):
"""Returns a header for use with tensorflow SELECTIVE_REGISTRATION.
Args:
ops_and_kernels: a set of (op_name, kernel_class_name) pairs to include.
include_all_ops_and_kernels: if True, ops_and_kernels is ignored and all op
kernels are included.
Returns:
the string of the header that should be written as ops_to_register.h.
"""
ops_and_kernels = sorted(ops_and_kernels)
ops = set(op for op, _ in ops_and_kernels)
result_list = []
def append(s):
result_list.append(s)
_, script_name = os.path.split(sys.argv[0])
append('// This file was autogenerated by %s' % script_name)
append('#ifndef OPS_TO_REGISTER')
append('#define OPS_TO_REGISTER')
if include_all_ops_and_kernels:
append('#define SHOULD_REGISTER_OP(op) true')
append('#define SHOULD_REGISTER_OP_KERNEL(clz) true')
append('#define SHOULD_REGISTER_OP_GRADIENT true')
else:
line = """
namespace {
constexpr const char* skip(const char* x) {
return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x;
}
constexpr bool isequal(const char* x, const char* y) {
return (*skip(x) && *skip(y))
? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1))
: (!*skip(x) && !*skip(y));
}
template<int N>
struct find_in {
static constexpr bool f(const char* x, const char* const y[N]) {
return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1);
}
};
template<>
struct find_in<0> {
static constexpr bool f(const char* x, const char* const y[]) {
return false;
}
};
} // end namespace
"""
line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n'
for _, kernel_class in ops_and_kernels:
if kernel_class is None:
continue
line += '"%s",\n' % kernel_class
line += '};'
append(line)
append('#define SHOULD_REGISTER_OP_KERNEL(clz) '
'(find_in<sizeof(kNecessaryOpKernelClasses) '
'/ sizeof(*kNecessaryOpKernelClasses)>::f(clz, '
'kNecessaryOpKernelClasses))')
append('')
append('constexpr inline bool ShouldRegisterOp(const char op[]) {')
append(' return false')
for op in sorted(ops):
append(' || isequal(op, "%s")' % op)
append(' ;')
append('}')
append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)')
append('')
append('#define SHOULD_REGISTER_OP_GRADIENT ' +
('true' if 'SymbolicGradient' in ops else 'false'))
append('#endif')
return '\n'.join(result_list)
def get_header(graphs,
proto_fileformat='rawproto',
default_ops='NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'):
"""Computes a header for use with tensorflow SELECTIVE_REGISTRATION.
Args:
graphs: a list of paths to GraphDef files to include.
proto_fileformat: optional format of proto file, either 'textproto',
'rawproto' (default) or ops_list. The ops_list is the file contain the
list of ops in JSON format, Ex: "[["Transpose", "TransposeCpuOp"]]".
default_ops: optional comma-separated string of operator:kernel pairs to
always include implementation for. Pass 'all' to have all operators and
kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'.
Returns:
the string of the header that should be written as ops_to_register.h.
"""
ops_and_kernels = get_ops_and_kernels(proto_fileformat, graphs, default_ops)
if not ops_and_kernels:
print('Error reading graph!')
return 1
return get_header_from_ops_and_kernels(ops_and_kernels, default_ops == 'all')