1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utility."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.ops import variables
24from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
25from tensorflow.python.platform import test
26from tensorflow.python.util import nest
27
28
29class PForTestCase(test.TestCase):
30  """Base class for test cases."""
31
32  def _run_targets(self, targets1, targets2=None, run_init=True):
33    targets1 = nest.flatten(targets1)
34    targets2 = ([] if targets2 is None else nest.flatten(targets2))
35    assert len(targets1) == len(targets2) or not targets2
36    if run_init:
37      init = variables.global_variables_initializer()
38      self.evaluate(init)
39    return self.evaluate(targets1 + targets2)
40
41  # TODO(agarwal): Allow tests to pass down tolerances.
42  def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
43    outputs = self._run_targets(targets1, targets2)
44    outputs = nest.flatten(outputs)  # flatten SparseTensorValues
45    n = len(outputs) // 2
46    for i in range(n):
47      if outputs[i + n].dtype != np.object:
48        self.assertAllClose(outputs[i + n], outputs[i], rtol=rtol, atol=atol)
49      else:
50        self.assertAllEqual(outputs[i + n], outputs[i])
51
52  def _test_loop_fn(self,
53                    loop_fn,
54                    iters,
55                    parallel_iterations=None,
56                    fallback_to_while_loop=False,
57                    rtol=1e-4,
58                    atol=1e-5):
59    t1 = pfor_control_flow_ops.pfor(
60        loop_fn,
61        iters=iters,
62        fallback_to_while_loop=fallback_to_while_loop,
63        parallel_iterations=parallel_iterations)
64    loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1)
65    t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters,
66                                        parallel_iterations=parallel_iterations)
67    self.run_and_assert_equal(t1, t2, rtol=rtol, atol=atol)
68