1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
17 
18 #include <set>
19 
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/test_helpers.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 
25 namespace xla {
26 
27 namespace {
28 
29 class HloReachabilityTest : public HloTestBase {};
30 
TEST_F(HloReachabilityTest,Reachability)31 TEST_F(HloReachabilityTest, Reachability) {
32   // Construct and test a reachability graph of the following form:
33   /*
34        a
35       / \
36      b   c
37       \ / \
38        d   e
39   */
40   auto builder = HloComputation::Builder(TestName());
41   auto a = builder.AddInstruction(
42       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
43   auto b = builder.AddInstruction(
44       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
45   auto c = builder.AddInstruction(
46       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
47   auto d = builder.AddInstruction(
48       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
49   auto e = builder.AddInstruction(
50       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
51   auto module = CreateNewVerifiedModule();
52   module->AddEntryComputation(builder.Build());
53 
54   HloReachabilityMap reachability({a, b, c, d, e});
55   reachability.SetReachable(a, a);
56   EXPECT_TRUE(reachability.SetReachabilityToUnion({a}, b));
57   EXPECT_TRUE(reachability.SetReachabilityToUnion({a}, c));
58   EXPECT_TRUE(reachability.SetReachabilityToUnion({b, c}, d));
59   EXPECT_TRUE(reachability.SetReachabilityToUnion({c}, e));
60 
61   EXPECT_TRUE(reachability.IsReachable(a, a));
62   EXPECT_TRUE(reachability.IsReachable(a, b));
63   EXPECT_TRUE(reachability.IsReachable(a, c));
64   EXPECT_TRUE(reachability.IsReachable(a, d));
65   EXPECT_TRUE(reachability.IsReachable(a, e));
66 
67   EXPECT_FALSE(reachability.IsReachable(b, a));
68   EXPECT_TRUE(reachability.IsReachable(b, b));
69   EXPECT_FALSE(reachability.IsReachable(b, c));
70   EXPECT_TRUE(reachability.IsReachable(b, d));
71   EXPECT_FALSE(reachability.IsReachable(b, e));
72 
73   EXPECT_FALSE(reachability.IsReachable(e, a));
74   EXPECT_FALSE(reachability.IsReachable(e, b));
75   EXPECT_FALSE(reachability.IsReachable(e, c));
76   EXPECT_FALSE(reachability.IsReachable(e, d));
77   EXPECT_TRUE(reachability.IsReachable(e, e));
78 
79   // Recomputing the same reachability for a previously computed instruction
80   // should return false (no change).
81   EXPECT_FALSE(reachability.SetReachabilityToUnion({a}, b));
82   EXPECT_FALSE(reachability.SetReachabilityToUnion({b, c}, d));
83 }
84 
TEST_F(HloReachabilityTest,NonTrivialReachability)85 TEST_F(HloReachabilityTest, NonTrivialReachability) {
86   // Test reachability of a non-trivial computation:
87   //
88   // const1    const2
89   //    |         |
90   //    | +-------+
91   //    | |       |
92   //    add ..   negate
93   //     |   .     |
94   //     |   .... exp
95   //     |         |
96   //     +---+   +-+---+
97   //         |   |     |
98   //       multiply   copy
99   //
100   // There is a control dependency from 'add' to 'exp'.
101   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
102   auto builder = HloComputation::Builder(TestName());
103   auto constant1 = builder.AddInstruction(
104       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
105   auto constant2 = builder.AddInstruction(
106       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
107   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
108       r0f32, HloOpcode::kAdd, constant1, constant2));
109   auto negate = builder.AddInstruction(
110       HloInstruction::CreateUnary(r0f32, HloOpcode::kNegate, constant2));
111   auto exp = builder.AddInstruction(
112       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, negate));
113   auto mul = builder.AddInstruction(
114       HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, add, exp));
115   auto copy = builder.AddInstruction(
116       HloInstruction::CreateUnary(r0f32, HloOpcode::kCopy, exp));
117 
118   auto module = CreateNewVerifiedModule();
119   auto computation =
120       module->AddEntryComputation(builder.Build(/*root_instruction=*/mul));
121 
122   TF_CHECK_OK(add->AddControlDependencyTo(exp));
123   auto reachability = HloReachabilityMap::Build(computation);
124 
125   EXPECT_TRUE(reachability->IsReachable(constant1, constant1));
126   EXPECT_FALSE(reachability->IsReachable(constant1, constant2));
127   EXPECT_TRUE(reachability->IsReachable(constant1, add));
128   EXPECT_FALSE(reachability->IsReachable(constant1, negate));
129   EXPECT_TRUE(reachability->IsReachable(constant1, exp));
130   EXPECT_TRUE(reachability->IsReachable(constant1, mul));
131   EXPECT_TRUE(reachability->IsReachable(constant1, copy));
132 
133   EXPECT_FALSE(reachability->IsReachable(constant2, constant1));
134   EXPECT_TRUE(reachability->IsReachable(constant2, constant2));
135   EXPECT_TRUE(reachability->IsReachable(constant2, add));
136   EXPECT_TRUE(reachability->IsReachable(constant2, negate));
137   EXPECT_TRUE(reachability->IsReachable(constant2, exp));
138   EXPECT_TRUE(reachability->IsReachable(constant2, mul));
139   EXPECT_TRUE(reachability->IsReachable(constant2, copy));
140 
141   EXPECT_FALSE(reachability->IsReachable(exp, constant1));
142   EXPECT_FALSE(reachability->IsReachable(exp, constant2));
143   EXPECT_FALSE(reachability->IsReachable(exp, add));
144   EXPECT_FALSE(reachability->IsReachable(exp, negate));
145   EXPECT_TRUE(reachability->IsReachable(exp, exp));
146   EXPECT_TRUE(reachability->IsReachable(exp, mul));
147   EXPECT_TRUE(reachability->IsReachable(exp, copy));
148 
149   EXPECT_FALSE(reachability->IsReachable(mul, constant1));
150   EXPECT_FALSE(reachability->IsReachable(mul, constant2));
151   EXPECT_FALSE(reachability->IsReachable(mul, add));
152   EXPECT_FALSE(reachability->IsReachable(mul, negate));
153   EXPECT_FALSE(reachability->IsReachable(mul, exp));
154   EXPECT_TRUE(reachability->IsReachable(mul, mul));
155   EXPECT_FALSE(reachability->IsReachable(mul, copy));
156 
157   EXPECT_TRUE(reachability->IsConnected(constant1, copy));
158   EXPECT_TRUE(reachability->IsConnected(copy, constant1));
159   EXPECT_FALSE(reachability->IsConnected(negate, add));
160   EXPECT_FALSE(reachability->IsConnected(add, negate));
161 
162   // Remove the control dependency then update and verify the reachability map
163   ASSERT_IS_OK(add->RemoveControlDependencyTo(exp));
164   reachability->UpdateReachabilityThroughInstruction(exp);
165 
166   EXPECT_TRUE(reachability->IsReachable(constant1, constant1));
167   EXPECT_FALSE(reachability->IsReachable(constant1, constant2));
168   EXPECT_TRUE(reachability->IsReachable(constant1, add));
169   EXPECT_FALSE(reachability->IsReachable(constant1, negate));
170   EXPECT_FALSE(reachability->IsReachable(constant1, exp));
171   EXPECT_TRUE(reachability->IsReachable(constant1, mul));
172   EXPECT_FALSE(reachability->IsReachable(constant1, copy));
173 
174   // Change a use within the graph then update and verify the reachability map
175   ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1));
176   reachability->UpdateReachabilityThroughInstruction(negate);
177 
178   EXPECT_FALSE(reachability->IsReachable(constant2, constant1));
179   EXPECT_TRUE(reachability->IsReachable(constant2, constant2));
180   EXPECT_TRUE(reachability->IsReachable(constant2, add));
181   EXPECT_FALSE(reachability->IsReachable(constant2, negate));
182   EXPECT_FALSE(reachability->IsReachable(constant2, exp));
183   EXPECT_TRUE(reachability->IsReachable(constant2, mul));
184   EXPECT_FALSE(reachability->IsReachable(constant2, copy));
185 }
186 
TEST_F(HloReachabilityTest,ChannelReachability)187 TEST_F(HloReachabilityTest, ChannelReachability) {
188   const Shape shape = ShapeUtil::MakeShape(F32, {5, 7});
189   HloComputation::Builder builder("ChannelReachability");
190   auto param = builder.AddInstruction(
191       HloInstruction::CreateParameter(0, shape, "param"));
192   auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
193   auto send =
194       builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1));
195   auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
196   auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
197   auto recv =
198       builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1));
199   auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
200 
201   auto module = CreateNewVerifiedModule();
202   auto computation = module->AddEntryComputation(builder.Build(recv_done));
203   auto reachability = HloReachabilityMap::Build(computation);
204   EXPECT_TRUE(reachability->IsReachable(param, recv_done));
205   EXPECT_FALSE(reachability->IsReachable(send, recv));
206   EXPECT_FALSE(reachability->IsReachable(send_done, recv));
207 }
208 
209 }  // namespace
210 
211 }  // namespace xla
212