1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utilities related to loss functions."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import distribution_strategy_context
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.keras import backend as K
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import confusion_matrix
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import weights_broadcast_ops
30from tensorflow.python.util.tf_export import keras_export
31
32
33@keras_export('keras.losses.Reduction', v1=[])
34class ReductionV2(object):
35  """Types of loss reduction.
36
37  Contains the following values:
38
39  * `NONE`: Un-reduced weighted losses with the same shape as input.
40  * `SUM`: Scalar sum of weighted losses.
41  * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
42     Note that when using `tf.distribute.Strategy`, this is the global batch
43     size across all the replicas that are contributing to a single step.
44  """
45
46  NONE = 'none'
47  SUM = 'sum'
48  SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
49
50  @classmethod
51  def all(cls):
52    return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
53
54  @classmethod
55  def validate(cls, key):
56    if key not in cls.all():
57      raise ValueError('Invalid Reduction Key %s.' % key)
58
59
60def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
61  """Squeeze or expand last dimension if needed.
62
63  1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
64  (using `confusion_matrix.remove_squeezable_dimensions`).
65  2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
66  from the new rank of `y_pred`.
67  If `sample_weight` is scalar, it is kept scalar.
68
69  This will use static shape if available. Otherwise, it will add graph
70  operations, which could result in a performance hit.
71
72  Args:
73    y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
74    y_true: Optional label `Tensor` whose dimensions match `y_pred`.
75    sample_weight: Optional weight scalar or `Tensor` whose dimensions match
76      `y_pred`.
77
78  Returns:
79    Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
80    the last dimension squeezed,
81    `sample_weight` could be extended by one dimension.
82  """
83  y_pred_shape = y_pred.get_shape()
84  y_pred_rank = y_pred_shape.ndims
85  if y_true is not None:
86
87    # If sparse matrix is provided as `y_true`, the last dimension in `y_pred`
88    # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)),
89    # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3))
90    # In this case, we should not try to remove squeezable dimension.
91    y_true_shape = y_true.get_shape()
92    y_true_rank = y_true_shape.ndims
93    if (y_true_rank is not None) and (y_pred_rank is not None):
94      # Use static rank for `y_true` and `y_pred`.
95      if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1:
96        y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
97            y_true, y_pred)
98    else:
99      # Use dynamic rank.
100      rank_diff = array_ops.rank(y_pred) - array_ops.rank(y_true)
101      squeeze_dims = lambda: confusion_matrix.remove_squeezable_dimensions(  # pylint: disable=g-long-lambda
102          y_true, y_pred)
103      is_last_dim_1 = math_ops.equal(1, array_ops.shape(y_pred)[-1])
104      maybe_squeeze_dims = lambda: control_flow_ops.cond(  # pylint: disable=g-long-lambda
105          is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred))
106      y_true, y_pred = control_flow_ops.cond(
107          math_ops.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims)
108
109  if sample_weight is None:
110    return y_pred, y_true, None
111
112  sample_weight = ops.convert_to_tensor(sample_weight)
113  weights_shape = sample_weight.get_shape()
114  weights_rank = weights_shape.ndims
115  if weights_rank == 0:  # If weights is scalar, do nothing.
116    return y_pred, y_true, sample_weight
117
118  if (y_pred_rank is not None) and (weights_rank is not None):
119    # Use static rank.
120    if weights_rank - y_pred_rank == 1:
121      sample_weight = array_ops.squeeze(sample_weight, [-1])
122    elif y_pred_rank - weights_rank == 1:
123      sample_weight = array_ops.expand_dims(sample_weight, [-1])
124    return y_pred, y_true, sample_weight
125
126  # Use dynamic rank.
127  weights_rank_tensor = array_ops.rank(sample_weight)
128  rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
129  maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
130
131  def _maybe_expand_weights():
132    return control_flow_ops.cond(
133        math_ops.equal(rank_diff,
134                       -1), lambda: array_ops.expand_dims(sample_weight, [-1]),
135        lambda: sample_weight)
136
137  def _maybe_adjust_weights():
138    return control_flow_ops.cond(
139        math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
140        _maybe_expand_weights)
141
142  # squeeze or expand last dim of `sample_weight` if its rank differs by 1
143  # from the new rank of `y_pred`.
144  sample_weight = control_flow_ops.cond(
145      math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
146      _maybe_adjust_weights)
147  return y_pred, y_true, sample_weight
148
149
150def _safe_mean(losses, num_present):
151  """Computes a safe mean of the losses.
152
153  Args:
154    losses: `Tensor` whose elements contain individual loss measurements.
155    num_present: The number of measurable elements in `losses`.
156
157  Returns:
158    A scalar representing the mean of `losses`. If `num_present` is zero,
159      then zero is returned.
160  """
161  total_loss = math_ops.reduce_sum(losses)
162  return math_ops.div_no_nan(total_loss, num_present, name='value')
163
164
165def _num_elements(losses):
166  """Computes the number of elements in `losses` tensor."""
167  with ops.name_scope(None, 'num_elements', values=[losses]) as scope:
168    return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
169
170
171def reduce_weighted_loss(weighted_losses,
172                         reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
173  """Reduces the individual weighted loss measurements."""
174  if reduction == ReductionV2.NONE:
175    loss = weighted_losses
176  else:
177    loss = math_ops.reduce_sum(weighted_losses)
178    if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
179      num_replicas = (  # Used to convert from local to global batch size.
180          distribution_strategy_context.get_strategy().num_replicas_in_sync)
181      loss = _safe_mean(loss, num_replicas * _num_elements(weighted_losses))
182  return loss
183
184
185def compute_weighted_loss(losses,
186                          sample_weight=None,
187                          reduction=ReductionV2.SUM_OVER_BATCH_SIZE,
188                          name=None):
189  """Computes the weighted loss.
190
191  Args:
192    losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
193    sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
194      `losses`, or be broadcastable to `losses`.
195    reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss.
196      Default value is `SUM_OVER_BATCH_SIZE`.
197    name: Optional name for the op.
198
199  Raises:
200    ValueError: If the shape of `sample_weight` is not compatible with `losses`.
201
202  Returns:
203    Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
204    `NONE`, this has the same shape as `losses`; otherwise, it is scalar.
205  """
206  ReductionV2.validate(reduction)
207  if sample_weight is None:
208    sample_weight = 1.0
209  with ops.name_scope(name, 'weighted_loss', (losses, sample_weight)):
210    # Update dimensions of `sample_weight` to match with `losses` if possible.
211    losses, _, sample_weight = squeeze_or_expand_dimensions(
212        losses, None, sample_weight)
213    losses = ops.convert_to_tensor(losses)
214    input_dtype = losses.dtype
215    losses = math_ops.cast(losses, dtypes.float32)
216    sample_weight = math_ops.cast(sample_weight, dtypes.float32)
217
218    try:
219      # Broadcast weights if possible.
220      sample_weight = weights_broadcast_ops.broadcast_weights(
221          sample_weight, losses)
222    except ValueError:
223      # Reduce values to same ndim as weight array.
224      ndim = K.ndim(losses)
225      weight_ndim = K.ndim(sample_weight)
226      losses = K.mean(losses, axis=list(range(weight_ndim, ndim)))
227
228    sample_weight.get_shape().assert_is_compatible_with(losses.get_shape())
229    weighted_losses = math_ops.multiply(losses, sample_weight)
230    # Apply reduction function to the individual weighted losses.
231    loss = reduce_weighted_loss(weighted_losses, reduction)
232    # Convert the result back to the input type.
233    loss = math_ops.cast(loss, input_dtype)
234    return loss
235
236
237def scale_loss_for_distribution(loss_value):
238  """Scales and returns the given loss value by the number of replicas."""
239  num_replicas = (
240      distribution_strategy_context.get_strategy().num_replicas_in_sync)
241  if num_replicas > 1:
242    loss_value *= (1. / num_replicas)
243  return loss_value
244