1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""XLA tests for pfor."""
16# pylint: disable=g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.compiler.tf2xla.python import xla as xla_ops
23from tensorflow.python.compiler.xla import jit
24from tensorflow.python.compiler.xla import xla
25from tensorflow.python.eager import context
26from tensorflow.python.eager import def_function
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import control_flow_v2_toggles
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import random_ops
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
36from tensorflow.python.ops.parallel_for.test_util import PForTestCase
37from tensorflow.python.platform import test
38
39
40@test_util.run_all_in_graph_and_eager_modes
41class PForTest(PForTestCase):
42
43  def __init__(self, method_name="runTest"):
44    super(PForTest, self).__init__(method_name)
45    context.context().enable_xla_devices()
46
47  def test_xla_einsum(self):
48    num_loop = 10
49    x_series = random_ops.random_uniform([num_loop, 9, 9])
50    y_series = random_ops.random_uniform([num_loop, 9, 1])
51
52    def loop_fn(i):
53      x = array_ops.gather(x_series, 0)  # invariant.
54      y = array_ops.gather(y_series, 0)  # invariant.
55      x_i = array_ops.gather(x_series, i)
56      y_i = array_ops.gather(y_series, i)
57      z1 = xla_ops.einsum(x_i, y, "ab,bc->ac")
58      z2 = xla_ops.einsum(x, y_i, "ab,bc->ac")
59      z3 = xla_ops.einsum(x, y, "ab,bc->ac")
60      z4 = xla_ops.einsum(x_i, y_i, "ab,bc->ac")
61      z5 = xla_ops.einsum(y_i, x_i, "cd,ce->de")  # Includes transpose.
62      outputs = [z1, z2, z3, z4, z5]
63      return outputs
64
65    self._test_loop_fn(loop_fn, num_loop)
66
67  def test_xla(self):
68
69    def compute(x):
70      return math_ops.reduce_mean(x, axis=0, keepdims=True)
71
72    def vectorized_compute(x):
73      return pfor_control_flow_ops.vectorized_map(compute, x)
74
75    result = xla.compile(
76        vectorized_compute, inputs=[array_ops.ones((10, 5, 3))])
77    self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
78
79  def test_function_jit_compile(self):
80
81    def compute(x):
82      return math_ops.reduce_mean(x, axis=0, keepdims=True)
83
84    @def_function.function(jit_compile=True)
85    def vectorized_compute(x):
86      return pfor_control_flow_ops.vectorized_map(compute, x)
87
88    result = vectorized_compute(array_ops.ones((10, 5, 3)))
89    self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
90
91  def test_xla_while_loop(self):
92
93    def compute(x):
94      return math_ops.reduce_mean(x, axis=0, keepdims=True)
95
96    def vectorized_compute(x, i):
97      inp = array_ops.gather(x, i)
98      output = pfor_control_flow_ops.vectorized_map(compute, inp)
99      output.set_shape([5, 1])
100      return output
101
102    def while_compute(x):
103      return control_flow_ops.while_loop_v2(
104          lambda i, _: i < 10,
105          lambda i, y: (i + 1, y + vectorized_compute(x, i)),
106          (0, array_ops.zeros([5, 1])))[1]
107
108    result = xla.compile(while_compute, inputs=[array_ops.ones((10, 5, 3))])
109    expected = array_ops.ones([5, 1]) * 10
110    self.run_and_assert_equal(expected, result)
111
112  def test_reduce_mean(self):
113    x = random_ops.random_uniform([8, 3])
114
115    @def_function.function(jit_compile=True)
116    def f():
117
118      def loop_fn(i, pfor_config):
119        x_i = array_ops.gather(x, i)
120        return x_i - pfor_config.reduce_mean(x_i)
121
122      return pfor_control_flow_ops.pfor(loop_fn, 8)
123
124    output = f()
125    ans = x - math_ops.reduce_mean(x, axis=0)
126    output_val, ans_val = self.evaluate([output, ans])
127    self.assertAllClose(ans_val, output_val)
128
129
130def _make_unstacked(cond, body, pfor_config):
131
132  def _cond(*args):
133    return math_ops.reduce_any(pfor_config.reduce_concat(args[0]))
134
135  def _body(*args):
136    not_done = args[0]
137    args = args[1:]
138    not_done = math_ops.logical_and(not_done, cond(*args))
139    outputs = body(*args)
140    return (not_done,) + tuple(
141        array_ops.where_v2(not_done, x, y) for x, y in zip(outputs, args))
142
143  return _cond, _body
144
145
146@test_util.run_all_in_graph_and_eager_modes
147class WhileV2Test(PForTestCase):
148
149  def setUp(self):
150    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
151    control_flow_v2_toggles.enable_control_flow_v2()
152    super(WhileV2Test, self).setUp()
153
154  def tearDown(self):
155    if not self._enabled:
156      control_flow_v2_toggles.disable_control_flow_v2()
157    super(WhileV2Test, self).tearDown()
158
159  def _test_loop_fn(self, loop_fn, iters, force_xla=False):
160
161    def f():
162      return pfor_control_flow_ops.pfor(loop_fn, iters)
163
164    @def_function.function
165    def jit_f():
166      with jit.experimental_jit_scope():
167        return f()
168
169    out = f()
170    jit_out = jit_f()
171    self.run_and_assert_equal(out, jit_out)
172    # TODO(agarwal): The following may complain about uncompilable nodes. Hence
173    # these are currently not enabled for all tests.
174    if force_xla:
175      out_exp_compile_f = def_function.function(jit_compile=True)(f)()
176      self.run_and_assert_equal(out, out_exp_compile_f)
177      out_xla_compile_f = xla.compile(f, inputs=[])
178      self.run_and_assert_equal(out, out_xla_compile_f)
179
180  def test_stateless_while(self):
181    x = random_ops.random_uniform([3, 5])
182    lengths = constant_op.constant([4, 0, 2])
183
184    def loop_fn(i):
185      x_i = array_ops.gather(x, i)
186      lengths_i = array_ops.gather(lengths, i)
187
188      return control_flow_ops.while_loop(
189          lambda j, _: j < lengths_i,
190          lambda j, t: (j + 1, t + array_ops.gather(x_i, j)),
191          [0, 0.])
192
193    self._test_loop_fn(loop_fn, 3)
194
195  def test_while_with_variable(self):
196    v = resource_variable_ops.ResourceVariable(5.)
197
198    def loop_fn(_):
199      _, output = control_flow_ops.while_loop(
200          lambda j, x: j < 4,
201          lambda j, x: (j + 1, x + v),
202          [0, 0.])
203      return output
204
205    self._test_loop_fn(loop_fn, 3)
206
207  def test_while_unstacked_condition(self):
208
209    def loop_fn(i):
210      return control_flow_ops.while_loop(
211          lambda j, x: j < 4,
212          lambda j, x: (j + 1, x + i), [0, 0])
213
214    self._test_loop_fn(loop_fn, 3, force_xla=True)
215
216  def test_while_force_unstacked_condition(self):
217    # The while_loop in this setup is similar to the one in test_stateless_while
218    # whose condition is loop variant. However here we wrap the cond and body of
219    # the loop in a way that makes the while_loop condition pfor loop invariant.
220    # This allows xla compilation to work since the vectorized code no longer
221    # needs to perform dynamic partitioning of the inputs.
222    x = random_ops.random_uniform([3, 5])
223    lengths = constant_op.constant([4, 0, 2])
224
225    def loop_fn(i, pfor_config):
226      x_i = array_ops.gather(x, i)
227      lengths_i = array_ops.gather(lengths, i)
228
229      def _cond(j, _):
230        return j < lengths_i
231
232      def _body(j, t):
233        return (j + 1, t + array_ops.gather(x_i, j))
234
235      cond, body = _make_unstacked(_cond, _body, pfor_config)
236      return control_flow_ops.while_loop(
237          cond,
238          body,
239          [True, 0, 0.])
240
241    self._test_loop_fn(loop_fn, 3, force_xla=True)
242
243
244if __name__ == "__main__":
245  test.main()
246