133 lines
5.5 KiB
Python
133 lines
5.5 KiB
Python
![]() |
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""The implementation of `tf.data.Dataset.group_by_window`."""
|
||
|
|
||
|
from tensorflow.python.data.ops import dataset_ops
|
||
|
from tensorflow.python.data.ops import structured_function
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.framework import tensor_spec
|
||
|
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||
|
|
||
|
|
||
|
def _group_by_window(input_dataset, # pylint: disable=unused-private-name
|
||
|
key_func,
|
||
|
reduce_func,
|
||
|
window_size=None,
|
||
|
window_size_func=None,
|
||
|
name=None):
|
||
|
"""See `Dataset.group_by_window()` for details."""
|
||
|
|
||
|
if (window_size is not None and window_size_func or
|
||
|
not (window_size is not None or window_size_func)):
|
||
|
raise ValueError("Either the `window_size` argument or the "
|
||
|
"`window_size_func` argument must be specified.")
|
||
|
|
||
|
if window_size is not None:
|
||
|
|
||
|
def constant_window_func(unused_key):
|
||
|
return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
|
||
|
|
||
|
window_size_func = constant_window_func
|
||
|
|
||
|
assert window_size_func is not None
|
||
|
|
||
|
return _GroupByWindowDataset(
|
||
|
input_dataset, key_func, reduce_func, window_size_func, name=name)
|
||
|
|
||
|
|
||
|
class _GroupByWindowDataset(dataset_ops.UnaryDataset):
|
||
|
"""A `Dataset` that groups its input and performs a windowed reduction."""
|
||
|
|
||
|
def __init__(self,
|
||
|
input_dataset,
|
||
|
key_func,
|
||
|
reduce_func,
|
||
|
window_size_func,
|
||
|
name=None):
|
||
|
"""See `group_by_window()` for details."""
|
||
|
self._input_dataset = input_dataset
|
||
|
self._make_key_func(key_func, input_dataset)
|
||
|
self._make_reduce_func(reduce_func, input_dataset)
|
||
|
self._make_window_size_func(window_size_func)
|
||
|
self._name = name
|
||
|
variant_tensor = ged_ops.group_by_window_dataset(
|
||
|
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||
|
self._key_func.function.captured_inputs,
|
||
|
self._reduce_func.function.captured_inputs,
|
||
|
self._window_size_func.function.captured_inputs,
|
||
|
key_func=self._key_func.function,
|
||
|
reduce_func=self._reduce_func.function,
|
||
|
window_size_func=self._window_size_func.function,
|
||
|
**self._common_args)
|
||
|
super().__init__(input_dataset, variant_tensor)
|
||
|
|
||
|
def _make_window_size_func(self, window_size_func):
|
||
|
"""Make wrapping defun for window_size_func."""
|
||
|
|
||
|
def window_size_func_wrapper(key):
|
||
|
return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
|
||
|
|
||
|
self._window_size_func = structured_function.StructuredFunctionWrapper(
|
||
|
window_size_func_wrapper,
|
||
|
self._transformation_name(),
|
||
|
input_structure=tensor_spec.TensorSpec([], dtypes.int64))
|
||
|
if not self._window_size_func.output_structure.is_compatible_with(
|
||
|
tensor_spec.TensorSpec([], dtypes.int64)):
|
||
|
raise ValueError(f"Invalid `window_size_func`. `window_size_func` must "
|
||
|
f"return a single `tf.int64` scalar tensor but its "
|
||
|
f"return type is "
|
||
|
f"{self._window_size_func.output_structure}.")
|
||
|
|
||
|
def _make_key_func(self, key_func, input_dataset):
|
||
|
"""Make wrapping defun for key_func."""
|
||
|
|
||
|
def key_func_wrapper(*args):
|
||
|
return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
|
||
|
|
||
|
self._key_func = structured_function.StructuredFunctionWrapper(
|
||
|
key_func_wrapper, self._transformation_name(), dataset=input_dataset)
|
||
|
if not self._key_func.output_structure.is_compatible_with(
|
||
|
tensor_spec.TensorSpec([], dtypes.int64)):
|
||
|
raise ValueError(f"Invalid `key_func`. `key_func` must return a single "
|
||
|
f"`tf.int64` scalar tensor but its return type is "
|
||
|
f"{self._key_func.output_structure}.")
|
||
|
|
||
|
def _make_reduce_func(self, reduce_func, input_dataset):
|
||
|
"""Make wrapping defun for reduce_func."""
|
||
|
nested_dataset = dataset_ops.DatasetSpec(input_dataset.element_spec)
|
||
|
input_structure = (tensor_spec.TensorSpec([], dtypes.int64), nested_dataset)
|
||
|
self._reduce_func = structured_function.StructuredFunctionWrapper(
|
||
|
reduce_func,
|
||
|
self._transformation_name(),
|
||
|
input_structure=input_structure)
|
||
|
if not isinstance(self._reduce_func.output_structure,
|
||
|
dataset_ops.DatasetSpec):
|
||
|
raise TypeError(f"Invalid `reduce_func`. `reduce_func` must return a "
|
||
|
f"single `tf.data.Dataset` object but its return type "
|
||
|
f"is {self._reduce_func.output_structure}.")
|
||
|
# pylint: disable=protected-access
|
||
|
self._element_spec = (self._reduce_func.output_structure._element_spec)
|
||
|
|
||
|
@property
|
||
|
def element_spec(self):
|
||
|
return self._element_spec
|
||
|
|
||
|
def _functions(self):
|
||
|
return [self._key_func, self._reduce_func, self._window_size_func]
|
||
|
|
||
|
def _transformation_name(self):
|
||
|
return "Dataset.group_by_window()"
|