1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
17
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/call_graph.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29
30 namespace xla {
31 namespace {
32
33 class FlattenCallGraphTest : public HloTestBase {
34 protected:
35 // Build and return a trivial computation taking and returning a scalar.
MakeScalarComputation()36 std::unique_ptr<HloComputation> MakeScalarComputation() {
37 HloComputation::Builder builder(TestName() + ".ScalarComputation");
38 HloInstruction* param0 = builder.AddInstruction(
39 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
40 builder.AddInstruction(
41 HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0));
42 return builder.Build();
43 }
44
45 // Build and return a computation which takes a scalar and maps (kMap) the
46 // given computation to the value 'callsites' number of times.
MakeMappingComputation(HloComputation * map_computation,int64 callsites)47 std::unique_ptr<HloComputation> MakeMappingComputation(
48 HloComputation* map_computation, int64 callsites) {
49 HloComputation::Builder builder(TestName() + ".MappingComputation");
50 HloInstruction* param0 = builder.AddInstruction(
51 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
52 HloInstruction* last_value = param0;
53 for (int64 i = 0; i < callsites; ++i) {
54 last_value = builder.AddInstruction(HloInstruction::CreateMap(
55 kScalarShape, {last_value}, map_computation));
56 }
57 return builder.Build();
58 }
59
60 // Build and return a computation which takes a scalar and calls (kCall) the
61 // given computation with value 'callsites' number of times.
MakeCallingComputation(HloComputation * callee_computation,int64 callsites,const string & suffix=".CallingComputation")62 std::unique_ptr<HloComputation> MakeCallingComputation(
63 HloComputation* callee_computation, int64 callsites,
64 const string& suffix = ".CallingComputation") {
65 HloComputation::Builder builder(TestName() + suffix);
66 HloInstruction* param0 = builder.AddInstruction(
67 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
68 HloInstruction* last_value = param0;
69 for (int64 i = 0; i < callsites; ++i) {
70 last_value = builder.AddInstruction(HloInstruction::CreateCall(
71 kScalarShape, {last_value}, callee_computation));
72 }
73 return builder.Build();
74 }
75
76 // Build and return a computation which takes a scalar and returns a PRED
77 // value.
MakeConditionComputation()78 std::unique_ptr<HloComputation> MakeConditionComputation() {
79 HloComputation::Builder builder(TestName() + ".ConditionComputation");
80 HloInstruction* param0 = builder.AddInstruction(
81 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
82 HloInstruction* zero = builder.AddInstruction(
83 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
84 builder.AddInstruction(
85 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
86 zero, ComparisonDirection::kGt));
87 return builder.Build();
88 }
89
RunFlattenCallGraph(HloModule * module)90 StatusOr<bool> RunFlattenCallGraph(HloModule* module) {
91 FlattenCallGraph flatten;
92 TF_ASSIGN_OR_RETURN(bool result, flatten.Run(module));
93 return result;
94 }
95
96 const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
97 };
98
TEST_F(FlattenCallGraphTest,ComplexGraph)99 TEST_F(FlattenCallGraphTest, ComplexGraph) {
100 // Test a call graph of a module with several computation called in various
101 // contexts. The call graph looks like:
102 //
103 // entry
104 // / |
105 // a |
106 // / | \ |
107 // b | cond
108 // \ |
109 // c
110 //
111 // Calls are made via kCall, kWhile, and kMap instructions.
112 auto module = CreateNewVerifiedModule();
113 HloComputation* cond_computation =
114 module->AddEmbeddedComputation(MakeConditionComputation());
115 HloComputation* c_computation =
116 module->AddEmbeddedComputation(MakeScalarComputation());
117 HloComputation* b_computation = module->AddEmbeddedComputation(
118 MakeMappingComputation(c_computation, /*callsites=*/1));
119
120 HloComputation* a_computation;
121 {
122 HloComputation::Builder builder(TestName() + ".a");
123 HloInstruction* param0 = builder.AddInstruction(
124 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
125 HloInstruction* call = builder.AddInstruction(
126 HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
127 builder.AddInstruction(HloInstruction::CreateWhile(
128 kScalarShape, cond_computation, b_computation, call));
129 a_computation = module->AddEmbeddedComputation(builder.Build());
130 }
131
132 HloComputation* entry_computation;
133 {
134 HloComputation::Builder builder(TestName() + ".entry");
135 HloInstruction* param0 = builder.AddInstruction(
136 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
137 builder.AddInstruction(HloInstruction::CreateWhile(
138 kScalarShape, cond_computation, a_computation, param0));
139 entry_computation = module->AddEntryComputation(builder.Build());
140 }
141
142 {
143 TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
144 EXPECT_TRUE(result);
145 std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
146 const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
147 EXPECT_EQ(1, c_node.caller_callsites().size());
148 }
149 }
150
151 // Test corner case of a computation used as a body and a loop condition.
TEST_F(FlattenCallGraphTest,SharedWhileConditionAndBody)152 TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
153 auto module = CreateNewVerifiedModule();
154 HloComputation* cond_computation;
155 {
156 HloComputation::Builder builder(TestName() + ".cond");
157 HloInstruction* param0 =
158 builder.AddInstruction(HloInstruction::CreateParameter(
159 0, ShapeUtil::MakeShape(PRED, {}), "param0"));
160 HloInstruction* false_constant = builder.AddInstruction(
161 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
162 builder.AddInstruction(HloInstruction::CreateCompare(
163 ShapeUtil::MakeShape(PRED, {}), param0, false_constant,
164 ComparisonDirection::kEq));
165 cond_computation = module->AddEmbeddedComputation(builder.Build());
166 }
167
168 HloComputation* entry_computation;
169 {
170 HloComputation::Builder builder(TestName() + ".entry");
171 HloInstruction* false_constant = builder.AddInstruction(
172 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
173 builder.AddInstruction(HloInstruction::CreateWhile(
174 ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation,
175 false_constant));
176 entry_computation = module->AddEntryComputation(builder.Build());
177 }
178
179 {
180 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
181 const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
182 EXPECT_EQ(2, cond_node.caller_callsites().size());
183 }
184
185 {
186 TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
187 EXPECT_TRUE(result);
188 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
189 const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
190 EXPECT_EQ(1, cond_node.caller_callsites().size());
191 }
192 }
193
194 // Test flattening of a nested calling computations.
195 //
196 // Entry
197 // / \
198 // \ /
199 // B
200 // / \
201 // \ /
202 // C
203 //
TEST_F(FlattenCallGraphTest,FlattenCalls)204 TEST_F(FlattenCallGraphTest, FlattenCalls) {
205 auto module = CreateNewVerifiedModule();
206 HloComputation* c_computation =
207 module->AddEmbeddedComputation(MakeScalarComputation());
208
209 HloComputation* b_computation = module->AddEmbeddedComputation(
210 MakeCallingComputation(c_computation, /*callsites=*/2, ".B"));
211
212 module->AddEntryComputation(
213 MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
214
215 TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
216 EXPECT_TRUE(result);
217 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
218 EXPECT_EQ(7, module->computation_count());
219
220 const CallGraphNode& c_node = call_graph->GetNode(c_computation);
221 EXPECT_EQ(1, c_node.caller_callsites().size());
222
223 const CallGraphNode& b_node = call_graph->GetNode(b_computation);
224 EXPECT_EQ(1, b_node.caller_callsites().size());
225 }
226
TEST_F(FlattenCallGraphTest,FlattenCallsInConditional)227 TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
228 auto module = CreateNewVerifiedModule();
229 HloComputation* sub_computation =
230 module->AddEmbeddedComputation(MakeScalarComputation());
231
232 // Create entry computation, which is a conditional that has the same
233 // computation in the true and false branch.
234 HloComputation::Builder builder(TestName());
235 auto pred = builder.AddInstruction(
236 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
237 auto constant1 = builder.AddInstruction(
238 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
239 auto constant2 = builder.AddInstruction(
240 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
241 builder.AddInstruction(HloInstruction::CreateConditional(
242 kScalarShape, pred, constant1, sub_computation, constant2,
243 sub_computation));
244 module->AddEntryComputation(builder.Build());
245 EXPECT_EQ(2, module->computation_count());
246
247 TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
248 EXPECT_TRUE(result);
249 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
250 // The true and false computations must now be different.
251 EXPECT_EQ(3, module->computation_count());
252
253 const CallGraphNode& sub_node = call_graph->GetNode(sub_computation);
254 EXPECT_EQ(1, sub_node.caller_callsites().size());
255 }
256
257 } // namespace
258 } // namespace xla
259