1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17
18 #include <memory>
19 #include <set>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
31 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/test_helpers.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36
37 namespace xla {
38
39 namespace {
40
41 namespace m = match;
42 namespace op = xla::testing::opcode_matchers;
43 using ::testing::ElementsAre;
44 using ::testing::UnorderedElementsAre;
45
46 class HloComputationTest : public HloTestBase {
47 protected:
HloComputationTest()48 HloComputationTest() {}
49
50 // Create a computation which takes a scalar and returns its negation.
CreateNegateComputation()51 std::unique_ptr<HloComputation> CreateNegateComputation() {
52 auto builder = HloComputation::Builder("Negate");
53 auto param = builder.AddInstruction(
54 HloInstruction::CreateParameter(0, r0f32_, "param0"));
55 builder.AddInstruction(
56 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
57 return builder.Build();
58 }
59
60 // Creates a computation which calls map with the given computation.
CreateMapComputation(HloComputation * map_computation)61 std::unique_ptr<HloComputation> CreateMapComputation(
62 HloComputation* map_computation) {
63 auto builder = HloComputation::Builder("Map");
64 auto param = builder.AddInstruction(
65 HloInstruction::CreateParameter(0, r0f32_, "param0"));
66 builder.AddInstruction(
67 HloInstruction::CreateMap(r0f32_, {param}, map_computation));
68 return builder.Build();
69 }
70
71 Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
72 };
73
TEST_F(HloComputationTest,GetEmbeddedComputationsEmpty)74 TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) {
75 auto module = CreateNewVerifiedModule();
76 auto negate_computation =
77 module->AddEntryComputation(CreateNegateComputation());
78 EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
79 }
80
TEST_F(HloComputationTest,GetEmbeddedComputationsOneComputation)81 TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) {
82 // Create computation which calls one other computation.
83 auto module = CreateNewVerifiedModule();
84 auto negate_computation =
85 module->AddEmbeddedComputation(CreateNegateComputation());
86 auto map_computation =
87 module->AddEntryComputation(CreateMapComputation(negate_computation));
88 EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
89 EXPECT_THAT(map_computation->MakeEmbeddedComputationsList(),
90 ElementsAre(negate_computation));
91 }
92
TEST_F(HloComputationTest,GetEmbeddedComputationsDiamond)93 TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) {
94 // Create computations with a diamond-shaped callgraph.
95 auto module = CreateNewVerifiedModule();
96 auto negate_computation =
97 module->AddEmbeddedComputation(CreateNegateComputation());
98 auto map1_computation =
99 module->AddEmbeddedComputation(CreateMapComputation(negate_computation));
100 auto map2_computation =
101 module->AddEmbeddedComputation(CreateMapComputation(negate_computation));
102
103 auto builder = HloComputation::Builder(TestName());
104 auto param = builder.AddInstruction(
105 HloInstruction::CreateParameter(0, r0f32_, "param0"));
106 auto map1 = builder.AddInstruction(
107 HloInstruction::CreateMap(r0f32_, {param}, map1_computation));
108 auto map2 = builder.AddInstruction(
109 HloInstruction::CreateMap(r0f32_, {param}, map2_computation));
110 builder.AddInstruction(
111 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2));
112 auto computation = module->AddEntryComputation(builder.Build());
113
114 auto embedded_computations = computation->MakeEmbeddedComputationsList();
115 EXPECT_EQ(3, embedded_computations.size());
116 // GetEmbeddedComputations returns a post order of the embedded computations,
117 // so the negate computation must come first.
118 EXPECT_EQ(negate_computation, *embedded_computations.begin());
119 EXPECT_THAT(embedded_computations,
120 UnorderedElementsAre(negate_computation, map1_computation,
121 map2_computation));
122 }
123
TEST_F(HloComputationTest,PostOrderSingleton)124 TEST_F(HloComputationTest, PostOrderSingleton) {
125 // Test GetInstructionPostOrder for a computation with one instruction.
126 auto builder = HloComputation::Builder(TestName());
127 auto constant = builder.AddInstruction(
128 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
129 auto module = CreateNewVerifiedModule();
130 auto computation = module->AddEntryComputation(builder.Build());
131 EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant));
132 }
133
TEST_F(HloComputationTest,PostOrderSimple)134 TEST_F(HloComputationTest, PostOrderSimple) {
135 // Test GetInstructionPostOrder for a computation with a chain of
136 // instructions.
137 auto builder = HloComputation::Builder(TestName());
138 auto constant = builder.AddInstruction(
139 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
140 auto negate1 = builder.AddInstruction(
141 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
142 auto negate2 = builder.AddInstruction(
143 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
144 auto module = CreateNewVerifiedModule();
145 auto computation = module->AddEntryComputation(builder.Build());
146 EXPECT_THAT(computation->MakeInstructionPostOrder(),
147 ElementsAre(constant, negate1, negate2));
148 }
149
TEST_F(HloComputationTest,PostOrderTrace)150 TEST_F(HloComputationTest, PostOrderTrace) {
151 // Test GetInstructionPostOrder for a computation with a trace instruction.
152 auto builder = HloComputation::Builder(TestName());
153 auto constant = builder.AddInstruction(
154 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
155 auto negate1 = builder.AddInstruction(
156 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
157 auto trace =
158 builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1));
159 auto negate2 = builder.AddInstruction(
160 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
161 auto module = CreateNewUnverifiedModule();
162 auto computation = module->AddEntryComputation(builder.Build());
163 // Trace instructions should be at the end of the sort.
164 EXPECT_THAT(computation->MakeInstructionPostOrder(),
165 ElementsAre(constant, negate1, negate2, trace));
166 }
167
TEST_F(HloComputationTest,PostOrderDisconnectedInstructions)168 TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) {
169 // Test GetInstructionPostOrder for a computation with multiple instructions
170 // which are not connected.
171 auto builder = HloComputation::Builder(TestName());
172 auto constant1 = builder.AddInstruction(
173 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
174 auto constant2 = builder.AddInstruction(
175 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
176 auto constant3 = builder.AddInstruction(
177 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
178 auto constant4 = builder.AddInstruction(
179 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
180 auto module = CreateNewVerifiedModule();
181 auto computation = module->AddEntryComputation(builder.Build());
182 EXPECT_THAT(computation->MakeInstructionPostOrder(),
183 UnorderedElementsAre(constant1, constant2, constant3, constant4));
184 }
185
TEST_F(HloComputationTest,PostOrderWithMultipleRoots)186 TEST_F(HloComputationTest, PostOrderWithMultipleRoots) {
187 // Test GetInstructionPostOrder for a computation with multiple instructions
188 // which are not connected.
189 auto builder = HloComputation::Builder(TestName());
190 auto constant1 = builder.AddInstruction(
191 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
192 auto constant2 = builder.AddInstruction(
193 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
194 auto constant3 = builder.AddInstruction(
195 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
196 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
197 r0f32_, HloOpcode::kAdd, constant1, constant2));
198 auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
199 r0f32_, HloOpcode::kAdd, constant2, constant3));
200 auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
201 r0f32_, HloOpcode::kAdd, constant1, constant3));
202 auto module = CreateNewVerifiedModule();
203 auto computation = module->AddEntryComputation(builder.Build());
204 auto post_order = computation->MakeInstructionPostOrder();
205 EXPECT_EQ(6, post_order.size());
206 EXPECT_THAT(post_order, UnorderedElementsAre(constant1, constant2, constant3,
207 add1, add2, add3));
208 }
209
TEST_F(HloComputationTest,VisitWithMultipleRoots)210 TEST_F(HloComputationTest, VisitWithMultipleRoots) {
211 // Test that Accept visits all instructions in the computation even if the
212 // computation has multiple roots (dead code).
213 auto builder = HloComputation::Builder(TestName());
214 auto constant1 = builder.AddInstruction(
215 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
216 auto constant2 = builder.AddInstruction(
217 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
218 auto constant3 = builder.AddInstruction(
219 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
220 // Add three disconnected add expressions.
221 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
222 constant1, constant2));
223 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
224 constant2, constant3));
225 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
226 constant1, constant3));
227 auto module = CreateNewVerifiedModule();
228 auto computation = module->AddEntryComputation(builder.Build());
229 // Visitor which keeps track of which instructions have been visited.
230 class TestVisitor : public DfsHloVisitorWithDefault {
231 public:
232 explicit TestVisitor(HloComputation* computation)
233 : computation_(computation) {}
234
235 Status DefaultAction(HloInstruction* hlo_instruction) override {
236 EXPECT_FALSE(visited_set_.contains(hlo_instruction));
237 visited_set_.insert(hlo_instruction);
238 last_visited_ = hlo_instruction;
239 return Status::OK();
240 }
241
242 Status FinishVisit(HloInstruction* root) override {
243 EXPECT_EQ(computation_->root_instruction(), root);
244 ++finish_visit_calls_;
245 return Status::OK();
246 }
247
248 HloComputation* computation_;
249 absl::flat_hash_set<HloInstruction*> visited_set_;
250 int64 finish_visit_calls_ = 0;
251 HloInstruction* last_visited_ = nullptr;
252 };
253
254 TestVisitor visitor(computation);
255 EXPECT_IS_OK(computation->Accept(&visitor));
256
257 EXPECT_EQ(6, visitor.visited_set_.size());
258 EXPECT_EQ(1, visitor.finish_visit_calls_);
259 EXPECT_EQ(computation->root_instruction(), visitor.last_visited_);
260 }
261
TEST_F(HloComputationTest,DeepCopyArray)262 TEST_F(HloComputationTest, DeepCopyArray) {
263 // Test that DeepCopyInstruction properly copies an array.
264 auto builder = HloComputation::Builder(TestName());
265 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
266 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
267 auto module = CreateNewVerifiedModule();
268 auto computation = module->AddEntryComputation(builder.Build());
269 auto copy = computation->DeepCopyInstruction(constant).ValueOrDie();
270
271 EXPECT_THAT(copy, GmockMatch(m::Copy(m::Op().Is(constant))));
272 }
273
TEST_F(HloComputationTest,DeepCopyTuple)274 TEST_F(HloComputationTest, DeepCopyTuple) {
275 // Test that DeepCopyInstruction properly copies a tuple.
276 auto builder = HloComputation::Builder(TestName());
277 auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
278 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
279 auto constant2 = builder.AddInstruction(
280 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
281 auto tuple = builder.AddInstruction(
282 HloInstruction::CreateTuple({constant1, constant2}));
283
284 auto module = CreateNewVerifiedModule();
285 auto computation = module->AddEntryComputation(builder.Build());
286 auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
287
288 EXPECT_THAT(tuple_copy, GmockMatch(m::Tuple(
289 m::Copy(m::GetTupleElement(m::Op().Is(tuple))),
290 m::Copy(m::GetTupleElement(m::Op().Is(tuple))))));
291 EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index());
292 EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index());
293 }
294
TEST_F(HloComputationTest,DeepCopyArrayAtIndices)295 TEST_F(HloComputationTest, DeepCopyArrayAtIndices) {
296 // Test that DeepCopyInstruction properly handles an array when the indices to
297 // copy are specified.
298 auto builder = HloComputation::Builder(TestName());
299 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
300 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
301 auto computation = builder.Build();
302
303 {
304 // If the index is true, then a copy should be made.
305 ShapeTree<bool> indices_to_copy(constant->shape(), /*init_value=*/true);
306 EXPECT_THAT(computation->DeepCopyInstruction(constant, &indices_to_copy)
307 .ValueOrDie(),
308 GmockMatch(m::Copy(m::Op().Is(constant))));
309 }
310
311 {
312 // If the index is false, then no copy should be made.
313 ShapeTree<bool> indices_to_copy(constant->shape(), /*init_value=*/false);
314 EXPECT_EQ(computation->DeepCopyInstruction(constant, &indices_to_copy)
315 .ValueOrDie(),
316 constant);
317 }
318 }
319
TEST_F(HloComputationTest,DeepCopyTupleAtIndices)320 TEST_F(HloComputationTest, DeepCopyTupleAtIndices) {
321 // Test that DeepCopyInstruction properly copies elements of a tuple as
322 // specified by the given indices.
323 auto builder = HloComputation::Builder(TestName());
324 auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
325 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
326 auto constant2 = builder.AddInstruction(
327 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
328 auto tuple = builder.AddInstruction(
329 HloInstruction::CreateTuple({constant1, constant2}));
330 auto computation = builder.Build();
331
332 {
333 // All true values should copy all array elements.
334 ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/true);
335 ShapeTree<HloInstruction*> copies_added(tuple->shape(),
336 /*init_value=*/nullptr);
337 HloInstruction* deep_copy =
338 computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
339 .ValueOrDie();
340
341 EXPECT_THAT(deep_copy, GmockMatch(m::Tuple(
342 m::Copy(m::GetTupleElement(m::Op().Is(tuple)))
343 .Is(copies_added.element({0})),
344 m::Copy(m::GetTupleElement(m::Op().Is(tuple)))
345 .Is(copies_added.element({1})))));
346 }
347
348 {
349 // All false elements should copy no array elements, but the GTE and tuple
350 // instruction scaffolding should be built.
351 ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/false);
352 ShapeTree<HloInstruction*> copies_added(tuple->shape(),
353 /*init_value=*/nullptr);
354 HloInstruction* deep_copy =
355 computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
356 .ValueOrDie();
357
358 EXPECT_THAT(deep_copy,
359 GmockMatch(m::Tuple(m::GetTupleElement(m::Op().Is(tuple)),
360 m::GetTupleElement(m::Op().Is(tuple)))));
361 EXPECT_TRUE(copies_added.element({}) == nullptr);
362 EXPECT_TRUE(copies_added.element({0}) == nullptr);
363 EXPECT_TRUE(copies_added.element({1}) == nullptr);
364 }
365
366 {
367 // Verify one element copied, the other not.
368 ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/false);
369 *indices_to_copy.mutable_element({0}) = true;
370 ShapeTree<HloInstruction*> copies_added(tuple->shape(),
371 /*init_value=*/nullptr);
372 HloInstruction* deep_copy =
373 computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
374 .ValueOrDie();
375
376 EXPECT_THAT(deep_copy, GmockMatch(m::Tuple(
377 m::Copy(m::GetTupleElement(m::Op().Is(tuple))),
378 m::GetTupleElement(m::Op().Is(tuple)))));
379 EXPECT_TRUE(copies_added.element({}) == nullptr);
380 EXPECT_TRUE(copies_added.element({0}) != nullptr);
381 EXPECT_TRUE(copies_added.element({1}) == nullptr);
382 }
383 }
384
TEST_F(HloComputationTest,DeepCopyToken)385 TEST_F(HloComputationTest, DeepCopyToken) {
386 // Test that DeepCopyInstruction properly handles tokens which should not be
387 // copied.
388 auto builder = HloComputation::Builder(TestName());
389 auto token = builder.AddInstruction(HloInstruction::CreateToken());
390 auto module = CreateNewVerifiedModule();
391 auto computation = module->AddEntryComputation(builder.Build());
392 auto copy = computation->DeepCopyInstruction(token).ValueOrDie();
393
394 // No copy should be added.
395 EXPECT_THAT(copy, GmockMatch(m::AfterAll()));
396 }
397
TEST_F(HloComputationTest,DeepCopyTokenTuple)398 TEST_F(HloComputationTest, DeepCopyTokenTuple) {
399 // Test that DeepCopyInstruction properly handles tokens which should not be
400 // copied.
401 auto builder = HloComputation::Builder(TestName());
402 auto token = builder.AddInstruction(HloInstruction::CreateToken());
403 auto constant = builder.AddInstruction(
404 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
405 auto tuple =
406 builder.AddInstruction(HloInstruction::CreateTuple({token, constant}));
407 auto module = CreateNewVerifiedModule();
408 auto computation = module->AddEntryComputation(builder.Build());
409 auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
410
411 // Only the array (second tuple element) should be copied. The token is passed
412 // through transparently.
413 EXPECT_THAT(copy, GmockMatch(m::Tuple(
414 m::GetTupleElement(m::Op().Is(tuple)),
415 m::Copy(m::GetTupleElement(m::Op().Is(tuple))))));
416 }
417
TEST_F(HloComputationTest,CycleDetection)418 TEST_F(HloComputationTest, CycleDetection) {
419 // Test whether the visitor can detect cycles in the graph.
420 auto builder = HloComputation::Builder(TestName());
421 auto constant = builder.AddInstruction(
422 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
423 auto negate = builder.AddInstruction(
424 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
425 auto add = builder.AddInstruction(
426 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate));
427 auto module = CreateNewUnverifiedModule();
428 auto computation = module->AddEntryComputation(builder.Build());
429 // Add a control dependency to create a cycle.
430 ASSERT_IS_OK(add->AddControlDependencyTo(negate));
431
432 auto instructions = computation->MakeInstructionPostOrder();
433 EXPECT_EQ(3, instructions.size());
434
435 const auto visitor = [](HloInstruction* instruction) { return Status::OK(); };
436 auto visit_status = computation->Accept(visitor);
437 ASSERT_FALSE(visit_status.ok());
438 ASSERT_THAT(visit_status.error_message(),
439 ::testing::ContainsRegex("cycle is detecte"));
440 }
441
TEST_F(HloComputationTest,RemoveInstructionWithDuplicateOperand)442 TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
443 // Test RemoveInstructionAndUnusedOperands with an instruction which has a
444 // duplicated (dead) operand. This verifies that the operand is not deleted
445 // twice.
446 auto builder = HloComputation::Builder(TestName());
447 auto constant = builder.AddInstruction(
448 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
449 auto dead_negate = builder.AddInstruction(
450 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
451 auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
452 r0f32_, HloOpcode::kAdd, dead_negate, dead_negate));
453 auto negate = builder.AddInstruction(
454 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
455 auto module = CreateNewVerifiedModule();
456 auto computation = module->AddEntryComputation(builder.Build());
457 EXPECT_EQ(4, computation->instruction_count());
458 EXPECT_THAT(computation->root_instruction(),
459 GmockMatch(m::Negate(m::Op().Is(constant))));
460 EXPECT_EQ(negate, computation->root_instruction());
461
462 ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add));
463
464 EXPECT_EQ(2, computation->instruction_count());
465 EXPECT_THAT(computation->root_instruction(),
466 GmockMatch(m::Negate(m::Op().Is(constant))));
467 EXPECT_EQ(negate, computation->root_instruction());
468 }
469
TEST_F(HloComputationTest,CloneWithControlDependency)470 TEST_F(HloComputationTest, CloneWithControlDependency) {
471 auto builder = HloComputation::Builder(TestName());
472 auto constant1 = builder.AddInstruction(
473 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
474 auto constant2 = builder.AddInstruction(
475 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
476 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
477 r0f32_, HloOpcode::kAdd, constant1, constant2));
478
479 auto param = builder.AddInstruction(
480 HloInstruction::CreateParameter(0, r0f32_, "param0"));
481 auto negate = builder.AddInstruction(
482 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
483 auto module = CreateNewVerifiedModule();
484 auto computation =
485 module->AddEntryComputation(builder.Build(/*root_instruction=*/add));
486
487 TF_CHECK_OK(negate->AddControlDependencyTo(add));
488
489 auto clone = computation->Clone();
490
491 auto cloned_add = clone->root_instruction();
492 EXPECT_EQ(cloned_add->opcode(), HloOpcode::kAdd);
493
494 auto predecessors = cloned_add->control_predecessors();
495 EXPECT_EQ(1, predecessors.size());
496 EXPECT_EQ(HloOpcode::kNegate, predecessors[0]->opcode());
497 auto successors = predecessors[0]->control_successors();
498 EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add));
499 }
500
TEST_F(HloComputationTest,CloneWithReplacements)501 TEST_F(HloComputationTest, CloneWithReplacements) {
502 auto builder = HloComputation::Builder(TestName());
503 Shape r0s64 = ShapeUtil::MakeShape(S64, {});
504 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
505 Shape r0u32 = ShapeUtil::MakeShape(U32, {});
506 auto param0 = builder.AddInstruction(
507 HloInstruction::CreateParameter(0, r0f32_, "p.0.lhs"));
508 auto param1 = builder.AddInstruction(
509 HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs"));
510 auto param2 =
511 builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1"));
512 auto lt = builder.AddInstruction(
513 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
514 param1, ComparisonDirection::kLt));
515 auto module = CreateNewVerifiedModule();
516 auto computation =
517 module->AddEntryComputation(builder.Build(/*root_instruction=*/lt));
518 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
519 replacements;
520 replacements.emplace(param2,
521 HloInstruction::CreateParameter(2, r0s32, "p.1"));
522 auto param3 = HloInstruction::CreateParameter(3, r0u32, "p.2");
523 std::vector<const HloInstruction*> extra_parameters{param3.get()};
524 auto clone = computation->CloneWithReplacements(std::move(replacements),
525 extra_parameters);
526 ASSERT_EQ(clone->num_parameters(), 4);
527 EXPECT_TRUE(
528 ShapeUtil::Equal(clone->parameter_instruction(0)->shape(), r0f32_));
529 EXPECT_TRUE(
530 ShapeUtil::Equal(clone->parameter_instruction(1)->shape(), r0f32_));
531 EXPECT_TRUE(
532 ShapeUtil::Equal(clone->parameter_instruction(2)->shape(), r0s32));
533 EXPECT_TRUE(
534 ShapeUtil::Equal(clone->parameter_instruction(3)->shape(), r0u32));
535 }
536
TEST_F(HloComputationTest,Stringification)537 TEST_F(HloComputationTest, Stringification) {
538 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
539 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
540 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
541 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
542
543 HloComputation::Builder builder("TransposeDot");
544 HloInstruction* x =
545 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
546 HloInstruction* y =
547 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
548 HloInstruction* reshape =
549 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
550 DotDimensionNumbers dot_dnums;
551 dot_dnums.add_lhs_contracting_dimensions(1);
552 dot_dnums.add_rhs_contracting_dimensions(0);
553 PrecisionConfig precision_config;
554 precision_config.mutable_operand_precision()->Resize(
555 2, PrecisionConfig::DEFAULT);
556 builder.AddInstruction(
557 HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
558 auto module = CreateNewVerifiedModule();
559 auto* computation = module->AddEntryComputation(builder.Build());
560
561 auto options = HloPrintOptions().set_print_metadata(false);
562 const string expected_computation =
563 R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
564 %x = f32[5,10]{1,0} parameter(0)
565 %y = f32[20,10]{1,0} parameter(1)
566 %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
567 ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
568 })";
569 EXPECT_EQ(computation->ToString(options), expected_computation);
570 }
571
TEST_F(HloComputationTest,StringificationIndent)572 TEST_F(HloComputationTest, StringificationIndent) {
573 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
574 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
575 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
576 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
577
578 HloComputation::Builder builder("TransposeDot");
579 HloInstruction* x =
580 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
581 HloInstruction* y =
582 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
583 HloInstruction* reshape =
584 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
585 DotDimensionNumbers dot_dnums;
586 dot_dnums.add_lhs_contracting_dimensions(1);
587 dot_dnums.add_rhs_contracting_dimensions(0);
588 PrecisionConfig precision_config;
589 precision_config.mutable_operand_precision()->Resize(
590 2, PrecisionConfig::DEFAULT);
591 builder.AddInstruction(
592 HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
593 auto module = CreateNewVerifiedModule();
594 auto* computation = module->AddEntryComputation(builder.Build());
595
596 auto options =
597 HloPrintOptions().set_print_metadata(false).set_indent_amount(2);
598 const string expected_computation =
599 R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
600 %x = f32[5,10]{1,0} parameter(0)
601 %y = f32[20,10]{1,0} parameter(1)
602 %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
603 ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
604 })";
605 EXPECT_EQ(computation->ToString(options), expected_computation);
606 }
607
TEST_F(HloComputationTest,StringificationCanonical)608 TEST_F(HloComputationTest, StringificationCanonical) {
609 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
610 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
611 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
612 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
613
614 HloComputation::Builder builder("TransposeDot");
615 HloInstruction* x =
616 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
617 HloInstruction* y =
618 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
619 HloInstruction* reshape =
620 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
621 DotDimensionNumbers dot_dnums;
622 dot_dnums.add_lhs_contracting_dimensions(1);
623 dot_dnums.add_rhs_contracting_dimensions(0);
624 PrecisionConfig precision_config;
625 precision_config.mutable_operand_precision()->Resize(
626 2, PrecisionConfig::DEFAULT);
627 builder.AddInstruction(
628 HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
629 auto module = CreateNewVerifiedModule();
630 auto* computation = module->AddEntryComputation(builder.Build());
631
632 auto options = HloPrintOptions().set_print_metadata(false);
633 const string expected_computation1 =
634 R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
635 %x = f32[5,10]{1,0} parameter(0)
636 %y = f32[20,10]{1,0} parameter(1)
637 %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
638 ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
639 })";
640 EXPECT_EQ(computation->ToString(options), expected_computation1);
641
642 options = HloPrintOptions().Canonical();
643 const string expected_computation2 = R"(TransposeDot {
644 tmp_0 = f32[5,10]{1,0} parameter(0)
645 tmp_1 = f32[20,10]{1,0} parameter(1)
646 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
647 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
648 })";
649 EXPECT_EQ(computation->ToString(options), expected_computation2);
650 }
651
MakeAddNComputation(int n)652 std::unique_ptr<HloComputation> MakeAddNComputation(int n) {
653 auto builder = HloComputation::Builder("add_n");
654 auto result = builder.AddInstruction(HloInstruction::CreateParameter(
655 0, ShapeUtil::MakeShape(F32, {}), "x_value"));
656 auto one = builder.AddInstruction(
657 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
658 for (int i = 0; i < n; ++i) {
659 result = builder.AddInstruction(HloInstruction::CreateBinary(
660 one->shape(), HloOpcode::kAdd, result, one));
661 }
662 return builder.Build();
663 }
664
TEST_F(HloComputationTest,DeepEquality)665 TEST_F(HloComputationTest, DeepEquality) {
666 auto computation_a = MakeAddNComputation(200000);
667 auto computation_b = MakeAddNComputation(200000);
668 EXPECT_TRUE(*computation_a == *computation_b);
669
670 auto computation_c = MakeAddNComputation(199999);
671 EXPECT_FALSE(*computation_a == *computation_c);
672 EXPECT_FALSE(*computation_c == *computation_b);
673 }
674
675 // Tests that cross-module AllReduce instructions are ordered before all their
676 // predecessors and after all their successors.
TEST_F(HloComputationTest,InstructionPostOrderWithAllReduce)677 TEST_F(HloComputationTest, InstructionPostOrderWithAllReduce) {
678 const char* const hlo_string = R"(
679 HloModule Module
680
681 add {
682 lhs = f32[] parameter(0)
683 rhs = f32[] parameter(1)
684 ROOT add = f32[] add(lhs, rhs)
685 }
686
687 ENTRY entry {
688 param = f32[128] parameter(0), sharding={maximal device=0}
689 crs0 = f32[128] all-reduce(param),
690 replica_groups={{0}}, all_reduce_id=1, barrier="", to_apply=add,
691 sharding={maximal device=0}
692 crs1 = f32[128] all-reduce(param),
693 replica_groups={{0}}, all_reduce_id=1, barrier="", to_apply=add,
694 sharding={maximal device=1}
695 add = f32[128] add(crs0, crs0), sharding={maximal device=0}
696 ROOT t = (f32[128], f32[128]) tuple(add, crs1)
697 })";
698 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
699 EXPECT_THAT(module->entry_computation()->MakeInstructionPostOrder(),
700 ElementsAre(op::Parameter(), op::AllReduce(), op::AllReduce(),
701 op::Add(), op::Tuple()));
702 }
703
704 } // namespace
705 } // namespace xla
706