1 //===- DecomposeCallGraphTypes.h - CG type decompositions -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Conversion patterns for decomposing types along call graph edges. That is,
10 // decomposing types for calls, returns, and function args.
11 //
12 // TODO: Make this handle dialect-defined functions, calls, and returns.
13 // Currently, the generic interfaces aren't sophisticated enough for the
14 // types of mutations that we are doing here.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
19 #define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
20 
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 namespace mlir {
24 
25 /// This class provides a hook that expands one Value into multiple Value's,
26 /// with a TypeConverter-inspired callback registration mechanism.
27 ///
28 /// For folks that are familiar with the dialect conversion framework /
29 /// TypeConverter, this is effectively the inverse of a source/argument
30 /// materialization. A target materialization is not what we want here because
31 /// it always produces a single Value, but in this case the whole point is to
32 /// decompose a Value into multiple Value's.
33 ///
34 /// The reason we need this inverse is easily understood by looking at what we
35 /// need to do for decomposing types for a return op. When converting a return
36 /// op, the dialect conversion framework will give the list of converted
37 /// operands, and will ensure that each converted operand, even if it expanded
38 /// into multiple types, is materialized as a single result. We then need to
39 /// undo that materialization to a single result, which we do with the
40 /// decomposeValue hooks registered on this object.
41 ///
42 /// TODO: Eventually, the type conversion infra should have this hook built-in.
43 /// See
44 /// https://llvm.discourse.group/t/extending-type-conversion-infrastructure/779/2
45 class ValueDecomposer {
46 public:
47   /// This method tries to decompose a value of a certain type using provided
48   /// decompose callback functions. If it is unable to do so, the original value
49   /// is returned.
50   void decomposeValue(OpBuilder &, Location, Type, Value,
51                       SmallVectorImpl<Value> &);
52 
53   /// This method registers a callback function that will be called to decompose
54   /// a value of a certain type into 0, 1, or multiple values.
55   template <typename FnT,
56             typename T = typename llvm::function_traits<FnT>::template arg_t<2>>
addDecomposeValueConversion(FnT && callback)57   void addDecomposeValueConversion(FnT &&callback) {
58     decomposeValueConversions.emplace_back(
59         wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
60   }
61 
62 private:
63   using DecomposeValueConversionCallFn = std::function<Optional<LogicalResult>(
64       OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
65 
66   /// Generate a wrapper for the given decompose value conversion callback.
67   template <typename T, typename FnT>
68   DecomposeValueConversionCallFn
wrapDecomposeValueConversionCallback(FnT && callback)69   wrapDecomposeValueConversionCallback(FnT &&callback) {
70     return [callback = std::forward<FnT>(callback)](
71                OpBuilder &builder, Location loc, Type type, Value value,
72                SmallVectorImpl<Value> &newValues) -> Optional<LogicalResult> {
73       if (T derivedType = type.dyn_cast<T>())
74         return callback(builder, loc, derivedType, value, newValues);
75       return llvm::None;
76     };
77   }
78 
79   SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
80 };
81 
82 /// Populates the patterns needed to drive the conversion process for
83 /// decomposing call graph types with the given `ValueDecomposer`.
84 void populateDecomposeCallGraphTypesPatterns(
85     MLIRContext *context, TypeConverter &typeConverter,
86     ValueDecomposer &decomposer, OwningRewritePatternList &patterns);
87 
88 } // end namespace mlir
89 
90 #endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
91