1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the behavior of the auto-compilation pass."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22from six.moves import xrange  # pylint: disable=redefined-builtin
23
24from tensorflow.compiler.tests import xla_test
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.platform import googletest
31
32CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
33
34
35class ClusteringTest(xla_test.XLATestCase):
36
37  def testAdd(self):
38    val1 = np.array([4, 3, 2, 1], dtype=np.float32)
39    val2 = np.array([5, 6, 7, 8], dtype=np.float32)
40    expected = val1 + val2
41    with self.cached_session():
42      with self.test_scope():
43        input1 = constant_op.constant(val1, name="const1")
44        input2 = constant_op.constant(val2, name="const2")
45        output = math_ops.add(input1, input2)
46      result = self.evaluate(output)
47    self.assertAllClose(result, expected, rtol=1e-3)
48
49  def testAddFromCpuMultiple(self):
50    val1 = np.array([4, 3, 2, 1]).astype(np.float32)
51    val2 = np.array([5, 6, 7, 8]).astype(np.float32)
52    expected = val1 + val2
53    with self.cached_session():
54      with ops.device(CPU_DEVICE):
55        input1 = constant_op.constant(val1, name="const1")
56        input2 = constant_op.constant(val2, name="const2")
57      with self.test_scope():
58        output = math_ops.add(input1, input2)
59      for _ in xrange(10):
60        result = self.evaluate(output)
61        self.assertAllClose(result, expected, rtol=1e-3)
62
63  def testDeadlock(self):
64    # Builds a graph of the form:
65    #  x -> y
66    #       | \
67    #       z -> w
68    # where x and z are placed on the CPU and y and w are placed on the XLA
69    # device. If y and w are clustered for compilation, then the graph will
70    # deadlock since the clustered graph will contain a self-loop.
71    with self.cached_session() as sess:
72      with ops.device(CPU_DEVICE):
73        x = array_ops.placeholder(dtypes.float32, [2])
74      with self.test_scope():
75        y = x * 2
76      with ops.device(CPU_DEVICE):
77        z = y * y
78      with self.test_scope():
79        w = y + z
80      result = sess.run(w, {x: [1.5, 0.5]})
81    self.assertAllClose(result, [12., 2.], rtol=1e-3)
82
83  def testHostMemory(self):
84    with self.cached_session() as sess:
85      x = array_ops.placeholder(dtypes.int32)
86      with self.test_scope():
87        y = x + 1
88      with ops.device(CPU_DEVICE):
89        # Place a computation on the CPU, so y and w cannot be merged into the
90        # same JIT compilation.
91        z = y * 2
92      with self.test_scope():
93        # Argument 'y' is a non-constant output of a previous cluster. Make sure
94        # it is properly copied to host memory so it can be used as a
95        # compile-time constant input for this cluster.
96        w = array_ops.reshape(z, y)
97      result = sess.run(w, {x: [1, 0]})
98      expected = np.array([[4], [2]], dtype=np.int32)
99      self.assertAllClose(expected, result, rtol=1e-3)
100
101
102if __name__ == "__main__":
103  googletest.main()
104