1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-short-docstring-punctuation
16"""Asserts and Boolean Checks."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.python.eager import context
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.util import compat
35from tensorflow.python.util import deprecation
36from tensorflow.python.util.tf_export import tf_export
37
38NUMERIC_TYPES = frozenset(
39    [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32,
40     dtypes.int64, dtypes.uint8, dtypes.qint8, dtypes.qint32, dtypes.quint8,
41     dtypes.complex64])
42
43__all__ = [
44    'assert_negative',
45    'assert_positive',
46    'assert_proper_iterable',
47    'assert_non_negative',
48    'assert_non_positive',
49    'assert_equal',
50    'assert_none_equal',
51    'assert_near',
52    'assert_integer',
53    'assert_less',
54    'assert_less_equal',
55    'assert_greater',
56    'assert_greater_equal',
57    'assert_rank',
58    'assert_rank_at_least',
59    'assert_rank_in',
60    'assert_same_float_dtype',
61    'assert_scalar',
62    'assert_type',
63    'is_non_decreasing',
64    'is_numeric_tensor',
65    'is_strictly_increasing',
66]
67
68
69def _maybe_constant_value_string(t):
70  if not isinstance(t, ops.Tensor):
71    return str(t)
72  const_t = tensor_util.constant_value(t)
73  if const_t is not None:
74    return str(const_t)
75  return t
76
77
78def _assert_static(condition, data):
79  """Raises a InvalidArgumentError with as much information as possible."""
80  if not condition:
81    data_static = [_maybe_constant_value_string(x) for x in data]
82    raise errors.InvalidArgumentError(node_def=None, op=None,
83                                      message='\n'.join(data_static))
84
85
86def _shape_and_dtype_str(tensor):
87  """Returns a string containing tensor's shape and dtype."""
88  return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)
89
90
91@tf_export(
92    'debugging.assert_proper_iterable',
93    v1=['debugging.assert_proper_iterable', 'assert_proper_iterable'])
94@deprecation.deprecated_endpoints('assert_proper_iterable')
95def assert_proper_iterable(values):
96  """Static assert that values is a "proper" iterable.
97
98  `Ops` that expect iterables of `Tensor` can call this to validate input.
99  Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.
100
101  Args:
102    values:  Object to be checked.
103
104  Raises:
105    TypeError:  If `values` is not iterable or is one of
106      `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
107  """
108  unintentional_iterables = (
109      (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray)
110      + compat.bytes_or_text_types
111  )
112  if isinstance(values, unintentional_iterables):
113    raise TypeError(
114        'Expected argument "values" to be a "proper" iterable.  Found: %s' %
115        type(values))
116
117  if not hasattr(values, '__iter__'):
118    raise TypeError(
119        'Expected argument "values" to be iterable.  Found: %s' % type(values))
120
121
122@tf_export('debugging.assert_negative', v1=[])
123def assert_negative_v2(x, message=None, summarize=None, name=None):
124  """Assert the condition `x < 0` holds element-wise.
125
126  This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is
127  empty, this is trivially satisfied.
128
129  If `x` is not negative everywhere, `message`, as well as the first `summarize`
130  entries of `x` are printed, and `InvalidArgumentError` is raised.
131
132  Args:
133    x:  Numeric `Tensor`.
134    message: A string to prefix to the default message.
135    summarize: Print this many entries of each tensor.
136    name: A name for this operation (optional).  Defaults to "assert_negative".
137
138  Raises:
139    InvalidArgumentError: if the check can be performed immediately and
140      `x[i] < 0` is False. The check can be performed immediately during eager
141      execution or if `x` is statically known.
142  """
143  assert_negative(x=x, message=message, summarize=summarize, name=name)
144
145
146@tf_export(v1=['debugging.assert_negative', 'assert_negative'])
147@deprecation.deprecated_endpoints('assert_negative')
148def assert_negative(x, data=None, summarize=None, message=None, name=None):
149  """Assert the condition `x < 0` holds element-wise.
150
151  Example of adding a dependency to an operation:
152
153  ```python
154  with tf.control_dependencies([tf.assert_negative(x)]):
155    output = tf.reduce_sum(x)
156  ```
157
158  Negative means, for every element `x[i]` of `x`, we have `x[i] < 0`.
159  If `x` is empty this is trivially satisfied.
160
161  Args:
162    x:  Numeric `Tensor`.
163    data:  The tensors to print out if the condition is False.  Defaults to
164      error message and first few entries of `x`.
165    summarize: Print this many entries of each tensor.
166    message: A string to prefix to the default message.
167    name: A name for this operation (optional).  Defaults to "assert_negative".
168
169  Returns:
170    Op raising `InvalidArgumentError` unless `x` is all negative.
171  """
172  message = message or ''
173  with ops.name_scope(name, 'assert_negative', [x, data]):
174    x = ops.convert_to_tensor(x, name='x')
175    if data is None:
176      if context.executing_eagerly():
177        name = _shape_and_dtype_str(x)
178      else:
179        name = x.name
180      data = [
181          message,
182          'Condition x < 0 did not hold element-wise:',
183          'x (%s) = ' % name, x]
184    zero = ops.convert_to_tensor(0, dtype=x.dtype)
185    return assert_less(x, zero, data=data, summarize=summarize)
186
187
188@tf_export('debugging.assert_positive', v1=[])
189def assert_positive_v2(x, message=None, summarize=None, name=None):
190  """Assert the condition `x > 0` holds element-wise.
191
192  This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is
193  empty, this is trivially satisfied.
194
195  If `x` is not positive everywhere, `message`, as well as the first `summarize`
196  entries of `x` are printed, and `InvalidArgumentError` is raised.
197
198  Args:
199    x:  Numeric `Tensor`.
200    message: A string to prefix to the default message.
201    summarize: Print this many entries of each tensor.
202    name: A name for this operation (optional). Defaults to "assert_positive".
203
204  Raises:
205    InvalidArgumentError: if the check can be performed immediately and
206      `x[i] > 0` is False. The check can be performed immediately during eager
207      execution or if `x` is statically known.
208  """
209  assert_positive(x=x, summarize=summarize, message=message, name=name)
210
211
212@tf_export(v1=['debugging.assert_positive', 'assert_positive'])
213@deprecation.deprecated_endpoints('assert_positive')
214def assert_positive(x, data=None, summarize=None, message=None, name=None):
215  """Assert the condition `x > 0` holds element-wise.
216
217  Example of adding a dependency to an operation:
218
219  ```python
220  with tf.control_dependencies([tf.assert_positive(x)]):
221    output = tf.reduce_sum(x)
222  ```
223
224  Positive means, for every element `x[i]` of `x`, we have `x[i] > 0`.
225  If `x` is empty this is trivially satisfied.
226
227  Args:
228    x:  Numeric `Tensor`.
229    data:  The tensors to print out if the condition is False.  Defaults to
230      error message and first few entries of `x`.
231    summarize: Print this many entries of each tensor.
232    message: A string to prefix to the default message.
233    name: A name for this operation (optional).  Defaults to "assert_positive".
234
235  Returns:
236    Op raising `InvalidArgumentError` unless `x` is all positive.
237  """
238  message = message or ''
239  with ops.name_scope(name, 'assert_positive', [x, data]):
240    x = ops.convert_to_tensor(x, name='x')
241    if data is None:
242      if context.executing_eagerly():
243        name = _shape_and_dtype_str(x)
244      else:
245        name = x.name
246      data = [
247          message, 'Condition x > 0 did not hold element-wise:',
248          'x (%s) = ' % name, x]
249    zero = ops.convert_to_tensor(0, dtype=x.dtype)
250    return assert_less(zero, x, data=data, summarize=summarize)
251
252
253@tf_export('debugging.assert_non_negative', v1=[])
254def assert_non_negative_v2(x, message=None, summarize=None, name=None):
255  """Assert the condition `x >= 0` holds element-wise.
256
257  This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is
258  empty, this is trivially satisfied.
259
260  If `x` is not >= 0 everywhere, `message`, as well as the first `summarize`
261  entries of `x` are printed, and `InvalidArgumentError` is raised.
262
263  Args:
264    x:  Numeric `Tensor`.
265    message: A string to prefix to the default message.
266    summarize: Print this many entries of each tensor.
267    name: A name for this operation (optional).  Defaults to
268      "assert_non_negative".
269
270  Raises:
271    InvalidArgumentError: if the check can be performed immediately and
272      `x[i] >= 0` is False. The check can be performed immediately during eager
273      execution or if `x` is statically known.
274  """
275  assert_non_negative(x=x, summarize=summarize, message=message, name=name)
276
277
278@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative'])
279@deprecation.deprecated_endpoints('assert_non_negative')
280def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
281  """Assert the condition `x >= 0` holds element-wise.
282
283  Example of adding a dependency to an operation:
284
285  ```python
286  with tf.control_dependencies([tf.assert_non_negative(x)]):
287    output = tf.reduce_sum(x)
288  ```
289
290  Non-negative means, for every element `x[i]` of `x`, we have `x[i] >= 0`.
291  If `x` is empty this is trivially satisfied.
292
293  Args:
294    x:  Numeric `Tensor`.
295    data:  The tensors to print out if the condition is False.  Defaults to
296      error message and first few entries of `x`.
297    summarize: Print this many entries of each tensor.
298    message: A string to prefix to the default message.
299    name: A name for this operation (optional).
300      Defaults to "assert_non_negative".
301
302  Returns:
303    Op raising `InvalidArgumentError` unless `x` is all non-negative.
304  """
305  message = message or ''
306  with ops.name_scope(name, 'assert_non_negative', [x, data]):
307    x = ops.convert_to_tensor(x, name='x')
308    if data is None:
309      if context.executing_eagerly():
310        name = _shape_and_dtype_str(x)
311      else:
312        name = x.name
313      data = [
314          message,
315          'Condition x >= 0 did not hold element-wise:',
316          'x (%s) = ' % name, x]
317    zero = ops.convert_to_tensor(0, dtype=x.dtype)
318    return assert_less_equal(zero, x, data=data, summarize=summarize)
319
320
321@tf_export('debugging.assert_non_positive', v1=[])
322def assert_non_positive_v2(x, message=None, summarize=None, name=None):
323  """Assert the condition `x <= 0` holds element-wise.
324
325  This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is
326  empty, this is trivially satisfied.
327
328  If `x` is not <= 0 everywhere, `message`, as well as the first `summarize`
329  entries of `x` are printed, and `InvalidArgumentError` is raised.
330
331  Args:
332    x:  Numeric `Tensor`.
333    message: A string to prefix to the default message.
334    summarize: Print this many entries of each tensor.
335    name: A name for this operation (optional).  Defaults to
336      "assert_non_positive".
337
338  Raises:
339    InvalidArgumentError: if the check can be performed immediately and
340      `x[i] <= 0` is False. The check can be performed immediately during eager
341      execution or if `x` is statically known.
342  """
343  assert_non_positive(x=x, summarize=summarize, message=message, name=name)
344
345
346@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive'])
347@deprecation.deprecated_endpoints('assert_non_positive')
348def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
349  """Assert the condition `x <= 0` holds element-wise.
350
351  Example of adding a dependency to an operation:
352
353  ```python
354  with tf.control_dependencies([tf.assert_non_positive(x)]):
355    output = tf.reduce_sum(x)
356  ```
357
358  Non-positive means, for every element `x[i]` of `x`, we have `x[i] <= 0`.
359  If `x` is empty this is trivially satisfied.
360
361  Args:
362    x:  Numeric `Tensor`.
363    data:  The tensors to print out if the condition is False.  Defaults to
364      error message and first few entries of `x`.
365    summarize: Print this many entries of each tensor.
366    message: A string to prefix to the default message.
367    name: A name for this operation (optional).
368      Defaults to "assert_non_positive".
369
370  Returns:
371    Op raising `InvalidArgumentError` unless `x` is all non-positive.
372  """
373  message = message or ''
374  with ops.name_scope(name, 'assert_non_positive', [x, data]):
375    x = ops.convert_to_tensor(x, name='x')
376    if data is None:
377      if context.executing_eagerly():
378        name = _shape_and_dtype_str(x)
379      else:
380        name = x.name
381      data = [
382          message,
383          'Condition x <= 0 did not hold element-wise:'
384          'x (%s) = ' % name, x]
385    zero = ops.convert_to_tensor(0, dtype=x.dtype)
386    return assert_less_equal(x, zero, data=data, summarize=summarize)
387
388
389@tf_export('debugging.assert_equal', 'assert_equal', v1=[])
390def assert_equal_v2(x, y, message=None, summarize=None, name=None):
391  """Assert the condition `x == y` holds element-wise.
392
393  This Op checks that `x[i] == y[i]` holds for every pair of (possibly
394  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
395  trivially satisfied.
396
397  If `x` and `y` are not equal, `message`, as well as the first `summarize`
398  entries of `x` and `y` are printed, and `InvalidArgumentError` is raised.
399
400  Args:
401    x:  Numeric `Tensor`.
402    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
403    message: A string to prefix to the default message.
404    summarize: Print this many entries of each tensor.
405    name: A name for this operation (optional).  Defaults to "assert_equal".
406
407  Raises:
408    InvalidArgumentError: if the check can be performed immediately and
409      `x == y` is False. The check can be performed immediately during eager
410      execution or if `x` and `y` are statically known.
411  """
412  assert_equal(x=x, y=y, summarize=summarize, message=message, name=name)
413
414
415@tf_export(v1=['debugging.assert_equal', 'assert_equal'])
416def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
417  """Assert the condition `x == y` holds element-wise.
418
419  Example of adding a dependency to an operation:
420
421  ```python
422  with tf.control_dependencies([tf.assert_equal(x, y)]):
423    output = tf.reduce_sum(x)
424  ```
425
426  This condition holds if for every pair of (possibly broadcast) elements
427  `x[i]`, `y[i]`, we have `x[i] == y[i]`.
428  If both `x` and `y` are empty, this is trivially satisfied.
429
430  Args:
431    x:  Numeric `Tensor`.
432    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
433    data:  The tensors to print out if the condition is False.  Defaults to
434      error message and first few entries of `x`, `y`.
435    summarize: Print this many entries of each tensor.
436    message: A string to prefix to the default message.
437    name: A name for this operation (optional).  Defaults to "assert_equal".
438
439  Returns:
440    Op that raises `InvalidArgumentError` if `x == y` is False.
441    @compatibility{eager} returns None
442
443  Raises:
444    InvalidArgumentError: if the check can be performed immediately and
445      `x == y` is False. The check can be performed immediately during eager
446      execution or if `x` and `y` are statically known.
447  """
448  message = message or ''
449  with ops.name_scope(name, 'assert_equal', [x, y, data]):
450    x = ops.convert_to_tensor(x, name='x')
451    y = ops.convert_to_tensor(y, name='y')
452
453    if context.executing_eagerly():
454      eq = math_ops.equal(x, y)
455      condition = math_ops.reduce_all(eq)
456      if not condition:
457        # Prepare a message with first elements of x and y.
458        summary_msg = ''
459        # Default to printing 3 elements like control_flow_ops.Assert (used
460        # by graph mode) does.
461        summarize = 3 if summarize is None else summarize
462        if summarize:
463          # reshape((-1,)) is the fastest way to get a flat array view.
464          x_np = x.numpy().reshape((-1,))
465          y_np = y.numpy().reshape((-1,))
466          x_sum = min(x_np.size, summarize)
467          y_sum = min(y_np.size, summarize)
468          summary_msg = ('First %d elements of x:\n%s\n'
469                         'First %d elements of y:\n%s\n' %
470                         (x_sum, x_np[:x_sum],
471                          y_sum, y_np[:y_sum]))
472
473        index_and_values_str = ''
474        if x.shape == y.shape and x.shape.as_list():
475          # If the shapes of x and y are the same (and not scalars),
476          # Get the values that actually differed and their indices.
477          # If shapes are different this information is more confusing
478          # than useful.
479          mask = math_ops.logical_not(eq)
480          indices = array_ops.where(mask)
481          indices_np = indices.numpy()
482          x_vals = array_ops.boolean_mask(x, mask)
483          y_vals = array_ops.boolean_mask(y, mask)
484          summarize = min(summarize, indices_np.shape[0])
485          index_and_values_str = (
486              'Indices of first %s different values:\n%s\n'
487              'Corresponding x values:\n%s\n'
488              'Corresponding y values:\n%s\n' %
489              (summarize, indices_np[:summarize],
490               x_vals.numpy().reshape((-1,))[:summarize],
491               y_vals.numpy().reshape((-1,))[:summarize]))
492
493        raise errors.InvalidArgumentError(
494            node_def=None, op=None,
495            message=('%s\nCondition x == y did not hold.\n%s%s' %
496                     (message or '', index_and_values_str, summary_msg)))
497      return
498
499    if data is None:
500      data = [
501          message,
502          'Condition x == y did not hold element-wise:',
503          'x (%s) = ' % x.name, x,
504          'y (%s) = ' % y.name, y
505      ]
506    condition = math_ops.reduce_all(math_ops.equal(x, y))
507    x_static = tensor_util.constant_value(x)
508    y_static = tensor_util.constant_value(y)
509    if x_static is not None and y_static is not None:
510      condition_static = (x_static == y_static).all()
511      _assert_static(condition_static, data)
512    return control_flow_ops.Assert(condition, data, summarize=summarize)
513
514
515@tf_export('debugging.assert_none_equal', v1=[])
516def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
517  """Assert the condition `x != y` holds for all elements.
518
519  This Op checks that `x[i] != y[i]` holds for every pair of (possibly
520  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
521  trivially satisfied.
522
523  If any elements of `x` and `y` are equal, `message`, as well as the first
524  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError`
525  is raised.
526
527  Args:
528    x:  Numeric `Tensor`.
529    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
530    summarize: Print this many entries of each tensor.
531    message: A string to prefix to the default message.
532    name: A name for this operation (optional).  Defaults to
533    "assert_none_equal".
534
535  Raises:
536    InvalidArgumentError: if the check can be performed immediately and
537      `x != y` is False for any pair of elements in `x` and `y`. The check can
538      be performed immediately during eager execution or if `x` and `y` are
539      statically known.
540  """
541  assert_none_equal(x=x, y=y, summarize=summarize, message=message, name=name)
542
543
544@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal'])
545@deprecation.deprecated_endpoints('assert_none_equal')
546def assert_none_equal(
547    x, y, data=None, summarize=None, message=None, name=None):
548  """Assert the condition `x != y` holds for all elements.
549
550  Example of adding a dependency to an operation:
551
552  ```python
553  with tf.control_dependencies([tf.assert_none_equal(x, y)]):
554    output = tf.reduce_sum(x)
555  ```
556
557  This condition holds if for every pair of (possibly broadcast) elements
558  `x[i]`, `y[i]`, we have `x[i] != y[i]`.
559  If both `x` and `y` are empty, this is trivially satisfied.
560
561  Args:
562    x:  Numeric `Tensor`.
563    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
564    data:  The tensors to print out if the condition is False.  Defaults to
565      error message and first few entries of `x`, `y`.
566    summarize: Print this many entries of each tensor.
567    message: A string to prefix to the default message.
568    name: A name for this operation (optional).
569      Defaults to "assert_none_equal".
570
571  Returns:
572    Op that raises `InvalidArgumentError` if `x != y` is ever False.
573  """
574  message = message or ''
575  with ops.name_scope(name, 'assert_none_equal', [x, y, data]):
576    x = ops.convert_to_tensor(x, name='x')
577    y = ops.convert_to_tensor(y, name='y')
578    if context.executing_eagerly():
579      x_name = _shape_and_dtype_str(x)
580      y_name = _shape_and_dtype_str(y)
581    else:
582      x_name = x.name
583      y_name = y.name
584
585    if data is None:
586      data = [
587          message,
588          'Condition x != y did not hold for every single element:',
589          'x (%s) = ' % x_name, x,
590          'y (%s) = ' % y_name, y
591      ]
592    condition = math_ops.reduce_all(math_ops.not_equal(x, y))
593    return control_flow_ops.Assert(condition, data, summarize=summarize)
594
595
596@tf_export('debugging.assert_near', v1=[])
597def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None,
598                   name=None):
599  """Assert the condition `x` and `y` are close element-wise.
600
601  This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every
602  pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are
603  empty, this is trivially satisfied.
604
605  If any elements of `x` and `y` are not close, `message`, as well as the first
606  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError`
607  is raised.
608
609  The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
610  representable positive number such that `1 + eps != 1`.  This is about
611  `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
612  See `numpy.finfo`.
613
614  Args:
615    x: Float or complex `Tensor`.
616    y: Float or complex `Tensor`, same dtype as and broadcastable to `x`.
617    rtol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
618      The relative tolerance.  Default is `10 * eps`.
619    atol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
620      The absolute tolerance.  Default is `10 * eps`.
621    message: A string to prefix to the default message.
622    summarize: Print this many entries of each tensor.
623    name: A name for this operation (optional).  Defaults to "assert_near".
624
625  Raises:
626    InvalidArgumentError: if the check can be performed immediately and
627      `x != y` is False for any pair of elements in `x` and `y`. The check can
628      be performed immediately during eager execution or if `x` and `y` are
629      statically known.
630
631  @compatibility(numpy)
632  Similar to `numpy.assert_allclose`, except tolerance depends on data type.
633  This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`,
634  and even `16bit` data.
635  @end_compatibility
636  """
637  assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize,
638              message=message, name=name)
639
640
641@tf_export(v1=['debugging.assert_near', 'assert_near'])
642@deprecation.deprecated_endpoints('assert_near')
643def assert_near(
644    x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
645    name=None):
646  """Assert the condition `x` and `y` are close element-wise.
647
648  Example of adding a dependency to an operation:
649
650  ```python
651  with tf.control_dependencies([tf.assert_near(x, y)]):
652    output = tf.reduce_sum(x)
653  ```
654
655  This condition holds if for every pair of (possibly broadcast) elements
656  `x[i]`, `y[i]`, we have
657
658  ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```.
659
660  If both `x` and `y` are empty, this is trivially satisfied.
661
662  The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
663  representable positive number such that `1 + eps != 1`.  This is about
664  `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
665  See `numpy.finfo`.
666
667  Args:
668    x:  Float or complex `Tensor`.
669    y:  Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`.
670    rtol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
671      The relative tolerance.  Default is `10 * eps`.
672    atol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
673      The absolute tolerance.  Default is `10 * eps`.
674    data:  The tensors to print out if the condition is False.  Defaults to
675      error message and first few entries of `x`, `y`.
676    summarize: Print this many entries of each tensor.
677    message: A string to prefix to the default message.
678    name: A name for this operation (optional).  Defaults to "assert_near".
679
680  Returns:
681    Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
682
683  @compatibility(numpy)
684  Similar to `numpy.assert_allclose`, except tolerance depends on data type.
685  This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`,
686  and even `16bit` data.
687  @end_compatibility
688  """
689  message = message or ''
690  with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]):
691    x = ops.convert_to_tensor(x, name='x')
692    y = ops.convert_to_tensor(y, name='y', dtype=x.dtype)
693
694    eps = np.finfo(x.dtype.as_numpy_dtype).eps
695    rtol = 10 * eps if rtol is None else rtol
696    atol = 10 * eps if atol is None else atol
697
698    rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype)
699    atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype)
700
701    if context.executing_eagerly():
702      x_name = _shape_and_dtype_str(x)
703      y_name = _shape_and_dtype_str(y)
704    else:
705      x_name = x.name
706      y_name = y.name
707
708    if data is None:
709      data = [
710          message,
711          'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol),
712          'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
713      ]
714    tol = atol + rtol * math_ops.abs(y)
715    diff = math_ops.abs(x - y)
716    condition = math_ops.reduce_all(math_ops.less(diff, tol))
717    return control_flow_ops.Assert(condition, data, summarize=summarize)
718
719
720@tf_export('debugging.assert_less', 'assert_less', v1=[])
721def assert_less_v2(x, y, message=None, summarize=None, name=None):
722  """Assert the condition `x < y` holds element-wise.
723
724  This Op checks that `x[i] < y[i]` holds for every pair of (possibly
725  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
726  trivially satisfied.
727
728  If `x` is not less than `y` element-wise, `message`, as well as the first
729  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is
730  raised.
731
732  Args:
733    x:  Numeric `Tensor`.
734    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
735    message: A string to prefix to the default message.
736    summarize: Print this many entries of each tensor.
737    name: A name for this operation (optional).  Defaults to "assert_less".
738
739  Raises:
740    InvalidArgumentError: if the check can be performed immediately and
741      `x < y` is False. The check can be performed immediately during eager
742      execution or if `x` and `y` are statically known.
743  """
744  assert_less(x=x, y=y, summarize=summarize, message=message, name=name)
745
746
747@tf_export(v1=['debugging.assert_less', 'assert_less'])
748def assert_less(x, y, data=None, summarize=None, message=None, name=None):
749  """Assert the condition `x < y` holds element-wise.
750
751  Example of adding a dependency to an operation:
752
753  ```python
754  with tf.control_dependencies([tf.assert_less(x, y)]):
755    output = tf.reduce_sum(x)
756  ```
757
758  This condition holds if for every pair of (possibly broadcast) elements
759  `x[i]`, `y[i]`, we have `x[i] < y[i]`.
760  If both `x` and `y` are empty, this is trivially satisfied.
761
762  Args:
763    x:  Numeric `Tensor`.
764    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
765    data:  The tensors to print out if the condition is False.  Defaults to
766      error message and first few entries of `x`, `y`.
767    summarize: Print this many entries of each tensor.
768    message: A string to prefix to the default message.
769    name: A name for this operation (optional).  Defaults to "assert_less".
770
771  Returns:
772    Op that raises `InvalidArgumentError` if `x < y` is False.
773  """
774  message = message or ''
775  with ops.name_scope(name, 'assert_less', [x, y, data]):
776    x = ops.convert_to_tensor(x, name='x')
777    y = ops.convert_to_tensor(y, name='y')
778    if context.executing_eagerly():
779      x_name = _shape_and_dtype_str(x)
780      y_name = _shape_and_dtype_str(y)
781    else:
782      x_name = x.name
783      y_name = y.name
784
785    if data is None:
786      data = [
787          message,
788          'Condition x < y did not hold element-wise:',
789          'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
790      ]
791    condition = math_ops.reduce_all(math_ops.less(x, y))
792    return control_flow_ops.Assert(condition, data, summarize=summarize)
793
794
795@tf_export('debugging.assert_less_equal', v1=[])
796def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
797  """Assert the condition `x <= y` holds element-wise.
798
799  This Op checks that `x[i] <= y[i]` holds for every pair of (possibly
800  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
801  trivially satisfied.
802
803  If `x` is not less or equal than `y` element-wise, `message`, as well as the
804  first `summarize` entries of `x` and `y` are printed, and
805  `InvalidArgumentError` is raised.
806
807  Args:
808    x:  Numeric `Tensor`.
809    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
810    message: A string to prefix to the default message.
811    summarize: Print this many entries of each tensor.
812    name: A name for this operation (optional). Defaults to "assert_less_equal".
813
814  Raises:
815    InvalidArgumentError: if the check can be performed immediately and
816      `x <= y` is False. The check can be performed immediately during eager
817      execution or if `x` and `y` are statically known.
818  """
819  assert_less_equal(x=x, y=y, summarize=summarize, message=message, name=name)
820
821
822@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal'])
823@deprecation.deprecated_endpoints('assert_less_equal')
824def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
825  """Assert the condition `x <= y` holds element-wise.
826
827  Example of adding a dependency to an operation:
828
829  ```python
830  with tf.control_dependencies([tf.assert_less_equal(x, y)]):
831    output = tf.reduce_sum(x)
832  ```
833
834  This condition holds if for every pair of (possibly broadcast) elements
835  `x[i]`, `y[i]`, we have `x[i] <= y[i]`.
836  If both `x` and `y` are empty, this is trivially satisfied.
837
838  Args:
839    x:  Numeric `Tensor`.
840    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
841    data:  The tensors to print out if the condition is False.  Defaults to
842      error message and first few entries of `x`, `y`.
843    summarize: Print this many entries of each tensor.
844    message: A string to prefix to the default message.
845    name: A name for this operation (optional).  Defaults to "assert_less_equal"
846
847  Returns:
848    Op that raises `InvalidArgumentError` if `x <= y` is False.
849  """
850  message = message or ''
851  with ops.name_scope(name, 'assert_less_equal', [x, y, data]):
852    x = ops.convert_to_tensor(x, name='x')
853    y = ops.convert_to_tensor(y, name='y')
854    if context.executing_eagerly():
855      x_name = _shape_and_dtype_str(x)
856      y_name = _shape_and_dtype_str(y)
857    else:
858      x_name = x.name
859      y_name = y.name
860
861    if data is None:
862      data = [
863          message,
864          'Condition x <= y did not hold element-wise:'
865          'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
866      ]
867    condition = math_ops.reduce_all(math_ops.less_equal(x, y))
868    return control_flow_ops.Assert(condition, data, summarize=summarize)
869
870
871@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
872def assert_greater_v2(x, y, message=None, summarize=None, name=None):
873  """Assert the condition `x > y` holds element-wise.
874
875  This Op checks that `x[i] > y[i]` holds for every pair of (possibly
876  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
877  trivially satisfied.
878
879  If `x` is not greater than `y` element-wise, `message`, as well as the first
880  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is
881  raised.
882
883  Args:
884    x:  Numeric `Tensor`.
885    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
886    message: A string to prefix to the default message.
887    summarize: Print this many entries of each tensor.
888    name: A name for this operation (optional).  Defaults to "assert_greater".
889
890  Raises:
891    InvalidArgumentError: if the check can be performed immediately and
892      `x > y` is False. The check can be performed immediately during eager
893      execution or if `x` and `y` are statically known.
894  """
895  assert_greater(x=x, y=y, summarize=summarize, message=message, name=name)
896
897
898@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
899def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
900  """Assert the condition `x > y` holds element-wise.
901
902  Example of adding a dependency to an operation:
903
904  ```python
905  with tf.control_dependencies([tf.assert_greater(x, y)]):
906    output = tf.reduce_sum(x)
907  ```
908
909  This condition holds if for every pair of (possibly broadcast) elements
910  `x[i]`, `y[i]`, we have `x[i] > y[i]`.
911  If both `x` and `y` are empty, this is trivially satisfied.
912
913  Args:
914    x:  Numeric `Tensor`.
915    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
916    data:  The tensors to print out if the condition is False.  Defaults to
917      error message and first few entries of `x`, `y`.
918    summarize: Print this many entries of each tensor.
919    message: A string to prefix to the default message.
920    name: A name for this operation (optional).  Defaults to "assert_greater".
921
922  Returns:
923    Op that raises `InvalidArgumentError` if `x > y` is False.
924  """
925  message = message or ''
926  with ops.name_scope(name, 'assert_greater', [x, y, data]):
927    x = ops.convert_to_tensor(x, name='x')
928    y = ops.convert_to_tensor(y, name='y')
929    if context.executing_eagerly():
930      x_name = _shape_and_dtype_str(x)
931      y_name = _shape_and_dtype_str(y)
932    else:
933      x_name = x.name
934      y_name = y.name
935
936    if data is None:
937      data = [
938          message,
939          'Condition x > y did not hold element-wise:'
940          'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
941      ]
942    condition = math_ops.reduce_all(math_ops.greater(x, y))
943    return control_flow_ops.Assert(condition, data, summarize=summarize)
944
945
946@tf_export('debugging.assert_greater_equal', v1=[])
947def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
948  """Assert the condition `x >= y` holds element-wise.
949
950  This Op checks that `x[i] >= y[i]` holds for every pair of (possibly
951  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
952  trivially satisfied.
953
954  If `x` is not greater or equal to `y` element-wise, `message`, as well as the
955  first `summarize` entries of `x` and `y` are printed, and
956  `InvalidArgumentError` is raised.
957
958  Args:
959    x:  Numeric `Tensor`.
960    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
961    message: A string to prefix to the default message.
962    summarize: Print this many entries of each tensor.
963    name: A name for this operation (optional).  Defaults to
964    "assert_greater_equal".
965
966  Raises:
967    InvalidArgumentError: if the check can be performed immediately and
968      `x >= y` is False. The check can be performed immediately during eager
969      execution or if `x` and `y` are statically known.
970  """
971  assert_greater_equal(x=x, y=y, summarize=summarize, message=message,
972                       name=name)
973
974
975@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal'])
976@deprecation.deprecated_endpoints('assert_greater_equal')
977def assert_greater_equal(x, y, data=None, summarize=None, message=None,
978                         name=None):
979  """Assert the condition `x >= y` holds element-wise.
980
981  Example of adding a dependency to an operation:
982
983  ```python
984  with tf.control_dependencies([tf.assert_greater_equal(x, y)]):
985    output = tf.reduce_sum(x)
986  ```
987
988  This condition holds if for every pair of (possibly broadcast) elements
989  `x[i]`, `y[i]`, we have `x[i] >= y[i]`.
990  If both `x` and `y` are empty, this is trivially satisfied.
991
992  Args:
993    x:  Numeric `Tensor`.
994    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
995    data:  The tensors to print out if the condition is False.  Defaults to
996      error message and first few entries of `x`, `y`.
997    summarize: Print this many entries of each tensor.
998    message: A string to prefix to the default message.
999    name: A name for this operation (optional).  Defaults to
1000      "assert_greater_equal"
1001
1002  Returns:
1003    Op that raises `InvalidArgumentError` if `x >= y` is False.
1004  """
1005  message = message or ''
1006  with ops.name_scope(name, 'assert_greater_equal', [x, y, data]):
1007    x = ops.convert_to_tensor(x, name='x')
1008    y = ops.convert_to_tensor(y, name='y')
1009    if context.executing_eagerly():
1010      x_name = _shape_and_dtype_str(x)
1011      y_name = _shape_and_dtype_str(y)
1012    else:
1013      x_name = x.name
1014      y_name = y.name
1015
1016    if data is None:
1017      data = [
1018          message,
1019          'Condition x >= y did not hold element-wise:'
1020          'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
1021      ]
1022    condition = math_ops.reduce_all(math_ops.greater_equal(x, y))
1023    return control_flow_ops.Assert(condition, data, summarize=summarize)
1024
1025
1026def _assert_rank_condition(
1027    x, rank, static_condition, dynamic_condition, data, summarize):
1028  """Assert `x` has a rank that satisfies a given condition.
1029
1030  Args:
1031    x:  Numeric `Tensor`.
1032    rank:  Scalar `Tensor`.
1033    static_condition:   A python function that takes `[actual_rank, given_rank]`
1034      and returns `True` if the condition is satisfied, `False` otherwise.
1035    dynamic_condition:  An `op` that takes [actual_rank, given_rank]
1036      and return `True` if the condition is satisfied, `False` otherwise.
1037    data:  The tensors to print out if the condition is false.  Defaults to
1038      error message and first few entries of `x`.
1039    summarize: Print this many entries of each tensor.
1040
1041  Returns:
1042    Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
1043
1044  Raises:
1045    ValueError:  If static checks determine `x` fails static_condition.
1046  """
1047  assert_type(rank, dtypes.int32)
1048
1049  # Attempt to statically defined rank.
1050  rank_static = tensor_util.constant_value(rank)
1051  if rank_static is not None:
1052    if rank_static.ndim != 0:
1053      raise ValueError('Rank must be a scalar.')
1054
1055    x_rank_static = x.get_shape().ndims
1056    if x_rank_static is not None:
1057      if not static_condition(x_rank_static, rank_static):
1058        raise ValueError(
1059            'Static rank condition failed', x_rank_static, rank_static)
1060      return control_flow_ops.no_op(name='static_checks_determined_all_ok')
1061
1062  condition = dynamic_condition(array_ops.rank(x), rank)
1063
1064  # Add the condition that `rank` must have rank zero.  Prevents the bug where
1065  # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
1066  if rank_static is None:
1067    this_data = ['Rank must be a scalar. Received rank: ', rank]
1068    rank_check = assert_rank(rank, 0, data=this_data)
1069    condition = control_flow_ops.with_dependencies([rank_check], condition)
1070
1071  return control_flow_ops.Assert(condition, data, summarize=summarize)
1072
1073
1074@tf_export('debugging.assert_rank', 'assert_rank', v1=[])
1075def assert_rank_v2(x, rank, message=None, name=None):
1076  """Assert that `x` has rank equal to `rank`.
1077
1078  This Op checks that the rank of `x` is equal to `rank`.
1079
1080  If `x` has a different rank, `message`, as well as the shape of `x` are
1081  printed, and `InvalidArgumentError` is raised.
1082
1083  Args:
1084    x: `Tensor`.
1085    rank: Scalar integer `Tensor`.
1086    message: A string to prefix to the default message.
1087    name: A name for this operation (optional). Defaults to
1088      "assert_rank".
1089
1090  Raises:
1091    InvalidArgumentError: if the check can be performed immediately and
1092      `x` does not have rank `rank`. The check can be performed immediately
1093      during eager execution or if the shape of `x` is statically known.
1094  """
1095  assert_rank(x=x, rank=rank, message=message, name=name)
1096
1097
1098@tf_export(v1=['debugging.assert_rank', 'assert_rank'])
1099def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
1100  """Assert `x` has rank equal to `rank`.
1101
1102  Example of adding a dependency to an operation:
1103
1104  ```python
1105  with tf.control_dependencies([tf.assert_rank(x, 2)]):
1106    output = tf.reduce_sum(x)
1107  ```
1108
1109  Args:
1110    x:  Numeric `Tensor`.
1111    rank:  Scalar integer `Tensor`.
1112    data:  The tensors to print out if the condition is False.  Defaults to
1113      error message and the shape of `x`.
1114    summarize: Print this many entries of each tensor.
1115    message: A string to prefix to the default message.
1116    name: A name for this operation (optional).  Defaults to "assert_rank".
1117
1118  Returns:
1119    Op raising `InvalidArgumentError` unless `x` has specified rank.
1120    If static checks determine `x` has correct rank, a `no_op` is returned.
1121
1122  Raises:
1123    ValueError:  If static checks determine `x` has wrong rank.
1124  """
1125  with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])):
1126    x = ops.convert_to_tensor(x, name='x')
1127    rank = ops.convert_to_tensor(rank, name='rank')
1128    message = message or ''
1129
1130    static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
1131    dynamic_condition = math_ops.equal
1132
1133    if context.executing_eagerly():
1134      name = ''
1135    else:
1136      name = x.name
1137
1138    if data is None:
1139      data = [
1140          message,
1141          'Tensor %s must have rank' % name, rank, 'Received shape: ',
1142          array_ops.shape(x)
1143      ]
1144
1145    try:
1146      assert_op = _assert_rank_condition(x, rank, static_condition,
1147                                         dynamic_condition, data, summarize)
1148
1149    except ValueError as e:
1150      if e.args[0] == 'Static rank condition failed':
1151        raise ValueError(
1152            '%s.  Tensor %s must have rank %d.  Received rank %d, shape %s' %
1153            (message, name, e.args[2], e.args[1], x.get_shape()))
1154      else:
1155        raise
1156
1157  return assert_op
1158
1159
1160@tf_export('debugging.assert_rank_at_least', v1=[])
1161def assert_rank_at_least_v2(x, rank, message=None, name=None):
1162  """Assert that `x` has rank of at least `rank`.
1163
1164  This Op checks that the rank of `x` is greater or equal to `rank`.
1165
1166  If `x` has a rank lower than `rank`, `message`, as well as the shape of `x`
1167  are printed, and `InvalidArgumentError` is raised.
1168
1169  Args:
1170    x: `Tensor`.
1171    rank: Scalar integer `Tensor`.
1172    message: A string to prefix to the default message.
1173    name: A name for this operation (optional).  Defaults to
1174      "assert_rank_at_least".
1175
1176  Raises:
1177    InvalidArgumentError: `x` does not have rank at least `rank`, but the rank
1178      cannot be statically determined.
1179    ValueError: If static checks determine `x` has mismatched rank.
1180  """
1181  assert_rank_at_least(x=x, rank=rank, message=message, name=name)
1182
1183
1184@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least'])
1185@deprecation.deprecated_endpoints('assert_rank_at_least')
1186def assert_rank_at_least(
1187    x, rank, data=None, summarize=None, message=None, name=None):
1188  """Assert `x` has rank equal to `rank` or higher.
1189
1190  Example of adding a dependency to an operation:
1191
1192  ```python
1193  with tf.control_dependencies([tf.assert_rank_at_least(x, 2)]):
1194    output = tf.reduce_sum(x)
1195  ```
1196
1197  Args:
1198    x:  Numeric `Tensor`.
1199    rank:  Scalar `Tensor`.
1200    data:  The tensors to print out if the condition is False.  Defaults to
1201      error message and first few entries of `x`.
1202    summarize: Print this many entries of each tensor.
1203    message: A string to prefix to the default message.
1204    name: A name for this operation (optional).
1205      Defaults to "assert_rank_at_least".
1206
1207  Returns:
1208    Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
1209    If static checks determine `x` has correct rank, a `no_op` is returned.
1210
1211  Raises:
1212    ValueError:  If static checks determine `x` has wrong rank.
1213  """
1214  with ops.name_scope(
1215      name, 'assert_rank_at_least', (x, rank) + tuple(data or [])):
1216    x = ops.convert_to_tensor(x, name='x')
1217    rank = ops.convert_to_tensor(rank, name='rank')
1218    message = message or ''
1219
1220    static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
1221    dynamic_condition = math_ops.greater_equal
1222
1223    if context.executing_eagerly():
1224      name = ''
1225    else:
1226      name = x.name
1227
1228    if data is None:
1229      data = [
1230          message,
1231          'Tensor %s must have rank at least' % name, rank,
1232          'Received shape: ', array_ops.shape(x)
1233      ]
1234
1235    try:
1236      assert_op = _assert_rank_condition(x, rank, static_condition,
1237                                         dynamic_condition, data, summarize)
1238
1239    except ValueError as e:
1240      if e.args[0] == 'Static rank condition failed':
1241        raise ValueError(
1242            '%s.  Tensor %s must have rank at least %d.  Received rank %d, '
1243            'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
1244      else:
1245        raise
1246
1247  return assert_op
1248
1249
1250def _static_rank_in(actual_rank, given_ranks):
1251  return actual_rank in given_ranks
1252
1253
1254def _dynamic_rank_in(actual_rank, given_ranks):
1255  if len(given_ranks) < 1:
1256    return ops.convert_to_tensor(False)
1257  result = math_ops.equal(given_ranks[0], actual_rank)
1258  for given_rank in given_ranks[1:]:
1259    result = math_ops.logical_or(
1260        result, math_ops.equal(given_rank, actual_rank))
1261  return result
1262
1263
1264def _assert_ranks_condition(
1265    x, ranks, static_condition, dynamic_condition, data, summarize):
1266  """Assert `x` has a rank that satisfies a given condition.
1267
1268  Args:
1269    x:  Numeric `Tensor`.
1270    ranks:  Scalar `Tensor`.
1271    static_condition:   A python function that takes
1272      `[actual_rank, given_ranks]` and returns `True` if the condition is
1273      satisfied, `False` otherwise.
1274    dynamic_condition:  An `op` that takes [actual_rank, given_ranks]
1275      and return `True` if the condition is satisfied, `False` otherwise.
1276    data:  The tensors to print out if the condition is false.  Defaults to
1277      error message and first few entries of `x`.
1278    summarize: Print this many entries of each tensor.
1279
1280  Returns:
1281    Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
1282
1283  Raises:
1284    ValueError:  If static checks determine `x` fails static_condition.
1285  """
1286  for rank in ranks:
1287    assert_type(rank, dtypes.int32)
1288
1289  # Attempt to statically defined rank.
1290  ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks])
1291  if not any(r is None for r in ranks_static):
1292    for rank_static in ranks_static:
1293      if rank_static.ndim != 0:
1294        raise ValueError('Rank must be a scalar.')
1295
1296    x_rank_static = x.get_shape().ndims
1297    if x_rank_static is not None:
1298      if not static_condition(x_rank_static, ranks_static):
1299        raise ValueError(
1300            'Static rank condition failed', x_rank_static, ranks_static)
1301      return control_flow_ops.no_op(name='static_checks_determined_all_ok')
1302
1303  condition = dynamic_condition(array_ops.rank(x), ranks)
1304
1305  # Add the condition that `rank` must have rank zero.  Prevents the bug where
1306  # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
1307  for rank, rank_static in zip(ranks, ranks_static):
1308    if rank_static is None:
1309      this_data = ['Rank must be a scalar. Received rank: ', rank]
1310      rank_check = assert_rank(rank, 0, data=this_data)
1311      condition = control_flow_ops.with_dependencies([rank_check], condition)
1312
1313  return control_flow_ops.Assert(condition, data, summarize=summarize)
1314
1315
1316@tf_export('debugging.assert_rank_in', v1=[])
1317def assert_rank_in_v2(x, ranks, message=None, name=None):
1318  """Assert that `x` has a rank in `ranks`.
1319
1320  This Op checks that the rank of `x` is in `ranks`.
1321
1322  If `x` has a different rank, `message`, as well as the shape of `x` are
1323  printed, and `InvalidArgumentError` is raised.
1324
1325  Args:
1326    x: `Tensor`.
1327    ranks: `Iterable` of scalar `Tensor` objects.
1328    message: A string to prefix to the default message.
1329    name: A name for this operation (optional). Defaults to "assert_rank_in".
1330
1331  Raises:
1332    InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot
1333      be statically determined.
1334    ValueError: If static checks determine `x` has mismatched rank.
1335  """
1336  assert_rank_in(x=x, ranks=ranks, message=message, name=name)
1337
1338
1339@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in'])
1340@deprecation.deprecated_endpoints('assert_rank_in')
1341def assert_rank_in(
1342    x, ranks, data=None, summarize=None, message=None, name=None):
1343  """Assert `x` has rank in `ranks`.
1344
1345  Example of adding a dependency to an operation:
1346
1347  ```python
1348  with tf.control_dependencies([tf.assert_rank_in(x, (2, 4))]):
1349    output = tf.reduce_sum(x)
1350  ```
1351
1352  Args:
1353    x:  Numeric `Tensor`.
1354    ranks:  Iterable of scalar `Tensor` objects.
1355    data:  The tensors to print out if the condition is False.  Defaults to
1356      error message and first few entries of `x`.
1357    summarize: Print this many entries of each tensor.
1358    message: A string to prefix to the default message.
1359    name: A name for this operation (optional).
1360      Defaults to "assert_rank_in".
1361
1362  Returns:
1363    Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
1364    If static checks determine `x` has matching rank, a `no_op` is returned.
1365
1366  Raises:
1367    ValueError:  If static checks determine `x` has mismatched rank.
1368  """
1369  with ops.name_scope(
1370      name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])):
1371    x = ops.convert_to_tensor(x, name='x')
1372    ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
1373    message = message or ''
1374
1375    if context.executing_eagerly():
1376      name = ''
1377    else:
1378      name = x.name
1379
1380    if data is None:
1381      data = [
1382          message, 'Tensor %s must have rank in' % name
1383      ] + list(ranks) + [
1384          'Received shape: ', array_ops.shape(x)
1385      ]
1386
1387    try:
1388      assert_op = _assert_ranks_condition(x, ranks, _static_rank_in,
1389                                          _dynamic_rank_in, data, summarize)
1390
1391    except ValueError as e:
1392      if e.args[0] == 'Static rank condition failed':
1393        raise ValueError(
1394            '%s.  Tensor %s must have rank in %s.  Received rank %d, '
1395            'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
1396      else:
1397        raise
1398
1399  return assert_op
1400
1401
1402@tf_export('debugging.assert_integer', v1=[])
1403def assert_integer_v2(x, message=None, name=None):
1404  """Assert that `x` is of integer dtype.
1405
1406  If `x` has a non-integer type, `message`, as well as the dtype of `x` are
1407  printed, and `InvalidArgumentError` is raised.
1408
1409  Args:
1410    x: A `Tensor`.
1411    message: A string to prefix to the default message.
1412    name: A name for this operation (optional). Defaults to "assert_integer".
1413
1414  Raises:
1415    TypeError:  If `x.dtype` is not a non-quantized integer type.
1416  """
1417  assert_integer(x=x, message=message, name=name)
1418
1419
1420@tf_export(v1=['debugging.assert_integer', 'assert_integer'])
1421@deprecation.deprecated_endpoints('assert_integer')
1422def assert_integer(x, message=None, name=None):
1423  """Assert that `x` is of integer dtype.
1424
1425  Example of adding a dependency to an operation:
1426
1427  ```python
1428  with tf.control_dependencies([tf.assert_integer(x)]):
1429    output = tf.reduce_sum(x)
1430  ```
1431
1432  Args:
1433    x: `Tensor` whose basetype is integer and is not quantized.
1434    message: A string to prefix to the default message.
1435    name: A name for this operation (optional).  Defaults to "assert_integer".
1436
1437  Raises:
1438    TypeError:  If `x.dtype` is anything other than non-quantized integer.
1439
1440  Returns:
1441    A `no_op` that does nothing.  Type can be determined statically.
1442  """
1443  message = message or ''
1444  with ops.name_scope(name, 'assert_integer', [x]):
1445    x = ops.convert_to_tensor(x, name='x')
1446    if not x.dtype.is_integer:
1447      if context.executing_eagerly():
1448        name = 'tensor'
1449      else:
1450        name = x.name
1451      err_msg = (
1452          '%s  Expected "x" to be integer type.  Found: %s of dtype %s'
1453          % (message, name, x.dtype))
1454      raise TypeError(err_msg)
1455
1456    return control_flow_ops.no_op('statically_determined_was_integer')
1457
1458
1459@tf_export('debugging.assert_type', v1=[])
1460def assert_type_v2(tensor, tf_type, message=None, name=None):
1461  """Asserts that the given `Tensor` is of the specified type.
1462
1463  Args:
1464    tensor: A `Tensor`.
1465    tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
1466      etc).
1467    message: A string to prefix to the default message.
1468    name:  A name for this operation. Defaults to "assert_type"
1469
1470  Raises:
1471    TypeError: If the tensor's data type doesn't match `tf_type`.
1472  """
1473  assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name)
1474
1475
1476@tf_export(v1=['debugging.assert_type', 'assert_type'])
1477@deprecation.deprecated_endpoints('assert_type')
1478def assert_type(tensor, tf_type, message=None, name=None):
1479  """Statically asserts that the given `Tensor` is of the specified type.
1480
1481  Args:
1482    tensor: A `Tensor`.
1483    tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
1484      etc).
1485    message: A string to prefix to the default message.
1486    name:  A name to give this `Op`.  Defaults to "assert_type"
1487
1488  Raises:
1489    TypeError: If the tensors data type doesn't match `tf_type`.
1490
1491  Returns:
1492    A `no_op` that does nothing.  Type can be determined statically.
1493  """
1494  message = message or ''
1495  with ops.name_scope(name, 'assert_type', [tensor]):
1496    tensor = ops.convert_to_tensor(tensor, name='tensor')
1497    if tensor.dtype != tf_type:
1498      if context.executing_eagerly():
1499        raise TypeError('%s tensor must be of type %s' % (message, tf_type))
1500      else:
1501        raise TypeError('%s  %s must be of type %s' % (message, tensor.name,
1502                                                       tf_type))
1503
1504    return control_flow_ops.no_op('statically_determined_correct_type')
1505
1506
1507# pylint: disable=line-too-long
1508def _get_diff_for_monotonic_comparison(x):
1509  """Gets the difference x[1:] - x[:-1]."""
1510  x = array_ops.reshape(x, [-1])
1511  if not is_numeric_tensor(x):
1512    raise TypeError('Expected x to be numeric, instead found: %s' % x)
1513
1514  # If x has less than 2 elements, there is nothing to compare.  So return [].
1515  is_shorter_than_two = math_ops.less(array_ops.size(x), 2)
1516  short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype)
1517
1518  # With 2 or more elements, return x[1:] - x[:-1]
1519  s_len = array_ops.shape(x) - 1
1520  diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len)
1521  return control_flow_ops.cond(is_shorter_than_two, short_result, diff)
1522
1523
1524@tf_export(
1525    'debugging.is_numeric_tensor',
1526    v1=['debugging.is_numeric_tensor', 'is_numeric_tensor'])
1527@deprecation.deprecated_endpoints('is_numeric_tensor')
1528def is_numeric_tensor(tensor):
1529  """Returns `True` if the elements of `tensor` are numbers.
1530
1531  Specifically, returns `True` if the dtype of `tensor` is one of the following:
1532
1533  * `tf.float32`
1534  * `tf.float64`
1535  * `tf.int8`
1536  * `tf.int16`
1537  * `tf.int32`
1538  * `tf.int64`
1539  * `tf.uint8`
1540  * `tf.qint8`
1541  * `tf.qint32`
1542  * `tf.quint8`
1543  * `tf.complex64`
1544
1545  Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not
1546  a `tf.Tensor` object.
1547  """
1548  return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
1549
1550
1551@tf_export(
1552    'math.is_non_decreasing',
1553    v1=[
1554        'math.is_non_decreasing', 'debugging.is_non_decreasing',
1555        'is_non_decreasing'
1556    ])
1557@deprecation.deprecated_endpoints('debugging.is_non_decreasing',
1558                                  'is_non_decreasing')
1559def is_non_decreasing(x, name=None):
1560  """Returns `True` if `x` is non-decreasing.
1561
1562  Elements of `x` are compared in row-major order.  The tensor `[x[0],...]`
1563  is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`.
1564  If `x` has less than two elements, it is trivially non-decreasing.
1565
1566  See also:  `is_strictly_increasing`
1567
1568  Args:
1569    x: Numeric `Tensor`.
1570    name: A name for this operation (optional).  Defaults to "is_non_decreasing"
1571
1572  Returns:
1573    Boolean `Tensor`, equal to `True` iff `x` is non-decreasing.
1574
1575  Raises:
1576    TypeError: if `x` is not a numeric tensor.
1577  """
1578  with ops.name_scope(name, 'is_non_decreasing', [x]):
1579    diff = _get_diff_for_monotonic_comparison(x)
1580    # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True.
1581    zero = ops.convert_to_tensor(0, dtype=diff.dtype)
1582    return math_ops.reduce_all(math_ops.less_equal(zero, diff))
1583
1584
1585@tf_export(
1586    'math.is_strictly_increasing',
1587    v1=[
1588        'math.is_strictly_increasing', 'debugging.is_strictly_increasing',
1589        'is_strictly_increasing'
1590    ])
1591@deprecation.deprecated_endpoints('debugging.is_strictly_increasing',
1592                                  'is_strictly_increasing')
1593def is_strictly_increasing(x, name=None):
1594  """Returns `True` if `x` is strictly increasing.
1595
1596  Elements of `x` are compared in row-major order.  The tensor `[x[0],...]`
1597  is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`.
1598  If `x` has less than two elements, it is trivially strictly increasing.
1599
1600  See also:  `is_non_decreasing`
1601
1602  Args:
1603    x: Numeric `Tensor`.
1604    name: A name for this operation (optional).
1605      Defaults to "is_strictly_increasing"
1606
1607  Returns:
1608    Boolean `Tensor`, equal to `True` iff `x` is strictly increasing.
1609
1610  Raises:
1611    TypeError: if `x` is not a numeric tensor.
1612  """
1613  with ops.name_scope(name, 'is_strictly_increasing', [x]):
1614    diff = _get_diff_for_monotonic_comparison(x)
1615    # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True.
1616    zero = ops.convert_to_tensor(0, dtype=diff.dtype)
1617    return math_ops.reduce_all(math_ops.less(zero, diff))
1618
1619
1620def _assert_same_base_type(items, expected_type=None):
1621  r"""Asserts all items are of the same base type.
1622
1623  Args:
1624    items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
1625        `Operation`, or `IndexedSlices`). Can include `None` elements, which
1626        will be ignored.
1627    expected_type: Expected type. If not specified, assert all items are
1628        of the same base type.
1629
1630  Returns:
1631    Validated type, or none if neither expected_type nor items provided.
1632
1633  Raises:
1634    ValueError: If any types do not match.
1635  """
1636  original_expected_type = expected_type
1637  mismatch = False
1638  for item in items:
1639    if item is not None:
1640      item_type = item.dtype.base_dtype
1641      if not expected_type:
1642        expected_type = item_type
1643      elif expected_type != item_type:
1644        mismatch = True
1645        break
1646  if mismatch:
1647    # Loop back through and build up an informative error message (this is very
1648    # slow, so we don't do it unless we found an error above).
1649    expected_type = original_expected_type
1650    original_item_str = None
1651    for item in items:
1652      if item is not None:
1653        item_type = item.dtype.base_dtype
1654        if not expected_type:
1655          expected_type = item_type
1656          original_item_str = item.name if hasattr(item, 'name') else str(item)
1657        elif expected_type != item_type:
1658          raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % (
1659              item.name if hasattr(item, 'name') else str(item),
1660              item_type, expected_type,
1661              (' as %s' % original_item_str) if original_item_str else ''))
1662    return expected_type  # Should be unreachable
1663  else:
1664    return expected_type
1665
1666
1667@tf_export(
1668    'debugging.assert_same_float_dtype',
1669    v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype'])
1670@deprecation.deprecated_endpoints('assert_same_float_dtype')
1671def assert_same_float_dtype(tensors=None, dtype=None):
1672  """Validate and return float type based on `tensors` and `dtype`.
1673
1674  For ops such as matrix multiplication, inputs and weights must be of the
1675  same float type. This function validates that all `tensors` are the same type,
1676  validates that type is `dtype` (if supplied), and returns the type. Type must
1677  be a floating point type. If neither `tensors` nor `dtype` is supplied,
1678  the function will return `dtypes.float32`.
1679
1680  Args:
1681    tensors: Tensors of input values. Can include `None` elements, which will be
1682        ignored.
1683    dtype: Expected type.
1684
1685  Returns:
1686    Validated type.
1687
1688  Raises:
1689    ValueError: if neither `tensors` nor `dtype` is supplied, or result is not
1690        float, or the common type of the inputs is not a floating point type.
1691  """
1692  if tensors:
1693    dtype = _assert_same_base_type(tensors, dtype)
1694  if not dtype:
1695    dtype = dtypes.float32
1696  elif not dtype.is_floating:
1697    raise ValueError('Expected floating point type, got %s.' % dtype)
1698  return dtype
1699
1700
1701@tf_export('debugging.assert_scalar', v1=[])
1702def assert_scalar_v2(tensor, message=None, name=None):
1703  """Asserts that the given `tensor` is a scalar.
1704
1705  This function raises `ValueError` unless it can be certain that the given
1706  `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
1707  unknown.
1708
1709  Args:
1710    tensor: A `Tensor`.
1711    message: A string to prefix to the default message.
1712    name:  A name for this operation. Defaults to "assert_scalar"
1713
1714  Raises:
1715    ValueError: If the tensor is not scalar (rank 0), or if its shape is
1716      unknown.
1717  """
1718  assert_scalar(tensor=tensor, message=message, name=name)
1719
1720
1721@tf_export(v1=['debugging.assert_scalar', 'assert_scalar'])
1722@deprecation.deprecated_endpoints('assert_scalar')
1723def assert_scalar(tensor, name=None, message=None):
1724  """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional).
1725
1726  This function raises `ValueError` unless it can be certain that the given
1727  `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
1728  unknown.
1729
1730  Args:
1731    tensor: A `Tensor`.
1732    name:  A name for this operation. Defaults to "assert_scalar"
1733    message: A string to prefix to the default message.
1734
1735  Returns:
1736    The input tensor (potentially converted to a `Tensor`).
1737
1738  Raises:
1739    ValueError: If the tensor is not scalar (rank 0), or if its shape is
1740      unknown.
1741  """
1742  with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope:
1743    tensor = ops.convert_to_tensor(tensor, name=name_scope)
1744    shape = tensor.get_shape()
1745    if shape.ndims != 0:
1746      if context.executing_eagerly():
1747        raise ValueError('%sExpected scalar shape, saw shape: %s.'
1748                         % (message or '', shape,))
1749      else:
1750        raise ValueError('%sExpected scalar shape for %s, saw shape: %s.'
1751                         % (message or '', tensor.name, shape))
1752    return tensor
1753
1754
1755@tf_export('ensure_shape')
1756def ensure_shape(x, shape, name=None):
1757  """Updates the shape of a tensor and checks at runtime that the shape holds.
1758
1759  For example:
1760  ```python
1761  x = tf.placeholder(tf.int32)
1762  print(x.shape)
1763  ==> TensorShape(None)
1764  y = x * 2
1765  print(y.shape)
1766  ==> TensorShape(None)
1767
1768  y = tf.ensure_shape(y, (None, 3, 3))
1769  print(y.shape)
1770  ==> TensorShape([Dimension(None), Dimension(3), Dimension(3)])
1771
1772  with tf.Session() as sess:
1773    # Raises tf.errors.InvalidArgumentError, because the shape (3,) is not
1774    # compatible with the shape (None, 3, 3)
1775    sess.run(y, feed_dict={x: [1, 2, 3]})
1776
1777  ```
1778
1779  NOTE: This differs from `Tensor.set_shape` in that it sets the static shape
1780  of the resulting tensor and enforces it at runtime, raising an error if the
1781  tensor's runtime shape is incompatible with the specified shape.
1782  `Tensor.set_shape` sets the static shape of the tensor without enforcing it
1783  at runtime, which may result in inconsistencies between the statically-known
1784  shape of tensors and the runtime value of tensors.
1785
1786  Args:
1787    x: A `Tensor`.
1788    shape: A `TensorShape` representing the shape of this tensor, a
1789      `TensorShapeProto`, a list, a tuple, or None.
1790    name: A name for this operation (optional). Defaults to "EnsureShape".
1791
1792  Returns:
1793    A `Tensor`. Has the same type and contents as `x`. At runtime, raises a
1794    `tf.errors.InvalidArgumentError` if `shape` is incompatible with the shape
1795    of `x`.
1796  """
1797  if not isinstance(shape, tensor_shape.TensorShape):
1798    shape = tensor_shape.TensorShape(shape)
1799
1800  return array_ops.ensure_shape(x, shape, name=name)
1801
1802
1803@ops.RegisterGradient('EnsureShape')
1804def _ensure_shape_grad(op, grad):
1805  del op  # Unused.
1806  return grad
1807