1 //===- Builders.cpp - MLIR Declarative Builder Classes --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SCF/EDSC/Builders.h"
10 #include "mlir/Dialect/SCF/SCF.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13
14 using namespace mlir;
15 using namespace mlir::edsc;
16
17 mlir::scf::LoopNest
loopNestBuilder(ValueRange lbs,ValueRange ubs,ValueRange steps,function_ref<void (ValueRange)> fun)18 mlir::edsc::loopNestBuilder(ValueRange lbs, ValueRange ubs, ValueRange steps,
19 function_ref<void(ValueRange)> fun) {
20 // Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
21 // the expected function interface.
22 assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
23 return mlir::scf::buildLoopNest(
24 ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lbs, ubs,
25 steps, [&](OpBuilder &builder, Location loc, ValueRange ivs) {
26 ScopedContext context(builder, loc);
27 if (fun)
28 fun(ivs);
29 });
30 }
31
32 mlir::scf::LoopNest
loopNestBuilder(Value lb,Value ub,Value step,function_ref<void (Value)> fun)33 mlir::edsc::loopNestBuilder(Value lb, Value ub, Value step,
34 function_ref<void(Value)> fun) {
35 // Delegates to the ValueRange-based version by wrapping the lambda.
36 auto wrapper = [&](ValueRange ivs) {
37 assert(ivs.size() == 1);
38 if (fun)
39 fun(ivs[0]);
40 };
41 return loopNestBuilder(ValueRange(lb), ValueRange(ub), ValueRange(step),
42 wrapper);
43 }
44
loopNestBuilder(Value lb,Value ub,Value step,ValueRange iterArgInitValues,function_ref<scf::ValueVector (Value,ValueRange)> fun)45 mlir::scf::LoopNest mlir::edsc::loopNestBuilder(
46 Value lb, Value ub, Value step, ValueRange iterArgInitValues,
47 function_ref<scf::ValueVector(Value, ValueRange)> fun) {
48 // Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
49 // the expected function interface.
50 assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
51 return mlir::scf::buildLoopNest(
52 ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lb, ub,
53 step, iterArgInitValues,
54 [&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) {
55 assert(ivs.size() == 1 && "expected one induction variable");
56 ScopedContext context(builder, loc);
57 if (fun)
58 return fun(ivs[0], args);
59 return scf::ValueVector(iterArgInitValues.begin(),
60 iterArgInitValues.end());
61 });
62 }
63
loopNestBuilder(ValueRange lbs,ValueRange ubs,ValueRange steps,ValueRange iterArgInitValues,function_ref<scf::ValueVector (ValueRange,ValueRange)> fun)64 mlir::scf::LoopNest mlir::edsc::loopNestBuilder(
65 ValueRange lbs, ValueRange ubs, ValueRange steps,
66 ValueRange iterArgInitValues,
67 function_ref<scf::ValueVector(ValueRange, ValueRange)> fun) {
68 // Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
69 // the expected function interface.
70 assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
71 return mlir::scf::buildLoopNest(
72 ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lbs, ubs,
73 steps, iterArgInitValues,
74 [&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) {
75 ScopedContext context(builder, loc);
76 if (fun)
77 return fun(ivs, args);
78 return scf::ValueVector(iterArgInitValues.begin(),
79 iterArgInitValues.end());
80 });
81 }
82
83 static std::function<void(OpBuilder &, Location)>
wrapIfBody(function_ref<scf::ValueVector ()> body,TypeRange expectedTypes)84 wrapIfBody(function_ref<scf::ValueVector()> body, TypeRange expectedTypes) {
85 (void)expectedTypes;
86 return [=](OpBuilder &builder, Location loc) {
87 ScopedContext context(builder, loc);
88 scf::ValueVector returned = body();
89 assert(ValueRange(returned).getTypes() == expectedTypes &&
90 "'if' body builder returned values of unexpected type");
91 builder.create<scf::YieldOp>(loc, returned);
92 };
93 }
94
95 ValueRange
conditionBuilder(TypeRange results,Value condition,function_ref<scf::ValueVector ()> thenBody,function_ref<scf::ValueVector ()> elseBody,scf::IfOp * ifOp)96 mlir::edsc::conditionBuilder(TypeRange results, Value condition,
97 function_ref<scf::ValueVector()> thenBody,
98 function_ref<scf::ValueVector()> elseBody,
99 scf::IfOp *ifOp) {
100 assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
101 assert(thenBody && "thenBody is mandatory");
102
103 auto newOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
104 ScopedContext::getLocation(), results, condition,
105 wrapIfBody(thenBody, results), wrapIfBody(elseBody, results));
106 if (ifOp)
107 *ifOp = newOp;
108 return newOp.getResults();
109 }
110
111 static std::function<void(OpBuilder &, Location)>
wrapZeroResultIfBody(function_ref<void ()> body)112 wrapZeroResultIfBody(function_ref<void()> body) {
113 return [=](OpBuilder &builder, Location loc) {
114 ScopedContext context(builder, loc);
115 body();
116 builder.create<scf::YieldOp>(loc);
117 };
118 }
119
conditionBuilder(Value condition,function_ref<void ()> thenBody,function_ref<void ()> elseBody,scf::IfOp * ifOp)120 ValueRange mlir::edsc::conditionBuilder(Value condition,
121 function_ref<void()> thenBody,
122 function_ref<void()> elseBody,
123 scf::IfOp *ifOp) {
124 assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
125 assert(thenBody && "thenBody is mandatory");
126
127 auto newOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
128 ScopedContext::getLocation(), condition, wrapZeroResultIfBody(thenBody),
129 elseBody ? llvm::function_ref<void(OpBuilder &, Location)>(
130 wrapZeroResultIfBody(elseBody))
131 : llvm::function_ref<void(OpBuilder &, Location)>(nullptr));
132 if (ifOp)
133 *ifOp = newOp;
134 return {};
135 }
136