2023-06-19 00:49:18 +02:00

146 lines
5.4 KiB

# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FuncGraph and related functionality."""
import collections as py_collections
import dataclasses
import inspect
from typing import Any, Callable, Hashable, Mapping
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import type_spec
from tensorflow.python.types import core
from tensorflow.python.util import nest
class CaptureContainer():
"""A container for both by-reference and by-value captures.
external: Used to record the tensor external to the func_graph.
For by-value captures, it would be the original tensor.
For by-reference captures, it would be the lambda function, which will be
called later to get the capture's runtime value.
internal: An internal placeholder for the capture, or a constant tensor.
The external value of the capture will be fed to this internal placeholder
when executing the func_graph as a side input.
idf: A Hashable identifier for the capture.
is_by_ref: A bool indicates if the capture is call by reference or value.
This flag will determine how `CaptureContainer.internal` is used.
external: Any
internal: core.Tensor
idf: Hashable
is_by_ref: bool = False
class FunctionCaptures(object):
"""A container for all capture usages within FuncGraph."""
def __init__(self):
# Dict that maps capture identifier -> CaptureContainer
self._by_ref = py_collections.OrderedDict()
self._by_val = py_collections.OrderedDict()
def capture_by_val(self,
value: Any,
idf: Hashable = None):
"""Create a by-value capture if not exists."""
raise NotImplementedError()
def capture_by_ref(self,
lam: Callable[[], Any],
idf: Hashable = None):
"""Create a by-referece capture if not exists."""
# check if the capture exist in self._by_ref
if idf is not None and idf in self._by_ref:
capture = self._by_ref[idf]
return capture.internal
if idf is None:
idf = len(self._by_ref)
if context.executing_eagerly():
return lam()
placeholder = self._create_capture_placeholder(lam)
capture = CaptureContainer(lam, placeholder, idf, is_by_ref=True)
self._by_ref[idf] = capture
return capture.internal
def merge_by_ref_with(self, other: "FunctionCaptures"):
"""Add by-ref captures from `other` to `self` if not exist."""
assert isinstance(other, FunctionCaptures)
for key, capture in other.by_ref_captures.items():
if key not in self._by_ref:
self._by_ref[key] = capture
def get_by_ref_snapshot(self) -> Mapping[Hashable, Any]:
"""Get a snapshot of current values of by-ref captures."""
snapshot = {}
for key, capture in self._by_ref.items():
func = capture.external
snapshot[key] = func()
return snapshot
# TODO(panzf): Use FunctionType/TraceType to create placeholder here.
def _create_capture_placeholder(self, func: Callable[[], Any]) -> ...:
"""Create placeholder if the input is tensor."""
values_nest = func()
values_flat = nest.flatten(values_nest)
# Return values in flat format. It consists of placeholders and non-tensor
# values.
return_flat = []
tensor_spec_flat = []
# Create return_flat and replace tensors with None. Later, each None is
# replaced again by corresponding placeholders
for value in values_flat:
if isinstance(value, core.Tensor):
elif isinstance(value, set) or isinstance(value, frozenset):
raise NotImplementedError(
(f"Side input returned by '{inspect.getsource(func).strip()}' "
f"has element of {type(value)} type, which is currently not "
"supported by tf.function."))
if tensor_spec_flat:
def tensor_func():
values = nest.flatten(func())
return [value for value in values if isinstance(value, core.Tensor)]
# TODO(panzf): remove get_default_graph after moving
# capture_call_time_value to this class.
graph = ops.get_default_graph()
placeholder_flat = graph.capture_call_time_value(
tensor_func, tensor_spec_flat)
# replace None that represents tensors with placehoders
flat_ptr = 0
for idx, item in enumerate(return_flat):
if item is None:
return_flat[idx] = placeholder_flat[flat_ptr]
flat_ptr += 1
return_nest = nest.pack_sequence_as(values_nest, return_flat)
return return_nest
def by_ref_captures(self):
return self._by_ref
def by_val_captures(self):
return self._by_val