1 //===- Builders.h - MLIR Declarative Linalg Builders ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides intuitive composable interfaces for building structured MLIR
10 // snippets in a declarative fashion.
11 //
12 //===----------------------------------------------------------------------===//
13 #ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
14 #define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
15 
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
18 #include "mlir/EDSC/Builders.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/Builders.h"
21 
22 namespace mlir {
23 class AffineForOp;
24 class BlockArgument;
25 
26 namespace scf {
27 class ParallelOp;
28 } // namespace scf
29 
30 namespace edsc {
defaultRegionBuilder(ValueRange args)31 inline void defaultRegionBuilder(ValueRange args) {}
32 
33 /// Build a `linalg.generic` op with the specified `inputs`, `outputBuffers`,
34 /// `initTensors`, `resultTensorsTypes` and `region`.
35 ///
36 /// `otherValues` and `otherAttributes` may be passed and will be appended as
37 /// operands and attributes respectively.
38 ///
39 /// Prerequisites:
40 /// =============
41 ///
42 /// 1. `inputs` may contain StructuredIndexed that capture either buffer or
43 /// tensor values.
44 /// 2. `outputsBuffers` may contain StructuredIndexed that capture buffer
45 /// values.
46 /// 3. `initTensors` contain tensor values, without indexing maps.
47 /// 4. `resultTensorTypes` may contain StructuredIndexed that capture return
48 /// tensor types.
49 Operation *makeGenericLinalgOp(
50     ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
51     ArrayRef<StructuredIndexed> outputBuffers, ArrayRef<Value> initTensors,
52     ArrayRef<StructuredIndexed> resultTensorTypes,
53     function_ref<void(ValueRange)> regionBuilder = defaultRegionBuilder,
54     ArrayRef<Value> otherValues = {}, ArrayRef<Attribute> otherAttributes = {});
55 
56 namespace ops {
57 using edsc::StructuredIndexed;
58 
59 //===----------------------------------------------------------------------===//
60 // EDSC builders for linalg generic operations.
61 //===----------------------------------------------------------------------===//
62 
63 /// Build the body of a region to compute a scalar multiply, under the current
64 /// ScopedContext, at the current insert point.
65 void mulRegionBuilder(ValueRange args);
66 
67 /// Build the body of a region to compute a scalar multiply-accumulate, under
68 /// the current ScopedContext, at the current insert point.
69 void macRegionBuilder(ValueRange args);
70 
71 /// TODO: In the future we should tie these implementations to something in
72 /// Tablegen that generates the proper interfaces and the proper sugared named
73 /// ops.
74 
75 /// Build a linalg.pointwise, under the current ScopedContext, at the current
76 /// insert point, that computes:
77 /// ```
78 ///    (i0, ..., in) = (par, ..., par)
79 ///    |
80 ///    |  O...(some_subset...(i0, ..., in)) =
81 ///    |    some_pointwise_func...(I...(some_other_subset...(i0, ..., in)))
82 /// ```
83 ///
84 /// This is a very generic entry point that can be configured in many ways to
85 /// build a perfect loop nest of parallel loops with arbitrarily complex
86 /// innermost loop code and whatever (explicit) broadcast semantics.
87 ///
88 /// This can be used with both out-of-place and in-place semantics.
89 /// The client is responsible for ensuring the region operations are compatible
90 /// with in-place semantics and parallelism.
91 
92 /// Unary pointwise operation (with broadcast) entry point.
93 using UnaryPointwiseOpBuilder = function_ref<Value(Value)>;
94 Operation *linalg_generic_pointwise(UnaryPointwiseOpBuilder unaryOp,
95                                     StructuredIndexed I, StructuredIndexed O);
96 
97 /// Build a linalg.pointwise with all `parallel` iterators and a region that
98 /// computes `O = tanh(I)`. The client is responsible for specifying the proper
99 /// indexings when creating the StructuredIndexed.
100 Operation *linalg_generic_pointwise_tanh(StructuredIndexed I,
101                                          StructuredIndexed O);
102 
103 /// Binary pointwise operation (with broadcast) entry point.
104 using BinaryPointwiseOpBuilder = function_ref<Value(Value, Value)>;
105 Operation *linalg_generic_pointwise(BinaryPointwiseOpBuilder binaryOp,
106                                     StructuredIndexed I1, StructuredIndexed I2,
107                                     StructuredIndexed O);
108 
109 /// Build a linalg.pointwise with all `parallel` iterators and a region that
110 /// computes `O = I1 + I2`. The client is responsible for specifying the proper
111 /// indexings when creating the StructuredIndexed.
112 Operation *linalg_generic_pointwise_add(StructuredIndexed I1,
113                                         StructuredIndexed I2,
114                                         StructuredIndexed O);
115 
116 /// Build a linalg.pointwise with all `parallel` iterators and a region that
117 /// computes `O = max(I1, I2)`. The client is responsible for specifying the
118 /// proper indexings when creating the StructuredIndexed.
119 Operation *linalg_generic_pointwise_max(StructuredIndexed I1,
120                                         StructuredIndexed I2,
121                                         StructuredIndexed O);
122 
123 // TODO: Implement more useful pointwise operations on a per-need basis.
124 
125 using MatmulRegionBuilder = function_ref<void(ValueRange args)>;
126 
127 /// Build a linalg.generic, under the current ScopedContext, at the current
128 /// insert point, that computes:
129 /// ```
130 ///    (m, n, k) = (par, par, seq)
131 ///    |
132 ///    |  C(m, n) += A(m, k) * B(k, n)
133 /// ```
134 Operation *
135 linalg_generic_matmul(Value vA, Value vB, Value vC,
136                       MatmulRegionBuilder regionBuilder = macRegionBuilder);
137 
138 /// Build a linalg.generic, under the current ScopedContext, at the current
139 /// insert point, that computes:
140 /// ```
141 ///    (m, n, k) = (par, par, seq)
142 ///    |
143 ///    |  D(m, n) = C(m, n) + sum_k(A(m, k) * B(k, n))
144 /// ```
145 /// and returns the tensor `D`.
146 Operation *
147 linalg_generic_matmul(Value vA, Value vB, Value vC, RankedTensorType tD,
148                       MatmulRegionBuilder regionBuilder = macRegionBuilder);
149 
150 template <typename Container>
151 Operation *
152 linalg_generic_matmul(Container values,
153                       MatmulRegionBuilder regionBuilder = macRegionBuilder) {
154   assert(values.size() == 3 && "Expected exactly 3 values");
155   return linalg_generic_matmul(values[0], values[1], values[2], regionBuilder);
156 }
157 
158 /// Build a linalg.generic, under the current ScopedContext, at the current
159 /// insert point, that computes:
160 /// ```
161 ///    (batch, f, [h, w, ...], [kh, kw, ...], c) =
162 ///    |  (par, par, [par, par, ...], [red, red, ...], red)
163 ///    |
164 ///    | O(batch, [h, w, ...], f) +=
165 ///    |   I(batch,
166 ///    |     [
167 ///    |       stride[0] * h + dilations[0] * kh,
168 ///    |       stride[1] * w + dilations[1] * kw, ...
169 ///          ],
170 ///    |     c)
171 ///    |   *
172 ///    |   W([kh, kw, ...], c, f)
173 /// ```
174 /// If `dilations` or `strides` are left empty, the default value of `1` is used
175 /// along each relevant dimension.
176 ///
177 /// For now `...` must be empty (i.e. only 2-D convolutions are supported).
178 ///
179 // TODO: Extend convolution rank with some template magic.
180 Operation *linalg_generic_conv_nhwc(Value vI, Value vW, Value vO,
181                                     ArrayRef<int> strides = {},
182                                     ArrayRef<int> dilations = {});
183 
184 template <typename Container>
185 Operation *linalg_generic_conv_nhwc(Container values,
186                                     ArrayRef<int> strides = {},
187                                     ArrayRef<int> dilations = {}) {
188   assert(values.size() == 3 && "Expected exactly 3 values");
189   return linalg_generic_conv_nhwc(values[0], values[1], values[2], strides,
190                                   dilations);
191 }
192 
193 /// Build a linalg.generic, under the current ScopedContext, at the current
194 /// insert point, that computes:
195 /// ```
196 ///    (batch, dm, c, [h, w, ...], [kh, kw, ...]) =
197 ///    |  (par, par, par, [par, par, ...], [red, red, ...])
198 ///    |
199 ///    | O(batch, [h, w, ...], c * depth_multiplier) +=
200 ///    |   I(batch,
201 ///    |     [
202 ///    |       stride[0] * h + dilations[0] * kh,
203 ///    |       stride[1] * w + dilations[1] * kw, ...
204 ///          ],
205 ///    |     c)
206 ///    |   *
207 ///    |   W([kh, kw, ...], c, depth_multiplier)
208 /// ```
209 /// If `dilations` or `strides` are left empty, the default value of `1` is used
210 /// along each relevant dimension.
211 ///
212 /// For now `...` must be empty (i.e. only 2-D convolutions are supported).
213 ///
214 // TODO: Extend convolution rank with some template magic.
215 Operation *linalg_generic_dilated_conv_nhwc(Value vI, Value vW, Value vO,
216                                             int depth_multiplier = 1,
217                                             ArrayRef<int> strides = {},
218                                             ArrayRef<int> dilations = {});
219 
220 template <typename Container>
221 Operation *linalg_generic_dilated_conv_nhwc(Container values,
222                                             int depth_multiplier,
223                                             ArrayRef<int> strides = {},
224                                             ArrayRef<int> dilations = {}) {
225   assert(values.size() == 3 && "Expected exactly 3 values");
226   return linalg_generic_dilated_conv_nhwc(values[0], values[1], values[2],
227                                           depth_multiplier, strides, dilations);
228 }
229 
230 } // namespace ops
231 } // namespace edsc
232 } // namespace mlir
233 
234 #endif // MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
235