1# Chapter 7: Adding a Composite Type to Toy
2
3[TOC]
4
5In the [previous chapter](Ch-6.md), we demonstrated an end-to-end compilation
6flow from our Toy front-end to LLVM IR. In this chapter, we will extend the Toy
7language to support a new composite `struct` type.
8
9## Defining a `struct` in Toy
10
11The first thing we need to define is the interface of this type in our `toy`
12source language. The general syntax of a `struct` type in Toy is as follows:
13
14```toy
15# A struct is defined by using the `struct` keyword followed by a name.
16struct MyStruct {
17  # Inside of the struct is a list of variable declarations without initializers
18  # or shapes, which may also be other previously defined structs.
19  var a;
20  var b;
21}
22```
23
24Structs may now be used in functions as variables or parameters by using the
25name of the struct instead of `var`. The members of the struct are accessed via
26a `.` access operator. Values of `struct` type may be initialized with a
27composite initializer, or a comma-separated list of other initializers
28surrounded by `{}`. An example is shown below:
29
30```toy
31struct Struct {
32  var a;
33  var b;
34}
35
36# User defined generic function may operate on struct types as well.
37def multiply_transpose(Struct value) {
38  # We can access the elements of a struct via the '.' operator.
39  return transpose(value.a) * transpose(value.b);
40}
41
42def main() {
43  # We initialize struct values using a composite initializer.
44  Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]};
45
46  # We pass these arguments to functions like we do with variables.
47  var c = multiply_transpose(value);
48  print(c);
49}
50```
51
52## Defining a `struct` in MLIR
53
54In MLIR, we will also need a representation for our struct types. MLIR does not
55provide a type that does exactly what we need, so we will need to define our
56own. We will simply define our `struct` as an unnamed container of a set of
57element types. The name of the `struct` and its elements are only useful for the
58AST of our `toy` compiler, so we don't need to encode it in the MLIR
59representation.
60
61### Defining the Type Class
62
63#### Defining the Type Class
64
65As mentioned in [chapter 2](Ch-2.md), [`Type`](../../LangRef.md#type-system)
66objects in MLIR are value-typed and rely on having an internal storage object
67that holds the actual data for the type. The `Type` class in itself acts as a
68simple wrapper around an internal `TypeStorage` object that is uniqued within an
69instance of an `MLIRContext`. When constructing a `Type`, we are internally just
70constructing and uniquing an instance of a storage class.
71
72When defining a new `Type` that contains parametric data (e.g. the `struct`
73type, which requires additional information to hold the element types), we will
74need to provide a derived storage class. The `singleton` types that don't have
75any additional data (e.g. the [`index` type](../../LangRef.md#index-type)) don't
76require a storage class and use the default `TypeStorage`.
77
78##### Defining the Storage Class
79
80Type storage objects contain all of the data necessary to construct and unique a
81type instance. Derived storage classes must inherit from the base
82`mlir::TypeStorage` and provide a set of aliases and hooks that will be used by
83the `MLIRContext` for uniquing. Below is the definition of the storage instance
84for our `struct` type, with each of the necessary requirements detailed inline:
85
86```c++
87/// This class represents the internal storage of the Toy `StructType`.
88struct StructTypeStorage : public mlir::TypeStorage {
89  /// The `KeyTy` is a required type that provides an interface for the storage
90  /// instance. This type will be used when uniquing an instance of the type
91  /// storage. For our struct type, we will unique each instance structurally on
92  /// the elements that it contains.
93  using KeyTy = llvm::ArrayRef<mlir::Type>;
94
95  /// A constructor for the type storage instance.
96  StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
97      : elementTypes(elementTypes) {}
98
99  /// Define the comparison function for the key type with the current storage
100  /// instance. This is used when constructing a new instance to ensure that we
101  /// haven't already uniqued an instance of the given key.
102  bool operator==(const KeyTy &key) const { return key == elementTypes; }
103
104  /// Define a hash function for the key type. This is used when uniquing
105  /// instances of the storage.
106  /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type
107  /// have hash functions available, so we could just omit this entirely.
108  static llvm::hash_code hashKey(const KeyTy &key) {
109    return llvm::hash_value(key);
110  }
111
112  /// Define a construction function for the key type from a set of parameters.
113  /// These parameters will be provided when constructing the storage instance
114  /// itself, see the `StructType::get` method further below.
115  /// Note: This method isn't necessary because KeyTy can be directly
116  /// constructed with the given parameters.
117  static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
118    return KeyTy(elementTypes);
119  }
120
121  /// Define a construction method for creating a new instance of this storage.
122  /// This method takes an instance of a storage allocator, and an instance of a
123  /// `KeyTy`. The given allocator must be used for *all* necessary dynamic
124  /// allocations used to create the type storage and its internal.
125  static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
126                                      const KeyTy &key) {
127    // Copy the elements from the provided `KeyTy` into the allocator.
128    llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);
129
130    // Allocate the storage instance and construct it.
131    return new (allocator.allocate<StructTypeStorage>())
132        StructTypeStorage(elementTypes);
133  }
134
135  /// The following field contains the element types of the struct.
136  llvm::ArrayRef<mlir::Type> elementTypes;
137};
138```
139
140##### Defining the Type Class
141
142With the storage class defined, we can add the definition for the user-visible
143`StructType` class. This is the class that we will actually interface with.
144
145```c++
146/// This class defines the Toy struct type. It represents a collection of
147/// element types. All derived types in MLIR must inherit from the CRTP class
148/// 'Type::TypeBase'. It takes as template parameters the concrete type
149/// (StructType), the base class to use (Type), and the storage class
150/// (StructTypeStorage).
151class StructType : public mlir::Type::TypeBase<StructType, mlir::Type,
152                                               StructTypeStorage> {
153public:
154  /// Inherit some necessary constructors from 'TypeBase'.
155  using Base::Base;
156
157  /// Create an instance of a `StructType` with the given element types. There
158  /// *must* be at least one element type.
159  static StructType get(llvm::ArrayRef<mlir::Type> elementTypes) {
160    assert(!elementTypes.empty() && "expected at least 1 element type");
161
162    // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
163    // of this type. The first parameter is the context to unique in. The
164    // parameters after are forwarded to the storage instance.
165    mlir::MLIRContext *ctx = elementTypes.front().getContext();
166    return Base::get(ctx, elementTypes);
167  }
168
169  /// Returns the element types of this struct type.
170  llvm::ArrayRef<mlir::Type> getElementTypes() {
171    // 'getImpl' returns a pointer to the internal storage instance.
172    return getImpl()->elementTypes;
173  }
174
175  /// Returns the number of element type held by this struct.
176  size_t getNumElementTypes() { return getElementTypes().size(); }
177};
178```
179
180We register this type in the `ToyDialect` constructor in a similar way to how we
181did with operations:
182
183```c++
184ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
185    : mlir::Dialect(getDialectNamespace(), ctx) {
186  addTypes<StructType>();
187}
188```
189
190With this we can now use our `StructType` when generating MLIR from Toy. See
191examples/toy/Ch7/mlir/MLIRGen.cpp for more details.
192
193### Parsing and Printing
194
195At this point we can use our `StructType` during MLIR generation and
196transformation, but we can't output or parse `.mlir`. For this we need to add
197support for parsing and printing instances of the `StructType`. This can be done
198by overriding the `parseType` and `printType` methods on the `ToyDialect`.
199
200```c++
201class ToyDialect : public mlir::Dialect {
202public:
203  /// Parse an instance of a type registered to the toy dialect.
204  mlir::Type parseType(mlir::DialectAsmParser &parser) const override;
205
206  /// Print an instance of a type registered to the toy dialect.
207  void printType(mlir::Type type,
208                 mlir::DialectAsmPrinter &printer) const override;
209};
210```
211
212These methods take an instance of a high-level parser or printer that allows for
213easily implementing the necessary functionality. Before going into the
214implementation, let's think about the syntax that we want for the `struct` type
215in the printed IR. As described in the
216[MLIR language reference](../../LangRef.md#dialect-types), dialect types are
217generally represented as: `! dialect-namespace < type-data >`, with a pretty
218form available under certain circumstances. The responsibility of our `Toy`
219parser and printer is to provide the `type-data` bits. We will define our
220`StructType` as having the following form:
221
222```
223  struct-type ::= `struct` `<` type (`,` type)* `>`
224```
225
226#### Parsing
227
228An implementation of the parser is shown below:
229
230```c++
231/// Parse an instance of a type registered to the toy dialect.
232mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
233  // Parse a struct type in the following form:
234  //   struct-type ::= `struct` `<` type (`,` type)* `>`
235
236  // NOTE: All MLIR parser function return a ParseResult. This is a
237  // specialization of LogicalResult that auto-converts to a `true` boolean
238  // value on failure to allow for chaining, but may be used with explicit
239  // `mlir::failed/mlir::succeeded` as desired.
240
241  // Parse: `struct` `<`
242  if (parser.parseKeyword("struct") || parser.parseLess())
243    return Type();
244
245  // Parse the element types of the struct.
246  SmallVector<mlir::Type, 1> elementTypes;
247  do {
248    // Parse the current element type.
249    llvm::SMLoc typeLoc = parser.getCurrentLocation();
250    mlir::Type elementType;
251    if (parser.parseType(elementType))
252      return nullptr;
253
254    // Check that the type is either a TensorType or another StructType.
255    if (!elementType.isa<mlir::TensorType, StructType>()) {
256      parser.emitError(typeLoc, "element type for a struct must either "
257                                "be a TensorType or a StructType, got: ")
258          << elementType;
259      return Type();
260    }
261    elementTypes.push_back(elementType);
262
263    // Parse the optional: `,`
264  } while (succeeded(parser.parseOptionalComma()));
265
266  // Parse: `>`
267  if (parser.parseGreater())
268    return Type();
269  return StructType::get(elementTypes);
270}
271```
272
273#### Printing
274
275An implementation of the printer is shown below:
276
277```c++
278/// Print an instance of a type registered to the toy dialect.
279void ToyDialect::printType(mlir::Type type,
280                           mlir::DialectAsmPrinter &printer) const {
281  // Currently the only toy type is a struct type.
282  StructType structType = type.cast<StructType>();
283
284  // Print the struct type according to the parser format.
285  printer << "struct<";
286  llvm::interleaveComma(structType.getElementTypes(), printer);
287  printer << '>';
288}
289```
290
291Before moving on, let's look at a quick of example showcasing the functionality
292we have now:
293
294```toy
295struct Struct {
296  var a;
297  var b;
298}
299
300def multiply_transpose(Struct value) {
301}
302```
303
304Which generates the following:
305
306```mlir
307module {
308  func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) {
309    toy.return
310  }
311}
312```
313
314### Operating on `StructType`
315
316Now that the `struct` type has been defined, and we can round-trip it through
317the IR. The next step is to add support for using it within our operations.
318
319#### Updating Existing Operations
320
321A few of our existing operations will need to be updated to handle `StructType`.
322The first step is to make the ODS framework aware of our Type so that we can use
323it in the operation definitions. A simple example is shown below:
324
325```tablegen
326// Provide a definition for the Toy StructType for use in ODS. This allows for
327// using StructType in a similar way to Tensor or MemRef.
328def Toy_StructType :
329    Type<CPred<"$_self.isa<StructType>()">, "Toy struct type">;
330
331// Provide a definition of the types that are used within the Toy dialect.
332def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>;
333```
334
335We can then update our operations, e.g. `ReturnOp`, to also accept the
336`Toy_StructType`:
337
338```tablegen
339def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
340  ...
341  let arguments = (ins Variadic<Toy_Type>:$input);
342  ...
343}
344```
345
346#### Adding New `Toy` Operations
347
348In addition to the existing operations, we will be adding a few new operations
349that will provide more specific handling of `structs`.
350
351##### `toy.struct_constant`
352
353This new operation materializes a constant value for a struct. In our current
354modeling, we just use an [array attribute](../../LangRef.md#array-attribute)
355that contains a set of constant values for each of the `struct` elements.
356
357```mlir
358  %0 = toy.struct_constant [
359    dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>
360  ] : !toy.struct<tensor<*xf64>>
361```
362
363##### `toy.struct_access`
364
365This new operation materializes the Nth element of a `struct` value.
366
367```mlir
368  // Using %0 from above
369  %1 = toy.struct_access %0[0] : !toy.struct<tensor<*xf64>> -> tensor<*xf64>
370```
371
372With these operations, we can revisit our original example:
373
374```toy
375struct Struct {
376  var a;
377  var b;
378}
379
380# User defined generic function may operate on struct types as well.
381def multiply_transpose(Struct value) {
382  # We can access the elements of a struct via the '.' operator.
383  return transpose(value.a) * transpose(value.b);
384}
385
386def main() {
387  # We initialize struct values using a composite initializer.
388  Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]};
389
390  # We pass these arguments to functions like we do with variables.
391  var c = multiply_transpose(value);
392  print(c);
393}
394```
395
396and finally get a full MLIR module:
397
398```mlir
399module {
400  func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> {
401    %0 = toy.struct_access %arg0[0] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
402    %1 = toy.transpose(%0 : tensor<*xf64>) to tensor<*xf64>
403    %2 = toy.struct_access %arg0[1] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
404    %3 = toy.transpose(%2 : tensor<*xf64>) to tensor<*xf64>
405    %4 = toy.mul %1, %3 : tensor<*xf64>
406    toy.return %4 : tensor<*xf64>
407  }
408  func @main() {
409    %0 = toy.struct_constant [
410      dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>,
411      dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
412    ] : !toy.struct<tensor<*xf64>, tensor<*xf64>>
413    %1 = toy.generic_call @multiply_transpose(%0) : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
414    toy.print %1 : tensor<*xf64>
415    toy.return
416  }
417}
418```
419
420#### Optimizing Operations on `StructType`
421
422Now that we have a few operations operating on `StructType`, we also have many
423new constant folding opportunities.
424
425After inlining, the MLIR module in the previous section looks something like:
426
427```mlir
428module {
429  func @main() {
430    %0 = toy.struct_constant [
431      dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>,
432      dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
433    ] : !toy.struct<tensor<*xf64>, tensor<*xf64>>
434    %1 = toy.struct_access %0[0] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
435    %2 = toy.transpose(%1 : tensor<*xf64>) to tensor<*xf64>
436    %3 = toy.struct_access %0[1] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
437    %4 = toy.transpose(%3 : tensor<*xf64>) to tensor<*xf64>
438    %5 = toy.mul %2, %4 : tensor<*xf64>
439    toy.print %5 : tensor<*xf64>
440    toy.return
441  }
442}
443```
444
445We have several `toy.struct_access` operations that access into a
446`toy.struct_constant`. As detailed in [chapter 3](Ch-3.md) (FoldConstantReshape),
447we can add folders for these `toy` operations by setting the `hasFolder` bit
448on the operation definition and providing a definition of the `*Op::fold`
449method.
450
451```c++
452/// Fold constants.
453OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
454
455/// Fold struct constants.
456OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
457  return value();
458}
459
460/// Fold simple struct access operations that access into a constant.
461OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
462  auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
463  if (!structAttr)
464    return nullptr;
465
466  size_t elementIndex = index().getZExtValue();
467  return structAttr[elementIndex];
468}
469```
470
471To ensure that MLIR generates the proper constant operations when folding our
472`Toy` operations, i.e. `ConstantOp` for `TensorType` and `StructConstant` for
473`StructType`, we will need to provide an override for the dialect hook
474`materializeConstant`. This allows for generic MLIR operations to create
475constants for the `Toy` dialect when necessary.
476
477```c++
478mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
479                                                 mlir::Attribute value,
480                                                 mlir::Type type,
481                                                 mlir::Location loc) {
482  if (type.isa<StructType>())
483    return builder.create<StructConstantOp>(loc, type,
484                                            value.cast<mlir::ArrayAttr>());
485  return builder.create<ConstantOp>(loc, type,
486                                    value.cast<mlir::DenseElementsAttr>());
487}
488```
489
490With this, we can now generate code that can be generated to LLVM without any
491changes to our pipeline.
492
493```mlir
494module {
495  func @main() {
496    %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
497    %1 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
498    %2 = toy.mul %1, %1 : tensor<3x2xf64>
499    toy.print %2 : tensor<3x2xf64>
500    toy.return
501  }
502}
503```
504
505You can build `toyc-ch7` and try yourself: `toyc-ch7
506test/Examples/Toy/Ch7/struct-codegen.toy -emit=mlir`. More details on defining
507custom types can be found in
508[DefiningAttributesAndTypes](../DefiningAttributesAndTypes.md).
509