1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utility."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.ops import variables
25from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
26from tensorflow.python.platform import test
27from tensorflow.python.util import nest
28
29
30class PForTestCase(test.TestCase):
31  """Base class for test cases."""
32
33  def _run_targets(self, targets1, targets2=None, run_init=True):
34    targets1 = nest.flatten(targets1)
35    targets2 = ([] if targets2 is None else nest.flatten(targets2))
36    assert len(targets1) == len(targets2) or not targets2
37    if run_init:
38      init = variables.global_variables_initializer()
39      self.evaluate(init)
40    return self.evaluate(targets1 + targets2)
41
42  def run_and_assert_equal(self, targets1, targets2):
43    outputs = self._run_targets(targets1, targets2)
44    outputs = nest.flatten(outputs)  # flatten SparseTensorValues
45    n = len(outputs) // 2
46    for i in range(n):
47      if outputs[i + n].dtype != np.object:
48        self.assertAllClose(outputs[i + n], outputs[i], rtol=1e-4, atol=1e-5)
49      else:
50        self.assertAllEqual(outputs[i + n], outputs[i])
51
52  def _test_loop_fn(self, loop_fn, iters,
53                    loop_fn_dtypes=dtypes.float32,
54                    parallel_iterations=None):
55    t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters,
56                                    parallel_iterations=parallel_iterations)
57    t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters,
58                                        parallel_iterations=parallel_iterations)
59    self.run_and_assert_equal(t1, t2)
60