1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
17 
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/call_graph.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 
30 namespace xla {
31 namespace {
32 
33 class FlattenCallGraphTest : public HloTestBase {
34  protected:
35   // Build and return a trivial computation taking and returning a scalar.
MakeScalarComputation()36   std::unique_ptr<HloComputation> MakeScalarComputation() {
37     HloComputation::Builder builder(TestName() + ".ScalarComputation");
38     HloInstruction* param0 = builder.AddInstruction(
39         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
40     builder.AddInstruction(
41         HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0));
42     return builder.Build();
43   }
44 
45   // Build and return a computation which takes a scalar and maps (kMap) the
46   // given computation to the value 'callsites' number of times.
MakeMappingComputation(HloComputation * map_computation,int64 callsites)47   std::unique_ptr<HloComputation> MakeMappingComputation(
48       HloComputation* map_computation, int64 callsites) {
49     HloComputation::Builder builder(TestName() + ".MappingComputation");
50     HloInstruction* param0 = builder.AddInstruction(
51         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
52     HloInstruction* last_value = param0;
53     for (int64 i = 0; i < callsites; ++i) {
54       last_value = builder.AddInstruction(HloInstruction::CreateMap(
55           kScalarShape, {last_value}, map_computation));
56     }
57     return builder.Build();
58   }
59 
60   // Build and return a computation which takes a scalar and calls (kCall) the
61   // given computation with value 'callsites' number of times.
MakeCallingComputation(HloComputation * callee_computation,int64 callsites,const string & suffix=".CallingComputation")62   std::unique_ptr<HloComputation> MakeCallingComputation(
63       HloComputation* callee_computation, int64 callsites,
64       const string& suffix = ".CallingComputation") {
65     HloComputation::Builder builder(TestName() + suffix);
66     HloInstruction* param0 = builder.AddInstruction(
67         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
68     HloInstruction* last_value = param0;
69     for (int64 i = 0; i < callsites; ++i) {
70       last_value = builder.AddInstruction(HloInstruction::CreateCall(
71           kScalarShape, {last_value}, callee_computation));
72     }
73     return builder.Build();
74   }
75 
76   // Build and return a computation which takes a scalar and returns a PRED
77   // value.
MakeConditionComputation()78   std::unique_ptr<HloComputation> MakeConditionComputation() {
79     HloComputation::Builder builder(TestName() + ".ConditionComputation");
80     HloInstruction* param0 = builder.AddInstruction(
81         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
82     HloInstruction* zero = builder.AddInstruction(
83         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
84     builder.AddInstruction(
85         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
86                                       zero, ComparisonDirection::kGt));
87     return builder.Build();
88   }
89 
RunFlattenCallGraph(HloModule * module)90   StatusOr<bool> RunFlattenCallGraph(HloModule* module) {
91     FlattenCallGraph flatten;
92     TF_ASSIGN_OR_RETURN(bool result, flatten.Run(module));
93     return result;
94   }
95 
96   const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
97 };
98 
TEST_F(FlattenCallGraphTest,ComplexGraph)99 TEST_F(FlattenCallGraphTest, ComplexGraph) {
100   // Test a call graph of a module with several computation called in various
101   // contexts. The call graph looks like:
102   //
103   //      entry
104   //      /  |
105   //     a   |
106   //   / | \ |
107   //  b  |  cond
108   //   \ |
109   //    c
110   //
111   // Calls are made via kCall, kWhile, and kMap instructions.
112   auto module = CreateNewVerifiedModule();
113   HloComputation* cond_computation =
114       module->AddEmbeddedComputation(MakeConditionComputation());
115   HloComputation* c_computation =
116       module->AddEmbeddedComputation(MakeScalarComputation());
117   HloComputation* b_computation = module->AddEmbeddedComputation(
118       MakeMappingComputation(c_computation, /*callsites=*/1));
119 
120   HloComputation* a_computation;
121   {
122     HloComputation::Builder builder(TestName() + ".a");
123     HloInstruction* param0 = builder.AddInstruction(
124         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
125     HloInstruction* call = builder.AddInstruction(
126         HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
127     builder.AddInstruction(HloInstruction::CreateWhile(
128         kScalarShape, cond_computation, b_computation, call));
129     a_computation = module->AddEmbeddedComputation(builder.Build());
130   }
131 
132   HloComputation* entry_computation;
133   {
134     HloComputation::Builder builder(TestName() + ".entry");
135     HloInstruction* param0 = builder.AddInstruction(
136         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
137     builder.AddInstruction(HloInstruction::CreateWhile(
138         kScalarShape, cond_computation, a_computation, param0));
139     entry_computation = module->AddEntryComputation(builder.Build());
140   }
141 
142   {
143     TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
144     EXPECT_TRUE(result);
145     std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
146     const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
147     EXPECT_EQ(1, c_node.caller_callsites().size());
148   }
149 }
150 
151 // Test corner case of a computation used as a body and a loop condition.
TEST_F(FlattenCallGraphTest,SharedWhileConditionAndBody)152 TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
153   auto module = CreateNewVerifiedModule();
154   HloComputation* cond_computation;
155   {
156     HloComputation::Builder builder(TestName() + ".cond");
157     HloInstruction* param0 =
158         builder.AddInstruction(HloInstruction::CreateParameter(
159             0, ShapeUtil::MakeShape(PRED, {}), "param0"));
160     HloInstruction* false_constant = builder.AddInstruction(
161         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
162     builder.AddInstruction(HloInstruction::CreateCompare(
163         ShapeUtil::MakeShape(PRED, {}), param0, false_constant,
164         ComparisonDirection::kEq));
165     cond_computation = module->AddEmbeddedComputation(builder.Build());
166   }
167 
168   HloComputation* entry_computation;
169   {
170     HloComputation::Builder builder(TestName() + ".entry");
171     HloInstruction* false_constant = builder.AddInstruction(
172         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
173     builder.AddInstruction(HloInstruction::CreateWhile(
174         ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation,
175         false_constant));
176     entry_computation = module->AddEntryComputation(builder.Build());
177   }
178 
179   {
180     std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
181     const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
182     EXPECT_EQ(2, cond_node.caller_callsites().size());
183   }
184 
185   {
186     TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
187     EXPECT_TRUE(result);
188     std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
189     const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
190     EXPECT_EQ(1, cond_node.caller_callsites().size());
191   }
192 }
193 
194 // Test flattening of a nested calling computations.
195 //
196 //   Entry
197 //    / \
198 //    \ /
199 //     B
200 //    / \
201 //    \ /
202 //     C
203 //
TEST_F(FlattenCallGraphTest,FlattenCalls)204 TEST_F(FlattenCallGraphTest, FlattenCalls) {
205   auto module = CreateNewVerifiedModule();
206   HloComputation* c_computation =
207       module->AddEmbeddedComputation(MakeScalarComputation());
208 
209   HloComputation* b_computation = module->AddEmbeddedComputation(
210       MakeCallingComputation(c_computation, /*callsites=*/2, ".B"));
211 
212   module->AddEntryComputation(
213       MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
214 
215   TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
216   EXPECT_TRUE(result);
217   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
218   EXPECT_EQ(7, module->computation_count());
219 
220   const CallGraphNode& c_node = call_graph->GetNode(c_computation);
221   EXPECT_EQ(1, c_node.caller_callsites().size());
222 
223   const CallGraphNode& b_node = call_graph->GetNode(b_computation);
224   EXPECT_EQ(1, b_node.caller_callsites().size());
225 }
226 
TEST_F(FlattenCallGraphTest,FlattenCallsInConditional)227 TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
228   auto module = CreateNewVerifiedModule();
229   HloComputation* sub_computation =
230       module->AddEmbeddedComputation(MakeScalarComputation());
231 
232   // Create entry computation, which is a conditional that has the same
233   // computation in the true and false branch.
234   HloComputation::Builder builder(TestName());
235   auto pred = builder.AddInstruction(
236       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
237   auto constant1 = builder.AddInstruction(
238       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
239   auto constant2 = builder.AddInstruction(
240       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
241   builder.AddInstruction(HloInstruction::CreateConditional(
242       kScalarShape, pred, constant1, sub_computation, constant2,
243       sub_computation));
244   module->AddEntryComputation(builder.Build());
245   EXPECT_EQ(2, module->computation_count());
246 
247   TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
248   EXPECT_TRUE(result);
249   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
250   // The true and false computations must now be different.
251   EXPECT_EQ(3, module->computation_count());
252 
253   const CallGraphNode& sub_node = call_graph->GetNode(sub_computation);
254   EXPECT_EQ(1, sub_node.caller_callsites().size());
255 }
256 
257 }  // namespace
258 }  // namespace xla
259