1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
16 #include "mlir/Rewrite/FrozenRewritePatternList.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/StringMap.h"
19 
20 namespace mlir {
21 
22 // Forward declarations.
23 class Block;
24 class ConversionPatternRewriter;
25 class FuncOp;
26 class MLIRContext;
27 class Operation;
28 class Type;
29 class Value;
30 
31 //===----------------------------------------------------------------------===//
32 // Type Conversion
33 //===----------------------------------------------------------------------===//
34 
35 /// Type conversion class. Specific conversions and materializations can be
36 /// registered using addConversion and addMaterialization, respectively.
37 class TypeConverter {
38 public:
39   /// This class provides all of the information necessary to convert a type
40   /// signature.
41   class SignatureConversion {
42   public:
SignatureConversion(unsigned numOrigInputs)43     SignatureConversion(unsigned numOrigInputs)
44         : remappedInputs(numOrigInputs) {}
45 
46     /// This struct represents a range of new types or a single value that
47     /// remaps an existing signature input.
48     struct InputMapping {
49       size_t inputNo, size;
50       Value replacementValue;
51     };
52 
53     /// Return the argument types for the new signature.
getConvertedTypes()54     ArrayRef<Type> getConvertedTypes() const { return argTypes; }
55 
56     /// Get the input mapping for the given argument.
getInputMapping(unsigned input)57     Optional<InputMapping> getInputMapping(unsigned input) const {
58       return remappedInputs[input];
59     }
60 
61     //===------------------------------------------------------------------===//
62     // Conversion Hooks
63     //===------------------------------------------------------------------===//
64 
65     /// Remap an input of the original signature with a new set of types. The
66     /// new types are appended to the new signature conversion.
67     void addInputs(unsigned origInputNo, ArrayRef<Type> types);
68 
69     /// Append new input types to the signature conversion, this should only be
70     /// used if the new types are not intended to remap an existing input.
71     void addInputs(ArrayRef<Type> types);
72 
73     /// Remap an input of the original signature to another `replacement`
74     /// value. This drops the original argument.
75     void remapInput(unsigned origInputNo, Value replacement);
76 
77   private:
78     /// Remap an input of the original signature with a range of types in the
79     /// new signature.
80     void remapInput(unsigned origInputNo, unsigned newInputNo,
81                     unsigned newInputCount = 1);
82 
83     /// The remapping information for each of the original arguments.
84     SmallVector<Optional<InputMapping>, 4> remappedInputs;
85 
86     /// The set of new argument types.
87     SmallVector<Type, 4> argTypes;
88   };
89 
90   /// Register a conversion function. A conversion function must be convertible
91   /// to any of the following forms(where `T` is a class derived from `Type`:
92   ///   * Optional<Type>(T)
93   ///     - This form represents a 1-1 type conversion. It should return nullptr
94   ///       or `llvm::None` to signify failure. If `llvm::None` is returned, the
95   ///       converter is allowed to try another conversion function to perform
96   ///       the conversion.
97   ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
98   ///     - This form represents a 1-N type conversion. It should return
99   ///       `failure` or `llvm::None` to signify a failed conversion. If the new
100   ///       set of types is empty, the type is removed and any usages of the
101   ///       existing value are expected to be removed during conversion. If
102   ///       `llvm::None` is returned, the converter is allowed to try another
103   ///       conversion function to perform the conversion.
104   /// Note: When attempting to convert a type, e.g. via 'convertType', the
105   ///       mostly recently added conversions will be invoked first.
106   template <typename FnT,
107             typename T = typename llvm::function_traits<FnT>::template arg_t<0>>
addConversion(FnT && callback)108   void addConversion(FnT &&callback) {
109     registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
110   }
111 
112   /// Register a materialization function, which must be convertible to the
113   /// following form:
114   ///   `Optional<Value>(OpBuilder &, T, ValueRange, Location)`,
115   /// where `T` is any subclass of `Type`. This function is responsible for
116   /// creating an operation, using the OpBuilder and Location provided, that
117   /// "casts" a range of values into a single value of the given type `T`. It
118   /// must return a Value of the converted type on success, an `llvm::None` if
119   /// it failed but other materialization can be attempted, and `nullptr` on
120   /// unrecoverable failure. It will only be called for (sub)types of `T`.
121   /// Materialization functions must be provided when a type conversion
122   /// results in more than one type, or if a type conversion may persist after
123   /// the conversion has finished.
124   ///
125   /// This method registers a materialization that will be called when
126   /// converting an illegal block argument type, to a legal type.
127   template <typename FnT,
128             typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
addArgumentMaterialization(FnT && callback)129   void addArgumentMaterialization(FnT &&callback) {
130     argumentMaterializations.emplace_back(
131         wrapMaterialization<T>(std::forward<FnT>(callback)));
132   }
133   /// This method registers a materialization that will be called when
134   /// converting a legal type to an illegal source type. This is used when
135   /// conversions to an illegal type must persist beyond the main conversion.
136   template <typename FnT,
137             typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
addSourceMaterialization(FnT && callback)138   void addSourceMaterialization(FnT &&callback) {
139     sourceMaterializations.emplace_back(
140         wrapMaterialization<T>(std::forward<FnT>(callback)));
141   }
142   /// This method registers a materialization that will be called when
143   /// converting type from an illegal, or source, type to a legal type.
144   template <typename FnT,
145             typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
addTargetMaterialization(FnT && callback)146   void addTargetMaterialization(FnT &&callback) {
147     targetMaterializations.emplace_back(
148         wrapMaterialization<T>(std::forward<FnT>(callback)));
149   }
150 
151   /// Convert the given type. This function should return failure if no valid
152   /// conversion exists, success otherwise. If the new set of types is empty,
153   /// the type is removed and any usages of the existing value are expected to
154   /// be removed during conversion.
155   LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
156 
157   /// This hook simplifies defining 1-1 type conversions. This function returns
158   /// the type to convert to on success, and a null type on failure.
159   Type convertType(Type t);
160 
161   /// Convert the given set of types, filling 'results' as necessary. This
162   /// returns failure if the conversion of any of the types fails, success
163   /// otherwise.
164   LogicalResult convertTypes(ArrayRef<Type> types,
165                              SmallVectorImpl<Type> &results);
166 
167   /// Return true if the given type is legal for this type converter, i.e. the
168   /// type converts to itself.
169   bool isLegal(Type type);
170   /// Return true if all of the given types are legal for this type converter.
171   template <typename RangeT>
172   std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
173                        !std::is_convertible<RangeT, Operation *>::value,
174                    bool>
isLegal(RangeT && range)175   isLegal(RangeT &&range) {
176     return llvm::all_of(range, [this](Type type) { return isLegal(type); });
177   }
178   /// Return true if the given operation has legal operand and result types.
179   bool isLegal(Operation *op);
180 
181   /// Return true if the types of block arguments within the region are legal.
182   bool isLegal(Region *region);
183 
184   /// Return true if the inputs and outputs of the given function type are
185   /// legal.
186   bool isSignatureLegal(FunctionType ty);
187 
188   /// This method allows for converting a specific argument of a signature. It
189   /// takes as inputs the original argument input number, type.
190   /// On success, it populates 'result' with any new mappings.
191   LogicalResult convertSignatureArg(unsigned inputNo, Type type,
192                                     SignatureConversion &result);
193   LogicalResult convertSignatureArgs(TypeRange types,
194                                      SignatureConversion &result,
195                                      unsigned origInputOffset = 0);
196 
197   /// This function converts the type signature of the given block, by invoking
198   /// 'convertSignatureArg' for each argument. This function should return a
199   /// valid conversion for the signature on success, None otherwise.
200   Optional<SignatureConversion> convertBlockSignature(Block *block);
201 
202   /// Materialize a conversion from a set of types into one result type by
203   /// generating a cast sequence of some kind. See the respective
204   /// `add*Materialization` for more information on the context for these
205   /// methods.
materializeArgumentConversion(OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)206   Value materializeArgumentConversion(OpBuilder &builder, Location loc,
207                                       Type resultType, ValueRange inputs) {
208     return materializeConversion(argumentMaterializations, builder, loc,
209                                  resultType, inputs);
210   }
materializeSourceConversion(OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)211   Value materializeSourceConversion(OpBuilder &builder, Location loc,
212                                     Type resultType, ValueRange inputs) {
213     return materializeConversion(sourceMaterializations, builder, loc,
214                                  resultType, inputs);
215   }
materializeTargetConversion(OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)216   Value materializeTargetConversion(OpBuilder &builder, Location loc,
217                                     Type resultType, ValueRange inputs) {
218     return materializeConversion(targetMaterializations, builder, loc,
219                                  resultType, inputs);
220   }
221 
222 private:
223   /// The signature of the callback used to convert a type. If the new set of
224   /// types is empty, the type is removed and any usages of the existing value
225   /// are expected to be removed during conversion.
226   using ConversionCallbackFn =
227       std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
228 
229   /// The signature of the callback used to materialize a conversion.
230   using MaterializationCallbackFn =
231       std::function<Optional<Value>(OpBuilder &, Type, ValueRange, Location)>;
232 
233   /// Attempt to materialize a conversion using one of the provided
234   /// materialization functions.
235   Value materializeConversion(
236       MutableArrayRef<MaterializationCallbackFn> materializations,
237       OpBuilder &builder, Location loc, Type resultType, ValueRange inputs);
238 
239   /// Generate a wrapper for the given callback. This allows for accepting
240   /// different callback forms, that all compose into a single version.
241   /// With callback of form: `Optional<Type>(T)`
242   template <typename T, typename FnT>
243   std::enable_if_t<llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
wrapCallback(FnT && callback)244   wrapCallback(FnT &&callback) {
245     return wrapCallback<T>([callback = std::forward<FnT>(callback)](
246                                T type, SmallVectorImpl<Type> &results) {
247       if (Optional<Type> resultOpt = callback(type)) {
248         bool wasSuccess = static_cast<bool>(resultOpt.getValue());
249         if (wasSuccess)
250           results.push_back(resultOpt.getValue());
251         return Optional<LogicalResult>(success(wasSuccess));
252       }
253       return Optional<LogicalResult>();
254     });
255   }
256   /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<> &)`
257   template <typename T, typename FnT>
258   std::enable_if_t<!llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
wrapCallback(FnT && callback)259   wrapCallback(FnT &&callback) {
260     return [callback = std::forward<FnT>(callback)](
261                Type type,
262                SmallVectorImpl<Type> &results) -> Optional<LogicalResult> {
263       T derivedType = type.dyn_cast<T>();
264       if (!derivedType)
265         return llvm::None;
266       return callback(derivedType, results);
267     };
268   }
269 
270   /// Register a type conversion.
registerConversion(ConversionCallbackFn callback)271   void registerConversion(ConversionCallbackFn callback) {
272     conversions.emplace_back(std::move(callback));
273     cachedDirectConversions.clear();
274     cachedMultiConversions.clear();
275   }
276 
277   /// Generate a wrapper for the given materialization callback. The callback
278   /// may take any subclass of `Type` and the wrapper will check for the target
279   /// type to be of the expected class before calling the callback.
280   template <typename T, typename FnT>
wrapMaterialization(FnT && callback)281   MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
282     return [callback = std::forward<FnT>(callback)](
283                OpBuilder &builder, Type resultType, ValueRange inputs,
284                Location loc) -> Optional<Value> {
285       if (T derivedType = resultType.dyn_cast<T>())
286         return callback(builder, derivedType, inputs, loc);
287       return llvm::None;
288     };
289   }
290 
291   /// The set of registered conversion functions.
292   SmallVector<ConversionCallbackFn, 4> conversions;
293 
294   /// The list of registered materialization functions.
295   SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
296   SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
297   SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
298 
299   /// A set of cached conversions to avoid recomputing in the common case.
300   /// Direct 1-1 conversions are the most common, so this cache stores the
301   /// successful 1-1 conversions as well as all failed conversions.
302   DenseMap<Type, Type> cachedDirectConversions;
303   /// This cache stores the successful 1->N conversions, where N != 1.
304   DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
305 };
306 
307 //===----------------------------------------------------------------------===//
308 // Conversion Patterns
309 //===----------------------------------------------------------------------===//
310 
311 /// Base class for the conversion patterns. This pattern class enables type
312 /// conversions, and other uses specific to the conversion framework. As such,
313 /// patterns of this type can only be used with the 'apply*' methods below.
314 class ConversionPattern : public RewritePattern {
315 public:
316   /// Hook for derived classes to implement rewriting. `op` is the (first)
317   /// operation matched by the pattern, `operands` is a list of the rewritten
318   /// operand values that are passed to `op`, `rewriter` can be used to emit the
319   /// new operations. This function should not fail. If some specific cases of
320   /// the operation are not supported, these cases should not be matched.
rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)321   virtual void rewrite(Operation *op, ArrayRef<Value> operands,
322                        ConversionPatternRewriter &rewriter) const {
323     llvm_unreachable("unimplemented rewrite");
324   }
325 
326   /// Hook for derived classes to implement combined matching and rewriting.
327   virtual LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)328   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
329                   ConversionPatternRewriter &rewriter) const {
330     if (failed(match(op)))
331       return failure();
332     rewrite(op, operands, rewriter);
333     return success();
334   }
335 
336   /// Attempt to match and rewrite the IR root at the specified operation.
337   LogicalResult matchAndRewrite(Operation *op,
338                                 PatternRewriter &rewriter) const final;
339 
340   /// Return the type converter held by this pattern, or nullptr if the pattern
341   /// does not require type conversion.
getTypeConverter()342   TypeConverter *getTypeConverter() const { return typeConverter; }
343 
344 protected:
345   /// See `RewritePattern::RewritePattern` for information on the other
346   /// available constructors.
347   using RewritePattern::RewritePattern;
348   /// Construct a conversion pattern that matches an operation with the given
349   /// root name. This constructor allows for providing a type converter to use
350   /// within the pattern.
ConversionPattern(StringRef rootName,PatternBenefit benefit,TypeConverter & typeConverter,MLIRContext * ctx)351   ConversionPattern(StringRef rootName, PatternBenefit benefit,
352                     TypeConverter &typeConverter, MLIRContext *ctx)
353       : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {}
354   /// Construct a conversion pattern that matches any operation type. This
355   /// constructor allows for providing a type converter to use within the
356   /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
357   /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
358   /// always be supplied here.
ConversionPattern(PatternBenefit benefit,TypeConverter & typeConverter,MatchAnyOpTypeTag tag)359   ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter,
360                     MatchAnyOpTypeTag tag)
361       : RewritePattern(benefit, tag), typeConverter(&typeConverter) {}
362 
363 protected:
364   /// An optional type converter for use by this pattern.
365   TypeConverter *typeConverter = nullptr;
366 
367 private:
368   using RewritePattern::rewrite;
369 };
370 
371 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
372 /// matching and rewriting against an instance of a derived operation class as
373 /// opposed to a raw Operation.
374 template <typename SourceOp>
375 struct OpConversionPattern : public ConversionPattern {
376   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
ConversionPatternOpConversionPattern377       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
378   OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
379                       PatternBenefit benefit = 1)
ConversionPatternOpConversionPattern380       : ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter,
381                           context) {}
382 
383   /// Wrappers around the ConversionPattern methods that pass the derived op
384   /// type.
rewriteOpConversionPattern385   void rewrite(Operation *op, ArrayRef<Value> operands,
386                ConversionPatternRewriter &rewriter) const final {
387     rewrite(cast<SourceOp>(op), operands, rewriter);
388   }
389   LogicalResult
matchAndRewriteOpConversionPattern390   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
391                   ConversionPatternRewriter &rewriter) const final {
392     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
393   }
394 
395   // TODO: Use OperandAdaptor when it supports access to unnamed operands.
396 
397   /// Rewrite and Match methods that operate on the SourceOp type. These must be
398   /// overridden by the derived pattern class.
rewriteOpConversionPattern399   virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
400                        ConversionPatternRewriter &rewriter) const {
401     llvm_unreachable("must override matchAndRewrite or a rewrite method");
402   }
403 
404   virtual LogicalResult
matchAndRewriteOpConversionPattern405   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
406                   ConversionPatternRewriter &rewriter) const {
407     if (failed(match(op)))
408       return failure();
409     rewrite(op, operands, rewriter);
410     return success();
411   }
412 
413 private:
414   using ConversionPattern::matchAndRewrite;
415 };
416 
417 /// Add a pattern to the given pattern list to convert the signature of a FuncOp
418 /// with the given type converter.
419 void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns,
420                                          MLIRContext *ctx,
421                                          TypeConverter &converter);
422 
423 //===----------------------------------------------------------------------===//
424 // Conversion PatternRewriter
425 //===----------------------------------------------------------------------===//
426 
427 namespace detail {
428 struct ConversionPatternRewriterImpl;
429 } // end namespace detail
430 
431 /// This class implements a pattern rewriter for use with ConversionPatterns. It
432 /// extends the base PatternRewriter and provides special conversion specific
433 /// hooks.
434 class ConversionPatternRewriter final : public PatternRewriter {
435 public:
436   ConversionPatternRewriter(MLIRContext *ctx);
437   ~ConversionPatternRewriter() override;
438 
439   /// Apply a signature conversion to the entry block of the given region. This
440   /// replaces the entry block with a new block containing the updated
441   /// signature. The new entry block to the region is returned for convenience.
442   Block *
443   applySignatureConversion(Region *region,
444                            TypeConverter::SignatureConversion &conversion);
445 
446   /// Convert the types of block arguments within the given region. This
447   /// replaces each block with a new block containing the updated signature. The
448   /// entry block may have a special conversion if `entryConversion` is
449   /// provided. On success, the new entry block to the region is returned for
450   /// convenience. Otherwise, failure is returned.
451   FailureOr<Block *> convertRegionTypes(
452       Region *region, TypeConverter &converter,
453       TypeConverter::SignatureConversion *entryConversion = nullptr);
454 
455   /// Replace all the uses of the block argument `from` with value `to`.
456   void replaceUsesOfBlockArgument(BlockArgument from, Value to);
457 
458   /// Return the converted value that replaces 'key'. Return 'key' if there is
459   /// no such a converted value.
460   Value getRemappedValue(Value key);
461 
462   //===--------------------------------------------------------------------===//
463   // PatternRewriter Hooks
464   //===--------------------------------------------------------------------===//
465 
466   /// PatternRewriter hook for replacing the results of an operation.
467   void replaceOp(Operation *op, ValueRange newValues) override;
468   using PatternRewriter::replaceOp;
469 
470   /// PatternRewriter hook for erasing a dead operation. The uses of this
471   /// operation *must* be made dead by the end of the conversion process,
472   /// otherwise an assert will be issued.
473   void eraseOp(Operation *op) override;
474 
475   /// PatternRewriter hook for erase all operations in a block. This is not yet
476   /// implemented for dialect conversion.
477   void eraseBlock(Block *block) override;
478 
479   /// PatternRewriter hook creating a new block.
480   void notifyBlockCreated(Block *block) override;
481 
482   /// PatternRewriter hook for splitting a block into two parts.
483   Block *splitBlock(Block *block, Block::iterator before) override;
484 
485   /// PatternRewriter hook for merging a block into another.
486   void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override;
487 
488   /// PatternRewriter hook for moving blocks out of a region.
489   void inlineRegionBefore(Region &region, Region &parent,
490                           Region::iterator before) override;
491   using PatternRewriter::inlineRegionBefore;
492 
493   /// PatternRewriter hook for cloning blocks of one region into another. The
494   /// given region to clone *must* not have been modified as part of conversion
495   /// yet, i.e. it must be within an operation that is either in the process of
496   /// conversion, or has not yet been converted.
497   void cloneRegionBefore(Region &region, Region &parent,
498                          Region::iterator before,
499                          BlockAndValueMapping &mapping) override;
500   using PatternRewriter::cloneRegionBefore;
501 
502   /// PatternRewriter hook for inserting a new operation.
503   void notifyOperationInserted(Operation *op) override;
504 
505   /// PatternRewriter hook for updating the root operation in-place.
506   /// Note: These methods only track updates to the top-level operation itself,
507   /// and not nested regions. Updates to regions will still require notification
508   /// through other more specific hooks above.
509   void startRootUpdate(Operation *op) override;
510 
511   /// PatternRewriter hook for updating the root operation in-place.
512   void finalizeRootUpdate(Operation *op) override;
513 
514   /// PatternRewriter hook for updating the root operation in-place.
515   void cancelRootUpdate(Operation *op) override;
516 
517   /// PatternRewriter hook for notifying match failure reasons.
518   LogicalResult
519   notifyMatchFailure(Operation *op,
520                      function_ref<void(Diagnostic &)> reasonCallback) override;
521   using PatternRewriter::notifyMatchFailure;
522 
523   /// Return a reference to the internal implementation.
524   detail::ConversionPatternRewriterImpl &getImpl();
525 
526 private:
527   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
528 };
529 
530 //===----------------------------------------------------------------------===//
531 // ConversionTarget
532 //===----------------------------------------------------------------------===//
533 
534 /// This class describes a specific conversion target.
535 class ConversionTarget {
536 public:
537   /// This enumeration corresponds to the specific action to take when
538   /// considering an operation legal for this conversion target.
539   enum class LegalizationAction {
540     /// The target supports this operation.
541     Legal,
542 
543     /// This operation has dynamic legalization constraints that must be checked
544     /// by the target.
545     Dynamic,
546 
547     /// The target explicitly does not support this operation.
548     Illegal,
549   };
550 
551   /// A structure containing additional information describing a specific legal
552   /// operation instance.
553   struct LegalOpDetails {
554     /// A flag that indicates if this operation is 'recursively' legal. This
555     /// means that if an operation is legal, either statically or dynamically,
556     /// all of the operations nested within are also considered legal.
557     bool isRecursivelyLegal = false;
558   };
559 
560   /// The signature of the callback used to determine if an operation is
561   /// dynamically legal on the target.
562   using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
563 
ConversionTarget(MLIRContext & ctx)564   ConversionTarget(MLIRContext &ctx)
565       : unknownOpsDynamicallyLegal(false), ctx(ctx) {}
566   virtual ~ConversionTarget() = default;
567 
568   //===--------------------------------------------------------------------===//
569   // Legality Registration
570   //===--------------------------------------------------------------------===//
571 
572   /// Register a legality action for the given operation.
573   void setOpAction(OperationName op, LegalizationAction action);
setOpAction(LegalizationAction action)574   template <typename OpT> void setOpAction(LegalizationAction action) {
575     setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
576   }
577 
578   /// Register the given operations as legal.
addLegalOp()579   template <typename OpT> void addLegalOp() {
580     setOpAction<OpT>(LegalizationAction::Legal);
581   }
addLegalOp()582   template <typename OpT, typename OpT2, typename... OpTs> void addLegalOp() {
583     addLegalOp<OpT>();
584     addLegalOp<OpT2, OpTs...>();
585   }
586 
587   /// Register the given operation as dynamically legal, i.e. requiring custom
588   /// handling by the target via 'isDynamicallyLegal'.
addDynamicallyLegalOp()589   template <typename OpT> void addDynamicallyLegalOp() {
590     setOpAction<OpT>(LegalizationAction::Dynamic);
591   }
592   template <typename OpT, typename OpT2, typename... OpTs>
addDynamicallyLegalOp()593   void addDynamicallyLegalOp() {
594     addDynamicallyLegalOp<OpT>();
595     addDynamicallyLegalOp<OpT2, OpTs...>();
596   }
597 
598   /// Register the given operation as dynamically legal and set the dynamic
599   /// legalization callback to the one provided.
600   template <typename OpT>
addDynamicallyLegalOp(const DynamicLegalityCallbackFn & callback)601   void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
602     OperationName opName(OpT::getOperationName(), &ctx);
603     setOpAction(opName, LegalizationAction::Dynamic);
604     setLegalityCallback(opName, callback);
605   }
606   template <typename OpT, typename OpT2, typename... OpTs>
addDynamicallyLegalOp(const DynamicLegalityCallbackFn & callback)607   void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
608     addDynamicallyLegalOp<OpT>(callback);
609     addDynamicallyLegalOp<OpT2, OpTs...>(callback);
610   }
611   template <typename OpT, class Callable>
612   typename std::enable_if<
613       !llvm::is_invocable<Callable, Operation *>::value>::type
addDynamicallyLegalOp(Callable && callback)614   addDynamicallyLegalOp(Callable &&callback) {
615     addDynamicallyLegalOp<OpT>(
616         [=](Operation *op) { return callback(cast<OpT>(op)); });
617   }
618 
619   /// Register the given operation as illegal, i.e. this operation is known to
620   /// not be supported by this target.
addIllegalOp()621   template <typename OpT> void addIllegalOp() {
622     setOpAction<OpT>(LegalizationAction::Illegal);
623   }
addIllegalOp()624   template <typename OpT, typename OpT2, typename... OpTs> void addIllegalOp() {
625     addIllegalOp<OpT>();
626     addIllegalOp<OpT2, OpTs...>();
627   }
628 
629   /// Mark an operation, that *must* have either been set as `Legal` or
630   /// `DynamicallyLegal`, as being recursively legal. This means that in
631   /// addition to the operation itself, all of the operations nested within are
632   /// also considered legal. An optional dynamic legality callback may be
633   /// provided to mark subsets of legal instances as recursively legal.
634   template <typename OpT>
635   void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
636     OperationName opName(OpT::getOperationName(), &ctx);
637     markOpRecursivelyLegal(opName, callback);
638   }
639   template <typename OpT, typename OpT2, typename... OpTs>
640   void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
641     markOpRecursivelyLegal<OpT>(callback);
642     markOpRecursivelyLegal<OpT2, OpTs...>(callback);
643   }
644   template <typename OpT, class Callable>
645   typename std::enable_if<
646       !llvm::is_invocable<Callable, Operation *>::value>::type
markOpRecursivelyLegal(Callable && callback)647   markOpRecursivelyLegal(Callable &&callback) {
648     markOpRecursivelyLegal<OpT>(
649         [=](Operation *op) { return callback(cast<OpT>(op)); });
650   }
651 
652   /// Register a legality action for the given dialects.
653   void setDialectAction(ArrayRef<StringRef> dialectNames,
654                         LegalizationAction action);
655 
656   /// Register the operations of the given dialects as legal.
657   template <typename... Names>
addLegalDialect(StringRef name,Names...names)658   void addLegalDialect(StringRef name, Names... names) {
659     SmallVector<StringRef, 2> dialectNames({name, names...});
660     setDialectAction(dialectNames, LegalizationAction::Legal);
661   }
addLegalDialect()662   template <typename... Args> void addLegalDialect() {
663     SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
664     setDialectAction(dialectNames, LegalizationAction::Legal);
665   }
666 
667   /// Register the operations of the given dialects as dynamically legal, i.e.
668   /// requiring custom handling by the target via 'isDynamicallyLegal'.
669   template <typename... Names>
addDynamicallyLegalDialect(StringRef name,Names...names)670   void addDynamicallyLegalDialect(StringRef name, Names... names) {
671     SmallVector<StringRef, 2> dialectNames({name, names...});
672     setDialectAction(dialectNames, LegalizationAction::Dynamic);
673   }
674   template <typename... Args>
675   void addDynamicallyLegalDialect(
676       Optional<DynamicLegalityCallbackFn> callback = llvm::None) {
677     SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
678     setDialectAction(dialectNames, LegalizationAction::Dynamic);
679     if (callback)
680       setLegalityCallback(dialectNames, *callback);
681   }
682   template <typename... Args>
addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)683   void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
684     SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
685     setDialectAction(dialectNames, LegalizationAction::Dynamic);
686     setLegalityCallback(dialectNames, callback);
687   }
688 
689   /// Register unknown operations as dynamically legal. For operations(and
690   /// dialects) that do not have a set legalization action, treat them as
691   /// dynamically legal and invoke the given callback if valid or
692   /// 'isDynamicallyLegal'.
markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn & fn)693   void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) {
694     unknownOpsDynamicallyLegal = true;
695     unknownLegalityFn = fn;
696   }
markUnknownOpDynamicallyLegal()697   void markUnknownOpDynamicallyLegal() { unknownOpsDynamicallyLegal = true; }
698 
699   /// Register the operations of the given dialects as illegal, i.e.
700   /// operations of this dialect are not supported by the target.
701   template <typename... Names>
addIllegalDialect(StringRef name,Names...names)702   void addIllegalDialect(StringRef name, Names... names) {
703     SmallVector<StringRef, 2> dialectNames({name, names...});
704     setDialectAction(dialectNames, LegalizationAction::Illegal);
705   }
addIllegalDialect()706   template <typename... Args> void addIllegalDialect() {
707     SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
708     setDialectAction(dialectNames, LegalizationAction::Illegal);
709   }
710 
711   //===--------------------------------------------------------------------===//
712   // Legality Querying
713   //===--------------------------------------------------------------------===//
714 
715   /// Get the legality action for the given operation.
716   Optional<LegalizationAction> getOpAction(OperationName op) const;
717 
718   /// If the given operation instance is legal on this target, a structure
719   /// containing legality information is returned. If the operation is not
720   /// legal, None is returned.
721   Optional<LegalOpDetails> isLegal(Operation *op) const;
722 
723 protected:
724   /// Runs a custom legalization query for the given operation. This should
725   /// return true if the given operation is legal, otherwise false.
isDynamicallyLegal(Operation * op)726   virtual bool isDynamicallyLegal(Operation *op) const {
727     llvm_unreachable(
728         "targets with custom legalization must override 'isDynamicallyLegal'");
729   }
730 
731 private:
732   /// Set the dynamic legality callback for the given operation.
733   void setLegalityCallback(OperationName name,
734                            const DynamicLegalityCallbackFn &callback);
735 
736   /// Set the dynamic legality callback for the given dialects.
737   void setLegalityCallback(ArrayRef<StringRef> dialects,
738                            const DynamicLegalityCallbackFn &callback);
739 
740   /// Set the recursive legality callback for the given operation and mark the
741   /// operation as recursively legal.
742   void markOpRecursivelyLegal(OperationName name,
743                               const DynamicLegalityCallbackFn &callback);
744 
745   /// The set of information that configures the legalization of an operation.
746   struct LegalizationInfo {
747     /// The legality action this operation was given.
748     LegalizationAction action;
749 
750     /// If some legal instances of this operation may also be recursively legal.
751     bool isRecursivelyLegal;
752 
753     /// The legality callback if this operation is dynamically legal.
754     Optional<DynamicLegalityCallbackFn> legalityFn;
755   };
756 
757   /// Get the legalization information for the given operation.
758   Optional<LegalizationInfo> getOpInfo(OperationName op) const;
759 
760   /// A deterministic mapping of operation name and its respective legality
761   /// information.
762   llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
763 
764   /// A set of legality callbacks for given operation names that are used to
765   /// check if an operation instance is recursively legal.
766   DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
767 
768   /// A deterministic mapping of dialect name to the specific legality action to
769   /// take.
770   llvm::StringMap<LegalizationAction> legalDialects;
771 
772   /// A set of dynamic legality callbacks for given dialect names.
773   llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
774 
775   /// An optional legality callback for unknown operations.
776   Optional<DynamicLegalityCallbackFn> unknownLegalityFn;
777 
778   /// Flag indicating if unknown operations should be treated as dynamically
779   /// legal.
780   bool unknownOpsDynamicallyLegal;
781 
782   /// The current context this target applies to.
783   MLIRContext &ctx;
784 };
785 
786 //===----------------------------------------------------------------------===//
787 // Op Conversion Entry Points
788 //===----------------------------------------------------------------------===//
789 
790 /// Below we define several entry points for operation conversion. It is
791 /// important to note that the patterns provided to the conversion framework may
792 /// have additional constraints. See the `PatternRewriter Hooks` section of the
793 /// ConversionPatternRewriter, to see what additional constraints are imposed on
794 /// the use of the PatternRewriter.
795 
796 /// Apply a partial conversion on the given operations and all nested
797 /// operations. This method converts as many operations to the target as
798 /// possible, ignoring operations that failed to legalize. This method only
799 /// returns failure if there ops explicitly marked as illegal. If an
800 /// `unconvertedOps` set is provided, all operations that are found not to be
801 /// legalizable to the given `target` are placed within that set. (Note that if
802 /// there is an op explicitly marked as illegal, the conversion terminates and
803 /// the `unconvertedOps` set will not necessarily be complete.)
804 LLVM_NODISCARD LogicalResult
805 applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
806                        const FrozenRewritePatternList &patterns,
807                        DenseSet<Operation *> *unconvertedOps = nullptr);
808 LLVM_NODISCARD LogicalResult
809 applyPartialConversion(Operation *op, ConversionTarget &target,
810                        const FrozenRewritePatternList &patterns,
811                        DenseSet<Operation *> *unconvertedOps = nullptr);
812 
813 /// Apply a complete conversion on the given operations, and all nested
814 /// operations. This method returns failure if the conversion of any operation
815 /// fails, or if there are unreachable blocks in any of the regions nested
816 /// within 'ops'.
817 LLVM_NODISCARD LogicalResult
818 applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
819                     const FrozenRewritePatternList &patterns);
820 LLVM_NODISCARD LogicalResult
821 applyFullConversion(Operation *op, ConversionTarget &target,
822                     const FrozenRewritePatternList &patterns);
823 
824 /// Apply an analysis conversion on the given operations, and all nested
825 /// operations. This method analyzes which operations would be successfully
826 /// converted to the target if a conversion was applied. All operations that
827 /// were found to be legalizable to the given 'target' are placed within the
828 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
829 /// operations on success and only pre-existing operations are added to the set.
830 /// This method only returns failure if there are unreachable blocks in any of
831 /// the regions nested within 'ops'.
832 LLVM_NODISCARD LogicalResult
833 applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
834                         const FrozenRewritePatternList &patterns,
835                         DenseSet<Operation *> &convertedOps);
836 LLVM_NODISCARD LogicalResult
837 applyAnalysisConversion(Operation *op, ConversionTarget &target,
838                         const FrozenRewritePatternList &patterns,
839                         DenseSet<Operation *> &convertedOps);
840 } // end namespace mlir
841 
842 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
843