1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Wrappers for primitive Neural Net (NN) Operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import numbers
23import os
24
25import numpy as np
26
27from tensorflow.python.eager import context
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors_impl
31from tensorflow.python.framework import graph_util
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import random_seed
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import check_ops
38from tensorflow.python.ops import gen_math_ops
39from tensorflow.python.ops import gen_nn_ops
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import random_ops
42from tensorflow.python.ops import variables as variables_lib
43# go/tf-wildcard-import
44# pylint: disable=wildcard-import
45from tensorflow.python.ops.gen_nn_ops import *
46# pylint: enable=wildcard-import
47from tensorflow.python.platform import device_context
48from tensorflow.python.util import deprecation
49from tensorflow.python.util import dispatch
50from tensorflow.python.util.compat import collections_abc
51from tensorflow.python.util.deprecation import deprecated_args
52from tensorflow.python.util.deprecation import deprecated_argument_lookup
53
54from tensorflow.python.util.tf_export import tf_export
55
56# Aliases for some automatically-generated names.
57local_response_normalization = gen_nn_ops.lrn
58
59# pylint: disable=protected-access
60
61# Acceptable channels last formats (robust to H, W, D order).
62_CHANNELS_LAST_FORMATS = frozenset({
63    "NWC", "NHC", "NHWC", "NWHC", "NDHWC", "NDWHC", "NHDWC", "NHWDC", "NWDHC",
64    "NWHDC"
65})
66
67
68def _get_sequence(value, n, channel_index, name):
69  """Formats a value input for gen_nn_ops."""
70  # Performance is fast-pathed for common cases:
71  # `None`, `list`, `tuple` and `int`.
72  if value is None:
73    return [1] * (n + 2)
74
75  # Always convert `value` to a `list`.
76  if isinstance(value, list):
77    pass
78  elif isinstance(value, tuple):
79    value = list(value)
80  elif isinstance(value, int):
81    value = [value]
82  elif not isinstance(value, collections_abc.Sized):
83    value = [value]
84  else:
85    value = list(value)  # Try casting to a list.
86
87  len_value = len(value)
88
89  # Fully specified, including batch and channel dims.
90  if len_value == n + 2:
91    return value
92
93  # Apply value to spatial dims only.
94  if len_value == 1:
95    value = value * n  # Broadcast to spatial dimensions.
96  elif len_value != n:
97    raise ValueError("{} should be of length 1, {} or {} but was {}".format(
98        name, n, n + 2, len_value))
99
100  # Add batch and channel dims (always 1).
101  if channel_index == 1:
102    return [1, 1] + value
103  else:
104    return [1] + value + [1]
105
106
107def _non_atrous_convolution(
108    input,  # pylint: disable=redefined-builtin
109    filter,  # pylint: disable=redefined-builtin
110    padding,
111    data_format=None,  # pylint: disable=redefined-builtin
112    strides=None,
113    name=None):
114  """Computes sums of N-D convolutions (actually cross correlation).
115
116  It is required that 1 <= N <= 3.
117
118  This is used to implement the more generic `convolution` function, which
119  extends the interface of this function with a `dilation_rate` parameter.
120
121  Args:
122
123    input: Rank N+2 tensor of type T of shape
124      `[batch_size] + input_spatial_shape + [in_channels]` if `data_format`
125      does not start with `"NC"`, or
126      `[batch_size, in_channels] + input_spatial_shape` if `data_format` starts
127      with `"NC"`.
128    filter: Rank N+2 tensor of type T of shape
129      `filter_spatial_shape + [in_channels, out_channels]`.  Rank of either
130      `input` or `filter` must be known.
131    padding: Padding method to use, must be either "VALID" or "SAME".
132    data_format: A string or None.  Specifies whether the channel dimension of
133      the `input` and output is the last dimension (default, or if `data_format`
134      does not start with "NC"), or the second dimension (if `data_format`
135      starts with "NC").  For N=1, the valid values are "NWC" (default) and
136      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
137      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
138    strides: Sequence of N positive integers, defaults to `[1] * N`.
139    name: Name prefix to use.
140
141  Returns:
142    Rank N+2 tensor of type T of shape
143    `[batch_size] + output_spatial_shape + [out_channels]`, where
144    if padding == "SAME":
145      output_spatial_shape = input_spatial_shape
146    if padding == "VALID":
147      output_spatial_shape = input_spatial_shape - filter_spatial_shape + 1.
148
149  Raises:
150    ValueError: if ranks are incompatible.
151
152  """
153  with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope:
154    input = ops.convert_to_tensor(input, name="input")  # pylint: disable=redefined-builtin
155    input_shape = input.shape
156    filter = ops.convert_to_tensor(filter, name="filter")  # pylint: disable=redefined-builtin
157    filter_shape = filter.shape
158    op = _NonAtrousConvolution(
159        input_shape,
160        filter_shape=filter_shape,
161        padding=padding,
162        data_format=data_format,
163        strides=strides,
164        name=scope)
165    return op(input, filter)
166
167
168class _NonAtrousConvolution(object):
169  """Helper class for _non_atrous_convolution.
170
171  Note that this class assumes that shapes of input and filter passed to
172  `__call__` are compatible with `input_shape` and filter_shape passed to the
173  constructor.
174
175  Args:
176    input_shape: static input shape, i.e. input.shape.
177    filter_shape: static filter shape, i.e. filter.shape.
178    padding: see _non_atrous_convolution.
179    data_format: see _non_atrous_convolution.
180    strides: see _non_atrous_convolution.
181    name: see _non_atrous_convolution.
182    num_batch_dims: (Optional.)  The number of batch dimensions in the input;
183     if not provided, the default of `1` is used.
184  """
185
186  def __init__(
187      self,
188      input_shape,
189      filter_shape,
190      padding,
191      data_format=None,
192      strides=None,
193      name=None,
194      num_batch_dims=1):
195    # filter shape is always rank num_spatial_dims + 2
196    # and num_spatial_dims == input_shape.ndims - num_batch_dims - 1
197    if input_shape.ndims is not None:
198      filter_shape = filter_shape.with_rank(
199          input_shape.ndims - num_batch_dims + 1)
200    self.padding = padding
201    self.name = name
202    # input shape is == num_spatial_dims + num_batch_dims + 1
203    # and filter_shape is always rank num_spatial_dims + 2
204    if filter_shape.ndims is not None:
205      input_shape = input_shape.with_rank(
206          filter_shape.ndims + num_batch_dims - 1)
207    if input_shape.ndims is None:
208      raise ValueError(
209          "Rank of convolution must be known, but saw input_shape.ndims == {}"
210          .format(input_shape.ndims))
211    if input_shape.ndims < 3 or input_shape.ndims - num_batch_dims + 1 > 5:
212      raise ValueError(
213          "`input_shape.ndims - num_batch_dims + 1` must be at least 3 and at "
214          "most 5 but saw `input_shape.ndims == {}` and `num_batch_dims == {}`"
215          .format(input_shape.ndims, num_batch_dims))
216    conv_dims = input_shape.ndims - num_batch_dims - 1
217    if strides is None:
218      strides = [1] * conv_dims
219    elif len(strides) != conv_dims:
220      raise ValueError("len(strides)=%d, but should be %d" % (len(strides),
221                                                              conv_dims))
222    if conv_dims == 1:
223      # conv1d uses the 2-d data format names
224      if data_format is None:
225        data_format = "NWC"
226      elif data_format not in {"NCW", "NWC", "NCHW", "NHWC"}:
227        raise ValueError("data_format must be \"NWC\" or \"NCW\".")
228      self.strides = strides[0]
229      self.data_format = data_format
230      self.conv_op = self._conv1d
231    elif conv_dims == 2:
232      if data_format is None or data_format == "NHWC":
233        data_format = "NHWC"
234        strides = [1] + list(strides) + [1]
235      elif data_format == "NCHW":
236        strides = [1, 1] + list(strides)
237      else:
238        raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
239      self.strides = strides
240      self.data_format = data_format
241      self.conv_op = conv2d
242    elif conv_dims == 3:
243      if data_format is None or data_format == "NDHWC":
244        strides = [1] + list(strides) + [1]
245      elif data_format == "NCDHW":
246        strides = [1, 1] + list(strides)
247      else:
248        raise ValueError("data_format must be \"NDHWC\" or \"NCDHW\". Have: %s"
249                         % data_format)
250      self.strides = strides
251      self.data_format = data_format
252      self.conv_op = _conv3d_expanded_batch
253
254  # Note that we need this adapter since argument names for conv1d don't match
255  # those for gen_nn_ops.conv2d and gen_nn_ops.conv3d.
256  # pylint: disable=redefined-builtin
257  def _conv1d(self, input, filter, strides, padding, data_format, name):
258    return conv1d(
259        value=input,
260        filters=filter,
261        stride=strides,
262        padding=padding,
263        data_format=data_format,
264        name=name)
265  # pylint: enable=redefined-builtin
266
267  def __call__(self, inp, filter):  # pylint: disable=redefined-builtin
268    return self.conv_op(
269        input=inp,
270        filter=filter,
271        strides=self.strides,
272        padding=self.padding,
273        data_format=self.data_format,
274        name=self.name)
275
276
277def squeeze_batch_dims(inp, op, inner_rank, name=None):
278  """Returns `unsqueeze_batch(op(squeeze_batch(inp)))`.
279
280  Where `squeeze_batch` reshapes `inp` to shape
281  `[prod(inp.shape[:-inner_rank])] + inp.shape[-inner_rank:]`
282  and `unsqueeze_batch` does the reverse reshape but on the output.
283
284  Args:
285    inp: A tensor with dims `batch_shape + inner_shape` where `inner_shape`
286      is length `inner_rank`.
287    op: A callable that takes a single input tensor and returns a single.
288      output tensor.
289    inner_rank: A python integer.
290    name: A string.
291
292  Returns:
293    `unsqueeze_batch_op(squeeze_batch(inp))`.
294  """
295  with ops.name_scope(name, "squeeze_batch_dims", [inp]):
296    inp = ops.convert_to_tensor(inp, name="input")
297    shape = inp.shape
298
299    inner_shape = shape[-inner_rank:]
300    if not inner_shape.is_fully_defined():
301      inner_shape = array_ops.shape(inp)[-inner_rank:]
302
303    batch_shape = shape[:-inner_rank]
304    if not batch_shape.is_fully_defined():
305      batch_shape = array_ops.shape(inp)[:-inner_rank]
306
307    if isinstance(inner_shape, tensor_shape.TensorShape):
308      inp_reshaped = array_ops.reshape(inp, [-1] + inner_shape.as_list())
309    else:
310      inp_reshaped = array_ops.reshape(
311          inp, array_ops.concat(([-1], inner_shape), axis=-1))
312
313    out_reshaped = op(inp_reshaped)
314
315    out_inner_shape = out_reshaped.shape[-inner_rank:]
316    if not out_inner_shape.is_fully_defined():
317      out_inner_shape = array_ops.shape(out_reshaped)[-inner_rank:]
318
319    out = array_ops.reshape(
320        out_reshaped, array_ops.concat((batch_shape, out_inner_shape), axis=-1))
321
322    out.set_shape(inp.shape[:-inner_rank] + out.shape[-inner_rank:])
323    return out
324
325
326@tf_export("nn.dilation2d", v1=[])
327@dispatch.add_dispatch_support
328def dilation2d_v2(
329    input,   # pylint: disable=redefined-builtin
330    filters,  # pylint: disable=redefined-builtin
331    strides,
332    padding,
333    data_format,
334    dilations,
335    name=None):
336  """Computes the grayscale dilation of 4-D `input` and 3-D `filters` tensors.
337
338  The `input` tensor has shape `[batch, in_height, in_width, depth]` and the
339  `filters` tensor has shape `[filter_height, filter_width, depth]`, i.e., each
340  input channel is processed independently of the others with its own
341  structuring function. The `output` tensor has shape
342  `[batch, out_height, out_width, depth]`. The spatial dimensions of the output
343  tensor depend on the `padding` algorithm. We currently only support the
344  default "NHWC" `data_format`.
345
346  In detail, the grayscale morphological 2-D dilation is the max-sum correlation
347  (for consistency with `conv2d`, we use unmirrored filters):
348
349      output[b, y, x, c] =
350         max_{dy, dx} input[b,
351                            strides[1] * y + rates[1] * dy,
352                            strides[2] * x + rates[2] * dx,
353                            c] +
354                      filters[dy, dx, c]
355
356  Max-pooling is a special case when the filter has size equal to the pooling
357  kernel size and contains all zeros.
358
359  Note on duality: The dilation of `input` by the `filters` is equal to the
360  negation of the erosion of `-input` by the reflected `filters`.
361
362  Args:
363    input: A `Tensor`. Must be one of the following types: `float32`, `float64`,
364      `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`,
365      `uint32`, `uint64`.
366      4-D with shape `[batch, in_height, in_width, depth]`.
367    filters: A `Tensor`. Must have the same type as `input`.
368      3-D with shape `[filter_height, filter_width, depth]`.
369    strides: A list of `ints` that has length `>= 4`.
370      The stride of the sliding window for each dimension of the input
371      tensor. Must be: `[1, stride_height, stride_width, 1]`.
372    padding: A `string` from: `"SAME", "VALID"`.
373      The type of padding algorithm to use.
374    data_format: A `string`, only `"NHWC"` is currently supported.
375    dilations: A list of `ints` that has length `>= 4`.
376      The input stride for atrous morphological dilation. Must be:
377      `[1, rate_height, rate_width, 1]`.
378    name: A name for the operation (optional).
379
380  Returns:
381    A `Tensor`. Has the same type as `input`.
382  """
383  if data_format != "NHWC":
384    raise ValueError("Data formats other than NHWC are not yet supported")
385
386  return gen_nn_ops.dilation2d(input=input,
387                               filter=filters,
388                               strides=strides,
389                               rates=dilations,
390                               padding=padding,
391                               name=name)
392
393
394@tf_export(v1=["nn.dilation2d"])
395@dispatch.add_dispatch_support
396def dilation2d_v1(  # pylint: disable=missing-docstring
397    input,  # pylint: disable=redefined-builtin
398    filter=None,  # pylint: disable=redefined-builtin
399    strides=None,
400    rates=None,
401    padding=None,
402    name=None,
403    filters=None,
404    dilations=None):
405  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
406  rates = deprecated_argument_lookup("dilations", dilations, "rates", rates)
407  return gen_nn_ops.dilation2d(input, filter, strides, rates, padding, name)
408
409
410dilation2d_v1.__doc__ = gen_nn_ops.dilation2d.__doc__
411
412
413@tf_export("nn.with_space_to_batch")
414@dispatch.add_dispatch_support
415def with_space_to_batch(
416    input,  # pylint: disable=redefined-builtin
417    dilation_rate,
418    padding,
419    op,
420    filter_shape=None,
421    spatial_dims=None,
422    data_format=None):
423  """Performs `op` on the space-to-batch representation of `input`.
424
425  This has the effect of transforming sliding window operations into the
426  corresponding "atrous" operation in which the input is sampled at the
427  specified `dilation_rate`.
428
429  In the special case that `dilation_rate` is uniformly 1, this simply returns:
430
431    op(input, num_spatial_dims, padding)
432
433  Otherwise, it returns:
434
435    batch_to_space_nd(
436      op(space_to_batch_nd(input, adjusted_dilation_rate, adjusted_paddings),
437         num_spatial_dims,
438         "VALID")
439      adjusted_dilation_rate,
440      adjusted_crops),
441
442  where:
443
444    adjusted_dilation_rate is an int64 tensor of shape [max(spatial_dims)],
445    adjusted_{paddings,crops} are int64 tensors of shape [max(spatial_dims), 2]
446
447  defined as follows:
448
449  We first define two int64 tensors `paddings` and `crops` of shape
450  `[num_spatial_dims, 2]` based on the value of `padding` and the spatial
451  dimensions of the `input`:
452
453  If `padding = "VALID"`, then:
454
455    paddings, crops = required_space_to_batch_paddings(
456      input_shape[spatial_dims],
457      dilation_rate)
458
459  If `padding = "SAME"`, then:
460
461    dilated_filter_shape =
462      filter_shape + (filter_shape - 1) * (dilation_rate - 1)
463
464    paddings, crops = required_space_to_batch_paddings(
465      input_shape[spatial_dims],
466      dilation_rate,
467      [(dilated_filter_shape - 1) // 2,
468       dilated_filter_shape - 1 - (dilated_filter_shape - 1) // 2])
469
470  Because `space_to_batch_nd` and `batch_to_space_nd` assume that the spatial
471  dimensions are contiguous starting at the second dimension, but the specified
472  `spatial_dims` may not be, we must adjust `dilation_rate`, `paddings` and
473  `crops` in order to be usable with these operations.  For a given dimension,
474  if the block size is 1, and both the starting and ending padding and crop
475  amounts are 0, then space_to_batch_nd effectively leaves that dimension alone,
476  which is what is needed for dimensions not part of `spatial_dims`.
477  Furthermore, `space_to_batch_nd` and `batch_to_space_nd` handle this case
478  efficiently for any number of leading and trailing dimensions.
479
480  For 0 <= i < len(spatial_dims), we assign:
481
482    adjusted_dilation_rate[spatial_dims[i] - 1] = dilation_rate[i]
483    adjusted_paddings[spatial_dims[i] - 1, :] = paddings[i, :]
484    adjusted_crops[spatial_dims[i] - 1, :] = crops[i, :]
485
486  All unassigned values of `adjusted_dilation_rate` default to 1, while all
487  unassigned values of `adjusted_paddings` and `adjusted_crops` default to 0.
488
489  Note in the case that `dilation_rate` is not uniformly 1, specifying "VALID"
490  padding is equivalent to specifying `padding = "SAME"` with a filter_shape of
491  `[1]*N`.
492
493  Advanced usage. Note the following optimization: A sequence of
494  `with_space_to_batch` operations with identical (not uniformly 1)
495  `dilation_rate` parameters and "VALID" padding
496
497    net = with_space_to_batch(net, dilation_rate, "VALID", op_1)
498    ...
499    net = with_space_to_batch(net, dilation_rate, "VALID", op_k)
500
501  can be combined into a single `with_space_to_batch` operation as follows:
502
503    def combined_op(converted_input, num_spatial_dims, _):
504      result = op_1(converted_input, num_spatial_dims, "VALID")
505      ...
506      result = op_k(result, num_spatial_dims, "VALID")
507
508    net = with_space_to_batch(net, dilation_rate, "VALID", combined_op)
509
510  This eliminates the overhead of `k-1` calls to `space_to_batch_nd` and
511  `batch_to_space_nd`.
512
513  Similarly, a sequence of `with_space_to_batch` operations with identical (not
514  uniformly 1) `dilation_rate` parameters, "SAME" padding, and odd filter
515  dimensions
516
517    net = with_space_to_batch(net, dilation_rate, "SAME", op_1, filter_shape_1)
518    ...
519    net = with_space_to_batch(net, dilation_rate, "SAME", op_k, filter_shape_k)
520
521  can be combined into a single `with_space_to_batch` operation as follows:
522
523    def combined_op(converted_input, num_spatial_dims, _):
524      result = op_1(converted_input, num_spatial_dims, "SAME")
525      ...
526      result = op_k(result, num_spatial_dims, "SAME")
527
528    net = with_space_to_batch(net, dilation_rate, "VALID", combined_op)
529
530  Args:
531    input: Tensor of rank > max(spatial_dims).
532    dilation_rate: int32 Tensor of *known* shape [num_spatial_dims].
533    padding: str constant equal to "VALID" or "SAME"
534    op: Function that maps (input, num_spatial_dims, padding) -> output
535    filter_shape: If padding = "SAME", specifies the shape of the convolution
536      kernel/pooling window as an integer Tensor of shape [>=num_spatial_dims].
537      If padding = "VALID", filter_shape is ignored and need not be specified.
538    spatial_dims: Monotonically increasing sequence of `num_spatial_dims`
539      integers (which are >= 1) specifying the spatial dimensions of `input`
540      and output.  Defaults to: `range(1, num_spatial_dims+1)`.
541    data_format: A string or None.  Specifies whether the channel dimension of
542      the `input` and output is the last dimension (default, or if `data_format`
543      does not start with "NC"), or the second dimension (if `data_format`
544      starts with "NC").  For N=1, the valid values are "NWC" (default) and
545      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
546      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
547
548  Returns:
549    The output Tensor as described above, dimensions will vary based on the op
550    provided.
551
552  Raises:
553    ValueError: if `padding` is invalid or the arguments are incompatible.
554    ValueError: if `spatial_dims` are invalid.
555
556  """
557  input = ops.convert_to_tensor(input, name="input")  # pylint: disable=redefined-builtin
558  input_shape = input.shape
559
560  def build_op(num_spatial_dims, padding):
561    return lambda inp, _: op(inp, num_spatial_dims, padding)
562
563  new_op = _WithSpaceToBatch(
564      input_shape,
565      dilation_rate,
566      padding,
567      build_op,
568      filter_shape=filter_shape,
569      spatial_dims=spatial_dims,
570      data_format=data_format)
571  return new_op(input, None)
572
573
574class _WithSpaceToBatch(object):
575  """Helper class for with_space_to_batch.
576
577  Note that this class assumes that shapes of input and filter passed to
578  `__call__` are compatible with `input_shape`, `filter_shape`, and
579  `spatial_dims` passed to the constructor.
580
581  Arguments
582    input_shape: static shape of input. i.e. input.shape.
583    dilation_rate: see `with_space_to_batch`.
584    padding: see `with_space_to_batch`.
585    build_op: Function that maps (num_spatial_dims, paddings) -> (function that
586      maps (input, filter) -> output).
587    filter_shape: see `with_space_to_batch`.
588    spatial_dims: `see with_space_to_batch`.
589    data_format: see `with_space_to_batch`.
590    num_batch_dims: (Optional).  Number of batch dims in `input_shape`.
591  """
592
593  def __init__(self,
594               input_shape,
595               dilation_rate,
596               padding,
597               build_op,
598               filter_shape=None,
599               spatial_dims=None,
600               data_format=None,
601               num_batch_dims=1):
602    """Helper class for _with_space_to_batch."""
603    dilation_rate = ops.convert_to_tensor(
604        dilation_rate, dtypes.int32, name="dilation_rate")
605    if dilation_rate.shape.ndims not in (None, 1):
606      raise ValueError(
607          "rate must be rank 1 but saw {}".format(dilation_rate.shape.ndims))
608
609    if not dilation_rate.shape.is_fully_defined():
610      raise ValueError("rate must have known shape, but saw {}"
611                       .format(dilation_rate.shape))
612
613    num_spatial_dims = dilation_rate.shape.dims[0].value
614
615    if data_format is not None and data_format.startswith("NC"):
616      starting_spatial_dim = num_batch_dims + 1
617    else:
618      starting_spatial_dim = num_batch_dims
619
620    if spatial_dims is None:
621      spatial_dims = range(starting_spatial_dim,
622                           num_spatial_dims + starting_spatial_dim)
623    orig_spatial_dims = list(spatial_dims)
624    spatial_dims = sorted(set(int(x) for x in orig_spatial_dims))
625    if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
626      raise ValueError(
627          "spatial_dims must be a monotonically increasing sequence of "
628          "positive integers, but saw: {}".format(orig_spatial_dims))
629
630    if data_format is not None and data_format.startswith("NC"):
631      expected_input_rank = spatial_dims[-1]
632    else:
633      expected_input_rank = spatial_dims[-1] + 1
634
635    try:
636      input_shape.with_rank_at_least(expected_input_rank)
637    except ValueError:
638      raise ValueError(
639          "input tensor must have rank at least {}, but saw rank {}"
640          .format(expected_input_rank, input_shape.ndims))
641
642    const_rate = tensor_util.constant_value(dilation_rate)
643    rate_or_const_rate = dilation_rate
644    if const_rate is not None:
645      rate_or_const_rate = const_rate
646      if np.any(const_rate < 1):
647        raise ValueError("dilation_rate must be positive, but saw: {}"
648                         .format(const_rate))
649      if np.all(const_rate == 1):
650        self.call = build_op(num_spatial_dims, padding)
651        return
652
653    padding, explicit_paddings = convert_padding(padding)
654
655    # We have two padding contributions. The first is used for converting "SAME"
656    # to "VALID". The second is required so that the height and width of the
657    # zero-padded value tensor are multiples of rate.
658
659    # Padding required to reduce to "VALID" convolution
660    if padding == "SAME":
661      if filter_shape is None:
662        raise ValueError("filter_shape must be specified for SAME padding")
663      filter_shape = ops.convert_to_tensor(filter_shape, name="filter_shape")
664      const_filter_shape = tensor_util.constant_value(filter_shape)
665      if const_filter_shape is not None:
666        filter_shape = const_filter_shape
667        self.base_paddings = _with_space_to_batch_base_paddings(
668            const_filter_shape, num_spatial_dims, rate_or_const_rate)
669      else:
670        self.num_spatial_dims = num_spatial_dims
671        self.rate_or_const_rate = rate_or_const_rate
672        self.base_paddings = None
673    elif padding == "VALID":
674      self.base_paddings = np.zeros([num_spatial_dims, 2], np.int32)
675    elif padding == "EXPLICIT":
676      base_paddings = (np.array(explicit_paddings)
677                       .reshape([num_spatial_dims + 2, 2]))
678      # Remove batch and channel dimensions
679      if data_format is not None and data_format.startswith("NC"):
680        self.base_paddings = base_paddings[2:]
681      else:
682        self.base_paddings = base_paddings[1:-1]
683    else:
684      raise ValueError("Invalid padding method %r" % padding)
685
686    self.input_shape = input_shape
687    self.spatial_dims = spatial_dims
688    self.dilation_rate = dilation_rate
689    self.data_format = data_format
690    self.op = build_op(num_spatial_dims, "VALID")
691    self.call = self._with_space_to_batch_call
692
693  def _with_space_to_batch_call(self, inp, filter):  # pylint: disable=redefined-builtin
694    """Call functionality for with_space_to_batch."""
695    # Handle input whose shape is unknown during graph creation.
696    input_spatial_shape = None
697    input_shape = self.input_shape
698    spatial_dims = self.spatial_dims
699    if input_shape.ndims is not None:
700      input_shape_list = input_shape.as_list()
701      input_spatial_shape = [input_shape_list[i] for i in spatial_dims]
702    if input_spatial_shape is None or None in input_spatial_shape:
703      input_shape_tensor = array_ops.shape(inp)
704      input_spatial_shape = array_ops.stack(
705          [input_shape_tensor[i] for i in spatial_dims])
706
707    base_paddings = self.base_paddings
708    if base_paddings is None:
709      # base_paddings could not be computed at build time since static filter
710      # shape was not fully defined.
711      filter_shape = array_ops.shape(filter)
712      base_paddings = _with_space_to_batch_base_paddings(
713          filter_shape, self.num_spatial_dims, self.rate_or_const_rate)
714
715    paddings, crops = array_ops.required_space_to_batch_paddings(
716        input_shape=input_spatial_shape,
717        base_paddings=base_paddings,
718        block_shape=self.dilation_rate)
719
720    dilation_rate = _with_space_to_batch_adjust(self.dilation_rate, 1,
721                                                spatial_dims)
722    paddings = _with_space_to_batch_adjust(paddings, 0, spatial_dims)
723    crops = _with_space_to_batch_adjust(crops, 0, spatial_dims)
724    input_converted = array_ops.space_to_batch_nd(
725        input=inp, block_shape=dilation_rate, paddings=paddings)
726
727    result = self.op(input_converted, filter)
728
729    result_converted = array_ops.batch_to_space_nd(
730        input=result, block_shape=dilation_rate, crops=crops)
731
732    # Recover channel information for output shape if channels are not last.
733    if self.data_format is not None and self.data_format.startswith("NC"):
734      if not result_converted.shape.dims[1].value and filter is not None:
735        output_shape = result_converted.shape.as_list()
736        output_shape[1] = filter.shape[-1]
737        result_converted.set_shape(output_shape)
738
739    return result_converted
740
741  def __call__(self, inp, filter):  # pylint: disable=redefined-builtin
742    return self.call(inp, filter)
743
744
745def _with_space_to_batch_base_paddings(filter_shape, num_spatial_dims,
746                                       rate_or_const_rate):
747  """Helper function to compute base_paddings."""
748  # Spatial dimensions of the filters and the upsampled filters in which we
749  # introduce (rate - 1) zeros between consecutive filter values.
750  filter_spatial_shape = filter_shape[:num_spatial_dims]
751  pad_extra_shape = (filter_spatial_shape - 1) * rate_or_const_rate
752
753  # When full_padding_shape is odd, we pad more at end, following the same
754  # convention as conv2d.
755  pad_extra_start = pad_extra_shape // 2
756  pad_extra_end = pad_extra_shape - pad_extra_start
757  base_paddings = array_ops.stack(
758      [[pad_extra_start[i], pad_extra_end[i]] for i in range(num_spatial_dims)])
759  return base_paddings
760
761
762def _with_space_to_batch_adjust(orig, fill_value, spatial_dims):
763  """Returns an `adjusted` version of `orig` based on `spatial_dims`.
764
765  Tensor of the same type as `orig` and with shape
766  `[max(spatial_dims), ...]` where:
767
768    adjusted[spatial_dims[i] - 1, ...] = orig[i, ...]
769
770  for 0 <= i < len(spatial_dims), and
771
772    adjusted[j, ...] = fill_value
773
774  for j != spatial_dims[i] - 1 for some i.
775
776  If `orig` is a constant value, then the result will be a constant value.
777
778  Args:
779    orig: Tensor of rank > max(spatial_dims).
780    fill_value: Numpy scalar (of same data type as `orig) specifying the fill
781      value for non-spatial dimensions.
782    spatial_dims: See with_space_to_batch.
783
784  Returns:
785    `adjusted` tensor.
786  """
787  fill_dims = orig.get_shape().as_list()[1:]
788  dtype = orig.dtype.as_numpy_dtype
789  parts = []
790  const_orig = tensor_util.constant_value(orig)
791  const_or_orig = const_orig if const_orig is not None else orig
792  prev_spatial_dim = 0
793  i = 0
794  while i < len(spatial_dims):
795    start_i = i
796    start_spatial_dim = spatial_dims[i]
797    if start_spatial_dim > 1:
798      # Fill in any gap from the previous spatial dimension (or dimension 1 if
799      # this is the first spatial dimension) with `fill_value`.
800      parts.append(
801          np.full(
802              [start_spatial_dim - 1 - prev_spatial_dim] + fill_dims,
803              fill_value,
804              dtype=dtype))
805    # Find the largest value of i such that:
806    #   [spatial_dims[start_i], ..., spatial_dims[i]]
807    #     == [start_spatial_dim, ..., start_spatial_dim + i - start_i],
808    # i.e. the end of a contiguous group of spatial dimensions.
809    while (i + 1 < len(spatial_dims) and
810           spatial_dims[i + 1] == spatial_dims[i] + 1):
811      i += 1
812    parts.append(const_or_orig[start_i:i + 1])
813    prev_spatial_dim = spatial_dims[i]
814    i += 1
815  if const_orig is not None:
816    return np.concatenate(parts)
817  else:
818    return array_ops.concat(parts, 0)
819
820
821def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate):
822  """Helper function for verifying strides and dilation_rate arguments.
823
824  This is used by `convolution` and `pool`.
825
826  Args:
827    num_spatial_dims: int
828    strides: Optional.  List of N ints >= 1.  Defaults to [1]*N.  If any value
829      of strides is > 1, then all values of dilation_rate must be 1.
830    dilation_rate: Optional.  List of N ints >= 1.  Defaults to [1]*N.  If any
831      value of dilation_rate is > 1, then all values of strides must be 1.
832
833  Returns:
834    Normalized (strides, dilation_rate) as int32 numpy arrays of shape
835    [num_spatial_dims].
836
837  Raises:
838    ValueError: if the parameters are invalid.
839  """
840  if dilation_rate is None:
841    dilation_rate = [1] * num_spatial_dims
842  elif len(dilation_rate) != num_spatial_dims:
843    raise ValueError("len(dilation_rate)=%d but should be %d" %
844                     (len(dilation_rate), num_spatial_dims))
845  dilation_rate = np.array(dilation_rate, dtype=np.int32)
846  if np.any(dilation_rate < 1):
847    raise ValueError("all values of dilation_rate must be positive")
848
849  if strides is None:
850    strides = [1] * num_spatial_dims
851  elif len(strides) != num_spatial_dims:
852    raise ValueError("len(strides)=%d but should be %d" % (len(strides),
853                                                           num_spatial_dims))
854  strides = np.array(strides, dtype=np.int32)
855  if np.any(strides < 1):
856    raise ValueError("all values of strides must be positive")
857
858  if np.any(strides > 1) and np.any(dilation_rate > 1):
859    raise ValueError(
860        "strides > 1 not supported in conjunction with dilation_rate > 1")
861  return strides, dilation_rate
862
863
864@tf_export(v1=["nn.convolution"])
865@dispatch.add_dispatch_support
866def convolution(
867    input,  # pylint: disable=redefined-builtin
868    filter,  # pylint: disable=redefined-builtin
869    padding,
870    strides=None,
871    dilation_rate=None,
872    name=None,
873    data_format=None,
874    filters=None,
875    dilations=None):  # pylint: disable=g-doc-args
876  """Computes sums of N-D convolutions (actually cross-correlation).
877
878  This also supports either output striding via the optional `strides` parameter
879  or atrous convolution (also known as convolution with holes or dilated
880  convolution, based on the French word "trous" meaning holes in English) via
881  the optional `dilation_rate` parameter.  Currently, however, output striding
882  is not supported for atrous convolutions.
883
884  Specifically, in the case that `data_format` does not start with "NC", given
885  a rank (N+2) `input` Tensor of shape
886
887    [num_batches,
888     input_spatial_shape[0],
889     ...,
890     input_spatial_shape[N-1],
891     num_input_channels],
892
893  a rank (N+2) `filter` Tensor of shape
894
895    [spatial_filter_shape[0],
896     ...,
897     spatial_filter_shape[N-1],
898     num_input_channels,
899     num_output_channels],
900
901  an optional `dilation_rate` tensor of shape [N] (defaulting to [1]*N)
902  specifying the filter upsampling/input downsampling rate, and an optional list
903  of N `strides` (defaulting [1]*N), this computes for each N-D spatial output
904  position (x[0], ..., x[N-1]):
905
906  ```
907    output[b, x[0], ..., x[N-1], k] =
908        sum_{z[0], ..., z[N-1], q}
909            filter[z[0], ..., z[N-1], q, k] *
910            padded_input[b,
911                         x[0]*strides[0] + dilation_rate[0]*z[0],
912                         ...,
913                         x[N-1]*strides[N-1] + dilation_rate[N-1]*z[N-1],
914                         q]
915  ```
916  where b is the index into the batch, k is the output channel number, q is the
917  input channel number, and z is the N-D spatial offset within the filter. Here,
918  `padded_input` is obtained by zero padding the input using an effective
919  spatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` and
920  output striding `strides`.
921
922  In the case that `data_format` does start with `"NC"`, the `input` and output
923  (but not the `filter`) are simply transposed as follows:
924
925    convolution(input, data_format, **kwargs) =
926      tf.transpose(convolution(tf.transpose(input, [0] + range(2,N+2) + [1]),
927                               **kwargs),
928                   [0, N+1] + range(1, N+1))
929
930  It is required that 1 <= N <= 3.
931
932  Args:
933    input: An (N+2)-D `Tensor` of type `T`, of shape
934      `[batch_size] + input_spatial_shape + [in_channels]` if data_format does
935      not start with "NC" (default), or
936      `[batch_size, in_channels] + input_spatial_shape` if data_format starts
937      with "NC".
938    filter: An (N+2)-D `Tensor` with the same type as `input` and shape
939      `spatial_filter_shape + [in_channels, out_channels]`.
940    padding: A string, either `"VALID"` or `"SAME"`. The padding algorithm.
941      `"valid"` means no padding. `"same"` results in padding evenly to
942      the left/right or up/down of the input such that output has the same
943      height/width dimension as the input.
944    strides: Optional.  Sequence of N ints >= 1.  Specifies the output stride.
945      Defaults to [1]*N.  If any value of strides is > 1, then all values of
946      dilation_rate must be 1.
947    dilation_rate: Optional.  Sequence of N ints >= 1.  Specifies the filter
948      upsampling/input downsampling rate.  In the literature, the same parameter
949      is sometimes called `input stride` or `dilation`.  The effective filter
950      size used for the convolution will be `spatial_filter_shape +
951      (spatial_filter_shape - 1) * (rate - 1)`, obtained by inserting
952      (dilation_rate[i]-1) zeros between consecutive elements of the original
953      filter in each spatial dimension i.  If any value of dilation_rate is > 1,
954      then all values of strides must be 1.
955    name: Optional name for the returned tensor.
956    data_format: A string or None.  Specifies whether the channel dimension of
957      the `input` and output is the last dimension (default, or if `data_format`
958      does not start with "NC"), or the second dimension (if `data_format`
959      starts with "NC").  For N=1, the valid values are "NWC" (default) and
960      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
961      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
962
963  Returns:
964    A `Tensor` with the same type as `input` of shape
965
966        `[batch_size] + output_spatial_shape + [out_channels]`
967
968    if data_format is None or does not start with "NC", or
969
970        `[batch_size, out_channels] + output_spatial_shape`
971
972    if data_format starts with "NC",
973    where `output_spatial_shape` depends on the value of `padding`.
974
975    If padding == "SAME":
976      output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
977
978    If padding == "VALID":
979      output_spatial_shape[i] =
980        ceil((input_spatial_shape[i] -
981              (spatial_filter_shape[i]-1) * dilation_rate[i])
982             / strides[i]).
983
984  Raises:
985    ValueError: If input/output depth does not match `filter` shape, if padding
986      is other than `"VALID"` or `"SAME"`, or if data_format is invalid.
987
988  """
989  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
990  dilation_rate = deprecated_argument_lookup(
991      "dilations", dilations, "dilation_rate", dilation_rate)
992  return convolution_internal(
993      input,
994      filter,
995      strides=strides,
996      padding=padding,
997      data_format=data_format,
998      dilations=dilation_rate,
999      name=name)
1000
1001
1002@tf_export("nn.convolution", v1=[])
1003@dispatch.add_dispatch_support
1004def convolution_v2(  # pylint: disable=missing-docstring
1005    input,  # pylint: disable=redefined-builtin
1006    filters,
1007    strides=None,
1008    padding="VALID",
1009    data_format=None,
1010    dilations=None,
1011    name=None):
1012  return convolution_internal(
1013      input,  # pylint: disable=redefined-builtin
1014      filters,
1015      strides=strides,
1016      padding=padding,
1017      data_format=data_format,
1018      dilations=dilations,
1019      name=name)
1020
1021
1022convolution_v2.__doc__ = deprecation.rewrite_argument_docstring(
1023    deprecation.rewrite_argument_docstring(
1024        convolution.__doc__, "dilation_rate", "dilations"),
1025    "filter", "filters")
1026
1027
1028def convolution_internal(
1029    input,  # pylint: disable=redefined-builtin
1030    filters,
1031    strides=None,
1032    padding="VALID",
1033    data_format=None,
1034    dilations=None,
1035    name=None,
1036    call_from_convolution=True,
1037    num_spatial_dims=None):
1038  """Internal function which performs rank agnostic convolution.
1039
1040  Args:
1041    input: See `convolution`.
1042    filters: See `convolution`.
1043    strides: See `convolution`.
1044    padding: See `convolution`.
1045    data_format: See `convolution`.
1046    dilations: See `convolution`.
1047    name: See `convolution`.
1048    call_from_convolution: See `convolution`.
1049    num_spatial_dims: (Optional.).  It is a integer describing the
1050      rank of the spatial dimensions.  For `1-D`, `2-D` and `3-D` convolutions,
1051      the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
1052      This argument is only required to disambiguate the rank of `batch_shape`
1053      when `filter_shape.ndims is None` and `len(batch_shape) > 1`.  For
1054      backwards compatibility, if `num_spatial_dims is None` and
1055     `filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
1056     `1` (i.e., the input is expected to be
1057     `[batch_size, num_channels] + input_spatial_shape`
1058     or `[batch_size] + input_spatial_shape + [num_channels]`.
1059
1060  Returns:
1061    A tensor of shape and dtype matching that of `input`.
1062
1063  Raises:
1064    ValueError: If input and filter both have unknown shapes, or if
1065      `num_spatial_dims` is provided and incompatible with the value
1066      estimated from `filters.shape`.
1067  """
1068  if (not isinstance(filters, variables_lib.Variable) and
1069      not tensor_util.is_tf_type(filters)):
1070    with ops.name_scope("convolution_internal", None, [filters, input]):
1071      filters = ops.convert_to_tensor(filters, name='filters')
1072  if (not isinstance(input, ops.Tensor) and not tensor_util.is_tf_type(input)):
1073    with ops.name_scope("convolution_internal", None, [filters, input]):
1074      input = ops.convert_to_tensor(input, name="input")
1075
1076  filters_rank = filters.shape.rank
1077  inputs_rank = input.shape.rank
1078  if num_spatial_dims is None:
1079    if filters_rank:
1080      num_spatial_dims = filters_rank - 2
1081    elif inputs_rank:
1082      num_spatial_dims = inputs_rank - 2
1083    else:
1084      raise ValueError("rank of input or filter must be known")
1085  elif filters_rank and filters_rank - 2 != num_spatial_dims:
1086    raise ValueError(
1087        "inconsistent estimate of spatial dims ({}) vs. actual passed "
1088        "num_spatial_dims ({}).  n was estimated as len(filters.shape) - 2, "
1089        "but filters shape is: {}".format(filters_rank, num_spatial_dims,
1090                                          filters.shape))
1091
1092  if inputs_rank:
1093    num_batch_dims = inputs_rank - num_spatial_dims - 1  # Channel dimension.
1094  else:
1095    num_batch_dims = 1  # By default, assume single batch dimension.
1096
1097  if num_spatial_dims not in {1, 2, 3}:
1098    raise ValueError(
1099        "num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one "
1100        "of 1, 2 or 3 but saw {}.  num_batch_dims: {}.".format(
1101            num_spatial_dims, num_batch_dims))
1102
1103  if data_format is None or data_format in _CHANNELS_LAST_FORMATS:
1104    channel_index = num_batch_dims + num_spatial_dims
1105  else:
1106    channel_index = num_batch_dims
1107
1108  if dilations is None:
1109    dilations = _get_sequence(dilations, num_spatial_dims, channel_index,
1110                              "dilations")
1111    is_dilated_conv = False
1112  else:
1113    dilations = _get_sequence(dilations, num_spatial_dims, channel_index,
1114                              "dilations")
1115    is_dilated_conv = any(i != 1 for i in dilations)
1116
1117  strides = _get_sequence(strides, num_spatial_dims, channel_index, "strides")
1118  has_tpu_context = device_context.enclosing_tpu_context() is not None
1119
1120  if name:
1121    default_name = None
1122  elif not has_tpu_context or call_from_convolution:
1123    default_name = "convolution"
1124  elif num_spatial_dims == 2:  # Most common case.
1125    default_name = "Conv2D"
1126  elif num_spatial_dims == 3:
1127    default_name = "Conv3D"
1128  else:
1129    default_name = "conv1d"
1130
1131  with ops.name_scope(name, default_name, [input, filters]) as name:
1132    # Fast path for TPU or if no dilation, as gradient only supported on TPU
1133    # for dilations.
1134    if not is_dilated_conv or has_tpu_context:
1135      if num_spatial_dims == 2:  # Most common case.
1136        op = _conv2d_expanded_batch
1137      elif num_spatial_dims == 3:
1138        op = _conv3d_expanded_batch
1139      else:
1140        op = conv1d
1141
1142      return op(
1143          input,
1144          filters,
1145          strides,
1146          padding=padding,
1147          data_format=data_format,
1148          dilations=dilations,
1149          name=name)
1150    else:
1151      if channel_index == 1:
1152        strides = strides[2:]
1153        dilations = dilations[2:]
1154      else:
1155        strides = strides[1:-1]
1156        dilations = dilations[1:-1]
1157
1158      op = Convolution(
1159          tensor_shape.as_shape(input.shape),
1160          tensor_shape.as_shape(filters.shape),
1161          padding,
1162          strides=strides,
1163          dilation_rate=dilations,
1164          name=name,
1165          data_format=data_format,
1166          num_spatial_dims=num_spatial_dims)
1167      return op(input, filters)
1168
1169
1170class Convolution(object):
1171  """Helper class for convolution.
1172
1173  Note that this class assumes that shapes of input and filter passed to
1174  `__call__` are compatible with `input_shape`, `filter_shape`, and
1175  `num_spatial_dims` passed to the constructor.
1176
1177  Arguments
1178    input_shape: static shape of input. i.e. input.shape.  Its length is
1179      `batch_shape + input_spatial_shape + [num_channels]` if `data_format`
1180      does not start with `NC`, or
1181      `batch_shape + [num_channels] + input_spatial_shape` if `data_format`
1182      starts with `NC`.
1183    filter_shape: static shape of the filter. i.e. filter.shape.
1184    padding: The padding algorithm, must be "SAME" or "VALID".
1185    strides: see convolution.
1186    dilation_rate: see convolution.
1187    name: see convolution.
1188    data_format: A string or `None`.  Specifies whether the channel dimension of
1189      the `input` and output is the last dimension (if `data_format` is `None`
1190      or does not start with `NC`), or the first post-batch dimension (i.e. if
1191      `data_format` starts with `NC`).
1192    num_spatial_dims: (Usually optional.)  Python integer, the rank of the
1193      spatial and channel dimensions.  For `1-D`, `2-D` and `3-D` convolutions,
1194      the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
1195      This argument is only required to disambiguate the rank of `batch_shape`
1196      when `filter_shape.ndims is None` and `len(batch_shape) > 1`.  For
1197      backwards compatibility, if `num_spatial_dims is None` and
1198      `filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
1199      `1` (i.e., the input is expected to be
1200      `[batch_size, num_channels] + input_spatial_shape`
1201      or `[batch_size] + input_spatial_shape + [num_channels]`.
1202  """
1203
1204  def __init__(self,
1205               input_shape,
1206               filter_shape,
1207               padding,
1208               strides=None,
1209               dilation_rate=None,
1210               name=None,
1211               data_format=None,
1212               num_spatial_dims=None):
1213    """Helper function for convolution."""
1214    num_batch_dims = None
1215    filter_shape = tensor_shape.as_shape(filter_shape)
1216    input_shape = tensor_shape.as_shape(input_shape)
1217
1218    if filter_shape.ndims is not None:
1219      if (num_spatial_dims is not None and
1220          filter_shape.ndims != num_spatial_dims + 2):
1221        raise ValueError(
1222            "Expected filter_shape.ndims == num_spatial_dims + 2, "
1223            "but saw filter_shape.ndims == {} and num_spatial_dims == {}"
1224            .format(filter_shape.ndims, num_spatial_dims))
1225      else:
1226        num_spatial_dims = filter_shape.ndims - 2
1227
1228    if input_shape.ndims is not None and num_spatial_dims is not None:
1229      num_batch_dims = input_shape.ndims - num_spatial_dims - 1
1230
1231    if num_spatial_dims is None:
1232      num_spatial_dims = input_shape.ndims - 2
1233    else:
1234      if input_shape.ndims is not None:
1235        if input_shape.ndims < num_spatial_dims + 2:
1236          raise ValueError(
1237              "Expected input_shape.ndims >= num_spatial_dims + 2, but saw "
1238              "input_shape.ndims == {} and num_spatial_dims == {}"
1239              .format(input_shape.ndims, num_spatial_dims))
1240        else:
1241          if num_batch_dims is None:
1242            num_batch_dims = input_shape.ndims - num_spatial_dims - 1
1243
1244    if num_spatial_dims is None:
1245      raise ValueError(
1246          "Cannot estimate num_spatial_dims since input_shape.ndims is None, "
1247          "filter_shape.ndims is None, and argument num_spatial_dims is also "
1248          "None.")
1249
1250    if num_batch_dims is None:
1251      num_batch_dims = 1
1252
1253    if num_batch_dims < 1:
1254      raise ValueError(
1255          "num_batch_dims should be >= 1, but saw {}.  num_batch_dims was "
1256          "estimated as `input_shape.ndims - num_spatial_dims - 1` and "
1257          "num_spatial_dims was either provided or estimated as "
1258          "`filter_shape.ndims - 2`.  input_shape.ndims: {}, "
1259          "num_spatial_dims: {}, filter_shape.ndims: {}"
1260          .format(num_batch_dims, input_shape.ndims, num_spatial_dims,
1261                  filter_shape.ndims))
1262
1263    if data_format is None or not data_format.startswith("NC"):
1264      input_channels_dim = tensor_shape.dimension_at_index(
1265          input_shape, num_spatial_dims + num_batch_dims)
1266      spatial_dims = range(num_batch_dims, num_spatial_dims + num_batch_dims)
1267    else:
1268      input_channels_dim = tensor_shape.dimension_at_index(
1269          input_shape, num_batch_dims)
1270      spatial_dims = range(
1271          num_batch_dims + 1, num_spatial_dims + num_batch_dims + 1)
1272
1273    filter_dim = tensor_shape.dimension_at_index(filter_shape, num_spatial_dims)
1274    if not (input_channels_dim % filter_dim).is_compatible_with(0):
1275      raise ValueError("The number of input channels is not divisible by the "
1276                       "corresponding number of output filters. Received: "
1277                       "input channels={}, output filters={}".format(
1278                           input_channels_dim, filter_dim))
1279
1280    strides, dilation_rate = _get_strides_and_dilation_rate(
1281        num_spatial_dims, strides, dilation_rate)
1282
1283    self.input_shape = input_shape
1284    self.filter_shape = filter_shape
1285    self.data_format = data_format
1286    self.strides = strides
1287    self.padding = padding
1288    self.name = name
1289    self.dilation_rate = dilation_rate
1290    self.num_batch_dims = num_batch_dims
1291    self.num_spatial_dims = num_spatial_dims
1292    self.conv_op = _WithSpaceToBatch(
1293        input_shape,
1294        dilation_rate=dilation_rate,
1295        padding=padding,
1296        build_op=self._build_op,
1297        filter_shape=filter_shape,
1298        spatial_dims=spatial_dims,
1299        data_format=data_format,
1300        num_batch_dims=num_batch_dims)
1301
1302  def _build_op(self, _, padding):
1303    return _NonAtrousConvolution(
1304        self.input_shape,
1305        filter_shape=self.filter_shape,
1306        padding=padding,
1307        data_format=self.data_format,
1308        strides=self.strides,
1309        name=self.name,
1310        num_batch_dims=self.num_batch_dims)
1311
1312  def __call__(self, inp, filter):  # pylint: disable=redefined-builtin
1313    # TPU convolution supports dilations greater than 1.
1314    if device_context.enclosing_tpu_context() is not None:
1315      return convolution_internal(
1316          inp,
1317          filter,
1318          strides=self.strides,
1319          padding=self.padding,
1320          data_format=self.data_format,
1321          dilations=self.dilation_rate,
1322          name=self.name,
1323          call_from_convolution=False,
1324          num_spatial_dims=self.num_spatial_dims)
1325    else:
1326      return self.conv_op(inp, filter)
1327
1328
1329@tf_export(v1=["nn.pool"])
1330@dispatch.add_dispatch_support
1331def pool(
1332    input,  # pylint: disable=redefined-builtin
1333    window_shape,
1334    pooling_type,
1335    padding,
1336    dilation_rate=None,
1337    strides=None,
1338    name=None,
1339    data_format=None,
1340    dilations=None):
1341  """Performs an N-D pooling operation.
1342
1343  In the case that `data_format` does not start with "NC", computes for
1344      0 <= b < batch_size,
1345      0 <= x[i] < output_spatial_shape[i],
1346      0 <= c < num_channels:
1347
1348  ```
1349    output[b, x[0], ..., x[N-1], c] =
1350      REDUCE_{z[0], ..., z[N-1]}
1351        input[b,
1352              x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
1353              ...
1354              x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1],
1355              c],
1356  ```
1357
1358  where the reduction function REDUCE depends on the value of `pooling_type`,
1359  and pad_before is defined based on the value of `padding` as described in
1360  the "returns" section of `tf.nn.convolution` for details.
1361  The reduction never includes out-of-bounds positions.
1362
1363  In the case that `data_format` starts with `"NC"`, the `input` and output are
1364  simply transposed as follows:
1365
1366  ```
1367    pool(input, data_format, **kwargs) =
1368      tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
1369                        **kwargs),
1370                   [0, N+1] + range(1, N+1))
1371  ```
1372
1373  Args:
1374    input: Tensor of rank N+2, of shape
1375      `[batch_size] + input_spatial_shape + [num_channels]` if data_format does
1376      not start with "NC" (default), or
1377      `[batch_size, num_channels] + input_spatial_shape` if data_format starts
1378      with "NC".  Pooling happens over the spatial dimensions only.
1379    window_shape: Sequence of N ints >= 1.
1380    pooling_type: Specifies pooling operation, must be "AVG" or "MAX".
1381    padding: The padding algorithm, must be "SAME" or "VALID".
1382      See the "returns" section of `tf.nn.convolution` for details.
1383    dilation_rate: Optional.  Dilation rate.  List of N ints >= 1.
1384      Defaults to [1]*N.  If any value of dilation_rate is > 1, then all values
1385      of strides must be 1.
1386    strides: Optional.  Sequence of N ints >= 1.  Defaults to [1]*N.
1387      If any value of strides is > 1, then all values of dilation_rate must be
1388      1.
1389    name: Optional. Name of the op.
1390    data_format: A string or None.  Specifies whether the channel dimension of
1391      the `input` and output is the last dimension (default, or if `data_format`
1392      does not start with "NC"), or the second dimension (if `data_format`
1393      starts with "NC").  For N=1, the valid values are "NWC" (default) and
1394      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
1395      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
1396    dilations: Alias for dilation_rate
1397
1398  Returns:
1399    Tensor of rank N+2, of shape
1400      [batch_size] + output_spatial_shape + [num_channels]
1401
1402    if data_format is None or does not start with "NC", or
1403
1404      [batch_size, num_channels] + output_spatial_shape
1405
1406    if data_format starts with "NC",
1407    where `output_spatial_shape` depends on the value of padding:
1408
1409    If padding = "SAME":
1410      output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
1411
1412    If padding = "VALID":
1413      output_spatial_shape[i] =
1414        ceil((input_spatial_shape[i] - (window_shape[i] - 1) * dilation_rate[i])
1415             / strides[i]).
1416
1417  Raises:
1418    ValueError: if arguments are invalid.
1419
1420  """
1421  dilation_rate = deprecated_argument_lookup(
1422      "dilations", dilations, "dilation_rate", dilation_rate)
1423  # pylint: enable=line-too-long
1424  with ops.name_scope(name, "%s_pool" % (pooling_type.lower()),
1425                      [input]) as scope:
1426    input = ops.convert_to_tensor(input, name="input")  # pylint: disable=redefined-builtin
1427
1428    num_spatial_dims = len(window_shape)
1429    if num_spatial_dims < 1 or num_spatial_dims > 3:
1430      raise ValueError("It is required that 1 <= num_spatial_dims <= 3.")
1431
1432    input.get_shape().with_rank(num_spatial_dims + 2)
1433
1434    strides, dilation_rate = _get_strides_and_dilation_rate(
1435        num_spatial_dims, strides, dilation_rate)
1436
1437    if padding == "SAME" and np.any(dilation_rate > 1):
1438      raise ValueError(
1439          "pooling with SAME padding is not implemented for dilation_rate > 1")
1440
1441    if np.any(strides > window_shape):
1442      raise ValueError(
1443          "strides > window_shape not supported due to inconsistency between "
1444          "CPU and GPU implementations")
1445
1446    pooling_ops = {
1447        ("MAX", 1): max_pool,
1448        ("MAX", 2): max_pool,
1449        ("MAX", 3): max_pool3d,  # pylint: disable=undefined-variable
1450        ("AVG", 1): avg_pool,
1451        ("AVG", 2): avg_pool,
1452        ("AVG", 3): avg_pool3d,  # pylint: disable=undefined-variable
1453    }
1454    op_key = (pooling_type, num_spatial_dims)
1455    if op_key not in pooling_ops:
1456      raise ValueError("%d-D %s pooling is not supported." % (op_key[1],
1457                                                              op_key[0]))
1458
1459    if data_format is None or not data_format.startswith("NC"):
1460      adjusted_window_shape = [1] + list(window_shape) + [1]
1461      adjusted_strides = [1] + list(strides) + [1]
1462      spatial_dims = range(1, num_spatial_dims + 1)
1463    else:
1464      adjusted_window_shape = [1, 1] + list(window_shape)
1465      adjusted_strides = [1, 1] + list(strides)
1466      spatial_dims = range(2, num_spatial_dims + 2)
1467
1468    if num_spatial_dims == 1:
1469      if data_format is None or data_format == "NWC":
1470        data_format_kwargs = dict(data_format="NHWC")
1471      elif data_format == "NCW":
1472        data_format_kwargs = dict(data_format="NCHW")
1473      else:
1474        raise ValueError("data_format must be either \"NWC\" or \"NCW\".")
1475      adjusted_window_shape = [1] + adjusted_window_shape
1476      adjusted_strides = [1] + adjusted_strides
1477    else:
1478      data_format_kwargs = dict(data_format=data_format)
1479
1480    def op(converted_input, _, converted_padding):  # pylint: disable=missing-docstring
1481      if num_spatial_dims == 1:
1482        converted_input = array_ops.expand_dims(converted_input,
1483                                                spatial_dims[0])
1484      result = pooling_ops[op_key](
1485          converted_input,
1486          adjusted_window_shape,
1487          adjusted_strides,
1488          converted_padding,
1489          name=scope,
1490          **data_format_kwargs)
1491      if num_spatial_dims == 1:
1492        result = array_ops.squeeze(result, [spatial_dims[0]])
1493      return result
1494
1495    return with_space_to_batch(
1496        input=input,
1497        dilation_rate=dilation_rate,
1498        padding=padding,
1499        op=op,
1500        spatial_dims=spatial_dims,
1501        filter_shape=window_shape)
1502
1503
1504@tf_export("nn.pool", v1=[])
1505@dispatch.add_dispatch_support
1506def pool_v2(
1507    input,  # pylint: disable=redefined-builtin
1508    window_shape,
1509    pooling_type,
1510    strides=None,
1511    padding="VALID",
1512    data_format=None,
1513    dilations=None,
1514    name=None):
1515  # pylint: disable=line-too-long
1516  """Performs an N-D pooling operation.
1517
1518  In the case that `data_format` does not start with "NC", computes for
1519      0 <= b < batch_size,
1520      0 <= x[i] < output_spatial_shape[i],
1521      0 <= c < num_channels:
1522
1523  ```
1524    output[b, x[0], ..., x[N-1], c] =
1525      REDUCE_{z[0], ..., z[N-1]}
1526        input[b,
1527              x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
1528              ...
1529              x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1],
1530              c],
1531  ```
1532
1533  where the reduction function REDUCE depends on the value of `pooling_type`,
1534  and pad_before is defined based on the value of `padding` as described in
1535  the "returns" section of `tf.nn.convolution` for details.
1536  The reduction never includes out-of-bounds positions.
1537
1538  In the case that `data_format` starts with `"NC"`, the `input` and output are
1539  simply transposed as follows:
1540
1541  ```
1542    pool(input, data_format, **kwargs) =
1543      tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
1544                        **kwargs),
1545                   [0, N+1] + range(1, N+1))
1546  ```
1547
1548  Args:
1549    input: Tensor of rank N+2, of shape `[batch_size] + input_spatial_shape +
1550      [num_channels]` if data_format does not start with "NC" (default), or
1551      `[batch_size, num_channels] + input_spatial_shape` if data_format starts
1552      with "NC".  Pooling happens over the spatial dimensions only.
1553    window_shape: Sequence of N ints >= 1.
1554    pooling_type: Specifies pooling operation, must be "AVG" or "MAX".
1555    strides: Optional. Sequence of N ints >= 1.  Defaults to [1]*N. If any value of
1556      strides is > 1, then all values of dilation_rate must be 1.
1557    padding: The padding algorithm, must be "SAME" or "VALID". Defaults to "SAME".
1558      See the "returns" section of `tf.nn.convolution` for details.
1559    data_format: A string or None.  Specifies whether the channel dimension of
1560      the `input` and output is the last dimension (default, or if `data_format`
1561      does not start with "NC"), or the second dimension (if `data_format`
1562      starts with "NC").  For N=1, the valid values are "NWC" (default) and
1563      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW". For
1564      N=3, the valid values are "NDHWC" (default) and "NCDHW".
1565    dilations: Optional.  Dilation rate.  List of N ints >= 1. Defaults to
1566      [1]*N.  If any value of dilation_rate is > 1, then all values of strides
1567      must be 1.
1568    name: Optional. Name of the op.
1569
1570  Returns:
1571    Tensor of rank N+2, of shape
1572      [batch_size] + output_spatial_shape + [num_channels]
1573
1574    if data_format is None or does not start with "NC", or
1575
1576      [batch_size, num_channels] + output_spatial_shape
1577
1578    if data_format starts with "NC",
1579    where `output_spatial_shape` depends on the value of padding:
1580
1581    If padding = "SAME":
1582      output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
1583
1584    If padding = "VALID":
1585      output_spatial_shape[i] =
1586        ceil((input_spatial_shape[i] - (window_shape[i] - 1) * dilation_rate[i])
1587             / strides[i]).
1588
1589  Raises:
1590    ValueError: if arguments are invalid.
1591
1592  """
1593  return pool(
1594      input=input,
1595      window_shape=window_shape,
1596      pooling_type=pooling_type,
1597      padding=padding,
1598      dilation_rate=dilations,
1599      strides=strides,
1600      name=name,
1601      data_format=data_format)
1602
1603
1604@tf_export("nn.atrous_conv2d")
1605@dispatch.add_dispatch_support
1606def atrous_conv2d(value, filters, rate, padding, name=None):
1607  """Atrous convolution (a.k.a. convolution with holes or dilated convolution).
1608
1609  This function is a simpler wrapper around the more general
1610  `tf.nn.convolution`, and exists only for backwards compatibility. You can
1611  use `tf.nn.convolution` to perform 1-D, 2-D, or 3-D atrous convolution.
1612
1613
1614  Computes a 2-D atrous convolution, also known as convolution with holes or
1615  dilated convolution, given 4-D `value` and `filters` tensors. If the `rate`
1616  parameter is equal to one, it performs regular 2-D convolution. If the `rate`
1617  parameter is greater than one, it performs convolution with holes, sampling
1618  the input values every `rate` pixels in the `height` and `width` dimensions.
1619  This is equivalent to convolving the input with a set of upsampled filters,
1620  produced by inserting `rate - 1` zeros between two consecutive values of the
1621  filters along the `height` and `width` dimensions, hence the name atrous
1622  convolution or convolution with holes (the French word trous means holes in
1623  English).
1624
1625  More specifically:
1626
1627  ```
1628  output[batch, height, width, out_channel] =
1629      sum_{dheight, dwidth, in_channel} (
1630          filters[dheight, dwidth, in_channel, out_channel] *
1631          value[batch, height + rate*dheight, width + rate*dwidth, in_channel]
1632      )
1633  ```
1634
1635  Atrous convolution allows us to explicitly control how densely to compute
1636  feature responses in fully convolutional networks. Used in conjunction with
1637  bilinear interpolation, it offers an alternative to `conv2d_transpose` in
1638  dense prediction tasks such as semantic image segmentation, optical flow
1639  computation, or depth estimation. It also allows us to effectively enlarge
1640  the field of view of filters without increasing the number of parameters or
1641  the amount of computation.
1642
1643  For a description of atrous convolution and how it can be used for dense
1644  feature extraction, please see: (Chen et al., 2015). The same operation is
1645  investigated further in (Yu et al., 2016). Previous works that effectively
1646  use atrous convolution in different ways are, among others,
1647  (Sermanet et al., 2014) and (Giusti et al., 2013).
1648  Atrous convolution is also closely related to the so-called noble identities
1649  in multi-rate signal processing.
1650
1651  There are many different ways to implement atrous convolution (see the refs
1652  above). The implementation here reduces
1653
1654  ```python
1655      atrous_conv2d(value, filters, rate, padding=padding)
1656  ```
1657
1658  to the following three operations:
1659
1660  ```python
1661      paddings = ...
1662      net = space_to_batch(value, paddings, block_size=rate)
1663      net = conv2d(net, filters, strides=[1, 1, 1, 1], padding="VALID")
1664      crops = ...
1665      net = batch_to_space(net, crops, block_size=rate)
1666  ```
1667
1668  Advanced usage. Note the following optimization: A sequence of `atrous_conv2d`
1669  operations with identical `rate` parameters, 'SAME' `padding`, and filters
1670  with odd heights/ widths:
1671
1672  ```python
1673      net = atrous_conv2d(net, filters1, rate, padding="SAME")
1674      net = atrous_conv2d(net, filters2, rate, padding="SAME")
1675      ...
1676      net = atrous_conv2d(net, filtersK, rate, padding="SAME")
1677  ```
1678
1679  can be equivalently performed cheaper in terms of computation and memory as:
1680
1681  ```python
1682      pad = ...  # padding so that the input dims are multiples of rate
1683      net = space_to_batch(net, paddings=pad, block_size=rate)
1684      net = conv2d(net, filters1, strides=[1, 1, 1, 1], padding="SAME")
1685      net = conv2d(net, filters2, strides=[1, 1, 1, 1], padding="SAME")
1686      ...
1687      net = conv2d(net, filtersK, strides=[1, 1, 1, 1], padding="SAME")
1688      net = batch_to_space(net, crops=pad, block_size=rate)
1689  ```
1690
1691  because a pair of consecutive `space_to_batch` and `batch_to_space` ops with
1692  the same `block_size` cancel out when their respective `paddings` and `crops`
1693  inputs are identical.
1694
1695  Args:
1696    value: A 4-D `Tensor` of type `float`. It needs to be in the default "NHWC"
1697      format. Its shape is `[batch, in_height, in_width, in_channels]`.
1698    filters: A 4-D `Tensor` with the same type as `value` and shape
1699      `[filter_height, filter_width, in_channels, out_channels]`. `filters`'
1700      `in_channels` dimension must match that of `value`. Atrous convolution is
1701      equivalent to standard convolution with upsampled filters with effective
1702      height `filter_height + (filter_height - 1) * (rate - 1)` and effective
1703      width `filter_width + (filter_width - 1) * (rate - 1)`, produced by
1704      inserting `rate - 1` zeros along consecutive elements across the
1705      `filters`' spatial dimensions.
1706    rate: A positive int32. The stride with which we sample input values across
1707      the `height` and `width` dimensions. Equivalently, the rate by which we
1708      upsample the filter values by inserting zeros across the `height` and
1709      `width` dimensions. In the literature, the same parameter is sometimes
1710      called `input stride` or `dilation`.
1711    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
1712    name: Optional name for the returned tensor.
1713
1714  Returns:
1715    A `Tensor` with the same type as `value`.
1716    Output shape with `'VALID'` padding is:
1717
1718        [batch, height - 2 * (filter_width - 1),
1719         width - 2 * (filter_height - 1), out_channels].
1720
1721    Output shape with `'SAME'` padding is:
1722
1723        [batch, height, width, out_channels].
1724
1725  Raises:
1726    ValueError: If input/output depth does not match `filters`' shape, or if
1727      padding is other than `'VALID'` or `'SAME'`.
1728
1729  References:
1730    Multi-Scale Context Aggregation by Dilated Convolutions:
1731      [Yu et al., 2016](https://arxiv.org/abs/1511.07122)
1732      ([pdf](https://arxiv.org/pdf/1511.07122.pdf))
1733    Semantic Image Segmentation with Deep Convolutional Nets and Fully
1734    Connected CRFs:
1735      [Chen et al., 2015](http://arxiv.org/abs/1412.7062)
1736      ([pdf](https://arxiv.org/pdf/1412.7062))
1737    OverFeat - Integrated Recognition, Localization and Detection using
1738    Convolutional Networks:
1739      [Sermanet et al., 2014](https://arxiv.org/abs/1312.6229)
1740      ([pdf](https://arxiv.org/pdf/1312.6229.pdf))
1741    Fast Image Scanning with Deep Max-Pooling Convolutional Neural Networks:
1742      [Giusti et al., 2013]
1743      (https://ieeexplore.ieee.org/abstract/document/6738831)
1744      ([pdf](https://arxiv.org/pdf/1302.1700.pdf))
1745  """
1746  return convolution(
1747      input=value,
1748      filter=filters,
1749      padding=padding,
1750      dilation_rate=np.broadcast_to(rate, (2,)),
1751      name=name)
1752
1753
1754def convert_padding(padding, expected_length=4):
1755  """Converts Python padding to C++ padding for ops which take EXPLICIT padding.
1756
1757  Args:
1758    padding: the `padding` argument for a Python op which supports EXPLICIT
1759      padding.
1760    expected_length: Expected number of entries in the padding list when
1761      explicit padding is used.
1762
1763  Returns:
1764    (padding, explicit_paddings) pair, which should be passed as attributes to a
1765    C++ op.
1766
1767  Raises:
1768    ValueError: If padding is invalid.
1769  """
1770  explicit_paddings = []
1771  if padding == "EXPLICIT":
1772    # Give a better error message if EXPLICIT is passed.
1773    raise ValueError('"EXPLICIT" is not a valid value for the padding '
1774                     "parameter. To use explicit padding, the padding "
1775                     "parameter must be a list.")
1776  if isinstance(padding, (list, tuple)):
1777    for i, dim_paddings in enumerate(padding):
1778      if not isinstance(dim_paddings, (list, tuple)):
1779        raise ValueError("When padding is a list, each element of padding must "
1780                         "be a list/tuple of size 2. Element with index %d of "
1781                         "padding is not a list/tuple" % i)
1782      if len(dim_paddings) != 2:
1783        raise ValueError("When padding is a list, each element of padding must "
1784                         "be a list/tuple of size 2. Element with index %d of "
1785                         "padding has size %d" % (i, len(dim_paddings)))
1786      explicit_paddings.extend(dim_paddings)
1787    if len(padding) != expected_length:
1788      raise ValueError("When padding is a list, it must be of size %d. Got "
1789                       "padding of size: %d" % (expected_length, len(padding)))
1790    padding = "EXPLICIT"
1791  return padding, explicit_paddings
1792
1793
1794@tf_export(v1=["nn.conv1d"])
1795@dispatch.add_dispatch_support
1796@deprecation.deprecated_arg_values(
1797    None,
1798    "`NCHW` for data_format is deprecated, use `NCW` instead",
1799    warn_once=True,
1800    data_format="NCHW")
1801@deprecation.deprecated_arg_values(
1802    None,
1803    "`NHWC` for data_format is deprecated, use `NWC` instead",
1804    warn_once=True,
1805    data_format="NHWC")
1806def conv1d(
1807    value=None,
1808    filters=None,
1809    stride=None,
1810    padding=None,
1811    use_cudnn_on_gpu=None,
1812    data_format=None,
1813    name=None,
1814    input=None,  # pylint: disable=redefined-builtin
1815    dilations=None):
1816  r"""Computes a 1-D convolution of input with rank `>=3` and a `3-D` filter.
1817
1818  Given an input tensor of shape
1819    `batch_shape + [in_width, in_channels]`
1820  if `data_format` is `"NWC"`, or
1821    `batch_shape + [in_channels, in_width]`
1822  if `data_format` is `"NCW"`,
1823  and a filter / kernel tensor of shape
1824  `[filter_width, in_channels, out_channels]`, this op reshapes
1825  the arguments to pass them to `conv2d` to perform the equivalent
1826  convolution operation.
1827
1828  Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
1829  For example, if `data_format` does not start with "NC", a tensor of shape
1830    `batch_shape + [in_width, in_channels]`
1831  is reshaped to
1832    `batch_shape + [1, in_width, in_channels]`,
1833  and the filter is reshaped to
1834    `[1, filter_width, in_channels, out_channels]`.
1835  The result is then reshaped back to
1836    `batch_shape + [out_width, out_channels]`
1837  \(where out_width is a function of the stride and padding as in conv2d\) and
1838  returned to the caller.
1839
1840  Args:
1841    value: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or
1842      `float64`.
1843    filters: A Tensor of rank at least 3.  Must have the same type as `value`.
1844    stride: An int or list of `ints` that has length `1` or `3`.  The number of
1845      entries by which the filter is moved right at each step.
1846    padding: 'SAME' or 'VALID'
1847    use_cudnn_on_gpu: An optional `bool`.  Defaults to `True`.
1848    data_format: An optional `string` from `"NWC", "NCW"`.  Defaults to `"NWC"`,
1849      the data is stored in the order of `batch_shape + [in_width,
1850      in_channels]`.  The `"NCW"` format stores data as `batch_shape +
1851      [in_channels, in_width]`.
1852    name: A name for the operation (optional).
1853    input: Alias for value.
1854    dilations: An int or list of `ints` that has length `1` or `3` which
1855      defaults to 1. The dilation factor for each dimension of input. If set to
1856      k > 1, there will be k-1 skipped cells between each filter element on that
1857      dimension. Dilations in the batch and depth dimensions must be 1.
1858
1859  Returns:
1860    A `Tensor`.  Has the same type as input.
1861
1862  Raises:
1863    ValueError: if `data_format` is invalid.
1864  """
1865  value = deprecation.deprecated_argument_lookup("input", input, "value", value)
1866  with ops.name_scope(name, "conv1d", [value, filters]) as name:
1867    # Reshape the input tensor to batch_shape + [1, in_width, in_channels]
1868    if data_format is None or data_format == "NHWC" or data_format == "NWC":
1869      data_format = "NHWC"
1870      spatial_start_dim = -3
1871      channel_index = 2
1872    elif data_format == "NCHW" or data_format == "NCW":
1873      data_format = "NCHW"
1874      spatial_start_dim = -2
1875      channel_index = 1
1876    else:
1877      raise ValueError("data_format must be \"NWC\" or \"NCW\".")
1878    strides = [1] + _get_sequence(stride, 1, channel_index, "stride")
1879    dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
1880
1881    value = array_ops.expand_dims(value, spatial_start_dim)
1882    filters = array_ops.expand_dims(filters, 0)
1883    if value.shape.ndims in (4, 3, 2, 1, 0, None):
1884      result = gen_nn_ops.conv2d(
1885          value,
1886          filters,
1887          strides,
1888          padding,
1889          use_cudnn_on_gpu=use_cudnn_on_gpu,
1890          data_format=data_format,
1891          dilations=dilations,
1892          name=name)
1893    else:
1894      result = squeeze_batch_dims(
1895          value,
1896          functools.partial(
1897              gen_nn_ops.conv2d,
1898              filter=filters,
1899              strides=strides,
1900              padding=padding,
1901              use_cudnn_on_gpu=use_cudnn_on_gpu,
1902              data_format=data_format,
1903              dilations=dilations,
1904          ),
1905          inner_rank=3,
1906          name=name)
1907    return array_ops.squeeze(result, [spatial_start_dim])
1908
1909
1910@tf_export("nn.conv1d", v1=[])
1911@dispatch.add_dispatch_support
1912def conv1d_v2(
1913    input,  # pylint: disable=redefined-builtin
1914    filters,
1915    stride,
1916    padding,
1917    data_format="NWC",
1918    dilations=None,
1919    name=None):
1920  r"""Computes a 1-D convolution given 3-D input and filter tensors.
1921
1922  Given an input tensor of shape
1923    `batch_shape + [in_width, in_channels]`
1924  if `data_format` is `"NWC"`, or
1925    `batch_shape + [in_channels, in_width]`
1926  if `data_format` is `"NCW"`,
1927  and a filter / kernel tensor of shape
1928  `[filter_width, in_channels, out_channels]`, this op reshapes
1929  the arguments to pass them to `conv2d` to perform the equivalent
1930  convolution operation.
1931
1932  Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
1933  For example, if `data_format` does not start with `"NC"`, a tensor of shape
1934    `batch_shape + [in_width, in_channels]`
1935  is reshaped to
1936    `batch_shape + [1, in_width, in_channels]`,
1937  and the filter is reshaped to
1938    `[1, filter_width, in_channels, out_channels]`.
1939  The result is then reshaped back to
1940    `batch_shape + [out_width, out_channels]`
1941  \(where out_width is a function of the stride and padding as in conv2d\) and
1942  returned to the caller.
1943
1944  Args:
1945    input: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or
1946      `float64`.
1947    filters: A Tensor of rank at least 3.  Must have the same type as `input`.
1948    stride: An int or list of `ints` that has length `1` or `3`.  The number of
1949      entries by which the filter is moved right at each step.
1950    padding: 'SAME' or 'VALID'
1951    data_format: An optional `string` from `"NWC", "NCW"`.  Defaults to `"NWC"`,
1952      the data is stored in the order of
1953      `batch_shape + [in_width, in_channels]`.  The `"NCW"` format stores data
1954      as `batch_shape + [in_channels, in_width]`.
1955    dilations: An int or list of `ints` that has length `1` or `3` which
1956      defaults to 1. The dilation factor for each dimension of input. If set to
1957      k > 1, there will be k-1 skipped cells between each filter element on that
1958      dimension. Dilations in the batch and depth dimensions must be 1.
1959    name: A name for the operation (optional).
1960
1961  Returns:
1962    A `Tensor`.  Has the same type as input.
1963
1964  Raises:
1965    ValueError: if `data_format` is invalid.
1966  """
1967  return conv1d(
1968      input,  # pylint: disable=redefined-builtin
1969      filters,
1970      stride,
1971      padding,
1972      use_cudnn_on_gpu=True,
1973      data_format=data_format,
1974      name=name,
1975      dilations=dilations)
1976
1977
1978@tf_export("nn.conv1d_transpose")
1979@dispatch.add_dispatch_support
1980def conv1d_transpose(
1981    input,  # pylint: disable=redefined-builtin
1982    filters,
1983    output_shape,
1984    strides,
1985    padding="SAME",
1986    data_format="NWC",
1987    dilations=None,
1988    name=None):
1989  """The transpose of `conv1d`.
1990
1991  This operation is sometimes called "deconvolution" after
1992  (Zeiler et al., 2010), but is actually the transpose (gradient) of `conv1d`
1993  rather than an actual deconvolution.
1994
1995  Args:
1996    input: A 3-D `Tensor` of type `float` and shape
1997      `[batch, in_width, in_channels]` for `NWC` data format or
1998      `[batch, in_channels, in_width]` for `NCW` data format.
1999    filters: A 3-D `Tensor` with the same type as `input` and shape
2000      `[filter_width, output_channels, in_channels]`.  `filter`'s
2001      `in_channels` dimension must match that of `input`.
2002    output_shape: A 1-D `Tensor`, containing three elements, representing the
2003      output shape of the deconvolution op.
2004    strides: An int or list of `ints` that has length `1` or `3`.  The number of
2005      entries by which the filter is moved right at each step.
2006    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
2007      See the "returns" section of `tf.nn.convolution` for details.
2008    data_format: A string. `'NWC'` and `'NCW'` are supported.
2009    dilations: An int or list of `ints` that has length `1` or `3` which
2010      defaults to 1. The dilation factor for each dimension of input. If set to
2011      k > 1, there will be k-1 skipped cells between each filter element on that
2012      dimension. Dilations in the batch and depth dimensions must be 1.
2013    name: Optional name for the returned tensor.
2014
2015  Returns:
2016    A `Tensor` with the same type as `input`.
2017
2018  Raises:
2019    ValueError: If input/output depth does not match `filter`'s shape, if
2020      `output_shape` is not at 3-element vector, if `padding` is other than
2021      `'VALID'` or `'SAME'`, or if `data_format` is invalid.
2022
2023  References:
2024    Deconvolutional Networks:
2025      [Zeiler et al., 2010]
2026      (https://ieeexplore.ieee.org/abstract/document/5539957)
2027      ([pdf]
2028      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
2029  """
2030  with ops.name_scope(name, "conv1d_transpose",
2031                      [input, filters, output_shape]) as name:
2032    # The format could be either NWC or NCW, map to NHWC or NCHW
2033    if data_format is None or data_format == "NWC":
2034      data_format = "NHWC"
2035      spatial_start_dim = 1
2036      channel_index = 2
2037    elif data_format == "NCW":
2038      data_format = "NCHW"
2039      spatial_start_dim = 2
2040      channel_index = 1
2041    else:
2042      raise ValueError("data_format must be \"NWC\" or \"NCW\".")
2043
2044    # Reshape the input tensor to [batch, 1, in_width, in_channels]
2045    strides = [1] + _get_sequence(strides, 1, channel_index, "stride")
2046    dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
2047
2048    input = array_ops.expand_dims(input, spatial_start_dim)
2049    filters = array_ops.expand_dims(filters, 0)
2050    output_shape = list(output_shape) if not isinstance(
2051        output_shape, ops.Tensor) else output_shape
2052    output_shape = array_ops.concat([output_shape[: spatial_start_dim], [1],
2053                                     output_shape[spatial_start_dim:]], 0)
2054
2055    result = gen_nn_ops.conv2d_backprop_input(
2056        input_sizes=output_shape,
2057        filter=filters,
2058        out_backprop=input,
2059        strides=strides,
2060        padding=padding,
2061        data_format=data_format,
2062        dilations=dilations,
2063        name=name)
2064    return array_ops.squeeze(result, spatial_start_dim)
2065
2066
2067@tf_export("nn.conv2d", v1=[])
2068@dispatch.add_dispatch_support
2069def conv2d_v2(input,  # pylint: disable=redefined-builtin
2070              filters,
2071              strides,
2072              padding,
2073              data_format="NHWC",
2074              dilations=None,
2075              name=None):
2076  # pylint: disable=line-too-long
2077  r"""Computes a 2-D convolution given `input` and 4-D `filters` tensors.
2078
2079  The `input` tensor may have rank `4` or higher, where shape dimensions `[:-3]`
2080  are considered batch dimensions (`batch_shape`).
2081
2082  Given an input tensor of shape
2083  `batch_shape + [in_height, in_width, in_channels]` and a filter / kernel
2084  tensor of shape `[filter_height, filter_width, in_channels, out_channels]`,
2085  this op performs the following:
2086
2087  1. Flattens the filter to a 2-D matrix with shape
2088     `[filter_height * filter_width * in_channels, output_channels]`.
2089  2. Extracts image patches from the input tensor to form a *virtual*
2090     tensor of shape `[batch, out_height, out_width,
2091     filter_height * filter_width * in_channels]`.
2092  3. For each patch, right-multiplies the filter matrix and the image patch
2093     vector.
2094
2095  In detail, with the default NHWC format,
2096
2097      output[b, i, j, k] =
2098          sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
2099                          filter[di, dj, q, k]
2100
2101  Must have `strides[0] = strides[3] = 1`.  For the most common case of the same
2102  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
2103
2104  Usage Example:
2105
2106  >>> x_in = np.array([[
2107  ...   [[2], [1], [2], [0], [1]],
2108  ...   [[1], [3], [2], [2], [3]],
2109  ...   [[1], [1], [3], [3], [0]],
2110  ...   [[2], [2], [0], [1], [1]],
2111  ...   [[0], [0], [3], [1], [2]], ]])
2112  >>> kernel_in = np.array([
2113  ...  [ [[2, 0.1]], [[3, 0.2]] ],
2114  ...  [ [[0, 0.3]],[[1, 0.4]] ], ])
2115  >>> x = tf.constant(x_in, dtype=tf.float32)
2116  >>> kernel = tf.constant(kernel_in, dtype=tf.float32)
2117  >>> tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
2118  <tf.Tensor: shape=(1, 4, 4, 2), dtype=float32, numpy=..., dtype=float32)>
2119
2120  Args:
2121    input: A `Tensor`. Must be one of the following types:
2122      `half`, `bfloat16`, `float32`, `float64`.
2123      A Tensor of rank at least 4. The dimension order is interpreted according
2124      to the value of `data_format`; with the all-but-inner-3 dimensions acting
2125      as batch dimensions. See below for details.
2126    filters: A `Tensor`. Must have the same type as `input`.
2127      A 4-D tensor of shape
2128      `[filter_height, filter_width, in_channels, out_channels]`
2129    strides: An int or list of `ints` that has length `1`, `2` or `4`.  The
2130      stride of the sliding window for each dimension of `input`. If a single
2131      value is given it is replicated in the `H` and `W` dimension. By default
2132      the `N` and `C` dimensions are set to 1. The dimension order is determined
2133      by the value of `data_format`, see below for details.
2134    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2135      padding algorithm to use, or a list indicating the explicit paddings at
2136      the start and end of each dimension. When explicit padding is used and
2137      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2138      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2139      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2140      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2141    data_format: An optional `string` from: `"NHWC", "NCHW"`.
2142      Defaults to `"NHWC"`.
2143      Specify the data format of the input and output data. With the
2144      default format "NHWC", the data is stored in the order of:
2145          `batch_shape + [height, width, channels]`.
2146      Alternatively, the format could be "NCHW", the data storage order of:
2147          `batch_shape + [channels, height, width]`.
2148    dilations: An int or list of `ints` that has length `1`, `2` or `4`,
2149      defaults to 1. The dilation factor for each dimension of`input`. If a
2150      single value is given it is replicated in the `H` and `W` dimension. By
2151      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
2152      will be k-1 skipped cells between each filter element on that dimension.
2153      The dimension order is determined by the value of `data_format`, see above
2154      for details. Dilations in the batch and depth dimensions if a 4-d tensor
2155      must be 1.
2156    name: A name for the operation (optional).
2157
2158  Returns:
2159    A `Tensor`. Has the same type as `input` and the same outer batch shape.
2160  """
2161  # pylint: enable=line-too-long
2162  return conv2d(input,  # pylint: disable=redefined-builtin
2163                filters,
2164                strides,
2165                padding,
2166                use_cudnn_on_gpu=True,
2167                data_format=data_format,
2168                dilations=dilations,
2169                name=name)
2170
2171
2172@tf_export(v1=["nn.conv2d"])
2173@dispatch.add_dispatch_support
2174def conv2d(  # pylint: disable=redefined-builtin,dangerous-default-value
2175    input,
2176    filter=None,
2177    strides=None,
2178    padding=None,
2179    use_cudnn_on_gpu=True,
2180    data_format="NHWC",
2181    dilations=[1, 1, 1, 1],
2182    name=None,
2183    filters=None):
2184  r"""Computes a 2-D convolution given 4-D `input` and `filter` tensors.
2185
2186  Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
2187  and a filter / kernel tensor of shape
2188  `[filter_height, filter_width, in_channels, out_channels]`, this op
2189  performs the following:
2190
2191  1. Flattens the filter to a 2-D matrix with shape
2192     `[filter_height * filter_width * in_channels, output_channels]`.
2193  2. Extracts image patches from the input tensor to form a *virtual*
2194     tensor of shape `[batch, out_height, out_width,
2195     filter_height * filter_width * in_channels]`.
2196  3. For each patch, right-multiplies the filter matrix and the image patch
2197     vector.
2198
2199  In detail, with the default NHWC format,
2200
2201      output[b, i, j, k] =
2202          sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q]
2203                          * filter[di, dj, q, k]
2204
2205  Must have `strides[0] = strides[3] = 1`.  For the most common case of the same
2206  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
2207
2208  Args:
2209    input: A `Tensor`. Must be one of the following types:
2210      `half`, `bfloat16`, `float32`, `float64`.
2211      A 4-D tensor. The dimension order is interpreted according to the value
2212      of `data_format`, see below for details.
2213    filter: A `Tensor`. Must have the same type as `input`.
2214      A 4-D tensor of shape
2215      `[filter_height, filter_width, in_channels, out_channels]`
2216    strides: An int or list of `ints` that has length `1`, `2` or `4`.  The
2217      stride of the sliding window for each dimension of `input`. If a single
2218      value is given it is replicated in the `H` and `W` dimension. By default
2219      the `N` and `C` dimensions are set to 1. The dimension order is determined
2220      by the value of `data_format`, see below for details.
2221    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2222      padding algorithm to use, or a list indicating the explicit paddings at
2223      the start and end of each dimension. When explicit padding is used and
2224      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2225      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2226      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2227      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2228    use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
2229    data_format: An optional `string` from: `"NHWC", "NCHW"`.
2230      Defaults to `"NHWC"`.
2231      Specify the data format of the input and output data. With the
2232      default format "NHWC", the data is stored in the order of:
2233          [batch, height, width, channels].
2234      Alternatively, the format could be "NCHW", the data storage order of:
2235          [batch, channels, height, width].
2236    dilations: An int or list of `ints` that has length `1`, `2` or `4`,
2237      defaults to 1. The dilation factor for each dimension of`input`. If a
2238      single value is given it is replicated in the `H` and `W` dimension. By
2239      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
2240      will be k-1 skipped cells between each filter element on that dimension.
2241      The dimension order is determined by the value of `data_format`, see above
2242      for details. Dilations in the batch and depth dimensions if a 4-d tensor
2243      must be 1.
2244    name: A name for the operation (optional).
2245    filters: Alias for filter.
2246
2247  Returns:
2248    A `Tensor`. Has the same type as `input`.
2249  """
2250  filter = deprecation.deprecated_argument_lookup(
2251      "filters", filters, "filter", filter)
2252  padding, explicit_paddings = convert_padding(padding)
2253  if data_format is None:
2254    data_format = "NHWC"
2255  channel_index = 1 if data_format.startswith("NC") else 3
2256
2257  strides = _get_sequence(strides, 2, channel_index, "strides")
2258  dilations = _get_sequence(dilations, 2, channel_index, "dilations")
2259
2260  shape = input.shape
2261  # shape object may lack ndims, e.g., if input is an np.ndarray.  In that case,
2262  # we fall back to len(shape).
2263  ndims = getattr(shape, "ndims", -1)
2264  if ndims == -1:
2265    ndims = len(shape)
2266  if ndims in (4, 3, 2, 1, 0, None):
2267    # We avoid calling squeeze_batch_dims to reduce extra python function
2268    # call slowdown in eager mode.  This branch doesn't require reshapes.
2269    return gen_nn_ops.conv2d(
2270        input,
2271        filter=filter,
2272        strides=strides,
2273        padding=padding,
2274        use_cudnn_on_gpu=use_cudnn_on_gpu,
2275        explicit_paddings=explicit_paddings,
2276        data_format=data_format,
2277        dilations=dilations,
2278        name=name)
2279  return squeeze_batch_dims(
2280      input,
2281      functools.partial(
2282          gen_nn_ops.conv2d,
2283          filter=filter,
2284          strides=strides,
2285          padding=padding,
2286          use_cudnn_on_gpu=use_cudnn_on_gpu,
2287          explicit_paddings=explicit_paddings,
2288          data_format=data_format,
2289          dilations=dilations),
2290      inner_rank=3,
2291      name=name)
2292
2293
2294@tf_export(v1=["nn.conv2d_backprop_filter"])
2295@dispatch.add_dispatch_support
2296def conv2d_backprop_filter(  # pylint: disable=redefined-builtin,dangerous-default-value
2297    input,
2298    filter_sizes,
2299    out_backprop,
2300    strides,
2301    padding,
2302    use_cudnn_on_gpu=True,
2303    data_format="NHWC",
2304    dilations=[1, 1, 1, 1],
2305    name=None):
2306  r"""Computes the gradients of convolution with respect to the filter.
2307
2308  Args:
2309    input: A `Tensor`. Must be one of the following types:
2310      `half`, `bfloat16`, `float32`, `float64`.
2311      4-D with shape `[batch, in_height, in_width, in_channels]`.
2312    filter_sizes: A `Tensor` of type `int32`.
2313      An integer vector representing the tensor shape of `filter`,
2314      where `filter` is a 4-D
2315      `[filter_height, filter_width, in_channels, out_channels]` tensor.
2316    out_backprop: A `Tensor`. Must have the same type as `input`.
2317      4-D with shape `[batch, out_height, out_width, out_channels]`.
2318      Gradients w.r.t. the output of the convolution.
2319    strides: A list of `ints`.
2320      The stride of the sliding window for each dimension of the input
2321      of the convolution. Must be in the same order as the dimension specified
2322      with format.
2323    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2324      padding algorithm to use, or a list indicating the explicit paddings at
2325      the start and end of each dimension. When explicit padding is used and
2326      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2327      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2328      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2329      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2330    use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
2331    data_format: An optional `string` from: `"NHWC", "NCHW"`.
2332      Defaults to `"NHWC"`.
2333      Specify the data format of the input and output data. With the
2334      default format "NHWC", the data is stored in the order of:
2335          [batch, in_height, in_width, in_channels].
2336      Alternatively, the format could be "NCHW", the data storage order of:
2337          [batch, in_channels, in_height, in_width].
2338    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`.
2339      1-D tensor of length 4.  The dilation factor for each dimension of
2340      `input`. If set to k > 1, there will be k-1 skipped cells between each
2341      filter element on that dimension. The dimension order is determined by
2342      the value of `data_format`, see above for details. Dilations in the batch
2343      and depth dimensions must be 1.
2344    name: A name for the operation (optional).
2345
2346  Returns:
2347    A `Tensor`. Has the same type as `input`.
2348  """
2349  padding, explicit_paddings = convert_padding(padding)
2350  return gen_nn_ops.conv2d_backprop_filter(
2351      input, filter_sizes, out_backprop, strides, padding, use_cudnn_on_gpu,
2352      explicit_paddings, data_format, dilations, name)
2353
2354
2355@tf_export(v1=["nn.conv2d_backprop_input"])
2356@dispatch.add_dispatch_support
2357def conv2d_backprop_input(  # pylint: disable=redefined-builtin,dangerous-default-value
2358    input_sizes,
2359    filter=None,
2360    out_backprop=None,
2361    strides=None,
2362    padding=None,
2363    use_cudnn_on_gpu=True,
2364    data_format="NHWC",
2365    dilations=[1, 1, 1, 1],
2366    name=None,
2367    filters=None):
2368  r"""Computes the gradients of convolution with respect to the input.
2369
2370  Args:
2371    input_sizes: A `Tensor` of type `int32`.
2372      An integer vector representing the shape of `input`,
2373      where `input` is a 4-D `[batch, height, width, channels]` tensor.
2374    filter: A `Tensor`. Must be one of the following types:
2375      `half`, `bfloat16`, `float32`, `float64`.
2376      4-D with shape
2377      `[filter_height, filter_width, in_channels, out_channels]`.
2378    out_backprop: A `Tensor`. Must have the same type as `filter`.
2379      4-D with shape `[batch, out_height, out_width, out_channels]`.
2380      Gradients w.r.t. the output of the convolution.
2381    strides: A list of `ints`.
2382      The stride of the sliding window for each dimension of the input
2383      of the convolution. Must be in the same order as the dimension specified
2384      with format.
2385    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2386      padding algorithm to use, or a list indicating the explicit paddings at
2387      the start and end of each dimension. When explicit padding is used and
2388      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2389      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2390      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2391      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2392    use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
2393    data_format: An optional `string` from: `"NHWC", "NCHW"`.
2394      Defaults to `"NHWC"`.
2395      Specify the data format of the input and output data. With the
2396      default format "NHWC", the data is stored in the order of:
2397          [batch, in_height, in_width, in_channels].
2398      Alternatively, the format could be "NCHW", the data storage order of:
2399          [batch, in_channels, in_height, in_width].
2400    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`.
2401      1-D tensor of length 4.  The dilation factor for each dimension of
2402      `input`. If set to k > 1, there will be k-1 skipped cells between each
2403      filter element on that dimension. The dimension order is determined by
2404      the value of `data_format`, see above for details. Dilations in the batch
2405      and depth dimensions must be 1.
2406    name: A name for the operation (optional).
2407    filters: Alias for filter.
2408
2409  Returns:
2410    A `Tensor`. Has the same type as `filter`.
2411  """
2412  filter = deprecation.deprecated_argument_lookup(
2413      "filters", filters, "filter", filter)
2414  padding, explicit_paddings = convert_padding(padding)
2415  return gen_nn_ops.conv2d_backprop_input(
2416      input_sizes, filter, out_backprop, strides, padding, use_cudnn_on_gpu,
2417      explicit_paddings, data_format, dilations, name)
2418
2419
2420@tf_export(v1=["nn.conv2d_transpose"])
2421@dispatch.add_dispatch_support
2422def conv2d_transpose(
2423    value=None,
2424    filter=None,  # pylint: disable=redefined-builtin
2425    output_shape=None,
2426    strides=None,
2427    padding="SAME",
2428    data_format="NHWC",
2429    name=None,
2430    input=None,  # pylint: disable=redefined-builtin
2431    filters=None,
2432    dilations=None):
2433  """The transpose of `conv2d`.
2434
2435  This operation is sometimes called "deconvolution" after
2436  (Zeiler et al., 2010), but is really the transpose (gradient) of `conv2d`
2437  rather than an actual deconvolution.
2438
2439  Args:
2440    value: A 4-D `Tensor` of type `float` and shape
2441      `[batch, height, width, in_channels]` for `NHWC` data format or
2442      `[batch, in_channels, height, width]` for `NCHW` data format.
2443    filter: A 4-D `Tensor` with the same type as `value` and shape
2444      `[height, width, output_channels, in_channels]`.  `filter`'s
2445      `in_channels` dimension must match that of `value`.
2446    output_shape: A 1-D `Tensor` representing the output shape of the
2447      deconvolution op.
2448    strides: An int or list of `ints` that has length `1`, `2` or `4`.  The
2449      stride of the sliding window for each dimension of `input`. If a single
2450      value is given it is replicated in the `H` and `W` dimension. By default
2451      the `N` and `C` dimensions are set to 0. The dimension order is determined
2452      by the value of `data_format`, see below for details.
2453    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
2454      See the "returns" section of `tf.nn.convolution` for details.
2455    data_format: A string. 'NHWC' and 'NCHW' are supported.
2456    name: Optional name for the returned tensor.
2457    input: Alias for value.
2458    filters: Alias for filter.
2459    dilations: An int or list of `ints` that has length `1`, `2` or `4`,
2460      defaults to 1. The dilation factor for each dimension of`input`. If a
2461      single value is given it is replicated in the `H` and `W` dimension. By
2462      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
2463      will be k-1 skipped cells between each filter element on that dimension.
2464      The dimension order is determined by the value of `data_format`, see above
2465      for details. Dilations in the batch and depth dimensions if a 4-d tensor
2466      must be 1.
2467
2468  Returns:
2469    A `Tensor` with the same type as `value`.
2470
2471  Raises:
2472    ValueError: If input/output depth does not match `filter`'s shape, or if
2473      padding is other than `'VALID'` or `'SAME'`.
2474
2475  References:
2476    Deconvolutional Networks:
2477      [Zeiler et al., 2010]
2478      (https://ieeexplore.ieee.org/abstract/document/5539957)
2479      ([pdf]
2480      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
2481  """
2482  value = deprecated_argument_lookup("input", input, "value", value)
2483  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
2484  with ops.name_scope(name, "conv2d_transpose",
2485                      [value, filter, output_shape]) as name:
2486    return conv2d_transpose_v2(
2487        value,
2488        filter,
2489        output_shape,
2490        strides,
2491        padding=padding,
2492        data_format=data_format,
2493        dilations=dilations,
2494        name=name)
2495
2496
2497@tf_export("nn.conv2d_transpose", v1=[])
2498@dispatch.add_dispatch_support
2499def conv2d_transpose_v2(
2500    input,  # pylint: disable=redefined-builtin
2501    filters,  # pylint: disable=redefined-builtin
2502    output_shape,
2503    strides,
2504    padding="SAME",
2505    data_format="NHWC",
2506    dilations=None,
2507    name=None):
2508  """The transpose of `conv2d`.
2509
2510  This operation is sometimes called "deconvolution" after
2511  (Zeiler et al., 2010), but is really the transpose (gradient) of
2512  `atrous_conv2d` rather than an actual deconvolution.
2513
2514  Args:
2515    input: A 4-D `Tensor` of type `float` and shape `[batch, height, width,
2516      in_channels]` for `NHWC` data format or `[batch, in_channels, height,
2517      width]` for `NCHW` data format.
2518    filters: A 4-D `Tensor` with the same type as `input` and shape `[height,
2519      width, output_channels, in_channels]`.  `filter`'s `in_channels` dimension
2520      must match that of `input`.
2521    output_shape: A 1-D `Tensor` representing the output shape of the
2522      deconvolution op.
2523    strides: An int or list of `ints` that has length `1`, `2` or `4`.  The
2524      stride of the sliding window for each dimension of `input`. If a single
2525      value is given it is replicated in the `H` and `W` dimension. By default
2526      the `N` and `C` dimensions are set to 0. The dimension order is determined
2527      by the value of `data_format`, see below for details.
2528    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2529      padding algorithm to use, or a list indicating the explicit paddings at
2530      the start and end of each dimension. When explicit padding is used and
2531      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2532      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2533      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2534      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2535    data_format: A string. 'NHWC' and 'NCHW' are supported.
2536    dilations: An int or list of `ints` that has length `1`, `2` or `4`,
2537      defaults to 1. The dilation factor for each dimension of`input`. If a
2538      single value is given it is replicated in the `H` and `W` dimension. By
2539      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
2540      will be k-1 skipped cells between each filter element on that dimension.
2541      The dimension order is determined by the value of `data_format`, see above
2542      for details. Dilations in the batch and depth dimensions if a 4-d tensor
2543      must be 1.
2544    name: Optional name for the returned tensor.
2545
2546  Returns:
2547    A `Tensor` with the same type as `input`.
2548
2549  Raises:
2550    ValueError: If input/output depth does not match `filter`'s shape, or if
2551      padding is other than `'VALID'` or `'SAME'`.
2552
2553  References:
2554    Deconvolutional Networks:
2555      [Zeiler et al., 2010]
2556      (https://ieeexplore.ieee.org/abstract/document/5539957)
2557      ([pdf]
2558      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
2559  """
2560  with ops.name_scope(name, "conv2d_transpose",
2561                      [input, filter, output_shape]) as name:
2562    if data_format is None:
2563      data_format = "NHWC"
2564    channel_index = 1 if data_format.startswith("NC") else 3
2565
2566    strides = _get_sequence(strides, 2, channel_index, "strides")
2567    dilations = _get_sequence(dilations, 2, channel_index, "dilations")
2568    padding, explicit_paddings = convert_padding(padding)
2569
2570    return gen_nn_ops.conv2d_backprop_input(
2571        input_sizes=output_shape,
2572        filter=filters,
2573        out_backprop=input,
2574        strides=strides,
2575        padding=padding,
2576        explicit_paddings=explicit_paddings,
2577        data_format=data_format,
2578        dilations=dilations,
2579        name=name)
2580
2581
2582def _conv2d_expanded_batch(
2583    input,  # pylint: disable=redefined-builtin
2584    filters,
2585    strides,
2586    padding,
2587    data_format,
2588    dilations,
2589    name):
2590  """Helper function for `convolution_internal`; handles expanded batches."""
2591  # Try really hard to avoid modifying the legacy name scopes - return early.
2592  input_rank = input.shape.rank
2593  if input_rank is None or input_rank < 5:
2594    # We avoid calling squeeze_batch_dims to reduce extra python function
2595    # call slowdown in eager mode.  This branch doesn't require reshapes.
2596    return gen_nn_ops.conv2d(
2597        input,
2598        filter=filters,
2599        strides=strides,
2600        padding=padding,
2601        data_format=data_format,
2602        dilations=dilations,
2603        name=name)
2604  return squeeze_batch_dims(
2605      input,
2606      functools.partial(
2607          gen_nn_ops.conv2d,
2608          filter=filters,
2609          strides=strides,
2610          padding=padding,
2611          data_format=data_format,
2612          dilations=dilations),
2613      inner_rank=3,
2614      name=name)
2615
2616
2617@tf_export("nn.atrous_conv2d_transpose")
2618@dispatch.add_dispatch_support
2619def atrous_conv2d_transpose(value,
2620                            filters,
2621                            output_shape,
2622                            rate,
2623                            padding,
2624                            name=None):
2625  """The transpose of `atrous_conv2d`.
2626
2627  This operation is sometimes called "deconvolution" after
2628  (Zeiler et al., 2010), but is really the transpose (gradient) of
2629  `atrous_conv2d` rather than an actual deconvolution.
2630
2631  Args:
2632    value: A 4-D `Tensor` of type `float`. It needs to be in the default `NHWC`
2633      format. Its shape is `[batch, in_height, in_width, in_channels]`.
2634    filters: A 4-D `Tensor` with the same type as `value` and shape
2635      `[filter_height, filter_width, out_channels, in_channels]`. `filters`'
2636      `in_channels` dimension must match that of `value`. Atrous convolution is
2637      equivalent to standard convolution with upsampled filters with effective
2638      height `filter_height + (filter_height - 1) * (rate - 1)` and effective
2639      width `filter_width + (filter_width - 1) * (rate - 1)`, produced by
2640      inserting `rate - 1` zeros along consecutive elements across the
2641      `filters`' spatial dimensions.
2642    output_shape: A 1-D `Tensor` of shape representing the output shape of the
2643      deconvolution op.
2644    rate: A positive int32. The stride with which we sample input values across
2645      the `height` and `width` dimensions. Equivalently, the rate by which we
2646      upsample the filter values by inserting zeros across the `height` and
2647      `width` dimensions. In the literature, the same parameter is sometimes
2648      called `input stride` or `dilation`.
2649    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
2650    name: Optional name for the returned tensor.
2651
2652  Returns:
2653    A `Tensor` with the same type as `value`.
2654
2655  Raises:
2656    ValueError: If input/output depth does not match `filters`' shape, or if
2657      padding is other than `'VALID'` or `'SAME'`, or if the `rate` is less
2658      than one, or if the output_shape is not a tensor with 4 elements.
2659
2660  References:
2661    Deconvolutional Networks:
2662      [Zeiler et al., 2010]
2663      (https://ieeexplore.ieee.org/abstract/document/5539957)
2664      ([pdf]
2665      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
2666  """
2667  with ops.name_scope(name, "atrous_conv2d_transpose",
2668                      [value, filters, output_shape]) as name:
2669    value = ops.convert_to_tensor(value, name="value")
2670    filters = ops.convert_to_tensor(filters, name="filters")
2671    if not value.get_shape().dims[3].is_compatible_with(filters.get_shape()[3]):
2672      raise ValueError(
2673          "value's input channels does not match filters' input channels, "
2674          "{} != {}".format(value.get_shape()[3],
2675                            filters.get_shape()[3]))
2676    if rate < 1:
2677      raise ValueError("rate {} cannot be less than one".format(rate))
2678
2679    if rate == 1:
2680      return conv2d_transpose(
2681          value,
2682          filters,
2683          output_shape,
2684          strides=[1, 1, 1, 1],
2685          padding=padding,
2686          data_format="NHWC")
2687
2688    output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
2689    if not output_shape_.get_shape().is_compatible_with(
2690        tensor_shape.TensorShape([4])):
2691      raise ValueError("output_shape must have shape (4,), got {}".format(
2692          output_shape_.get_shape()))
2693
2694    if isinstance(output_shape, tuple):
2695      output_shape = list(output_shape)
2696
2697    if isinstance(output_shape, (list, np.ndarray)):
2698      # output_shape's shape should be == [4] if reached this point.
2699      if not filters.get_shape().dims[2].is_compatible_with(output_shape[3]):
2700        raise ValueError(
2701            "output_shape does not match filter's output channels, "
2702            "{} != {}".format(output_shape[3],
2703                              filters.get_shape()[2]))
2704
2705    # We have two padding contributions. The first is used for converting "SAME"
2706    # to "VALID". The second is required so that the height and width of the
2707    # zero-padded value tensor are multiples of rate.
2708
2709    # Padding required to reduce to "VALID" convolution
2710    if padding == "SAME":
2711      # Handle filters whose shape is unknown during graph creation.
2712      if filters.get_shape().is_fully_defined():
2713        filter_shape = filters.get_shape().as_list()
2714      else:
2715        filter_shape = array_ops.shape(filters)
2716      filter_height, filter_width = filter_shape[0], filter_shape[1]
2717
2718      # Spatial dimensions of the filters and the upsampled filters in which we
2719      # introduce (rate - 1) zeros between consecutive filter values.
2720      filter_height_up = filter_height + (filter_height - 1) * (rate - 1)
2721      filter_width_up = filter_width + (filter_width - 1) * (rate - 1)
2722
2723      pad_height = filter_height_up - 1
2724      pad_width = filter_width_up - 1
2725
2726      # When pad_height (pad_width) is odd, we pad more to bottom (right),
2727      # following the same convention as conv2d().
2728      pad_top = pad_height // 2
2729      pad_bottom = pad_height - pad_top
2730      pad_left = pad_width // 2
2731      pad_right = pad_width - pad_left
2732    elif padding == "VALID":
2733      pad_top = 0
2734      pad_bottom = 0
2735      pad_left = 0
2736      pad_right = 0
2737    else:
2738      raise ValueError("padding must be either VALID or SAME:"
2739                       " {}".format(padding))
2740
2741    in_height = output_shape[1] + pad_top + pad_bottom
2742    in_width = output_shape[2] + pad_left + pad_right
2743
2744    # More padding so that rate divides the height and width of the input.
2745    pad_bottom_extra = (rate - in_height % rate) % rate
2746    pad_right_extra = (rate - in_width % rate) % rate
2747
2748    # The paddings argument to space_to_batch is just the extra padding
2749    # component.
2750    space_to_batch_pad = [[0, pad_bottom_extra], [0, pad_right_extra]]
2751
2752    value = array_ops.space_to_batch(
2753        input=value, paddings=space_to_batch_pad, block_size=rate)
2754
2755    input_sizes = [
2756        rate * rate * output_shape[0], (in_height + pad_bottom_extra) // rate,
2757        (in_width + pad_right_extra) // rate, output_shape[3]
2758    ]
2759
2760    value = gen_nn_ops.conv2d_backprop_input(
2761        input_sizes=input_sizes,
2762        filter=filters,
2763        out_backprop=value,
2764        strides=[1, 1, 1, 1],
2765        padding="VALID",
2766        data_format="NHWC")
2767
2768    # The crops argument to batch_to_space includes both padding components.
2769    batch_to_space_crop = [[pad_top, pad_bottom + pad_bottom_extra],
2770                           [pad_left, pad_right + pad_right_extra]]
2771
2772    return array_ops.batch_to_space(
2773        input=value, crops=batch_to_space_crop, block_size=rate)
2774
2775
2776@tf_export(v1=["nn.depthwise_conv2d_native"])
2777@dispatch.add_dispatch_support
2778@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native")
2779def depthwise_conv2d_native(  # pylint: disable=redefined-builtin,dangerous-default-value
2780    input,
2781    filter,
2782    strides,
2783    padding,
2784    data_format="NHWC",
2785    dilations=[1, 1, 1, 1],
2786    name=None):
2787  r"""Computes a 2-D depthwise convolution.
2788
2789  Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
2790  and a filter / kernel tensor of shape
2791  `[filter_height, filter_width, in_channels, channel_multiplier]`, containing
2792  `in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies
2793  a different filter to each input channel (expanding from 1 channel to
2794  `channel_multiplier` channels for each), then concatenates the results
2795  together. Thus, the output has `in_channels * channel_multiplier` channels.
2796
2797  ```
2798  for k in 0..in_channels-1
2799    for q in 0..channel_multiplier-1
2800      output[b, i, j, k * channel_multiplier + q] =
2801        sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] *
2802                          filter[di, dj, k, q]
2803  ```
2804
2805  Must have `strides[0] = strides[3] = 1`.  For the most common case of the same
2806  horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
2807
2808  Args:
2809    input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`,
2810      `float32`, `float64`.
2811    filter: A `Tensor`. Must have the same type as `input`.
2812    strides: A list of `ints`. 1-D of length 4.  The stride of the sliding
2813      window for each dimension of `input`.
2814    padding: Controls how to pad the image before applying the convolution. Can
2815      be the string `"SAME"` or `"VALID"` indicating the type of padding
2816      algorithm to use, or a list indicating the explicit paddings at the start
2817      and end of each dimension. When explicit padding is used and data_format
2818      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
2819      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
2820      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2821      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2822    data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
2823      `"NHWC"`. Specify the data format of the input and output data. With the
2824      default format "NHWC", the data is stored in the order of: [batch, height,
2825        width, channels].
2826      Alternatively, the format could be "NCHW", the data storage order of:
2827        [batch, channels, height, width].
2828    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`. 1-D
2829      tensor of length 4.  The dilation factor for each dimension of `input`. If
2830      set to k > 1, there will be k-1 skipped cells between each filter element
2831      on that dimension. The dimension order is determined by the value of
2832      `data_format`, see above for details. Dilations in the batch and depth
2833      dimensions must be 1.
2834    name: A name for the operation (optional).
2835
2836  Returns:
2837    A `Tensor`. Has the same type as `input`.
2838  """
2839  padding, explicit_paddings = convert_padding(padding)
2840  return gen_nn_ops.depthwise_conv2d_native(
2841      input,
2842      filter,
2843      strides,
2844      padding,
2845      explicit_paddings=explicit_paddings,
2846      data_format=data_format,
2847      dilations=dilations,
2848      name=name)
2849
2850
2851@tf_export(
2852    "nn.depthwise_conv2d_backprop_input",
2853    v1=[
2854        "nn.depthwise_conv2d_native_backprop_input",
2855        "nn.depthwise_conv2d_backprop_input"
2856    ])
2857@dispatch.add_dispatch_support
2858@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_input")
2859def depthwise_conv2d_native_backprop_input(  # pylint: disable=redefined-builtin,dangerous-default-value
2860    input_sizes,
2861    filter,
2862    out_backprop,
2863    strides,
2864    padding,
2865    data_format="NHWC",
2866    dilations=[1, 1, 1, 1],
2867    name=None):
2868  r"""Computes the gradients of depthwise convolution with respect to the input.
2869
2870  Args:
2871    input_sizes: A `Tensor` of type `int32`. An integer vector representing the
2872      shape of `input`, based on `data_format`.  For example, if `data_format`
2873      is 'NHWC' then `input` is a 4-D `[batch, height, width, channels]` tensor.
2874    filter: A `Tensor`. Must be one of the following types: `half`, `bfloat16`,
2875      `float32`, `float64`. 4-D with shape `[filter_height, filter_width,
2876      in_channels, depthwise_multiplier]`.
2877    out_backprop: A `Tensor`. Must have the same type as `filter`. 4-D with
2878      shape  based on `data_format`. For example, if `data_format` is 'NHWC'
2879      then out_backprop shape is `[batch, out_height, out_width, out_channels]`.
2880      Gradients w.r.t. the output of the convolution.
2881    strides: A list of `ints`. The stride of the sliding window for each
2882      dimension of the input of the convolution.
2883    padding: Controls how to pad the image before applying the convolution. Can
2884      be the string `"SAME"` or `"VALID"` indicating the type of padding
2885      algorithm to use, or a list indicating the explicit paddings at the start
2886      and end of each dimension. When explicit padding is used and data_format
2887      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
2888      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
2889      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2890      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2891    data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
2892      `"NHWC"`. Specify the data format of the input and output data. With the
2893      default format "NHWC", the data is stored in the order of: [batch, height,
2894        width, channels].
2895      Alternatively, the format could be "NCHW", the data storage order of:
2896        [batch, channels, height, width].
2897    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`. 1-D
2898      tensor of length 4.  The dilation factor for each dimension of `input`. If
2899      set to k > 1, there will be k-1 skipped cells between each filter element
2900      on that dimension. The dimension order is determined by the value of
2901      `data_format`, see above for details. Dilations in the batch and depth
2902      dimensions must be 1.
2903    name: A name for the operation (optional).
2904
2905  Returns:
2906    A `Tensor`. Has the same type as `filter`.
2907  """
2908  padding, explicit_paddings = convert_padding(padding)
2909  return gen_nn_ops.depthwise_conv2d_native_backprop_input(
2910      input_sizes,
2911      filter,
2912      out_backprop,
2913      strides,
2914      padding,
2915      explicit_paddings=explicit_paddings,
2916      data_format=data_format,
2917      dilations=dilations,
2918      name=name)
2919
2920
2921@tf_export(
2922    "nn.depthwise_conv2d_backprop_filter",
2923    v1=[
2924        "nn.depthwise_conv2d_native_backprop_filter",
2925        "nn.depthwise_conv2d_backprop_filter"
2926    ])
2927@dispatch.add_dispatch_support
2928@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_filter")
2929def depthwise_conv2d_native_backprop_filter(  # pylint: disable=redefined-builtin,dangerous-default-value
2930    input,
2931    filter_sizes,
2932    out_backprop,
2933    strides,
2934    padding,
2935    data_format="NHWC",
2936    dilations=[1, 1, 1, 1],
2937    name=None):
2938  r"""Computes the gradients of depthwise convolution with respect to the filter.
2939
2940  Args:
2941    input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`,
2942      `float32`, `float64`. 4-D with shape based on `data_format`.  For example,
2943      if `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height,
2944      in_width, in_channels]` tensor.
2945    filter_sizes: A `Tensor` of type `int32`. An integer vector representing the
2946      tensor shape of `filter`, where `filter` is a 4-D `[filter_height,
2947      filter_width, in_channels, depthwise_multiplier]` tensor.
2948    out_backprop: A `Tensor`. Must have the same type as `input`. 4-D with shape
2949      based on `data_format`. For example, if `data_format` is 'NHWC' then
2950      out_backprop shape is `[batch, out_height, out_width, out_channels]`.
2951      Gradients w.r.t. the output of the convolution.
2952    strides: A list of `ints`. The stride of the sliding window for each
2953      dimension of the input of the convolution.
2954    padding: Controls how to pad the image before applying the convolution. Can
2955      be the string `"SAME"` or `"VALID"` indicating the type of padding
2956      algorithm to use, or a list indicating the explicit paddings at the start
2957      and end of each dimension. When explicit padding is used and data_format
2958      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
2959      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
2960      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2961      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2962    data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
2963      `"NHWC"`. Specify the data format of the input and output data. With the
2964      default format "NHWC", the data is stored in the order of: [batch, height,
2965        width, channels].
2966      Alternatively, the format could be "NCHW", the data storage order of:
2967        [batch, channels, height, width].
2968    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`. 1-D
2969      tensor of length 4.  The dilation factor for each dimension of `input`. If
2970      set to k > 1, there will be k-1 skipped cells between each filter element
2971      on that dimension. The dimension order is determined by the value of
2972      `data_format`, see above for details. Dilations in the batch and depth
2973      dimensions must be 1.
2974    name: A name for the operation (optional).
2975
2976  Returns:
2977    A `Tensor`. Has the same type as `input`.
2978  """
2979  padding, explicit_paddings = convert_padding(padding)
2980  return gen_nn_ops.depthwise_conv2d_native_backprop_filter(
2981      input,
2982      filter_sizes,
2983      out_backprop,
2984      strides,
2985      padding,
2986      explicit_paddings=explicit_paddings,
2987      data_format=data_format,
2988      dilations=dilations,
2989      name=name)
2990
2991
2992def _conv3d_expanded_batch(
2993    input,  # pylint: disable=redefined-builtin
2994    filter,  # pylint: disable=redefined-builtin
2995    strides,
2996    padding,
2997    data_format,
2998    dilations=None,
2999    name=None):
3000  """Helper function for `conv3d`; handles expanded batches."""
3001  shape = input.shape
3002  # shape object may lack ndims, e.g., if input is an np.ndarray.  In that case,
3003  # we fall back to len(shape).
3004  ndims = getattr(shape, "ndims", -1)
3005  if ndims == -1:
3006    ndims = len(shape)
3007  if ndims in (5, 4, 3, 2, 1, 0, None):
3008    # We avoid calling squeeze_batch_dims to reduce extra python function
3009    # call slowdown in eager mode.  This branch doesn't require reshapes.
3010    return gen_nn_ops.conv3d(
3011        input,
3012        filter,
3013        strides,
3014        padding,
3015        data_format=data_format,
3016        dilations=dilations,
3017        name=name)
3018  else:
3019    return squeeze_batch_dims(
3020        input,
3021        functools.partial(
3022            gen_nn_ops.conv3d,
3023            filter=filter,
3024            strides=strides,
3025            padding=padding,
3026            data_format=data_format,
3027            dilations=dilations),
3028        inner_rank=4,
3029        name=name)
3030
3031
3032@tf_export("nn.conv3d", v1=[])
3033@dispatch.add_dispatch_support
3034def conv3d_v2(input,  # pylint: disable=redefined-builtin,missing-docstring
3035              filters,
3036              strides,
3037              padding,
3038              data_format="NDHWC",
3039              dilations=None,
3040              name=None):
3041  if dilations is None:
3042    dilations = [1, 1, 1, 1, 1]
3043  return _conv3d_expanded_batch(input, filters, strides, padding, data_format,
3044                                dilations, name)
3045
3046
3047@tf_export(v1=["nn.conv3d"])
3048@dispatch.add_dispatch_support
3049def conv3d_v1(  # pylint: disable=missing-docstring,dangerous-default-value
3050    input,  # pylint: disable=redefined-builtin
3051    filter=None,  # pylint: disable=redefined-builtin
3052    strides=None,
3053    padding=None,
3054    data_format="NDHWC",
3055    dilations=[1, 1, 1, 1, 1],
3056    name=None,
3057    filters=None):
3058  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
3059  return gen_nn_ops.conv3d(
3060      input, filter, strides, padding, data_format, dilations, name)
3061
3062
3063conv3d_v2.__doc__ = deprecation.rewrite_argument_docstring(
3064    gen_nn_ops.conv3d.__doc__, "filter", "filters")
3065conv3d_v1.__doc__ = gen_nn_ops.conv3d.__doc__
3066
3067
3068@tf_export(v1=["nn.conv3d_transpose"])
3069@dispatch.add_dispatch_support
3070def conv3d_transpose(
3071    value,
3072    filter=None,  # pylint: disable=redefined-builtin
3073    output_shape=None,
3074    strides=None,
3075    padding="SAME",
3076    data_format="NDHWC",
3077    name=None,
3078    input=None,  # pylint: disable=redefined-builtin
3079    filters=None,
3080    dilations=None):
3081  """The transpose of `conv3d`.
3082
3083  This operation is sometimes called "deconvolution" after
3084  (Zeiler et al., 2010), but is really the transpose (gradient) of `conv3d`
3085  rather than an actual deconvolution.
3086
3087  Args:
3088    value: A 5-D `Tensor` of type `float` and shape
3089      `[batch, depth, height, width, in_channels]`.
3090    filter: A 5-D `Tensor` with the same type as `value` and shape
3091      `[depth, height, width, output_channels, in_channels]`.  `filter`'s
3092      `in_channels` dimension must match that of `value`.
3093    output_shape: A 1-D `Tensor` representing the output shape of the
3094      deconvolution op.
3095    strides: A list of ints. The stride of the sliding window for each
3096      dimension of the input tensor.
3097    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
3098      See the "returns" section of `tf.nn.convolution` for details.
3099    data_format: A string, either `'NDHWC'` or `'NCDHW`' specifying the layout
3100      of the input and output tensors. Defaults to `'NDHWC'`.
3101    name: Optional name for the returned tensor.
3102    input: Alias of value.
3103    filters: Alias of filter.
3104    dilations: An int or list of `ints` that has length `1`, `3` or `5`,
3105      defaults to 1. The dilation factor for each dimension of`input`. If a
3106      single value is given it is replicated in the `D`, `H` and `W` dimension.
3107      By default the `N` and `C` dimensions are set to 1. If set to k > 1, there
3108      will be k-1 skipped cells between each filter element on that dimension.
3109      The dimension order is determined by the value of `data_format`, see above
3110      for details. Dilations in the batch and depth dimensions if a 5-d tensor
3111      must be 1.
3112
3113  Returns:
3114    A `Tensor` with the same type as `value`.
3115
3116  Raises:
3117    ValueError: If input/output depth does not match `filter`'s shape, or if
3118      padding is other than `'VALID'` or `'SAME'`.
3119
3120  References:
3121    Deconvolutional Networks:
3122      [Zeiler et al., 2010]
3123      (https://ieeexplore.ieee.org/abstract/document/5539957)
3124      ([pdf]
3125      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
3126  """
3127  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
3128  value = deprecated_argument_lookup("input", input, "value", value)
3129  return conv3d_transpose_v2(
3130      value,
3131      filter,
3132      output_shape,
3133      strides,
3134      padding=padding,
3135      data_format=data_format,
3136      dilations=dilations,
3137      name=name)
3138
3139
3140@tf_export("nn.conv3d_transpose", v1=[])
3141@dispatch.add_dispatch_support
3142def conv3d_transpose_v2(input,  # pylint: disable=redefined-builtin
3143                        filters,
3144                        output_shape,
3145                        strides,
3146                        padding="SAME",
3147                        data_format="NDHWC",
3148                        dilations=None,
3149                        name=None):
3150  """The transpose of `conv3d`.
3151
3152  This operation is sometimes called "deconvolution" after
3153  (Zeiler et al., 2010), but is really the transpose (gradient) of `conv3d`
3154  rather than an actual deconvolution.
3155
3156  Args:
3157    input: A 5-D `Tensor` of type `float` and shape `[batch, depth, height,
3158      width, in_channels]` for `NDHWC` data format or `[batch, in_channels,
3159      depth, height, width]` for `NCDHW` data format.
3160    filters: A 5-D `Tensor` with the same type as `input` and shape `[depth,
3161      height, width, output_channels, in_channels]`.  `filter`'s `in_channels`
3162      dimension must match that of `input`.
3163    output_shape: A 1-D `Tensor` representing the output shape of the
3164      deconvolution op.
3165    strides: An int or list of `ints` that has length `1`, `3` or `5`.  The
3166      stride of the sliding window for each dimension of `input`. If a single
3167      value is given it is replicated in the `D`, `H` and `W` dimension. By
3168      default the `N` and `C` dimensions are set to 0. The dimension order is
3169      determined by the value of `data_format`, see below for details.
3170    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
3171      the "returns" section of `tf.nn.convolution` for details.
3172    data_format: A string. 'NDHWC' and 'NCDHW' are supported.
3173    dilations: An int or list of `ints` that has length `1`, `3` or `5`,
3174      defaults to 1. The dilation factor for each dimension of`input`. If a
3175      single value is given it is replicated in the `D`, `H` and `W` dimension.
3176      By default the `N` and `C` dimensions are set to 1. If set to k > 1, there
3177      will be k-1 skipped cells between each filter element on that dimension.
3178      The dimension order is determined by the value of `data_format`, see above
3179      for details. Dilations in the batch and depth dimensions if a 5-d tensor
3180      must be 1.
3181    name: Optional name for the returned tensor.
3182
3183  Returns:
3184    A `Tensor` with the same type as `input`.
3185
3186  References:
3187    Deconvolutional Networks:
3188      [Zeiler et al., 2010]
3189      (https://ieeexplore.ieee.org/abstract/document/5539957)
3190      ([pdf]
3191      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
3192  """
3193  with ops.name_scope(name, "conv3d_transpose",
3194                      [input, filter, output_shape]) as name:
3195    if data_format is None:
3196      data_format = "NDHWC"
3197    channel_index = 1 if data_format.startswith("NC") else 4
3198
3199    strides = _get_sequence(strides, 3, channel_index, "strides")
3200    dilations = _get_sequence(dilations, 3, channel_index, "dilations")
3201
3202    return gen_nn_ops.conv3d_backprop_input_v2(
3203        input_sizes=output_shape,
3204        filter=filters,
3205        out_backprop=input,
3206        strides=strides,
3207        padding=padding,
3208        data_format=data_format,
3209        dilations=dilations,
3210        name=name)
3211
3212
3213CONV_TRANSPOSE_OPS = (
3214    conv1d_transpose,
3215    conv2d_transpose_v2,
3216    conv3d_transpose_v2,
3217)
3218
3219
3220@tf_export("nn.conv_transpose")
3221@dispatch.add_dispatch_support
3222def conv_transpose(input,  # pylint: disable=redefined-builtin
3223                   filters,
3224                   output_shape,
3225                   strides,
3226                   padding="SAME",
3227                   data_format=None,
3228                   dilations=None,
3229                   name=None):
3230  """The transpose of `convolution`.
3231
3232  This operation is sometimes called "deconvolution" after
3233  (Zeiler et al., 2010), but is really the transpose (gradient) of `conv3d`
3234  rather than an actual deconvolution.
3235
3236  Args:
3237    input: An N+2 dimensional `Tensor` of shape
3238      `[batch_size] + input_spatial_shape + [in_channels]` if data_format does
3239      not start with "NC" (default), or
3240      `[batch_size, in_channels] + input_spatial_shape` if data_format starts
3241      with "NC". It must be one of the following types:
3242      `half`, `bfloat16`, `float32`, `float64`.
3243    filters: An N+2 dimensional `Tensor` with the same type as `input` and
3244      shape `spatial_filter_shape + [in_channels, out_channels]`.
3245    output_shape: A 1-D `Tensor` representing the output shape of the
3246      deconvolution op.
3247    strides: An int or list of `ints` that has length `1`, `N` or `N+2`.  The
3248      stride of the sliding window for each dimension of `input`. If a single
3249      value is given it is replicated in the spatial dimensions. By default
3250      the `N` and `C` dimensions are set to 0. The dimension order is determined
3251      by the value of `data_format`, see below for details.
3252    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
3253      the "returns" section of `tf.nn.convolution` for details.
3254    data_format: A string or None.  Specifies whether the channel dimension of
3255      the `input` and output is the last dimension (default, or if `data_format`
3256      does not start with "NC"), or the second dimension (if `data_format`
3257      starts with "NC").  For N=1, the valid values are "NWC" (default) and
3258      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
3259      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
3260    dilations: An int or list of `ints` that has length `1`, `N` or `N+2`,
3261      defaults to 1. The dilation factor for each dimension of`input`. If a
3262      single value is given it is replicated in the spatial dimensions. By
3263      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
3264      will be k-1 skipped cells between each filter element on that dimension.
3265      The dimension order is determined by the value of `data_format`, see above
3266      for details.
3267    name: A name for the operation (optional). If not specified "conv_transpose"
3268      is used.
3269
3270  Returns:
3271    A `Tensor` with the same type as `value`.
3272
3273  References:
3274    Deconvolutional Networks:
3275      [Zeiler et al., 2010]
3276      (https://ieeexplore.ieee.org/abstract/document/5539957)
3277      ([pdf]
3278      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
3279  """
3280  with ops.name_scope(name, "conv_transpose",
3281                      [input, filter, output_shape]) as name:
3282    if tensor_util.is_tf_type(output_shape):
3283      n = output_shape.shape[0] - 2
3284    elif isinstance(output_shape, collections_abc.Sized):
3285      n = len(output_shape) - 2
3286    else:
3287      raise ValueError("output_shape must be a tensor or sized collection.")
3288
3289    if not 1 <= n <= 3:
3290      raise ValueError(
3291          "output_shape must be of length 3, 4 or 5 but was {}.".format(n + 2))
3292
3293    op = CONV_TRANSPOSE_OPS[n-1]
3294    return op(
3295        input,
3296        filters,
3297        output_shape,
3298        strides,
3299        padding=padding,
3300        data_format=data_format,
3301        dilations=dilations,
3302        name=name)
3303
3304
3305def _tf_deterministic_ops():
3306  if _tf_deterministic_ops.value is None:
3307    tf_deterministic_ops = os.environ.get("TF_DETERMINISTIC_OPS")
3308    if tf_deterministic_ops is not None:
3309      tf_deterministic_ops = tf_deterministic_ops.lower()
3310    _tf_deterministic_ops.value = (
3311        tf_deterministic_ops == "true" or tf_deterministic_ops == "1")
3312  return _tf_deterministic_ops.value
3313
3314
3315_tf_deterministic_ops.value = None
3316
3317
3318@tf_export("nn.bias_add")
3319@dispatch.add_dispatch_support
3320def bias_add(value, bias, data_format=None, name=None):
3321  """Adds `bias` to `value`.
3322
3323  This is (mostly) a special case of `tf.add` where `bias` is restricted to 1-D.
3324  Broadcasting is supported, so `value` may have any number of dimensions.
3325  Unlike `tf.add`, the type of `bias` is allowed to differ from `value` in the
3326  case where both types are quantized.
3327
3328  Args:
3329    value: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
3330      `int16`, `int8`, `complex64`, or `complex128`.
3331    bias: A 1-D `Tensor` with size matching the channel dimension of `value`.
3332      Must be the same type as `value` unless `value` is a quantized type,
3333      in which case a different quantized type may be used.
3334    data_format: A string. 'N...C' and 'NC...' are supported. If `None` (the
3335      default) is specified then 'N..C' is assumed.
3336    name: A name for the operation (optional).
3337
3338  Returns:
3339    A `Tensor` with the same type as `value`.
3340
3341  Raises:
3342    ValueError if data format is unrecognized, if `value` has less than two
3343    dimensions when `data_format` is 'N..C'/`None` or `value` has less
3344    then three dimensions when `data_format` is `NC..`, if `bias` does not
3345    have exactly one dimension (is a vector), or if the size of `bias`
3346    does not match the size of the channel dimension of `value`.
3347  """
3348  with ops.name_scope(name, "BiasAdd", [value, bias]) as name:
3349    if data_format is not None:
3350      if data_format.startswith("NC"):
3351        data_format = "NCHW"
3352      elif data_format.startswith("N") and data_format.endswith("C"):
3353        data_format = "NHWC"
3354      else:
3355        raise ValueError("data_format must be of the form `N...C` or `NC...`")
3356
3357    if not context.executing_eagerly():
3358      value = ops.convert_to_tensor(value, name="input")
3359      bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
3360
3361    # TODO(duncanriach): Implement deterministic functionality at CUDA kernel
3362    #   level.
3363    if _tf_deterministic_ops():
3364      # Note that this code does not implement the same error checks as the
3365      # pre-existing C++ ops.
3366      if data_format == "NCHW":
3367        broadcast_shape_head = [1, array_ops.size(bias)]
3368        broadcast_shape_tail = array_ops.ones(
3369            array_ops.rank(value) - 2, dtype=dtypes.int32)
3370        broadcast_shape = array_ops.concat(
3371            [broadcast_shape_head, broadcast_shape_tail], 0)
3372        return math_ops.add(
3373            value, array_ops.reshape(bias, broadcast_shape), name=name)
3374      else:  # data_format == 'NHWC' or data_format == None
3375        return math_ops.add(value, bias, name=name)
3376    else:
3377      return gen_nn_ops.bias_add(
3378          value, bias, data_format=data_format, name=name)
3379
3380
3381def bias_add_v1(value, bias, name=None):
3382  """Adds `bias` to `value`.
3383
3384  This is a deprecated version of bias_add and will soon to be removed.
3385
3386  This is (mostly) a special case of `tf.add` where `bias` is restricted to 1-D.
3387  Broadcasting is supported, so `value` may have any number of dimensions.
3388  Unlike `tf.add`, the type of `bias` is allowed to differ from `value` in the
3389  case where both types are quantized.
3390
3391  Args:
3392    value: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
3393      `int16`, `int8`, `complex64`, or `complex128`.
3394    bias: A 1-D `Tensor` with size matching the last dimension of `value`.
3395      Must be the same type as `value` unless `value` is a quantized type,
3396      in which case a different quantized type may be used.
3397    name: A name for the operation (optional).
3398
3399  Returns:
3400    A `Tensor` with the same type as `value`.
3401  """
3402  with ops.name_scope(name, "BiasAddV1", [value, bias]) as name:
3403    value = ops.convert_to_tensor(value, name="input")
3404    bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
3405    return gen_nn_ops.bias_add_v1(value, bias, name=name)
3406
3407
3408@tf_export(v1=["nn.crelu"])
3409@dispatch.add_dispatch_support
3410def crelu(features, name=None, axis=-1):
3411  """Computes Concatenated ReLU.
3412
3413  Concatenates a ReLU which selects only the positive part of the activation
3414  with a ReLU which selects only the *negative* part of the activation.
3415  Note that as a result this non-linearity doubles the depth of the activations.
3416  Source: [Understanding and Improving Convolutional Neural Networks via
3417  Concatenated Rectified Linear Units. W. Shang, et
3418  al.](https://arxiv.org/abs/1603.05201)
3419
3420  Args:
3421    features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
3422      `int16`, or `int8`.
3423    name: A name for the operation (optional).
3424    axis: The axis that the output values are concatenated along. Default is -1.
3425
3426  Returns:
3427    A `Tensor` with the same type as `features`.
3428
3429  References:
3430    Understanding and Improving Convolutional Neural Networks via Concatenated
3431    Rectified Linear Units:
3432      [Shang et al., 2016](http://proceedings.mlr.press/v48/shang16)
3433      ([pdf](http://proceedings.mlr.press/v48/shang16.pdf))
3434  """
3435  with ops.name_scope(name, "CRelu", [features]) as name:
3436    features = ops.convert_to_tensor(features, name="features")
3437    c = array_ops.concat([features, -features], axis, name=name)
3438    return gen_nn_ops.relu(c)
3439
3440
3441@tf_export("nn.crelu", v1=[])
3442@dispatch.add_dispatch_support
3443def crelu_v2(features, axis=-1, name=None):
3444  return crelu(features, name=name, axis=axis)
3445crelu_v2.__doc__ = crelu.__doc__
3446
3447
3448@tf_export("nn.relu6")
3449@dispatch.add_dispatch_support
3450def relu6(features, name=None):
3451  """Computes Rectified Linear 6: `min(max(features, 0), 6)`.
3452
3453  Args:
3454    features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
3455      `int16`, or `int8`.
3456    name: A name for the operation (optional).
3457
3458  Returns:
3459    A `Tensor` with the same type as `features`.
3460
3461  References:
3462    Convolutional Deep Belief Networks on CIFAR-10:
3463      Krizhevsky et al., 2010
3464      ([pdf](http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf))
3465  """
3466  with ops.name_scope(name, "Relu6", [features]) as name:
3467    features = ops.convert_to_tensor(features, name="features")
3468    return gen_nn_ops.relu6(features, name=name)
3469
3470
3471@tf_export("nn.leaky_relu")
3472@dispatch.add_dispatch_support
3473def leaky_relu(features, alpha=0.2, name=None):
3474  """Compute the Leaky ReLU activation function.
3475
3476  Source: [Rectifier Nonlinearities Improve Neural Network Acoustic Models.
3477  AL Maas, AY Hannun, AY Ng - Proc. ICML, 2013]
3478  (https://ai.stanford.edu/~amaas/papers/relu_hybrid_icml2013_final.pdf).
3479  Args:
3480    features: A `Tensor` representing preactivation values. Must be one of
3481      the following types: `float16`, `float32`, `float64`, `int32`, `int64`.
3482    alpha: Slope of the activation function at x < 0.
3483    name: A name for the operation (optional).
3484
3485  Returns:
3486    The activation value.
3487
3488  References:
3489    Rectifier Nonlinearities Improve Neural Network Acoustic Models:
3490      [Maas et al., 2013]
3491      (http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.693.1422)
3492      ([pdf]
3493      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.693.1422&rep=rep1&type=pdf))
3494  """
3495  with ops.name_scope(name, "LeakyRelu", [features, alpha]) as name:
3496    features = ops.convert_to_tensor(features, name="features")
3497    if features.dtype.is_integer:
3498      features = math_ops.cast(features, dtypes.float32)
3499    if isinstance(alpha, np.ndarray):
3500      alpha = alpha.item()
3501    return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
3502
3503
3504@tf_export("nn.gelu", v1=[])
3505@dispatch.add_dispatch_support
3506def gelu(features, approximate=False, name=None):
3507  """Compute the Gaussian Error Linear Unit (GELU) activation function.
3508
3509  Gaussian error linear unit (GELU) computes
3510  `x * P(X <= x)`, where `P(X) ~ N(0, 1)`.
3511  The (GELU) nonlinearity weights inputs by their value, rather than gates
3512  inputs by their sign as in ReLU.
3513
3514  For example:
3515
3516  >>> x = tf.constant([-3.0, -1.0, 0.0, 1.0, 3.0], dtype=tf.float32)
3517  >>> y = tf.nn.gelu(x)
3518  >>> y.numpy()
3519  array([-0.00404951, -0.15865529,  0.        ,  0.8413447 ,  2.9959507 ],
3520      dtype=float32)
3521  >>> y = tf.nn.gelu(x, approximate=True)
3522  >>> y.numpy()
3523  array([-0.00363752, -0.15880796,  0.        ,  0.841192  ,  2.9963627 ],
3524      dtype=float32)
3525
3526  Args:
3527    features: A `Tensor` representing preactivation values.
3528    approximate: An optional `bool`. Defaults to `False`. Whether to enable
3529      approximation.
3530    name: A name for the operation (optional).
3531
3532  Returns:
3533    A `Tensor` with the same type as `features`.
3534
3535  References:
3536    [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415).
3537  """
3538  with ops.name_scope(name, "Gelu", [features]):
3539    features = ops.convert_to_tensor(features, name="features")
3540    if approximate:
3541      coeff = math_ops.cast(0.044715, features.dtype)
3542      return 0.5 * features * (
3543          1.0 + math_ops.tanh(0.7978845608028654 *
3544                              (features + coeff * math_ops.pow(features, 3))))
3545    else:
3546      return 0.5 * features * (1.0 + math_ops.erf(
3547          features / math_ops.cast(1.4142135623730951, features.dtype)))
3548
3549
3550def _flatten_outer_dims(logits):
3551  """Flattens logits' outer dimensions and keep its last dimension."""
3552  rank = array_ops.rank(logits)
3553  last_dim_size = array_ops.slice(
3554      array_ops.shape(logits), [math_ops.subtract(rank, 1)], [1])
3555  output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0))
3556
3557  # Set output shape if known.
3558  if not context.executing_eagerly():
3559    shape = logits.get_shape()
3560    if shape is not None and shape.dims is not None:
3561      shape = shape.as_list()
3562      product = 1
3563      product_valid = True
3564      for d in shape[:-1]:
3565        if d is None:
3566          product_valid = False
3567          break
3568        else:
3569          product *= d
3570      if product_valid:
3571        output_shape = [product, shape[-1]]
3572        output.set_shape(output_shape)
3573
3574  return output
3575
3576
3577def _wrap_2d_function(inputs, compute_op, dim=-1, name=None):
3578  """Helper function for ops that accept and return 2d inputs of same shape.
3579
3580  It reshapes and transposes the inputs into a 2-D Tensor and then invokes
3581  the given function. The output would be transposed and reshaped back.
3582  If the given function returns a tuple of tensors, each of them will be
3583  transposed and reshaped.
3584
3585  Args:
3586    inputs: A non-empty `Tensor`. Must be one of the following types: `half`,
3587      `float32`, `float64`.
3588    compute_op: The function to wrap. Must accept the input tensor as its first
3589      arugment, and a second keyword argument `name`.
3590    dim: The dimension softmax would be performed on. The default is -1 which
3591      indicates the last dimension.
3592    name: A name for the operation (optional).
3593
3594  Returns:
3595    A `Tensor`. Has the same shape as inputs. If compute_op returns multiple
3596      tensors, each of them have the same shape as the input.
3597  Raises:
3598    InvalidArgumentError: if `inputs` is empty or `dim` is beyond the last
3599      dimension of `inputs`.
3600  """
3601
3602  def _swap_axis(input_tensor, dim_index, last_index, name=None):
3603    """Swaps logits's dim_index and last_index."""
3604    return array_ops.transpose(
3605        input_tensor,
3606        array_ops.concat([
3607            math_ops.range(dim_index), [last_index],
3608            math_ops.range(dim_index + 1, last_index), [dim_index]
3609        ], 0),
3610        name=name)
3611
3612  inputs = ops.convert_to_tensor(inputs)
3613
3614  # We need its original shape for shape inference.
3615  shape = inputs.get_shape()
3616  is_last_dim = (dim == -1) or (dim == shape.ndims - 1)
3617
3618  if is_last_dim:
3619    return compute_op(inputs, name=name)
3620
3621  dim_val = dim
3622  if isinstance(dim, ops.Tensor):
3623    dim_val = tensor_util.constant_value(dim)
3624  if dim_val is not None and not -shape.ndims <= dim_val < shape.ndims:
3625    raise errors_impl.InvalidArgumentError(
3626        None, None,
3627        "Dimension (%d) must be in the range [%d, %d) where %d is the number of"
3628        " dimensions in the input." % (dim_val, -shape.ndims, shape.ndims,
3629                                       shape.ndims))
3630
3631  # If dim is not the last dimension, we have to do a transpose so that we can
3632  # still perform the op on its last dimension.
3633
3634  # In case dim is negative (and is not last dimension -1), add shape.ndims
3635  ndims = array_ops.rank(inputs)
3636  if not isinstance(dim, ops.Tensor):
3637    if dim < 0:
3638      dim += ndims
3639  else:
3640    dim = array_ops.where(math_ops.less(dim, 0), dim + ndims, dim)
3641
3642  # Swap logits' dimension of dim and its last dimension.
3643  input_rank = array_ops.rank(inputs)
3644  dim_axis = dim % shape.ndims
3645  inputs = _swap_axis(inputs, dim_axis, math_ops.subtract(input_rank, 1))
3646
3647  # Do the actual call on its last dimension.
3648  def fix_output(output):
3649    output = _swap_axis(
3650        output, dim_axis, math_ops.subtract(input_rank, 1), name=name)
3651
3652    # Make shape inference work since transpose may erase its static shape.
3653    output.set_shape(shape)
3654    return output
3655
3656  outputs = compute_op(inputs)
3657  if isinstance(outputs, tuple):
3658    return tuple(fix_output(output) for output in outputs)
3659  else:
3660    return fix_output(outputs)
3661
3662
3663@tf_export(v1=["nn.softmax", "math.softmax"])
3664@dispatch.add_dispatch_support
3665@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
3666def softmax(logits, axis=None, name=None, dim=None):
3667  """Computes softmax activations.
3668
3669  This function performs the equivalent of
3670
3671      softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)
3672
3673  See: https://en.wikipedia.org/wiki/Softmax_function
3674
3675  Example usage:
3676
3677  >>> tf.nn.softmax([-1, 0., 1.])
3678  <tf.Tensor: shape=(3,), dtype=float32,
3679  numpy=array([0.09003057, 0.24472848, 0.66524094], dtype=float32)>
3680
3681  Args:
3682    logits: A non-empty `Tensor`, or an object whose type has a registered
3683      `Tensor` conversion function. Must be one of the following types:
3684      `half`,`float32`, `float64`. See also `convert_to_tensor`
3685    axis: The dimension softmax would be performed on. The default is -1 which
3686      indicates the last dimension.
3687    name: A name for the operation (optional).
3688    dim: Deprecated alias for `axis`.
3689
3690  Returns:
3691    A `Tensor`. Has the same type and shape as `logits`.
3692
3693  Raises:
3694    InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
3695      dimension of `logits`.
3696    TypeError: If no conversion function is registered for `logits` to
3697      Tensor.
3698    RuntimeError: If a registered conversion function returns an invalid
3699      value.
3700
3701  """
3702  axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
3703  if axis is None:
3704    axis = -1
3705  return _wrap_2d_function(logits, gen_nn_ops.softmax, axis, name)
3706
3707
3708@tf_export("nn.softmax", "math.softmax", v1=[])
3709@dispatch.add_dispatch_support
3710def softmax_v2(logits, axis=None, name=None):
3711  """Computes softmax activations.
3712
3713  This function performs the equivalent of
3714
3715      softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)
3716
3717  Args:
3718    logits: A non-empty `Tensor`. Must be one of the following types: `half`,
3719      `float32`, `float64`.
3720    axis: The dimension softmax would be performed on. The default is -1 which
3721      indicates the last dimension.
3722    name: A name for the operation (optional).
3723
3724  Returns:
3725    A `Tensor`. Has the same type and shape as `logits`.
3726
3727  Raises:
3728    InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
3729      dimension of `logits`.
3730  """
3731  if axis is None:
3732    axis = -1
3733  return _wrap_2d_function(logits, gen_nn_ops.softmax, axis, name)
3734
3735
3736@tf_export(v1=["nn.log_softmax", "math.log_softmax"])
3737@dispatch.add_dispatch_support
3738@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
3739def log_softmax(logits, axis=None, name=None, dim=None):
3740  """Computes log softmax activations.
3741
3742  For each batch `i` and class `j` we have
3743
3744      logsoftmax = logits - log(reduce_sum(exp(logits), axis))
3745
3746  Args:
3747    logits: A non-empty `Tensor`. Must be one of the following types: `half`,
3748      `float32`, `float64`.
3749    axis: The dimension softmax would be performed on. The default is -1 which
3750      indicates the last dimension.
3751    name: A name for the operation (optional).
3752    dim: Deprecated alias for `axis`.
3753
3754  Returns:
3755    A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
3756
3757  Raises:
3758    InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
3759      dimension of `logits`.
3760  """
3761  axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
3762  if axis is None:
3763    axis = -1
3764  return _wrap_2d_function(logits, gen_nn_ops.log_softmax, axis, name)
3765
3766
3767@tf_export("nn.log_softmax", "math.log_softmax", v1=[])
3768@dispatch.add_dispatch_support
3769def log_softmax_v2(logits, axis=None, name=None):
3770  """Computes log softmax activations.
3771
3772  For each batch `i` and class `j` we have
3773
3774      logsoftmax = logits - log(reduce_sum(exp(logits), axis))
3775
3776  Args:
3777    logits: A non-empty `Tensor`. Must be one of the following types: `half`,
3778      `float32`, `float64`.
3779    axis: The dimension softmax would be performed on. The default is -1 which
3780      indicates the last dimension.
3781    name: A name for the operation (optional).
3782
3783  Returns:
3784    A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
3785
3786  Raises:
3787    InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
3788      dimension of `logits`.
3789  """
3790  if axis is None:
3791    axis = -1
3792  return _wrap_2d_function(logits, gen_nn_ops.log_softmax, axis, name)
3793
3794
3795def _ensure_xent_args(name, sentinel, labels, logits):
3796  # Make sure that all arguments were passed as named arguments.
3797  if sentinel is not None:
3798    raise ValueError("Only call `%s` with "
3799                     "named arguments (labels=..., logits=..., ...)" % name)
3800  if labels is None or logits is None:
3801    raise ValueError("Both labels and logits must be provided.")
3802
3803
3804@tf_export("nn.softmax_cross_entropy_with_logits", v1=[])
3805@dispatch.add_dispatch_support
3806def softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None):
3807  """Computes softmax cross entropy between `logits` and `labels`.
3808
3809  Measures the probability error in discrete classification tasks in which the
3810  classes are mutually exclusive (each entry is in exactly one class).  For
3811  example, each CIFAR-10 image is labeled with one and only one label: an image
3812  can be a dog or a truck, but not both.
3813
3814  **NOTE:**  While the classes are mutually exclusive, their probabilities
3815  need not be.  All that is required is that each row of `labels` is
3816  a valid probability distribution.  If they are not, the computation of the
3817  gradient will be incorrect.
3818
3819  If using exclusive `labels` (wherein one and only
3820  one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
3821
3822  Usage:
3823
3824  >>> logits = [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]
3825  >>> labels = [[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]
3826  >>> tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
3827  <tf.Tensor: shape=(2,), dtype=float32,
3828  numpy=array([0.16984604, 0.82474494], dtype=float32)>
3829
3830  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
3831  on `logits` internally for efficiency.  Do not call this op with the
3832  output of `softmax`, as it will produce incorrect results.
3833
3834  A common use case is to have logits and labels of shape
3835  `[batch_size, num_classes]`, but higher dimensions are supported, with
3836  the `axis` argument specifying the class dimension.
3837
3838  `logits` and `labels` must have the same dtype (either `float16`, `float32`,
3839  or `float64`).
3840
3841  Backpropagation will happen into both `logits` and `labels`.  To disallow
3842  backpropagation into `labels`, pass label tensors through `tf.stop_gradient`
3843  before feeding it to this function.
3844
3845  **Note that to avoid confusion, it is required to pass only named arguments to
3846  this function.**
3847
3848  Args:
3849    labels: Each vector along the class dimension should hold a valid
3850      probability distribution e.g. for the case in which labels are of shape
3851      `[batch_size, num_classes]`, each row of `labels[i]` must be a valid
3852      probability distribution.
3853    logits: Per-label activations, typically a linear output. These activation
3854      energies are interpreted as unnormalized log probabilities.
3855    axis: The class dimension. Defaulted to -1 which is the last dimension.
3856    name: A name for the operation (optional).
3857
3858  Returns:
3859    A `Tensor` that contains the softmax cross entropy loss. Its type is the
3860    same as `logits` and its shape is the same as `labels` except that it does
3861    not have the last dimension of `labels`.
3862  """
3863  return softmax_cross_entropy_with_logits_v2_helper(
3864      labels=labels, logits=logits, axis=axis, name=name)
3865
3866
3867@tf_export(v1=["nn.softmax_cross_entropy_with_logits_v2"])
3868@dispatch.add_dispatch_support
3869@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
3870def softmax_cross_entropy_with_logits_v2_helper(
3871    labels, logits, axis=None, name=None, dim=None):
3872  """Computes softmax cross entropy between `logits` and `labels`.
3873
3874  Measures the probability error in discrete classification tasks in which the
3875  classes are mutually exclusive (each entry is in exactly one class).  For
3876  example, each CIFAR-10 image is labeled with one and only one label: an image
3877  can be a dog or a truck, but not both.
3878
3879  **NOTE:**  While the classes are mutually exclusive, their probabilities
3880  need not be.  All that is required is that each row of `labels` is
3881  a valid probability distribution.  If they are not, the computation of the
3882  gradient will be incorrect.
3883
3884  If using exclusive `labels` (wherein one and only
3885  one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
3886
3887  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
3888  on `logits` internally for efficiency.  Do not call this op with the
3889  output of `softmax`, as it will produce incorrect results.
3890
3891  A common use case is to have logits and labels of shape
3892  `[batch_size, num_classes]`, but higher dimensions are supported, with
3893  the `axis` argument specifying the class dimension.
3894
3895  `logits` and `labels` must have the same dtype (either `float16`, `float32`,
3896  or `float64`).
3897
3898  Backpropagation will happen into both `logits` and `labels`.  To disallow
3899  backpropagation into `labels`, pass label tensors through `tf.stop_gradient`
3900  before feeding it to this function.
3901
3902  **Note that to avoid confusion, it is required to pass only named arguments to
3903  this function.**
3904
3905  Args:
3906    labels: Each vector along the class dimension should hold a valid
3907      probability distribution e.g. for the case in which labels are of shape
3908      `[batch_size, num_classes]`, each row of `labels[i]` must be a valid
3909      probability distribution.
3910    logits: Unscaled log probabilities.
3911    axis: The class dimension. Defaulted to -1 which is the last dimension.
3912    name: A name for the operation (optional).
3913    dim: Deprecated alias for axis.
3914
3915  Returns:
3916    A `Tensor` that contains the softmax cross entropy loss. Its type is the
3917    same as `logits` and its shape is the same as `labels` except that it does
3918    not have the last dimension of `labels`.
3919  """
3920  # TODO(pcmurray) Raise an error when the labels do not sum to 1. Note: This
3921  # could break users who call this with bad labels, but disregard the bad
3922  # results.
3923  axis = deprecated_argument_lookup("axis", axis, "dim", dim)
3924  del dim
3925  if axis is None:
3926    axis = -1
3927
3928  with ops.name_scope(name, "softmax_cross_entropy_with_logits",
3929                      [logits, labels]) as name:
3930    logits = ops.convert_to_tensor(logits, name="logits")
3931    labels = ops.convert_to_tensor(labels, name="labels")
3932    convert_to_float32 = (
3933        logits.dtype == dtypes.float16 or logits.dtype == dtypes.bfloat16)
3934    precise_logits = math_ops.cast(
3935        logits, dtypes.float32) if convert_to_float32 else logits
3936    # labels and logits must be of the same type
3937    labels = math_ops.cast(labels, precise_logits.dtype)
3938    input_rank = array_ops.rank(precise_logits)
3939    # For shape inference.
3940    shape = logits.get_shape()
3941
3942    # Move the dim to the end if dim is not the last dimension.
3943    if axis != -1:
3944
3945      def _move_dim_to_end(tensor, dim_index, rank):
3946        return array_ops.transpose(
3947            tensor,
3948            array_ops.concat([
3949                math_ops.range(dim_index),
3950                math_ops.range(dim_index + 1, rank), [dim_index]
3951            ], 0))
3952
3953      precise_logits = _move_dim_to_end(precise_logits, axis, input_rank)
3954      labels = _move_dim_to_end(labels, axis, input_rank)
3955
3956    input_shape = array_ops.shape(precise_logits)
3957
3958    # Make precise_logits and labels into matrices.
3959    precise_logits = _flatten_outer_dims(precise_logits)
3960    labels = _flatten_outer_dims(labels)
3961
3962    # Do the actual op computation.
3963    # The second output tensor contains the gradients.  We use it in
3964    # CrossEntropyGrad() in nn_grad but not here.
3965    cost, unused_backprop = gen_nn_ops.softmax_cross_entropy_with_logits(
3966        precise_logits, labels, name=name)
3967
3968    # The output cost shape should be the input minus axis.
3969    output_shape = array_ops.slice(input_shape, [0],
3970                                   [math_ops.subtract(input_rank, 1)])
3971    cost = array_ops.reshape(cost, output_shape)
3972
3973    # Make shape inference work since reshape and transpose may erase its static
3974    # shape.
3975    if not context.executing_eagerly(
3976    ) and shape is not None and shape.dims is not None:
3977      shape = shape.as_list()
3978      del shape[axis]
3979      cost.set_shape(shape)
3980
3981    if convert_to_float32:
3982      return math_ops.cast(cost, logits.dtype)
3983    else:
3984      return cost
3985
3986
3987_XENT_DEPRECATION = """
3988Future major versions of TensorFlow will allow gradients to flow
3989into the labels input on backprop by default.
3990
3991See `tf.nn.softmax_cross_entropy_with_logits_v2`.
3992"""
3993
3994
3995@tf_export(v1=["nn.softmax_cross_entropy_with_logits"])
3996@dispatch.add_dispatch_support
3997@deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION)
3998def softmax_cross_entropy_with_logits(
3999    _sentinel=None,  # pylint: disable=invalid-name
4000    labels=None,
4001    logits=None,
4002    dim=-1,
4003    name=None,
4004    axis=None):
4005  """Computes softmax cross entropy between `logits` and `labels`.
4006
4007  Measures the probability error in discrete classification tasks in which the
4008  classes are mutually exclusive (each entry is in exactly one class).  For
4009  example, each CIFAR-10 image is labeled with one and only one label: an image
4010  can be a dog or a truck, but not both.
4011
4012  **NOTE:**  While the classes are mutually exclusive, their probabilities
4013  need not be.  All that is required is that each row of `labels` is
4014  a valid probability distribution.  If they are not, the computation of the
4015  gradient will be incorrect.
4016
4017  If using exclusive `labels` (wherein one and only
4018  one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
4019
4020  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
4021  on `logits` internally for efficiency.  Do not call this op with the
4022  output of `softmax`, as it will produce incorrect results.
4023
4024  A common use case is to have logits and labels of shape
4025  `[batch_size, num_classes]`, but higher dimensions are supported, with
4026  the `dim` argument specifying the class dimension.
4027
4028  Backpropagation will happen only into `logits`.  To calculate a cross entropy
4029  loss that allows backpropagation into both `logits` and `labels`, see
4030  `tf.nn.softmax_cross_entropy_with_logits_v2`.
4031
4032  **Note that to avoid confusion, it is required to pass only named arguments to
4033  this function.**
4034
4035  Args:
4036    _sentinel: Used to prevent positional parameters. Internal, do not use.
4037    labels: Each vector along the class dimension should hold a valid
4038      probability distribution e.g. for the case in which labels are of shape
4039      `[batch_size, num_classes]`, each row of `labels[i]` must be a valid
4040      probability distribution.
4041    logits: Per-label activations, typically a linear output. These activation
4042      energies are interpreted as unnormalized log probabilities.
4043    dim: The class dimension. Defaulted to -1 which is the last dimension.
4044    name: A name for the operation (optional).
4045    axis: Alias for dim.
4046
4047  Returns:
4048    A `Tensor` that contains the softmax cross entropy loss. Its type is the
4049    same as `logits` and its shape is the same as `labels` except that it does
4050    not have the last dimension of `labels`.
4051  """
4052  dim = deprecated_argument_lookup("axis", axis, "dim", dim)
4053  _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel, labels,
4054                    logits)
4055
4056  with ops.name_scope(name, "softmax_cross_entropy_with_logits_sg",
4057                      [logits, labels]) as name:
4058    labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
4059
4060  return softmax_cross_entropy_with_logits_v2(
4061      labels=labels, logits=logits, axis=dim, name=name)
4062
4063
4064@tf_export(v1=["nn.sparse_softmax_cross_entropy_with_logits"])
4065@dispatch.add_dispatch_support
4066def sparse_softmax_cross_entropy_with_logits(
4067    _sentinel=None,  # pylint: disable=invalid-name
4068    labels=None,
4069    logits=None,
4070    name=None):
4071  """Computes sparse softmax cross entropy between `logits` and `labels`.
4072
4073  Measures the probability error in discrete classification tasks in which the
4074  classes are mutually exclusive (each entry is in exactly one class).  For
4075  example, each CIFAR-10 image is labeled with one and only one label: an image
4076  can be a dog or a truck, but not both.
4077
4078  **NOTE:**  For this operation, the probability of a given label is considered
4079  exclusive.  That is, soft classes are not allowed, and the `labels` vector
4080  must provide a single specific index for the true class for each row of
4081  `logits` (each minibatch entry).  For soft softmax classification with
4082  a probability distribution for each entry, see
4083  `softmax_cross_entropy_with_logits_v2`.
4084
4085  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
4086  on `logits` internally for efficiency.  Do not call this op with the
4087  output of `softmax`, as it will produce incorrect results.
4088
4089  A common use case is to have logits of shape
4090  `[batch_size, num_classes]` and have labels of shape
4091  `[batch_size]`, but higher dimensions are supported, in which
4092  case the `dim`-th dimension is assumed to be of size `num_classes`.
4093  `logits` must have the dtype of `float16`, `float32`, or `float64`, and
4094  `labels` must have the dtype of `int32` or `int64`.
4095
4096  **Note that to avoid confusion, it is required to pass only named arguments to
4097  this function.**
4098
4099  Args:
4100    _sentinel: Used to prevent positional parameters. Internal, do not use.
4101    labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
4102      `labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
4103      must be an index in `[0, num_classes)`. Other values will raise an
4104      exception when this op is run on CPU, and return `NaN` for corresponding
4105      loss and gradient rows on GPU.
4106    logits: Per-label activations (typically a linear output) of shape
4107      `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32`, or
4108      `float64`. These activation energies are interpreted as unnormalized log
4109      probabilities.
4110    name: A name for the operation (optional).
4111
4112  Returns:
4113    A `Tensor` of the same shape as `labels` and of the same type as `logits`
4114    with the softmax cross entropy loss.
4115
4116  Raises:
4117    ValueError: If logits are scalars (need to have rank >= 1) or if the rank
4118      of the labels is not equal to the rank of the logits minus one.
4119  """
4120  _ensure_xent_args("sparse_softmax_cross_entropy_with_logits", _sentinel,
4121                    labels, logits)
4122
4123  # TODO(pcmurray) Raise an error when the label is not an index in
4124  # [0, num_classes). Note: This could break users who call this with bad
4125  # labels, but disregard the bad results.
4126
4127  # Reshape logits and labels to rank 2.
4128  with ops.name_scope(name, "SparseSoftmaxCrossEntropyWithLogits",
4129                      [labels, logits]):
4130    labels = ops.convert_to_tensor(labels)
4131    logits = ops.convert_to_tensor(logits)
4132    precise_logits = math_ops.cast(logits, dtypes.float32) if (dtypes.as_dtype(
4133        logits.dtype) == dtypes.float16) else logits
4134
4135    # Store label shape for result later.
4136    labels_static_shape = labels.get_shape()
4137    labels_shape = array_ops.shape(labels)
4138    static_shapes_fully_defined = (
4139        labels_static_shape.is_fully_defined() and
4140        logits.get_shape()[:-1].is_fully_defined())
4141    if logits.get_shape().ndims is not None and logits.get_shape().ndims == 0:
4142      raise ValueError(
4143          "Logits cannot be scalars - received shape %s." % logits.get_shape())
4144    if logits.get_shape().ndims is not None and (
4145        labels_static_shape.ndims is not None and
4146        labels_static_shape.ndims != logits.get_shape().ndims - 1):
4147      raise ValueError("Rank mismatch: Rank of labels (received %s) should "
4148                       "equal rank of logits minus 1 (received %s)." %
4149                       (labels_static_shape.ndims, logits.get_shape().ndims))
4150    if (static_shapes_fully_defined and
4151        labels_static_shape != logits.get_shape()[:-1]):
4152      raise ValueError("Shape mismatch: The shape of labels (received %s) "
4153                       "should equal the shape of logits except for the last "
4154                       "dimension (received %s)." % (labels_static_shape,
4155                                                     logits.get_shape()))
4156    # Check if no reshapes are required.
4157    if logits.get_shape().ndims == 2:
4158      cost, _ = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
4159          precise_logits, labels, name=name)
4160      if logits.dtype == dtypes.float16:
4161        return math_ops.cast(cost, dtypes.float16)
4162      else:
4163        return cost
4164
4165    # Perform a check of the dynamic shapes if the static shapes are not fully
4166    # defined.
4167    shape_checks = []
4168    if not static_shapes_fully_defined:
4169      shape_checks.append(
4170          check_ops.assert_equal(
4171              array_ops.shape(labels),
4172              array_ops.shape(logits)[:-1]))
4173    with ops.control_dependencies(shape_checks):
4174      # Reshape logits to 2 dim, labels to 1 dim.
4175      num_classes = array_ops.shape(logits)[array_ops.rank(logits) - 1]
4176      precise_logits = array_ops.reshape(precise_logits, [-1, num_classes])
4177      labels = array_ops.reshape(labels, [-1])
4178      # The second output tensor contains the gradients.  We use it in
4179      # _CrossEntropyGrad() in nn_grad but not here.
4180      cost, _ = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
4181          precise_logits, labels, name=name)
4182      cost = array_ops.reshape(cost, labels_shape)
4183      cost.set_shape(labels_static_shape)
4184      if logits.dtype == dtypes.float16:
4185        return math_ops.cast(cost, dtypes.float16)
4186      else:
4187        return cost
4188
4189
4190@tf_export("nn.sparse_softmax_cross_entropy_with_logits", v1=[])
4191@dispatch.add_dispatch_support
4192def sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name=None):
4193  """Computes sparse softmax cross entropy between `logits` and `labels`.
4194
4195  Measures the probability error in discrete classification tasks in which the
4196  classes are mutually exclusive (each entry is in exactly one class).  For
4197  example, each CIFAR-10 image is labeled with one and only one label: an image
4198  can be a dog or a truck, but not both.
4199
4200  **NOTE:**  For this operation, the probability of a given label is considered
4201  exclusive.  That is, soft classes are not allowed, and the `labels` vector
4202  must provide a single specific index for the true class for each row of
4203  `logits` (each minibatch entry).  For soft softmax classification with
4204  a probability distribution for each entry, see
4205  `softmax_cross_entropy_with_logits_v2`.
4206
4207  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
4208  on `logits` internally for efficiency.  Do not call this op with the
4209  output of `softmax`, as it will produce incorrect results.
4210
4211  A common use case is to have logits of shape
4212  `[batch_size, num_classes]` and have labels of shape
4213  `[batch_size]`, but higher dimensions are supported, in which
4214  case the `dim`-th dimension is assumed to be of size `num_classes`.
4215  `logits` must have the dtype of `float16`, `float32`, or `float64`, and
4216  `labels` must have the dtype of `int32` or `int64`.
4217
4218  **Note that to avoid confusion, it is required to pass only named arguments to
4219  this function.**
4220
4221  Args:
4222    labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
4223      `labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
4224      must be an index in `[0, num_classes)`. Other values will raise an
4225      exception when this op is run on CPU, and return `NaN` for corresponding
4226      loss and gradient rows on GPU.
4227    logits: Unscaled log probabilities of shape `[d_0, d_1, ..., d_{r-1},
4228      num_classes]` and dtype `float16`, `float32`, or `float64`.
4229    name: A name for the operation (optional).
4230
4231  Returns:
4232    A `Tensor` of the same shape as `labels` and of the same type as `logits`
4233    with the softmax cross entropy loss.
4234
4235  Raises:
4236    ValueError: If logits are scalars (need to have rank >= 1) or if the rank
4237      of the labels is not equal to the rank of the logits minus one.
4238  """
4239  return sparse_softmax_cross_entropy_with_logits(
4240      labels=labels, logits=logits, name=name)
4241
4242
4243@tf_export("nn.avg_pool", v1=["nn.avg_pool_v2"])
4244@dispatch.add_dispatch_support
4245def avg_pool_v2(input, ksize, strides, padding, data_format=None, name=None):  # pylint: disable=redefined-builtin
4246  """Performs the avg pooling on the input.
4247
4248  Each entry in `output` is the mean of the corresponding size `ksize`
4249  window in `value`.
4250
4251  Args:
4252    input:  Tensor of rank N+2, of shape `[batch_size] + input_spatial_shape +
4253      [num_channels]` if `data_format` does not start with "NC" (default), or
4254      `[batch_size, num_channels] + input_spatial_shape` if data_format starts
4255      with "NC". Pooling happens over the spatial dimensions only.
4256    ksize: An int or list of `ints` that has length `1`, `N` or `N+2`. The size
4257      of the window for each dimension of the input tensor.
4258    strides: An int or list of `ints` that has length `1`, `N` or `N+2`. The
4259      stride of the sliding window for each dimension of the input tensor.
4260    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
4261      the "returns" section of `tf.nn.convolution` for details.
4262    data_format: A string. Specifies the channel dimension. For N=1 it can be
4263      either "NWC" (default) or "NCW", for N=2 it can be either "NHWC" (default)
4264      or "NCHW" and for N=3 either "NDHWC" (default) or "NCDHW".
4265    name: Optional name for the operation.
4266
4267  Returns:
4268    A `Tensor` of format specified by `data_format`.
4269    The average pooled output tensor.
4270  """
4271  if input.shape is not None:
4272    n = len(input.shape) - 2
4273  elif data_format is not None:
4274    n = len(data_format) - 2
4275  else:
4276    raise ValueError(
4277        "The input must have a rank or a data format must be given.")
4278  if not 1 <= n <= 3:
4279    raise ValueError(
4280        "Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
4281
4282  if data_format is None:
4283    channel_index = n + 1
4284  else:
4285    channel_index = 1 if data_format.startswith("NC") else n + 1
4286
4287  ksize = _get_sequence(ksize, n, channel_index, "ksize")
4288  strides = _get_sequence(strides, n, channel_index, "strides")
4289
4290  avg_pooling_ops = {
4291      1: avg_pool1d,
4292      2: gen_nn_ops.avg_pool,
4293      3: gen_nn_ops.avg_pool3d
4294  }
4295
4296  op = avg_pooling_ops[n]
4297  return op(
4298      input,
4299      ksize=ksize,
4300      strides=strides,
4301      padding=padding,
4302      data_format=data_format,
4303      name=name)
4304
4305
4306@tf_export(v1=["nn.avg_pool", "nn.avg_pool2d"])
4307@dispatch.add_dispatch_support
4308def avg_pool(value, ksize, strides, padding, data_format="NHWC",
4309             name=None, input=None):  # pylint: disable=redefined-builtin
4310  """Performs the average pooling on the input.
4311
4312  Each entry in `output` is the mean of the corresponding size `ksize`
4313  window in `value`.
4314
4315  Args:
4316    value: A 4-D `Tensor` of shape `[batch, height, width, channels]` and type
4317      `float32`, `float64`, `qint8`, `quint8`, or `qint32`.
4318    ksize: An int or list of `ints` that has length `1`, `2` or `4`. The size of
4319      the window for each dimension of the input tensor.
4320    strides: An int or list of `ints` that has length `1`, `2` or `4`. The
4321      stride of the sliding window for each dimension of the input tensor.
4322    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
4323      See the "returns" section of `tf.nn.convolution` for details.
4324    data_format: A string. 'NHWC' and 'NCHW' are supported.
4325    name: Optional name for the operation.
4326    input: Alias for value.
4327
4328  Returns:
4329    A `Tensor` with the same type as `value`.  The average pooled output tensor.
4330  """
4331  with ops.name_scope(name, "AvgPool", [value]) as name:
4332    value = deprecation.deprecated_argument_lookup(
4333        "input", input, "value", value)
4334
4335    if data_format is None:
4336      data_format = "NHWC"
4337    channel_index = 1 if data_format.startswith("NC") else 3
4338
4339    ksize = _get_sequence(ksize, 2, channel_index, "ksize")
4340    strides = _get_sequence(strides, 2, channel_index, "strides")
4341
4342    return gen_nn_ops.avg_pool(
4343        value,
4344        ksize=ksize,
4345        strides=strides,
4346        padding=padding,
4347        data_format=data_format,
4348        name=name)
4349
4350
4351@tf_export("nn.avg_pool2d", v1=[])
4352@dispatch.add_dispatch_support
4353def avg_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):  # pylint: disable=redefined-builtin
4354  """Performs the average pooling on the input.
4355
4356  Each entry in `output` is the mean of the corresponding size `ksize`
4357  window in `value`.
4358
4359  Args:
4360    input: A 4-D `Tensor` of shape `[batch, height, width, channels]` and type
4361      `float32`, `float64`, `qint8`, `quint8`, or `qint32`.
4362    ksize: An int or list of `ints` that has length `1`, `2` or `4`. The size of
4363      the window for each dimension of the input tensor.
4364    strides: An int or list of `ints` that has length `1`, `2` or `4`. The
4365      stride of the sliding window for each dimension of the input tensor.
4366    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
4367      See the "returns" section of `tf.nn.convolution` for details.
4368    data_format: A string. 'NHWC' and 'NCHW' are supported.
4369    name: Optional name for the operation.
4370
4371  Returns:
4372    A `Tensor` with the same type as `value`.  The average pooled output tensor.
4373  """
4374  with ops.name_scope(name, "AvgPool2D", [input]) as name:
4375    if data_format is None:
4376      data_format = "NHWC"
4377    channel_index = 1 if data_format.startswith("NC") else 3
4378
4379    ksize = _get_sequence(ksize, 2, channel_index, "ksize")
4380    strides = _get_sequence(strides, 2, channel_index, "strides")
4381
4382    return gen_nn_ops.avg_pool(
4383        input,
4384        ksize=ksize,
4385        strides=strides,
4386        padding=padding,
4387        data_format=data_format,
4388        name=name)
4389
4390
4391@tf_export("nn.avg_pool1d")
4392@dispatch.add_dispatch_support
4393def avg_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):  # pylint: disable=redefined-builtin
4394  """Performs the average pooling on the input.
4395
4396  Each entry in `output` is the mean of the corresponding size `ksize`
4397  window in `value`.
4398
4399  Note internally this op reshapes and uses the underlying 2d operation.
4400
4401  Args:
4402    input: A 3-D `Tensor` of the format specified by `data_format`.
4403    ksize: An int or list of `ints` that has length `1` or `3`. The size of the
4404      window for each dimension of the input tensor.
4405    strides: An int or list of `ints` that has length `1` or `3`. The stride of
4406      the sliding window for each dimension of the input tensor.
4407    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
4408      the "returns" section of `tf.nn.convolution` for details.
4409    data_format: An optional string from: "NWC", "NCW". Defaults to "NWC".
4410    name: A name for the operation (optional).
4411
4412  Returns:
4413    A `Tensor` of format specified by `data_format`.
4414    The max pooled output tensor.
4415  """
4416  with ops.name_scope(name, "AvgPool1D", [input]) as name:
4417    if data_format is None:
4418      data_format = "NWC"
4419    channel_index = 1 if data_format.startswith("NC") else 2
4420    ksize = [1] + _get_sequence(ksize, 1, channel_index, "ksize")
4421    strides = [1] + _get_sequence(strides, 1, channel_index, "strides")
4422
4423    expanding_dim = 1 if data_format == "NWC" else 2
4424    data_format = "NHWC" if data_format == "NWC" else "NCHW"
4425
4426    input = array_ops.expand_dims_v2(input, expanding_dim)
4427    result = gen_nn_ops.avg_pool(
4428        input,
4429        ksize=ksize,
4430        strides=strides,
4431        padding=padding,
4432        data_format=data_format,
4433        name=name)
4434    return array_ops.squeeze(result, expanding_dim)
4435
4436
4437@tf_export("nn.avg_pool3d")
4438@dispatch.add_dispatch_support
4439def avg_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):  # pylint: disable=redefined-builtin
4440  """Performs the average pooling on the input.
4441
4442  Each entry in `output` is the mean of the corresponding size `ksize`
4443  window in `value`.
4444
4445  Args:
4446    input: A 5-D `Tensor` of shape `[batch, height, width, channels]` and type
4447      `float32`, `float64`, `qint8`, `quint8`, or `qint32`.
4448    ksize: An int or list of `ints` that has length `1`, `3` or `5`. The size of
4449      the window for each dimension of the input tensor.
4450    strides: An int or list of `ints` that has length `1`, `3` or `5`. The
4451      stride of the sliding window for each dimension of the input tensor.
4452    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
4453      See the "returns" section of `tf.nn.convolution` for details.
4454    data_format: A string. 'NDHWC' and 'NCDHW' are supported.
4455    name: Optional name for the operation.
4456
4457  Returns:
4458    A `Tensor` with the same type as `value`.  The average pooled output tensor.
4459  """
4460  with ops.name_scope(name, "AvgPool3D", [input]) as name:
4461    if data_format is None:
4462      data_format = "NDHWC"
4463    channel_index = 1 if data_format.startswith("NC") else 3
4464
4465    ksize = _get_sequence(ksize, 3, channel_index, "ksize")
4466    strides = _get_sequence(strides, 3, channel_index, "strides")
4467
4468    return gen_nn_ops.avg_pool3d(
4469        input,
4470        ksize=ksize,
4471        strides=strides,
4472        padding=padding,
4473        data_format=data_format,
4474        name=name)
4475
4476
4477# pylint: disable=redefined-builtin
4478@tf_export("nn.max_pool", v1=["nn.max_pool_v2"])
4479@dispatch.add_dispatch_support
4480def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None):
4481  """Performs the max pooling on the input.
4482
4483  Args:
4484    input:  Tensor of rank N+2, of shape `[batch_size] + input_spatial_shape +
4485      [num_channels]` if `data_format` does not start with "NC" (default), or
4486      `[batch_size, num_channels] + input_spatial_shape` if data_format starts
4487      with "NC". Pooling happens over the spatial dimensions only.
4488    ksize: An int or list of `ints` that has length `1`, `N` or `N+2`. The size
4489      of the window for each dimension of the input tensor.
4490    strides: An int or list of `ints` that has length `1`, `N` or `N+2`. The
4491      stride of the sliding window for each dimension of the input tensor.
4492    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
4493      padding algorithm to use, or a list indicating the explicit paddings at
4494      the start and end of each dimension. When explicit padding is used and
4495      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
4496      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
4497      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
4498      [pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
4499      padding, the size of the paddings cannot be greater than the sliding
4500      window size.
4501    data_format: A string. Specifies the channel dimension. For N=1 it can be
4502      either "NWC" (default) or "NCW", for N=2 it can be either "NHWC" (default)
4503      or "NCHW" and for N=3 either "NDHWC" (default) or "NCDHW".
4504    name: Optional name for the operation.
4505
4506  Returns:
4507    A `Tensor` of format specified by `data_format`.
4508    The max pooled output tensor.
4509  """
4510  if input.shape is not None:
4511    n = len(input.shape) - 2
4512  elif data_format is not None:
4513    n = len(data_format) - 2
4514  else:
4515    raise ValueError(
4516        "The input must have a rank or a data format must be given.")
4517  if not 1 <= n <= 3:
4518    raise ValueError(
4519        "Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
4520
4521  if data_format is None:
4522    channel_index = n + 1
4523  else:
4524    channel_index = 1 if data_format.startswith("NC") else n + 1
4525
4526  if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
4527    raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
4528                     "explicit padding")
4529
4530  ksize = _get_sequence(ksize, n, channel_index, "ksize")
4531  strides = _get_sequence(strides, n, channel_index, "strides")
4532
4533  if (isinstance(padding, (list, tuple)) and n == 3):
4534    raise ValueError("Explicit padding is not yet supported with an input "
4535                     "tensor of rank 5")
4536
4537  max_pooling_ops = {
4538      1: max_pool1d,
4539      2: max_pool2d,
4540      3: gen_nn_ops.max_pool3d
4541  }
4542
4543  op = max_pooling_ops[n]
4544  return op(
4545      input,
4546      ksize=ksize,
4547      strides=strides,
4548      padding=padding,
4549      data_format=data_format,
4550      name=name)
4551# pylint: enable=redefined-builtin
4552
4553
4554@tf_export(v1=["nn.max_pool"])
4555@dispatch.add_dispatch_support
4556def max_pool(value,
4557             ksize,
4558             strides,
4559             padding,
4560             data_format="NHWC",
4561             name=None,
4562             input=None):  # pylint: disable=redefined-builtin
4563  """Performs the max pooling on the input.
4564
4565  Args:
4566    value: A 4-D `Tensor` of the format specified by `data_format`.
4567    ksize: An int or list of `ints` that has length `1`, `2` or `4`.
4568      The size of the window for each dimension of the input tensor.
4569    strides: An int or list of `ints` that has length `1`, `2` or `4`.
4570      The stride of the sliding window for each dimension of the input tensor.
4571    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
4572      padding algorithm to use, or a list indicating the explicit paddings at
4573      the start and end of each dimension. When explicit padding is used and
4574      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
4575      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
4576      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
4577      [pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
4578      padding, the size of the paddings cannot be greater than the sliding
4579      window size.
4580    data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
4581    name: Optional name for the operation.
4582    input: Alias for value.
4583
4584  Returns:
4585    A `Tensor` of format specified by `data_format`.
4586    The max pooled output tensor.
4587  """
4588  value = deprecation.deprecated_argument_lookup("input", input, "value", value)
4589  with ops.name_scope(name, "MaxPool", [value]) as name:
4590    if data_format is None:
4591      data_format = "NHWC"
4592    channel_index = 1 if data_format.startswith("NC") else 3
4593
4594    ksize = _get_sequence(ksize, 2, channel_index, "ksize")
4595    strides = _get_sequence(strides, 2, channel_index, "strides")
4596    if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
4597      raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
4598                       "explicit padding")
4599    padding, explicit_paddings = convert_padding(padding)
4600    if ((np.isscalar(ksize) and ksize == 0) or
4601        (isinstance(ksize,
4602                    (list, tuple, np.ndarray)) and any(v == 0 for v in ksize))):
4603      raise ValueError("ksize cannot be zero.")
4604
4605    return gen_nn_ops.max_pool(
4606        value,
4607        ksize=ksize,
4608        strides=strides,
4609        padding=padding,
4610        explicit_paddings=explicit_paddings,
4611        data_format=data_format,
4612        name=name)
4613
4614
4615# pylint: disable=redefined-builtin
4616@tf_export("nn.max_pool1d")
4617@dispatch.add_dispatch_support
4618def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
4619  """Performs the max pooling on the input.
4620
4621  Note internally this op reshapes and uses the underlying 2d operation.
4622
4623  Args:
4624    input: A 3-D `Tensor` of the format specified by `data_format`.
4625    ksize: An int or list of `ints` that has length `1` or `3`. The size of the
4626      window for each dimension of the input tensor.
4627    strides: An int or list of `ints` that has length `1` or `3`. The stride of
4628      the sliding window for each dimension of the input tensor.
4629    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
4630      padding algorithm to use, or a list indicating the explicit paddings at
4631      the start and end of each dimension. When explicit padding is used and
4632      data_format is `"NWC"`, this should be in the form `[[0, 0], [pad_left,
4633      pad_right], [0, 0]]`. When explicit padding used and data_format is
4634      `"NCW"`, this should be in the form `[[0, 0], [0, 0], [pad_left,
4635      pad_right]]`. When using explicit padding, the size of the paddings cannot
4636      be greater than the sliding window size.
4637    data_format: An optional string from: "NWC", "NCW". Defaults to "NWC".
4638    name: A name for the operation (optional).
4639
4640  Returns:
4641    A `Tensor` of format specified by `data_format`.
4642    The max pooled output tensor.
4643  """
4644  with ops.name_scope(name, "MaxPool1d", [input]) as name:
4645    if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
4646      raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
4647                       "explicit padding")
4648    if data_format is None:
4649      data_format = "NWC"
4650    channel_index = 1 if data_format.startswith("NC") else 2
4651    ksize = [1] + _get_sequence(ksize, 1, channel_index, "ksize")
4652    strides = [1] + _get_sequence(strides, 1, channel_index, "strides")
4653    padding, explicit_paddings = convert_padding(padding, 3)
4654    if padding == "EXPLICIT":
4655      explicit_paddings = [0, 0] + explicit_paddings
4656
4657    expanding_dim = 1 if data_format == "NWC" else 2
4658    data_format = "NHWC" if data_format == "NWC" else "NCHW"
4659
4660    input = array_ops.expand_dims_v2(input, expanding_dim)
4661    result = gen_nn_ops.max_pool(
4662        input,
4663        ksize=ksize,
4664        strides=strides,
4665        padding=padding,
4666        explicit_paddings=explicit_paddings,
4667        data_format=data_format,
4668        name=name)
4669    return array_ops.squeeze(result, expanding_dim)
4670# pylint: enable=redefined-builtin
4671
4672
4673# pylint: disable=redefined-builtin
4674@tf_export("nn.max_pool2d")
4675@dispatch.add_dispatch_support
4676def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
4677  """Performs the max pooling on the input.
4678
4679  Args:
4680    input: A 4-D `Tensor` of the format specified by `data_format`.
4681    ksize: An int or list of `ints` that has length `1`, `2` or `4`. The size of
4682      the window for each dimension of the input tensor.
4683    strides: An int or list of `ints` that has length `1`, `2` or `4`. The
4684      stride of the sliding window for each dimension of the input tensor.
4685    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
4686      padding algorithm to use, or a list indicating the explicit paddings at
4687      the start and end of each dimension. When explicit padding is used and
4688      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
4689      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
4690      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
4691      [pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
4692      padding, the size of the paddings cannot be greater than the sliding
4693      window size.
4694    data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
4695    name: Optional name for the operation.
4696
4697  Returns:
4698    A `Tensor` of format specified by `data_format`.
4699    The max pooled output tensor.
4700  """
4701  with ops.name_scope(name, "MaxPool2d", [input]) as name:
4702    if data_format is None:
4703      data_format = "NHWC"
4704    channel_index = 1 if data_format.startswith("NC") else 3
4705
4706    ksize = _get_sequence(ksize, 2, channel_index, "ksize")
4707    strides = _get_sequence(strides, 2, channel_index, "strides")
4708    if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
4709      raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
4710                       "explicit padding")
4711    padding, explicit_paddings = convert_padding(padding)
4712
4713    return gen_nn_ops.max_pool(
4714        input,
4715        ksize=ksize,
4716        strides=strides,
4717        padding=padding,
4718        explicit_paddings=explicit_paddings,
4719        data_format=data_format,
4720        name=name)
4721# pylint: enable=redefined-builtin
4722
4723
4724# pylint: disable=redefined-builtin
4725@tf_export("nn.max_pool3d")
4726@dispatch.add_dispatch_support
4727def max_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):
4728  """Performs the max pooling on the input.
4729
4730  Args:
4731    input: A 5-D `Tensor` of the format specified by `data_format`.
4732    ksize: An int or list of `ints` that has length `1`, `3` or `5`. The size of
4733      the window for each dimension of the input tensor.
4734    strides: An int or list of `ints` that has length `1`, `3` or `5`. The
4735      stride of the sliding window for each dimension of the input tensor.
4736    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
4737      the "returns" section of `tf.nn.convolution` for details.
4738    data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC".
4739      The data format of the input and output data. With the default format
4740      "NDHWC", the data is stored in the order of: [batch, in_depth, in_height,
4741        in_width, in_channels]. Alternatively, the format could be "NCDHW", the
4742      data storage order is: [batch, in_channels, in_depth, in_height,
4743        in_width].
4744    name: A name for the operation (optional).
4745
4746  Returns:
4747    A `Tensor` of format specified by `data_format`.
4748    The max pooled output tensor.
4749  """
4750  with ops.name_scope(name, "MaxPool3D", [input]) as name:
4751    if data_format is None:
4752      data_format = "NDHWC"
4753    channel_index = 1 if data_format.startswith("NC") else 4
4754
4755    ksize = _get_sequence(ksize, 3, channel_index, "ksize")
4756    strides = _get_sequence(strides, 3, channel_index, "strides")
4757
4758    return gen_nn_ops.max_pool3d(
4759        input,
4760        ksize=ksize,
4761        strides=strides,
4762        padding=padding,
4763        data_format=data_format,
4764        name=name)
4765# pylint: enable=redefined-builtin
4766
4767
4768@tf_export("nn.max_pool_with_argmax", v1=[])
4769@dispatch.add_dispatch_support
4770def max_pool_with_argmax_v2(
4771    input,  # pylint: disable=redefined-builtin
4772    ksize,
4773    strides,
4774    padding,
4775    data_format="NHWC",
4776    output_dtype=dtypes.int64,
4777    include_batch_in_index=False,
4778    name=None):
4779  """Performs max pooling on the input and outputs both max values and indices.
4780
4781  The indices in `argmax` are flattened, so that a maximum value at position
4782  `[b, y, x, c]` becomes flattened index: `(y * width + x) * channels + c` if
4783  `include_batch_in_index` is False;
4784  `((b * height + y) * width + x) * channels + c`
4785  if `include_batch_in_index` is True.
4786
4787  The indices returned are always in `[0, height) x [0, width)` before
4788  flattening, even if padding is involved and the mathematically correct answer
4789  is outside (either negative or too large).  This is a bug, but fixing it is
4790  difficult to do in a safe backwards compatible way, especially due to
4791  flattening.
4792
4793  Args:
4794    input: A `Tensor`. Must be one of the following types: `float32`, `float64`,
4795      `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`,
4796      `uint32`, `uint64`.
4797      4-D with shape `[batch, height, width, channels]`.  Input to pool over.
4798    ksize: An int or list of `ints` that has length `1`, `2` or `4`.
4799      The size of the window for each dimension of the input tensor.
4800    strides: An int or list of `ints` that has length `1`, `2` or `4`.
4801      The stride of the sliding window for each dimension of the
4802      input tensor.
4803    padding: A `string` from: `"SAME", "VALID"`.
4804      The type of padding algorithm to use.
4805    data_format: An optional `string`, must be set to `"NHWC"`. Defaults to
4806      `"NHWC"`.
4807      Specify the data format of the input and output data.
4808    output_dtype: An optional `tf.DType` from: `tf.int32, tf.int64`.
4809      Defaults to `tf.int64`.
4810      The dtype of the returned argmax tensor.
4811    include_batch_in_index: An optional `boolean`. Defaults to `False`.
4812      Whether to include batch dimension in flattened index of `argmax`.
4813    name: A name for the operation (optional).
4814
4815  Returns:
4816    A tuple of `Tensor` objects (output, argmax).
4817
4818    output: A `Tensor`. Has the same type as `input`.
4819    argmax: A `Tensor` of type `output_dtype`.
4820  """
4821
4822  if data_format != "NHWC":
4823    raise ValueError("Data formats other than 'NHWC' are not yet supported")
4824
4825  ksize = _get_sequence(ksize, 2, 3, "ksize")
4826  strides = _get_sequence(strides, 2, 3, "strides")
4827
4828  return gen_nn_ops.max_pool_with_argmax(
4829      input=input,
4830      ksize=ksize,
4831      strides=strides,
4832      padding=padding,
4833      Targmax=output_dtype,
4834      include_batch_in_index=include_batch_in_index,
4835      name=name)
4836
4837
4838@tf_export(v1=["nn.max_pool_with_argmax"])
4839@dispatch.add_dispatch_support
4840def max_pool_with_argmax_v1(  # pylint: disable=missing-docstring,invalid-name
4841    input,  # pylint: disable=redefined-builtin
4842    ksize,
4843    strides,
4844    padding,
4845    data_format="NHWC",
4846    Targmax=None,
4847    name=None,
4848    output_dtype=None,
4849    include_batch_in_index=False):
4850  if data_format != "NHWC":
4851    raise ValueError("Data formats other than 'NHWC' are not yet supported")
4852
4853  Targmax = deprecated_argument_lookup(
4854      "output_dtype", output_dtype, "Targmax", Targmax)
4855  if Targmax is None:
4856    Targmax = dtypes.int64
4857  return gen_nn_ops.max_pool_with_argmax(
4858      input=input,
4859      ksize=ksize,
4860      strides=strides,
4861      padding=padding,
4862      Targmax=Targmax,
4863      include_batch_in_index=include_batch_in_index,
4864      name=name)
4865
4866
4867max_pool_with_argmax_v1.__doc__ = gen_nn_ops.max_pool_with_argmax.__doc__
4868
4869
4870@ops.RegisterStatistics("Conv3D", "flops")
4871def _calc_conv3d_flops(graph, node):
4872  """Calculates the compute resources needed for Conv3D."""
4873  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
4874  input_shape.assert_is_fully_defined()
4875  filter_shape = graph_util.tensor_shape_from_node_def_name(
4876      graph, node.input[1])
4877  filter_shape.assert_is_fully_defined()
4878  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
4879  output_shape.assert_is_fully_defined()
4880  filter_time = int(filter_shape[0])
4881  filter_height = int(filter_shape[1])
4882  filter_width = int(filter_shape[2])
4883  filter_in_depth = int(filter_shape[3])
4884  output_count = np.prod(output_shape.as_list(), dtype=np.int64)
4885  return ops.OpStats("flops", (output_count * filter_in_depth * filter_time *
4886                               filter_height * filter_width * 2))
4887
4888
4889@ops.RegisterStatistics("Conv2D", "flops")
4890def _calc_conv_flops(graph, node):
4891  """Calculates the compute resources needed for Conv2D."""
4892  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
4893  input_shape.assert_is_fully_defined()
4894  filter_shape = graph_util.tensor_shape_from_node_def_name(
4895      graph, node.input[1])
4896  filter_shape.assert_is_fully_defined()
4897  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
4898  output_shape.assert_is_fully_defined()
4899  filter_height = int(filter_shape[0])
4900  filter_width = int(filter_shape[1])
4901  filter_in_depth = int(filter_shape[2])
4902  output_count = np.prod(output_shape.as_list(), dtype=np.int64)
4903  return ops.OpStats(
4904      "flops",
4905      (output_count * filter_in_depth * filter_height * filter_width * 2))
4906
4907
4908@ops.RegisterStatistics("DepthwiseConv2dNative", "flops")
4909def _calc_depthwise_conv_flops(graph, node):
4910  """Calculates the compute resources needed for DepthwiseConv2dNative."""
4911  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
4912  input_shape.assert_is_fully_defined()
4913  filter_shape = graph_util.tensor_shape_from_node_def_name(
4914      graph, node.input[1])
4915  filter_shape.assert_is_fully_defined()
4916  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
4917  output_shape.assert_is_fully_defined()
4918  filter_height = int(filter_shape[0])
4919  filter_width = int(filter_shape[1])
4920  output_count = np.prod(output_shape.as_list(), dtype=np.int64)
4921  return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
4922
4923
4924@ops.RegisterStatistics("BiasAdd", "flops")
4925def _calc_bias_add_flops(graph, node):
4926  """Calculates the computing needed for BiasAdd."""
4927  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
4928  input_shape.assert_is_fully_defined()
4929  input_count = np.prod(input_shape.as_list())
4930  return ops.OpStats("flops", input_count)
4931
4932
4933@tf_export(v1=["nn.xw_plus_b"])
4934@dispatch.add_dispatch_support
4935def xw_plus_b(x, weights, biases, name=None):  # pylint: disable=invalid-name
4936  """Computes matmul(x, weights) + biases.
4937
4938  Args:
4939    x: a 2D tensor.  Dimensions typically: batch, in_units
4940    weights: a 2D tensor.  Dimensions typically: in_units, out_units
4941    biases: a 1D tensor.  Dimensions: out_units
4942    name: A name for the operation (optional).  If not specified
4943      "xw_plus_b" is used.
4944
4945  Returns:
4946    A 2-D Tensor computing matmul(x, weights) + biases.
4947    Dimensions typically: batch, out_units.
4948  """
4949  with ops.name_scope(name, "xw_plus_b", [x, weights, biases]) as name:
4950    x = ops.convert_to_tensor(x, name="x")
4951    weights = ops.convert_to_tensor(weights, name="weights")
4952    biases = ops.convert_to_tensor(biases, name="biases")
4953    mm = math_ops.matmul(x, weights)
4954    return bias_add(mm, biases, name=name)
4955
4956
4957def xw_plus_b_v1(x, weights, biases, name=None):
4958  """Computes matmul(x, weights) + biases.
4959
4960  This is a deprecated version of that will soon be removed.
4961
4962  Args:
4963    x: a 2D tensor.  Dimensions typically: batch, in_units
4964    weights: a 2D tensor.  Dimensions typically: in_units, out_units
4965    biases: a 1D tensor.  Dimensions: out_units
4966    name: A name for the operation (optional).  If not specified
4967      "xw_plus_b_v1" is used.
4968
4969  Returns:
4970    A 2-D Tensor computing matmul(x, weights) + biases.
4971    Dimensions typically: batch, out_units.
4972  """
4973  with ops.name_scope(name, "xw_plus_b_v1", [x, weights, biases]) as name:
4974    x = ops.convert_to_tensor(x, name="x")
4975    weights = ops.convert_to_tensor(weights, name="weights")
4976    biases = ops.convert_to_tensor(biases, name="biases")
4977    mm = math_ops.matmul(x, weights)
4978    return bias_add_v1(mm, biases, name=name)
4979
4980
4981def _get_noise_shape(x, noise_shape):
4982  # If noise_shape is none return immediately.
4983  if noise_shape is None:
4984    return array_ops.shape(x)
4985
4986  try:
4987    # Best effort to figure out the intended shape.
4988    # If not possible, let the op to handle it.
4989    # In eager mode exception will show up.
4990    noise_shape_ = tensor_shape.as_shape(noise_shape)
4991  except (TypeError, ValueError):
4992    return noise_shape
4993
4994  if x.shape.dims is not None and len(x.shape.dims) == len(noise_shape_.dims):
4995    new_dims = []
4996    for i, dim in enumerate(x.shape.dims):
4997      if noise_shape_.dims[i].value is None and dim.value is not None:
4998        new_dims.append(dim.value)
4999      else:
5000        new_dims.append(noise_shape_.dims[i].value)
5001    return tensor_shape.TensorShape(new_dims)
5002
5003  return noise_shape
5004
5005
5006@tf_export(v1=["nn.dropout"])
5007@dispatch.add_dispatch_support
5008@deprecation.deprecated_args(None, "Please use `rate` instead of `keep_prob`. "
5009                             "Rate should be set to `rate = 1 - keep_prob`.",
5010                             "keep_prob")
5011def dropout(x, keep_prob=None, noise_shape=None, seed=None, name=None,
5012            rate=None):
5013  """Computes dropout.
5014
5015  For each element of `x`, with probability `rate`, outputs `0`, and otherwise
5016  scales up the input by `1 / (1-rate)`. The scaling is such that the expected
5017  sum is unchanged.
5018
5019  By default, each element is kept or dropped independently.  If `noise_shape`
5020  is specified, it must be
5021  [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
5022  to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]`
5023  will make independent decisions.  For example, if `shape(x) = [k, l, m, n]`
5024  and `noise_shape = [k, 1, 1, n]`, each batch and channel component will be
5025  kept independently and each row and column will be kept or not kept together.
5026
5027  Args:
5028    x: A floating point tensor.
5029    keep_prob: (deprecated) A deprecated alias for `(1-rate)`.
5030    noise_shape: A 1-D `Tensor` of type `int32`, representing the
5031      shape for randomly generated keep/drop flags.
5032    seed: A Python integer. Used to create random seeds. See
5033      `tf.random.set_seed` for behavior.
5034    name: A name for this operation (optional).
5035    rate: A scalar `Tensor` with the same type as `x`. The probability that each
5036      element of `x` is discarded.
5037
5038  Returns:
5039    A Tensor of the same shape of `x`.
5040
5041  Raises:
5042    ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating
5043      point tensor.
5044  """
5045  try:
5046    keep = 1. - keep_prob if keep_prob is not None else None
5047  except TypeError:
5048    raise ValueError("keep_prob must be a floating point number or Tensor "
5049                     "(got %r)" % keep_prob)
5050
5051  rate = deprecation.deprecated_argument_lookup(
5052      "rate", rate,
5053      "keep_prob", keep)
5054
5055  if rate is None:
5056    raise ValueError("You must provide a rate to dropout.")
5057
5058  return dropout_v2(x, rate, noise_shape=noise_shape, seed=seed, name=name)
5059
5060
5061@tf_export("nn.dropout", v1=[])
5062@dispatch.add_dispatch_support
5063def dropout_v2(x, rate, noise_shape=None, seed=None, name=None):
5064  """Computes dropout: randomly sets elements to zero to prevent overfitting.
5065
5066  Note: The behavior of dropout has changed between TensorFlow 1.x and 2.x.
5067  When converting 1.x code, please use named arguments to ensure behavior stays
5068  consistent.
5069
5070  See also: `tf.keras.layers.Dropout` for a dropout layer.
5071
5072  [Dropout](https://arxiv.org/abs/1207.0580) is useful for regularizing DNN
5073  models. Inputs elements are randomly set to zero (and the other elements are
5074  rescaled). This encourages each node to be independently useful, as it cannot
5075  rely on the output of other nodes.
5076
5077  More precisely: With probability `rate` elements of `x` are set to `0`.
5078  The remaining elements are scaled up by `1.0 / (1 - rate)`, so that the
5079  expected value is preserved.
5080
5081  >>> tf.random.set_seed(0)
5082  >>> x = tf.ones([3,5])
5083  >>> tf.nn.dropout(x, rate = 0.5, seed = 1).numpy()
5084  array([[2., 0., 0., 2., 2.],
5085       [2., 2., 2., 2., 2.],
5086       [2., 0., 2., 0., 2.]], dtype=float32)
5087
5088  >>> tf.random.set_seed(0)
5089  >>> x = tf.ones([3,5])
5090  >>> tf.nn.dropout(x, rate = 0.8, seed = 1).numpy()
5091  array([[0., 0., 0., 5., 5.],
5092       [0., 5., 0., 5., 0.],
5093       [5., 0., 5., 0., 5.]], dtype=float32)
5094
5095  >>> tf.nn.dropout(x, rate = 0.0) == x
5096  <tf.Tensor: shape=(3, 5), dtype=bool, numpy=
5097    array([[ True,  True,  True,  True,  True],
5098           [ True,  True,  True,  True,  True],
5099           [ True,  True,  True,  True,  True]])>
5100
5101
5102  By default, each element is kept or dropped independently.  If `noise_shape`
5103  is specified, it must be
5104  [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
5105  to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]`
5106  will make independent decisions. This is useful for dropping whole
5107  channels from an image or sequence. For example:
5108
5109  >>> tf.random.set_seed(0)
5110  >>> x = tf.ones([3,10])
5111  >>> tf.nn.dropout(x, rate = 2/3, noise_shape=[1,10], seed=1).numpy()
5112  array([[0., 0., 0., 3., 3., 0., 3., 3., 3., 0.],
5113       [0., 0., 0., 3., 3., 0., 3., 3., 3., 0.],
5114       [0., 0., 0., 3., 3., 0., 3., 3., 3., 0.]], dtype=float32)
5115
5116  Args:
5117    x: A floating point tensor.
5118    rate: A scalar `Tensor` with the same type as x. The probability
5119      that each element is dropped. For example, setting rate=0.1 would drop
5120      10% of input elements.
5121    noise_shape: A 1-D `Tensor` of type `int32`, representing the
5122      shape for randomly generated keep/drop flags.
5123    seed: A Python integer. Used to create random seeds. See
5124      `tf.random.set_seed` for behavior.
5125    name: A name for this operation (optional).
5126
5127  Returns:
5128    A Tensor of the same shape of `x`.
5129
5130  Raises:
5131    ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating point
5132      tensor. `rate=1` is disallowed, because the output would be all zeros,
5133      which is likely not what was intended.
5134  """
5135  with ops.name_scope(name, "dropout", [x]) as name:
5136    is_rate_number = isinstance(rate, numbers.Real)
5137    if is_rate_number and (rate < 0 or rate >= 1):
5138      raise ValueError("rate must be a scalar tensor or a float in the "
5139                       "range [0, 1), got %g" % rate)
5140    x = ops.convert_to_tensor(x, name="x")
5141    x_dtype = x.dtype
5142    if not x_dtype.is_floating:
5143      raise ValueError("x has to be a floating point tensor since it's going "
5144                       "to be scaled. Got a %s tensor instead." % x_dtype)
5145    if is_rate_number and rate == 0:
5146      # Fast-path: Return the input immediately if rate is non-tensor & is `0`.
5147      # We trigger this after all error checking
5148      # and after `x` has been converted to a tensor, to prevent inconsistent
5149      # tensor conversions/error raising if rate is changed to/from 0.
5150      #
5151      # We also explicitly call `random_seed.get_seed` to make sure
5152      # we don't change the random number generation behavior of
5153      # stateful random ops by entering a fastpath,
5154      # despite not generating a random tensor in the fastpath
5155      random_seed.get_seed(seed)
5156      return x
5157
5158    is_executing_eagerly = context.executing_eagerly()
5159    if not tensor_util.is_tf_type(rate):
5160      if is_rate_number:
5161        keep_prob = 1 - rate
5162        scale = 1 / keep_prob
5163        scale = ops.convert_to_tensor(scale, dtype=x_dtype)
5164        ret = gen_math_ops.mul(x, scale)
5165      else:
5166        raise ValueError("rate is neither scalar nor scalar tensor %r" % rate)
5167    else:
5168      rate.get_shape().assert_has_rank(0)
5169      rate_dtype = rate.dtype
5170      if rate_dtype != x_dtype:
5171        if not rate_dtype.is_compatible_with(x_dtype):
5172          raise ValueError(
5173              "Tensor dtype %s is incomptaible with Tensor dtype %s: %r" %
5174              (x_dtype.name, rate_dtype.name, rate))
5175        rate = gen_math_ops.cast(rate, x_dtype, name="rate")
5176      one_tensor = constant_op.constant(1, dtype=x_dtype)
5177      ret = gen_math_ops.real_div(x, gen_math_ops.sub(one_tensor, rate))
5178
5179    noise_shape = _get_noise_shape(x, noise_shape)
5180    # Sample a uniform distribution on [0.0, 1.0) and select values larger
5181    # than rate.
5182    #
5183    # NOTE: Random uniform can only generate 2^23 floats on [1.0, 2.0)
5184    # and subtract 1.0.
5185    random_tensor = random_ops.random_uniform(
5186        noise_shape, seed=seed, dtype=x_dtype)
5187    # NOTE: if (1.0 + rate) - 1 is equal to rate, then that float is selected,
5188    # hence a >= comparison is used.
5189    keep_mask = random_tensor >= rate
5190    ret = gen_math_ops.mul(ret, gen_math_ops.cast(keep_mask, x_dtype))
5191    if not is_executing_eagerly:
5192      ret.set_shape(x.get_shape())
5193    return ret
5194
5195
5196@tf_export("math.top_k", "nn.top_k")
5197@dispatch.add_dispatch_support
5198def top_k(input, k=1, sorted=True, name=None):  # pylint: disable=redefined-builtin
5199  """Finds values and indices of the `k` largest entries for the last dimension.
5200
5201  If the input is a vector (rank=1), finds the `k` largest entries in the vector
5202  and outputs their values and indices as vectors.  Thus `values[j]` is the
5203  `j`-th largest entry in `input`, and its index is `indices[j]`.
5204
5205  >>> result = tf.math.top_k([1, 2, 98, 1, 1, 99, 3, 1, 3, 96, 4, 1],
5206  ...                         k=3)
5207  >>> result.values.numpy()
5208  array([99, 98, 96], dtype=int32)
5209  >>> result.indices.numpy()
5210  array([5, 2, 9], dtype=int32)
5211
5212  For matrices (resp. higher rank input), computes the top `k` entries in each
5213  row (resp. vector along the last dimension).  Thus,
5214
5215  >>> input = tf.random.normal(shape=(3,4,5,6))
5216  >>> k = 2
5217  >>> values, indices  = tf.math.top_k(input, k=k)
5218  >>> values.shape.as_list()
5219  [3, 4, 5, 2]
5220  >>>
5221  >>> values.shape == indices.shape == input.shape[:-1] + [k]
5222  True
5223
5224  The indices can be used to `gather` from a tensor who's shape matches `input`.
5225
5226  >>> gathered_values = tf.gather(input, indices, batch_dims=-1)
5227  >>> assert tf.reduce_all(gathered_values == values)
5228
5229  If two elements are equal, the lower-index element appears first.
5230
5231  >>> result = tf.math.top_k([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
5232  ...                        k=3)
5233  >>> result.indices.numpy()
5234  array([0, 1, 3], dtype=int32)
5235
5236  Args:
5237    input: 1-D or higher `Tensor` with last dimension at least `k`.
5238    k: 0-D `int32` `Tensor`.  Number of top elements to look for along the last
5239      dimension (along each row for matrices).
5240    sorted: If true the resulting `k` elements will be sorted by the values in
5241      descending order.
5242    name: Optional name for the operation.
5243
5244  Returns:
5245    A tuple with two named fields:
5246    values: The `k` largest elements along each last dimensional slice.
5247    indices: The indices of `values` within the last dimension of `input`.
5248  """
5249  return gen_nn_ops.top_kv2(input, k=k, sorted=sorted, name=name)
5250
5251
5252def nth_element(input, n, reverse=False, name=None):  # pylint: disable=redefined-builtin
5253  r"""Finds values of the `n`-th smallest value for the last dimension.
5254
5255  Note that n is zero-indexed.
5256
5257  If the input is a vector (rank-1), finds the entries which is the nth-smallest
5258  value in the vector and outputs their values as scalar tensor.
5259
5260  For matrices (resp. higher rank input), computes the entries which is the
5261  nth-smallest value in each row (resp. vector along the last dimension). Thus,
5262
5263      values.shape = input.shape[:-1]
5264
5265  Args:
5266    input: 1-D or higher `Tensor` with last dimension at least `n+1`.
5267    n: A `Tensor` of type `int32`.
5268      0-D. Position of sorted vector to select along the last dimension (along
5269      each row for matrices). Valid range of n is `[0, input.shape[:-1])`
5270    reverse: An optional `bool`. Defaults to `False`.
5271      When set to True, find the nth-largest value in the vector and vice
5272      versa.
5273    name: A name for the operation (optional).
5274
5275  Returns:
5276    A `Tensor`. Has the same type as `input`.
5277    The `n`-th order statistic along each last dimensional slice.
5278  """
5279  return gen_nn_ops.nth_element(input, n, reverse=reverse, name=name)
5280
5281
5282@tf_export(v1=["nn.fractional_max_pool"])
5283@dispatch.add_dispatch_support
5284@deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` "
5285                        "args are deprecated.  Use fractional_max_pool_v2.")
5286def fractional_max_pool(value,
5287                        pooling_ratio,
5288                        pseudo_random=False,
5289                        overlapping=False,
5290                        deterministic=False,
5291                        seed=0,
5292                        seed2=0,
5293                        name=None):   # pylint: disable=redefined-builtin
5294  r"""Performs fractional max pooling on the input.
5295
5296  This is a deprecated version of `fractional_max_pool`.
5297
5298  Fractional max pooling is slightly different than regular max pooling.  In
5299  regular max pooling, you downsize an input set by taking the maximum value of
5300  smaller N x N subsections of the set (often 2x2), and try to reduce the set by
5301  a factor of N, where N is an integer.  Fractional max pooling, as you might
5302  expect from the word "fractional", means that the overall reduction ratio N
5303  does not have to be an integer.
5304
5305  The sizes of the pooling regions are generated randomly but are fairly
5306  uniform.  For example, let's look at the height dimension, and the constraints
5307  on the list of rows that will be pool boundaries.
5308
5309  First we define the following:
5310
5311  1.  input_row_length : the number of rows from the input set
5312  2.  output_row_length : which will be smaller than the input
5313  3.  alpha = input_row_length / output_row_length : our reduction ratio
5314  4.  K = floor(alpha)
5315  5.  row_pooling_sequence : this is the result list of pool boundary rows
5316
5317  Then, row_pooling_sequence should satisfy:
5318
5319  1.  a[0] = 0 : the first value of the sequence is 0
5320  2.  a[end] = input_row_length : the last value of the sequence is the size
5321  3.  K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
5322  4.  length(row_pooling_sequence) = output_row_length+1
5323
5324  Args:
5325    value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
5326    pooling_ratio: A list of `floats` that has length >= 4.  Pooling ratio for
5327      each dimension of `value`, currently only supports row and col dimension
5328      and should be >= 1.0. For example, a valid pooling ratio looks like [1.0,
5329      1.44, 1.73, 1.0]. The first and last elements must be 1.0 because we don't
5330      allow pooling on batch and channels dimensions.  1.44 and 1.73 are pooling
5331      ratio on height and width dimensions respectively.
5332    pseudo_random: An optional `bool`.  Defaults to `False`. When set to `True`,
5333      generates the pooling sequence in a pseudorandom fashion, otherwise, in a
5334      random fashion. Check (Graham, 2015) for difference between
5335      pseudorandom and random.
5336    overlapping: An optional `bool`.  Defaults to `False`.  When set to `True`,
5337      it means when pooling, the values at the boundary of adjacent pooling
5338      cells are used by both cells. For example:
5339      `index  0  1  2  3  4`
5340      `value  20 5  16 3  7`
5341      If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used
5342      twice.  The result would be [20, 16] for fractional max pooling.
5343    deterministic: An optional `bool`.  Deprecated; use `fractional_max_pool_v2`
5344      instead.
5345    seed: An optional `int`.  Defaults to `0`.  If set to be non-zero, the
5346      random number generator is seeded by the given seed.  Otherwise it is
5347      seeded by a random seed.
5348    seed2: An optional `int`.  Deprecated; use `fractional_max_pool_v2` instead.
5349    name: A name for the operation (optional).
5350
5351  Returns:
5352  A tuple of `Tensor` objects (`output`, `row_pooling_sequence`,
5353  `col_pooling_sequence`).
5354    output: Output `Tensor` after fractional max pooling.  Has the same type as
5355      `value`.
5356    row_pooling_sequence: A `Tensor` of type `int64`.
5357    col_pooling_sequence: A `Tensor` of type `int64`.
5358
5359  References:
5360    Fractional Max-Pooling:
5361      [Graham, 2015](https://arxiv.org/abs/1412.6071)
5362      ([pdf](https://arxiv.org/pdf/1412.6071.pdf))
5363  """
5364  return gen_nn_ops.fractional_max_pool(value, pooling_ratio, pseudo_random,
5365                                        overlapping, deterministic, seed, seed2,
5366                                        name)
5367
5368
5369@tf_export("nn.fractional_max_pool", v1=[])
5370@dispatch.add_dispatch_support
5371def fractional_max_pool_v2(value,
5372                           pooling_ratio,
5373                           pseudo_random=False,
5374                           overlapping=False,
5375                           seed=0,
5376                           name=None):  # pylint: disable=redefined-builtin
5377  r"""Performs fractional max pooling on the input.
5378
5379  Fractional max pooling is slightly different than regular max pooling.  In
5380  regular max pooling, you downsize an input set by taking the maximum value of
5381  smaller N x N subsections of the set (often 2x2), and try to reduce the set by
5382  a factor of N, where N is an integer.  Fractional max pooling, as you might
5383  expect from the word "fractional", means that the overall reduction ratio N
5384  does not have to be an integer.
5385
5386  The sizes of the pooling regions are generated randomly but are fairly
5387  uniform.  For example, let's look at the height dimension, and the constraints
5388  on the list of rows that will be pool boundaries.
5389
5390  First we define the following:
5391
5392  1.  input_row_length : the number of rows from the input set
5393  2.  output_row_length : which will be smaller than the input
5394  3.  alpha = input_row_length / output_row_length : our reduction ratio
5395  4.  K = floor(alpha)
5396  5.  row_pooling_sequence : this is the result list of pool boundary rows
5397
5398  Then, row_pooling_sequence should satisfy:
5399
5400  1.  a[0] = 0 : the first value of the sequence is 0
5401  2.  a[end] = input_row_length : the last value of the sequence is the size
5402  3.  K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
5403  4.  length(row_pooling_sequence) = output_row_length+1
5404
5405  Args:
5406    value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
5407    pooling_ratio: An int or list of `ints` that has length `1`, `2` or `4`.
5408      Pooling ratio for each dimension of `value`, currently only supports row
5409      and col dimension and should be >= 1.0. For example, a valid pooling ratio
5410      looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements must be 1.0
5411      because we don't allow pooling on batch and channels dimensions.  1.44 and
5412      1.73 are pooling ratio on height and width dimensions respectively.
5413    pseudo_random: An optional `bool`.  Defaults to `False`. When set to `True`,
5414      generates the pooling sequence in a pseudorandom fashion, otherwise, in a
5415      random fashion. Check paper (Graham, 2015) for difference between
5416      pseudorandom and random.
5417    overlapping: An optional `bool`.  Defaults to `False`.  When set to `True`,
5418      it means when pooling, the values at the boundary of adjacent pooling
5419      cells are used by both cells. For example:
5420      `index  0  1  2  3  4`
5421      `value  20 5  16 3  7`
5422      If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used
5423      twice.  The result would be [20, 16] for fractional max pooling.
5424    seed: An optional `int`.  Defaults to `0`.  If set to be non-zero, the
5425      random number generator is seeded by the given seed.  Otherwise it is
5426      seeded by a random seed.
5427    name: A name for the operation (optional).
5428
5429  Returns:
5430  A tuple of `Tensor` objects (`output`, `row_pooling_sequence`,
5431  `col_pooling_sequence`).
5432    output: Output `Tensor` after fractional max pooling.  Has the same type as
5433      `value`.
5434    row_pooling_sequence: A `Tensor` of type `int64`.
5435    col_pooling_sequence: A `Tensor` of type `int64`.
5436
5437  References:
5438    Fractional Max-Pooling:
5439      [Graham, 2015](https://arxiv.org/abs/1412.6071)
5440      ([pdf](https://arxiv.org/pdf/1412.6071.pdf))
5441  """
5442  if (isinstance(pooling_ratio, (list, tuple))):
5443    if (pooling_ratio[0] != 1.0 or pooling_ratio[-1] != 1.0):
5444      raise ValueError(
5445          "The first and last elements of pooling ratio must be 1.0.")
5446    for element in pooling_ratio:
5447      if element < 1.0:
5448        raise ValueError("pooling_ratio should be >= 1.0.")
5449  elif (isinstance(pooling_ratio, (int, float))):
5450    if pooling_ratio < 1.0:
5451      raise ValueError("pooling_ratio should be >= 1.0.")
5452  else:
5453    raise ValueError("pooling_ratio should be an int or a list of ints.")
5454
5455  pooling_ratio = _get_sequence(pooling_ratio, 2, 3, "pooling_ratio")
5456
5457  if seed == 0:
5458    return gen_nn_ops.fractional_max_pool(value, pooling_ratio, pseudo_random,
5459                                          overlapping, deterministic=False,
5460                                          seed=0, seed2=0, name=name)
5461  else:
5462    seed1, seed2 = random_seed.get_seed(seed)
5463    return gen_nn_ops.fractional_max_pool(value, pooling_ratio, pseudo_random,
5464                                          overlapping, deterministic=True,
5465                                          seed=seed1, seed2=seed2, name=name)
5466
5467
5468@tf_export(v1=["nn.fractional_avg_pool"])
5469@dispatch.add_dispatch_support
5470@deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` "
5471                        "args are deprecated.  Use fractional_avg_pool_v2.")
5472def fractional_avg_pool(value,
5473                        pooling_ratio,
5474                        pseudo_random=False,
5475                        overlapping=False,
5476                        deterministic=False,
5477                        seed=0,
5478                        seed2=0,
5479                        name=None):  # pylint: disable=redefined-builtin
5480  r"""Performs fractional average pooling on the input.
5481
5482  This is a deprecated version of `fractional_avg_pool`.
5483
5484  Fractional average pooling is similar to Fractional max pooling in the pooling
5485  region generation step. The only difference is that after pooling regions are
5486  generated, a mean operation is performed instead of a max operation in each
5487  pooling region.
5488
5489  Args:
5490    value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
5491    pooling_ratio: A list of `floats` that has length >= 4.  Pooling ratio for
5492      each dimension of `value`, currently only supports row and col dimension
5493      and should be >= 1.0. For example, a valid pooling ratio looks like [1.0,
5494      1.44, 1.73, 1.0]. The first and last elements must be 1.0 because we don't
5495      allow pooling on batch and channels dimensions.  1.44 and 1.73 are pooling
5496      ratio on height and width dimensions respectively.
5497    pseudo_random: An optional `bool`.  Defaults to `False`. When set to `True`,
5498      generates the pooling sequence in a pseudorandom fashion, otherwise, in a
5499      random fashion. Check paper (Graham, 2015) for difference between
5500      pseudorandom and random.
5501    overlapping: An optional `bool`.  Defaults to `False`.  When set to `True`,
5502      it means when pooling, the values at the boundary of adjacent pooling
5503      cells are used by both cells. For example:
5504      `index  0  1  2  3  4`
5505      `value  20 5  16 3  7`
5506      If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used
5507      twice.  The result would be [20, 16] for fractional avg pooling.
5508    deterministic: An optional `bool`.  Deprecated; use `fractional_avg_pool_v2`
5509      instead.
5510    seed: An optional `int`.  Defaults to `0`.  If set to be non-zero, the
5511      random number generator is seeded by the given seed.  Otherwise it is
5512      seeded by a random seed.
5513    seed2: An optional `int`.  Deprecated; use `fractional_avg_pool_v2` instead.
5514    name: A name for the operation (optional).
5515
5516  Returns:
5517  A tuple of `Tensor` objects (`output`, `row_pooling_sequence`,
5518  `col_pooling_sequence`).
5519    output: Output `Tensor` after fractional avg pooling.  Has the same type as
5520      `value`.
5521    row_pooling_sequence: A `Tensor` of type `int64`.
5522    col_pooling_sequence: A `Tensor` of type `int64`.
5523
5524  References:
5525    Fractional Max-Pooling:
5526      [Graham, 2015](https://arxiv.org/abs/1412.6071)
5527      ([pdf](https://arxiv.org/pdf/1412.6071.pdf))
5528  """
5529  return gen_nn_ops.fractional_avg_pool(value, pooling_ratio, pseudo_random,
5530                                        overlapping, deterministic, seed, seed2,
5531                                        name=name)
5532
5533
5534@tf_export("nn.fractional_avg_pool", v1=[])
5535@dispatch.add_dispatch_support
5536def fractional_avg_pool_v2(value,
5537                           pooling_ratio,
5538                           pseudo_random=False,
5539                           overlapping=False,
5540                           seed=0,
5541                           name=None):  # pylint: disable=redefined-builtin
5542  r"""Performs fractional average pooling on the input.
5543
5544  Fractional average pooling is similar to Fractional max pooling in the pooling
5545  region generation step. The only difference is that after pooling regions are
5546  generated, a mean operation is performed instead of a max operation in each
5547  pooling region.
5548
5549  Args:
5550    value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
5551    pooling_ratio: A list of `floats` that has length >= 4.  Pooling ratio for
5552      each dimension of `value`, currently only supports row and col dimension
5553      and should be >= 1.0. For example, a valid pooling ratio looks like [1.0,
5554      1.44, 1.73, 1.0]. The first and last elements must be 1.0 because we don't
5555      allow pooling on batch and channels dimensions.  1.44 and 1.73 are pooling
5556      ratio on height and width dimensions respectively.
5557    pseudo_random: An optional `bool`.  Defaults to `False`. When set to `True`,
5558      generates the pooling sequence in a pseudorandom fashion, otherwise, in a
5559      random fashion. Check paper (Graham, 2015) for difference between
5560      pseudorandom and random.
5561    overlapping: An optional `bool`.  Defaults to `False`.  When set to `True`,
5562      it means when pooling, the values at the boundary of adjacent pooling
5563      cells are used by both cells. For example:
5564      `index  0  1  2  3  4`
5565      `value  20 5  16 3  7`
5566      If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used
5567      twice.  The result would be [20, 16] for fractional avg pooling.
5568    seed: An optional `int`.  Defaults to `0`.  If set to be non-zero, the
5569      random number generator is seeded by the given seed.  Otherwise it is
5570      seeded by a random seed.
5571    name: A name for the operation (optional).
5572
5573  Returns:
5574  A tuple of `Tensor` objects (`output`, `row_pooling_sequence`,
5575  `col_pooling_sequence`).
5576    output: Output `Tensor` after fractional avg pooling.  Has the same type as
5577      `value`.
5578    row_pooling_sequence: A `Tensor` of type `int64`.
5579    col_pooling_sequence: A `Tensor` of type `int64`.
5580
5581  References:
5582    Fractional Max-Pooling:
5583      [Graham, 2015](https://arxiv.org/abs/1412.6071)
5584      ([pdf](https://arxiv.org/pdf/1412.6071.pdf))
5585  """
5586  if seed == 0:
5587    return gen_nn_ops.fractional_avg_pool(value, pooling_ratio, pseudo_random,
5588                                          overlapping, deterministic=False,
5589                                          seed=0, seed2=0, name=name)
5590  else:
5591    seed1, seed2 = random_seed.get_seed(seed)
5592    return gen_nn_ops.fractional_avg_pool(value, pooling_ratio, pseudo_random,
5593                                          overlapping, deterministic=True,
5594                                          seed=seed1, seed2=seed2, name=name)
5595
5596
5597@ops.RegisterStatistics("Dilation2D", "flops")
5598def _calc_dilation2d_flops(graph, node):
5599  """Calculates the compute resources needed for Dilation2D."""
5600  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
5601  input_shape.assert_is_fully_defined()
5602  filter_shape = graph_util.tensor_shape_from_node_def_name(
5603      graph, node.input[1])
5604  filter_shape.assert_is_fully_defined()
5605  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
5606  output_shape.assert_is_fully_defined()
5607  filter_height = int(filter_shape[0])
5608  filter_width = int(filter_shape[1])
5609  output_count = np.prod(output_shape.as_list(), dtype=np.int64)
5610  return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
5611
5612
5613@tf_export(v1=["nn.erosion2d"])
5614@dispatch.add_dispatch_support
5615def erosion2d(value, kernel, strides, rates, padding, name=None):
5616  """Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors.
5617
5618  The `value` tensor has shape `[batch, in_height, in_width, depth]` and the
5619  `kernel` tensor has shape `[kernel_height, kernel_width, depth]`, i.e.,
5620  each input channel is processed independently of the others with its own
5621  structuring function. The `output` tensor has shape
5622  `[batch, out_height, out_width, depth]`. The spatial dimensions of the
5623  output tensor depend on the `padding` algorithm. We currently only support the
5624  default "NHWC" `data_format`.
5625
5626  In detail, the grayscale morphological 2-D erosion is given by:
5627
5628      output[b, y, x, c] =
5629         min_{dy, dx} value[b,
5630                            strides[1] * y - rates[1] * dy,
5631                            strides[2] * x - rates[2] * dx,
5632                            c] -
5633                      kernel[dy, dx, c]
5634
5635  Duality: The erosion of `value` by the `kernel` is equal to the negation of
5636  the dilation of `-value` by the reflected `kernel`.
5637
5638  Args:
5639    value: A `Tensor`. 4-D with shape `[batch, in_height, in_width, depth]`.
5640    kernel: A `Tensor`. Must have the same type as `value`.
5641      3-D with shape `[kernel_height, kernel_width, depth]`.
5642    strides: A list of `ints` that has length `>= 4`.
5643      1-D of length 4. The stride of the sliding window for each dimension of
5644      the input tensor. Must be: `[1, stride_height, stride_width, 1]`.
5645    rates: A list of `ints` that has length `>= 4`.
5646      1-D of length 4. The input stride for atrous morphological dilation.
5647      Must be: `[1, rate_height, rate_width, 1]`.
5648    padding: A `string` from: `"SAME", "VALID"`.
5649      The type of padding algorithm to use.
5650    name: A name for the operation (optional). If not specified "erosion2d"
5651      is used.
5652
5653  Returns:
5654    A `Tensor`. Has the same type as `value`.
5655    4-D with shape `[batch, out_height, out_width, depth]`.
5656  Raises:
5657    ValueError: If the `value` depth does not match `kernel`' shape, or if
5658      padding is other than `'VALID'` or `'SAME'`.
5659  """
5660  with ops.name_scope(name, "erosion2d", [value, kernel]) as name:
5661    # Reduce erosion to dilation by duality.
5662    return math_ops.negative(
5663        gen_nn_ops.dilation2d(
5664            input=math_ops.negative(value),
5665            filter=array_ops.reverse_v2(kernel, [0, 1]),
5666            strides=strides,
5667            rates=rates,
5668            padding=padding,
5669            name=name))
5670
5671
5672@tf_export("nn.erosion2d", v1=[])
5673@dispatch.add_dispatch_support
5674def erosion2d_v2(value,
5675                 filters,
5676                 strides,
5677                 padding,
5678                 data_format,
5679                 dilations,
5680                 name=None):
5681  """Computes the grayscale erosion of 4-D `value` and 3-D `filters` tensors.
5682
5683  The `value` tensor has shape `[batch, in_height, in_width, depth]` and the
5684  `filters` tensor has shape `[filters_height, filters_width, depth]`, i.e.,
5685  each input channel is processed independently of the others with its own
5686  structuring function. The `output` tensor has shape
5687  `[batch, out_height, out_width, depth]`. The spatial dimensions of the
5688  output tensor depend on the `padding` algorithm. We currently only support the
5689  default "NHWC" `data_format`.
5690
5691  In detail, the grayscale morphological 2-D erosion is given by:
5692
5693      output[b, y, x, c] =
5694         min_{dy, dx} value[b,
5695                            strides[1] * y - dilations[1] * dy,
5696                            strides[2] * x - dilations[2] * dx,
5697                            c] -
5698                      filters[dy, dx, c]
5699
5700  Duality: The erosion of `value` by the `filters` is equal to the negation of
5701  the dilation of `-value` by the reflected `filters`.
5702
5703  Args:
5704    value: A `Tensor`. 4-D with shape `[batch, in_height, in_width, depth]`.
5705    filters: A `Tensor`. Must have the same type as `value`.
5706      3-D with shape `[filters_height, filters_width, depth]`.
5707    strides: A list of `ints` that has length `>= 4`.
5708      1-D of length 4. The stride of the sliding window for each dimension of
5709      the input tensor. Must be: `[1, stride_height, stride_width, 1]`.
5710    padding: A `string` from: `"SAME", "VALID"`.
5711      The type of padding algorithm to use.
5712    data_format: A `string`, only `"NHWC"` is currently supported.
5713    dilations: A list of `ints` that has length `>= 4`.
5714      1-D of length 4. The input stride for atrous morphological dilation.
5715      Must be: `[1, rate_height, rate_width, 1]`.
5716    name: A name for the operation (optional). If not specified "erosion2d"
5717      is used.
5718
5719  Returns:
5720    A `Tensor`. Has the same type as `value`.
5721    4-D with shape `[batch, out_height, out_width, depth]`.
5722
5723  Raises:
5724    ValueError: If the `value` depth does not match `filters`' shape, or if
5725      padding is other than `'VALID'` or `'SAME'`.
5726  """
5727  if data_format != "NHWC":
5728    raise ValueError("Data formats other than NHWC are not yet supported")
5729
5730  with ops.name_scope(name, "erosion2d", [value, filters]) as name:
5731    # Reduce erosion to dilation by duality.
5732    return math_ops.negative(
5733        gen_nn_ops.dilation2d(
5734            input=math_ops.negative(value),
5735            filter=array_ops.reverse_v2(filters, [0, 1]),
5736            strides=strides,
5737            rates=dilations,
5738            padding=padding,
5739            name=name))
5740
5741
5742@tf_export(v1=["math.in_top_k", "nn.in_top_k"])
5743@dispatch.add_dispatch_support
5744def in_top_k(predictions, targets, k, name=None):
5745  r"""Says whether the targets are in the top `K` predictions.
5746
5747  This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
5748  prediction for the target class is finite (not inf, -inf, or nan) and among
5749  the top `k` predictions among all predictions for example `i`. Note that the
5750  behavior of `InTopK` differs from the `TopK` op in its handling of ties; if
5751  multiple classes have the same prediction value and straddle the top-`k`
5752  boundary, all of those classes are considered to be in the top `k`.
5753
5754  More formally, let
5755
5756    \\(predictions_i\\) be the predictions for all classes for example `i`,
5757    \\(targets_i\\) be the target class for example `i`,
5758    \\(out_i\\) be the output for example `i`,
5759
5760  $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
5761
5762  Args:
5763    predictions: A `Tensor` of type `float32`.
5764      A `batch_size` x `classes` tensor.
5765    targets: A `Tensor`. Must be one of the following types: `int32`, `int64`.
5766      A `batch_size` vector of class ids.
5767    k: An `int`. Number of top elements to look at for computing precision.
5768    name: A name for the operation (optional).
5769
5770  Returns:
5771    A `Tensor` of type `bool`. Computed Precision at `k` as a `bool Tensor`.
5772  """
5773  with ops.name_scope(name, "in_top_k"):
5774    return gen_nn_ops.in_top_kv2(predictions, targets, k, name=name)
5775
5776
5777@tf_export("math.in_top_k", "nn.in_top_k", v1=[])
5778@dispatch.add_dispatch_support
5779def in_top_k_v2(targets, predictions, k, name=None):
5780  return in_top_k(predictions, targets, k, name)
5781
5782
5783in_top_k_v2.__doc__ = in_top_k.__doc__
5784
5785
5786tf_export(v1=["nn.quantized_avg_pool"])(
5787    dispatch.add_dispatch_support(gen_nn_ops.quantized_avg_pool))
5788tf_export(v1=["nn.quantized_conv2d"])(
5789    dispatch.add_dispatch_support(gen_nn_ops.quantized_conv2d))
5790tf_export(v1=["nn.quantized_relu_x"])(
5791    dispatch.add_dispatch_support(gen_nn_ops.quantized_relu_x))
5792tf_export(v1=["nn.quantized_max_pool"])(
5793    dispatch.add_dispatch_support(gen_nn_ops.quantized_max_pool))
5794
5795
5796@tf_export("nn.isotonic_regression", v1=[])
5797@dispatch.add_dispatch_support
5798def isotonic_regression(inputs, decreasing=True, axis=-1):
5799  r"""Solves isotonic regression problems along the given axis.
5800
5801  For each vector x, the problem solved is
5802
5803  $$\argmin_{y_1 >= y_2 >= ... >= y_n} \sum_i (x_i - y_i)^2.$$
5804
5805  As the solution is component-wise constant, a second tensor is returned that
5806  encodes the segments. The problems are solved over the given axis.
5807
5808  Consider the following example, where we solve a batch of two problems. The
5809  first input is [3, 1, 2], while the second [1, 3, 4] (as the axis is 1).
5810  >>> x = tf.constant([[3, 1, 2], [1, 3, 4]], dtype=tf.float32)
5811  >>> y, segments = tf.nn.isotonic_regression(x, axis=1)
5812  >>> y  # The solution.
5813  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
5814  array([[3.       , 1.5      , 1.5      ],
5815         [2.6666667, 2.6666667, 2.6666667]], dtype=float32)>
5816
5817  Note that the first solution has two blocks [2] and [1.5, 1.5]. The second
5818  solution is constant, and thus has a single segment. These segments are
5819  exactly what the second returned tensor encodes:
5820
5821  >>> segments
5822  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
5823  array([[0, 1, 1],
5824         [0, 0, 0]], dtype=int32)>
5825
5826
5827  Args:
5828    inputs: A tensor holding the inputs.
5829    decreasing: If set to False, the inequalities in the optimizing constrained
5830      are flipped.
5831    axis: The axis along which the problems should be solved.
5832
5833  Returns:
5834    output: The solutions, same shape as type as the input.
5835    segments: An int32 tensor, same shape as the input indicating the segments
5836      that have the same value. Specifically, those positions that have the same
5837      value correspond to the same segment. These values start at zero, and are
5838      monotonously increasing for each solution.
5839  """
5840  type_promotions = {
5841      # Float types get mapped to themselves, int8/16 to float32, rest to double
5842      dtypes.float32:
5843          dtypes.float32,
5844      dtypes.half:
5845          dtypes.half,
5846      dtypes.bfloat16:
5847          dtypes.bfloat16,
5848      dtypes.int8:
5849          dtypes.float32,
5850      dtypes.int16:
5851          dtypes.float32,
5852  }
5853  inputs = ops.convert_to_tensor(inputs)
5854  try:
5855    output_dtype = type_promotions[inputs.dtype]
5856  except KeyError:
5857    output_dtype = dtypes.float64
5858
5859  def compute_on_matrix(matrix, name=None):
5860    iso_fn = functools.partial(
5861        gen_nn_ops.isotonic_regression, output_dtype=output_dtype, name=name)
5862    if decreasing:
5863      return iso_fn(matrix)
5864    else:
5865      output, segments = iso_fn(-matrix)
5866      return -output, segments
5867
5868  return _wrap_2d_function(inputs, compute_on_matrix, axis)
5869