1 //===- DecomposeCallGraphTypes.h - CG type decompositions -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Conversion patterns for decomposing types along call graph edges. That is, 10 // decomposing types for calls, returns, and function args. 11 // 12 // TODO: Make this handle dialect-defined functions, calls, and returns. 13 // Currently, the generic interfaces aren't sophisticated enough for the 14 // types of mutations that we are doing here. 15 // 16 //===----------------------------------------------------------------------===// 17 18 #ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H 19 #define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H 20 21 #include "mlir/Transforms/DialectConversion.h" 22 23 namespace mlir { 24 25 /// This class provides a hook that expands one Value into multiple Value's, 26 /// with a TypeConverter-inspired callback registration mechanism. 27 /// 28 /// For folks that are familiar with the dialect conversion framework / 29 /// TypeConverter, this is effectively the inverse of a source/argument 30 /// materialization. A target materialization is not what we want here because 31 /// it always produces a single Value, but in this case the whole point is to 32 /// decompose a Value into multiple Value's. 33 /// 34 /// The reason we need this inverse is easily understood by looking at what we 35 /// need to do for decomposing types for a return op. When converting a return 36 /// op, the dialect conversion framework will give the list of converted 37 /// operands, and will ensure that each converted operand, even if it expanded 38 /// into multiple types, is materialized as a single result. We then need to 39 /// undo that materialization to a single result, which we do with the 40 /// decomposeValue hooks registered on this object. 41 /// 42 /// TODO: Eventually, the type conversion infra should have this hook built-in. 43 /// See 44 /// https://llvm.discourse.group/t/extending-type-conversion-infrastructure/779/2 45 class ValueDecomposer { 46 public: 47 /// This method tries to decompose a value of a certain type using provided 48 /// decompose callback functions. If it is unable to do so, the original value 49 /// is returned. 50 void decomposeValue(OpBuilder &, Location, Type, Value, 51 SmallVectorImpl<Value> &); 52 53 /// This method registers a callback function that will be called to decompose 54 /// a value of a certain type into 0, 1, or multiple values. 55 template <typename FnT, 56 typename T = typename llvm::function_traits<FnT>::template arg_t<2>> addDecomposeValueConversion(FnT && callback)57 void addDecomposeValueConversion(FnT &&callback) { 58 decomposeValueConversions.emplace_back( 59 wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback))); 60 } 61 62 private: 63 using DecomposeValueConversionCallFn = std::function<Optional<LogicalResult>( 64 OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>; 65 66 /// Generate a wrapper for the given decompose value conversion callback. 67 template <typename T, typename FnT> 68 DecomposeValueConversionCallFn wrapDecomposeValueConversionCallback(FnT && callback)69 wrapDecomposeValueConversionCallback(FnT &&callback) { 70 return [callback = std::forward<FnT>(callback)]( 71 OpBuilder &builder, Location loc, Type type, Value value, 72 SmallVectorImpl<Value> &newValues) -> Optional<LogicalResult> { 73 if (T derivedType = type.dyn_cast<T>()) 74 return callback(builder, loc, derivedType, value, newValues); 75 return llvm::None; 76 }; 77 } 78 79 SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions; 80 }; 81 82 /// Populates the patterns needed to drive the conversion process for 83 /// decomposing call graph types with the given `ValueDecomposer`. 84 void populateDecomposeCallGraphTypesPatterns( 85 MLIRContext *context, TypeConverter &typeConverter, 86 ValueDecomposer &decomposer, OwningRewritePatternList &patterns); 87 88 } // end namespace mlir 89 90 #endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H 91