1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17 
18 #include <memory>
19 #include <set>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
31 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/test_helpers.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 
37 namespace xla {
38 
39 namespace {
40 
41 namespace m = match;
42 namespace op = xla::testing::opcode_matchers;
43 using ::testing::ElementsAre;
44 using ::testing::UnorderedElementsAre;
45 
46 class HloComputationTest : public HloTestBase {
47  protected:
HloComputationTest()48   HloComputationTest() {}
49 
50   // Create a computation which takes a scalar and returns its negation.
CreateNegateComputation()51   std::unique_ptr<HloComputation> CreateNegateComputation() {
52     auto builder = HloComputation::Builder("Negate");
53     auto param = builder.AddInstruction(
54         HloInstruction::CreateParameter(0, r0f32_, "param0"));
55     builder.AddInstruction(
56         HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
57     return builder.Build();
58   }
59 
60   // Creates a computation which calls map with the given computation.
CreateMapComputation(HloComputation * map_computation)61   std::unique_ptr<HloComputation> CreateMapComputation(
62       HloComputation* map_computation) {
63     auto builder = HloComputation::Builder("Map");
64     auto param = builder.AddInstruction(
65         HloInstruction::CreateParameter(0, r0f32_, "param0"));
66     builder.AddInstruction(
67         HloInstruction::CreateMap(r0f32_, {param}, map_computation));
68     return builder.Build();
69   }
70 
71   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
72 };
73 
TEST_F(HloComputationTest,GetEmbeddedComputationsEmpty)74 TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) {
75   auto module = CreateNewVerifiedModule();
76   auto negate_computation =
77       module->AddEntryComputation(CreateNegateComputation());
78   EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
79 }
80 
TEST_F(HloComputationTest,GetEmbeddedComputationsOneComputation)81 TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) {
82   // Create computation which calls one other computation.
83   auto module = CreateNewVerifiedModule();
84   auto negate_computation =
85       module->AddEmbeddedComputation(CreateNegateComputation());
86   auto map_computation =
87       module->AddEntryComputation(CreateMapComputation(negate_computation));
88   EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
89   EXPECT_THAT(map_computation->MakeEmbeddedComputationsList(),
90               ElementsAre(negate_computation));
91 }
92 
TEST_F(HloComputationTest,GetEmbeddedComputationsDiamond)93 TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) {
94   // Create computations with a diamond-shaped callgraph.
95   auto module = CreateNewVerifiedModule();
96   auto negate_computation =
97       module->AddEmbeddedComputation(CreateNegateComputation());
98   auto map1_computation =
99       module->AddEmbeddedComputation(CreateMapComputation(negate_computation));
100   auto map2_computation =
101       module->AddEmbeddedComputation(CreateMapComputation(negate_computation));
102 
103   auto builder = HloComputation::Builder(TestName());
104   auto param = builder.AddInstruction(
105       HloInstruction::CreateParameter(0, r0f32_, "param0"));
106   auto map1 = builder.AddInstruction(
107       HloInstruction::CreateMap(r0f32_, {param}, map1_computation));
108   auto map2 = builder.AddInstruction(
109       HloInstruction::CreateMap(r0f32_, {param}, map2_computation));
110   builder.AddInstruction(
111       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2));
112   auto computation = module->AddEntryComputation(builder.Build());
113 
114   auto embedded_computations = computation->MakeEmbeddedComputationsList();
115   EXPECT_EQ(3, embedded_computations.size());
116   // GetEmbeddedComputations returns a post order of the embedded computations,
117   // so the negate computation must come first.
118   EXPECT_EQ(negate_computation, *embedded_computations.begin());
119   EXPECT_THAT(embedded_computations,
120               UnorderedElementsAre(negate_computation, map1_computation,
121                                    map2_computation));
122 }
123 
TEST_F(HloComputationTest,PostOrderSingleton)124 TEST_F(HloComputationTest, PostOrderSingleton) {
125   // Test GetInstructionPostOrder for a computation with one instruction.
126   auto builder = HloComputation::Builder(TestName());
127   auto constant = builder.AddInstruction(
128       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
129   auto module = CreateNewVerifiedModule();
130   auto computation = module->AddEntryComputation(builder.Build());
131   EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant));
132 }
133 
TEST_F(HloComputationTest,PostOrderSimple)134 TEST_F(HloComputationTest, PostOrderSimple) {
135   // Test GetInstructionPostOrder for a computation with a chain of
136   // instructions.
137   auto builder = HloComputation::Builder(TestName());
138   auto constant = builder.AddInstruction(
139       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
140   auto negate1 = builder.AddInstruction(
141       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
142   auto negate2 = builder.AddInstruction(
143       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
144   auto module = CreateNewVerifiedModule();
145   auto computation = module->AddEntryComputation(builder.Build());
146   EXPECT_THAT(computation->MakeInstructionPostOrder(),
147               ElementsAre(constant, negate1, negate2));
148 }
149 
TEST_F(HloComputationTest,PostOrderTrace)150 TEST_F(HloComputationTest, PostOrderTrace) {
151   // Test GetInstructionPostOrder for a computation with a trace instruction.
152   auto builder = HloComputation::Builder(TestName());
153   auto constant = builder.AddInstruction(
154       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
155   auto negate1 = builder.AddInstruction(
156       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
157   auto trace =
158       builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1));
159   auto negate2 = builder.AddInstruction(
160       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
161   auto module = CreateNewUnverifiedModule();
162   auto computation = module->AddEntryComputation(builder.Build());
163   // Trace instructions should be at the end of the sort.
164   EXPECT_THAT(computation->MakeInstructionPostOrder(),
165               ElementsAre(constant, negate1, negate2, trace));
166 }
167 
TEST_F(HloComputationTest,PostOrderDisconnectedInstructions)168 TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) {
169   // Test GetInstructionPostOrder for a computation with multiple instructions
170   // which are not connected.
171   auto builder = HloComputation::Builder(TestName());
172   auto constant1 = builder.AddInstruction(
173       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
174   auto constant2 = builder.AddInstruction(
175       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
176   auto constant3 = builder.AddInstruction(
177       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
178   auto constant4 = builder.AddInstruction(
179       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
180   auto module = CreateNewVerifiedModule();
181   auto computation = module->AddEntryComputation(builder.Build());
182   EXPECT_THAT(computation->MakeInstructionPostOrder(),
183               UnorderedElementsAre(constant1, constant2, constant3, constant4));
184 }
185 
TEST_F(HloComputationTest,PostOrderWithMultipleRoots)186 TEST_F(HloComputationTest, PostOrderWithMultipleRoots) {
187   // Test GetInstructionPostOrder for a computation with multiple instructions
188   // which are not connected.
189   auto builder = HloComputation::Builder(TestName());
190   auto constant1 = builder.AddInstruction(
191       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
192   auto constant2 = builder.AddInstruction(
193       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
194   auto constant3 = builder.AddInstruction(
195       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
196   auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
197       r0f32_, HloOpcode::kAdd, constant1, constant2));
198   auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
199       r0f32_, HloOpcode::kAdd, constant2, constant3));
200   auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
201       r0f32_, HloOpcode::kAdd, constant1, constant3));
202   auto module = CreateNewVerifiedModule();
203   auto computation = module->AddEntryComputation(builder.Build());
204   auto post_order = computation->MakeInstructionPostOrder();
205   EXPECT_EQ(6, post_order.size());
206   EXPECT_THAT(post_order, UnorderedElementsAre(constant1, constant2, constant3,
207                                                add1, add2, add3));
208 }
209 
TEST_F(HloComputationTest,VisitWithMultipleRoots)210 TEST_F(HloComputationTest, VisitWithMultipleRoots) {
211   // Test that Accept visits all instructions in the computation even if the
212   // computation has multiple roots (dead code).
213   auto builder = HloComputation::Builder(TestName());
214   auto constant1 = builder.AddInstruction(
215       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
216   auto constant2 = builder.AddInstruction(
217       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
218   auto constant3 = builder.AddInstruction(
219       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
220   // Add three disconnected add expressions.
221   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
222                                                       constant1, constant2));
223   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
224                                                       constant2, constant3));
225   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
226                                                       constant1, constant3));
227   auto module = CreateNewVerifiedModule();
228   auto computation = module->AddEntryComputation(builder.Build());
229   // Visitor which keeps track of which instructions have been visited.
230   class TestVisitor : public DfsHloVisitorWithDefault {
231    public:
232     explicit TestVisitor(HloComputation* computation)
233         : computation_(computation) {}
234 
235     Status DefaultAction(HloInstruction* hlo_instruction) override {
236       EXPECT_FALSE(visited_set_.contains(hlo_instruction));
237       visited_set_.insert(hlo_instruction);
238       last_visited_ = hlo_instruction;
239       return Status::OK();
240     }
241 
242     Status FinishVisit(HloInstruction* root) override {
243       EXPECT_EQ(computation_->root_instruction(), root);
244       ++finish_visit_calls_;
245       return Status::OK();
246     }
247 
248     HloComputation* computation_;
249     absl::flat_hash_set<HloInstruction*> visited_set_;
250     int64 finish_visit_calls_ = 0;
251     HloInstruction* last_visited_ = nullptr;
252   };
253 
254   TestVisitor visitor(computation);
255   EXPECT_IS_OK(computation->Accept(&visitor));
256 
257   EXPECT_EQ(6, visitor.visited_set_.size());
258   EXPECT_EQ(1, visitor.finish_visit_calls_);
259   EXPECT_EQ(computation->root_instruction(), visitor.last_visited_);
260 }
261 
TEST_F(HloComputationTest,DeepCopyArray)262 TEST_F(HloComputationTest, DeepCopyArray) {
263   // Test that DeepCopyInstruction properly copies an array.
264   auto builder = HloComputation::Builder(TestName());
265   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
266       LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
267   auto module = CreateNewVerifiedModule();
268   auto computation = module->AddEntryComputation(builder.Build());
269   auto copy = computation->DeepCopyInstruction(constant).ValueOrDie();
270 
271   EXPECT_THAT(copy, GmockMatch(m::Copy(m::Op().Is(constant))));
272 }
273 
TEST_F(HloComputationTest,DeepCopyTuple)274 TEST_F(HloComputationTest, DeepCopyTuple) {
275   // Test that DeepCopyInstruction properly copies a tuple.
276   auto builder = HloComputation::Builder(TestName());
277   auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
278       LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
279   auto constant2 = builder.AddInstruction(
280       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
281   auto tuple = builder.AddInstruction(
282       HloInstruction::CreateTuple({constant1, constant2}));
283 
284   auto module = CreateNewVerifiedModule();
285   auto computation = module->AddEntryComputation(builder.Build());
286   auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
287 
288   EXPECT_THAT(tuple_copy, GmockMatch(m::Tuple(
289                               m::Copy(m::GetTupleElement(m::Op().Is(tuple))),
290                               m::Copy(m::GetTupleElement(m::Op().Is(tuple))))));
291   EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index());
292   EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index());
293 }
294 
TEST_F(HloComputationTest,DeepCopyArrayAtIndices)295 TEST_F(HloComputationTest, DeepCopyArrayAtIndices) {
296   // Test that DeepCopyInstruction properly handles an array when the indices to
297   // copy are specified.
298   auto builder = HloComputation::Builder(TestName());
299   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
300       LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
301   auto computation = builder.Build();
302 
303   {
304     // If the index is true, then a copy should be made.
305     ShapeTree<bool> indices_to_copy(constant->shape(), /*init_value=*/true);
306     EXPECT_THAT(computation->DeepCopyInstruction(constant, &indices_to_copy)
307                     .ValueOrDie(),
308                 GmockMatch(m::Copy(m::Op().Is(constant))));
309   }
310 
311   {
312     // If the index is false, then no copy should be made.
313     ShapeTree<bool> indices_to_copy(constant->shape(), /*init_value=*/false);
314     EXPECT_EQ(computation->DeepCopyInstruction(constant, &indices_to_copy)
315                   .ValueOrDie(),
316               constant);
317   }
318 }
319 
TEST_F(HloComputationTest,DeepCopyTupleAtIndices)320 TEST_F(HloComputationTest, DeepCopyTupleAtIndices) {
321   // Test that DeepCopyInstruction properly copies elements of a tuple as
322   // specified by the given indices.
323   auto builder = HloComputation::Builder(TestName());
324   auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
325       LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
326   auto constant2 = builder.AddInstruction(
327       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
328   auto tuple = builder.AddInstruction(
329       HloInstruction::CreateTuple({constant1, constant2}));
330   auto computation = builder.Build();
331 
332   {
333     // All true values should copy all array elements.
334     ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/true);
335     ShapeTree<HloInstruction*> copies_added(tuple->shape(),
336                                             /*init_value=*/nullptr);
337     HloInstruction* deep_copy =
338         computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
339             .ValueOrDie();
340 
341     EXPECT_THAT(deep_copy, GmockMatch(m::Tuple(
342                                m::Copy(m::GetTupleElement(m::Op().Is(tuple)))
343                                    .Is(copies_added.element({0})),
344                                m::Copy(m::GetTupleElement(m::Op().Is(tuple)))
345                                    .Is(copies_added.element({1})))));
346   }
347 
348   {
349     // All false elements should copy no array elements, but the GTE and tuple
350     // instruction scaffolding should be built.
351     ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/false);
352     ShapeTree<HloInstruction*> copies_added(tuple->shape(),
353                                             /*init_value=*/nullptr);
354     HloInstruction* deep_copy =
355         computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
356             .ValueOrDie();
357 
358     EXPECT_THAT(deep_copy,
359                 GmockMatch(m::Tuple(m::GetTupleElement(m::Op().Is(tuple)),
360                                     m::GetTupleElement(m::Op().Is(tuple)))));
361     EXPECT_TRUE(copies_added.element({}) == nullptr);
362     EXPECT_TRUE(copies_added.element({0}) == nullptr);
363     EXPECT_TRUE(copies_added.element({1}) == nullptr);
364   }
365 
366   {
367     // Verify one element copied, the other not.
368     ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/false);
369     *indices_to_copy.mutable_element({0}) = true;
370     ShapeTree<HloInstruction*> copies_added(tuple->shape(),
371                                             /*init_value=*/nullptr);
372     HloInstruction* deep_copy =
373         computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
374             .ValueOrDie();
375 
376     EXPECT_THAT(deep_copy, GmockMatch(m::Tuple(
377                                m::Copy(m::GetTupleElement(m::Op().Is(tuple))),
378                                m::GetTupleElement(m::Op().Is(tuple)))));
379     EXPECT_TRUE(copies_added.element({}) == nullptr);
380     EXPECT_TRUE(copies_added.element({0}) != nullptr);
381     EXPECT_TRUE(copies_added.element({1}) == nullptr);
382   }
383 }
384 
TEST_F(HloComputationTest,DeepCopyToken)385 TEST_F(HloComputationTest, DeepCopyToken) {
386   // Test that DeepCopyInstruction properly handles tokens which should not be
387   // copied.
388   auto builder = HloComputation::Builder(TestName());
389   auto token = builder.AddInstruction(HloInstruction::CreateToken());
390   auto module = CreateNewVerifiedModule();
391   auto computation = module->AddEntryComputation(builder.Build());
392   auto copy = computation->DeepCopyInstruction(token).ValueOrDie();
393 
394   // No copy should be added.
395   EXPECT_THAT(copy, GmockMatch(m::AfterAll()));
396 }
397 
TEST_F(HloComputationTest,DeepCopyTokenTuple)398 TEST_F(HloComputationTest, DeepCopyTokenTuple) {
399   // Test that DeepCopyInstruction properly handles tokens which should not be
400   // copied.
401   auto builder = HloComputation::Builder(TestName());
402   auto token = builder.AddInstruction(HloInstruction::CreateToken());
403   auto constant = builder.AddInstruction(
404       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
405   auto tuple =
406       builder.AddInstruction(HloInstruction::CreateTuple({token, constant}));
407   auto module = CreateNewVerifiedModule();
408   auto computation = module->AddEntryComputation(builder.Build());
409   auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
410 
411   // Only the array (second tuple element) should be copied. The token is passed
412   // through transparently.
413   EXPECT_THAT(copy, GmockMatch(m::Tuple(
414                         m::GetTupleElement(m::Op().Is(tuple)),
415                         m::Copy(m::GetTupleElement(m::Op().Is(tuple))))));
416 }
417 
TEST_F(HloComputationTest,CycleDetection)418 TEST_F(HloComputationTest, CycleDetection) {
419   // Test whether the visitor can detect cycles in the graph.
420   auto builder = HloComputation::Builder(TestName());
421   auto constant = builder.AddInstruction(
422       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
423   auto negate = builder.AddInstruction(
424       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
425   auto add = builder.AddInstruction(
426       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate));
427   auto module = CreateNewUnverifiedModule();
428   auto computation = module->AddEntryComputation(builder.Build());
429   // Add a control dependency to create a cycle.
430   ASSERT_IS_OK(add->AddControlDependencyTo(negate));
431 
432   auto instructions = computation->MakeInstructionPostOrder();
433   EXPECT_EQ(3, instructions.size());
434 
435   const auto visitor = [](HloInstruction* instruction) { return Status::OK(); };
436   auto visit_status = computation->Accept(visitor);
437   ASSERT_FALSE(visit_status.ok());
438   ASSERT_THAT(visit_status.error_message(),
439               ::testing::ContainsRegex("cycle is detecte"));
440 }
441 
TEST_F(HloComputationTest,RemoveInstructionWithDuplicateOperand)442 TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
443   // Test RemoveInstructionAndUnusedOperands with an instruction which has a
444   // duplicated (dead) operand. This verifies that the operand is not deleted
445   // twice.
446   auto builder = HloComputation::Builder(TestName());
447   auto constant = builder.AddInstruction(
448       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
449   auto dead_negate = builder.AddInstruction(
450       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
451   auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
452       r0f32_, HloOpcode::kAdd, dead_negate, dead_negate));
453   auto negate = builder.AddInstruction(
454       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
455   auto module = CreateNewVerifiedModule();
456   auto computation = module->AddEntryComputation(builder.Build());
457   EXPECT_EQ(4, computation->instruction_count());
458   EXPECT_THAT(computation->root_instruction(),
459               GmockMatch(m::Negate(m::Op().Is(constant))));
460   EXPECT_EQ(negate, computation->root_instruction());
461 
462   ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add));
463 
464   EXPECT_EQ(2, computation->instruction_count());
465   EXPECT_THAT(computation->root_instruction(),
466               GmockMatch(m::Negate(m::Op().Is(constant))));
467   EXPECT_EQ(negate, computation->root_instruction());
468 }
469 
TEST_F(HloComputationTest,CloneWithControlDependency)470 TEST_F(HloComputationTest, CloneWithControlDependency) {
471   auto builder = HloComputation::Builder(TestName());
472   auto constant1 = builder.AddInstruction(
473       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
474   auto constant2 = builder.AddInstruction(
475       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
476   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
477       r0f32_, HloOpcode::kAdd, constant1, constant2));
478 
479   auto param = builder.AddInstruction(
480       HloInstruction::CreateParameter(0, r0f32_, "param0"));
481   auto negate = builder.AddInstruction(
482       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
483   auto module = CreateNewVerifiedModule();
484   auto computation =
485       module->AddEntryComputation(builder.Build(/*root_instruction=*/add));
486 
487   TF_CHECK_OK(negate->AddControlDependencyTo(add));
488 
489   auto clone = computation->Clone();
490 
491   auto cloned_add = clone->root_instruction();
492   EXPECT_EQ(cloned_add->opcode(), HloOpcode::kAdd);
493 
494   auto predecessors = cloned_add->control_predecessors();
495   EXPECT_EQ(1, predecessors.size());
496   EXPECT_EQ(HloOpcode::kNegate, predecessors[0]->opcode());
497   auto successors = predecessors[0]->control_successors();
498   EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add));
499 }
500 
TEST_F(HloComputationTest,CloneWithReplacements)501 TEST_F(HloComputationTest, CloneWithReplacements) {
502   auto builder = HloComputation::Builder(TestName());
503   Shape r0s64 = ShapeUtil::MakeShape(S64, {});
504   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
505   Shape r0u32 = ShapeUtil::MakeShape(U32, {});
506   auto param0 = builder.AddInstruction(
507       HloInstruction::CreateParameter(0, r0f32_, "p.0.lhs"));
508   auto param1 = builder.AddInstruction(
509       HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs"));
510   auto param2 =
511       builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1"));
512   auto lt = builder.AddInstruction(
513       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
514                                     param1, ComparisonDirection::kLt));
515   auto module = CreateNewVerifiedModule();
516   auto computation =
517       module->AddEntryComputation(builder.Build(/*root_instruction=*/lt));
518   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
519       replacements;
520   replacements.emplace(param2,
521                        HloInstruction::CreateParameter(2, r0s32, "p.1"));
522   auto param3 = HloInstruction::CreateParameter(3, r0u32, "p.2");
523   std::vector<const HloInstruction*> extra_parameters{param3.get()};
524   auto clone = computation->CloneWithReplacements(std::move(replacements),
525                                                   extra_parameters);
526   ASSERT_EQ(clone->num_parameters(), 4);
527   EXPECT_TRUE(
528       ShapeUtil::Equal(clone->parameter_instruction(0)->shape(), r0f32_));
529   EXPECT_TRUE(
530       ShapeUtil::Equal(clone->parameter_instruction(1)->shape(), r0f32_));
531   EXPECT_TRUE(
532       ShapeUtil::Equal(clone->parameter_instruction(2)->shape(), r0s32));
533   EXPECT_TRUE(
534       ShapeUtil::Equal(clone->parameter_instruction(3)->shape(), r0u32));
535 }
536 
TEST_F(HloComputationTest,Stringification)537 TEST_F(HloComputationTest, Stringification) {
538   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
539   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
540   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
541   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
542 
543   HloComputation::Builder builder("TransposeDot");
544   HloInstruction* x =
545       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
546   HloInstruction* y =
547       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
548   HloInstruction* reshape =
549       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
550   DotDimensionNumbers dot_dnums;
551   dot_dnums.add_lhs_contracting_dimensions(1);
552   dot_dnums.add_rhs_contracting_dimensions(0);
553   PrecisionConfig precision_config;
554   precision_config.mutable_operand_precision()->Resize(
555       2, PrecisionConfig::DEFAULT);
556   builder.AddInstruction(
557       HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
558   auto module = CreateNewVerifiedModule();
559   auto* computation = module->AddEntryComputation(builder.Build());
560 
561   auto options = HloPrintOptions().set_print_metadata(false);
562   const string expected_computation =
563       R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
564   %x = f32[5,10]{1,0} parameter(0)
565   %y = f32[20,10]{1,0} parameter(1)
566   %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
567   ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
568 })";
569   EXPECT_EQ(computation->ToString(options), expected_computation);
570 }
571 
TEST_F(HloComputationTest,StringificationIndent)572 TEST_F(HloComputationTest, StringificationIndent) {
573   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
574   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
575   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
576   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
577 
578   HloComputation::Builder builder("TransposeDot");
579   HloInstruction* x =
580       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
581   HloInstruction* y =
582       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
583   HloInstruction* reshape =
584       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
585   DotDimensionNumbers dot_dnums;
586   dot_dnums.add_lhs_contracting_dimensions(1);
587   dot_dnums.add_rhs_contracting_dimensions(0);
588   PrecisionConfig precision_config;
589   precision_config.mutable_operand_precision()->Resize(
590       2, PrecisionConfig::DEFAULT);
591   builder.AddInstruction(
592       HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
593   auto module = CreateNewVerifiedModule();
594   auto* computation = module->AddEntryComputation(builder.Build());
595 
596   auto options =
597       HloPrintOptions().set_print_metadata(false).set_indent_amount(2);
598   const string expected_computation =
599       R"(    %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
600       %x = f32[5,10]{1,0} parameter(0)
601       %y = f32[20,10]{1,0} parameter(1)
602       %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
603       ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
604     })";
605   EXPECT_EQ(computation->ToString(options), expected_computation);
606 }
607 
TEST_F(HloComputationTest,StringificationCanonical)608 TEST_F(HloComputationTest, StringificationCanonical) {
609   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
610   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
611   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
612   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
613 
614   HloComputation::Builder builder("TransposeDot");
615   HloInstruction* x =
616       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
617   HloInstruction* y =
618       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
619   HloInstruction* reshape =
620       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
621   DotDimensionNumbers dot_dnums;
622   dot_dnums.add_lhs_contracting_dimensions(1);
623   dot_dnums.add_rhs_contracting_dimensions(0);
624   PrecisionConfig precision_config;
625   precision_config.mutable_operand_precision()->Resize(
626       2, PrecisionConfig::DEFAULT);
627   builder.AddInstruction(
628       HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
629   auto module = CreateNewVerifiedModule();
630   auto* computation = module->AddEntryComputation(builder.Build());
631 
632   auto options = HloPrintOptions().set_print_metadata(false);
633   const string expected_computation1 =
634       R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
635   %x = f32[5,10]{1,0} parameter(0)
636   %y = f32[20,10]{1,0} parameter(1)
637   %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
638   ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
639 })";
640   EXPECT_EQ(computation->ToString(options), expected_computation1);
641 
642   options = HloPrintOptions().Canonical();
643   const string expected_computation2 = R"(TransposeDot {
644   tmp_0 = f32[5,10]{1,0} parameter(0)
645   tmp_1 = f32[20,10]{1,0} parameter(1)
646   tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
647   ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
648 })";
649   EXPECT_EQ(computation->ToString(options), expected_computation2);
650 }
651 
MakeAddNComputation(int n)652 std::unique_ptr<HloComputation> MakeAddNComputation(int n) {
653   auto builder = HloComputation::Builder("add_n");
654   auto result = builder.AddInstruction(HloInstruction::CreateParameter(
655       0, ShapeUtil::MakeShape(F32, {}), "x_value"));
656   auto one = builder.AddInstruction(
657       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
658   for (int i = 0; i < n; ++i) {
659     result = builder.AddInstruction(HloInstruction::CreateBinary(
660         one->shape(), HloOpcode::kAdd, result, one));
661   }
662   return builder.Build();
663 }
664 
TEST_F(HloComputationTest,DeepEquality)665 TEST_F(HloComputationTest, DeepEquality) {
666   auto computation_a = MakeAddNComputation(200000);
667   auto computation_b = MakeAddNComputation(200000);
668   EXPECT_TRUE(*computation_a == *computation_b);
669 
670   auto computation_c = MakeAddNComputation(199999);
671   EXPECT_FALSE(*computation_a == *computation_c);
672   EXPECT_FALSE(*computation_c == *computation_b);
673 }
674 
675 // Tests that cross-module AllReduce instructions are ordered before all their
676 // predecessors and after all their successors.
TEST_F(HloComputationTest,InstructionPostOrderWithAllReduce)677 TEST_F(HloComputationTest, InstructionPostOrderWithAllReduce) {
678   const char* const hlo_string = R"(
679 HloModule Module
680 
681 add {
682   lhs = f32[] parameter(0)
683   rhs = f32[] parameter(1)
684   ROOT add = f32[] add(lhs, rhs)
685 }
686 
687 ENTRY entry {
688   param = f32[128] parameter(0), sharding={maximal device=0}
689   crs0 = f32[128] all-reduce(param),
690     replica_groups={{0}}, all_reduce_id=1, barrier="", to_apply=add,
691     sharding={maximal device=0}
692   crs1 = f32[128] all-reduce(param),
693     replica_groups={{0}}, all_reduce_id=1, barrier="", to_apply=add,
694     sharding={maximal device=1}
695   add = f32[128] add(crs0, crs0), sharding={maximal device=0}
696   ROOT t = (f32[128], f32[128]) tuple(add, crs1)
697 })";
698   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
699   EXPECT_THAT(module->entry_computation()->MakeInstructionPostOrder(),
700               ElementsAre(op::Parameter(), op::AllReduce(), op::AllReduce(),
701                           op::Add(), op::Tuple()));
702 }
703 
704 }  // namespace
705 }  // namespace xla
706