1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/OperationSupport.h"
18 
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include <cstddef>
23 
24 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
25 
26 using namespace mlir;
27 using namespace mlir::omp;
28 
initialize()29 void OpenMPDialect::initialize() {
30   addOperations<
31 #define GET_OP_LIST
32 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
33       >();
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // ParallelOp
38 //===----------------------------------------------------------------------===//
39 
build(OpBuilder & builder,OperationState & state,ArrayRef<NamedAttribute> attributes)40 void ParallelOp::build(OpBuilder &builder, OperationState &state,
41                        ArrayRef<NamedAttribute> attributes) {
42   ParallelOp::build(
43       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
44       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
45       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
46       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
47       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
48   state.addAttributes(attributes);
49 }
50 
51 /// Parse a list of operands with types.
52 ///
53 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
54 /// ssa-id-and-type-list ::= ssa-id-and-type |
55 ///                          ssa-id-and-type `,` ssa-id-and-type-list
56 /// ssa-id-and-type ::= ssa-id `:` type
57 static ParseResult
parseOperandAndTypeList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & operands,SmallVectorImpl<Type> & types)58 parseOperandAndTypeList(OpAsmParser &parser,
59                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
60                         SmallVectorImpl<Type> &types) {
61   if (parser.parseLParen())
62     return failure();
63 
64   do {
65     OpAsmParser::OperandType operand;
66     Type type;
67     if (parser.parseOperand(operand) || parser.parseColonType(type))
68       return failure();
69     operands.push_back(operand);
70     types.push_back(type);
71   } while (succeeded(parser.parseOptionalComma()));
72 
73   if (parser.parseRParen())
74     return failure();
75 
76   return success();
77 }
78 
79 /// Parse an allocate clause with allocators and a list of operands with types.
80 ///
81 /// operand-and-type-list ::= `(` allocate-operand-list `)`
82 /// allocate-operand-list :: = allocate-operand |
83 ///                            allocator-operand `,` allocate-operand-list
84 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
85 /// ssa-id-and-type ::= ssa-id `:` type
parseAllocateAndAllocator(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & operandsAllocate,SmallVectorImpl<Type> & typesAllocate,SmallVectorImpl<OpAsmParser::OperandType> & operandsAllocator,SmallVectorImpl<Type> & typesAllocator)86 static ParseResult parseAllocateAndAllocator(
87     OpAsmParser &parser,
88     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
89     SmallVectorImpl<Type> &typesAllocate,
90     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
91     SmallVectorImpl<Type> &typesAllocator) {
92   if (parser.parseLParen())
93     return failure();
94 
95   do {
96     OpAsmParser::OperandType operand;
97     Type type;
98 
99     if (parser.parseOperand(operand) || parser.parseColonType(type))
100       return failure();
101     operandsAllocator.push_back(operand);
102     typesAllocator.push_back(type);
103     if (parser.parseArrow())
104       return failure();
105     if (parser.parseOperand(operand) || parser.parseColonType(type))
106       return failure();
107 
108     operandsAllocate.push_back(operand);
109     typesAllocate.push_back(type);
110   } while (succeeded(parser.parseOptionalComma()));
111 
112   if (parser.parseRParen())
113     return failure();
114 
115   return success();
116 }
117 
verifyParallelOp(ParallelOp op)118 static LogicalResult verifyParallelOp(ParallelOp op) {
119   if (op.allocate_vars().size() != op.allocators_vars().size())
120     return op.emitError(
121         "expected equal sizes for allocate and allocator variables");
122   return success();
123 }
124 
printParallelOp(OpAsmPrinter & p,ParallelOp op)125 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
126   p << "omp.parallel";
127 
128   if (auto ifCond = op.if_expr_var())
129     p << " if(" << ifCond << " : " << ifCond.getType() << ")";
130 
131   if (auto threads = op.num_threads_var())
132     p << " num_threads(" << threads << " : " << threads.getType() << ")";
133 
134   // Print private, firstprivate, shared and copyin parameters
135   auto printDataVars = [&p](StringRef name, OperandRange vars) {
136     if (vars.size()) {
137       p << " " << name << "(";
138       for (unsigned i = 0; i < vars.size(); ++i) {
139         std::string separator = i == vars.size() - 1 ? ")" : ", ";
140         p << vars[i] << " : " << vars[i].getType() << separator;
141       }
142     }
143   };
144 
145   // Print allocator and allocate parameters
146   auto printAllocateAndAllocator = [&p](OperandRange varsAllocate,
147                                         OperandRange varsAllocator) {
148     if (varsAllocate.empty())
149       return;
150 
151     p << " allocate(";
152     for (unsigned i = 0; i < varsAllocate.size(); ++i) {
153       std::string separator = i == varsAllocate.size() - 1 ? ")" : ", ";
154       p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
155       p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
156     }
157   };
158 
159   printDataVars("private", op.private_vars());
160   printDataVars("firstprivate", op.firstprivate_vars());
161   printDataVars("shared", op.shared_vars());
162   printDataVars("copyin", op.copyin_vars());
163   printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars());
164 
165   if (auto def = op.default_val())
166     p << " default(" << def->drop_front(3) << ")";
167 
168   if (auto bind = op.proc_bind_val())
169     p << " proc_bind(" << bind << ")";
170 
171   p.printRegion(op.getRegion());
172 }
173 
174 /// Emit an error if the same clause is present more than once on an operation.
allowedOnce(OpAsmParser & parser,llvm::StringRef clause,llvm::StringRef operation)175 static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause,
176                                llvm::StringRef operation) {
177   return parser.emitError(parser.getNameLoc())
178          << " at most one " << clause << " clause can appear on the "
179          << operation << " operation";
180 }
181 
182 /// Parses a parallel operation.
183 ///
184 /// operation ::= `omp.parallel` clause-list
185 /// clause-list ::= clause | clause clause-list
186 /// clause ::= if | numThreads | private | firstprivate | shared | copyin |
187 ///            default | procBind
188 /// if ::= `if` `(` ssa-id `)`
189 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
190 /// private ::= `private` operand-and-type-list
191 /// firstprivate ::= `firstprivate` operand-and-type-list
192 /// shared ::= `shared` operand-and-type-list
193 /// copyin ::= `copyin` operand-and-type-list
194 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list
195 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
196 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
197 ///
198 /// Note that each clause can only appear once in the clase-list.
parseParallelOp(OpAsmParser & parser,OperationState & result)199 static ParseResult parseParallelOp(OpAsmParser &parser,
200                                    OperationState &result) {
201   std::pair<OpAsmParser::OperandType, Type> ifCond;
202   std::pair<OpAsmParser::OperandType, Type> numThreads;
203   SmallVector<OpAsmParser::OperandType, 4> privates;
204   SmallVector<Type, 4> privateTypes;
205   SmallVector<OpAsmParser::OperandType, 4> firstprivates;
206   SmallVector<Type, 4> firstprivateTypes;
207   SmallVector<OpAsmParser::OperandType, 4> shareds;
208   SmallVector<Type, 4> sharedTypes;
209   SmallVector<OpAsmParser::OperandType, 4> copyins;
210   SmallVector<Type, 4> copyinTypes;
211   SmallVector<OpAsmParser::OperandType, 4> allocates;
212   SmallVector<Type, 4> allocateTypes;
213   SmallVector<OpAsmParser::OperandType, 4> allocators;
214   SmallVector<Type, 4> allocatorTypes;
215   std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0};
216   llvm::StringRef keyword;
217   bool defaultVal = false;
218   bool procBind = false;
219 
220   const int ifClausePos = 0;
221   const int numThreadsClausePos = 1;
222   const int privateClausePos = 2;
223   const int firstprivateClausePos = 3;
224   const int sharedClausePos = 4;
225   const int copyinClausePos = 5;
226   const int allocateClausePos = 6;
227   const int allocatorPos = 7;
228   const llvm::StringRef opName = result.name.getStringRef();
229 
230   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
231     if (keyword == "if") {
232       // Fail if there was already another if condition
233       if (segments[ifClausePos])
234         return allowedOnce(parser, "if", opName);
235       if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
236           parser.parseColonType(ifCond.second) || parser.parseRParen())
237         return failure();
238       segments[ifClausePos] = 1;
239     } else if (keyword == "num_threads") {
240       // fail if there was already another num_threads clause
241       if (segments[numThreadsClausePos])
242         return allowedOnce(parser, "num_threads", opName);
243       if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
244           parser.parseColonType(numThreads.second) || parser.parseRParen())
245         return failure();
246       segments[numThreadsClausePos] = 1;
247     } else if (keyword == "private") {
248       // fail if there was already another private clause
249       if (segments[privateClausePos])
250         return allowedOnce(parser, "private", opName);
251       if (parseOperandAndTypeList(parser, privates, privateTypes))
252         return failure();
253       segments[privateClausePos] = privates.size();
254     } else if (keyword == "firstprivate") {
255       // fail if there was already another firstprivate clause
256       if (segments[firstprivateClausePos])
257         return allowedOnce(parser, "firstprivate", opName);
258       if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
259         return failure();
260       segments[firstprivateClausePos] = firstprivates.size();
261     } else if (keyword == "shared") {
262       // fail if there was already another shared clause
263       if (segments[sharedClausePos])
264         return allowedOnce(parser, "shared", opName);
265       if (parseOperandAndTypeList(parser, shareds, sharedTypes))
266         return failure();
267       segments[sharedClausePos] = shareds.size();
268     } else if (keyword == "copyin") {
269       // fail if there was already another copyin clause
270       if (segments[copyinClausePos])
271         return allowedOnce(parser, "copyin", opName);
272       if (parseOperandAndTypeList(parser, copyins, copyinTypes))
273         return failure();
274       segments[copyinClausePos] = copyins.size();
275     } else if (keyword == "allocate") {
276       // fail if there was already another allocate clause
277       if (segments[allocateClausePos])
278         return allowedOnce(parser, "allocate", opName);
279       if (parseAllocateAndAllocator(parser, allocates, allocateTypes,
280                                     allocators, allocatorTypes))
281         return failure();
282       segments[allocateClausePos] = allocates.size();
283       segments[allocatorPos] = allocators.size();
284     } else if (keyword == "default") {
285       // fail if there was already another default clause
286       if (defaultVal)
287         return allowedOnce(parser, "default", opName);
288       defaultVal = true;
289       llvm::StringRef defval;
290       if (parser.parseLParen() || parser.parseKeyword(&defval) ||
291           parser.parseRParen())
292         return failure();
293       llvm::SmallString<16> attrval;
294       // The def prefix is required for the attribute as "private" is a keyword
295       // in C++
296       attrval += "def";
297       attrval += defval;
298       auto attr = parser.getBuilder().getStringAttr(attrval);
299       result.addAttribute("default_val", attr);
300     } else if (keyword == "proc_bind") {
301       // fail if there was already another proc_bind clause
302       if (procBind)
303         return allowedOnce(parser, "proc_bind", opName);
304       procBind = true;
305       llvm::StringRef bind;
306       if (parser.parseLParen() || parser.parseKeyword(&bind) ||
307           parser.parseRParen())
308         return failure();
309       auto attr = parser.getBuilder().getStringAttr(bind);
310       result.addAttribute("proc_bind_val", attr);
311     } else {
312       return parser.emitError(parser.getNameLoc())
313              << keyword << " is not a valid clause for the " << opName
314              << " operation";
315     }
316   }
317 
318   // Add if parameter
319   if (segments[ifClausePos] &&
320       parser.resolveOperand(ifCond.first, ifCond.second, result.operands))
321     return failure();
322 
323   // Add num_threads parameter
324   if (segments[numThreadsClausePos] &&
325       parser.resolveOperand(numThreads.first, numThreads.second,
326                             result.operands))
327     return failure();
328 
329   // Add private parameters
330   if (segments[privateClausePos] &&
331       parser.resolveOperands(privates, privateTypes, privates[0].location,
332                              result.operands))
333     return failure();
334 
335   // Add firstprivate parameters
336   if (segments[firstprivateClausePos] &&
337       parser.resolveOperands(firstprivates, firstprivateTypes,
338                              firstprivates[0].location, result.operands))
339     return failure();
340 
341   // Add shared parameters
342   if (segments[sharedClausePos] &&
343       parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
344                              result.operands))
345     return failure();
346 
347   // Add copyin parameters
348   if (segments[copyinClausePos] &&
349       parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
350                              result.operands))
351     return failure();
352 
353   // Add allocate parameters
354   if (segments[allocateClausePos] &&
355       parser.resolveOperands(allocates, allocateTypes, allocates[0].location,
356                              result.operands))
357     return failure();
358 
359   // Add allocator parameters
360   if (segments[allocatorPos] &&
361       parser.resolveOperands(allocators, allocatorTypes, allocators[0].location,
362                              result.operands))
363     return failure();
364 
365   result.addAttribute("operand_segment_sizes",
366                       parser.getBuilder().getI32VectorAttr(segments));
367 
368   Region *body = result.addRegion();
369   SmallVector<OpAsmParser::OperandType, 4> regionArgs;
370   SmallVector<Type, 4> regionArgTypes;
371   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
372     return failure();
373   return success();
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // WsLoopOp
378 //===----------------------------------------------------------------------===//
379 
build(OpBuilder & builder,OperationState & state,ValueRange lowerBound,ValueRange upperBound,ValueRange step,ArrayRef<NamedAttribute> attributes)380 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
381                      ValueRange lowerBound, ValueRange upperBound,
382                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
383   build(builder, state, TypeRange(), lowerBound, upperBound, step,
384         /*private_vars=*/ValueRange(),
385         /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
386         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
387         /*schedule_val=*/nullptr, /*schedule_chunk_var=*/nullptr,
388         /*collapse_val=*/nullptr,
389         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr);
390   state.addAttributes(attributes);
391 }
392 
393 #define GET_OP_CLASSES
394 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
395