1 //===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various common operation folders. These folders
10 // are intended to be used by dialects to support common folding behavior
11 // without requiring each dialect to provide its own implementation.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #ifndef MLIR_DIALECT_COMMONFOLDERS_H
16 #define MLIR_DIALECT_COMMONFOLDERS_H
17
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22
23 namespace mlir {
24 /// Performs constant folding `calculate` with element-wise behavior on the two
25 /// attributes in `operands` and returns the result if possible.
26 template <class AttrElementT,
27 class ElementValueT = typename AttrElementT::ValueType,
28 class CalculationT =
29 function_ref<ElementValueT(ElementValueT, ElementValueT)>>
constFoldBinaryOp(ArrayRef<Attribute> operands,const CalculationT & calculate)30 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
31 const CalculationT &calculate) {
32 assert(operands.size() == 2 && "binary op takes two operands");
33 if (!operands[0] || !operands[1])
34 return {};
35 if (operands[0].getType() != operands[1].getType())
36 return {};
37
38 if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
39 auto lhs = operands[0].cast<AttrElementT>();
40 auto rhs = operands[1].cast<AttrElementT>();
41
42 return AttrElementT::get(lhs.getType(),
43 calculate(lhs.getValue(), rhs.getValue()));
44 } else if (operands[0].isa<SplatElementsAttr>() &&
45 operands[1].isa<SplatElementsAttr>()) {
46 // Both operands are splats so we can avoid expanding the values out and
47 // just fold based on the splat value.
48 auto lhs = operands[0].cast<SplatElementsAttr>();
49 auto rhs = operands[1].cast<SplatElementsAttr>();
50
51 auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
52 rhs.getSplatValue<ElementValueT>());
53 return DenseElementsAttr::get(lhs.getType(), elementResult);
54 } else if (operands[0].isa<ElementsAttr>() &&
55 operands[1].isa<ElementsAttr>()) {
56 // Operands are ElementsAttr-derived; perform an element-wise fold by
57 // expanding the values.
58 auto lhs = operands[0].cast<ElementsAttr>();
59 auto rhs = operands[1].cast<ElementsAttr>();
60
61 auto lhsIt = lhs.getValues<ElementValueT>().begin();
62 auto rhsIt = rhs.getValues<ElementValueT>().begin();
63 SmallVector<ElementValueT, 4> elementResults;
64 elementResults.reserve(lhs.getNumElements());
65 for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
66 elementResults.push_back(calculate(*lhsIt, *rhsIt));
67 return DenseElementsAttr::get(lhs.getType(), elementResults);
68 }
69 return {};
70 }
71 } // namespace mlir
72
73 #endif // MLIR_DIALECT_COMMONFOLDERS_H
74