1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
16 #define TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
17 
18 #include <memory>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/Location.h"  // from @llvm-project
26 #include "mlir/IR/Operation.h"  // from @llvm-project
27 #include "mlir/IR/Value.h"  // from @llvm-project
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/shape.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/stream_executor/lib/statusor.h"
34 
35 namespace xla {
36 
37 // Provides a way to construct mhlo dialect ops in MLIR using XlaBuilder
38 // interface.
39 //
40 // Requires that all XlaOp arguments are either returned by any of the builder
41 // method or constructed using MakeXlaOp method in this builder.
42 //
43 // TODO(hinsu): Support more ops and utility functions to set special attributes
44 // like OpMetadata and Sharding.
45 class MlirHloBuilder : public XlaBuilder {
46  public:
47   // Constructs builder for the given function. New operations are added to the
48   // beginning of the function, if it is non empty and has a block.
MlirHloBuilder(mlir::FuncOp func)49   explicit MlirHloBuilder(mlir::FuncOp func)
50       : XlaBuilder(func.getName().str()),
51         builder_(&func.getBody()),
52         loc_(builder_.getUnknownLoc()) {}
53 
54   // TODO(hinsu): Add a constructor to build a new MLIR function from scratch
55   // and override Build methods.
56 
MlirHloBuilder(std::string name,mlir::OpBuilder builder,mlir::Location loc)57   MlirHloBuilder(std::string name, mlir::OpBuilder builder, mlir::Location loc)
58       : XlaBuilder(name), builder_(builder), loc_(loc) {}
59 
60   MlirHloBuilder(const MlirHloBuilder&) = delete;
61   MlirHloBuilder& operator=(const MlirHloBuilder&) = delete;
62 
63   ~MlirHloBuilder() override;
64 
65   // Wraps the given MLIR value under an XlaOp instance. Note that all HLO
66   // operations returns exactly one result therefore each op has an XlaOp
67   // wrapping result of the op.
68   //
69   // Returns an error if the HLO dialect doesn't support type of the given
70   // value.
71   StatusOr<XlaOp> MakeXlaOp(mlir::Value val);
72 
73   // Returns value corresponding to the given op.
74   //
75   // Requires that the op was created by this builder.
GetValue(XlaOp op)76   mlir::Value GetValue(XlaOp op) {
77     void* ptr = reinterpret_cast<void*>(op.handle());
78     return mlir::Value::getFromOpaquePointer(ptr);
79   }
80 
81   // Returns MLIR values corresponding to the given XLA ops.
82   //
83   // Requires that the ops were created by this builder.
GetValues(absl::Span<const XlaOp> ops)84   std::vector<mlir::Value> GetValues(absl::Span<const XlaOp> ops) {
85     std::vector<mlir::Value> values;
86     for (auto xla_op : ops) {
87       values.push_back(GetValue(xla_op));
88     }
89     return values;
90   }
91 
92   // Sets location for newly built ops, until reset.
SetLocation(mlir::Location loc)93   void SetLocation(mlir::Location loc) { loc_ = loc; }
94 
95   // Update insertion point so that newly built ops are inserted before the
96   // given op in order, until reset.
setInsertionPoint(mlir::Operation * op)97   void setInsertionPoint(mlir::Operation* op) {
98     builder_.setInsertionPoint(op);
99   }
100 
101   // Returns the shape of the given op.
102   StatusOr<const Shape*> GetShapePtr(XlaOp op) const override;
103 
104   // Creates the given op at the current location.
105   template <typename OpTy, typename... Args>
create(Args &&...args)106   OpTy create(Args&&... args) {
107     return builder_.create<OpTy>(loc_, std::forward<Args>(args)...);
108   }
109 
110  private:
111   XlaOp ConstantLiteral(const LiteralSlice& literal) override;
112 
113   StatusOr<XlaOp> ConvGeneralDilatedInternal(
114       const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
115       absl::Span<const int64> window_strides,
116       absl::Span<const std::pair<int64, int64>> padding,
117       absl::Span<const int64> lhs_dilation,
118       absl::Span<const int64> rhs_dilation,
119       const ConvolutionDimensionNumbers& dimension_numbers,
120       int64 feature_group_count, int64 batch_group_count,
121       const PrecisionConfig* precision_config) override;
122 
123   StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
124                               FftType fft_type,
125                               absl::Span<const int64> fft_length) override;
126 
127   StatusOr<XlaOp> TriangularSolveInternal(
128       const Shape& shape, XlaOp a, XlaOp b,
129       TriangularSolveOptions options) override;
130 
131   StatusOr<XlaOp> CholeskyInternal(const Shape& shape, XlaOp a,
132                                    bool lower) override;
133 
134   StatusOr<XlaOp> CustomCallInternal(
135       const string& call_target_name, absl::Span<const XlaOp> operands,
136       const Shape& shape, const string& opaque,
137       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
138       bool has_side_effect,
139       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
140           output_operand_aliasing,
141       const Literal* literal) override;
142 
143   StatusOr<XlaOp> ReduceInternal(
144       const Shape& shape, absl::Span<const XlaOp> all_operands,
145       const XlaComputation& computation,
146       absl::Span<const int64> dimensions_to_reduce) override;
147 
148   StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape, XlaOp operand,
149                                        XlaOp init_value,
150                                        const XlaComputation& computation,
151                                        Window window) override;
152 
153   XlaOp Iota(const Shape& shape, int64 iota_dimension) override;
154 
155   StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
156                                              XlaOp operand) override;
157 
158   StatusOr<XlaOp> TransposeInternal(
159       const Shape& shape, XlaOp operand,
160       absl::Span<const int64> permutation) override;
161 
162   StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
163                               absl::Span<const int64> dimensions) override;
164 
165   StatusOr<XlaOp> SortInternal(const Shape& shape,
166                                absl::Span<const XlaOp> operands,
167                                const XlaComputation& comparator,
168                                int64 dimension, bool is_stable) override;
169 
170   StatusOr<XlaOp> WhileInternal(const Shape& shape,
171                                 const XlaComputation& condition,
172                                 const XlaComputation& body,
173                                 XlaOp init) override;
174 
175   StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape, XlaOp operand,
176                                           const int exponent_bits,
177                                           const int mantissa_bits) override;
178 
179   StatusOr<XlaOp> GatherInternal(
180       const Shape& shape, XlaOp input, XlaOp start_indices,
181       const GatherDimensionNumbers& dimension_numbers,
182       absl::Span<const int64> slice_sizes, bool indices_are_sorted) override;
183 
184   StatusOr<XlaOp> ScatterInternal(
185       const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
186       const XlaComputation& update_computation,
187       const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
188       bool unique_indices) override;
189 
190   StatusOr<XlaOp> SetDimensionSizeInternal(const Shape& shape, XlaOp operand,
191                                            XlaOp val, int64 dimension) override;
192 
193   StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
194                                 absl::Span<const XlaOp> parameters,
195                                 const Shape& shape) override;
196   StatusOr<XlaOp> RngBitGeneratorInternal(const Shape& full_result_shape,
197                                           RandomAlgorithm algorithm,
198                                           XlaOp initial_state) override;
199 
200   StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
201                                   int64 inferred_dimension) override;
202 
203   StatusOr<XlaOp> DotGeneralInternal(
204       const Shape& shape, XlaOp lhs, XlaOp rhs,
205       const DotDimensionNumbers& dimension_number,
206       const PrecisionConfig* precision_config) override;
207 
208   StatusOr<XlaOp> InDimBroadcast(
209       const Shape& shape, XlaOp operand,
210       absl::Span<const int64> broadcast_dimensions) override;
211 
212   StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
213                                  absl::Span<const XlaOp> operands) override;
214 
215   StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
216                           ComparisonDirection direction,
217                           Comparison::Type type) override;
218 
219   XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs,
220                             XlaOp rhs) override;
221 
222   StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
223                                  absl::Span<const XlaOp> operands) override;
224 
225   XlaOp CreateToken() override;
226 
227   StatusOr<XlaOp> InfeedWithTokenInternal(const Shape& infeed_instruction_shape,
228                                           XlaOp token,
229                                           const string& config) override;
230   StatusOr<XlaOp> OutfeedWithTokenInternal(
231       XlaOp operand, XlaOp token, const Shape& shape_with_layout,
232       const string& outfeed_config) override;
233 
234   StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
235                                       absl::Span<const XlaOp> operands,
236                                       int64 dimension) override;
237 
238   StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape, XlaOp tuple_data,
239                                           int64 index) override;
240 
241   StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
242                                 absl::Span<const int64> start_indices,
243                                 absl::Span<const int64> limit_indices,
244                                 absl::Span<const int64> strides) override;
245 
246   StatusOr<XlaOp> DynamicSliceInternal(
247       const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
248       absl::Span<const int64> slice_sizes) override;
249 
250   StatusOr<XlaOp> DynamicUpdateSliceInternal(
251       const Shape& shape, XlaOp operand, XlaOp update,
252       absl::Span<const XlaOp> start_indices) override;
253 
254   StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
255                               XlaOp padding_value,
256                               const PaddingConfig& padding_config) override;
257 
258   StatusOr<XlaOp> TupleInternal(const Shape& shape,
259                                 absl::Span<const XlaOp> elements) override;
260 
261   // Creates HLO dialect op and returns the result as an XlaOp.
262   StatusOr<XlaOp> CreateOp(
263       const std::string& op_name, const Shape& shape,
264       llvm::ArrayRef<XlaOp> operands,
265       llvm::ArrayRef<mlir::NamedAttribute> attributes = {});
266 
267   Status ImportComputation(const HloModuleProto& computation,
268                            mlir::Region* region);
269 
270   mlir::OpBuilder builder_;
271   mlir::Location loc_;
272 
273   absl::flat_hash_map<int64, std::unique_ptr<Shape>> handle_to_shape_;
274 };
275 
276 }  // namespace xla
277 
278 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
279