1 //===- LinalgToSPIRVPass.cpp - Linalg to SPIR-V conversion pass -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h" 12 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 13 #include "mlir/Dialect/SPIRV/SPIRVLowering.h" 14 15 using namespace mlir; 16 17 namespace { 18 /// A pass converting MLIR Linalg ops into SPIR-V ops. 19 class LinalgToSPIRVPass : public ConvertLinalgToSPIRVBase<LinalgToSPIRVPass> { 20 void runOnOperation() override; 21 }; 22 } // namespace 23 runOnOperation()24void LinalgToSPIRVPass::runOnOperation() { 25 MLIRContext *context = &getContext(); 26 ModuleOp module = getOperation(); 27 28 auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 29 std::unique_ptr<ConversionTarget> target = 30 spirv::SPIRVConversionTarget::get(targetAttr); 31 32 SPIRVTypeConverter typeConverter(targetAttr); 33 OwningRewritePatternList patterns; 34 populateLinalgToSPIRVPatterns(context, typeConverter, patterns); 35 populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); 36 37 // Allow builtin ops. 38 target->addLegalOp<ModuleOp, ModuleTerminatorOp>(); 39 target->addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 40 return typeConverter.isSignatureLegal(op.getType()) && 41 typeConverter.isLegal(&op.getBody()); 42 }); 43 44 if (failed(applyFullConversion(module, *target, std::move(patterns)))) 45 return signalPassFailure(); 46 } 47 createLinalgToSPIRVPass()48std::unique_ptr<OperationPass<ModuleOp>> mlir::createLinalgToSPIRVPass() { 49 return std::make_unique<LinalgToSPIRVPass>(); 50 } 51