1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
20 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_query.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
25 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28
29 namespace xla {
30 namespace gpu {
31
32 namespace {
ElementIsF32OrF16(const Shape & shape)33 bool ElementIsF32OrF16(const Shape& shape) {
34 PrimitiveType type = shape.element_type();
35 return type == F32 || type == F16;
36 }
37 } // namespace
38
IsExpensive(const HloInstruction & instruction)39 /*static*/ bool GpuInstructionFusion::IsExpensive(
40 const HloInstruction& instruction) {
41 // We say that some floating-point math ops are cheap on the GPU. Unlike other
42 // intrinsics that can be expanded into many instructions, Div and Rsqrt are
43 // lowered into single hardware instructions.
44 switch (instruction.opcode()) {
45 case HloOpcode::kDivide:
46 case HloOpcode::kRsqrt:
47 if (ElementIsF32OrF16(instruction.shape())) {
48 return false;
49 }
50 break;
51 default:
52 break;
53 }
54 return InstructionFusion::IsExpensive(instruction);
55 }
56
ShouldFuseInexpensiveChecks(HloInstruction * consumer,int64 operand_index)57 bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer,
58 int64 operand_index) {
59 HloInstruction* producer = consumer->mutable_operand(operand_index);
60
61 // Output fusions are not currently supported on GPUs.
62 if (producer->opcode() == HloOpcode::kFusion) {
63 VLOG(4) << "Producer " << producer->name() << " is a fusion op";
64 return false;
65 }
66 // Cost condition: not fuse (simple, expensive producers) and (consumers who
67 // reuse operand elements).
68 if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) &&
69 ReusesOperandElements(consumer, operand_index)) {
70 VLOG(4) << "Do not fuse simple, expensive producer " << producer->name()
71 << " and consumer which reuses operand elements.";
72 return false;
73 }
74
75 if (!IsProducerConsumerFusible(*producer, *consumer) ||
76 !InstructionFusion::ShouldFuse(consumer, operand_index)) {
77 VLOG(4) << "Producer " << producer->name()
78 << " is not fusible or should not be fused.";
79 return false;
80 }
81 return true;
82 }
83
ShouldFuse(HloInstruction * consumer,int64 operand_index)84 bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
85 int64 operand_index) {
86 if (!ShouldFuseInexpensiveChecks(consumer, operand_index)) {
87 VLOG(5) << "Not fusing inexpensive checks of operand " << operand_index
88 << " of " << consumer->ToString();
89 return false;
90 }
91 auto producer = consumer->operand(operand_index);
92
93 // The following checks are potentially expensive.
94 if (FusionWouldBeTooLarge(*consumer, *producer,
95 /*is_consumer_producer_fusion=*/true)) {
96 VLOG(5) << "Fusion of (" << producer->ToString() << ") into ("
97 << consumer->ToString() << ") would be too large";
98 return false;
99 }
100 if (consumer->opcode() != HloOpcode::kFusion) {
101 return true;
102 }
103 // Also check that our emitter can handle the fusion node. We currently can
104 // have exponential time/memory requirements for emitting certain fusion
105 // kernels, in which case we don't want to fuse.
106 // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
107 if (fusion_node_evaluations_.find(consumer) ==
108 fusion_node_evaluations_.end()) {
109 // We have no cached results for this fusion node yet. This can happen when
110 // we run the InstructionFusion pass more than once. We can only cache the
111 // results within one run.
112 fusion_node_evaluations_.emplace(consumer,
113 FusionNodeIndexingEvaluation(consumer));
114 }
115 if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) {
116 VLOG(5) << "Fusion of " << producer->name() << " into " << consumer->name()
117 << " would result in overly large code duplication.";
118 return false;
119 }
120 return true;
121 }
122
ShouldFuseIntoMultiOutput(HloInstruction * consumer,int64 operand_index)123 bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer,
124 int64 operand_index) {
125 return false;
126 }
127
ChooseKind(const HloInstruction * producer,const HloInstruction * consumer)128 HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
129 const HloInstruction* producer, const HloInstruction* consumer) {
130 return ChooseFusionKind(*producer, *consumer);
131 }
132
FuseInstruction(HloInstruction * fusion_instruction,HloInstruction * producer)133 HloInstruction* GpuInstructionFusion::FuseInstruction(
134 HloInstruction* fusion_instruction, HloInstruction* producer) {
135 auto evaluation = fusion_node_evaluations_.find(fusion_instruction);
136 if (evaluation == fusion_node_evaluations_.end()) {
137 evaluation = fusion_node_evaluations_
138 .emplace(fusion_instruction,
139 FusionNodeIndexingEvaluation(fusion_instruction))
140 .first;
141 }
142 auto indexing_users = evaluation->second.RemoveFusionOperand(producer);
143 HloInstruction* new_producer =
144 InstructionFusion::FuseInstruction(fusion_instruction, producer);
145 evaluation->second.UpdateEvaluationCache(new_producer, indexing_users);
146 return new_producer;
147 }
148
149 } // namespace gpu
150 } // namespace xla
151