1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
24 #include "tensorflow/compiler/xla/service/hlo_parser.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/test.h"
31 
32 namespace xla {
33 
34 namespace {
35 
36 namespace op = ::xla::testing::opcode_matchers;
37 
38 class HloModuleTest : public HloTestBase {
39  protected:
HloModuleTest()40   HloModuleTest() {}
41 
42   // Create a computation which returns a constant.
CreateConstantComputation()43   std::unique_ptr<HloComputation> CreateConstantComputation() {
44     auto builder = HloComputation::Builder("Constant");
45     builder.AddInstruction(
46         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
47     return builder.Build();
48   }
49 
50   // Creates a computation which calls the given zero-parameter computations.
CreateCallComputation(absl::Span<HloComputation * const> computations)51   std::unique_ptr<HloComputation> CreateCallComputation(
52       absl::Span<HloComputation* const> computations) {
53     auto builder = HloComputation::Builder("Call");
54     for (auto computation : computations) {
55       builder.AddInstruction(
56           HloInstruction::CreateCall(r0f32_, {}, computation));
57     }
58     return builder.Build();
59   }
60 
61   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
62 };
63 
TEST_F(HloModuleTest,OneComputationPostOrder)64 TEST_F(HloModuleTest, OneComputationPostOrder) {
65   // Create a module with a single computation.
66   auto module = CreateNewVerifiedModule();
67   auto computation = module->AddEntryComputation(CreateConstantComputation());
68 
69   EXPECT_THAT(module->MakeComputationPostOrder(),
70               ::testing::ElementsAre(computation));
71 }
72 
TEST_F(HloModuleTest,TwoComputationsPostOrder)73 TEST_F(HloModuleTest, TwoComputationsPostOrder) {
74   // Create a module with two unconnected computations.
75   auto module = CreateNewVerifiedModule();
76   auto computation1 = module->AddEntryComputation(CreateConstantComputation());
77   auto computation2 =
78       module->AddEmbeddedComputation(CreateConstantComputation());
79 
80   EXPECT_THAT(module->MakeComputationPostOrder(),
81               ::testing::UnorderedElementsAre(computation1, computation2));
82 
83   // We specified the same name for both computations, but the HloModule should
84   // have made the names unique.
85   EXPECT_EQ(computation1->name(), "Constant");
86   EXPECT_EQ(computation2->name(), "Constant.1");
87 }
88 
TEST_F(HloModuleTest,CloneTest)89 TEST_F(HloModuleTest, CloneTest) {
90   // Create and copy a module with a diamond call graph of computations.
91   auto module = CreateNewVerifiedModule();
92   auto computation1 =
93       module->AddEmbeddedComputation(CreateConstantComputation());
94   auto computation2 =
95       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
96   auto computation3 =
97       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
98   module->AddEntryComputation(
99       CreateCallComputation({computation2, computation3}));
100 
101   auto post_order = module->MakeComputationPostOrder();
102   auto cloned_module = module->Clone("copy");
103   auto post_order_copied = cloned_module->MakeComputationPostOrder();
104 
105   EXPECT_EQ(post_order.size(), post_order_copied.size());
106   for (auto origin = post_order.begin(), copied = post_order_copied.begin();
107        origin != post_order.end() && copied != post_order_copied.end();
108        ++origin, ++copied) {
109     EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
110   }
111 }
112 
TEST_F(HloModuleTest,CloneHasFusion)113 TEST_F(HloModuleTest, CloneHasFusion) {
114   auto module = CreateNewVerifiedModule();
115 
116   // Create the fused computation.
117   HloComputation* fused_computation;
118   {
119     auto b = HloComputation::Builder("Fused");
120     auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
121     b.AddInstruction(
122         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x));
123     fused_computation = module->AddEmbeddedComputation(b.Build());
124   }
125 
126   // Create the entry computation.
127   {
128     auto b = HloComputation::Builder("Entry");
129     auto input = b.AddInstruction(
130         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
131     b.AddInstruction(
132         HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput,
133                                      /*operands=*/{input}, fused_computation));
134     module->AddEntryComputation(b.Build());
135   }
136 
137   auto post_order = module->MakeComputationPostOrder();
138   auto cloned_module = module->Clone("copy");
139   auto post_order_copied = cloned_module->MakeComputationPostOrder();
140 
141   EXPECT_EQ(post_order.size(), post_order_copied.size());
142   for (auto origin = post_order.begin(), copied = post_order_copied.begin();
143        origin != post_order.end() && copied != post_order_copied.end();
144        ++origin, ++copied) {
145     if ((*origin)->name() == "Fused") {
146       // Clone of the fused computation is handled when its fusion instruction
147       // is cloned, which always use suffix ".clone".
148       EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name());
149     } else {
150       EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
151     }
152   }
153 }
154 
TEST_F(HloModuleTest,DiamondComputationsPostOrder)155 TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
156   // Create a module with a diamond call graph of computations.
157   auto module = CreateNewVerifiedModule();
158   auto computation1 =
159       module->AddEmbeddedComputation(CreateConstantComputation());
160   auto computation2 =
161       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
162   auto computation3 =
163       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
164   auto computation4 = module->AddEntryComputation(
165       CreateCallComputation({computation2, computation3}));
166 
167   auto post_order = module->MakeComputationPostOrder();
168   EXPECT_THAT(post_order,
169               ::testing::UnorderedElementsAre(computation1, computation2,
170                                               computation3, computation4));
171   EXPECT_EQ(post_order.back(), computation4);
172   EXPECT_EQ(post_order.front(), computation1);
173 }
174 
TEST_F(HloModuleTest,LargeConstantToString)175 TEST_F(HloModuleTest, LargeConstantToString) {
176   // Create a module with a single computation.
177   auto module = CreateNewVerifiedModule();
178   auto builder = HloComputation::Builder("Constant");
179   std::vector<float> values(16, 42.0);
180   builder.AddInstruction(
181       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(values)));
182   module->AddEntryComputation(builder.Build());
183 
184   EXPECT_EQ(
185       "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n  "
186       "ROOT %constant = f32[16]{0} constant({...})\n}\n\n",
187       module->ToString(HloPrintOptions().set_print_large_constants(false)));
188 
189   EXPECT_EQ(
190       "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n  "
191       "ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, "
192       "42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n",
193       module->ToString(HloPrintOptions().set_print_large_constants(true)));
194 }
195 
TEST_F(HloModuleTest,UniqueModuleId)196 TEST_F(HloModuleTest, UniqueModuleId) {
197   auto module_a = CreateNewVerifiedModule();
198   auto module_b = CreateNewVerifiedModule();
199   EXPECT_NE(module_a->unique_id(), module_b->unique_id());
200 }
201 
TEST_F(HloModuleTest,ProtoSerializationWithoutSchedule)202 TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) {
203   const string text = R"(
204 HloModule axpy_module
205 
206 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
207   %alpha = f32[] parameter(0)
208   %x = f32[2,4]{1,0} parameter(1)
209   %y = f32[2,4]{1,0} parameter(2)
210   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
211   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
212   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
213 }
214 )";
215   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
216                           ParseHloString(text));
217   ASSERT_FALSE(module->has_schedule());
218   TF_ASSERT_OK_AND_ASSIGN(
219       std::unique_ptr<HloModule> module_copy,
220       HloModule::CreateFromProto(module->ToProto(), module->config()));
221   ASSERT_FALSE(module_copy->has_schedule());
222 }
223 
TEST_F(HloModuleTest,ProtoSerializationWithSchedule)224 TEST_F(HloModuleTest, ProtoSerializationWithSchedule) {
225   const string text = R"(
226 HloModule axpy_module, is_scheduled=true
227 
228 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
229   %alpha = f32[] parameter(0)
230   %x = f32[2,4]{1,0} parameter(1)
231   %y = f32[2,4]{1,0} parameter(2)
232   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
233   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
234   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
235 }
236 )";
237   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
238                           ParseHloString(text));
239   ASSERT_TRUE(module->has_schedule());
240   TF_ASSERT_OK_AND_ASSIGN(
241       std::unique_ptr<HloModule> module_copy,
242       HloModule::CreateFromProto(module->ToProto(), module->config()));
243   ASSERT_TRUE(module_copy->has_schedule());
244   TF_ASSERT_OK(module_copy->schedule().Verify());
245   EXPECT_EQ(module_copy->schedule().sequences().size(), 1);
246   ASSERT_TRUE(module_copy->schedule().is_computation_scheduled(
247       module_copy->entry_computation()));
248   EXPECT_THAT(
249       module_copy->schedule()
250           .sequence(module_copy->entry_computation())
251           .instructions(),
252       ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
253                              op::Broadcast(), op::Multiply(), op::Add()));
254 }
255 
TEST_F(HloModuleTest,ProtoSerializationPreservesIds)256 TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
257   // Verify that serializing then deserializing an HLO proto preserves the
258   // unique IDs of the instruction and module.
259   const string text =
260       R"(HloModule ReduceR3ToR2_module
261 
262 add_F32.v3 {
263   lhs = f32[] parameter(0)
264   rhs = f32[] parameter(1)
265   ROOT add = f32[] add(lhs, rhs)
266 }
267 
268 ENTRY ReduceR3ToR2.v3 {
269   input = f32[8,16,256]{2,1,0} parameter(0)
270   constant = f32[] constant(0)
271   ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
272 }
273 )";
274   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
275                           ParseHloString(text));
276 
277   // Perform various transformations on the graph:
278   //
279   //  * clone the reduction function
280   //  * replace use of reduction function with the clone.
281   //  * add a random instruction to the entry computation.
282   //
283   // This will create instruction and computation IDs which are interesting:
284   // not consecutive and not densely packed.
285   HloComputation* entry = module->entry_computation();
286   HloInstruction* root = entry->root_instruction();
287   HloComputation* reduction = root->to_apply();
288   HloComputation* reduction_clone =
289       module->AddEmbeddedComputation(reduction->Clone());
290   root->set_to_apply(reduction_clone);
291   TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
292   HloInstruction* negate = entry->AddInstruction(
293       HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
294   entry->set_root_instruction(negate);
295 
296   // Schedule the transformed module, this verifies that the serialized schedule
297   // is robust against non-consecutive IDs as well (b/114712358).
298   auto size_fn = [](const BufferValue& buffer) {
299     return ShapeUtil::ByteSizeOf(buffer.shape());
300   };
301   HloMemoryScheduler scheduler(size_fn);
302   TF_ASSERT_OK(scheduler.Run(module.get()).status());
303   ASSERT_TRUE(module->has_schedule());
304 
305   // Serialize and deserialize and verify that the instruction and computations
306   // unique ids are the same.
307   TF_ASSERT_OK_AND_ASSIGN(
308       std::unique_ptr<HloModule> module_copy,
309       HloModule::CreateFromProto(module->ToProto(), module->config()));
310 
311   // The module IDs should *not* be the same because module ids must be globally
312   // unique.
313   EXPECT_NE(module->unique_id(), module_copy->unique_id());
314 
315   // Verify that the computations and instructions all have the same unique id.
316   auto computation_copy_it = module_copy->computations().begin();
317   for (const HloComputation* computation_orig : module->computations()) {
318     const HloComputation* computation_copy = *computation_copy_it++;
319     EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
320         << absl::StrFormat(
321                "ID of original computation %s != ID of deserialized "
322                "computation %s: %d != %d",
323                computation_orig->name(), computation_copy->name(),
324                computation_orig->unique_id(), computation_copy->unique_id());
325 
326     auto instruction_copy_it = computation_copy->instructions().begin();
327     for (const HloInstruction* instruction_orig :
328          computation_orig->instructions()) {
329       const HloInstruction* instruction_copy = *instruction_copy_it++;
330       EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
331           << absl::StrFormat(
332                  "ID of original instruction %s != ID of deserialized "
333                  "instruction %s: %d != %d",
334                  instruction_orig->name(), instruction_copy->name(),
335                  instruction_orig->unique_id(), instruction_copy->unique_id());
336     }
337   }
338 
339   // Verify that the next unique ID which the module would have handed out is
340   // greater than the unique id of any instruction.
341   int next_id = module_copy->NewUniqueInstructionId();
342   for (const HloComputation* computation : module_copy->computations()) {
343     for (const HloInstruction* instruction : computation->instructions()) {
344       EXPECT_GT(next_id, instruction->unique_id());
345     }
346   }
347 }
348 
349 }  // namespace
350 
351 }  // namespace xla
352