1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
17 
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/util.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 
23 namespace xla {
24 
Convert(llvm::ArrayRef<int64_t> elements,mlir::Builder * builder)25 static mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements,
26                                           mlir::Builder* builder) {
27   return mlir::DenseIntElementsAttr::get(
28       mlir::RankedTensorType::get(elements.size(), builder->getIntegerType(64)),
29       elements);
30 }
31 
ConvertPrecisionConfig(const PrecisionConfig * config,mlir::Builder * builder)32 mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
33                                        mlir::Builder* builder) {
34   if (!config) return {};
35 
36   // TODO(b/129709049) The HLO text format elides this in the all DEFAULT
37   // case and the parser sticks it in. Maybe we should too.
38   llvm::SmallVector<mlir::Attribute, 4> operand_precision_attrs;
39 
40   for (auto prec : config->operand_precision()) {
41     operand_precision_attrs.push_back(
42         builder->getStringAttr(PrecisionConfig_Precision_Name(prec)));
43   }
44   return builder->getArrayAttr(operand_precision_attrs);
45 }
46 
47 // Converts the gather dimensions to attributes.
ConvertGatherDimensionNumbers(const xla::GatherDimensionNumbers & dnums,mlir::Builder * builder)48 mlir::mhlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
49     const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder) {
50   std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
51                                    dnums.offset_dims().end());
52   std::vector<int64_t> collapsed_slice_dims(
53       dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
54   std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
55                                        dnums.start_index_map().end());
56   return mlir::mhlo::GatherDimensionNumbers::get(
57       Convert(offset_dims, builder), Convert(collapsed_slice_dims, builder),
58       Convert(start_index_map, builder),
59       builder->getI64IntegerAttr(dnums.index_vector_dim()),
60       builder->getContext());
61 }
62 
ConvertScatterDimensionNumbers(const xla::ScatterDimensionNumbers & dnums,mlir::Builder * builder)63 mlir::mhlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
64     const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder) {
65   std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
66                                           dnums.update_window_dims().end());
67   std::vector<int64_t> inserted_window_dims(
68       dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
69   std::vector<int64_t> scatter_dims_to_operand_dims(
70       dnums.scatter_dims_to_operand_dims().begin(),
71       dnums.scatter_dims_to_operand_dims().end());
72   return mlir::mhlo::ScatterDimensionNumbers::get(
73       Convert(update_window_dims, builder),
74       Convert(inserted_window_dims, builder),
75       Convert(scatter_dims_to_operand_dims, builder),
76       builder->getI64IntegerAttr(dnums.index_vector_dim()),
77       builder->getContext());
78 }
79 
ConvertDotDimensionNumbers(const DotDimensionNumbers & dnums,mlir::Builder * builder)80 mlir::mhlo::DotDimensionNumbers ConvertDotDimensionNumbers(
81     const DotDimensionNumbers& dnums, mlir::Builder* builder) {
82   std::vector<int64_t> rhs_contracting_dimensions(
83       dnums.rhs_contracting_dimensions().begin(),
84       dnums.rhs_contracting_dimensions().end());
85   std::vector<int64_t> lhs_contracting_dimensions(
86       dnums.lhs_contracting_dimensions().begin(),
87       dnums.lhs_contracting_dimensions().end());
88   std::vector<int64_t> rhs_batch_dimensions(
89       dnums.rhs_batch_dimensions().begin(), dnums.rhs_batch_dimensions().end());
90   std::vector<int64_t> lhs_batch_dimensions(
91       dnums.lhs_batch_dimensions().begin(), dnums.lhs_batch_dimensions().end());
92 
93   // Push the attributes into our new DictionaryAttr.
94   auto lhs_batch_dims_attr = Convert(lhs_batch_dimensions, builder);
95   auto rhs_batch_dims_attr = Convert(rhs_batch_dimensions, builder);
96   auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions, builder);
97   auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions, builder);
98 
99   return mlir::mhlo::DotDimensionNumbers::get(
100       lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr,
101       rhs_contracting_dims_attr, builder->getContext());
102 }
103 
ConvertConvDimensionNumbers(const xla::ConvolutionDimensionNumbers & dnums,mlir::Builder * builder)104 mlir::mhlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
105     const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder) {
106   llvm::SmallVector<int64_t, 4> input_spatial_dims(
107       dnums.input_spatial_dimensions().begin(),
108       dnums.input_spatial_dimensions().end());
109   llvm::SmallVector<int64_t, 4> kernel_spatial_dims(
110       dnums.kernel_spatial_dimensions().begin(),
111       dnums.kernel_spatial_dimensions().end());
112   llvm::SmallVector<int64_t, 4> output_spatial_dims(
113       dnums.output_spatial_dimensions().begin(),
114       dnums.output_spatial_dimensions().end());
115   return mlir::mhlo::ConvDimensionNumbers::get(
116       builder->getI64IntegerAttr(dnums.input_batch_dimension()),
117       builder->getI64IntegerAttr(dnums.input_feature_dimension()),
118       Convert(input_spatial_dims, builder),
119       builder->getI64IntegerAttr(dnums.kernel_input_feature_dimension()),
120       builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
121       Convert(kernel_spatial_dims, builder),
122       builder->getI64IntegerAttr(dnums.output_batch_dimension()),
123       builder->getI64IntegerAttr(dnums.output_feature_dimension()),
124       Convert(output_spatial_dims, builder), builder->getContext());
125 }
126 
ConvertFftType(FftType type)127 StatusOr<mlir::mhlo::FftType> ConvertFftType(FftType type) {
128   switch (type) {
129     case FftType::FFT:
130       return mlir::mhlo::FftType::FFT;
131     case FftType::IFFT:
132       return mlir::mhlo::FftType::IFFT;
133     case FftType::RFFT:
134       return mlir::mhlo::FftType::RFFT;
135     case FftType::IRFFT:
136       return mlir::mhlo::FftType::IRFFT;
137     default:
138       return InvalidArgument("Unknown FFT type enum value #%d", type);
139   }
140 }
141 
ConvertTranspose(xla::TriangularSolveOptions_Transpose transpose)142 StatusOr<mlir::mhlo::Transpose> ConvertTranspose(
143     xla::TriangularSolveOptions_Transpose transpose) {
144   switch (transpose) {
145     case TriangularSolveOptions::NO_TRANSPOSE:
146       return mlir::mhlo::Transpose::NO_TRANSPOSE;
147     case TriangularSolveOptions::TRANSPOSE:
148       return mlir::mhlo::Transpose::TRANSPOSE;
149     case TriangularSolveOptions::ADJOINT:
150       return mlir::mhlo::Transpose::ADJOINT;
151     case TriangularSolveOptions::TRANSPOSE_INVALID:
152       return mlir::mhlo::Transpose::TRANSPOSE_INVALID;
153     default:
154       return InvalidArgument("Unknown transpose enum value #%d", transpose);
155   }
156 }
157 
158 }  // namespace xla
159