1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
10 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
11 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/Pass/Pass.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // Printing op availability pass
19 //===----------------------------------------------------------------------===//
20 
21 namespace {
22 /// A pass for testing SPIR-V op availability.
23 struct PrintOpAvailability
24     : public PassWrapper<PrintOpAvailability, FunctionPass> {
25   void runOnFunction() override;
26 };
27 } // end anonymous namespace
28 
runOnFunction()29 void PrintOpAvailability::runOnFunction() {
30   auto f = getFunction();
31   llvm::outs() << f.getName() << "\n";
32 
33   Dialect *spvDialect = getContext().getLoadedDialect("spv");
34 
35   f->walk([&](Operation *op) {
36     if (op->getDialect() != spvDialect)
37       return WalkResult::advance();
38 
39     auto opName = op->getName();
40     auto &os = llvm::outs();
41 
42     if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
43       os << opName << " min version: "
44          << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n";
45 
46     if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
47       os << opName << " max version: "
48          << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n";
49 
50     if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
51       os << opName << " extensions: [";
52       for (const auto &exts : extension.getExtensions()) {
53         os << " [";
54         llvm::interleaveComma(exts, os, [&](spirv::Extension ext) {
55           os << spirv::stringifyExtension(ext);
56         });
57         os << "]";
58       }
59       os << " ]\n";
60     }
61 
62     if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
63       os << opName << " capabilities: [";
64       for (const auto &caps : capability.getCapabilities()) {
65         os << " [";
66         llvm::interleaveComma(caps, os, [&](spirv::Capability cap) {
67           os << spirv::stringifyCapability(cap);
68         });
69         os << "]";
70       }
71       os << " ]\n";
72     }
73     os.flush();
74 
75     return WalkResult::advance();
76   });
77 }
78 
79 namespace mlir {
registerPrintOpAvailabilityPass()80 void registerPrintOpAvailabilityPass() {
81   PassRegistration<PrintOpAvailability> printOpAvailabilityPass(
82       "test-spirv-op-availability", "Test SPIR-V op availability");
83 }
84 } // namespace mlir
85 
86 //===----------------------------------------------------------------------===//
87 // Converting target environment pass
88 //===----------------------------------------------------------------------===//
89 
90 namespace {
91 /// A pass for testing SPIR-V op availability.
92 struct ConvertToTargetEnv
93     : public PassWrapper<ConvertToTargetEnv, FunctionPass> {
94   void runOnFunction() override;
95 };
96 
97 struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
98   ConvertToAtomCmpExchangeWeak(MLIRContext *context);
99   LogicalResult matchAndRewrite(Operation *op,
100                                 PatternRewriter &rewriter) const override;
101 };
102 
103 struct ConvertToBitReverse : public RewritePattern {
104   ConvertToBitReverse(MLIRContext *context);
105   LogicalResult matchAndRewrite(Operation *op,
106                                 PatternRewriter &rewriter) const override;
107 };
108 
109 struct ConvertToGroupNonUniformBallot : public RewritePattern {
110   ConvertToGroupNonUniformBallot(MLIRContext *context);
111   LogicalResult matchAndRewrite(Operation *op,
112                                 PatternRewriter &rewriter) const override;
113 };
114 
115 struct ConvertToModule : public RewritePattern {
116   ConvertToModule(MLIRContext *context);
117   LogicalResult matchAndRewrite(Operation *op,
118                                 PatternRewriter &rewriter) const override;
119 };
120 
121 struct ConvertToSubgroupBallot : public RewritePattern {
122   ConvertToSubgroupBallot(MLIRContext *context);
123   LogicalResult matchAndRewrite(Operation *op,
124                                 PatternRewriter &rewriter) const override;
125 };
126 } // end anonymous namespace
127 
runOnFunction()128 void ConvertToTargetEnv::runOnFunction() {
129   MLIRContext *context = &getContext();
130   FuncOp fn = getFunction();
131 
132   auto targetEnv = fn.getOperation()
133                        ->getAttr(spirv::getTargetEnvAttrName())
134                        .cast<spirv::TargetEnvAttr>();
135   if (!targetEnv) {
136     fn.emitError("missing 'spv.target_env' attribute");
137     return signalPassFailure();
138   }
139 
140   auto target = spirv::SPIRVConversionTarget::get(targetEnv);
141 
142   OwningRewritePatternList patterns;
143   patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
144                   ConvertToGroupNonUniformBallot, ConvertToModule,
145                   ConvertToSubgroupBallot>(context);
146 
147   if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
148     return signalPassFailure();
149 }
150 
ConvertToAtomCmpExchangeWeak(MLIRContext * context)151 ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
152     : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op",
153                      {"spv.AtomicCompareExchangeWeak"}, 1, context) {}
154 
155 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const156 ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
157                                               PatternRewriter &rewriter) const {
158   Value ptr = op->getOperand(0);
159   Value value = op->getOperand(1);
160   Value comparator = op->getOperand(2);
161 
162   // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in
163   // memory semantics to additionally require AtomicStorage capability.
164   rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
165       op, value.getType(), ptr, spirv::Scope::Workgroup,
166       spirv::MemorySemantics::AcquireRelease |
167           spirv::MemorySemantics::AtomicCounterMemory,
168       spirv::MemorySemantics::Acquire, value, comparator);
169   return success();
170 }
171 
ConvertToBitReverse(MLIRContext * context)172 ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
173     : RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1,
174                      context) {}
175 
176 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const177 ConvertToBitReverse::matchAndRewrite(Operation *op,
178                                      PatternRewriter &rewriter) const {
179   Value predicate = op->getOperand(0);
180 
181   rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
182       op, op->getResult(0).getType(), predicate);
183   return success();
184 }
185 
ConvertToGroupNonUniformBallot(MLIRContext * context)186 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
187     MLIRContext *context)
188     : RewritePattern("test.convert_to_group_non_uniform_ballot_op",
189                      {"spv.GroupNonUniformBallot"}, 1, context) {}
190 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const191 LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
192     Operation *op, PatternRewriter &rewriter) const {
193   Value predicate = op->getOperand(0);
194 
195   rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
196       op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
197   return success();
198 }
199 
ConvertToModule(MLIRContext * context)200 ConvertToModule::ConvertToModule(MLIRContext *context)
201     : RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {}
202 
203 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const204 ConvertToModule::matchAndRewrite(Operation *op,
205                                  PatternRewriter &rewriter) const {
206   rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
207       op, spirv::AddressingModel::PhysicalStorageBuffer64,
208       spirv::MemoryModel::Vulkan);
209   return success();
210 }
211 
ConvertToSubgroupBallot(MLIRContext * context)212 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
213     : RewritePattern("test.convert_to_subgroup_ballot_op",
214                      {"spv.SubgroupBallotKHR"}, 1, context) {}
215 
216 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const217 ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
218                                          PatternRewriter &rewriter) const {
219   Value predicate = op->getOperand(0);
220 
221   rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
222       op, op->getResult(0).getType(), predicate);
223   return success();
224 }
225 
226 namespace mlir {
registerConvertToTargetEnvPass()227 void registerConvertToTargetEnvPass() {
228   PassRegistration<ConvertToTargetEnv> convertToTargetEnvPass(
229       "test-spirv-target-env", "Test SPIR-V target environment");
230 }
231 } // namespace mlir
232