1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 using namespace mlir;
21 
getIdentifier(StringRef str)22 Identifier Builder::getIdentifier(StringRef str) {
23   return Identifier::get(str, context);
24 }
25 
26 //===----------------------------------------------------------------------===//
27 // Locations.
28 //===----------------------------------------------------------------------===//
29 
getUnknownLoc()30 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
31 
getFileLineColLoc(Identifier filename,unsigned line,unsigned column)32 Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
33                                     unsigned column) {
34   return FileLineColLoc::get(filename, line, column, context);
35 }
36 
getFusedLoc(ArrayRef<Location> locs,Attribute metadata)37 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
38   return FusedLoc::get(locs, metadata, context);
39 }
40 
41 //===----------------------------------------------------------------------===//
42 // Types.
43 //===----------------------------------------------------------------------===//
44 
getBF16Type()45 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
46 
getF16Type()47 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
48 
getF32Type()49 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
50 
getF64Type()51 FloatType Builder::getF64Type() { return FloatType::getF64(context); }
52 
getIndexType()53 IndexType Builder::getIndexType() { return IndexType::get(context); }
54 
getI1Type()55 IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
56 
getI32Type()57 IntegerType Builder::getI32Type() { return IntegerType::get(32, context); }
58 
getI64Type()59 IntegerType Builder::getI64Type() { return IntegerType::get(64, context); }
60 
getIntegerType(unsigned width)61 IntegerType Builder::getIntegerType(unsigned width) {
62   return IntegerType::get(width, context);
63 }
64 
getIntegerType(unsigned width,bool isSigned)65 IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
66   return IntegerType::get(
67       width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
68 }
69 
getFunctionType(TypeRange inputs,TypeRange results)70 FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
71   return FunctionType::get(inputs, results, context);
72 }
73 
getTupleType(TypeRange elementTypes)74 TupleType Builder::getTupleType(TypeRange elementTypes) {
75   return TupleType::get(elementTypes, context);
76 }
77 
getNoneType()78 NoneType Builder::getNoneType() { return NoneType::get(context); }
79 
80 //===----------------------------------------------------------------------===//
81 // Attributes.
82 //===----------------------------------------------------------------------===//
83 
getNamedAttr(StringRef name,Attribute val)84 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
85   return NamedAttribute(getIdentifier(name), val);
86 }
87 
getUnitAttr()88 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
89 
getBoolAttr(bool value)90 BoolAttr Builder::getBoolAttr(bool value) {
91   return BoolAttr::get(value, context);
92 }
93 
getDictionaryAttr(ArrayRef<NamedAttribute> value)94 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
95   return DictionaryAttr::get(value, context);
96 }
97 
getIndexAttr(int64_t value)98 IntegerAttr Builder::getIndexAttr(int64_t value) {
99   return IntegerAttr::get(getIndexType(), APInt(64, value));
100 }
101 
getI64IntegerAttr(int64_t value)102 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
103   return IntegerAttr::get(getIntegerType(64), APInt(64, value));
104 }
105 
getBoolVectorAttr(ArrayRef<bool> values)106 DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef<bool> values) {
107   return DenseIntElementsAttr::get(
108       VectorType::get(static_cast<int64_t>(values.size()), getI1Type()),
109       values);
110 }
111 
getI32VectorAttr(ArrayRef<int32_t> values)112 DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
113   return DenseIntElementsAttr::get(
114       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
115       values);
116 }
117 
getI64VectorAttr(ArrayRef<int64_t> values)118 DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
119   return DenseIntElementsAttr::get(
120       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
121       values);
122 }
123 
getI32TensorAttr(ArrayRef<int32_t> values)124 DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
125   return DenseIntElementsAttr::get(
126       RankedTensorType::get(static_cast<int64_t>(values.size()),
127                             getIntegerType(32)),
128       values);
129 }
130 
getI64TensorAttr(ArrayRef<int64_t> values)131 DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
132   return DenseIntElementsAttr::get(
133       RankedTensorType::get(static_cast<int64_t>(values.size()),
134                             getIntegerType(64)),
135       values);
136 }
137 
getIndexTensorAttr(ArrayRef<int64_t> values)138 DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
139   return DenseIntElementsAttr::get(
140       RankedTensorType::get(static_cast<int64_t>(values.size()),
141                             getIndexType()),
142       values);
143 }
144 
getI32IntegerAttr(int32_t value)145 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
146   return IntegerAttr::get(getIntegerType(32), APInt(32, value));
147 }
148 
getSI32IntegerAttr(int32_t value)149 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
150   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
151                           APInt(32, value, /*isSigned=*/true));
152 }
153 
getUI32IntegerAttr(uint32_t value)154 IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
155   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
156                           APInt(32, (uint64_t)value, /*isSigned=*/false));
157 }
158 
getI16IntegerAttr(int16_t value)159 IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
160   return IntegerAttr::get(getIntegerType(16), APInt(16, value));
161 }
162 
getI8IntegerAttr(int8_t value)163 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
164   return IntegerAttr::get(getIntegerType(8), APInt(8, value));
165 }
166 
getIntegerAttr(Type type,int64_t value)167 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
168   if (type.isIndex())
169     return IntegerAttr::get(type, APInt(64, value));
170   return IntegerAttr::get(
171       type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
172 }
173 
getIntegerAttr(Type type,const APInt & value)174 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
175   return IntegerAttr::get(type, value);
176 }
177 
getF64FloatAttr(double value)178 FloatAttr Builder::getF64FloatAttr(double value) {
179   return FloatAttr::get(getF64Type(), APFloat(value));
180 }
181 
getF32FloatAttr(float value)182 FloatAttr Builder::getF32FloatAttr(float value) {
183   return FloatAttr::get(getF32Type(), APFloat(value));
184 }
185 
getF16FloatAttr(float value)186 FloatAttr Builder::getF16FloatAttr(float value) {
187   return FloatAttr::get(getF16Type(), value);
188 }
189 
getFloatAttr(Type type,double value)190 FloatAttr Builder::getFloatAttr(Type type, double value) {
191   return FloatAttr::get(type, value);
192 }
193 
getFloatAttr(Type type,const APFloat & value)194 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
195   return FloatAttr::get(type, value);
196 }
197 
getStringAttr(StringRef bytes)198 StringAttr Builder::getStringAttr(StringRef bytes) {
199   return StringAttr::get(bytes, context);
200 }
201 
getArrayAttr(ArrayRef<Attribute> value)202 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
203   return ArrayAttr::get(value, context);
204 }
205 
getSymbolRefAttr(Operation * value)206 FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
207   auto symName =
208       value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
209   assert(symName && "value does not have a valid symbol name");
210   return getSymbolRefAttr(symName.getValue());
211 }
getSymbolRefAttr(StringRef value)212 FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
213   return SymbolRefAttr::get(value, getContext());
214 }
215 SymbolRefAttr
getSymbolRefAttr(StringRef value,ArrayRef<FlatSymbolRefAttr> nestedReferences)216 Builder::getSymbolRefAttr(StringRef value,
217                           ArrayRef<FlatSymbolRefAttr> nestedReferences) {
218   return SymbolRefAttr::get(value, nestedReferences, getContext());
219 }
220 
getBoolArrayAttr(ArrayRef<bool> values)221 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
222   auto attrs = llvm::to_vector<8>(llvm::map_range(
223       values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
224   return getArrayAttr(attrs);
225 }
226 
getI32ArrayAttr(ArrayRef<int32_t> values)227 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
228   auto attrs = llvm::to_vector<8>(llvm::map_range(
229       values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
230   return getArrayAttr(attrs);
231 }
getI64ArrayAttr(ArrayRef<int64_t> values)232 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
233   auto attrs = llvm::to_vector<8>(llvm::map_range(
234       values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
235   return getArrayAttr(attrs);
236 }
237 
getIndexArrayAttr(ArrayRef<int64_t> values)238 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
239   auto attrs = llvm::to_vector<8>(
240       llvm::map_range(values, [this](int64_t v) -> Attribute {
241         return getIntegerAttr(IndexType::get(getContext()), v);
242       }));
243   return getArrayAttr(attrs);
244 }
245 
getF32ArrayAttr(ArrayRef<float> values)246 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
247   auto attrs = llvm::to_vector<8>(llvm::map_range(
248       values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
249   return getArrayAttr(attrs);
250 }
251 
getF64ArrayAttr(ArrayRef<double> values)252 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
253   auto attrs = llvm::to_vector<8>(llvm::map_range(
254       values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
255   return getArrayAttr(attrs);
256 }
257 
getStrArrayAttr(ArrayRef<StringRef> values)258 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
259   auto attrs = llvm::to_vector<8>(llvm::map_range(
260       values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
261   return getArrayAttr(attrs);
262 }
263 
getTypeArrayAttr(TypeRange values)264 ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
265   auto attrs = llvm::to_vector<8>(llvm::map_range(
266       values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
267   return getArrayAttr(attrs);
268 }
269 
getAffineMapArrayAttr(ArrayRef<AffineMap> values)270 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
271   auto attrs = llvm::to_vector<8>(llvm::map_range(
272       values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
273   return getArrayAttr(attrs);
274 }
275 
getZeroAttr(Type type)276 Attribute Builder::getZeroAttr(Type type) {
277   if (type.isa<FloatType>())
278     return getFloatAttr(type, 0.0);
279   if (type.isa<IndexType>())
280     return getIndexAttr(0);
281   if (auto integerType = type.dyn_cast<IntegerType>())
282     return getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), 0));
283   if (type.isa<RankedTensorType, VectorType>()) {
284     auto vtType = type.cast<ShapedType>();
285     auto element = getZeroAttr(vtType.getElementType());
286     if (!element)
287       return {};
288     return DenseElementsAttr::get(vtType, element);
289   }
290   return {};
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // Affine Expressions, Affine Maps, and Integer Sets.
295 //===----------------------------------------------------------------------===//
296 
getAffineDimExpr(unsigned position)297 AffineExpr Builder::getAffineDimExpr(unsigned position) {
298   return mlir::getAffineDimExpr(position, context);
299 }
300 
getAffineSymbolExpr(unsigned position)301 AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
302   return mlir::getAffineSymbolExpr(position, context);
303 }
304 
getAffineConstantExpr(int64_t constant)305 AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
306   return mlir::getAffineConstantExpr(constant, context);
307 }
308 
getEmptyAffineMap()309 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
310 
getConstantAffineMap(int64_t val)311 AffineMap Builder::getConstantAffineMap(int64_t val) {
312   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
313                         getAffineConstantExpr(val));
314 }
315 
getDimIdentityMap()316 AffineMap Builder::getDimIdentityMap() {
317   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
318 }
319 
getMultiDimIdentityMap(unsigned rank)320 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
321   SmallVector<AffineExpr, 4> dimExprs;
322   dimExprs.reserve(rank);
323   for (unsigned i = 0; i < rank; ++i)
324     dimExprs.push_back(getAffineDimExpr(i));
325   return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
326                         context);
327 }
328 
getSymbolIdentityMap()329 AffineMap Builder::getSymbolIdentityMap() {
330   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
331                         getAffineSymbolExpr(0));
332 }
333 
getSingleDimShiftAffineMap(int64_t shift)334 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
335   // expr = d0 + shift.
336   auto expr = getAffineDimExpr(0) + shift;
337   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
338 }
339 
getShiftedAffineMap(AffineMap map,int64_t shift)340 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
341   SmallVector<AffineExpr, 4> shiftedResults;
342   shiftedResults.reserve(map.getNumResults());
343   for (auto resultExpr : map.getResults())
344     shiftedResults.push_back(resultExpr + shift);
345   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
346                         context);
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // OpBuilder
351 //===----------------------------------------------------------------------===//
352 
~Listener()353 OpBuilder::Listener::~Listener() {}
354 
355 /// Insert the given operation at the current insertion point and return it.
insert(Operation * op)356 Operation *OpBuilder::insert(Operation *op) {
357   if (block)
358     block->getOperations().insert(insertPoint, op);
359 
360   if (listener)
361     listener->notifyOperationInserted(op);
362   return op;
363 }
364 
365 /// Add new block with 'argTypes' arguments and set the insertion point to the
366 /// end of it. The block is inserted at the provided insertion point of
367 /// 'parent'.
createBlock(Region * parent,Region::iterator insertPt,TypeRange argTypes)368 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
369                               TypeRange argTypes) {
370   assert(parent && "expected valid parent region");
371   if (insertPt == Region::iterator())
372     insertPt = parent->end();
373 
374   Block *b = new Block();
375   b->addArguments(argTypes);
376   parent->getBlocks().insert(insertPt, b);
377   setInsertionPointToEnd(b);
378 
379   if (listener)
380     listener->notifyBlockCreated(b);
381   return b;
382 }
383 
384 /// Add new block with 'argTypes' arguments and set the insertion point to the
385 /// end of it.  The block is placed before 'insertBefore'.
createBlock(Block * insertBefore,TypeRange argTypes)386 Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
387   assert(insertBefore && "expected valid insertion block");
388   return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
389                      argTypes);
390 }
391 
392 /// Create an operation given the fields represented as an OperationState.
createOperation(const OperationState & state)393 Operation *OpBuilder::createOperation(const OperationState &state) {
394   return insert(Operation::create(state));
395 }
396 
397 /// Attempts to fold the given operation and places new results within
398 /// 'results'. Returns success if the operation was folded, failure otherwise.
399 /// Note: This function does not erase the operation on a successful fold.
tryFold(Operation * op,SmallVectorImpl<Value> & results)400 LogicalResult OpBuilder::tryFold(Operation *op,
401                                  SmallVectorImpl<Value> &results) {
402   results.reserve(op->getNumResults());
403   auto cleanupFailure = [&] {
404     results.assign(op->result_begin(), op->result_end());
405     return failure();
406   };
407 
408   // If this operation is already a constant, there is nothing to do.
409   if (matchPattern(op, m_Constant()))
410     return cleanupFailure();
411 
412   // Check to see if any operands to the operation is constant and whether
413   // the operation knows how to constant fold itself.
414   SmallVector<Attribute, 4> constOperands(op->getNumOperands());
415   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
416     matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
417 
418   // Try to fold the operation.
419   SmallVector<OpFoldResult, 4> foldResults;
420   if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
421     return cleanupFailure();
422 
423   // A temporary builder used for creating constants during folding.
424   OpBuilder cstBuilder(context);
425   SmallVector<Operation *, 1> generatedConstants;
426 
427   // Populate the results with the folded results.
428   Dialect *dialect = op->getDialect();
429   for (auto &it : llvm::enumerate(foldResults)) {
430     // Normal values get pushed back directly.
431     if (auto value = it.value().dyn_cast<Value>()) {
432       results.push_back(value);
433       continue;
434     }
435 
436     // Otherwise, try to materialize a constant operation.
437     if (!dialect)
438       return cleanupFailure();
439 
440     // Ask the dialect to materialize a constant operation for this value.
441     Attribute attr = it.value().get<Attribute>();
442     auto *constOp = dialect->materializeConstant(
443         cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
444     if (!constOp) {
445       // Erase any generated constants.
446       for (Operation *cst : generatedConstants)
447         cst->erase();
448       return cleanupFailure();
449     }
450     assert(matchPattern(constOp, m_Constant()));
451 
452     generatedConstants.push_back(constOp);
453     results.push_back(constOp->getResult(0));
454   }
455 
456   // If we were successful, insert any generated constants.
457   for (Operation *cst : generatedConstants)
458     insert(cst);
459 
460   return success();
461 }
462 
clone(Operation & op,BlockAndValueMapping & mapper)463 Operation *OpBuilder::clone(Operation &op, BlockAndValueMapping &mapper) {
464   Operation *newOp = op.clone(mapper);
465   // The `insert` call below handles the notification for inserting `newOp`
466   // itself. But if `newOp` has any regions, we need to notify the listener
467   // about any ops that got inserted inside those regions as part of cloning.
468   if (listener) {
469     auto walkFn = [&](Operation *walkedOp) {
470       listener->notifyOperationInserted(walkedOp);
471     };
472     for (Region &region : newOp->getRegions())
473       region.walk(walkFn);
474   }
475   return insert(newOp);
476 }
477 
clone(Operation & op)478 Operation *OpBuilder::clone(Operation &op) {
479   BlockAndValueMapping mapper;
480   return clone(op, mapper);
481 }
482