1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for stateful_random_ops.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import re 23 24from absl.testing import parameterized 25import numpy as np 26 27from tensorflow.python.distribute import values as dist_values 28from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy 29from tensorflow.python.eager import context 30from tensorflow.python.eager import def_function 31from tensorflow.python.framework import config 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import errors 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import test_util 37from tensorflow.python.kernel_tests.random import util as \ 38random_test_util 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import gen_random_ops 41from tensorflow.python.ops import gen_stateful_random_ops 42from tensorflow.python.ops import logging_ops 43from tensorflow.python.ops import stateful_random_ops as \ 44random 45from tensorflow.python.ops import variables 46from tensorflow.python.platform import test 47from tensorflow.python.training.tracking import util as tracking_util 48 49 50g_seeded = None 51g_unseeded = None 52 53 54GPU_FLOATS = [dtypes.float16, dtypes.float32, dtypes.float64] 55CPU_FLOATS = GPU_FLOATS + [dtypes.bfloat16] 56FLOATS = GPU_FLOATS 57INTS = [dtypes.int32, dtypes.int64] 58 59 60class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase): 61 62 def setUp(self): 63 super(StatefulRandomOpsTest, self).setUp() 64 physical_devices = config.list_physical_devices("CPU") 65 config.set_logical_device_configuration( 66 physical_devices[0], [ 67 context.LogicalDeviceConfiguration(), 68 context.LogicalDeviceConfiguration() 69 ]) 70 71 def testCreateRNGStateIntSeed(self): 72 """Tests `create_rng_state` when `seed` is int.""" 73 # using leading 'F' to test overflow tolerance 74 state = random.create_rng_state(0xFFFF222233334444FFAA666677778888, 75 random.RNG_ALG_PHILOX) 76 self.assertAllEqual( 77 list(map(random._uint_to_int, 78 [0xFFAA666677778888, 0xFFFF222233334444] + 79 [0] * (random.PHILOX_STATE_SIZE - 2))), 80 state) 81 82 def assertAllDifferent(self, tensors): 83 """Checks that there are no duplicate elements anywhere among the tensors. 84 85 Args: 86 tensors: a list of tensors. They can have different shapes. 87 """ 88 tensors = [array_ops.reshape(t, shape=[-1]) for t in tensors] 89 ls = array_ops.concat(tensors, axis=0).numpy().tolist() 90 self.assertAllEqual(len(ls), len(set(ls))) 91 92 @test_util.run_v2_only 93 def testNonDeterministicInts(self): 94 """Tests that non_deterministic_ints returns different results every time. 95 96 This test is flaky, but with very low probability of failing. 97 """ 98 shape = [2, 3] 99 dtype = dtypes.int64 100 a = random.non_deterministic_ints(shape=shape, dtype=dtype) 101 self.assertAllEqual(shape, a.shape) 102 self.assertEqual(dtype, a.dtype) 103 b = random.non_deterministic_ints(shape, dtype=dtype) 104 self.assertAllDifferent([a, b]) 105 106 @test_util.run_v2_only 107 def testBatchSeeds(self): 108 """Test for batch seeds. 109 """ 110 shape = [2, 3] 111 count = 6 112 gen = random.Generator.from_seed(1234) 113 keys1 = gen._make_int64_keys(shape=shape) 114 keys2 = gen._make_int64_keys(shape=shape) 115 self.assertAllDifferent([keys1, keys2]) 116 seeds1 = gen.make_seeds(count=count) 117 seeds2 = gen.make_seeds(count=count) 118 self.assertAllDifferent([seeds1[0, :], seeds2[0, :]]) 119 gens = gen.split(count=count) 120 self.assertAllEqual(count, len(gens)) 121 randoms = [g.uniform_full_int(shape=shape, dtype=dtypes.int32) 122 for g in gens] 123 self.assertAllDifferent(randoms) 124 # Tests graph mode. 125 @def_function.function 126 def f(): 127 return gen.make_seeds(count=count) 128 for _ in range(3): 129 f() 130 131 def assertRegex(self, pattern, text): 132 self.assertTrue( 133 re.search(pattern, text), 134 "Can't find pattern '%s' in text '%s'" % (pattern, text)) 135 136 @test_util.run_v2_only 137 @test_util.run_cuda_only 138 def testCrossDeviceSplit(self): 139 """Tests that a CPU RNG can split into RNGs on GPU. 140 """ 141 with ops.device("/device:CPU:0"): 142 gen = random.Generator.from_seed(1234) # gen is on CPU 143 self.assertRegex("CPU", gen.state.device) 144 with ops.device(test_util.gpu_device_name()): 145 gens = gen.split(count=10) # gens are on GPU 146 self.assertRegex("GPU", gens[0].state.device) 147 148 @test_util.run_v2_only 149 def testReset(self): 150 shape = [2, 3] 151 gen = random.Generator.from_seed(0) 152 for resetter in [ 153 lambda g: g.reset(state=[1, 2, 3]), 154 lambda g: g.reset_from_seed(1234), 155 lambda g: g.reset_from_key_counter(key=1, counter=[2, 3]), 156 ]: 157 resetter(gen) 158 expected_normal = gen.normal(shape) 159 @def_function.function 160 def f(resetter): 161 resetter(gen) 162 return gen.normal(shape) 163 def check_results(expected_normal, v): 164 self.assertAllEqual(expected_normal, v) 165 check_results(expected_normal, f(resetter)) 166 check_results(expected_normal, f(resetter)) 167 168 @test_util.run_v2_only 169 def testGeneratorCreation(self): 170 """Tests generator creation, in both eager and tf.function. 171 172 The interaction between Generator creation and defun should be the same as 173 tf.Variable. 174 """ 175 shape = [2, 3] 176 alg = random.RNG_ALG_PHILOX 177 for constructor in [ 178 lambda: random.Generator(state=[1, 2, 3], alg=alg), 179 lambda: random.Generator.from_seed(1234), 180 lambda: random.Generator.from_key_counter( # pylint: disable=g-long-lambda 181 key=1, counter=[2, 3], alg=alg), 182 ]: 183 gen = constructor() 184 # Tests tf.function 185 expected_normal1 = gen.normal(shape) 186 expected_normal2 = gen.normal(shape) 187 global g_seeded 188 g_seeded = None 189 @def_function.function 190 def f(constructor): 191 global g_seeded 192 # defun'ed function should only create variables once 193 if g_seeded is None: 194 g_seeded = constructor() 195 return g_seeded.normal(shape) 196 def check_results(expected_normal, v): 197 self.assertAllEqual(expected_normal, v) 198 check_results(expected_normal1, f(constructor)) 199 check_results(expected_normal2, f(constructor)) 200 201 @parameterized.parameters([ 202 ("philox", random.RNG_ALG_PHILOX, random.Algorithm.PHILOX), 203 ("threefry", random.RNG_ALG_THREEFRY, random.Algorithm.THREEFRY)]) 204 @test_util.run_v2_only 205 def testAlg(self, name, int_id, enum_id): 206 g_by_name = random.Generator.from_seed(1234, name) 207 g_by_int = random.Generator.from_seed(1234, int_id) 208 g_by_enum = random.Generator.from_seed(1234, enum_id) 209 self.assertEqual(g_by_name.algorithm, g_by_int.algorithm) 210 self.assertEqual(g_by_name.algorithm, g_by_enum.algorithm) 211 212 @test_util.run_v2_only 213 def testGeneratorCreationWithVar(self): 214 """Tests creating generator with a variable. 215 """ 216 alg = random.RNG_ALG_PHILOX 217 state = [1, 2, 3] 218 var = variables.Variable(state, dtype=random.STATE_TYPE) 219 g = random.Generator(state=state, alg=alg) 220 g_var = random.Generator(state=var, alg=alg) 221 shape = [2, 3] 222 g.normal(shape) 223 g_var.normal(shape) 224 self.assertAllEqual(g.state.read_value(), var.read_value()) 225 226 @test_util.run_v2_only 227 def testGeneratorCreationUnseeded(self): 228 """Tests generator creation, the unseeded case.""" 229 shape = [2, 3] 230 global g_unseeded 231 g_unseeded = None 232 @def_function.function 233 def f(): 234 global g_unseeded 235 # defun'ed function should only create variables once 236 if g_unseeded is None: 237 g_unseeded = random.Generator.from_non_deterministic_state() 238 return g_unseeded.normal(shape) 239 self.assertAllEqual(shape, f().shape) 240 241 @test_util.run_v2_only 242 def testGeneratorCopy(self): 243 """Tests copying a generator.""" 244 g = random.Generator.from_seed(0) 245 g_copy = random.Generator(g) 246 self.assertAllEqual(g.algorithm, g_copy.algorithm) 247 self.assertAllEqual(g.state.read_value(), g_copy.state.read_value()) 248 # Tests tf.function 249 global g_seeded 250 g_seeded = None 251 # Do the same in tf.function 252 @def_function.function 253 def f(): 254 global g_seeded 255 # defun'ed function should only create variables once 256 if g_seeded is None: 257 g_seeded = random.Generator(g) 258 self.assertAllEqual(g.algorithm, g_seeded.algorithm) 259 self.assertAllEqual(g.state.read_value(), g_seeded.state.read_value()) 260 f() 261 262 @test_util.run_v1_only( 263 ("This test is specifically for checking TF1 compatibility. " 264 "It cannot run under TF2.")) 265 def testTF1(self): 266 seed = 1234 267 shape = [2, 3] 268 expected_normal1 = constant_op.constant( 269 [[0.9356609, 1.0854305, -0.93788373], 270 [-0.50615472, 1.31697023, 0.71375787]], dtype=dtypes.float32) 271 expected_normal2 = constant_op.constant( 272 [[-0.3964749, 0.8369565, -0.30946946], 273 [1.1206646, 1.00852597, -0.10185789]], dtype=dtypes.float32) 274 with self.cached_session() as sess: 275 gen1 = random.Generator.from_seed(seed) 276 gen2 = random.Generator.from_non_deterministic_state() 277 sess.run((gen1.state.initializer, gen2.state.initializer)) 278 r1 = gen1.normal(shape, dtype=dtypes.float32) 279 r2 = gen2.normal(shape, dtype=dtypes.float32) 280 def f(): 281 return sess.run((r1, r2)) 282 def check_results(expected_normal, v1, v2): 283 self.assertAllClose(expected_normal, v1, rtol=1e-5, atol=1e-5) 284 self.assertAllEqual(shape, v2.shape) 285 check_results(expected_normal1, *f()) 286 check_results(expected_normal2, *f()) 287 288 @test_util.run_v2_only 289 @test_util.also_run_as_tf_function 290 def testEagerAndDefun(self): 291 """A simple test to make sure the op works in eager and defunned mode.""" 292 random.get_global_generator().normal((3,)) 293 294 @test_util.run_v2_only 295 def testOpSeedSelectionAfterSetSeed(self): 296 """Tests that op-seed selection is reset after reseting global generator. 297 298 Fixing GitHub issue 9171: 299 https://github.com/tensorflow/tensorflow/issues/9171 300 """ 301 shape = (3,) 302 random.get_global_generator().reset_from_seed(1) 303 a = random.get_global_generator().normal(shape) 304 random.get_global_generator().reset_from_seed(1) 305 b = random.get_global_generator().normal(shape) 306 self.assertAllEqual(a, b) 307 308 # Now do the above again using accelerated ('defun'ed) computation 309 @def_function.function 310 def f(): 311 return random.get_global_generator().normal(shape) 312 313 random.get_global_generator().reset_from_seed(1) 314 c = f() 315 random.get_global_generator().reset_from_seed(1) 316 d = f() 317 self.assertAllEqual(c, d) 318 self.assertAllEqual(a, c) 319 320 @test_util.run_v2_only 321 def testOpSeedSelectionNotSensitive(self): 322 """Test that op-seed selection is not sensitive to trivial changes. 323 324 Test that op-seed selection is not sensitive to trivial computation 325 (i.e. graph) changes. 326 327 Fixing b/32087099 328 """ 329 def f(include_print): 330 shape = constant_op.constant([5]) 331 if include_print: 332 shape = logging_ops.Print(shape, [shape]) 333 return random.get_global_generator().normal(shape) 334 335 def compare(fst_includes_print, snd_includes_print): 336 random.get_global_generator().reset_from_seed(50) 337 fst = f(fst_includes_print) 338 random.get_global_generator().reset_from_seed(50) 339 snd = f(snd_includes_print) 340 self.assertAllEqual(fst, snd) 341 # Now do the above again using accelerated (defunned) 'f'. 342 # Running 'f' with two different Boolean arguments should cause 343 # two different graphs to be generated, hence demonstrating the 344 # insensitivity to graph changes. 345 f_acc = def_function.function(f) 346 random.get_global_generator().reset_from_seed(50) 347 fst = f_acc(fst_includes_print) 348 random.get_global_generator().reset_from_seed(50) 349 snd = f_acc(snd_includes_print) 350 self.assertAllEqual(fst, snd) 351 352 compare(False, False) 353 compare(True, True) 354 compare(True, False) 355 356 @test_util.run_v2_only 357 def testKey(self): 358 key = 1234 359 gen = random.Generator(state=[0, 0, key], alg=random.RNG_ALG_PHILOX) 360 got = gen.key 361 self.assertAllEqual(key, got) 362 @def_function.function 363 def f(): 364 return gen.key 365 got = f() 366 self.assertAllEqual(key, got) 367 368 @test_util.run_v2_only 369 def testSkip(self): 370 key = 1234 371 counter = 5678 372 gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX) 373 delta = 432 374 gen.skip(delta) 375 new_counter = gen.state[0] 376 self.assertAllEqual(counter + delta * 256, new_counter) 377 378 def _sameAsOldRandomOps(self, device, floats): 379 def compare(dtype, old, new): 380 seed1, seed2 = 79, 25 381 # note how the two seeds for the old op correspond to the seed for the new 382 # op 383 with ops.device(device): 384 gen = random.Generator(state=[0, seed2, seed1], 385 alg=random.RNG_ALG_PHILOX) 386 387 # create a graph for the old op in order to call it many times 388 @def_function.function 389 def run_old(): 390 with ops.device(device): 391 return old(dtype, seed1, seed2) 392 393 def run_new(): 394 with ops.device(device): 395 return new(dtype, gen) 396 397 for _ in range(5): 398 self.assertAllEqual(run_old(), run_new()) 399 400 shape = constant_op.constant([4, 7]) 401 minval = 128 402 maxval = 256 403 404 # passing `dtype` around to compress go/gpylint-faq#cell-var-from-loop and 405 # go/gpylint-faq#undefined-loop-variable 406 def old_normal(dtype, seed1, seed2): 407 return gen_random_ops.random_standard_normal( 408 shape, dtype=dtype, seed=seed1, seed2=seed2) 409 def new_normal(dtype, gen): 410 return gen._standard_normal(shape, dtype=dtype) 411 def old_truncated_normal(dtype, seed1, seed2): 412 return gen_random_ops.truncated_normal( 413 shape, dtype=dtype, seed=seed1, seed2=seed2) 414 def new_truncated_normal(dtype, gen): 415 return gen._truncated_normal(shape, dtype=dtype) 416 def old_uniform_int(dtype, seed1, seed2): 417 minval2 = constant_op.constant(minval, dtype=dtype) 418 maxval2 = constant_op.constant(maxval, dtype=dtype) 419 return gen_random_ops.random_uniform_int( 420 shape, minval=minval2, maxval=maxval2, seed=seed1, seed2=seed2) 421 def new_uniform_int(dtype, gen): 422 return gen.uniform(shape, minval=minval, maxval=maxval, dtype=dtype) 423 def old_uniform(dtype, seed1, seed2): 424 return gen_random_ops.random_uniform( 425 shape, dtype=dtype, seed=seed1, seed2=seed2) 426 def new_uniform(dtype, gen): 427 return gen._uniform(shape, dtype=dtype) 428 429 for dtype in floats: 430 compare(dtype, old_normal, new_normal) 431 compare(dtype, old_truncated_normal, new_truncated_normal) 432 compare(dtype, old_uniform, new_uniform) 433 for dtype in INTS: 434 compare(dtype, old_uniform_int, new_uniform_int) 435 436 @test_util.run_v2_only 437 def testSameAsOldRandomOpsCPU(self): 438 """Tests that the generated numbers are the same as the old random_ops.py. 439 440 The CPU version. 441 """ 442 self._sameAsOldRandomOps("/device:CPU:0", CPU_FLOATS) 443 444 @test_util.run_v2_only 445 @test_util.run_cuda_only 446 def testSameAsOldRandomOpsGPU(self): 447 """Tests that the generated numbers are the same as the old random_ops.py. 448 449 The GPU version. 450 """ 451 self._sameAsOldRandomOps(test_util.gpu_device_name(), GPU_FLOATS) 452 453 @parameterized.parameters(INTS + [dtypes.uint32, dtypes.uint64]) 454 @test_util.run_v2_only 455 @test_util.run_cuda_only 456 def testGPUEqualsCPU(self, dtype): 457 """Tests that GPU and CPU generate the same integer outputs.""" 458 seed = 1234 459 shape = [315, 49] 460 with ops.device("/device:CPU:0"): 461 cpu = random.Generator.from_seed(seed).uniform_full_int( 462 shape=shape, dtype=dtype) 463 with ops.device(test_util.gpu_device_name()): 464 gpu = random.Generator.from_seed(seed).uniform_full_int( 465 shape=shape, dtype=dtype) 466 self.assertAllEqual(cpu, gpu) 467 468 @parameterized.parameters(FLOATS + INTS) 469 @test_util.run_v2_only 470 def testUniformIsInRange(self, dtype): 471 minval = 2 472 maxval = 33 473 size = 1000 474 gen = random.Generator.from_seed(1234) 475 x = gen.uniform( 476 shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy() 477 self.assertTrue(np.all(x >= minval)) 478 self.assertTrue(np.all(x < maxval)) 479 480 @parameterized.parameters(FLOATS) 481 @test_util.run_v2_only 482 def testNormalIsFinite(self, dtype): 483 gen = random.Generator.from_seed(1234) 484 x = gen.normal(shape=[10000], dtype=dtype).numpy() 485 self.assertTrue(np.all(np.isfinite(x))) 486 487 @parameterized.parameters(FLOATS + INTS) 488 @test_util.run_v2_only 489 def testDistributionOfUniform(self, dtype): 490 """Use Pearson's Chi-squared test to test for uniformity.""" 491 n = 1000 492 seed = 12 493 gen = random.Generator.from_seed(seed) 494 maxval = 1 495 if dtype.is_integer: 496 maxval = 100 497 x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy() 498 if maxval > 1: 499 # Normalize y to range [0, 1). 500 x = x.astype(float) / maxval 501 # Tests that the values are distributed amongst 10 bins with equal 502 # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with 503 # p=0.05. This test is probabilistic and would be flaky if the random 504 # seed were not fixed. 505 val = random_test_util.chi_squared(x, 10) 506 self.assertLess(val, 16.92) 507 508 @parameterized.parameters(FLOATS) 509 @test_util.run_v2_only 510 def testDistributionOfNormal(self, dtype): 511 """Use Anderson-Darling test to test distribution appears normal.""" 512 n = 1000 513 gen = random.Generator.from_seed(1234) 514 x = gen.normal(shape=[n], dtype=dtype).numpy() 515 # The constant 2.492 is the 5% critical value for the Anderson-Darling 516 # test where the mean and variance are known. This test is probabilistic 517 # so to avoid flakiness the seed is fixed. 518 self.assertLess( 519 random_test_util.anderson_darling(x.astype(float)), 2.492) 520 521 @test_util.run_v2_only 522 def testErrors(self): 523 """Tests that proper errors are raised. 524 """ 525 shape = [2, 3] 526 gen = random.Generator.from_seed(1234) 527 with self.assertRaisesWithPredicateMatch( 528 errors.InvalidArgumentError, 529 r"must have shape \[\], not"): 530 gen_stateful_random_ops.stateful_standard_normal_v2( 531 gen.state.handle, [0, 0], shape) 532 with self.assertRaisesWithPredicateMatch( 533 errors.InvalidArgumentError, 534 r"must have shape \[\], not"): 535 gen_stateful_random_ops.rng_skip( 536 gen.state.handle, gen.algorithm, [0, 0]) 537 with self.assertRaisesWithPredicateMatch( 538 TypeError, "EagerTensor of dtype int64"): 539 gen_stateful_random_ops.stateful_standard_normal_v2( 540 gen.state.handle, 1.1, shape) 541 with self.assertRaisesWithPredicateMatch( 542 errors.InvalidArgumentError, 543 "Unsupported algorithm id"): 544 gen_stateful_random_ops.stateful_standard_normal_v2( 545 gen.state.handle, 123, shape) 546 var = variables.Variable([0, 0], dtype=dtypes.int32) 547 with self.assertRaisesWithPredicateMatch( 548 errors.InvalidArgumentError, 549 "dtype of RNG state variable must be int64, not"): 550 gen_stateful_random_ops.stateful_standard_normal_v2( 551 var.handle, random.RNG_ALG_PHILOX, shape) 552 var = variables.Variable([[0]], dtype=dtypes.int64) 553 with self.assertRaisesWithPredicateMatch( 554 errors.InvalidArgumentError, 555 "RNG state must have one and only one dimension, not"): 556 gen_stateful_random_ops.stateful_standard_normal_v2( 557 var.handle, random.RNG_ALG_PHILOX, shape) 558 var = variables.Variable([0], dtype=dtypes.int64) 559 with self.assertRaisesWithPredicateMatch( 560 errors.InvalidArgumentError, 561 "For the Philox algorithm, the size of state must be at least"): 562 gen_stateful_random_ops.stateful_standard_normal_v2( 563 var.handle, random.RNG_ALG_PHILOX, shape) 564 with self.assertRaisesWithPredicateMatch( 565 ValueError, 566 "minval must be a scalar; got a tensor of shape "): 567 @def_function.function 568 def f(): 569 gen.uniform(shape=shape, minval=array_ops.zeros(shape, "int32"), 570 maxval=100, dtype="int32") 571 f() 572 with self.assertRaisesWithPredicateMatch( 573 ValueError, 574 "maxval must be a scalar; got a tensor of shape "): 575 @def_function.function 576 def f2(): 577 gen.uniform( 578 shape=shape, minval=0, maxval=array_ops.ones(shape, "int32") * 100, 579 dtype="int32") 580 f2() 581 582 @test_util.run_v2_only 583 def testGetGlobalGeneratorWithXla(self): 584 """Demonstrates using the global generator with XLA.""" 585 # This test was passing before because soft placement silently picked the 586 # CPU kernel. 587 # TODO(wangpeng): Remove this skip 588 self.skipTest("NonDeterministicInts lacks XLA kernel.") 589 590 if not config.list_physical_devices("XLA_CPU"): 591 self.skipTest("No XLA_CPU device available.") 592 593 random.set_global_generator(None) 594 595 @def_function.function(jit_compile=True) 596 def make_seed(): 597 generator = random.get_global_generator() 598 state = array_ops.identity(generator.state, name="state") 599 return generator.uniform_full_int((2,), dtypes.int32, name="seed"), state 600 601 with ops.device("/device:XLA_CPU:0"): 602 seed, state = make_seed() 603 self.assertTrue(np.all(np.isfinite(seed.numpy()))) 604 random.get_global_generator().reset(state) 605 self.assertAllEqual(make_seed()[0], seed) 606 607 @test_util.run_v2_only 608 def testSetGlobalGeneratorBadWithDefun(self): 609 """Demonstrates that set_global_generator don't work properly with defun. 610 """ 611 shape = (3,) 612 613 @def_function.function 614 def f(): 615 return random.get_global_generator().normal(shape) 616 617 random.set_global_generator(random.Generator.from_seed(50)) 618 with self.assertRaisesWithPredicateMatch( 619 errors.NotFoundError, "Resource .+ does not exist"): 620 _ = f() 621 random.set_global_generator(random.Generator.from_seed(50)) 622 _ = f() 623 624 @test_util.run_v2_only 625 def testFunctionArg(self): 626 """Tests that RNG can be used as tf.function's argument. 627 """ 628 shape = [2, 3] 629 @def_function.function 630 def f(gen): 631 return gen.normal(shape) 632 g1 = random.Generator.from_seed(1) 633 g2 = random.Generator.from_seed(1) 634 res1 = f(g1) 635 res2 = g2.normal(shape) 636 self.assertAllEqual(res1, res2) 637 self.assertAllEqual(g1.state.read_value(), g2.state.read_value()) 638 639 @test_util.run_v2_only 640 def testCreateOutsideMirroredStrat(self): 641 """Tests RNG/MirrorStrategy interaction #1. 642 643 If an RNG is created outside a DS scope, all replicas will access the 644 same RNG object, and accesses are serialized. 645 """ 646 shape = [3, 4] 647 dtype = dtypes.int32 648 gen = random.Generator.from_seed(1234) 649 strat = MirroredStrategy(devices=["cpu:0", "cpu:1"]) 650 with strat.scope(): 651 def f(): 652 t1 = gen.uniform_full_int(shape=shape, dtype=dtype) 653 t2 = gen.uniform_full_int(shape=shape, dtype=dtype) 654 t = array_ops.stack([t1, t2]) 655 return t 656 results = strat.extended.call_for_each_replica( 657 fn=f) 658 values = results.values 659 self.assertAllEqual(2, len(values)) 660 self.assertAllDifferent(values) 661 662 @test_util.run_v2_only 663 def testMirroredStratParaAsync(self): 664 """Tests RNG/MirrorStrategy interaction #2. 665 666 The user can create n independent RNGs outside strategy.scope(), where n 667 is the number of replicas, and give one to each replica. The replicas can 668 thus get different random-number streams. 669 """ 670 shape = [3, 4] 671 dtype = dtypes.int32 672 gens = random.get_global_generator().split(count=2) 673 devices = ["cpu:0", "cpu:1"] 674 strat = MirroredStrategy(devices=devices) 675 # Use `PerReplica` to specify which `gen` is sent to which replica 676 gens = dist_values.PerReplica([[g] for g in gens]) 677 with strat.scope(): 678 def f(gen): 679 t1 = gen.uniform_full_int(shape=shape, dtype=dtype) 680 t2 = gen.uniform_full_int(shape=shape, dtype=dtype) 681 t = array_ops.stack([t1, t2]) 682 return t 683 results = strat.extended.call_for_each_replica( 684 fn=f, args=gens) 685 local_results = strat.experimental_local_results(results) 686 self.assertAllEqual(2, len(local_results)) 687 self.assertAllDifferent(local_results) 688 689 @test_util.run_v2_only 690 def testUniformFullInt(self): 691 """Tests full-range int uniform. 692 """ 693 shape = [3, 4] 694 dtype = dtypes.int32 695 g = random.Generator.from_seed(1) 696 r1 = g.uniform(shape=shape, dtype=dtype, minval=None) 697 g = random.Generator.from_seed(1) 698 r2 = g.uniform_full_int(shape=shape, dtype=dtype) 699 self.assertAllEqual(r1, r2) 700 701 @test_util.run_v2_only 702 def testRestore(self): 703 """Tests save and restore. 704 """ 705 fname = os.path.join(self.get_temp_dir(), "checkpoint") 706 g = random.Generator.from_seed(1) 707 cp = tracking_util.Checkpoint(g=g) 708 def write_restore_compare(): 709 cp.write(fname) 710 r1 = g.uniform([], dtype=dtypes.uint32, minval=None) 711 cp.restore(fname) 712 r2 = g.uniform([], dtype=dtypes.uint32, minval=None) 713 self.assertAllEqual(r1, r2) 714 # Run multiple times so that cp.write is called in various RNG states 715 for _ in range(2): 716 write_restore_compare() 717 718 719if __name__ == "__main__": 720 config.set_soft_device_placement(False) 721 test.main() 722