1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17 
18 #include <unordered_map>
19 
20 #include "absl/memory/memory.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 
33 namespace xla {
34 
35 namespace {
36 
37 namespace op = ::xla::testing::opcode_matchers;
38 
39 class HloModuleTest : public HloTestBase {
40  protected:
HloModuleTest()41   HloModuleTest() {}
42 
43   // Create a computation which returns a constant.
CreateConstantComputation()44   std::unique_ptr<HloComputation> CreateConstantComputation() {
45     auto builder = HloComputation::Builder("Constant");
46     builder.AddInstruction(
47         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
48     return builder.Build();
49   }
50 
51   // Creates a computation which calls the given zero-parameter computations.
CreateCallComputation(absl::Span<HloComputation * const> computations)52   std::unique_ptr<HloComputation> CreateCallComputation(
53       absl::Span<HloComputation* const> computations) {
54     auto builder = HloComputation::Builder("Call");
55     for (auto computation : computations) {
56       builder.AddInstruction(
57           HloInstruction::CreateCall(r0f32_, {}, computation));
58     }
59     return builder.Build();
60   }
61 
62   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
63 };
64 
TEST_F(HloModuleTest,OneComputationPostOrder)65 TEST_F(HloModuleTest, OneComputationPostOrder) {
66   // Create a module with a single computation.
67   auto module = CreateNewVerifiedModule();
68   auto computation = module->AddEntryComputation(CreateConstantComputation());
69 
70   EXPECT_THAT(module->MakeComputationPostOrder(),
71               ::testing::ElementsAre(computation));
72 }
73 
TEST_F(HloModuleTest,TwoComputationsPostOrder)74 TEST_F(HloModuleTest, TwoComputationsPostOrder) {
75   // Create a module with two unconnected computations.
76   auto module = CreateNewVerifiedModule();
77   auto computation1 = module->AddEntryComputation(CreateConstantComputation());
78   auto computation2 =
79       module->AddEmbeddedComputation(CreateConstantComputation());
80 
81   EXPECT_THAT(module->MakeComputationPostOrder(),
82               ::testing::UnorderedElementsAre(computation1, computation2));
83 
84   // We specified the same name for both computations, but the HloModule should
85   // have made the names unique.
86   EXPECT_EQ(computation1->name(), "Constant");
87   EXPECT_EQ(computation2->name(), "Constant.1");
88 }
89 
TEST_F(HloModuleTest,CloneTest)90 TEST_F(HloModuleTest, CloneTest) {
91   // Create and copy a module with a diamond call graph of computations.
92   auto module = CreateNewVerifiedModule();
93   auto computation1 =
94       module->AddEmbeddedComputation(CreateConstantComputation());
95   auto computation2 =
96       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
97   auto computation3 =
98       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
99   module->AddEntryComputation(
100       CreateCallComputation({computation2, computation3}));
101 
102   auto post_order = module->MakeComputationPostOrder();
103   auto cloned_module = module->Clone("copy");
104   auto post_order_copied = cloned_module->MakeComputationPostOrder();
105 
106   EXPECT_EQ(post_order.size(), post_order_copied.size());
107   for (auto origin = post_order.begin(), copied = post_order_copied.begin();
108        origin != post_order.end() && copied != post_order_copied.end();
109        ++origin, ++copied) {
110     EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
111   }
112 }
113 
TEST_F(HloModuleTest,CloneHasFusion)114 TEST_F(HloModuleTest, CloneHasFusion) {
115   auto module = CreateNewVerifiedModule();
116 
117   // Create the fused computation.
118   HloComputation* fused_computation;
119   {
120     auto b = HloComputation::Builder("Fused");
121     auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
122     b.AddInstruction(
123         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x));
124     fused_computation = module->AddEmbeddedComputation(b.Build());
125   }
126 
127   // Create the entry computation.
128   {
129     auto b = HloComputation::Builder("Entry");
130     auto input = b.AddInstruction(
131         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
132     b.AddInstruction(
133         HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput,
134                                      /*operands=*/{input}, fused_computation));
135     module->AddEntryComputation(b.Build());
136   }
137 
138   auto post_order = module->MakeComputationPostOrder();
139   auto cloned_module = module->Clone("copy");
140   auto post_order_copied = cloned_module->MakeComputationPostOrder();
141 
142   EXPECT_EQ(post_order.size(), post_order_copied.size());
143   for (auto origin = post_order.begin(), copied = post_order_copied.begin();
144        origin != post_order.end() && copied != post_order_copied.end();
145        ++origin, ++copied) {
146     if ((*origin)->name() == "Fused") {
147       // Clone of the fused computation is handled when its fusion instruction
148       // is cloned, which always use suffix ".clone".
149       EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name());
150     } else {
151       EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
152     }
153   }
154 }
155 
TEST_F(HloModuleTest,DiamondComputationsPostOrder)156 TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
157   // Create a module with a diamond call graph of computations.
158   auto module = CreateNewVerifiedModule();
159   auto computation1 =
160       module->AddEmbeddedComputation(CreateConstantComputation());
161   auto computation2 =
162       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
163   auto computation3 =
164       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
165   auto computation4 = module->AddEntryComputation(
166       CreateCallComputation({computation2, computation3}));
167 
168   auto post_order = module->MakeComputationPostOrder();
169   EXPECT_THAT(post_order,
170               ::testing::UnorderedElementsAre(computation1, computation2,
171                                               computation3, computation4));
172   EXPECT_EQ(post_order.back(), computation4);
173   EXPECT_EQ(post_order.front(), computation1);
174 }
175 
TEST_F(HloModuleTest,LargeConstantToString)176 TEST_F(HloModuleTest, LargeConstantToString) {
177   // Create a module with a single computation.
178   auto module = CreateNewVerifiedModule();
179   auto builder = HloComputation::Builder("Constant");
180   std::vector<float> values(16, 42.0);
181   builder.AddInstruction(
182       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(values)));
183   module->AddEntryComputation(builder.Build());
184 
185   EXPECT_EQ(
186       "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n  "
187       "ROOT %constant = f32[16]{0} constant({...})\n}\n\n",
188       module->ToString(HloPrintOptions().set_print_large_constants(false)));
189 
190   EXPECT_EQ(
191       "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n  "
192       "ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, "
193       "42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n",
194       module->ToString(HloPrintOptions().set_print_large_constants(true)));
195 }
196 
TEST_F(HloModuleTest,UniqueModuleId)197 TEST_F(HloModuleTest, UniqueModuleId) {
198   auto module_a = CreateNewVerifiedModule();
199   auto module_b = CreateNewVerifiedModule();
200   EXPECT_NE(module_a->unique_id(), module_b->unique_id());
201 }
202 
TEST_F(HloModuleTest,ProtoSerializationWithoutSchedule)203 TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) {
204   const string text = R"(
205 HloModule axpy_module
206 
207 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
208   %alpha = f32[] parameter(0)
209   %x = f32[2,4]{1,0} parameter(1)
210   %y = f32[2,4]{1,0} parameter(2)
211   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
212   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
213   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
214 }
215 )";
216   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
217   ASSERT_FALSE(module->has_schedule());
218   TF_ASSERT_OK_AND_ASSIGN(
219       auto module_copy,
220       HloModule::CreateFromProto(module->ToProto(), module->config()));
221   ASSERT_FALSE(module_copy->has_schedule());
222 }
223 
TEST_F(HloModuleTest,ProtoSerializationWithSchedule)224 TEST_F(HloModuleTest, ProtoSerializationWithSchedule) {
225   const string text = R"(
226 HloModule axpy_module, is_scheduled=true
227 
228 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
229   %alpha = f32[] parameter(0)
230   %x = f32[2,4]{1,0} parameter(1)
231   %y = f32[2,4]{1,0} parameter(2)
232   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
233   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
234   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
235 }
236 )";
237   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
238   ASSERT_TRUE(module->has_schedule());
239   TF_ASSERT_OK_AND_ASSIGN(
240       auto module_copy,
241       HloModule::CreateFromProto(module->ToProto(), module->config()));
242   ASSERT_TRUE(module_copy->has_schedule());
243   TF_ASSERT_OK(module_copy->schedule().Verify());
244   EXPECT_EQ(module_copy->schedule().sequences().size(), 1);
245   ASSERT_TRUE(module_copy->schedule().is_computation_scheduled(
246       module_copy->entry_computation()));
247   EXPECT_THAT(
248       module_copy->schedule()
249           .sequence(module_copy->entry_computation())
250           .instructions(),
251       ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
252                              op::Broadcast(), op::Multiply(), op::Add()));
253 }
254 
TEST_F(HloModuleTest,ProtoSerializationPreservesIds)255 TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
256   // Verify that serializing then deserializing an HLO proto preserves the
257   // unique IDs of the instruction and module.
258   const string text =
259       R"(HloModule ReduceR3ToR2_module
260 
261 add_F32.v3 {
262   lhs = f32[] parameter(0)
263   rhs = f32[] parameter(1)
264   ROOT add = f32[] add(lhs, rhs)
265 }
266 
267 ENTRY ReduceR3ToR2.v3 {
268   input = f32[8,16,256]{2,1,0} parameter(0)
269   constant = f32[] constant(0)
270   ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
271 }
272 )";
273   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
274 
275   // Perform various transformations on the graph:
276   //
277   //  * clone the reduction function
278   //  * replace use of reduction function with the clone.
279   //  * add a random instruction to the entry computation.
280   //
281   // This will create instruction and computation IDs which are interesting:
282   // not consecutive and not densely packed.
283   HloComputation* entry = module->entry_computation();
284   HloInstruction* root = entry->root_instruction();
285   HloComputation* reduction = root->to_apply();
286   HloComputation* reduction_clone =
287       module->AddEmbeddedComputation(reduction->Clone());
288   root->set_to_apply(reduction_clone);
289   TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
290   HloInstruction* negate = entry->AddInstruction(
291       HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
292   entry->set_root_instruction(negate);
293 
294   // Schedule the transformed module, this verifies that the serialized schedule
295   // is robust against non-consecutive IDs as well (b/114712358).
296   auto size_fn = [](const BufferValue& buffer) {
297     return ShapeUtil::ByteSizeOf(buffer.shape());
298   };
299   HloMemoryScheduler scheduler(size_fn);
300   TF_ASSERT_OK(scheduler.Run(module.get()).status());
301   ASSERT_TRUE(module->has_schedule());
302 
303   // Serialize and deserialize and verify that the instruction and computations
304   // unique ids are the same.
305   TF_ASSERT_OK_AND_ASSIGN(
306       auto module_copy,
307       HloModule::CreateFromProto(module->ToProto(), module->config()));
308 
309   // The module IDs should *not* be the same because module ids must be globally
310   // unique.
311   EXPECT_NE(module->unique_id(), module_copy->unique_id());
312 
313   // Verify that the computations and instructions all have the same unique id.
314   auto computation_copy_it = module_copy->computations().begin();
315   for (const HloComputation* computation_orig : module->computations()) {
316     const HloComputation* computation_copy = *computation_copy_it++;
317     EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
318         << absl::StrFormat(
319                "ID of original computation %s != ID of deserialized "
320                "computation %s: %d != %d",
321                computation_orig->name(), computation_copy->name(),
322                computation_orig->unique_id(), computation_copy->unique_id());
323 
324     auto instruction_copy_it = computation_copy->instructions().begin();
325     for (const HloInstruction* instruction_orig :
326          computation_orig->instructions()) {
327       const HloInstruction* instruction_copy = *instruction_copy_it++;
328       EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
329           << absl::StrFormat(
330                  "ID of original instruction %s != ID of deserialized "
331                  "instruction %s: %d != %d",
332                  instruction_orig->name(), instruction_copy->name(),
333                  instruction_orig->unique_id(), instruction_copy->unique_id());
334     }
335   }
336 
337   // Verify that the next unique ID which the module would have handed out is
338   // greater than the unique id of any instruction.
339   int next_id = module_copy->NewUniqueInstructionId();
340   for (const HloComputation* computation : module_copy->computations()) {
341     for (const HloInstruction* instruction : computation->instructions()) {
342       EXPECT_GT(next_id, instruction->unique_id());
343     }
344   }
345 }
346 
TEST_F(HloModuleTest,VerifyReplaceComputationsWithSortOp)347 TEST_F(HloModuleTest, VerifyReplaceComputationsWithSortOp) {
348   const string text = R"(
349   HloModule sort
350 
351   compare {
352       p.0.lhs = f32[] parameter(0)
353       p.0.rhs = f32[] parameter(1)
354       p.1.lhs = f32[] parameter(2)
355       p.1.rhs = f32[] parameter(3)
356       ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
357   }
358 
359   ENTRY top {
360     p.0 = f32[32] parameter(0)
361     p.1 = f32[32] parameter(1)
362     ROOT %sort.148.1589 = (f32[32], f32[32]) sort(p.0, p.1), dimensions={0}, to_apply=compare
363   }
364   )";
365 
366   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
367 
368   // Create a replacement computation
369   HloComputation* new_comp;
370   {
371     auto b = HloComputation::Builder("Fused");
372     auto p0 =
373         b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p0"));
374     auto p1 =
375         b.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "p1"));
376     b.AddInstruction(HloInstruction::CreateParameter(2, r0f32_, "p2"));
377     b.AddInstruction(HloInstruction::CreateParameter(3, r0f32_, "p3"));
378     b.AddInstruction(HloInstruction::CreateCompare(
379         ShapeUtil::MakeShape(PRED, {}), p0, p1, ComparisonDirection::kGt));
380     new_comp = module->AddEmbeddedComputation(b.Build());
381   }
382 
383   HloComputation* entry = module->entry_computation();
384   HloInstruction* root = entry->root_instruction();
385   EXPECT_EQ(root->to_apply()->root_instruction()->opcode(),
386             HloOpcode::kCompare);
387   EXPECT_EQ(root->to_apply()->root_instruction()->comparison_direction(),
388             ComparisonDirection::kLt);
389 
390   std::unordered_map<HloComputation*, HloComputation*> replacement;
391   replacement[root->to_apply()] = new_comp;
392   module->ReplaceComputations(replacement);
393 
394   EXPECT_EQ(root->to_apply(), new_comp);
395 }
396 
TEST_F(HloModuleTest,OneComputationAllAllowed)397 TEST_F(HloModuleTest, OneComputationAllAllowed) {
398   // Create a module with a single computation and
399   // ensure it is available when placed in the allow-list
400   auto module = CreateNewVerifiedModule();
401   auto computation = module->AddEntryComputation(CreateConstantComputation());
402 
403   absl::flat_hash_set<HloComputation*> allowList = {computation};
404   EXPECT_THAT(module->MakeComputationPostOrder(allowList),
405               ::testing::ElementsAre(computation));
406 }
407 
TEST_F(HloModuleTest,OneComputationAllFiltered)408 TEST_F(HloModuleTest, OneComputationAllFiltered) {
409   // Create a module with a single computation.
410   auto module = CreateNewVerifiedModule();
411   module->AddEntryComputation(CreateConstantComputation());
412 
413   absl::flat_hash_set<HloComputation*> allowList = {};
414   module->MakeComputationPostOrder(allowList);
415   EXPECT_THAT(module->MakeComputationPostOrder(allowList),
416               ::testing::IsEmpty());
417 }
418 
TEST_F(HloModuleTest,DiamondComputationsPostOrderAllAllowed)419 TEST_F(HloModuleTest, DiamondComputationsPostOrderAllAllowed) {
420   // Create a module with a diamond call graph of computations.
421   auto module = CreateNewVerifiedModule();
422   auto computation1 =
423       module->AddEmbeddedComputation(CreateConstantComputation());
424   auto computation2 =
425       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
426   auto computation3 =
427       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
428   auto computation4 = module->AddEntryComputation(
429       CreateCallComputation({computation2, computation3}));
430 
431   absl::flat_hash_set<HloComputation*> allowList = {computation1, computation2,
432                                                     computation3, computation4};
433   auto post_order = module->MakeComputationPostOrder(allowList);
434   EXPECT_THAT(post_order,
435               ::testing::UnorderedElementsAre(computation1, computation2,
436                                               computation3, computation4));
437   EXPECT_EQ(post_order.back(), computation4);
438   EXPECT_EQ(post_order.front(), computation1);
439 }
440 
TEST_F(HloModuleTest,DiamondComputationsPostOrderMiddleFiltered)441 TEST_F(HloModuleTest, DiamondComputationsPostOrderMiddleFiltered) {
442   // Create a module with a diamond call graph of computations.
443   auto module = CreateNewVerifiedModule();
444   auto computation1 =
445       module->AddEmbeddedComputation(CreateConstantComputation());
446   auto computation2 =
447       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
448   auto computation3 =
449       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
450   auto computation4 = module->AddEntryComputation(
451       CreateCallComputation({computation2, computation3}));
452 
453   absl::flat_hash_set<HloComputation*> allowList = {computation1, computation4};
454   auto post_order = module->MakeComputationPostOrder(allowList);
455   EXPECT_THAT(post_order,
456               ::testing::UnorderedElementsAre(computation1, computation4));
457 }
458 
TEST_F(HloModuleTest,DiamondComputationsPostOrderAllFiltered)459 TEST_F(HloModuleTest, DiamondComputationsPostOrderAllFiltered) {
460   // Create a module with a diamond call graph of computations.
461   auto module = CreateNewVerifiedModule();
462   auto computation1 =
463       module->AddEmbeddedComputation(CreateConstantComputation());
464   auto computation2 =
465       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
466   auto computation3 =
467       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
468   module->AddEntryComputation(
469       CreateCallComputation({computation2, computation3}));
470 
471   absl::flat_hash_set<HloComputation*> allowList = {};
472   auto post_order = module->MakeComputationPostOrder(allowList);
473   EXPECT_THAT(module->MakeComputationPostOrder(allowList),
474               ::testing::IsEmpty());
475 }
476 
477 }  // namespace
478 
479 }  // namespace xla
480