1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17
18 #include <unordered_map>
19
20 #include "absl/memory/memory.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32
33 namespace xla {
34
35 namespace {
36
37 namespace op = ::xla::testing::opcode_matchers;
38
39 class HloModuleTest : public HloTestBase {
40 protected:
HloModuleTest()41 HloModuleTest() {}
42
43 // Create a computation which returns a constant.
CreateConstantComputation()44 std::unique_ptr<HloComputation> CreateConstantComputation() {
45 auto builder = HloComputation::Builder("Constant");
46 builder.AddInstruction(
47 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
48 return builder.Build();
49 }
50
51 // Creates a computation which calls the given zero-parameter computations.
CreateCallComputation(absl::Span<HloComputation * const> computations)52 std::unique_ptr<HloComputation> CreateCallComputation(
53 absl::Span<HloComputation* const> computations) {
54 auto builder = HloComputation::Builder("Call");
55 for (auto computation : computations) {
56 builder.AddInstruction(
57 HloInstruction::CreateCall(r0f32_, {}, computation));
58 }
59 return builder.Build();
60 }
61
62 Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
63 };
64
TEST_F(HloModuleTest,OneComputationPostOrder)65 TEST_F(HloModuleTest, OneComputationPostOrder) {
66 // Create a module with a single computation.
67 auto module = CreateNewVerifiedModule();
68 auto computation = module->AddEntryComputation(CreateConstantComputation());
69
70 EXPECT_THAT(module->MakeComputationPostOrder(),
71 ::testing::ElementsAre(computation));
72 }
73
TEST_F(HloModuleTest,TwoComputationsPostOrder)74 TEST_F(HloModuleTest, TwoComputationsPostOrder) {
75 // Create a module with two unconnected computations.
76 auto module = CreateNewVerifiedModule();
77 auto computation1 = module->AddEntryComputation(CreateConstantComputation());
78 auto computation2 =
79 module->AddEmbeddedComputation(CreateConstantComputation());
80
81 EXPECT_THAT(module->MakeComputationPostOrder(),
82 ::testing::UnorderedElementsAre(computation1, computation2));
83
84 // We specified the same name for both computations, but the HloModule should
85 // have made the names unique.
86 EXPECT_EQ(computation1->name(), "Constant");
87 EXPECT_EQ(computation2->name(), "Constant.1");
88 }
89
TEST_F(HloModuleTest,CloneTest)90 TEST_F(HloModuleTest, CloneTest) {
91 // Create and copy a module with a diamond call graph of computations.
92 auto module = CreateNewVerifiedModule();
93 auto computation1 =
94 module->AddEmbeddedComputation(CreateConstantComputation());
95 auto computation2 =
96 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
97 auto computation3 =
98 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
99 module->AddEntryComputation(
100 CreateCallComputation({computation2, computation3}));
101
102 auto post_order = module->MakeComputationPostOrder();
103 auto cloned_module = module->Clone("copy");
104 auto post_order_copied = cloned_module->MakeComputationPostOrder();
105
106 EXPECT_EQ(post_order.size(), post_order_copied.size());
107 for (auto origin = post_order.begin(), copied = post_order_copied.begin();
108 origin != post_order.end() && copied != post_order_copied.end();
109 ++origin, ++copied) {
110 EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
111 }
112 }
113
TEST_F(HloModuleTest,CloneHasFusion)114 TEST_F(HloModuleTest, CloneHasFusion) {
115 auto module = CreateNewVerifiedModule();
116
117 // Create the fused computation.
118 HloComputation* fused_computation;
119 {
120 auto b = HloComputation::Builder("Fused");
121 auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
122 b.AddInstruction(
123 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x));
124 fused_computation = module->AddEmbeddedComputation(b.Build());
125 }
126
127 // Create the entry computation.
128 {
129 auto b = HloComputation::Builder("Entry");
130 auto input = b.AddInstruction(
131 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
132 b.AddInstruction(
133 HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput,
134 /*operands=*/{input}, fused_computation));
135 module->AddEntryComputation(b.Build());
136 }
137
138 auto post_order = module->MakeComputationPostOrder();
139 auto cloned_module = module->Clone("copy");
140 auto post_order_copied = cloned_module->MakeComputationPostOrder();
141
142 EXPECT_EQ(post_order.size(), post_order_copied.size());
143 for (auto origin = post_order.begin(), copied = post_order_copied.begin();
144 origin != post_order.end() && copied != post_order_copied.end();
145 ++origin, ++copied) {
146 if ((*origin)->name() == "Fused") {
147 // Clone of the fused computation is handled when its fusion instruction
148 // is cloned, which always use suffix ".clone".
149 EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name());
150 } else {
151 EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
152 }
153 }
154 }
155
TEST_F(HloModuleTest,DiamondComputationsPostOrder)156 TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
157 // Create a module with a diamond call graph of computations.
158 auto module = CreateNewVerifiedModule();
159 auto computation1 =
160 module->AddEmbeddedComputation(CreateConstantComputation());
161 auto computation2 =
162 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
163 auto computation3 =
164 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
165 auto computation4 = module->AddEntryComputation(
166 CreateCallComputation({computation2, computation3}));
167
168 auto post_order = module->MakeComputationPostOrder();
169 EXPECT_THAT(post_order,
170 ::testing::UnorderedElementsAre(computation1, computation2,
171 computation3, computation4));
172 EXPECT_EQ(post_order.back(), computation4);
173 EXPECT_EQ(post_order.front(), computation1);
174 }
175
TEST_F(HloModuleTest,LargeConstantToString)176 TEST_F(HloModuleTest, LargeConstantToString) {
177 // Create a module with a single computation.
178 auto module = CreateNewVerifiedModule();
179 auto builder = HloComputation::Builder("Constant");
180 std::vector<float> values(16, 42.0);
181 builder.AddInstruction(
182 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(values)));
183 module->AddEntryComputation(builder.Build());
184
185 EXPECT_EQ(
186 "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n "
187 "ROOT %constant = f32[16]{0} constant({...})\n}\n\n",
188 module->ToString(HloPrintOptions().set_print_large_constants(false)));
189
190 EXPECT_EQ(
191 "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n "
192 "ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, "
193 "42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n",
194 module->ToString(HloPrintOptions().set_print_large_constants(true)));
195 }
196
TEST_F(HloModuleTest,UniqueModuleId)197 TEST_F(HloModuleTest, UniqueModuleId) {
198 auto module_a = CreateNewVerifiedModule();
199 auto module_b = CreateNewVerifiedModule();
200 EXPECT_NE(module_a->unique_id(), module_b->unique_id());
201 }
202
TEST_F(HloModuleTest,ProtoSerializationWithoutSchedule)203 TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) {
204 const string text = R"(
205 HloModule axpy_module
206
207 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
208 %alpha = f32[] parameter(0)
209 %x = f32[2,4]{1,0} parameter(1)
210 %y = f32[2,4]{1,0} parameter(2)
211 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
212 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
213 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
214 }
215 )";
216 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
217 ASSERT_FALSE(module->has_schedule());
218 TF_ASSERT_OK_AND_ASSIGN(
219 auto module_copy,
220 HloModule::CreateFromProto(module->ToProto(), module->config()));
221 ASSERT_FALSE(module_copy->has_schedule());
222 }
223
TEST_F(HloModuleTest,ProtoSerializationWithSchedule)224 TEST_F(HloModuleTest, ProtoSerializationWithSchedule) {
225 const string text = R"(
226 HloModule axpy_module, is_scheduled=true
227
228 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
229 %alpha = f32[] parameter(0)
230 %x = f32[2,4]{1,0} parameter(1)
231 %y = f32[2,4]{1,0} parameter(2)
232 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
233 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
234 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
235 }
236 )";
237 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
238 ASSERT_TRUE(module->has_schedule());
239 TF_ASSERT_OK_AND_ASSIGN(
240 auto module_copy,
241 HloModule::CreateFromProto(module->ToProto(), module->config()));
242 ASSERT_TRUE(module_copy->has_schedule());
243 TF_ASSERT_OK(module_copy->schedule().Verify());
244 EXPECT_EQ(module_copy->schedule().sequences().size(), 1);
245 ASSERT_TRUE(module_copy->schedule().is_computation_scheduled(
246 module_copy->entry_computation()));
247 EXPECT_THAT(
248 module_copy->schedule()
249 .sequence(module_copy->entry_computation())
250 .instructions(),
251 ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
252 op::Broadcast(), op::Multiply(), op::Add()));
253 }
254
TEST_F(HloModuleTest,ProtoSerializationPreservesIds)255 TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
256 // Verify that serializing then deserializing an HLO proto preserves the
257 // unique IDs of the instruction and module.
258 const string text =
259 R"(HloModule ReduceR3ToR2_module
260
261 add_F32.v3 {
262 lhs = f32[] parameter(0)
263 rhs = f32[] parameter(1)
264 ROOT add = f32[] add(lhs, rhs)
265 }
266
267 ENTRY ReduceR3ToR2.v3 {
268 input = f32[8,16,256]{2,1,0} parameter(0)
269 constant = f32[] constant(0)
270 ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
271 }
272 )";
273 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
274
275 // Perform various transformations on the graph:
276 //
277 // * clone the reduction function
278 // * replace use of reduction function with the clone.
279 // * add a random instruction to the entry computation.
280 //
281 // This will create instruction and computation IDs which are interesting:
282 // not consecutive and not densely packed.
283 HloComputation* entry = module->entry_computation();
284 HloInstruction* root = entry->root_instruction();
285 HloComputation* reduction = root->to_apply();
286 HloComputation* reduction_clone =
287 module->AddEmbeddedComputation(reduction->Clone());
288 root->set_to_apply(reduction_clone);
289 TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
290 HloInstruction* negate = entry->AddInstruction(
291 HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
292 entry->set_root_instruction(negate);
293
294 // Schedule the transformed module, this verifies that the serialized schedule
295 // is robust against non-consecutive IDs as well (b/114712358).
296 auto size_fn = [](const BufferValue& buffer) {
297 return ShapeUtil::ByteSizeOf(buffer.shape());
298 };
299 HloMemoryScheduler scheduler(size_fn);
300 TF_ASSERT_OK(scheduler.Run(module.get()).status());
301 ASSERT_TRUE(module->has_schedule());
302
303 // Serialize and deserialize and verify that the instruction and computations
304 // unique ids are the same.
305 TF_ASSERT_OK_AND_ASSIGN(
306 auto module_copy,
307 HloModule::CreateFromProto(module->ToProto(), module->config()));
308
309 // The module IDs should *not* be the same because module ids must be globally
310 // unique.
311 EXPECT_NE(module->unique_id(), module_copy->unique_id());
312
313 // Verify that the computations and instructions all have the same unique id.
314 auto computation_copy_it = module_copy->computations().begin();
315 for (const HloComputation* computation_orig : module->computations()) {
316 const HloComputation* computation_copy = *computation_copy_it++;
317 EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
318 << absl::StrFormat(
319 "ID of original computation %s != ID of deserialized "
320 "computation %s: %d != %d",
321 computation_orig->name(), computation_copy->name(),
322 computation_orig->unique_id(), computation_copy->unique_id());
323
324 auto instruction_copy_it = computation_copy->instructions().begin();
325 for (const HloInstruction* instruction_orig :
326 computation_orig->instructions()) {
327 const HloInstruction* instruction_copy = *instruction_copy_it++;
328 EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
329 << absl::StrFormat(
330 "ID of original instruction %s != ID of deserialized "
331 "instruction %s: %d != %d",
332 instruction_orig->name(), instruction_copy->name(),
333 instruction_orig->unique_id(), instruction_copy->unique_id());
334 }
335 }
336
337 // Verify that the next unique ID which the module would have handed out is
338 // greater than the unique id of any instruction.
339 int next_id = module_copy->NewUniqueInstructionId();
340 for (const HloComputation* computation : module_copy->computations()) {
341 for (const HloInstruction* instruction : computation->instructions()) {
342 EXPECT_GT(next_id, instruction->unique_id());
343 }
344 }
345 }
346
TEST_F(HloModuleTest,VerifyReplaceComputationsWithSortOp)347 TEST_F(HloModuleTest, VerifyReplaceComputationsWithSortOp) {
348 const string text = R"(
349 HloModule sort
350
351 compare {
352 p.0.lhs = f32[] parameter(0)
353 p.0.rhs = f32[] parameter(1)
354 p.1.lhs = f32[] parameter(2)
355 p.1.rhs = f32[] parameter(3)
356 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
357 }
358
359 ENTRY top {
360 p.0 = f32[32] parameter(0)
361 p.1 = f32[32] parameter(1)
362 ROOT %sort.148.1589 = (f32[32], f32[32]) sort(p.0, p.1), dimensions={0}, to_apply=compare
363 }
364 )";
365
366 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
367
368 // Create a replacement computation
369 HloComputation* new_comp;
370 {
371 auto b = HloComputation::Builder("Fused");
372 auto p0 =
373 b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p0"));
374 auto p1 =
375 b.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "p1"));
376 b.AddInstruction(HloInstruction::CreateParameter(2, r0f32_, "p2"));
377 b.AddInstruction(HloInstruction::CreateParameter(3, r0f32_, "p3"));
378 b.AddInstruction(HloInstruction::CreateCompare(
379 ShapeUtil::MakeShape(PRED, {}), p0, p1, ComparisonDirection::kGt));
380 new_comp = module->AddEmbeddedComputation(b.Build());
381 }
382
383 HloComputation* entry = module->entry_computation();
384 HloInstruction* root = entry->root_instruction();
385 EXPECT_EQ(root->to_apply()->root_instruction()->opcode(),
386 HloOpcode::kCompare);
387 EXPECT_EQ(root->to_apply()->root_instruction()->comparison_direction(),
388 ComparisonDirection::kLt);
389
390 std::unordered_map<HloComputation*, HloComputation*> replacement;
391 replacement[root->to_apply()] = new_comp;
392 module->ReplaceComputations(replacement);
393
394 EXPECT_EQ(root->to_apply(), new_comp);
395 }
396
TEST_F(HloModuleTest,OneComputationAllAllowed)397 TEST_F(HloModuleTest, OneComputationAllAllowed) {
398 // Create a module with a single computation and
399 // ensure it is available when placed in the allow-list
400 auto module = CreateNewVerifiedModule();
401 auto computation = module->AddEntryComputation(CreateConstantComputation());
402
403 absl::flat_hash_set<HloComputation*> allowList = {computation};
404 EXPECT_THAT(module->MakeComputationPostOrder(allowList),
405 ::testing::ElementsAre(computation));
406 }
407
TEST_F(HloModuleTest,OneComputationAllFiltered)408 TEST_F(HloModuleTest, OneComputationAllFiltered) {
409 // Create a module with a single computation.
410 auto module = CreateNewVerifiedModule();
411 module->AddEntryComputation(CreateConstantComputation());
412
413 absl::flat_hash_set<HloComputation*> allowList = {};
414 module->MakeComputationPostOrder(allowList);
415 EXPECT_THAT(module->MakeComputationPostOrder(allowList),
416 ::testing::IsEmpty());
417 }
418
TEST_F(HloModuleTest,DiamondComputationsPostOrderAllAllowed)419 TEST_F(HloModuleTest, DiamondComputationsPostOrderAllAllowed) {
420 // Create a module with a diamond call graph of computations.
421 auto module = CreateNewVerifiedModule();
422 auto computation1 =
423 module->AddEmbeddedComputation(CreateConstantComputation());
424 auto computation2 =
425 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
426 auto computation3 =
427 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
428 auto computation4 = module->AddEntryComputation(
429 CreateCallComputation({computation2, computation3}));
430
431 absl::flat_hash_set<HloComputation*> allowList = {computation1, computation2,
432 computation3, computation4};
433 auto post_order = module->MakeComputationPostOrder(allowList);
434 EXPECT_THAT(post_order,
435 ::testing::UnorderedElementsAre(computation1, computation2,
436 computation3, computation4));
437 EXPECT_EQ(post_order.back(), computation4);
438 EXPECT_EQ(post_order.front(), computation1);
439 }
440
TEST_F(HloModuleTest,DiamondComputationsPostOrderMiddleFiltered)441 TEST_F(HloModuleTest, DiamondComputationsPostOrderMiddleFiltered) {
442 // Create a module with a diamond call graph of computations.
443 auto module = CreateNewVerifiedModule();
444 auto computation1 =
445 module->AddEmbeddedComputation(CreateConstantComputation());
446 auto computation2 =
447 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
448 auto computation3 =
449 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
450 auto computation4 = module->AddEntryComputation(
451 CreateCallComputation({computation2, computation3}));
452
453 absl::flat_hash_set<HloComputation*> allowList = {computation1, computation4};
454 auto post_order = module->MakeComputationPostOrder(allowList);
455 EXPECT_THAT(post_order,
456 ::testing::UnorderedElementsAre(computation1, computation4));
457 }
458
TEST_F(HloModuleTest,DiamondComputationsPostOrderAllFiltered)459 TEST_F(HloModuleTest, DiamondComputationsPostOrderAllFiltered) {
460 // Create a module with a diamond call graph of computations.
461 auto module = CreateNewVerifiedModule();
462 auto computation1 =
463 module->AddEmbeddedComputation(CreateConstantComputation());
464 auto computation2 =
465 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
466 auto computation3 =
467 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
468 module->AddEntryComputation(
469 CreateCallComputation({computation2, computation3}));
470
471 absl::flat_hash_set<HloComputation*> allowList = {};
472 auto post_order = module->MakeComputationPostOrder(allowList);
473 EXPECT_THAT(module->MakeComputationPostOrder(allowList),
474 ::testing::IsEmpty());
475 }
476
477 } // namespace
478
479 } // namespace xla
480