1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
17
18 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
19 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27
28 namespace xla {
29 namespace cpu {
30
Run(HloModule * module)31 StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
32 bool changed = false;
33 for (HloInstruction* hlo :
34 module->entry_computation()->MakeInstructionPostOrder()) {
35 if (hlo->opcode() == HloOpcode::kConvolution &&
36 !PotentiallyImplementedAsEigenConvolution(*hlo,
37 target_machine_features_)) {
38 const ConvolutionDimensionNumbers& dnums =
39 hlo->convolution_dimension_numbers();
40 auto input_batch_dim = dnums.input_batch_dimension();
41 auto input_feature_dim = dnums.input_feature_dimension();
42 auto kernel_input_feature_dim = dnums.kernel_input_feature_dimension();
43 auto kernel_output_feature_dim = dnums.kernel_output_feature_dimension();
44
45 const int64 num_spatial_dims = dnums.output_spatial_dimensions_size();
46 const int64 num_dims = num_spatial_dims + 2;
47
48 // A canonical convolution's dimension numbers need to satisfy the
49 // following conditions (see cs/PotentiallyImplementedAsEigenConvolution).
50 //
51 // - the input is in NHWC order.
52 // - the kernel is in HWIO order.
53 //
54 // For simplicity, as a first step, we reshape the input and filter to
55 // NHWC and HWIO order, respectively. This may lose precision but won't
56 // break the soundness.
57 HloInstruction* input = hlo->mutable_operand(0);
58
59 std::vector<int64> new_input_dim_order(num_dims);
60 std::vector<int64> new_input_dims(num_dims);
61 new_input_dim_order[0] = input_batch_dim;
62 new_input_dims[0] = input->shape().dimensions(input_batch_dim);
63 for (int64 i = 0; i < num_spatial_dims; ++i) {
64 new_input_dim_order[i + 1] = dnums.input_spatial_dimensions(i);
65 new_input_dims[i + 1] =
66 input->shape().dimensions(dnums.input_spatial_dimensions(i));
67 }
68 new_input_dim_order[num_dims - 1] = input_feature_dim;
69 new_input_dims[num_dims - 1] =
70 input->shape().dimensions(input_feature_dim);
71
72 Shape new_input_shape =
73 ShapeUtil::MakeShape(input->shape().element_type(), new_input_dims);
74 HloInstruction* new_input = module->entry_computation()->AddInstruction(
75 HloInstruction::CreateTranspose(new_input_shape, input,
76 new_input_dim_order));
77
78 HloInstruction* kernel = hlo->mutable_operand(1);
79
80 std::vector<int64> new_kernel_dim_order(num_dims);
81 std::vector<int64> new_kernel_dims(num_dims);
82 for (int64 i = 0; i < num_spatial_dims; ++i) {
83 new_kernel_dim_order[i] = dnums.kernel_spatial_dimensions(i);
84 new_kernel_dims[i] =
85 kernel->shape().dimensions(dnums.kernel_spatial_dimensions(i));
86 }
87 new_kernel_dim_order[num_dims - 2] = kernel_input_feature_dim;
88 new_kernel_dims[num_dims - 2] =
89 kernel->shape().dimensions(kernel_input_feature_dim);
90 new_kernel_dim_order[num_dims - 1] = kernel_output_feature_dim;
91 new_kernel_dims[num_dims - 1] =
92 kernel->shape().dimensions(kernel_output_feature_dim);
93
94 Shape new_kernel_shape =
95 ShapeUtil::MakeShape(kernel->shape().element_type(), new_kernel_dims);
96 HloInstruction* new_kernel = module->entry_computation()->AddInstruction(
97 HloInstruction::CreateTranspose(new_kernel_shape, kernel,
98 new_kernel_dim_order));
99
100 std::vector<int64> new_output_dim_order(num_dims);
101 std::vector<int64> new_conv_dims(num_dims);
102 auto output_batch_dim = dnums.output_batch_dimension();
103 auto output_feature_dim = dnums.output_feature_dimension();
104 new_output_dim_order[0] = output_batch_dim;
105 new_conv_dims[0] = hlo->shape().dimensions(output_batch_dim);
106 for (int64 i = 0; i < num_spatial_dims; ++i) {
107 new_output_dim_order[i + 1] = dnums.output_spatial_dimensions(i);
108 new_conv_dims[i + 1] =
109 hlo->shape().dimensions(dnums.output_spatial_dimensions(i));
110 }
111 new_output_dim_order[num_dims - 1] = output_feature_dim;
112 new_conv_dims[num_dims - 1] = hlo->shape().dimensions(output_feature_dim);
113 Shape new_conv_shape =
114 ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims);
115
116 ConvolutionDimensionNumbers new_dnums;
117 new_dnums.set_input_batch_dimension(0);
118 new_dnums.set_output_batch_dimension(0);
119 for (int64 i = 0; i < num_spatial_dims; ++i) {
120 new_dnums.add_input_spatial_dimensions(i + 1);
121 new_dnums.add_kernel_spatial_dimensions(i);
122 new_dnums.add_output_spatial_dimensions(i + 1);
123 }
124 new_dnums.set_input_feature_dimension(num_dims - 1);
125 new_dnums.set_output_feature_dimension(num_dims - 1);
126 new_dnums.set_kernel_input_feature_dimension(num_dims - 2);
127 new_dnums.set_kernel_output_feature_dimension(num_dims - 1);
128
129 // The window of the old convolution is reused, because reshapes only
130 // change the dimension mapping but not the dimension sizes. For
131 // example, input height and width are the same as before the reshapes.
132 HloInstruction* new_conv = module->entry_computation()->AddInstruction(
133 HloInstruction::CreateConvolve(
134 new_conv_shape, new_input, new_kernel, hlo->feature_group_count(),
135 hlo->batch_group_count(), hlo->window(), new_dnums,
136 hlo->precision_config()));
137
138 // Reshape the output back to the shape of the original convolution.
139 TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction(
140 hlo, HloInstruction::CreateTranspose(
141 hlo->shape(), new_conv,
142 InversePermutation(new_output_dim_order))));
143 changed = true;
144 }
145 }
146
147 return changed;
148 }
149
150 } // namespace cpu
151 } // namespace xla
152