1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 
15 ==============================================================================*/
16 
17 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/Shape/IR/Shape.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 
33 namespace mlir {
34 namespace {
35 
36 // TODO(herhut): Generate these out of op definitions.
37 #define MAP_XLA_OPERATION_CWISE_UNARY(fn, sep)                                \
38   fn(AbsOp) sep fn(CeilOp) sep fn(ClzOp) sep fn(ConvertOp) sep fn(CosOp)      \
39       sep fn(ExpOp) sep fn(Expm1Op) sep fn(FloorOp) sep fn(ImagOp)            \
40           sep fn(IsFiniteOp) sep fn(LogOp) sep fn(Log1pOp) sep fn(LogisticOp) \
41               sep fn(NotOp) sep fn(NegOp) sep fn(PopulationCountOp)           \
42                   sep fn(RealOp) sep fn(RoundOp) sep fn(RsqrtOp)              \
43                       sep fn(SignOp) sep fn(SinOp) sep fn(SqrtOp)             \
44                           sep fn(TanhOp)
45 
46 // TODO(herhut): Generate these out of op definitions.
47 #define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep)                            \
48   fn(AddOp) sep fn(AndOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp)  \
49       sep fn(MaxOp) sep fn(MinOp) sep fn(MulOp) sep fn(OrOp) sep fn(PowOp) \
50           sep fn(RemOp) sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
51               sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
52 
53 // TODO(herhut): Generate these out of op definitions.
54 #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep)                            \
55   fn(AcosOp) sep fn(AcoshOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) \
56       sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(DigammaOp)      \
57           sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp)    \
58               sep fn(SinhOp) sep fn(TanOp)
59 
60 // TODO(herhut): Generate these out of op definitions.
61 #define MAP_CHLO_OPERATION_CWISE_BINARY(fn, sep) fn(ZetaOp)
62 
63 template <typename OpTy>
AddLegalOpOnRankedTensor(ConversionTarget * target)64 inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
65   target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
66     return llvm::all_of(op.getOperation()->getOperandTypes(),
67                         [&](Type t) { return t.isa<RankedTensorType>(); });
68   });
69 }
70 
71 /// Element-wise operations on unranked tensors can be applied to the flattened
72 /// tensor operands with the same effect.  This pattern rewrites every such
73 /// operation to
74 ///   (i)   flatten the input tensor,
75 ///   (ii)  apply the operation, and
76 ///   (iii) restore the original shape.
77 template <typename OpTy>
78 struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
ElementwiseOpConversionmlir::__anon80dbcd850111::ElementwiseOpConversion79   explicit ElementwiseOpConversion(MLIRContext *context)
80       : OpRewritePattern<OpTy>(context) {}
81 
matchAndRewritemlir::__anon80dbcd850111::ElementwiseOpConversion82   LogicalResult matchAndRewrite(OpTy op,
83                                 PatternRewriter &rewriter) const override {
84     // Don't apply conversion unless all operands are unranked.
85     if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) {
86           return operand.getType().isa<UnrankedTensorType>();
87         })) {
88       return failure();
89     }
90 
91     // Get operands' shape.
92     auto loc = op.getLoc();
93     Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
94     SmallVector<Value, 3> operandShapes;
95     for (Value operand : op.getOperation()->getOperands()) {
96       Value shape =
97           rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
98       operandShapes.push_back(shape);
99     }
100     Value shape =
101         operandShapes.size() == 1
102             ? operandShapes.front()
103             : rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes);
104 
105     // Derive flat shape.
106     Type indexTy = rewriter.getIndexType();
107     Value numElements =
108         rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
109     Value flatShape = rewriter.create<tensor::FromElementsOp>(loc, numElements);
110 
111     // Flatten operands.
112     SmallVector<Value, 3> flatOperands;
113     for (Value operand : op.getOperation()->getOperands()) {
114       Type operandElementTy =
115           operand.getType().template cast<ShapedType>().getElementType();
116       Type flatTy =
117           RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
118       Value flat = rewriter.create<mhlo::DynamicReshapeOp>(loc, flatTy, operand,
119                                                            flatShape);
120       flatOperands.push_back(flat);
121     }
122 
123     // Apply operation to flattened operands.
124     Type resultElementTy =
125         op.getType().template cast<ShapedType>().getElementType();
126     Type flatResultTy =
127         RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
128     Value flatResult =
129         rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
130 
131     // Restore original shape.
132     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
133                                                         flatResult, shape);
134 
135     return success();
136   }
137 };
138 
139 // Converts a broadcasting binary operation with a scalar operand and an
140 // unranked operand to a ranked broadcasting operation by dynamically reshaping
141 // the unranked operand to a 1D tensor. This will always be safe because
142 // broadcasting from a scalar to another shape always works.
143 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
144 struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
145     : public OpConversionPattern<ChloOpTy> {
146   using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::__anon80dbcd850111::ConvertUnrankedScalarDynamicBroadcastBinaryOp147   LogicalResult matchAndRewrite(
148       ChloOpTy op, ArrayRef<Value> operands,
149       ConversionPatternRewriter &rewriter) const override {
150     auto loc = op.getLoc();
151     typename ChloOpTy::Adaptor transformed(operands);
152     Value lhs = transformed.lhs();
153     Value rhs = transformed.rhs();
154 
155     auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
156     auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
157 
158     auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
159     auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
160 
161     bool lhs_is_scalar = lhs_ranked_type &&
162                          lhs_ranked_type.getShape().empty() &&
163                          rhs_unranked_type;
164     bool rhs_is_scalar = rhs_ranked_type &&
165                          rhs_ranked_type.getShape().empty() &&
166                          lhs_unranked_type;
167 
168     // Only support the case where exactly one operand is scalar and the other
169     // is unranked. Other patterns in chlo-to-hlo legalization will create more
170     // efficient lowerings for cases where both ranks are known or will handle
171     // the more generic case of both inputs being unranked.
172     if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
173 
174     auto scalar_element_type = lhs_is_scalar ? lhs_ranked_type.getElementType()
175                                              : rhs_ranked_type.getElementType();
176     auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
177     auto result_element_type = result_type.getElementType();
178 
179     // Reshape the non-scalar value into a dynamically sized, rank-1 tensor
180     Value shape =
181         rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
182     Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
183     Value size_tensor =
184         rewriter.create<tensor::FromElementsOp>(loc, num_elements);
185     Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
186         loc, RankedTensorType::get({-1}, scalar_element_type),
187         lhs_is_scalar ? rhs : lhs, size_tensor);
188 
189     // Create a new ranked Chlo op that will be further lowered by other
190     // patterns into Mhlo.
191     SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
192                                        rhs_is_scalar ? rhs : reshaped};
193     Value computed = rewriter.create<ChloOpTy>(
194         loc, TypeRange{RankedTensorType::get({-1}, result_element_type)},
195         new_operands, op.getAttrs());
196 
197     // Reshape the result back into an unranked tensor.
198     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
199                                                         computed, shape);
200 
201     return success();
202   }
203 };
204 
205 // Handles lowering of the following pattern to patterns that will be further
206 // matched by other patterns until they result in LHLO:
207 //   %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
208 //
209 // The sequence of specializations this handles is:
210 //   - Either operand being scalar
211 //   - Operands having equal shapes
212 //   - The resulting value being any of ranks [2,6]
213 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
214 struct ConvertUnrankedDynamicBroadcastBinaryOp
215     : public OpConversionPattern<ChloOpTy> {
216   using OpConversionPattern<ChloOpTy>::OpConversionPattern;
217 
matchAndRewritemlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp218   LogicalResult matchAndRewrite(
219       ChloOpTy op, ArrayRef<Value> operands,
220       ConversionPatternRewriter &rewriter) const override {
221     auto loc = op.getLoc();
222     typename ChloOpTy::Adaptor transformed(operands);
223     Value lhs = transformed.lhs();
224     Value rhs = transformed.rhs();
225     auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
226     auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
227     auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
228 
229     // Only support unranked operands. If either operand is ranked, another
230     // pattern will handle the lowering.
231     if (!lhs_type || !rhs_type) return failure();
232 
233     Value shape_of_lhs = rewriter.create<shape::ShapeOfOp>(loc, lhs);
234     Value shape_of_rhs = rewriter.create<shape::ShapeOfOp>(loc, rhs);
235 
236     // If lhs has exactly one element
237     auto if_op = rewriter.create<scf::IfOp>(
238         loc, result_type, IsSingleElementShape(rewriter, op, shape_of_lhs),
239         true);
240     OpBuilder if_lhs_scalar_builder =
241         if_op.getThenBodyBuilder(rewriter.getListener());
242     Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
243         loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
244     Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
245         loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
246         op.getAttrs());
247     Value extended_if_lhs_scalar_result =
248         extendToBroadcastShape(if_lhs_scalar_builder, loc, if_lhs_scalar_result,
249                                shape_of_lhs, shape_of_rhs);
250     if_lhs_scalar_builder.create<scf::YieldOp>(loc,
251                                                extended_if_lhs_scalar_result);
252 
253     // If lhs does not have exactly one element
254     //
255     // See if rhs has exactly one element
256     OpBuilder else_lhs_scalar_builder =
257         if_op.getElseBodyBuilder(rewriter.getListener());
258     auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
259         loc, result_type,
260         IsSingleElementShape(else_lhs_scalar_builder, op, shape_of_rhs), true);
261     else_lhs_scalar_builder.create<scf::YieldOp>(loc,
262                                                  if_rhs_scalar_op.getResult(0));
263     OpBuilder if_rhs_scalar_builder =
264         if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
265     Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
266         loc, RankedTensorType::get({}, rhs_type.getElementType()), rhs);
267     Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
268         loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
269         op.getAttrs());
270     Value extended_if_rhs_scalar_result =
271         extendToBroadcastShape(if_rhs_scalar_builder, loc, if_rhs_scalar_result,
272                                shape_of_lhs, shape_of_rhs);
273     if_rhs_scalar_builder.create<scf::YieldOp>(loc,
274                                                extended_if_rhs_scalar_result);
275 
276     // If NEITHER shape has exactly one element
277     //
278     // See if shapes are equal.
279     OpBuilder else_no_scalars_builder =
280         if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
281     Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
282         loc, shape_of_lhs, shape_of_rhs);
283 
284     auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
285         loc, result_type, equal_shapes, true);
286     else_no_scalars_builder.create<scf::YieldOp>(loc,
287                                                  if_eq_shapes_op.getResult(0));
288 
289     OpBuilder if_eq_shapes_builder =
290         if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
291     Value non_broadcast_op =
292         Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
293     if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
294 
295     // If shapes do not have exactly one element, nor are equal
296     //
297     // See if values are of a rank that we support.
298     OpBuilder if_neq_shapes_builder =
299         if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
300     if_neq_shapes_builder.create<scf::YieldOp>(
301         loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
302 
303     rewriter.replaceOp(op, {if_op.getResult(0)});
304     return success();
305   }
306 
307  private:
308   // Returns the dynamic result of checking the given value is effectively a
309   // scalar shape (i.e. the number of elements is 1).
IsSingleElementShapemlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp310   Value IsSingleElementShape(OpBuilder &rewriter, ChloOpTy op,
311                              Value shape_of_tensor) const {
312     auto loc = op.getLoc();
313 
314     Value num_elements =
315         rewriter.create<shape::NumElementsOp>(loc, shape_of_tensor);
316     return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
317                                    num_elements,
318                                    rewriter.create<ConstantIndexOp>(loc, 1));
319   }
320 
GreaterRankIsNmlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp321   Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank,
322                        int targeted_rank) const {
323     return builder.create<CmpIOp>(
324         loc, CmpIPredicate::eq, actual_rank,
325         builder.create<ConstantIndexOp>(loc, targeted_rank));
326   }
327 
createIfOpForRankSpecializedBroadcastAndOpmlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp328   scf::IfOp createIfOpForRankSpecializedBroadcastAndOp(
329       OpBuilder &builder, ChloOpTy op, Value actual_rank,
330       int targeted_rank) const {
331     // Create the if block to place the current specialized logic in.
332     Value greater_rank_is_n =
333         GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank);
334     return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(),
335                                      greater_rank_is_n, true);
336   }
337 
extendToBroadcastShapemlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp338   Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value,
339                                Value shape_of_lhs, Value shape_of_rhs) const {
340     auto unknown_rank_extent_tensor_type = RankedTensorType::get(
341         {RankedTensorType::kDynamicSize}, builder.getIndexType());
342     Value broadcast_shape =
343         builder.create<shape::BroadcastOp>(loc, unknown_rank_extent_tensor_type,
344                                            shape_of_lhs, shape_of_rhs, nullptr);
345     return builder.create<mhlo::DynamicReshapeOp>(loc, value.getType(), value,
346                                                   broadcast_shape);
347   }
348 
createBroadcastToKnownRankmlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp349   Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value value,
350                                    int targeted_rank) const {
351     auto loc = op.getLoc();
352     Value shape = builder.create<shape::ShapeOfOp>(loc, value);
353     SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
354     auto unknown_rank_extent_tensor_type = RankedTensorType::get(
355         {RankedTensorType::kDynamicSize}, builder.getIndexType());
356     auto known_rank_extent_tensor_type =
357         RankedTensorType::get({targeted_rank}, builder.getIndexType());
358     Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
359         loc, known_rank_extent_tensor_type,
360         mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
361                                         ranked_shape));
362     Value extended_value = builder.create<shape::BroadcastOp>(
363         loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
364     return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
365                                           extended_value);
366   }
367 
368   // Create the if statement and code for a broadcasting op with a result of a
369   // given rank.
createRankSpecializedBroadcastAndOpmlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp370   void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op,
371                                            Value lhs, Value rhs,
372                                            int targeted_rank) const {
373     auto loc = op.getLoc();
374 
375     // Handle shape broadcasting and inference.
376     Value extended_lhs_casted =
377         createBroadcastToKnownRank(if_builder, op, lhs, targeted_rank);
378     Value extended_rhs_casted =
379         createBroadcastToKnownRank(if_builder, op, rhs, targeted_rank);
380     auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
381         targeted_rank, RankedTensorType::kDynamicSize);
382     auto reshaped_type = RankedTensorType::get(
383         dynamic_dimensions,
384         lhs.getType().template dyn_cast<TensorType>().getElementType());
385 
386     // 1. Reshape operands to the given rank (with the same number of elements)
387     // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
388     //    can be broadcasted and do the actual broadcasting)
389     // 3. Type erase the output back to unranked
390     Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
391         loc, reshaped_type, lhs, extended_lhs_casted);
392     Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
393         loc, reshaped_type, rhs, extended_rhs_casted);
394     auto result_element_type = op.getResult()
395                                    .getType()
396                                    .template dyn_cast<TensorType>()
397                                    .getElementType();
398     auto result_type =
399         RankedTensorType::get(dynamic_dimensions, result_element_type);
400     Value result = if_builder.create<ChloOpTy>(
401         loc, ArrayRef<Type>{result_type},
402         ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
403     Value reshaped_result = if_builder.create<tensor::CastOp>(
404         loc, UnrankedTensorType::get(result_element_type), result);
405     if_builder.create<scf::YieldOp>(loc, reshaped_result);
406   }
407 
408   // Iterates over the desired ranks to be specialized and generates the code
409   // snippet for each case.
HandleBroadcastAndOpmlir::__anon80dbcd850111::ConvertUnrankedDynamicBroadcastBinaryOp410   Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
411                              Value rhs) const {
412     auto loc = op.getLoc();
413 
414     // Find the larger rank of the 2 operands.
415     auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
416                                                     rewriter.getIndexType());
417     Value lhs_shape =
418         rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
419     Value rhs_shape =
420         rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
421     Value lhs_rank =
422         rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), lhs_shape);
423     Value rhs_rank =
424         rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), rhs_shape);
425     Value greater_rank_lhs =
426         rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
427     Value greater_rank =
428         rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
429 
430     // Generate a list of nested if/else statements to handle rank
431     // specializations from 1 to `kMaxRankSpecialization`.
432     scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
433         rewriter, op, greater_rank, 1);
434     OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
435     createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, 1);
436 
437     // Put each subsequent rank specialization inside the else statement of the
438     // previous one.
439     OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
440     constexpr int kMaxRankSpecialization = 6;
441     for (int i = 2; i < kMaxRankSpecialization; i++) {
442       auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
443           else_builder, op, greater_rank, i);
444       if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
445       createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, i);
446       else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
447       else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
448     }
449     // Fire an assertion if none of the rank specializations applied (one of
450     // the ranks was greater than `kMaxRankSpecialization`).
451     else_builder.create<AssertOp>(
452         loc,
453         GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
454                        kMaxRankSpecialization),
455         "Input for dynamic binary op lowering was of a rank greater than " +
456             std::to_string(kMaxRankSpecialization));
457     // Add the rank 6 specialization to the innermost else block.
458     createRankSpecializedBroadcastAndOp(else_builder, op, lhs, rhs,
459                                         kMaxRankSpecialization);
460 
461     // Return the result of the outermost if statement.
462     return if_op.getResult(0);
463   }
464 };
465 
466 struct TransformUnrankedHloPass
467     : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
getDependentDialectsmlir::__anon80dbcd850111::TransformUnrankedHloPass468   void getDependentDialects(DialectRegistry &registry) const override {
469     registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
470   }
471 
runOnFunctionmlir::__anon80dbcd850111::TransformUnrankedHloPass472   void runOnFunction() override {
473     // Setup conversion target.
474     MLIRContext &ctx = getContext();
475     ConversionTarget target(ctx);
476     target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
477                            shape::ShapeDialect, scf::SCFDialect,
478                            tensor::TensorDialect>();
479     target.addLegalOp<FuncOp>();
480 #define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
481 #define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
482     MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;);
483     MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;);
484     MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;);
485 #undef ADD_LEGAL_MHLO
486 #undef ADD_LEGAL_CHLO
487     AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
488     AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
489     target.addDynamicallyLegalDialect<chlo::HloClientDialect>(
490         [](Operation *op) {
491           return !llvm::any_of(op->getOperandTypes(), [](Type type) {
492             return type.isa<UnrankedTensorType>();
493           });
494         });
495 
496     // Populate rewrite patterns.
497     OwningRewritePatternList patterns;
498     PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
499 
500     // Apply transformation.
501     if (failed(
502             applyPartialConversion(getFunction(), target, std::move(patterns))))
503       return signalPassFailure();
504   }
505 };
506 
507 }  // namespace
508 
PopulateTransformUnrankedHloPatterns(MLIRContext * context,OwningRewritePatternList * patterns)509 void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
510                                           OwningRewritePatternList *patterns) {
511 #define MAP_HLO(op) ElementwiseOpConversion<mhlo::op>
512 #define MAP_CHLO(op) ElementwiseOpConversion<chlo::op>
513 #define COMMA ,
514   // clang-format off
515   patterns->insert<
516       MAP_XLA_OPERATION_CWISE_UNARY(MAP_HLO, COMMA),
517       MAP_XLA_OPERATION_CWISE_BINARY(MAP_HLO, COMMA),
518       MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO, COMMA),
519       MAP_CHLO_OPERATION_CWISE_BINARY(MAP_CHLO, COMMA),
520       ElementwiseOpConversion<mhlo::CompareOp>,
521       ElementwiseOpConversion<mhlo::SelectOp>>(context);
522   // clang-format on
523 #undef MAP_HLO
524 #undef MAP_CHLO
525 #undef COMMA
526   chlo::PopulateForBroadcastingBinaryOp<
527       ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
528   chlo::PopulateForBroadcastingBinaryOp<
529       ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
530 }
531 
createTransformUnrankedHloPass()532 std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
533   return std::make_unique<TransformUnrankedHloPass>();
534 }
535 
536 }  // namespace mlir
537