1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
20 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_query.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
25 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 namespace gpu {
31 
32 namespace {
ElementIsF32OrF16(const Shape & shape)33 bool ElementIsF32OrF16(const Shape& shape) {
34   PrimitiveType type = shape.element_type();
35   return type == F32 || type == F16;
36 }
37 }  // namespace
38 
IsExpensive(const HloInstruction & instruction)39 /*static*/ bool GpuInstructionFusion::IsExpensive(
40     const HloInstruction& instruction) {
41   // We say that some floating-point math ops are cheap on the GPU. Unlike other
42   // intrinsics that can be expanded into many instructions, Div and Rsqrt are
43   // lowered into single hardware instructions.
44   switch (instruction.opcode()) {
45     case HloOpcode::kDivide:
46     case HloOpcode::kRsqrt:
47       if (ElementIsF32OrF16(instruction.shape())) {
48         return false;
49       }
50       break;
51     default:
52       break;
53   }
54   return InstructionFusion::IsExpensive(instruction);
55 }
56 
ShouldFuseInexpensiveChecks(HloInstruction * consumer,int64 operand_index)57 bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer,
58                                                        int64 operand_index) {
59   HloInstruction* producer = consumer->mutable_operand(operand_index);
60 
61   // Output fusions are not currently supported on GPUs.
62   if (producer->opcode() == HloOpcode::kFusion) {
63     VLOG(4) << "Producer " << producer->name() << " is a fusion op";
64     return false;
65   }
66   // Cost condition: not fuse (simple, expensive producers) and (consumers who
67   // reuse operand elements).
68   if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) &&
69       ReusesOperandElements(consumer, operand_index)) {
70     VLOG(4) << "Do not fuse simple, expensive producer " << producer->name()
71             << " and consumer which reuses operand elements.";
72     return false;
73   }
74 
75   if (!IsProducerConsumerFusible(*producer, *consumer) ||
76       !InstructionFusion::ShouldFuse(consumer, operand_index)) {
77     VLOG(4) << "Producer " << producer->name()
78             << " is not fusible or should not be fused.";
79     return false;
80   }
81   return true;
82 }
83 
ShouldFuse(HloInstruction * consumer,int64 operand_index)84 bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
85                                       int64 operand_index) {
86   if (!ShouldFuseInexpensiveChecks(consumer, operand_index)) {
87     VLOG(5) << "Not fusing inexpensive checks of operand " << operand_index
88             << " of " << consumer->ToString();
89     return false;
90   }
91   auto producer = consumer->operand(operand_index);
92 
93   // The following checks are potentially expensive.
94   if (FusionWouldBeTooLarge(*consumer, *producer,
95                             /*is_consumer_producer_fusion=*/true)) {
96     VLOG(5) << "Fusion of (" << producer->ToString() << ") into ("
97             << consumer->ToString() << ") would be too large";
98     return false;
99   }
100   if (consumer->opcode() != HloOpcode::kFusion) {
101     return true;
102   }
103   // Also check that our emitter can handle the fusion node. We currently can
104   // have exponential time/memory requirements for emitting certain fusion
105   // kernels, in which case we don't want to fuse.
106   // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
107   if (fusion_node_evaluations_.find(consumer) ==
108       fusion_node_evaluations_.end()) {
109     // We have no cached results for this fusion node yet. This can happen when
110     // we run the InstructionFusion pass more than once. We can only cache the
111     // results within one run.
112     fusion_node_evaluations_.emplace(consumer,
113                                      FusionNodeIndexingEvaluation(consumer));
114   }
115   if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) {
116     VLOG(5) << "Fusion of " << producer->name() << " into " << consumer->name()
117             << " would result in overly large code duplication.";
118     return false;
119   }
120   return true;
121 }
122 
ShouldFuseIntoMultiOutput(HloInstruction * consumer,int64 operand_index)123 bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer,
124                                                      int64 operand_index) {
125   return false;
126 }
127 
ChooseKind(const HloInstruction * producer,const HloInstruction * consumer)128 HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
129     const HloInstruction* producer, const HloInstruction* consumer) {
130   return ChooseFusionKind(*producer, *consumer);
131 }
132 
FuseInstruction(HloInstruction * fusion_instruction,HloInstruction * producer)133 HloInstruction* GpuInstructionFusion::FuseInstruction(
134     HloInstruction* fusion_instruction, HloInstruction* producer) {
135   auto evaluation = fusion_node_evaluations_.find(fusion_instruction);
136   if (evaluation == fusion_node_evaluations_.end()) {
137     evaluation = fusion_node_evaluations_
138                      .emplace(fusion_instruction,
139                               FusionNodeIndexingEvaluation(fusion_instruction))
140                      .first;
141   }
142   auto indexing_users = evaluation->second.RemoveFusionOperand(producer);
143   HloInstruction* new_producer =
144       InstructionFusion::FuseInstruction(fusion_instruction, producer);
145   evaluation->second.UpdateEvaluationCache(new_producer, indexing_users);
146   return new_producer;
147 }
148 
149 }  // namespace gpu
150 }  // namespace xla
151