1 //===- Builders.cpp - MLIR Declarative Builder Classes --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/EDSC/Builders.h"
10 #include "mlir/Dialect/SCF/SCF.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 
14 using namespace mlir;
15 using namespace mlir::edsc;
16 
17 mlir::scf::LoopNest
loopNestBuilder(ValueRange lbs,ValueRange ubs,ValueRange steps,function_ref<void (ValueRange)> fun)18 mlir::edsc::loopNestBuilder(ValueRange lbs, ValueRange ubs, ValueRange steps,
19                             function_ref<void(ValueRange)> fun) {
20   // Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
21   // the expected function interface.
22   assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
23   return mlir::scf::buildLoopNest(
24       ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lbs, ubs,
25       steps, [&](OpBuilder &builder, Location loc, ValueRange ivs) {
26         ScopedContext context(builder, loc);
27         if (fun)
28           fun(ivs);
29       });
30 }
31 
32 mlir::scf::LoopNest
loopNestBuilder(Value lb,Value ub,Value step,function_ref<void (Value)> fun)33 mlir::edsc::loopNestBuilder(Value lb, Value ub, Value step,
34                             function_ref<void(Value)> fun) {
35   // Delegates to the ValueRange-based version by wrapping the lambda.
36   auto wrapper = [&](ValueRange ivs) {
37     assert(ivs.size() == 1);
38     if (fun)
39       fun(ivs[0]);
40   };
41   return loopNestBuilder(ValueRange(lb), ValueRange(ub), ValueRange(step),
42                          wrapper);
43 }
44 
loopNestBuilder(Value lb,Value ub,Value step,ValueRange iterArgInitValues,function_ref<scf::ValueVector (Value,ValueRange)> fun)45 mlir::scf::LoopNest mlir::edsc::loopNestBuilder(
46     Value lb, Value ub, Value step, ValueRange iterArgInitValues,
47     function_ref<scf::ValueVector(Value, ValueRange)> fun) {
48   // Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
49   // the expected function interface.
50   assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
51   return mlir::scf::buildLoopNest(
52       ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lb, ub,
53       step, iterArgInitValues,
54       [&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) {
55         assert(ivs.size() == 1 && "expected one induction variable");
56         ScopedContext context(builder, loc);
57         if (fun)
58           return fun(ivs[0], args);
59         return scf::ValueVector(iterArgInitValues.begin(),
60                                 iterArgInitValues.end());
61       });
62 }
63 
loopNestBuilder(ValueRange lbs,ValueRange ubs,ValueRange steps,ValueRange iterArgInitValues,function_ref<scf::ValueVector (ValueRange,ValueRange)> fun)64 mlir::scf::LoopNest mlir::edsc::loopNestBuilder(
65     ValueRange lbs, ValueRange ubs, ValueRange steps,
66     ValueRange iterArgInitValues,
67     function_ref<scf::ValueVector(ValueRange, ValueRange)> fun) {
68   // Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
69   // the expected function interface.
70   assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
71   return mlir::scf::buildLoopNest(
72       ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lbs, ubs,
73       steps, iterArgInitValues,
74       [&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) {
75         ScopedContext context(builder, loc);
76         if (fun)
77           return fun(ivs, args);
78         return scf::ValueVector(iterArgInitValues.begin(),
79                                 iterArgInitValues.end());
80       });
81 }
82 
83 static std::function<void(OpBuilder &, Location)>
wrapIfBody(function_ref<scf::ValueVector ()> body,TypeRange expectedTypes)84 wrapIfBody(function_ref<scf::ValueVector()> body, TypeRange expectedTypes) {
85   (void)expectedTypes;
86   return [=](OpBuilder &builder, Location loc) {
87     ScopedContext context(builder, loc);
88     scf::ValueVector returned = body();
89     assert(ValueRange(returned).getTypes() == expectedTypes &&
90            "'if' body builder returned values of unexpected type");
91     builder.create<scf::YieldOp>(loc, returned);
92   };
93 }
94 
95 ValueRange
conditionBuilder(TypeRange results,Value condition,function_ref<scf::ValueVector ()> thenBody,function_ref<scf::ValueVector ()> elseBody,scf::IfOp * ifOp)96 mlir::edsc::conditionBuilder(TypeRange results, Value condition,
97                              function_ref<scf::ValueVector()> thenBody,
98                              function_ref<scf::ValueVector()> elseBody,
99                              scf::IfOp *ifOp) {
100   assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
101   assert(thenBody && "thenBody is mandatory");
102 
103   auto newOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
104       ScopedContext::getLocation(), results, condition,
105       wrapIfBody(thenBody, results), wrapIfBody(elseBody, results));
106   if (ifOp)
107     *ifOp = newOp;
108   return newOp.getResults();
109 }
110 
111 static std::function<void(OpBuilder &, Location)>
wrapZeroResultIfBody(function_ref<void ()> body)112 wrapZeroResultIfBody(function_ref<void()> body) {
113   return [=](OpBuilder &builder, Location loc) {
114     ScopedContext context(builder, loc);
115     body();
116     builder.create<scf::YieldOp>(loc);
117   };
118 }
119 
conditionBuilder(Value condition,function_ref<void ()> thenBody,function_ref<void ()> elseBody,scf::IfOp * ifOp)120 ValueRange mlir::edsc::conditionBuilder(Value condition,
121                                         function_ref<void()> thenBody,
122                                         function_ref<void()> elseBody,
123                                         scf::IfOp *ifOp) {
124   assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
125   assert(thenBody && "thenBody is mandatory");
126 
127   auto newOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
128       ScopedContext::getLocation(), condition, wrapZeroResultIfBody(thenBody),
129       elseBody ? llvm::function_ref<void(OpBuilder &, Location)>(
130                      wrapZeroResultIfBody(elseBody))
131                : llvm::function_ref<void(OpBuilder &, Location)>(nullptr));
132   if (ifOp)
133     *ifOp = newOp;
134   return {};
135 }
136