1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
17
18 #include <iterator>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
24 #include "tensorflow/compiler/xla/shape.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26
27 namespace xla {
28 namespace gpu {
29
30 namespace {
AppendParams(const HloInstruction & instr,std::vector<HloInstruction * > * params)31 void AppendParams(const HloInstruction& instr,
32 std::vector<HloInstruction*>* params) {
33 if (instr.opcode() == HloOpcode::kFusion) {
34 params->insert(std::end(*params), std::begin(instr.fused_parameters()),
35 std::end(instr.fused_parameters()));
36 } else {
37 for (HloInstruction* operand : instr.operands()) {
38 params->push_back(operand);
39 }
40 }
41 }
42 } // namespace
43
LayoutsAreReduceInputFusionFriendly(const HloInstruction & producer,const HloInstruction & reduce)44 bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
45 const HloInstruction& reduce) {
46 std::vector<HloInstruction*> params;
47 AppendParams(producer, ¶ms);
48 AppendParams(reduce, ¶ms);
49 int64 max_rank = -1;
50 const Layout* max_rank_layout;
51 for (HloInstruction* param : params) {
52 if (param->shape().IsArray() && param->shape().rank() > max_rank) {
53 max_rank = param->shape().rank();
54 max_rank_layout = ¶m->shape().layout();
55 }
56 }
57 return absl::c_all_of(params, [&](HloInstruction* param) {
58 return (!param->shape().IsArray()) || (param->shape().rank() < max_rank) ||
59 (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
60 });
61 }
62
IsReduceInputFusion(const HloInstruction & instr)63 bool IsReduceInputFusion(const HloInstruction& instr) {
64 if (instr.IsMultiOutputFusion()) {
65 for (const HloInstruction* operand :
66 instr.fused_expression_root()->operands()) {
67 if (IsReductionToVector(*operand)) {
68 CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput)
69 << " Multi-output fusion rooted at reduction-to-vector ops must be "
70 "of kind kInput: "
71 << instr.ToString();
72 return true;
73 }
74 }
75 } else if (instr.opcode() == HloOpcode::kFusion &&
76 IsReductionToVector(*instr.fused_expression_root())) {
77 CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput)
78 << " Fusion rooted at reduction-to-vector op must be of kind kInput: "
79 << instr.ToString();
80 return true;
81 }
82 return false;
83 }
84
IsInputFusibleReduction(const HloInstruction & instr)85 bool IsInputFusibleReduction(const HloInstruction& instr) {
86 return IsReduceInputFusion(instr) || IsReductionToVector(instr);
87 }
88
ShapesCompatibleForMultiOutputFusion(const HloInstruction & instr1,const HloInstruction & instr2)89 bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
90 const HloInstruction& instr2) {
91 // Returns the instructions that determines the emitter used for lowering,
92 // sometimes referred to as "the real hero".
93 auto get_real_hero =
94 [&](const HloInstruction* instr) -> const HloInstruction* {
95 if (instr->opcode() == HloOpcode::kFusion) {
96 auto fused_expression_root = instr->fused_expression_root();
97 if (instr->IsMultiOutputFusion()) {
98 // If possible, we want to pick a reduction-to-vector operand of the
99 // fusion root, because it has the most constraints.
100 for (const auto* inst : fused_expression_root->operands()) {
101 if (IsReductionToVector(*inst)) {
102 return inst;
103 }
104 }
105 return fused_expression_root->operands()[0];
106 }
107 return fused_expression_root;
108 }
109 return instr;
110 };
111
112 // Multi-output fusion kernels share a common parallel loop. The loop
113 // dimenstions are determined by instruction shapes.
114 auto get_loop_shape = [&](const HloInstruction* element_instr) {
115 // Special-case reduction-to-vector ops: The loop dimensions are determined
116 // by the shape of the first operand.
117 if (IsReductionToVector(*element_instr)) {
118 return element_instr->operand(0)->shape();
119 }
120 return element_instr->shape();
121 };
122
123 // All shapes of the root tuple of multi-output fusions should agree, i.e. all
124 // root ops should have equal output shapes. An exception are
125 // reduction-to-vector ops. Here the input shapes of the reduction (first
126 // operand shape) and the reduction dimensions need to match.
127 auto* instr_1 = get_real_hero(&instr1);
128 auto* instr_2 = get_real_hero(&instr2);
129 // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
130 if (IsReductionToVector(*instr_1) && IsReductionToVector(*instr_2) &&
131 (!ShapeUtil::Equal(instr_1->shape(), instr_2->shape()) ||
132 instr_1->dimensions() != instr_2->dimensions())) {
133 return false;
134 }
135 // The elementwise output shapes must be the same (including layout).
136 // TODO(tjoerg): Further relax the constraint. The datatype does not matter.
137 return ShapeUtil::EqualIgnoringFpPrecision(get_loop_shape(instr_1),
138 get_loop_shape(instr_2));
139 }
140
IsInputFusibleScatter(const HloInstruction & instr)141 bool IsInputFusibleScatter(const HloInstruction& instr) {
142 if (instr.opcode() == HloOpcode::kScatter ||
143 (instr.opcode() == HloOpcode::kFusion &&
144 instr.fusion_kind() == HloInstruction::FusionKind::kInput &&
145 instr.fused_expression_root()->opcode() == HloOpcode::kScatter)) {
146 return true;
147 }
148 return false;
149 }
150
IsInputFusible(const HloInstruction & instr)151 bool IsInputFusible(const HloInstruction& instr) {
152 // Input fusion only handles non-elemental reduction and scatter operations.
153 return IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr);
154 }
155
IsLoopFusible(const HloInstruction & instr)156 bool IsLoopFusible(const HloInstruction& instr) {
157 // Don't fuse get-tuple-element on GPU: We can, but it's slower than not
158 // fusing. We never generate kernels for unfused GTEs. Instead, if an
159 // unfused GTE is an input to a kernel (including a fusion kernel), we
160 // compute the address of the GTE at the top of the kernel. Often we know the
161 // address of the GTE result statically, so we can do this without chasing any
162 // pointers.
163 return (instr.IsElementwise() && instr.operand_count() > 0) ||
164 instr.opcode() == HloOpcode::kBitcast ||
165 instr.opcode() == HloOpcode::kBroadcast ||
166 instr.opcode() == HloOpcode::kConcatenate ||
167 instr.opcode() == HloOpcode::kDynamicSlice ||
168 instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
169 (instr.opcode() == HloOpcode::kFusion &&
170 instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
171 instr.opcode() == HloOpcode::kGather ||
172 instr.opcode() == HloOpcode::kIota ||
173 instr.opcode() == HloOpcode::kPad ||
174 (instr.opcode() == HloOpcode::kReduce &&
175 !IsReductionToVector(instr)) ||
176 instr.opcode() == HloOpcode::kReduceWindow ||
177 instr.opcode() == HloOpcode::kReshape ||
178 instr.opcode() == HloOpcode::kReverse ||
179 instr.opcode() == HloOpcode::kSlice ||
180 instr.opcode() == HloOpcode::kTranspose;
181 }
182
IsFusible(const HloInstruction & instr)183 bool IsFusible(const HloInstruction& instr) {
184 return IsInputFusible(instr) || IsLoopFusible(instr);
185 }
186
187 } // namespace gpu
188 } // namespace xla
189