1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utilities related to loss functions."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import distribution_strategy_context
22from tensorflow.python.framework import ops
23from tensorflow.python.keras import backend as K
24from tensorflow.python.keras.engine import keras_tensor
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops.ragged import ragged_tensor
29from tensorflow.python.util.tf_export import keras_export
30
31
32@keras_export('keras.losses.Reduction', v1=[])
33class ReductionV2(object):
34  """Types of loss reduction.
35
36  Contains the following values:
37
38  * `AUTO`: Indicates that the reduction option will be determined by the usage
39     context. For almost all cases this defaults to `SUM_OVER_BATCH_SIZE`. When
40     used with `tf.distribute.Strategy`, outside of built-in training loops such
41     as `tf.keras` `compile` and `fit`, we expect reduction value to be
42     `SUM` or `NONE`. Using `AUTO` in that case will raise an error.
43  * `NONE`: Weighted losses with one dimension reduced (axis=-1, or axis
44     specified by loss function). When this reduction type used with built-in
45     Keras training loops like `fit`/`evaluate`, the unreduced vector loss is
46     passed to the optimizer but the reported loss will be a scalar value.
47  * `SUM`: Scalar sum of weighted losses.
48  * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
49     This reduction type is not supported when used with
50     `tf.distribute.Strategy` outside of built-in training loops like `tf.keras`
51     `compile`/`fit`.
52
53     You can implement 'SUM_OVER_BATCH_SIZE' using global batch size like:
54     ```
55     with strategy.scope():
56       loss_obj = tf.keras.losses.CategoricalCrossentropy(
57           reduction=tf.keras.losses.Reduction.NONE)
58       ....
59       loss = tf.reduce_sum(loss_obj(labels, predictions)) *
60           (1. / global_batch_size)
61     ```
62
63  Please see the [custom training guide](
64  https://www.tensorflow.org/tutorials/distribute/custom_training) for more
65  details on this.
66  """
67
68  AUTO = 'auto'
69  NONE = 'none'
70  SUM = 'sum'
71  SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
72
73  @classmethod
74  def all(cls):
75    return (cls.AUTO, cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
76
77  @classmethod
78  def validate(cls, key):
79    if key not in cls.all():
80      raise ValueError('Invalid Reduction Key %s.' % key)
81
82
83def remove_squeezable_dimensions(
84    labels, predictions, expected_rank_diff=0, name=None):
85  """Squeeze last dim if ranks differ from expected by exactly 1.
86
87  In the common case where we expect shapes to match, `expected_rank_diff`
88  defaults to 0, and we squeeze the last dimension of the larger rank if they
89  differ by 1.
90
91  But, for example, if `labels` contains class IDs and `predictions` contains 1
92  probability per class, we expect `predictions` to have 1 more dimension than
93  `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze
94  `labels` if `rank(predictions) - rank(labels) == 0`, and
95  `predictions` if `rank(predictions) - rank(labels) == 2`.
96
97  This will use static shape if available. Otherwise, it will add graph
98  operations, which could result in a performance hit.
99
100  Args:
101    labels: Label values, a `Tensor` whose dimensions match `predictions`.
102    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
103    expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
104    name: Name of the op.
105
106  Returns:
107    Tuple of `labels` and `predictions`, possibly with last dim squeezed.
108  """
109  with K.name_scope(name or 'remove_squeezable_dimensions'):
110    if not isinstance(predictions, ragged_tensor.RaggedTensor):
111      predictions = ops.convert_to_tensor_v2_with_dispatch(predictions)
112    if not isinstance(labels, ragged_tensor.RaggedTensor):
113      labels = ops.convert_to_tensor_v2_with_dispatch(labels)
114    predictions_shape = predictions.shape
115    predictions_rank = predictions_shape.ndims
116    labels_shape = labels.shape
117    labels_rank = labels_shape.ndims
118    if (labels_rank is not None) and (predictions_rank is not None):
119      # Use static rank.
120      rank_diff = predictions_rank - labels_rank
121      if (rank_diff == expected_rank_diff + 1 and
122          predictions_shape.dims[-1].is_compatible_with(1)):
123        predictions = array_ops.squeeze(predictions, [-1])
124      elif (rank_diff == expected_rank_diff - 1 and
125            labels_shape.dims[-1].is_compatible_with(1)):
126        labels = array_ops.squeeze(labels, [-1])
127      return labels, predictions
128
129    # Use dynamic rank.
130    rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
131    if (predictions_rank is None) or (
132        predictions_shape.dims[-1].is_compatible_with(1)):
133      predictions = control_flow_ops.cond(
134          math_ops.equal(expected_rank_diff + 1, rank_diff),
135          lambda: array_ops.squeeze(predictions, [-1]),
136          lambda: predictions)
137    if (labels_rank is None) or (
138        labels_shape.dims[-1].is_compatible_with(1)):
139      labels = control_flow_ops.cond(
140          math_ops.equal(expected_rank_diff - 1, rank_diff),
141          lambda: array_ops.squeeze(labels, [-1]),
142          lambda: labels)
143    return labels, predictions
144
145
146def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):
147  """Squeeze or expand last dimension if needed.
148
149  1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
150  (using `remove_squeezable_dimensions`).
151  2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
152  from the new rank of `y_pred`.
153  If `sample_weight` is scalar, it is kept scalar.
154
155  This will use static shape if available. Otherwise, it will add graph
156  operations, which could result in a performance hit.
157
158  Args:
159    y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
160    y_true: Optional label `Tensor` whose dimensions match `y_pred`.
161    sample_weight: Optional weight scalar or `Tensor` whose dimensions match
162      `y_pred`.
163
164  Returns:
165    Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
166    the last dimension squeezed,
167    `sample_weight` could be extended by one dimension.
168    If `sample_weight` is None, (y_pred, y_true) is returned.
169  """
170  y_pred_shape = y_pred.shape
171  y_pred_rank = y_pred_shape.ndims
172  if y_true is not None:
173
174    # If sparse matrix is provided as `y_true`, the last dimension in `y_pred`
175    # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)),
176    # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3))
177    # In this case, we should not try to remove squeezable dimension.
178    y_true_shape = y_true.shape
179    y_true_rank = y_true_shape.ndims
180    if (y_true_rank is not None) and (y_pred_rank is not None):
181      # Use static rank for `y_true` and `y_pred`.
182      if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1:
183        y_true, y_pred = remove_squeezable_dimensions(
184            y_true, y_pred)
185    else:
186      # Use dynamic rank.
187      rank_diff = array_ops.rank(y_pred) - array_ops.rank(y_true)
188      squeeze_dims = lambda: remove_squeezable_dimensions(  # pylint: disable=g-long-lambda
189          y_true, y_pred)
190      is_last_dim_1 = math_ops.equal(1, array_ops.shape(y_pred)[-1])
191      maybe_squeeze_dims = lambda: control_flow_ops.cond(  # pylint: disable=g-long-lambda
192          is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred))
193      y_true, y_pred = control_flow_ops.cond(
194          math_ops.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims)
195
196  if sample_weight is None:
197    return y_pred, y_true
198
199  weights_shape = sample_weight.shape
200  weights_rank = weights_shape.ndims
201  if weights_rank == 0:  # If weights is scalar, do nothing.
202    return y_pred, y_true, sample_weight
203
204  if (y_pred_rank is not None) and (weights_rank is not None):
205    # Use static rank.
206    if weights_rank - y_pred_rank == 1:
207      sample_weight = array_ops.squeeze(sample_weight, [-1])
208    elif y_pred_rank - weights_rank == 1:
209      sample_weight = array_ops.expand_dims(sample_weight, [-1])
210    return y_pred, y_true, sample_weight
211
212  # Use dynamic rank.
213  weights_rank_tensor = array_ops.rank(sample_weight)
214  rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
215  maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
216
217  def _maybe_expand_weights():
218    expand_weights = lambda: array_ops.expand_dims(sample_weight, [-1])
219    return control_flow_ops.cond(
220        math_ops.equal(rank_diff, -1), expand_weights, lambda: sample_weight)
221
222  def _maybe_adjust_weights():
223    return control_flow_ops.cond(
224        math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
225        _maybe_expand_weights)
226
227  # squeeze or expand last dim of `sample_weight` if its rank differs by 1
228  # from the new rank of `y_pred`.
229  sample_weight = control_flow_ops.cond(
230      math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
231      _maybe_adjust_weights)
232  return y_pred, y_true, sample_weight
233
234
235def _safe_mean(losses, num_present):
236  """Computes a safe mean of the losses.
237
238  Args:
239    losses: `Tensor` whose elements contain individual loss measurements.
240    num_present: The number of measurable elements in `losses`.
241
242  Returns:
243    A scalar representing the mean of `losses`. If `num_present` is zero,
244      then zero is returned.
245  """
246  total_loss = math_ops.reduce_sum(losses)
247  return math_ops.div_no_nan(total_loss, num_present, name='value')
248
249
250def _num_elements(losses):
251  """Computes the number of elements in `losses` tensor."""
252  with K.name_scope('num_elements') as scope:
253    return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
254
255
256def reduce_weighted_loss(weighted_losses,
257                         reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
258  """Reduces the individual weighted loss measurements."""
259  if reduction == ReductionV2.NONE:
260    loss = weighted_losses
261  else:
262    loss = math_ops.reduce_sum(weighted_losses)
263    if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
264      loss = _safe_mean(loss, _num_elements(weighted_losses))
265  return loss
266
267
268def compute_weighted_loss(losses,
269                          sample_weight=None,
270                          reduction=ReductionV2.SUM_OVER_BATCH_SIZE,
271                          name=None):
272  """Computes the weighted loss.
273
274  Args:
275    losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
276    sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
277      `losses`, or be broadcastable to `losses`.
278    reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss.
279      Default value is `SUM_OVER_BATCH_SIZE`.
280    name: Optional name for the op.
281
282  Raises:
283    ValueError: If the shape of `sample_weight` is not compatible with `losses`.
284
285  Returns:
286    Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
287    `NONE`, this has the same shape as `losses`; otherwise, it is scalar.
288  """
289  ReductionV2.validate(reduction)
290
291  # If this function is called directly, then we just default 'AUTO' to
292  # 'SUM_OVER_BATCH_SIZE'. Eg. Canned estimator use cases.
293  if reduction == ReductionV2.AUTO:
294    reduction = ReductionV2.SUM_OVER_BATCH_SIZE
295  if sample_weight is None:
296    sample_weight = 1.0
297  with K.name_scope(name or 'weighted_loss'):
298    # Save the `reduction` argument for loss normalization when distributing
299    # to multiple replicas. Used only for estimator + v1 optimizer flow.
300    ops.get_default_graph()._last_loss_reduction = reduction  # pylint: disable=protected-access
301
302    if not isinstance(losses,
303                      (keras_tensor.KerasTensor, ragged_tensor.RaggedTensor)):
304      losses = ops.convert_to_tensor_v2_with_dispatch(losses)
305    input_dtype = losses.dtype
306
307    if not isinstance(sample_weight, keras_tensor.KerasTensor):
308      sample_weight = ops.convert_to_tensor_v2_with_dispatch(sample_weight)
309
310    # TODO(psv): Handle casting here in a better way, eg. if losses is float64
311    # we do not want to lose precision.
312    losses = math_ops.cast(losses, 'float32')
313    sample_weight = math_ops.cast(sample_weight, 'float32')
314    # Update dimensions of `sample_weight` to match with `losses` if possible.
315    losses, _, sample_weight = squeeze_or_expand_dimensions(  # pylint: disable=unbalanced-tuple-unpacking
316        losses, None, sample_weight)
317    weighted_losses = math_ops.multiply(losses, sample_weight)
318
319    # Apply reduction function to the individual weighted losses.
320    loss = reduce_weighted_loss(weighted_losses, reduction)
321    # Convert the result back to the input type.
322    loss = math_ops.cast(loss, input_dtype)
323    return loss
324
325
326def scale_loss_for_distribution(loss_value):
327  """Scales and returns the given loss value by the number of replicas."""
328  num_replicas = (
329      distribution_strategy_context.get_strategy().num_replicas_in_sync)
330  if num_replicas > 1:
331    loss_value *= (1. / num_replicas)
332  return loss_value
333
334
335def cast_losses_to_common_dtype(losses):
336  """Cast a list of losses to a common dtype.
337
338  If any loss is floating-point, they will all be casted to the most-precise
339  floating-point loss. Otherwise the losses are not casted. We also skip casting
340  losses if there are any complex losses.
341
342  Args:
343    losses: A list of losses.
344
345  Returns:
346    `losses`, but they have been casted to a common dtype.
347  """
348  highest_float = None
349  for loss in losses:
350    if loss.dtype.is_floating:
351      if highest_float is None or loss.dtype.size > highest_float.size:
352        highest_float = loss.dtype
353      elif {loss.dtype, highest_float} == {'bfloat16', 'float16'}:
354        highest_float = 'float32'
355    if loss.dtype.is_complex:
356      return losses  # If we find any complex losses, do not cast any losses
357  if highest_float:
358    losses = [math_ops.cast(loss, highest_float) for loss in losses]
359  return losses
360