1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "absl/types/optional.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 25 namespace xla { 26 namespace dot_as_convolution_util { 27 28 // Describes the dimensions of a convolution that can be interpreted as a dot 29 // or a normal convolution. 30 struct DotConvolutionDimsInfo { 31 // The dimension numbers for the operands and output corresponding to a 32 // logical dimension (e.g., batch, contracting, non-contracting). If an 33 // operand or the output doesn't have the logical dimension, it is set to 34 // -1. 35 struct DimNums { 36 int64 lhs; 37 int64 rhs; 38 int64 output; 39 // The corresponding spatial dimension in the convolution's config. Set to 40 // -1 if it's not mapped to a spatial dimension. 41 int64 spatial_dim; 42 }; 43 std::vector<DimNums> batch_dims; 44 std::vector<DimNums> contracting_dims; 45 std::vector<DimNums> lhs_non_contracting_dims; 46 std::vector<DimNums> rhs_non_contracting_dims; 47 std::vector<DimNums> conv_spatial_dims; 48 }; 49 50 // Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo. If it can 51 // be interpreted as a dot, there is no conv_spatial_dims. 52 DotConvolutionDimsInfo ParseConvolutionDimsInfo(const HloInstruction* conv); 53 54 // Creates sharded convolution instruction that can be interpreted as a dot. 55 // This is a utility for per-op partitioners. 56 // - 'conv' is the original convolution instruction. 57 // - 'dot_dnums' is the result of ParseDotConvolutionDimsInfo() for 'conv'. 58 // - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result 59 // convolution instruction. 60 StatusOr<std::unique_ptr<HloInstruction>> 61 CreateShardedConvForDotGeneralConvolution( 62 const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums, 63 HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo); 64 65 // Check if a spatial dim is parallel batch dimension. 66 // A parallel batch dimension in DotGeneral is represented as a spatial 67 // dimension with window size B (batch dimension size), stride B - 1, and base 68 // dilation B. 69 bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size); 70 71 // Returns a DotConvolutionDimsInfo from a kDot instruction, where all 72 // the spatial_dim values are set to -1. 73 DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot); 74 75 } // namespace dot_as_convolution_util 76 } // namespace xla 77 78 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ 79