1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for py_builtins module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import sys
22
23import six
24
25from tensorflow.python.autograph.core import converter
26from tensorflow.python.autograph.core import function_wrappers
27from tensorflow.python.autograph.operators import data_structures
28from tensorflow.python.autograph.operators import py_builtins
29from tensorflow.python.data.ops import dataset_ops
30from tensorflow.python.eager import def_function
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import errors_impl
34from tensorflow.python.framework import test_util
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import tensor_array_ops
38from tensorflow.python.platform import test
39
40
41class TestBase(object):
42
43  def overridden_method(self, x):
44    return x + 20
45
46
47@test_util.run_all_in_graph_and_eager_modes
48class PyBuiltinsTest(test.TestCase):
49
50  def test_abs(self):
51    self.assertEqual(py_builtins.abs_(-1), 1)
52    with self.cached_session() as sess:
53      t = py_builtins.abs_(constant_op.constant(-1))
54      self.assertEqual(self.evaluate(t), 1)
55      t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
56      self.assertAllEqual(self.evaluate(t), [1, 2, 3])
57
58  def test_abs_dataset(self):
59    dataset = dataset_ops.DatasetV2.from_tensor_slices([-1, 2, 3])
60    dataset = py_builtins.abs_(dataset)
61    iterator = dataset_ops.make_one_shot_iterator(dataset)
62    with self.cached_session() as sess:
63      self.assertAllEqual(self.evaluate(iterator.get_next()), 1)
64      self.assertAllEqual(self.evaluate(iterator.get_next()), 2)
65      self.assertAllEqual(self.evaluate(iterator.get_next()), 3)
66
67  def test_abs_dataset_zipped(self):
68    dataset_1 = dataset_ops.DatasetV2.from_tensor_slices([-1, 2, 3])
69    dataset_2 = dataset_ops.DatasetV2.from_tensor_slices([1, -2, 3])
70    dataset = dataset_ops.DatasetV2.zip((dataset_1, dataset_2))
71    dataset = py_builtins.abs_(dataset)
72    iterator = dataset_ops.make_one_shot_iterator(dataset)
73    with self.cached_session() as sess:
74      self.assertAllEqual(self.evaluate(iterator.get_next()), (1, 1))
75      self.assertAllEqual(self.evaluate(iterator.get_next()), (2, 2))
76      self.assertAllEqual(self.evaluate(iterator.get_next()), (3, 3))
77
78  def test_abs_dataset_mixed(self):
79    dataset_1 = dataset_ops.DatasetV2.from_tensor_slices([-1, 2, 3])
80    dataset_2 = dataset_ops.DatasetV2.from_tensor_slices([1, -2, 3])
81    dataset_3 = dataset_ops.DatasetV2.from_tensor_slices([-1, -2, -3])
82    dataset_4 = dataset_ops.DatasetV2.zip((dataset_1, dataset_2))
83    dataset = dataset_ops.DatasetV2.zip((dataset_3, dataset_4))
84    dataset = py_builtins.abs_(dataset)
85    iterator = dataset_ops.make_one_shot_iterator(dataset)
86    with self.cached_session() as sess:
87      for i in range(1, 4):
88        actual = self.evaluate(iterator.get_next())
89        self.assertAllEqual(actual[0], i)
90        self.assertAllEqual(actual[1], (i, i))
91
92  def test_float(self):
93    self.assertEqual(py_builtins.float_(10), 10.0)
94    self.assertEqual(py_builtins.float_('10.0'), 10.0)
95    with self.cached_session() as sess:
96      t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
97      self.assertEqual(self.evaluate(t), 1.0)
98      st = py_builtins.float_(constant_op.constant('1.0'))
99      self.assertEqual(self.evaluate(st), 1.0)
100
101  def test_int(self):
102    self.assertEqual(py_builtins.int_(10.0), 10)
103    self.assertEqual(py_builtins.int_('11', 2), 3)
104    with self.cached_session() as sess:
105      t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
106      self.assertEqual(self.evaluate(t), 1)
107      st = py_builtins.int_(constant_op.constant('1'))
108      self.assertEqual(self.evaluate(st), 1)
109      st = py_builtins.int_(constant_op.constant('1'), 10)
110      self.assertEqual(self.evaluate(st), 1)
111
112  def test_int_unsupported_base(self):
113    t = constant_op.constant(1, dtype=dtypes.float64)
114    with self.assertRaises(NotImplementedError):
115      py_builtins.int_(t, 2)
116
117  def test_len(self):
118    self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
119    with self.cached_session() as sess:
120      t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
121      self.assertEqual(t, 3)
122      ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
123      self.assertEqual(self.evaluate(ta), 5)
124      tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
125      self.assertEqual(self.evaluate(tl), 3)
126
127  def test_len_dataset(self):
128    dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
129    self.assertEqual(self.evaluate(py_builtins.len_(dataset)), 3)
130
131    # graph mode
132    @def_function.function(autograph=False)
133    def test_fn():
134      dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
135      return py_builtins.len_(dataset)
136
137    self.assertEqual(self.evaluate(test_fn()), 3)
138
139  def test_len_dataset_infinite(self):
140    dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
141    with self.assertRaises(errors_impl.InvalidArgumentError):
142      _ = self.evaluate(py_builtins.len_(dataset))
143
144    # graph mode
145    @def_function.function
146    def test_fn():
147      dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
148      return py_builtins.len_(dataset)
149
150    with self.assertRaises(errors_impl.InvalidArgumentError):
151      self.evaluate(test_fn())
152
153  def test_len_dataset_unknown(self):
154    dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
155    with self.assertRaises(errors_impl.InvalidArgumentError):
156      _ = self.evaluate(py_builtins.len_(dataset))
157
158    # graph mode
159    @def_function.function(autograph=False)
160    def test_fn():
161      dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
162      return py_builtins.len_(dataset)
163
164    with self.assertRaises(errors_impl.InvalidArgumentError):
165      self.evaluate(test_fn())
166
167  def test_len_scalar(self):
168    with self.assertRaises(ValueError):
169      py_builtins.len_(constant_op.constant(1))
170
171  @test_util.run_deprecated_v1
172  def test_len_dynamic_shape(self):
173    with self.cached_session() as sess:
174      p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
175      t = py_builtins.len_(p)
176      self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
177
178      with self.assertRaises(errors_impl.InvalidArgumentError):
179        t = py_builtins.len_(p)
180        sess.run(t, {p: 1})
181
182  @test_util.run_deprecated_v1
183  def test_print_tensors(self):
184    try:
185      out_capturer = six.StringIO()
186      sys.stdout = out_capturer
187      with self.cached_session() as sess:
188        sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
189        self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
190    finally:
191      sys.stdout = sys.__stdout__
192
193  @test_util.run_deprecated_v1
194  def test_print_complex(self):
195    try:
196      out_capturer = six.StringIO()
197      sys.stdout = out_capturer
198      with self.cached_session() as sess:
199        sess.run(
200            py_builtins.print_(constant_op.constant('test message'), [1, 2]))
201        self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
202    finally:
203      sys.stdout = sys.__stdout__
204
205  def test_range(self):
206    self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2])
207    self.assertListEqual(list(py_builtins.range_(1, 3)), [1, 2])
208    self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
209
210  def test_range_tensor(self):
211    with self.cached_session() as sess:
212      r = py_builtins.range_(constant_op.constant(3))
213      self.assertAllEqual(self.evaluate(r), [0, 1, 2])
214      r = py_builtins.range_(1, constant_op.constant(3))
215      self.assertAllEqual(self.evaluate(r), [1, 2])
216      r = py_builtins.range_(2, 0, constant_op.constant(-1))
217      self.assertAllEqual(self.evaluate(r), [2, 1])
218
219  def test_range_tensor_empty_range(self):
220    with self.session() as sess:
221      r = py_builtins.range_(constant_op.constant(-3))
222      self.assertAllEqual(self.evaluate(r), [])
223      r = py_builtins.range_(5, constant_op.constant(2))
224      self.assertAllEqual(self.evaluate(r), [])
225
226  def test_enumerate(self):
227    self.assertListEqual(
228        list(py_builtins.enumerate_([3, 2, 1])), [(0, 3), (1, 2), (2, 1)])
229    self.assertListEqual(
230        list(py_builtins.enumerate_([3, 2, 1], 5)), [(5, 3), (6, 2), (7, 1)])
231    self.assertListEqual(list(py_builtins.enumerate_([-8], -3)), [(-3, -8)])
232
233  def test_enumerate_dataset(self):
234    dataset = dataset_ops.DatasetV2.from_tensor_slices(['a', 'c'])
235    start = constant_op.constant(20, dtype=dtypes.int64)
236    dataset = py_builtins.enumerate_(dataset, start)
237    iterator = dataset_ops.make_one_shot_iterator(dataset)
238    with self.cached_session() as sess:
239      self.assertAllEqual(self.evaluate(iterator.get_next()), (20, b'a'))
240      self.assertAllEqual(self.evaluate(iterator.get_next()), (21, b'c'))
241
242  def test_zip(self):
243    self.assertListEqual(
244        list(py_builtins.zip_([3, 2, 1], [1, 2, 3])), [(3, 1), (2, 2), (1, 3)])
245    self.assertListEqual(
246        list(py_builtins.zip_([4, 5, 6], [-1, -2])), [(4, -1), (5, -2)])
247
248  def test_zip_dataset(self):
249    ds1 = dataset_ops.DatasetV2.from_tensor_slices([-11, -12, 4])
250    ds2 = dataset_ops.DatasetV2.from_tensor_slices([-21, -22, 5])
251    ds3 = py_builtins.zip_(ds1, ds2)
252    iterator = dataset_ops.make_one_shot_iterator(ds3)
253    with self.cached_session() as sess:
254      self.assertAllEqual(self.evaluate(iterator.get_next()), (-11, -21))
255      self.assertAllEqual(self.evaluate(iterator.get_next()), (-12, -22))
256      self.assertAllEqual(self.evaluate(iterator.get_next()), (4, 5))
257
258  def test_map(self):
259
260    def increment(x):
261      return x + 1
262
263    add_list = lambda x, y: x + y
264    self.assertListEqual(
265        list(py_builtins.map_(increment, [4, 5, 6])), [5, 6, 7])
266    self.assertListEqual(
267        list(py_builtins.map_(add_list, [3, 2, 1], [-1, -2, -3])), [2, 0, -2])
268
269  def test_map_dataset(self):
270
271    def increment(x):
272      return x + 1
273
274    ds1 = dataset_ops.DatasetV2.from_tensor_slices([4, 5, 6])
275    ds2 = py_builtins.map_(increment, ds1)
276    iterator = dataset_ops.make_one_shot_iterator(ds2)
277    with self.cached_session() as sess:
278      self.assertAllEqual(self.evaluate(iterator.get_next()), 5)
279      self.assertAllEqual(self.evaluate(iterator.get_next()), 6)
280      self.assertAllEqual(self.evaluate(iterator.get_next()), 7)
281
282  def test_map_multiple_datasets(self):
283    add_list = lambda x, y: x + y
284    ds1 = dataset_ops.DatasetV2.from_tensor_slices([-11, -12, 4])
285    ds2 = dataset_ops.DatasetV2.from_tensor_slices([-21, -22, 5])
286    ds3 = py_builtins.map_(add_list, ds1, ds2)
287    iterator = dataset_ops.make_one_shot_iterator(ds3)
288    with self.cached_session() as sess:
289      self.assertAllEqual(self.evaluate(iterator.get_next()), -32)
290      self.assertAllEqual(self.evaluate(iterator.get_next()), -34)
291      self.assertAllEqual(self.evaluate(iterator.get_next()), 9)
292
293  def test_next_normal(self):
294    iterator = iter([1, 2, 3])
295    self.assertEqual(py_builtins.next_(iterator), 1)
296    self.assertEqual(py_builtins.next_(iterator), 2)
297    self.assertEqual(py_builtins.next_(iterator), 3)
298    with self.assertRaises(StopIteration):
299      py_builtins.next_(iterator)
300    self.assertEqual(py_builtins.next_(iterator, 4), 4)
301
302  def test_next_tf_iterator(self):
303    # graph-mode iterators are only supported inside tf.function.
304    @def_function.function(autograph=False)
305    def test_fn(go_out_of_range, with_default):
306      iterator = iter(dataset_ops.Dataset.range(3))
307      retval = (
308          py_builtins.next_(iterator),
309          py_builtins.next_(iterator),
310          py_builtins.next_(iterator),
311      )
312      if go_out_of_range:
313        if with_default:
314          retval += (
315              py_builtins.next_(iterator,
316                                constant_op.constant(-3, dtype=dtypes.int64)),
317              py_builtins.next_(iterator,
318                                constant_op.constant(-4, dtype=dtypes.int64)),
319          )
320        else:
321          py_builtins.next_(iterator)
322      return retval
323
324    self.assertAllEqual(
325        self.evaluate(test_fn(go_out_of_range=False, with_default=None)),
326        (0, 1, 2))
327    self.assertAllEqual(
328        self.evaluate(test_fn(go_out_of_range=True, with_default=True)),
329        (0, 1, 2, -3, -4))
330    with self.assertRaises(errors_impl.OutOfRangeError):
331      self.evaluate(test_fn(go_out_of_range=True, with_default=False))
332
333  def test_next_tf_iterator_error_checking(self):
334    # graph-mode iterators are only supported inside tf.function.
335    @def_function.function(autograph=False)
336    def test_fn():
337      iterator = iter(dataset_ops.Dataset.range(1))
338      py_builtins.next_(iterator)
339      py_builtins.next_(iterator, constant_op.constant(-3))
340
341    # Dataset.range defaults to int64,
342    with self.assertRaisesRegex(TypeError, 'default.*int64'):
343      self.evaluate(test_fn())
344
345  def test_next_tf_iterator_error_checking_structures(self):
346    # graph-mode iterators are only supported inside tf.function.
347    @def_function.function(autograph=False)
348    def test_fn(default_val):
349      ds = dataset_ops.Dataset.range(1)
350      ds = ds.map(lambda i: {'a': i + 1, 'b': i + 10})
351      iterator = iter(ds)
352      py_builtins.next_(iterator)
353      py_builtins.next_(iterator, default_val)
354
355    default = {
356        'a': constant_op.constant(3, dtype=dtypes.int64),
357    }
358    with self.assertRaisesRegex(TypeError, 'same element structure'):
359      test_fn(default)
360    default = {
361        'a': constant_op.constant(3.0),
362        'b': [constant_op.constant(30), constant_op.constant(300)]
363    }
364    with self.assertRaisesRegex(TypeError, 'same element structure'):
365      test_fn(default)
366    default = {
367        'a': constant_op.constant(3.0),
368        'b': constant_op.constant(30, dtype=dtypes.int64),
369    }
370    with self.assertRaisesRegex(TypeError, 'float32'):
371      test_fn(default)
372
373  def _basic_function_scope(self):
374    return function_wrappers.FunctionScope(
375        'test_function_name',
376        'test_scope',  # Note: this must match the name in the `with` statement.
377        converter.ConversionOptions())
378
379  def test_eval_in_original_context(self):
380
381    def test_fn():
382      l = 1  # pylint:disable=unused-variable
383      with self._basic_function_scope() as test_scope:
384        return py_builtins.eval_in_original_context(eval, ('l',), test_scope)
385
386    self.assertEqual(test_fn(), 1)
387
388  def test_eval_in_original_context_inner_function(self):
389
390    def test_fn():
391      l = 1  # pylint:disable=unused-variable
392      with self._basic_function_scope() as test_scope:
393
394        def inner_fn():
395          # Note: a user function without a top-level function scope should
396          # never be found in user code; it's only possible in generated code.
397          l = 2  # pylint:disable=unused-variable
398          return py_builtins.eval_in_original_context(eval, ('l',), test_scope)
399
400        return inner_fn()
401
402    self.assertEqual(test_fn(), 2)
403
404  def test_locals_in_original_context(self):
405
406    def test_fn():
407      l = 1  # pylint:disable=unused-variable
408      with self._basic_function_scope() as test_scope:
409        return py_builtins.locals_in_original_context(test_scope)
410
411    locs = test_fn()
412
413    self.assertEqual(locs['l'], 1)
414
415  def test_locals_in_original_context_inner_function(self):
416
417    def test_fn():
418      l = 1  # pylint:disable=unused-variable
419      with self._basic_function_scope() as test_scope:
420
421        def inner_fn():
422          # Note: a user function without a top-level function scope should
423          # never be found in user code; it's only possible in generated code.
424          l = 2  # pylint:disable=unused-variable
425          return py_builtins.locals_in_original_context(test_scope)
426
427        return inner_fn()
428
429    locs = test_fn()
430
431    self.assertEqual(locs['l'], 2)
432
433  def test_globals_in_original_context(self):
434
435    def test_fn():
436      with self._basic_function_scope() as test_scope:
437        return py_builtins.globals_in_original_context(test_scope)
438
439    globs = test_fn()
440
441    self.assertIs(globs['TestBase'], TestBase)
442
443  def test_globals_in_original_context_inner_function(self):
444
445    def test_fn():
446      with self._basic_function_scope() as test_scope:
447
448        def inner_fn():
449          # Note: a user function without a top-level function scope should
450          # never be found in user code; it's only possible in generated code.
451          return py_builtins.globals_in_original_context(test_scope)
452
453        return inner_fn()
454
455    globs = test_fn()
456
457    self.assertIs(globs['TestBase'], TestBase)
458
459  def test_super_in_original_context_unary_call(self):
460    test_case_self = self
461
462    class TestSubclass(TestBase):
463
464      def overridden_method(self, x):
465        test_case_self.fail('This should never be called.')
466
467      def test_method(self):
468        with test_case_self._basic_function_scope() as test_scope:
469          test_base_unbound = py_builtins.super_in_original_context(
470              super, (TestSubclass,), test_scope)
471          test_base = test_base_unbound.__get__(self, TestSubclass)
472          return test_base.overridden_method(1)
473
474    tc = TestSubclass()
475    self.assertEqual(tc.test_method(), 21)
476
477  def test_super_in_original_context_binary_call(self):
478    test_case_self = self
479
480    class TestSubclass(TestBase):
481
482      def overridden_method(self, x):
483        test_case_self.fail('This should never be called.')
484
485      def test_method(self):
486        with test_case_self._basic_function_scope() as test_scope:
487          test_base = py_builtins.super_in_original_context(
488              super, (TestSubclass, self), test_scope)
489          return test_base.overridden_method(1)
490
491    tc = TestSubclass()
492    self.assertEqual(tc.test_method(), 21)
493
494  def test_super_in_original_context_niladic_call(self):
495    test_case_self = self
496
497    class TestSubclass(TestBase):
498
499      def overridden_method(self, x):
500        test_case_self.fail('This should never be called.')
501
502      def test_method(self):
503        with test_case_self._basic_function_scope() as test_scope:
504          b = py_builtins.super_in_original_context(super, (), test_scope)
505          return b.overridden_method(1)
506
507    tc = TestSubclass()
508    self.assertEqual(tc.test_method(), 21)
509
510  def test_super_in_original_context_caller_with_locals(self):
511    test_case_self = self
512
513    class TestSubclass(TestBase):
514
515      def overridden_method(self, x):
516        test_case_self.fail('This should never be called.')
517
518      def test_method(self, x):
519        y = 7
520        with test_case_self._basic_function_scope() as test_scope:
521          z = 7
522          return py_builtins.super_in_original_context(
523              super, (), test_scope).overridden_method(x + y - z)
524
525    tc = TestSubclass()
526    self.assertEqual(tc.test_method(1), 21)
527
528  def test_super_in_original_context_inner_function(self):
529    test_case_self = self
530
531    class TestSubclass(TestBase):
532
533      def overridden_method(self, x):
534        test_case_self.fail('This should never be called.')
535
536      def test_method(self, x):
537        with test_case_self._basic_function_scope() as test_scope:
538          # Oddly, it's sufficient to use `self` in an inner function
539          # to gain access to __class__ in this scope.
540          # TODO(mdan): Is this true across implementations?
541          # Note: normally, it's illegal to use super() in inner functions (it
542          # throws an error), but the generated code may create them.
543          def inner_fn():
544            return py_builtins.super_in_original_context(
545                super, (), test_scope).overridden_method(x)
546
547          return inner_fn()
548
549    tc = TestSubclass()
550    self.assertEqual(tc.test_method(1), 21)
551
552  def test_super_in_original_context_inner_lambda(self):
553    test_case_self = self
554
555    class TestSubclass(TestBase):
556
557      def overridden_method(self, x):
558        test_case_self.fail('This should never be called.')
559
560      def test_method(self, x):
561        with test_case_self._basic_function_scope() as test_scope:
562          # Oddly, it's sufficient to use `self` in an inner function
563          # to gain access to __class__ in this scope.
564          # TODO(mdan): Is this true across implementations?
565          # Note: normally, it's illegal to use super() in inner functions (it
566          # throws an error), but the generated code may create them.
567          l = lambda: py_builtins.super_in_original_context(  # pylint:disable=g-long-lambda
568              super, (), test_scope).overridden_method(x)
569          return l()
570
571    tc = TestSubclass()
572    self.assertEqual(tc.test_method(1), 21)
573
574  def test_filter(self):
575    self.assertListEqual(
576        list(py_builtins.filter_(lambda x: x == 'b', ['a', 'b', 'c'])), ['b'])
577    self.assertListEqual(
578        list(py_builtins.filter_(lambda x: x < 3, [3, 2, 1])), [2, 1])
579
580  def test_filter_dataset(self):
581    dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
582    dataset = py_builtins.filter_(lambda x: x < 3, dataset)
583    iterator = dataset_ops.make_one_shot_iterator(dataset)
584    with self.cached_session() as sess:
585      self.assertAllEqual(self.evaluate(iterator.get_next()), 2)
586      self.assertAllEqual(self.evaluate(iterator.get_next()), 1)
587
588  def test_any(self):
589    self.assertEqual(py_builtins.any_([False, True, False]), True)
590    self.assertEqual(py_builtins.any_([False, False, False]), False)
591
592  def test_any_dataset(self):
593    dataset_1 = dataset_ops.DatasetV2.from_tensor_slices([False, True, False])
594    dataset_2 = dataset_ops.DatasetV2.from_tensor_slices([False, False, False])
595    self.assertEqual(self.evaluate(py_builtins.any_(dataset_1)), True)
596    self.assertEqual(self.evaluate(py_builtins.any_(dataset_2)), False)
597
598    dataset_3 = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2])
599    with self.assertRaises(ValueError):
600      py_builtins.any_(dataset_3)
601
602    dataset_4 = dataset_ops.DatasetV2.from_tensor_slices([False, True, False])
603    dataset_zipped = dataset_ops.DatasetV2.zip((dataset_4, dataset_4))
604    with self.assertRaises(ValueError):
605      py_builtins.any_(dataset_zipped)
606
607    dataset_mixed = dataset_ops.DatasetV2.zip((dataset_3, dataset_4))
608    with self.assertRaises(ValueError):
609      py_builtins.any_(dataset_mixed)
610
611  def test_all(self):
612    self.assertEqual(py_builtins.all_([False, True, False]), False)
613    self.assertEqual(py_builtins.all_([True, True, True]), True)
614
615  def test_all_dataset(self):
616    dataset_1 = dataset_ops.DatasetV2.from_tensor_slices([False, True, False])
617    dataset_2 = dataset_ops.DatasetV2.from_tensor_slices([True, True, True])
618    self.assertEqual(self.evaluate(py_builtins.all_(dataset_1)), False)
619    self.assertEqual(self.evaluate(py_builtins.all_(dataset_2)), True)
620
621    dataset_3 = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2])
622    with self.assertRaises(ValueError):
623      py_builtins.all_(dataset_3)
624
625    dataset_4 = dataset_ops.DatasetV2.from_tensor_slices([False, True, False])
626    dataset_zipped = dataset_ops.DatasetV2.zip((dataset_4, dataset_4))
627    with self.assertRaises(ValueError):
628      py_builtins.all_(dataset_zipped)
629
630    dataset_mixed = dataset_ops.DatasetV2.zip((dataset_3, dataset_4))
631    with self.assertRaises(ValueError):
632      py_builtins.all_(dataset_mixed)
633
634  def test_sorted(self):
635    self.assertListEqual(py_builtins.sorted_([2, 3, 1]), [1, 2, 3])
636    self.assertListEqual(
637        py_builtins.sorted_([2, 3, 1], key=lambda x: -x), [3, 2, 1])
638    self.assertListEqual(
639        py_builtins.sorted_([2, 3, 1], reverse=True), [3, 2, 1])
640    self.assertListEqual(
641        py_builtins.sorted_([2, 3, 1], key=lambda x: -x, reverse=True),
642        [1, 2, 3])
643    self.assertAllEqual(
644        py_builtins.sorted_([[4, 3], [2, 1]], key=lambda x: sum(x)),
645        [[2, 1], [4, 3]])
646
647  def test_sorted_tensor(self):
648    iterable_1 = constant_op.constant([2, 3, 1])
649    self.assertListEqual(
650        list(self.evaluate(py_builtins.sorted_(iterable_1))), [1, 2, 3])
651    self.assertListEqual(
652        list(self.evaluate(py_builtins.sorted_(iterable_1, key=lambda x: -x))),
653        [3, 2, 1])
654    self.assertListEqual(
655        list(self.evaluate(py_builtins.sorted_(iterable_1, reverse=True))),
656        [3, 2, 1])
657    self.assertListEqual(
658        list(
659            self.evaluate(
660                py_builtins.sorted_(iterable_1, key=lambda x: -x,
661                                    reverse=True))), [1, 2, 3])
662
663    iterable_2 = constant_op.constant([[4, 3], [2, 1]])
664    with self.assertRaises(ValueError):
665      py_builtins.sorted_(iterable_2)
666    with self.assertRaises(ValueError):
667      py_builtins.sorted_(iterable_2, key=lambda x: -x)
668    self.assertAllEqual(
669        list(
670            self.evaluate(
671                py_builtins.sorted_(
672                    iterable_2, key=lambda x: math_ops.reduce_sum(x)))),
673        [[2, 1], [4, 3]])
674
675
676if __name__ == '__main__':
677  test.main()
678