1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/types/optional.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 
25 namespace xla {
26 namespace dot_as_convolution_util {
27 
28 // Describes the dimensions of a convolution that can be interpreted as a dot
29 // or a normal convolution.
30 struct DotConvolutionDimsInfo {
31   // The dimension numbers for the operands and output corresponding to a
32   // logical dimension (e.g., batch, contracting, non-contracting). If an
33   // operand or the output doesn't have the logical dimension, it is set to
34   // -1.
35   struct DimNums {
36     int64 lhs;
37     int64 rhs;
38     int64 output;
39     // The corresponding spatial dimension in the convolution's config. Set to
40     // -1 if it's not mapped to a spatial dimension.
41     int64 spatial_dim;
42   };
43   std::vector<DimNums> batch_dims;
44   std::vector<DimNums> contracting_dims;
45   std::vector<DimNums> lhs_non_contracting_dims;
46   std::vector<DimNums> rhs_non_contracting_dims;
47   std::vector<DimNums> conv_spatial_dims;
48 };
49 
50 // Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo. If it can
51 // be interpreted as a dot, there is no conv_spatial_dims.
52 DotConvolutionDimsInfo ParseConvolutionDimsInfo(const HloInstruction* conv);
53 
54 // Creates sharded convolution instruction that can be interpreted as a dot.
55 // This is a utility for per-op partitioners.
56 //  - 'conv' is the original convolution instruction.
57 //  - 'dot_dnums' is the result of ParseDotConvolutionDimsInfo() for 'conv'.
58 //  - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result
59 //    convolution instruction.
60 StatusOr<std::unique_ptr<HloInstruction>>
61 CreateShardedConvForDotGeneralConvolution(
62     const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
63     HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo);
64 
65 // Check if a spatial dim is parallel batch dimension.
66 // A parallel batch dimension in DotGeneral is represented as a spatial
67 // dimension with window size B (batch dimension size), stride B - 1, and base
68 // dilation B.
69 bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size);
70 
71 // Returns a DotConvolutionDimsInfo from a kDot instruction, where all
72 // the spatial_dim values are set to -1.
73 DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot);
74 
75 }  // namespace dot_as_convolution_util
76 }  // namespace xla
77 
78 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_
79