• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
17 
18 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
19 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 
28 namespace xla {
29 namespace cpu {
30 
Run(HloModule * module)31 StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
32   bool changed = false;
33   for (HloInstruction* hlo :
34        module->entry_computation()->MakeInstructionPostOrder()) {
35     if (hlo->opcode() == HloOpcode::kConvolution &&
36         !PotentiallyImplementedAsEigenConvolution(*hlo,
37                                                   target_machine_features_)) {
38       const ConvolutionDimensionNumbers& dnums =
39           hlo->convolution_dimension_numbers();
40       auto input_batch_dim = dnums.input_batch_dimension();
41       auto input_feature_dim = dnums.input_feature_dimension();
42       auto kernel_input_feature_dim = dnums.kernel_input_feature_dimension();
43       auto kernel_output_feature_dim = dnums.kernel_output_feature_dimension();
44 
45       const int64 num_spatial_dims = dnums.output_spatial_dimensions_size();
46       const int64 num_dims = num_spatial_dims + 2;
47 
48       // A canonical convolution's dimension numbers need to satisfy the
49       // following conditions (see cs/PotentiallyImplementedAsEigenConvolution).
50       //
51       // - the input is in NHWC order.
52       // - the kernel is in HWIO order.
53       //
54       // For simplicity, as a first step, we reshape the input and filter to
55       // NHWC and HWIO order, respectively. This may lose precision but won't
56       // break the soundness.
57       HloInstruction* input = hlo->mutable_operand(0);
58 
59       std::vector<int64> new_input_dim_order(num_dims);
60       std::vector<int64> new_input_dims(num_dims);
61       new_input_dim_order[0] = input_batch_dim;
62       new_input_dims[0] = input->shape().dimensions(input_batch_dim);
63       for (int64 i = 0; i < num_spatial_dims; ++i) {
64         new_input_dim_order[i + 1] = dnums.input_spatial_dimensions(i);
65         new_input_dims[i + 1] =
66             input->shape().dimensions(dnums.input_spatial_dimensions(i));
67       }
68       new_input_dim_order[num_dims - 1] = input_feature_dim;
69       new_input_dims[num_dims - 1] =
70           input->shape().dimensions(input_feature_dim);
71 
72       Shape new_input_shape =
73           ShapeUtil::MakeShape(input->shape().element_type(), new_input_dims);
74       HloInstruction* new_input = module->entry_computation()->AddInstruction(
75           HloInstruction::CreateTranspose(new_input_shape, input,
76                                           new_input_dim_order));
77 
78       HloInstruction* kernel = hlo->mutable_operand(1);
79 
80       std::vector<int64> new_kernel_dim_order(num_dims);
81       std::vector<int64> new_kernel_dims(num_dims);
82       for (int64 i = 0; i < num_spatial_dims; ++i) {
83         new_kernel_dim_order[i] = dnums.kernel_spatial_dimensions(i);
84         new_kernel_dims[i] =
85             kernel->shape().dimensions(dnums.kernel_spatial_dimensions(i));
86       }
87       new_kernel_dim_order[num_dims - 2] = kernel_input_feature_dim;
88       new_kernel_dims[num_dims - 2] =
89           kernel->shape().dimensions(kernel_input_feature_dim);
90       new_kernel_dim_order[num_dims - 1] = kernel_output_feature_dim;
91       new_kernel_dims[num_dims - 1] =
92           kernel->shape().dimensions(kernel_output_feature_dim);
93 
94       Shape new_kernel_shape =
95           ShapeUtil::MakeShape(kernel->shape().element_type(), new_kernel_dims);
96       HloInstruction* new_kernel = module->entry_computation()->AddInstruction(
97           HloInstruction::CreateTranspose(new_kernel_shape, kernel,
98                                           new_kernel_dim_order));
99 
100       std::vector<int64> new_output_dim_order(num_dims);
101       std::vector<int64> new_conv_dims(num_dims);
102       auto output_batch_dim = dnums.output_batch_dimension();
103       auto output_feature_dim = dnums.output_feature_dimension();
104       new_output_dim_order[0] = output_batch_dim;
105       new_conv_dims[0] = hlo->shape().dimensions(output_batch_dim);
106       for (int64 i = 0; i < num_spatial_dims; ++i) {
107         new_output_dim_order[i + 1] = dnums.output_spatial_dimensions(i);
108         new_conv_dims[i + 1] =
109             hlo->shape().dimensions(dnums.output_spatial_dimensions(i));
110       }
111       new_output_dim_order[num_dims - 1] = output_feature_dim;
112       new_conv_dims[num_dims - 1] = hlo->shape().dimensions(output_feature_dim);
113       Shape new_conv_shape =
114           ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims);
115 
116       ConvolutionDimensionNumbers new_dnums;
117       new_dnums.set_input_batch_dimension(0);
118       new_dnums.set_output_batch_dimension(0);
119       for (int64 i = 0; i < num_spatial_dims; ++i) {
120         new_dnums.add_input_spatial_dimensions(i + 1);
121         new_dnums.add_kernel_spatial_dimensions(i);
122         new_dnums.add_output_spatial_dimensions(i + 1);
123       }
124       new_dnums.set_input_feature_dimension(num_dims - 1);
125       new_dnums.set_output_feature_dimension(num_dims - 1);
126       new_dnums.set_kernel_input_feature_dimension(num_dims - 2);
127       new_dnums.set_kernel_output_feature_dimension(num_dims - 1);
128 
129       // The window of the old convolution is reused, because reshapes only
130       // change the dimension mapping but not the dimension sizes. For
131       // example, input height and width are the same as before the reshapes.
132       HloInstruction* new_conv = module->entry_computation()->AddInstruction(
133           HloInstruction::CreateConvolve(
134               new_conv_shape, new_input, new_kernel, hlo->feature_group_count(),
135               hlo->batch_group_count(), hlo->window(), new_dnums,
136               hlo->precision_config()));
137 
138       // Reshape the output back to the shape of the original convolution.
139       TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction(
140           hlo, HloInstruction::CreateTranspose(
141                    hlo->shape(), new_conv,
142                    InversePermutation(new_output_dim_order))));
143       changed = true;
144     }
145   }
146 
147   return changed;
148 }
149 
150 }  // namespace cpu
151 }  // namespace xla
152