1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/test_helpers.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace op = xla::testing::opcode_matchers;
30 
31 namespace xla {
32 
33 using ::testing::UnorderedElementsAre;
34 
35 class ReducePrecisionInsertionTest : public HloTestBase {
36  protected:
InsertOps(HloModule * module,const HloReducePrecisionOptions::Location location,const std::function<bool (const HloInstruction *)> & filter)37   bool InsertOps(HloModule* module,
38                  const HloReducePrecisionOptions::Location location,
39                  const std::function<bool(const HloInstruction*)>& filter) {
40     ReducePrecisionInsertion op_insertion(5, 10, location, filter);
41     StatusOr<bool> result = op_insertion.Run(module);
42     EXPECT_IS_OK(result.status());
43     return result.ValueOrDie();
44   }
45 };
46 
TEST_F(ReducePrecisionInsertionTest,BeforeUnaryInstruction)47 TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) {
48   auto builder = HloComputation::Builder(TestName());
49   Shape shape = ShapeUtil::MakeShape(F32, {4});
50 
51   // Create a simple graph with a parameter feeding a unary cosine function.
52   HloInstruction* a =
53       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
54   HloInstruction* b = builder.AddInstruction(
55       HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
56 
57   auto module = CreateNewVerifiedModule();
58   auto computation = module->AddEntryComputation(builder.Build());
59 
60   // Confirm expected state before adding ops.
61   EXPECT_EQ(computation->root_instruction(), b);
62   EXPECT_EQ(b->operand(0), a);
63 
64   EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
65                         [](const HloInstruction* instruction) {
66                           return instruction->opcode() == HloOpcode::kCos;
67                         }));
68 
69   // Confirm expected graph after adding ops.
70   EXPECT_EQ(computation->root_instruction(), b);
71   EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
72 }
73 
TEST_F(ReducePrecisionInsertionTest,BeforeUnaryScalarInstruction)74 TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) {
75   auto builder = HloComputation::Builder(TestName());
76   Shape shape = ShapeUtil::MakeShape(F32, {});
77 
78   // Create a simple graph with a parameter feeding a unary cosine function.
79   HloInstruction* a =
80       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
81   HloInstruction* b = builder.AddInstruction(
82       HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
83 
84   auto module = CreateNewVerifiedModule();
85   auto computation = module->AddEntryComputation(builder.Build());
86 
87   // Confirm expected state before adding ops.
88   EXPECT_EQ(computation->root_instruction(), b);
89   EXPECT_EQ(b->operand(0), a);
90 
91   EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
92                         [](const HloInstruction* instruction) {
93                           return instruction->opcode() == HloOpcode::kCos;
94                         }));
95 
96   // Confirm expected graph after adding ops.
97   EXPECT_EQ(computation->root_instruction(), b);
98   EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
99 }
100 
TEST_F(ReducePrecisionInsertionTest,BeforeBinaryInstruction)101 TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) {
102   auto builder = HloComputation::Builder(TestName());
103   Shape shape = ShapeUtil::MakeShape(F32, {4});
104 
105   // Create a simple graph with parameter feeding a binary add function.
106 
107   HloInstruction* a =
108       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
109   HloInstruction* b =
110       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
111   HloInstruction* c = builder.AddInstruction(
112       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
113 
114   auto module = CreateNewVerifiedModule();
115   auto computation = module->AddEntryComputation(builder.Build());
116 
117   // Confirm expected state before adding ops.
118   EXPECT_EQ(computation->root_instruction(), c);
119   EXPECT_EQ(c->operand(0), a);
120   EXPECT_EQ(c->operand(1), b);
121 
122   EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
123                         [](const HloInstruction* instruction) {
124                           return instruction->opcode() == HloOpcode::kAdd;
125                         }));
126 
127   // Confirm expected graph after adding ops.
128   EXPECT_EQ(computation->root_instruction(), c);
129   EXPECT_THAT(c->operand(0), op::ReducePrecision(a));
130   EXPECT_THAT(c->operand(1), op::ReducePrecision(b));
131 }
132 
TEST_F(ReducePrecisionInsertionTest,BeforeZeroInputInstruction)133 TEST_F(ReducePrecisionInsertionTest, BeforeZeroInputInstruction) {
134   auto builder = HloComputation::Builder(TestName());
135   Shape shape = ShapeUtil::MakeShape(F32, {4});
136 
137   // Create a simple graph with a parameter feeding a unary cosine function.
138   HloInstruction* a =
139       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
140   HloInstruction* b = builder.AddInstruction(
141       HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
142 
143   auto module = CreateNewVerifiedModule();
144   auto computation = module->AddEntryComputation(builder.Build());
145 
146   // Confirm expected state before adding ops.
147   EXPECT_EQ(computation->root_instruction(), b);
148   EXPECT_EQ(b->operand(0), a);
149 
150   EXPECT_FALSE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
151                          [](const HloInstruction* instruction) {
152                            return instruction->opcode() ==
153                                   HloOpcode::kParameter;
154                          }));
155 
156   // Confirm that graph has not changed.
157   EXPECT_EQ(computation->root_instruction(), b);
158   EXPECT_EQ(b->operand(0), a);
159 }
160 
TEST_F(ReducePrecisionInsertionTest,AvoidAddingDuplicateInstructions)161 TEST_F(ReducePrecisionInsertionTest, AvoidAddingDuplicateInstructions) {
162   auto builder = HloComputation::Builder(TestName());
163   Shape shape = ShapeUtil::MakeShape(F32, {4});
164 
165   // Create a simple graph with parameter feeding a binary add function.
166 
167   HloInstruction* a =
168       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
169   HloInstruction* b = builder.AddInstruction(
170       HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
171   HloInstruction* c = builder.AddInstruction(
172       HloInstruction::CreateUnary(shape, HloOpcode::kSin, a));
173   HloInstruction* d = builder.AddInstruction(
174       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, c));
175 
176   auto module = CreateNewVerifiedModule();
177   auto computation = module->AddEntryComputation(builder.Build());
178 
179   // Confirm expected state before adding ops.
180   EXPECT_EQ(computation->root_instruction(), d);
181   EXPECT_EQ(b->operand(0), a);
182   EXPECT_EQ(c->operand(0), a);
183 
184   EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS,
185                         [](const HloInstruction* instruction) {
186                           return instruction->opcode() == HloOpcode::kCos ||
187                                  instruction->opcode() == HloOpcode::kSin;
188                         }));
189 
190   // Confirm expected graph after adding ops.  In particular, we want to confirm
191   // that the reduced-precision operation added for the input to b is re-used
192   // for the input to c.
193   EXPECT_THAT(b->operand(0), op::ReducePrecision(a));
194   EXPECT_THAT(c->operand(0), op::ReducePrecision(a));
195   EXPECT_EQ(b->operand(0), c->operand(0));
196 }
197 
TEST_F(ReducePrecisionInsertionTest,AfterRootInstruction)198 TEST_F(ReducePrecisionInsertionTest, AfterRootInstruction) {
199   auto builder = HloComputation::Builder(TestName());
200   Shape shape = ShapeUtil::MakeShape(F32, {4});
201 
202   // Create a simple graph with a parameter feeding a unary cosine function.
203   HloInstruction* a =
204       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
205   HloInstruction* b = builder.AddInstruction(
206       HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
207 
208   auto module = CreateNewVerifiedModule();
209   auto computation = module->AddEntryComputation(builder.Build());
210 
211   // Confirm expected state before adding ops.
212   EXPECT_EQ(computation->root_instruction(), b);
213 
214   EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
215                         [](const HloInstruction* instruction) {
216                           return instruction->opcode() == HloOpcode::kCos;
217                         }));
218 
219   // Confirm expected graph after adding ops.
220   EXPECT_THAT(computation->root_instruction(), op::ReducePrecision(b));
221 }
222 
TEST_F(ReducePrecisionInsertionTest,AfterNonRootInstruction)223 TEST_F(ReducePrecisionInsertionTest, AfterNonRootInstruction) {
224   auto builder = HloComputation::Builder(TestName());
225   Shape shape = ShapeUtil::MakeShape(F32, {4});
226 
227   // Create a graph with two parameters feeding into unary cosine functions,
228   // and the output of those feeds into an add function.  Feeding the outputs
229   // from the suffixed cosine functions into a binary add function allows us to
230   // confirm that the separate operand streams are not crossed when the new
231   // instructions are inserted.
232   HloInstruction* a =
233       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
234   HloInstruction* a_cos = builder.AddInstruction(
235       HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
236 
237   HloInstruction* b =
238       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
239   HloInstruction* b_cos = builder.AddInstruction(
240       HloInstruction::CreateUnary(shape, HloOpcode::kCos, b));
241 
242   HloInstruction* c = builder.AddInstruction(
243       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos));
244 
245   auto module = CreateNewVerifiedModule();
246   module->AddEntryComputation(builder.Build());
247 
248   // Confirm expected graph before adding ops.
249   EXPECT_EQ(c->operand(0), a_cos);
250   EXPECT_EQ(c->operand(1), b_cos);
251 
252   EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
253                         [](const HloInstruction* instruction) {
254                           return instruction->opcode() == HloOpcode::kCos;
255                         }));
256 
257   // Confirm expected graph after adding ops.
258   EXPECT_THAT(c->operand(0), op::ReducePrecision());
259   EXPECT_EQ(c->operand(0)->operand(0), a_cos);
260   EXPECT_THAT(c->operand(1), op::ReducePrecision());
261   EXPECT_EQ(c->operand(1)->operand(0), b_cos);
262 }
263 
TEST_F(ReducePrecisionInsertionTest,OutputIsNotFloat)264 TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) {
265   auto builder = HloComputation::Builder(TestName());
266   Shape shape = ShapeUtil::MakeShape(S32, {4});
267   HloInstruction* x =
268       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
269   HloInstruction* y = builder.AddInstruction(
270       HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
271 
272   auto module = CreateNewUnverifiedModule();
273   auto computation = module->AddEntryComputation(builder.Build());
274 
275   // Confirm expected graph before adding ops.
276   EXPECT_THAT(x->users(), UnorderedElementsAre(y));
277   EXPECT_EQ(computation->root_instruction(), y);
278 
279   // Since none of the instructions produce F32 data, this should not change
280   // the graph.
281   EXPECT_FALSE(
282       InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
283                 [](const HloInstruction* instruction) { return true; }));
284 
285   // Confirm that graph has not changed.
286   EXPECT_THAT(x->users(), UnorderedElementsAre(y));
287   EXPECT_EQ(computation->root_instruction(), y);
288 }
289 
TEST_F(ReducePrecisionInsertionTest,ShouldReduceOutputPrecisionIsFalse)290 TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) {
291   auto builder = HloComputation::Builder(TestName());
292   Shape shape = ShapeUtil::MakeShape(F32, {4});
293   HloInstruction* x =
294       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
295   HloInstruction* y = builder.AddInstruction(
296       HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
297 
298   auto module = CreateNewVerifiedModule();
299   auto computation = module->AddEntryComputation(builder.Build());
300 
301   // Confirm expected graph before adding ops.
302   EXPECT_THAT(x->users(), UnorderedElementsAre(y));
303   EXPECT_EQ(computation->root_instruction(), y);
304 
305   // Since none of the instructions match the should_reduce_output_precision
306   // function, this should not change the graph.
307   EXPECT_FALSE(
308       InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
309                 [](const HloInstruction* instruction) { return false; }));
310 
311   // Confirm that graph has not changed.
312   EXPECT_THAT(x->users(), UnorderedElementsAre(y));
313   EXPECT_EQ(computation->root_instruction(), y);
314 }
315 
TEST_F(ReducePrecisionInsertionTest,InsertionIsNotRecursive)316 TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) {
317   auto builder = HloComputation::Builder(TestName());
318   Shape shape = ShapeUtil::MakeShape(F32, {4});
319   HloInstruction* a =
320       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
321   HloInstruction* b = builder.AddInstruction(
322       HloInstruction::CreateReducePrecision(shape, a, 8, 23));
323 
324   auto module = CreateNewVerifiedModule();
325   auto computation = module->AddEntryComputation(builder.Build());
326 
327   // Confirm expected state before adding ops.
328   EXPECT_EQ(computation->root_instruction(), b);
329 
330   // This should insert a new ReducePrecision after the existing one, but
331   // should not then recurse by adding another after the just-inserted one.
332   EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
333                         [](const HloInstruction* instruction) {
334                           return instruction->opcode() ==
335                                  HloOpcode::kReducePrecision;
336                         }));
337 
338   // Confirm expected graph after adding ops.
339   EXPECT_THAT(computation->root_instruction(), op::ReducePrecision());
340   EXPECT_EQ(computation->root_instruction()->operand(0), b);
341 }
342 
TEST_F(ReducePrecisionInsertionTest,SkipRedundantReducePrecisionAfter)343 TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecisionAfter) {
344   auto builder = HloComputation::Builder(TestName());
345   Shape shape = ShapeUtil::MakeShape(F32, {4});
346   HloInstruction* x =
347       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
348   HloInstruction* y = builder.AddInstruction(
349       HloInstruction::CreateReducePrecision(shape, x, 5, 10));
350 
351   auto module = CreateNewVerifiedModule();
352   auto computation = module->AddEntryComputation(builder.Build());
353 
354   // Confirm expected graph before adding ops.
355   EXPECT_THAT(x->users(), UnorderedElementsAre(y));
356   EXPECT_EQ(computation->root_instruction(), y);
357 
358   // Since the new reduce-precision operation would be redundant, this
359   // should not change the graph.
360   EXPECT_FALSE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
361                          [](const HloInstruction* instruction) {
362                            return instruction->opcode() ==
363                                   HloOpcode::kParameter;
364                          }));
365 
366   // Confirm that graph has not changed.
367   EXPECT_THAT(x->users(), UnorderedElementsAre(y));
368   EXPECT_EQ(computation->root_instruction(), y);
369 }
370 
TEST_F(ReducePrecisionInsertionTest,AddNonRedundantReducePrecision)371 TEST_F(ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) {
372   auto builder = HloComputation::Builder(TestName());
373   Shape shape = ShapeUtil::MakeShape(F32, {4});
374   HloInstruction* x =
375       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
376   HloInstruction* y = builder.AddInstruction(
377       HloInstruction::CreateReducePrecision(shape, x, 8, 23));
378 
379   auto module = CreateNewVerifiedModule();
380   auto computation = module->AddEntryComputation(builder.Build());
381 
382   // Confirm expected graph before adding ops.
383   EXPECT_THAT(x->users(), UnorderedElementsAre(y));
384   EXPECT_EQ(computation->root_instruction(), y);
385 
386   // Since the new reduce-precision operation is not the same as the existing
387   // one, this should add a new one.
388   EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS,
389                         [](const HloInstruction* instruction) {
390                           return instruction->opcode() == HloOpcode::kParameter;
391                         }));
392 
393   // Confirm that graph is as expected.
394   EXPECT_EQ(computation->root_instruction(), y);
395   EXPECT_THAT(y->operand(0), op::ReducePrecision(x));
396 }
397 
TEST_F(ReducePrecisionInsertionTest,IgnoreOpsInsideFusionNode)398 TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) {
399   auto builder = HloComputation::Builder(TestName());
400   Shape shape = ShapeUtil::MakeShape(F32, {4});
401   HloInstruction* x =
402       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
403   HloInstruction* y = builder.AddInstruction(
404       HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
405   auto module = CreateNewVerifiedModule();
406   auto computation = module->AddEntryComputation(builder.Build());
407 
408   // Manually fuse the kCos operation into a fusion operation.
409   HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
410       shape, HloInstruction::FusionKind::kLoop, y));
411   EXPECT_IS_OK(y->ReplaceAllUsesWith(z));
412   EXPECT_IS_OK(computation->RemoveInstruction(y));
413 
414   // Confirm expected graph before adding reduce-precision ops.
415   EXPECT_THAT(x->users(), UnorderedElementsAre(z));
416   EXPECT_EQ(computation->root_instruction(), z);
417   HloInstruction* y_fused = z->fused_expression_root();
418   EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
419 
420   // The ReducePrecisionInsertion pass should not see inside the fusion
421   // operation, so this should not change the graph.
422   EXPECT_FALSE(InsertOps(module.get(),
423                          HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS,
424                          [](const HloInstruction* instruction) {
425                            return instruction->opcode() == HloOpcode::kCos;
426                          }));
427 
428   // Confirm that graph has not changed.
429   EXPECT_THAT(x->users(), UnorderedElementsAre(z));
430   EXPECT_EQ(computation->root_instruction(), z);
431   EXPECT_EQ(z->fused_expression_root(), y_fused);
432 }
433 
TEST_F(ReducePrecisionInsertionTest,OpGetsInsertedInHeadOfFusionNode)434 TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) {
435   auto builder = HloComputation::Builder(TestName());
436   Shape shape = ShapeUtil::MakeShape(F32, {4});
437   HloInstruction* x =
438       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
439   HloInstruction* y = builder.AddInstruction(
440       HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
441   auto module = CreateNewVerifiedModule();
442   auto computation = module->AddEntryComputation(builder.Build());
443 
444   // Manually fuse the kCos operation into a fusion operation.
445   HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
446       shape, HloInstruction::FusionKind::kLoop, y));
447   EXPECT_IS_OK(y->ReplaceAllUsesWith(z));
448   EXPECT_IS_OK(computation->RemoveInstruction(y));
449 
450   // Confirm expected graph before adding reduce-precision ops.
451   EXPECT_THAT(x->users(), UnorderedElementsAre(z));
452   EXPECT_EQ(computation->root_instruction(), z);
453   HloInstruction* y_fused = z->fused_expression_root();
454   EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
455 
456   // This should see that the fusion computation contains a kCos operation,
457   // and insert a new reduce-precision node at its input.
458   EXPECT_TRUE(InsertOps(module.get(),
459                         HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT,
460                         [](const HloInstruction* instruction) {
461                           return instruction->opcode() == HloOpcode::kCos;
462                         }));
463 
464   // This should refuse to insert a second reduce-precision operation, as
465   // it would be redundant with the first.
466   EXPECT_FALSE(InsertOps(module.get(),
467                          HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT,
468                          [](const HloInstruction* instruction) {
469                            return instruction->opcode() == HloOpcode::kCos;
470                          }));
471 
472   // Confirm that the top-level computation still only contains the fusion
473   // instruction, but that the fused computation now has a reduce-precision
474   // instruction inserted after its parameter instruction.
475   EXPECT_THAT(x->users(), UnorderedElementsAre(z));
476   EXPECT_EQ(computation->root_instruction(), z);
477   EXPECT_THAT(z->fused_expression_root(), y_fused);
478   EXPECT_THAT(y_fused->operand(0), op::ReducePrecision(op::Parameter()));
479 }
480 
TEST_F(ReducePrecisionInsertionTest,OpGetsInsertedInTailOfFusionNode)481 TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) {
482   auto builder = HloComputation::Builder(TestName());
483   Shape shape = ShapeUtil::MakeShape(F32, {4});
484   HloInstruction* x =
485       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
486   HloInstruction* y = builder.AddInstruction(
487       HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
488   auto module = CreateNewVerifiedModule();
489   auto computation = module->AddEntryComputation(builder.Build());
490 
491   // Manually fuse the kCos operation into a fusion operation.
492   HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
493       shape, HloInstruction::FusionKind::kLoop, y));
494   EXPECT_IS_OK(y->ReplaceAllUsesWith(z));
495   EXPECT_IS_OK(computation->RemoveInstruction(y));
496 
497   // Confirm expected graph before adding reduce-precision ops.
498   EXPECT_THAT(x->users(), UnorderedElementsAre(z));
499   EXPECT_EQ(computation->root_instruction(), z);
500   HloInstruction* y_fused = z->fused_expression_root();
501   EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
502 
503   // This should see that the fusion computation contains a kCos operation,
504   // and insert a new reduce-precision node at its root.
505   EXPECT_TRUE(InsertOps(module.get(),
506                         HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT,
507                         [](const HloInstruction* instruction) {
508                           return instruction->opcode() == HloOpcode::kCos;
509                         }));
510 
511   // This should refuse to insert a second reduce-precision operation, as
512   // it would be redundant with the first.
513   EXPECT_FALSE(InsertOps(module.get(),
514                          HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT,
515                          [](const HloInstruction* instruction) {
516                            return instruction->opcode() == HloOpcode::kCos;
517                          }));
518 
519   // Confirm that the top-level computation still only contains the fusion
520   // instruction, but that the fused computation now has a reduce-precision
521   // instruction inserted as its root.
522   EXPECT_THAT(x->users(), UnorderedElementsAre(z));
523   EXPECT_EQ(computation->root_instruction(), z);
524   EXPECT_THAT(z->fused_expression_root(), op::ReducePrecision(y_fused));
525 }
526 
TEST_F(ReducePrecisionInsertionTest,MakeFilterFunctionNoSubstrings)527 TEST_F(ReducePrecisionInsertionTest, MakeFilterFunctionNoSubstrings) {
528   auto builder = HloComputation::Builder(TestName());
529   Shape shape = ShapeUtil::MakeShape(F32, {4});
530   HloInstruction* a =
531       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
532   HloInstruction* b = builder.AddInstruction(
533       HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
534   HloInstruction* c = builder.AddInstruction(
535       HloInstruction::CreateUnary(shape, HloOpcode::kSin, a));
536 
537   auto options_proto = ReducePrecisionInsertion::make_options_proto(
538       HloReducePrecisionOptions::OP_OUTPUTS, 5, 10,
539       [](const HloOpcode opcode) { return opcode == HloOpcode::kCos; });
540 
541   auto filter_function =
542       ReducePrecisionInsertion::make_filter_function(options_proto);
543 
544   EXPECT_TRUE(filter_function(b));
545   EXPECT_FALSE(filter_function(c));
546 }
547 
TEST_F(ReducePrecisionInsertionTest,MakeFilterFunctionWithSubstrings)548 TEST_F(ReducePrecisionInsertionTest, MakeFilterFunctionWithSubstrings) {
549   auto builder = HloComputation::Builder(TestName());
550   Shape shape = ShapeUtil::MakeShape(F32, {4});
551   HloInstruction* a =
552       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
553 
554   HloInstruction* b = builder.AddInstruction(
555       HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
556   OpMetadata b_metadata;
557   b_metadata.set_op_name("FlowTensor/foom");
558   b->set_metadata(b_metadata);
559 
560   HloInstruction* c = builder.AddInstruction(
561       HloInstruction::CreateUnary(shape, HloOpcode::kCos, a));
562   OpMetadata c_metadata;
563   c_metadata.set_op_name("FlowTensor/barn");
564   c->set_metadata(c_metadata);
565 
566   auto options_proto = ReducePrecisionInsertion::make_options_proto(
567       HloReducePrecisionOptions::OP_OUTPUTS, 5, 10,
568       [](const HloOpcode opcode) { return opcode == HloOpcode::kCos; },
569       {"foo", "baz"});
570 
571   auto filter_function =
572       ReducePrecisionInsertion::make_filter_function(options_proto);
573 
574   EXPECT_TRUE(filter_function(b));
575   EXPECT_FALSE(filter_function(c));
576 }
577 
578 }  // namespace xla
579