# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utility object to handler partial batches for TPUStrategy.""" import numpy as np import tensorflow.compat.v2 as tf from keras import backend class PartialBatchPaddingHandler: """A container that holds info about partial batches for `predict()`.""" def __init__(self, output_shape): self.padded_batch_size = 0 self.padding_mask = tf.zeros(0) self.output_shape = output_shape def get_real_batch_size(self, dataset_batch): """Returns the number of elements in a potentially partial batch.""" if isinstance(dataset_batch, (tuple, list)): dataset_batch = dataset_batch[0] assert tf.nest.flatten(dataset_batch) def _find_any_tensor(batch_features): tensors = [ x for x in tf.nest.flatten(batch_features) if tf.is_tensor(x) ] if not tensors: raise ValueError("Cannot find any Tensor in features dict.") return tensors[0] return backend.cast( backend.shape(_find_any_tensor(dataset_batch))[0], dtype="int64" ) def update_mask(self, padding_mask, dataset_batch): """Calculate and cache the amount of padding required for a batch.""" original_batch_size = self.get_real_batch_size(dataset_batch) missing_count = self.padded_batch_size - original_batch_size mask = backend.concatenate( [tf.ones(original_batch_size), tf.zeros(missing_count)], axis=0 ) return backend.concatenate([padding_mask, mask], axis=0) def pad_batch(self, *dataset_batch_elements): """Pads the batch dimension of a tensor to the complete batch size.""" def _pad(batch): """Helper function to pad nested data within each batch elements.""" padded_dict_batch = {} if isinstance(batch, dict): for key, value in batch.items(): padded_dict_batch[key] = _pad(value) return padded_dict_batch rank = len(batch.shape) assert rank > 0 missing_count = self.padded_batch_size - self.get_real_batch_size( batch ) padding = backend.stack( [[0, missing_count]] + [[0, 0]] * (rank - 1) ) return tf.pad(batch, padding, "constant") if len(dataset_batch_elements) == 1: return _pad(dataset_batch_elements[0]) batch_elements = [] for batch_element in dataset_batch_elements: batch_elements.append(_pad(batch_element)) return tuple(batch_elements) def apply_mask(self, prediction_result): """Removes prediction output that corresponds to padded input.""" padding_mask = backend.get_value(self.padding_mask) assert len(padding_mask.shape) == 1 if len(self.output_shape) == 1: prediction = np.take( prediction_result, np.nonzero(padding_mask[: len(prediction_result)]), axis=0, ) if prediction.shape[0] == 1: prediction = np.squeeze(prediction, axis=0) return prediction else: predictions = [] for i in range(len(self.output_shape)): prediction = prediction_result[i] prediction = np.take( prediction, np.nonzero(padding_mask[: len(prediction)]), axis=0, ) predictions.append(np.squeeze(prediction)) return predictions