1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-short-docstring-punctuation
16"""Asserts and Boolean Checks."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23
24import numpy as np
25
26from tensorflow.python.eager import context
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.util import compat
37from tensorflow.python.util import deprecation
38from tensorflow.python.util import dispatch
39from tensorflow.python.util.tf_export import tf_export
40
41NUMERIC_TYPES = frozenset(
42    [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32,
43     dtypes.int64, dtypes.uint8, dtypes.qint8, dtypes.qint32, dtypes.quint8,
44     dtypes.complex64])
45
46__all__ = [
47    'assert_negative',
48    'assert_positive',
49    'assert_proper_iterable',
50    'assert_non_negative',
51    'assert_non_positive',
52    'assert_equal',
53    'assert_none_equal',
54    'assert_near',
55    'assert_integer',
56    'assert_less',
57    'assert_less_equal',
58    'assert_greater',
59    'assert_greater_equal',
60    'assert_rank',
61    'assert_rank_at_least',
62    'assert_rank_in',
63    'assert_same_float_dtype',
64    'assert_scalar',
65    'assert_type',
66    'assert_shapes',
67    'is_non_decreasing',
68    'is_numeric_tensor',
69    'is_strictly_increasing',
70]
71
72
73def _maybe_constant_value_string(t):
74  if not isinstance(t, ops.Tensor):
75    return str(t)
76  const_t = tensor_util.constant_value(t)
77  if const_t is not None:
78    return str(const_t)
79  return t
80
81
82def _assert_static(condition, data):
83  """Raises a InvalidArgumentError with as much information as possible."""
84  if not condition:
85    data_static = [_maybe_constant_value_string(x) for x in data]
86    raise errors.InvalidArgumentError(node_def=None, op=None,
87                                      message='\n'.join(data_static))
88
89
90def _shape_and_dtype_str(tensor):
91  """Returns a string containing tensor's shape and dtype."""
92  return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)
93
94
95def _unary_assert_doc(sym, sym_name):
96  """Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor.
97
98  Args:
99    sym: Mathematical symbol for the check performed on each element, i.e. "> 0"
100    sym_name: English-language name for the op described by sym
101
102  Returns:
103    Decorator that adds the appropriate docstring to the function for symbol
104    `sym`.
105  """
106
107  def _decorator(func):
108    """Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
109
110    Args:
111      func: Function for a TensorFlow op
112
113    Returns:
114      Version of `func` with documentation attached.
115    """
116    opname = func.__name__
117    cap_sym_name = sym_name.capitalize()
118
119    func.__doc__ = """
120    Assert the condition `x {sym}` holds element-wise.
121
122    When running in graph mode, you should add a dependency on this operation
123    to ensure that it runs. Example of adding a dependency to an operation:
124
125    ```python
126    with tf.control_dependencies([tf.debugging.{opname}(x, y)]):
127      output = tf.reduce_sum(x)
128    ```
129
130    {sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`.
131    If `x` is empty this is trivially satisfied.
132
133    Args:
134      x:  Numeric `Tensor`.
135      data:  The tensors to print out if the condition is False.  Defaults to
136        error message and first few entries of `x`.
137      summarize: Print this many entries of each tensor.
138      message: A string to prefix to the default message.
139      name: A name for this operation (optional).  Defaults to "{opname}".
140
141    Returns:
142      Op that raises `InvalidArgumentError` if `x {sym}` is False.
143      @compatibility(eager)
144        returns None
145      @end_compatibility
146
147    Raises:
148      InvalidArgumentError: if the check can be performed immediately and
149        `x {sym}` is False. The check can be performed immediately during
150        eager execution or if `x` is statically known.
151    """.format(
152        sym=sym, sym_name=cap_sym_name, opname=opname)
153    return func
154
155  return _decorator
156
157
158def _binary_assert_doc(sym):
159  """Common docstring for most of the v1 assert_* ops that compare two tensors element-wise.
160
161  Args:
162    sym: Binary operation symbol, i.e. "=="
163
164  Returns:
165    Decorator that adds the appropriate docstring to the function for
166  symbol `sym`.
167  """
168
169  def _decorator(func):
170    """Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
171
172    Args:
173      func: Function for a TensorFlow op
174
175    Returns:
176      A version of `func` with documentation attached.
177    """
178    opname = func.__name__
179
180    func.__doc__ = """
181    Assert the condition `x {sym} y` holds element-wise.
182
183    This condition holds if for every pair of (possibly broadcast) elements
184    `x[i]`, `y[i]`, we have `x[i] {sym} y[i]`.
185    If both `x` and `y` are empty, this is trivially satisfied.
186
187    When running in graph mode, you should add a dependency on this operation
188    to ensure that it runs. Example of adding a dependency to an operation:
189
190    ```python
191    with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]):
192      output = tf.reduce_sum(x)
193    ```
194
195    Args:
196      x:  Numeric `Tensor`.
197      y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
198      data:  The tensors to print out if the condition is False.  Defaults to
199        error message and first few entries of `x`, `y`.
200      summarize: Print this many entries of each tensor.
201      message: A string to prefix to the default message.
202      name: A name for this operation (optional).  Defaults to "{opname}".
203
204    Returns:
205      Op that raises `InvalidArgumentError` if `x {sym} y` is False.
206      @compatibility(eager)
207        returns None
208      @end_compatibility
209
210    Raises:
211      InvalidArgumentError: if the check can be performed immediately and
212        `x {sym} y` is False. The check can be performed immediately during
213        eager execution or if `x` and `y` are statically known.
214    """.format(
215        sym=sym, opname=opname)
216    return func
217
218  return _decorator
219
220
221def _make_assert_msg_data(sym, x, y, summarize, test_op):
222  """Subroutine of _binary_assert that generates the components of the default error message when running in eager mode.
223
224  Args:
225    sym: Mathematical symbol for the test to apply to pairs of tensor elements,
226      i.e. "=="
227    x: First input to the assertion after applying `convert_to_tensor()`
228    y: Second input to the assertion
229    summarize: Value of the "summarize" parameter to the original assert_* call;
230      tells how many elements of each tensor to print.
231    test_op: TensorFlow op that returns a Boolean tensor with True in each
232      position where the assertion is satisfied.
233
234  Returns:
235    List of tensors and scalars that, when stringified and concatenated,
236    will produce the error message string.
237  """
238  # Prepare a message with first elements of x and y.
239  data = []
240
241  data.append('Condition x %s y did not hold.' % sym)
242
243  if summarize > 0:
244    if x.shape == y.shape and x.shape.as_list():
245      # If the shapes of x and y are the same (and not scalars),
246      # Get the values that actually differed and their indices.
247      # If shapes are different this information is more confusing
248      # than useful.
249      mask = math_ops.logical_not(test_op)
250      indices = array_ops.where(mask)
251      indices_np = indices.numpy()
252      x_vals = array_ops.boolean_mask(x, mask)
253      y_vals = array_ops.boolean_mask(y, mask)
254      num_vals = min(summarize, indices_np.shape[0])
255      data.append('Indices of first %d different values:' % num_vals)
256      data.append(indices_np[:num_vals])
257      data.append('Corresponding x values:')
258      data.append(x_vals.numpy().reshape((-1,))[:num_vals])
259      data.append('Corresponding y values:')
260      data.append(y_vals.numpy().reshape((-1,))[:num_vals])
261
262    # reshape((-1,)) is the fastest way to get a flat array view.
263    x_np = x.numpy().reshape((-1,))
264    y_np = y.numpy().reshape((-1,))
265    x_sum = min(x_np.size, summarize)
266    y_sum = min(y_np.size, summarize)
267    data.append('First %d elements of x:' % x_sum)
268    data.append(x_np[:x_sum])
269    data.append('First %d elements of y:' % y_sum)
270    data.append(y_np[:y_sum])
271
272  return data
273
274
275def _pretty_print(data_item, summarize):
276  """Format a data item for use in an error message in eager mode.
277
278  Args:
279    data_item: One of the items in the "data" argument to an assert_* function.
280      Can be a Tensor or a scalar value.
281    summarize: How many elements to retain of each tensor-valued entry in data.
282
283  Returns:
284    An appropriate string representation of data_item
285  """
286  if isinstance(data_item, ops.Tensor):
287    arr = data_item.numpy()
288    if np.isscalar(arr):
289      # Tensor.numpy() returns a scalar for zero-dimensional tensors
290      return str(arr)
291    else:
292      flat = arr.reshape((-1,))
293      lst = [str(x) for x in flat[:summarize]]
294      if len(lst) < flat.size:
295        lst.append('...')
296      return str(lst)
297  else:
298    return str(data_item)
299
300
301def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
302                   message, name):
303  """Generic binary elementwise assertion.
304
305  Implements the behavior described in _binary_assert_doc() above.
306  Args:
307    sym: Mathematical symbol for the test to apply to pairs of tensor elements,
308      i.e. "=="
309    opname: Name of the assert op in the public API, i.e. "assert_equal"
310    op_func: Function that, if passed the two Tensor inputs to the assertion (x
311      and y), will return the test to be passed to reduce_all() i.e.
312    static_func: Function that, if passed numpy ndarray versions of the two
313      inputs to the assertion, will return a Boolean ndarray with containing
314      True in all positions where the assertion PASSES.
315      i.e. np.equal for assert_equal()
316    x:  Numeric `Tensor`.
317    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
318    data:  The tensors to print out if the condition is False.  Defaults to
319      error message and first few entries of `x`, `y`.
320    summarize: Print this many entries of each tensor.
321    message: A string to prefix to the default message.
322    name: A name for this operation (optional).  Defaults to the value of
323      `opname`.
324
325  Returns:
326    See docstring template in _binary_assert_doc().
327  """
328  with ops.name_scope(name, opname, [x, y, data]):
329    x = ops.convert_to_tensor(x, name='x')
330    y = ops.convert_to_tensor(y, name='y')
331
332    if context.executing_eagerly():
333      test_op = op_func(x, y)
334      condition = math_ops.reduce_all(test_op)
335      if condition:
336        return
337
338      # If we get here, the assertion has failed.
339      # Default to printing 3 elements like control_flow_ops.Assert (used
340      # by graph mode) does. Also treat negative values as "print
341      # everything" for consistency with Tensor::SummarizeValue().
342      if summarize is None:
343        summarize = 3
344      elif summarize < 0:
345        summarize = 1e9  # Code below will find exact size of x and y.
346
347      if data is None:
348        data = _make_assert_msg_data(sym, x, y, summarize, test_op)
349
350      if message is not None:
351        data = [message] + list(data)
352
353      raise errors.InvalidArgumentError(
354          node_def=None,
355          op=None,
356          message=('\n'.join(_pretty_print(d, summarize) for d in data)))
357
358    else:  # not context.executing_eagerly()
359      if data is None:
360        data = [
361            'Condition x %s y did not hold element-wise:' % sym,
362            'x (%s) = ' % x.name, x,
363            'y (%s) = ' % y.name, y
364        ]
365      if message is not None:
366        data = [message] + list(data)
367      condition = math_ops.reduce_all(op_func(x, y))
368      x_static = tensor_util.constant_value(x)
369      y_static = tensor_util.constant_value(y)
370      if x_static is not None and y_static is not None:
371        condition_static = np.all(static_func(x_static, y_static))
372        _assert_static(condition_static, data)
373      return control_flow_ops.Assert(condition, data, summarize=summarize)
374
375
376@tf_export(
377    'debugging.assert_proper_iterable',
378    v1=['debugging.assert_proper_iterable', 'assert_proper_iterable'])
379@dispatch.add_dispatch_support
380@deprecation.deprecated_endpoints('assert_proper_iterable')
381def assert_proper_iterable(values):
382  """Static assert that values is a "proper" iterable.
383
384  `Ops` that expect iterables of `Tensor` can call this to validate input.
385  Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.
386
387  Args:
388    values:  Object to be checked.
389
390  Raises:
391    TypeError:  If `values` is not iterable or is one of
392      `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
393  """
394  unintentional_iterables = (
395      (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray)
396      + compat.bytes_or_text_types
397  )
398  if isinstance(values, unintentional_iterables):
399    raise TypeError(
400        'Expected argument "values" to be a "proper" iterable.  Found: %s' %
401        type(values))
402
403  if not hasattr(values, '__iter__'):
404    raise TypeError(
405        'Expected argument "values" to be iterable.  Found: %s' % type(values))
406
407
408@tf_export('debugging.assert_negative', v1=[])
409@dispatch.add_dispatch_support
410def assert_negative_v2(x, message=None, summarize=None, name=None):
411  """Assert the condition `x < 0` holds element-wise.
412
413  This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is
414  empty, this is trivially satisfied.
415
416  If `x` is not negative everywhere, `message`, as well as the first `summarize`
417  entries of `x` are printed, and `InvalidArgumentError` is raised.
418
419  Args:
420    x:  Numeric `Tensor`.
421    message: A string to prefix to the default message.
422    summarize: Print this many entries of each tensor.
423    name: A name for this operation (optional).  Defaults to "assert_negative".
424
425  Returns:
426    Op raising `InvalidArgumentError` unless `x` is all negative. This can be
427      used with `tf.control_dependencies` inside of `tf.function`s to block
428      followup computation until the check has executed.
429    @compatibility(eager)
430    returns None
431    @end_compatibility
432
433  Raises:
434    InvalidArgumentError: if the check can be performed immediately and
435      `x[i] < 0` is False. The check can be performed immediately during eager
436      execution or if `x` is statically known.
437  """
438  return assert_negative(x=x, message=message, summarize=summarize, name=name)
439
440
441@tf_export(v1=['debugging.assert_negative', 'assert_negative'])
442@dispatch.add_dispatch_support
443@deprecation.deprecated_endpoints('assert_negative')
444@_unary_assert_doc('< 0', 'negative')
445def assert_negative(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
446  message = message or ''
447  with ops.name_scope(name, 'assert_negative', [x, data]):
448    x = ops.convert_to_tensor(x, name='x')
449    if data is None:
450      if context.executing_eagerly():
451        name = _shape_and_dtype_str(x)
452      else:
453        name = x.name
454      data = [
455          message,
456          'Condition x < 0 did not hold element-wise:',
457          'x (%s) = ' % name, x]
458    zero = ops.convert_to_tensor(0, dtype=x.dtype)
459    return assert_less(x, zero, data=data, summarize=summarize)
460
461
462@tf_export('debugging.assert_positive', v1=[])
463@dispatch.add_dispatch_support
464def assert_positive_v2(x, message=None, summarize=None, name=None):
465  """Assert the condition `x > 0` holds element-wise.
466
467  This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is
468  empty, this is trivially satisfied.
469
470  If `x` is not positive everywhere, `message`, as well as the first `summarize`
471  entries of `x` are printed, and `InvalidArgumentError` is raised.
472
473  Args:
474    x:  Numeric `Tensor`.
475    message: A string to prefix to the default message.
476    summarize: Print this many entries of each tensor.
477    name: A name for this operation (optional). Defaults to "assert_positive".
478
479  Returns:
480    Op raising `InvalidArgumentError` unless `x` is all positive. This can be
481      used with `tf.control_dependencies` inside of `tf.function`s to block
482      followup computation until the check has executed.
483    @compatibility(eager)
484    returns None
485    @end_compatibility
486
487  Raises:
488    InvalidArgumentError: if the check can be performed immediately and
489      `x[i] > 0` is False. The check can be performed immediately during eager
490      execution or if `x` is statically known.
491  """
492  return assert_positive(x=x, summarize=summarize, message=message, name=name)
493
494
495@tf_export(v1=['debugging.assert_positive', 'assert_positive'])
496@dispatch.add_dispatch_support
497@deprecation.deprecated_endpoints('assert_positive')
498@_unary_assert_doc('> 0', 'positive')
499def assert_positive(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
500  message = message or ''
501  with ops.name_scope(name, 'assert_positive', [x, data]):
502    x = ops.convert_to_tensor(x, name='x')
503    if data is None:
504      if context.executing_eagerly():
505        name = _shape_and_dtype_str(x)
506      else:
507        name = x.name
508      data = [
509          message, 'Condition x > 0 did not hold element-wise:',
510          'x (%s) = ' % name, x]
511    zero = ops.convert_to_tensor(0, dtype=x.dtype)
512    return assert_less(zero, x, data=data, summarize=summarize)
513
514
515@tf_export('debugging.assert_non_negative', v1=[])
516@dispatch.add_dispatch_support
517def assert_non_negative_v2(x, message=None, summarize=None, name=None):
518  """Assert the condition `x >= 0` holds element-wise.
519
520  This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is
521  empty, this is trivially satisfied.
522
523  If `x` is not >= 0 everywhere, `message`, as well as the first `summarize`
524  entries of `x` are printed, and `InvalidArgumentError` is raised.
525
526  Args:
527    x:  Numeric `Tensor`.
528    message: A string to prefix to the default message.
529    summarize: Print this many entries of each tensor.
530    name: A name for this operation (optional).  Defaults to
531      "assert_non_negative".
532
533  Returns:
534    Op raising `InvalidArgumentError` unless `x` is all non-negative. This can
535      be used with `tf.control_dependencies` inside of `tf.function`s to block
536      followup computation until the check has executed.
537    @compatibility(eager)
538    returns None
539    @end_compatibility
540
541  Raises:
542    InvalidArgumentError: if the check can be performed immediately and
543      `x[i] >= 0` is False. The check can be performed immediately during eager
544      execution or if `x` is statically known.
545  """
546  return assert_non_negative(x=x, summarize=summarize, message=message,
547                             name=name)
548
549
550@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative'])
551@dispatch.add_dispatch_support
552@deprecation.deprecated_endpoints('assert_non_negative')
553@_unary_assert_doc('>= 0', 'non-negative')
554def assert_non_negative(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
555  message = message or ''
556  with ops.name_scope(name, 'assert_non_negative', [x, data]):
557    x = ops.convert_to_tensor(x, name='x')
558    if data is None:
559      if context.executing_eagerly():
560        name = _shape_and_dtype_str(x)
561      else:
562        name = x.name
563      data = [
564          message,
565          'Condition x >= 0 did not hold element-wise:',
566          'x (%s) = ' % name, x]
567    zero = ops.convert_to_tensor(0, dtype=x.dtype)
568    return assert_less_equal(zero, x, data=data, summarize=summarize)
569
570
571@tf_export('debugging.assert_non_positive', v1=[])
572@dispatch.add_dispatch_support
573def assert_non_positive_v2(x, message=None, summarize=None, name=None):
574  """Assert the condition `x <= 0` holds element-wise.
575
576  This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is
577  empty, this is trivially satisfied.
578
579  If `x` is not <= 0 everywhere, `message`, as well as the first `summarize`
580  entries of `x` are printed, and `InvalidArgumentError` is raised.
581
582  Args:
583    x:  Numeric `Tensor`.
584    message: A string to prefix to the default message.
585    summarize: Print this many entries of each tensor.
586    name: A name for this operation (optional).  Defaults to
587      "assert_non_positive".
588
589  Returns:
590    Op raising `InvalidArgumentError` unless `x` is all non-positive. This can
591      be used with `tf.control_dependencies` inside of `tf.function`s to block
592      followup computation until the check has executed.
593    @compatibility(eager)
594    returns None
595    @end_compatibility
596
597  Raises:
598    InvalidArgumentError: if the check can be performed immediately and
599      `x[i] <= 0` is False. The check can be performed immediately during eager
600      execution or if `x` is statically known.
601  """
602  return assert_non_positive(x=x, summarize=summarize, message=message,
603                             name=name)
604
605
606@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive'])
607@dispatch.add_dispatch_support
608@deprecation.deprecated_endpoints('assert_non_positive')
609@_unary_assert_doc('<= 0', 'non-positive')
610def assert_non_positive(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
611  message = message or ''
612  with ops.name_scope(name, 'assert_non_positive', [x, data]):
613    x = ops.convert_to_tensor(x, name='x')
614    if data is None:
615      if context.executing_eagerly():
616        name = _shape_and_dtype_str(x)
617      else:
618        name = x.name
619      data = [
620          message,
621          'Condition x <= 0 did not hold element-wise:'
622          'x (%s) = ' % name, x]
623    zero = ops.convert_to_tensor(0, dtype=x.dtype)
624    return assert_less_equal(x, zero, data=data, summarize=summarize)
625
626
627@tf_export('debugging.assert_equal', 'assert_equal', v1=[])
628@dispatch.add_dispatch_support
629def assert_equal_v2(x, y, message=None, summarize=None, name=None):
630  """Assert the condition `x == y` holds element-wise.
631
632  This Op checks that `x[i] == y[i]` holds for every pair of (possibly
633  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
634  trivially satisfied.
635
636  If `x` and `y` are not equal, `message`, as well as the first `summarize`
637  entries of `x` and `y` are printed, and `InvalidArgumentError` is raised.
638
639  Args:
640    x:  Numeric `Tensor`.
641    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
642    message: A string to prefix to the default message.
643    summarize: Print this many entries of each tensor.
644    name: A name for this operation (optional).  Defaults to "assert_equal".
645
646  Returns:
647    Op that raises `InvalidArgumentError` if `x == y` is False. This can be
648      used with `tf.control_dependencies` inside of `tf.function`s to block
649      followup computation until the check has executed.
650    @compatibility(eager)
651    returns None
652    @end_compatibility
653
654  Raises:
655    InvalidArgumentError: if the check can be performed immediately and
656      `x == y` is False. The check can be performed immediately during eager
657      execution or if `x` and `y` are statically known.
658  """
659  return assert_equal(x=x, y=y, summarize=summarize, message=message, name=name)
660
661
662@tf_export(v1=['debugging.assert_equal', 'assert_equal'])
663@dispatch.add_dispatch_support
664@_binary_assert_doc('==')
665def assert_equal(x, y, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
666  with ops.name_scope(name, 'assert_equal', [x, y, data]):
667    # Short-circuit if x and y are the same tensor.
668    if x is y:
669      return None if context.executing_eagerly() else control_flow_ops.no_op()
670  return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y,
671                        data, summarize, message, name)
672
673
674@tf_export('debugging.assert_none_equal', v1=[])
675@dispatch.add_dispatch_support
676def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
677  """Assert the condition `x != y` holds for all elements.
678
679  This Op checks that `x[i] != y[i]` holds for every pair of (possibly
680  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
681  trivially satisfied.
682
683  If any elements of `x` and `y` are equal, `message`, as well as the first
684  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError`
685  is raised.
686
687  Args:
688    x:  Numeric `Tensor`.
689    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
690    summarize: Print this many entries of each tensor.
691    message: A string to prefix to the default message.
692    name: A name for this operation (optional).  Defaults to
693    "assert_none_equal".
694
695  Returns:
696    Op that raises `InvalidArgumentError` if `x != y` is ever False. This can
697      be used with `tf.control_dependencies` inside of `tf.function`s to block
698      followup computation until the check has executed.
699    @compatibility(eager)
700    returns None
701    @end_compatibility
702
703  Raises:
704    InvalidArgumentError: if the check can be performed immediately and
705      `x != y` is False for any pair of elements in `x` and `y`. The check can
706      be performed immediately during eager execution or if `x` and `y` are
707      statically known.
708  """
709  return assert_none_equal(x=x, y=y, summarize=summarize, message=message,
710                           name=name)
711
712
713@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal'])
714@dispatch.add_dispatch_support
715@deprecation.deprecated_endpoints('assert_none_equal')
716@_binary_assert_doc('!=')
717def assert_none_equal(
718    x, y, data=None, summarize=None, message=None, name=None):
719  return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal,
720                        np.not_equal, x, y, data, summarize, message, name)
721
722
723@tf_export('debugging.assert_near', v1=[])
724@dispatch.add_dispatch_support
725def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None,
726                   name=None):
727  """Assert the condition `x` and `y` are close element-wise.
728
729  This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every
730  pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are
731  empty, this is trivially satisfied.
732
733  If any elements of `x` and `y` are not close, `message`, as well as the first
734  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError`
735  is raised.
736
737  The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
738  representable positive number such that `1 + eps != 1`.  This is about
739  `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
740  See `numpy.finfo`.
741
742  Args:
743    x: Float or complex `Tensor`.
744    y: Float or complex `Tensor`, same dtype as and broadcastable to `x`.
745    rtol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
746      The relative tolerance.  Default is `10 * eps`.
747    atol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
748      The absolute tolerance.  Default is `10 * eps`.
749    message: A string to prefix to the default message.
750    summarize: Print this many entries of each tensor.
751    name: A name for this operation (optional).  Defaults to "assert_near".
752
753  Returns:
754    Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
755      This can be used with `tf.control_dependencies` inside of `tf.function`s
756      to block followup computation until the check has executed.
757    @compatibility(eager)
758    returns None
759    @end_compatibility
760
761  Raises:
762    InvalidArgumentError: if the check can be performed immediately and
763      `x != y` is False for any pair of elements in `x` and `y`. The check can
764      be performed immediately during eager execution or if `x` and `y` are
765      statically known.
766
767  @compatibility(numpy)
768  Similar to `numpy.testing.assert_allclose`, except tolerance depends on data
769  type. This is due to the fact that `TensorFlow` is often used with `32bit`,
770  `64bit`, and even `16bit` data.
771  @end_compatibility
772  """
773  return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize,
774                     message=message, name=name)
775
776
777@tf_export(v1=['debugging.assert_near', 'assert_near'])
778@dispatch.add_dispatch_support
779@deprecation.deprecated_endpoints('assert_near')
780def assert_near(
781    x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
782    name=None):
783  """Assert the condition `x` and `y` are close element-wise.
784
785  Example of adding a dependency to an operation:
786
787  ```python
788  with tf.control_dependencies([tf.compat.v1.assert_near(x, y)]):
789    output = tf.reduce_sum(x)
790  ```
791
792  This condition holds if for every pair of (possibly broadcast) elements
793  `x[i]`, `y[i]`, we have
794
795  ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```.
796
797  If both `x` and `y` are empty, this is trivially satisfied.
798
799  The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
800  representable positive number such that `1 + eps != 1`.  This is about
801  `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
802  See `numpy.finfo`.
803
804  Args:
805    x:  Float or complex `Tensor`.
806    y:  Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`.
807    rtol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
808      The relative tolerance.  Default is `10 * eps`.
809    atol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
810      The absolute tolerance.  Default is `10 * eps`.
811    data:  The tensors to print out if the condition is False.  Defaults to
812      error message and first few entries of `x`, `y`.
813    summarize: Print this many entries of each tensor.
814    message: A string to prefix to the default message.
815    name: A name for this operation (optional).  Defaults to "assert_near".
816
817  Returns:
818    Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
819
820  @compatibility(numpy)
821  Similar to `numpy.testing.assert_allclose`, except tolerance depends on data
822  type. This is due to the fact that `TensorFlow` is often used with `32bit`,
823  `64bit`, and even `16bit` data.
824  @end_compatibility
825  """
826  message = message or ''
827  with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]):
828    x = ops.convert_to_tensor(x, name='x')
829    y = ops.convert_to_tensor(y, name='y', dtype=x.dtype)
830
831    dtype = x.dtype
832    if dtype.is_complex:
833      dtype = dtype.real_dtype
834    eps = np.finfo(dtype.as_numpy_dtype).eps
835    rtol = 10 * eps if rtol is None else rtol
836    atol = 10 * eps if atol is None else atol
837
838    rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype)
839    atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype)
840
841    if context.executing_eagerly():
842      x_name = _shape_and_dtype_str(x)
843      y_name = _shape_and_dtype_str(y)
844    else:
845      x_name = x.name
846      y_name = y.name
847
848    if data is None:
849      data = [
850          message,
851          'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol),
852          'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
853      ]
854    tol = atol + rtol * math_ops.abs(y)
855    diff = math_ops.abs(x - y)
856    condition = math_ops.reduce_all(math_ops.less(diff, tol))
857    return control_flow_ops.Assert(condition, data, summarize=summarize)
858
859
860@tf_export('debugging.assert_less', 'assert_less', v1=[])
861@dispatch.add_dispatch_support
862def assert_less_v2(x, y, message=None, summarize=None, name=None):
863  """Assert the condition `x < y` holds element-wise.
864
865  This Op checks that `x[i] < y[i]` holds for every pair of (possibly
866  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
867  trivially satisfied.
868
869  If `x` is not less than `y` element-wise, `message`, as well as the first
870  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is
871  raised.
872
873  Args:
874    x:  Numeric `Tensor`.
875    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
876    message: A string to prefix to the default message.
877    summarize: Print this many entries of each tensor.
878    name: A name for this operation (optional).  Defaults to "assert_less".
879
880  Returns:
881    Op that raises `InvalidArgumentError` if `x < y` is False.
882    This can be used with `tf.control_dependencies` inside of `tf.function`s
883    to block followup computation until the check has executed.
884    @compatibility(eager)
885    returns None
886    @end_compatibility
887
888  Raises:
889    InvalidArgumentError: if the check can be performed immediately and
890      `x < y` is False. The check can be performed immediately during eager
891      execution or if `x` and `y` are statically known.
892  """
893  return assert_less(x=x, y=y, summarize=summarize, message=message, name=name)
894
895
896@tf_export(v1=['debugging.assert_less', 'assert_less'])
897@dispatch.add_dispatch_support
898@_binary_assert_doc('<')
899def assert_less(x, y, data=None, summarize=None, message=None, name=None):
900  return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data,
901                        summarize, message, name)
902
903
904@tf_export('debugging.assert_less_equal', v1=[])
905@dispatch.add_dispatch_support
906def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
907  """Assert the condition `x <= y` holds element-wise.
908
909  This Op checks that `x[i] <= y[i]` holds for every pair of (possibly
910  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
911  trivially satisfied.
912
913  If `x` is not less or equal than `y` element-wise, `message`, as well as the
914  first `summarize` entries of `x` and `y` are printed, and
915  `InvalidArgumentError` is raised.
916
917  Args:
918    x:  Numeric `Tensor`.
919    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
920    message: A string to prefix to the default message.
921    summarize: Print this many entries of each tensor.
922    name: A name for this operation (optional). Defaults to "assert_less_equal".
923
924  Returns:
925    Op that raises `InvalidArgumentError` if `x <= y` is False. This can be
926      used with `tf.control_dependencies` inside of `tf.function`s to block
927      followup computation until the check has executed.
928    @compatibility(eager)
929    returns None
930    @end_compatibility
931
932  Raises:
933    InvalidArgumentError: if the check can be performed immediately and
934      `x <= y` is False. The check can be performed immediately during eager
935      execution or if `x` and `y` are statically known.
936  """
937  return assert_less_equal(x=x, y=y,
938                           summarize=summarize, message=message, name=name)
939
940
941@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal'])
942@dispatch.add_dispatch_support
943@deprecation.deprecated_endpoints('assert_less_equal')
944@_binary_assert_doc('<=')
945def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
946  return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal,
947                        np.less_equal, x, y, data, summarize, message, name)
948
949
950@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
951@dispatch.add_dispatch_support
952def assert_greater_v2(x, y, message=None, summarize=None, name=None):
953  """Assert the condition `x > y` holds element-wise.
954
955  This Op checks that `x[i] > y[i]` holds for every pair of (possibly
956  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
957  trivially satisfied.
958
959  If `x` is not greater than `y` element-wise, `message`, as well as the first
960  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is
961  raised.
962
963  Args:
964    x:  Numeric `Tensor`.
965    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
966    message: A string to prefix to the default message.
967    summarize: Print this many entries of each tensor.
968    name: A name for this operation (optional).  Defaults to "assert_greater".
969
970  Returns:
971    Op that raises `InvalidArgumentError` if `x > y` is False. This can be
972      used with `tf.control_dependencies` inside of `tf.function`s to block
973      followup computation until the check has executed.
974    @compatibility(eager)
975    returns None
976    @end_compatibility
977
978  Raises:
979    InvalidArgumentError: if the check can be performed immediately and
980      `x > y` is False. The check can be performed immediately during eager
981      execution or if `x` and `y` are statically known.
982  """
983  return assert_greater(x=x, y=y, summarize=summarize, message=message,
984                        name=name)
985
986
987@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
988@dispatch.add_dispatch_support
989@_binary_assert_doc('>')
990def assert_greater(x, y, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
991  return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x,
992                        y, data, summarize, message, name)
993
994
995@tf_export('debugging.assert_greater_equal', v1=[])
996@dispatch.add_dispatch_support
997def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
998  """Assert the condition `x >= y` holds element-wise.
999
1000  This Op checks that `x[i] >= y[i]` holds for every pair of (possibly
1001  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
1002  trivially satisfied.
1003
1004  If `x` is not greater or equal to `y` element-wise, `message`, as well as the
1005  first `summarize` entries of `x` and `y` are printed, and
1006  `InvalidArgumentError` is raised.
1007
1008  Args:
1009    x:  Numeric `Tensor`.
1010    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
1011    message: A string to prefix to the default message.
1012    summarize: Print this many entries of each tensor.
1013    name: A name for this operation (optional).  Defaults to
1014    "assert_greater_equal".
1015
1016  Returns:
1017    Op that raises `InvalidArgumentError` if `x >= y` is False. This can be
1018      used with `tf.control_dependencies` inside of `tf.function`s to block
1019      followup computation until the check has executed.
1020    @compatibility(eager)
1021    returns None
1022    @end_compatibility
1023
1024  Raises:
1025    InvalidArgumentError: if the check can be performed immediately and
1026      `x >= y` is False. The check can be performed immediately during eager
1027      execution or if `x` and `y` are statically known.
1028  """
1029  return assert_greater_equal(x=x, y=y, summarize=summarize, message=message,
1030                              name=name)
1031
1032
1033@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal'])
1034@dispatch.add_dispatch_support
1035@deprecation.deprecated_endpoints('assert_greater_equal')
1036@_binary_assert_doc('>=')
1037def assert_greater_equal(x, y, data=None, summarize=None, message=None,
1038                         name=None):
1039  return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal,
1040                        np.greater_equal, x, y, data, summarize, message, name)
1041
1042
1043def _assert_rank_condition(
1044    x, rank, static_condition, dynamic_condition, data, summarize):
1045  """Assert `x` has a rank that satisfies a given condition.
1046
1047  Args:
1048    x:  Numeric `Tensor`.
1049    rank:  Scalar `Tensor`.
1050    static_condition:   A python function that takes `[actual_rank, given_rank]`
1051      and returns `True` if the condition is satisfied, `False` otherwise.
1052    dynamic_condition:  An `op` that takes [actual_rank, given_rank] and return
1053      `True` if the condition is satisfied, `False` otherwise.
1054    data:  The tensors to print out if the condition is false.  Defaults to
1055      error message and first few entries of `x`.
1056    summarize: Print this many entries of each tensor.
1057
1058  Returns:
1059    Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
1060
1061  Raises:
1062    ValueError:  If static checks determine `x` fails static_condition.
1063  """
1064  assert_type(rank, dtypes.int32)
1065
1066  # Attempt to statically defined rank.
1067  rank_static = tensor_util.constant_value(rank)
1068  if rank_static is not None:
1069    if rank_static.ndim != 0:
1070      raise ValueError('Rank must be a scalar.')
1071
1072    x_rank_static = x.get_shape().ndims
1073    if x_rank_static is not None:
1074      if not static_condition(x_rank_static, rank_static):
1075        raise ValueError(
1076            'Static rank condition failed', x_rank_static, rank_static)
1077      return control_flow_ops.no_op(name='static_checks_determined_all_ok')
1078
1079  condition = dynamic_condition(array_ops.rank(x), rank)
1080
1081  # Add the condition that `rank` must have rank zero.  Prevents the bug where
1082  # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
1083  if rank_static is None:
1084    this_data = ['Rank must be a scalar. Received rank: ', rank]
1085    rank_check = assert_rank(rank, 0, data=this_data)
1086    condition = control_flow_ops.with_dependencies([rank_check], condition)
1087
1088  return control_flow_ops.Assert(condition, data, summarize=summarize)
1089
1090
1091@tf_export('debugging.assert_rank', 'assert_rank', v1=[])
1092@dispatch.add_dispatch_support
1093def assert_rank_v2(x, rank, message=None, name=None):
1094  """Assert that `x` has rank equal to `rank`.
1095
1096  This Op checks that the rank of `x` is equal to `rank`.
1097
1098  If `x` has a different rank, `message`, as well as the shape of `x` are
1099  printed, and `InvalidArgumentError` is raised.
1100
1101  Args:
1102    x: `Tensor`.
1103    rank: Scalar integer `Tensor`.
1104    message: A string to prefix to the default message.
1105    name: A name for this operation (optional). Defaults to
1106      "assert_rank".
1107
1108  Returns:
1109    Op raising `InvalidArgumentError` unless `x` has specified rank.
1110    If static checks determine `x` has correct rank, a `no_op` is returned.
1111    This can be used with `tf.control_dependencies` inside of `tf.function`s
1112    to block followup computation until the check has executed.
1113    @compatibility(eager)
1114    returns None
1115    @end_compatibility
1116
1117  Raises:
1118    InvalidArgumentError: if the check can be performed immediately and
1119      `x` does not have rank `rank`. The check can be performed immediately
1120      during eager execution or if the shape of `x` is statically known.
1121  """
1122  return assert_rank(x=x, rank=rank, message=message, name=name)
1123
1124
1125@tf_export(v1=['debugging.assert_rank', 'assert_rank'])
1126@dispatch.add_dispatch_support
1127def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
1128  """Assert `x` has rank equal to `rank`.
1129
1130  Example of adding a dependency to an operation:
1131
1132  ```python
1133  with tf.control_dependencies([tf.compat.v1.assert_rank(x, 2)]):
1134    output = tf.reduce_sum(x)
1135  ```
1136
1137  Args:
1138    x:  Numeric `Tensor`.
1139    rank:  Scalar integer `Tensor`.
1140    data:  The tensors to print out if the condition is False.  Defaults to
1141      error message and the shape of `x`.
1142    summarize: Print this many entries of each tensor.
1143    message: A string to prefix to the default message.
1144    name: A name for this operation (optional).  Defaults to "assert_rank".
1145
1146  Returns:
1147    Op raising `InvalidArgumentError` unless `x` has specified rank.
1148    If static checks determine `x` has correct rank, a `no_op` is returned.
1149
1150  Raises:
1151    ValueError:  If static checks determine `x` has wrong rank.
1152  """
1153  with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])):
1154    if not isinstance(x, sparse_tensor.SparseTensor):
1155      x = ops.convert_to_tensor(x, name='x')
1156    rank = ops.convert_to_tensor(rank, name='rank')
1157    message = message or ''
1158
1159    static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
1160    dynamic_condition = math_ops.equal
1161
1162    if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor):
1163      name = ''
1164    else:
1165      name = x.name
1166
1167    if data is None:
1168      data = [
1169          message,
1170          'Tensor %s must have rank' % name, rank, 'Received shape: ',
1171          array_ops.shape(x)
1172      ]
1173
1174    try:
1175      assert_op = _assert_rank_condition(x, rank, static_condition,
1176                                         dynamic_condition, data, summarize)
1177
1178    except ValueError as e:
1179      if e.args[0] == 'Static rank condition failed':
1180        raise ValueError(
1181            '%s.  Tensor %s must have rank %d.  Received rank %d, shape %s' %
1182            (message, name, e.args[2], e.args[1], x.get_shape()))
1183      else:
1184        raise
1185
1186  return assert_op
1187
1188
1189@tf_export('debugging.assert_rank_at_least', v1=[])
1190@dispatch.add_dispatch_support
1191def assert_rank_at_least_v2(x, rank, message=None, name=None):
1192  """Assert that `x` has rank of at least `rank`.
1193
1194  This Op checks that the rank of `x` is greater or equal to `rank`.
1195
1196  If `x` has a rank lower than `rank`, `message`, as well as the shape of `x`
1197  are printed, and `InvalidArgumentError` is raised.
1198
1199  Args:
1200    x: `Tensor`.
1201    rank: Scalar integer `Tensor`.
1202    message: A string to prefix to the default message.
1203    name: A name for this operation (optional).  Defaults to
1204      "assert_rank_at_least".
1205
1206  Returns:
1207    Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
1208    If static checks determine `x` has correct rank, a `no_op` is returned.
1209    This can be used with `tf.control_dependencies` inside of `tf.function`s
1210    to block followup computation until the check has executed.
1211    @compatibility(eager)
1212    returns None
1213    @end_compatibility
1214
1215  Raises:
1216    InvalidArgumentError: `x` does not have rank at least `rank`, but the rank
1217      cannot be statically determined.
1218    ValueError: If static checks determine `x` has mismatched rank.
1219  """
1220  return assert_rank_at_least(x=x, rank=rank, message=message, name=name)
1221
1222
1223@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least'])
1224@dispatch.add_dispatch_support
1225@deprecation.deprecated_endpoints('assert_rank_at_least')
1226def assert_rank_at_least(
1227    x, rank, data=None, summarize=None, message=None, name=None):
1228  """Assert `x` has rank equal to `rank` or higher.
1229
1230  Example of adding a dependency to an operation:
1231
1232  ```python
1233  with tf.control_dependencies([tf.compat.v1.assert_rank_at_least(x, 2)]):
1234    output = tf.reduce_sum(x)
1235  ```
1236
1237  Args:
1238    x:  Numeric `Tensor`.
1239    rank:  Scalar `Tensor`.
1240    data:  The tensors to print out if the condition is False.  Defaults to
1241      error message and first few entries of `x`.
1242    summarize: Print this many entries of each tensor.
1243    message: A string to prefix to the default message.
1244    name: A name for this operation (optional).
1245      Defaults to "assert_rank_at_least".
1246
1247  Returns:
1248    Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
1249    If static checks determine `x` has correct rank, a `no_op` is returned.
1250
1251  Raises:
1252    ValueError:  If static checks determine `x` has wrong rank.
1253  """
1254  with ops.name_scope(
1255      name, 'assert_rank_at_least', (x, rank) + tuple(data or [])):
1256    x = ops.convert_to_tensor(x, name='x')
1257    rank = ops.convert_to_tensor(rank, name='rank')
1258    message = message or ''
1259
1260    static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
1261    dynamic_condition = math_ops.greater_equal
1262
1263    if context.executing_eagerly():
1264      name = ''
1265    else:
1266      name = x.name
1267
1268    if data is None:
1269      data = [
1270          message,
1271          'Tensor %s must have rank at least' % name, rank,
1272          'Received shape: ', array_ops.shape(x)
1273      ]
1274
1275    try:
1276      assert_op = _assert_rank_condition(x, rank, static_condition,
1277                                         dynamic_condition, data, summarize)
1278
1279    except ValueError as e:
1280      if e.args[0] == 'Static rank condition failed':
1281        raise ValueError(
1282            '%s.  Tensor %s must have rank at least %d.  Received rank %d, '
1283            'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
1284      else:
1285        raise
1286
1287  return assert_op
1288
1289
1290def _static_rank_in(actual_rank, given_ranks):
1291  return actual_rank in given_ranks
1292
1293
1294def _dynamic_rank_in(actual_rank, given_ranks):
1295  if len(given_ranks) < 1:
1296    return ops.convert_to_tensor(False)
1297  result = math_ops.equal(given_ranks[0], actual_rank)
1298  for given_rank in given_ranks[1:]:
1299    result = math_ops.logical_or(
1300        result, math_ops.equal(given_rank, actual_rank))
1301  return result
1302
1303
1304def _assert_ranks_condition(
1305    x, ranks, static_condition, dynamic_condition, data, summarize):
1306  """Assert `x` has a rank that satisfies a given condition.
1307
1308  Args:
1309    x:  Numeric `Tensor`.
1310    ranks:  Scalar `Tensor`.
1311    static_condition:   A python function that takes
1312      `[actual_rank, given_ranks]` and returns `True` if the condition is
1313      satisfied, `False` otherwise.
1314    dynamic_condition:  An `op` that takes [actual_rank, given_ranks]
1315      and return `True` if the condition is satisfied, `False` otherwise.
1316    data:  The tensors to print out if the condition is false.  Defaults to
1317      error message and first few entries of `x`.
1318    summarize: Print this many entries of each tensor.
1319
1320  Returns:
1321    Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
1322
1323  Raises:
1324    ValueError:  If static checks determine `x` fails static_condition.
1325  """
1326  for rank in ranks:
1327    assert_type(rank, dtypes.int32)
1328
1329  # Attempt to statically defined rank.
1330  ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks])
1331  if not any(r is None for r in ranks_static):
1332    for rank_static in ranks_static:
1333      if rank_static.ndim != 0:
1334        raise ValueError('Rank must be a scalar.')
1335
1336    x_rank_static = x.get_shape().ndims
1337    if x_rank_static is not None:
1338      if not static_condition(x_rank_static, ranks_static):
1339        raise ValueError(
1340            'Static rank condition failed', x_rank_static, ranks_static)
1341      return control_flow_ops.no_op(name='static_checks_determined_all_ok')
1342
1343  condition = dynamic_condition(array_ops.rank(x), ranks)
1344
1345  # Add the condition that `rank` must have rank zero.  Prevents the bug where
1346  # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
1347  for rank, rank_static in zip(ranks, ranks_static):
1348    if rank_static is None:
1349      this_data = ['Rank must be a scalar. Received rank: ', rank]
1350      rank_check = assert_rank(rank, 0, data=this_data)
1351      condition = control_flow_ops.with_dependencies([rank_check], condition)
1352
1353  return control_flow_ops.Assert(condition, data, summarize=summarize)
1354
1355
1356@tf_export('debugging.assert_rank_in', v1=[])
1357@dispatch.add_dispatch_support
1358def assert_rank_in_v2(x, ranks, message=None, name=None):
1359  """Assert that `x` has a rank in `ranks`.
1360
1361  This Op checks that the rank of `x` is in `ranks`.
1362
1363  If `x` has a different rank, `message`, as well as the shape of `x` are
1364  printed, and `InvalidArgumentError` is raised.
1365
1366  Args:
1367    x: `Tensor`.
1368    ranks: `Iterable` of scalar `Tensor` objects.
1369    message: A string to prefix to the default message.
1370    name: A name for this operation (optional). Defaults to "assert_rank_in".
1371
1372  Returns:
1373    Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
1374    If static checks determine `x` has matching rank, a `no_op` is returned.
1375    This can be used with `tf.control_dependencies` inside of `tf.function`s
1376    to block followup computation until the check has executed.
1377    @compatibility(eager)
1378    returns None
1379    @end_compatibility
1380
1381  Raises:
1382    InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot
1383      be statically determined.
1384    ValueError: If static checks determine `x` has mismatched rank.
1385  """
1386  return assert_rank_in(x=x, ranks=ranks, message=message, name=name)
1387
1388
1389@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in'])
1390@dispatch.add_dispatch_support
1391@deprecation.deprecated_endpoints('assert_rank_in')
1392def assert_rank_in(
1393    x, ranks, data=None, summarize=None, message=None, name=None):
1394  """Assert `x` has rank in `ranks`.
1395
1396  Example of adding a dependency to an operation:
1397
1398  ```python
1399  with tf.control_dependencies([tf.compat.v1.assert_rank_in(x, (2, 4))]):
1400    output = tf.reduce_sum(x)
1401  ```
1402
1403  Args:
1404    x:  Numeric `Tensor`.
1405    ranks:  Iterable of scalar `Tensor` objects.
1406    data:  The tensors to print out if the condition is False.  Defaults to
1407      error message and first few entries of `x`.
1408    summarize: Print this many entries of each tensor.
1409    message: A string to prefix to the default message.
1410    name: A name for this operation (optional).
1411      Defaults to "assert_rank_in".
1412
1413  Returns:
1414    Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
1415    If static checks determine `x` has matching rank, a `no_op` is returned.
1416
1417  Raises:
1418    ValueError:  If static checks determine `x` has mismatched rank.
1419  """
1420  with ops.name_scope(
1421      name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])):
1422    if not isinstance(x, sparse_tensor.SparseTensor):
1423      x = ops.convert_to_tensor(x, name='x')
1424    ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
1425    message = message or ''
1426
1427    if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor):
1428      name = ''
1429    else:
1430      name = x.name
1431
1432    if data is None:
1433      data = [
1434          message, 'Tensor %s must have rank in' % name
1435      ] + list(ranks) + [
1436          'Received shape: ', array_ops.shape(x)
1437      ]
1438
1439    try:
1440      assert_op = _assert_ranks_condition(x, ranks, _static_rank_in,
1441                                          _dynamic_rank_in, data, summarize)
1442
1443    except ValueError as e:
1444      if e.args[0] == 'Static rank condition failed':
1445        raise ValueError(
1446            '%s.  Tensor %s must have rank in %s.  Received rank %d, '
1447            'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
1448      else:
1449        raise
1450
1451  return assert_op
1452
1453
1454@tf_export('debugging.assert_integer', v1=[])
1455@dispatch.add_dispatch_support
1456def assert_integer_v2(x, message=None, name=None):
1457  """Assert that `x` is of integer dtype.
1458
1459  If `x` has a non-integer type, `message`, as well as the dtype of `x` are
1460  printed, and `InvalidArgumentError` is raised.
1461
1462  This can always be checked statically, so this method returns nothing.
1463
1464  Args:
1465    x: A `Tensor`.
1466    message: A string to prefix to the default message.
1467    name: A name for this operation (optional). Defaults to "assert_integer".
1468
1469  Raises:
1470    TypeError:  If `x.dtype` is not a non-quantized integer type.
1471  """
1472  assert_integer(x=x, message=message, name=name)
1473
1474
1475@tf_export(v1=['debugging.assert_integer', 'assert_integer'])
1476@dispatch.add_dispatch_support
1477@deprecation.deprecated_endpoints('assert_integer')
1478def assert_integer(x, message=None, name=None):
1479  """Assert that `x` is of integer dtype.
1480
1481  Example of adding a dependency to an operation:
1482
1483  ```python
1484  with tf.control_dependencies([tf.compat.v1.assert_integer(x)]):
1485    output = tf.reduce_sum(x)
1486  ```
1487
1488  Args:
1489    x: `Tensor` whose basetype is integer and is not quantized.
1490    message: A string to prefix to the default message.
1491    name: A name for this operation (optional).  Defaults to "assert_integer".
1492
1493  Raises:
1494    TypeError:  If `x.dtype` is anything other than non-quantized integer.
1495
1496  Returns:
1497    A `no_op` that does nothing.  Type can be determined statically.
1498  """
1499  message = message or ''
1500  with ops.name_scope(name, 'assert_integer', [x]):
1501    x = ops.convert_to_tensor(x, name='x')
1502    if not x.dtype.is_integer:
1503      if context.executing_eagerly():
1504        name = 'tensor'
1505      else:
1506        name = x.name
1507      err_msg = (
1508          '%s  Expected "x" to be integer type.  Found: %s of dtype %s'
1509          % (message, name, x.dtype))
1510      raise TypeError(err_msg)
1511
1512    return control_flow_ops.no_op('statically_determined_was_integer')
1513
1514
1515@tf_export('debugging.assert_type', v1=[])
1516@dispatch.add_dispatch_support
1517def assert_type_v2(tensor, tf_type, message=None, name=None):
1518  """Asserts that the given `Tensor` is of the specified type.
1519
1520  This can always be checked statically, so this method returns nothing.
1521
1522  Args:
1523    tensor: A `Tensor` or `SparseTensor`.
1524    tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
1525      etc).
1526    message: A string to prefix to the default message.
1527    name:  A name for this operation. Defaults to "assert_type"
1528
1529  Raises:
1530    TypeError: If the tensor's data type doesn't match `tf_type`.
1531  """
1532  assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name)
1533
1534
1535@tf_export(v1=['debugging.assert_type', 'assert_type'])
1536@dispatch.add_dispatch_support
1537@deprecation.deprecated_endpoints('assert_type')
1538def assert_type(tensor, tf_type, message=None, name=None):
1539  """Statically asserts that the given `Tensor` is of the specified type.
1540
1541  Args:
1542    tensor: A `Tensor` or `SparseTensor`.
1543    tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
1544      etc).
1545    message: A string to prefix to the default message.
1546    name:  A name to give this `Op`.  Defaults to "assert_type"
1547
1548  Raises:
1549    TypeError: If the tensors data type doesn't match `tf_type`.
1550
1551  Returns:
1552    A `no_op` that does nothing.  Type can be determined statically.
1553  """
1554  message = message or ''
1555  tf_type = dtypes.as_dtype(tf_type)
1556  with ops.name_scope(name, 'assert_type', [tensor]):
1557    if not isinstance(tensor, sparse_tensor.SparseTensor):
1558      tensor = ops.convert_to_tensor(tensor, name='tensor')
1559    if tensor.dtype != tf_type:
1560      if context.executing_eagerly():
1561        raise TypeError('%s tensor must be of type %s' % (message, tf_type))
1562      else:
1563        raise TypeError(
1564            '%s  %s must be of type %s' %
1565            (message, tensor.name if hasattr(tensor, 'name') else '', tf_type))
1566
1567    return control_flow_ops.no_op('statically_determined_correct_type')
1568
1569
1570def _dimension_sizes(x):
1571  """Gets the dimension sizes of a tensor `x`.
1572
1573  If a size can be determined statically it is returned as an integer,
1574  otherwise as a tensor.
1575
1576  If `x` is a scalar it is treated as rank 1 size 1.
1577
1578  Args:
1579    x: A `Tensor`.
1580
1581  Returns:
1582    Dimension sizes.
1583  """
1584  dynamic_shape = array_ops.shape(x)
1585  rank = x.get_shape().rank
1586  rank_is_known = rank is not None
1587  if rank_is_known and rank == 0:
1588    return (1,)
1589  if rank_is_known and rank > 0:
1590    static_shape = x.get_shape().as_list()
1591    sizes = [
1592        int(size) if size is not None else dynamic_shape[i]
1593        for i, size in enumerate(static_shape)
1594    ]
1595    return sizes
1596  has_rank_zero = math_ops.equal(array_ops.rank(x), 0)
1597  return control_flow_ops.cond(
1598      has_rank_zero, lambda: array_ops.constant([1]), lambda: dynamic_shape)
1599
1600
1601def _symbolic_dimension_sizes(symbolic_shape):
1602  # If len(symbolic_shape) == 0 construct a tuple
1603  if not symbolic_shape:
1604    return tuple([1])
1605
1606  return symbolic_shape
1607
1608
1609def _has_known_value(dimension_size):
1610  not_none = dimension_size is not None
1611  try:
1612    int(dimension_size)
1613    can_be_parsed_as_int = True
1614  except (ValueError, TypeError):
1615    can_be_parsed_as_int = False
1616  return not_none and can_be_parsed_as_int
1617
1618
1619def _is_symbol_for_any_size(symbol):
1620  return symbol in [None, '.']
1621
1622
1623_TensorDimSizes = collections.namedtuple(
1624    '_TensorDimSizes',
1625    ['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes'])
1626
1627
1628@tf_export('debugging.assert_shapes', v1=[])
1629@dispatch.add_dispatch_support
1630def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
1631                     name=None):
1632  """Assert tensor shapes and dimension size relationships between tensors.
1633
1634  This Op checks that a collection of tensors shape relationships
1635  satisfies given constraints.
1636
1637  Example:
1638
1639  >>> n = 10
1640  >>> q = 3
1641  >>> d = 7
1642  >>> x = tf.zeros([n,q])
1643  >>> y = tf.ones([n,d])
1644  >>> param = tf.Variable([1.0, 2.0, 3.0])
1645  >>> scalar = 1.0
1646  >>> tf.debugging.assert_shapes([
1647  ...  (x, ('N', 'Q')),
1648  ...  (y, ('N', 'D')),
1649  ...  (param, ('Q',)),
1650  ...  (scalar, ()),
1651  ... ])
1652
1653  >>> tf.debugging.assert_shapes([
1654  ...   (x, ('N', 'D')),
1655  ...   (y, ('N', 'D'))
1656  ... ])
1657  Traceback (most recent call last):
1658  ...
1659  ValueError: ...
1660
1661  If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
1662  all specified constraints, `message`, as well as the first `summarize` entries
1663  of the first encountered violating tensor are printed, and
1664  `InvalidArgumentError` is raised.
1665
1666  Size entries in the specified shapes are checked against other entries by
1667  their __hash__, except:
1668    - a size entry is interpreted as an explicit size if it can be parsed as an
1669      integer primitive.
1670    - a size entry is interpreted as *any* size if it is None or '.'.
1671
1672  If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates
1673  a variable number of outer dimensions of unspecified size, i.e. the constraint
1674  applies to the inner-most dimensions only.
1675
1676  Scalar tensors and specified shapes of length zero (excluding the 'inner-most'
1677  prefix) are both treated as having a single dimension of size one.
1678
1679  Args:
1680    shapes: dictionary with (`Tensor` to shape) items, or a list of
1681      (`Tensor`, shape) tuples. A shape must be an iterable.
1682    data: The tensors to print out if the condition is False.  Defaults to error
1683      message and first few entries of the violating tensor.
1684    summarize: Print this many entries of the tensor.
1685    message: A string to prefix to the default message.
1686    name: A name for this operation (optional).  Defaults to "assert_shapes".
1687
1688  Raises:
1689    ValueError:  If static checks determine any shape constraint is violated.
1690  """
1691  assert_shapes(
1692      shapes, data=data, summarize=summarize, message=message, name=name)
1693
1694
1695@tf_export(v1=['debugging.assert_shapes'])
1696@dispatch.add_dispatch_support
1697def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
1698  """Assert tensor shapes and dimension size relationships between tensors.
1699
1700  This Op checks that a collection of tensors shape relationships
1701  satisfies given constraints.
1702
1703  Example:
1704
1705  >>> n = 10
1706  >>> q = 3
1707  >>> d = 7
1708  >>> x = tf.zeros([n,q])
1709  >>> y = tf.ones([n,d])
1710  >>> param = tf.Variable([1.0, 2.0, 3.0])
1711  >>> scalar = 1.0
1712  >>> tf.debugging.assert_shapes([
1713  ...  (x, ('N', 'Q')),
1714  ...  (y, ('N', 'D')),
1715  ...  (param, ('Q',)),
1716  ...  (scalar, ()),
1717  ... ])
1718
1719  >>> tf.debugging.assert_shapes([
1720  ...   (x, ('N', 'D')),
1721  ...   (y, ('N', 'D'))
1722  ... ])
1723  Traceback (most recent call last):
1724  ...
1725  ValueError: ...
1726
1727  Example of adding a dependency to an operation:
1728
1729  ```python
1730  with tf.control_dependencies([tf.assert_shapes(shapes)]):
1731    output = tf.matmul(x, y, transpose_a=True)
1732  ```
1733
1734  If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
1735  all specified constraints, `message`, as well as the first `summarize` entries
1736  of the first encountered violating tensor are printed, and
1737  `InvalidArgumentError` is raised.
1738
1739  Size entries in the specified shapes are checked against other entries by
1740  their __hash__, except:
1741    - a size entry is interpreted as an explicit size if it can be parsed as an
1742      integer primitive.
1743    - a size entry is interpreted as *any* size if it is None or '.'.
1744
1745  If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates
1746  a variable number of outer dimensions of unspecified size, i.e. the constraint
1747  applies to the inner-most dimensions only.
1748
1749  Scalar tensors and specified shapes of length zero (excluding the 'inner-most'
1750  prefix) are both treated as having a single dimension of size one.
1751
1752  Args:
1753    shapes: A list of (`Tensor`, `shape`) tuples, wherein `shape` is the
1754      expected shape of `Tensor`. See the example code above. The `shape` must
1755      be an iterable. Each element of the iterable can be either a concrete
1756      integer value or a string that abstractly represents the dimension.
1757      For example,
1758        - `('N', 'Q')` specifies a 2D shape wherein the first and second
1759          dimensions of shape may or may not be equal.
1760        - `('N', 'N', 'Q')` specifies a 3D shape wherein the first and second
1761          dimensions are equal.
1762        - `(1, 'N')` specifies a 2D shape wherein the first dimension is
1763          exactly 1 and the second dimension can be any value.
1764      Note that the abstract dimension letters take effect across different
1765      tuple elements of the list. For example,
1766      `tf.debugging.assert_shapes([(x, ('N', 'A')), (y, ('N', 'B'))]` asserts
1767      that both `x` and `y` are rank-2 tensors and their first dimensions are
1768      equal (`N`).
1769      `shape` can also be a `tf.TensorShape`.
1770    data: The tensors to print out if the condition is False.  Defaults to error
1771      message and first few entries of the violating tensor.
1772    summarize: Print this many entries of the tensor.
1773    message: A string to prefix to the default message.
1774    name: A name for this operation (optional).  Defaults to "assert_shapes".
1775
1776  Returns:
1777    Op raising `InvalidArgumentError` unless all shape constraints are
1778    satisfied.
1779    If static checks determine all constraints are satisfied, a `no_op` is
1780    returned.
1781
1782  Raises:
1783    ValueError:  If static checks determine any shape constraint is violated.
1784  """
1785  # If the user manages to assemble a dict containing tensors (possible in
1786  # Graph mode only), make sure we still accept that.
1787  if isinstance(shapes, dict):
1788    shapes = shapes.items()
1789
1790  message = message or ''
1791  with ops.name_scope(name, 'assert_shapes', [shapes, data]):
1792    # Shape specified as None implies no constraint
1793    shape_constraints = [(x if isinstance(x, sparse_tensor.SparseTensor) else
1794                          ops.convert_to_tensor(x), s)
1795                         for x, s in shapes if s is not None]
1796
1797    executing_eagerly = context.executing_eagerly()
1798
1799    def tensor_name(x):
1800      if executing_eagerly or isinstance(x, sparse_tensor.SparseTensor):
1801        return _shape_and_dtype_str(x)
1802      return x.name
1803
1804    tensor_dim_sizes = []
1805    for tensor, symbolic_shape in shape_constraints:
1806      is_iterable = (
1807          hasattr(symbolic_shape, '__iter__') or
1808          hasattr(symbolic_shape, '__getitem__')  # For Python 2 compat.
1809      )
1810      if not is_iterable:
1811        raise ValueError(
1812            '%s.  '
1813            'Tensor %s.  Specified shape must be an iterable.  '
1814            'An iterable has the attribute `__iter__` or `__getitem__`.  '
1815            'Received specified shape: %s' %
1816            (message, tensor_name(tensor), symbolic_shape))
1817
1818      # We convert this into a tuple to handle strings, lists and numpy arrays
1819      symbolic_shape_tuple = tuple(symbolic_shape)
1820
1821      tensors_specified_innermost = False
1822      for i, symbol in enumerate(symbolic_shape_tuple):
1823        if symbol not in [Ellipsis, '*']:
1824          continue
1825
1826        if i != 0:
1827          raise ValueError(
1828              '%s.  '
1829              'Tensor %s specified shape index %d.  '
1830              'Symbol `...` or `*` for a variable number of '
1831              'unspecified dimensions is only allowed as the first entry' %
1832              (message, tensor_name(tensor), i))
1833
1834        tensors_specified_innermost = True
1835
1836      # Only include the size of the specified dimensions since the 0th symbol
1837      # is either ellipsis or *
1838      tensor_dim_sizes.append(
1839          _TensorDimSizes(
1840              tensor, tensors_specified_innermost, _dimension_sizes(tensor),
1841              _symbolic_dimension_sizes(
1842                  symbolic_shape_tuple[1:]
1843                  if tensors_specified_innermost else symbolic_shape_tuple)))
1844
1845    rank_assertions = []
1846    for sizes in tensor_dim_sizes:
1847      rank = len(sizes.symbolic_sizes)
1848      rank_zero_or_one = rank in [0, 1]
1849      if sizes.unspecified_dim:
1850        if rank_zero_or_one:
1851          # No assertion of rank needed as `x` only need to have rank at least
1852          # 0. See elif rank_zero_or_one case comment.
1853          continue
1854        assertion = assert_rank_at_least(
1855            x=sizes.x,
1856            rank=rank,
1857            data=data,
1858            summarize=summarize,
1859            message=message,
1860            name=name)
1861      elif rank_zero_or_one:
1862        # Rank 0 is treated as rank 1 size 1, i.e. there is
1863        # no distinction between the two in terms of rank.
1864        # See _dimension_sizes.
1865        assertion = assert_rank_in(
1866            x=sizes.x,
1867            ranks=[0, 1],
1868            data=data,
1869            summarize=summarize,
1870            message=message,
1871            name=name)
1872      else:
1873        assertion = assert_rank(
1874            x=sizes.x,
1875            rank=rank,
1876            data=data,
1877            summarize=summarize,
1878            message=message,
1879            name=name)
1880      rank_assertions.append(assertion)
1881
1882    size_assertions = []
1883    size_specifications = {}
1884    for sizes in tensor_dim_sizes:
1885      for i, size_symbol in enumerate(sizes.symbolic_sizes):
1886
1887        if _is_symbol_for_any_size(size_symbol):
1888          # Size specified as any implies no constraint
1889          continue
1890
1891        if sizes.unspecified_dim:
1892          tensor_dim = i - len(sizes.symbolic_sizes)
1893        else:
1894          tensor_dim = i
1895
1896        if size_symbol in size_specifications or _has_known_value(size_symbol):
1897          if _has_known_value(size_symbol):
1898            specified_size = int(size_symbol)
1899            size_check_message = 'Specified explicitly'
1900          else:
1901            specified_size, specified_by_y, specified_at_dim = \
1902                size_specifications[size_symbol]
1903            size_check_message = (
1904                'Specified by tensor %s dimension %d' %
1905                (tensor_name(specified_by_y), specified_at_dim))
1906
1907          # This is extremely subtle. If actual_sizes is dynamic, we must
1908          # make sure a control dependency is inserted here so that this slice
1909          # can not execute until the rank is asserted to be enough for the
1910          # slice to not fail.
1911          with ops.control_dependencies(rank_assertions):
1912            actual_size = sizes.actual_sizes[tensor_dim]
1913          if _has_known_value(actual_size) and _has_known_value(specified_size):
1914            if int(actual_size) != int(specified_size):
1915              raise ValueError(
1916                  '%s.  %s.  Tensor %s dimension %s must have size %d.  '
1917                  'Received size %d, shape %s' %
1918                  (message, size_check_message, tensor_name(sizes.x),
1919                   tensor_dim, specified_size, actual_size,
1920                   sizes.x.get_shape()))
1921            # No dynamic assertion needed
1922            continue
1923
1924          condition = math_ops.equal(
1925              ops.convert_to_tensor(actual_size),
1926              ops.convert_to_tensor(specified_size))
1927          data_ = data
1928          if data is None:
1929            data_ = [
1930                message, size_check_message,
1931                'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim,
1932                'must have size', specified_size, 'Received shape: ',
1933                array_ops.shape(sizes.x)
1934            ]
1935          size_assertions.append(
1936              control_flow_ops.Assert(condition, data_, summarize=summarize))
1937        else:
1938          # Not sure if actual_sizes is a constant, but for safety, guard
1939          # on rank. See explanation above about actual_sizes need for safety.
1940          with ops.control_dependencies(rank_assertions):
1941            size = sizes.actual_sizes[tensor_dim]
1942          size_specifications[size_symbol] = (size, sizes.x, tensor_dim)
1943
1944  # Ensure both assertions actually occur.
1945  with ops.control_dependencies(rank_assertions):
1946    shapes_assertion = control_flow_ops.group(size_assertions)
1947
1948  return shapes_assertion
1949
1950
1951# pylint: disable=line-too-long
1952def _get_diff_for_monotonic_comparison(x):
1953  """Gets the difference x[1:] - x[:-1]."""
1954  x = array_ops.reshape(x, [-1])
1955  if not is_numeric_tensor(x):
1956    raise TypeError('Expected x to be numeric, instead found: %s' % x)
1957
1958  # If x has less than 2 elements, there is nothing to compare.  So return [].
1959  is_shorter_than_two = math_ops.less(array_ops.size(x), 2)
1960  short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype)
1961
1962  # With 2 or more elements, return x[1:] - x[:-1]
1963  s_len = array_ops.shape(x) - 1
1964  diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len)
1965  return control_flow_ops.cond(is_shorter_than_two, short_result, diff)
1966
1967
1968@tf_export(
1969    'debugging.is_numeric_tensor',
1970    v1=['debugging.is_numeric_tensor', 'is_numeric_tensor'])
1971@deprecation.deprecated_endpoints('is_numeric_tensor')
1972def is_numeric_tensor(tensor):
1973  """Returns `True` if the elements of `tensor` are numbers.
1974
1975  Specifically, returns `True` if the dtype of `tensor` is one of the following:
1976
1977  * `tf.float32`
1978  * `tf.float64`
1979  * `tf.int8`
1980  * `tf.int16`
1981  * `tf.int32`
1982  * `tf.int64`
1983  * `tf.uint8`
1984  * `tf.qint8`
1985  * `tf.qint32`
1986  * `tf.quint8`
1987  * `tf.complex64`
1988
1989  Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not
1990  a `tf.Tensor` object.
1991  """
1992  return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
1993
1994
1995@tf_export(
1996    'math.is_non_decreasing',
1997    v1=[
1998        'math.is_non_decreasing', 'debugging.is_non_decreasing',
1999        'is_non_decreasing'
2000    ])
2001@dispatch.add_dispatch_support
2002@deprecation.deprecated_endpoints('debugging.is_non_decreasing',
2003                                  'is_non_decreasing')
2004def is_non_decreasing(x, name=None):
2005  """Returns `True` if `x` is non-decreasing.
2006
2007  Elements of `x` are compared in row-major order.  The tensor `[x[0],...]`
2008  is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`.
2009  If `x` has less than two elements, it is trivially non-decreasing.
2010
2011  See also:  `is_strictly_increasing`
2012
2013  >>> x1 = tf.constant([1.0, 1.0, 3.0])
2014  >>> tf.math.is_non_decreasing(x1)
2015  <tf.Tensor: shape=(), dtype=bool, numpy=True>
2016  >>> x2 = tf.constant([3.0, 1.0, 2.0])
2017  >>> tf.math.is_non_decreasing(x2)
2018  <tf.Tensor: shape=(), dtype=bool, numpy=False>
2019
2020  Args:
2021    x: Numeric `Tensor`.
2022    name: A name for this operation (optional).  Defaults to "is_non_decreasing"
2023
2024  Returns:
2025    Boolean `Tensor`, equal to `True` iff `x` is non-decreasing.
2026
2027  Raises:
2028    TypeError: if `x` is not a numeric tensor.
2029  """
2030  with ops.name_scope(name, 'is_non_decreasing', [x]):
2031    diff = _get_diff_for_monotonic_comparison(x)
2032    # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True.
2033    zero = ops.convert_to_tensor(0, dtype=diff.dtype)
2034    return math_ops.reduce_all(math_ops.less_equal(zero, diff))
2035
2036
2037@tf_export(
2038    'math.is_strictly_increasing',
2039    v1=[
2040        'math.is_strictly_increasing', 'debugging.is_strictly_increasing',
2041        'is_strictly_increasing'
2042    ])
2043@dispatch.add_dispatch_support
2044@deprecation.deprecated_endpoints('debugging.is_strictly_increasing',
2045                                  'is_strictly_increasing')
2046def is_strictly_increasing(x, name=None):
2047  """Returns `True` if `x` is strictly increasing.
2048
2049  Elements of `x` are compared in row-major order.  The tensor `[x[0],...]`
2050  is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`.
2051  If `x` has less than two elements, it is trivially strictly increasing.
2052
2053  See also:  `is_non_decreasing`
2054
2055  >>> x1 = tf.constant([1.0, 2.0, 3.0])
2056  >>> tf.math.is_strictly_increasing(x1)
2057  <tf.Tensor: shape=(), dtype=bool, numpy=True>
2058  >>> x2 = tf.constant([3.0, 1.0, 2.0])
2059  >>> tf.math.is_strictly_increasing(x2)
2060  <tf.Tensor: shape=(), dtype=bool, numpy=False>
2061
2062  Args:
2063    x: Numeric `Tensor`.
2064    name: A name for this operation (optional).
2065      Defaults to "is_strictly_increasing"
2066
2067  Returns:
2068    Boolean `Tensor`, equal to `True` iff `x` is strictly increasing.
2069
2070  Raises:
2071    TypeError: if `x` is not a numeric tensor.
2072  """
2073  with ops.name_scope(name, 'is_strictly_increasing', [x]):
2074    diff = _get_diff_for_monotonic_comparison(x)
2075    # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True.
2076    zero = ops.convert_to_tensor(0, dtype=diff.dtype)
2077    return math_ops.reduce_all(math_ops.less(zero, diff))
2078
2079
2080def _assert_same_base_type(items, expected_type=None):
2081  r"""Asserts all items are of the same base type.
2082
2083  Args:
2084    items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
2085        `Operation`, or `IndexedSlices`). Can include `None` elements, which
2086        will be ignored.
2087    expected_type: Expected type. If not specified, assert all items are
2088        of the same base type.
2089
2090  Returns:
2091    Validated type, or none if neither expected_type nor items provided.
2092
2093  Raises:
2094    ValueError: If any types do not match.
2095  """
2096  original_expected_type = expected_type
2097  mismatch = False
2098  for item in items:
2099    if item is not None:
2100      item_type = item.dtype.base_dtype
2101      if not expected_type:
2102        expected_type = item_type
2103      elif expected_type != item_type:
2104        mismatch = True
2105        break
2106  if mismatch:
2107    # Loop back through and build up an informative error message (this is very
2108    # slow, so we don't do it unless we found an error above).
2109    expected_type = original_expected_type
2110    original_item_str = None
2111    for item in items:
2112      if item is not None:
2113        item_type = item.dtype.base_dtype
2114        if not expected_type:
2115          expected_type = item_type
2116          original_item_str = item.name if hasattr(item, 'name') else str(item)
2117        elif expected_type != item_type:
2118          raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % (
2119              item.name if hasattr(item, 'name') else str(item),
2120              item_type, expected_type,
2121              (' as %s' % original_item_str) if original_item_str else ''))
2122    return expected_type  # Should be unreachable
2123  else:
2124    return expected_type
2125
2126
2127@tf_export(
2128    'debugging.assert_same_float_dtype',
2129    v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype'])
2130@dispatch.add_dispatch_support
2131@deprecation.deprecated_endpoints('assert_same_float_dtype')
2132def assert_same_float_dtype(tensors=None, dtype=None):
2133  """Validate and return float type based on `tensors` and `dtype`.
2134
2135  For ops such as matrix multiplication, inputs and weights must be of the
2136  same float type. This function validates that all `tensors` are the same type,
2137  validates that type is `dtype` (if supplied), and returns the type. Type must
2138  be a floating point type. If neither `tensors` nor `dtype` is supplied,
2139  the function will return `dtypes.float32`.
2140
2141  Args:
2142    tensors: Tensors of input values. Can include `None` elements, which will be
2143        ignored.
2144    dtype: Expected type.
2145
2146  Returns:
2147    Validated type.
2148
2149  Raises:
2150    ValueError: if neither `tensors` nor `dtype` is supplied, or result is not
2151        float, or the common type of the inputs is not a floating point type.
2152  """
2153  if tensors:
2154    dtype = _assert_same_base_type(tensors, dtype)
2155  if not dtype:
2156    dtype = dtypes.float32
2157  elif not dtype.is_floating:
2158    raise ValueError('Expected floating point type, got %s.' % dtype)
2159  return dtype
2160
2161
2162@tf_export('debugging.assert_scalar', v1=[])
2163@dispatch.add_dispatch_support
2164def assert_scalar_v2(tensor, message=None, name=None):
2165  """Asserts that the given `tensor` is a scalar.
2166
2167  This function raises `ValueError` unless it can be certain that the given
2168  `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
2169  unknown.
2170
2171  This is always checked statically, so this method returns nothing.
2172
2173  Args:
2174    tensor: A `Tensor`.
2175    message: A string to prefix to the default message.
2176    name:  A name for this operation. Defaults to "assert_scalar"
2177
2178  Raises:
2179    ValueError: If the tensor is not scalar (rank 0), or if its shape is
2180      unknown.
2181  """
2182  assert_scalar(tensor=tensor, message=message, name=name)
2183
2184
2185@tf_export(v1=['debugging.assert_scalar', 'assert_scalar'])
2186@dispatch.add_dispatch_support
2187@deprecation.deprecated_endpoints('assert_scalar')
2188def assert_scalar(tensor, name=None, message=None):
2189  """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional).
2190
2191  This function raises `ValueError` unless it can be certain that the given
2192  `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
2193  unknown.
2194
2195  Args:
2196    tensor: A `Tensor`.
2197    name:  A name for this operation. Defaults to "assert_scalar"
2198    message: A string to prefix to the default message.
2199
2200  Returns:
2201    The input tensor (potentially converted to a `Tensor`).
2202
2203  Raises:
2204    ValueError: If the tensor is not scalar (rank 0), or if its shape is
2205      unknown.
2206  """
2207  with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope:
2208    tensor = ops.convert_to_tensor(tensor, name=name_scope)
2209    shape = tensor.get_shape()
2210    if shape.ndims != 0:
2211      if context.executing_eagerly():
2212        raise ValueError('%sExpected scalar shape, saw shape: %s.'
2213                         % (message or '', shape,))
2214      else:
2215        raise ValueError('%sExpected scalar shape for %s, saw shape: %s.'
2216                         % (message or '', tensor.name, shape))
2217    return tensor
2218
2219
2220@tf_export('ensure_shape')
2221@dispatch.add_dispatch_support
2222def ensure_shape(x, shape, name=None):
2223  """Updates the shape of a tensor and checks at runtime that the shape holds.
2224
2225  When executed, this operation asserts that the input tensor `x`'s shape
2226  is compatible with the `shape` argument.
2227  See `tf.TensorShape.is_compatible_with` for details.
2228
2229  >>> x = tf.constant([[1, 2, 3],
2230  ...                  [4, 5, 6]])
2231  >>> x = tf.ensure_shape(x, [2, 3])
2232
2233  Use `None` for unknown dimensions:
2234
2235  >>> x = tf.ensure_shape(x, [None, 3])
2236  >>> x = tf.ensure_shape(x, [2, None])
2237
2238  If the tensor's shape is not compatible with the `shape` argument, an error
2239  is raised:
2240
2241  >>> x = tf.ensure_shape(x, [5])
2242  Traceback (most recent call last):
2243  ...
2244  tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not
2245    compatible with expected shape [5]. [Op:EnsureShape]
2246
2247  During graph construction (typically tracing a `tf.function`),
2248  `tf.ensure_shape` updates the static-shape of the **result** tensor by
2249  merging the two shapes. See `tf.TensorShape.merge_with` for details.
2250
2251  This is most useful when **you** know a shape that can't be determined
2252  statically by TensorFlow.
2253
2254  The following trivial `tf.function` prints the input tensor's
2255  static-shape before and after `ensure_shape` is applied.
2256
2257  >>> @tf.function
2258  ... def f(tensor):
2259  ...   print("Static-shape before:", tensor.shape)
2260  ...   tensor = tf.ensure_shape(tensor, [None, 3])
2261  ...   print("Static-shape after:", tensor.shape)
2262  ...   return tensor
2263
2264  This lets you see the effect of `tf.ensure_shape` when the function is traced:
2265  >>> cf = f.get_concrete_function(tf.TensorSpec([None, None]))
2266  Static-shape before: (None, None)
2267  Static-shape after: (None, 3)
2268
2269  >>> cf(tf.zeros([3, 3])) # Passes
2270  >>> cf(tf.constant([1, 2, 3])) # fails
2271  Traceback (most recent call last):
2272  ...
2273  InvalidArgumentError:  Shape of tensor x [3] is not compatible with expected shape [3,3].
2274
2275  The above example raises `tf.errors.InvalidArgumentError`, because `x`'s
2276  shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)`
2277
2278  Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and
2279  runtime shapes. This is stricter than `tf.Tensor.set_shape` which only
2280  checks the buildtime shape.
2281
2282  Note: This differs from `tf.Tensor.set_shape` in that it sets the static shape
2283  of the resulting tensor and enforces it at runtime, raising an error if the
2284  tensor's runtime shape is incompatible with the specified shape.
2285  `tf.Tensor.set_shape` sets the static shape of the tensor without enforcing it
2286  at runtime, which may result in inconsistencies between the statically-known
2287  shape of tensors and the runtime value of tensors.
2288
2289  For example, of loading images of a known size:
2290
2291  >>> @tf.function
2292  ... def decode_image(png):
2293  ...   image = tf.image.decode_png(png, channels=3)
2294  ...   # the `print` executes during tracing.
2295  ...   print("Initial shape: ", image.shape)
2296  ...   image = tf.ensure_shape(image,[28, 28, 3])
2297  ...   print("Final shape: ", image.shape)
2298  ...   return image
2299
2300  When tracing a function, no ops are being executed, shapes may be unknown.
2301  See the [Concrete Functions Guide](https://www.tensorflow.org/guide/concrete_function)
2302  for details.
2303
2304  >>> concrete_decode = decode_image.get_concrete_function(
2305  ...     tf.TensorSpec([], dtype=tf.string))
2306  Initial shape:  (None, None, 3)
2307  Final shape:  (28, 28, 3)
2308
2309  >>> image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32)
2310  >>> image = tf.cast(image,tf.uint8)
2311  >>> png = tf.image.encode_png(image)
2312  >>> image2 = concrete_decode(png)
2313  >>> print(image2.shape)
2314  (28, 28, 3)
2315
2316  >>> image = tf.concat([image,image], axis=0)
2317  >>> print(image.shape)
2318  (56, 28, 3)
2319  >>> png = tf.image.encode_png(image)
2320  >>> image2 = concrete_decode(png)
2321  Traceback (most recent call last):
2322  ...
2323  tf.errors.InvalidArgumentError:  Shape of tensor DecodePng [56,28,3] is not
2324    compatible with expected shape [28,28,3].
2325
2326  Caution: if you don't use the result of `tf.ensure_shape` the check may not
2327  run.
2328
2329  >>> @tf.function
2330  ... def bad_decode_image(png):
2331  ...   image = tf.image.decode_png(png, channels=3)
2332  ...   # the `print` executes during tracing.
2333  ...   print("Initial shape: ", image.shape)
2334  ...   # BAD: forgot to use the returned tensor.
2335  ...   tf.ensure_shape(image,[28, 28, 3])
2336  ...   print("Final shape: ", image.shape)
2337  ...   return image
2338
2339  >>> image = bad_decode_image(png)
2340  Initial shape:  (None, None, 3)
2341  Final shape:  (None, None, 3)
2342  >>> print(image.shape)
2343  (56, 28, 3)
2344
2345  Args:
2346    x: A `Tensor`.
2347    shape: A `TensorShape` representing the shape of this tensor, a
2348      `TensorShapeProto`, a list, a tuple, or None.
2349    name: A name for this operation (optional). Defaults to "EnsureShape".
2350
2351  Returns:
2352    A `Tensor`. Has the same type and contents as `x`.
2353
2354  Raises:
2355    tf.errors.InvalidArgumentError: If `shape` is incompatible with the shape
2356    of `x`.
2357  """
2358  if not isinstance(shape, tensor_shape.TensorShape):
2359    shape = tensor_shape.TensorShape(shape)
2360
2361  return array_ops.ensure_shape(x, shape, name=name)
2362
2363
2364@ops.RegisterGradient('EnsureShape')
2365def _ensure_shape_grad(op, grad):
2366  del op  # Unused.
2367  return grad
2368