1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the operations used in the MHLO dialect.
17 
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19 
20 #include <assert.h>
21 #include <stddef.h>
22 #include <stdint.h>
23 
24 #include <algorithm>
25 #include <functional>
26 
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/iterator_range.h"
35 #include "llvm/Support/Casting.h"
36 #include "llvm/Support/FormatVariadic.h"
37 #include "llvm/Support/MathExtras.h"
38 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
39 #include "mlir-hlo/utils/convert_op_folder.h"
40 #include "mlir-hlo/utils/hlo_utils.h"
41 #include "mlir/Dialect/Shape/IR/Shape.h"
42 #include "mlir/Dialect/StandardOps/IR/Ops.h"
43 #include "mlir/Dialect/Tensor/IR/Tensor.h"
44 #include "mlir/IR/Attributes.h"
45 #include "mlir/IR/Builders.h"
46 #include "mlir/IR/BuiltinTypes.h"
47 #include "mlir/IR/Dialect.h"
48 #include "mlir/IR/Location.h"
49 #include "mlir/IR/MLIRContext.h"
50 #include "mlir/IR/Matchers.h"
51 #include "mlir/IR/OpDefinition.h"
52 #include "mlir/IR/OpImplementation.h"
53 #include "mlir/IR/Operation.h"
54 #include "mlir/IR/OperationSupport.h"
55 #include "mlir/IR/PatternMatch.h"
56 #include "mlir/IR/TypeUtilities.h"
57 #include "mlir/IR/Types.h"
58 #include "mlir/IR/Value.h"
59 #include "mlir/Support/LLVM.h"
60 #include "mlir/Support/LogicalResult.h"
61 #include "mlir/Transforms/InliningUtils.h"
62 
63 namespace mlir {
64 #include "hlo_patterns.cc.inc"
65 }  // namespace mlir
66 
67 namespace mlir {
68 namespace mhlo {
69 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)70 Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
71                                             Type type, Location loc) {
72   // HLO dialect constants only support ElementsAttr unlike standard dialect
73   // constant which supports all attributes.
74   if (value.isa<ElementsAttr>())
75     return builder.create<mhlo::ConstOp>(loc, type, value.cast<ElementsAttr>());
76   return nullptr;
77 }
78 
79 template <typename T>
Verify(T op)80 static LogicalResult Verify(T op) {
81   return success();
82 }
83 
84 namespace {
85 
86 //===----------------------------------------------------------------------===//
87 // Utilities for the canonicalize patterns
88 //===----------------------------------------------------------------------===//
89 
90 // Verifies that dimension attribute for the op correctly indexes in operand or
91 // result shape.
92 template <typename OpT>
VerifyDimAttr(OpT op)93 static LogicalResult VerifyDimAttr(OpT op) {
94   int64_t rank = -1;
95   if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
96     rank = ty.getRank();
97   } else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
98     rank = ty.getRank();
99   } else {
100     return success();
101   }
102 
103   int64_t dim = op.dimension();
104   if (dim < 0 || dim >= rank)
105     return op.emitOpError() << "requires dimension attribute in range [0, "
106                             << rank << "); found (" << dim << ")";
107   return success();
108 }
109 
110 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)111 DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
112                                         Builder* builder) {
113   RankedTensorType ty = RankedTensorType::get(
114       {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
115   return DenseIntElementsAttr::get(ty, values);
116 }
117 
118 // Given the start indices and slice sizes for a dynamic-slice that can be
119 // converted to a static slice, returns the limits for the static slice.
BuildSliceLimits(DenseIntElementsAttr start_indices,DenseIntElementsAttr slice_sizes,Builder * builder)120 DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
121                                       DenseIntElementsAttr slice_sizes,
122                                       Builder* builder) {
123   SmallVector<int64_t, 4> slice_limits;
124   for (int64_t i = 0; i < slice_sizes.getNumElements(); ++i) {
125     int64_t start_index = start_indices.getValue<IntegerAttr>(i).getInt();
126     int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
127     slice_limits.push_back(start_index + slice_size);
128   }
129   return GetI64ElementsAttr(slice_limits, builder);
130 }
131 
132 #include "mhlo_canonicalize.inc"
133 }  // namespace
134 
135 //===----------------------------------------------------------------------===//
136 // ConstOp
137 //===----------------------------------------------------------------------===//
138 
fold(ArrayRef<Attribute> operands)139 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
140   assert(operands.empty() && "constant has no operands");
141 
142   // Return the held attribute value.
143   return value();
144 }
145 
146 // Builds a constant op with the specified attribute `value`.
build(OpBuilder & builder,OperationState & result,Attribute value)147 void ConstOp::build(OpBuilder& builder, OperationState& result,
148                     Attribute value) {
149   Type type;
150   if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
151     type = elemAttr.getType();
152   } else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
153              value.isa<IntegerAttr>()) {
154     // All XLA types must be tensor types. In the build() method, we want to
155     // provide more flexibility by allowing attributes of scalar types. But we
156     // need to wrap it up with ElementsAttr to construct valid XLA constants.
157     type = RankedTensorType::get(/*shape=*/{}, value.getType());
158     value = DenseElementsAttr::get(type.cast<TensorType>(), value);
159   }
160 
161   // TODO: support other XLA specific types.
162   assert(type && "unsupported attribute type for building mhlo.constant");
163   result.types.push_back(type);
164   result.addAttribute("value", value);
165 }
166 
167 //===----------------------------------------------------------------------===//
168 // DotGeneralOp
169 //===----------------------------------------------------------------------===//
170 
Verify(DotGeneralOp op)171 static LogicalResult Verify(DotGeneralOp op) {
172   auto dot_dimension_numbers = op.dot_dimension_numbers();
173   int64_t lhs_batching_dimensions_size = llvm::size(
174       dot_dimension_numbers.lhs_batching_dimensions().getValues<int64_t>());
175   int64_t rhs_batching_dimensions_size = llvm::size(
176       dot_dimension_numbers.rhs_batching_dimensions().getValues<int64_t>());
177   if (lhs_batching_dimensions_size != rhs_batching_dimensions_size) {
178     return op.emitError()
179            << "lhs and rhs should have the same number of batching dimensions";
180   }
181   int64_t lhs_contracting_dimensions_size = llvm::size(
182       dot_dimension_numbers.lhs_contracting_dimensions().getValues<int64_t>());
183   int64_t rhs_contracting_dimensions_size = llvm::size(
184       dot_dimension_numbers.rhs_contracting_dimensions().getValues<int64_t>());
185   if (lhs_contracting_dimensions_size != rhs_contracting_dimensions_size) {
186     return op.emitError() << "lhs and rhs should have the same number of "
187                              "contracting dimensions";
188   }
189   return success();
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // GatherOp
194 //===----------------------------------------------------------------------===//
195 
196 // Converts gather ops to slice ops in case we have a single set of constant
197 // indices.
198 struct GatherSlice : public OpRewritePattern<GatherOp> {
199   using OpRewritePattern<GatherOp>::OpRewritePattern;
200 
matchAndRewritemlir::mhlo::GatherSlice201   LogicalResult matchAndRewrite(GatherOp gather,
202                                 PatternRewriter& rewriter) const override {
203     DenseIntElementsAttr index;
204     if (!matchPattern(gather.start_indices(), m_Constant(&index)))
205       return failure();
206 
207     const auto& dnums = gather.dimension_numbers();
208     if (dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
209       return failure();
210 
211     // TODO(tberghammer): Remove when the verifier catches this case what is
212     // invalid if all previous condition holds.
213     if (index.getNumElements() != dnums.start_index_map().getNumElements())
214       return failure();
215 
216     auto slice_end =
217         llvm::to_vector<8>(gather.slice_sizes().getValues<int64_t>());
218     llvm::SmallVector<int64_t, 8> slice_start(slice_end.size(), 0);
219     for (auto it : llvm::zip(dnums.start_index_map().getIntValues(),
220                              index.getIntValues())) {
221       int64_t map_index = std::get<0>(it).getSExtValue();
222       int64_t offset = std::get<1>(it).getSExtValue();
223       slice_start[map_index] += offset;
224       slice_end[map_index] += offset;
225     }
226 
227     llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
228     llvm::SmallVector<int64_t, 8> slice_shape(slice_end.size());
229     for (int64_t i = 0; i < slice_end.size(); ++i) {
230       slice_shape[i] = slice_end[i] - slice_start[i];
231     }
232     Type element_type = gather.getType().cast<TensorType>().getElementType();
233     auto slice_type = RankedTensorType::get(slice_shape, element_type);
234     Value result = rewriter.create<SliceOp>(
235         gather.getLoc(), slice_type, gather.getOperand(0),
236         GetI64ElementsAttr(slice_start, &rewriter),
237         GetI64ElementsAttr(slice_end, &rewriter),
238         GetI64ElementsAttr(slice_stride, &rewriter));
239 
240     if (dnums.collapsed_slice_dims().getNumElements() > 0) {
241       auto collapsed_slice_dims = llvm::to_vector<8>(llvm::map_range(
242           dnums.collapsed_slice_dims().getIntValues(),
243           [](const llvm::APInt& i) { return i.getSExtValue(); }));
244       llvm::SmallVector<int64_t, 8> reshape_shape;
245       for (int64_t i = 0; i < slice_shape.size(); ++i) {
246         if (llvm::count(collapsed_slice_dims, i) == 0) {
247           reshape_shape.push_back(slice_shape[i]);
248         }
249       }
250       auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
251       result =
252           rewriter.create<ReshapeOp>(gather.getLoc(), reshape_type, result);
253     }
254 
255     result.setType(gather.getType());
256     rewriter.replaceOp(gather, result);
257     return success();
258   }
259 };
260 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)261 void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
262                                            MLIRContext* context) {
263   results.insert<GatherSlice>(context);
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // GetDimensionSizeOp
268 //===----------------------------------------------------------------------===//
269 //
Verify(GetDimensionSizeOp op)270 static LogicalResult Verify(GetDimensionSizeOp op) { return VerifyDimAttr(op); }
271 
272 /// Fold get_dimension_size when the said shape dimension is a constant.
fold(ArrayRef<Attribute> attrs)273 OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
274   RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>();
275   if (!type) return {};
276 
277   int32_t dim = dimension();
278   if (type.isDynamic(dim)) return {};
279   // The result type is always is a 0-d i32 tensor.
280   return DenseIntElementsAttr::get<int32_t>(
281       getResult().getType().cast<RankedTensorType>(), type.getDimSize(dim));
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // IotaOp
286 //===----------------------------------------------------------------------===//
287 
Verify(IotaOp op)288 static LogicalResult Verify(IotaOp op) {
289   auto shape = op.getType().cast<ShapedType>();
290   if (!shape.hasRank()) return success();
291 
292   if (shape.getRank() == 0)
293     return op.emitOpError() << "does not support scalars.";
294 
295   auto iota_dimension = op.iota_dimension();
296   if (iota_dimension >= shape.getRank() || iota_dimension < 0)
297     return op.emitOpError() << "iota dimension cannot go beyond the output "
298                                "rank or be negative.";
299   return success();
300 }
301 
302 // Iota operations across multiple dimensions can be reduced to an iota and a
303 // ranked broadcast.
304 struct IotaBroadcast : public OpRewritePattern<IotaOp> {
305   using OpRewritePattern<IotaOp>::OpRewritePattern;
306 
matchAndRewritemlir::mhlo::IotaBroadcast307   LogicalResult matchAndRewrite(IotaOp iota,
308                                 PatternRewriter& rewriter) const override {
309     auto result_ty = iota.getType().cast<ShapedType>();
310     if (!result_ty.hasRank() || result_ty.getRank() < 2) {
311       return failure();
312     }
313 
314     auto iota_dimension = iota.iota_dimension();
315 
316     auto iota_type = RankedTensorType::get(
317         {result_ty.getDimSize(iota_dimension)}, result_ty.getElementType());
318 
319     auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
320                                             rewriter.getI64IntegerAttr(0));
321 
322     auto broadcast_attr = DenseIntElementsAttr::get(
323         RankedTensorType::get({1}, rewriter.getIntegerType(64)),
324         {iota_dimension});
325     rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, result_ty, new_iota,
326                                                   broadcast_attr);
327     return success();
328   }
329 };
330 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)331 void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
332                                          MLIRContext* context) {
333   results.insert<IotaBroadcast>(context);
334 }
335 
fold(ArrayRef<Attribute> operands)336 OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
337   auto dimension = iota_dimension();
338   auto result_ty = getResult().getType().cast<ShapedType>();
339   if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
340     Builder builder(getContext());
341     return builder.getZeroAttr(result_ty);
342   }
343 
344   return {};
345 }
346 
347 //===----------------------------------------------------------------------===//
348 // DynamicIotaOp
349 //===----------------------------------------------------------------------===//
350 
351 namespace {
352 
353 struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
354   using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
355 
matchAndRewritemlir::mhlo::__anon1950c2d00311::DynamicIotaIsStatic356   LogicalResult matchAndRewrite(DynamicIotaOp iota,
357                                 PatternRewriter& rewriter) const override {
358     auto result_ty = iota.getType().cast<ShapedType>();
359     if (!result_ty.hasStaticShape()) {
360       return failure();
361     }
362 
363     rewriter.replaceOpWithNewOp<IotaOp>(iota, result_ty, iota.iota_dimension());
364     return success();
365   }
366 };
367 
368 // Dynamic Iota operations across multiple dimensions can be reduced to an iota
369 // and a ranked broadcast.
370 struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
371   using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
372 
matchAndRewritemlir::mhlo::__anon1950c2d00311::DynamicIotaBroadcast373   LogicalResult matchAndRewrite(DynamicIotaOp iota,
374                                 PatternRewriter& rewriter) const override {
375     auto result_ty = iota.getType().cast<ShapedType>();
376     if (!result_ty.hasRank() || result_ty.getRank() < 2) {
377       return failure();
378     }
379 
380     auto iota_dimension = iota.iota_dimension();
381     auto iota_dimension_int = iota_dimension;
382 
383     auto converted_shape = rewriter.create<IndexCastOp>(
384         iota.getLoc(),
385         RankedTensorType::get(
386             iota.output_shape().getType().cast<ShapedType>().getShape(),
387             rewriter.getI64Type()),
388         iota.output_shape());
389 
390     auto sliced_shape = rewriter.create<SliceOp>(
391         iota.getLoc(), converted_shape,
392         GetI64ElementsAttr(iota_dimension_int, &rewriter),
393         GetI64ElementsAttr(iota_dimension_int + 1, &rewriter),
394         GetI64ElementsAttr(1, &rewriter));
395 
396     auto converted_sliced_shape = rewriter.create<IndexCastOp>(
397         iota.getLoc(),
398         RankedTensorType::get(
399             {1},
400             iota.output_shape().getType().cast<ShapedType>().getElementType()),
401         sliced_shape);
402 
403     auto iota_type = RankedTensorType::get(
404         {result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType());
405 
406     auto new_iota = rewriter.create<DynamicIotaOp>(
407         iota.getLoc(), iota_type, converted_sliced_shape,
408         rewriter.getI64IntegerAttr(0));
409 
410     auto broadcast_attr = DenseIntElementsAttr::get(
411         RankedTensorType::get({1}, rewriter.getIntegerType(64)),
412         {iota_dimension});
413     rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
414         iota, result_ty, new_iota, iota.output_shape(), broadcast_attr);
415     return success();
416   }
417 };
418 
419 }  // namespace
420 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)421 void DynamicIotaOp::getCanonicalizationPatterns(
422     OwningRewritePatternList& results, MLIRContext* context) {
423   results.insert<DynamicIotaIsStatic>(context);
424   results.insert<DynamicIotaBroadcast>(context);
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // DynamicUpdateSliceOp
429 //===----------------------------------------------------------------------===//
430 
Verify(DynamicUpdateSliceOp op)431 static LogicalResult Verify(DynamicUpdateSliceOp op) {
432   OperandRange indices = op.start_indices();
433   if (indices.size() <= 1) return success();
434 
435   // Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it
436   // is OK to cast indices to ShapedType here.
437   auto idx_tensor = indices.take_front().front().getType().cast<ShapedType>();
438   Type first_elem_ty = idx_tensor.getElementType();
439   Type elem_ty;
440 
441   for (auto idx : llvm::drop_begin(indices, 1)) {
442     idx_tensor = idx.getType().cast<ShapedType>();
443     elem_ty = idx_tensor.getElementType();
444 
445     if (first_elem_ty != elem_ty) {
446       return op.emitOpError() << "start indices must have same element type "
447                                  "(encountered mismatch: "
448                               << first_elem_ty << " vs " << elem_ty << ")";
449     }
450   }
451   return success();
452 }
453 
454 //===----------------------------------------------------------------------===//
455 // AbsOp
456 //===----------------------------------------------------------------------===//
457 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)458 LogicalResult AbsOp::inferReturnTypes(
459     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
460     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
461   auto operand_ty = (*operands.begin()).getType().cast<ShapedType>();
462   Type element_ty = operand_ty.getElementType();
463   if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
464     element_ty = complex_ty.getElementType();
465   }
466 
467   Type result_ty;
468   if (operand_ty.hasRank()) {
469     result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty);
470   } else {
471     result_ty = UnrankedTensorType::get(element_ty);
472   }
473   inferredReturnTypes.push_back(result_ty);
474   return success();
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // CollectivePermuteOp
479 //===----------------------------------------------------------------------===//
480 
Verify(CollectivePermuteOp op)481 static LogicalResult Verify(CollectivePermuteOp op) {
482   // Check that source target pair is Nx2 tensor.
483   auto type = op.source_target_pairs().getType().dyn_cast<RankedTensorType>();
484   if (type.getRank() != 2)
485     return op.emitError() << "expect source_target_pairs attribute to be of "
486                              "rank 2, but got rank "
487                           << type.getRank();
488   if (type.getShape()[1] != 2)
489     return op.emitError()
490            << "expect source_target_pairs attribute of shape (N, 2), but got ("
491            << type.getShape() << ")";
492   // Check source target pairs for duplicate sources or targets
493   llvm::DenseSet<int64_t> sources;
494   llvm::DenseSet<int64_t> targets;
495   for (auto i = op.source_target_pairs().begin(),
496             e = op.source_target_pairs().end();
497        i != e; ++i) {
498     auto val = (*i).getSExtValue();
499     if (i.getIndex() % 2 == 0) {
500       bool is_unique = sources.insert(val).second;
501       if (!is_unique) return op.emitError() << "duplicate sources not allowed.";
502     } else {
503       bool is_unique = targets.insert(val).second;
504       if (!is_unique) return op.emitError() << "duplicate targets not allowed.";
505     }
506   }
507   return success();
508 }
509 
510 //===----------------------------------------------------------------------===//
511 // ConvertOp
512 //===----------------------------------------------------------------------===//
513 
build(OpBuilder & builder,OperationState & result,Value operand,Type result_element_ty)514 void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
515                       Type result_element_ty) {
516   Type result_ty;
517   Type operand_ty = operand.getType();
518   if (auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>()) {
519     result_ty = RankedTensorType::get(ranked_ty.getShape(), result_element_ty);
520   } else {
521     result_ty = UnrankedTensorType::get(result_element_ty);
522   }
523   build(builder, result, result_ty, operand);
524 }
525 
fold(ArrayRef<Attribute> operands)526 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
527   auto operand_ty = getOperand().getType().cast<TensorType>();
528   auto result_ty = getResult().getType().cast<TensorType>();
529   if (operand_ty == result_ty) return getOperand();
530 
531   // If the result has non-static shape, a convert op is necessary to go from
532   // static shape to non-static shape.
533   if (!result_ty.hasStaticShape()) return {};
534 
535   // TODO(hinsu): Handle unsigned types.
536   if (operand_ty.getElementType().isUnsignedInteger() ||
537       result_ty.getElementType().isUnsignedInteger()) {
538     return {};
539   }
540 
541   // If the operand is constant, we can do the conversion now.
542   if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
543     return hlo::ConvertElementsAttr(elementsAttr,
544                                     getElementTypeOrSelf(getResult()));
545   }
546 
547   return {};
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // DequantizeOp
552 //===----------------------------------------------------------------------===//
553 
Verify(DequantizeOp op)554 static LogicalResult Verify(DequantizeOp op) {
555   auto input_type = op.input().getType().dyn_cast<ShapedType>();
556   auto output_type = op.output().getType().dyn_cast<ShapedType>();
557   if (!input_type || !output_type) {
558     return op.emitError() << "ranked input and output.";
559   }
560   auto input_shape = input_type.getShape();
561   auto output_shape = output_type.getShape().vec();
562   if (op.transpose_output()) {
563     std::reverse(output_shape.begin(), output_shape.end());
564   }
565 
566   // Check the input rank and output rank are same, and also the lower
567   // dimensions are same.
568   if (input_shape.size() != output_shape.size() ||
569       !std::equal(input_shape.begin(),
570                   std::next(input_shape.begin(), input_shape.size() - 1),
571                   output_shape.begin())) {
572     return op.emitError() << "mismatched dimensions.";
573   }
574 
575   // Check that the last dimension of the output is 2x or 4x of that of the
576   // input depending on the unpacked input is 16 or 8 bits.
577   int input_last_dim = *input_shape.rbegin();
578   int output_last_dim = *output_shape.rbegin();
579   int scale_factor = op.is_16bits() ? 2 : 4;
580   if (output_last_dim != scale_factor * input_last_dim) {
581     return op.emitError() << "last dimension of output should be "
582                           << scale_factor << "x of the input.";
583   }
584 
585   return success();
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // GetTupleElementOp
590 //===----------------------------------------------------------------------===//
591 
Verify(GetTupleElementOp op)592 static LogicalResult Verify(GetTupleElementOp op) {
593   auto indexVal = op.index();
594   auto operandType = op.getOperand().getType().cast<TupleType>();
595   if (indexVal >= operandType.size()) {
596     return op.emitOpError(
597         llvm::formatv("index {0} is out of bounds of operand with size {1}",
598                       indexVal, operandType.size()));
599   }
600 
601   auto expectedType = operandType.getType(indexVal);
602   if (op.getType() != expectedType) {
603     return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
604                                         op.getType(), expectedType));
605   }
606   return success();
607 }
608 
fold(ArrayRef<Attribute> operands)609 OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
610   if (auto tupleOp =
611           dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
612     return tupleOp.getOperand(index());
613   }
614 
615   return {};
616 }
617 
618 //===----------------------------------------------------------------------===//
619 // TupleOp
620 //===----------------------------------------------------------------------===//
621 
Verify(TupleOp op)622 static LogicalResult Verify(TupleOp op) {
623   SmallVector<Type, 8> operandTypes = {op.operand_type_begin(),
624                                        op.operand_type_end()};
625   auto expectedType = TupleType::get(op.getContext(), operandTypes);
626   if (op.getType() != expectedType) {
627     return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
628                                         op.getType(), expectedType));
629   }
630   return success();
631 }
632 
633 namespace {
634 
635 // Pattern for unpacking and repacking the same tuple.
636 struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
637   using OpRewritePattern<TupleOp>::OpRewritePattern;
638 
matchAndRewritemlir::mhlo::__anon1950c2d00411::UnpackRepackSameTuple639   LogicalResult matchAndRewrite(TupleOp op,
640                                 PatternRewriter& rewriter) const override {
641     if (op.val().empty()) return failure();
642 
643     Value first_element = op.val().front();
644     auto first_element_op =
645         dyn_cast_or_null<GetTupleElementOp>(first_element.getDefiningOp());
646     if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
647       return failure();
648 
649     Value tuple_predecessor = first_element_op.getOperand();
650     if (tuple_predecessor.getType() != op.getType()) return failure();
651 
652     for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) {
653       auto element_op = dyn_cast_or_null<GetTupleElementOp>(
654           element_and_idx.value().getDefiningOp());
655       if (!element_op ||
656           element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
657           element_op.getOperand() != tuple_predecessor)
658         return failure();
659     }
660 
661     rewriter.replaceOp(op, tuple_predecessor);
662     return success();
663   }
664 };
665 
666 }  // namespace
667 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)668 void TupleOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
669                                           MLIRContext* context) {
670   results.insert<UnpackRepackSameTuple>(context);
671 }
672 
673 //===----------------------------------------------------------------------===//
674 // AllToAllOp
675 //===----------------------------------------------------------------------===//
676 
Verify(AllToAllOp op)677 static LogicalResult Verify(AllToAllOp op) {
678   // If operand is ranked, size of split dimension should be a multiple of split
679   // count.
680   auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
681   if (!type) return success();
682   auto split_dim_size = type.getDimSize(op.split_dimension());
683   auto split_count = op.split_count();
684   if (split_dim_size % split_count != 0) {
685     return op.emitError() << "split dimension has size " << split_dim_size
686                           << ", expected to be a multiple of split_count "
687                           << split_count;
688   }
689   return success();
690 }
691 
692 //===----------------------------------------------------------------------===//
693 // BroadcastOp
694 //===----------------------------------------------------------------------===//
695 
696 // TODO(b/129012527) These should be expressed as type constraints.
Verify(BroadcastOp op)697 static LogicalResult Verify(BroadcastOp op) {
698   auto sizes = op.broadcast_sizes();
699   auto sizesType = sizes.getType();
700   auto sizesRank = sizesType.getRank();
701   if (sizesRank != 1) {
702     return op.emitOpError(llvm::formatv(
703         "broadcast_sizes has rank {0} instead of rank 1", sizesRank));
704   }
705 
706   auto resultType = op.getResult().getType().cast<RankedTensorType>();
707   auto resultRank = resultType.getRank();
708   auto operandType = op.operand().getType().cast<RankedTensorType>();
709   auto operandRank = operandType.getRank();
710   auto sizesSize = sizesType.getNumElements();
711   auto expectedRank = operandRank + sizesSize;
712 
713   if (resultRank != expectedRank) {
714     return op.emitOpError(
715         llvm::formatv("result rank ({0}) does not match operand rank "
716                       "({1}) plus size of broadcast_sizes ({2})",
717                       resultRank, operandRank, sizesSize));
718   }
719 
720   llvm::SmallVector<int64_t, 10> expectedShape(sizes.getValues<int64_t>());
721 
722   auto operandShape = operandType.getShape();
723   expectedShape.insert(expectedShape.end(), operandShape.begin(),
724                        operandShape.end());
725 
726   auto resultShape = resultType.getShape();
727   if (resultShape != llvm::makeArrayRef(expectedShape)) {
728     return op.emitOpError(llvm::formatv(
729         "result has shape [{0}] instead of [{1}]",
730         llvm::make_range(resultShape.begin(), resultShape.end()),
731         llvm::make_range(expectedShape.begin(), expectedShape.end())));
732   }
733 
734   return success();
735 }
736 
737 //===----------------------------------------------------------------------===//
738 // BroadcastInDimOp
739 //===----------------------------------------------------------------------===//
740 
Verify(BroadcastInDimOp op)741 static LogicalResult Verify(BroadcastInDimOp op) {
742   auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
743   if (!operandType) {
744     // The following verification checks all depend on knowing the rank of
745     // the operand. Bail out now if we don't know the rank of the operand.
746     return success();
747   }
748 
749   auto operandRank = operandType.getRank();
750   if (!op.broadcast_dimensions()) {
751     if (operandRank == 0) {
752       return success();
753     }
754     return op.emitOpError(
755         llvm::formatv("broadcast_dimensions is absent, but required because "
756                       "operand has non-zero rank ({0})",
757                       operandRank));
758   }
759 
760   auto dimensions = op.broadcast_dimensions();
761   auto dimensionsType = op.broadcast_dimensions().getType();
762   auto dimensionsRank = dimensionsType.getRank();
763   if (dimensionsRank != 1) {
764     return op.emitOpError(llvm::formatv(
765         "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank));
766   }
767 
768   auto dimensionsSize = dimensionsType.getNumElements();
769   if (dimensionsSize != operandRank) {
770     return op.emitOpError(llvm::formatv(
771         "broadcast_dimensions size ({0}) does not match operand rank ({1})",
772         dimensionsSize, operandRank));
773   }
774 
775   auto resultType = op.getResult().getType().cast<RankedTensorType>();
776   auto resultRank = resultType.getRank();
777   if (resultRank < operandRank) {
778     return op.emitOpError(
779         llvm::formatv("result rank ({0}) is less than operand rank ({1})",
780                       resultRank, operandRank));
781   }
782 
783   for (int i = 0; i != dimensionsSize; ++i) {
784     auto dimIndex = dimensions.getValue<int64_t>(i);
785     if (dimIndex >= resultRank) {
786       return op.emitOpError(
787           llvm::formatv("broadcast_dimensions contains invalid value {0} for "
788                         "result with rank {1}",
789                         dimIndex, resultRank));
790     }
791 
792     if (!operandType.isDynamicDim(i)) {
793       auto dimSize = operandType.getDimSize(i);
794       auto resultDimSize = resultType.getDimSize(dimIndex);
795       if (dimSize != 1 && dimSize != resultDimSize) {
796         return op.emitOpError(
797             llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
798                           "1 or size of result dimension {2} ({3})",
799                           i, dimSize, dimIndex, resultDimSize));
800       }
801     }
802   }
803 
804   return success();
805 }
806 
fold(ArrayRef<Attribute> attrs)807 OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute> attrs) {
808   auto type = getType().cast<RankedTensorType>();
809   if (type == getOperand().getType()) {
810     auto broadcast_values = broadcast_dimensions().getValues<int64_t>();
811     if (!std::equal(broadcast_values.begin(), broadcast_values.end(),
812                     llvm::seq<int64_t>(0, type.getRank()).begin())) {
813       return {};
814     }
815     return getOperand();
816   }
817 
818   // Constant fold when an operand is a splat tensor attribute.
819   if (!attrs[0] || !type.hasStaticShape()) return {};
820   auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
821   if (!splatOperandAttr) return {};
822   // MLIR core bug (https://bugs.llvm.org/show_bug.cgi?id=46588): dense element
823   // attribute iterator not implemented for complex element types.
824   if (type.getElementType().isa<ComplexType>()) return {};
825   return SplatElementsAttr::get(type, splatOperandAttr.getSplatValue());
826 }
827 
828 //===----------------------------------------------------------------------===//
829 // DynamicBroadcastInDimOp
830 //===----------------------------------------------------------------------===//
831 
Verify(DynamicBroadcastInDimOp op)832 static LogicalResult Verify(DynamicBroadcastInDimOp op) {
833   auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
834   auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
835 
836   // If either the operand or result are unranked, there is very little
837   // to verify statically.
838   if (!operandType || !resultType) {
839     return success();
840   }
841 
842   auto outputDimensionsType =
843       op.output_dimensions().getType().cast<RankedTensorType>();
844   auto outputDimensionsSize = outputDimensionsType.getDimSize(0);
845   auto operandRank = operandType.getRank();
846   auto resultRank = resultType.getRank();
847 
848   // Verify broadcast_dimensions.
849   auto bcastDimensions = op.broadcast_dimensions();
850   auto bcastDimensionsType = op.broadcast_dimensions().getType();
851   auto bcastDimensionsRank = bcastDimensionsType.getRank();
852   // TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1.
853   if (bcastDimensionsRank != 1) {
854     return op.emitOpError(
855         llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1",
856                       bcastDimensionsRank));
857   }
858 
859   auto bcastDimensionsSize = bcastDimensionsType.getNumElements();
860   if (bcastDimensionsSize != operandRank) {
861     return op.emitOpError(llvm::formatv(
862         "broadcast_dimensions size ({0}) does not match operand rank ({1})",
863         bcastDimensionsSize, operandRank));
864   }
865 
866   if (resultRank < operandRank) {
867     return op.emitOpError(
868         llvm::formatv("result rank ({0}) is less than operand rank ({1})",
869                       resultRank, operandRank));
870   }
871 
872   for (int i = 0; i != bcastDimensionsSize; ++i) {
873     auto dimIndex = bcastDimensions.getValue<int64_t>(i);
874     if (dimIndex >= resultRank) {
875       return op.emitOpError(
876           llvm::formatv("broadcast_dimensions contains invalid value {0} for "
877                         "result with rank {1}",
878                         dimIndex, resultRank));
879     }
880 
881     auto dimSize = operandType.getDimSize(i);
882     auto resultDimSize = resultType.getDimSize(dimIndex);
883     // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we
884     // add a manual check for this.
885     if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) {
886       return op.emitOpError(
887           llvm::formatv("size of operand dimension {0} ({1}) is not compatible "
888                         "with size of result dimension {2} ({3})",
889                         i, dimSize, dimIndex, resultDimSize));
890     }
891   }
892 
893   if (outputDimensionsSize != resultRank) {
894     return op.emitOpError(
895         llvm::formatv("result rank ({0}) is not equal to number of output "
896                       "dimensions ({1})",
897                       resultRank, outputDimensionsSize));
898   }
899 
900   return success();
901 }
902 
903 // If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
904 // BroadcastInDimOp.
905 class DynamicBroadcastInDimOpNotActuallyDynamic
906     : public OpRewritePattern<DynamicBroadcastInDimOp> {
907   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicBroadcastInDimOp op,PatternRewriter & rewriter) const908   LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
909                                 PatternRewriter& rewriter) const override {
910     auto type = op.getType().dyn_cast<RankedTensorType>();
911     if (!type || !type.hasStaticShape()) {
912       return rewriter.notifyMatchFailure(op, "requires static shape");
913     }
914     rewriter.replaceOpWithNewOp<BroadcastInDimOp>(
915         op, op.getType(), op.operand(), op.broadcast_dimensions());
916     return success();
917   }
918 };
919 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)920 void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
921     OwningRewritePatternList& results, MLIRContext* context) {
922   results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
923                  DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2>(
924       context);
925 }
926 
927 //===----------------------------------------------------------------------===//
928 // ClampOp
929 //===----------------------------------------------------------------------===//
930 
Verify(ClampOp op)931 static LogicalResult Verify(ClampOp op) {
932   auto operandType = op.operand().getType().cast<RankedTensorType>();
933   auto operandShape = operandType.getShape();
934   auto minType = op.min().getType().cast<RankedTensorType>();
935 
936   auto minShape = minType.getShape();
937   if (minShape != operandShape && minType.getRank() != 0) {
938     return op.emitOpError(llvm::formatv(
939         "min shape [{0}] is not scalar and does not match operand shape [{1}]",
940         llvm::make_range(minShape.begin(), minShape.end()),
941         llvm::make_range(operandShape.begin(), operandShape.end())));
942   }
943 
944   auto maxType = op.max().getType().cast<RankedTensorType>();
945   auto maxShape = maxType.getShape();
946   if (maxShape != operandShape && maxType.getRank() != 0) {
947     return op.emitOpError(llvm::formatv(
948         "max shape [{0}] is not scalar and does not match operand shape [{1}]",
949         llvm::make_range(maxShape.begin(), maxShape.end()),
950         llvm::make_range(operandShape.begin(), operandShape.end())));
951   }
952 
953   return success();
954 }
955 
956 //===----------------------------------------------------------------------===//
957 // ComplexOp
958 //===----------------------------------------------------------------------===//
959 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)960 LogicalResult ComplexOp::inferReturnTypes(
961     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
962     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
963   auto type = operands[0].getType();
964   auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
965   Type result_ty;
966   if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
967     result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty);
968   } else if (type.isa<UnrankedTensorType>()) {
969     result_ty = UnrankedTensorType::get(element_ty);
970   } else {
971     result_ty = element_ty;
972   }
973   inferredReturnTypes.push_back(result_ty);
974   return success();
975 }
976 
fold(ArrayRef<Attribute> operands)977 OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
978   auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp());
979   auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp());
980   if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
981     return real_op.getOperand();
982   }
983 
984   return {};
985 }
986 
987 //===----------------------------------------------------------------------===//
988 // ImagOp
989 //===----------------------------------------------------------------------===//
990 
991 namespace {
CreateRealType(Type type)992 Type CreateRealType(Type type) {
993   auto element_ty = getElementTypeOrSelf(type);
994   if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
995     element_ty = complex_ty.getElementType();
996   }
997 
998   if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
999     return RankedTensorType::get(ranked_type.getShape(), element_ty);
1000   } else if (type.dyn_cast<UnrankedTensorType>()) {
1001     return UnrankedTensorType::get(element_ty);
1002   }
1003 
1004   return element_ty;
1005 }
1006 }  // namespace
1007 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1008 LogicalResult ImagOp::inferReturnTypes(
1009     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1010     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1011   inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
1012   return success();
1013 }
1014 
fold(ArrayRef<Attribute> operands)1015 OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
1016   if (auto complex_op =
1017           dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
1018     return complex_op.getOperand(1);
1019   }
1020 
1021   return {};
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // IsFiniteOp
1026 //===----------------------------------------------------------------------===//
1027 
getSameShapeTensorType(TensorType tensor_type,Type element_type)1028 TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type) {
1029   if (auto ranked_tensor_ty = tensor_type.dyn_cast<RankedTensorType>()) {
1030     return RankedTensorType::get(ranked_tensor_ty.getShape(), element_type);
1031   }
1032   if (auto unranked_tensor_ty = tensor_type.dyn_cast<UnrankedTensorType>()) {
1033     return UnrankedTensorType::get(element_type);
1034   }
1035   llvm_unreachable("unhandled type");
1036 }
1037 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1038 LogicalResult IsFiniteOp::inferReturnTypes(
1039     MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
1040     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1041   auto arg_ty = operands.front().getType().cast<TensorType>();
1042   Builder b(ctx);
1043   inferredReturnTypes.push_back(getSameShapeTensorType(arg_ty, b.getI1Type()));
1044   return success();
1045 }
1046 
1047 //===----------------------------------------------------------------------===//
1048 // RealOp
1049 //===----------------------------------------------------------------------===//
1050 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1051 LogicalResult RealOp::inferReturnTypes(
1052     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1053     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1054   inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
1055   return success();
1056 }
1057 
fold(ArrayRef<Attribute> operands)1058 OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
1059   if (auto complex_op =
1060           dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
1061     return complex_op.getOperand(0);
1062   }
1063 
1064   return {};
1065 }
1066 
1067 //===----------------------------------------------------------------------===//
1068 // ConcatenateOp
1069 //===----------------------------------------------------------------------===//
1070 
1071 namespace {
1072 class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
1073  public:
1074   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(ConcatenateOp op,PatternRewriter & rewriter) const1075   LogicalResult matchAndRewrite(ConcatenateOp op,
1076                                 PatternRewriter& rewriter) const override {
1077     auto axis = op.dimension();
1078     llvm::SmallVector<Value, 6> new_operands;
1079     for (auto operand : op.getOperands()) {
1080       auto ty = operand.getType().cast<ShapedType>();
1081       if (ty.getDimSize(axis) != 0) {
1082         new_operands.push_back(operand);
1083       }
1084     }
1085 
1086     if (!new_operands.empty() && new_operands.size() < op.getNumOperands()) {
1087       rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
1088                                                  new_operands, op.dimension());
1089       return success();
1090     }
1091 
1092     return failure();
1093   }
1094 };
1095 }  // namespace
1096 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1097 LogicalResult ConcatenateOp::inferReturnTypes(
1098     MLIRContext*, Optional<Location> location, ValueRange operands,
1099     DictionaryAttr attributes, RegionRange regions,
1100     SmallVectorImpl<Type>& inferredReturnTypes) {
1101   if (operands.empty()) {
1102     return failure();
1103   }
1104 
1105   auto dimension_attr = attributes.get("dimension").cast<IntegerAttr>();
1106   auto dimension = dimension_attr.getInt();
1107 
1108   auto first_type = (*operands.begin()).getType().cast<ShapedType>();
1109   auto out_element = first_type.getElementType();
1110 
1111   for (auto operand : operands.getTypes()) {
1112     auto element_type = getElementTypeOrSelf(operand);
1113     if (element_type != out_element) {
1114       return failure();
1115     }
1116   }
1117 
1118   // Find the first ranked input to determine the output rank.
1119   for (auto type : operands.getTypes()) {
1120     auto shaped_type = type.cast<ShapedType>();
1121     if (shaped_type.hasRank()) {
1122       first_type = shaped_type;
1123       break;
1124     }
1125   }
1126 
1127   // If all inputs are unranked, the result must be unranked.
1128   if (!first_type.hasRank()) {
1129     inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
1130     return success();
1131   }
1132 
1133   if (first_type.getRank() == 0)
1134     return emitOptionalError(location, "rank-0 values cannot be concatenated");
1135 
1136   auto out_shape = llvm::to_vector<6>(first_type.getShape());
1137 
1138   // Determine what the non-concatenate dimensions should be.
1139   for (auto type : operands.getTypes()) {
1140     auto shaped_ty = type.cast<ShapedType>();
1141     if (!shaped_ty.hasRank()) {
1142       continue;
1143     }
1144 
1145     for (auto it : llvm::enumerate(shaped_ty.getShape())) {
1146       // If a dimension is not dynamic, the output shape should match.
1147       if (ShapedType::isDynamic(out_shape[it.index()])) {
1148         out_shape[it.index()] = it.value();
1149       }
1150     }
1151   }
1152 
1153   out_shape[dimension] = 0;
1154 
1155   for (auto operand : operands.getTypes()) {
1156     auto type = operand.cast<ShapedType>();
1157     if (!type.hasRank()) {
1158       inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
1159       return success();
1160     }
1161 
1162     // If the dimension is dynamic we know the output dimension is dynamic.
1163     auto dim = type.getShape()[dimension];
1164     if (dim == -1) {
1165       out_shape[dimension] = -1;
1166       break;
1167     }
1168 
1169     out_shape[dimension] += dim;
1170   }
1171 
1172   inferredReturnTypes.push_back(RankedTensorType::get(out_shape, out_element));
1173 
1174   return success();
1175 }
1176 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1177 void ConcatenateOp::getCanonicalizationPatterns(
1178     OwningRewritePatternList& results, MLIRContext* context) {
1179   results.insert<ConcatenateOperandRemoval>(context);
1180 }
1181 
1182 template <typename T>
foldConcatenateHelper(ConcatenateOp * op,ArrayRef<Attribute> operands)1183 static Attribute foldConcatenateHelper(ConcatenateOp* op,
1184                                        ArrayRef<Attribute> operands) {
1185   auto axis = op->dimension();
1186   auto type = op->getType().cast<ShapedType>();
1187 
1188   SmallVector<T, 6> values;
1189   auto shape = type.getShape();
1190 
1191   size_t top_size = 1;
1192   for (int i = 0, e = axis; i < e; i++) {
1193     top_size = top_size * shape[i];
1194   }
1195 
1196   for (size_t i = 0; i < top_size; i++) {
1197     for (auto operand : operands) {
1198       DenseElementsAttr attr = operand.cast<DenseElementsAttr>();
1199       size_t bottom_size = attr.getNumElements() / top_size;
1200       auto iter = attr.getValues<T>().begin() + i * bottom_size;
1201       values.append(iter, iter + bottom_size);
1202     }
1203   }
1204 
1205   return DenseElementsAttr::get(type, values);
1206 }
1207 
foldConcatenate(ConcatenateOp * op,ArrayRef<Attribute> operands)1208 static Attribute foldConcatenate(ConcatenateOp* op,
1209                                  ArrayRef<Attribute> operands) {
1210   for (auto operand : operands) {
1211     if (!operand) return {};
1212   }
1213 
1214   auto type = op->getResult().getType().cast<ShapedType>();
1215   auto etype = type.getElementType();
1216   if (etype.isa<IntegerType>()) {
1217     return foldConcatenateHelper<APInt>(op, operands);
1218   }
1219 
1220   if (etype.isa<FloatType>()) {
1221     return foldConcatenateHelper<APFloat>(op, operands);
1222   }
1223 
1224   return {};
1225 }
1226 
fold(ArrayRef<Attribute> operands)1227 OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
1228   if (getNumOperands() == 1) return getOperand(0);
1229 
1230   ShapedType type = getResult().getType().cast<ShapedType>();
1231   if (!type.hasStaticShape()) return {};
1232 
1233   auto axis = dimension();
1234   if (auto attr = foldConcatenate(this, operands)) {
1235     return attr;
1236   }
1237 
1238   llvm::SmallVector<Value, 6> new_operands;
1239   for (auto operand : getOperands()) {
1240     auto ty = operand.getType().cast<ShapedType>();
1241     if (ty.getDimSize(axis) != 0) {
1242       return {};
1243     }
1244   }
1245 
1246   return DenseElementsAttr::get(type, ArrayRef<Attribute>());
1247 }
1248 
Verify(ConcatenateOp op)1249 static LogicalResult Verify(ConcatenateOp op) {
1250   Type element_type = getElementTypeOrSelf(op.getOperand(0).getType());
1251   RankedTensorType first_ranked_type;
1252   int num_operands = op.getNumOperands();
1253   for (int i = 0; i < num_operands; i++) {
1254     auto second_type = op.getOperand(i).getType().dyn_cast<ShapedType>();
1255     if (second_type.getElementType() != element_type) {
1256       return op.emitOpError(
1257           llvm::formatv("operands (0) and ({0}) do not match element type", i));
1258     }
1259 
1260     if (!second_type.hasRank()) {
1261       continue;
1262     }
1263 
1264     if (!first_ranked_type) {
1265       first_ranked_type = second_type.cast<RankedTensorType>();
1266       continue;
1267     }
1268 
1269     if (first_ranked_type.getRank() != second_type.getRank()) {
1270       return op.emitOpError(
1271           llvm::formatv("operands (0) and ({0}) do not match rank", i));
1272     }
1273 
1274     auto first_shape = second_type.getShape();
1275     auto second_shape = second_type.getShape();
1276     for (int d = 0; d < first_ranked_type.getRank(); ++d) {
1277       if (first_shape[d] != second_shape[d] && d != op.dimension()) {
1278         return op.emitOpError(llvm::formatv(
1279             "operands (0) and ({0}) non-concat dimensions do not match "
1280             "({1}) != ({2})",
1281             i, llvm::make_range(first_shape.begin(), first_shape.end()),
1282             llvm::make_range(second_shape.begin(), second_shape.end())));
1283       }
1284     }
1285   }
1286   return success();
1287 }
1288 
1289 //===----------------------------------------------------------------------===//
1290 // DynamicReshapeOp
1291 //===----------------------------------------------------------------------===//
1292 
Verify(DynamicReshapeOp op)1293 static LogicalResult Verify(DynamicReshapeOp op) {
1294   auto result_type = op.result().getType().dyn_cast<RankedTensorType>();
1295   auto output_shape_type =
1296       op.output_shape().getType().dyn_cast<RankedTensorType>();
1297   if (result_type && output_shape_type && output_shape_type.hasStaticShape() &&
1298       output_shape_type.getDimSize(0) != result_type.getRank()) {
1299     return op.emitError() << "output should have a rank equal to the number of "
1300                              "elements in output_shape";
1301   }
1302   return success();
1303 }
1304 
1305 namespace {
1306 class DynamicReshapeOpNotActuallyDynamic
1307     : public OpRewritePattern<DynamicReshapeOp> {
1308  public:
1309   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const1310   LogicalResult matchAndRewrite(DynamicReshapeOp op,
1311                                 PatternRewriter& rewriter) const override {
1312     auto type = op.result().getType().dyn_cast<RankedTensorType>();
1313     if (!type || !type.hasStaticShape()) {
1314       return rewriter.notifyMatchFailure(op, "requires static shape tensor");
1315     }
1316     rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
1317     return success();
1318   }
1319 };
1320 
1321 // Canonicalizes
1322 // %0 = some_op(%tensor)
1323 // %1 = "mhlo.dynamic_reshape"(%0, %shape)
1324 //      (tensor<?xT>, tensor<1xindex>) -> tensor<?xT>
1325 // ... uses of %1.
1326 //
1327 // into
1328 //
1329 // ... uses of %0.
1330 // This canonicalization is only correct if the input is correct!
1331 // TODO(b/178779691): Use a more sophisticated canonicalization that preserves
1332 // errors in input, and still allows us to get rid of redundant reshapes.
1333 class RemoveRedundantRank1DynamicReshape
1334     : public OpRewritePattern<DynamicReshapeOp> {
1335  public:
1336   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const1337   LogicalResult matchAndRewrite(DynamicReshapeOp op,
1338                                 PatternRewriter& rewriter) const override {
1339     auto type = op.result().getType().dyn_cast<RankedTensorType>();
1340     if (!type || type.getRank() != 1 || type.hasStaticShape()) {
1341       return rewriter.notifyMatchFailure(
1342           op, "requires rank 1 shape tensor with dynamic dimension");
1343     }
1344     auto operand_type = op.operand().getType().dyn_cast<RankedTensorType>();
1345     if (!operand_type || operand_type.getRank() != 1 ||
1346         operand_type.hasStaticShape()) {
1347       return rewriter.notifyMatchFailure(
1348           op, "requires rank 1 shape tensor with dynamic dimension");
1349     }
1350     rewriter.replaceOp(op, {op.operand()});
1351     return success();
1352   }
1353 };
1354 
1355 // Canonicalizes
1356 // %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
1357 // %1 = same_operands_and_result_shape_op(%tensor)
1358 // %2 = "mhlo.dynamic_reshape"(%1, %shape)
1359 // ... uses of %2.
1360 //
1361 // into
1362 //
1363 // %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
1364 // %1 = same_operands_and_result_shape_op(%tensor)
1365 // ... uses of %1.
1366 class DynamicReshapeOpSameShapeOpResult
1367     : public OpRewritePattern<DynamicReshapeOp> {
1368  public:
1369   using OpRewritePattern::OpRewritePattern;
1370 
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const1371   LogicalResult matchAndRewrite(DynamicReshapeOp op,
1372                                 PatternRewriter& rewriter) const override {
1373     Operation* def_op = op.operand().getDefiningOp();
1374     if (!def_op || !def_op->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
1375       return failure();
1376     }
1377     Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
1378     if (!input_def_op) {
1379       return failure();
1380     }
1381     auto reshape = dyn_cast<DynamicReshapeOp>(*input_def_op);
1382     if (reshape && reshape.output_shape() == op.output_shape()) {
1383       rewriter.replaceOp(op, {def_op->getResult(0)});
1384       return success();
1385     }
1386     return failure();
1387   }
1388 };
1389 }  // namespace
1390 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1391 void DynamicReshapeOp::getCanonicalizationPatterns(
1392     OwningRewritePatternList& results, MLIRContext* context) {
1393   // clang-format off
1394   results.insert<
1395       DynamicReshapeOpNotActuallyDynamic,
1396       DynamicReshapeOpSameShapeOpResult,
1397       RemoveRedundantDynamicBroadcast,
1398       RemoveRedundantDynamicReshape,
1399       RemoveRedundantRank1DynamicReshape,
1400       ShapeOfDynamicReshape
1401     >(context);
1402   // clang-format on
1403 }
1404 
1405 //===----------------------------------------------------------------------===//
1406 // DynamicSliceOp
1407 //===----------------------------------------------------------------------===//
1408 
1409 namespace {
1410 // Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops.
1411 // This canonicalization is applied the case when the `begin` input values are
1412 // compile time constants and thus can be made into a tensor.
1413 struct DynamicSliceToSlice : public OpRewritePattern<DynamicSliceOp> {
1414   using OpRewritePattern<DynamicSliceOp>::OpRewritePattern;
1415 
matchAndRewritemlir::mhlo::__anon1950c2d00811::DynamicSliceToSlice1416   LogicalResult matchAndRewrite(DynamicSliceOp dynamic_slice,
1417                                 PatternRewriter& rewriter) const override {
1418     Value input = dynamic_slice.operand();
1419     auto input_tensor = input.getType().dyn_cast<RankedTensorType>();
1420     if (!input_tensor) return failure();
1421 
1422     SmallVector<int64_t, 4> temp_start_indices;
1423     for (Value start : dynamic_slice.start_indices()) {
1424       APInt val;
1425       if (!matchPattern(start, m_ConstantInt(&val))) {
1426         return failure();
1427       }
1428       temp_start_indices.push_back(*(val.getRawData()));
1429     }
1430 
1431     // At this point we've determined that the start indices are all constants;
1432     // pack them into a single tensor.
1433     auto loc = dynamic_slice.getLoc();
1434     int64_t input_rank = input_tensor.getRank();
1435     auto slice_start_indices =
1436         GetI64ElementsAttr(temp_start_indices, &rewriter);
1437     DenseIntElementsAttr slice_limits = BuildSliceLimits(
1438         slice_start_indices, dynamic_slice.slice_sizes(), &rewriter);
1439     DenseIntElementsAttr slice_strides =
1440         GetI64ElementsAttr(SmallVector<int64_t, 4>(input_rank, 1), &rewriter);
1441     auto result = rewriter.create<SliceOp>(loc, input, slice_start_indices,
1442                                            slice_limits, slice_strides);
1443     rewriter.replaceOp(dynamic_slice, {result});
1444     return success();
1445   }
1446 };
1447 
1448 }  // namespace
1449 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1450 void DynamicSliceOp::getCanonicalizationPatterns(
1451     OwningRewritePatternList& results, MLIRContext* context) {
1452   results.insert<DynamicSliceToSlice>(context);
1453 }
1454 
1455 // Verifies that the number of slice sizes and the number of start indices match
Verify(DynamicSliceOp op)1456 static LogicalResult Verify(DynamicSliceOp op) {
1457   int num_slice_sizes = op.slice_sizes().getNumElements();
1458   int num_start_indices = op.start_indices().size();
1459   if (num_start_indices != num_slice_sizes) {
1460     return op.emitOpError()
1461            << "has mismatched number of slice sizes (" << num_slice_sizes
1462            << ") and number of start indices (" << num_start_indices << ")";
1463   }
1464   return success();
1465 }
1466 
1467 //===----------------------------------------------------------------------===//
1468 // InfeedOp
1469 //===----------------------------------------------------------------------===//
1470 
1471 // Checks that the result type is of the form `tuple< any_type, token >`.
Verify(InfeedOp op)1472 static LogicalResult Verify(InfeedOp op) {
1473   auto result_ty = op.getResult().getType().cast<TupleType>();
1474   auto subtypes = result_ty.getTypes();
1475   if (subtypes.size() != 2)
1476     return op.emitOpError()
1477            << "result is expected to be a tuple of size 2, but got "
1478            << subtypes.size();
1479   if (!subtypes[1].isa<TokenType>())
1480     return op.emitOpError() << "second element of result tuple is expected to "
1481                                "be of token type, but got "
1482                             << subtypes[1];
1483   return success();
1484 }
1485 
1486 //===----------------------------------------------------------------------===//
1487 // Logical Ops
1488 //===----------------------------------------------------------------------===//
1489 
fold(ArrayRef<Attribute> operands)1490 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
1491   if (lhs() == rhs()) return lhs();
1492 
1493   auto rType = getType().cast<ShapedType>();
1494   auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
1495   auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1496 
1497   if (lhsVal && lhsVal.isSplat()) {
1498     if (lhsVal.getSplatValue()
1499             .cast<IntegerAttr>()
1500             .getValue()
1501             .isAllOnesValue()) {
1502       return rhs();
1503     }
1504 
1505     if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
1506       return lhsVal;
1507     }
1508   }
1509 
1510   if (rhsVal && rhsVal.isSplat()) {
1511     if (rhsVal.getSplatValue()
1512             .cast<IntegerAttr>()
1513             .getValue()
1514             .isAllOnesValue()) {
1515       return lhs();
1516     }
1517 
1518     if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
1519       return rhsVal;
1520     }
1521   }
1522 
1523   if (!rhsVal || !lhsVal) return {};
1524 
1525   llvm::SmallVector<APInt, 4> values;
1526   values.reserve(rhsVal.getNumElements());
1527   for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
1528     values.push_back(std::get<0>(it) & std::get<1>(it));
1529   }
1530 
1531   return DenseIntElementsAttr::get(rType, values);
1532 }
1533 
fold(ArrayRef<Attribute> operands)1534 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
1535   if (lhs() == rhs()) return lhs();
1536 
1537   auto rType = getType().cast<ShapedType>();
1538   auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
1539   auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1540 
1541   if (lhsVal && lhsVal.isSplat()) {
1542     if (lhsVal.getSplatValue()
1543             .cast<IntegerAttr>()
1544             .getValue()
1545             .isAllOnesValue()) {
1546       return lhsVal;
1547     }
1548 
1549     if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
1550       return rhs();
1551     }
1552   }
1553 
1554   if (rhsVal && rhsVal.isSplat()) {
1555     if (rhsVal.getSplatValue()
1556             .cast<IntegerAttr>()
1557             .getValue()
1558             .isAllOnesValue()) {
1559       return rhsVal;
1560     }
1561 
1562     if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
1563       return lhs();
1564     }
1565   }
1566 
1567   if (!rhsVal || !lhsVal) return {};
1568 
1569   llvm::SmallVector<APInt, 4> values;
1570   values.reserve(rhsVal.getNumElements());
1571   for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
1572     values.push_back(std::get<0>(it) | std::get<1>(it));
1573   }
1574 
1575   return DenseIntElementsAttr::get(rType, values);
1576 }
1577 
fold(ArrayRef<Attribute> operands)1578 OpFoldResult XorOp::fold(ArrayRef<Attribute> operands) {
1579   auto rType = getType().cast<ShapedType>();
1580   if (lhs() == rhs()) {
1581     Builder builder(getContext());
1582     return builder.getZeroAttr(rType);
1583   }
1584 
1585   auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
1586   auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1587 
1588   if (lhsVal && lhsVal.isSplat()) {
1589     if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
1590       return rhs();
1591     }
1592   }
1593 
1594   if (rhsVal && rhsVal.isSplat()) {
1595     if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
1596       return lhs();
1597     }
1598   }
1599 
1600   if (!rhsVal || !lhsVal) return {};
1601 
1602   llvm::SmallVector<APInt, 4> values;
1603   values.reserve(rhsVal.getNumElements());
1604   for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
1605     values.push_back(std::get<0>(it) ^ std::get<1>(it));
1606   }
1607 
1608   return DenseIntElementsAttr::get(rType, values);
1609 }
1610 
1611 //===----------------------------------------------------------------------===//
1612 // MapOp
1613 //===----------------------------------------------------------------------===//
1614 
Verify(MapOp op)1615 static LogicalResult Verify(MapOp op) {
1616   // Checks if the number of `operands` match the arity of the map `computation`
1617   // region.
1618   auto& computation_block = op.computation().front();
1619   auto computation_args = computation_block.getArguments();
1620   if (op.operands().size() != computation_args.size())
1621     return op.emitOpError()
1622            << "expects number of operands to match the arity "
1623               "of map computation, but got: "
1624            << op.operands().size() << " and " << computation_args.size();
1625 
1626   // The parameters of computation should all be scalars and match the element
1627   // type of operands.
1628   auto operand_type = op.operands()[0].getType().cast<TensorType>();
1629   auto operand_elem_ty = operand_type.getElementType();
1630 
1631   for (auto indexed_arg : llvm::enumerate(computation_args)) {
1632     auto arg_type = indexed_arg.value().getType().dyn_cast<TensorType>();
1633     if (!arg_type || arg_type.getRank() != 0)
1634       return op.emitOpError()
1635              << "computation arguments must be 0-rank tensor, but got: arg #"
1636              << indexed_arg.index() << " of type "
1637              << indexed_arg.value().getType();
1638     if (arg_type.getElementType() != operand_elem_ty) {
1639       return op.emitOpError()
1640              << "element type of operands and computation arguments must "
1641                 "match, but got: "
1642              << operand_elem_ty << " and " << arg_type.getElementType();
1643     }
1644   }
1645 
1646   // Mapped computation must return single output
1647   auto computation_outputs = computation_block.getTerminator()->getOperands();
1648   if (computation_outputs.size() != 1)
1649     return op.emitOpError()
1650            << "computation must return single output, but got: "
1651            << computation_outputs.size();
1652 
1653   // The output of computation must be scalar and have the same element type
1654   // as op result.
1655   auto computation_output_type =
1656       computation_outputs[0].getType().dyn_cast<TensorType>();
1657   if (!computation_output_type || computation_output_type.getRank() != 0)
1658     return op.emitOpError()
1659            << "computation must return 0-rank tensor, but got: "
1660            << computation_outputs[0].getType();
1661 
1662   auto result_type = op.getType().cast<TensorType>();
1663   if (computation_output_type.getElementType() != result_type.getElementType())
1664     return op.emitOpError() << "element type of result and computation output "
1665                                "must match, but got: "
1666                             << result_type.getElementType() << " and "
1667                             << computation_output_type.getElementType();
1668 
1669   // Checks that the requested map dimension numbers are monotonically
1670   // increasing.
1671   auto values = op.dimensions().getValues<int64_t>();
1672   auto dimensions = std::vector<int64_t>{values.begin(), values.end()};
1673   for (int i = 0, e = dimensions.size(); i < e; ++i) {
1674     if (dimensions[i] != i)
1675       return op.emitOpError() << "requires monotonically increasing dimension "
1676                                  "numbers, but got: "
1677                               << op.dimensions();
1678   }
1679 
1680   // Checks that number of dimensions of operands matches the size of
1681   // `dimensions` since we currently only support mapping across all
1682   // dimensions: i.e., scalar map functions.
1683   if (operand_type.hasRank()) {
1684     if (dimensions.size() != operand_type.getShape().size())
1685       return op.emitOpError()
1686              << "applied to a subset of dimensions currently not supported: "
1687                 "operand dimensions = "
1688              << operand_type.getShape().size()
1689              << ", requested map dimensions size = " << dimensions.size();
1690   }
1691 
1692   return success();
1693 }
1694 
1695 //===----------------------------------------------------------------------===//
1696 // RecvOp
1697 //===----------------------------------------------------------------------===//
1698 
1699 // Checks that the result type is of the form `tuple<any_type, mhlo::token>`
Verify(RecvOp op)1700 static LogicalResult Verify(RecvOp op) {
1701   auto result_ty = op.getResult().getType().cast<TupleType>();
1702   auto subtypes = result_ty.getTypes();
1703   if (subtypes.size() != 2)
1704     return op.emitOpError()
1705            << "result is expected to be a tuple of size 2, but got "
1706            << subtypes.size();
1707   if (!subtypes[1].isa<TokenType>())
1708     return op.emitOpError() << "second element of result tuple is expected to "
1709                                "be of token type, but got "
1710                             << subtypes[1];
1711   return success();
1712 }
1713 
1714 //===----------------------------------------------------------------------===//
1715 // CopyOp
1716 //===----------------------------------------------------------------------===//
1717 
fold(ArrayRef<Attribute> operands)1718 OpFoldResult CopyOp::fold(ArrayRef<Attribute> operands) { return getOperand(); }
1719 
1720 //===----------------------------------------------------------------------===//
1721 // ReverseOp
1722 //===----------------------------------------------------------------------===//
1723 
fold(ArrayRef<Attribute> operands)1724 OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
1725   auto input = operand();
1726 
1727   // No dimensions to reverse.
1728   if (dimensions().getNumElements() == 0) return input;
1729 
1730   llvm::SmallVector<APInt, 5> new_dims;
1731   new_dims.reserve(dimensions().getNumElements());
1732 
1733   auto shaped_type = input.getType().cast<ShapedType>();
1734   for (auto dim : dimensions().getValues<APInt>()) {
1735     if (shaped_type.getDimSize(dim.getLimitedValue()) != 1) {
1736       return nullptr;
1737     }
1738   }
1739 
1740   return input;
1741 }
1742 
1743 //===----------------------------------------------------------------------===//
1744 // ReduceOp
1745 //===----------------------------------------------------------------------===//
1746 
1747 // Returns the result type after reducing operand of the given type across the
1748 // specified dimensions.
GetReduceResultType(Type operand_ty,DenseIntElementsAttr dimensions,Builder * builder)1749 static TensorType GetReduceResultType(Type operand_ty,
1750                                       DenseIntElementsAttr dimensions,
1751                                       Builder* builder) {
1752   Type element_ty = getElementTypeOrSelf(operand_ty);
1753 
1754   auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>();
1755   if (!ranked_ty) return UnrankedTensorType::get(element_ty);
1756 
1757   int64_t rank = ranked_ty.getRank();
1758   llvm::SmallVector<bool, 4> dims_mask(rank, false);
1759   for (int64_t dim : dimensions.getValues<int64_t>()) dims_mask[dim] = true;
1760 
1761   SmallVector<int64_t, 4> shape;
1762   for (int64_t i = 0; i < rank; ++i) {
1763     if (!dims_mask[i]) shape.push_back(ranked_ty.getDimSize(i));
1764   }
1765 
1766   return RankedTensorType::get(shape, element_ty);
1767 }
1768 
build(OpBuilder & builder,OperationState & state,ValueRange operands,ValueRange init_values,DenseIntElementsAttr dimensions)1769 void ReduceOp::build(OpBuilder& builder, OperationState& state,
1770                      ValueRange operands, ValueRange init_values,
1771                      DenseIntElementsAttr dimensions) {
1772   SmallVector<Type, 1> result_ty;
1773   result_ty.reserve(operands.size());
1774 
1775   for (Value operand : operands) {
1776     result_ty.push_back(
1777         GetReduceResultType(operand.getType(), dimensions, &builder));
1778   }
1779   build(builder, state, result_ty, operands, init_values, dimensions);
1780 }
1781 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1782 LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
1783                              SmallVectorImpl<OpFoldResult>& results) {
1784   // No dimensions to reduce.
1785   if (dimensions().getNumElements() == 0) {
1786     for (Value input : this->operands()) {
1787       results.push_back(input);
1788     }
1789     return success();
1790   }
1791   return failure();
1792 }
1793 
1794 //===----------------------------------------------------------------------===//
1795 // SelectOp
1796 //===----------------------------------------------------------------------===//
1797 
Verify(SelectOp op)1798 static LogicalResult Verify(SelectOp op) {
1799   // TODO(jpienaar): Update to allow broadcastable and unranked inputs. This
1800   // corresponds to the client side HLO.
1801   return success();
1802 }
1803 
fold(ArrayRef<Attribute> operands)1804 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
1805   if (on_true() == on_false()) {
1806     return on_true();
1807   }
1808 
1809   auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1810   if (!predicate) {
1811     return {};
1812   }
1813 
1814   auto predicateTy = predicate.getType().cast<ShapedType>();
1815   if (!predicateTy.getElementType().isInteger(1)) {
1816     return {};
1817   }
1818 
1819   if (predicate.isSplat()) {
1820     return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
1821                                                            : on_false();
1822   }
1823 
1824   return {};
1825 }
1826 
1827 // Makes it such that a SelectOp that is a non-root operation in a DRR infers
1828 // the return type based on operand type.
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1829 LogicalResult SelectOp::inferReturnTypes(
1830     MLIRContext*, Optional<Location> location, ValueRange operands,
1831     DictionaryAttr attributes, RegionRange regions,
1832     SmallVectorImpl<Type>& inferredReturnTypes) {
1833   auto x_type = operands[1].getType();
1834   auto y_type = operands[2].getType();
1835   auto x_tensor = x_type.cast<TensorType>();
1836   auto y_tensor = y_type.cast<TensorType>();
1837 
1838   // Check for type compatibility in the select op. This requires that the two
1839   // non-predicate operands:
1840   //   (a) have the same element type
1841   //   (b) have compatible shapes (i.e. the same shape and/or at least one
1842   //       dynamic shape)
1843   if (x_tensor.getElementType() != y_tensor.getElementType() ||
1844       failed(mlir::verifyCompatibleShape(x_type, y_type))) {
1845     return emitOptionalError(location, "incompatible operand types: ", x_type,
1846                              " and ", y_type);
1847   }
1848 
1849   // TODO(lucyfox): Support output shape inference when operands have compatible
1850   // shapes. (The output shape should be the most general of the operand shapes
1851   // at each dimension.) For now, handle the straightforward cases and fail
1852   // otherwise. When this is fully implemented, this logic should move into
1853   // reusable functionality in MLIR Core.
1854   Type output_type;
1855   if (x_type == y_type || !x_tensor.hasRank()) {
1856     output_type = x_type;
1857   } else if (!y_tensor.hasRank()) {
1858     output_type = y_type;
1859   } else {
1860     return emitOptionalError(location,
1861                              "currently unsupported operand types: ", x_type,
1862                              " and ", y_type);
1863   }
1864   inferredReturnTypes.assign({output_type});
1865   return success();
1866 }
1867 
inferReturnTypeComponents(mlir::MLIRContext *,llvm::Optional<mlir::Location>,mlir::ValueRange,mlir::DictionaryAttr,mlir::RegionRange,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> &)1868 LogicalResult SelectOp::inferReturnTypeComponents(
1869     mlir::MLIRContext*, llvm::Optional<mlir::Location>, mlir::ValueRange,
1870     mlir::DictionaryAttr, mlir::RegionRange,
1871     llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
1872   // TODO(b/168772852)
1873   return failure();
1874 }
1875 
reifyReturnTypeShapes(OpBuilder & builder,SmallVectorImpl<Value> & reifiedReturnShapes)1876 LogicalResult SelectOp::reifyReturnTypeShapes(
1877     OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
1878   return deriveShapeFromFirstOperand(&builder, getOperation(),
1879                                      &reifiedReturnShapes);
1880 }
1881 
1882 //===----------------------------------------------------------------------===//
1883 // SetDimensionSizeOp
1884 //===----------------------------------------------------------------------===//
1885 
Verify(SetDimensionSizeOp op)1886 static LogicalResult Verify(SetDimensionSizeOp op) {
1887   if (auto size = op.size().getType().dyn_cast<RankedTensorType>()) {
1888     if (size.getRank() != 0)
1889       return op.emitOpError() << "size operand should be of rank-0";
1890   }
1891 
1892   return VerifyDimAttr(op);
1893 }
1894 
fold(ArrayRef<Attribute> operands)1895 OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) {
1896   DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
1897   if (input) return input;
1898 
1899   DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1900   if (!size || !size.isSplat()) return {};
1901 
1902   auto ty = getType().dyn_cast<RankedTensorType>();
1903   if (!ty) return {};
1904 
1905   int64_t dim_size = ty.getDimSize(dimension());
1906   if (dim_size == size.getSplatValue().cast<IntegerAttr>().getInt())
1907     return operand();
1908   return {};
1909 }
1910 
1911 //===----------------------------------------------------------------------===//
1912 // PadOp
1913 //===----------------------------------------------------------------------===//
1914 
Verify(PadOp op)1915 static LogicalResult Verify(PadOp op) {
1916   auto input_type = op.operand().getType().cast<RankedTensorType>();
1917   auto pad_type = op.padding_value().getType().cast<RankedTensorType>();
1918 
1919   if (pad_type.getRank() != 0) {
1920     return op.emitOpError(
1921         llvm::formatv("padding value type should be a rank-0 "
1922                       "tensor, is rank {0}",
1923                       pad_type.getRank()));
1924   }
1925 
1926   const auto& padding_low = op.edge_padding_low();
1927   if (padding_low.getType().getNumElements() != input_type.getRank()) {
1928     return op.emitOpError(llvm::formatv(
1929         "edge_padding_low length ({0}) must match operand rank ({1})",
1930         padding_low.getType().getNumElements(), input_type.getRank()));
1931   }
1932 
1933   const auto& padding_high = op.edge_padding_high();
1934   if (padding_high.getType().getNumElements() != input_type.getRank()) {
1935     return op.emitOpError(llvm::formatv(
1936         "edge_padding_high length ({0}) must match operand rank ({1})",
1937         padding_high.getType().getNumElements(), input_type.getRank()));
1938   }
1939 
1940   const auto& padding_interior = op.interior_padding();
1941   if (padding_interior.getType().getNumElements() != input_type.getRank()) {
1942     return op.emitOpError(llvm::formatv(
1943         "interior_padding length ({0}) must match operand rank ({1})",
1944         padding_interior.getType().getNumElements(), input_type.getRank()));
1945   }
1946 
1947   auto input_shape = input_type.getShape();
1948   auto output_shape =
1949       op.getResult().getType().cast<RankedTensorType>().getShape();
1950   if (input_shape.size() != output_shape.size()) {
1951     return op.emitOpError(
1952         llvm::formatv("operand rank ({0}) and result rank({0}) should match",
1953                       input_shape.size(), output_shape.size()));
1954   }
1955 
1956   for (int i = 0, e = input_shape.size(); i < e; i++) {
1957     int padding_low_val = padding_low.getValue<IntegerAttr>(i).getInt();
1958     int padding_high_val = padding_high.getValue<IntegerAttr>(i).getInt();
1959     int padding_interior_val =
1960         padding_interior.getValue<IntegerAttr>(i).getInt();
1961     int expected_output =
1962         input_shape[i] + padding_low_val + padding_high_val +
1963         std::max<int64_t>(input_shape[i] - 1, 0LL) * padding_interior_val;
1964     if (expected_output != output_shape[i]) {
1965       return op.emitOpError(llvm::formatv(
1966           "expected output shape's dimension #{0} to be {1} but found {2}", i,
1967           expected_output, output_shape[i]));
1968     }
1969   }
1970 
1971   return success();
1972 }
1973 
fold(ArrayRef<Attribute> operands)1974 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
1975   // If all padding is zero then it is an identity pad.
1976   auto is_zero = [](const APInt& i) { return i == 0; };
1977   if (llvm::all_of(edge_padding_low().getIntValues(), is_zero) &&
1978       llvm::all_of(edge_padding_high().getIntValues(), is_zero) &&
1979       llvm::all_of(interior_padding().getIntValues(), is_zero))
1980     return operand();
1981 
1982   // If any padding is negative then it isn't supported by the folder (yet).
1983   auto is_negative = [](const APInt& i) { return i.slt(0); };
1984   if (llvm::all_of(edge_padding_low().getIntValues(), is_negative) &&
1985       llvm::all_of(edge_padding_high().getIntValues(), is_negative) &&
1986       llvm::all_of(interior_padding().getIntValues(), is_negative))
1987     return {};
1988 
1989   DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
1990   DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1991   RankedTensorType return_type = getType().dyn_cast_or_null<RankedTensorType>();
1992   if (!input || !input.getType().hasRank() || !padding || !return_type ||
1993       !return_type.hasStaticShape())
1994     return {};
1995 
1996   // Fill the full result tensor with the padding value.
1997   llvm::SmallVector<Attribute, 4> result(return_type.getNumElements(),
1998                                          padding.getValue({}));
1999 
2000   auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
2001                        llvm::ArrayRef<int64_t> shape) {
2002     for (int64_t i = index.size() - 1; i >= 0; --i) {
2003       ++index[i];
2004       if (index[i] < shape[i]) return;
2005       index[i] = 0;
2006     }
2007   };
2008 
2009   // Iterate over all elements of the input tensor and copy it to the correct
2010   // location in the output tensor.
2011   llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
2012   uint64_t num_elements = input.getNumElements();
2013   for (uint64_t operand_idx = 0; operand_idx < num_elements; operand_idx++) {
2014     uint64_t result_idx = 0;
2015     uint64_t idx_multiplyer = 1;
2016     for (int64_t i = index.size() - 1; i >= 0; --i) {
2017       result_idx +=
2018           (edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
2019            index[i] *
2020                (interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
2021           idx_multiplyer;
2022       idx_multiplyer *= return_type.getDimSize(i);
2023     }
2024     result[result_idx] = input.getValue(index);
2025     next_index(index, input.getType().getShape());
2026   }
2027   return DenseElementsAttr::get(return_type, result);
2028 }
2029 
2030 //===----------------------------------------------------------------------===//
2031 // ReshapeOp
2032 //===----------------------------------------------------------------------===//
2033 
Verify(ReshapeOp op)2034 static LogicalResult Verify(ReshapeOp op) {
2035   // If the operand type is dynamically shaped there is nothing to verify.
2036   auto operand_ty = op.operand().getType().dyn_cast<RankedTensorType>();
2037   if (!operand_ty || !operand_ty.hasStaticShape()) return success();
2038 
2039   // If the operand type is statically shaped (not required) the number of
2040   // elements must match that of the result type.
2041   auto result_ty = op.getType().cast<RankedTensorType>();
2042   assert(result_ty && result_ty.hasStaticShape() &&
2043          "result type must be statically shaped");
2044   int64_t num_result_elements = result_ty.getNumElements();
2045   int64_t num_operand_elements = operand_ty.getNumElements();
2046   if (num_result_elements != num_operand_elements)
2047     return op.emitOpError()
2048            << "number of output elements (" << num_result_elements
2049            << ") doesn't match expected number of elements ("
2050            << num_operand_elements << ")";
2051 
2052   return success();
2053 }
2054 
fold(ArrayRef<Attribute> operands)2055 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
2056   if (getOperand().getType() == getType()) {
2057     return getOperand();
2058   }
2059 
2060   if (auto prev_op =
2061           dyn_cast_or_null<ReshapeOp>(getOperand().getDefiningOp())) {
2062     setOperand(prev_op.getOperand());
2063     return getResult();
2064   }
2065 
2066   if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
2067     return elements.reshape(getResult().getType().cast<ShapedType>());
2068   }
2069 
2070   return {};
2071 }
2072 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2073 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
2074                                             MLIRContext* context) {
2075   results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape>(
2076       context);
2077 }
2078 
2079 //===----------------------------------------------------------------------===//
2080 // ReplicaId Op
2081 //===----------------------------------------------------------------------===//
2082 
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2083 LogicalResult ReplicaIdOp::inferReturnTypes(
2084     MLIRContext* context, Optional<Location>, ValueRange operands,
2085     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2086   inferredReturnTypes.push_back(RankedTensorType::get(
2087       /*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
2088   return success();
2089 }
2090 
2091 //===----------------------------------------------------------------------===//
2092 // Case Op
2093 //===----------------------------------------------------------------------===//
2094 
Verify(CaseOp op)2095 static LogicalResult Verify(CaseOp op) {
2096   auto num_branches = op.branches().size();
2097   if (op.branch_operands().size() != num_branches)
2098     return op.emitOpError() << "expects number of branches " << num_branches
2099                             << " to be same as number of branch operands "
2100                             << op.branch_operands().size();
2101 
2102   MutableArrayRef<Region> branches = op.branches();
2103   OperandRange branch_operands = op.branch_operands();
2104   for (unsigned i = 0; i < num_branches; ++i) {
2105     mlir::Region& branch_region = branches[i];
2106     if (branch_region.empty())
2107       return op.emitOpError() << "cannot have empty regions";
2108     mlir::Block& entry_block = branch_region.front();
2109     if (entry_block.getNumArguments() != 1)
2110       return op.emitOpError()
2111              << "expects branch regions to have single argument, but found "
2112              << entry_block.getNumArguments() << " for branch " << i;
2113     auto operand = branch_operands[i];
2114     if (entry_block.getArgument(0).getType() != operand.getType())
2115       return op.emitOpError()
2116              << "expects operand " << i + 1 << " to be of type "
2117              << entry_block.getArgument(0).getType() << ", but found "
2118              << operand.getType();
2119     WalkResult walker = branch_region.walk([&](ReturnOp return_op) {
2120       if (return_op.getOperands().getTypes() != op.getResultTypes())
2121         return WalkResult::interrupt();
2122       return WalkResult::advance();
2123     });
2124     if (walker.wasInterrupted())
2125       return op.emitOpError()
2126              << "branch " << i
2127              << " returned values do not match op result types";
2128   }
2129   return success();
2130 }
2131 
2132 //===----------------------------------------------------------------------===//
2133 // SqrtOp
2134 //===----------------------------------------------------------------------===//
2135 
fold(ArrayRef<Attribute> operands)2136 OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
2137   auto val = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2138   if (!val) return {};
2139 
2140   auto type = getElementTypeOrSelf(getType());
2141   if (!type.isF32() && !type.isF64()) return {};
2142 
2143   auto shaped_type = getType().cast<ShapedType>();
2144   if (!shaped_type.hasStaticShape()) return {};
2145 
2146   int bit_width = type.getIntOrFloatBitWidth();
2147   llvm::SmallVector<APFloat, 4> values;
2148   values.reserve(val.getNumElements());
2149   for (auto it : val.getFloatValues()) {
2150     double value = bit_width == 32 ? it.convertToFloat() : it.convertToDouble();
2151     if (value < 0) return {};
2152     value = std::sqrt(value);
2153     if (bit_width == 32)
2154       values.emplace_back(static_cast<float>(value));
2155     else
2156       values.emplace_back(value);
2157   }
2158   return DenseFPElementsAttr::get(shaped_type, values);
2159 }
2160 
2161 //===----------------------------------------------------------------------===//
2162 // UnaryOps
2163 //===----------------------------------------------------------------------===//
2164 
2165 template <typename Op, typename ElementType = Type, typename ValType,
2166           typename Convert>
UnaryFolder(Op * op,ArrayRef<Attribute> attrs)2167 static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
2168   if (!attrs[0]) return {};
2169 
2170   DenseElementsAttr val = attrs[0].dyn_cast<DenseElementsAttr>();
2171   if (!val) return {};
2172 
2173   ShapedType type = op->getType().template cast<ShapedType>();
2174   if (!type.hasStaticShape()) {
2175     return {};
2176   }
2177 
2178   Type etype = type.getElementType();
2179 
2180   // Evaluate for integer values.
2181   if (!etype.isa<ElementType>()) {
2182     return {};
2183   }
2184 
2185   SmallVector<ValType, 6> values;
2186   values.reserve(val.getNumElements());
2187   for (const auto v : val.getValues<ValType>()) {
2188     values.push_back(Convert()(v));
2189   }
2190 
2191   return DenseElementsAttr::get(type, values);
2192 }
2193 
2194 struct round {
operator ()mlir::mhlo::round2195   APFloat operator()(const APFloat& f) {
2196     APFloat r = f;
2197     r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway);
2198     return r;
2199   }
2200 };
2201 
2202 #define UNARY_FOLDER(Op, Func)                                                \
2203   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) {                          \
2204     if (getElementTypeOrSelf(getType()).isa<FloatType>())                     \
2205       return UnaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
2206     if (getElementTypeOrSelf(getType()).isa<IntegerType>())                   \
2207       return UnaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs);   \
2208     return {};                                                                \
2209   }
2210 
2211 #define UNARY_FOLDER_FLOAT(Op, Func)                                 \
2212   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) {                 \
2213     if (getElementTypeOrSelf(getType()).isa<FloatType>())            \
2214       return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \
2215     return {};                                                       \
2216   }
2217 
2218 UNARY_FOLDER(NegOp, std::negate);
2219 UNARY_FOLDER_FLOAT(RoundOp, round);
2220 
2221 //===----------------------------------------------------------------------===//
2222 // BinaryOps
2223 //===----------------------------------------------------------------------===//
2224 
2225 namespace {
2226 
2227 // Updates the element type of a (presumed) tensor type 'x', returning either
2228 // a permuted UnrankedTensorType or RankedTensorType.
UpdateResultElementType(Builder * builder,Type x,Type element_type)2229 static Type UpdateResultElementType(Builder* builder, Type x,
2230                                     Type element_type) {
2231   auto x_ranked = x.dyn_cast<RankedTensorType>();
2232   if (!x_ranked) {
2233     return UnrankedTensorType::get(element_type);
2234   }
2235 
2236   auto shape_x = x_ranked.getShape();
2237   return RankedTensorType::get(shape_x, element_type);
2238 }
2239 }  // namespace
2240 
2241 template <typename Op, typename ElementType = Type, typename ValType,
2242           typename Convert>
BinaryFolder(Op * op,ArrayRef<Attribute> attrs)2243 static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
2244   if (!attrs[0] || !attrs[1]) return {};
2245 
2246   DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
2247   DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
2248   if (!lhs || !rhs) return {};
2249 
2250   ShapedType type = op->getType().template cast<ShapedType>();
2251   if (!type.hasStaticShape()) {
2252     return {};
2253   }
2254 
2255   Type etype = type.getElementType();
2256 
2257   // Evaluate for integer values.
2258   if (!etype.isa<ElementType>()) {
2259     return {};
2260   }
2261 
2262   SmallVector<ValType, 6> values;
2263   values.reserve(lhs.getNumElements());
2264   for (const auto zip :
2265        llvm::zip(lhs.getValues<ValType>(), rhs.getValues<ValType>())) {
2266     values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
2267   }
2268 
2269   return DenseElementsAttr::get(type, values);
2270 }
2271 
2272 template <typename T>
2273 struct divide : std::divides<T> {};
2274 
2275 template <>
2276 struct divide<APInt> {
operator ()mlir::mhlo::divide2277   APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
2278 };
2279 
2280 template <typename T>
2281 struct remainder : std::modulus<T> {};
2282 
2283 template <>
2284 struct remainder<APInt> {
operator ()mlir::mhlo::remainder2285   APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
2286 };
2287 
2288 template <>
2289 struct remainder<APFloat> {
operator ()mlir::mhlo::remainder2290   APFloat operator()(const APFloat& a, const APFloat& b) const {
2291     APFloat result(a);
2292     result.remainder(b);
2293     return result;
2294   }
2295 };
2296 
2297 template <typename T>
2298 struct max {
operator ()mlir::mhlo::max2299   T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
2300 };
2301 
2302 template <>
2303 struct max<APInt> {
operator ()mlir::mhlo::max2304   APInt operator()(const APInt& a, const APInt& b) const {
2305     return llvm::APIntOps::smax(a, b);
2306   }
2307 };
2308 
2309 template <typename T>
2310 struct min {
operator ()mlir::mhlo::min2311   T operator()(const T& a, const T& b) const { return std::min<T>(a, b); }
2312 };
2313 
2314 template <>
2315 struct min<APInt> {
operator ()mlir::mhlo::min2316   APInt operator()(const APInt& a, const APInt& b) const {
2317     return llvm::APIntOps::smin(a, b);
2318   }
2319 };
2320 
2321 #define BINARY_FOLDER(Op, Func)                                                \
2322   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) {                           \
2323     if (getElementTypeOrSelf(getType()).isa<FloatType>())                      \
2324       return BinaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
2325     if (getElementTypeOrSelf(getType()).isa<IntegerType>())                    \
2326       return BinaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs);   \
2327     return {};                                                                 \
2328   }
2329 
2330 // Addition, subtraction and multiplication use the std:: versions of the ops.
2331 // Due to the other ops behaving differently in signed vs unsigned integers,
2332 // APInts need a special implementation. Currently, it replicates signed int
2333 // op behavior.
2334 BINARY_FOLDER(AddOp, std::plus);
2335 BINARY_FOLDER(SubOp, std::minus);
2336 BINARY_FOLDER(MulOp, std::multiplies);
2337 BINARY_FOLDER(DivOp, divide);
2338 BINARY_FOLDER(RemOp, remainder);
2339 BINARY_FOLDER(MaxOp, max);
2340 BINARY_FOLDER(MinOp, min);
2341 
2342 #undef BINARY_FOLDER
2343 
2344 //===----------------------------------------------------------------------===//
2345 // SliceOp
2346 //===----------------------------------------------------------------------===//
2347 
2348 // Returns output dimension size for slice result for the given arguments.
2349 // Returns -1 if arguments are illegal.
InferSliceDim(int64_t input_dim,int64_t start,int64_t end,int64_t stride)2350 static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
2351                              int64_t stride) {
2352   if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
2353       stride == 0)
2354     return -1;
2355 
2356   return llvm::divideCeil(end - start, stride);
2357 }
2358 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2359 LogicalResult SliceOp::inferReturnTypes(
2360     MLIRContext* context, Optional<Location> location, ValueRange operands,
2361     DictionaryAttr attributes, RegionRange regions,
2362     SmallVectorImpl<Type>& inferredReturnTypes) {
2363   SliceOpAdaptor slice(operands, attributes);
2364   // TODO(jpienaar): Update this code after refactoring verify.
2365   if (failed(slice.verify(location.getValueOr(UnknownLoc::get(context))))) {
2366     return failure();
2367   }
2368 
2369   Type ty = slice.operand().getType();
2370   RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
2371   if (!ranked_ty) {
2372     // The operand type is unranked, so the best we can infer for the result
2373     // type is an unranked tensor with the same element type as the operand
2374     // type.
2375     inferredReturnTypes.assign({ty});
2376     return success();
2377   }
2378 
2379   ShapedType attr_ty = slice.start_indices().getType();
2380   if (attr_ty.getRank() != 1) {
2381     return emitOptionalError(location, "start_indices has rank ",
2382                              attr_ty.getRank(), " instead of required rank 1");
2383   }
2384 
2385   int64_t rank = ranked_ty.getRank();
2386   if (attr_ty.getNumElements() != rank) {
2387     return emitOptionalError(
2388         location, "the number of elements in start_indices (",
2389         attr_ty.getNumElements(), ") does not match the rank of the operand (",
2390         rank, ")");
2391   }
2392 
2393   if (!attr_ty.getElementType().isSignlessInteger(64) ||
2394       slice.limit_indices().getType() != attr_ty ||
2395       slice.strides().getType() != attr_ty) {
2396     // Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp
2397     // having been verified at this point. Emit an error message that matches
2398     // the one that would be reported by AllTypesMatch for a more consistent
2399     // user experience.
2400     // TODO(b/171567182): Clean this up after AllTypesMatch has been refactored.
2401     return emitOptionalError(location,
2402                              "failed to verify that all of {start_indices, "
2403                              "limit_indices, strides} have same type");
2404   }
2405 
2406   SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>());
2407   SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>());
2408   SmallVector<int64_t, 4> stride_vals(slice.strides().getValues<int64_t>());
2409 
2410   SmallVector<int64_t, 4> shape;
2411   shape.reserve(rank);
2412   for (int64_t i = 0, e = rank; i != e; i++) {
2413     shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
2414                                   stride_vals[i]));
2415   }
2416   inferredReturnTypes.assign(
2417       {RankedTensorType::get(shape, ranked_ty.getElementType())});
2418   return success();
2419 }
2420 
2421 template <typename I, typename E>
SliceElements(I values,ArrayRef<int64_t> sizes,ArrayRef<int64_t> starts,ArrayRef<int64_t> limits,ArrayRef<int64_t> strides,llvm::SmallVectorImpl<E> * out_values)2422 static void SliceElements(I values, ArrayRef<int64_t> sizes,
2423                           ArrayRef<int64_t> starts, ArrayRef<int64_t> limits,
2424                           ArrayRef<int64_t> strides,
2425                           llvm::SmallVectorImpl<E>* out_values) {
2426   assert(starts.size() == limits.size());
2427   assert(starts.size() == strides.size());
2428   if (starts.empty()) return;
2429 
2430   int64_t start = starts.front();
2431   int64_t limit = limits.front();
2432   int64_t stride = strides.front();
2433   if (starts.size() == 1) {
2434     for (int i = start; i < limit; i += stride) {
2435       out_values->push_back(*(values + i));
2436     }
2437     return;
2438   }
2439 
2440   for (; start < limit; start += stride) {
2441     auto begin = values + start * sizes.front();
2442     SliceElements<I, E>(begin, sizes.drop_front(), starts.drop_front(),
2443                         limits.drop_front(), strides.drop_front(), out_values);
2444   }
2445 }
2446 
2447 template <typename I, typename E>
FoldSlice(SliceOp * op,I values)2448 static Attribute FoldSlice(SliceOp* op, I values) {
2449   auto start = llvm::to_vector<6>(op->start_indices().getValues<int64_t>());
2450   auto limit = llvm::to_vector<6>(op->limit_indices().getValues<int64_t>());
2451   auto stride = llvm::to_vector<6>(op->strides().getValues<int64_t>());
2452 
2453   auto result_type = op->operand().getType().cast<ShapedType>();
2454   if (!result_type.hasStaticShape()) return {};
2455 
2456   auto shape = result_type.getShape();
2457   int64_t count = result_type.getNumElements();
2458   if (count == 0) {
2459     return DenseElementsAttr::get<E>(
2460         op->getResult().getType().cast<ShapedType>(),
2461         /*list=*/{});
2462   }
2463 
2464   // Compute the striding for each dimension.
2465   llvm::SmallVector<int64_t, 6> sizes;
2466   sizes.reserve(shape.size());
2467   for (auto v : shape) {
2468     count = count / v;
2469     sizes.push_back(count);
2470   }
2471 
2472   llvm::SmallVector<E, 6> out_values;
2473   out_values.reserve(result_type.getNumElements());
2474   SliceElements<I, E>(values, sizes, start, limit, stride, &out_values);
2475 
2476   return DenseElementsAttr::get(op->getResult().getType().cast<ShapedType>(),
2477                                 out_values);
2478 }
2479 
fold(ArrayRef<Attribute> operands)2480 OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
2481   // Check if the SliceOp is a NoOp operation.
2482   auto operand_type = getOperand().getType().cast<ShapedType>();
2483   auto result_type = getResult().getType().cast<ShapedType>();
2484 
2485   if (operand_type.hasStaticShape() && result_type.hasStaticShape() &&
2486       (operand_type.getShape() == result_type.getShape())) {
2487     return getOperand();
2488   }
2489 
2490   if (operands.empty() || !operands.front()) return {};
2491 
2492   // Evaluate for statically valued inputs.
2493   DenseElementsAttr elements = operands.front().dyn_cast<DenseElementsAttr>();
2494   if (!elements) return {};
2495 
2496   auto etype = elements.getType().getElementType();
2497   if (etype.isa<IntegerType>()) {
2498     return FoldSlice<DenseElementsAttr::IntElementIterator, APInt>(
2499         this, elements.getIntValues().begin());
2500   } else if (etype.isa<FloatType>()) {
2501     return FoldSlice<
2502         llvm::mapped_iterator<DenseElementsAttr::IntElementIterator,
2503                               std::function<APFloat(const APInt&)>>,
2504         APFloat>(this, elements.getFloatValues().begin());
2505   }
2506 
2507   return {};
2508 }
2509 
2510 namespace {
2511 // In cases where a concat is fed into a slice, it is possible the concat
2512 // can be simplified or bypassed. This checks which inputs to the concat are
2513 // used by the slice, either reducing the number of concatenated values or
2514 // entirely removes the concat.
2515 struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
2516   using OpRewritePattern<SliceOp>::OpRewritePattern;
2517 
matchAndRewritemlir::mhlo::__anon1950c2d00e11::SimplifyConcatSlice2518   LogicalResult matchAndRewrite(SliceOp slice,
2519                                 PatternRewriter& rewriter) const override {
2520     auto result_ty = slice.getType().cast<ShapedType>();
2521     if (!result_ty.hasStaticShape()) {
2522       return failure();
2523     }
2524 
2525     auto slice_input = slice.operand();
2526     auto slice_input_ty = slice_input.getType().cast<ShapedType>();
2527     auto concat = dyn_cast_or_null<ConcatenateOp>(slice_input.getDefiningOp());
2528     if (!concat) {
2529       return failure();
2530     }
2531 
2532     auto dimension = concat.dimension();
2533 
2534     auto start = slice.start_indices().getIntValues();
2535     auto limit = slice.limit_indices().getIntValues();
2536 
2537     auto slice_start = (*(start.begin() + dimension)).getSExtValue();
2538     auto slice_limit = (*(limit.begin() + dimension)).getSExtValue();
2539 
2540     // We need to determine what inputs from the concat affect the slice, and
2541     // how the bounds of the slice need to be updated for the minimally required
2542     // inputs.
2543     int64_t running_size = 0;
2544     int64_t front_offset = slice_input_ty.getShape()[dimension];
2545 
2546     auto subset_start = concat.operand_end();
2547     auto subset_end = concat.operand_end();
2548     for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) {
2549       auto input = *it;
2550       ShapedType input_ty = input.getType().cast<ShapedType>();
2551       if (input_ty.isDynamicDim(dimension)) {
2552         return failure();
2553       }
2554       auto dim_size = input_ty.getShape()[dimension];
2555 
2556       // If this position is in the slice its the start of the subset and we
2557       // need to update the start and limit values.
2558       if (running_size + dim_size > slice_start &&
2559           subset_start == concat.operand_end()) {
2560         subset_start = it;
2561         front_offset = running_size;
2562       }
2563 
2564       // Determine the last required offset.
2565       if (running_size < slice_limit) {
2566         subset_end = it + 1;
2567       }
2568 
2569       running_size += dim_size;
2570     }
2571 
2572     auto subset_size = subset_end - subset_start;
2573     // We need all inputs so no optimization.
2574     if (subset_size == concat.getNumOperands()) {
2575       return failure();
2576     }
2577 
2578     if (subset_size > 1 && !concat.getResult().hasOneUse()) {
2579       return failure();
2580     }
2581 
2582     auto concat_range = OperandRange(subset_start, subset_end);
2583     auto new_concat = rewriter.create<ConcatenateOp>(
2584         concat.getLoc(), concat_range, concat.dimension());
2585 
2586     llvm::SmallVector<APInt, 6> new_start(start);
2587     llvm::SmallVector<APInt, 6> new_limit(limit);
2588     new_start[dimension] -= front_offset;
2589     new_limit[dimension] -= front_offset;
2590 
2591     auto attr_type = slice.start_indices().getType().cast<ShapedType>();
2592     auto create = rewriter.create<SliceOp>(
2593         slice.getLoc(), new_concat,
2594         DenseIntElementsAttr::get(attr_type, new_start),
2595         DenseIntElementsAttr::get(attr_type, new_limit), slice.strides());
2596     rewriter.replaceOp(slice, create.getResult());
2597     return success();
2598   }
2599 };
2600 }  // namespace
2601 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2602 void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
2603                                           MLIRContext* context) {
2604   results.insert<SimplifyConcatSlice>(context);
2605 }
2606 
2607 //===----------------------------------------------------------------------===//
2608 // SortOp
2609 //===----------------------------------------------------------------------===//
2610 
build(OpBuilder & builder,OperationState & state,ValueRange operands,int64_t dimension,bool is_stable)2611 void SortOp::build(OpBuilder& builder, OperationState& state,
2612                    ValueRange operands, int64_t dimension, bool is_stable) {
2613   state.addOperands(operands);
2614   state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
2615   state.addAttribute("is_stable", builder.getBoolAttr(dimension));
2616 
2617   for (Value operand : operands) state.addTypes(operand.getType());
2618 
2619   state.addRegion();
2620 }
2621 
Verify(SortOp op)2622 static LogicalResult Verify(SortOp op) {
2623   Operation::operand_range operands = op.operands();
2624   if (operands.empty()) return op.emitOpError("requires at least one input");
2625 
2626   // TODO(antiagainst): verify partionally dynamic shapes
2627   if (llvm::all_of(operands, [](Value operand) {
2628         return operand.getType().cast<ShapedType>().hasRank();
2629       })) {
2630     ArrayRef<int64_t> input_shape =
2631         (*operands.begin()).getType().cast<ShapedType>().getShape();
2632 
2633     if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
2634           return operand.getType().cast<ShapedType>().getShape() != input_shape;
2635         }))
2636       return op.emitOpError("requires all inputs to have the same dimensions");
2637 
2638     int64_t rank = input_shape.size();
2639     int64_t cmp_dim = op.dimension();
2640     if (cmp_dim < -rank || cmp_dim >= rank)
2641       return op.emitOpError("dimension attribute value must be in range [-")
2642              << rank << ", " << rank << "), but found " << cmp_dim;
2643   }
2644 
2645   Block& block = op.comparator().front();
2646   size_t num_operands = op.getOperation()->getNumOperands();
2647   if (block.getNumArguments() != 2 * num_operands)
2648     return op.emitOpError("comparator block should have ")
2649            << 2 * num_operands << " arguments";
2650 
2651   for (auto indexed_operand : llvm::enumerate(operands)) {
2652     int index = indexed_operand.index();
2653     Type element_type =
2654         indexed_operand.value().getType().cast<ShapedType>().getElementType();
2655     Type tensor_type = RankedTensorType::get({}, element_type);
2656     for (int i : {2 * index, 2 * index + 1}) {
2657       Type arg_type = block.getArgument(i).getType();
2658       if (arg_type != tensor_type)
2659         return op.emitOpError("comparator block argument #")
2660                << i << " should be of type " << tensor_type << " but got "
2661                << arg_type;
2662     }
2663   }
2664 
2665   return success();
2666 }
2667 
2668 //===----------------------------------------------------------------------===//
2669 // TransposeOp
2670 //===----------------------------------------------------------------------===//
2671 
fold(ArrayRef<Attribute> operands)2672 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
2673   for (auto it : llvm::enumerate(permutation().getValues<APInt>())) {
2674     if (it.index() != it.value()) {
2675       return {};
2676     }
2677   }
2678   return getOperand();
2679 }
2680 
Verify(TransposeOp op)2681 static LogicalResult Verify(TransposeOp op) {
2682   // permutation is an attribute of the op so it has static shape.
2683   auto permutationType = op.permutation().getType();
2684   auto permutationRank = permutationType.getRank();
2685   if (permutationRank != 1) {
2686     return op.emitOpError(llvm::formatv(
2687         "permutation has rank {0} instead of rank 1", permutationRank));
2688   }
2689   auto permutationSize = permutationType.getNumElements();
2690 
2691   auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
2692   if (operandType) {
2693     auto operandRank = operandType.getRank();
2694     if (operandRank != permutationSize) {
2695       return op.emitOpError(llvm::formatv(
2696           "operand rank ({0}) does not match permutation size ({1})",
2697           operandRank, permutationSize));
2698     }
2699   }
2700 
2701   auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
2702   if (resultType) {
2703     auto resultRank = resultType.getRank();
2704     if (resultRank != permutationSize) {
2705       return op.emitOpError(llvm::formatv(
2706           "result rank ({0}) does not match permutation size ({1})", resultRank,
2707           permutationSize));
2708     }
2709   }
2710 
2711   if (!resultType || !operandType) return success();
2712 
2713   auto operandRank = operandType.getRank();
2714   SmallVector<int64_t, 4> expectedShape(operandRank);
2715   for (int i = 0; i != operandRank; ++i) {
2716     auto permutedDim = op.permutation().getValue<IntegerAttr>(i).getInt();
2717     expectedShape[i] = operandType.getDimSize(permutedDim);
2718   }
2719 
2720   auto expectedType =
2721       RankedTensorType::get(expectedShape, resultType.getElementType());
2722   if (failed(verifyCompatibleShape(resultType, expectedType))) {
2723     return op.emitOpError(llvm::formatv(
2724         "result type {0} is incompatible with the expected type {1}",
2725         resultType, expectedType));
2726   }
2727 
2728   return success();
2729 }
2730 
2731 //===----------------------------------------------------------------------===//
2732 // TriangularSolveOp
2733 //===----------------------------------------------------------------------===//
2734 
Verify(TriangularSolveOp op)2735 static LogicalResult Verify(TriangularSolveOp op) {
2736   auto a_type = op.a().getType().dyn_cast<RankedTensorType>();
2737 
2738   // Skip verifier if a is unranked tensor.
2739   if (!a_type) return success();
2740 
2741   // Check that a should have rank >= 2
2742   auto a_rank = a_type.getRank();
2743   if (a_rank < 2)
2744     return op.emitOpError()
2745            << "operand 'a' must have rank >= 2, but got " << a_type;
2746 
2747   // The two minor dimensions of a must have same size.
2748   if (a_type.getDimSize(a_rank - 2) != a_type.getDimSize(a_rank - 1))
2749     return op.emitOpError() << "two minor dimensions of operand 'a' must have "
2750                                "equal size, but got "
2751                             << a_type;
2752 
2753   auto b_type = op.b().getType().dyn_cast<RankedTensorType>();
2754   // If b is unranked skip remaining checks.
2755   if (!b_type) return success();
2756 
2757   // Check that a and b have same rank.
2758   auto b_rank = b_type.getRank();
2759   if (a_rank != b_rank)
2760     return op.emitOpError() << "operands must have equal rank, but got "
2761                             << a_type << " and " << b_type;
2762 
2763   // The shared dimension of a and b should match.
2764   if (a_type.getDimSize(a_rank - 1) !=
2765       b_type.getDimSize(b_rank - (op.left_side() ? 2 : 1)))
2766     return op.emitOpError() << "shared dimension of operands 'a' and 'b' does "
2767                                "not match, but got "
2768                             << a_type << " and " << b_type;
2769 
2770   // The leading batch dimensions of a and b must be equal.
2771   auto a_batch_dims = a_type.getShape().drop_back(2);
2772   auto b_batch_dims = b_type.getShape().drop_back(2);
2773   if (a_batch_dims != b_batch_dims)
2774     return op.emitOpError()
2775            << "leading batch dimensions of the operands must be same, but got "
2776            << a_type << " and " << b_type;
2777 
2778   // Result and argument b must have same shape.
2779   auto result_type = op.getType().dyn_cast<RankedTensorType>();
2780   if (!result_type) return success();
2781   if (result_type != b_type)
2782     return op.emitOpError()
2783            << "result and operand 'b' must have same shape, but got "
2784            << result_type << " and " << b_type;
2785   return success();
2786 }
2787 
2788 //===----------------------------------------------------------------------===//
2789 // GetTupleElementOp
2790 //===----------------------------------------------------------------------===//
2791 
build(OpBuilder & builder,OperationState & result,Value tuple,int32_t index)2792 void GetTupleElementOp::build(OpBuilder& builder, OperationState& result,
2793                               Value tuple, int32_t index) {
2794   if (auto tuple_type = tuple.getType().dyn_cast<TupleType>()) {
2795     auto element_type = tuple_type.getType(index);
2796     build(builder, result, element_type, tuple,
2797           builder.getI32IntegerAttr(index));
2798     return;
2799   }
2800 
2801   build(builder, result, tuple.getType(), tuple,
2802         builder.getI32IntegerAttr(index));
2803 }
2804 
2805 //===----------------------------------------------------------------------===//
2806 // TupleOp
2807 //===----------------------------------------------------------------------===//
2808 
build(OpBuilder & builder,OperationState & result,ValueRange values)2809 void TupleOp::build(OpBuilder& builder, OperationState& result,
2810                     ValueRange values) {
2811   SmallVector<Type, 4> types;
2812   types.reserve(values.size());
2813   for (auto val : values) {
2814     types.push_back(val.getType());
2815   }
2816 
2817   build(builder, result, builder.getTupleType(types), values);
2818 }
2819 
2820 //===----------------------------------------------------------------------===//
2821 // UnaryEinsumOp
2822 //===----------------------------------------------------------------------===//
2823 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2824 void UnaryEinsumOp::getCanonicalizationPatterns(
2825     OwningRewritePatternList& results, MLIRContext* context) {
2826   results.insert<UnaryEinsumToEinsum>(context);
2827 }
2828 
2829 //===----------------------------------------------------------------------===//
2830 // CompareOp
2831 //===----------------------------------------------------------------------===//
2832 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,StringAttr comparison_direction,StringAttr compare_type)2833 void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
2834                       Value rhs, StringAttr comparison_direction,
2835                       StringAttr compare_type) {
2836   auto new_type =
2837       UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
2838   build(builder, result, new_type, lhs, rhs, comparison_direction,
2839         compare_type);
2840 }
2841 
inferReturnTypeComponents(mlir::MLIRContext *,llvm::Optional<mlir::Location>,mlir::ValueRange,mlir::DictionaryAttr,mlir::RegionRange,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> &)2842 LogicalResult CompareOp::inferReturnTypeComponents(
2843     mlir::MLIRContext*, llvm::Optional<mlir::Location>, mlir::ValueRange,
2844     mlir::DictionaryAttr, mlir::RegionRange,
2845     llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
2846   // TODO(b/168772852)
2847   return failure();
2848 }
2849 
reifyReturnTypeShapes(OpBuilder & builder,SmallVectorImpl<Value> & reifiedReturnShapes)2850 LogicalResult CompareOp::reifyReturnTypeShapes(
2851     OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
2852   return deriveShapeFromFirstOperand(&builder, getOperation(),
2853                                      &reifiedReturnShapes);
2854 }
2855 
2856 template <typename T>
2857 struct less : std::less<T> {};
2858 
2859 template <>
2860 struct less<APInt> {
operator ()mlir::mhlo::less2861   bool operator()(const APInt& a, const APInt& b) const { return a.slt(b); }
2862 };
2863 
2864 template <typename T>
2865 struct less_equal : std::less_equal<T> {};
2866 
2867 template <>
2868 struct less_equal<APInt> {
operator ()mlir::mhlo::less_equal2869   bool operator()(const APInt& a, const APInt& b) const { return a.sle(b); }
2870 };
2871 
2872 template <typename T>
2873 struct greater : std::greater<T> {};
2874 
2875 template <>
2876 struct greater<APInt> {
operator ()mlir::mhlo::greater2877   bool operator()(const APInt& a, const APInt& b) const { return a.sgt(b); }
2878 };
2879 
2880 template <typename T>
2881 struct greater_equal : std::greater_equal<T> {};
2882 
2883 template <>
2884 struct greater_equal<APInt> {
operator ()mlir::mhlo::greater_equal2885   bool operator()(const APInt& a, const APInt& b) const { return a.sge(b); }
2886 };
2887 
2888 template <typename Op, typename ElementType, typename SrcType, typename Convert>
CompareFolder(CompareOp op,ArrayRef<Attribute> attrs)2889 static Attribute CompareFolder(CompareOp op, ArrayRef<Attribute> attrs) {
2890   if (!attrs[0] || !attrs[1]) return {};
2891 
2892   DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
2893   DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
2894   if (!lhs || !rhs) return {};
2895 
2896   ShapedType operand_type =
2897       op.getOperand(0).getType().template cast<ShapedType>();
2898   if (!operand_type.hasStaticShape()) {
2899     return {};
2900   }
2901 
2902   if (!operand_type.getElementType().isa<ElementType>()) {
2903     return {};
2904   }
2905 
2906   SmallVector<bool, 6> values;
2907   values.reserve(lhs.getNumElements());
2908   for (const auto zip :
2909        llvm::zip(lhs.getValues<SrcType>(), rhs.getValues<SrcType>())) {
2910     values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
2911   }
2912 
2913   auto result_ty = op.getType().cast<ShapedType>();
2914   return DenseElementsAttr::get(result_ty, values);
2915 }
2916 
fold(ArrayRef<Attribute> operands)2917 OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
2918   auto result_ty = getType().cast<ShapedType>();
2919   if (!result_ty.hasStaticShape()) return {};
2920 
2921   auto direction = comparison_direction();
2922   if (lhs() == rhs() && !getElementTypeOrSelf(lhs()).isa<FloatType>()) {
2923     if (direction == "LE" || direction == "EQ" || direction == "GE") {
2924       return DenseIntElementsAttr::get(result_ty, {true});
2925     }
2926     return DenseIntElementsAttr::get(result_ty, {false});
2927   }
2928 
2929   if (!operands[0] || !operands[1]) {
2930     return {};
2931   }
2932 
2933 #define COMPARE_FOLDER(Op, comparison, Func)                                \
2934   if (direction == comparison) {                                            \
2935     if (auto folded = CompareFolder<Op, FloatType, APFloat, Func<APFloat>>( \
2936             *this, operands))                                               \
2937       return folded;                                                        \
2938     if (auto folded = CompareFolder<Op, IntegerType, APInt, Func<APInt>>(   \
2939             *this, operands))                                               \
2940       return folded;                                                        \
2941   }
2942 
2943   COMPARE_FOLDER(CompareOp, "EQ", std::equal_to);
2944   COMPARE_FOLDER(CompareOp, "NE", std::not_equal_to);
2945   COMPARE_FOLDER(CompareOp, "LT", less);
2946   COMPARE_FOLDER(CompareOp, "LE", less_equal);
2947   COMPARE_FOLDER(CompareOp, "GT", greater);
2948   COMPARE_FOLDER(CompareOp, "GE", greater_equal);
2949 #undef COMPARE_FOLDER
2950 
2951   return {};
2952 }
2953 
2954 //===----------------------------------------------------------------------===//
2955 // ScatterOp
2956 //===----------------------------------------------------------------------===//
2957 
evaluateMhloRegion(Region & region,ArrayRef<Attribute> inputs)2958 llvm::SmallVector<Attribute, 4> evaluateMhloRegion(Region& region,
2959                                                    ArrayRef<Attribute> inputs) {
2960   if (region.getNumArguments() != inputs.size()) return {};
2961 
2962   llvm::DenseMap<Value, Attribute> values;
2963   values.reserve(region.getNumArguments());
2964   for (auto it : llvm::zip(region.getArguments(), inputs)) {
2965     values.try_emplace(std::get<0>(it), std::get<1>(it));
2966   }
2967 
2968   for (auto& op : region.getOps()) {
2969     llvm::SmallVector<Attribute, 4> inputs;
2970     for (auto& operand : op.getOpOperands()) {
2971       inputs.push_back(values.lookup(operand.get()));
2972     }
2973     if (isa<ReturnOp>(op)) return inputs;
2974 
2975     llvm::SmallVector<OpFoldResult, 4> results;
2976     if (failed(op.fold(inputs, results))) return {};
2977     for (auto it : llvm::zip(op.getResults(), results)) {
2978       if (!std::get<1>(it).is<Attribute>()) return {};
2979       values.insert({std::get<0>(it), std::get<1>(it).get<Attribute>()});
2980     }
2981   }
2982   return {};
2983 }
2984 
fold(ArrayRef<Attribute> operands)2985 OpFoldResult ScatterOp::fold(ArrayRef<Attribute> operands) {
2986   auto base = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2987   auto index = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
2988   auto update = operands[2].dyn_cast_or_null<DenseElementsAttr>();
2989   if (!base || !index || !update) return {};
2990 
2991   auto base_type = base.getType().dyn_cast<RankedTensorType>();
2992   auto index_type = index.getType().dyn_cast<RankedTensorType>();
2993   auto update_type = update.getType().dyn_cast<RankedTensorType>();
2994   if (!base_type || !index_type || !update_type) return {};
2995 
2996   // Add the virtual trailing dimension of size 1 if index_vector_dim equals to
2997   // index_type.rank.
2998   const int64_t index_vector_dim =
2999       scatter_dimension_numbers().index_vector_dim().getInt();
3000   if (index_vector_dim == index_type.getRank()) {
3001     auto index_shape = index_type.getShape().vec();
3002     index_shape.push_back(1);
3003     index_type =
3004         RankedTensorType::get(index_shape, index_type.getElementType());
3005     index = index.reshape(index_type).cast<DenseIntElementsAttr>();
3006   }
3007 
3008   // Increment the multi-dimensional index vector based on the limits for each
3009   // dimension specified by shape and returns false if the index rolled around
3010   // with true otherwise.
3011   auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
3012                        llvm::ArrayRef<int64_t> shape) {
3013     for (int64_t i = index.size() - 1; i >= 0; --i) {
3014       ++index[i];
3015       if (index[i] < shape[i]) return true;
3016       index[i] = 0;
3017     }
3018     return false;
3019   };
3020 
3021   // Iterate over all elements of the update tensor, then find the corresponding
3022   // value in the indices tensor to determine which location we have to update
3023   // in the base/result tensor.
3024   llvm::SmallVector<Attribute, 8> results(base.getValues<Attribute>());
3025   llvm::SmallVector<uint64_t, 8> update_index(update_type.getRank(), 0);
3026   llvm::SmallVector<uint64_t, 8> index_index;
3027   index_index.reserve(index_type.getRank());
3028   llvm::SmallVector<uint64_t, 8> base_index;
3029   base_index.reserve(base_type.getRank());
3030   do {
3031     // Compute the index for the slice of the indices tensor for this update
3032     // value.
3033     index_index.clear();
3034     if (index_vector_dim == 0) index_index.push_back(0);
3035     for (int64_t i = 0; i < update_index.size(); ++i) {
3036       if (llvm::count(scatter_dimension_numbers().update_window_dims(), i) == 0)
3037         index_index.push_back(update_index[i]);
3038       if (index_index.size() == index_vector_dim) index_index.push_back(0);
3039     }
3040 
3041     // Compute the index for the given update value in the base tensor.
3042     base_index.assign(base_type.getRank(), 0);
3043     uint64_t index_count = index_type.getShape()[index_vector_dim];
3044     for (uint64_t i = 0; i < index_count; ++i) {
3045       uint64_t operand_dim = scatter_dimension_numbers()
3046                                  .scatter_dims_to_operand_dims()
3047                                  .getValue<APInt>({i})
3048                                  .getSExtValue();
3049       index_index[index_vector_dim] = i;
3050       base_index[operand_dim] +=
3051           index.getValue<APInt>(index_index).getSExtValue();
3052     }
3053     uint64_t update_window_dim_index = 0;
3054     for (uint64_t i = 0; i < base_index.size(); ++i) {
3055       if (llvm::count(scatter_dimension_numbers().inserted_window_dims(), i))
3056         continue;
3057       base_index[i] +=
3058           update_index[scatter_dimension_numbers()
3059                            .update_window_dims()
3060                            .getValue<APInt>({update_window_dim_index})
3061                            .getSExtValue()];
3062       update_window_dim_index++;
3063     }
3064 
3065     // Compute the linear index for the index into the base tensor.
3066     int64_t linear_base_index = 0;
3067     int64_t linear_base_index_multiplyer = 1;
3068     for (int64_t i = base_index.size() - 1; i >= 0; --i) {
3069       // Out of bound index have backend specific behaviour so avoid folding it.
3070       if (base_index[i] < 0 || base_index[i] >= base_type.getShape()[i])
3071         return {};
3072       linear_base_index += base_index[i] * linear_base_index_multiplyer;
3073       linear_base_index_multiplyer *= base_type.getShape()[i];
3074     }
3075 
3076     // Evaluate update computation and update the value with the newly computed
3077     // attribute in the base tensor.
3078     auto lhs = DenseElementsAttr::get(
3079         RankedTensorType::get({}, base_type.getElementType()),
3080         results[linear_base_index]);
3081     auto rhs = DenseElementsAttr::get(
3082         RankedTensorType::get({}, base_type.getElementType()),
3083         update.getValue<Attribute>(update_index));
3084     auto new_value = evaluateMhloRegion(update_computation(), {lhs, rhs});
3085     if (new_value.size() != 1 || !new_value[0]) return {};
3086     results[linear_base_index] =
3087         new_value[0].cast<DenseElementsAttr>().getValue<Attribute>({});
3088   } while (next_index(update_index, update_type.getShape()));
3089 
3090   return DenseElementsAttr::get(base_type, results);
3091 }
3092 
3093 }  // namespace mhlo
3094 }  // namespace mlir
3095 
3096 #define GET_OP_CLASSES
3097 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
3098 
3099 namespace mlir {
3100 namespace mhlo {
3101 
3102 //===----------------------------------------------------------------------===//
3103 // mhlo Dialect Interfaces
3104 //===----------------------------------------------------------------------===//
3105 
3106 namespace {
3107 struct HLOInlinerInterface : public DialectInlinerInterface {
3108   using DialectInlinerInterface::DialectInlinerInterface;
3109 
3110   // Allow all call operations to be inlined.
isLegalToInlinemlir::mhlo::__anon1950c2d01211::HLOInlinerInterface3111   bool isLegalToInline(Operation* call, Operation* callable,
3112                        bool wouldBeCloned) const final {
3113     return true;
3114   }
3115   // We don't have any special restrictions on what can be inlined into
3116   // destination regions (e.g. while/conditional bodies). Always allow it.
isLegalToInlinemlir::mhlo::__anon1950c2d01211::HLOInlinerInterface3117   bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned,
3118                        BlockAndValueMapping& valueMapping) const final {
3119     return true;
3120   }
3121   // Operations in mhlo dialect are always legal to inline since they are
3122   // pure.
isLegalToInlinemlir::mhlo::__anon1950c2d01211::HLOInlinerInterface3123   bool isLegalToInline(Operation*, Region*, bool,
3124                        BlockAndValueMapping&) const final {
3125     return true;
3126   }
3127 };
3128 }  // end anonymous namespace
3129 
3130 //===----------------------------------------------------------------------===//
3131 // mhlo Dialect Constructor
3132 //===----------------------------------------------------------------------===//
3133 
MhloDialect(MLIRContext * context)3134 MhloDialect::MhloDialect(MLIRContext* context)
3135     : Dialect(getDialectNamespace(), context, TypeID::get<MhloDialect>()) {
3136   addOperations<
3137 #define GET_OP_LIST
3138 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
3139       >();
3140   addInterfaces<HLOInlinerInterface>();
3141   addTypes<TokenType>();
3142   context->loadDialect<tensor::TensorDialect>();
3143 }
3144 
parseType(DialectAsmParser & parser) const3145 Type MhloDialect::parseType(DialectAsmParser& parser) const {
3146   StringRef data_type;
3147   if (parser.parseKeyword(&data_type)) return Type();
3148 
3149   if (data_type == "token") return TokenType::get(getContext());
3150   parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type;
3151   return nullptr;
3152 }
3153 
printType(Type type,DialectAsmPrinter & os) const3154 void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
3155   if (type.isa<TokenType>()) {
3156     os << "token";
3157     return;
3158   }
3159   os << "<unknown mhlo type>";
3160 }
3161 
3162 //===----------------------------------------------------------------------===//
3163 // Shape inference
3164 //===----------------------------------------------------------------------===//
3165 
deriveShapeFromFirstOperand(OpBuilder * builder,Operation * op,SmallVectorImpl<Value> * reifiedReturnShapes)3166 LogicalResult deriveShapeFromFirstOperand(
3167     OpBuilder* builder, Operation* op,
3168     SmallVectorImpl<Value>* reifiedReturnShapes) {
3169   Value operand = op->getOperand(0);
3170   ShapedType operand_type = operand.getType().dyn_cast<ShapedType>();
3171   if (!operand_type) {
3172     op->emitOpError() << "first operand is not a shaped type";
3173     return failure();
3174   }
3175   auto loc = op->getLoc();
3176   SmallVector<Value, 4> shape_values;
3177   shape_values.reserve(operand_type.getRank());
3178   auto shape_scalar_type = builder->getIntegerType(64);
3179   for (auto element : llvm::enumerate(operand_type.getShape())) {
3180     if (element.value() == ShapedType::kDynamicSize) {
3181       Value dim = builder->create<DimOp>(loc, operand, element.index());
3182       shape_values.push_back(
3183           builder->create<IndexCastOp>(loc, dim, shape_scalar_type));
3184     } else {
3185       shape_values.push_back(builder->create<ConstantOp>(
3186           loc, builder->getI64IntegerAttr(element.value())));
3187     }
3188   }
3189   *reifiedReturnShapes = SmallVector<Value, 1>{
3190       builder->create<tensor::FromElementsOp>(loc, shape_values)};
3191   return success();
3192 }
3193 
3194 }  // namespace mhlo
3195 }  // namespace mlir
3196