1//===-- ROCDLOps.td - ROCDL IR dialect op definition file --*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the ROCDL IR operation definition file.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef ROCDLIR_OPS
14#define ROCDLIR_OPS
15
16include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
17include "mlir/Interfaces/SideEffectInterfaces.td"
18
19//===----------------------------------------------------------------------===//
20// ROCDL dialect definitions
21//===----------------------------------------------------------------------===//
22
23def ROCDL_Dialect : Dialect {
24  let name = "rocdl";
25  let cppNamespace = "::mlir::ROCDL";
26  let dependentDialects = ["LLVM::LLVMDialect"];
27}
28
29//===----------------------------------------------------------------------===//
30// ROCDL op definitions
31//===----------------------------------------------------------------------===//
32
33class ROCDL_Op<string mnemonic, list<OpTrait> traits = []> :
34  LLVM_OpBase<ROCDL_Dialect, mnemonic, traits> {
35}
36
37//===----------------------------------------------------------------------===//
38// ROCDL special register op definitions
39//===----------------------------------------------------------------------===//
40
41class ROCDL_SpecialRegisterOp<string mnemonic,
42    list<OpTrait> traits = []> :
43  ROCDL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>,
44  Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
45  string llvmBuilder = "$res = createIntrinsicCall(builder,"
46    # "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic) # ");";
47  let assemblyFormat = "attr-dict `:` type($res)";
48}
49
50class ROCDL_DeviceFunctionOp<string mnemonic, string device_function,
51                             int parameter, list<OpTrait> traits = []> :
52  ROCDL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>,
53  Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
54  string llvmBuilder = "$res = createDeviceFunctionCall(builder, \""
55  # device_function # "\", " # parameter # ");";
56  let assemblyFormat = "attr-dict `:` type($res)";
57}
58
59//===----------------------------------------------------------------------===//
60// Thread index and Block index
61
62def ROCDL_ThreadIdXOp : ROCDL_SpecialRegisterOp<"workitem.id.x">;
63def ROCDL_ThreadIdYOp : ROCDL_SpecialRegisterOp<"workitem.id.y">;
64def ROCDL_ThreadIdZOp : ROCDL_SpecialRegisterOp<"workitem.id.z">;
65
66def ROCDL_BlockIdXOp : ROCDL_SpecialRegisterOp<"workgroup.id.x">;
67def ROCDL_BlockIdYOp : ROCDL_SpecialRegisterOp<"workgroup.id.y">;
68def ROCDL_BlockIdZOp : ROCDL_SpecialRegisterOp<"workgroup.id.z">;
69
70//===----------------------------------------------------------------------===//
71// Thread range and Block range
72
73def ROCDL_BlockDimXOp : ROCDL_DeviceFunctionOp<"workgroup.dim.x",
74                                               "__ockl_get_local_size", 0>;
75
76def ROCDL_BlockDimYOp : ROCDL_DeviceFunctionOp<"workgroup.dim.y",
77                                               "__ockl_get_local_size", 1>;
78
79def ROCDL_BlockDimZOp : ROCDL_DeviceFunctionOp<"workgroup.dim.z",
80                                               "__ockl_get_local_size", 2>;
81
82def ROCDL_GridDimXOp : ROCDL_DeviceFunctionOp<"grid.dim.x",
83                                               "__ockl_get_global_size", 0>;
84
85def ROCDL_GridDimYOp : ROCDL_DeviceFunctionOp<"grid.dim.y",
86                                               "__ockl_get_global_size", 1>;
87
88def ROCDL_GridDimZOp : ROCDL_DeviceFunctionOp<"grid.dim.z",
89                                               "__ockl_get_global_size", 2>;
90
91//===----------------------------------------------------------------------===//
92// Synchronization primitives
93
94def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
95  string llvmBuilder = [{
96    llvm::LLVMContext &llvmContext = builder.getContext();
97    builder.CreateFence(llvm::AtomicOrdering::Release,
98                        llvmContext.getOrInsertSyncScopeID("workgroup"));
99    createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_barrier);
100    builder.CreateFence(llvm::AtomicOrdering::Acquire,
101                        llvmContext.getOrInsertSyncScopeID("workgroup"));
102  }];
103  let assemblyFormat = "attr-dict";
104}
105
106//===---------------------------------------------------------------------===//
107// Xdlops intrinsics
108
109class ROCDL_Mfma_IntrOp<string mnemonic, list<OpTrait> traits = []> :
110  LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
111                  "amdgcn_" # !subst(".","_", mnemonic),
112                  [], [], traits, 1>,
113  Arguments<(ins Variadic<LLVM_Type>:$args)> {
114  let assemblyFormat =
115    "$args attr-dict `:` functional-type($args, $res)";
116}
117
118def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32">;
119def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2f32">;
120def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f32">;
121def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32">;
122def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4f16">;
123def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16">;
124def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f16">;
125def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16">;
126def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16">;
127def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16">;
128def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8bf16">;
129def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16">;
130def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16">;
131def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32">;
132def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4f16">;
133def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x4i8">;
134def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x4i8">;
135def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.4x4x4i8">;
136def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8">;
137def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8">;
138
139//===---------------------------------------------------------------------===//
140// Vector buffer load/store intrinsics
141
142def ROCDL_MubufLoadOp :
143  ROCDL_Op<"buffer.load">,
144  Results<(outs LLVM_Type:$res)>,
145  Arguments<(ins LLVM_Type:$rsrc,
146                 LLVM_Type:$vindex,
147                 LLVM_Type:$offset,
148                 LLVM_Type:$glc,
149                 LLVM_Type:$slc)>{
150  string llvmBuilder = [{
151      $res = createIntrinsicCall(builder,
152          llvm::Intrinsic::amdgcn_buffer_load, {$rsrc, $vindex, $offset, $glc,
153          $slc}, {$_resultType});
154  }];
155  let parser = [{ return parseROCDLMubufLoadOp(parser, result); }];
156  let printer = [{
157    Operation *op = this->getOperation();
158    p << op->getName() << " " << op->getOperands()
159      << " : " << op->getResultTypes();
160  }];
161}
162
163def ROCDL_MubufStoreOp :
164  ROCDL_Op<"buffer.store">,
165  Arguments<(ins LLVM_Type:$vdata,
166                 LLVM_Type:$rsrc,
167                 LLVM_Type:$vindex,
168                 LLVM_Type:$offset,
169                 LLVM_Type:$glc,
170                 LLVM_Type:$slc)>{
171  string llvmBuilder = [{
172    auto vdataType = convertType(op.vdata().getType().cast<LLVM::LLVMType>());
173    createIntrinsicCall(builder,
174          llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex,
175          $offset, $glc, $slc}, {vdataType});
176  }];
177  let parser = [{ return parseROCDLMubufStoreOp(parser, result); }];
178  let printer = [{
179    Operation *op = this->getOperation();
180    p << op->getName() << " " << op->getOperands()
181      << " : " << vdata().getType();
182  }];
183}
184
185#endif // ROCDLIR_OPS
186