1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
17
18 #include "absl/types/optional.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 #include "tensorflow/compiler/xla/service/shape_inference.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23
24 namespace xla {
25 namespace dot_as_convolution_util {
26
ConvSpatialDimensionIsParallel(const WindowDimension & wd,int64 lhs_size)27 bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size) {
28 // A parallel batch dimension in DotGeneral is represented as a
29 // spatial dimension with window size B (batch dimension size),
30 // stride B - 1, and base dilation B.
31 if (lhs_size == wd.size() && lhs_size == wd.base_dilation() &&
32 ((std::max<int64>(1, lhs_size - 1) == wd.stride() &&
33 wd.window_dilation() == 1) ||
34 (std::max<int64>(1, lhs_size - 1) == wd.window_dilation() &&
35 wd.stride() == 1)) &&
36 wd.padding_high() == 0 && wd.padding_low() == 0 &&
37 !wd.window_reversal()) {
38 return true;
39 }
40
41 // Aternative representation of a batch dimension.
42 if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 &&
43 wd.padding_low() == lhs_size - 1 && wd.window_reversal() &&
44 wd.window_dilation() == 1 && wd.stride() == lhs_size &&
45 wd.base_dilation() == lhs_size - 1) {
46 return true;
47 }
48
49 return false;
50 }
51
ParseConvolutionDimsInfo(const HloInstruction * conv)52 /* static */ DotConvolutionDimsInfo ParseConvolutionDimsInfo(
53 const HloInstruction* conv) {
54 CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
55 const auto& conv_dims = conv->convolution_dimension_numbers();
56 DotConvolutionDimsInfo dims;
57 dims.lhs_non_contracting_dims.push_back(
58 {conv_dims.input_batch_dimension(), -1,
59 conv_dims.output_batch_dimension(), -1});
60 dims.rhs_non_contracting_dims.push_back(
61 {-1, conv_dims.kernel_output_feature_dimension(),
62 conv_dims.output_feature_dimension(), -1});
63 dims.contracting_dims.push_back({conv_dims.input_feature_dimension(),
64 conv_dims.kernel_input_feature_dimension(),
65 -1, -1});
66
67 for (int64 i = 0; i < conv_dims.input_spatial_dimensions_size(); ++i) {
68 int64 lhs = conv_dims.input_spatial_dimensions(i);
69 int64 lhs_size = conv->operand(0)->shape().dimensions(lhs);
70 int64 rhs = conv_dims.kernel_spatial_dimensions(i);
71 int64 rhs_size = conv->operand(1)->shape().dimensions(rhs);
72 int64 output = conv_dims.output_spatial_dimensions(i);
73 const auto& wd = conv->window().dimensions(i);
74 if (ConvSpatialDimensionIsParallel(wd, lhs_size)) {
75 dims.batch_dims.push_back({lhs, rhs, output, i});
76 } else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
77 wd.window_dilation() == 1 && wd.padding_high() == 0 &&
78 wd.padding_low() == 0 && !wd.window_reversal()) {
79 // A contracting dimension be represented as a spatial dimension with
80 // window size C (contracting dimension size). Stride can be any size
81 // since there is only one window.
82 dims.contracting_dims.push_back({lhs, rhs, output, i});
83 } else if (wd.stride() == 1 && wd.window_dilation() == 1 &&
84 wd.base_dilation() == 1) {
85 if (rhs_size == 1 && wd.size() == 1 && wd.padding_high() == 0 &&
86 wd.padding_low() == 0 && !wd.window_reversal()) {
87 // A LHS non-contracting dimension can be represented as a spatial
88 // dimension with window size 1.
89 dims.lhs_non_contracting_dims.push_back({lhs, rhs, output, i});
90 } else if (lhs_size == 1 && wd.size() == rhs_size &&
91 wd.padding_high() == rhs_size - 1 &&
92 wd.padding_low() == rhs_size - 1 && wd.window_reversal()) {
93 // A RHS non-contracting dimension can be represented as a spatial
94 // dimension with window size N (non-contracting dimension size), low
95 // padding N - 1, high padding N - 1 and window reversal.
96 dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i});
97 } else {
98 dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
99 }
100 } else {
101 dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
102 }
103 }
104
105 return dims;
106 }
107
108 StatusOr<std::unique_ptr<HloInstruction>>
CreateShardedConvForDotGeneralConvolution(const HloInstruction & conv,const DotConvolutionDimsInfo & dot_dnums,HloInstruction * sharded_lhs_hlo,HloInstruction * sharded_rhs_hlo)109 CreateShardedConvForDotGeneralConvolution(
110 const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
111 HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) {
112 CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
113 const auto& conv_dnums = conv.convolution_dimension_numbers();
114 auto window = conv.window();
115 for (const auto& dim : dot_dnums.batch_dims) {
116 auto wd = window.mutable_dimensions(dim.spatial_dim);
117 wd->set_size(sharded_lhs_hlo->shape().dimensions(
118 conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
119 wd->set_stride(std::max<int64>(1, wd->size() - 1));
120 wd->set_base_dilation(wd->size());
121 }
122 for (const auto& dim : dot_dnums.contracting_dims) {
123 if (dim.spatial_dim < 0) {
124 continue;
125 }
126 auto wd = window.mutable_dimensions(dim.spatial_dim);
127 wd->set_size(sharded_lhs_hlo->shape().dimensions(
128 conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
129 }
130 for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
131 if (dim.spatial_dim < 0) {
132 continue;
133 }
134 auto wd = window.mutable_dimensions(dim.spatial_dim);
135 wd->set_size(sharded_rhs_hlo->shape().dimensions(
136 conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
137 wd->set_padding_high(wd->size() - 1);
138 wd->set_padding_low(wd->size() - 1);
139 }
140 TF_ASSIGN_OR_RETURN(
141 Shape sharded_conv_shape,
142 ShapeInference::InferConvolveShape(
143 sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
144 /*feature_group_count=*/conv.feature_group_count(),
145 /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
146 /*preferred_element_type=*/conv.shape().element_type()));
147 *sharded_conv_shape.mutable_layout() = conv.shape().layout();
148 return HloInstruction::CreateConvolve(
149 sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo,
150 /*feature_group_count=*/conv.feature_group_count(),
151 /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
152 conv.precision_config());
153 }
154
ParseDotGeneralFromDot(const HloInstruction * dot)155 DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) {
156 const auto& dot_dim_numbs = dot->dot_dimension_numbers();
157 dot_as_convolution_util::DotConvolutionDimsInfo dnums;
158 for (int64 i = 0; i < dot_dim_numbs.lhs_batch_dimensions().size(); ++i) {
159 dnums.batch_dims.emplace_back();
160 dnums.batch_dims.back().lhs = dot_dim_numbs.lhs_batch_dimensions(i);
161 dnums.batch_dims.back().rhs = dot_dim_numbs.rhs_batch_dimensions(i);
162 dnums.batch_dims.back().output = i;
163 dnums.batch_dims.back().spatial_dim = -1;
164 }
165 for (int64 i = 0; i < dot_dim_numbs.lhs_contracting_dimensions().size();
166 ++i) {
167 dnums.contracting_dims.emplace_back();
168 dnums.contracting_dims.back().lhs =
169 dot_dim_numbs.lhs_contracting_dimensions(i);
170 dnums.contracting_dims.back().rhs =
171 dot_dim_numbs.rhs_contracting_dimensions(i);
172 dnums.contracting_dims.back().output = -1;
173 dnums.contracting_dims.back().spatial_dim = -1;
174 }
175 for (int64 i = 0; i < dot->operand(0)->shape().rank(); ++i) {
176 if (!absl::c_linear_search(dot_dim_numbs.lhs_batch_dimensions(), i) &&
177 !absl::c_linear_search(dot_dim_numbs.lhs_contracting_dimensions(), i)) {
178 dnums.lhs_non_contracting_dims.emplace_back();
179 dnums.lhs_non_contracting_dims.back().lhs = i;
180 dnums.lhs_non_contracting_dims.back().rhs = -1;
181 dnums.lhs_non_contracting_dims.back().output =
182 dot_dim_numbs.lhs_batch_dimensions_size() +
183 dnums.lhs_non_contracting_dims.size() - 1;
184 dnums.lhs_non_contracting_dims.back().spatial_dim = -1;
185 }
186 }
187 for (int64 i = 0; i < dot->operand(1)->shape().rank(); ++i) {
188 if (!absl::c_linear_search(dot_dim_numbs.rhs_batch_dimensions(), i) &&
189 !absl::c_linear_search(dot_dim_numbs.rhs_contracting_dimensions(), i)) {
190 dnums.rhs_non_contracting_dims.emplace_back();
191 dnums.rhs_non_contracting_dims.back().lhs = -1;
192 dnums.rhs_non_contracting_dims.back().rhs = i;
193 dnums.rhs_non_contracting_dims.back().output =
194 dot_dim_numbs.lhs_batch_dimensions_size() +
195 dnums.lhs_non_contracting_dims.size() +
196 dnums.rhs_non_contracting_dims.size() - 1;
197 dnums.rhs_non_contracting_dims.back().spatial_dim = -1;
198 }
199 }
200 return dnums;
201 }
202
203 } // namespace dot_as_convolution_util
204 } // namespace xla
205