1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
17 
18 #include "absl/types/optional.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 #include "tensorflow/compiler/xla/service/shape_inference.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 
24 namespace xla {
25 namespace dot_as_convolution_util {
26 
ConvSpatialDimensionIsParallel(const WindowDimension & wd,int64 lhs_size)27 bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size) {
28   // A parallel batch dimension in DotGeneral is represented as a
29   // spatial dimension with window size B (batch dimension size),
30   // stride B - 1, and base dilation B.
31   if (lhs_size == wd.size() && lhs_size == wd.base_dilation() &&
32       ((std::max<int64>(1, lhs_size - 1) == wd.stride() &&
33         wd.window_dilation() == 1) ||
34        (std::max<int64>(1, lhs_size - 1) == wd.window_dilation() &&
35         wd.stride() == 1)) &&
36       wd.padding_high() == 0 && wd.padding_low() == 0 &&
37       !wd.window_reversal()) {
38     return true;
39   }
40 
41   // Aternative representation of a batch dimension.
42   if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 &&
43       wd.padding_low() == lhs_size - 1 && wd.window_reversal() &&
44       wd.window_dilation() == 1 && wd.stride() == lhs_size &&
45       wd.base_dilation() == lhs_size - 1) {
46     return true;
47   }
48 
49   return false;
50 }
51 
ParseConvolutionDimsInfo(const HloInstruction * conv)52 /* static */ DotConvolutionDimsInfo ParseConvolutionDimsInfo(
53     const HloInstruction* conv) {
54   CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
55   const auto& conv_dims = conv->convolution_dimension_numbers();
56   DotConvolutionDimsInfo dims;
57   dims.lhs_non_contracting_dims.push_back(
58       {conv_dims.input_batch_dimension(), -1,
59        conv_dims.output_batch_dimension(), -1});
60   dims.rhs_non_contracting_dims.push_back(
61       {-1, conv_dims.kernel_output_feature_dimension(),
62        conv_dims.output_feature_dimension(), -1});
63   dims.contracting_dims.push_back({conv_dims.input_feature_dimension(),
64                                    conv_dims.kernel_input_feature_dimension(),
65                                    -1, -1});
66 
67   for (int64 i = 0; i < conv_dims.input_spatial_dimensions_size(); ++i) {
68     int64 lhs = conv_dims.input_spatial_dimensions(i);
69     int64 lhs_size = conv->operand(0)->shape().dimensions(lhs);
70     int64 rhs = conv_dims.kernel_spatial_dimensions(i);
71     int64 rhs_size = conv->operand(1)->shape().dimensions(rhs);
72     int64 output = conv_dims.output_spatial_dimensions(i);
73     const auto& wd = conv->window().dimensions(i);
74     if (ConvSpatialDimensionIsParallel(wd, lhs_size)) {
75       dims.batch_dims.push_back({lhs, rhs, output, i});
76     } else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
77                wd.window_dilation() == 1 && wd.padding_high() == 0 &&
78                wd.padding_low() == 0 && !wd.window_reversal()) {
79       // A contracting dimension be represented as a spatial dimension with
80       // window size C (contracting dimension size). Stride can be any size
81       // since there is only one window.
82       dims.contracting_dims.push_back({lhs, rhs, output, i});
83     } else if (wd.stride() == 1 && wd.window_dilation() == 1 &&
84                wd.base_dilation() == 1) {
85       if (rhs_size == 1 && wd.size() == 1 && wd.padding_high() == 0 &&
86           wd.padding_low() == 0 && !wd.window_reversal()) {
87         // A LHS non-contracting dimension can be represented as a spatial
88         // dimension with window size 1.
89         dims.lhs_non_contracting_dims.push_back({lhs, rhs, output, i});
90       } else if (lhs_size == 1 && wd.size() == rhs_size &&
91                  wd.padding_high() == rhs_size - 1 &&
92                  wd.padding_low() == rhs_size - 1 && wd.window_reversal()) {
93         // A RHS non-contracting dimension can be represented as a spatial
94         // dimension with window size N (non-contracting dimension size), low
95         // padding N - 1,  high padding N - 1 and window reversal.
96         dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i});
97       } else {
98         dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
99       }
100     } else {
101       dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
102     }
103   }
104 
105   return dims;
106 }
107 
108 StatusOr<std::unique_ptr<HloInstruction>>
CreateShardedConvForDotGeneralConvolution(const HloInstruction & conv,const DotConvolutionDimsInfo & dot_dnums,HloInstruction * sharded_lhs_hlo,HloInstruction * sharded_rhs_hlo)109 CreateShardedConvForDotGeneralConvolution(
110     const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
111     HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) {
112   CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
113   const auto& conv_dnums = conv.convolution_dimension_numbers();
114   auto window = conv.window();
115   for (const auto& dim : dot_dnums.batch_dims) {
116     auto wd = window.mutable_dimensions(dim.spatial_dim);
117     wd->set_size(sharded_lhs_hlo->shape().dimensions(
118         conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
119     wd->set_stride(std::max<int64>(1, wd->size() - 1));
120     wd->set_base_dilation(wd->size());
121   }
122   for (const auto& dim : dot_dnums.contracting_dims) {
123     if (dim.spatial_dim < 0) {
124       continue;
125     }
126     auto wd = window.mutable_dimensions(dim.spatial_dim);
127     wd->set_size(sharded_lhs_hlo->shape().dimensions(
128         conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
129   }
130   for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
131     if (dim.spatial_dim < 0) {
132       continue;
133     }
134     auto wd = window.mutable_dimensions(dim.spatial_dim);
135     wd->set_size(sharded_rhs_hlo->shape().dimensions(
136         conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
137     wd->set_padding_high(wd->size() - 1);
138     wd->set_padding_low(wd->size() - 1);
139   }
140   TF_ASSIGN_OR_RETURN(
141       Shape sharded_conv_shape,
142       ShapeInference::InferConvolveShape(
143           sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
144           /*feature_group_count=*/conv.feature_group_count(),
145           /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
146           /*preferred_element_type=*/conv.shape().element_type()));
147   *sharded_conv_shape.mutable_layout() = conv.shape().layout();
148   return HloInstruction::CreateConvolve(
149       sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo,
150       /*feature_group_count=*/conv.feature_group_count(),
151       /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
152       conv.precision_config());
153 }
154 
ParseDotGeneralFromDot(const HloInstruction * dot)155 DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) {
156   const auto& dot_dim_numbs = dot->dot_dimension_numbers();
157   dot_as_convolution_util::DotConvolutionDimsInfo dnums;
158   for (int64 i = 0; i < dot_dim_numbs.lhs_batch_dimensions().size(); ++i) {
159     dnums.batch_dims.emplace_back();
160     dnums.batch_dims.back().lhs = dot_dim_numbs.lhs_batch_dimensions(i);
161     dnums.batch_dims.back().rhs = dot_dim_numbs.rhs_batch_dimensions(i);
162     dnums.batch_dims.back().output = i;
163     dnums.batch_dims.back().spatial_dim = -1;
164   }
165   for (int64 i = 0; i < dot_dim_numbs.lhs_contracting_dimensions().size();
166        ++i) {
167     dnums.contracting_dims.emplace_back();
168     dnums.contracting_dims.back().lhs =
169         dot_dim_numbs.lhs_contracting_dimensions(i);
170     dnums.contracting_dims.back().rhs =
171         dot_dim_numbs.rhs_contracting_dimensions(i);
172     dnums.contracting_dims.back().output = -1;
173     dnums.contracting_dims.back().spatial_dim = -1;
174   }
175   for (int64 i = 0; i < dot->operand(0)->shape().rank(); ++i) {
176     if (!absl::c_linear_search(dot_dim_numbs.lhs_batch_dimensions(), i) &&
177         !absl::c_linear_search(dot_dim_numbs.lhs_contracting_dimensions(), i)) {
178       dnums.lhs_non_contracting_dims.emplace_back();
179       dnums.lhs_non_contracting_dims.back().lhs = i;
180       dnums.lhs_non_contracting_dims.back().rhs = -1;
181       dnums.lhs_non_contracting_dims.back().output =
182           dot_dim_numbs.lhs_batch_dimensions_size() +
183           dnums.lhs_non_contracting_dims.size() - 1;
184       dnums.lhs_non_contracting_dims.back().spatial_dim = -1;
185     }
186   }
187   for (int64 i = 0; i < dot->operand(1)->shape().rank(); ++i) {
188     if (!absl::c_linear_search(dot_dim_numbs.rhs_batch_dimensions(), i) &&
189         !absl::c_linear_search(dot_dim_numbs.rhs_contracting_dimensions(), i)) {
190       dnums.rhs_non_contracting_dims.emplace_back();
191       dnums.rhs_non_contracting_dims.back().lhs = -1;
192       dnums.rhs_non_contracting_dims.back().rhs = i;
193       dnums.rhs_non_contracting_dims.back().output =
194           dot_dim_numbs.lhs_batch_dimensions_size() +
195           dnums.lhs_non_contracting_dims.size() +
196           dnums.rhs_non_contracting_dims.size() - 1;
197       dnums.rhs_non_contracting_dims.back().spatial_dim = -1;
198     }
199   }
200   return dnums;
201 }
202 
203 }  // namespace dot_as_convolution_util
204 }  // namespace xla
205