1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/transpose_folding.h"
17
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/logging.h"
29
30 namespace xla {
31
32 namespace {
33
CanFoldOperandsIntoDot(const HloInstruction & dot,const TransposeFolding::TransposableGemmOperandsFn & transposable_gemm_operands)34 TransposeFolding::OperandIndices CanFoldOperandsIntoDot(
35 const HloInstruction& dot,
36 const TransposeFolding::TransposableGemmOperandsFn&
37 transposable_gemm_operands) {
38 if (HloOpcode::kDot != dot.opcode() ||
39 dot.dot_dimension_numbers().lhs_batch_dimensions_size() != 0) {
40 return {};
41 }
42
43 TransposeFolding::OperandIndices operand_set;
44 for (int64 i = 0; i < dot.operand_count(); ++i) {
45 auto& operand = *dot.operand(i);
46 if (operand.IsRank2Transpose()) {
47 operand_set.push_back(i);
48 } else if (operand.shape().rank() != 2) {
49 return {};
50 }
51 }
52
53 return transposable_gemm_operands(dot, operand_set);
54 }
55
CanFoldOperandsIntoConvolution(const HloInstruction & convolution,const TransposeFolding::TransposableConvOperandsFn & transposable_conv_operands)56 TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
57 const HloInstruction& convolution,
58 const TransposeFolding::TransposableConvOperandsFn&
59 transposable_conv_operands) {
60 if (HloOpcode::kConvolution != convolution.opcode()) {
61 return {};
62 }
63
64 TransposeFolding::OperandIndices operand_set;
65 for (int64 i = 0; i < convolution.operand_count(); ++i) {
66 auto& operand = *convolution.operand(i);
67 if (operand.opcode() == HloOpcode::kTranspose) {
68 operand_set.push_back(i);
69 }
70 }
71
72 return transposable_conv_operands(convolution, operand_set);
73 }
74
75 using InstructionOperandsPair =
76 std::pair<HloInstruction*, TransposeFolding::OperandIndices>;
77
78 // Folds the operands of `dot` that are foldable transposes. `computation` is
79 // the parent HLO computation of `dot`.
FoldTransposeIntoDot(InstructionOperandsPair pair)80 Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
81 HloInstruction* dot = pair.first;
82
83 DotDimensionNumbers new_dim_numbers = dot->dot_dimension_numbers();
84 HloInstruction* new_lhs = dot->mutable_operand(0);
85 HloInstruction* new_rhs = dot->mutable_operand(1);
86
87 CHECK_EQ(new_dim_numbers.lhs_batch_dimensions_size(), 0);
88 CHECK_EQ(new_dim_numbers.rhs_batch_dimensions_size(), 0);
89 CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1);
90 CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1);
91
92 for (int64 operand_index : pair.second) {
93 // We've checked that there aren't any batch dimensions and that the inputs
94 // are rank 2, and shape inference guarantees that there is exactly one
95 // contracting dimension.
96 if (operand_index == 0) {
97 CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose);
98 new_dim_numbers.set_lhs_contracting_dimensions(
99 0, 1 - new_dim_numbers.lhs_contracting_dimensions(0));
100 new_lhs = new_lhs->mutable_operand(0);
101 } else {
102 CHECK_EQ(operand_index, 1);
103 CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose);
104 new_dim_numbers.set_rhs_contracting_dimensions(
105 0, 1 - new_dim_numbers.rhs_contracting_dimensions(0));
106 new_rhs = new_rhs->mutable_operand(0);
107 }
108 }
109
110 std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
111 dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config());
112 return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
113 }
114
115 // Folds the operands of `convolution` that are foldable transposes.
116 // `computation` is the parent HLO computation of `convolution`.
117 //
118 // Returns whether the module is changed.
FoldTransposeIntoConvolution(InstructionOperandsPair pair)119 bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
120 auto& convolution = *pair.first;
121 auto& operand_indices = pair.second;
122
123 if (operand_indices.empty()) {
124 return false;
125 }
126
127 const ConvolutionDimensionNumbers& dnums =
128 convolution.convolution_dimension_numbers();
129 ConvolutionDimensionNumbers new_dnums = dnums;
130
131 HloInstruction* new_lhs;
132 const int64 kLhsIdx = 0;
133 if (absl::c_linear_search(operand_indices, kLhsIdx)) {
134 HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
135 const auto& transpose_dimensions = transpose.dimensions();
136 HloInstruction& transpose_operand = *transpose.mutable_operand(0);
137
138 // Everything remains the same except for the input/output dimension
139 // numbers. We need to apply the transpose permutation to the original shape
140 // to figure out what the new logical dimensions are.
141 new_dnums.set_input_batch_dimension(
142 transpose_dimensions[dnums.input_batch_dimension()]);
143 new_dnums.set_input_feature_dimension(
144 transpose_dimensions[dnums.input_feature_dimension()]);
145 for (auto& input_spatial_dimension :
146 *new_dnums.mutable_input_spatial_dimensions()) {
147 input_spatial_dimension = transpose_dimensions[input_spatial_dimension];
148 }
149 new_lhs = &transpose_operand;
150 } else {
151 new_lhs = convolution.mutable_operand(kLhsIdx);
152 }
153
154 HloInstruction* new_rhs;
155 const int64 kRhsIdx = 1;
156 if (absl::c_linear_search(operand_indices, kRhsIdx)) {
157 HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
158 const auto& transpose_dimensions = transpose.dimensions();
159 HloInstruction& transpose_operand = *transpose.mutable_operand(0);
160
161 // Everything remains the same except for the kernel dimension numbers. We
162 // need to apply the transpose permutation to the original shape to figure
163 // out what the new logical dimensions are.
164 new_dnums.set_kernel_input_feature_dimension(
165 transpose_dimensions[dnums.kernel_input_feature_dimension()]);
166 new_dnums.set_kernel_output_feature_dimension(
167 transpose_dimensions[dnums.kernel_output_feature_dimension()]);
168 for (auto& kernel_spatial_dimension :
169 *new_dnums.mutable_kernel_spatial_dimensions()) {
170 kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
171 }
172 new_rhs = &transpose_operand;
173 } else {
174 new_rhs = convolution.mutable_operand(kRhsIdx);
175 }
176
177 auto new_conv = HloInstruction::CreateConvolve(
178 convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(),
179 convolution.batch_group_count(), convolution.window(), new_dnums,
180 convolution.precision_config());
181 TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
182 &convolution, std::move(new_conv)));
183
184 return true;
185 }
186
187 } // namespace
188
TransposeFolding(TransposableGemmOperandsFn transposable_gemm_operands,TransposableConvOperandsFn transposable_conv_operands)189 TransposeFolding::TransposeFolding(
190 TransposableGemmOperandsFn transposable_gemm_operands,
191 TransposableConvOperandsFn transposable_conv_operands)
192 : transposable_gemm_operands_(std::move(transposable_gemm_operands)),
193 transposable_conv_operands_(std::move(transposable_conv_operands)) {}
194
Run(HloModule * module)195 StatusOr<bool> TransposeFolding::Run(HloModule* module) {
196 // Modifying the graph while traversing is dangerous, so we find all folding
197 // opportunities before actually folding them.
198 std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_dots;
199 std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_convolutions;
200 auto visit_fn = [this, &foldable_dots,
201 &foldable_convolutions](HloInstruction* instruction) {
202 {
203 OperandIndices operand_indices =
204 CanFoldOperandsIntoDot(*instruction, transposable_gemm_operands_);
205 if (!operand_indices.empty()) {
206 foldable_dots.emplace_back(instruction, operand_indices);
207 }
208 }
209 {
210 OperandIndices operand_indices = CanFoldOperandsIntoConvolution(
211 *instruction, transposable_conv_operands_);
212 if (!operand_indices.empty()) {
213 foldable_convolutions.emplace_back(
214 std::make_pair(instruction, operand_indices));
215 }
216 }
217 return Status::OK();
218 };
219
220 for (auto* comp : module->MakeNonfusionComputations()) {
221 TF_RETURN_IF_ERROR(comp->Accept(visit_fn));
222 }
223
224 bool changed = false;
225 for (InstructionOperandsPair& pair : foldable_dots) {
226 TF_RETURN_IF_ERROR(FoldTransposeIntoDot(pair));
227 changed = true;
228 }
229 for (InstructionOperandsPair& pair : foldable_convolutions) {
230 changed |= FoldTransposeIntoConvolution(pair);
231 }
232 return changed;
233 }
234
235 } // namespace xla
236