1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/OperationSupport.h"
18
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include <cstddef>
23
24 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
25
26 using namespace mlir;
27 using namespace mlir::omp;
28
initialize()29 void OpenMPDialect::initialize() {
30 addOperations<
31 #define GET_OP_LIST
32 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
33 >();
34 }
35
36 //===----------------------------------------------------------------------===//
37 // ParallelOp
38 //===----------------------------------------------------------------------===//
39
build(OpBuilder & builder,OperationState & state,ArrayRef<NamedAttribute> attributes)40 void ParallelOp::build(OpBuilder &builder, OperationState &state,
41 ArrayRef<NamedAttribute> attributes) {
42 ParallelOp::build(
43 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
44 /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
45 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
46 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
47 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
48 state.addAttributes(attributes);
49 }
50
51 /// Parse a list of operands with types.
52 ///
53 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
54 /// ssa-id-and-type-list ::= ssa-id-and-type |
55 /// ssa-id-and-type `,` ssa-id-and-type-list
56 /// ssa-id-and-type ::= ssa-id `:` type
57 static ParseResult
parseOperandAndTypeList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & operands,SmallVectorImpl<Type> & types)58 parseOperandAndTypeList(OpAsmParser &parser,
59 SmallVectorImpl<OpAsmParser::OperandType> &operands,
60 SmallVectorImpl<Type> &types) {
61 if (parser.parseLParen())
62 return failure();
63
64 do {
65 OpAsmParser::OperandType operand;
66 Type type;
67 if (parser.parseOperand(operand) || parser.parseColonType(type))
68 return failure();
69 operands.push_back(operand);
70 types.push_back(type);
71 } while (succeeded(parser.parseOptionalComma()));
72
73 if (parser.parseRParen())
74 return failure();
75
76 return success();
77 }
78
79 /// Parse an allocate clause with allocators and a list of operands with types.
80 ///
81 /// operand-and-type-list ::= `(` allocate-operand-list `)`
82 /// allocate-operand-list :: = allocate-operand |
83 /// allocator-operand `,` allocate-operand-list
84 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
85 /// ssa-id-and-type ::= ssa-id `:` type
parseAllocateAndAllocator(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & operandsAllocate,SmallVectorImpl<Type> & typesAllocate,SmallVectorImpl<OpAsmParser::OperandType> & operandsAllocator,SmallVectorImpl<Type> & typesAllocator)86 static ParseResult parseAllocateAndAllocator(
87 OpAsmParser &parser,
88 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
89 SmallVectorImpl<Type> &typesAllocate,
90 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
91 SmallVectorImpl<Type> &typesAllocator) {
92 if (parser.parseLParen())
93 return failure();
94
95 do {
96 OpAsmParser::OperandType operand;
97 Type type;
98
99 if (parser.parseOperand(operand) || parser.parseColonType(type))
100 return failure();
101 operandsAllocator.push_back(operand);
102 typesAllocator.push_back(type);
103 if (parser.parseArrow())
104 return failure();
105 if (parser.parseOperand(operand) || parser.parseColonType(type))
106 return failure();
107
108 operandsAllocate.push_back(operand);
109 typesAllocate.push_back(type);
110 } while (succeeded(parser.parseOptionalComma()));
111
112 if (parser.parseRParen())
113 return failure();
114
115 return success();
116 }
117
verifyParallelOp(ParallelOp op)118 static LogicalResult verifyParallelOp(ParallelOp op) {
119 if (op.allocate_vars().size() != op.allocators_vars().size())
120 return op.emitError(
121 "expected equal sizes for allocate and allocator variables");
122 return success();
123 }
124
printParallelOp(OpAsmPrinter & p,ParallelOp op)125 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
126 p << "omp.parallel";
127
128 if (auto ifCond = op.if_expr_var())
129 p << " if(" << ifCond << " : " << ifCond.getType() << ")";
130
131 if (auto threads = op.num_threads_var())
132 p << " num_threads(" << threads << " : " << threads.getType() << ")";
133
134 // Print private, firstprivate, shared and copyin parameters
135 auto printDataVars = [&p](StringRef name, OperandRange vars) {
136 if (vars.size()) {
137 p << " " << name << "(";
138 for (unsigned i = 0; i < vars.size(); ++i) {
139 std::string separator = i == vars.size() - 1 ? ")" : ", ";
140 p << vars[i] << " : " << vars[i].getType() << separator;
141 }
142 }
143 };
144
145 // Print allocator and allocate parameters
146 auto printAllocateAndAllocator = [&p](OperandRange varsAllocate,
147 OperandRange varsAllocator) {
148 if (varsAllocate.empty())
149 return;
150
151 p << " allocate(";
152 for (unsigned i = 0; i < varsAllocate.size(); ++i) {
153 std::string separator = i == varsAllocate.size() - 1 ? ")" : ", ";
154 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
155 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
156 }
157 };
158
159 printDataVars("private", op.private_vars());
160 printDataVars("firstprivate", op.firstprivate_vars());
161 printDataVars("shared", op.shared_vars());
162 printDataVars("copyin", op.copyin_vars());
163 printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars());
164
165 if (auto def = op.default_val())
166 p << " default(" << def->drop_front(3) << ")";
167
168 if (auto bind = op.proc_bind_val())
169 p << " proc_bind(" << bind << ")";
170
171 p.printRegion(op.getRegion());
172 }
173
174 /// Emit an error if the same clause is present more than once on an operation.
allowedOnce(OpAsmParser & parser,llvm::StringRef clause,llvm::StringRef operation)175 static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause,
176 llvm::StringRef operation) {
177 return parser.emitError(parser.getNameLoc())
178 << " at most one " << clause << " clause can appear on the "
179 << operation << " operation";
180 }
181
182 /// Parses a parallel operation.
183 ///
184 /// operation ::= `omp.parallel` clause-list
185 /// clause-list ::= clause | clause clause-list
186 /// clause ::= if | numThreads | private | firstprivate | shared | copyin |
187 /// default | procBind
188 /// if ::= `if` `(` ssa-id `)`
189 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
190 /// private ::= `private` operand-and-type-list
191 /// firstprivate ::= `firstprivate` operand-and-type-list
192 /// shared ::= `shared` operand-and-type-list
193 /// copyin ::= `copyin` operand-and-type-list
194 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list
195 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
196 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
197 ///
198 /// Note that each clause can only appear once in the clase-list.
parseParallelOp(OpAsmParser & parser,OperationState & result)199 static ParseResult parseParallelOp(OpAsmParser &parser,
200 OperationState &result) {
201 std::pair<OpAsmParser::OperandType, Type> ifCond;
202 std::pair<OpAsmParser::OperandType, Type> numThreads;
203 SmallVector<OpAsmParser::OperandType, 4> privates;
204 SmallVector<Type, 4> privateTypes;
205 SmallVector<OpAsmParser::OperandType, 4> firstprivates;
206 SmallVector<Type, 4> firstprivateTypes;
207 SmallVector<OpAsmParser::OperandType, 4> shareds;
208 SmallVector<Type, 4> sharedTypes;
209 SmallVector<OpAsmParser::OperandType, 4> copyins;
210 SmallVector<Type, 4> copyinTypes;
211 SmallVector<OpAsmParser::OperandType, 4> allocates;
212 SmallVector<Type, 4> allocateTypes;
213 SmallVector<OpAsmParser::OperandType, 4> allocators;
214 SmallVector<Type, 4> allocatorTypes;
215 std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0};
216 llvm::StringRef keyword;
217 bool defaultVal = false;
218 bool procBind = false;
219
220 const int ifClausePos = 0;
221 const int numThreadsClausePos = 1;
222 const int privateClausePos = 2;
223 const int firstprivateClausePos = 3;
224 const int sharedClausePos = 4;
225 const int copyinClausePos = 5;
226 const int allocateClausePos = 6;
227 const int allocatorPos = 7;
228 const llvm::StringRef opName = result.name.getStringRef();
229
230 while (succeeded(parser.parseOptionalKeyword(&keyword))) {
231 if (keyword == "if") {
232 // Fail if there was already another if condition
233 if (segments[ifClausePos])
234 return allowedOnce(parser, "if", opName);
235 if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
236 parser.parseColonType(ifCond.second) || parser.parseRParen())
237 return failure();
238 segments[ifClausePos] = 1;
239 } else if (keyword == "num_threads") {
240 // fail if there was already another num_threads clause
241 if (segments[numThreadsClausePos])
242 return allowedOnce(parser, "num_threads", opName);
243 if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
244 parser.parseColonType(numThreads.second) || parser.parseRParen())
245 return failure();
246 segments[numThreadsClausePos] = 1;
247 } else if (keyword == "private") {
248 // fail if there was already another private clause
249 if (segments[privateClausePos])
250 return allowedOnce(parser, "private", opName);
251 if (parseOperandAndTypeList(parser, privates, privateTypes))
252 return failure();
253 segments[privateClausePos] = privates.size();
254 } else if (keyword == "firstprivate") {
255 // fail if there was already another firstprivate clause
256 if (segments[firstprivateClausePos])
257 return allowedOnce(parser, "firstprivate", opName);
258 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
259 return failure();
260 segments[firstprivateClausePos] = firstprivates.size();
261 } else if (keyword == "shared") {
262 // fail if there was already another shared clause
263 if (segments[sharedClausePos])
264 return allowedOnce(parser, "shared", opName);
265 if (parseOperandAndTypeList(parser, shareds, sharedTypes))
266 return failure();
267 segments[sharedClausePos] = shareds.size();
268 } else if (keyword == "copyin") {
269 // fail if there was already another copyin clause
270 if (segments[copyinClausePos])
271 return allowedOnce(parser, "copyin", opName);
272 if (parseOperandAndTypeList(parser, copyins, copyinTypes))
273 return failure();
274 segments[copyinClausePos] = copyins.size();
275 } else if (keyword == "allocate") {
276 // fail if there was already another allocate clause
277 if (segments[allocateClausePos])
278 return allowedOnce(parser, "allocate", opName);
279 if (parseAllocateAndAllocator(parser, allocates, allocateTypes,
280 allocators, allocatorTypes))
281 return failure();
282 segments[allocateClausePos] = allocates.size();
283 segments[allocatorPos] = allocators.size();
284 } else if (keyword == "default") {
285 // fail if there was already another default clause
286 if (defaultVal)
287 return allowedOnce(parser, "default", opName);
288 defaultVal = true;
289 llvm::StringRef defval;
290 if (parser.parseLParen() || parser.parseKeyword(&defval) ||
291 parser.parseRParen())
292 return failure();
293 llvm::SmallString<16> attrval;
294 // The def prefix is required for the attribute as "private" is a keyword
295 // in C++
296 attrval += "def";
297 attrval += defval;
298 auto attr = parser.getBuilder().getStringAttr(attrval);
299 result.addAttribute("default_val", attr);
300 } else if (keyword == "proc_bind") {
301 // fail if there was already another proc_bind clause
302 if (procBind)
303 return allowedOnce(parser, "proc_bind", opName);
304 procBind = true;
305 llvm::StringRef bind;
306 if (parser.parseLParen() || parser.parseKeyword(&bind) ||
307 parser.parseRParen())
308 return failure();
309 auto attr = parser.getBuilder().getStringAttr(bind);
310 result.addAttribute("proc_bind_val", attr);
311 } else {
312 return parser.emitError(parser.getNameLoc())
313 << keyword << " is not a valid clause for the " << opName
314 << " operation";
315 }
316 }
317
318 // Add if parameter
319 if (segments[ifClausePos] &&
320 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))
321 return failure();
322
323 // Add num_threads parameter
324 if (segments[numThreadsClausePos] &&
325 parser.resolveOperand(numThreads.first, numThreads.second,
326 result.operands))
327 return failure();
328
329 // Add private parameters
330 if (segments[privateClausePos] &&
331 parser.resolveOperands(privates, privateTypes, privates[0].location,
332 result.operands))
333 return failure();
334
335 // Add firstprivate parameters
336 if (segments[firstprivateClausePos] &&
337 parser.resolveOperands(firstprivates, firstprivateTypes,
338 firstprivates[0].location, result.operands))
339 return failure();
340
341 // Add shared parameters
342 if (segments[sharedClausePos] &&
343 parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
344 result.operands))
345 return failure();
346
347 // Add copyin parameters
348 if (segments[copyinClausePos] &&
349 parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
350 result.operands))
351 return failure();
352
353 // Add allocate parameters
354 if (segments[allocateClausePos] &&
355 parser.resolveOperands(allocates, allocateTypes, allocates[0].location,
356 result.operands))
357 return failure();
358
359 // Add allocator parameters
360 if (segments[allocatorPos] &&
361 parser.resolveOperands(allocators, allocatorTypes, allocators[0].location,
362 result.operands))
363 return failure();
364
365 result.addAttribute("operand_segment_sizes",
366 parser.getBuilder().getI32VectorAttr(segments));
367
368 Region *body = result.addRegion();
369 SmallVector<OpAsmParser::OperandType, 4> regionArgs;
370 SmallVector<Type, 4> regionArgTypes;
371 if (parser.parseRegion(*body, regionArgs, regionArgTypes))
372 return failure();
373 return success();
374 }
375
376 //===----------------------------------------------------------------------===//
377 // WsLoopOp
378 //===----------------------------------------------------------------------===//
379
build(OpBuilder & builder,OperationState & state,ValueRange lowerBound,ValueRange upperBound,ValueRange step,ArrayRef<NamedAttribute> attributes)380 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
381 ValueRange lowerBound, ValueRange upperBound,
382 ValueRange step, ArrayRef<NamedAttribute> attributes) {
383 build(builder, state, TypeRange(), lowerBound, upperBound, step,
384 /*private_vars=*/ValueRange(),
385 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
386 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
387 /*schedule_val=*/nullptr, /*schedule_chunk_var=*/nullptr,
388 /*collapse_val=*/nullptr,
389 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr);
390 state.addAttributes(attributes);
391 }
392
393 #define GET_OP_CLASSES
394 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
395