1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/test_helpers.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28
29 namespace op = xla::testing::opcode_matchers;
30
31 namespace xla {
32
33 using ::testing::UnorderedElementsAre;
34
35 class ReducePrecisionInsertionTest : public HloTestBase {
36 protected:
InsertOps(HloModule * module,const HloReducePrecisionOptions::Location location,const std::function<bool (const HloInstruction *)> & filter)37 bool InsertOps(HloModule* module,
38 const HloReducePrecisionOptions::Location location,
39 const std::function<bool(const HloInstruction*)>& filter) {
40 ReducePrecisionInsertion op_insertion(5, 10, location, filter);
41 StatusOr<bool> result = op_insertion.Run(module);
42 EXPECT_IS_OK(result.status());
43 return result.ValueOrDie();
44 }
45 };
46
TEST_F(ReducePrecisionInsertionTest,BeforeUnaryInstruction)47 TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) {
48 auto builder = HloComputation::Builder(TestName());
49 Shape shape = ShapeUtil::MakeShape(F32, {4});
50
51 // Create a simple graph with a parameter feeding a unary cosine function.
52 HloInstruction* a =
53 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
54 HloInstruction* b = builder.AddInstruction(
55 HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
56
57 auto module = CreateNewVerifiedModule();
58 auto computation = module->AddEntryComputation(builder.Build());
59
60 // Confirm expected state before adding ops.
61 EXPECT_EQ(computation->root_instruction(), b);
62 EXPECT_EQ(b->operand(0), a);
63
64 EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
65 [](const HloInstruction* instruction) {
66 return instruction->opcode() == HloOpcode::kCos;
67 }));
68
69 // Confirm expected graph after adding ops.
70 EXPECT_EQ(computation->root_instruction(), b);
71 EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
72 }
73
TEST_F(ReducePrecisionInsertionTest,BeforeUnaryScalarInstruction)74 TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) {
75 auto builder = HloComputation::Builder(TestName());
76 Shape shape = ShapeUtil::MakeShape(F32, {});
77
78 // Create a simple graph with a parameter feeding a unary cosine function.
79 HloInstruction* a =
80 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
81 HloInstruction* b = builder.AddInstruction(
82 HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
83
84 auto module = CreateNewVerifiedModule();
85 auto computation = module->AddEntryComputation(builder.Build());
86
87 // Confirm expected state before adding ops.
88 EXPECT_EQ(computation->root_instruction(), b);
89 EXPECT_EQ(b->operand(0), a);
90
91 EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
92 [](const HloInstruction* instruction) {
93 return instruction->opcode() == HloOpcode::kCos;
94 }));
95
96 // Confirm expected graph after adding ops.
97 EXPECT_EQ(computation->root_instruction(), b);
98 EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
99 }
100
TEST_F(ReducePrecisionInsertionTest,BeforeBinaryInstruction)101 TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) {
102 auto builder = HloComputation::Builder(TestName());
103 Shape shape = ShapeUtil::MakeShape(F32, {4});
104
105 // Create a simple graph with parameter feeding a binary add function.
106
107 HloInstruction* a =
108 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
109 HloInstruction* b =
110 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
111 HloInstruction* c = builder.AddInstruction(
112 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
113
114 auto module = CreateNewVerifiedModule();
115 auto computation = module->AddEntryComputation(builder.Build());
116
117 // Confirm expected state before adding ops.
118 EXPECT_EQ(computation->root_instruction(), c);
119 EXPECT_EQ(c->operand(0), a);
120 EXPECT_EQ(c->operand(1), b);
121
122 EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
123 [](const HloInstruction* instruction) {
124 return instruction->opcode() == HloOpcode::kAdd;
125 }));
126
127 // Confirm expected graph after adding ops.
128 EXPECT_EQ(computation->root_instruction(), c);
129 EXPECT_THAT(c->operand(0), op::ReducePrecision(a));
130 EXPECT_THAT(c->operand(1), op::ReducePrecision(b));
131 }
132
TEST_F(ReducePrecisionInsertionTest,BeforeZeroInputInstruction)133 TEST_F(ReducePrecisionInsertionTest, BeforeZeroInputInstruction) {
134 auto builder = HloComputation::Builder(TestName());
135 Shape shape = ShapeUtil::MakeShape(F32, {4});
136
137 // Create a simple graph with a parameter feeding a unary cosine function.
138 HloInstruction* a =
139 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
140 HloInstruction* b = builder.AddInstruction(
141 HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
142
143 auto module = CreateNewVerifiedModule();
144 auto computation = module->AddEntryComputation(builder.Build());
145
146 // Confirm expected state before adding ops.
147 EXPECT_EQ(computation->root_instruction(), b);
148 EXPECT_EQ(b->operand(0), a);
149
150 EXPECT_FALSE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
151 [](const HloInstruction* instruction) {
152 return instruction->opcode() ==
153 HloOpcode::kParameter;
154 }));
155
156 // Confirm that graph has not changed.
157 EXPECT_EQ(computation->root_instruction(), b);
158 EXPECT_EQ(b->operand(0), a);
159 }
160
TEST_F(ReducePrecisionInsertionTest,AvoidAddingDuplicateInstructions)161 TEST_F(ReducePrecisionInsertionTest, AvoidAddingDuplicateInstructions) {
162 auto builder = HloComputation::Builder(TestName());
163 Shape shape = ShapeUtil::MakeShape(F32, {4});
164
165 // Create a simple graph with parameter feeding a binary add function.
166
167 HloInstruction* a =
168 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
169 HloInstruction* b = builder.AddInstruction(
170 HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
171 HloInstruction* c = builder.AddInstruction(
172 HloInstruction::CreateUnary(shape, HloOpcode::kSin, a));
173 HloInstruction* d = builder.AddInstruction(
174 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, c));
175
176 auto module = CreateNewVerifiedModule();
177 auto computation = module->AddEntryComputation(builder.Build());
178
179 // Confirm expected state before adding ops.
180 EXPECT_EQ(computation->root_instruction(), d);
181 EXPECT_EQ(b->operand(0), a);
182 EXPECT_EQ(c->operand(0), a);
183
184 EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
185 [](const HloInstruction* instruction) {
186 return instruction->opcode() == HloOpcode::kCos ||
187 instruction->opcode() == HloOpcode::kSin;
188 }));
189
190 // Confirm expected graph after adding ops. In particular, we want to confirm
191 // that the reduced-precision operation added for the input to b is re-used
192 // for the input to c.
193 EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
194 EXPECT_THAT(c->operand(0), op::ReducePrecision(a));
195 EXPECT_EQ(b->operand(0), c->operand(0));
196 }
197
TEST_F(ReducePrecisionInsertionTest,AfterRootInstruction)198 TEST_F(ReducePrecisionInsertionTest, AfterRootInstruction) {
199 auto builder = HloComputation::Builder(TestName());
200 Shape shape = ShapeUtil::MakeShape(F32, {4});
201
202 // Create a simple graph with a parameter feeding a unary cosine function.
203 HloInstruction* a =
204 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
205 HloInstruction* b = builder.AddInstruction(
206 HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
207
208 auto module = CreateNewVerifiedModule();
209 auto computation = module->AddEntryComputation(builder.Build());
210
211 // Confirm expected state before adding ops.
212 EXPECT_EQ(computation->root_instruction(), b);
213
214 EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
215 [](const HloInstruction* instruction) {
216 return instruction->opcode() == HloOpcode::kCos;
217 }));
218
219 // Confirm expected graph after adding ops.
220 EXPECT_THAT(computation->root_instruction(), op::ReducePrecision(b));
221 }
222
TEST_F(ReducePrecisionInsertionTest,AfterNonRootInstruction)223 TEST_F(ReducePrecisionInsertionTest, AfterNonRootInstruction) {
224 auto builder = HloComputation::Builder(TestName());
225 Shape shape = ShapeUtil::MakeShape(F32, {4});
226
227 // Create a graph with two parameters feeding into unary cosine functions,
228 // and the output of those feeds into an add function. Feeding the outputs
229 // from the suffixed cosine functions into a binary add function allows us to
230 // confirm that the separate operand streams are not crossed when the new
231 // instructions are inserted.
232 HloInstruction* a =
233 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
234 HloInstruction* a_cos = builder.AddInstruction(
235 HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
236
237 HloInstruction* b =
238 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
239 HloInstruction* b_cos = builder.AddInstruction(
240 HloInstruction::CreateUnary(shape, HloOpcode::kCos, b));
241
242 HloInstruction* c = builder.AddInstruction(
243 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos));
244
245 auto module = CreateNewVerifiedModule();
246 module->AddEntryComputation(builder.Build());
247
248 // Confirm expected graph before adding ops.
249 EXPECT_EQ(c->operand(0), a_cos);
250 EXPECT_EQ(c->operand(1), b_cos);
251
252 EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
253 [](const HloInstruction* instruction) {
254 return instruction->opcode() == HloOpcode::kCos;
255 }));
256
257 // Confirm expected graph after adding ops.
258 EXPECT_THAT(c->operand(0), op::ReducePrecision());
259 EXPECT_EQ(c->operand(0)->operand(0), a_cos);
260 EXPECT_THAT(c->operand(1), op::ReducePrecision());
261 EXPECT_EQ(c->operand(1)->operand(0), b_cos);
262 }
263
TEST_F(ReducePrecisionInsertionTest,OutputIsNotFloat)264 TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) {
265 auto builder = HloComputation::Builder(TestName());
266 Shape shape = ShapeUtil::MakeShape(S32, {4});
267 HloInstruction* x =
268 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
269 HloInstruction* y = builder.AddInstruction(
270 HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
271
272 auto module = CreateNewUnverifiedModule();
273 auto computation = module->AddEntryComputation(builder.Build());
274
275 // Confirm expected graph before adding ops.
276 EXPECT_THAT(x->users(), UnorderedElementsAre(y));
277 EXPECT_EQ(computation->root_instruction(), y);
278
279 // Since none of the instructions produce F32 data, this should not change
280 // the graph.
281 EXPECT_FALSE(
282 InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
283 [](const HloInstruction* instruction) { return true; }));
284
285 // Confirm that graph has not changed.
286 EXPECT_THAT(x->users(), UnorderedElementsAre(y));
287 EXPECT_EQ(computation->root_instruction(), y);
288 }
289
TEST_F(ReducePrecisionInsertionTest,ShouldReduceOutputPrecisionIsFalse)290 TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) {
291 auto builder = HloComputation::Builder(TestName());
292 Shape shape = ShapeUtil::MakeShape(F32, {4});
293 HloInstruction* x =
294 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
295 HloInstruction* y = builder.AddInstruction(
296 HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
297
298 auto module = CreateNewVerifiedModule();
299 auto computation = module->AddEntryComputation(builder.Build());
300
301 // Confirm expected graph before adding ops.
302 EXPECT_THAT(x->users(), UnorderedElementsAre(y));
303 EXPECT_EQ(computation->root_instruction(), y);
304
305 // Since none of the instructions match the should_reduce_output_precision
306 // function, this should not change the graph.
307 EXPECT_FALSE(
308 InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
309 [](const HloInstruction* instruction) { return false; }));
310
311 // Confirm that graph has not changed.
312 EXPECT_THAT(x->users(), UnorderedElementsAre(y));
313 EXPECT_EQ(computation->root_instruction(), y);
314 }
315
TEST_F(ReducePrecisionInsertionTest,InsertionIsNotRecursive)316 TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) {
317 auto builder = HloComputation::Builder(TestName());
318 Shape shape = ShapeUtil::MakeShape(F32, {4});
319 HloInstruction* a =
320 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
321 HloInstruction* b = builder.AddInstruction(
322 HloInstruction::CreateReducePrecision(shape, a, 8, 23));
323
324 auto module = CreateNewVerifiedModule();
325 auto computation = module->AddEntryComputation(builder.Build());
326
327 // Confirm expected state before adding ops.
328 EXPECT_EQ(computation->root_instruction(), b);
329
330 // This should insert a new ReducePrecision after the existing one, but
331 // should not then recurse by adding another after the just-inserted one.
332 EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
333 [](const HloInstruction* instruction) {
334 return instruction->opcode() ==
335 HloOpcode::kReducePrecision;
336 }));
337
338 // Confirm expected graph after adding ops.
339 EXPECT_THAT(computation->root_instruction(), op::ReducePrecision());
340 EXPECT_EQ(computation->root_instruction()->operand(0), b);
341 }
342
TEST_F(ReducePrecisionInsertionTest,SkipRedundantReducePrecisionAfter)343 TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecisionAfter) {
344 auto builder = HloComputation::Builder(TestName());
345 Shape shape = ShapeUtil::MakeShape(F32, {4});
346 HloInstruction* x =
347 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
348 HloInstruction* y = builder.AddInstruction(
349 HloInstruction::CreateReducePrecision(shape, x, 5, 10));
350
351 auto module = CreateNewVerifiedModule();
352 auto computation = module->AddEntryComputation(builder.Build());
353
354 // Confirm expected graph before adding ops.
355 EXPECT_THAT(x->users(), UnorderedElementsAre(y));
356 EXPECT_EQ(computation->root_instruction(), y);
357
358 // Since the new reduce-precision operation would be redundant, this
359 // should not change the graph.
360 EXPECT_FALSE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
361 [](const HloInstruction* instruction) {
362 return instruction->opcode() ==
363 HloOpcode::kParameter;
364 }));
365
366 // Confirm that graph has not changed.
367 EXPECT_THAT(x->users(), UnorderedElementsAre(y));
368 EXPECT_EQ(computation->root_instruction(), y);
369 }
370
TEST_F(ReducePrecisionInsertionTest,AddNonRedundantReducePrecision)371 TEST_F(ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) {
372 auto builder = HloComputation::Builder(TestName());
373 Shape shape = ShapeUtil::MakeShape(F32, {4});
374 HloInstruction* x =
375 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
376 HloInstruction* y = builder.AddInstruction(
377 HloInstruction::CreateReducePrecision(shape, x, 8, 23));
378
379 auto module = CreateNewVerifiedModule();
380 auto computation = module->AddEntryComputation(builder.Build());
381
382 // Confirm expected graph before adding ops.
383 EXPECT_THAT(x->users(), UnorderedElementsAre(y));
384 EXPECT_EQ(computation->root_instruction(), y);
385
386 // Since the new reduce-precision operation is not the same as the existing
387 // one, this should add a new one.
388 EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
389 [](const HloInstruction* instruction) {
390 return instruction->opcode() == HloOpcode::kParameter;
391 }));
392
393 // Confirm that graph is as expected.
394 EXPECT_EQ(computation->root_instruction(), y);
395 EXPECT_THAT(y->operand(0), op::ReducePrecision(x));
396 }
397
TEST_F(ReducePrecisionInsertionTest,IgnoreOpsInsideFusionNode)398 TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) {
399 auto builder = HloComputation::Builder(TestName());
400 Shape shape = ShapeUtil::MakeShape(F32, {4});
401 HloInstruction* x =
402 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
403 HloInstruction* y = builder.AddInstruction(
404 HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
405 auto module = CreateNewVerifiedModule();
406 auto computation = module->AddEntryComputation(builder.Build());
407
408 // Manually fuse the kCos operation into a fusion operation.
409 HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
410 shape, HloInstruction::FusionKind::kLoop, y));
411 EXPECT_IS_OK(y->ReplaceAllUsesWith(z));
412 EXPECT_IS_OK(computation->RemoveInstruction(y));
413
414 // Confirm expected graph before adding reduce-precision ops.
415 EXPECT_THAT(x->users(), UnorderedElementsAre(z));
416 EXPECT_EQ(computation->root_instruction(), z);
417 HloInstruction* y_fused = z->fused_expression_root();
418 EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
419
420 // The ReducePrecisionInsertion pass should not see inside the fusion
421 // operation, so this should not change the graph.
422 EXPECT_FALSE(InsertOps(module.get(),
423 HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS,
424 [](const HloInstruction* instruction) {
425 return instruction->opcode() == HloOpcode::kCos;
426 }));
427
428 // Confirm that graph has not changed.
429 EXPECT_THAT(x->users(), UnorderedElementsAre(z));
430 EXPECT_EQ(computation->root_instruction(), z);
431 EXPECT_EQ(z->fused_expression_root(), y_fused);
432 }
433
TEST_F(ReducePrecisionInsertionTest,OpGetsInsertedInHeadOfFusionNode)434 TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) {
435 auto builder = HloComputation::Builder(TestName());
436 Shape shape = ShapeUtil::MakeShape(F32, {4});
437 HloInstruction* x =
438 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
439 HloInstruction* y = builder.AddInstruction(
440 HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
441 auto module = CreateNewVerifiedModule();
442 auto computation = module->AddEntryComputation(builder.Build());
443
444 // Manually fuse the kCos operation into a fusion operation.
445 HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
446 shape, HloInstruction::FusionKind::kLoop, y));
447 EXPECT_IS_OK(y->ReplaceAllUsesWith(z));
448 EXPECT_IS_OK(computation->RemoveInstruction(y));
449
450 // Confirm expected graph before adding reduce-precision ops.
451 EXPECT_THAT(x->users(), UnorderedElementsAre(z));
452 EXPECT_EQ(computation->root_instruction(), z);
453 HloInstruction* y_fused = z->fused_expression_root();
454 EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
455
456 // This should see that the fusion computation contains a kCos operation,
457 // and insert a new reduce-precision node at its input.
458 EXPECT_TRUE(InsertOps(module.get(),
459 HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT,
460 [](const HloInstruction* instruction) {
461 return instruction->opcode() == HloOpcode::kCos;
462 }));
463
464 // This should refuse to insert a second reduce-precision operation, as
465 // it would be redundant with the first.
466 EXPECT_FALSE(InsertOps(module.get(),
467 HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT,
468 [](const HloInstruction* instruction) {
469 return instruction->opcode() == HloOpcode::kCos;
470 }));
471
472 // Confirm that the top-level computation still only contains the fusion
473 // instruction, but that the fused computation now has a reduce-precision
474 // instruction inserted after its parameter instruction.
475 EXPECT_THAT(x->users(), UnorderedElementsAre(z));
476 EXPECT_EQ(computation->root_instruction(), z);
477 EXPECT_THAT(z->fused_expression_root(), y_fused);
478 EXPECT_THAT(y_fused->operand(0), op::ReducePrecision(op::Parameter()));
479 }
480
TEST_F(ReducePrecisionInsertionTest,OpGetsInsertedInTailOfFusionNode)481 TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) {
482 auto builder = HloComputation::Builder(TestName());
483 Shape shape = ShapeUtil::MakeShape(F32, {4});
484 HloInstruction* x =
485 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
486 HloInstruction* y = builder.AddInstruction(
487 HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
488 auto module = CreateNewVerifiedModule();
489 auto computation = module->AddEntryComputation(builder.Build());
490
491 // Manually fuse the kCos operation into a fusion operation.
492 HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
493 shape, HloInstruction::FusionKind::kLoop, y));
494 EXPECT_IS_OK(y->ReplaceAllUsesWith(z));
495 EXPECT_IS_OK(computation->RemoveInstruction(y));
496
497 // Confirm expected graph before adding reduce-precision ops.
498 EXPECT_THAT(x->users(), UnorderedElementsAre(z));
499 EXPECT_EQ(computation->root_instruction(), z);
500 HloInstruction* y_fused = z->fused_expression_root();
501 EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
502
503 // This should see that the fusion computation contains a kCos operation,
504 // and insert a new reduce-precision node at its root.
505 EXPECT_TRUE(InsertOps(module.get(),
506 HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT,
507 [](const HloInstruction* instruction) {
508 return instruction->opcode() == HloOpcode::kCos;
509 }));
510
511 // This should refuse to insert a second reduce-precision operation, as
512 // it would be redundant with the first.
513 EXPECT_FALSE(InsertOps(module.get(),
514 HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT,
515 [](const HloInstruction* instruction) {
516 return instruction->opcode() == HloOpcode::kCos;
517 }));
518
519 // Confirm that the top-level computation still only contains the fusion
520 // instruction, but that the fused computation now has a reduce-precision
521 // instruction inserted as its root.
522 EXPECT_THAT(x->users(), UnorderedElementsAre(z));
523 EXPECT_EQ(computation->root_instruction(), z);
524 EXPECT_THAT(z->fused_expression_root(), op::ReducePrecision(y_fused));
525 }
526
TEST_F(ReducePrecisionInsertionTest,MakeFilterFunctionNoSubstrings)527 TEST_F(ReducePrecisionInsertionTest, MakeFilterFunctionNoSubstrings) {
528 auto builder = HloComputation::Builder(TestName());
529 Shape shape = ShapeUtil::MakeShape(F32, {4});
530 HloInstruction* a =
531 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
532 HloInstruction* b = builder.AddInstruction(
533 HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
534 HloInstruction* c = builder.AddInstruction(
535 HloInstruction::CreateUnary(shape, HloOpcode::kSin, a));
536
537 auto options_proto = ReducePrecisionInsertion::make_options_proto(
538 HloReducePrecisionOptions::OP_OUTPUTS, 5, 10,
539 [](const HloOpcode opcode) { return opcode == HloOpcode::kCos; });
540
541 auto filter_function =
542 ReducePrecisionInsertion::make_filter_function(options_proto);
543
544 EXPECT_TRUE(filter_function(b));
545 EXPECT_FALSE(filter_function(c));
546 }
547
TEST_F(ReducePrecisionInsertionTest,MakeFilterFunctionWithSubstrings)548 TEST_F(ReducePrecisionInsertionTest, MakeFilterFunctionWithSubstrings) {
549 auto builder = HloComputation::Builder(TestName());
550 Shape shape = ShapeUtil::MakeShape(F32, {4});
551 HloInstruction* a =
552 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
553
554 HloInstruction* b = builder.AddInstruction(
555 HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
556 OpMetadata b_metadata;
557 b_metadata.set_op_name("FlowTensor/foom");
558 b->set_metadata(b_metadata);
559
560 HloInstruction* c = builder.AddInstruction(
561 HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
562 OpMetadata c_metadata;
563 c_metadata.set_op_name("FlowTensor/barn");
564 c->set_metadata(c_metadata);
565
566 auto options_proto = ReducePrecisionInsertion::make_options_proto(
567 HloReducePrecisionOptions::OP_OUTPUTS, 5, 10,
568 [](const HloOpcode opcode) { return opcode == HloOpcode::kCos; },
569 {"foo", "baz"});
570
571 auto filter_function =
572 ReducePrecisionInsertion::make_filter_function(options_proto);
573
574 EXPECT_TRUE(filter_function(b));
575 EXPECT_FALSE(filter_function(c));
576 }
577
578 } // namespace xla
579