1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17
18 #include <set>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/protobuf_util.h"
26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/test.h"
32 #include "tensorflow/compiler/xla/test_helpers.h"
33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/compiler/xla/window_util.h"
36
37 namespace xla {
38 namespace {
39
40 using ::testing::ElementsAre;
41 using ::testing::UnorderedElementsAre;
42
43 class HloInstructionTest : public HloTestBase {
44 protected:
45 Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
46 };
47
48 // Simple visitor that collects the number of users and operands for certain HLO
49 // nodes. It also verifies some of the DFS visiting invariants (operands visited
50 // before their users, nodes not visited twice, etc.)
51 class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
52 public:
DefaultAction(HloInstruction * hlo_instruction)53 Status DefaultAction(HloInstruction* hlo_instruction) override {
54 return Unimplemented("not implemented %s",
55 HloOpcodeString(hlo_instruction->opcode()));
56 }
57
HandleParameter(HloInstruction * parameter)58 Status HandleParameter(HloInstruction* parameter) override {
59 EXPECT_FALSE(count_.contains(parameter));
60 count_[parameter] = GetCountsForNode(parameter);
61 return Status::OK();
62 }
63
HandleConstant(HloInstruction * constant)64 Status HandleConstant(HloInstruction* constant) override {
65 EXPECT_FALSE(count_.contains(constant));
66 count_[constant] = GetCountsForNode(constant);
67 return Status::OK();
68 }
69
HandleAdd(HloInstruction * add)70 Status HandleAdd(HloInstruction* add) override {
71 auto lhs = add->operand(0);
72 auto rhs = add->operand(1);
73 EXPECT_FALSE(count_.contains(add));
74 EXPECT_TRUE(count_.contains(lhs));
75 EXPECT_TRUE(count_.contains(rhs));
76 count_[add] = GetCountsForNode(add);
77 return Status::OK();
78 }
79
HandleNegate(HloInstruction * negate)80 Status HandleNegate(HloInstruction* negate) override {
81 auto operand = negate->operand(0);
82 EXPECT_FALSE(count_.contains(negate));
83 EXPECT_TRUE(count_.contains(operand));
84 count_[negate] = GetCountsForNode(negate);
85 return Status::OK();
86 }
87
HandleMap(HloInstruction * map)88 Status HandleMap(HloInstruction* map) override {
89 EXPECT_FALSE(count_.contains(map));
90 for (HloInstruction* arg : map->operands()) {
91 EXPECT_TRUE(count_.contains(arg));
92 }
93 count_[map] = GetCountsForNode(map);
94 return Status::OK();
95 }
96
HandleReduce(HloInstruction * reduce)97 Status HandleReduce(HloInstruction* reduce) override {
98 auto arg = reduce->operand(0);
99 auto init_value = reduce->operand(1);
100 EXPECT_FALSE(count_.contains(reduce));
101 EXPECT_TRUE(count_.contains(arg));
102 EXPECT_TRUE(count_.contains(init_value));
103 count_[reduce] = GetCountsForNode(reduce);
104 return Status::OK();
105 }
106
NumOperands(const HloInstruction * node)107 int64 NumOperands(const HloInstruction* node) {
108 auto count_iterator = count_.find(node);
109 EXPECT_NE(count_.end(), count_iterator);
110 return count_iterator->second.operand_count;
111 }
112
NumUsers(const HloInstruction * node)113 int64 NumUsers(const HloInstruction* node) {
114 auto count_iterator = count_.find(node);
115 EXPECT_NE(count_.end(), count_iterator);
116 return count_iterator->second.user_count;
117 }
118
119 private:
120 struct NumOpsAndUsers {
121 int64 operand_count;
122 int64 user_count;
123 };
124
125 // Helper function to count operands and users for the given HLO.
GetCountsForNode(const HloInstruction * node)126 NumOpsAndUsers GetCountsForNode(const HloInstruction* node) {
127 NumOpsAndUsers counts{node->operand_count(), node->user_count()};
128 return counts;
129 }
130
131 // Counters for HLOs. Maps HLO to a NumOpsAndUsers.
132 absl::flat_hash_map<const HloInstruction*, NumOpsAndUsers> count_;
133 };
134
TEST_F(HloInstructionTest,BasicProperties)135 TEST_F(HloInstructionTest, BasicProperties) {
136 auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo");
137
138 EXPECT_EQ(HloOpcode::kParameter, parameter->opcode());
139 EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32));
140 EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32));
141 EXPECT_FALSE(parameter->operand_count());
142 }
143
TEST_F(HloInstructionTest,UserWithTwoOperands)144 TEST_F(HloInstructionTest, UserWithTwoOperands) {
145 // [Param foo]-----> |-----|
146 // | Add |
147 // [Param bar]-----> |-----|
148 HloComputation::Builder builder(TestName());
149 auto foo =
150 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
151 auto bar =
152 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
153 auto add = builder.AddInstruction(
154 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
155 auto module = CreateNewVerifiedModule();
156 module->AddEntryComputation(builder.Build());
157
158 EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar));
159 EXPECT_THAT(foo->users(), UnorderedElementsAre(add));
160 EXPECT_THAT(bar->users(), UnorderedElementsAre(add));
161
162 OpAndUserCollectingVisitor visitor;
163 ASSERT_IS_OK(add->Accept(&visitor));
164
165 EXPECT_EQ(2, visitor.NumOperands(add));
166 EXPECT_EQ(0, visitor.NumUsers(add));
167 EXPECT_EQ(1, visitor.NumUsers(foo));
168 EXPECT_EQ(1, visitor.NumUsers(bar));
169 }
170
TEST_F(HloInstructionTest,MultipleUsers)171 TEST_F(HloInstructionTest, MultipleUsers) {
172 // [Param foo]
173 // / | \
174 // / | \ [Param bar]
175 // / | \ |
176 // | | | |
177 // V V V V
178 // ------- ------- -----------
179 // | exp | | exp | | add |
180 // ------- ------- -----------
181 HloComputation::Builder builder(TestName());
182 auto foo =
183 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
184 auto bar =
185 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
186 auto exp1 = builder.AddInstruction(
187 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
188 auto exp2 = builder.AddInstruction(
189 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
190 auto add = builder.AddInstruction(
191 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
192 auto module = CreateNewVerifiedModule();
193 module->AddEntryComputation(builder.Build());
194
195 EXPECT_EQ(3, foo->user_count());
196 EXPECT_EQ(1, bar->user_count());
197 EXPECT_EQ(0, exp1->user_count());
198 EXPECT_EQ(0, exp2->user_count());
199 EXPECT_EQ(0, add->user_count());
200
201 OpAndUserCollectingVisitor visitor;
202 ASSERT_IS_OK(add->Accept(&visitor));
203
204 EXPECT_EQ(2, visitor.NumOperands(add));
205 EXPECT_EQ(3, visitor.NumUsers(foo));
206 }
207
TEST_F(HloInstructionTest,RepeatedUser)208 TEST_F(HloInstructionTest, RepeatedUser) {
209 // Here we have a user 'add' nodes that uses the same HLO in both operands.
210 // Make sure we don't count it as two distinct users.
211 //
212 // [Param foo]
213 // | |
214 // | |
215 // | |
216 // V V
217 // -------
218 // | add |
219 // -------
220 HloComputation::Builder builder(TestName());
221 auto foo =
222 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
223 auto add = builder.AddInstruction(
224 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
225 auto module = CreateNewVerifiedModule();
226 module->AddEntryComputation(builder.Build());
227
228 EXPECT_EQ(1, foo->user_count());
229
230 // But 'add' still has two operands, even if both are the same HLO.
231 EXPECT_EQ(2, add->operand_count());
232 }
233
TEST_F(HloInstructionTest,MultipleUsersAndOperands)234 TEST_F(HloInstructionTest, MultipleUsersAndOperands) {
235 // [param0] [param1]
236 // | |
237 // | [c0] |
238 // | | |
239 // V | V
240 // ------- | -------
241 // | add | <---^---> | add |
242 // ------- -------
243 // | |
244 // \ ------- /
245 // ---->| add |<----
246 // -------
247 HloComputation::Builder builder(TestName());
248 auto param0 = builder.AddInstruction(
249 HloInstruction::CreateParameter(0, r0f32_, "param0"));
250 auto param1 = builder.AddInstruction(
251 HloInstruction::CreateParameter(1, r0f32_, "param1"));
252 auto c0 = builder.AddInstruction(
253 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
254 auto addleft = builder.AddInstruction(
255 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, c0));
256 auto addright = builder.AddInstruction(
257 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1));
258 auto addtotal = builder.AddInstruction(
259 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
260 auto module = CreateNewVerifiedModule();
261 module->AddEntryComputation(builder.Build());
262
263 OpAndUserCollectingVisitor visitor;
264 ASSERT_IS_OK(addtotal->Accept(&visitor));
265
266 EXPECT_EQ(2, visitor.NumUsers(c0));
267 EXPECT_EQ(2, visitor.NumOperands(addleft));
268 EXPECT_EQ(2, visitor.NumOperands(addright));
269 EXPECT_EQ(2, visitor.NumOperands(addtotal));
270 }
271
TEST_F(HloInstructionTest,MultipleUsersAndOperandsWithUnaryOps)272 TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) {
273 // [param0] [c0] [param1]
274 // | | |
275 // | V |
276 // | ------- |
277 // | | neg | |
278 // | ------- |
279 // V | V
280 // ------- | -------
281 // | add | <---^---> | add |
282 // ------- -------
283 // | |
284 // \ ------- /
285 // ---->| add |<----
286 // -------
287 // |
288 // V
289 // -------
290 // | neg |
291 // -------
292 HloComputation::Builder builder(TestName());
293 auto param0 = builder.AddInstruction(
294 HloInstruction::CreateParameter(0, r0f32_, "param0"));
295 auto param1 = builder.AddInstruction(
296 HloInstruction::CreateParameter(1, r0f32_, "param1"));
297 auto c0 = builder.AddInstruction(
298 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
299 auto neg1 = builder.AddInstruction(
300 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0));
301 auto addleft = builder.AddInstruction(
302 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, neg1));
303 auto addright = builder.AddInstruction(
304 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, neg1, param1));
305 auto addtotal = builder.AddInstruction(
306 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
307 auto neg2 = builder.AddInstruction(
308 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal));
309 auto module = CreateNewVerifiedModule();
310 module->AddEntryComputation(builder.Build());
311
312 OpAndUserCollectingVisitor visitor;
313 ASSERT_IS_OK(neg2->Accept(&visitor));
314
315 EXPECT_EQ(1, visitor.NumUsers(c0));
316 EXPECT_EQ(2, visitor.NumUsers(neg1));
317 EXPECT_EQ(2, visitor.NumOperands(addleft));
318 EXPECT_EQ(2, visitor.NumOperands(addright));
319 EXPECT_EQ(2, visitor.NumOperands(addtotal));
320 EXPECT_EQ(1, visitor.NumOperands(neg2));
321 EXPECT_EQ(0, visitor.NumUsers(neg2));
322 }
323
TEST_F(HloInstructionTest,TrivialMap)324 TEST_F(HloInstructionTest, TrivialMap) {
325 // This tests creating a trivial x+1 map as the only operation.
326 //
327 // param0[100x10] ---> (map x+1)
328 //
329 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
330 Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
331 auto module = CreateNewVerifiedModule();
332
333 // Builds an x+1.0 computation to use in a Map.
334 auto embedded_builder = HloComputation::Builder("f32+1");
335 auto param = embedded_builder.AddInstruction(
336 HloInstruction::CreateParameter(0, r0f32, "x"));
337 auto value = embedded_builder.AddInstruction(
338 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
339 embedded_builder.AddInstruction(
340 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value));
341 auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
342
343 // Builds a parameter and feeds it to the map.
344 HloComputation::Builder builder(TestName());
345 auto param0 = builder.AddInstruction(
346 HloInstruction::CreateParameter(0, f32a100x10, "p"));
347 auto map = builder.AddInstruction(
348 HloInstruction::CreateMap(f32a100x10, {param0}, add_f32));
349 module->AddEntryComputation(builder.Build());
350
351 OpAndUserCollectingVisitor visitor;
352 ASSERT_IS_OK(map->Accept(&visitor));
353
354 // Check counts. We aren't walking the mapper computation yet.
355 EXPECT_EQ(1, visitor.NumUsers(param0));
356 EXPECT_EQ(0, visitor.NumUsers(map));
357 EXPECT_EQ(1, visitor.NumOperands(map));
358
359 // TODO(dehnert): Add walking and counters for the wrapped computation.
360 }
361
TEST_F(HloInstructionTest,TrivialReduce)362 TEST_F(HloInstructionTest, TrivialReduce) {
363 // This tests creating a trivial x+y reduce as the only operation.
364 //
365 // param0[100x10] ---> (reduce x+y)
366 //
367 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
368 Shape f32v100 = ShapeUtil::MakeShape(F32, {100});
369 Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
370
371 // Builds an x+y computation to use in a Reduce.
372 auto embedded_builder = HloComputation::Builder("f32+f32");
373 auto paramx = embedded_builder.AddInstruction(
374 HloInstruction::CreateParameter(0, r0f32, "x"));
375 auto paramy = embedded_builder.AddInstruction(
376 HloInstruction::CreateParameter(1, r0f32, "y"));
377 embedded_builder.AddInstruction(
378 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy));
379 auto module = CreateNewVerifiedModule();
380 auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
381
382 // Builds a parameter and an initial value and feeds them to the reduce.
383 HloComputation::Builder builder(TestName());
384 auto param0 = builder.AddInstruction(
385 HloInstruction::CreateParameter(0, f32a100x10, "p"));
386 auto const0 = builder.AddInstruction(
387 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
388 builder.AddInstruction(
389 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
390 auto reduce = builder.AddInstruction(
391 HloInstruction::CreateReduce(f32v100, param0, const0,
392 /*dimensions_to_reduce=*/{1}, add_f32));
393 module->AddEntryComputation(builder.Build());
394
395 OpAndUserCollectingVisitor visitor;
396 ASSERT_IS_OK(reduce->Accept(&visitor));
397
398 // Check counts. We aren't walking the reducer computation.
399 EXPECT_EQ(1, visitor.NumUsers(param0));
400 EXPECT_EQ(1, visitor.NumUsers(const0));
401 EXPECT_EQ(0, visitor.NumUsers(reduce));
402 EXPECT_EQ(2, visitor.NumOperands(reduce));
403 }
404
TEST_F(HloInstructionTest,ReplaceUseInBinaryOps)405 TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) {
406 // Construct a graph of a few binary ops using two different
407 // parameters. Replace one of the parameters with the other parameter in one
408 // of the instructions.
409 HloComputation::Builder builder(TestName());
410 auto foo =
411 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
412 auto bar =
413 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
414 auto add_foobar = builder.AddInstruction(
415 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
416 auto add_foofoo = builder.AddInstruction(
417 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
418 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
419 add_foobar, add_foofoo));
420 auto module = CreateNewVerifiedModule();
421 module->AddEntryComputation(builder.Build());
422
423 EXPECT_EQ(2, foo->user_count());
424 EXPECT_EQ(1, bar->user_count());
425
426 // Replace the use of foo in add_foofoo with bar.
427 ASSERT_IS_OK(foo->ReplaceUseWith(add_foofoo, bar));
428
429 EXPECT_EQ(1, foo->user_count());
430 EXPECT_EQ(2, bar->user_count());
431
432 EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar));
433 EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar));
434
435 EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo));
436 EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar));
437 EXPECT_THAT(add_foofoo->operands(), ElementsAre(bar, bar));
438 }
439
TEST_F(HloInstructionTest,ReplaceUseInVariadicOp)440 TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) {
441 // Construct a tuple containing several parameters. Replace one parameter with
442 // another in the tuple.
443 HloComputation::Builder builder(TestName());
444 auto foo =
445 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
446 auto bar =
447 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
448 auto baz =
449 builder.AddInstruction(HloInstruction::CreateParameter(2, r0f32_, "baz"));
450
451 auto tuple =
452 builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo}));
453 auto add_foobar = builder.AddInstruction(
454 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
455 auto module = CreateNewVerifiedModule();
456 module->AddEntryComputation(builder.Build());
457
458 EXPECT_EQ(2, foo->user_count());
459 EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar));
460
461 // Replace the use of foo in tuple with bar.
462 ASSERT_IS_OK(foo->ReplaceUseWith(tuple, bar));
463
464 EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar));
465
466 // Both uses of foo in tuple should have been replaced with bar.
467 EXPECT_THAT(tuple->operands(), ElementsAre(bar, bar, baz, bar));
468 }
469
TEST_F(HloInstructionTest,ReplaceUseInUnaryOp)470 TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) {
471 // Construct a couple unary instructions which use a parameter. Replace the
472 // use of a parameter in one of the unary ops with the other parameter.
473 HloComputation::Builder builder(TestName());
474 auto foo =
475 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
476 auto bar =
477 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
478
479 auto exp = builder.AddInstruction(
480 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
481 auto log = builder.AddInstruction(
482 HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
483 auto module = CreateNewVerifiedModule();
484 module->AddEntryComputation(builder.Build());
485
486 EXPECT_EQ(2, foo->user_count());
487 EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log));
488 EXPECT_EQ(0, bar->user_count());
489
490 // Replace the use of foo in exp with bar.
491 ASSERT_IS_OK(foo->ReplaceUseWith(exp, bar));
492
493 // The use of foo in log should not have been affected.
494 EXPECT_EQ(1, foo->user_count());
495 EXPECT_THAT(foo->users(), UnorderedElementsAre(log));
496 EXPECT_THAT(log->operands(), ElementsAre(foo));
497
498 // Bar should now be used in exp.
499 EXPECT_EQ(1, bar->user_count());
500 EXPECT_EQ(*bar->users().begin(), exp);
501 EXPECT_EQ(1, exp->operands().size());
502 EXPECT_EQ(*exp->operands().begin(), bar);
503 }
504
TEST_F(HloInstructionTest,ReplaceAllUsesWithInBinaryOps)505 TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) {
506 // Construct a simple graph of a few binary ops using two different
507 // parameters. Replace all uses of one of the parameters with the other
508 // parameter.
509 HloComputation::Builder builder(TestName());
510 auto foo =
511 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
512 auto bar =
513 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
514 auto add_foobar = builder.AddInstruction(
515 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
516 auto add_foofoo = builder.AddInstruction(
517 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
518 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
519 add_foobar, add_foofoo));
520 auto module = CreateNewVerifiedModule();
521 module->AddEntryComputation(builder.Build());
522
523 EXPECT_EQ(2, foo->user_count());
524 EXPECT_EQ(1, bar->user_count());
525
526 // Replace all uses of foo with bar.
527 ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar));
528
529 EXPECT_EQ(0, foo->user_count());
530 EXPECT_EQ(2, bar->user_count());
531
532 EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo));
533 }
534
TEST_F(HloInstructionTest,ReplaceAllUsesInMultipleOps)535 TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) {
536 // Construct a graph containing several ops (a unary, binary, and variadic)
537 // which use two parameters. Replace all uses of one of the parameters with
538 // the other parameter.
539 HloComputation::Builder builder(TestName());
540 auto foo =
541 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
542 auto bar =
543 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
544
545 auto add_foobar = builder.AddInstruction(
546 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
547 auto exp = builder.AddInstruction(
548 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
549 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar}));
550 auto module = CreateNewVerifiedModule();
551 module->AddEntryComputation(builder.Build());
552
553 EXPECT_EQ(3, foo->user_count());
554 EXPECT_EQ(2, bar->user_count());
555
556 // Replace all uses of foo with bar.
557 ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar));
558
559 EXPECT_EQ(0, foo->user_count());
560 EXPECT_EQ(3, bar->user_count());
561
562 EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, exp, tuple));
563 }
564
565 // Simple visitor that collects and post-processes each node in the graph.
566 class NodeCollectorAndPostProcessor : public DfsHloVisitorWithDefault {
567 public:
NodeCollectorAndPostProcessor()568 NodeCollectorAndPostProcessor() {}
569
Postprocess(HloInstruction * hlo)570 Status Postprocess(HloInstruction* hlo) override {
571 post_processed_nodes_.push_back(hlo);
572 return Status::OK();
573 }
574
DefaultAction(HloInstruction * hlo_instruction)575 Status DefaultAction(HloInstruction* hlo_instruction) override {
576 visited_nodes_.push_back(hlo_instruction);
577 return Status::OK();
578 }
579
visited_nodes()580 const std::vector<const HloInstruction*>& visited_nodes() {
581 return visited_nodes_;
582 }
583
post_processed_nodes()584 const std::vector<const HloInstruction*>& post_processed_nodes() {
585 return post_processed_nodes_;
586 }
587
588 private:
589 std::vector<const HloInstruction*> visited_nodes_;
590 std::vector<const HloInstruction*> post_processed_nodes_;
591 };
592
593 // Returns true if "vec" contains distinct nodes.
Distinct(const std::vector<const HloInstruction * > & vec)594 bool Distinct(const std::vector<const HloInstruction*>& vec) {
595 std::set<const HloInstruction*> distinct_nodes(vec.begin(), vec.end());
596 return distinct_nodes.size() == vec.size();
597 }
598
TEST_F(HloInstructionTest,PostProcessAllVisitedNodes)599 TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) {
600 // Verifies all the nodes are visited and post-processed in the same order,
601 // and that each node is visited exactly once.
602 //
603 // /--> exp --\
604 // foo add
605 // \--> log --/
606 HloComputation::Builder builder(TestName());
607 auto foo =
608 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
609 auto exp = builder.AddInstruction(
610 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
611 auto log = builder.AddInstruction(
612 HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
613 auto add = builder.AddInstruction(
614 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log));
615 auto module = CreateNewVerifiedModule();
616 module->AddEntryComputation(builder.Build());
617
618 NodeCollectorAndPostProcessor visitor;
619 ASSERT_IS_OK(add->Accept(&visitor));
620 // Verifies all the nodes are visited and post-processed in the same order.
621 EXPECT_EQ(visitor.visited_nodes(), visitor.post_processed_nodes());
622 // Verifies each node is visited exactly once.
623 EXPECT_TRUE(Distinct(visitor.visited_nodes()));
624 }
625
TEST_F(HloInstructionTest,SingletonFusionOp)626 TEST_F(HloInstructionTest, SingletonFusionOp) {
627 HloComputation::Builder builder(TestName());
628 // Create a fusion instruction containing a single unary operation.
629 auto constant = builder.AddInstruction(
630 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
631 auto exp = builder.AddInstruction(
632 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
633 auto module = CreateNewVerifiedModule();
634 auto* computation = module->AddEntryComputation(builder.Build());
635 auto* fusion = computation->CreateFusionInstruction(
636 {exp}, HloInstruction::FusionKind::kLoop);
637
638 EXPECT_THAT(fusion->operands(), ElementsAre(constant));
639 EXPECT_THAT(constant->users(), ElementsAre(fusion));
640 }
641
TEST_F(HloInstructionTest,BinaryFusionOp)642 TEST_F(HloInstructionTest, BinaryFusionOp) {
643 HloComputation::Builder builder(TestName());
644 // Create a fusion instruction containing a single binary operation.
645 auto constant1 = builder.AddInstruction(
646 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
647 auto constant2 = builder.AddInstruction(
648 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
649 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
650 r0f32_, HloOpcode::kAdd, constant1, constant2));
651 auto module = CreateNewVerifiedModule();
652 auto* computation = module->AddEntryComputation(builder.Build());
653 auto* fusion = computation->CreateFusionInstruction(
654 {add}, HloInstruction::FusionKind::kLoop);
655
656 EXPECT_THAT(fusion->operands(), ElementsAre(constant1, constant2));
657 EXPECT_THAT(constant1->users(), ElementsAre(fusion));
658 EXPECT_THAT(constant2->users(), ElementsAre(fusion));
659 }
660
TEST_F(HloInstructionTest,ChainFusionOp)661 TEST_F(HloInstructionTest, ChainFusionOp) {
662 HloComputation::Builder builder(TestName());
663 // Create a chain of fused unary ops.
664 auto constant = builder.AddInstruction(
665 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
666 auto exp1 = builder.AddInstruction(
667 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
668 auto exp2 = builder.AddInstruction(
669 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
670 auto exp3 = builder.AddInstruction(
671 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
672
673 auto module = CreateNewVerifiedModule();
674 auto* computation = module->AddEntryComputation(builder.Build());
675 auto* fusion = computation->CreateFusionInstruction(
676 {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop);
677
678 EXPECT_THAT(fusion->operands(), ElementsAre(constant));
679 EXPECT_THAT(constant->users(), ElementsAre(fusion));
680 }
681
TEST_F(HloInstructionTest,PreserveMetadataInFusionAndClone)682 TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
683 HloComputation::Builder builder(TestName());
684 // Create a chain of fused unary ops.
685 auto constant = builder.AddInstruction(
686 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
687 auto exp1 = builder.AddInstruction(
688 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
689 auto exp2 = builder.AddInstruction(
690 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
691 OpMetadata metadata;
692 metadata.set_op_name("tf_op");
693 exp1->set_metadata(metadata);
694 exp2->set_metadata(metadata);
695
696 auto module = CreateNewVerifiedModule();
697 auto* computation = module->AddEntryComputation(builder.Build());
698 auto* fusion = computation->CreateFusionInstruction(
699 {exp2, exp1}, HloInstruction::FusionKind::kLoop);
700
701 EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
702 EXPECT_TRUE(protobuf_util::ProtobufEquals(
703 metadata, fusion->fused_expression_root()->metadata()));
704 EXPECT_TRUE(protobuf_util::ProtobufEquals(
705 metadata, fusion->fused_expression_root()->operand(0)->metadata()));
706
707 auto cloned = fusion->CloneWithNewOperands(fusion->shape(), {});
708 EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
709 }
710
TEST_F(HloInstructionTest,PreserveOutfeedShapeThroughClone)711 TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
712 HloComputation::Builder builder(TestName());
713 auto constant = builder.AddInstruction(
714 HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
715 {1, 2},
716 {3, 4},
717 })));
718 auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
719 auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
720 auto token = builder.AddInstruction(HloInstruction::CreateToken());
721 auto outfeed10 = builder.AddInstruction(
722 HloInstruction::CreateOutfeed(shape10, constant, token, ""));
723 auto outfeed01 = builder.AddInstruction(
724 HloInstruction::CreateOutfeed(shape01, constant, token, ""));
725
726 auto clone01 = builder.AddInstruction(outfeed01->Clone());
727 auto clone10 = builder.AddInstruction(outfeed10->Clone());
728
729 EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01));
730 EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10));
731 }
732
TEST_F(HloInstructionTest,PreserveTupleShapeThroughClone)733 TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
734 HloComputation::Builder builder(TestName());
735 auto* constant = builder.AddInstruction(
736 HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
737 {1, 2},
738 {3, 4},
739 })));
740 auto* tuple =
741 builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
742 *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {0})
743 ->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
744 *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {1})
745 ->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
746 auto tuple_clone = tuple->Clone();
747 EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape()));
748 }
749
TEST_F(HloInstructionTest,FusionOpWithCalledComputations)750 TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
751 // Create a fusion instruction containing a single unary operation.
752 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
753 auto module = CreateNewVerifiedModule();
754
755 auto make_map_computation = [&]() {
756 auto builder = HloComputation::Builder("FusionMap");
757 builder.AddInstruction(
758 HloInstruction::CreateParameter(0, scalar_shape, "param"));
759 return module->AddEmbeddedComputation(builder.Build());
760 };
761
762 HloComputation* computation_x = make_map_computation();
763 HloComputation* computation_y = make_map_computation();
764
765 HloComputation::Builder builder(TestName());
766 auto constant = builder.AddInstruction(
767 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
768 auto map_1_x = builder.AddInstruction(
769 HloInstruction::CreateMap(scalar_shape, {constant}, computation_x));
770 auto map_2_x = builder.AddInstruction(
771 HloInstruction::CreateMap(scalar_shape, {map_1_x}, computation_x));
772 auto map_3_y = builder.AddInstruction(
773 HloInstruction::CreateMap(scalar_shape, {map_2_x}, computation_y));
774 auto* computation = module->AddEntryComputation(builder.Build());
775
776 auto* fusion = computation->CreateFusionInstruction(
777 {map_3_y}, HloInstruction::FusionKind::kLoop);
778 auto* fused_computation = fusion->fused_instructions_computation();
779 EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
780
781 fusion->FuseInstruction(map_2_x);
782 EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
783
784 fusion->FuseInstruction(map_1_x);
785 EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
786 }
787
TEST_F(HloInstructionTest,ComplexFusionOp)788 TEST_F(HloInstructionTest, ComplexFusionOp) {
789 HloComputation::Builder builder(TestName());
790 // Fuse all instructions in complicated expression:
791 //
792 // add = Add(C1, C2)
793 // clamp = Clamp(C2, add, add)
794 // exp = Exp(add)
795 // mul = Mul(exp, C3)
796 // sub = Sub(mul, clamp)
797 // tuple = Tuple({sub, sub, mul, C1})
798 //
799 // Notable complexities are repeated operands in the same instruction,
800 // different shapes, use of value in different expressions.
801 auto c1 = builder.AddInstruction(
802 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
803 auto c2 = builder.AddInstruction(
804 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.1f)));
805 auto c3 = builder.AddInstruction(
806 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(9.0f)));
807
808 auto add = builder.AddInstruction(
809 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2));
810 auto clamp = builder.AddInstruction(
811 HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, c2, add, add));
812 auto exp = builder.AddInstruction(
813 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add));
814 auto mul = builder.AddInstruction(
815 HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, exp, c3));
816 auto sub = builder.AddInstruction(
817 HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, mul, clamp));
818 auto tuple =
819 builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1}));
820
821 auto module = CreateNewVerifiedModule();
822 auto* computation = module->AddEntryComputation(builder.Build());
823 auto* fusion = computation->CreateFusionInstruction(
824 {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
825
826 // Operands in the fusion instruction's operands() vector should be in the
827 // order in which their users were added fused.
828 EXPECT_THAT(fusion->operands(), ElementsAre(c1, c3, c2));
829 EXPECT_THAT(c1->users(), ElementsAre(fusion));
830 }
831
832 // Convenience function for comparing two HloInstructions.
Identical(const HloInstruction & instruction1,const HloInstruction & instruction2)833 static bool Identical(const HloInstruction& instruction1,
834 const HloInstruction& instruction2) {
835 // Verify Identical is reflexive for both instructions.
836 EXPECT_TRUE(instruction1.Identical(instruction1));
837 EXPECT_TRUE(instruction2.Identical(instruction2));
838
839 bool is_equal = instruction1.Identical(instruction2);
840 // Verify Identical is symmetric.
841 EXPECT_EQ(is_equal, instruction2.Identical(instruction1));
842 return is_equal;
843 }
844
845 // Convenience function for comparing two HloInstructions for structural
846 // equality.
StructuralEqual(const HloInstruction & instruction1,const HloInstruction & instruction2)847 static bool StructuralEqual(const HloInstruction& instruction1,
848 const HloInstruction& instruction2) {
849 auto eq_operand_shapes = [](const HloInstruction* a,
850 const HloInstruction* b) {
851 return ShapeUtil::Equal(a->shape(), b->shape());
852 };
853 auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
854 return *a == *b;
855 };
856
857 // Verify Identical is reflexive for both instructions.
858 EXPECT_TRUE(
859 instruction1.Identical(instruction1, eq_operand_shapes, eq_computations));
860 EXPECT_TRUE(
861 instruction2.Identical(instruction2, eq_operand_shapes, eq_computations));
862
863 bool is_equal =
864 instruction1.Identical(instruction2, eq_operand_shapes, eq_computations);
865 // Verify Identical is symmetric.
866 EXPECT_EQ(is_equal, instruction2.Identical(instruction1, eq_operand_shapes,
867 eq_computations));
868 return is_equal;
869 }
870
TEST_F(HloInstructionTest,IdenticalInstructions)871 TEST_F(HloInstructionTest, IdenticalInstructions) {
872 // Test HloInstruction::Identical with some subset of instructions types.
873
874 // Create a set of random constant operands to use below. Make them matrices
875 // so dimensions are interesting.
876 auto operand1 = HloInstruction::CreateConstant(
877 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
878 auto operand2 = HloInstruction::CreateConstant(
879 LiteralUtil::CreateR2<float>({{10.0, 20.0}, {30.0, 40.0}}));
880 auto vector_operand = HloInstruction::CreateConstant(
881 LiteralUtil::CreateR1<float>({42.0, 123.0}));
882 Shape shape = operand1->shape();
883
884 // Convenient short names for the operands.
885 HloInstruction* op1 = operand1.get();
886 HloInstruction* op2 = operand2.get();
887
888 // Operations which only depend on their operands and opcode.
889 EXPECT_TRUE(
890 Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
891 *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1)));
892 EXPECT_FALSE(
893 Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
894 *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2)));
895 EXPECT_FALSE(
896 Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
897 *HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1)));
898
899 // Tuples.
900 EXPECT_TRUE(Identical(*HloInstruction::CreateTuple({op1, op2}),
901 *HloInstruction::CreateTuple({op1, op2})));
902 EXPECT_FALSE(Identical(*HloInstruction::CreateTuple({op1, op2}),
903 *HloInstruction::CreateTuple({op2, op1})));
904
905 // Broadcasts.
906 EXPECT_TRUE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
907 *HloInstruction::CreateBroadcast(shape, op1, {0, 1})));
908 EXPECT_FALSE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
909 *HloInstruction::CreateBroadcast(shape, op1, {1, 0})));
910 Shape bcast_shape1 = ShapeUtil::MakeShape(F32, {2, 2, 42});
911 Shape bcast_shape2 = ShapeUtil::MakeShape(F32, {2, 2, 123});
912 EXPECT_FALSE(
913 Identical(*HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}),
914 *HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1})));
915
916 // Binary operands.
917 EXPECT_TRUE(Identical(
918 *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
919 *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2)));
920 EXPECT_FALSE(Identical(
921 *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
922 *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1)));
923 EXPECT_FALSE(Identical(
924 *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
925 *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2)));
926 }
927
TEST_F(HloInstructionTest,IdenticalCallInstructions)928 TEST_F(HloInstructionTest, IdenticalCallInstructions) {
929 const char* const hlo_string = R"(
930 HloModule Module
931
932 subcomp1 (x: f32[]) -> f32[] {
933 x = f32[] parameter(0)
934 ROOT n = f32[] sine(x)
935 }
936
937 subcomp2 (x: f32[]) -> f32[] {
938 x = f32[] parameter(0)
939 ROOT n = f32[] cosine(x)
940 }
941
942 ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) {
943 p = f32[] parameter(0)
944 t1 = f32[] call(p), to_apply=subcomp1
945 t2 = f32[] call(p), to_apply=subcomp1
946 t3 = f32[] call(p), to_apply=subcomp2
947 ROOT t = (f32[], f32[], f32[]) tuple(t1, t2, t3)
948 }
949 )";
950 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
951 ParseHloString(hlo_string));
952
953 auto* root = module->entry_computation()->root_instruction();
954 auto* t1 = root->operand(0);
955 auto* t2 = root->operand(1);
956 auto* t3 = root->operand(2);
957
958 EXPECT_TRUE(StructuralEqual(*t1, *t2));
959 EXPECT_FALSE(StructuralEqual(*t1, *t3));
960 }
961
TEST_F(HloInstructionTest,FunctionVisitor)962 TEST_F(HloInstructionTest, FunctionVisitor) {
963 // Verify the function visitor HloInstruction::Accept visits all instructions
964 // from a root properly given the following graph:
965 //
966 // param
967 // / \
968 // negate exp
969 // \ /
970 // add
971 const Shape f32 = ShapeUtil::MakeShape(F32, {});
972 HloComputation::Builder builder(TestName());
973 auto param =
974 builder.AddInstruction(HloInstruction::CreateParameter(0, f32, "0"));
975 auto negate = builder.AddInstruction(
976 HloInstruction::CreateUnary(f32, HloOpcode::kNegate, param));
977 auto exp = builder.AddInstruction(
978 HloInstruction::CreateUnary(f32, HloOpcode::kExp, param));
979 auto add = builder.AddInstruction(
980 HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp));
981 auto module = CreateNewVerifiedModule();
982 module->AddEntryComputation(builder.Build());
983
984 int visit_num = 0;
985 absl::flat_hash_map<HloInstruction*, int> visit_order;
986 EXPECT_IS_OK(add->Accept([&visit_num, &visit_order](HloInstruction* inst) {
987 EXPECT_FALSE(visit_order.contains(inst));
988 visit_order[inst] = visit_num;
989 visit_num++;
990 return Status::OK();
991 }));
992
993 EXPECT_EQ(0, visit_order.at(param));
994 // negate and exp can be visited in an arbitrary order.
995 EXPECT_TRUE(visit_order.at(exp) == 1 || visit_order.at(exp) == 2);
996 EXPECT_TRUE(visit_order.at(negate) == 1 || visit_order.at(negate) == 2);
997 EXPECT_NE(visit_order.at(exp), visit_order.at(negate));
998 EXPECT_EQ(3, visit_order.at(add));
999 }
1000
TEST_F(HloInstructionTest,FullyElementwise)1001 TEST_F(HloInstructionTest, FullyElementwise) {
1002 const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
1003 HloComputation::Builder builder(TestName());
1004 auto x =
1005 builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
1006 auto y =
1007 builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
1008 auto add = builder.AddInstruction(
1009 HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y));
1010 auto module = CreateNewVerifiedModule();
1011 module->AddEntryComputation(builder.Build());
1012
1013 EXPECT_TRUE(add->IsElementwise());
1014 for (int i = 0; i < add->operand_count(); ++i) {
1015 EXPECT_TRUE(add->IsElementwiseOnOperand(i));
1016 }
1017 }
1018
TEST_F(HloInstructionTest,MapIsElementwise)1019 TEST_F(HloInstructionTest, MapIsElementwise) {
1020 auto module = CreateNewVerifiedModule();
1021 const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0});
1022 HloComputation::Builder builder(TestName());
1023 HloComputation::Builder map_builder("id");
1024 map_builder.AddInstruction(
1025 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
1026 auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1027 auto x =
1028 builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x"));
1029 auto map = builder.AddInstruction(
1030 HloInstruction::CreateMap(r2f32, {x}, map_computation));
1031 module->AddEntryComputation(builder.Build());
1032
1033 EXPECT_TRUE(map->IsElementwise());
1034 }
1035
TEST_F(HloInstructionTest,PartiallyElementwise)1036 TEST_F(HloInstructionTest, PartiallyElementwise) {
1037 const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
1038 const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5});
1039
1040 // Fused expression:
1041 //
1042 // p0 p1 p2 p3
1043 // \ / / |
1044 // mul / |
1045 // \ / |
1046 // div broadcast
1047 // \ /
1048 // max
1049 //
1050 // The fusion instruction is not elementwise on p3 because the broadcast is
1051 // not elementwise.
1052 HloComputation::Builder builder("PartiallyElementwise");
1053 HloInstruction* p0 =
1054 builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "p0"));
1055 HloInstruction* p1 =
1056 builder.AddInstruction(HloInstruction::CreateParameter(1, r2f32, "p1"));
1057 HloInstruction* p2 =
1058 builder.AddInstruction(HloInstruction::CreateParameter(2, r2f32, "p2"));
1059 HloInstruction* p3 =
1060 builder.AddInstruction(HloInstruction::CreateParameter(3, r1f32, "p3"));
1061 HloInstruction* mul = builder.AddInstruction(
1062 HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, p0, p1));
1063 HloInstruction* div = builder.AddInstruction(
1064 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, mul, p2));
1065 // Dimension 0 of shape [5] is mapped to dimension 1 of shape [3x5].
1066 HloInstruction* broadcast =
1067 builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, p3, {1}));
1068 HloInstruction* max = builder.AddInstruction(
1069 HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
1070
1071 auto module = CreateNewVerifiedModule();
1072 auto* computation = module->AddEntryComputation(builder.Build());
1073 HloInstruction* fusion = computation->CreateFusionInstruction(
1074 {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
1075 EXPECT_FALSE(fusion->IsElementwise());
1076 for (int64 operand_idx = 0; operand_idx < fusion->operand_count();
1077 ++operand_idx) {
1078 const HloInstruction* operand = fusion->operand(operand_idx);
1079 if (operand == p3) {
1080 EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
1081 } else {
1082 EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
1083 }
1084 }
1085 }
1086
TEST_F(HloInstructionTest,PartiallyElementwiseWithReuse)1087 TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
1088 // Fused expression:
1089 // y
1090 // /
1091 // x broadcast
1092 // \ / |
1093 // min |
1094 // \ /
1095 // sub
1096 //
1097 const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1098 const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
1099
1100 HloComputation::Builder builder("PartiallyElementwiseWithReuse");
1101 HloInstruction* x =
1102 builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
1103 HloInstruction* y =
1104 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y"));
1105 HloInstruction* broadcast =
1106 builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {}));
1107 HloInstruction* min = builder.AddInstruction(
1108 HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, broadcast));
1109 HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
1110 r1f32, HloOpcode::kSubtract, min, broadcast));
1111
1112 auto module = CreateNewVerifiedModule();
1113 auto* computation = module->AddEntryComputation(builder.Build());
1114 HloInstruction* fusion = computation->CreateFusionInstruction(
1115 {sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
1116 EXPECT_FALSE(fusion->IsElementwise());
1117 for (int64 operand_idx = 0; operand_idx < fusion->operand_count();
1118 ++operand_idx) {
1119 if (fusion->operand(operand_idx) == y) {
1120 EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
1121 } else {
1122 EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
1123 }
1124 }
1125 }
1126
TEST_F(HloInstructionTest,CloneOfFusionPreservesShape)1127 TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
1128 // Fused expression:
1129 //
1130 // x y
1131 // | |
1132 // | transpose
1133 // \ /
1134 // dot
1135 //
1136 // Tests that shapes aren't mangled by Clone().
1137 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1138 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1139 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1140 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1141
1142 HloComputation::Builder builder("TransposeDot");
1143 HloInstruction* x =
1144 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1145 HloInstruction* y =
1146 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1147 HloInstruction* reshape =
1148 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1149 DotDimensionNumbers dot_dnums;
1150 dot_dnums.add_lhs_contracting_dimensions(1);
1151 dot_dnums.add_rhs_contracting_dimensions(0);
1152 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1153 sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1154
1155 auto module = CreateNewVerifiedModule();
1156 auto* computation = module->AddEntryComputation(builder.Build());
1157 HloInstruction* fusion = computation->CreateFusionInstruction(
1158 {dot, reshape}, HloInstruction::FusionKind::kLoop);
1159
1160 auto fusion2 = fusion->Clone();
1161 const HloInstruction* root = fusion->fused_expression_root();
1162 const HloInstruction* root2 = fusion2->fused_expression_root();
1163 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), root2->shape()));
1164 EXPECT_TRUE(
1165 ShapeUtil::Equal(root->operand(0)->shape(), root2->operand(0)->shape()));
1166 EXPECT_TRUE(
1167 ShapeUtil::Equal(root->operand(1)->shape(), root2->operand(1)->shape()));
1168 EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->operand(0)->shape(),
1169 root2->operand(1)->operand(0)->shape()));
1170 EXPECT_TRUE(StructuralEqual(*fusion, *fusion2));
1171 }
1172
TEST_F(HloInstructionTest,NoRedundantFusionOperandsAfterReplacingUse)1173 TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
1174 // Fused expression:
1175 //
1176 // x y
1177 // | |
1178 // | transpose
1179 // \ /
1180 // dot
1181 const Shape s = ShapeUtil::MakeShape(F32, {10, 10});
1182
1183 HloComputation::Builder builder("TransposeDot");
1184 HloInstruction* x =
1185 builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x"));
1186 HloInstruction* y =
1187 builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y"));
1188 HloInstruction* reshape =
1189 builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0}));
1190 DotDimensionNumbers dot_dnums;
1191 dot_dnums.add_lhs_contracting_dimensions(1);
1192 dot_dnums.add_rhs_contracting_dimensions(0);
1193 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1194 s, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1195
1196 auto module = CreateNewVerifiedModule();
1197 auto* computation = module->AddEntryComputation(builder.Build());
1198 HloInstruction* fusion = computation->CreateFusionInstruction(
1199 {dot, reshape}, HloInstruction::FusionKind::kLoop);
1200
1201 EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok());
1202
1203 EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y));
1204 EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1);
1205 }
1206
TEST_F(HloInstructionTest,FusionEquality)1207 TEST_F(HloInstructionTest, FusionEquality) {
1208 auto module = CreateNewVerifiedModule();
1209 HloComputation::Builder builder(TestName());
1210
1211 // Create two fusion instructions containing a single unary operation.
1212 auto parameter =
1213 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
1214 auto exp = builder.AddInstruction(
1215 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter));
1216 auto neg = builder.AddInstruction(
1217 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter));
1218 auto* computation = module->AddEntryComputation(builder.Build());
1219 auto* fusion = computation->CreateFusionInstruction(
1220 {exp}, HloInstruction::FusionKind::kLoop);
1221 auto* fusion2 = computation->CreateFusionInstruction(
1222 {neg}, HloInstruction::FusionKind::kLoop);
1223 EXPECT_FALSE(StructuralEqual(*fusion, *fusion2));
1224
1225 auto clone = fusion->Clone();
1226 EXPECT_TRUE(StructuralEqual(*fusion, *clone));
1227 }
1228
TEST_F(HloInstructionTest,NestedFusionEquality)1229 TEST_F(HloInstructionTest, NestedFusionEquality) {
1230 auto module = CreateNewVerifiedModule();
1231 HloComputation::Builder builder(TestName());
1232
1233 // Build a nested fusion computation.
1234 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
1235 auto a = builder.AddInstruction(HloInstruction::CreateConstant(
1236 LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
1237 auto b = builder.AddInstruction(HloInstruction::CreateConstant(
1238 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
1239 auto b_t = builder.AddInstruction(
1240 HloInstruction::CreateTranspose(data_shape, b, {1, 0}));
1241 DotDimensionNumbers dot_dnums;
1242 dot_dnums.add_lhs_contracting_dimensions(1);
1243 dot_dnums.add_rhs_contracting_dimensions(0);
1244 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
1245 data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2)));
1246 auto one = builder.AddInstruction(
1247 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1248 auto add_operand = builder.AddInstruction(
1249 HloInstruction::CreateBroadcast(data_shape, one, {}));
1250 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1251 data_shape, HloOpcode::kAdd, dot, add_operand));
1252 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
1253 data_shape, HloOpcode::kSubtract, dot, add_operand));
1254 builder.AddInstruction(
1255 HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub));
1256 auto computation = module->AddEntryComputation(builder.Build());
1257
1258 auto nested_fusion = computation->CreateFusionInstruction(
1259 {dot, b_t}, HloInstruction::FusionKind::kLoop);
1260
1261 auto fusion = computation->CreateFusionInstruction(
1262 {add, nested_fusion}, HloInstruction::FusionKind::kOutput);
1263 auto fusion2 = computation->CreateFusionInstruction(
1264 {sub, nested_fusion}, HloInstruction::FusionKind::kOutput);
1265 auto clone = fusion->Clone();
1266 EXPECT_TRUE(StructuralEqual(*fusion, *clone));
1267 EXPECT_FALSE(StructuralEqual(*fusion, *fusion2));
1268 }
1269
TEST_F(HloInstructionTest,CloneSuffixNames)1270 TEST_F(HloInstructionTest, CloneSuffixNames) {
1271 // Test that the suffix string added to cloned instructions is not
1272 // duplicated. Rather a numeric incrementing value should be appended. That
1273 // is, we want "foo.clone2", not "foo.clone.clone".
1274
1275 // Test cloning the same instruction multiple times.
1276 auto foo =
1277 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo");
1278 EXPECT_EQ(foo->Clone()->name(), "foo.clone");
1279 EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2");
1280 EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3");
1281
1282 // Test custom suffixes.
1283 EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar");
1284 EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2");
1285 EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone");
1286
1287 // Test instruction name with a dot.
1288 auto foo_baz = HloInstruction::CreateParameter(
1289 0, ShapeUtil::MakeShape(F32, {}), "foo.baz");
1290 EXPECT_EQ(foo_baz->Clone()->name(), "foo.baz.clone");
1291
1292 // Test incrementing a large number after the suffix.
1293 auto foo_clone234 = HloInstruction::CreateParameter(
1294 0, ShapeUtil::MakeShape(F32, {}), "foo.clone234");
1295 EXPECT_EQ(foo_clone234->Clone()->name(), "foo.clone235");
1296
1297 // Test a non-numeric string after the cloning suffix.
1298 auto foo_clonexyz = HloInstruction::CreateParameter(
1299 0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz");
1300 EXPECT_EQ(foo_clonexyz->Clone()->name(), "foo.clonexyz.clone");
1301
1302 // Test a name with multiple appearances of the suffix.
1303 auto foo_clone_clone3 = HloInstruction::CreateParameter(
1304 0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3");
1305 EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4");
1306 }
1307
TEST_F(HloInstructionTest,Stringification)1308 TEST_F(HloInstructionTest, Stringification) {
1309 // Tests stringification of a simple op, fusion, while, and conditional.
1310 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1311 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1312 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1313 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1314
1315 HloComputation::Builder builder("TransposeDot");
1316 HloInstruction* x =
1317 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1318 HloInstruction* y =
1319 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1320 HloInstruction* reshape =
1321 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1322 DotDimensionNumbers dot_dnums;
1323 dot_dnums.add_lhs_contracting_dimensions(1);
1324 dot_dnums.add_rhs_contracting_dimensions(0);
1325 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1326 sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1327
1328 auto options = HloPrintOptions().set_print_metadata(false);
1329
1330 EXPECT_EQ(dot->ToString(options),
1331 "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} "
1332 "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}");
1333
1334 auto module = CreateNewVerifiedModule();
1335 auto* computation = module->AddEntryComputation(builder.Build());
1336
1337 HloInstruction* loop = builder.AddInstruction(
1338 HloInstruction::CreateWhile(sout, computation, computation, x));
1339 EXPECT_EQ(loop->ToString(options),
1340 "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), "
1341 "condition=%TransposeDot, body=%TransposeDot");
1342
1343 auto pred = builder.AddInstruction(
1344 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1345 HloInstruction* conditional =
1346 builder.AddInstruction(HloInstruction::CreateConditional(
1347 sout, pred, x, computation, x, computation));
1348 EXPECT_EQ(conditional->ToString(options),
1349 "%conditional = f32[5,20]{1,0} conditional(pred[] %constant, "
1350 "f32[5,10]{1,0} %x, f32[5,10]{1,0} %x), "
1351 "true_computation=%TransposeDot, false_computation=%TransposeDot");
1352 }
1353
TEST_F(HloInstructionTest,StringifyGather_0)1354 TEST_F(HloInstructionTest, StringifyGather_0) {
1355 Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1356 Shape start_indices_tensor_shape =
1357 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
1358 Shape gather_result_shape =
1359 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26});
1360
1361 HloComputation::Builder builder("Gather");
1362 HloInstruction* input = builder.AddInstruction(
1363 HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
1364 HloInstruction* start_indices =
1365 builder.AddInstruction(HloInstruction::CreateParameter(
1366 1, start_indices_tensor_shape, "start_indices"));
1367
1368 HloInstruction* gather_instruction = builder.AddInstruction(
1369 HloInstruction::CreateGather(gather_result_shape, input, start_indices,
1370 HloGatherInstruction::MakeGatherDimNumbers(
1371 /*offset_dims=*/{4, 5, 6, 7, 8},
1372 /*collapsed_slice_dims=*/{},
1373 /*start_index_map=*/{0, 1, 2, 3, 4},
1374 /*index_vector_dim=*/4),
1375 /*slice_sizes=*/{30, 29, 28, 27, 26}));
1376
1377 auto module = CreateNewVerifiedModule();
1378 module->AddEntryComputation(builder.Build());
1379
1380 EXPECT_EQ(gather_instruction->ToString(),
1381 "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
1382 "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
1383 "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), "
1384 "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
1385 "start_index_map={0,1,2,3,4}, "
1386 "index_vector_dim=4, slice_sizes={30,29,28,27,26}");
1387 }
1388
TEST_F(HloInstructionTest,StringifyGather_1)1389 TEST_F(HloInstructionTest, StringifyGather_1) {
1390 Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1391 Shape start_indices_tensor_shape =
1392 ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
1393 Shape gather_result_shape =
1394 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
1395
1396 HloComputation::Builder builder("Gather");
1397 HloInstruction* input = builder.AddInstruction(
1398 HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
1399 HloInstruction* start_indices =
1400 builder.AddInstruction(HloInstruction::CreateParameter(
1401 1, start_indices_tensor_shape, "start_indices"));
1402
1403 HloInstruction* gather_instruction = builder.AddInstruction(
1404 HloInstruction::CreateGather(gather_result_shape, input, start_indices,
1405 HloGatherInstruction::MakeGatherDimNumbers(
1406 /*offset_dims=*/{4, 5, 6, 7, 8},
1407 /*collapsed_slice_dims=*/{},
1408 /*start_index_map=*/{0, 1, 2, 3, 4},
1409 /*index_vector_dim=*/2),
1410 /*slice_sizes=*/{30, 29, 28, 27, 26}));
1411
1412 auto module = CreateNewVerifiedModule();
1413 module->AddEntryComputation(builder.Build());
1414
1415 EXPECT_EQ(gather_instruction->ToString(),
1416 "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
1417 "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
1418 "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), "
1419 "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
1420 "start_index_map={0,1,2,3,4}, "
1421 "index_vector_dim=2, slice_sizes={30,29,28,27,26}");
1422 }
1423
TEST_F(HloInstructionTest,StringifyScatter)1424 TEST_F(HloInstructionTest, StringifyScatter) {
1425 Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1426 Shape scatter_indices_tensor_shape =
1427 ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
1428 Shape scatter_updates_shape =
1429 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
1430
1431 HloComputation::Builder builder("Scatter");
1432 HloInstruction* input = builder.AddInstruction(
1433 HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
1434 HloInstruction* scatter_indices =
1435 builder.AddInstruction(HloInstruction::CreateParameter(
1436 1, scatter_indices_tensor_shape, "scatter_indices"));
1437 HloInstruction* scatter_updates =
1438 builder.AddInstruction(HloInstruction::CreateParameter(
1439 2, scatter_updates_shape, "scatter_updates"));
1440
1441 HloComputation::Builder update_builder("Scatter.update");
1442 update_builder.AddInstruction(
1443 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1"));
1444 update_builder.AddInstruction(
1445 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2"));
1446
1447 auto module = CreateNewVerifiedModule();
1448 auto* update_computation =
1449 module->AddEmbeddedComputation(update_builder.Build());
1450
1451 HloInstruction* scatter_instruction =
1452 builder.AddInstruction(HloInstruction::CreateScatter(
1453 input_tensor_shape, input, scatter_indices, scatter_updates,
1454 update_computation,
1455 HloScatterInstruction::MakeScatterDimNumbers(
1456 /*update_window_dims=*/{4, 5, 6, 7, 8},
1457 /*inserted_window_dims=*/{},
1458 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
1459 /*index_vector_dim=*/2)));
1460 module->AddEntryComputation(builder.Build());
1461
1462 EXPECT_EQ(
1463 scatter_instruction->ToString(),
1464 "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} "
1465 "scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
1466 "s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, "
1467 "f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), "
1468 "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, "
1469 "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, "
1470 "to_apply=%Scatter.update");
1471 }
1472
TEST_F(HloInstructionTest,CanonnicalStringificationFusion)1473 TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
1474 // Tests stringification of a simple op, fusion, while, and conditional.
1475 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1476 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1477 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1478 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1479
1480 HloComputation::Builder builder("TransposeDot");
1481 HloInstruction* x =
1482 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1483 HloInstruction* y =
1484 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1485 HloInstruction* reshape =
1486 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1487 DotDimensionNumbers dot_dnums;
1488 dot_dnums.add_lhs_contracting_dimensions(1);
1489 dot_dnums.add_rhs_contracting_dimensions(0);
1490 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1491 sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1492
1493 auto options = HloPrintOptions().Canonical();
1494
1495 EXPECT_EQ(dot->ToString(options),
1496 "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), "
1497 "lhs_contracting_dims={1}, rhs_contracting_dims={0}");
1498
1499 auto module = CreateNewVerifiedModule();
1500 auto* computation = module->AddEntryComputation(builder.Build());
1501 HloInstruction* fusion = computation->CreateFusionInstruction(
1502 {dot, reshape}, HloInstruction::FusionKind::kLoop);
1503
1504 const string expected_fusion =
1505 R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls=
1506 {
1507 tmp_0 = f32[5,10]{1,0} parameter(0)
1508 tmp_1 = f32[20,10]{1,0} parameter(1)
1509 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1510 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1511 })";
1512 EXPECT_EQ(fusion->ToString(options), expected_fusion);
1513 }
1514
TEST_F(HloInstructionTest,CanonnicalStringificationWhile)1515 TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
1516 // Tests stringification of a simple op, fusion, while, and conditional.
1517 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1518 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1519 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1520 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1521
1522 HloComputation::Builder builder("TransposeDot");
1523 HloInstruction* x =
1524 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1525 HloInstruction* y =
1526 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1527 HloInstruction* reshape =
1528 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1529 DotDimensionNumbers dot_dnums;
1530 dot_dnums.add_lhs_contracting_dimensions(1);
1531 dot_dnums.add_rhs_contracting_dimensions(0);
1532 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1533 sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1534
1535 auto module = CreateNewVerifiedModule();
1536 auto* computation = module->AddEntryComputation(builder.Build());
1537 computation->CreateFusionInstruction({dot, reshape},
1538 HloInstruction::FusionKind::kLoop);
1539
1540 HloInstruction* loop = builder.AddInstruction(
1541 HloInstruction::CreateWhile(sout, computation, computation, x));
1542
1543 auto options = HloPrintOptions().Canonical();
1544 const string expected_loop =
1545 R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
1546 {
1547 tmp_0 = f32[5,10]{1,0} parameter(0)
1548 tmp_1 = f32[20,10]{1,0} parameter(1)
1549 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1550 {
1551 tmp_0 = f32[5,10]{1,0} parameter(0)
1552 tmp_1 = f32[20,10]{1,0} parameter(1)
1553 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1554 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1555 }
1556 }, body=
1557 {
1558 tmp_0 = f32[5,10]{1,0} parameter(0)
1559 tmp_1 = f32[20,10]{1,0} parameter(1)
1560 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1561 {
1562 tmp_0 = f32[5,10]{1,0} parameter(0)
1563 tmp_1 = f32[20,10]{1,0} parameter(1)
1564 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1565 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1566 }
1567 })";
1568 EXPECT_EQ(loop->ToString(options), expected_loop);
1569 }
1570
TEST_F(HloInstructionTest,CanonnicalStringificationConditional)1571 TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
1572 // Tests stringification of a simple op, fusion, while, and conditional.
1573 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1574 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1575 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1576 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1577
1578 HloComputation::Builder builder("TransposeDot");
1579 HloInstruction* x =
1580 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1581 HloInstruction* y =
1582 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1583 HloInstruction* reshape =
1584 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1585 DotDimensionNumbers dot_dnums;
1586 dot_dnums.add_lhs_contracting_dimensions(1);
1587 dot_dnums.add_rhs_contracting_dimensions(0);
1588 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1589 sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1590
1591 auto module = CreateNewVerifiedModule();
1592 auto* computation = module->AddEntryComputation(builder.Build());
1593 computation->CreateFusionInstruction({dot, reshape},
1594 HloInstruction::FusionKind::kLoop);
1595
1596 builder.AddInstruction(
1597 HloInstruction::CreateWhile(sout, computation, computation, x));
1598
1599 auto pred = builder.AddInstruction(
1600 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1601 HloInstruction* conditional =
1602 builder.AddInstruction(HloInstruction::CreateConditional(
1603 sout, pred, x, computation, x, computation));
1604 auto options = HloPrintOptions().Canonical();
1605 const string expected_conditional =
1606 R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation=
1607 {
1608 tmp_0 = f32[5,10]{1,0} parameter(0)
1609 tmp_1 = f32[20,10]{1,0} parameter(1)
1610 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1611 {
1612 tmp_0 = f32[5,10]{1,0} parameter(0)
1613 tmp_1 = f32[20,10]{1,0} parameter(1)
1614 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1615 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1616 }
1617 }, false_computation=
1618 {
1619 tmp_0 = f32[5,10]{1,0} parameter(0)
1620 tmp_1 = f32[20,10]{1,0} parameter(1)
1621 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1622 {
1623 tmp_0 = f32[5,10]{1,0} parameter(0)
1624 tmp_1 = f32[20,10]{1,0} parameter(1)
1625 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1626 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1627 }
1628 })";
1629 EXPECT_EQ(conditional->ToString(options), expected_conditional);
1630 }
1631
TEST_F(HloInstructionTest,CheckDeepClone)1632 TEST_F(HloInstructionTest, CheckDeepClone) {
1633 const char* const hlo_string = R"(
1634 HloModule Module
1635
1636 addy (lhs: s32[], rhs: s32[]) -> s32[] {
1637 lhs = s32[] parameter(0)
1638 rhs = s32[] parameter(1)
1639 ROOT zadd = s32[] add(lhs, rhs)
1640 }
1641
1642 calla (x: s32[]) -> s32[] {
1643 x = s32[] parameter(0)
1644 reduce = s32[] reduce-window(x, x), to_apply=addy
1645 ROOT xadd = s32[] add(x, reduce)
1646 }
1647
1648 body (bparam: s32[]) -> s32[] {
1649 constant = s32[] constant(1)
1650 bparam = s32[] parameter(0)
1651 v = s32[] call(bparam), to_apply=calla
1652 ROOT add = s32[] add(constant, bparam)
1653 }
1654
1655 condition (cparam: s32[]) -> pred[] {
1656 xconstant = s32[] constant(5)
1657 cparam = s32[] parameter(0)
1658 ROOT greater-than = pred[] compare(xconstant, cparam), direction=GT
1659 }
1660
1661 ENTRY entry (param: s32[]) -> s32[] {
1662 eparam = s32[] parameter(0)
1663 ROOT while = s32[] while(eparam), condition=condition, body=body
1664 }
1665 )";
1666 // Check that deep clones really deep clones every instruction and
1667 // computations, without leaving dangling pointers to the old module.
1668 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1669 ParseHloString(hlo_string));
1670 std::unique_ptr<HloModule> clone = module->Clone();
1671 for (HloComputation* computation : clone->computations()) {
1672 EXPECT_EQ(computation->parent(), clone.get());
1673 for (HloInstruction* instruction : computation->instructions()) {
1674 EXPECT_EQ(instruction->parent()->parent(), clone.get());
1675 }
1676 }
1677 }
1678
TEST_F(HloInstructionTest,IdenticalAccountsForBackendConfig)1679 TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) {
1680 const Shape shape = ShapeUtil::MakeShape(F32, {42});
1681 HloComputation::Builder builder("test");
1682 HloInstruction* p =
1683 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
1684
1685 HloInstruction* add1 = builder.AddInstruction(
1686 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p));
1687 HloInstruction* add2 = builder.AddInstruction(
1688 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p));
1689
1690 EXPECT_TRUE(add1->Identical(*add2));
1691 add1->set_raw_backend_config_string("abc");
1692 EXPECT_FALSE(add1->Identical(*add2));
1693 }
1694
TEST_F(HloInstructionTest,IdenticalAccountsForCustomCallWindow)1695 TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) {
1696 auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
1697 /*operands=*/{},
1698 /*custom_call_target=*/"foo");
1699 auto instr2 = instr1->Clone();
1700 EXPECT_TRUE(instr1->Identical(*instr2));
1701
1702 Window w = window_util::MakeWindow({1, 2, 3});
1703 instr1->set_window(w);
1704 EXPECT_FALSE(instr1->Identical(*instr2));
1705 }
1706
TEST_F(HloInstructionTest,IdenticalAccountsForCustomCallDnums)1707 TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) {
1708 auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
1709 /*operands=*/{},
1710 /*custom_call_target=*/"foo");
1711 auto instr2 = instr1->Clone();
1712 EXPECT_TRUE(instr1->Identical(*instr2));
1713
1714 ConvolutionDimensionNumbers dnums;
1715 dnums.set_output_batch_dimension(42);
1716 instr1->set_convolution_dimension_numbers(dnums);
1717 EXPECT_FALSE(instr1->Identical(*instr2));
1718 }
1719
TEST_F(HloInstructionTest,CloneWindowOnCustomCall)1720 TEST_F(HloInstructionTest, CloneWindowOnCustomCall) {
1721 auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
1722 /*operands=*/{},
1723 /*custom_call_target=*/"foo");
1724 Window w = window_util::MakeWindow({1, 2, 3});
1725 instr->set_window(w);
1726 auto clone = instr->Clone();
1727 EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w))
1728 << clone->window().DebugString();
1729 }
1730
TEST_F(HloInstructionTest,CloneDnumsOnCustomCall)1731 TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) {
1732 auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
1733 /*operands=*/{},
1734 /*custom_call_target=*/"foo");
1735 ConvolutionDimensionNumbers dnums;
1736 dnums.set_output_batch_dimension(42);
1737 instr->set_convolution_dimension_numbers(dnums);
1738 auto clone = instr->Clone();
1739 EXPECT_TRUE(protobuf_util::ProtobufEquals(
1740 clone->convolution_dimension_numbers(), dnums))
1741 << clone->convolution_dimension_numbers().DebugString();
1742 }
1743
TEST_F(HloInstructionTest,PreserveOperandPrecisionOnCloneConv)1744 TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) {
1745 constexpr char kHloString[] = R"(
1746 HloModule test_module
1747 ENTRY test {
1748 arg0 = f32[1,2,1] parameter(0)
1749 arg1 = f32[1,1,1] parameter(1)
1750 ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1},
1751 dim_labels=b0f_0io->b0f, operand_precision={high,default}
1752 })";
1753 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kHloString));
1754 auto* conv = module->entry_computation()->root_instruction();
1755
1756 auto clone = conv->Clone();
1757 EXPECT_THAT(
1758 clone->precision_config().operand_precision(),
1759 ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT));
1760 }
1761
1762 } // namespace
1763 } // namespace xla
1764