1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17 
18 #include <set>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/protobuf_util.h"
26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/test.h"
32 #include "tensorflow/compiler/xla/test_helpers.h"
33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/compiler/xla/window_util.h"
36 
37 namespace xla {
38 namespace {
39 
40 using ::testing::ElementsAre;
41 using ::testing::UnorderedElementsAre;
42 
43 class HloInstructionTest : public HloTestBase {
44  protected:
45   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
46 };
47 
48 // Simple visitor that collects the number of users and operands for certain HLO
49 // nodes. It also verifies some of the DFS visiting invariants (operands visited
50 // before their users, nodes not visited twice, etc.)
51 class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
52  public:
DefaultAction(HloInstruction * hlo_instruction)53   Status DefaultAction(HloInstruction* hlo_instruction) override {
54     return Unimplemented("not implemented %s",
55                          HloOpcodeString(hlo_instruction->opcode()));
56   }
57 
HandleParameter(HloInstruction * parameter)58   Status HandleParameter(HloInstruction* parameter) override {
59     EXPECT_FALSE(count_.contains(parameter));
60     count_[parameter] = GetCountsForNode(parameter);
61     return Status::OK();
62   }
63 
HandleConstant(HloInstruction * constant)64   Status HandleConstant(HloInstruction* constant) override {
65     EXPECT_FALSE(count_.contains(constant));
66     count_[constant] = GetCountsForNode(constant);
67     return Status::OK();
68   }
69 
HandleAdd(HloInstruction * add)70   Status HandleAdd(HloInstruction* add) override {
71     auto lhs = add->operand(0);
72     auto rhs = add->operand(1);
73     EXPECT_FALSE(count_.contains(add));
74     EXPECT_TRUE(count_.contains(lhs));
75     EXPECT_TRUE(count_.contains(rhs));
76     count_[add] = GetCountsForNode(add);
77     return Status::OK();
78   }
79 
HandleNegate(HloInstruction * negate)80   Status HandleNegate(HloInstruction* negate) override {
81     auto operand = negate->operand(0);
82     EXPECT_FALSE(count_.contains(negate));
83     EXPECT_TRUE(count_.contains(operand));
84     count_[negate] = GetCountsForNode(negate);
85     return Status::OK();
86   }
87 
HandleMap(HloInstruction * map)88   Status HandleMap(HloInstruction* map) override {
89     EXPECT_FALSE(count_.contains(map));
90     for (HloInstruction* arg : map->operands()) {
91       EXPECT_TRUE(count_.contains(arg));
92     }
93     count_[map] = GetCountsForNode(map);
94     return Status::OK();
95   }
96 
HandleReduce(HloInstruction * reduce)97   Status HandleReduce(HloInstruction* reduce) override {
98     auto arg = reduce->operand(0);
99     auto init_value = reduce->operand(1);
100     EXPECT_FALSE(count_.contains(reduce));
101     EXPECT_TRUE(count_.contains(arg));
102     EXPECT_TRUE(count_.contains(init_value));
103     count_[reduce] = GetCountsForNode(reduce);
104     return Status::OK();
105   }
106 
NumOperands(const HloInstruction * node)107   int64 NumOperands(const HloInstruction* node) {
108     auto count_iterator = count_.find(node);
109     EXPECT_NE(count_.end(), count_iterator);
110     return count_iterator->second.operand_count;
111   }
112 
NumUsers(const HloInstruction * node)113   int64 NumUsers(const HloInstruction* node) {
114     auto count_iterator = count_.find(node);
115     EXPECT_NE(count_.end(), count_iterator);
116     return count_iterator->second.user_count;
117   }
118 
119  private:
120   struct NumOpsAndUsers {
121     int64 operand_count;
122     int64 user_count;
123   };
124 
125   // Helper function to count operands and users for the given HLO.
GetCountsForNode(const HloInstruction * node)126   NumOpsAndUsers GetCountsForNode(const HloInstruction* node) {
127     NumOpsAndUsers counts{node->operand_count(), node->user_count()};
128     return counts;
129   }
130 
131   // Counters for HLOs. Maps HLO to a NumOpsAndUsers.
132   absl::flat_hash_map<const HloInstruction*, NumOpsAndUsers> count_;
133 };
134 
TEST_F(HloInstructionTest,BasicProperties)135 TEST_F(HloInstructionTest, BasicProperties) {
136   auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo");
137 
138   EXPECT_EQ(HloOpcode::kParameter, parameter->opcode());
139   EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32));
140   EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32));
141   EXPECT_FALSE(parameter->operand_count());
142 }
143 
TEST_F(HloInstructionTest,UserWithTwoOperands)144 TEST_F(HloInstructionTest, UserWithTwoOperands) {
145   // [Param foo]----->  |-----|
146   //                    | Add |
147   // [Param bar]----->  |-----|
148   HloComputation::Builder builder(TestName());
149   auto foo =
150       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
151   auto bar =
152       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
153   auto add = builder.AddInstruction(
154       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
155   auto module = CreateNewVerifiedModule();
156   module->AddEntryComputation(builder.Build());
157 
158   EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar));
159   EXPECT_THAT(foo->users(), UnorderedElementsAre(add));
160   EXPECT_THAT(bar->users(), UnorderedElementsAre(add));
161 
162   OpAndUserCollectingVisitor visitor;
163   ASSERT_IS_OK(add->Accept(&visitor));
164 
165   EXPECT_EQ(2, visitor.NumOperands(add));
166   EXPECT_EQ(0, visitor.NumUsers(add));
167   EXPECT_EQ(1, visitor.NumUsers(foo));
168   EXPECT_EQ(1, visitor.NumUsers(bar));
169 }
170 
TEST_F(HloInstructionTest,MultipleUsers)171 TEST_F(HloInstructionTest, MultipleUsers) {
172   //        [Param foo]
173   //       /     |     \
174   //      /      |      \     [Param bar]
175   //     /       |       \         |
176   //     |       |       |         |
177   //     V       V       V         V
178   //  -------  -------   -----------
179   //  | exp |  | exp |   |   add   |
180   //  -------  -------   -----------
181   HloComputation::Builder builder(TestName());
182   auto foo =
183       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
184   auto bar =
185       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
186   auto exp1 = builder.AddInstruction(
187       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
188   auto exp2 = builder.AddInstruction(
189       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
190   auto add = builder.AddInstruction(
191       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
192   auto module = CreateNewVerifiedModule();
193   module->AddEntryComputation(builder.Build());
194 
195   EXPECT_EQ(3, foo->user_count());
196   EXPECT_EQ(1, bar->user_count());
197   EXPECT_EQ(0, exp1->user_count());
198   EXPECT_EQ(0, exp2->user_count());
199   EXPECT_EQ(0, add->user_count());
200 
201   OpAndUserCollectingVisitor visitor;
202   ASSERT_IS_OK(add->Accept(&visitor));
203 
204   EXPECT_EQ(2, visitor.NumOperands(add));
205   EXPECT_EQ(3, visitor.NumUsers(foo));
206 }
207 
TEST_F(HloInstructionTest,RepeatedUser)208 TEST_F(HloInstructionTest, RepeatedUser) {
209   // Here we have a user 'add' nodes that uses the same HLO in both operands.
210   // Make sure we don't count it as two distinct users.
211   //
212   //        [Param foo]
213   //           |   |
214   //           |   |
215   //           |   |
216   //           V   V
217   //          -------
218   //          | add |
219   //          -------
220   HloComputation::Builder builder(TestName());
221   auto foo =
222       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
223   auto add = builder.AddInstruction(
224       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
225   auto module = CreateNewVerifiedModule();
226   module->AddEntryComputation(builder.Build());
227 
228   EXPECT_EQ(1, foo->user_count());
229 
230   // But 'add' still has two operands, even if both are the same HLO.
231   EXPECT_EQ(2, add->operand_count());
232 }
233 
TEST_F(HloInstructionTest,MultipleUsersAndOperands)234 TEST_F(HloInstructionTest, MultipleUsersAndOperands) {
235   //        [param0]          [param1]
236   //           |                 |
237   //           |       [c0]      |
238   //           |        |        |
239   //           V        |        V
240   //        -------     |     -------
241   //        | add | <---^---> | add |
242   //        -------           -------
243   //           |                 |
244   //           \     -------     /
245   //            ---->| add |<----
246   //                 -------
247   HloComputation::Builder builder(TestName());
248   auto param0 = builder.AddInstruction(
249       HloInstruction::CreateParameter(0, r0f32_, "param0"));
250   auto param1 = builder.AddInstruction(
251       HloInstruction::CreateParameter(1, r0f32_, "param1"));
252   auto c0 = builder.AddInstruction(
253       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
254   auto addleft = builder.AddInstruction(
255       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, c0));
256   auto addright = builder.AddInstruction(
257       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1));
258   auto addtotal = builder.AddInstruction(
259       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
260   auto module = CreateNewVerifiedModule();
261   module->AddEntryComputation(builder.Build());
262 
263   OpAndUserCollectingVisitor visitor;
264   ASSERT_IS_OK(addtotal->Accept(&visitor));
265 
266   EXPECT_EQ(2, visitor.NumUsers(c0));
267   EXPECT_EQ(2, visitor.NumOperands(addleft));
268   EXPECT_EQ(2, visitor.NumOperands(addright));
269   EXPECT_EQ(2, visitor.NumOperands(addtotal));
270 }
271 
TEST_F(HloInstructionTest,MultipleUsersAndOperandsWithUnaryOps)272 TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) {
273   //        [param0]   [c0]   [param1]
274   //           |        |        |
275   //           |        V        |
276   //           |     -------     |
277   //           |     | neg |     |
278   //           |     -------     |
279   //           V        |        V
280   //        -------     |     -------
281   //        | add | <---^---> | add |
282   //        -------           -------
283   //           |                 |
284   //           \     -------     /
285   //            ---->| add |<----
286   //                 -------
287   //                    |
288   //                    V
289   //                 -------
290   //                 | neg |
291   //                 -------
292   HloComputation::Builder builder(TestName());
293   auto param0 = builder.AddInstruction(
294       HloInstruction::CreateParameter(0, r0f32_, "param0"));
295   auto param1 = builder.AddInstruction(
296       HloInstruction::CreateParameter(1, r0f32_, "param1"));
297   auto c0 = builder.AddInstruction(
298       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
299   auto neg1 = builder.AddInstruction(
300       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0));
301   auto addleft = builder.AddInstruction(
302       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, neg1));
303   auto addright = builder.AddInstruction(
304       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, neg1, param1));
305   auto addtotal = builder.AddInstruction(
306       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
307   auto neg2 = builder.AddInstruction(
308       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal));
309   auto module = CreateNewVerifiedModule();
310   module->AddEntryComputation(builder.Build());
311 
312   OpAndUserCollectingVisitor visitor;
313   ASSERT_IS_OK(neg2->Accept(&visitor));
314 
315   EXPECT_EQ(1, visitor.NumUsers(c0));
316   EXPECT_EQ(2, visitor.NumUsers(neg1));
317   EXPECT_EQ(2, visitor.NumOperands(addleft));
318   EXPECT_EQ(2, visitor.NumOperands(addright));
319   EXPECT_EQ(2, visitor.NumOperands(addtotal));
320   EXPECT_EQ(1, visitor.NumOperands(neg2));
321   EXPECT_EQ(0, visitor.NumUsers(neg2));
322 }
323 
TEST_F(HloInstructionTest,TrivialMap)324 TEST_F(HloInstructionTest, TrivialMap) {
325   // This tests creating a trivial x+1 map as the only operation.
326   //
327   // param0[100x10] ---> (map x+1)
328   //
329   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
330   Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
331   auto module = CreateNewVerifiedModule();
332 
333   // Builds an x+1.0 computation to use in a Map.
334   auto embedded_builder = HloComputation::Builder("f32+1");
335   auto param = embedded_builder.AddInstruction(
336       HloInstruction::CreateParameter(0, r0f32, "x"));
337   auto value = embedded_builder.AddInstruction(
338       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
339   embedded_builder.AddInstruction(
340       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value));
341   auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
342 
343   // Builds a parameter and feeds it to the map.
344   HloComputation::Builder builder(TestName());
345   auto param0 = builder.AddInstruction(
346       HloInstruction::CreateParameter(0, f32a100x10, "p"));
347   auto map = builder.AddInstruction(
348       HloInstruction::CreateMap(f32a100x10, {param0}, add_f32));
349   module->AddEntryComputation(builder.Build());
350 
351   OpAndUserCollectingVisitor visitor;
352   ASSERT_IS_OK(map->Accept(&visitor));
353 
354   // Check counts.  We aren't walking the mapper computation yet.
355   EXPECT_EQ(1, visitor.NumUsers(param0));
356   EXPECT_EQ(0, visitor.NumUsers(map));
357   EXPECT_EQ(1, visitor.NumOperands(map));
358 
359   // TODO(dehnert):  Add walking and counters for the wrapped computation.
360 }
361 
TEST_F(HloInstructionTest,TrivialReduce)362 TEST_F(HloInstructionTest, TrivialReduce) {
363   // This tests creating a trivial x+y reduce as the only operation.
364   //
365   // param0[100x10] ---> (reduce x+y)
366   //
367   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
368   Shape f32v100 = ShapeUtil::MakeShape(F32, {100});
369   Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
370 
371   // Builds an x+y computation to use in a Reduce.
372   auto embedded_builder = HloComputation::Builder("f32+f32");
373   auto paramx = embedded_builder.AddInstruction(
374       HloInstruction::CreateParameter(0, r0f32, "x"));
375   auto paramy = embedded_builder.AddInstruction(
376       HloInstruction::CreateParameter(1, r0f32, "y"));
377   embedded_builder.AddInstruction(
378       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy));
379   auto module = CreateNewVerifiedModule();
380   auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
381 
382   // Builds a parameter and an initial value and feeds them to the reduce.
383   HloComputation::Builder builder(TestName());
384   auto param0 = builder.AddInstruction(
385       HloInstruction::CreateParameter(0, f32a100x10, "p"));
386   auto const0 = builder.AddInstruction(
387       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
388   builder.AddInstruction(
389       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
390   auto reduce = builder.AddInstruction(
391       HloInstruction::CreateReduce(f32v100, param0, const0,
392                                    /*dimensions_to_reduce=*/{1}, add_f32));
393   module->AddEntryComputation(builder.Build());
394 
395   OpAndUserCollectingVisitor visitor;
396   ASSERT_IS_OK(reduce->Accept(&visitor));
397 
398   // Check counts.  We aren't walking the reducer computation.
399   EXPECT_EQ(1, visitor.NumUsers(param0));
400   EXPECT_EQ(1, visitor.NumUsers(const0));
401   EXPECT_EQ(0, visitor.NumUsers(reduce));
402   EXPECT_EQ(2, visitor.NumOperands(reduce));
403 }
404 
TEST_F(HloInstructionTest,ReplaceUseInBinaryOps)405 TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) {
406   // Construct a graph of a few binary ops using two different
407   // parameters. Replace one of the parameters with the other parameter in one
408   // of the instructions.
409   HloComputation::Builder builder(TestName());
410   auto foo =
411       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
412   auto bar =
413       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
414   auto add_foobar = builder.AddInstruction(
415       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
416   auto add_foofoo = builder.AddInstruction(
417       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
418   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
419                                                       add_foobar, add_foofoo));
420   auto module = CreateNewVerifiedModule();
421   module->AddEntryComputation(builder.Build());
422 
423   EXPECT_EQ(2, foo->user_count());
424   EXPECT_EQ(1, bar->user_count());
425 
426   // Replace the use of foo in add_foofoo with bar.
427   ASSERT_IS_OK(foo->ReplaceUseWith(add_foofoo, bar));
428 
429   EXPECT_EQ(1, foo->user_count());
430   EXPECT_EQ(2, bar->user_count());
431 
432   EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar));
433   EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar));
434 
435   EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo));
436   EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar));
437   EXPECT_THAT(add_foofoo->operands(), ElementsAre(bar, bar));
438 }
439 
TEST_F(HloInstructionTest,ReplaceUseInVariadicOp)440 TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) {
441   // Construct a tuple containing several parameters. Replace one parameter with
442   // another in the tuple.
443   HloComputation::Builder builder(TestName());
444   auto foo =
445       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
446   auto bar =
447       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
448   auto baz =
449       builder.AddInstruction(HloInstruction::CreateParameter(2, r0f32_, "baz"));
450 
451   auto tuple =
452       builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo}));
453   auto add_foobar = builder.AddInstruction(
454       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
455   auto module = CreateNewVerifiedModule();
456   module->AddEntryComputation(builder.Build());
457 
458   EXPECT_EQ(2, foo->user_count());
459   EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar));
460 
461   // Replace the use of foo in tuple with bar.
462   ASSERT_IS_OK(foo->ReplaceUseWith(tuple, bar));
463 
464   EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar));
465 
466   // Both uses of foo in tuple should have been replaced with bar.
467   EXPECT_THAT(tuple->operands(), ElementsAre(bar, bar, baz, bar));
468 }
469 
TEST_F(HloInstructionTest,ReplaceUseInUnaryOp)470 TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) {
471   // Construct a couple unary instructions which use a parameter. Replace the
472   // use of a parameter in one of the unary ops with the other parameter.
473   HloComputation::Builder builder(TestName());
474   auto foo =
475       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
476   auto bar =
477       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
478 
479   auto exp = builder.AddInstruction(
480       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
481   auto log = builder.AddInstruction(
482       HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
483   auto module = CreateNewVerifiedModule();
484   module->AddEntryComputation(builder.Build());
485 
486   EXPECT_EQ(2, foo->user_count());
487   EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log));
488   EXPECT_EQ(0, bar->user_count());
489 
490   // Replace the use of foo in exp with bar.
491   ASSERT_IS_OK(foo->ReplaceUseWith(exp, bar));
492 
493   // The use of foo in log should not have been affected.
494   EXPECT_EQ(1, foo->user_count());
495   EXPECT_THAT(foo->users(), UnorderedElementsAre(log));
496   EXPECT_THAT(log->operands(), ElementsAre(foo));
497 
498   // Bar should now be used in exp.
499   EXPECT_EQ(1, bar->user_count());
500   EXPECT_EQ(*bar->users().begin(), exp);
501   EXPECT_EQ(1, exp->operands().size());
502   EXPECT_EQ(*exp->operands().begin(), bar);
503 }
504 
TEST_F(HloInstructionTest,ReplaceAllUsesWithInBinaryOps)505 TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) {
506   // Construct a simple graph of a few binary ops using two different
507   // parameters. Replace all uses of one of the parameters with the other
508   // parameter.
509   HloComputation::Builder builder(TestName());
510   auto foo =
511       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
512   auto bar =
513       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
514   auto add_foobar = builder.AddInstruction(
515       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
516   auto add_foofoo = builder.AddInstruction(
517       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
518   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
519                                                       add_foobar, add_foofoo));
520   auto module = CreateNewVerifiedModule();
521   module->AddEntryComputation(builder.Build());
522 
523   EXPECT_EQ(2, foo->user_count());
524   EXPECT_EQ(1, bar->user_count());
525 
526   // Replace all uses of foo with bar.
527   ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar));
528 
529   EXPECT_EQ(0, foo->user_count());
530   EXPECT_EQ(2, bar->user_count());
531 
532   EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo));
533 }
534 
TEST_F(HloInstructionTest,ReplaceAllUsesInMultipleOps)535 TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) {
536   // Construct a graph containing several ops (a unary, binary, and variadic)
537   // which use two parameters. Replace all uses of one of the parameters with
538   // the other parameter.
539   HloComputation::Builder builder(TestName());
540   auto foo =
541       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
542   auto bar =
543       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
544 
545   auto add_foobar = builder.AddInstruction(
546       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
547   auto exp = builder.AddInstruction(
548       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
549   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar}));
550   auto module = CreateNewVerifiedModule();
551   module->AddEntryComputation(builder.Build());
552 
553   EXPECT_EQ(3, foo->user_count());
554   EXPECT_EQ(2, bar->user_count());
555 
556   // Replace all uses of foo with bar.
557   ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar));
558 
559   EXPECT_EQ(0, foo->user_count());
560   EXPECT_EQ(3, bar->user_count());
561 
562   EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, exp, tuple));
563 }
564 
565 // Simple visitor that collects and post-processes each node in the graph.
566 class NodeCollectorAndPostProcessor : public DfsHloVisitorWithDefault {
567  public:
NodeCollectorAndPostProcessor()568   NodeCollectorAndPostProcessor() {}
569 
Postprocess(HloInstruction * hlo)570   Status Postprocess(HloInstruction* hlo) override {
571     post_processed_nodes_.push_back(hlo);
572     return Status::OK();
573   }
574 
DefaultAction(HloInstruction * hlo_instruction)575   Status DefaultAction(HloInstruction* hlo_instruction) override {
576     visited_nodes_.push_back(hlo_instruction);
577     return Status::OK();
578   }
579 
visited_nodes()580   const std::vector<const HloInstruction*>& visited_nodes() {
581     return visited_nodes_;
582   }
583 
post_processed_nodes()584   const std::vector<const HloInstruction*>& post_processed_nodes() {
585     return post_processed_nodes_;
586   }
587 
588  private:
589   std::vector<const HloInstruction*> visited_nodes_;
590   std::vector<const HloInstruction*> post_processed_nodes_;
591 };
592 
593 // Returns true if "vec" contains distinct nodes.
Distinct(const std::vector<const HloInstruction * > & vec)594 bool Distinct(const std::vector<const HloInstruction*>& vec) {
595   std::set<const HloInstruction*> distinct_nodes(vec.begin(), vec.end());
596   return distinct_nodes.size() == vec.size();
597 }
598 
TEST_F(HloInstructionTest,PostProcessAllVisitedNodes)599 TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) {
600   // Verifies all the nodes are visited and post-processed in the same order,
601   // and that each node is visited exactly once.
602   //
603   //    /--> exp --\
604   // foo            add
605   //    \--> log --/
606   HloComputation::Builder builder(TestName());
607   auto foo =
608       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
609   auto exp = builder.AddInstruction(
610       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
611   auto log = builder.AddInstruction(
612       HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
613   auto add = builder.AddInstruction(
614       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log));
615   auto module = CreateNewVerifiedModule();
616   module->AddEntryComputation(builder.Build());
617 
618   NodeCollectorAndPostProcessor visitor;
619   ASSERT_IS_OK(add->Accept(&visitor));
620   // Verifies all the nodes are visited and post-processed in the same order.
621   EXPECT_EQ(visitor.visited_nodes(), visitor.post_processed_nodes());
622   // Verifies each node is visited exactly once.
623   EXPECT_TRUE(Distinct(visitor.visited_nodes()));
624 }
625 
TEST_F(HloInstructionTest,SingletonFusionOp)626 TEST_F(HloInstructionTest, SingletonFusionOp) {
627   HloComputation::Builder builder(TestName());
628   // Create a fusion instruction containing a single unary operation.
629   auto constant = builder.AddInstruction(
630       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
631   auto exp = builder.AddInstruction(
632       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
633   auto module = CreateNewVerifiedModule();
634   auto* computation = module->AddEntryComputation(builder.Build());
635   auto* fusion = computation->CreateFusionInstruction(
636       {exp}, HloInstruction::FusionKind::kLoop);
637 
638   EXPECT_THAT(fusion->operands(), ElementsAre(constant));
639   EXPECT_THAT(constant->users(), ElementsAre(fusion));
640 }
641 
TEST_F(HloInstructionTest,BinaryFusionOp)642 TEST_F(HloInstructionTest, BinaryFusionOp) {
643   HloComputation::Builder builder(TestName());
644   // Create a fusion instruction containing a single binary operation.
645   auto constant1 = builder.AddInstruction(
646       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
647   auto constant2 = builder.AddInstruction(
648       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
649   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
650       r0f32_, HloOpcode::kAdd, constant1, constant2));
651   auto module = CreateNewVerifiedModule();
652   auto* computation = module->AddEntryComputation(builder.Build());
653   auto* fusion = computation->CreateFusionInstruction(
654       {add}, HloInstruction::FusionKind::kLoop);
655 
656   EXPECT_THAT(fusion->operands(), ElementsAre(constant1, constant2));
657   EXPECT_THAT(constant1->users(), ElementsAre(fusion));
658   EXPECT_THAT(constant2->users(), ElementsAre(fusion));
659 }
660 
TEST_F(HloInstructionTest,ChainFusionOp)661 TEST_F(HloInstructionTest, ChainFusionOp) {
662   HloComputation::Builder builder(TestName());
663   // Create a chain of fused unary ops.
664   auto constant = builder.AddInstruction(
665       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
666   auto exp1 = builder.AddInstruction(
667       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
668   auto exp2 = builder.AddInstruction(
669       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
670   auto exp3 = builder.AddInstruction(
671       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
672 
673   auto module = CreateNewVerifiedModule();
674   auto* computation = module->AddEntryComputation(builder.Build());
675   auto* fusion = computation->CreateFusionInstruction(
676       {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop);
677 
678   EXPECT_THAT(fusion->operands(), ElementsAre(constant));
679   EXPECT_THAT(constant->users(), ElementsAre(fusion));
680 }
681 
TEST_F(HloInstructionTest,PreserveMetadataInFusionAndClone)682 TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
683   HloComputation::Builder builder(TestName());
684   // Create a chain of fused unary ops.
685   auto constant = builder.AddInstruction(
686       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
687   auto exp1 = builder.AddInstruction(
688       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
689   auto exp2 = builder.AddInstruction(
690       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
691   OpMetadata metadata;
692   metadata.set_op_name("tf_op");
693   exp1->set_metadata(metadata);
694   exp2->set_metadata(metadata);
695 
696   auto module = CreateNewVerifiedModule();
697   auto* computation = module->AddEntryComputation(builder.Build());
698   auto* fusion = computation->CreateFusionInstruction(
699       {exp2, exp1}, HloInstruction::FusionKind::kLoop);
700 
701   EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
702   EXPECT_TRUE(protobuf_util::ProtobufEquals(
703       metadata, fusion->fused_expression_root()->metadata()));
704   EXPECT_TRUE(protobuf_util::ProtobufEquals(
705       metadata, fusion->fused_expression_root()->operand(0)->metadata()));
706 
707   auto cloned = fusion->CloneWithNewOperands(fusion->shape(), {});
708   EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
709 }
710 
TEST_F(HloInstructionTest,PreserveOutfeedShapeThroughClone)711 TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
712   HloComputation::Builder builder(TestName());
713   auto constant = builder.AddInstruction(
714       HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
715           {1, 2},
716           {3, 4},
717       })));
718   auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
719   auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
720   auto token = builder.AddInstruction(HloInstruction::CreateToken());
721   auto outfeed10 = builder.AddInstruction(
722       HloInstruction::CreateOutfeed(shape10, constant, token, ""));
723   auto outfeed01 = builder.AddInstruction(
724       HloInstruction::CreateOutfeed(shape01, constant, token, ""));
725 
726   auto clone01 = builder.AddInstruction(outfeed01->Clone());
727   auto clone10 = builder.AddInstruction(outfeed10->Clone());
728 
729   EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01));
730   EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10));
731 }
732 
TEST_F(HloInstructionTest,PreserveTupleShapeThroughClone)733 TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
734   HloComputation::Builder builder(TestName());
735   auto* constant = builder.AddInstruction(
736       HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
737           {1, 2},
738           {3, 4},
739       })));
740   auto* tuple =
741       builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
742   *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {0})
743        ->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
744   *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {1})
745        ->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
746   auto tuple_clone = tuple->Clone();
747   EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape()));
748 }
749 
TEST_F(HloInstructionTest,FusionOpWithCalledComputations)750 TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
751   // Create a fusion instruction containing a single unary operation.
752   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
753   auto module = CreateNewVerifiedModule();
754 
755   auto make_map_computation = [&]() {
756     auto builder = HloComputation::Builder("FusionMap");
757     builder.AddInstruction(
758         HloInstruction::CreateParameter(0, scalar_shape, "param"));
759     return module->AddEmbeddedComputation(builder.Build());
760   };
761 
762   HloComputation* computation_x = make_map_computation();
763   HloComputation* computation_y = make_map_computation();
764 
765   HloComputation::Builder builder(TestName());
766   auto constant = builder.AddInstruction(
767       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
768   auto map_1_x = builder.AddInstruction(
769       HloInstruction::CreateMap(scalar_shape, {constant}, computation_x));
770   auto map_2_x = builder.AddInstruction(
771       HloInstruction::CreateMap(scalar_shape, {map_1_x}, computation_x));
772   auto map_3_y = builder.AddInstruction(
773       HloInstruction::CreateMap(scalar_shape, {map_2_x}, computation_y));
774   auto* computation = module->AddEntryComputation(builder.Build());
775 
776   auto* fusion = computation->CreateFusionInstruction(
777       {map_3_y}, HloInstruction::FusionKind::kLoop);
778   auto* fused_computation = fusion->fused_instructions_computation();
779   EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
780 
781   fusion->FuseInstruction(map_2_x);
782   EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
783 
784   fusion->FuseInstruction(map_1_x);
785   EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
786 }
787 
TEST_F(HloInstructionTest,ComplexFusionOp)788 TEST_F(HloInstructionTest, ComplexFusionOp) {
789   HloComputation::Builder builder(TestName());
790   // Fuse all instructions in complicated expression:
791   //
792   //   add = Add(C1, C2)
793   //   clamp = Clamp(C2, add, add)
794   //   exp = Exp(add)
795   //   mul = Mul(exp, C3)
796   //   sub = Sub(mul, clamp)
797   //   tuple = Tuple({sub, sub, mul, C1})
798   //
799   // Notable complexities are repeated operands in the same instruction,
800   // different shapes, use of value in different expressions.
801   auto c1 = builder.AddInstruction(
802       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
803   auto c2 = builder.AddInstruction(
804       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.1f)));
805   auto c3 = builder.AddInstruction(
806       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(9.0f)));
807 
808   auto add = builder.AddInstruction(
809       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2));
810   auto clamp = builder.AddInstruction(
811       HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, c2, add, add));
812   auto exp = builder.AddInstruction(
813       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add));
814   auto mul = builder.AddInstruction(
815       HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, exp, c3));
816   auto sub = builder.AddInstruction(
817       HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, mul, clamp));
818   auto tuple =
819       builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1}));
820 
821   auto module = CreateNewVerifiedModule();
822   auto* computation = module->AddEntryComputation(builder.Build());
823   auto* fusion = computation->CreateFusionInstruction(
824       {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
825 
826   // Operands in the fusion instruction's operands() vector should be in the
827   // order in which their users were added fused.
828   EXPECT_THAT(fusion->operands(), ElementsAre(c1, c3, c2));
829   EXPECT_THAT(c1->users(), ElementsAre(fusion));
830 }
831 
832 // Convenience function for comparing two HloInstructions.
Identical(const HloInstruction & instruction1,const HloInstruction & instruction2)833 static bool Identical(const HloInstruction& instruction1,
834                       const HloInstruction& instruction2) {
835   // Verify Identical is reflexive for both instructions.
836   EXPECT_TRUE(instruction1.Identical(instruction1));
837   EXPECT_TRUE(instruction2.Identical(instruction2));
838 
839   bool is_equal = instruction1.Identical(instruction2);
840   // Verify Identical is symmetric.
841   EXPECT_EQ(is_equal, instruction2.Identical(instruction1));
842   return is_equal;
843 }
844 
845 // Convenience function for comparing two HloInstructions for structural
846 // equality.
StructuralEqual(const HloInstruction & instruction1,const HloInstruction & instruction2)847 static bool StructuralEqual(const HloInstruction& instruction1,
848                             const HloInstruction& instruction2) {
849   auto eq_operand_shapes = [](const HloInstruction* a,
850                               const HloInstruction* b) {
851     return ShapeUtil::Equal(a->shape(), b->shape());
852   };
853   auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
854     return *a == *b;
855   };
856 
857   // Verify Identical is reflexive for both instructions.
858   EXPECT_TRUE(
859       instruction1.Identical(instruction1, eq_operand_shapes, eq_computations));
860   EXPECT_TRUE(
861       instruction2.Identical(instruction2, eq_operand_shapes, eq_computations));
862 
863   bool is_equal =
864       instruction1.Identical(instruction2, eq_operand_shapes, eq_computations);
865   // Verify Identical is symmetric.
866   EXPECT_EQ(is_equal, instruction2.Identical(instruction1, eq_operand_shapes,
867                                              eq_computations));
868   return is_equal;
869 }
870 
TEST_F(HloInstructionTest,IdenticalInstructions)871 TEST_F(HloInstructionTest, IdenticalInstructions) {
872   // Test HloInstruction::Identical with some subset of instructions types.
873 
874   // Create a set of random constant operands to use below. Make them matrices
875   // so dimensions are interesting.
876   auto operand1 = HloInstruction::CreateConstant(
877       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
878   auto operand2 = HloInstruction::CreateConstant(
879       LiteralUtil::CreateR2<float>({{10.0, 20.0}, {30.0, 40.0}}));
880   auto vector_operand = HloInstruction::CreateConstant(
881       LiteralUtil::CreateR1<float>({42.0, 123.0}));
882   Shape shape = operand1->shape();
883 
884   // Convenient short names for the operands.
885   HloInstruction* op1 = operand1.get();
886   HloInstruction* op2 = operand2.get();
887 
888   // Operations which only depend on their operands and opcode.
889   EXPECT_TRUE(
890       Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
891                 *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1)));
892   EXPECT_FALSE(
893       Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
894                 *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2)));
895   EXPECT_FALSE(
896       Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
897                 *HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1)));
898 
899   // Tuples.
900   EXPECT_TRUE(Identical(*HloInstruction::CreateTuple({op1, op2}),
901                         *HloInstruction::CreateTuple({op1, op2})));
902   EXPECT_FALSE(Identical(*HloInstruction::CreateTuple({op1, op2}),
903                          *HloInstruction::CreateTuple({op2, op1})));
904 
905   // Broadcasts.
906   EXPECT_TRUE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
907                         *HloInstruction::CreateBroadcast(shape, op1, {0, 1})));
908   EXPECT_FALSE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
909                          *HloInstruction::CreateBroadcast(shape, op1, {1, 0})));
910   Shape bcast_shape1 = ShapeUtil::MakeShape(F32, {2, 2, 42});
911   Shape bcast_shape2 = ShapeUtil::MakeShape(F32, {2, 2, 123});
912   EXPECT_FALSE(
913       Identical(*HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}),
914                 *HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1})));
915 
916   // Binary operands.
917   EXPECT_TRUE(Identical(
918       *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
919       *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2)));
920   EXPECT_FALSE(Identical(
921       *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
922       *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1)));
923   EXPECT_FALSE(Identical(
924       *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
925       *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2)));
926 }
927 
TEST_F(HloInstructionTest,IdenticalCallInstructions)928 TEST_F(HloInstructionTest, IdenticalCallInstructions) {
929   const char* const hlo_string = R"(
930 HloModule Module
931 
932 subcomp1 (x: f32[]) -> f32[] {
933   x = f32[] parameter(0)
934   ROOT n = f32[] sine(x)
935 }
936 
937 subcomp2 (x: f32[]) -> f32[] {
938   x = f32[] parameter(0)
939   ROOT n = f32[] cosine(x)
940 }
941 
942 ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) {
943   p = f32[] parameter(0)
944   t1 = f32[] call(p), to_apply=subcomp1
945   t2 = f32[] call(p), to_apply=subcomp1
946   t3 = f32[] call(p), to_apply=subcomp2
947   ROOT t = (f32[], f32[], f32[]) tuple(t1, t2, t3)
948  }
949 )";
950   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
951                           ParseHloString(hlo_string));
952 
953   auto* root = module->entry_computation()->root_instruction();
954   auto* t1 = root->operand(0);
955   auto* t2 = root->operand(1);
956   auto* t3 = root->operand(2);
957 
958   EXPECT_TRUE(StructuralEqual(*t1, *t2));
959   EXPECT_FALSE(StructuralEqual(*t1, *t3));
960 }
961 
TEST_F(HloInstructionTest,FunctionVisitor)962 TEST_F(HloInstructionTest, FunctionVisitor) {
963   // Verify the function visitor HloInstruction::Accept visits all instructions
964   // from a root properly given the following graph:
965   //
966   //        param
967   //       /     \
968   //    negate   exp
969   //        \    /
970   //         add
971   const Shape f32 = ShapeUtil::MakeShape(F32, {});
972   HloComputation::Builder builder(TestName());
973   auto param =
974       builder.AddInstruction(HloInstruction::CreateParameter(0, f32, "0"));
975   auto negate = builder.AddInstruction(
976       HloInstruction::CreateUnary(f32, HloOpcode::kNegate, param));
977   auto exp = builder.AddInstruction(
978       HloInstruction::CreateUnary(f32, HloOpcode::kExp, param));
979   auto add = builder.AddInstruction(
980       HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp));
981   auto module = CreateNewVerifiedModule();
982   module->AddEntryComputation(builder.Build());
983 
984   int visit_num = 0;
985   absl::flat_hash_map<HloInstruction*, int> visit_order;
986   EXPECT_IS_OK(add->Accept([&visit_num, &visit_order](HloInstruction* inst) {
987     EXPECT_FALSE(visit_order.contains(inst));
988     visit_order[inst] = visit_num;
989     visit_num++;
990     return Status::OK();
991   }));
992 
993   EXPECT_EQ(0, visit_order.at(param));
994   // negate and exp can be visited in an arbitrary order.
995   EXPECT_TRUE(visit_order.at(exp) == 1 || visit_order.at(exp) == 2);
996   EXPECT_TRUE(visit_order.at(negate) == 1 || visit_order.at(negate) == 2);
997   EXPECT_NE(visit_order.at(exp), visit_order.at(negate));
998   EXPECT_EQ(3, visit_order.at(add));
999 }
1000 
TEST_F(HloInstructionTest,FullyElementwise)1001 TEST_F(HloInstructionTest, FullyElementwise) {
1002   const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
1003   HloComputation::Builder builder(TestName());
1004   auto x =
1005       builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
1006   auto y =
1007       builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
1008   auto add = builder.AddInstruction(
1009       HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y));
1010   auto module = CreateNewVerifiedModule();
1011   module->AddEntryComputation(builder.Build());
1012 
1013   EXPECT_TRUE(add->IsElementwise());
1014   for (int i = 0; i < add->operand_count(); ++i) {
1015     EXPECT_TRUE(add->IsElementwiseOnOperand(i));
1016   }
1017 }
1018 
TEST_F(HloInstructionTest,MapIsElementwise)1019 TEST_F(HloInstructionTest, MapIsElementwise) {
1020   auto module = CreateNewVerifiedModule();
1021   const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0});
1022   HloComputation::Builder builder(TestName());
1023   HloComputation::Builder map_builder("id");
1024   map_builder.AddInstruction(
1025       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
1026   auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1027   auto x =
1028       builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x"));
1029   auto map = builder.AddInstruction(
1030       HloInstruction::CreateMap(r2f32, {x}, map_computation));
1031   module->AddEntryComputation(builder.Build());
1032 
1033   EXPECT_TRUE(map->IsElementwise());
1034 }
1035 
TEST_F(HloInstructionTest,PartiallyElementwise)1036 TEST_F(HloInstructionTest, PartiallyElementwise) {
1037   const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
1038   const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5});
1039 
1040   // Fused expression:
1041   //
1042   // p0     p1   p2   p3
1043   //   \   /    /     |
1044   //    mul    /      |
1045   //      \   /       |
1046   //       div     broadcast
1047   //          \    /
1048   //           max
1049   //
1050   // The fusion instruction is not elementwise on p3 because the broadcast is
1051   // not elementwise.
1052   HloComputation::Builder builder("PartiallyElementwise");
1053   HloInstruction* p0 =
1054       builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "p0"));
1055   HloInstruction* p1 =
1056       builder.AddInstruction(HloInstruction::CreateParameter(1, r2f32, "p1"));
1057   HloInstruction* p2 =
1058       builder.AddInstruction(HloInstruction::CreateParameter(2, r2f32, "p2"));
1059   HloInstruction* p3 =
1060       builder.AddInstruction(HloInstruction::CreateParameter(3, r1f32, "p3"));
1061   HloInstruction* mul = builder.AddInstruction(
1062       HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, p0, p1));
1063   HloInstruction* div = builder.AddInstruction(
1064       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, mul, p2));
1065   // Dimension 0 of shape [5] is mapped to dimension 1 of shape [3x5].
1066   HloInstruction* broadcast =
1067       builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, p3, {1}));
1068   HloInstruction* max = builder.AddInstruction(
1069       HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
1070 
1071   auto module = CreateNewVerifiedModule();
1072   auto* computation = module->AddEntryComputation(builder.Build());
1073   HloInstruction* fusion = computation->CreateFusionInstruction(
1074       {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
1075   EXPECT_FALSE(fusion->IsElementwise());
1076   for (int64 operand_idx = 0; operand_idx < fusion->operand_count();
1077        ++operand_idx) {
1078     const HloInstruction* operand = fusion->operand(operand_idx);
1079     if (operand == p3) {
1080       EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
1081     } else {
1082       EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
1083     }
1084   }
1085 }
1086 
TEST_F(HloInstructionTest,PartiallyElementwiseWithReuse)1087 TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
1088   // Fused expression:
1089   //         y
1090   //        /
1091   // x   broadcast
1092   //  \   /  |
1093   //   min   |
1094   //     \   /
1095   //      sub
1096   //
1097   const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1098   const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
1099 
1100   HloComputation::Builder builder("PartiallyElementwiseWithReuse");
1101   HloInstruction* x =
1102       builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
1103   HloInstruction* y =
1104       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y"));
1105   HloInstruction* broadcast =
1106       builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {}));
1107   HloInstruction* min = builder.AddInstruction(
1108       HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, broadcast));
1109   HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
1110       r1f32, HloOpcode::kSubtract, min, broadcast));
1111 
1112   auto module = CreateNewVerifiedModule();
1113   auto* computation = module->AddEntryComputation(builder.Build());
1114   HloInstruction* fusion = computation->CreateFusionInstruction(
1115       {sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
1116   EXPECT_FALSE(fusion->IsElementwise());
1117   for (int64 operand_idx = 0; operand_idx < fusion->operand_count();
1118        ++operand_idx) {
1119     if (fusion->operand(operand_idx) == y) {
1120       EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
1121     } else {
1122       EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
1123     }
1124   }
1125 }
1126 
TEST_F(HloInstructionTest,CloneOfFusionPreservesShape)1127 TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
1128   // Fused expression:
1129   //
1130   // x     y
1131   // |     |
1132   // |  transpose
1133   //  \   /
1134   //   dot
1135   //
1136   // Tests that shapes aren't mangled by Clone().
1137   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1138   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1139   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1140   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1141 
1142   HloComputation::Builder builder("TransposeDot");
1143   HloInstruction* x =
1144       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1145   HloInstruction* y =
1146       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1147   HloInstruction* reshape =
1148       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1149   DotDimensionNumbers dot_dnums;
1150   dot_dnums.add_lhs_contracting_dimensions(1);
1151   dot_dnums.add_rhs_contracting_dimensions(0);
1152   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1153       sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1154 
1155   auto module = CreateNewVerifiedModule();
1156   auto* computation = module->AddEntryComputation(builder.Build());
1157   HloInstruction* fusion = computation->CreateFusionInstruction(
1158       {dot, reshape}, HloInstruction::FusionKind::kLoop);
1159 
1160   auto fusion2 = fusion->Clone();
1161   const HloInstruction* root = fusion->fused_expression_root();
1162   const HloInstruction* root2 = fusion2->fused_expression_root();
1163   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), root2->shape()));
1164   EXPECT_TRUE(
1165       ShapeUtil::Equal(root->operand(0)->shape(), root2->operand(0)->shape()));
1166   EXPECT_TRUE(
1167       ShapeUtil::Equal(root->operand(1)->shape(), root2->operand(1)->shape()));
1168   EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->operand(0)->shape(),
1169                                root2->operand(1)->operand(0)->shape()));
1170   EXPECT_TRUE(StructuralEqual(*fusion, *fusion2));
1171 }
1172 
TEST_F(HloInstructionTest,NoRedundantFusionOperandsAfterReplacingUse)1173 TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
1174   // Fused expression:
1175   //
1176   // x     y
1177   // |     |
1178   // |  transpose
1179   //  \   /
1180   //   dot
1181   const Shape s = ShapeUtil::MakeShape(F32, {10, 10});
1182 
1183   HloComputation::Builder builder("TransposeDot");
1184   HloInstruction* x =
1185       builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x"));
1186   HloInstruction* y =
1187       builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y"));
1188   HloInstruction* reshape =
1189       builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0}));
1190   DotDimensionNumbers dot_dnums;
1191   dot_dnums.add_lhs_contracting_dimensions(1);
1192   dot_dnums.add_rhs_contracting_dimensions(0);
1193   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1194       s, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1195 
1196   auto module = CreateNewVerifiedModule();
1197   auto* computation = module->AddEntryComputation(builder.Build());
1198   HloInstruction* fusion = computation->CreateFusionInstruction(
1199       {dot, reshape}, HloInstruction::FusionKind::kLoop);
1200 
1201   EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok());
1202 
1203   EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y));
1204   EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1);
1205 }
1206 
TEST_F(HloInstructionTest,FusionEquality)1207 TEST_F(HloInstructionTest, FusionEquality) {
1208   auto module = CreateNewVerifiedModule();
1209   HloComputation::Builder builder(TestName());
1210 
1211   // Create two fusion instructions containing a single unary operation.
1212   auto parameter =
1213       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
1214   auto exp = builder.AddInstruction(
1215       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter));
1216   auto neg = builder.AddInstruction(
1217       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter));
1218   auto* computation = module->AddEntryComputation(builder.Build());
1219   auto* fusion = computation->CreateFusionInstruction(
1220       {exp}, HloInstruction::FusionKind::kLoop);
1221   auto* fusion2 = computation->CreateFusionInstruction(
1222       {neg}, HloInstruction::FusionKind::kLoop);
1223   EXPECT_FALSE(StructuralEqual(*fusion, *fusion2));
1224 
1225   auto clone = fusion->Clone();
1226   EXPECT_TRUE(StructuralEqual(*fusion, *clone));
1227 }
1228 
TEST_F(HloInstructionTest,NestedFusionEquality)1229 TEST_F(HloInstructionTest, NestedFusionEquality) {
1230   auto module = CreateNewVerifiedModule();
1231   HloComputation::Builder builder(TestName());
1232 
1233   // Build a nested fusion computation.
1234   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
1235   auto a = builder.AddInstruction(HloInstruction::CreateConstant(
1236       LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
1237   auto b = builder.AddInstruction(HloInstruction::CreateConstant(
1238       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
1239   auto b_t = builder.AddInstruction(
1240       HloInstruction::CreateTranspose(data_shape, b, {1, 0}));
1241   DotDimensionNumbers dot_dnums;
1242   dot_dnums.add_lhs_contracting_dimensions(1);
1243   dot_dnums.add_rhs_contracting_dimensions(0);
1244   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
1245       data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2)));
1246   auto one = builder.AddInstruction(
1247       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1248   auto add_operand = builder.AddInstruction(
1249       HloInstruction::CreateBroadcast(data_shape, one, {}));
1250   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1251       data_shape, HloOpcode::kAdd, dot, add_operand));
1252   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
1253       data_shape, HloOpcode::kSubtract, dot, add_operand));
1254   builder.AddInstruction(
1255       HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub));
1256   auto computation = module->AddEntryComputation(builder.Build());
1257 
1258   auto nested_fusion = computation->CreateFusionInstruction(
1259       {dot, b_t}, HloInstruction::FusionKind::kLoop);
1260 
1261   auto fusion = computation->CreateFusionInstruction(
1262       {add, nested_fusion}, HloInstruction::FusionKind::kOutput);
1263   auto fusion2 = computation->CreateFusionInstruction(
1264       {sub, nested_fusion}, HloInstruction::FusionKind::kOutput);
1265   auto clone = fusion->Clone();
1266   EXPECT_TRUE(StructuralEqual(*fusion, *clone));
1267   EXPECT_FALSE(StructuralEqual(*fusion, *fusion2));
1268 }
1269 
TEST_F(HloInstructionTest,CloneSuffixNames)1270 TEST_F(HloInstructionTest, CloneSuffixNames) {
1271   // Test that the suffix string added to cloned instructions is not
1272   // duplicated. Rather a numeric incrementing value should be appended. That
1273   // is, we want "foo.clone2", not "foo.clone.clone".
1274 
1275   // Test cloning the same instruction multiple times.
1276   auto foo =
1277       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo");
1278   EXPECT_EQ(foo->Clone()->name(), "foo.clone");
1279   EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2");
1280   EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3");
1281 
1282   // Test custom suffixes.
1283   EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar");
1284   EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2");
1285   EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone");
1286 
1287   // Test instruction name with a dot.
1288   auto foo_baz = HloInstruction::CreateParameter(
1289       0, ShapeUtil::MakeShape(F32, {}), "foo.baz");
1290   EXPECT_EQ(foo_baz->Clone()->name(), "foo.baz.clone");
1291 
1292   // Test incrementing a large number after the suffix.
1293   auto foo_clone234 = HloInstruction::CreateParameter(
1294       0, ShapeUtil::MakeShape(F32, {}), "foo.clone234");
1295   EXPECT_EQ(foo_clone234->Clone()->name(), "foo.clone235");
1296 
1297   // Test a non-numeric string after the cloning suffix.
1298   auto foo_clonexyz = HloInstruction::CreateParameter(
1299       0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz");
1300   EXPECT_EQ(foo_clonexyz->Clone()->name(), "foo.clonexyz.clone");
1301 
1302   // Test a name with multiple appearances of the suffix.
1303   auto foo_clone_clone3 = HloInstruction::CreateParameter(
1304       0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3");
1305   EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4");
1306 }
1307 
TEST_F(HloInstructionTest,Stringification)1308 TEST_F(HloInstructionTest, Stringification) {
1309   // Tests stringification of a simple op, fusion, while, and conditional.
1310   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1311   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1312   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1313   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1314 
1315   HloComputation::Builder builder("TransposeDot");
1316   HloInstruction* x =
1317       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1318   HloInstruction* y =
1319       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1320   HloInstruction* reshape =
1321       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1322   DotDimensionNumbers dot_dnums;
1323   dot_dnums.add_lhs_contracting_dimensions(1);
1324   dot_dnums.add_rhs_contracting_dimensions(0);
1325   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1326       sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1327 
1328   auto options = HloPrintOptions().set_print_metadata(false);
1329 
1330   EXPECT_EQ(dot->ToString(options),
1331             "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} "
1332             "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}");
1333 
1334   auto module = CreateNewVerifiedModule();
1335   auto* computation = module->AddEntryComputation(builder.Build());
1336 
1337   HloInstruction* loop = builder.AddInstruction(
1338       HloInstruction::CreateWhile(sout, computation, computation, x));
1339   EXPECT_EQ(loop->ToString(options),
1340             "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), "
1341             "condition=%TransposeDot, body=%TransposeDot");
1342 
1343   auto pred = builder.AddInstruction(
1344       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1345   HloInstruction* conditional =
1346       builder.AddInstruction(HloInstruction::CreateConditional(
1347           sout, pred, x, computation, x, computation));
1348   EXPECT_EQ(conditional->ToString(options),
1349             "%conditional = f32[5,20]{1,0} conditional(pred[] %constant, "
1350             "f32[5,10]{1,0} %x, f32[5,10]{1,0} %x), "
1351             "true_computation=%TransposeDot, false_computation=%TransposeDot");
1352 }
1353 
TEST_F(HloInstructionTest,StringifyGather_0)1354 TEST_F(HloInstructionTest, StringifyGather_0) {
1355   Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1356   Shape start_indices_tensor_shape =
1357       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
1358   Shape gather_result_shape =
1359       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26});
1360 
1361   HloComputation::Builder builder("Gather");
1362   HloInstruction* input = builder.AddInstruction(
1363       HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
1364   HloInstruction* start_indices =
1365       builder.AddInstruction(HloInstruction::CreateParameter(
1366           1, start_indices_tensor_shape, "start_indices"));
1367 
1368   HloInstruction* gather_instruction = builder.AddInstruction(
1369       HloInstruction::CreateGather(gather_result_shape, input, start_indices,
1370                                    HloGatherInstruction::MakeGatherDimNumbers(
1371                                        /*offset_dims=*/{4, 5, 6, 7, 8},
1372                                        /*collapsed_slice_dims=*/{},
1373                                        /*start_index_map=*/{0, 1, 2, 3, 4},
1374                                        /*index_vector_dim=*/4),
1375                                    /*slice_sizes=*/{30, 29, 28, 27, 26}));
1376 
1377   auto module = CreateNewVerifiedModule();
1378   module->AddEntryComputation(builder.Build());
1379 
1380   EXPECT_EQ(gather_instruction->ToString(),
1381             "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
1382             "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
1383             "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), "
1384             "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
1385             "start_index_map={0,1,2,3,4}, "
1386             "index_vector_dim=4, slice_sizes={30,29,28,27,26}");
1387 }
1388 
TEST_F(HloInstructionTest,StringifyGather_1)1389 TEST_F(HloInstructionTest, StringifyGather_1) {
1390   Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1391   Shape start_indices_tensor_shape =
1392       ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
1393   Shape gather_result_shape =
1394       ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
1395 
1396   HloComputation::Builder builder("Gather");
1397   HloInstruction* input = builder.AddInstruction(
1398       HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
1399   HloInstruction* start_indices =
1400       builder.AddInstruction(HloInstruction::CreateParameter(
1401           1, start_indices_tensor_shape, "start_indices"));
1402 
1403   HloInstruction* gather_instruction = builder.AddInstruction(
1404       HloInstruction::CreateGather(gather_result_shape, input, start_indices,
1405                                    HloGatherInstruction::MakeGatherDimNumbers(
1406                                        /*offset_dims=*/{4, 5, 6, 7, 8},
1407                                        /*collapsed_slice_dims=*/{},
1408                                        /*start_index_map=*/{0, 1, 2, 3, 4},
1409                                        /*index_vector_dim=*/2),
1410                                    /*slice_sizes=*/{30, 29, 28, 27, 26}));
1411 
1412   auto module = CreateNewVerifiedModule();
1413   module->AddEntryComputation(builder.Build());
1414 
1415   EXPECT_EQ(gather_instruction->ToString(),
1416             "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
1417             "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
1418             "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), "
1419             "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
1420             "start_index_map={0,1,2,3,4}, "
1421             "index_vector_dim=2, slice_sizes={30,29,28,27,26}");
1422 }
1423 
TEST_F(HloInstructionTest,StringifyScatter)1424 TEST_F(HloInstructionTest, StringifyScatter) {
1425   Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1426   Shape scatter_indices_tensor_shape =
1427       ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
1428   Shape scatter_updates_shape =
1429       ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
1430 
1431   HloComputation::Builder builder("Scatter");
1432   HloInstruction* input = builder.AddInstruction(
1433       HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
1434   HloInstruction* scatter_indices =
1435       builder.AddInstruction(HloInstruction::CreateParameter(
1436           1, scatter_indices_tensor_shape, "scatter_indices"));
1437   HloInstruction* scatter_updates =
1438       builder.AddInstruction(HloInstruction::CreateParameter(
1439           2, scatter_updates_shape, "scatter_updates"));
1440 
1441   HloComputation::Builder update_builder("Scatter.update");
1442   update_builder.AddInstruction(
1443       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1"));
1444   update_builder.AddInstruction(
1445       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2"));
1446 
1447   auto module = CreateNewVerifiedModule();
1448   auto* update_computation =
1449       module->AddEmbeddedComputation(update_builder.Build());
1450 
1451   HloInstruction* scatter_instruction =
1452       builder.AddInstruction(HloInstruction::CreateScatter(
1453           input_tensor_shape, input, scatter_indices, scatter_updates,
1454           update_computation,
1455           HloScatterInstruction::MakeScatterDimNumbers(
1456               /*update_window_dims=*/{4, 5, 6, 7, 8},
1457               /*inserted_window_dims=*/{},
1458               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
1459               /*index_vector_dim=*/2)));
1460   module->AddEntryComputation(builder.Build());
1461 
1462   EXPECT_EQ(
1463       scatter_instruction->ToString(),
1464       "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} "
1465       "scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
1466       "s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, "
1467       "f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), "
1468       "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, "
1469       "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, "
1470       "to_apply=%Scatter.update");
1471 }
1472 
TEST_F(HloInstructionTest,CanonnicalStringificationFusion)1473 TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
1474   // Tests stringification of a simple op, fusion, while, and conditional.
1475   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1476   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1477   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1478   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1479 
1480   HloComputation::Builder builder("TransposeDot");
1481   HloInstruction* x =
1482       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1483   HloInstruction* y =
1484       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1485   HloInstruction* reshape =
1486       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1487   DotDimensionNumbers dot_dnums;
1488   dot_dnums.add_lhs_contracting_dimensions(1);
1489   dot_dnums.add_rhs_contracting_dimensions(0);
1490   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1491       sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1492 
1493   auto options = HloPrintOptions().Canonical();
1494 
1495   EXPECT_EQ(dot->ToString(options),
1496             "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), "
1497             "lhs_contracting_dims={1}, rhs_contracting_dims={0}");
1498 
1499   auto module = CreateNewVerifiedModule();
1500   auto* computation = module->AddEntryComputation(builder.Build());
1501   HloInstruction* fusion = computation->CreateFusionInstruction(
1502       {dot, reshape}, HloInstruction::FusionKind::kLoop);
1503 
1504   const string expected_fusion =
1505       R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls=
1506 {
1507   tmp_0 = f32[5,10]{1,0} parameter(0)
1508   tmp_1 = f32[20,10]{1,0} parameter(1)
1509   tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1510   ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1511 })";
1512   EXPECT_EQ(fusion->ToString(options), expected_fusion);
1513 }
1514 
TEST_F(HloInstructionTest,CanonnicalStringificationWhile)1515 TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
1516   // Tests stringification of a simple op, fusion, while, and conditional.
1517   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1518   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1519   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1520   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1521 
1522   HloComputation::Builder builder("TransposeDot");
1523   HloInstruction* x =
1524       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1525   HloInstruction* y =
1526       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1527   HloInstruction* reshape =
1528       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1529   DotDimensionNumbers dot_dnums;
1530   dot_dnums.add_lhs_contracting_dimensions(1);
1531   dot_dnums.add_rhs_contracting_dimensions(0);
1532   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1533       sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1534 
1535   auto module = CreateNewVerifiedModule();
1536   auto* computation = module->AddEntryComputation(builder.Build());
1537   computation->CreateFusionInstruction({dot, reshape},
1538                                        HloInstruction::FusionKind::kLoop);
1539 
1540   HloInstruction* loop = builder.AddInstruction(
1541       HloInstruction::CreateWhile(sout, computation, computation, x));
1542 
1543   auto options = HloPrintOptions().Canonical();
1544   const string expected_loop =
1545       R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
1546 {
1547   tmp_0 = f32[5,10]{1,0} parameter(0)
1548   tmp_1 = f32[20,10]{1,0} parameter(1)
1549   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1550   {
1551     tmp_0 = f32[5,10]{1,0} parameter(0)
1552     tmp_1 = f32[20,10]{1,0} parameter(1)
1553     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1554     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1555   }
1556 }, body=
1557 {
1558   tmp_0 = f32[5,10]{1,0} parameter(0)
1559   tmp_1 = f32[20,10]{1,0} parameter(1)
1560   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1561   {
1562     tmp_0 = f32[5,10]{1,0} parameter(0)
1563     tmp_1 = f32[20,10]{1,0} parameter(1)
1564     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1565     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1566   }
1567 })";
1568   EXPECT_EQ(loop->ToString(options), expected_loop);
1569 }
1570 
TEST_F(HloInstructionTest,CanonnicalStringificationConditional)1571 TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
1572   // Tests stringification of a simple op, fusion, while, and conditional.
1573   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1574   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1575   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1576   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1577 
1578   HloComputation::Builder builder("TransposeDot");
1579   HloInstruction* x =
1580       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1581   HloInstruction* y =
1582       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1583   HloInstruction* reshape =
1584       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1585   DotDimensionNumbers dot_dnums;
1586   dot_dnums.add_lhs_contracting_dimensions(1);
1587   dot_dnums.add_rhs_contracting_dimensions(0);
1588   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1589       sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1590 
1591   auto module = CreateNewVerifiedModule();
1592   auto* computation = module->AddEntryComputation(builder.Build());
1593   computation->CreateFusionInstruction({dot, reshape},
1594                                        HloInstruction::FusionKind::kLoop);
1595 
1596   builder.AddInstruction(
1597       HloInstruction::CreateWhile(sout, computation, computation, x));
1598 
1599   auto pred = builder.AddInstruction(
1600       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1601   HloInstruction* conditional =
1602       builder.AddInstruction(HloInstruction::CreateConditional(
1603           sout, pred, x, computation, x, computation));
1604   auto options = HloPrintOptions().Canonical();
1605   const string expected_conditional =
1606       R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation=
1607 {
1608   tmp_0 = f32[5,10]{1,0} parameter(0)
1609   tmp_1 = f32[20,10]{1,0} parameter(1)
1610   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1611   {
1612     tmp_0 = f32[5,10]{1,0} parameter(0)
1613     tmp_1 = f32[20,10]{1,0} parameter(1)
1614     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1615     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1616   }
1617 }, false_computation=
1618 {
1619   tmp_0 = f32[5,10]{1,0} parameter(0)
1620   tmp_1 = f32[20,10]{1,0} parameter(1)
1621   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1622   {
1623     tmp_0 = f32[5,10]{1,0} parameter(0)
1624     tmp_1 = f32[20,10]{1,0} parameter(1)
1625     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1626     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1627   }
1628 })";
1629   EXPECT_EQ(conditional->ToString(options), expected_conditional);
1630 }
1631 
TEST_F(HloInstructionTest,CheckDeepClone)1632 TEST_F(HloInstructionTest, CheckDeepClone) {
1633   const char* const hlo_string = R"(
1634 HloModule Module
1635 
1636 addy (lhs: s32[], rhs: s32[]) -> s32[] {
1637   lhs = s32[] parameter(0)
1638   rhs = s32[] parameter(1)
1639   ROOT zadd = s32[] add(lhs, rhs)
1640 }
1641 
1642 calla (x: s32[]) -> s32[] {
1643   x = s32[] parameter(0)
1644   reduce = s32[] reduce-window(x, x), to_apply=addy
1645   ROOT xadd = s32[] add(x, reduce)
1646 }
1647 
1648 body (bparam: s32[]) -> s32[] {
1649   constant = s32[] constant(1)
1650   bparam = s32[] parameter(0)
1651   v = s32[] call(bparam), to_apply=calla
1652   ROOT add = s32[] add(constant, bparam)
1653 }
1654 
1655 condition (cparam: s32[]) -> pred[] {
1656   xconstant = s32[] constant(5)
1657   cparam = s32[] parameter(0)
1658   ROOT greater-than = pred[] compare(xconstant, cparam), direction=GT
1659 }
1660 
1661 ENTRY entry (param: s32[]) -> s32[] {
1662   eparam = s32[] parameter(0)
1663   ROOT while = s32[] while(eparam), condition=condition, body=body
1664  }
1665 )";
1666   // Check that deep clones really deep clones every instruction and
1667   // computations, without leaving dangling pointers to the old module.
1668   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1669                           ParseHloString(hlo_string));
1670   std::unique_ptr<HloModule> clone = module->Clone();
1671   for (HloComputation* computation : clone->computations()) {
1672     EXPECT_EQ(computation->parent(), clone.get());
1673     for (HloInstruction* instruction : computation->instructions()) {
1674       EXPECT_EQ(instruction->parent()->parent(), clone.get());
1675     }
1676   }
1677 }
1678 
TEST_F(HloInstructionTest,IdenticalAccountsForBackendConfig)1679 TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) {
1680   const Shape shape = ShapeUtil::MakeShape(F32, {42});
1681   HloComputation::Builder builder("test");
1682   HloInstruction* p =
1683       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
1684 
1685   HloInstruction* add1 = builder.AddInstruction(
1686       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p));
1687   HloInstruction* add2 = builder.AddInstruction(
1688       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p));
1689 
1690   EXPECT_TRUE(add1->Identical(*add2));
1691   add1->set_raw_backend_config_string("abc");
1692   EXPECT_FALSE(add1->Identical(*add2));
1693 }
1694 
TEST_F(HloInstructionTest,IdenticalAccountsForCustomCallWindow)1695 TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) {
1696   auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
1697                                                  /*operands=*/{},
1698                                                  /*custom_call_target=*/"foo");
1699   auto instr2 = instr1->Clone();
1700   EXPECT_TRUE(instr1->Identical(*instr2));
1701 
1702   Window w = window_util::MakeWindow({1, 2, 3});
1703   instr1->set_window(w);
1704   EXPECT_FALSE(instr1->Identical(*instr2));
1705 }
1706 
TEST_F(HloInstructionTest,IdenticalAccountsForCustomCallDnums)1707 TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) {
1708   auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
1709                                                  /*operands=*/{},
1710                                                  /*custom_call_target=*/"foo");
1711   auto instr2 = instr1->Clone();
1712   EXPECT_TRUE(instr1->Identical(*instr2));
1713 
1714   ConvolutionDimensionNumbers dnums;
1715   dnums.set_output_batch_dimension(42);
1716   instr1->set_convolution_dimension_numbers(dnums);
1717   EXPECT_FALSE(instr1->Identical(*instr2));
1718 }
1719 
TEST_F(HloInstructionTest,CloneWindowOnCustomCall)1720 TEST_F(HloInstructionTest, CloneWindowOnCustomCall) {
1721   auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
1722                                                 /*operands=*/{},
1723                                                 /*custom_call_target=*/"foo");
1724   Window w = window_util::MakeWindow({1, 2, 3});
1725   instr->set_window(w);
1726   auto clone = instr->Clone();
1727   EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w))
1728       << clone->window().DebugString();
1729 }
1730 
TEST_F(HloInstructionTest,CloneDnumsOnCustomCall)1731 TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) {
1732   auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
1733                                                 /*operands=*/{},
1734                                                 /*custom_call_target=*/"foo");
1735   ConvolutionDimensionNumbers dnums;
1736   dnums.set_output_batch_dimension(42);
1737   instr->set_convolution_dimension_numbers(dnums);
1738   auto clone = instr->Clone();
1739   EXPECT_TRUE(protobuf_util::ProtobufEquals(
1740       clone->convolution_dimension_numbers(), dnums))
1741       << clone->convolution_dimension_numbers().DebugString();
1742 }
1743 
TEST_F(HloInstructionTest,PreserveOperandPrecisionOnCloneConv)1744 TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) {
1745   constexpr char kHloString[] = R"(
1746   HloModule test_module
1747   ENTRY test {
1748     arg0 = f32[1,2,1] parameter(0)
1749     arg1 = f32[1,1,1] parameter(1)
1750     ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1},
1751       dim_labels=b0f_0io->b0f, operand_precision={high,default}
1752   })";
1753   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kHloString));
1754   auto* conv = module->entry_computation()->root_instruction();
1755 
1756   auto clone = conv->Clone();
1757   EXPECT_THAT(
1758       clone->precision_config().operand_precision(),
1759       ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT));
1760 }
1761 
1762 }  // namespace
1763 }  // namespace xla
1764