1 //===- LinalgTraits.h - Linalg Traits ---------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_LINALG_LINALGTRAITS_H_
10 #define MLIR_DIALECT_LINALG_LINALGTRAITS_H_
11 
12 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
13 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/Support/LLVM.h"
19 
20 namespace mlir {
21 namespace OpTrait {
22 namespace linalg {
23 
24 /// This class provides the API for ops that are known to have a specified
25 /// number of inputs, all passed as operands. Use as a trait as follows:
26 ///
27 ///   class DotOp : public Op<DotOp, OpTrait::NInputs<2>::Impl> {
28 ///
29 template <unsigned N> class NInputs {
30 public:
31   template <typename ConcreteType>
32   class Impl : public OpTrait::TraitBase<ConcreteType, NInputs<N>::Impl> {
33   public:
getNumInputs()34     static unsigned getNumInputs() { return N; }
35   };
36 };
37 
38 /// This class provides the API for ops that are known to not have init tensor
39 /// operands. Use as a trait as follows:
40 ///
41 ///   class CopyOp : public Op<CopyOp, OpTrait::ZeroInitTensors> {
42 ///
43 template <typename ConcreteType>
44 class ZeroInitTensors : public TraitBase<ConcreteType, ZeroInitTensors> {
45 public:
getNumInitTensors()46   static unsigned getNumInitTensors() { return 0; }
47 };
48 
49 /// This class provides the API for ops that are known to have a specified
50 /// number of outputs, all passed as operands. Use as a trait as follows:
51 ///
52 ///   class DotOp : public Op<DotOp, OpTrait::NOutputs<2>::Impl> {
53 ///
54 template <unsigned N> class NOutputs {
55 public:
56   template <typename ConcreteType>
57   class Impl : public OpTrait::TraitBase<ConcreteType, NOutputs<N>::Impl> {
58   public:
getNumOutputs()59     static unsigned getNumOutputs() { return N; }
60   };
61 };
62 
63 /// This class provides a verifier for structured ops that are known to operate
64 /// on buffers or tensors. This trait must be used in conjunction with an op
65 /// definition or a trait that provides the methods `getNumInputs` and
66 /// `getNumOutputs`. Use as a trait as follows:
67 ///
68 ///   class DotOp : public Op<DotOp, OpTrait::StructuredOpTraits> {
69 ///
70 template <typename ConcreteType>
71 class StructuredOpTraits
72     : public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> {
73 public:
verifyTrait(Operation * op)74   static LogicalResult verifyTrait(Operation *op) {
75     ConcreteType concreteOp = cast<ConcreteType>(op);
76     auto nOperands = concreteOp.getNumInputsAndOutputBuffers();
77     if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands)))
78       return failure();
79     if (op->getNumResults() > concreteOp.getNumOutputs())
80       return op->emitError("unexpected #results > #outputs");
81     return success();
82   }
83 };
84 
85 /// This class provides a verifier for structured ops that are known to operate
86 /// on buffers or tensors and that support `ins`, `outs` and `init` arguments.
87 /// This trait must be used in conjunction with an op definition or a trait that
88 /// provides the methods `getNumInputs` and `getNumOutputs`.
89 ///
90 /// Use as a trait as follows:
91 ///
92 ///   class MatmulOp : public Op<MatmulOp, OpTrait::NamedStructuredOpTrait> {
93 ///
94 template <typename ConcreteType>
95 class NamedStructuredOpTrait
96     : public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTrait> {
97 public:
getNumInputs()98   unsigned getNumInputs() {
99     return cast<ConcreteType>(this->getOperation()).inputs().size();
100   }
getNumInitTensors()101   unsigned getNumInitTensors() {
102     return cast<ConcreteType>(this->getOperation()).init_tensors().size();
103   }
getNumOutputs()104   unsigned getNumOutputs() {
105     ConcreteType concreteOp = cast<ConcreteType>(this->getOperation());
106     return concreteOp.output_buffers().size() +
107            concreteOp.result_tensors().size();
108   }
verifyTrait(Operation * op)109   static LogicalResult verifyTrait(Operation *op) {
110     ConcreteType concreteOp = cast<ConcreteType>(op);
111     unsigned nInputAndBufferOperands =
112         concreteOp.getNumInputsAndOutputBuffers();
113     if (failed(
114             OpTrait::impl::verifyAtLeastNOperands(op, nInputAndBufferOperands)))
115       return failure();
116 
117     SmallVector<AffineExpr, 4> redDims;
118     concreteOp.getReductionDims(redDims);
119     // If no result and no reduction, only check there is no init tensor and we
120     // are done.
121     if (redDims.empty() || op->getNumResults() == 0) {
122       if (!concreteOp.init_tensors().empty())
123         return op->emitError("expected empty `init` when op has no "
124                              "results or no reduction dims");
125       return success();
126     }
127 
128     // Only a single tensor result supported atm.
129     if (op->getNumResults() != 1)
130       return op->emitError(
131           "expected single tensor result when reduction present");
132 
133     if (concreteOp.init_tensors().size() != op->getNumResults())
134       return op->emitError(
135           "expected #init tensors to match #results when reduction present");
136 
137     for (unsigned idx = 0, e = op->getNumResults(); idx < e; ++idx)
138       if (concreteOp.init_tensors()[idx].getType() != op->getResultTypes()[idx])
139         return op->emitError("expected init tensor #")
140                << idx << " of the same type as result #" << idx;
141 
142     // Output tensor indexing map may not depend on reduction index.
143     // TODO: this is not yet tested. Add a test when linalg.generic switches to
144     // this representation.
145     for (unsigned idx = 0, e = concreteOp.getNumOutputs(); idx < e; ++idx) {
146       AffineMap outputMap = concreteOp.getOutputIndexingMap(idx);
147       for (auto expr : outputMap.getResults()) {
148         for (auto dim : redDims) {
149           unsigned pos = dim.cast<AffineDimExpr>().getPosition();
150           if (expr.isFunctionOfDim(pos))
151             return op->emitError(
152                        "unexpected single tensor output indexing map ")
153                    << "is function of reduction dim @" << pos;
154         }
155       }
156     }
157 
158     return success();
159   }
160 };
161 
162 } // namespace linalg
163 } // namespace OpTrait
164 } // namespace mlir
165 
166 #endif // MLIR_DIALECT_LINALG_LINALGTRAITS_H_
167