1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras training and evaluation routines for eager execution.
16"""
17# pylint: disable=protected-access
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23
24import numpy as np
25
26from tensorflow.python.eager.backprop import GradientTape
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_util
29from tensorflow.python.keras import backend
30from tensorflow.python.keras.engine import training_utils
31from tensorflow.python.keras.utils import losses_utils
32from tensorflow.python.ops import math_ops
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.util import nest
35
36
37def _eager_loss_fn(outputs, targets, loss_fn, output_name):
38  with backend.name_scope(output_name + '_loss'):
39    loss = loss_fn(targets, outputs)
40  return loss
41
42
43def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None):
44  """Calculates the metrics for each output of the given model.
45
46  Arguments:
47      model: The model on which metrics are being calculated.
48      outputs: The outputs of the given model.
49      targets: The predictions or targets of the given model.
50      sample_weights: Optional list of sample weights for each output.
51      masks: Optional list of masks for each output.
52
53  Returns:
54      Returns the metric results for each output of the model.
55  """
56  outputs = nest.flatten(outputs)
57  targets = nest.flatten(targets)
58  # TODO(psv): Consider supporting skip target indices in eager mode?
59  metric_results = model._handle_metrics(
60      outputs, targets=targets, sample_weights=sample_weights, masks=masks)
61  return [backend.mean(t) for t in metric_results]
62
63
64def _model_loss(model,
65                inputs,
66                targets,
67                output_loss_metrics=None,
68                sample_weights=None,
69                training=False):
70  """Calculates the loss for a given model.
71
72  Arguments:
73      model: The model on which metrics are being calculated.
74      inputs: Either a dictionary of inputs to the model or a list of input
75        arrays.
76      targets: List of target arrays.
77      output_loss_metrics: List of metrics that are used to aggregated output
78        loss values.
79      sample_weights: Optional list of sample weight arrays.
80      training: Whether the model should be run in inference or training mode.
81
82  Returns:
83     Returns the model output, total loss, loss value calculated using the
84     specified loss function and masks for each output. The total loss includes
85     regularization losses and applies masking and sample weighting
86     to the loss value.
87  """
88  # Used to keep track of the total loss value (stateless).
89  # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
90  #                   loss_weight_2 * output_2_loss_fn(...) +
91  #                   layer losses.
92  total_loss = 0
93  kwargs = {}
94  if model._expects_training_arg:
95    kwargs['training'] = training
96  if len(inputs) == 1 and not isinstance(inputs, dict):
97    inputs = inputs[0]
98
99  # Allow mixed `NumPy` and `EagerTensor` input here.
100  if any(
101      isinstance(input_t, (np.ndarray, float, int))
102      for input_t in nest.flatten(inputs)):
103    inputs = nest.map_structure(ops.convert_to_tensor, inputs)
104
105  outs = model(inputs, **kwargs)
106
107  outs = nest.flatten(outs)
108  # `None` by default for `EagerTensors`.
109  masks = [t._keras_mask for t in outs]
110  targets = nest.flatten(targets)
111
112  # Used to keep track of individual output losses.
113  output_losses = []
114
115  with backend.name_scope('loss'):
116    for i, loss_fn in enumerate(model.loss_functions):
117      weights = sample_weights[i] if sample_weights else None
118      mask = masks[i]
119      with backend.name_scope(model.output_names[i] + '_loss'):
120        if mask is not None:
121          mask = math_ops.cast(mask, outs[i].dtype)
122          # Update weights with mask.
123          if weights is None:
124            weights = mask
125          else:
126            # Update dimensions of weights to match with mask if possible.
127            mask, _, weights = (
128                losses_utils.squeeze_or_expand_dimensions(mask, None, weights))
129            weights *= mask
130
131        # Reset reduction on the loss so that we can get the per sample loss
132        # value. We use this to get both the stateless and stateful loss
133        # values without having to compute the underlying loss function
134        # twice.
135        weighted_losses = None
136        if hasattr(loss_fn, 'reduction'):
137          current_loss_reduction = loss_fn.reduction
138          loss_fn.reduction = losses_utils.ReductionV2.NONE
139          weighted_losses = loss_fn(targets[i], outs[i], sample_weight=weights)
140          loss_fn.reduction = current_loss_reduction
141
142          # Compute the stateless loss value.
143          output_loss = losses_utils.reduce_weighted_loss(weighted_losses)
144        else:
145          # Compute the stateless loss value for a custom loss class.
146          # Here we assume that the class takes care of loss reduction
147          # because if this class returns a vector value we cannot
148          # differentiate between use case where a custom optimizer
149          # expects a vector loss value vs unreduced per-sample loss value.
150          output_loss = loss_fn(targets[i], outs[i], sample_weight=weights)
151
152      # If the number of outputs is 1 then we don't append the loss metric
153      # associated with each model output. When there are multiple outputs
154      # associated with a model, each output's loss is calculated and returned
155      # as part of the loss_metrics.
156      if len(model.outputs) > 1:
157        # Compute the stateful loss value.
158        if weighted_losses is not None:
159          aggregated_output_loss = output_loss_metrics[i](weighted_losses)
160        else:
161          # Custom loss class.
162          aggregated_output_loss = training_utils.call_metric_function(
163              output_loss_metrics[i], targets[i], outs[i], weights=weights)
164        # Keep track of the stateful output loss result.
165        output_losses.append(aggregated_output_loss)
166
167      total_loss += model.loss_weights_list[i] * output_loss
168
169    total_loss = backend.mean(total_loss)
170    # Add regularization losses
171    custom_losses = model.losses
172    if custom_losses:
173      total_loss += losses_utils.scale_loss_for_distribution(
174          math_ops.add_n(custom_losses))
175
176  return outs, total_loss, output_losses, masks
177
178
179def _process_single_batch(model,
180                          inputs,
181                          targets,
182                          output_loss_metrics=None,
183                          sample_weights=None,
184                          training=False):
185  """Calculate the loss and gradient for one input batch.
186
187     The model weights are updated if training is set to True.
188
189  Arguments:
190      model: Model whose loss has to be calculated.
191      inputs: List of input arrays.
192      targets: List of target arrays.
193      output_loss_metrics: List of metrics that are used to aggregated output
194        loss values.
195      sample_weights: Optional list of sample weight arrays.
196      training: The boolean represents if the weights of the model are updated.
197              'fit' methods will set this to True while 'evaluate' methods will
198              set this to False.
199
200  Returns:
201      output of the model, total loss, the loss and the mask
202      associated with each output.
203
204  Raises:
205      ValueError: If the model has no loss to optimize.
206  """
207  with backend.eager_learning_phase_scope(1 if training else 0):
208    with GradientTape() as tape:
209      outs, total_loss, output_losses, masks = (
210          _model_loss(
211              model,
212              inputs,
213              targets,
214              output_loss_metrics=output_loss_metrics,
215              sample_weights=sample_weights,
216              training=training))
217      if total_loss is None:
218        raise ValueError('The model cannot be run '
219                         'because it has no loss to optimize.')
220    if training:
221      if not model.trainable_weights:
222        logging.warning('The list of trainable weights is empty. Make sure that'
223                        ' you are not setting model.trainable to False before '
224                        'compiling the model.')
225      else:
226        grads = tape.gradient(total_loss, model.trainable_weights)
227        model.optimizer.apply_gradients(zip(grads,
228                                            model.trainable_weights))
229    return outs, total_loss, output_losses, masks
230
231
232def train_on_batch(model,
233                   inputs,
234                   targets,
235                   sample_weights=None,
236                   output_loss_metrics=None):
237  """Calculates the loss and gradient updates for one input batch.
238
239  Arguments:
240      model: Model whose loss has to be calculated.
241      inputs: Input batch data.
242      targets: Target batch data.
243      sample_weights: Sample weight batch data.
244      output_loss_metrics: List of metrics that are used to aggregated output
245        loss values.
246
247  Returns:
248      total loss and the loss associated with each output.
249  """
250  if isinstance(inputs, collections.Sequence):
251    if len(inputs) and tensor_util.is_tensor(inputs[0]):
252      inputs = training_utils.cast_if_floating_dtype(inputs)
253      targets = training_utils.cast_if_floating_dtype(targets)
254    else:
255      inputs = training_utils.cast_if_floating_dtype(
256          [ops.convert_to_tensor(val) for val in inputs])
257      targets = training_utils.cast_if_floating_dtype(
258          [ops.convert_to_tensor(val) for val in targets])
259  if sample_weights:
260    sample_weights = [
261        training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
262        if val is not None else None for val in sample_weights
263    ]
264
265  outs, total_loss, output_losses, masks = (
266      _process_single_batch(
267          model,
268          inputs,
269          targets,
270          sample_weights=sample_weights,
271          training=True,
272          output_loss_metrics=output_loss_metrics))
273  if not isinstance(outs, list):
274    outs = [outs]
275  metrics_results = _eager_metrics_fn(
276      model, outs, targets, sample_weights=sample_weights, masks=masks)
277  total_loss = nest.flatten(total_loss)
278  results = total_loss + output_losses + metrics_results
279
280  return [tensor_util.constant_value(v) for v in results]
281
282
283def test_on_batch(model,
284                  inputs,
285                  targets,
286                  sample_weights=None,
287                  output_loss_metrics=None):
288  """Calculates the loss for one input batch.
289
290  Arguments:
291      model: Model whose loss has to be calculated.
292      inputs: Input batch data.
293      targets: Target batch data.
294      sample_weights: Sample weight batch data.
295      output_loss_metrics: List of metrics that are used to aggregated output
296        loss values.
297
298  Returns:
299      total loss, loss and metrics associated with each output.
300  """
301  if isinstance(inputs, collections.Sequence):
302    if len(inputs) and tensor_util.is_tensor(inputs[0]):
303      inputs = training_utils.cast_if_floating_dtype(inputs)
304      targets = training_utils.cast_if_floating_dtype(targets)
305    else:
306      inputs = training_utils.cast_if_floating_dtype(
307          [ops.convert_to_tensor(val) for val in inputs])
308      targets = training_utils.cast_if_floating_dtype(
309          [ops.convert_to_tensor(val) for val in targets])
310  if sample_weights:
311    sample_weights = [
312        training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
313        if val is not None else None for val in sample_weights
314    ]
315  outs, total_loss, output_losses, masks = (
316      _model_loss(
317          model,
318          inputs,
319          targets,
320          sample_weights=sample_weights,
321          training=False,
322          output_loss_metrics=output_loss_metrics))
323  if not isinstance(outs, list):
324    outs = [outs]
325  metrics_results = _eager_metrics_fn(
326      model, outs, targets, sample_weights=sample_weights, masks=masks)
327  total_loss = nest.flatten(total_loss)
328  results = total_loss + output_losses + metrics_results
329
330  return [tensor_util.constant_value(v) for v in results]
331