1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14
15 ==============================================================================*/
16
17 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/Shape/IR/Shape.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Transforms/DialectConversion.h"
32
33 namespace mlir {
34 namespace {
35
36 // TODO(herhut): Generate these out of op definitions.
37 #define MAP_XLA_OPERATION_CWISE_UNARY(fn, sep) \
38 fn(AbsOp) sep fn(CeilOp) sep fn(ClzOp) sep fn(ConvertOp) sep fn(CosOp) \
39 sep fn(ExpOp) sep fn(Expm1Op) sep fn(FloorOp) sep fn(ImagOp) \
40 sep fn(IsFiniteOp) sep fn(LogOp) sep fn(Log1pOp) sep fn(LogisticOp) \
41 sep fn(NotOp) sep fn(NegOp) sep fn(PopulationCountOp) \
42 sep fn(RealOp) sep fn(RoundOp) sep fn(RsqrtOp) \
43 sep fn(SignOp) sep fn(SinOp) sep fn(SqrtOp) \
44 sep fn(TanhOp)
45
46 // TODO(herhut): Generate these out of op definitions.
47 #define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
48 fn(AddOp) sep fn(AndOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) \
49 sep fn(MaxOp) sep fn(MinOp) sep fn(MulOp) sep fn(OrOp) sep fn(PowOp) \
50 sep fn(RemOp) sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
51 sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
52
53 // TODO(herhut): Generate these out of op definitions.
54 #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
55 fn(AcosOp) sep fn(AcoshOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) \
56 sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(DigammaOp) \
57 sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \
58 sep fn(SinhOp) sep fn(TanOp)
59
60 // TODO(herhut): Generate these out of op definitions.
61 #define MAP_CHLO_OPERATION_CWISE_BINARY(fn, sep) fn(ZetaOp)
62
63 template <typename OpTy>
AddLegalOpOnRankedTensor(ConversionTarget * target)64 inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
65 target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
66 return llvm::all_of(op.getOperation()->getOperandTypes(),
67 [&](Type t) { return t.isa<RankedTensorType>(); });
68 });
69 }
70
71 /// Element-wise operations on unranked tensors can be applied to the flattened
72 /// tensor operands with the same effect. This pattern rewrites every such
73 /// operation to
74 /// (i) flatten the input tensor,
75 /// (ii) apply the operation, and
76 /// (iii) restore the original shape.
77 template <typename OpTy>
78 struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
ElementwiseOpConversionmlir::__anon80dbcd850111::ElementwiseOpConversion79 explicit ElementwiseOpConversion(MLIRContext *context)
80 : OpRewritePattern<OpTy>(context) {}
81
matchAndRewritemlir::__anon80dbcd850111::ElementwiseOpConversion82 LogicalResult matchAndRewrite(OpTy op,
83 PatternRewriter &rewriter) const override {
84 // Don't apply conversion unless all operands are unranked.
85 if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) {
86 return operand.getType().isa<UnrankedTensorType>();
87 })) {
88 return failure();
89 }
90
91 // Get operands' shape.
92 auto loc = op.getLoc();
93 Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
94 SmallVector<Value, 3> operandShapes;
95 for (Value operand : op.getOperation()->getOperands()) {
96 Value shape =
97 rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
98 operandShapes.push_back(shape);
99 }
100 Value shape =
101 operandShapes.size() == 1
102 ? operandShapes.front()
103 : rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes);
104
105 // Derive flat shape.
106 Type indexTy = rewriter.getIndexType();
107 Value numElements =
108 rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
109 Value flatShape = rewriter.create<tensor::FromElementsOp>(loc, numElements);
110
111 // Flatten operands.
112 SmallVector<Value, 3> flatOperands;
113 for (Value operand : op.getOperation()->getOperands()) {
114 Type operandElementTy =
115 operand.getType().template cast<ShapedType>().getElementType();
116 Type flatTy =
117 RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
118 Value flat = rewriter.create<mhlo::DynamicReshapeOp>(loc, flatTy, operand,
119 flatShape);
120 flatOperands.push_back(flat);
121 }
122
123 // Apply operation to flattened operands.
124 Type resultElementTy =
125 op.getType().template cast<ShapedType>().getElementType();
126 Type flatResultTy =
127 RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
128 Value flatResult =
129 rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
130
131 // Restore original shape.
132 rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
133 flatResult, shape);
134
135 return success();
136 }
137 };
138
139 // Converts a broadcasting binary operation with a scalar operand and an
140 // unranked operand to a ranked broadcasting operation by dynamically reshaping
141 // the unranked operand to a 1D tensor. This will always be safe because
142 // broadcasting from a scalar to another shape always works.
143 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
144 struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
145 : public OpConversionPattern<ChloOpTy> {
146 using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::__anon80dbcd850111::ConvertUnrankedScalarDynamicBroadcastBinaryOp147 LogicalResult matchAndRewrite(
148 ChloOpTy op, ArrayRef<Value> operands,
149 ConversionPatternRewriter &rewriter) const override {
150 auto loc = op.getLoc();
151 typename ChloOpTy::Adaptor transformed(operands);
152 Value lhs = transformed.lhs();
153 Value rhs = transformed.rhs();
154
155 auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
156 auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
157
158 auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
159 auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
160
161 bool lhs_is_scalar = lhs_ranked_type &&
162 lhs_ranked_type.getShape().empty() &&
163 rhs_unranked_type;
164 bool rhs_is_scalar = rhs_ranked_type &&
165 rhs_ranked_type.getShape().empty() &&
166 lhs_unranked_type;
167
168 // Only support the case where exactly one operand is scalar and the other
169 // is unranked. Other patterns in chlo-to-hlo legalization will create more
170 // efficient lowerings for cases where both ranks are known or will handle
171 // the more generic case of both inputs being unranked.
172 if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
173
174 auto scalar_element_type = lhs_is_scalar ? lhs_ranked_type.getElementType()
175 : rhs_ranked_type.getElementType();
176 auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
177 auto result_element_type = result_type.getElementType();
178
179 // Reshape the non-scalar value into a dynamically sized, rank-1 tensor
180 Value shape =
181 rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
182 Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
183 Value size_tensor =
184 rewriter.create<tensor::FromElementsOp>(loc, num_elements);
185 Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
186 loc, RankedTensorType::get({-1}, scalar_element_type),
187 lhs_is_scalar ? rhs : lhs, size_tensor);
188
189 // Create a new ranked Chlo op that will be further lowered by other
190 // patterns into Mhlo.
191 SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
192 rhs_is_scalar ? rhs : reshaped};
193 Value computed = rewriter.create<ChloOpTy>(
194 loc, TypeRange{RankedTensorType::get({-1}, result_element_type)},
195 new_operands, op.getAttrs());
196
197 // Reshape the result back into an unranked tensor.
198 rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
199 computed, shape);
200
201 return success();
202 }
203 };
204
205 // Handles lowering of the following pattern to patterns that will be further
206 // matched by other patterns until they result in LHLO:
207 // %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
208 //
209 // The sequence of specializations this handles is:
210 // - Either operand being scalar
211 // - Operands having equal shapes
212 // - The resulting value being any of ranks [2,6]
213 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
214 struct ConvertUnrankedDynamicBroadcastBinaryOp
215 : public OpConversionPattern<ChloOpTy> {
216 using OpConversionPattern<ChloOpTy>::OpConversionPattern;
217
matchAndRewritemlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp218 LogicalResult matchAndRewrite(
219 ChloOpTy op, ArrayRef<Value> operands,
220 ConversionPatternRewriter &rewriter) const override {
221 auto loc = op.getLoc();
222 typename ChloOpTy::Adaptor transformed(operands);
223 Value lhs = transformed.lhs();
224 Value rhs = transformed.rhs();
225 auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
226 auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
227 auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
228
229 // Only support unranked operands. If either operand is ranked, another
230 // pattern will handle the lowering.
231 if (!lhs_type || !rhs_type) return failure();
232
233 Value shape_of_lhs = rewriter.create<shape::ShapeOfOp>(loc, lhs);
234 Value shape_of_rhs = rewriter.create<shape::ShapeOfOp>(loc, rhs);
235
236 // If lhs has exactly one element
237 auto if_op = rewriter.create<scf::IfOp>(
238 loc, result_type, IsSingleElementShape(rewriter, op, shape_of_lhs),
239 true);
240 OpBuilder if_lhs_scalar_builder =
241 if_op.getThenBodyBuilder(rewriter.getListener());
242 Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
243 loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
244 Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
245 loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
246 op.getAttrs());
247 Value extended_if_lhs_scalar_result =
248 extendToBroadcastShape(if_lhs_scalar_builder, loc, if_lhs_scalar_result,
249 shape_of_lhs, shape_of_rhs);
250 if_lhs_scalar_builder.create<scf::YieldOp>(loc,
251 extended_if_lhs_scalar_result);
252
253 // If lhs does not have exactly one element
254 //
255 // See if rhs has exactly one element
256 OpBuilder else_lhs_scalar_builder =
257 if_op.getElseBodyBuilder(rewriter.getListener());
258 auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
259 loc, result_type,
260 IsSingleElementShape(else_lhs_scalar_builder, op, shape_of_rhs), true);
261 else_lhs_scalar_builder.create<scf::YieldOp>(loc,
262 if_rhs_scalar_op.getResult(0));
263 OpBuilder if_rhs_scalar_builder =
264 if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
265 Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
266 loc, RankedTensorType::get({}, rhs_type.getElementType()), rhs);
267 Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
268 loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
269 op.getAttrs());
270 Value extended_if_rhs_scalar_result =
271 extendToBroadcastShape(if_rhs_scalar_builder, loc, if_rhs_scalar_result,
272 shape_of_lhs, shape_of_rhs);
273 if_rhs_scalar_builder.create<scf::YieldOp>(loc,
274 extended_if_rhs_scalar_result);
275
276 // If NEITHER shape has exactly one element
277 //
278 // See if shapes are equal.
279 OpBuilder else_no_scalars_builder =
280 if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
281 Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
282 loc, shape_of_lhs, shape_of_rhs);
283
284 auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
285 loc, result_type, equal_shapes, true);
286 else_no_scalars_builder.create<scf::YieldOp>(loc,
287 if_eq_shapes_op.getResult(0));
288
289 OpBuilder if_eq_shapes_builder =
290 if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
291 Value non_broadcast_op =
292 Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
293 if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
294
295 // If shapes do not have exactly one element, nor are equal
296 //
297 // See if values are of a rank that we support.
298 OpBuilder if_neq_shapes_builder =
299 if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
300 if_neq_shapes_builder.create<scf::YieldOp>(
301 loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
302
303 rewriter.replaceOp(op, {if_op.getResult(0)});
304 return success();
305 }
306
307 private:
308 // Returns the dynamic result of checking the given value is effectively a
309 // scalar shape (i.e. the number of elements is 1).
IsSingleElementShapemlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp310 Value IsSingleElementShape(OpBuilder &rewriter, ChloOpTy op,
311 Value shape_of_tensor) const {
312 auto loc = op.getLoc();
313
314 Value num_elements =
315 rewriter.create<shape::NumElementsOp>(loc, shape_of_tensor);
316 return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
317 num_elements,
318 rewriter.create<ConstantIndexOp>(loc, 1));
319 }
320
GreaterRankIsNmlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp321 Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank,
322 int targeted_rank) const {
323 return builder.create<CmpIOp>(
324 loc, CmpIPredicate::eq, actual_rank,
325 builder.create<ConstantIndexOp>(loc, targeted_rank));
326 }
327
createIfOpForRankSpecializedBroadcastAndOpmlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp328 scf::IfOp createIfOpForRankSpecializedBroadcastAndOp(
329 OpBuilder &builder, ChloOpTy op, Value actual_rank,
330 int targeted_rank) const {
331 // Create the if block to place the current specialized logic in.
332 Value greater_rank_is_n =
333 GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank);
334 return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(),
335 greater_rank_is_n, true);
336 }
337
extendToBroadcastShapemlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp338 Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value,
339 Value shape_of_lhs, Value shape_of_rhs) const {
340 auto unknown_rank_extent_tensor_type = RankedTensorType::get(
341 {RankedTensorType::kDynamicSize}, builder.getIndexType());
342 Value broadcast_shape =
343 builder.create<shape::BroadcastOp>(loc, unknown_rank_extent_tensor_type,
344 shape_of_lhs, shape_of_rhs, nullptr);
345 return builder.create<mhlo::DynamicReshapeOp>(loc, value.getType(), value,
346 broadcast_shape);
347 }
348
createBroadcastToKnownRankmlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp349 Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value value,
350 int targeted_rank) const {
351 auto loc = op.getLoc();
352 Value shape = builder.create<shape::ShapeOfOp>(loc, value);
353 SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
354 auto unknown_rank_extent_tensor_type = RankedTensorType::get(
355 {RankedTensorType::kDynamicSize}, builder.getIndexType());
356 auto known_rank_extent_tensor_type =
357 RankedTensorType::get({targeted_rank}, builder.getIndexType());
358 Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
359 loc, known_rank_extent_tensor_type,
360 mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
361 ranked_shape));
362 Value extended_value = builder.create<shape::BroadcastOp>(
363 loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
364 return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
365 extended_value);
366 }
367
368 // Create the if statement and code for a broadcasting op with a result of a
369 // given rank.
createRankSpecializedBroadcastAndOpmlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp370 void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op,
371 Value lhs, Value rhs,
372 int targeted_rank) const {
373 auto loc = op.getLoc();
374
375 // Handle shape broadcasting and inference.
376 Value extended_lhs_casted =
377 createBroadcastToKnownRank(if_builder, op, lhs, targeted_rank);
378 Value extended_rhs_casted =
379 createBroadcastToKnownRank(if_builder, op, rhs, targeted_rank);
380 auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
381 targeted_rank, RankedTensorType::kDynamicSize);
382 auto reshaped_type = RankedTensorType::get(
383 dynamic_dimensions,
384 lhs.getType().template dyn_cast<TensorType>().getElementType());
385
386 // 1. Reshape operands to the given rank (with the same number of elements)
387 // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
388 // can be broadcasted and do the actual broadcasting)
389 // 3. Type erase the output back to unranked
390 Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
391 loc, reshaped_type, lhs, extended_lhs_casted);
392 Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
393 loc, reshaped_type, rhs, extended_rhs_casted);
394 auto result_element_type = op.getResult()
395 .getType()
396 .template dyn_cast<TensorType>()
397 .getElementType();
398 auto result_type =
399 RankedTensorType::get(dynamic_dimensions, result_element_type);
400 Value result = if_builder.create<ChloOpTy>(
401 loc, ArrayRef<Type>{result_type},
402 ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
403 Value reshaped_result = if_builder.create<tensor::CastOp>(
404 loc, UnrankedTensorType::get(result_element_type), result);
405 if_builder.create<scf::YieldOp>(loc, reshaped_result);
406 }
407
408 // Iterates over the desired ranks to be specialized and generates the code
409 // snippet for each case.
HandleBroadcastAndOpmlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp410 Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
411 Value rhs) const {
412 auto loc = op.getLoc();
413
414 // Find the larger rank of the 2 operands.
415 auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
416 rewriter.getIndexType());
417 Value lhs_shape =
418 rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
419 Value rhs_shape =
420 rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
421 Value lhs_rank =
422 rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), lhs_shape);
423 Value rhs_rank =
424 rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), rhs_shape);
425 Value greater_rank_lhs =
426 rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
427 Value greater_rank =
428 rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
429
430 // Generate a list of nested if/else statements to handle rank
431 // specializations from 1 to `kMaxRankSpecialization`.
432 scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
433 rewriter, op, greater_rank, 1);
434 OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
435 createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, 1);
436
437 // Put each subsequent rank specialization inside the else statement of the
438 // previous one.
439 OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
440 constexpr int kMaxRankSpecialization = 6;
441 for (int i = 2; i < kMaxRankSpecialization; i++) {
442 auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
443 else_builder, op, greater_rank, i);
444 if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
445 createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, i);
446 else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
447 else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
448 }
449 // Fire an assertion if none of the rank specializations applied (one of
450 // the ranks was greater than `kMaxRankSpecialization`).
451 else_builder.create<AssertOp>(
452 loc,
453 GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
454 kMaxRankSpecialization),
455 "Input for dynamic binary op lowering was of a rank greater than " +
456 std::to_string(kMaxRankSpecialization));
457 // Add the rank 6 specialization to the innermost else block.
458 createRankSpecializedBroadcastAndOp(else_builder, op, lhs, rhs,
459 kMaxRankSpecialization);
460
461 // Return the result of the outermost if statement.
462 return if_op.getResult(0);
463 }
464 };
465
466 struct TransformUnrankedHloPass
467 : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
getDependentDialectsmlir::__anon80dbcd850111::TransformUnrankedHloPass468 void getDependentDialects(DialectRegistry ®istry) const override {
469 registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
470 }
471
runOnFunctionmlir::__anon80dbcd850111::TransformUnrankedHloPass472 void runOnFunction() override {
473 // Setup conversion target.
474 MLIRContext &ctx = getContext();
475 ConversionTarget target(ctx);
476 target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
477 shape::ShapeDialect, scf::SCFDialect,
478 tensor::TensorDialect>();
479 target.addLegalOp<FuncOp>();
480 #define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
481 #define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
482 MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;);
483 MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;);
484 MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;);
485 #undef ADD_LEGAL_MHLO
486 #undef ADD_LEGAL_CHLO
487 AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
488 AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
489 target.addDynamicallyLegalDialect<chlo::HloClientDialect>(
490 [](Operation *op) {
491 return !llvm::any_of(op->getOperandTypes(), [](Type type) {
492 return type.isa<UnrankedTensorType>();
493 });
494 });
495
496 // Populate rewrite patterns.
497 OwningRewritePatternList patterns;
498 PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
499
500 // Apply transformation.
501 if (failed(
502 applyPartialConversion(getFunction(), target, std::move(patterns))))
503 return signalPassFailure();
504 }
505 };
506
507 } // namespace
508
PopulateTransformUnrankedHloPatterns(MLIRContext * context,OwningRewritePatternList * patterns)509 void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
510 OwningRewritePatternList *patterns) {
511 #define MAP_HLO(op) ElementwiseOpConversion<mhlo::op>
512 #define MAP_CHLO(op) ElementwiseOpConversion<chlo::op>
513 #define COMMA ,
514 // clang-format off
515 patterns->insert<
516 MAP_XLA_OPERATION_CWISE_UNARY(MAP_HLO, COMMA),
517 MAP_XLA_OPERATION_CWISE_BINARY(MAP_HLO, COMMA),
518 MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO, COMMA),
519 MAP_CHLO_OPERATION_CWISE_BINARY(MAP_CHLO, COMMA),
520 ElementwiseOpConversion<mhlo::CompareOp>,
521 ElementwiseOpConversion<mhlo::SelectOp>>(context);
522 // clang-format on
523 #undef MAP_HLO
524 #undef MAP_CHLO
525 #undef COMMA
526 chlo::PopulateForBroadcastingBinaryOp<
527 ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
528 chlo::PopulateForBroadcastingBinaryOp<
529 ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
530 }
531
createTransformUnrankedHloPass()532 std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
533 return std::make_unique<TransformUnrankedHloPass>();
534 }
535
536 } // namespace mlir
537