Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/framework/c_api_util.py

285 lines
8.4 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for using the TensorFlow C API."""
import contextlib
from tensorflow.core.framework import api_def_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
class AlreadyGarbageCollectedError(Exception):
def __init__(self, name, obj_type):
super(AlreadyGarbageCollectedError,
self).__init__(f"{name} of type {obj_type} has already been garbage "
f"collected and cannot be called.")
# FIXME(b/235488206): Convert all Scoped objects to the context manager
# to protect against deletion during use when the object is attached to
# an attribute.
class UniquePtr(object):
"""Wrapper around single-ownership C-API objects that handles deletion."""
__slots__ = ["_obj", "deleter", "name", "type_name"]
def __init__(self, name, obj, deleter):
# '_' prefix marks _obj private, but unclear if it is required also to
# maintain a special CPython destruction order.
self._obj = obj
self.name = name
# Note: when we're destructing the global context (i.e when the process is
# terminating) we may have already deleted other modules. By capturing the
# DeleteGraph function here, we retain the ability to cleanly destroy the
# graph at shutdown, which satisfies leak checkers.
self.deleter = deleter
self.type_name = str(type(obj))
@contextlib.contextmanager
def get(self):
"""Yields the managed C-API Object, guaranteeing aliveness.
This is a context manager. Inside the context the C-API object is
guaranteed to be alive.
Raises:
AlreadyGarbageCollectedError: if the object is already deleted.
"""
# Thread-safety: self.__del__ never runs during the call of this function
# because there is a reference to self from the argument list.
if self._obj is None:
raise AlreadyGarbageCollectedError(self.name, self.type_name)
yield self._obj
def __del__(self):
obj = self._obj
if obj is not None:
self._obj = None
self.deleter(obj)
class ScopedTFStatus(object):
"""Wrapper around TF_Status that handles deletion."""
__slots__ = ["status"]
def __init__(self):
self.status = c_api.TF_NewStatus()
def __del__(self):
# Note: when we're destructing the global context (i.e when the process is
# terminating) we can have already deleted other modules.
if c_api is not None and c_api.TF_DeleteStatus is not None:
c_api.TF_DeleteStatus(self.status)
class ScopedTFGraph(UniquePtr):
"""Wrapper around TF_Graph that handles deletion."""
def __init__(self, name):
super(ScopedTFGraph, self).__init__(
name, obj=c_api.TF_NewGraph(), deleter=c_api.TF_DeleteGraph)
class ScopedTFImportGraphDefOptions(object):
"""Wrapper around TF_ImportGraphDefOptions that handles deletion."""
__slots__ = ["options"]
def __init__(self):
self.options = c_api.TF_NewImportGraphDefOptions()
def __del__(self):
# Note: when we're destructing the global context (i.e when the process is
# terminating) we can have already deleted other modules.
if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None:
c_api.TF_DeleteImportGraphDefOptions(self.options)
class ScopedTFImportGraphDefResults(object):
"""Wrapper around TF_ImportGraphDefOptions that handles deletion."""
__slots__ = ["results"]
def __init__(self, results):
self.results = results
def __del__(self):
# Note: when we're destructing the global context (i.e when the process is
# terminating) we can have already deleted other modules.
if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None:
c_api.TF_DeleteImportGraphDefResults(self.results)
class ScopedTFFunction(UniquePtr):
"""Wrapper around TF_Function that handles deletion."""
def __init__(self, func, name):
super(ScopedTFFunction, self).__init__(
name=name, obj=func, deleter=c_api.TF_DeleteFunction)
class ScopedTFBuffer(object):
"""An internal class to help manage the TF_Buffer lifetime."""
__slots__ = ["buffer"]
def __init__(self, buf_string):
self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string))
def __del__(self):
c_api.TF_DeleteBuffer(self.buffer)
class ApiDefMap(object):
"""Wrapper around Tf_ApiDefMap that handles querying and deletion.
The OpDef protos are also stored in this class so that they could
be queried by op name.
"""
__slots__ = ["_api_def_map", "_op_per_name"]
def __init__(self):
op_def_proto = op_def_pb2.OpList()
buf = c_api.TF_GetAllOpList()
try:
op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
self._api_def_map = c_api.TF_NewApiDefMap(buf)
finally:
c_api.TF_DeleteBuffer(buf)
self._op_per_name = {}
for op in op_def_proto.op:
self._op_per_name[op.name] = op
def __del__(self):
# Note: when we're destructing the global context (i.e when the process is
# terminating) we can have already deleted other modules.
if c_api is not None and c_api.TF_DeleteApiDefMap is not None:
c_api.TF_DeleteApiDefMap(self._api_def_map)
def put_api_def(self, text):
c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text))
def get_api_def(self, op_name):
api_def_proto = api_def_pb2.ApiDef()
buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name))
try:
api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
finally:
c_api.TF_DeleteBuffer(buf)
return api_def_proto
def get_op_def(self, op_name):
if op_name in self._op_per_name:
return self._op_per_name[op_name]
raise ValueError(f"No op_def found for op name {op_name}.")
def op_names(self):
return self._op_per_name.keys()
@tf_contextlib.contextmanager
def tf_buffer(data=None):
"""Context manager that creates and deletes TF_Buffer.
Example usage:
with tf_buffer() as buf:
# get serialized graph def into buf
...
proto_data = c_api.TF_GetBuffer(buf)
graph_def.ParseFromString(compat.as_bytes(proto_data))
# buf has been deleted
with tf_buffer(some_string) as buf:
c_api.TF_SomeFunction(buf)
# buf has been deleted
Args:
data: An optional `bytes`, `str`, or `unicode` object. If not None, the
yielded buffer will contain this data.
Yields:
Created TF_Buffer
"""
if data:
buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
else:
buf = c_api.TF_NewBuffer()
try:
yield buf
finally:
c_api.TF_DeleteBuffer(buf)
def tf_output(c_op, index):
"""Returns a wrapped TF_Output with specified operation and index.
Args:
c_op: wrapped TF_Operation
index: integer
Returns:
Wrapped TF_Output
"""
ret = c_api.TF_Output()
ret.oper = c_op
ret.index = index
return ret
def tf_operations(graph):
"""Generator that yields every TF_Operation in `graph`.
Args:
graph: Graph
Yields:
wrapped TF_Operation
"""
# pylint: disable=protected-access
pos = 0
with graph._c_graph.get() as c_graph:
c_op, pos = c_api.TF_GraphNextOperation(c_graph, pos)
while c_op is not None:
yield c_op
c_op, pos = c_api.TF_GraphNextOperation(c_graph, pos)
# pylint: enable=protected-access
def new_tf_operations(graph):
"""Generator that yields newly-added TF_Operations in `graph`.
Specifically, yields TF_Operations that don't have associated Operations in
`graph`. This is useful for processing nodes added by the C API.
Args:
graph: Graph
Yields:
wrapped TF_Operation
"""
# TODO(b/69679162): do this more efficiently
for c_op in tf_operations(graph):
try:
graph._get_operation_by_tf_operation(c_op) # pylint: disable=protected-access
except KeyError:
yield c_op