1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
24 #include "tensorflow/compiler/xla/service/hlo_parser.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/test.h"
31
32 namespace xla {
33
34 namespace {
35
36 namespace op = ::xla::testing::opcode_matchers;
37
38 class HloModuleTest : public HloTestBase {
39 protected:
HloModuleTest()40 HloModuleTest() {}
41
42 // Create a computation which returns a constant.
CreateConstantComputation()43 std::unique_ptr<HloComputation> CreateConstantComputation() {
44 auto builder = HloComputation::Builder("Constant");
45 builder.AddInstruction(
46 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
47 return builder.Build();
48 }
49
50 // Creates a computation which calls the given zero-parameter computations.
CreateCallComputation(absl::Span<HloComputation * const> computations)51 std::unique_ptr<HloComputation> CreateCallComputation(
52 absl::Span<HloComputation* const> computations) {
53 auto builder = HloComputation::Builder("Call");
54 for (auto computation : computations) {
55 builder.AddInstruction(
56 HloInstruction::CreateCall(r0f32_, {}, computation));
57 }
58 return builder.Build();
59 }
60
61 Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
62 };
63
TEST_F(HloModuleTest,OneComputationPostOrder)64 TEST_F(HloModuleTest, OneComputationPostOrder) {
65 // Create a module with a single computation.
66 auto module = CreateNewVerifiedModule();
67 auto computation = module->AddEntryComputation(CreateConstantComputation());
68
69 EXPECT_THAT(module->MakeComputationPostOrder(),
70 ::testing::ElementsAre(computation));
71 }
72
TEST_F(HloModuleTest,TwoComputationsPostOrder)73 TEST_F(HloModuleTest, TwoComputationsPostOrder) {
74 // Create a module with two unconnected computations.
75 auto module = CreateNewVerifiedModule();
76 auto computation1 = module->AddEntryComputation(CreateConstantComputation());
77 auto computation2 =
78 module->AddEmbeddedComputation(CreateConstantComputation());
79
80 EXPECT_THAT(module->MakeComputationPostOrder(),
81 ::testing::UnorderedElementsAre(computation1, computation2));
82
83 // We specified the same name for both computations, but the HloModule should
84 // have made the names unique.
85 EXPECT_EQ(computation1->name(), "Constant");
86 EXPECT_EQ(computation2->name(), "Constant.1");
87 }
88
TEST_F(HloModuleTest,CloneTest)89 TEST_F(HloModuleTest, CloneTest) {
90 // Create and copy a module with a diamond call graph of computations.
91 auto module = CreateNewVerifiedModule();
92 auto computation1 =
93 module->AddEmbeddedComputation(CreateConstantComputation());
94 auto computation2 =
95 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
96 auto computation3 =
97 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
98 module->AddEntryComputation(
99 CreateCallComputation({computation2, computation3}));
100
101 auto post_order = module->MakeComputationPostOrder();
102 auto cloned_module = module->Clone("copy");
103 auto post_order_copied = cloned_module->MakeComputationPostOrder();
104
105 EXPECT_EQ(post_order.size(), post_order_copied.size());
106 for (auto origin = post_order.begin(), copied = post_order_copied.begin();
107 origin != post_order.end() && copied != post_order_copied.end();
108 ++origin, ++copied) {
109 EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
110 }
111 }
112
TEST_F(HloModuleTest,CloneHasFusion)113 TEST_F(HloModuleTest, CloneHasFusion) {
114 auto module = CreateNewVerifiedModule();
115
116 // Create the fused computation.
117 HloComputation* fused_computation;
118 {
119 auto b = HloComputation::Builder("Fused");
120 auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
121 b.AddInstruction(
122 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x));
123 fused_computation = module->AddEmbeddedComputation(b.Build());
124 }
125
126 // Create the entry computation.
127 {
128 auto b = HloComputation::Builder("Entry");
129 auto input = b.AddInstruction(
130 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
131 b.AddInstruction(
132 HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput,
133 /*operands=*/{input}, fused_computation));
134 module->AddEntryComputation(b.Build());
135 }
136
137 auto post_order = module->MakeComputationPostOrder();
138 auto cloned_module = module->Clone("copy");
139 auto post_order_copied = cloned_module->MakeComputationPostOrder();
140
141 EXPECT_EQ(post_order.size(), post_order_copied.size());
142 for (auto origin = post_order.begin(), copied = post_order_copied.begin();
143 origin != post_order.end() && copied != post_order_copied.end();
144 ++origin, ++copied) {
145 if ((*origin)->name() == "Fused") {
146 // Clone of the fused computation is handled when its fusion instruction
147 // is cloned, which always use suffix ".clone".
148 EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name());
149 } else {
150 EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
151 }
152 }
153 }
154
TEST_F(HloModuleTest,DiamondComputationsPostOrder)155 TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
156 // Create a module with a diamond call graph of computations.
157 auto module = CreateNewVerifiedModule();
158 auto computation1 =
159 module->AddEmbeddedComputation(CreateConstantComputation());
160 auto computation2 =
161 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
162 auto computation3 =
163 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
164 auto computation4 = module->AddEntryComputation(
165 CreateCallComputation({computation2, computation3}));
166
167 auto post_order = module->MakeComputationPostOrder();
168 EXPECT_THAT(post_order,
169 ::testing::UnorderedElementsAre(computation1, computation2,
170 computation3, computation4));
171 EXPECT_EQ(post_order.back(), computation4);
172 EXPECT_EQ(post_order.front(), computation1);
173 }
174
TEST_F(HloModuleTest,LargeConstantToString)175 TEST_F(HloModuleTest, LargeConstantToString) {
176 // Create a module with a single computation.
177 auto module = CreateNewVerifiedModule();
178 auto builder = HloComputation::Builder("Constant");
179 std::vector<float> values(16, 42.0);
180 builder.AddInstruction(
181 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(values)));
182 module->AddEntryComputation(builder.Build());
183
184 EXPECT_EQ(
185 "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n "
186 "ROOT %constant = f32[16]{0} constant({...})\n}\n\n",
187 module->ToString(HloPrintOptions().set_print_large_constants(false)));
188
189 EXPECT_EQ(
190 "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n "
191 "ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, "
192 "42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n",
193 module->ToString(HloPrintOptions().set_print_large_constants(true)));
194 }
195
TEST_F(HloModuleTest,UniqueModuleId)196 TEST_F(HloModuleTest, UniqueModuleId) {
197 auto module_a = CreateNewVerifiedModule();
198 auto module_b = CreateNewVerifiedModule();
199 EXPECT_NE(module_a->unique_id(), module_b->unique_id());
200 }
201
TEST_F(HloModuleTest,ProtoSerializationWithoutSchedule)202 TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) {
203 const string text = R"(
204 HloModule axpy_module
205
206 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
207 %alpha = f32[] parameter(0)
208 %x = f32[2,4]{1,0} parameter(1)
209 %y = f32[2,4]{1,0} parameter(2)
210 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
211 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
212 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
213 }
214 )";
215 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
216 ParseHloString(text));
217 ASSERT_FALSE(module->has_schedule());
218 TF_ASSERT_OK_AND_ASSIGN(
219 std::unique_ptr<HloModule> module_copy,
220 HloModule::CreateFromProto(module->ToProto(), module->config()));
221 ASSERT_FALSE(module_copy->has_schedule());
222 }
223
TEST_F(HloModuleTest,ProtoSerializationWithSchedule)224 TEST_F(HloModuleTest, ProtoSerializationWithSchedule) {
225 const string text = R"(
226 HloModule axpy_module, is_scheduled=true
227
228 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
229 %alpha = f32[] parameter(0)
230 %x = f32[2,4]{1,0} parameter(1)
231 %y = f32[2,4]{1,0} parameter(2)
232 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
233 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
234 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
235 }
236 )";
237 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
238 ParseHloString(text));
239 ASSERT_TRUE(module->has_schedule());
240 TF_ASSERT_OK_AND_ASSIGN(
241 std::unique_ptr<HloModule> module_copy,
242 HloModule::CreateFromProto(module->ToProto(), module->config()));
243 ASSERT_TRUE(module_copy->has_schedule());
244 TF_ASSERT_OK(module_copy->schedule().Verify());
245 EXPECT_EQ(module_copy->schedule().sequences().size(), 1);
246 ASSERT_TRUE(module_copy->schedule().is_computation_scheduled(
247 module_copy->entry_computation()));
248 EXPECT_THAT(
249 module_copy->schedule()
250 .sequence(module_copy->entry_computation())
251 .instructions(),
252 ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
253 op::Broadcast(), op::Multiply(), op::Add()));
254 }
255
TEST_F(HloModuleTest,ProtoSerializationPreservesIds)256 TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
257 // Verify that serializing then deserializing an HLO proto preserves the
258 // unique IDs of the instruction and module.
259 const string text =
260 R"(HloModule ReduceR3ToR2_module
261
262 add_F32.v3 {
263 lhs = f32[] parameter(0)
264 rhs = f32[] parameter(1)
265 ROOT add = f32[] add(lhs, rhs)
266 }
267
268 ENTRY ReduceR3ToR2.v3 {
269 input = f32[8,16,256]{2,1,0} parameter(0)
270 constant = f32[] constant(0)
271 ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
272 }
273 )";
274 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
275 ParseHloString(text));
276
277 // Perform various transformations on the graph:
278 //
279 // * clone the reduction function
280 // * replace use of reduction function with the clone.
281 // * add a random instruction to the entry computation.
282 //
283 // This will create instruction and computation IDs which are interesting:
284 // not consecutive and not densely packed.
285 HloComputation* entry = module->entry_computation();
286 HloInstruction* root = entry->root_instruction();
287 HloComputation* reduction = root->to_apply();
288 HloComputation* reduction_clone =
289 module->AddEmbeddedComputation(reduction->Clone());
290 root->set_to_apply(reduction_clone);
291 TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
292 HloInstruction* negate = entry->AddInstruction(
293 HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
294 entry->set_root_instruction(negate);
295
296 // Schedule the transformed module, this verifies that the serialized schedule
297 // is robust against non-consecutive IDs as well (b/114712358).
298 auto size_fn = [](const BufferValue& buffer) {
299 return ShapeUtil::ByteSizeOf(buffer.shape());
300 };
301 HloMemoryScheduler scheduler(size_fn);
302 TF_ASSERT_OK(scheduler.Run(module.get()).status());
303 ASSERT_TRUE(module->has_schedule());
304
305 // Serialize and deserialize and verify that the instruction and computations
306 // unique ids are the same.
307 TF_ASSERT_OK_AND_ASSIGN(
308 std::unique_ptr<HloModule> module_copy,
309 HloModule::CreateFromProto(module->ToProto(), module->config()));
310
311 // The module IDs should *not* be the same because module ids must be globally
312 // unique.
313 EXPECT_NE(module->unique_id(), module_copy->unique_id());
314
315 // Verify that the computations and instructions all have the same unique id.
316 auto computation_copy_it = module_copy->computations().begin();
317 for (const HloComputation* computation_orig : module->computations()) {
318 const HloComputation* computation_copy = *computation_copy_it++;
319 EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
320 << absl::StrFormat(
321 "ID of original computation %s != ID of deserialized "
322 "computation %s: %d != %d",
323 computation_orig->name(), computation_copy->name(),
324 computation_orig->unique_id(), computation_copy->unique_id());
325
326 auto instruction_copy_it = computation_copy->instructions().begin();
327 for (const HloInstruction* instruction_orig :
328 computation_orig->instructions()) {
329 const HloInstruction* instruction_copy = *instruction_copy_it++;
330 EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
331 << absl::StrFormat(
332 "ID of original instruction %s != ID of deserialized "
333 "instruction %s: %d != %d",
334 instruction_orig->name(), instruction_copy->name(),
335 instruction_orig->unique_id(), instruction_copy->unique_id());
336 }
337 }
338
339 // Verify that the next unique ID which the module would have handed out is
340 // greater than the unique id of any instruction.
341 int next_id = module_copy->NewUniqueInstructionId();
342 for (const HloComputation* computation : module_copy->computations()) {
343 for (const HloInstruction* instruction : computation->instructions()) {
344 EXPECT_GT(next_id, instruction->unique_id());
345 }
346 }
347 }
348
349 } // namespace
350
351 } // namespace xla
352