1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility object to handler partial batches for TPUStrategy."""
16# pylint: disable=protected-access
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22import six
23
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.keras import backend as K
26from tensorflow.python.ops import array_ops
27from tensorflow.python.util import nest
28
29
30class PartialBatchPaddingHandler(object):
31  """A container that holds info about partial batches for `predict()`."""
32
33  def __init__(self, output_shape):
34    self.padded_batch_size = 0
35    self.padding_mask = array_ops.zeros(0)
36    self.output_shape = output_shape
37
38  def get_real_batch_size(self, dataset_batch):
39    """Returns the number of elements in a potentially partial batch."""
40    if isinstance(dataset_batch, (tuple, list)):
41      dataset_batch = dataset_batch[0]
42
43    assert nest.flatten(dataset_batch)
44
45    def _find_any_tensor(batch_features):
46      tensors = [
47          x for x in nest.flatten(batch_features) if tensor_util.is_tensor(x)
48      ]
49      if not tensors:
50        raise ValueError('Cannot find any Tensor in features dict.')
51      return tensors[0]
52
53    return K.cast(K.shape(_find_any_tensor(dataset_batch))[0],
54                  dtype='int64')
55
56  def update_mask(self, padding_mask, dataset_batch):
57    """Calculate and cache the amount of padding required for a batch."""
58    original_batch_size = self.get_real_batch_size(dataset_batch)
59    missing_count = self.padded_batch_size - original_batch_size
60    mask = K.concatenate([array_ops.ones(original_batch_size),
61                          array_ops.zeros(missing_count)], axis=0)
62    return K.concatenate([padding_mask, mask], axis=0)
63
64  def pad_batch(self, *dataset_batch_elements):
65    """Pads out the batch dimension of a tensor to the complete batch size."""
66    def _pad(batch):
67      """Helper function to pad nested data within each batch elements."""
68      padded_dict_batch = {}
69      if isinstance(batch, dict):
70        for key, value in six.iteritems(batch):
71          padded_dict_batch[key] = _pad(value)
72        return padded_dict_batch
73
74      rank = len(batch.shape)
75      assert rank > 0
76      missing_count = (self.padded_batch_size -
77                       self.get_real_batch_size(batch))
78      padding = K.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
79      return array_ops.pad(batch, padding, 'constant')
80
81    if len(dataset_batch_elements) == 1:
82      return _pad(dataset_batch_elements[0])
83
84    batch_elements = []
85    for batch_element in dataset_batch_elements:
86      batch_elements.append(_pad(batch_element))
87    return tuple(batch_elements)
88
89  def apply_mask(self, prediction_result):
90    """Removes prediction output that corresponds to padded input."""
91    padding_mask = K.get_value(self.padding_mask)
92    assert len(padding_mask.shape) == 1
93
94    if len(self.output_shape) == 1:
95      prediction = np.take(prediction_result,
96                           np.nonzero(
97                               padding_mask[:len(prediction_result)]),
98                           axis=0)
99      if prediction.shape[0] == 1:
100        prediction = np.squeeze(prediction, axis=0)
101      return prediction
102
103    else:
104      predictions = []
105      for i in range(len(self.output_shape)):
106        prediction = prediction_result[i]
107        prediction = np.take(prediction, np.nonzero(
108            padding_mask[:len(prediction)]), axis=0)
109        predictions.append(np.squeeze(prediction))
110
111      return predictions
112