1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/transpose_folding.h"
17 
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/logging.h"
29 
30 namespace xla {
31 
32 namespace {
33 
CanFoldOperandsIntoDot(const HloInstruction & dot,const TransposeFolding::TransposableGemmOperandsFn & transposable_gemm_operands)34 TransposeFolding::OperandIndices CanFoldOperandsIntoDot(
35     const HloInstruction& dot,
36     const TransposeFolding::TransposableGemmOperandsFn&
37         transposable_gemm_operands) {
38   if (HloOpcode::kDot != dot.opcode() ||
39       dot.dot_dimension_numbers().lhs_batch_dimensions_size() != 0) {
40     return {};
41   }
42 
43   TransposeFolding::OperandIndices operand_set;
44   for (int64 i = 0; i < dot.operand_count(); ++i) {
45     auto& operand = *dot.operand(i);
46     if (operand.IsRank2Transpose()) {
47       operand_set.push_back(i);
48     } else if (operand.shape().rank() != 2) {
49       return {};
50     }
51   }
52 
53   return transposable_gemm_operands(dot, operand_set);
54 }
55 
CanFoldOperandsIntoConvolution(const HloInstruction & convolution,const TransposeFolding::TransposableConvOperandsFn & transposable_conv_operands)56 TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
57     const HloInstruction& convolution,
58     const TransposeFolding::TransposableConvOperandsFn&
59         transposable_conv_operands) {
60   if (HloOpcode::kConvolution != convolution.opcode()) {
61     return {};
62   }
63 
64   TransposeFolding::OperandIndices operand_set;
65   for (int64 i = 0; i < convolution.operand_count(); ++i) {
66     auto& operand = *convolution.operand(i);
67     if (operand.opcode() == HloOpcode::kTranspose) {
68       operand_set.push_back(i);
69     }
70   }
71 
72   return transposable_conv_operands(convolution, operand_set);
73 }
74 
75 using InstructionOperandsPair =
76     std::pair<HloInstruction*, TransposeFolding::OperandIndices>;
77 
78 // Folds the operands of `dot` that are foldable transposes. `computation` is
79 // the parent HLO computation of `dot`.
FoldTransposeIntoDot(InstructionOperandsPair pair)80 Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
81   HloInstruction* dot = pair.first;
82 
83   DotDimensionNumbers new_dim_numbers = dot->dot_dimension_numbers();
84   HloInstruction* new_lhs = dot->mutable_operand(0);
85   HloInstruction* new_rhs = dot->mutable_operand(1);
86 
87   CHECK_EQ(new_dim_numbers.lhs_batch_dimensions_size(), 0);
88   CHECK_EQ(new_dim_numbers.rhs_batch_dimensions_size(), 0);
89   CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1);
90   CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1);
91 
92   for (int64 operand_index : pair.second) {
93     // We've checked that there aren't any batch dimensions and that the inputs
94     // are rank 2, and shape inference guarantees that there is exactly one
95     // contracting dimension.
96     if (operand_index == 0) {
97       CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose);
98       new_dim_numbers.set_lhs_contracting_dimensions(
99           0, 1 - new_dim_numbers.lhs_contracting_dimensions(0));
100       new_lhs = new_lhs->mutable_operand(0);
101     } else {
102       CHECK_EQ(operand_index, 1);
103       CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose);
104       new_dim_numbers.set_rhs_contracting_dimensions(
105           0, 1 - new_dim_numbers.rhs_contracting_dimensions(0));
106       new_rhs = new_rhs->mutable_operand(0);
107     }
108   }
109 
110   std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
111       dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config());
112   return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
113 }
114 
115 // Folds the operands of `convolution` that are foldable transposes.
116 // `computation` is the parent HLO computation of `convolution`.
117 //
118 // Returns whether the module is changed.
FoldTransposeIntoConvolution(InstructionOperandsPair pair)119 bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
120   auto& convolution = *pair.first;
121   auto& operand_indices = pair.second;
122 
123   if (operand_indices.empty()) {
124     return false;
125   }
126 
127   const ConvolutionDimensionNumbers& dnums =
128       convolution.convolution_dimension_numbers();
129   ConvolutionDimensionNumbers new_dnums = dnums;
130 
131   HloInstruction* new_lhs;
132   const int64 kLhsIdx = 0;
133   if (absl::c_linear_search(operand_indices, kLhsIdx)) {
134     HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
135     const auto& transpose_dimensions = transpose.dimensions();
136     HloInstruction& transpose_operand = *transpose.mutable_operand(0);
137 
138     // Everything remains the same except for the input/output dimension
139     // numbers. We need to apply the transpose permutation to the original shape
140     // to figure out what the new logical dimensions are.
141     new_dnums.set_input_batch_dimension(
142         transpose_dimensions[dnums.input_batch_dimension()]);
143     new_dnums.set_input_feature_dimension(
144         transpose_dimensions[dnums.input_feature_dimension()]);
145     for (auto& input_spatial_dimension :
146          *new_dnums.mutable_input_spatial_dimensions()) {
147       input_spatial_dimension = transpose_dimensions[input_spatial_dimension];
148     }
149     new_lhs = &transpose_operand;
150   } else {
151     new_lhs = convolution.mutable_operand(kLhsIdx);
152   }
153 
154   HloInstruction* new_rhs;
155   const int64 kRhsIdx = 1;
156   if (absl::c_linear_search(operand_indices, kRhsIdx)) {
157     HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
158     const auto& transpose_dimensions = transpose.dimensions();
159     HloInstruction& transpose_operand = *transpose.mutable_operand(0);
160 
161     // Everything remains the same except for the kernel dimension numbers. We
162     // need to apply the transpose permutation to the original shape to figure
163     // out what the new logical dimensions are.
164     new_dnums.set_kernel_input_feature_dimension(
165         transpose_dimensions[dnums.kernel_input_feature_dimension()]);
166     new_dnums.set_kernel_output_feature_dimension(
167         transpose_dimensions[dnums.kernel_output_feature_dimension()]);
168     for (auto& kernel_spatial_dimension :
169          *new_dnums.mutable_kernel_spatial_dimensions()) {
170       kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
171     }
172     new_rhs = &transpose_operand;
173   } else {
174     new_rhs = convolution.mutable_operand(kRhsIdx);
175   }
176 
177   auto new_conv = HloInstruction::CreateConvolve(
178       convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(),
179       convolution.batch_group_count(), convolution.window(), new_dnums,
180       convolution.precision_config());
181   TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
182       &convolution, std::move(new_conv)));
183 
184   return true;
185 }
186 
187 }  // namespace
188 
TransposeFolding(TransposableGemmOperandsFn transposable_gemm_operands,TransposableConvOperandsFn transposable_conv_operands)189 TransposeFolding::TransposeFolding(
190     TransposableGemmOperandsFn transposable_gemm_operands,
191     TransposableConvOperandsFn transposable_conv_operands)
192     : transposable_gemm_operands_(std::move(transposable_gemm_operands)),
193       transposable_conv_operands_(std::move(transposable_conv_operands)) {}
194 
Run(HloModule * module)195 StatusOr<bool> TransposeFolding::Run(HloModule* module) {
196   // Modifying the graph while traversing is dangerous, so we find all folding
197   // opportunities before actually folding them.
198   std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_dots;
199   std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_convolutions;
200   auto visit_fn = [this, &foldable_dots,
201                    &foldable_convolutions](HloInstruction* instruction) {
202     {
203       OperandIndices operand_indices =
204           CanFoldOperandsIntoDot(*instruction, transposable_gemm_operands_);
205       if (!operand_indices.empty()) {
206         foldable_dots.emplace_back(instruction, operand_indices);
207       }
208     }
209     {
210       OperandIndices operand_indices = CanFoldOperandsIntoConvolution(
211           *instruction, transposable_conv_operands_);
212       if (!operand_indices.empty()) {
213         foldable_convolutions.emplace_back(
214             std::make_pair(instruction, operand_indices));
215       }
216     }
217     return Status::OK();
218   };
219 
220   for (auto* comp : module->MakeNonfusionComputations()) {
221     TF_RETURN_IF_ERROR(comp->Accept(visit_fn));
222   }
223 
224   bool changed = false;
225   for (InstructionOperandsPair& pair : foldable_dots) {
226     TF_RETURN_IF_ERROR(FoldTransposeIntoDot(pair));
227     changed = true;
228   }
229   for (InstructionOperandsPair& pair : foldable_convolutions) {
230     changed |= FoldTransposeIntoConvolution(pair);
231   }
232   return changed;
233 }
234 
235 }  // namespace xla
236