1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
17 
18 #include <iterator>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
24 #include "tensorflow/compiler/xla/shape.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 
27 namespace xla {
28 namespace gpu {
29 
30 namespace {
AppendParams(const HloInstruction & instr,std::vector<HloInstruction * > * params)31 void AppendParams(const HloInstruction& instr,
32                   std::vector<HloInstruction*>* params) {
33   if (instr.opcode() == HloOpcode::kFusion) {
34     params->insert(std::end(*params), std::begin(instr.fused_parameters()),
35                    std::end(instr.fused_parameters()));
36   } else {
37     for (HloInstruction* operand : instr.operands()) {
38       params->push_back(operand);
39     }
40   }
41 }
42 }  // namespace
43 
LayoutsAreReduceInputFusionFriendly(const HloInstruction & producer,const HloInstruction & reduce)44 bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
45                                          const HloInstruction& reduce) {
46   std::vector<HloInstruction*> params;
47   AppendParams(producer, &params);
48   AppendParams(reduce, &params);
49   int64 max_rank = -1;
50   const Layout* max_rank_layout;
51   for (HloInstruction* param : params) {
52     if (param->shape().IsArray() && param->shape().rank() > max_rank) {
53       max_rank = param->shape().rank();
54       max_rank_layout = &param->shape().layout();
55     }
56   }
57   return absl::c_all_of(params, [&](HloInstruction* param) {
58     return (!param->shape().IsArray()) || (param->shape().rank() < max_rank) ||
59            (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
60   });
61 }
62 
IsReduceInputFusion(const HloInstruction & instr)63 bool IsReduceInputFusion(const HloInstruction& instr) {
64   if (instr.IsMultiOutputFusion()) {
65     for (const HloInstruction* operand :
66          instr.fused_expression_root()->operands()) {
67       if (IsReductionToVector(*operand)) {
68         CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput)
69             << " Multi-output fusion rooted at reduction-to-vector ops must be "
70                "of kind kInput: "
71             << instr.ToString();
72         return true;
73       }
74     }
75   } else if (instr.opcode() == HloOpcode::kFusion &&
76              IsReductionToVector(*instr.fused_expression_root())) {
77     CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput)
78         << " Fusion rooted at reduction-to-vector op must be of kind kInput: "
79         << instr.ToString();
80     return true;
81   }
82   return false;
83 }
84 
IsInputFusibleReduction(const HloInstruction & instr)85 bool IsInputFusibleReduction(const HloInstruction& instr) {
86   return IsReduceInputFusion(instr) || IsReductionToVector(instr);
87 }
88 
ShapesCompatibleForMultiOutputFusion(const HloInstruction & instr1,const HloInstruction & instr2)89 bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
90                                           const HloInstruction& instr2) {
91   // Returns the instructions that determines the emitter used for lowering,
92   // sometimes referred to as "the real hero".
93   auto get_real_hero =
94       [&](const HloInstruction* instr) -> const HloInstruction* {
95     if (instr->opcode() == HloOpcode::kFusion) {
96       auto fused_expression_root = instr->fused_expression_root();
97       if (instr->IsMultiOutputFusion()) {
98         // If possible, we want to pick a reduction-to-vector operand of the
99         // fusion root, because it has the most constraints.
100         for (const auto* inst : fused_expression_root->operands()) {
101           if (IsReductionToVector(*inst)) {
102             return inst;
103           }
104         }
105         return fused_expression_root->operands()[0];
106       }
107       return fused_expression_root;
108     }
109     return instr;
110   };
111 
112   // Multi-output fusion kernels share a common parallel loop. The loop
113   // dimenstions are determined by instruction shapes.
114   auto get_loop_shape = [&](const HloInstruction* element_instr) {
115     // Special-case reduction-to-vector ops: The loop dimensions are determined
116     // by the shape of the first operand.
117     if (IsReductionToVector(*element_instr)) {
118       return element_instr->operand(0)->shape();
119     }
120     return element_instr->shape();
121   };
122 
123   // All shapes of the root tuple of multi-output fusions should agree, i.e. all
124   // root ops should have equal output shapes. An exception are
125   // reduction-to-vector ops. Here the input shapes of the reduction (first
126   // operand shape) and the reduction dimensions need to match.
127   auto* instr_1 = get_real_hero(&instr1);
128   auto* instr_2 = get_real_hero(&instr2);
129   // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
130   if (IsReductionToVector(*instr_1) && IsReductionToVector(*instr_2) &&
131       (!ShapeUtil::Equal(instr_1->shape(), instr_2->shape()) ||
132        instr_1->dimensions() != instr_2->dimensions())) {
133     return false;
134   }
135   // The elementwise output shapes must be the same (including layout).
136   // TODO(tjoerg): Further relax the constraint. The datatype does not matter.
137   return ShapeUtil::EqualIgnoringFpPrecision(get_loop_shape(instr_1),
138                                              get_loop_shape(instr_2));
139 }
140 
IsInputFusibleScatter(const HloInstruction & instr)141 bool IsInputFusibleScatter(const HloInstruction& instr) {
142   if (instr.opcode() == HloOpcode::kScatter ||
143       (instr.opcode() == HloOpcode::kFusion &&
144        instr.fusion_kind() == HloInstruction::FusionKind::kInput &&
145        instr.fused_expression_root()->opcode() == HloOpcode::kScatter)) {
146     return true;
147   }
148   return false;
149 }
150 
IsInputFusible(const HloInstruction & instr)151 bool IsInputFusible(const HloInstruction& instr) {
152   // Input fusion only handles non-elemental reduction and scatter operations.
153   return IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr);
154 }
155 
IsLoopFusible(const HloInstruction & instr)156 bool IsLoopFusible(const HloInstruction& instr) {
157   // Don't fuse get-tuple-element on GPU: We can, but it's slower than not
158   // fusing.  We never generate kernels for unfused GTEs.  Instead, if an
159   // unfused GTE is an input to a kernel (including a fusion kernel), we
160   // compute the address of the GTE at the top of the kernel.  Often we know the
161   // address of the GTE result statically, so we can do this without chasing any
162   // pointers.
163   return (instr.IsElementwise() && instr.operand_count() > 0) ||
164          instr.opcode() == HloOpcode::kBitcast ||
165          instr.opcode() == HloOpcode::kBroadcast ||
166          instr.opcode() == HloOpcode::kConcatenate ||
167          instr.opcode() == HloOpcode::kDynamicSlice ||
168          instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
169          (instr.opcode() == HloOpcode::kFusion &&
170           instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
171          instr.opcode() == HloOpcode::kGather ||
172          instr.opcode() == HloOpcode::kIota ||
173          instr.opcode() == HloOpcode::kPad ||
174          (instr.opcode() == HloOpcode::kReduce &&
175           !IsReductionToVector(instr)) ||
176          instr.opcode() == HloOpcode::kReduceWindow ||
177          instr.opcode() == HloOpcode::kReshape ||
178          instr.opcode() == HloOpcode::kReverse ||
179          instr.opcode() == HloOpcode::kSlice ||
180          instr.opcode() == HloOpcode::kTranspose;
181 }
182 
IsFusible(const HloInstruction & instr)183 bool IsFusible(const HloInstruction& instr) {
184   return IsInputFusible(instr) || IsLoopFusible(instr);
185 }
186 
187 }  // namespace gpu
188 }  // namespace xla
189