1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "llvm/AsmParser/Parser.h"
25 #include "llvm/IR/Attributes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/Support/SourceMgr.h"
29 
30 using namespace mlir;
31 using namespace NVVM;
32 
33 //===----------------------------------------------------------------------===//
34 // Printing/parsing for NVVM ops
35 //===----------------------------------------------------------------------===//
36 
printNVVMIntrinsicOp(OpAsmPrinter & p,Operation * op)37 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
38   p << op->getName() << " " << op->getOperands();
39   if (op->getNumResults() > 0)
40     p << " : " << op->getResultTypes();
41 }
42 
43 // <operation> ::=
44 //     `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
45 //      ({return_value_and_is_valid})? : result_type
parseNVVMShflSyncBflyOp(OpAsmParser & parser,OperationState & result)46 static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
47                                            OperationState &result) {
48   SmallVector<OpAsmParser::OperandType, 8> ops;
49   Type resultType;
50   if (parser.parseOperandList(ops) ||
51       parser.parseOptionalAttrDict(result.attributes) ||
52       parser.parseColonType(resultType) ||
53       parser.addTypeToList(resultType, result.types))
54     return failure();
55 
56   auto type = resultType.cast<LLVM::LLVMType>();
57   for (auto &attr : result.attributes) {
58     if (attr.first != "return_value_and_is_valid")
59       continue;
60     if (type.isStructTy() && type.getStructNumElements() > 0)
61       type = type.getStructElementType(0);
62     break;
63   }
64 
65   auto int32Ty = LLVM::LLVMType::getInt32Ty(parser.getBuilder().getContext());
66   return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
67                                 parser.getNameLoc(), result.operands);
68 }
69 
70 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
parseNVVMVoteBallotOp(OpAsmParser & parser,OperationState & result)71 static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
72                                          OperationState &result) {
73   MLIRContext *context = parser.getBuilder().getContext();
74   auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
75   auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
76 
77   SmallVector<OpAsmParser::OperandType, 8> ops;
78   Type type;
79   return failure(parser.parseOperandList(ops) ||
80                  parser.parseOptionalAttrDict(result.attributes) ||
81                  parser.parseColonType(type) ||
82                  parser.addTypeToList(type, result.types) ||
83                  parser.resolveOperands(ops, {int32Ty, int1Ty},
84                                         parser.getNameLoc(), result.operands));
85 }
86 
verify(MmaOp op)87 static LogicalResult verify(MmaOp op) {
88   MLIRContext *context = op.getContext();
89   auto f16Ty = LLVM::LLVMType::getHalfTy(context);
90   auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2);
91   auto f32Ty = LLVM::LLVMType::getFloatTy(context);
92   auto f16x2x4StructTy = LLVM::LLVMType::getStructTy(
93       context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
94   auto f32x8StructTy = LLVM::LLVMType::getStructTy(
95       context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
96 
97   SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
98                                       op.getOperandTypes().end());
99   if (operand_types != SmallVector<Type, 8>(8, f16x2Ty) &&
100       operand_types != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
101                                              f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
102                                              f32Ty, f32Ty, f32Ty}) {
103     return op.emitOpError(
104         "expected operands to be 4 <halfx2>s followed by either "
105         "4 <halfx2>s or 8 floats");
106   }
107   if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) {
108     return op.emitOpError("expected result type to be a struct of either 4 "
109                           "<halfx2>s or 8 floats");
110   }
111 
112   auto alayout = op->getAttrOfType<StringAttr>("alayout");
113   auto blayout = op->getAttrOfType<StringAttr>("blayout");
114 
115   if (!(alayout && blayout) ||
116       !(alayout.getValue() == "row" || alayout.getValue() == "col") ||
117       !(blayout.getValue() == "row" || blayout.getValue() == "col")) {
118     return op.emitOpError(
119         "alayout and blayout attributes must be set to either "
120         "\"row\" or \"col\"");
121   }
122 
123   if (operand_types == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
124                                              f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
125                                              f32Ty, f32Ty, f32Ty} &&
126       op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
127       blayout.getValue() == "col") {
128     return success();
129   }
130   return op.emitOpError("unimplemented mma.sync variant");
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // NVVMDialect initialization, type parsing, and registration.
135 //===----------------------------------------------------------------------===//
136 
137 // TODO: This should be the llvm.nvvm dialect once this is supported.
initialize()138 void NVVMDialect::initialize() {
139   addOperations<
140 #define GET_OP_LIST
141 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
142       >();
143 
144   // Support unknown operations because not all NVVM operations are registered.
145   allowUnknownOperations();
146 }
147 
148 #define GET_OP_CLASSES
149 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
150