• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for stateful_random_ops.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import re
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.python.distribute import values as dist_values
28from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
29from tensorflow.python.eager import context
30from tensorflow.python.eager import def_function
31from tensorflow.python.framework import config
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import test_util
37from tensorflow.python.kernel_tests.random import util as \
38random_test_util
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import gen_random_ops
41from tensorflow.python.ops import gen_stateful_random_ops
42from tensorflow.python.ops import logging_ops
43from tensorflow.python.ops import stateful_random_ops as \
44random
45from tensorflow.python.ops import variables
46from tensorflow.python.platform import test
47from tensorflow.python.training.tracking import util as tracking_util
48
49
50g_seeded = None
51g_unseeded = None
52
53
54GPU_FLOATS = [dtypes.float16, dtypes.float32, dtypes.float64]
55CPU_FLOATS = GPU_FLOATS + [dtypes.bfloat16]
56FLOATS = GPU_FLOATS
57INTS = [dtypes.int32, dtypes.int64]
58
59
60class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
61
62  def setUp(self):
63    super(StatefulRandomOpsTest, self).setUp()
64    physical_devices = config.list_physical_devices("CPU")
65    config.set_logical_device_configuration(
66        physical_devices[0], [
67            context.LogicalDeviceConfiguration(),
68            context.LogicalDeviceConfiguration()
69        ])
70
71  def testCreateRNGStateIntSeed(self):
72    """Tests `create_rng_state` when `seed` is int."""
73    # using leading 'F' to test overflow tolerance
74    state = random.create_rng_state(0xFFFF222233334444FFAA666677778888,
75                                    random.RNG_ALG_PHILOX)
76    self.assertAllEqual(
77        list(map(random._uint_to_int,
78                 [0xFFAA666677778888, 0xFFFF222233334444] +
79                 [0] * (random.PHILOX_STATE_SIZE - 2))),
80        state)
81
82  def assertAllDifferent(self, tensors):
83    """Checks that there are no duplicate elements anywhere among the tensors.
84
85    Args:
86      tensors: a list of tensors. They can have different shapes.
87    """
88    tensors = [array_ops.reshape(t, shape=[-1]) for t in tensors]
89    ls = array_ops.concat(tensors, axis=0).numpy().tolist()
90    self.assertAllEqual(len(ls), len(set(ls)))
91
92  @test_util.run_v2_only
93  def testNonDeterministicInts(self):
94    """Tests that non_deterministic_ints returns different results every time.
95
96    This test is flaky, but with very low probability of failing.
97    """
98    shape = [2, 3]
99    dtype = dtypes.int64
100    a = random.non_deterministic_ints(shape=shape, dtype=dtype)
101    self.assertAllEqual(shape, a.shape)
102    self.assertEqual(dtype, a.dtype)
103    b = random.non_deterministic_ints(shape, dtype=dtype)
104    self.assertAllDifferent([a, b])
105
106  @test_util.run_v2_only
107  def testBatchSeeds(self):
108    """Test for batch seeds.
109    """
110    shape = [2, 3]
111    count = 6
112    gen = random.Generator.from_seed(1234)
113    keys1 = gen._make_int64_keys(shape=shape)
114    keys2 = gen._make_int64_keys(shape=shape)
115    self.assertAllDifferent([keys1, keys2])
116    seeds1 = gen.make_seeds(count=count)
117    seeds2 = gen.make_seeds(count=count)
118    self.assertAllDifferent([seeds1[0, :], seeds2[0, :]])
119    gens = gen.split(count=count)
120    self.assertAllEqual(count, len(gens))
121    randoms = [g.uniform_full_int(shape=shape, dtype=dtypes.int32)
122               for g in gens]
123    self.assertAllDifferent(randoms)
124    # Tests graph mode.
125    @def_function.function
126    def f():
127      return gen.make_seeds(count=count)
128    for _ in range(3):
129      f()
130
131  def assertRegex(self, pattern, text):
132    self.assertTrue(
133        re.search(pattern, text),
134        "Can't find pattern '%s' in text '%s'" % (pattern, text))
135
136  @test_util.run_v2_only
137  @test_util.run_cuda_only
138  def testCrossDeviceSplit(self):
139    """Tests that a CPU RNG can split into RNGs on GPU.
140    """
141    with ops.device("/device:CPU:0"):
142      gen = random.Generator.from_seed(1234)  # gen is on CPU
143      self.assertRegex("CPU", gen.state.device)
144    with ops.device(test_util.gpu_device_name()):
145      gens = gen.split(count=10)  # gens are on GPU
146      self.assertRegex("GPU", gens[0].state.device)
147
148  @test_util.run_v2_only
149  def testReset(self):
150    shape = [2, 3]
151    gen = random.Generator.from_seed(0)
152    for resetter in [
153        lambda g: g.reset(state=[1, 2, 3]),
154        lambda g: g.reset_from_seed(1234),
155        lambda g: g.reset_from_key_counter(key=1, counter=[2, 3]),
156    ]:
157      resetter(gen)
158      expected_normal = gen.normal(shape)
159      @def_function.function
160      def f(resetter):
161        resetter(gen)
162        return gen.normal(shape)
163      def check_results(expected_normal, v):
164        self.assertAllEqual(expected_normal, v)
165      check_results(expected_normal, f(resetter))
166      check_results(expected_normal, f(resetter))
167
168  @test_util.run_v2_only
169  def testGeneratorCreation(self):
170    """Tests generator creation, in both eager and tf.function.
171
172    The interaction between Generator creation and defun should be the same as
173    tf.Variable.
174    """
175    shape = [2, 3]
176    alg = random.RNG_ALG_PHILOX
177    for constructor in [
178        lambda: random.Generator(state=[1, 2, 3], alg=alg),
179        lambda: random.Generator.from_seed(1234),
180        lambda: random.Generator.from_key_counter(  # pylint: disable=g-long-lambda
181            key=1, counter=[2, 3], alg=alg),
182    ]:
183      gen = constructor()
184      # Tests tf.function
185      expected_normal1 = gen.normal(shape)
186      expected_normal2 = gen.normal(shape)
187      global g_seeded
188      g_seeded = None
189      @def_function.function
190      def f(constructor):
191        global g_seeded
192        # defun'ed function should only create variables once
193        if g_seeded is None:
194          g_seeded = constructor()
195        return g_seeded.normal(shape)
196      def check_results(expected_normal, v):
197        self.assertAllEqual(expected_normal, v)
198      check_results(expected_normal1, f(constructor))
199      check_results(expected_normal2, f(constructor))
200
201  @parameterized.parameters([
202      ("philox", random.RNG_ALG_PHILOX, random.Algorithm.PHILOX),
203      ("threefry", random.RNG_ALG_THREEFRY, random.Algorithm.THREEFRY)])
204  @test_util.run_v2_only
205  def testAlg(self, name, int_id, enum_id):
206    g_by_name = random.Generator.from_seed(1234, name)
207    g_by_int = random.Generator.from_seed(1234, int_id)
208    g_by_enum = random.Generator.from_seed(1234, enum_id)
209    self.assertEqual(g_by_name.algorithm, g_by_int.algorithm)
210    self.assertEqual(g_by_name.algorithm, g_by_enum.algorithm)
211
212  @test_util.run_v2_only
213  def testGeneratorCreationWithVar(self):
214    """Tests creating generator with a variable.
215    """
216    alg = random.RNG_ALG_PHILOX
217    state = [1, 2, 3]
218    var = variables.Variable(state, dtype=random.STATE_TYPE)
219    g = random.Generator(state=state, alg=alg)
220    g_var = random.Generator(state=var, alg=alg)
221    shape = [2, 3]
222    g.normal(shape)
223    g_var.normal(shape)
224    self.assertAllEqual(g.state.read_value(), var.read_value())
225
226  @test_util.run_v2_only
227  def testGeneratorCreationUnseeded(self):
228    """Tests generator creation, the unseeded case."""
229    shape = [2, 3]
230    global g_unseeded
231    g_unseeded = None
232    @def_function.function
233    def f():
234      global g_unseeded
235      # defun'ed function should only create variables once
236      if g_unseeded is None:
237        g_unseeded = random.Generator.from_non_deterministic_state()
238      return g_unseeded.normal(shape)
239    self.assertAllEqual(shape, f().shape)
240
241  @test_util.run_v2_only
242  def testGeneratorCopy(self):
243    """Tests copying a generator."""
244    g = random.Generator.from_seed(0)
245    g_copy = random.Generator(g)
246    self.assertAllEqual(g.algorithm, g_copy.algorithm)
247    self.assertAllEqual(g.state.read_value(), g_copy.state.read_value())
248    # Tests tf.function
249    global g_seeded
250    g_seeded = None
251    # Do the same in tf.function
252    @def_function.function
253    def f():
254      global g_seeded
255      # defun'ed function should only create variables once
256      if g_seeded is None:
257        g_seeded = random.Generator(g)
258      self.assertAllEqual(g.algorithm, g_seeded.algorithm)
259      self.assertAllEqual(g.state.read_value(), g_seeded.state.read_value())
260    f()
261
262  @test_util.run_v1_only(
263      ("This test is specifically for checking TF1 compatibility. "
264       "It cannot run under TF2."))
265  def testTF1(self):
266    seed = 1234
267    shape = [2, 3]
268    expected_normal1 = constant_op.constant(
269        [[0.9356609, 1.0854305, -0.93788373],
270         [-0.50615472, 1.31697023, 0.71375787]], dtype=dtypes.float32)
271    expected_normal2 = constant_op.constant(
272        [[-0.3964749, 0.8369565, -0.30946946],
273         [1.1206646, 1.00852597, -0.10185789]], dtype=dtypes.float32)
274    with self.cached_session() as sess:
275      gen1 = random.Generator.from_seed(seed)
276      gen2 = random.Generator.from_non_deterministic_state()
277      sess.run((gen1.state.initializer, gen2.state.initializer))
278      r1 = gen1.normal(shape, dtype=dtypes.float32)
279      r2 = gen2.normal(shape, dtype=dtypes.float32)
280      def f():
281        return sess.run((r1, r2))
282      def check_results(expected_normal, v1, v2):
283        self.assertAllClose(expected_normal, v1, rtol=1e-5, atol=1e-5)
284        self.assertAllEqual(shape, v2.shape)
285      check_results(expected_normal1, *f())
286      check_results(expected_normal2, *f())
287
288  @test_util.run_v2_only
289  @test_util.also_run_as_tf_function
290  def testEagerAndDefun(self):
291    """A simple test to make sure the op works in eager and defunned mode."""
292    random.get_global_generator().normal((3,))
293
294  @test_util.run_v2_only
295  def testOpSeedSelectionAfterSetSeed(self):
296    """Tests that op-seed selection is reset after reseting global generator.
297
298    Fixing GitHub issue 9171:
299    https://github.com/tensorflow/tensorflow/issues/9171
300    """
301    shape = (3,)
302    random.get_global_generator().reset_from_seed(1)
303    a = random.get_global_generator().normal(shape)
304    random.get_global_generator().reset_from_seed(1)
305    b = random.get_global_generator().normal(shape)
306    self.assertAllEqual(a, b)
307
308    # Now do the above again using accelerated ('defun'ed) computation
309    @def_function.function
310    def f():
311      return random.get_global_generator().normal(shape)
312
313    random.get_global_generator().reset_from_seed(1)
314    c = f()
315    random.get_global_generator().reset_from_seed(1)
316    d = f()
317    self.assertAllEqual(c, d)
318    self.assertAllEqual(a, c)
319
320  @test_util.run_v2_only
321  def testOpSeedSelectionNotSensitive(self):
322    """Test that op-seed selection is not sensitive to trivial changes.
323
324    Test that op-seed selection is not sensitive to trivial computation
325    (i.e. graph) changes.
326
327    Fixing b/32087099
328    """
329    def f(include_print):
330      shape = constant_op.constant([5])
331      if include_print:
332        shape = logging_ops.Print(shape, [shape])
333      return random.get_global_generator().normal(shape)
334
335    def compare(fst_includes_print, snd_includes_print):
336      random.get_global_generator().reset_from_seed(50)
337      fst = f(fst_includes_print)
338      random.get_global_generator().reset_from_seed(50)
339      snd = f(snd_includes_print)
340      self.assertAllEqual(fst, snd)
341      # Now do the above again using accelerated (defunned) 'f'.
342      # Running 'f' with two different Boolean arguments should cause
343      # two different graphs to be generated, hence demonstrating the
344      # insensitivity to graph changes.
345      f_acc = def_function.function(f)
346      random.get_global_generator().reset_from_seed(50)
347      fst = f_acc(fst_includes_print)
348      random.get_global_generator().reset_from_seed(50)
349      snd = f_acc(snd_includes_print)
350      self.assertAllEqual(fst, snd)
351
352    compare(False, False)
353    compare(True, True)
354    compare(True, False)
355
356  @test_util.run_v2_only
357  def testKey(self):
358    key = 1234
359    gen = random.Generator(state=[0, 0, key], alg=random.RNG_ALG_PHILOX)
360    got = gen.key
361    self.assertAllEqual(key, got)
362    @def_function.function
363    def f():
364      return gen.key
365    got = f()
366    self.assertAllEqual(key, got)
367
368  @test_util.run_v2_only
369  def testSkip(self):
370    key = 1234
371    counter = 5678
372    gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX)
373    delta = 432
374    gen.skip(delta)
375    new_counter = gen.state[0]
376    self.assertAllEqual(counter + delta * 256, new_counter)
377
378  def _sameAsOldRandomOps(self, device, floats):
379    def compare(dtype, old, new):
380      seed1, seed2 = 79, 25
381      # note how the two seeds for the old op correspond to the seed for the new
382      # op
383      with ops.device(device):
384        gen = random.Generator(state=[0, seed2, seed1],
385                               alg=random.RNG_ALG_PHILOX)
386
387      # create a graph for the old op in order to call it many times
388      @def_function.function
389      def run_old():
390        with ops.device(device):
391          return old(dtype, seed1, seed2)
392
393      def run_new():
394        with ops.device(device):
395          return new(dtype, gen)
396
397      for _ in range(5):
398        self.assertAllEqual(run_old(), run_new())
399
400    shape = constant_op.constant([4, 7])
401    minval = 128
402    maxval = 256
403
404    # passing `dtype` around to compress go/gpylint-faq#cell-var-from-loop and
405    # go/gpylint-faq#undefined-loop-variable
406    def old_normal(dtype, seed1, seed2):
407      return gen_random_ops.random_standard_normal(
408          shape, dtype=dtype, seed=seed1, seed2=seed2)
409    def new_normal(dtype, gen):
410      return gen._standard_normal(shape, dtype=dtype)
411    def old_truncated_normal(dtype, seed1, seed2):
412      return gen_random_ops.truncated_normal(
413          shape, dtype=dtype, seed=seed1, seed2=seed2)
414    def new_truncated_normal(dtype, gen):
415      return gen._truncated_normal(shape, dtype=dtype)
416    def old_uniform_int(dtype, seed1, seed2):
417      minval2 = constant_op.constant(minval, dtype=dtype)
418      maxval2 = constant_op.constant(maxval, dtype=dtype)
419      return gen_random_ops.random_uniform_int(
420          shape, minval=minval2, maxval=maxval2, seed=seed1, seed2=seed2)
421    def new_uniform_int(dtype, gen):
422      return gen.uniform(shape, minval=minval, maxval=maxval, dtype=dtype)
423    def old_uniform(dtype, seed1, seed2):
424      return gen_random_ops.random_uniform(
425          shape, dtype=dtype, seed=seed1, seed2=seed2)
426    def new_uniform(dtype, gen):
427      return gen._uniform(shape, dtype=dtype)
428
429    for dtype in floats:
430      compare(dtype, old_normal, new_normal)
431      compare(dtype, old_truncated_normal, new_truncated_normal)
432      compare(dtype, old_uniform, new_uniform)
433    for dtype in INTS:
434      compare(dtype, old_uniform_int, new_uniform_int)
435
436  @test_util.run_v2_only
437  def testSameAsOldRandomOpsCPU(self):
438    """Tests that the generated numbers are the same as the old random_ops.py.
439
440    The CPU version.
441    """
442    self._sameAsOldRandomOps("/device:CPU:0", CPU_FLOATS)
443
444  @test_util.run_v2_only
445  @test_util.run_cuda_only
446  def testSameAsOldRandomOpsGPU(self):
447    """Tests that the generated numbers are the same as the old random_ops.py.
448
449    The GPU version.
450    """
451    self._sameAsOldRandomOps(test_util.gpu_device_name(), GPU_FLOATS)
452
453  @parameterized.parameters(INTS + [dtypes.uint32, dtypes.uint64])
454  @test_util.run_v2_only
455  @test_util.run_cuda_only
456  def testGPUEqualsCPU(self, dtype):
457    """Tests that GPU and CPU generate the same integer outputs."""
458    seed = 1234
459    shape = [315, 49]
460    with ops.device("/device:CPU:0"):
461      cpu = random.Generator.from_seed(seed).uniform_full_int(
462          shape=shape, dtype=dtype)
463    with ops.device(test_util.gpu_device_name()):
464      gpu = random.Generator.from_seed(seed).uniform_full_int(
465          shape=shape, dtype=dtype)
466    self.assertAllEqual(cpu, gpu)
467
468  @parameterized.parameters(FLOATS + INTS)
469  @test_util.run_v2_only
470  def testUniformIsInRange(self, dtype):
471    minval = 2
472    maxval = 33
473    size = 1000
474    gen = random.Generator.from_seed(1234)
475    x = gen.uniform(
476        shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
477    self.assertTrue(np.all(x >= minval))
478    self.assertTrue(np.all(x < maxval))
479
480  @parameterized.parameters(FLOATS)
481  @test_util.run_v2_only
482  def testNormalIsFinite(self, dtype):
483    gen = random.Generator.from_seed(1234)
484    x = gen.normal(shape=[10000], dtype=dtype).numpy()
485    self.assertTrue(np.all(np.isfinite(x)))
486
487  @parameterized.parameters(FLOATS + INTS)
488  @test_util.run_v2_only
489  def testDistributionOfUniform(self, dtype):
490    """Use Pearson's Chi-squared test to test for uniformity."""
491    n = 1000
492    seed = 12
493    gen = random.Generator.from_seed(seed)
494    maxval = 1
495    if dtype.is_integer:
496      maxval = 100
497    x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy()
498    if maxval > 1:
499      # Normalize y to range [0, 1).
500      x = x.astype(float) / maxval
501    # Tests that the values are distributed amongst 10 bins with equal
502    # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
503    # p=0.05. This test is probabilistic and would be flaky if the random
504    # seed were not fixed.
505    val = random_test_util.chi_squared(x, 10)
506    self.assertLess(val, 16.92)
507
508  @parameterized.parameters(FLOATS)
509  @test_util.run_v2_only
510  def testDistributionOfNormal(self, dtype):
511    """Use Anderson-Darling test to test distribution appears normal."""
512    n = 1000
513    gen = random.Generator.from_seed(1234)
514    x = gen.normal(shape=[n], dtype=dtype).numpy()
515    # The constant 2.492 is the 5% critical value for the Anderson-Darling
516    # test where the mean and variance are known. This test is probabilistic
517    # so to avoid flakiness the seed is fixed.
518    self.assertLess(
519        random_test_util.anderson_darling(x.astype(float)), 2.492)
520
521  @test_util.run_v2_only
522  def testErrors(self):
523    """Tests that proper errors are raised.
524    """
525    shape = [2, 3]
526    gen = random.Generator.from_seed(1234)
527    with self.assertRaisesWithPredicateMatch(
528        errors.InvalidArgumentError,
529        r"must have shape \[\], not"):
530      gen_stateful_random_ops.stateful_standard_normal_v2(
531          gen.state.handle, [0, 0], shape)
532    with self.assertRaisesWithPredicateMatch(
533        errors.InvalidArgumentError,
534        r"must have shape \[\], not"):
535      gen_stateful_random_ops.rng_skip(
536          gen.state.handle, gen.algorithm, [0, 0])
537    with self.assertRaisesWithPredicateMatch(
538        TypeError, "EagerTensor of dtype int64"):
539      gen_stateful_random_ops.stateful_standard_normal_v2(
540          gen.state.handle, 1.1, shape)
541    with self.assertRaisesWithPredicateMatch(
542        errors.InvalidArgumentError,
543        "Unsupported algorithm id"):
544      gen_stateful_random_ops.stateful_standard_normal_v2(
545          gen.state.handle, 123, shape)
546    var = variables.Variable([0, 0], dtype=dtypes.int32)
547    with self.assertRaisesWithPredicateMatch(
548        errors.InvalidArgumentError,
549        "dtype of RNG state variable must be int64, not"):
550      gen_stateful_random_ops.stateful_standard_normal_v2(
551          var.handle, random.RNG_ALG_PHILOX, shape)
552    var = variables.Variable([[0]], dtype=dtypes.int64)
553    with self.assertRaisesWithPredicateMatch(
554        errors.InvalidArgumentError,
555        "RNG state must have one and only one dimension, not"):
556      gen_stateful_random_ops.stateful_standard_normal_v2(
557          var.handle, random.RNG_ALG_PHILOX, shape)
558    var = variables.Variable([0], dtype=dtypes.int64)
559    with self.assertRaisesWithPredicateMatch(
560        errors.InvalidArgumentError,
561        "For the Philox algorithm, the size of state must be at least"):
562      gen_stateful_random_ops.stateful_standard_normal_v2(
563          var.handle, random.RNG_ALG_PHILOX, shape)
564    with self.assertRaisesWithPredicateMatch(
565        ValueError,
566        "minval must be a scalar; got a tensor of shape "):
567      @def_function.function
568      def f():
569        gen.uniform(shape=shape, minval=array_ops.zeros(shape, "int32"),
570                    maxval=100, dtype="int32")
571      f()
572    with self.assertRaisesWithPredicateMatch(
573        ValueError,
574        "maxval must be a scalar; got a tensor of shape "):
575      @def_function.function
576      def f2():
577        gen.uniform(
578            shape=shape, minval=0, maxval=array_ops.ones(shape, "int32") * 100,
579            dtype="int32")
580      f2()
581
582  @test_util.run_v2_only
583  def testGetGlobalGeneratorWithXla(self):
584    """Demonstrates using the global generator with XLA."""
585    # This test was passing before because soft placement silently picked the
586    # CPU kernel.
587    # TODO(wangpeng): Remove this skip
588    self.skipTest("NonDeterministicInts lacks XLA kernel.")
589
590    if not config.list_physical_devices("XLA_CPU"):
591      self.skipTest("No XLA_CPU device available.")
592
593    random.set_global_generator(None)
594
595    @def_function.function(jit_compile=True)
596    def make_seed():
597      generator = random.get_global_generator()
598      state = array_ops.identity(generator.state, name="state")
599      return generator.uniform_full_int((2,), dtypes.int32, name="seed"), state
600
601    with ops.device("/device:XLA_CPU:0"):
602      seed, state = make_seed()
603      self.assertTrue(np.all(np.isfinite(seed.numpy())))
604      random.get_global_generator().reset(state)
605      self.assertAllEqual(make_seed()[0], seed)
606
607  @test_util.run_v2_only
608  def testSetGlobalGeneratorBadWithDefun(self):
609    """Demonstrates that set_global_generator don't work properly with defun.
610    """
611    shape = (3,)
612
613    @def_function.function
614    def f():
615      return random.get_global_generator().normal(shape)
616
617    random.set_global_generator(random.Generator.from_seed(50))
618    with self.assertRaisesWithPredicateMatch(
619        errors.NotFoundError, "Resource .+ does not exist"):
620      _ = f()
621      random.set_global_generator(random.Generator.from_seed(50))
622      _ = f()
623
624  @test_util.run_v2_only
625  def testFunctionArg(self):
626    """Tests that RNG can be used as tf.function's argument.
627    """
628    shape = [2, 3]
629    @def_function.function
630    def f(gen):
631      return gen.normal(shape)
632    g1 = random.Generator.from_seed(1)
633    g2 = random.Generator.from_seed(1)
634    res1 = f(g1)
635    res2 = g2.normal(shape)
636    self.assertAllEqual(res1, res2)
637    self.assertAllEqual(g1.state.read_value(), g2.state.read_value())
638
639  @test_util.run_v2_only
640  def testCreateOutsideMirroredStrat(self):
641    """Tests RNG/MirrorStrategy interaction #1.
642
643    If an RNG is created outside a DS scope, all replicas will access the
644    same RNG object, and accesses are serialized.
645    """
646    shape = [3, 4]
647    dtype = dtypes.int32
648    gen = random.Generator.from_seed(1234)
649    strat = MirroredStrategy(devices=["cpu:0", "cpu:1"])
650    with strat.scope():
651      def f():
652        t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
653        t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
654        t = array_ops.stack([t1, t2])
655        return t
656      results = strat.extended.call_for_each_replica(
657          fn=f)
658      values = results.values
659      self.assertAllEqual(2, len(values))
660      self.assertAllDifferent(values)
661
662  @test_util.run_v2_only
663  def testMirroredStratParaAsync(self):
664    """Tests RNG/MirrorStrategy interaction #2.
665
666    The user can create n independent RNGs outside strategy.scope(), where n
667    is the number of replicas, and give one to each replica. The replicas can
668    thus get different random-number streams.
669    """
670    shape = [3, 4]
671    dtype = dtypes.int32
672    gens = random.get_global_generator().split(count=2)
673    devices = ["cpu:0", "cpu:1"]
674    strat = MirroredStrategy(devices=devices)
675    # Use `PerReplica` to specify which `gen` is sent to which replica
676    gens = dist_values.PerReplica([[g] for g in gens])
677    with strat.scope():
678      def f(gen):
679        t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
680        t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
681        t = array_ops.stack([t1, t2])
682        return t
683      results = strat.extended.call_for_each_replica(
684          fn=f, args=gens)
685      local_results = strat.experimental_local_results(results)
686      self.assertAllEqual(2, len(local_results))
687      self.assertAllDifferent(local_results)
688
689  @test_util.run_v2_only
690  def testUniformFullInt(self):
691    """Tests full-range int uniform.
692    """
693    shape = [3, 4]
694    dtype = dtypes.int32
695    g = random.Generator.from_seed(1)
696    r1 = g.uniform(shape=shape, dtype=dtype, minval=None)
697    g = random.Generator.from_seed(1)
698    r2 = g.uniform_full_int(shape=shape, dtype=dtype)
699    self.assertAllEqual(r1, r2)
700
701  @test_util.run_v2_only
702  def testRestore(self):
703    """Tests save and restore.
704    """
705    fname = os.path.join(self.get_temp_dir(), "checkpoint")
706    g = random.Generator.from_seed(1)
707    cp = tracking_util.Checkpoint(g=g)
708    def write_restore_compare():
709      cp.write(fname)
710      r1 = g.uniform([], dtype=dtypes.uint32, minval=None)
711      cp.restore(fname)
712      r2 = g.uniform([], dtype=dtypes.uint32, minval=None)
713      self.assertAllEqual(r1, r2)
714    # Run multiple times so that cp.write is called in various RNG states
715    for _ in range(2):
716      write_restore_compare()
717
718
719if __name__ == "__main__":
720  config.set_soft_device_placement(False)
721  test.main()
722