1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Implementation of Neural Net (NN) functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23from tensorflow.python.distribute import distribution_strategy_context as ds
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import candidate_sampling_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import custom_gradient
31from tensorflow.python.ops import embedding_ops
32from tensorflow.python.ops import gen_array_ops  # pylint: disable=unused-import
33from tensorflow.python.ops import gen_nn_ops
34from tensorflow.python.ops import gen_sparse_ops
35from tensorflow.python.ops import linalg_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import nn_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.ops.losses import util as losses_util
40from tensorflow.python.platform import device_context
41from tensorflow.python.util import dispatch
42from tensorflow.python.util.deprecation import deprecated_args
43from tensorflow.python.util.deprecation import deprecated_argument_lookup
44from tensorflow.python.util.tf_export import tf_export
45
46
47@tf_export("nn.log_poisson_loss")
48@dispatch.add_dispatch_support
49def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
50  """Computes log Poisson loss given `log_input`.
51
52  Gives the log-likelihood loss between the prediction and the target under the
53  assumption that the target has a Poisson distribution.
54  Caveat: By default, this is not the exact loss, but the loss minus a
55    constant term [log(z!)]. That has no effect for optimization, but
56    does not play well with relative loss comparisons. To compute an
57    approximation of the log factorial term, specify
58    compute_full_loss=True to enable Stirling's Approximation.
59
60  For brevity, let `c = log(x) = log_input`, `z = targets`.  The log Poisson
61  loss is
62
63        -log(exp(-x) * (x^z) / z!)
64      = -log(exp(-x) * (x^z)) + log(z!)
65      ~ -log(exp(-x)) - log(x^z) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
66          [ Note the second term is the Stirling's Approximation for log(z!).
67            It is invariant to x and does not affect optimization, though
68            important for correct relative loss comparisons. It is only
69            computed when compute_full_loss == True. ]
70      = x - z * log(x) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
71      = exp(c) - z * c [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
72
73  Args:
74    targets: A `Tensor` of the same type and shape as `log_input`.
75    log_input: A `Tensor` of type `float32` or `float64`.
76    compute_full_loss: whether to compute the full loss. If false, a constant
77      term is dropped in favor of more efficient optimization.
78    name: A name for the operation (optional).
79
80  Returns:
81    A `Tensor` of the same shape as `log_input` with the componentwise
82    logistic losses.
83
84  Raises:
85    ValueError: If `log_input` and `targets` do not have the same shape.
86  """
87  with ops.name_scope(name, "log_poisson_loss", [log_input, targets]) as name:
88    log_input = ops.convert_to_tensor(log_input, name="log_input")
89    targets = ops.convert_to_tensor(targets, name="targets")
90    try:
91      targets.get_shape().assert_is_compatible_with(log_input.get_shape())
92    except ValueError:
93      raise ValueError(
94          "log_input and targets must have the same shape (%s vs %s)" %
95          (log_input.get_shape(), targets.get_shape()))
96
97    result = math_ops.exp(log_input) - log_input * targets
98    if compute_full_loss:
99      # need to create constant tensors here so that their dtypes can be matched
100      # to that of the targets.
101      point_five = constant_op.constant(0.5, dtype=targets.dtype)
102      two_pi = constant_op.constant(2 * math.pi, dtype=targets.dtype)
103
104      stirling_approx = (targets * math_ops.log(targets)) - targets + (
105          point_five * math_ops.log(two_pi * targets))
106      zeros = array_ops.zeros_like(targets, dtype=targets.dtype)
107      ones = array_ops.ones_like(targets, dtype=targets.dtype)
108      cond = math_ops.logical_and(targets >= zeros, targets <= ones)
109      result += array_ops.where(cond, zeros, stirling_approx)
110    return result
111
112
113@tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"])
114@dispatch.add_dispatch_support
115def sigmoid_cross_entropy_with_logits(  # pylint: disable=invalid-name
116    _sentinel=None,
117    labels=None,
118    logits=None,
119    name=None):
120  """Computes sigmoid cross entropy given `logits`.
121
122  Measures the probability error in discrete classification tasks in which each
123  class is independent and not mutually exclusive.  For instance, one could
124  perform multilabel classification where a picture can contain both an elephant
125  and a dog at the same time.
126
127  For brevity, let `x = logits`, `z = labels`.  The logistic loss is
128
129        z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
130      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
131      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
132      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
133      = (1 - z) * x + log(1 + exp(-x))
134      = x - x * z + log(1 + exp(-x))
135
136  For x < 0, to avoid overflow in exp(-x), we reformulate the above
137
138        x - x * z + log(1 + exp(-x))
139      = log(exp(x)) - x * z + log(1 + exp(-x))
140      = - x * z + log(1 + exp(x))
141
142  Hence, to ensure stability and avoid overflow, the implementation uses this
143  equivalent formulation
144
145      max(x, 0) - x * z + log(1 + exp(-abs(x)))
146
147  `logits` and `labels` must have the same type and shape.
148
149  Args:
150    _sentinel: Used to prevent positional parameters. Internal, do not use.
151    labels: A `Tensor` of the same type and shape as `logits`.
152    logits: A `Tensor` of type `float32` or `float64`.
153    name: A name for the operation (optional).
154
155  Returns:
156    A `Tensor` of the same shape as `logits` with the componentwise
157    logistic losses.
158
159  Raises:
160    ValueError: If `logits` and `labels` do not have the same shape.
161  """
162  # pylint: disable=protected-access
163  nn_ops._ensure_xent_args("sigmoid_cross_entropy_with_logits", _sentinel,
164                           labels, logits)
165  # pylint: enable=protected-access
166
167  with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
168    logits = ops.convert_to_tensor(logits, name="logits")
169    labels = ops.convert_to_tensor(labels, name="labels")
170    try:
171      labels.get_shape().assert_is_compatible_with(logits.get_shape())
172    except ValueError:
173      raise ValueError("logits and labels must have the same shape (%s vs %s)" %
174                       (logits.get_shape(), labels.get_shape()))
175
176    # The logistic loss formula from above is
177    #   x - x * z + log(1 + exp(-x))
178    # For x < 0, a more numerically stable formula is
179    #   -x * z + log(1 + exp(x))
180    # Note that these two expressions can be combined into the following:
181    #   max(x, 0) - x * z + log(1 + exp(-abs(x)))
182    # To allow computing gradients at zero, we define custom versions of max and
183    # abs functions.
184    zeros = array_ops.zeros_like(logits, dtype=logits.dtype)
185    cond = (logits >= zeros)
186    relu_logits = array_ops.where(cond, logits, zeros)
187    neg_abs_logits = array_ops.where(cond, -logits, logits)
188    return math_ops.add(
189        relu_logits - logits * labels,
190        math_ops.log1p(math_ops.exp(neg_abs_logits)),
191        name=name)
192
193
194# Note: intentionally calling this v2 to not allow existing code with indirect
195# imports to ignore the sentinel behavior.
196@tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[])
197@dispatch.add_dispatch_support
198def sigmoid_cross_entropy_with_logits_v2(  # pylint: disable=invalid-name
199    labels=None,
200    logits=None,
201    name=None):
202  """Computes sigmoid cross entropy given `logits`.
203
204  Measures the probability error in discrete classification tasks in which each
205  class is independent and not mutually exclusive.  For instance, one could
206  perform multilabel classification where a picture can contain both an elephant
207  and a dog at the same time.
208
209  For brevity, let `x = logits`, `z = labels`.  The logistic loss is
210
211        z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
212      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
213      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
214      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
215      = (1 - z) * x + log(1 + exp(-x))
216      = x - x * z + log(1 + exp(-x))
217
218  For x < 0, to avoid overflow in exp(-x), we reformulate the above
219
220        x - x * z + log(1 + exp(-x))
221      = log(exp(x)) - x * z + log(1 + exp(-x))
222      = - x * z + log(1 + exp(x))
223
224  Hence, to ensure stability and avoid overflow, the implementation uses this
225  equivalent formulation
226
227      max(x, 0) - x * z + log(1 + exp(-abs(x)))
228
229  `logits` and `labels` must have the same type and shape.
230
231  Args:
232    labels: A `Tensor` of the same type and shape as `logits`.
233    logits: A `Tensor` of type `float32` or `float64`.
234    name: A name for the operation (optional).
235
236  Returns:
237    A `Tensor` of the same shape as `logits` with the componentwise
238    logistic losses.
239
240  Raises:
241    ValueError: If `logits` and `labels` do not have the same shape.
242  """
243  return sigmoid_cross_entropy_with_logits(
244      logits=logits, labels=labels, name=name)
245
246
247@tf_export("nn.weighted_cross_entropy_with_logits", v1=[])
248@dispatch.add_dispatch_support
249def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight,
250                                          name=None):
251  """Computes a weighted cross entropy.
252
253  This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
254  allows one to trade off recall and precision by up- or down-weighting the
255  cost of a positive error relative to a negative error.
256
257  The usual cross-entropy cost is defined as:
258
259      labels * -log(sigmoid(logits)) +
260          (1 - labels) * -log(1 - sigmoid(logits))
261
262  A value `pos_weight > 1` decreases the false negative count, hence increasing
263  the recall.
264  Conversely setting `pos_weight < 1` decreases the false positive count and
265  increases the precision.
266  This can be seen from the fact that `pos_weight` is introduced as a
267  multiplicative coefficient for the positive labels term
268  in the loss expression:
269
270      labels * -log(sigmoid(logits)) * pos_weight +
271          (1 - labels) * -log(1 - sigmoid(logits))
272
273  For brevity, let `x = logits`, `z = labels`, `q = pos_weight`.
274  The loss is:
275
276        qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
277      = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
278      = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
279      = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
280      = (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
281      = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
282
283  Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
284  the implementation uses
285
286      (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
287
288  `logits` and `labels` must have the same type and shape.
289
290  Args:
291    labels: A `Tensor` of the same type and shape as `logits`.
292    logits: A `Tensor` of type `float32` or `float64`.
293    pos_weight: A coefficient to use on the positive examples.
294    name: A name for the operation (optional).
295
296  Returns:
297    A `Tensor` of the same shape as `logits` with the componentwise
298    weighted logistic losses.
299
300  Raises:
301    ValueError: If `logits` and `labels` do not have the same shape.
302  """
303  with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
304    logits = ops.convert_to_tensor(logits, name="logits")
305    labels = ops.convert_to_tensor(labels, name="labels")
306    try:
307      labels.get_shape().assert_is_compatible_with(logits.get_shape())
308    except ValueError:
309      raise ValueError("logits and labels must have the same shape (%s vs %s)" %
310                       (logits.get_shape(), labels.get_shape()))
311
312    # The logistic loss formula from above is
313    #   (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
314    # For x < 0, a more numerically stable formula is
315    #   (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x
316    # To avoid branching, we use the combined version
317    #   (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
318    log_weight = 1 + (pos_weight - 1) * labels
319    return math_ops.add(
320        (1 - labels) * logits,
321        log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) +
322                      nn_ops.relu(-logits)),
323        name=name)
324
325
326@tf_export(v1=["nn.weighted_cross_entropy_with_logits"])
327@dispatch.add_dispatch_support
328@deprecated_args(None, "targets is deprecated, use labels instead", "targets")
329def weighted_cross_entropy_with_logits(labels=None,
330                                       logits=None,
331                                       pos_weight=None,
332                                       name=None,
333                                       targets=None):
334  """Computes a weighted cross entropy.
335
336  This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
337  allows one to trade off recall and precision by up- or down-weighting the
338  cost of a positive error relative to a negative error.
339
340  The usual cross-entropy cost is defined as:
341
342      labels * -log(sigmoid(logits)) +
343          (1 - labels) * -log(1 - sigmoid(logits))
344
345  A value `pos_weight > 1` decreases the false negative count, hence increasing
346  the recall.
347  Conversely setting `pos_weight < 1` decreases the false positive count and
348  increases the precision.
349  This can be seen from the fact that `pos_weight` is introduced as a
350  multiplicative coefficient for the positive labels term
351  in the loss expression:
352
353      labels * -log(sigmoid(logits)) * pos_weight +
354          (1 - labels) * -log(1 - sigmoid(logits))
355
356  For brevity, let `x = logits`, `z = labels`, `q = pos_weight`.
357  The loss is:
358
359        qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
360      = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
361      = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
362      = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
363      = (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
364      = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
365
366  Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
367  the implementation uses
368
369      (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
370
371  `logits` and `labels` must have the same type and shape.
372
373  Args:
374    labels: A `Tensor` of the same type and shape as `logits`.
375    logits: A `Tensor` of type `float32` or `float64`.
376    pos_weight: A coefficient to use on the positive examples.
377    name: A name for the operation (optional).
378    targets: Deprecated alias for labels.
379
380  Returns:
381    A `Tensor` of the same shape as `logits` with the componentwise
382    weighted logistic losses.
383
384  Raises:
385    ValueError: If `logits` and `labels` do not have the same shape.
386  """
387  labels = deprecated_argument_lookup("labels", labels, "targets", targets)
388  return weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, name)
389
390
391@tf_export("nn.compute_average_loss")
392@dispatch.add_dispatch_support
393def compute_average_loss(per_example_loss,
394                         sample_weight=None,
395                         global_batch_size=None):
396  """Scales per-example losses with sample_weights and computes their average.
397
398  Usage with distribution strategy and custom training loop:
399
400  ```python
401  with strategy.scope():
402    def compute_loss(labels, predictions, sample_weight=None):
403
404      # If you are using a `Loss` class instead, set reduction to `NONE` so that
405      # we can do the reduction afterwards and divide by global batch size.
406      per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
407          labels, predictions)
408
409      # Compute loss that is scaled by sample_weight and by global batch size.
410      return tf.nn.compute_average_loss(
411          per_example_loss,
412          sample_weight=sample_weight,
413          global_batch_size=GLOBAL_BATCH_SIZE)
414  ```
415
416  Args:
417    per_example_loss: Per-example loss.
418    sample_weight: Optional weighting for each example.
419    global_batch_size: Optional global batch size value. Defaults to (size of
420      first dimension of `losses`) * (number of replicas).
421
422  Returns:
423    Scalar loss value.
424  """  # pylint: disable=g-doc-exception
425  per_example_loss = ops.convert_to_tensor(per_example_loss)
426  input_dtype = per_example_loss.dtype
427
428  with losses_util.check_per_example_loss_rank(per_example_loss):
429    if sample_weight is not None:
430      sample_weight = ops.convert_to_tensor(sample_weight)
431      per_example_loss = losses_util.scale_losses_by_sample_weight(
432          per_example_loss, sample_weight)
433    per_example_loss = math_ops.cast(per_example_loss, input_dtype)
434
435    if global_batch_size is None:
436      if ds.has_strategy() and ds.in_cross_replica_context():
437        raise RuntimeError(
438            "You are calling `compute_average_loss` in cross replica context, "
439            "while it was expected to be called in replica context.")
440
441      num_replicas = ds.get_strategy().num_replicas_in_sync
442      per_replica_batch_size = array_ops.shape_v2(per_example_loss)[0]
443      global_batch_size = per_replica_batch_size * num_replicas
444      global_batch_size = math_ops.cast(global_batch_size, input_dtype)
445
446    return math_ops.reduce_sum(per_example_loss) / global_batch_size
447
448
449@tf_export("nn.scale_regularization_loss")
450@dispatch.add_dispatch_support
451def scale_regularization_loss(regularization_loss):
452  """Scales the sum of the given regularization losses by number of replicas.
453
454  Usage with distribution strategy and custom training loop:
455
456  ```python
457  with strategy.scope():
458    def compute_loss(self, label, predictions):
459      per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
460          labels, predictions)
461
462      # Compute loss that is scaled by sample_weight and by global batch size.
463      loss = tf.nn.compute_average_loss(
464          per_example_loss,
465          sample_weight=sample_weight,
466          global_batch_size=GLOBAL_BATCH_SIZE)
467
468      # Add scaled regularization losses.
469      loss += tf.nn.scale_regularization_loss(tf.nn.l2_loss(weights))
470      return loss
471  ```
472
473  Args:
474    regularization_loss: Regularization loss.
475
476  Returns:
477    Scalar loss value.
478  """  # pylint: disable=g-doc-exception
479  if ds.has_strategy() and ds.in_cross_replica_context():
480    raise RuntimeError(
481        "You are calling `scale_regularization_loss` in cross replica context, "
482        "while it was expected to be called in replica context.")
483
484  num_replicas = ds.get_strategy().num_replicas_in_sync
485  return math_ops.reduce_sum(regularization_loss) / num_replicas
486
487
488@tf_export(v1=["nn.relu_layer"])
489@dispatch.add_dispatch_support
490def relu_layer(x, weights, biases, name=None):
491  """Computes Relu(x * weight + biases).
492
493  Args:
494    x: a 2D tensor.  Dimensions typically: batch, in_units
495    weights: a 2D tensor.  Dimensions typically: in_units, out_units
496    biases: a 1D tensor.  Dimensions: out_units
497    name: A name for the operation (optional).  If not specified
498      "nn_relu_layer" is used.
499
500  Returns:
501    A 2-D Tensor computing relu(matmul(x, weights) + biases).
502    Dimensions typically: batch, out_units.
503  """
504  with ops.name_scope(name, "relu_layer", [x, weights, biases]) as name:
505    x = ops.convert_to_tensor(x, name="x")
506    weights = ops.convert_to_tensor(weights, name="weights")
507    biases = ops.convert_to_tensor(biases, name="biases")
508    xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases)
509    return nn_ops.relu(xw_plus_b, name=name)
510
511
512@tf_export("nn.silu", "nn.swish")
513@dispatch.add_dispatch_support
514@custom_gradient.custom_gradient
515def swish(features):
516  # pylint: disable=g-doc-args
517  """Computes the SiLU or Swish activation function: `x * sigmoid(x)`.
518
519  The SiLU activation function was introduced in "Gaussian Error Linear Units
520  (GELUs)" [Hendrycks et al. 2016](https://arxiv.org/abs/1606.08415) and
521  "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in
522  Reinforcement Learning"
523  [Elfwing et al. 2017](https://arxiv.org/abs/1702.03118) and was independently
524  discovered (and called swish) in "Searching for Activation Functions"
525  [Ramachandran et al. 2017](https://arxiv.org/abs/1710.05941)
526
527  Args:
528    features: A `Tensor` representing preactivation values.
529
530  Returns:
531    The activation value.
532  """
533  # pylint: enable=g-doc-args
534  features = ops.convert_to_tensor(features, name="features")
535
536  def grad(dy):
537    """Gradient for the Swish activation function"""
538    # Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x)
539    # around for backprop, effectively doubling the tensor's memory consumption.
540    # We use a control dependency here so that sigmoid(features) is re-computed
541    # during backprop (the control dep prevents it being de-duped with the
542    # forward pass) and we can free the sigmoid(features) expression immediately
543    # after use during the forward pass.
544    with ops.control_dependencies([dy]):
545      sigmoid_features = math_ops.sigmoid(features)
546    activation_grad = (
547        sigmoid_features * (1.0 + features * (1.0 - sigmoid_features)))
548    return dy * activation_grad
549
550  return features * math_ops.sigmoid(features), grad
551
552
553# pylint: disable=redefined-builtin
554@tf_export("linalg.normalize")
555@dispatch.add_dispatch_support
556def normalize(tensor, ord="euclidean", axis=None, name=None):
557  """Normalizes `tensor` along dimension `axis` using specified norm.
558
559  This uses `tf.linalg.norm` to compute the norm along `axis`.
560
561  This function can compute several different vector norms (the 1-norm, the
562  Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
563  matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).
564
565  Args:
566    tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
567    ord: Order of the norm. Supported values are `'fro'`, `'euclidean'`, `1`,
568      `2`, `np.inf` and any positive real number yielding the corresponding
569      p-norm. Default is `'euclidean'` which is equivalent to Frobenius norm if
570      `tensor` is a matrix and equivalent to 2-norm for vectors.
571      Some restrictions apply: a) The Frobenius norm `'fro'` is not defined for
572        vectors, b) If axis is a 2-tuple (matrix norm), only `'euclidean'`,
573        '`fro'`, `1`, `2`, `np.inf` are supported. See the description of `axis`
574        on how to compute norms for a batch of vectors or matrices stored in a
575        tensor.
576    axis: If `axis` is `None` (the default), the input is considered a vector
577      and a single vector norm is computed over the entire set of values in the
578      tensor, i.e. `norm(tensor, ord=ord)` is equivalent to
579      `norm(reshape(tensor, [-1]), ord=ord)`. If `axis` is a Python integer, the
580      input is considered a batch of vectors, and `axis` determines the axis in
581      `tensor` over which to compute vector norms. If `axis` is a 2-tuple of
582      Python integers it is considered a batch of matrices and `axis` determines
583      the axes in `tensor` over which to compute a matrix norm.
584      Negative indices are supported. Example: If you are passing a tensor that
585        can be either a matrix or a batch of matrices at runtime, pass
586        `axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are
587        computed.
588    name: The name of the op.
589
590  Returns:
591    normalized: A normalized `Tensor` with the same shape as `tensor`.
592    norm: The computed norms with the same shape and dtype `tensor` but the
593      final axis is 1 instead. Same as running
594      `tf.cast(tf.linalg.norm(tensor, ord, axis keepdims=True), tensor.dtype)`.
595
596  Raises:
597    ValueError: If `ord` or `axis` is invalid.
598  """
599  with ops.name_scope(name, "normalize", [tensor]) as name:
600    tensor = ops.convert_to_tensor(tensor)
601    norm = linalg_ops.norm(tensor, ord, axis, keepdims=True)
602    norm = math_ops.cast(norm, tensor.dtype)
603    normalized = tensor / norm
604    return normalized, norm
605
606
607@tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
608@dispatch.add_dispatch_support
609@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
610def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
611  """Normalizes along dimension `axis` using an L2 norm.
612
613  For a 1-D tensor with `axis = 0`, computes
614
615      output = x / sqrt(max(sum(x**2), epsilon))
616
617  For `x` with more dimensions, independently normalizes each 1-D slice along
618  dimension `axis`.
619
620  Args:
621    x: A `Tensor`.
622    axis: Dimension along which to normalize.  A scalar or a vector of
623      integers.
624    epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
625      divisor if `norm < sqrt(epsilon)`.
626    name: A name for this operation (optional).
627    dim: Deprecated alias for axis.
628
629  Returns:
630    A `Tensor` with the same shape as `x`.
631  """
632  axis = deprecated_argument_lookup("axis", axis, "dim", dim)
633  return l2_normalize_v2(x, axis, epsilon, name)
634
635
636@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[])
637@dispatch.add_dispatch_support
638def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None):
639  """Normalizes along dimension `axis` using an L2 norm.
640
641  For a 1-D tensor with `axis = 0`, computes
642
643      output = x / sqrt(max(sum(x**2), epsilon))
644
645  For `x` with more dimensions, independently normalizes each 1-D slice along
646  dimension `axis`.
647
648  * 1-D tensor example:
649  >>> x = tf.constant([3.0, 4.0])
650  >>> tf.math.l2_normalize(x).numpy()
651  array([0.6, 0.8], dtype=float32)
652
653  * 2-D tensor example:
654  >>> x = tf.constant([[3.0], [4.0]])
655  >>> tf.math.l2_normalize(x, 0).numpy()
656  array([[0.6],
657       [0.8]], dtype=float32)
658
659  >>> x = tf.constant([[3.0], [4.0]])
660  >>> tf.math.l2_normalize(x, 1).numpy()
661  array([[1.],
662       [1.]], dtype=float32)
663
664  Args:
665    x: A `Tensor`.
666    axis: Dimension along which to normalize.  A scalar or a vector of
667      integers.
668    epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
669      divisor if `norm < sqrt(epsilon)`.
670    name: A name for this operation (optional).
671
672  Returns:
673    A `Tensor` with the same shape as `x`.
674  """
675  with ops.name_scope(name, "l2_normalize", [x]) as name:
676    x = ops.convert_to_tensor(x, name="x")
677    if x.dtype.is_complex:
678      square_real = math_ops.square(math_ops.real(x))
679      square_imag = math_ops.square(math_ops.imag(x))
680      square_sum = math_ops.real(
681          math_ops.reduce_sum(square_real + square_imag, axis, keepdims=True))
682      x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
683      norm_real = math_ops.multiply(math_ops.real(x), x_inv_norm)
684      norm_imag = math_ops.multiply(math_ops.imag(x), x_inv_norm)
685      return math_ops.complex(norm_real, norm_imag, name=name)
686    square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True)
687    x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
688    return math_ops.multiply(x, x_inv_norm, name=name)
689
690
691def _count_nonzero(input_tensor, dtype=dtypes.int64):
692  """Same as math_ops.count_nonzero.
693
694  The reduction is done in dtype, which can be faster for 32-bit dtypes.
695
696  Args:
697      input_tensor: numeric tensor
698      dtype: reduction dtype
699
700  Returns:
701      number of nonzero values with type dtype
702  """
703  with ops.name_scope("count_nonzero", values=[input_tensor]):
704    zero = array_ops.zeros([], dtype=input_tensor.dtype)
705    nonzero_count = math_ops.reduce_sum(
706        math_ops.cast(
707            math_ops.not_equal(input_tensor, zero),
708            dtype=dtype), name="nonzero_count")
709    return nonzero_count
710
711
712@tf_export("math.zero_fraction", "nn.zero_fraction")
713@dispatch.add_dispatch_support
714def zero_fraction(value, name=None):
715  """Returns the fraction of zeros in `value`.
716
717  If `value` is empty, the result is `nan`.
718
719  This is useful in summaries to measure and report sparsity.  For example,
720
721  ```python
722      z = tf.nn.relu(...)
723      summ = tf.compat.v1.summary.scalar('sparsity', tf.nn.zero_fraction(z))
724  ```
725
726  Args:
727    value: A tensor of numeric type.
728    name: A name for the operation (optional).
729
730  Returns:
731    The fraction of zeros in `value`, with type `float32`.
732  """
733  with ops.name_scope(name, "zero_fraction", [value]):
734    value = ops.convert_to_tensor(value, name="value")
735    size = array_ops.size(value, out_type=dtypes.int64)
736    # If the count is small, we can save memory/CPU with an int32 reduction.
737    num_nonzero = control_flow_ops.cond(
738        size <= dtypes.int32.max,
739        # pylint: disable=g-long-lambda
740        true_fn=lambda: math_ops.cast(
741            _count_nonzero(value, dtype=dtypes.int32),
742            dtype=dtypes.int64),
743        false_fn=lambda: _count_nonzero(value, dtype=dtypes.int64))
744
745    with ops.name_scope("counts_to_fraction"):
746      num_zero = size - num_nonzero
747      num_zero_float32 = math_ops.cast(num_zero, dtype=dtypes.float32)
748      size_float32 = math_ops.cast(size, dtype=dtypes.float32)
749      zero_fraction_float32 = num_zero_float32 / size_float32
750
751    return array_ops.identity(zero_fraction_float32, "fraction")
752
753
754# pylint: disable=redefined-builtin
755@tf_export(v1=["nn.depthwise_conv2d"])
756@dispatch.add_dispatch_support
757def depthwise_conv2d(input,
758                     filter,
759                     strides,
760                     padding,
761                     rate=None,
762                     name=None,
763                     data_format=None,
764                     dilations=None):
765  """Depthwise 2-D convolution.
766
767  Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
768  and a filter tensor of shape
769  `[filter_height, filter_width, in_channels, channel_multiplier]`
770  containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
771  applies a different filter to each input channel (expanding from 1 channel
772  to `channel_multiplier` channels for each), then concatenates the results
773  together.  The output has `in_channels * channel_multiplier` channels.
774
775  In detail, with the default NHWC format,
776
777      output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
778           filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
779                                           strides[2] * j + rate[1] * dj, k]
780
781  Must have `strides[0] = strides[3] = 1`.  For the most common case of the
782  same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
783  If any value in `rate` is greater than 1, we perform atrous depthwise
784  convolution, in which case all values in the `strides` tensor must be equal
785  to 1.
786
787  Usage Example:
788
789  >>> x = np.array([
790  ...     [1., 2.],
791  ...     [3., 4.],
792  ...     [5., 6.]
793  ... ], dtype=np.float32).reshape((1, 3, 2, 1))
794  >>> kernel = np.array([
795  ...     [1., 2.],
796  ...     [3., 4]
797  ... ], dtype=np.float32).reshape((2, 1, 1, 2))
798  >>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
799  ...                                  padding='VALID').numpy()
800    array([[[[10., 14.],
801             [14., 20.]],
802            [[18., 26.],
803             [22., 32.]]]], dtype=float32)
804
805  >>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
806  ...                                  padding=[[0, 0], [1, 0], [1, 0], [0, 0]]
807  ...                                 ).numpy()
808    array([[[[ 0.,  0.],
809             [ 3.,  4.],
810             [ 6.,  8.]],
811            [[ 0.,  0.],
812             [10., 14.],
813             [14., 20.]],
814            [[ 0.,  0.],
815             [18., 26.],
816             [22., 32.]]]], dtype=float32)
817
818  Args:
819    input: 4-D with shape according to `data_format`.
820    filter: 4-D with shape
821      `[filter_height, filter_width, in_channels, channel_multiplier]`.
822    strides: 1-D of size 4.  The stride of the sliding window for each
823      dimension of `input`.
824    padding: Controls how to pad the image before applying the convolution. Can
825      be the string `"SAME"` or `"VALID"` indicating the type of padding
826      algorithm to use, or a list indicating the explicit paddings at the start
827      and end of each dimension. When explicit padding is used and data_format
828      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
829      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
830      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
831      [pad_top, pad_bottom], [pad_left, pad_right]]`.
832    rate: 1-D of size 2. The dilation rate in which we sample input values
833      across the `height` and `width` dimensions in atrous convolution. If it is
834      greater than 1, then all values of strides must be 1.
835    name: A name for this operation (optional).
836    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
837    dilations: Alias of rate.
838
839  Returns:
840    A 4-D `Tensor` with shape according to `data_format`.  E.g., for
841    "NHWC" format, shape is
842    `[batch, out_height, out_width, in_channels * channel_multiplier].`
843  """
844  rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
845  with ops.name_scope(name, "depthwise", [input, filter]) as name:
846    input = ops.convert_to_tensor(input, name="tensor_in")
847    filter = ops.convert_to_tensor(filter, name="filter_in")
848    if rate is None:
849      rate = [1, 1]
850
851    # Use depthwise_conv2d_native if executing on TPU.
852    if device_context.enclosing_tpu_context() is not None:
853      if data_format == "NCHW":
854        dilations = [1, 1, rate[0], rate[1]]
855      else:
856        dilations = [1, rate[0], rate[1], 1]
857      return nn_ops.depthwise_conv2d_native(
858          input=input,
859          filter=filter,
860          strides=strides,
861          padding=padding,
862          data_format=data_format,
863          dilations=dilations,
864          name=name)
865
866    def op(input_converted, _, padding):
867      return nn_ops.depthwise_conv2d_native(
868          input=input_converted,
869          filter=filter,
870          strides=strides,
871          padding=padding,
872          data_format=data_format,
873          name=name)
874
875    return nn_ops.with_space_to_batch(
876        input=input,
877        filter_shape=array_ops.shape(filter),
878        dilation_rate=rate,
879        padding=padding,
880        data_format=data_format,
881        op=op)
882
883
884@tf_export("nn.depthwise_conv2d", v1=[])
885@dispatch.add_dispatch_support
886def depthwise_conv2d_v2(input,
887                        filter,
888                        strides,
889                        padding,
890                        data_format=None,
891                        dilations=None,
892                        name=None):
893  """Depthwise 2-D convolution.
894
895  Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
896  and a filter tensor of shape
897  `[filter_height, filter_width, in_channels, channel_multiplier]`
898  containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
899  applies a different filter to each input channel (expanding from 1 channel
900  to `channel_multiplier` channels for each), then concatenates the results
901  together.  The output has `in_channels * channel_multiplier` channels.
902
903  In detail, with the default NHWC format,
904
905      output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
906           filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
907                                           strides[2] * j + rate[1] * dj, k]
908
909  Must have `strides[0] = strides[3] = 1`.  For the most common case of the
910  same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
911  If any value in `rate` is greater than 1, we perform atrous depthwise
912  convolution, in which case all values in the `strides` tensor must be equal
913  to 1.
914
915  Usage Example:
916
917  >>> x = np.array([
918  ...     [1., 2.],
919  ...     [3., 4.],
920  ...     [5., 6.]
921  ... ], dtype=np.float32).reshape((1, 3, 2, 1))
922  >>> kernel = np.array([
923  ...     [1., 2.],
924  ...     [3., 4]
925  ... ], dtype=np.float32).reshape((2, 1, 1, 2))
926  >>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
927  ...                        padding='VALID').numpy()
928    array([[[[10., 14.],
929             [14., 20.]],
930            [[18., 26.],
931             [22., 32.]]]], dtype=float32)
932
933  >>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
934  ...                        padding=[[0, 0], [1, 0], [1, 0], [0, 0]]).numpy()
935    array([[[[ 0.,  0.],
936             [ 3.,  4.],
937             [ 6.,  8.]],
938            [[ 0.,  0.],
939             [10., 14.],
940             [14., 20.]],
941            [[ 0.,  0.],
942             [18., 26.],
943             [22., 32.]]]], dtype=float32)
944
945  Args:
946    input: 4-D with shape according to `data_format`.
947    filter: 4-D with shape
948      `[filter_height, filter_width, in_channels, channel_multiplier]`.
949    strides: 1-D of size 4.  The stride of the sliding window for each
950      dimension of `input`.
951    padding: Controls how to pad the image before applying the convolution. Can
952      be the string `"SAME"` or `"VALID"` indicating the type of padding
953      algorithm to use, or a list indicating the explicit paddings at the start
954      and end of each dimension. When explicit padding is used and data_format
955      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
956      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
957      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
958      [pad_top, pad_bottom], [pad_left, pad_right]]`.
959    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
960    dilations: 1-D of size 2. The dilation rate in which we sample input values
961      across the `height` and `width` dimensions in atrous convolution. If it is
962      greater than 1, then all values of strides must be 1.
963    name: A name for this operation (optional).
964
965  Returns:
966    A 4-D `Tensor` with shape according to `data_format`.  E.g., for
967    "NHWC" format, shape is
968    `[batch, out_height, out_width, in_channels * channel_multiplier].`
969  """
970  return depthwise_conv2d(input=input,
971                          filter=filter,
972                          strides=strides,
973                          padding=padding,
974                          rate=dilations,
975                          name=name,
976                          data_format=data_format)
977
978# pylint: enable=redefined-builtin
979
980
981# pylint: disable=redefined-builtin,line-too-long
982@tf_export(v1=["nn.separable_conv2d"])
983@dispatch.add_dispatch_support
984def separable_conv2d(input,
985                     depthwise_filter,
986                     pointwise_filter,
987                     strides,
988                     padding,
989                     rate=None,
990                     name=None,
991                     data_format=None,
992                     dilations=None):
993  """2-D convolution with separable filters.
994
995  Performs a depthwise convolution that acts separately on channels followed by
996  a pointwise convolution that mixes channels.  Note that this is separability
997  between dimensions `[1, 2]` and `3`, not spatial separability between
998  dimensions `1` and `2`.
999
1000  In detail, with the default NHWC format,
1001
1002      output[b, i, j, k] = sum_{di, dj, q, r}
1003          input[b, strides[1] * i + di, strides[2] * j + dj, q] *
1004          depthwise_filter[di, dj, q, r] *
1005          pointwise_filter[0, 0, q * channel_multiplier + r, k]
1006
1007  `strides` controls the strides for the depthwise convolution only, since
1008  the pointwise convolution has implicit strides of `[1, 1, 1, 1]`.  Must have
1009  `strides[0] = strides[3] = 1`.  For the most common case of the same
1010  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
1011  If any value in `rate` is greater than 1, we perform atrous depthwise
1012  convolution, in which case all values in the `strides` tensor must be equal
1013  to 1.
1014
1015  Args:
1016    input: 4-D `Tensor` with shape according to `data_format`.
1017    depthwise_filter: 4-D `Tensor` with shape
1018      `[filter_height, filter_width, in_channels, channel_multiplier]`.
1019      Contains `in_channels` convolutional filters of depth 1.
1020    pointwise_filter: 4-D `Tensor` with shape
1021      `[1, 1, channel_multiplier * in_channels, out_channels]`.  Pointwise
1022      filter to mix channels after `depthwise_filter` has convolved spatially.
1023    strides: 1-D of size 4.  The strides for the depthwise convolution for
1024      each dimension of `input`.
1025    padding: Controls how to pad the image before applying the depthwise
1026      convolution. Can be the string `"SAME"` or `"VALID"` indicating the type
1027      of padding algorithm to use, or a Python list indicating the explicit
1028      paddings at the start and end of each dimension. When explicit padding is
1029      used and data_format is `"NHWC"`, this should be in the form `[[0, 0],
1030      [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit
1031      padding used and data_format is `"NCHW"`, this should be in the form
1032      `[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]]`.
1033    rate: 1-D of size 2. The dilation rate in which we sample input values
1034      across the `height` and `width` dimensions in atrous convolution. If it is
1035      greater than 1, then all values of strides must be 1.
1036    name: A name for this operation (optional).
1037    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
1038    dilations: Alias of rate.
1039
1040  Returns:
1041    A 4-D `Tensor` with shape according to 'data_format'. For
1042      example, with data_format="NHWC", shape is [batch, out_height,
1043      out_width, out_channels].
1044  """
1045  rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
1046  with ops.name_scope(name, "separable_conv2d",
1047                      [input, depthwise_filter, pointwise_filter]) as name:
1048    input = ops.convert_to_tensor(input, name="tensor_in")
1049    depthwise_filter = ops.convert_to_tensor(
1050        depthwise_filter, name="depthwise_filter")
1051    pointwise_filter = ops.convert_to_tensor(
1052        pointwise_filter, name="pointwise_filter")
1053
1054    pointwise_filter_shape = pointwise_filter.get_shape().with_rank(4)
1055    pointwise_filter_shape.dims[0].assert_is_compatible_with(1)
1056    pointwise_filter_shape.dims[1].assert_is_compatible_with(1)
1057
1058    if rate is None:
1059      rate = [1, 1]
1060
1061    # The layout of the ops in the graph are expected to be as follows:
1062    # depthwise_conv2d  // Conv2D op corresponding to native depthwise conv.
1063    # separable_conv2d  // Conv2D op corresponding to the pointwise conv.
1064
1065    def op(input_converted, _, padding):
1066      return nn_ops.depthwise_conv2d_native(
1067          input=input_converted,
1068          filter=depthwise_filter,
1069          strides=strides,
1070          padding=padding,
1071          data_format=data_format,
1072          name="depthwise")
1073
1074    depthwise = nn_ops.with_space_to_batch(
1075        input=input,
1076        filter_shape=array_ops.shape(depthwise_filter),
1077        dilation_rate=rate,
1078        padding=padding,
1079        data_format=data_format,
1080        op=op)
1081
1082    return nn_ops.conv2d(
1083        depthwise,
1084        pointwise_filter, [1, 1, 1, 1],
1085        padding="VALID",
1086        data_format=data_format,
1087        name=name)
1088
1089
1090@tf_export("nn.separable_conv2d", v1=[])
1091@dispatch.add_dispatch_support
1092def separable_conv2d_v2(
1093    input,
1094    depthwise_filter,
1095    pointwise_filter,
1096    strides,
1097    padding,
1098    data_format=None,
1099    dilations=None,
1100    name=None,
1101):
1102  """2-D convolution with separable filters.
1103
1104  Performs a depthwise convolution that acts separately on channels followed by
1105  a pointwise convolution that mixes channels.  Note that this is separability
1106  between dimensions `[1, 2]` and `3`, not spatial separability between
1107  dimensions `1` and `2`.
1108
1109  In detail, with the default NHWC format,
1110
1111      output[b, i, j, k] = sum_{di, dj, q, r}
1112          input[b, strides[1] * i + di, strides[2] * j + dj, q] *
1113          depthwise_filter[di, dj, q, r] *
1114          pointwise_filter[0, 0, q * channel_multiplier + r, k]
1115
1116  `strides` controls the strides for the depthwise convolution only, since
1117  the pointwise convolution has implicit strides of `[1, 1, 1, 1]`.  Must have
1118  `strides[0] = strides[3] = 1`.  For the most common case of the same
1119  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
1120  If any value in `rate` is greater than 1, we perform atrous depthwise
1121  convolution, in which case all values in the `strides` tensor must be equal
1122  to 1.
1123
1124  Args:
1125    input: 4-D `Tensor` with shape according to `data_format`.
1126    depthwise_filter: 4-D `Tensor` with shape `[filter_height, filter_width,
1127      in_channels, channel_multiplier]`. Contains `in_channels` convolutional
1128      filters of depth 1.
1129    pointwise_filter: 4-D `Tensor` with shape `[1, 1, channel_multiplier *
1130      in_channels, out_channels]`.  Pointwise filter to mix channels after
1131      `depthwise_filter` has convolved spatially.
1132    strides: 1-D of size 4.  The strides for the depthwise convolution for each
1133      dimension of `input`.
1134    padding: Controls how to pad the image before applying the depthwise
1135      convolution. Can be the string `"SAME"` or `"VALID"` indicating the type
1136      of padding algorithm to use, or a Python list indicating the explicit
1137      paddings at the start and end of each dimension. When explicit padding is
1138      used and data_format is `"NHWC"`, this should be in the form `[[0, 0],
1139      [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit
1140      padding used and data_format is `"NCHW"`, this should be in the form
1141      `[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]]`.
1142    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
1143    dilations: 1-D of size 2. The dilation rate in which we sample input values
1144      across the `height` and `width` dimensions in atrous convolution. If it is
1145      greater than 1, then all values of strides must be 1.
1146    name: A name for this operation (optional).
1147
1148  Returns:
1149    A 4-D `Tensor` with shape according to 'data_format'. For
1150      example, with data_format="NHWC", shape is [batch, out_height,
1151      out_width, out_channels].
1152  """
1153  return separable_conv2d(
1154      input,
1155      depthwise_filter,
1156      pointwise_filter,
1157      strides,
1158      padding,
1159      rate=dilations,
1160      name=name,
1161      data_format=data_format)
1162
1163# pylint: enable=redefined-builtin,line-too-long
1164
1165
1166@tf_export(v1=["nn.sufficient_statistics"])
1167@dispatch.add_dispatch_support
1168def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
1169                          keepdims=None):
1170  """Calculate the sufficient statistics for the mean and variance of `x`.
1171
1172  These sufficient statistics are computed using the one pass algorithm on
1173  an input that's optionally shifted. See:
1174  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
1175
1176  For example:
1177  >>> t = [[1, 2, 3], [4, 5, 6]]
1178  >>> sufficient_statistics(t, [1])
1179  (<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,),
1180  dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,),
1181  dtype=int32, numpy=array([14, 77], dtype=int32)>, None)
1182  >>> sufficient_statistics(t, [-1])
1183  (<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,),
1184  dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,),
1185  dtype=int32, numpy=array([14, 77], dtype=int32)>, None)
1186
1187  Args:
1188    x: A `Tensor`.
1189    axes: Array of ints. Axes along which to compute mean and variance. As in
1190      Python, the axes can also be negative numbers. A negative axis is
1191      interpreted as counting from the end of the rank, i.e., axis +
1192      rank(values)-th dimension.
1193    shift: A `Tensor` containing the value by which to shift the data for
1194      numerical stability, or `None` if no shift is to be performed. A shift
1195      close to the true mean provides the most numerically stable results.
1196    keep_dims: produce statistics with the same dimensionality as the input.
1197    name: Name used to scope the operations that compute the sufficient stats.
1198    keepdims: Alias for keep_dims.
1199
1200  Returns:
1201    Four `Tensor` objects of the same type as `x`:
1202
1203    * the count (number of elements to average over).
1204    * the (possibly shifted) sum of the elements in the array.
1205    * the (possibly shifted) sum of squares of the elements in the array.
1206    * the shift by which the mean must be corrected or None if `shift` is None.
1207  """
1208  axes = list(set(axes))
1209  keep_dims = deprecated_argument_lookup(
1210      "keepdims", keepdims, "keep_dims", keep_dims)
1211  if keep_dims is None:
1212    keep_dims = False
1213  with ops.name_scope(name, "sufficient_statistics", [x, shift]):
1214    x = ops.convert_to_tensor(x, name="x")
1215    x_shape = x.get_shape()
1216    if x_shape.rank is not None and all(
1217        x_shape.dims[d].value is not None for d in axes):
1218      counts = 1
1219      for d in axes:
1220        counts *= x_shape.dims[d].value
1221      counts = constant_op.constant(counts, dtype=x.dtype)
1222    else:  # shape needs to be inferred at runtime.
1223      # Normalize axes to be positive. Required for gather.
1224      rank = array_ops.rank(x)
1225      positive_axes = [axis + rank if axis < 0 else axis for axis in axes]
1226      x_dims = array_ops.gather(
1227          math_ops.cast(array_ops.shape(x), x.dtype), positive_axes)
1228      counts = math_ops.reduce_prod(x_dims, name="count")
1229    if shift is not None:
1230      shift = ops.convert_to_tensor(shift, name="shift")
1231      m_ss = math_ops.subtract(x, shift)
1232      v_ss = math_ops.squared_difference(x, shift)
1233    else:  # no shift.
1234      m_ss = x
1235      v_ss = math_ops.square(x)
1236    m_ss = math_ops.reduce_sum(m_ss, axes, keepdims=keep_dims, name="mean_ss")
1237    v_ss = math_ops.reduce_sum(v_ss, axes, keepdims=keep_dims, name="var_ss")
1238  return counts, m_ss, v_ss, shift
1239
1240
1241@tf_export("nn.sufficient_statistics", v1=[])
1242@dispatch.add_dispatch_support
1243def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None):
1244  """Calculate the sufficient statistics for the mean and variance of `x`.
1245
1246  These sufficient statistics are computed using the one pass algorithm on
1247  an input that's optionally shifted. See:
1248  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
1249
1250  Args:
1251    x: A `Tensor`.
1252    axes: Array of ints. Axes along which to compute mean and variance.
1253    shift: A `Tensor` containing the value by which to shift the data for
1254      numerical stability, or `None` if no shift is to be performed. A shift
1255      close to the true mean provides the most numerically stable results.
1256    keepdims: produce statistics with the same dimensionality as the input.
1257    name: Name used to scope the operations that compute the sufficient stats.
1258
1259  Returns:
1260    Four `Tensor` objects of the same type as `x`:
1261
1262    * the count (number of elements to average over).
1263    * the (possibly shifted) sum of the elements in the array.
1264    * the (possibly shifted) sum of squares of the elements in the array.
1265    * the shift by which the mean must be corrected or None if `shift` is None.
1266  """
1267  return sufficient_statistics(
1268      x=x, axes=axes, shift=shift, keep_dims=keepdims, name=name)
1269
1270
1271@tf_export("nn.normalize_moments")
1272@dispatch.add_dispatch_support
1273def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
1274  """Calculate the mean and variance of based on the sufficient statistics.
1275
1276  Args:
1277    counts: A `Tensor` containing the total count of the data (one value).
1278    mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly
1279      shifted) sum of the elements to average over.
1280    variance_ss: A `Tensor` containing the variance sufficient statistics: the
1281      (possibly shifted) squared sum of the data to compute the variance over.
1282    shift: A `Tensor` containing the value by which the data is shifted for
1283      numerical stability, or `None` if no shift was performed.
1284    name: Name used to scope the operations that compute the moments.
1285
1286  Returns:
1287    Two `Tensor` objects: `mean` and `variance`.
1288  """
1289  with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]):
1290    divisor = math_ops.reciprocal(counts, name="divisor")
1291    if shift is not None:
1292      shifted_mean = math_ops.multiply(mean_ss, divisor, name="shifted_mean")
1293      mean = math_ops.add(shifted_mean, shift, name="mean")
1294    else:  # no shift.
1295      shifted_mean = math_ops.multiply(mean_ss, divisor, name="mean")
1296      mean = shifted_mean
1297    variance = math_ops.subtract(
1298        math_ops.multiply(variance_ss, divisor),
1299        math_ops.square(shifted_mean),
1300        name="variance")
1301  return (mean, variance)
1302
1303
1304@tf_export(v1=["nn.moments"])
1305@dispatch.add_dispatch_support
1306def moments(
1307    x,
1308    axes,
1309    shift=None,  # pylint: disable=unused-argument
1310    name=None,
1311    keep_dims=None,
1312    keepdims=None):
1313  """Calculate the mean and variance of `x`.
1314
1315  The mean and variance are calculated by aggregating the contents of `x`
1316  across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
1317  and variance of a vector.
1318
1319  Note: shift is currently not used; the true mean is computed and used.
1320
1321  When using these moments for batch normalization (see
1322  `tf.nn.batch_normalization`):
1323
1324   * for so-called "global normalization", used with convolutional filters with
1325     shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
1326   * for simple batch normalization pass `axes=[0]` (batch only).
1327
1328  Args:
1329    x: A `Tensor`.
1330    axes: Array of ints.  Axes along which to compute mean and
1331      variance.
1332    shift: Not used in the current implementation
1333    name: Name used to scope the operations that compute the moments.
1334    keep_dims: produce moments with the same dimensionality as the input.
1335    keepdims: Alias to keep_dims.
1336
1337  Returns:
1338    Two `Tensor` objects: `mean` and `variance`.
1339  """
1340  keep_dims = deprecated_argument_lookup(
1341      "keepdims", keepdims, "keep_dims", keep_dims)
1342  if keep_dims is None:
1343    keep_dims = False
1344  with ops.name_scope(name, "moments", [x, axes]):
1345    # The dynamic range of fp16 is too limited to support the collection of
1346    # sufficient statistics. As a workaround we simply perform the operations
1347    # on 32-bit floats before converting the mean and variance back to fp16
1348    y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
1349    # Compute true mean while keeping the dims for proper broadcasting.
1350    mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")
1351    # sample variance, not unbiased variance
1352    # Note: stop_gradient does not change the gradient that gets
1353    #       backpropagated to the mean from the variance calculation,
1354    #       because that gradient is zero
1355    variance = math_ops.reduce_mean(
1356        math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
1357        axes,
1358        keepdims=True,
1359        name="variance")
1360    if not keep_dims:
1361      mean = array_ops.squeeze(mean, axes)
1362      variance = array_ops.squeeze(variance, axes)
1363    if x.dtype == dtypes.float16:
1364      return (math_ops.cast(mean, dtypes.float16),
1365              math_ops.cast(variance, dtypes.float16))
1366    else:
1367      return (mean, variance)
1368
1369
1370@tf_export("nn.moments", v1=[])
1371@dispatch.add_dispatch_support
1372def moments_v2(
1373    x,
1374    axes,
1375    shift=None,
1376    keepdims=False,
1377    name=None):
1378  """Calculates the mean and variance of `x`.
1379
1380  The mean and variance are calculated by aggregating the contents of `x`
1381  across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
1382  and variance of a vector.
1383
1384  Note: shift is currently not used; the true mean is computed and used.
1385
1386  When using these moments for batch normalization (see
1387  `tf.nn.batch_normalization`):
1388
1389   * for so-called "global normalization", used with convolutional filters with
1390     shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
1391   * for simple batch normalization pass `axes=[0]` (batch only).
1392
1393  Args:
1394    x: A `Tensor`.
1395    axes: Array of ints.  Axes along which to compute mean and
1396      variance.
1397    shift: Not used in the current implementation.
1398    keepdims: produce moments with the same dimensionality as the input.
1399    name: Name used to scope the operations that compute the moments.
1400
1401  Returns:
1402    Two `Tensor` objects: `mean` and `variance`.
1403  """
1404  return moments(x=x, axes=axes, shift=shift, name=name, keep_dims=keepdims)
1405
1406
1407@tf_export(v1=["nn.weighted_moments"])
1408@dispatch.add_dispatch_support
1409def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None,
1410                     keepdims=None):
1411  """Returns the frequency-weighted mean and variance of `x`.
1412
1413  Args:
1414    x: A tensor.
1415    axes: 1-d tensor of int32 values; these are the axes along which
1416      to compute mean and variance.
1417    frequency_weights: A tensor of positive weights which can be
1418      broadcast with x.
1419    name: Name used to scope the operation.
1420    keep_dims: Produce moments with the same dimensionality as the input.
1421    keepdims: Alias of keep_dims.
1422
1423  Returns:
1424    Two tensors: `weighted_mean` and `weighted_variance`.
1425  """
1426  keep_dims = deprecated_argument_lookup(
1427      "keepdims", keepdims, "keep_dims", keep_dims)
1428  if keep_dims is None:
1429    keep_dims = False
1430  with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]):
1431    x = ops.convert_to_tensor(x, name="x")
1432    frequency_weights = ops.convert_to_tensor(
1433        frequency_weights, name="frequency_weights")
1434
1435    # Unlike moments(), this just uses a simpler two-pass method.
1436
1437    # See comment in moments() WRT precision; it applies here too.
1438    needs_cast = x.dtype == dtypes.float16
1439    if needs_cast:
1440      x = math_ops.cast(x, dtypes.float32)
1441
1442    if frequency_weights.dtype != x.dtype:
1443      frequency_weights = math_ops.cast(frequency_weights, x.dtype)
1444
1445    # Note that we use keep_dims=True for our reductions regardless of the arg;
1446    # this is so that the results remain broadcast-compatible with the inputs.
1447    weighted_input_sum = math_ops.reduce_sum(
1448        frequency_weights * x, axes, name="weighted_input_sum", keepdims=True)
1449
1450    # The shape of the weights isn't necessarily the same as x's
1451    # shape, just broadcast-compatible with it -- so this expression
1452    # performs broadcasting to give a per-item weight, with the same
1453    # shape as (frequency_weights * x). This avoids having to reason
1454    # through all the broadcast logic to compute a correct
1455    # sum_of_weights.
1456    broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
1457
1458    sum_of_weights = math_ops.reduce_sum(
1459        broadcasted_weights, axes, name="sum_of_weights", keepdims=True)
1460
1461    divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum")
1462
1463    weighted_mean = math_ops.multiply(weighted_input_sum, divisor)
1464
1465    # Have the weighted mean; now on to variance:
1466    weighted_distsq = math_ops.reduce_sum(
1467        frequency_weights * math_ops.squared_difference(x, weighted_mean),
1468        axes,
1469        name="weighted_distsq",
1470        keepdims=True)
1471
1472    weighted_variance = math_ops.multiply(weighted_distsq, divisor)
1473
1474    if not keep_dims:
1475      weighted_mean = array_ops.squeeze(weighted_mean, axis=axes)
1476      weighted_variance = array_ops.squeeze(
1477          weighted_variance, axis=axes)
1478
1479    if needs_cast:
1480      weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
1481      weighted_variance = math_ops.cast(weighted_variance, dtypes.float16)
1482
1483    return weighted_mean, weighted_variance
1484
1485
1486@tf_export("nn.weighted_moments", v1=[])
1487@dispatch.add_dispatch_support
1488def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None):
1489  """Returns the frequency-weighted mean and variance of `x`.
1490
1491  Args:
1492    x: A tensor.
1493    axes: 1-d tensor of int32 values; these are the axes along which
1494      to compute mean and variance.
1495    frequency_weights: A tensor of positive weights which can be
1496      broadcast with x.
1497    keepdims: Produce moments with the same dimensionality as the input.
1498    name: Name used to scope the operation.
1499
1500  Returns:
1501    Two tensors: `weighted_mean` and `weighted_variance`.
1502  """
1503  return weighted_moments(
1504      x=x,
1505      axes=axes,
1506      frequency_weights=frequency_weights,
1507      name=name,
1508      keep_dims=keepdims)
1509
1510
1511@tf_export("nn.batch_normalization")
1512@dispatch.add_dispatch_support
1513def batch_normalization(x,
1514                        mean,
1515                        variance,
1516                        offset,
1517                        scale,
1518                        variance_epsilon,
1519                        name=None):
1520  r"""Batch normalization.
1521
1522  Normalizes a tensor by `mean` and `variance`, and applies (optionally) a
1523  `scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\):
1524
1525  \\(\frac{\gamma(x-\mu)}{\sigma}+\beta\\)
1526
1527  `mean`, `variance`, `offset` and `scale` are all expected to be of one of two
1528  shapes:
1529
1530    * In all generality, they can have the same number of dimensions as the
1531      input `x`, with identical sizes as `x` for the dimensions that are not
1532      normalized over (the 'depth' dimension(s)), and dimension 1 for the
1533      others which are being normalized over.
1534      `mean` and `variance` in this case would typically be the outputs of
1535      `tf.nn.moments(..., keepdims=True)` during training, or running averages
1536      thereof during inference.
1537    * In the common case where the 'depth' dimension is the last dimension in
1538      the input tensor `x`, they may be one dimensional tensors of the same
1539      size as the 'depth' dimension.
1540      This is the case for example for the common `[batch, depth]` layout of
1541      fully-connected layers, and `[batch, height, width, depth]` for
1542      convolutions.
1543      `mean` and `variance` in this case would typically be the outputs of
1544      `tf.nn.moments(..., keepdims=False)` during training, or running averages
1545      thereof during inference.
1546
1547  See equation 11 in Algorithm 2 of source:
1548  [Batch Normalization: Accelerating Deep Network Training by
1549  Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy]
1550  (http://arxiv.org/abs/1502.03167).
1551
1552  Args:
1553    x: Input `Tensor` of arbitrary dimensionality.
1554    mean: A mean `Tensor`.
1555    variance: A variance `Tensor`.
1556    offset: An offset `Tensor`, often denoted \\(\beta\\) in equations, or
1557      None. If present, will be added to the normalized tensor.
1558    scale: A scale `Tensor`, often denoted \\(\gamma\\) in equations, or
1559      `None`. If present, the scale is applied to the normalized tensor.
1560    variance_epsilon: A small float number to avoid dividing by 0.
1561    name: A name for this operation (optional).
1562
1563  Returns:
1564    the normalized, scaled, offset tensor.
1565
1566  References:
1567    Batch Normalization - Accelerating Deep Network Training by Reducing
1568    Internal Covariate Shift:
1569      [Ioffe et al., 2015](http://arxiv.org/abs/1502.03167)
1570      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1571  """
1572  with ops.name_scope(name, "batchnorm", [x, mean, variance, scale, offset]):
1573    inv = math_ops.rsqrt(variance + variance_epsilon)
1574    if scale is not None:
1575      inv *= scale
1576    # Note: tensorflow/contrib/quantize/python/fold_batch_norms.py depends on
1577    # the precise order of ops that are generated by the expression below.
1578    return x * math_ops.cast(inv, x.dtype) + math_ops.cast(
1579        offset - mean * inv if offset is not None else -mean * inv, x.dtype)
1580
1581
1582@tf_export(v1=["nn.fused_batch_norm"])
1583@dispatch.add_dispatch_support
1584def fused_batch_norm(
1585    x,
1586    scale,
1587    offset,  # pylint: disable=invalid-name
1588    mean=None,
1589    variance=None,
1590    epsilon=0.001,
1591    data_format="NHWC",
1592    is_training=True,
1593    name=None,
1594    exponential_avg_factor=1.0):
1595  r"""Batch normalization.
1596
1597
1598  See Source: [Batch Normalization: Accelerating Deep Network Training by
1599  Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy]
1600  (http://arxiv.org/abs/1502.03167).
1601
1602  Args:
1603    x: Input `Tensor` of 4 or 5 dimensions.
1604    scale: A `Tensor` of 1 dimension for scaling.
1605    offset: A `Tensor` of 1 dimension for bias.
1606    mean: A `Tensor` of 1 dimension for population mean. The shape and meaning
1607          of this argument depends on the value of is_training and
1608          exponential_avg_factor as follows:
1609          is_training==False (inference):
1610            Mean must be a `Tensor` of the same shape as scale containing the
1611            estimated population mean computed during training.
1612          is_training==True and exponential_avg_factor == 1.0:
1613            Mean must be None.
1614          is_training==True and exponential_avg_factor != 1.0:
1615            Mean must be a `Tensor` of the same shape as scale containing the
1616            exponential running mean.
1617    variance: A `Tensor` of 1 dimension for population variance. The shape and
1618          meaning of this argument depends on the value of is_training and
1619          exponential_avg_factor as follows:
1620          is_training==False (inference):
1621            Variance must be a `Tensor` of the same shape as scale containing
1622            the estimated population variance computed during training.
1623          is_training==True and exponential_avg_factor == 1.0:
1624            Variance must be None.
1625          is_training==True and exponential_avg_factor != 1.0:
1626            Variance must be a `Tensor` of the same shape as scale containing
1627            the exponential running variance.
1628    epsilon: A small float number added to the variance of x.
1629    data_format: The data format for x. Support "NHWC" (default) or "NCHW" for
1630                 4D tenors and "NDHWC" or "NCDHW" for 5D tensors.
1631    is_training: A bool value to specify if the operation is used for
1632                 training or inference.
1633    name: A name for this operation (optional).
1634    exponential_avg_factor: A float number (usually between 0 and 1) used
1635                            for controlling the decay of the running
1636                            population average of mean and variance.
1637                            If set to 1.0, the current batch average is
1638                            returned.
1639
1640  Returns:
1641    y: A 4D or 5D Tensor for the normalized, scaled, offsetted x.
1642    running_mean: A 1D Tensor for the exponential running mean of x.
1643                  The output value is (1 - exponential_avg_factor) * mean +
1644                  exponential_avg_factor * batch_mean), where batch_mean
1645                  is the mean of the current batch in x.
1646    running_var: A 1D Tensor for the exponential running variance
1647                 The output value is (1 - exponential_avg_factor) * variance +
1648                 exponential_avg_factor * batch_variance), where batch_variance
1649                 is the variance of the current batch in x.
1650
1651  References:
1652    Batch Normalization - Accelerating Deep Network Training by Reducing
1653    Internal Covariate Shift:
1654      [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1655      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1656  """
1657  if (not is_training or exponential_avg_factor != 1.0) and (
1658      (mean is None) or (variance is None)):
1659    raise ValueError("Both 'mean' and 'variance' must be a 1D tensor when "
1660                     "is_training is False or "
1661                     "exponential_avg_factor != 1.0.")
1662  x = ops.convert_to_tensor(x, name="input")
1663  scale = ops.convert_to_tensor(scale, name="scale")
1664  offset = ops.convert_to_tensor(offset, name="offset")
1665  if mean is None:
1666    mean = constant_op.constant([])
1667  if variance is None:
1668    variance = constant_op.constant([])
1669
1670  # Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
1671  # prevent exception (see cudnn.h).
1672  min_epsilon = 1.001e-5
1673  epsilon = epsilon if epsilon > min_epsilon else min_epsilon
1674
1675  y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
1676      x,
1677      scale,
1678      offset,
1679      mean,
1680      variance,
1681      epsilon=epsilon,
1682      exponential_avg_factor=exponential_avg_factor,
1683      data_format=data_format,
1684      is_training=is_training,
1685      name=name)
1686  return y, running_mean, running_var
1687
1688
1689@tf_export(v1=["nn.batch_norm_with_global_normalization"])
1690@dispatch.add_dispatch_support
1691def batch_norm_with_global_normalization(t=None,
1692                                         m=None,
1693                                         v=None,
1694                                         beta=None,
1695                                         gamma=None,
1696                                         variance_epsilon=None,
1697                                         scale_after_normalization=None,
1698                                         name=None,
1699                                         input=None,  # pylint: disable=redefined-builtin
1700                                         mean=None,
1701                                         variance=None):
1702  """Batch normalization.
1703
1704  This op is deprecated. See `tf.nn.batch_normalization`.
1705
1706  Args:
1707    t: A 4D input Tensor.
1708    m: A 1D mean Tensor with size matching the last dimension of t.
1709      This is the first output from tf.nn.moments,
1710      or a saved moving average thereof.
1711    v: A 1D variance Tensor with size matching the last dimension of t.
1712      This is the second output from tf.nn.moments,
1713      or a saved moving average thereof.
1714    beta: A 1D beta Tensor with size matching the last dimension of t.
1715      An offset to be added to the normalized tensor.
1716    gamma: A 1D gamma Tensor with size matching the last dimension of t.
1717      If "scale_after_normalization" is true, this tensor will be multiplied
1718      with the normalized tensor.
1719    variance_epsilon: A small float number to avoid dividing by 0.
1720    scale_after_normalization: A bool indicating whether the resulted tensor
1721      needs to be multiplied with gamma.
1722    name: A name for this operation (optional).
1723    input: Alias for t.
1724    mean: Alias for m.
1725    variance: Alias for v.
1726
1727  Returns:
1728     A batch-normalized `t`.
1729
1730  References:
1731    Batch Normalization - Accelerating Deep Network Training by Reducing
1732    Internal Covariate Shift:
1733      [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1734      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1735  """
1736  t = deprecated_argument_lookup("input", input, "t", t)
1737  m = deprecated_argument_lookup("mean", mean, "m", m)
1738  v = deprecated_argument_lookup("variance", variance, "v", v)
1739  return batch_normalization(t, m, v, beta, gamma if scale_after_normalization
1740                             else None, variance_epsilon, name)
1741
1742
1743# pylint: disable=redefined-builtin,line-too-long
1744@tf_export("nn.batch_norm_with_global_normalization", v1=[])
1745@dispatch.add_dispatch_support
1746def batch_norm_with_global_normalization_v2(input,
1747                                            mean,
1748                                            variance,
1749                                            beta,
1750                                            gamma,
1751                                            variance_epsilon,
1752                                            scale_after_normalization,
1753                                            name=None):
1754  """Batch normalization.
1755
1756  This op is deprecated. See `tf.nn.batch_normalization`.
1757
1758  Args:
1759    input: A 4D input Tensor.
1760    mean: A 1D mean Tensor with size matching the last dimension of t.
1761      This is the first output from tf.nn.moments,
1762      or a saved moving average thereof.
1763    variance: A 1D variance Tensor with size matching the last dimension of t.
1764      This is the second output from tf.nn.moments,
1765      or a saved moving average thereof.
1766    beta: A 1D beta Tensor with size matching the last dimension of t.
1767      An offset to be added to the normalized tensor.
1768    gamma: A 1D gamma Tensor with size matching the last dimension of t.
1769      If "scale_after_normalization" is true, this tensor will be multiplied
1770      with the normalized tensor.
1771    variance_epsilon: A small float number to avoid dividing by 0.
1772    scale_after_normalization: A bool indicating whether the resulted tensor
1773      needs to be multiplied with gamma.
1774    name: A name for this operation (optional).
1775
1776  Returns:
1777     A batch-normalized `t`.
1778
1779  References:
1780    Batch Normalization - Accelerating Deep Network Training by Reducing Internal Covariate Shift:
1781      [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1782      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1783  """
1784  return batch_norm_with_global_normalization(t=input,
1785                                              m=mean,
1786                                              v=variance,
1787                                              beta=beta,
1788                                              gamma=gamma,
1789                                              variance_epsilon=variance_epsilon,
1790                                              scale_after_normalization=scale_after_normalization,
1791                                              name=name)
1792
1793# pylint: enable=redefined-builtin,line-too-long
1794
1795
1796def _sum_rows(x):
1797  """Returns a vector summing up each row of the matrix x."""
1798  # _sum_rows(x) is equivalent to math_ops.reduce_sum(x, 1) when x is
1799  # a matrix.  The gradient of _sum_rows(x) is more efficient than
1800  # reduce_sum(x, 1)'s gradient in today's implementation. Therefore,
1801  # we use _sum_rows(x) in the nce_loss() computation since the loss
1802  # is mostly used for training.
1803  cols = array_ops.shape(x)[1]
1804  ones_shape = array_ops.stack([cols, 1])
1805  ones = array_ops.ones(ones_shape, x.dtype)
1806  return array_ops.reshape(math_ops.matmul(x, ones), [-1])
1807
1808
1809def _compute_sampled_logits(weights,
1810                            biases,
1811                            labels,
1812                            inputs,
1813                            num_sampled,
1814                            num_classes,
1815                            num_true=1,
1816                            sampled_values=None,
1817                            subtract_log_q=True,
1818                            remove_accidental_hits=False,
1819                            partition_strategy="mod",
1820                            name=None,
1821                            seed=None):
1822  """Helper function for nce_loss and sampled_softmax_loss functions.
1823
1824  Computes sampled output training logits and labels suitable for implementing
1825  e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
1826  sampled_softmax_loss).
1827
1828  Note: In the case where num_true > 1, we assign to each target class
1829  the target probability 1 / num_true so that the target probabilities
1830  sum to 1 per-example.
1831
1832  Args:
1833    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
1834        objects whose concatenation along dimension 0 has shape
1835        `[num_classes, dim]`.  The (possibly-partitioned) class embeddings.
1836    biases: A `Tensor` of shape `[num_classes]`.  The (possibly-partitioned)
1837        class biases.
1838    labels: A `Tensor` of type `int64` and shape `[batch_size,
1839        num_true]`. The target classes.  Note that this format differs from
1840        the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
1841    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
1842        activations of the input network.
1843    num_sampled: An `int`.  The number of classes to randomly sample per batch.
1844    num_classes: An `int`. The number of possible classes.
1845    num_true: An `int`.  The number of target classes per training example.
1846    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
1847        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
1848        (if None, we default to `log_uniform_candidate_sampler`)
1849    subtract_log_q: A `bool`.  whether to subtract the log expected count of
1850        the labels in the sample to get the logits of the true labels.
1851        Default is True.  Turn off for Negative Sampling.
1852    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
1853        where a sampled class equals one of the target classes.  Default is
1854        False.
1855    partition_strategy: A string specifying the partitioning strategy, relevant
1856        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
1857        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
1858    name: A name for the operation (optional).
1859    seed: random seed for candidate sampling. Default to None, which doesn't set
1860        the op-level random seed for candidate sampling.
1861  Returns:
1862    out_logits: `Tensor` object with shape
1863        `[batch_size, num_true + num_sampled]`, for passing to either
1864        `nn.sigmoid_cross_entropy_with_logits` (NCE) or
1865        `nn.softmax_cross_entropy_with_logits` (sampled softmax).
1866    out_labels: A Tensor object with the same shape as `out_logits`.
1867  """
1868
1869  if isinstance(weights, variables.PartitionedVariable):
1870    weights = list(weights)
1871  if not isinstance(weights, list):
1872    weights = [weights]
1873
1874  with ops.name_scope(name, "compute_sampled_logits",
1875                      weights + [biases, inputs, labels]):
1876    if labels.dtype != dtypes.int64:
1877      labels = math_ops.cast(labels, dtypes.int64)
1878    labels_flat = array_ops.reshape(labels, [-1])
1879
1880    # Sample the negative labels.
1881    #   sampled shape: [num_sampled] tensor
1882    #   true_expected_count shape = [batch_size, 1] tensor
1883    #   sampled_expected_count shape = [num_sampled] tensor
1884    if sampled_values is None:
1885      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
1886          true_classes=labels,
1887          num_true=num_true,
1888          num_sampled=num_sampled,
1889          unique=True,
1890          range_max=num_classes,
1891          seed=seed)
1892    # NOTE: pylint cannot tell that 'sampled_values' is a sequence
1893    # pylint: disable=unpacking-non-sequence
1894    sampled, true_expected_count, sampled_expected_count = (
1895        array_ops.stop_gradient(s) for s in sampled_values)
1896    # pylint: enable=unpacking-non-sequence
1897    sampled = math_ops.cast(sampled, dtypes.int64)
1898
1899    # labels_flat is a [batch_size * num_true] tensor
1900    # sampled is a [num_sampled] int tensor
1901    all_ids = array_ops.concat([labels_flat, sampled], 0)
1902
1903    # Retrieve the true weights and the logits of the sampled weights.
1904
1905    # weights shape is [num_classes, dim]
1906    all_w = embedding_ops.embedding_lookup(
1907        weights, all_ids, partition_strategy=partition_strategy)
1908    if all_w.dtype != inputs.dtype:
1909      all_w = math_ops.cast(all_w, inputs.dtype)
1910
1911    # true_w shape is [batch_size * num_true, dim]
1912    true_w = array_ops.slice(all_w, [0, 0],
1913                             array_ops.stack(
1914                                 [array_ops.shape(labels_flat)[0], -1]))
1915
1916    sampled_w = array_ops.slice(
1917        all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
1918    # inputs has shape [batch_size, dim]
1919    # sampled_w has shape [num_sampled, dim]
1920    # Apply X*W', which yields [batch_size, num_sampled]
1921    sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)
1922
1923    # Retrieve the true and sampled biases, compute the true logits, and
1924    # add the biases to the true and sampled logits.
1925    all_b = embedding_ops.embedding_lookup(
1926        biases, all_ids, partition_strategy=partition_strategy)
1927    if all_b.dtype != inputs.dtype:
1928      all_b = math_ops.cast(all_b, inputs.dtype)
1929    # true_b is a [batch_size * num_true] tensor
1930    # sampled_b is a [num_sampled] float tensor
1931    true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
1932    sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])
1933
1934    # inputs shape is [batch_size, dim]
1935    # true_w shape is [batch_size * num_true, dim]
1936    # row_wise_dots is [batch_size, num_true, dim]
1937    dim = array_ops.shape(true_w)[1:2]
1938    new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
1939    row_wise_dots = math_ops.multiply(
1940        array_ops.expand_dims(inputs, 1),
1941        array_ops.reshape(true_w, new_true_w_shape))
1942    # We want the row-wise dot plus biases which yields a
1943    # [batch_size, num_true] tensor of true_logits.
1944    dots_as_matrix = array_ops.reshape(row_wise_dots,
1945                                       array_ops.concat([[-1], dim], 0))
1946    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
1947    true_b = array_ops.reshape(true_b, [-1, num_true])
1948    true_logits += true_b
1949    sampled_logits += sampled_b
1950
1951    if remove_accidental_hits:
1952      acc_hits = candidate_sampling_ops.compute_accidental_hits(
1953          labels, sampled, num_true=num_true)
1954      acc_indices, acc_ids, acc_weights = acc_hits
1955
1956      # This is how SparseToDense expects the indices.
1957      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
1958      acc_ids_2d_int32 = array_ops.reshape(
1959          math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
1960      sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,
1961                                        "sparse_indices")
1962      # Create sampled_logits_shape = [batch_size, num_sampled]
1963      sampled_logits_shape = array_ops.concat(
1964          [array_ops.shape(labels)[:1],
1965           array_ops.expand_dims(num_sampled, 0)], 0)
1966      if sampled_logits.dtype != acc_weights.dtype:
1967        acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
1968      sampled_logits += gen_sparse_ops.sparse_to_dense(
1969          sparse_indices,
1970          sampled_logits_shape,
1971          acc_weights,
1972          default_value=0.0,
1973          validate_indices=False)
1974
1975    if subtract_log_q:
1976      # Subtract log of Q(l), prior probability that l appears in sampled.
1977      true_logits -= math_ops.log(true_expected_count)
1978      sampled_logits -= math_ops.log(sampled_expected_count)
1979
1980    # Construct output logits and labels. The true labels/logits start at col 0.
1981    out_logits = array_ops.concat([true_logits, sampled_logits], 1)
1982
1983    # true_logits is a float tensor, ones_like(true_logits) is a float
1984    # tensor of ones. We then divide by num_true to ensure the per-example
1985    # labels sum to 1.0, i.e. form a proper probability distribution.
1986    out_labels = array_ops.concat([
1987        array_ops.ones_like(true_logits) / num_true,
1988        array_ops.zeros_like(sampled_logits)
1989    ], 1)
1990
1991    return out_logits, out_labels
1992
1993
1994@tf_export("nn.nce_loss", v1=[])
1995@dispatch.add_dispatch_support
1996def nce_loss_v2(weights,
1997                biases,
1998                labels,
1999                inputs,
2000                num_sampled,
2001                num_classes,
2002                num_true=1,
2003                sampled_values=None,
2004                remove_accidental_hits=False,
2005                name="nce_loss"):
2006  """Computes and returns the noise-contrastive estimation training loss.
2007
2008  See [Noise-contrastive estimation: A new estimation principle for
2009  unnormalized statistical
2010  models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
2011  Also see our [Candidate Sampling Algorithms
2012  Reference](https://www.tensorflow.org/extras/candidate_sampling.pdf)
2013
2014  A common use case is to use this method for training, and calculate the full
2015  sigmoid loss for evaluation or inference as in the following example:
2016
2017  ```python
2018  if mode == "train":
2019    loss = tf.nn.nce_loss(
2020        weights=weights,
2021        biases=biases,
2022        labels=labels,
2023        inputs=inputs,
2024        ...)
2025  elif mode == "eval":
2026    logits = tf.matmul(inputs, tf.transpose(weights))
2027    logits = tf.nn.bias_add(logits, biases)
2028    labels_one_hot = tf.one_hot(labels, n_classes)
2029    loss = tf.nn.sigmoid_cross_entropy_with_logits(
2030        labels=labels_one_hot,
2031        logits=logits)
2032    loss = tf.reduce_sum(loss, axis=1)
2033  ```
2034
2035  Note: when doing embedding lookup on `weights` and `bias`, "div" partition
2036  strategy will be used. Support for other partition strategy will be added
2037  later.
2038
2039  Note: By default this uses a log-uniform (Zipfian) distribution for sampling,
2040  so your labels must be sorted in order of decreasing frequency to achieve
2041  good results.  For more details, see
2042  `tf.random.log_uniform_candidate_sampler`.
2043
2044  Note: In the case where `num_true` > 1, we assign to each target class
2045  the target probability 1 / `num_true` so that the target probabilities
2046  sum to 1 per-example.
2047
2048  Note: It would be useful to allow a variable number of target classes per
2049  example.  We hope to provide this functionality in a future release.
2050  For now, if you have a variable number of target classes, you can pad them
2051  out to a constant number by either repeating them or by padding
2052  with an otherwise unused class.
2053
2054  Args:
2055    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2056      objects whose concatenation along dimension 0 has shape [num_classes,
2057      dim].  The (possibly-partitioned) class embeddings.
2058    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2059    labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The
2060      target classes.
2061    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward activations of
2062      the input network.
2063    num_sampled: An `int`.  The number of negative classes to randomly sample
2064      per batch. This single sample of negative classes is evaluated for each
2065      element in the batch.
2066    num_classes: An `int`. The number of possible classes.
2067    num_true: An `int`.  The number of target classes per training example.
2068    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2069      `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2070      (if None, we default to `log_uniform_candidate_sampler`)
2071    remove_accidental_hits:  A `bool`.  Whether to remove "accidental hits"
2072      where a sampled class equals one of the target classes.  If set to `True`,
2073      this is a "Sampled Logistic" loss instead of NCE, and we are learning to
2074      generate log-odds instead of log probabilities.  See our [Candidate
2075      Sampling Algorithms Reference]
2076        (https://www.tensorflow.org/extras/candidate_sampling.pdf). Default is
2077          False.
2078    name: A name for the operation (optional).
2079
2080  Returns:
2081    A `batch_size` 1-D tensor of per-example NCE losses.
2082  """
2083  # TODO(yuefengz): get partition_strategy from either variables or distribution
2084  # strategies.
2085  return nce_loss(
2086      weights,
2087      biases,
2088      labels,
2089      inputs,
2090      num_sampled,
2091      num_classes,
2092      num_true=num_true,
2093      sampled_values=sampled_values,
2094      remove_accidental_hits=remove_accidental_hits,
2095      partition_strategy="div",
2096      name=name)
2097
2098
2099@tf_export(v1=["nn.nce_loss"])
2100@dispatch.add_dispatch_support
2101def nce_loss(weights,
2102             biases,
2103             labels,
2104             inputs,
2105             num_sampled,
2106             num_classes,
2107             num_true=1,
2108             sampled_values=None,
2109             remove_accidental_hits=False,
2110             partition_strategy="mod",
2111             name="nce_loss"):
2112  """Computes and returns the noise-contrastive estimation training loss.
2113
2114  A common use case is to use this method for training, and calculate the full
2115  sigmoid loss for evaluation or inference. In this case, you must set
2116  `partition_strategy="div"` for the two losses to be consistent, as in the
2117  following example:
2118
2119  ```python
2120  if mode == "train":
2121    loss = tf.nn.nce_loss(
2122        weights=weights,
2123        biases=biases,
2124        labels=labels,
2125        inputs=inputs,
2126        ...,
2127        partition_strategy="div")
2128  elif mode == "eval":
2129    logits = tf.matmul(inputs, tf.transpose(weights))
2130    logits = tf.nn.bias_add(logits, biases)
2131    labels_one_hot = tf.one_hot(labels, n_classes)
2132    loss = tf.nn.sigmoid_cross_entropy_with_logits(
2133        labels=labels_one_hot,
2134        logits=logits)
2135    loss = tf.reduce_sum(loss, axis=1)
2136  ```
2137
2138  Note: By default this uses a log-uniform (Zipfian) distribution for sampling,
2139  so your labels must be sorted in order of decreasing frequency to achieve
2140  good results.  For more details, see
2141  `tf.random.log_uniform_candidate_sampler`.
2142
2143  Note: In the case where `num_true` > 1, we assign to each target class
2144  the target probability 1 / `num_true` so that the target probabilities
2145  sum to 1 per-example.
2146
2147  Note: It would be useful to allow a variable number of target classes per
2148  example.  We hope to provide this functionality in a future release.
2149  For now, if you have a variable number of target classes, you can pad them
2150  out to a constant number by either repeating them or by padding
2151  with an otherwise unused class.
2152
2153  Args:
2154    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2155        objects whose concatenation along dimension 0 has shape
2156        [num_classes, dim].  The (possibly-partitioned) class embeddings.
2157    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2158    labels: A `Tensor` of type `int64` and shape `[batch_size,
2159        num_true]`. The target classes.
2160    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
2161        activations of the input network.
2162    num_sampled: An `int`.  The number of negative classes to randomly sample
2163        per batch. This single sample of negative classes is evaluated for each
2164        element in the batch.
2165    num_classes: An `int`. The number of possible classes.
2166    num_true: An `int`.  The number of target classes per training example.
2167    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2168        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2169        (if None, we default to `log_uniform_candidate_sampler`)
2170    remove_accidental_hits:  A `bool`.  Whether to remove "accidental hits"
2171        where a sampled class equals one of the target classes.  If set to
2172        `True`, this is a "Sampled Logistic" loss instead of NCE, and we are
2173        learning to generate log-odds instead of log probabilities. See
2174        our Candidate Sampling Algorithms Reference
2175        ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)).
2176        Default is False.
2177    partition_strategy: A string specifying the partitioning strategy, relevant
2178        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
2179        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
2180    name: A name for the operation (optional).
2181
2182  Returns:
2183    A `batch_size` 1-D tensor of per-example NCE losses.
2184
2185  References:
2186    Noise-contrastive estimation - A new estimation principle for unnormalized
2187    statistical models:
2188      [Gutmann et al., 2010](http://proceedings.mlr.press/v9/gutmann10a)
2189      ([pdf](http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf))
2190  """
2191  logits, labels = _compute_sampled_logits(
2192      weights=weights,
2193      biases=biases,
2194      labels=labels,
2195      inputs=inputs,
2196      num_sampled=num_sampled,
2197      num_classes=num_classes,
2198      num_true=num_true,
2199      sampled_values=sampled_values,
2200      subtract_log_q=True,
2201      remove_accidental_hits=remove_accidental_hits,
2202      partition_strategy=partition_strategy,
2203      name=name)
2204  sampled_losses = sigmoid_cross_entropy_with_logits(
2205      labels=labels, logits=logits, name="sampled_losses")
2206  # sampled_losses is batch_size x {true_loss, sampled_losses...}
2207  # We sum out true and sampled losses.
2208  return _sum_rows(sampled_losses)
2209
2210
2211@tf_export("nn.sampled_softmax_loss", v1=[])
2212@dispatch.add_dispatch_support
2213def sampled_softmax_loss_v2(weights,
2214                            biases,
2215                            labels,
2216                            inputs,
2217                            num_sampled,
2218                            num_classes,
2219                            num_true=1,
2220                            sampled_values=None,
2221                            remove_accidental_hits=True,
2222                            seed=None,
2223                            name="sampled_softmax_loss"):
2224  """Computes and returns the sampled softmax training loss.
2225
2226  This is a faster way to train a softmax classifier over a huge number of
2227  classes.
2228
2229  This operation is for training only.  It is generally an underestimate of
2230  the full softmax loss.
2231
2232  A common use case is to use this method for training, and calculate the full
2233  sigmoid loss for evaluation or inference as in the following example:
2234
2235  ```python
2236  if mode == "train":
2237    loss = tf.nn.sampled_softmax_loss(
2238        weights=weights,
2239        biases=biases,
2240        labels=labels,
2241        inputs=inputs,
2242        ...)
2243  elif mode == "eval":
2244    logits = tf.matmul(inputs, tf.transpose(weights))
2245    logits = tf.nn.bias_add(logits, biases)
2246    labels_one_hot = tf.one_hot(labels, n_classes)
2247    loss = tf.nn.softmax_cross_entropy_with_logits(
2248        labels=labels_one_hot,
2249        logits=logits)
2250  ```
2251
2252  See our [Candidate Sampling Algorithms Reference]
2253  (https://www.tensorflow.org/extras/candidate_sampling.pdf)
2254
2255  Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
2256  ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.
2257
2258  Note: when doing embedding lookup on `weights` and `bias`, "div" partition
2259  strategy will be used. Support for other partition strategy will be added
2260  later.
2261
2262  Args:
2263    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2264      objects whose concatenation along dimension 0 has shape [num_classes,
2265      dim].  The (possibly-sharded) class embeddings.
2266    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2267    labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The
2268      target classes.  Note that this format differs from the `labels` argument
2269      of `nn.softmax_cross_entropy_with_logits`.
2270    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward activations of
2271      the input network.
2272    num_sampled: An `int`.  The number of classes to randomly sample per batch.
2273    num_classes: An `int`. The number of possible classes.
2274    num_true: An `int`.  The number of target classes per training example.
2275    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2276      `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2277      (if None, we default to `log_uniform_candidate_sampler`)
2278    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
2279      where a sampled class equals one of the target classes.  Default is True.
2280    seed: random seed for candidate sampling. Default to None, which doesn't set
2281      the op-level random seed for candidate sampling.
2282    name: A name for the operation (optional).
2283
2284  Returns:
2285    A `batch_size` 1-D tensor of per-example sampled softmax losses.
2286
2287  """
2288  return sampled_softmax_loss(
2289      weights,
2290      biases,
2291      labels,
2292      inputs,
2293      num_sampled,
2294      num_classes,
2295      num_true=num_true,
2296      sampled_values=sampled_values,
2297      remove_accidental_hits=remove_accidental_hits,
2298      partition_strategy="div",
2299      name=name,
2300      seed=seed)
2301
2302
2303@tf_export(v1=["nn.sampled_softmax_loss"])
2304@dispatch.add_dispatch_support
2305def sampled_softmax_loss(weights,
2306                         biases,
2307                         labels,
2308                         inputs,
2309                         num_sampled,
2310                         num_classes,
2311                         num_true=1,
2312                         sampled_values=None,
2313                         remove_accidental_hits=True,
2314                         partition_strategy="mod",
2315                         name="sampled_softmax_loss",
2316                         seed=None):
2317  """Computes and returns the sampled softmax training loss.
2318
2319  This is a faster way to train a softmax classifier over a huge number of
2320  classes.
2321
2322  This operation is for training only.  It is generally an underestimate of
2323  the full softmax loss.
2324
2325  A common use case is to use this method for training, and calculate the full
2326  softmax loss for evaluation or inference. In this case, you must set
2327  `partition_strategy="div"` for the two losses to be consistent, as in the
2328  following example:
2329
2330  ```python
2331  if mode == "train":
2332    loss = tf.nn.sampled_softmax_loss(
2333        weights=weights,
2334        biases=biases,
2335        labels=labels,
2336        inputs=inputs,
2337        ...,
2338        partition_strategy="div")
2339  elif mode == "eval":
2340    logits = tf.matmul(inputs, tf.transpose(weights))
2341    logits = tf.nn.bias_add(logits, biases)
2342    labels_one_hot = tf.one_hot(labels, n_classes)
2343    loss = tf.nn.softmax_cross_entropy_with_logits(
2344        labels=labels_one_hot,
2345        logits=logits)
2346  ```
2347
2348  See our Candidate Sampling Algorithms Reference
2349  ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)).
2350  Also see Section 3 of (Jean et al., 2014) for the math.
2351
2352  Args:
2353    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2354        objects whose concatenation along dimension 0 has shape
2355        [num_classes, dim].  The (possibly-sharded) class embeddings.
2356    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2357    labels: A `Tensor` of type `int64` and shape `[batch_size,
2358        num_true]`. The target classes.  Note that this format differs from
2359        the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
2360    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
2361        activations of the input network.
2362    num_sampled: An `int`.  The number of classes to randomly sample per batch.
2363    num_classes: An `int`. The number of possible classes.
2364    num_true: An `int`.  The number of target classes per training example.
2365    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2366        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2367        (if None, we default to `log_uniform_candidate_sampler`)
2368    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
2369        where a sampled class equals one of the target classes.  Default is
2370        True.
2371    partition_strategy: A string specifying the partitioning strategy, relevant
2372        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
2373        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
2374    name: A name for the operation (optional).
2375    seed: random seed for candidate sampling. Default to None, which doesn't set
2376        the op-level random seed for candidate sampling.
2377
2378  Returns:
2379    A `batch_size` 1-D tensor of per-example sampled softmax losses.
2380
2381  References:
2382    On Using Very Large Target Vocabulary for Neural Machine Translation:
2383      [Jean et al., 2014]
2384      (https://aclanthology.coli.uni-saarland.de/papers/P15-1001/p15-1001)
2385      ([pdf](http://aclweb.org/anthology/P15-1001))
2386  """
2387  logits, labels = _compute_sampled_logits(
2388      weights=weights,
2389      biases=biases,
2390      labels=labels,
2391      inputs=inputs,
2392      num_sampled=num_sampled,
2393      num_classes=num_classes,
2394      num_true=num_true,
2395      sampled_values=sampled_values,
2396      subtract_log_q=True,
2397      remove_accidental_hits=remove_accidental_hits,
2398      partition_strategy=partition_strategy,
2399      name=name,
2400      seed=seed)
2401  labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
2402  sampled_losses = nn_ops.softmax_cross_entropy_with_logits_v2(
2403      labels=labels, logits=logits)
2404  # sampled_losses is a [batch_size] tensor.
2405  return sampled_losses
2406