1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for py_builtins module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import sys 22 23import six 24 25from tensorflow.python.autograph.core import converter 26from tensorflow.python.autograph.core import function_wrappers 27from tensorflow.python.autograph.operators import data_structures 28from tensorflow.python.autograph.operators import py_builtins 29from tensorflow.python.data.ops import dataset_ops 30from tensorflow.python.eager import def_function 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import errors_impl 34from tensorflow.python.framework import test_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import tensor_array_ops 38from tensorflow.python.platform import test 39 40 41class TestBase(object): 42 43 def overridden_method(self, x): 44 return x + 20 45 46 47@test_util.run_all_in_graph_and_eager_modes 48class PyBuiltinsTest(test.TestCase): 49 50 def test_abs(self): 51 self.assertEqual(py_builtins.abs_(-1), 1) 52 with self.cached_session() as sess: 53 t = py_builtins.abs_(constant_op.constant(-1)) 54 self.assertEqual(self.evaluate(t), 1) 55 t = py_builtins.abs_(constant_op.constant([-1, 2, -3])) 56 self.assertAllEqual(self.evaluate(t), [1, 2, 3]) 57 58 def test_abs_dataset(self): 59 dataset = dataset_ops.DatasetV2.from_tensor_slices([-1, 2, 3]) 60 dataset = py_builtins.abs_(dataset) 61 iterator = dataset_ops.make_one_shot_iterator(dataset) 62 with self.cached_session() as sess: 63 self.assertAllEqual(self.evaluate(iterator.get_next()), 1) 64 self.assertAllEqual(self.evaluate(iterator.get_next()), 2) 65 self.assertAllEqual(self.evaluate(iterator.get_next()), 3) 66 67 def test_abs_dataset_zipped(self): 68 dataset_1 = dataset_ops.DatasetV2.from_tensor_slices([-1, 2, 3]) 69 dataset_2 = dataset_ops.DatasetV2.from_tensor_slices([1, -2, 3]) 70 dataset = dataset_ops.DatasetV2.zip((dataset_1, dataset_2)) 71 dataset = py_builtins.abs_(dataset) 72 iterator = dataset_ops.make_one_shot_iterator(dataset) 73 with self.cached_session() as sess: 74 self.assertAllEqual(self.evaluate(iterator.get_next()), (1, 1)) 75 self.assertAllEqual(self.evaluate(iterator.get_next()), (2, 2)) 76 self.assertAllEqual(self.evaluate(iterator.get_next()), (3, 3)) 77 78 def test_abs_dataset_mixed(self): 79 dataset_1 = dataset_ops.DatasetV2.from_tensor_slices([-1, 2, 3]) 80 dataset_2 = dataset_ops.DatasetV2.from_tensor_slices([1, -2, 3]) 81 dataset_3 = dataset_ops.DatasetV2.from_tensor_slices([-1, -2, -3]) 82 dataset_4 = dataset_ops.DatasetV2.zip((dataset_1, dataset_2)) 83 dataset = dataset_ops.DatasetV2.zip((dataset_3, dataset_4)) 84 dataset = py_builtins.abs_(dataset) 85 iterator = dataset_ops.make_one_shot_iterator(dataset) 86 with self.cached_session() as sess: 87 for i in range(1, 4): 88 actual = self.evaluate(iterator.get_next()) 89 self.assertAllEqual(actual[0], i) 90 self.assertAllEqual(actual[1], (i, i)) 91 92 def test_float(self): 93 self.assertEqual(py_builtins.float_(10), 10.0) 94 self.assertEqual(py_builtins.float_('10.0'), 10.0) 95 with self.cached_session() as sess: 96 t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64)) 97 self.assertEqual(self.evaluate(t), 1.0) 98 st = py_builtins.float_(constant_op.constant('1.0')) 99 self.assertEqual(self.evaluate(st), 1.0) 100 101 def test_int(self): 102 self.assertEqual(py_builtins.int_(10.0), 10) 103 self.assertEqual(py_builtins.int_('11', 2), 3) 104 with self.cached_session() as sess: 105 t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64)) 106 self.assertEqual(self.evaluate(t), 1) 107 st = py_builtins.int_(constant_op.constant('1')) 108 self.assertEqual(self.evaluate(st), 1) 109 st = py_builtins.int_(constant_op.constant('1'), 10) 110 self.assertEqual(self.evaluate(st), 1) 111 112 def test_int_unsupported_base(self): 113 t = constant_op.constant(1, dtype=dtypes.float64) 114 with self.assertRaises(NotImplementedError): 115 py_builtins.int_(t, 2) 116 117 def test_len(self): 118 self.assertEqual(py_builtins.len_([1, 2, 3]), 3) 119 with self.cached_session() as sess: 120 t = py_builtins.len_(constant_op.constant([[1], [2], [3]])) 121 self.assertEqual(t, 3) 122 ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5)) 123 self.assertEqual(self.evaluate(ta), 5) 124 tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5])) 125 self.assertEqual(self.evaluate(tl), 3) 126 127 def test_len_dataset(self): 128 dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1]) 129 self.assertEqual(self.evaluate(py_builtins.len_(dataset)), 3) 130 131 # graph mode 132 @def_function.function(autograph=False) 133 def test_fn(): 134 dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1]) 135 return py_builtins.len_(dataset) 136 137 self.assertEqual(self.evaluate(test_fn()), 3) 138 139 def test_len_dataset_infinite(self): 140 dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2) 141 with self.assertRaises(errors_impl.InvalidArgumentError): 142 _ = self.evaluate(py_builtins.len_(dataset)) 143 144 # graph mode 145 @def_function.function 146 def test_fn(): 147 dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2) 148 return py_builtins.len_(dataset) 149 150 with self.assertRaises(errors_impl.InvalidArgumentError): 151 self.evaluate(test_fn()) 152 153 def test_len_dataset_unknown(self): 154 dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2) 155 with self.assertRaises(errors_impl.InvalidArgumentError): 156 _ = self.evaluate(py_builtins.len_(dataset)) 157 158 # graph mode 159 @def_function.function(autograph=False) 160 def test_fn(): 161 dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2) 162 return py_builtins.len_(dataset) 163 164 with self.assertRaises(errors_impl.InvalidArgumentError): 165 self.evaluate(test_fn()) 166 167 def test_len_scalar(self): 168 with self.assertRaises(ValueError): 169 py_builtins.len_(constant_op.constant(1)) 170 171 @test_util.run_deprecated_v1 172 def test_len_dynamic_shape(self): 173 with self.cached_session() as sess: 174 p = array_ops.placeholder(dtype=dtypes.int32, shape=None) 175 t = py_builtins.len_(p) 176 self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3) 177 178 with self.assertRaises(errors_impl.InvalidArgumentError): 179 t = py_builtins.len_(p) 180 sess.run(t, {p: 1}) 181 182 @test_util.run_deprecated_v1 183 def test_print_tensors(self): 184 try: 185 out_capturer = six.StringIO() 186 sys.stdout = out_capturer 187 with self.cached_session() as sess: 188 sess.run(py_builtins.print_(constant_op.constant('test message'), 1)) 189 self.assertEqual(out_capturer.getvalue(), 'test message 1\n') 190 finally: 191 sys.stdout = sys.__stdout__ 192 193 @test_util.run_deprecated_v1 194 def test_print_complex(self): 195 try: 196 out_capturer = six.StringIO() 197 sys.stdout = out_capturer 198 with self.cached_session() as sess: 199 sess.run( 200 py_builtins.print_(constant_op.constant('test message'), [1, 2])) 201 self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') 202 finally: 203 sys.stdout = sys.__stdout__ 204 205 def test_range(self): 206 self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2]) 207 self.assertListEqual(list(py_builtins.range_(1, 3)), [1, 2]) 208 self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1]) 209 210 def test_range_tensor(self): 211 with self.cached_session() as sess: 212 r = py_builtins.range_(constant_op.constant(3)) 213 self.assertAllEqual(self.evaluate(r), [0, 1, 2]) 214 r = py_builtins.range_(1, constant_op.constant(3)) 215 self.assertAllEqual(self.evaluate(r), [1, 2]) 216 r = py_builtins.range_(2, 0, constant_op.constant(-1)) 217 self.assertAllEqual(self.evaluate(r), [2, 1]) 218 219 def test_range_tensor_empty_range(self): 220 with self.session() as sess: 221 r = py_builtins.range_(constant_op.constant(-3)) 222 self.assertAllEqual(self.evaluate(r), []) 223 r = py_builtins.range_(5, constant_op.constant(2)) 224 self.assertAllEqual(self.evaluate(r), []) 225 226 def test_enumerate(self): 227 self.assertListEqual( 228 list(py_builtins.enumerate_([3, 2, 1])), [(0, 3), (1, 2), (2, 1)]) 229 self.assertListEqual( 230 list(py_builtins.enumerate_([3, 2, 1], 5)), [(5, 3), (6, 2), (7, 1)]) 231 self.assertListEqual(list(py_builtins.enumerate_([-8], -3)), [(-3, -8)]) 232 233 def test_enumerate_dataset(self): 234 dataset = dataset_ops.DatasetV2.from_tensor_slices(['a', 'c']) 235 start = constant_op.constant(20, dtype=dtypes.int64) 236 dataset = py_builtins.enumerate_(dataset, start) 237 iterator = dataset_ops.make_one_shot_iterator(dataset) 238 with self.cached_session() as sess: 239 self.assertAllEqual(self.evaluate(iterator.get_next()), (20, b'a')) 240 self.assertAllEqual(self.evaluate(iterator.get_next()), (21, b'c')) 241 242 def test_zip(self): 243 self.assertListEqual( 244 list(py_builtins.zip_([3, 2, 1], [1, 2, 3])), [(3, 1), (2, 2), (1, 3)]) 245 self.assertListEqual( 246 list(py_builtins.zip_([4, 5, 6], [-1, -2])), [(4, -1), (5, -2)]) 247 248 def test_zip_dataset(self): 249 ds1 = dataset_ops.DatasetV2.from_tensor_slices([-11, -12, 4]) 250 ds2 = dataset_ops.DatasetV2.from_tensor_slices([-21, -22, 5]) 251 ds3 = py_builtins.zip_(ds1, ds2) 252 iterator = dataset_ops.make_one_shot_iterator(ds3) 253 with self.cached_session() as sess: 254 self.assertAllEqual(self.evaluate(iterator.get_next()), (-11, -21)) 255 self.assertAllEqual(self.evaluate(iterator.get_next()), (-12, -22)) 256 self.assertAllEqual(self.evaluate(iterator.get_next()), (4, 5)) 257 258 def test_map(self): 259 260 def increment(x): 261 return x + 1 262 263 add_list = lambda x, y: x + y 264 self.assertListEqual( 265 list(py_builtins.map_(increment, [4, 5, 6])), [5, 6, 7]) 266 self.assertListEqual( 267 list(py_builtins.map_(add_list, [3, 2, 1], [-1, -2, -3])), [2, 0, -2]) 268 269 def test_map_dataset(self): 270 271 def increment(x): 272 return x + 1 273 274 ds1 = dataset_ops.DatasetV2.from_tensor_slices([4, 5, 6]) 275 ds2 = py_builtins.map_(increment, ds1) 276 iterator = dataset_ops.make_one_shot_iterator(ds2) 277 with self.cached_session() as sess: 278 self.assertAllEqual(self.evaluate(iterator.get_next()), 5) 279 self.assertAllEqual(self.evaluate(iterator.get_next()), 6) 280 self.assertAllEqual(self.evaluate(iterator.get_next()), 7) 281 282 def test_map_multiple_datasets(self): 283 add_list = lambda x, y: x + y 284 ds1 = dataset_ops.DatasetV2.from_tensor_slices([-11, -12, 4]) 285 ds2 = dataset_ops.DatasetV2.from_tensor_slices([-21, -22, 5]) 286 ds3 = py_builtins.map_(add_list, ds1, ds2) 287 iterator = dataset_ops.make_one_shot_iterator(ds3) 288 with self.cached_session() as sess: 289 self.assertAllEqual(self.evaluate(iterator.get_next()), -32) 290 self.assertAllEqual(self.evaluate(iterator.get_next()), -34) 291 self.assertAllEqual(self.evaluate(iterator.get_next()), 9) 292 293 def test_next_normal(self): 294 iterator = iter([1, 2, 3]) 295 self.assertEqual(py_builtins.next_(iterator), 1) 296 self.assertEqual(py_builtins.next_(iterator), 2) 297 self.assertEqual(py_builtins.next_(iterator), 3) 298 with self.assertRaises(StopIteration): 299 py_builtins.next_(iterator) 300 self.assertEqual(py_builtins.next_(iterator, 4), 4) 301 302 def test_next_tf_iterator(self): 303 # graph-mode iterators are only supported inside tf.function. 304 @def_function.function(autograph=False) 305 def test_fn(go_out_of_range, with_default): 306 iterator = iter(dataset_ops.Dataset.range(3)) 307 retval = ( 308 py_builtins.next_(iterator), 309 py_builtins.next_(iterator), 310 py_builtins.next_(iterator), 311 ) 312 if go_out_of_range: 313 if with_default: 314 retval += ( 315 py_builtins.next_(iterator, 316 constant_op.constant(-3, dtype=dtypes.int64)), 317 py_builtins.next_(iterator, 318 constant_op.constant(-4, dtype=dtypes.int64)), 319 ) 320 else: 321 py_builtins.next_(iterator) 322 return retval 323 324 self.assertAllEqual( 325 self.evaluate(test_fn(go_out_of_range=False, with_default=None)), 326 (0, 1, 2)) 327 self.assertAllEqual( 328 self.evaluate(test_fn(go_out_of_range=True, with_default=True)), 329 (0, 1, 2, -3, -4)) 330 with self.assertRaises(errors_impl.OutOfRangeError): 331 self.evaluate(test_fn(go_out_of_range=True, with_default=False)) 332 333 def test_next_tf_iterator_error_checking(self): 334 # graph-mode iterators are only supported inside tf.function. 335 @def_function.function(autograph=False) 336 def test_fn(): 337 iterator = iter(dataset_ops.Dataset.range(1)) 338 py_builtins.next_(iterator) 339 py_builtins.next_(iterator, constant_op.constant(-3)) 340 341 # Dataset.range defaults to int64, 342 with self.assertRaisesRegex(TypeError, 'default.*int64'): 343 self.evaluate(test_fn()) 344 345 def test_next_tf_iterator_error_checking_structures(self): 346 # graph-mode iterators are only supported inside tf.function. 347 @def_function.function(autograph=False) 348 def test_fn(default_val): 349 ds = dataset_ops.Dataset.range(1) 350 ds = ds.map(lambda i: {'a': i + 1, 'b': i + 10}) 351 iterator = iter(ds) 352 py_builtins.next_(iterator) 353 py_builtins.next_(iterator, default_val) 354 355 default = { 356 'a': constant_op.constant(3, dtype=dtypes.int64), 357 } 358 with self.assertRaisesRegex(TypeError, 'same element structure'): 359 test_fn(default) 360 default = { 361 'a': constant_op.constant(3.0), 362 'b': [constant_op.constant(30), constant_op.constant(300)] 363 } 364 with self.assertRaisesRegex(TypeError, 'same element structure'): 365 test_fn(default) 366 default = { 367 'a': constant_op.constant(3.0), 368 'b': constant_op.constant(30, dtype=dtypes.int64), 369 } 370 with self.assertRaisesRegex(TypeError, 'float32'): 371 test_fn(default) 372 373 def _basic_function_scope(self): 374 return function_wrappers.FunctionScope( 375 'test_function_name', 376 'test_scope', # Note: this must match the name in the `with` statement. 377 converter.ConversionOptions()) 378 379 def test_eval_in_original_context(self): 380 381 def test_fn(): 382 l = 1 # pylint:disable=unused-variable 383 with self._basic_function_scope() as test_scope: 384 return py_builtins.eval_in_original_context(eval, ('l',), test_scope) 385 386 self.assertEqual(test_fn(), 1) 387 388 def test_eval_in_original_context_inner_function(self): 389 390 def test_fn(): 391 l = 1 # pylint:disable=unused-variable 392 with self._basic_function_scope() as test_scope: 393 394 def inner_fn(): 395 # Note: a user function without a top-level function scope should 396 # never be found in user code; it's only possible in generated code. 397 l = 2 # pylint:disable=unused-variable 398 return py_builtins.eval_in_original_context(eval, ('l',), test_scope) 399 400 return inner_fn() 401 402 self.assertEqual(test_fn(), 2) 403 404 def test_locals_in_original_context(self): 405 406 def test_fn(): 407 l = 1 # pylint:disable=unused-variable 408 with self._basic_function_scope() as test_scope: 409 return py_builtins.locals_in_original_context(test_scope) 410 411 locs = test_fn() 412 413 self.assertEqual(locs['l'], 1) 414 415 def test_locals_in_original_context_inner_function(self): 416 417 def test_fn(): 418 l = 1 # pylint:disable=unused-variable 419 with self._basic_function_scope() as test_scope: 420 421 def inner_fn(): 422 # Note: a user function without a top-level function scope should 423 # never be found in user code; it's only possible in generated code. 424 l = 2 # pylint:disable=unused-variable 425 return py_builtins.locals_in_original_context(test_scope) 426 427 return inner_fn() 428 429 locs = test_fn() 430 431 self.assertEqual(locs['l'], 2) 432 433 def test_globals_in_original_context(self): 434 435 def test_fn(): 436 with self._basic_function_scope() as test_scope: 437 return py_builtins.globals_in_original_context(test_scope) 438 439 globs = test_fn() 440 441 self.assertIs(globs['TestBase'], TestBase) 442 443 def test_globals_in_original_context_inner_function(self): 444 445 def test_fn(): 446 with self._basic_function_scope() as test_scope: 447 448 def inner_fn(): 449 # Note: a user function without a top-level function scope should 450 # never be found in user code; it's only possible in generated code. 451 return py_builtins.globals_in_original_context(test_scope) 452 453 return inner_fn() 454 455 globs = test_fn() 456 457 self.assertIs(globs['TestBase'], TestBase) 458 459 def test_super_in_original_context_unary_call(self): 460 test_case_self = self 461 462 class TestSubclass(TestBase): 463 464 def overridden_method(self, x): 465 test_case_self.fail('This should never be called.') 466 467 def test_method(self): 468 with test_case_self._basic_function_scope() as test_scope: 469 test_base_unbound = py_builtins.super_in_original_context( 470 super, (TestSubclass,), test_scope) 471 test_base = test_base_unbound.__get__(self, TestSubclass) 472 return test_base.overridden_method(1) 473 474 tc = TestSubclass() 475 self.assertEqual(tc.test_method(), 21) 476 477 def test_super_in_original_context_binary_call(self): 478 test_case_self = self 479 480 class TestSubclass(TestBase): 481 482 def overridden_method(self, x): 483 test_case_self.fail('This should never be called.') 484 485 def test_method(self): 486 with test_case_self._basic_function_scope() as test_scope: 487 test_base = py_builtins.super_in_original_context( 488 super, (TestSubclass, self), test_scope) 489 return test_base.overridden_method(1) 490 491 tc = TestSubclass() 492 self.assertEqual(tc.test_method(), 21) 493 494 def test_super_in_original_context_niladic_call(self): 495 test_case_self = self 496 497 class TestSubclass(TestBase): 498 499 def overridden_method(self, x): 500 test_case_self.fail('This should never be called.') 501 502 def test_method(self): 503 with test_case_self._basic_function_scope() as test_scope: 504 b = py_builtins.super_in_original_context(super, (), test_scope) 505 return b.overridden_method(1) 506 507 tc = TestSubclass() 508 self.assertEqual(tc.test_method(), 21) 509 510 def test_super_in_original_context_caller_with_locals(self): 511 test_case_self = self 512 513 class TestSubclass(TestBase): 514 515 def overridden_method(self, x): 516 test_case_self.fail('This should never be called.') 517 518 def test_method(self, x): 519 y = 7 520 with test_case_self._basic_function_scope() as test_scope: 521 z = 7 522 return py_builtins.super_in_original_context( 523 super, (), test_scope).overridden_method(x + y - z) 524 525 tc = TestSubclass() 526 self.assertEqual(tc.test_method(1), 21) 527 528 def test_super_in_original_context_inner_function(self): 529 test_case_self = self 530 531 class TestSubclass(TestBase): 532 533 def overridden_method(self, x): 534 test_case_self.fail('This should never be called.') 535 536 def test_method(self, x): 537 with test_case_self._basic_function_scope() as test_scope: 538 # Oddly, it's sufficient to use `self` in an inner function 539 # to gain access to __class__ in this scope. 540 # TODO(mdan): Is this true across implementations? 541 # Note: normally, it's illegal to use super() in inner functions (it 542 # throws an error), but the generated code may create them. 543 def inner_fn(): 544 return py_builtins.super_in_original_context( 545 super, (), test_scope).overridden_method(x) 546 547 return inner_fn() 548 549 tc = TestSubclass() 550 self.assertEqual(tc.test_method(1), 21) 551 552 def test_super_in_original_context_inner_lambda(self): 553 test_case_self = self 554 555 class TestSubclass(TestBase): 556 557 def overridden_method(self, x): 558 test_case_self.fail('This should never be called.') 559 560 def test_method(self, x): 561 with test_case_self._basic_function_scope() as test_scope: 562 # Oddly, it's sufficient to use `self` in an inner function 563 # to gain access to __class__ in this scope. 564 # TODO(mdan): Is this true across implementations? 565 # Note: normally, it's illegal to use super() in inner functions (it 566 # throws an error), but the generated code may create them. 567 l = lambda: py_builtins.super_in_original_context( # pylint:disable=g-long-lambda 568 super, (), test_scope).overridden_method(x) 569 return l() 570 571 tc = TestSubclass() 572 self.assertEqual(tc.test_method(1), 21) 573 574 def test_filter(self): 575 self.assertListEqual( 576 list(py_builtins.filter_(lambda x: x == 'b', ['a', 'b', 'c'])), ['b']) 577 self.assertListEqual( 578 list(py_builtins.filter_(lambda x: x < 3, [3, 2, 1])), [2, 1]) 579 580 def test_filter_dataset(self): 581 dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1]) 582 dataset = py_builtins.filter_(lambda x: x < 3, dataset) 583 iterator = dataset_ops.make_one_shot_iterator(dataset) 584 with self.cached_session() as sess: 585 self.assertAllEqual(self.evaluate(iterator.get_next()), 2) 586 self.assertAllEqual(self.evaluate(iterator.get_next()), 1) 587 588 def test_any(self): 589 self.assertEqual(py_builtins.any_([False, True, False]), True) 590 self.assertEqual(py_builtins.any_([False, False, False]), False) 591 592 def test_any_dataset(self): 593 dataset_1 = dataset_ops.DatasetV2.from_tensor_slices([False, True, False]) 594 dataset_2 = dataset_ops.DatasetV2.from_tensor_slices([False, False, False]) 595 self.assertEqual(self.evaluate(py_builtins.any_(dataset_1)), True) 596 self.assertEqual(self.evaluate(py_builtins.any_(dataset_2)), False) 597 598 dataset_3 = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2]) 599 with self.assertRaises(ValueError): 600 py_builtins.any_(dataset_3) 601 602 dataset_4 = dataset_ops.DatasetV2.from_tensor_slices([False, True, False]) 603 dataset_zipped = dataset_ops.DatasetV2.zip((dataset_4, dataset_4)) 604 with self.assertRaises(ValueError): 605 py_builtins.any_(dataset_zipped) 606 607 dataset_mixed = dataset_ops.DatasetV2.zip((dataset_3, dataset_4)) 608 with self.assertRaises(ValueError): 609 py_builtins.any_(dataset_mixed) 610 611 def test_all(self): 612 self.assertEqual(py_builtins.all_([False, True, False]), False) 613 self.assertEqual(py_builtins.all_([True, True, True]), True) 614 615 def test_all_dataset(self): 616 dataset_1 = dataset_ops.DatasetV2.from_tensor_slices([False, True, False]) 617 dataset_2 = dataset_ops.DatasetV2.from_tensor_slices([True, True, True]) 618 self.assertEqual(self.evaluate(py_builtins.all_(dataset_1)), False) 619 self.assertEqual(self.evaluate(py_builtins.all_(dataset_2)), True) 620 621 dataset_3 = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2]) 622 with self.assertRaises(ValueError): 623 py_builtins.all_(dataset_3) 624 625 dataset_4 = dataset_ops.DatasetV2.from_tensor_slices([False, True, False]) 626 dataset_zipped = dataset_ops.DatasetV2.zip((dataset_4, dataset_4)) 627 with self.assertRaises(ValueError): 628 py_builtins.all_(dataset_zipped) 629 630 dataset_mixed = dataset_ops.DatasetV2.zip((dataset_3, dataset_4)) 631 with self.assertRaises(ValueError): 632 py_builtins.all_(dataset_mixed) 633 634 def test_sorted(self): 635 self.assertListEqual(py_builtins.sorted_([2, 3, 1]), [1, 2, 3]) 636 self.assertListEqual( 637 py_builtins.sorted_([2, 3, 1], key=lambda x: -x), [3, 2, 1]) 638 self.assertListEqual( 639 py_builtins.sorted_([2, 3, 1], reverse=True), [3, 2, 1]) 640 self.assertListEqual( 641 py_builtins.sorted_([2, 3, 1], key=lambda x: -x, reverse=True), 642 [1, 2, 3]) 643 self.assertAllEqual( 644 py_builtins.sorted_([[4, 3], [2, 1]], key=lambda x: sum(x)), 645 [[2, 1], [4, 3]]) 646 647 def test_sorted_tensor(self): 648 iterable_1 = constant_op.constant([2, 3, 1]) 649 self.assertListEqual( 650 list(self.evaluate(py_builtins.sorted_(iterable_1))), [1, 2, 3]) 651 self.assertListEqual( 652 list(self.evaluate(py_builtins.sorted_(iterable_1, key=lambda x: -x))), 653 [3, 2, 1]) 654 self.assertListEqual( 655 list(self.evaluate(py_builtins.sorted_(iterable_1, reverse=True))), 656 [3, 2, 1]) 657 self.assertListEqual( 658 list( 659 self.evaluate( 660 py_builtins.sorted_(iterable_1, key=lambda x: -x, 661 reverse=True))), [1, 2, 3]) 662 663 iterable_2 = constant_op.constant([[4, 3], [2, 1]]) 664 with self.assertRaises(ValueError): 665 py_builtins.sorted_(iterable_2) 666 with self.assertRaises(ValueError): 667 py_builtins.sorted_(iterable_2, key=lambda x: -x) 668 self.assertAllEqual( 669 list( 670 self.evaluate( 671 py_builtins.sorted_( 672 iterable_2, key=lambda x: math_ops.reduce_sum(x)))), 673 [[2, 1], [4, 3]]) 674 675 676if __name__ == '__main__': 677 test.main() 678