1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
17
18 #include <set>
19
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/test_helpers.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24
25 namespace xla {
26
27 namespace {
28
29 class HloReachabilityTest : public HloTestBase {};
30
TEST_F(HloReachabilityTest,Reachability)31 TEST_F(HloReachabilityTest, Reachability) {
32 // Construct and test a reachability graph of the following form:
33 /*
34 a
35 / \
36 b c
37 \ / \
38 d e
39 */
40 auto builder = HloComputation::Builder(TestName());
41 auto a = builder.AddInstruction(
42 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
43 auto b = builder.AddInstruction(
44 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
45 auto c = builder.AddInstruction(
46 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
47 auto d = builder.AddInstruction(
48 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
49 auto e = builder.AddInstruction(
50 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
51 auto module = CreateNewVerifiedModule();
52 module->AddEntryComputation(builder.Build());
53
54 HloReachabilityMap reachability({a, b, c, d, e});
55 reachability.SetReachable(a, a);
56 EXPECT_TRUE(reachability.SetReachabilityToUnion({a}, b));
57 EXPECT_TRUE(reachability.SetReachabilityToUnion({a}, c));
58 EXPECT_TRUE(reachability.SetReachabilityToUnion({b, c}, d));
59 EXPECT_TRUE(reachability.SetReachabilityToUnion({c}, e));
60
61 EXPECT_TRUE(reachability.IsReachable(a, a));
62 EXPECT_TRUE(reachability.IsReachable(a, b));
63 EXPECT_TRUE(reachability.IsReachable(a, c));
64 EXPECT_TRUE(reachability.IsReachable(a, d));
65 EXPECT_TRUE(reachability.IsReachable(a, e));
66
67 EXPECT_FALSE(reachability.IsReachable(b, a));
68 EXPECT_TRUE(reachability.IsReachable(b, b));
69 EXPECT_FALSE(reachability.IsReachable(b, c));
70 EXPECT_TRUE(reachability.IsReachable(b, d));
71 EXPECT_FALSE(reachability.IsReachable(b, e));
72
73 EXPECT_FALSE(reachability.IsReachable(e, a));
74 EXPECT_FALSE(reachability.IsReachable(e, b));
75 EXPECT_FALSE(reachability.IsReachable(e, c));
76 EXPECT_FALSE(reachability.IsReachable(e, d));
77 EXPECT_TRUE(reachability.IsReachable(e, e));
78
79 // Recomputing the same reachability for a previously computed instruction
80 // should return false (no change).
81 EXPECT_FALSE(reachability.SetReachabilityToUnion({a}, b));
82 EXPECT_FALSE(reachability.SetReachabilityToUnion({b, c}, d));
83 }
84
TEST_F(HloReachabilityTest,NonTrivialReachability)85 TEST_F(HloReachabilityTest, NonTrivialReachability) {
86 // Test reachability of a non-trivial computation:
87 //
88 // const1 const2
89 // | |
90 // | +-------+
91 // | | |
92 // add .. negate
93 // | . |
94 // | .... exp
95 // | |
96 // +---+ +-+---+
97 // | | |
98 // multiply copy
99 //
100 // There is a control dependency from 'add' to 'exp'.
101 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
102 auto builder = HloComputation::Builder(TestName());
103 auto constant1 = builder.AddInstruction(
104 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
105 auto constant2 = builder.AddInstruction(
106 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
107 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
108 r0f32, HloOpcode::kAdd, constant1, constant2));
109 auto negate = builder.AddInstruction(
110 HloInstruction::CreateUnary(r0f32, HloOpcode::kNegate, constant2));
111 auto exp = builder.AddInstruction(
112 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, negate));
113 auto mul = builder.AddInstruction(
114 HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, add, exp));
115 auto copy = builder.AddInstruction(
116 HloInstruction::CreateUnary(r0f32, HloOpcode::kCopy, exp));
117
118 auto module = CreateNewVerifiedModule();
119 auto computation =
120 module->AddEntryComputation(builder.Build(/*root_instruction=*/mul));
121
122 TF_CHECK_OK(add->AddControlDependencyTo(exp));
123 auto reachability = HloReachabilityMap::Build(computation);
124
125 EXPECT_TRUE(reachability->IsReachable(constant1, constant1));
126 EXPECT_FALSE(reachability->IsReachable(constant1, constant2));
127 EXPECT_TRUE(reachability->IsReachable(constant1, add));
128 EXPECT_FALSE(reachability->IsReachable(constant1, negate));
129 EXPECT_TRUE(reachability->IsReachable(constant1, exp));
130 EXPECT_TRUE(reachability->IsReachable(constant1, mul));
131 EXPECT_TRUE(reachability->IsReachable(constant1, copy));
132
133 EXPECT_FALSE(reachability->IsReachable(constant2, constant1));
134 EXPECT_TRUE(reachability->IsReachable(constant2, constant2));
135 EXPECT_TRUE(reachability->IsReachable(constant2, add));
136 EXPECT_TRUE(reachability->IsReachable(constant2, negate));
137 EXPECT_TRUE(reachability->IsReachable(constant2, exp));
138 EXPECT_TRUE(reachability->IsReachable(constant2, mul));
139 EXPECT_TRUE(reachability->IsReachable(constant2, copy));
140
141 EXPECT_FALSE(reachability->IsReachable(exp, constant1));
142 EXPECT_FALSE(reachability->IsReachable(exp, constant2));
143 EXPECT_FALSE(reachability->IsReachable(exp, add));
144 EXPECT_FALSE(reachability->IsReachable(exp, negate));
145 EXPECT_TRUE(reachability->IsReachable(exp, exp));
146 EXPECT_TRUE(reachability->IsReachable(exp, mul));
147 EXPECT_TRUE(reachability->IsReachable(exp, copy));
148
149 EXPECT_FALSE(reachability->IsReachable(mul, constant1));
150 EXPECT_FALSE(reachability->IsReachable(mul, constant2));
151 EXPECT_FALSE(reachability->IsReachable(mul, add));
152 EXPECT_FALSE(reachability->IsReachable(mul, negate));
153 EXPECT_FALSE(reachability->IsReachable(mul, exp));
154 EXPECT_TRUE(reachability->IsReachable(mul, mul));
155 EXPECT_FALSE(reachability->IsReachable(mul, copy));
156
157 EXPECT_TRUE(reachability->IsConnected(constant1, copy));
158 EXPECT_TRUE(reachability->IsConnected(copy, constant1));
159 EXPECT_FALSE(reachability->IsConnected(negate, add));
160 EXPECT_FALSE(reachability->IsConnected(add, negate));
161
162 // Remove the control dependency then update and verify the reachability map
163 ASSERT_IS_OK(add->RemoveControlDependencyTo(exp));
164 reachability->UpdateReachabilityThroughInstruction(exp);
165
166 EXPECT_TRUE(reachability->IsReachable(constant1, constant1));
167 EXPECT_FALSE(reachability->IsReachable(constant1, constant2));
168 EXPECT_TRUE(reachability->IsReachable(constant1, add));
169 EXPECT_FALSE(reachability->IsReachable(constant1, negate));
170 EXPECT_FALSE(reachability->IsReachable(constant1, exp));
171 EXPECT_TRUE(reachability->IsReachable(constant1, mul));
172 EXPECT_FALSE(reachability->IsReachable(constant1, copy));
173
174 // Change a use within the graph then update and verify the reachability map
175 ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1));
176 reachability->UpdateReachabilityThroughInstruction(negate);
177
178 EXPECT_FALSE(reachability->IsReachable(constant2, constant1));
179 EXPECT_TRUE(reachability->IsReachable(constant2, constant2));
180 EXPECT_TRUE(reachability->IsReachable(constant2, add));
181 EXPECT_FALSE(reachability->IsReachable(constant2, negate));
182 EXPECT_FALSE(reachability->IsReachable(constant2, exp));
183 EXPECT_TRUE(reachability->IsReachable(constant2, mul));
184 EXPECT_FALSE(reachability->IsReachable(constant2, copy));
185 }
186
TEST_F(HloReachabilityTest,ChannelReachability)187 TEST_F(HloReachabilityTest, ChannelReachability) {
188 const Shape shape = ShapeUtil::MakeShape(F32, {5, 7});
189 HloComputation::Builder builder("ChannelReachability");
190 auto param = builder.AddInstruction(
191 HloInstruction::CreateParameter(0, shape, "param"));
192 auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
193 auto send =
194 builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1));
195 auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
196 auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
197 auto recv =
198 builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1));
199 auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
200
201 auto module = CreateNewVerifiedModule();
202 auto computation = module->AddEntryComputation(builder.Build(recv_done));
203 auto reachability = HloReachabilityMap::Build(computation);
204 EXPECT_TRUE(reachability->IsReachable(param, recv_done));
205 EXPECT_FALSE(reachability->IsReachable(send, recv));
206 EXPECT_FALSE(reachability->IsReachable(send_done, recv));
207 }
208
209 } // namespace
210
211 } // namespace xla
212