1 //===- Builders.cpp - MLIR Declarative Builder Classes --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/EDSC/Builders.h"
10 #include "mlir/Dialect/StandardOps/EDSC/Builders.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 
14 using namespace mlir;
15 using namespace mlir::edsc;
16 
affineLoopNestBuilder(ValueRange lbs,ValueRange ubs,ArrayRef<int64_t> steps,function_ref<void (ValueRange)> bodyBuilderFn)17 void mlir::edsc::affineLoopNestBuilder(
18     ValueRange lbs, ValueRange ubs, ArrayRef<int64_t> steps,
19     function_ref<void(ValueRange)> bodyBuilderFn) {
20   assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
21 
22   // Wrap the body builder function into an interface compatible with the main
23   // builder.
24   auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc,
25                               ValueRange ivs) {
26     ScopedContext context(nestedBuilder, nestedLoc);
27     bodyBuilderFn(ivs);
28   };
29   function_ref<void(OpBuilder &, Location, ValueRange)> wrapper;
30   if (bodyBuilderFn)
31     wrapper = wrappedBuilderFn;
32 
33   // Extract the builder, location and construct the loop nest.
34   OpBuilder &builder = ScopedContext::getBuilderRef();
35   Location loc = ScopedContext::getLocation();
36   buildAffineLoopNest(builder, loc, lbs, ubs, steps, wrapper);
37 }
38 
affineLoopBuilder(ValueRange lbs,ValueRange ubs,int64_t step,function_ref<void (Value)> bodyBuilderFn)39 void mlir::edsc::affineLoopBuilder(ValueRange lbs, ValueRange ubs, int64_t step,
40                                    function_ref<void(Value)> bodyBuilderFn) {
41   // Fetch the builder and location.
42   assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
43   OpBuilder &builder = ScopedContext::getBuilderRef();
44   Location loc = ScopedContext::getLocation();
45 
46   // Create the actual loop and call the body builder, if provided, after
47   // updating the scoped context.
48   builder.create<AffineForOp>(
49       loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs,
50       builder.getMultiDimIdentityMap(ubs.size()), step, llvm::None,
51       [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
52           ValueRange itrArgs) {
53         if (bodyBuilderFn) {
54           ScopedContext nestedContext(nestedBuilder, nestedLoc);
55           OpBuilder::InsertionGuard guard(nestedBuilder);
56           bodyBuilderFn(iv);
57         }
58         nestedBuilder.create<AffineYieldOp>(nestedLoc);
59       });
60 }
61 
affineLoopBuilder(ValueRange lbs,ValueRange ubs,int64_t step,ValueRange iterArgs,function_ref<void (Value,ValueRange)> bodyBuilderFn)62 void mlir::edsc::affineLoopBuilder(
63     ValueRange lbs, ValueRange ubs, int64_t step, ValueRange iterArgs,
64     function_ref<void(Value, ValueRange)> bodyBuilderFn) {
65   // Fetch the builder and location.
66   assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
67   OpBuilder &builder = ScopedContext::getBuilderRef();
68   Location loc = ScopedContext::getLocation();
69 
70   // Create the actual loop and call the body builder, if provided, after
71   // updating the scoped context.
72   builder.create<AffineForOp>(
73       loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs,
74       builder.getMultiDimIdentityMap(ubs.size()), step, iterArgs,
75       [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
76           ValueRange itrArgs) {
77         if (bodyBuilderFn) {
78           ScopedContext nestedContext(nestedBuilder, nestedLoc);
79           OpBuilder::InsertionGuard guard(nestedBuilder);
80           bodyBuilderFn(iv, itrArgs);
81         } else if (itrArgs.empty())
82           nestedBuilder.create<AffineYieldOp>(nestedLoc);
83       });
84 }
85 
86 static std::pair<AffineExpr, Value>
categorizeValueByAffineType(MLIRContext * context,Value val,unsigned & numDims,unsigned & numSymbols)87 categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
88                             unsigned &numSymbols) {
89   AffineExpr d;
90   Value resultVal = nullptr;
91   if (auto constant = val.getDefiningOp<ConstantIndexOp>()) {
92     d = getAffineConstantExpr(constant.getValue(), context);
93   } else if (isValidSymbol(val) && !isValidDim(val)) {
94     d = getAffineSymbolExpr(numSymbols++, context);
95     resultVal = val;
96   } else {
97     d = getAffineDimExpr(numDims++, context);
98     resultVal = val;
99   }
100   return std::make_pair(d, resultVal);
101 }
102 
createBinaryIndexHandle(Value lhs,Value rhs,function_ref<AffineExpr (AffineExpr,AffineExpr)> affCombiner)103 static Value createBinaryIndexHandle(
104     Value lhs, Value rhs,
105     function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
106   MLIRContext *context = ScopedContext::getContext();
107   unsigned numDims = 0, numSymbols = 0;
108   AffineExpr d0, d1;
109   Value v0, v1;
110   std::tie(d0, v0) =
111       categorizeValueByAffineType(context, lhs, numDims, numSymbols);
112   std::tie(d1, v1) =
113       categorizeValueByAffineType(context, rhs, numDims, numSymbols);
114   SmallVector<Value, 2> operands;
115   if (v0)
116     operands.push_back(v0);
117   if (v1)
118     operands.push_back(v1);
119   auto map = AffineMap::get(numDims, numSymbols, affCombiner(d0, d1));
120 
121   // TODO: createOrFold when available.
122   Operation *op =
123       makeComposedAffineApply(ScopedContext::getBuilderRef(),
124                               ScopedContext::getLocation(), map, operands)
125           .getOperation();
126   assert(op->getNumResults() == 1 && "Expected single result AffineApply");
127   return op->getResult(0);
128 }
129 
130 template <typename IOp, typename FOp>
createBinaryHandle(Value lhs,Value rhs,function_ref<AffineExpr (AffineExpr,AffineExpr)> affCombiner)131 static Value createBinaryHandle(
132     Value lhs, Value rhs,
133     function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
134   auto thisType = lhs.getType();
135   auto thatType = rhs.getType();
136   assert(thisType == thatType && "cannot mix types in operators");
137   (void)thisType;
138   (void)thatType;
139   if (thisType.isIndex()) {
140     return createBinaryIndexHandle(lhs, rhs, affCombiner);
141   } else if (thisType.isSignlessInteger()) {
142     return ValueBuilder<IOp>(lhs, rhs);
143   } else if (thisType.isa<FloatType>()) {
144     return ValueBuilder<FOp>(lhs, rhs);
145   } else if (thisType.isa<VectorType, TensorType>()) {
146     auto aggregateType = thisType.cast<ShapedType>();
147     if (aggregateType.getElementType().isSignlessInteger())
148       return ValueBuilder<IOp>(lhs, rhs);
149     else if (aggregateType.getElementType().isa<FloatType>())
150       return ValueBuilder<FOp>(lhs, rhs);
151   }
152   llvm_unreachable("failed to create a Value");
153 }
154 
operator +(Value lhs,Value rhs)155 Value mlir::edsc::op::operator+(Value lhs, Value rhs) {
156   return createBinaryHandle<AddIOp, AddFOp>(
157       lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; });
158 }
159 
operator -(Value lhs,Value rhs)160 Value mlir::edsc::op::operator-(Value lhs, Value rhs) {
161   return createBinaryHandle<SubIOp, SubFOp>(
162       lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; });
163 }
164 
operator *(Value lhs,Value rhs)165 Value mlir::edsc::op::operator*(Value lhs, Value rhs) {
166   return createBinaryHandle<MulIOp, MulFOp>(
167       lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; });
168 }
169 
operator /(Value lhs,Value rhs)170 Value mlir::edsc::op::operator/(Value lhs, Value rhs) {
171   return createBinaryHandle<SignedDivIOp, DivFOp>(
172       lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr {
173         llvm_unreachable("only exprs of non-index type support operator/");
174       });
175 }
176 
operator %(Value lhs,Value rhs)177 Value mlir::edsc::op::operator%(Value lhs, Value rhs) {
178   return createBinaryHandle<SignedRemIOp, RemFOp>(
179       lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; });
180 }
181 
floorDiv(Value lhs,Value rhs)182 Value mlir::edsc::op::floorDiv(Value lhs, Value rhs) {
183   return createBinaryIndexHandle(
184       lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); });
185 }
186 
ceilDiv(Value lhs,Value rhs)187 Value mlir::edsc::op::ceilDiv(Value lhs, Value rhs) {
188   return createBinaryIndexHandle(
189       lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); });
190 }
191 
negate(Value value)192 Value mlir::edsc::op::negate(Value value) {
193   assert(value.getType().isInteger(1) && "expected boolean expression");
194   return ValueBuilder<ConstantIntOp>(1, 1) - value;
195 }
196 
operator &&(Value lhs,Value rhs)197 Value mlir::edsc::op::operator&&(Value lhs, Value rhs) {
198   assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS");
199   assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS");
200   return ValueBuilder<AndOp>(lhs, rhs);
201 }
202 
operator ||(Value lhs,Value rhs)203 Value mlir::edsc::op::operator||(Value lhs, Value rhs) {
204   assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS");
205   assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS");
206   return ValueBuilder<OrOp>(lhs, rhs);
207 }
208 
createIComparisonExpr(CmpIPredicate predicate,Value lhs,Value rhs)209 static Value createIComparisonExpr(CmpIPredicate predicate, Value lhs,
210                                    Value rhs) {
211   auto lhsType = lhs.getType();
212   auto rhsType = rhs.getType();
213   (void)lhsType;
214   (void)rhsType;
215   assert(lhsType == rhsType && "cannot mix types in operators");
216   assert((lhsType.isa<IndexType>() || lhsType.isSignlessInteger()) &&
217          "only integer comparisons are supported");
218 
219   return ScopedContext::getBuilderRef().create<CmpIOp>(
220       ScopedContext::getLocation(), predicate, lhs, rhs);
221 }
222 
createFComparisonExpr(CmpFPredicate predicate,Value lhs,Value rhs)223 static Value createFComparisonExpr(CmpFPredicate predicate, Value lhs,
224                                    Value rhs) {
225   auto lhsType = lhs.getType();
226   auto rhsType = rhs.getType();
227   (void)lhsType;
228   (void)rhsType;
229   assert(lhsType == rhsType && "cannot mix types in operators");
230   assert(lhsType.isa<FloatType>() && "only float comparisons are supported");
231 
232   return ScopedContext::getBuilderRef().create<CmpFOp>(
233       ScopedContext::getLocation(), predicate, lhs, rhs);
234 }
235 
236 // All floating point comparison are ordered through EDSL
eq(Value lhs,Value rhs)237 Value mlir::edsc::op::eq(Value lhs, Value rhs) {
238   auto type = lhs.getType();
239   return type.isa<FloatType>()
240              ? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs)
241              : createIComparisonExpr(CmpIPredicate::eq, lhs, rhs);
242 }
ne(Value lhs,Value rhs)243 Value mlir::edsc::op::ne(Value lhs, Value rhs) {
244   auto type = lhs.getType();
245   return type.isa<FloatType>()
246              ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
247              : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs);
248 }
slt(Value lhs,Value rhs)249 Value mlir::edsc::op::slt(Value lhs, Value rhs) {
250   auto type = lhs.getType();
251   return type.isa<FloatType>()
252              ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
253              : createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
254 }
sle(Value lhs,Value rhs)255 Value mlir::edsc::op::sle(Value lhs, Value rhs) {
256   auto type = lhs.getType();
257   return type.isa<FloatType>()
258              ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
259              : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs);
260 }
sgt(Value lhs,Value rhs)261 Value mlir::edsc::op::sgt(Value lhs, Value rhs) {
262   auto type = lhs.getType();
263   return type.isa<FloatType>()
264              ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
265              : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs);
266 }
sge(Value lhs,Value rhs)267 Value mlir::edsc::op::sge(Value lhs, Value rhs) {
268   auto type = lhs.getType();
269   return type.isa<FloatType>()
270              ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
271              : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs);
272 }
ult(Value lhs,Value rhs)273 Value mlir::edsc::op::ult(Value lhs, Value rhs) {
274   auto type = lhs.getType();
275   return type.isa<FloatType>()
276              ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
277              : createIComparisonExpr(CmpIPredicate::ult, lhs, rhs);
278 }
ule(Value lhs,Value rhs)279 Value mlir::edsc::op::ule(Value lhs, Value rhs) {
280   auto type = lhs.getType();
281   return type.isa<FloatType>()
282              ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
283              : createIComparisonExpr(CmpIPredicate::ule, lhs, rhs);
284 }
ugt(Value lhs,Value rhs)285 Value mlir::edsc::op::ugt(Value lhs, Value rhs) {
286   auto type = lhs.getType();
287   return type.isa<FloatType>()
288              ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
289              : createIComparisonExpr(CmpIPredicate::ugt, lhs, rhs);
290 }
uge(Value lhs,Value rhs)291 Value mlir::edsc::op::uge(Value lhs, Value rhs) {
292   auto type = lhs.getType();
293   return type.isa<FloatType>()
294              ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
295              : createIComparisonExpr(CmpIPredicate::uge, lhs, rhs);
296 }
297