1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related utilities."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.keras.utils import generic_utils
25from tensorflow.python.ops import array_ops
26from tensorflow.python.util import nest
27
28
29def slice_arrays(arrays, indices, contiguous=True):
30  """Slices batches out of provided arrays (workaround for eager tensors).
31
32  Unfortunately eager tensors don't have the same slicing behavior as
33  Numpy arrays (they follow the same slicing behavior as symbolic TF tensors),
34  hence we cannot use `generic_utils.slice_arrays` directly
35  and we have to implement this workaround based on `concat`. This has a
36  performance cost.
37
38  Args:
39    arrays: Single array or list of arrays.
40    indices: List of indices in the array that should be included in the output
41      batch.
42    contiguous: Boolean flag indicating whether the indices are contiguous.
43
44  Returns:
45    Slice of data (either single array or list of arrays).
46  """
47  converted_to_list = False
48  if not isinstance(arrays, list):
49    converted_to_list = True
50    arrays = [arrays]
51  if any(tensor_util.is_tf_type(x) for x in arrays):
52    if not contiguous:
53      entries = [[x[i:i + 1] for i in indices] for x in arrays]
54      slices = [array_ops.concat(x, axis=0) for x in entries]
55    else:
56      slices = [x[indices[0]:indices[-1] + 1] for x in arrays]
57  else:
58    slices = generic_utils.slice_arrays(arrays, indices)
59
60  if converted_to_list:
61    slices = slices[0]
62  return slices
63
64
65def handle_partial_sample_weights(outputs, sample_weights, sample_weight_modes,
66                                  check_all_flat=False):
67  """Adds 1.0 as sample weights for the outputs for which there is no weight.
68
69  Args:
70    outputs: List of model outputs.
71    sample_weights: List of sample weight inputs.
72    sample_weight_modes: List of sample weight modes or None.
73    check_all_flat: Ensure that inputs are not nested structures. This is not
74      a free check, so we may not want to run it eagerly every iteration.
75
76  Returns:
77    Tuple of sample weights, one sample weight for every output, and booleans
78    describing the raw sample weights.
79  """
80  any_sample_weight = sample_weights is not None and any(
81      w is not None for w in sample_weights)
82  partial_sample_weight = any_sample_weight and any(
83      w is None for w in sample_weights)
84
85  if not any_sample_weight:
86    return None, any_sample_weight, partial_sample_weight
87
88  if not partial_sample_weight:
89    return sample_weights, any_sample_weight, partial_sample_weight
90
91  if check_all_flat:
92    nest.assert_same_structure(
93        list_to_tuple(sample_weights),
94        list_to_tuple(nest.flatten(sample_weights)))
95    nest.assert_same_structure(
96        list_to_tuple(outputs),
97        list_to_tuple(nest.flatten(outputs)))
98    if sample_weight_modes is not None:
99      nest.assert_same_structure(
100          sample_weight_modes, nest.flatten(sample_weight_modes))
101
102  new_sample_weights = []
103  for i, sw in enumerate(sample_weights):
104    if sw is None:
105      as_numpy = isinstance(outputs[i], np.ndarray)
106      output = outputs[i]
107      output_shape = output.shape if as_numpy else array_ops.shape(output)
108
109      is_temporal = (
110          sample_weight_modes is not None and
111          sample_weight_modes[i] == 'temporal')
112      sw_shape = (output_shape[0],
113                  output_shape[1]) if is_temporal else (output_shape[0],)
114
115      new_sample_weights.append(
116          np.ones(sw_shape) if as_numpy else array_ops.ones(sw_shape))
117
118    else:
119      new_sample_weights.append(sw)
120  return (list_to_tuple(new_sample_weights),
121          any_sample_weight, partial_sample_weight)
122
123
124class RespectCompiledTrainableState(object):
125  """Set and restore trainable state if it has changed since compile.
126
127  The keras API guarantees that the value of each Layer's `trainable` property
128  at `Model.compile` time will be used when training that model. In order to
129  respect this requirement, it may be necessary to set the trainable value of
130  layers to their compile time values before beginning a training endpoint and
131  restore the values before returing from said endpoint. This scope checks if
132  any layer's trainable state has changed since Model compile, and performs this
133  set and un-set bookkeeping.
134
135  However, the trainable state of a layer changes quite infrequently, if ever,
136  for many kinds of workflows. Moreover, updating every layer in a model is an
137  expensive operation. As a result, we will only explicitly set and unset the
138  trainable state of a model if a trainable value has changed since compile.
139  """
140
141  def __init__(self, model):
142    self._model = model
143    self._current_trainable_state = None
144    self._compiled_trainable_state = None
145    self._should_set_trainable = False
146
147  def __enter__(self):
148    self._current_trainable_state = self._model._get_trainable_state()  # pylint: disable=protected-access
149    self._compiled_trainable_state = self._model._compiled_trainable_state  # pylint: disable=protected-access
150
151    # Check to see if any layer's trainable state has changed since `compile`.
152    for layer, trainable in self._compiled_trainable_state.items():
153      if (layer in self._current_trainable_state and
154          trainable != self._current_trainable_state[layer]):
155        self._should_set_trainable = True
156        break
157
158    # If so, restore the model to its compiled state.
159    if self._should_set_trainable:
160      self._model._set_trainable_state(self._compiled_trainable_state)  # pylint: disable=protected-access
161
162  def __exit__(self, type_arg, value_arg, traceback_arg):
163    # If we set the values to their compiled state in __enter__, we need to
164    # restore the original values before leaving the scope.
165    if self._should_set_trainable:
166      self._model._set_trainable_state(self._current_trainable_state)  # pylint: disable=protected-access
167    return False  # False values do not suppress exceptions
168
169
170# Allow use of methods not exposed to the user.
171# pylint: disable=protected-access
172def get_input_shape_and_dtype(layer):
173  """Retrieves input shape and input dtype of layer if applicable.
174
175  Args:
176    layer: Layer (or model) instance.
177
178  Returns:
179    Tuple (input_shape, input_dtype). Both could be None if the layer
180      does not have a defined input shape.
181
182  Raises:
183    ValueError: in case an empty Sequential or Functional model is passed.
184  """
185
186  def _is_graph_model(layer):
187    return ((hasattr(layer, '_is_graph_network') and layer._is_graph_network) or
188            layer.__class__.__name__ == 'Sequential')
189
190  # In case of nested models: recover the first layer
191  # of the deepest model to infer input shape and dtype.
192  # Subclassed Models may not have been built so can't be checked.
193  while _is_graph_model(layer):
194    if not layer.layers:
195      raise ValueError('An empty Model cannot be used as a Layer.')
196    layer = layer.layers[0]
197
198  if getattr(layer, '_batch_input_shape', None):
199    return layer._batch_input_shape, layer.dtype
200  return None, None
201
202
203# pylint: enable=protected-access
204
205
206def get_static_batch_size(layer):
207  """Gets the static batch size of a Layer.
208
209  Args:
210    layer: a `Layer` instance.
211
212  Returns:
213    The static batch size of a Layer.
214  """
215  batch_input_shape, _ = get_input_shape_and_dtype(layer)
216  if batch_input_shape is not None:
217    return tensor_shape.Dimension(batch_input_shape[0]).value
218  return None
219
220
221def list_to_tuple(maybe_list):
222  """Datasets will stack the list of tensor, so switch them to tuples."""
223  if isinstance(maybe_list, list):
224    return tuple(maybe_list)
225  return maybe_list
226