1 //===- LinalgToSPIRV.cpp - Linalg to SPIR-V dialect conversion ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h"
10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
11 #include "mlir/Dialect/Linalg/Utils/Utils.h"
12 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
13 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
14 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17 #include "mlir/IR/AffineExpr.h"
18
19 using namespace mlir;
20
21 //===----------------------------------------------------------------------===//
22 // Utilities
23 //===----------------------------------------------------------------------===//
24
25 /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
26 /// location invocation ID. This function will create necessary operations with
27 /// `builder` at the proper region containing `op`.
getLocalInvocationDimSize(Operation * op,int dim,Location loc,OpBuilder * builder)28 static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc,
29 OpBuilder *builder) {
30 assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
31 Value invocation = spirv::getBuiltinVariableValue(
32 op, spirv::BuiltIn::LocalInvocationId, *builder);
33 Type xType = invocation.getType().cast<ShapedType>().getElementType();
34 return builder->create<spirv::CompositeExtractOp>(
35 loc, xType, invocation, builder->getI32ArrayAttr({dim}));
36 }
37
38 //===----------------------------------------------------------------------===//
39 // Reduction (single workgroup)
40 //===----------------------------------------------------------------------===//
41
42 namespace {
43
44 /// A pattern to convert a linalg.generic op to SPIR-V ops under the condition
45 /// that the linalg.generic op is performing reduction with a workload size that
46 /// can fit in one workgroup.
47 class SingleWorkgroupReduction final
48 : public SPIRVOpLowering<linalg::GenericOp> {
49 public:
50 using SPIRVOpLowering<linalg::GenericOp>::SPIRVOpLowering;
51
52 /// Matches the given linalg.generic op as performing reduction and returns
53 /// the binary op kind if successful.
54 static Optional<linalg::RegionMatcher::BinaryOpKind>
55 matchAsPerformingReduction(linalg::GenericOp genericOp);
56
57 LogicalResult
58 matchAndRewrite(linalg::GenericOp genericOp, ArrayRef<Value> operands,
59 ConversionPatternRewriter &rewriter) const override;
60 };
61
62 } // namespace
63
64 Optional<linalg::RegionMatcher::BinaryOpKind>
matchAsPerformingReduction(linalg::GenericOp genericOp)65 SingleWorkgroupReduction::matchAsPerformingReduction(
66 linalg::GenericOp genericOp) {
67 Operation *op = genericOp.getOperation();
68
69 // Make sure the linalg.generic is working on memrefs.
70 if (!genericOp.hasBufferSemantics())
71 return llvm::None;
72
73 // Make sure this is reduction with one input and one output.
74 if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1)
75 return llvm::None;
76
77 auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
78 auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
79
80 // Make sure the original input has one dimension.
81 if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1)
82 return llvm::None;
83 // Make sure the original output has one element.
84 if (!originalOutputType.hasStaticShape() ||
85 originalOutputType.getNumElements() != 1)
86 return llvm::None;
87
88 if (!genericOp.hasSingleReductionLoop())
89 return llvm::None;
90
91 if (genericOp.indexing_maps().getValue().size() != 2)
92 return llvm::None;
93
94 // TODO: create utility functions for these checks in Linalg
95 // and use them.
96 auto inputMap = genericOp.indexing_maps().getValue()[0].cast<AffineMapAttr>();
97 auto outputMap =
98 genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>();
99 // The indexing map for the input should be `(i) -> (i)`.
100 if (inputMap.getValue() !=
101 AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext())))
102 return llvm::None;
103 // The indexing map for the input should be `(i) -> (0)`.
104 if (outputMap.getValue() !=
105 AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext())))
106 return llvm::None;
107
108 return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);
109 }
110
matchAndRewrite(linalg::GenericOp genericOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const111 LogicalResult SingleWorkgroupReduction::matchAndRewrite(
112 linalg::GenericOp genericOp, ArrayRef<Value> operands,
113 ConversionPatternRewriter &rewriter) const {
114 Operation *op = genericOp.getOperation();
115 auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
116 auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
117
118 auto binaryOpKind = matchAsPerformingReduction(genericOp);
119 if (!binaryOpKind)
120 return failure();
121
122 // Query the shader interface for local workgroup size to make sure the
123 // invocation configuration fits with the input memref's shape.
124 DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp);
125 if (!localSize)
126 return failure();
127
128 if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
129 return failure();
130 if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1),
131 [](const APInt &size) { return !size.isOneValue(); }))
132 return failure();
133
134 // TODO: Query the target environment to make sure the current
135 // workload fits in a local workgroup.
136
137 Value convertedInput = operands[0], convertedOutput = operands[1];
138 Location loc = genericOp.getLoc();
139
140 // Get the invocation ID.
141 Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter);
142
143 // TODO: Load to Workgroup storage class first.
144
145 // Get the input element accessed by this invocation.
146 Value inputElementPtr = spirv::getElementPtr(
147 typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
148 Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr);
149
150 // Perform the group reduction operation.
151 Value groupOperation;
152 #define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \
153 case linalg::RegionMatcher::BinaryOpKind::opKind: { \
154 groupOperation = rewriter.create<spirv::spvOp>( \
155 loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \
156 spirv::GroupOperation::Reduce, inputElement, \
157 /*cluster_size=*/nullptr); \
158 } break
159 switch (*binaryOpKind) {
160 CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp);
161 }
162 #undef CREATE_GROUP_NON_UNIFORM_BIN_OP
163
164 // Get the output element accessed by this reduction.
165 Value zero = spirv::ConstantOp::getZero(
166 typeConverter.getIndexType(rewriter.getContext()), loc, rewriter);
167 SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
168 Value outputElementPtr =
169 spirv::getElementPtr(typeConverter, originalOutputType, convertedOutput,
170 zeroIndices, loc, rewriter);
171
172 // Write out the final reduction result. This should be only conducted by one
173 // invocation. We use spv.GroupNonUniformElect to find the invocation with the
174 // lowest ID.
175 //
176 // ```
177 // if (spv.GroupNonUniformElect) { output = ... }
178 // ```
179
180 Value condition = rewriter.create<spirv::GroupNonUniformElectOp>(
181 loc, spirv::Scope::Subgroup);
182
183 auto createAtomicOp = [&](OpBuilder &builder) {
184 #define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \
185 case linalg::RegionMatcher::BinaryOpKind::opKind: { \
186 builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device, \
187 spirv::MemorySemantics::AcquireRelease, \
188 groupOperation); \
189 } break
190 switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); }
191 #undef CREATE_ATOMIC_BIN_OP
192 };
193
194 spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter);
195
196 rewriter.eraseOp(genericOp);
197 return success();
198 }
199
200 //===----------------------------------------------------------------------===//
201 // Pattern population
202 //===----------------------------------------------------------------------===//
203
populateLinalgToSPIRVPatterns(MLIRContext * context,SPIRVTypeConverter & typeConverter,OwningRewritePatternList & patterns)204 void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context,
205 SPIRVTypeConverter &typeConverter,
206 OwningRewritePatternList &patterns) {
207 patterns.insert<SingleWorkgroupReduction>(context, typeConverter);
208 }
209