1# Chapter 4: Enabling Generic Transformation with Interfaces
2
3[TOC]
4
5## Background: Grappling with an Extensible IR
6
7Through dialects, MLIR allows for the representation of many different levels of
8abstraction; the Toy dialect that we have previously defined is one such
9example. Though these different dialects may represent different abstractions,
10there is often a set of common transformations and analyses that we would like
11to perform. The problem that arises is that naively implementing each
12transformation for each dialect leads to large amounts of code duplication, as
13the internal algorithms are generally very similar, if not the same. We would
14like to provide the ability for transformations to opaquely hook into dialects
15like Toy to get the information they need.
16
17MLIR provides a set of always available-hooks for certain core transformations,
18as seen in the [previous chapter](Ch-3.md), where we registered some
19canonicalizations via a hook on our operations (`getCanonicalizationPatterns`).
20However, these types of hooks don't really scale well. Therefore, a more generic
21solution was designed, in the form of [interfaces](../../Interfaces.md), to make
22the MLIR infrastructure as extensible as the representation. Interfaces provide
23a generic mechanism for dialects and operations to provide information to a
24transformation or analysis.
25
26## Shape Inference: Preparing for Code Generation
27
28Our Toy IR currently operates on generic tensors, meaning that we don't know the
29shape of tensors other than during the initialization of constants. This
30complicates optimizations, as well as code generation. Fortunately, we can
31simply propagate the shapes through the computation until they are all known.
32The issue is how to handle calls to user-defined generic functions: every call
33site could deduce different shapes. One possibility would be to perform symbolic
34inference based on the argument types, but this would be hard to generalize if
35we were to introduce more control flow in the language. Another approach would
36be function specialization, where every call site with new argument shapes
37duplicates the called function and specializes it. The approach we take for Toy
38is to inline all of the function calls, then perform intraprocedural shape
39propagation.
40
41### Inlining
42
43Here we could write an inlining algorithm specifically designed for the Toy
44dialect, but that can become quite complicated depending on the level of
45complexity that we want. Disregarding cost modeling, the pure structural
46transformation is already complex to implement from scratch. Thankfully, MLIR
47provides a generic inliner algorithm that dialects can plug into. All we need to
48do in Toy is to provide the [interfaces](../../Interfaces.md) for the inliner to
49hook into.
50
51The first thing we need to do is to define the constraints on inlining
52operations in the Toy dialect. This information is provided through a
53[dialect interface](../../Interfaces.md#dialect-interfaces). This is essentially
54a class containing a set of virtual hooks which the dialect can override.
55In this case, the interface is `DialectInlinerInterface`.
56
57```c++
58/// This class defines the interface for handling inlining with Toy operations.
59/// We simplify inherit from the base interface class and override
60/// the necessary methods.
61struct ToyInlinerInterface : public DialectInlinerInterface {
62  using DialectInlinerInterface::DialectInlinerInterface;
63
64  /// This hook checks to see if the given callable operation is legal to inline
65  /// into the given call. For Toy this hook can simply return true, as the Toy
66  /// Call operation is always inlinable.
67  bool isLegalToInline(Operation *call, Operation *callable,
68                       bool wouldBeCloned) const final {
69    return true;
70  }
71
72  /// This hook checks to see if the given operation is legal to inline into the
73  /// given region. For Toy this hook can simply return true, as all Toy
74  /// operations are inlinable.
75  bool isLegalToInline(Operation *, Region *, bool,
76                       BlockAndValueMapping &) const final {
77    return true;
78  }
79
80  /// This hook is called when a terminator operation has been inlined. The only
81  /// terminator that we have in the Toy dialect is the return
82  /// operation(toy.return). We handle the return by replacing the values
83  /// previously returned by the call operation with the operands of the
84  /// return.
85  void handleTerminator(Operation *op,
86                        ArrayRef<Value> valuesToRepl) const final {
87    // Only "toy.return" needs to be handled here.
88    auto returnOp = cast<ReturnOp>(op);
89
90    // Replace the values directly with the return operands.
91    assert(returnOp.getNumOperands() == valuesToRepl.size());
92    for (const auto &it : llvm::enumerate(returnOp.getOperands()))
93      valuesToRepl[it.index()].replaceAllUsesWith(it.value());
94  }
95};
96```
97
98We then register our dialect interface directly on the Toy dialect, similarly to
99how we did for operations.
100
101```c++
102ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
103  addInterfaces<ToyInlinerInterface>();
104}
105```
106
107Next, we need to provide a way for the inliner to know that `toy.generic_call`
108represents a call to a function. MLIR provides an
109[operation interface](../../Interfaces.md#operation-interfaces) that can be used
110to mark an operation as being "call-like". Unlike dialect interfaces, operation
111interfaces provide a more refined granularity of information that is specific
112and core to a single operation. The interface that we will be adding here is the
113`CallOpInterface`.
114
115To add this interface we just need to include the definition into our operation
116specification file (`Ops.td`):
117
118```tablegen
119include "mlir/Interfaces/CallInterfaces.td"
120```
121
122and add it to the traits list of `GenericCallOp`:
123
124```tablegen
125def GenericCallOp : Toy_Op<"generic_call",
126    [DeclareOpInterfaceMethods<CallOpInterface>]> {
127  ...
128}
129```
130
131In the above we also use the `DeclareOpInterfaceMethods` directive to
132auto-declare all of the interface methods in the class declaration of
133GenericCallOp. This means that we just need to provide a definition:
134
135```c++
136/// Return the callee of the generic call operation, this is required by the
137/// call interface.
138CallInterfaceCallable GenericCallOp::getCallableForCallee() {
139  return getAttrOfType<SymbolRefAttr>("callee");
140}
141
142/// Get the argument operands to the called function, this is required by the
143/// call interface.
144Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
145```
146
147Now that the inliner has been informed about the Toy dialect, we can add the
148inliner pass to the pass manager for Toy:
149
150```c++
151  pm.addPass(mlir::createInlinerPass());
152```
153
154Now let's look at a working example:
155
156```mlir
157func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
158  %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
159  %1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
160  %2 = toy.mul %0, %1 : tensor<*xf64>
161  toy.return %2 : tensor<*xf64>
162}
163func @main() {
164  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
165  %1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
166  %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
167  %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<2x3xf64>
168  %4 = toy.generic_call @multiply_transpose(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
169  %5 = toy.generic_call @multiply_transpose(%3, %1) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
170  toy.print %5 : tensor<*xf64>
171  toy.return
172}
173```
174
175We have two calls to multiple_transpose that we would like to inline into main,
176but if we look at the output nothing has changed. We are missing one last subtle
177piece: there is a hidden type conversion on the edge of the call. If we look at
178the above, the operands to the generic_call are of type `tensor<2x3xf64>`, while
179the inputs to the function expect `tensor<*xf64>`. To resolve this difference,
180the inliner expects an explicit cast operation to be inserted. For this, we need
181to add a new operation to the Toy dialect, `ToyCastOp`(toy.cast), to represent
182casts between two different shapes.
183
184```tablegen
185def CastOp : Toy_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
186  let summary = "shape cast operation";
187  let description = [{
188    The "cast" operation converts a tensor from one type to an equivalent type
189    without changing any data elements. The source and destination types
190    must both be tensor types with the same element type. If both are ranked
191    then the rank should be the same and static dimensions should match. The
192    operation is invalid if converting to a mismatching constant dimension.
193  }];
194
195  let arguments = (ins F64Tensor:$input);
196  let results = (outs F64Tensor:$output);
197
198  // Set the folder bit so that we can fold redundant cast operations.
199  let hasFolder = 1;
200}
201```
202
203We can then override the necessary hook on the ToyInlinerInterface to insert
204this for us when necessary:
205
206```c++
207struct ToyInlinerInterface : public DialectInlinerInterface {
208  ...
209
210  /// Attempts to materialize a conversion for a type mismatch between a call
211  /// from this dialect, and a callable region. This method should generate an
212  /// operation that takes 'input' as the only operand, and produces a single
213  /// result of 'resultType'. If a conversion can not be generated, nullptr
214  /// should be returned.
215  Operation *materializeCallConversion(OpBuilder &builder, Value input,
216                                       Type resultType,
217                                       Location conversionLoc) const final {
218    return builder.create<CastOp>(conversionLoc, resultType, input);
219  }
220};
221```
222
223If we run the working example through the pipeline again, we get the expected:
224
225```mlir
226func @main() {
227  %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
228  %1 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
229  %2 = "toy.cast"(%1) : (tensor<2x3xf64>) -> tensor<*xf64>
230  %3 = "toy.cast"(%0) : (tensor<2x3xf64>) -> tensor<*xf64>
231  %4 = "toy.transpose"(%2) : (tensor<*xf64>) -> tensor<*xf64>
232  %5 = "toy.transpose"(%3) : (tensor<*xf64>) -> tensor<*xf64>
233  %6 = "toy.mul"(%4, %5) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
234  toy.print %6 : tensor<*xf64>
235  toy.return
236}
237```
238
239NOTE: The generic inliner will also perform simplifications, so the output may
240be a bit cleaner than expected.
241
242### Intraprocedural Shape Inference
243
244Now that we have inlined all of the functions, we are left with a main function
245containing a mix of static and dynamically shaped operations. We can now write a
246simple shape inference pass to propagate shapes intraprocedurally (within a
247single function). We could write this as a pass that directly encodes the
248constraints of the operations within the Toy dialect, but this seems like a good
249candidate for a transformation that could be written generically. As a good rule
250of thumb, it is best to express a transformation as generically as possible,
251such that it can be extended to other dialects in the future. There is no
252telling how many other dialects may have similar needs or encounter the same
253problems.
254
255For shape inference, if we break down the problem to its core, we really just
256want operations to tell us the expected outputs given a set of statically known
257inputs. (We can definitely get more complex than that, but for our needs we can
258keep it simple.) Given that this property is core to a specific operation, we
259can define an operation interface that can be specified on operations that need
260to have their result shapes inferred.
261
262Similarly to operations, we can also
263[define operation interfaces](../../OpDefinitions.md#operation-interfaces) using
264the operation definition specification (ODS) framework.
265
266The interface is defined by inheriting from `OpInterface`, which takes the name
267to be given to the generated C++ interface class as a template argument. For our
268purposes, we will simply name the generated class `ShapeInference`. We also
269provide a description for the interface.
270
271```tablegen
272def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
273  let description = [{
274    Interface to access a registered method to infer the return types for an
275    operation that can be used during type inference.
276  }];
277}
278```
279
280Next, we define the interface methods that the operations will need to provide.
281An interface method is comprised of: a description; a C++ return type in string
282form; a method name in string form; and a few optional components, depending on
283the need. See the
284[ODS documentation](../../OpDefinitions.md#operation-interfaces) for more
285information.
286
287```tablegen
288def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
289  ...
290
291  let methods = [
292    InterfaceMethod<"Infer and set the output shape for the current operation.",
293                    "void", "inferShapes">
294  ];
295}
296```
297
298Now that the interface is defined, we can add it to the necessary Toy operations
299in a similar way to how we added the `CallOpInterface` to the GenericCallOp:
300
301```tablegen
302def MulOp : Toy_Op<"mul",
303    [..., DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
304  ...
305}
306```
307
308Each of these operations will then need to provide a definition for the
309`inferShapes()` method. As an example, for the mul op, the result shape is
310inferred as the shape of the inputs.
311
312```c++
313/// Infer the output shape of the MulOp, this is required by the shape inference
314/// interface.
315void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
316```
317
318At this point, each of the necessary Toy operations provide a mechanism by which
319to infer their output shapes. The ShapeInferencePass is a FunctionPass: it will
320run on each Function in isolation. MLIR also supports general
321[OperationPasses](../../PassManagement.md#operation-pass) that run on any isolated
322operation (i.e. other function-like operations), but here our module only
323contains functions, so there is no need to generalize to all operations.
324
325Implementing such a pass is done by creating a class inheriting from
326`mlir::FunctionPass` and overriding the `runOnFunction()` method.
327
328```c++
329class ShapeInferencePass
330    : public mlir::PassWrapper<ShapeInferencePass, FunctionPass> {
331  void runOnFunction() override {
332    FuncOp function = getFunction();
333    ...
334  }
335};
336```
337
338While at it, let's also create a helper method for instantiating the pass:
339
340```c++
341std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
342  return std::make_unique<ShapeInferencePass>();
343}
344```
345
346The shape inference algorithm operates as follows:
347
3481.  Build a worklist containing all the operations that return a dynamically
349    shaped tensor: these are the operations that need shape inference.
3502.  Iterate on the worklist:
351    -   find an operation to process: the next ready operation in the worklist
352        has all of its arguments non-generic,
353    -   if no operation is found, break out of the loop,
354    -   remove the operation from the worklist,
355    -   infer the shape of its output from the argument types.
3563.  If the worklist is empty, the algorithm succeeded.
357
358When processing an operation like described, we query if it registered the
359`ShapeInference` interface, using this code snippet:
360
361```c++
362  // Ask the operation to infer its output shapes.
363  LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
364
365  /// We check if an operation has a particular interface by casting.
366  if (ShapeInference shapeOp = dyn_cast<ShapeInference>(op)) {
367    shapeOp.inferShapes();
368  } else {
369    op->emitError("unable to infer shape of operation without shape "
370                  "inference interface");
371    return signalPassFailure();
372  }
373```
374
375We can then add our pass to the pass manager:
376
377```c++
378  pm.addPass(mlir::createShapeInferencePass());
379```
380
381If we rerun our original example, we now get the following:
382
383```mlir
384func @main() {
385  %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
386  %1 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64>
387  %2 = "toy.mul"(%1, %1) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64>
388  toy.print %2 : tensor<3x2xf64>
389  toy.return
390}
391```
392
393You can build `toyc-ch4` and try yourself: `toyc-ch4
394test/Examples/Toy/Ch4/codegen.toy -emit=mlir -opt`.
395
396In the [next chapter](Ch-5.md), we will start the process of code generation by
397targeting a lower level dialect for optimizing some of the more compute-heavy
398Toy operations.
399