1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility object to handler partial batches for TPUStrategy.""" 16# pylint: disable=protected-access 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22import six 23 24from tensorflow.python.framework import tensor_util 25from tensorflow.python.keras import backend as K 26from tensorflow.python.ops import array_ops 27from tensorflow.python.util import nest 28 29 30class PartialBatchPaddingHandler(object): 31 """A container that holds info about partial batches for `predict()`.""" 32 33 def __init__(self, output_shape): 34 self.padded_batch_size = 0 35 self.padding_mask = array_ops.zeros(0) 36 self.output_shape = output_shape 37 38 def get_real_batch_size(self, dataset_batch): 39 """Returns the number of elements in a potentially partial batch.""" 40 if isinstance(dataset_batch, (tuple, list)): 41 dataset_batch = dataset_batch[0] 42 43 assert nest.flatten(dataset_batch) 44 45 def _find_any_tensor(batch_features): 46 tensors = [ 47 x for x in nest.flatten(batch_features) if tensor_util.is_tf_type(x) 48 ] 49 if not tensors: 50 raise ValueError('Cannot find any Tensor in features dict.') 51 return tensors[0] 52 53 return K.cast(K.shape(_find_any_tensor(dataset_batch))[0], 54 dtype='int64') 55 56 def update_mask(self, padding_mask, dataset_batch): 57 """Calculate and cache the amount of padding required for a batch.""" 58 original_batch_size = self.get_real_batch_size(dataset_batch) 59 missing_count = self.padded_batch_size - original_batch_size 60 mask = K.concatenate([array_ops.ones(original_batch_size), 61 array_ops.zeros(missing_count)], axis=0) 62 return K.concatenate([padding_mask, mask], axis=0) 63 64 def pad_batch(self, *dataset_batch_elements): 65 """Pads out the batch dimension of a tensor to the complete batch size.""" 66 def _pad(batch): 67 """Helper function to pad nested data within each batch elements.""" 68 padded_dict_batch = {} 69 if isinstance(batch, dict): 70 for key, value in six.iteritems(batch): 71 padded_dict_batch[key] = _pad(value) 72 return padded_dict_batch 73 74 rank = len(batch.shape) 75 assert rank > 0 76 missing_count = (self.padded_batch_size - 77 self.get_real_batch_size(batch)) 78 padding = K.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) 79 return array_ops.pad(batch, padding, 'constant') 80 81 if len(dataset_batch_elements) == 1: 82 return _pad(dataset_batch_elements[0]) 83 84 batch_elements = [] 85 for batch_element in dataset_batch_elements: 86 batch_elements.append(_pad(batch_element)) 87 return tuple(batch_elements) 88 89 def apply_mask(self, prediction_result): 90 """Removes prediction output that corresponds to padded input.""" 91 padding_mask = K.get_value(self.padding_mask) 92 assert len(padding_mask.shape) == 1 93 94 if len(self.output_shape) == 1: 95 prediction = np.take(prediction_result, 96 np.nonzero( 97 padding_mask[:len(prediction_result)]), 98 axis=0) 99 if prediction.shape[0] == 1: 100 prediction = np.squeeze(prediction, axis=0) 101 return prediction 102 103 else: 104 predictions = [] 105 for i in range(len(self.output_shape)): 106 prediction = prediction_result[i] 107 prediction = np.take(prediction, np.nonzero( 108 padding_mask[:len(prediction)]), axis=0) 109 predictions.append(np.squeeze(prediction)) 110 111 return predictions 112