1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
14 
15 #include "mlir/Dialect/SPIRV/ParserUtils.h"
16 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/FunctionImplementation.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/Interfaces/CallInterfaces.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 
29 using namespace mlir;
30 
31 // TODO: generate these strings using ODS.
32 static constexpr const char kMemoryAccessAttrName[] = "memory_access";
33 static constexpr const char kSourceMemoryAccessAttrName[] =
34     "source_memory_access";
35 static constexpr const char kAlignmentAttrName[] = "alignment";
36 static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
37 static constexpr const char kBranchWeightAttrName[] = "branch_weights";
38 static constexpr const char kCallee[] = "callee";
39 static constexpr const char kClusterSize[] = "cluster_size";
40 static constexpr const char kControl[] = "control";
41 static constexpr const char kDefaultValueAttrName[] = "default_value";
42 static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
43 static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
44 static constexpr const char kFnNameAttrName[] = "fn";
45 static constexpr const char kGroupOperationAttrName[] = "group_operation";
46 static constexpr const char kIndicesAttrName[] = "indices";
47 static constexpr const char kInitializerAttrName[] = "initializer";
48 static constexpr const char kInterfaceAttrName[] = "interface";
49 static constexpr const char kMemoryScopeAttrName[] = "memory_scope";
50 static constexpr const char kSemanticsAttrName[] = "semantics";
51 static constexpr const char kSpecIdAttrName[] = "spec_id";
52 static constexpr const char kTypeAttrName[] = "type";
53 static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
54 static constexpr const char kValueAttrName[] = "value";
55 static constexpr const char kValuesAttrName[] = "values";
56 static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
57 
58 //===----------------------------------------------------------------------===//
59 // Common utility functions
60 //===----------------------------------------------------------------------===//
61 
62 /// Returns true if the given op is a function-like op or nested in a
63 /// function-like op without a module-like op in the middle.
isNestedInFunctionLikeOp(Operation * op)64 static bool isNestedInFunctionLikeOp(Operation *op) {
65   if (!op)
66     return false;
67   if (op->hasTrait<OpTrait::SymbolTable>())
68     return false;
69   if (op->hasTrait<OpTrait::FunctionLike>())
70     return true;
71   return isNestedInFunctionLikeOp(op->getParentOp());
72 }
73 
74 /// Returns true if the given op is an module-like op that maintains a symbol
75 /// table.
isDirectInModuleLikeOp(Operation * op)76 static bool isDirectInModuleLikeOp(Operation *op) {
77   return op && op->hasTrait<OpTrait::SymbolTable>();
78 }
79 
extractValueFromConstOp(Operation * op,int32_t & value)80 static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
81   auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
82   if (!constOp) {
83     return failure();
84   }
85   auto valueAttr = constOp.value();
86   auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
87   if (!integerValueAttr) {
88     return failure();
89   }
90   value = integerValueAttr.getInt();
91   return success();
92 }
93 
94 template <typename Ty>
95 static ArrayAttr
getStrArrayAttrForEnumList(Builder & builder,ArrayRef<Ty> enumValues,function_ref<StringRef (Ty)> stringifyFn)96 getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
97                            function_ref<StringRef(Ty)> stringifyFn) {
98   if (enumValues.empty()) {
99     return nullptr;
100   }
101   SmallVector<StringRef, 1> enumValStrs;
102   enumValStrs.reserve(enumValues.size());
103   for (auto val : enumValues) {
104     enumValStrs.emplace_back(stringifyFn(val));
105   }
106   return builder.getStrArrayAttr(enumValStrs);
107 }
108 
109 /// Parses the next string attribute in `parser` as an enumerant of the given
110 /// `EnumClass`.
111 template <typename EnumClass>
112 static ParseResult
parseEnumStrAttr(EnumClass & value,OpAsmParser & parser,StringRef attrName=spirv::attributeName<EnumClass> ())113 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
114                  StringRef attrName = spirv::attributeName<EnumClass>()) {
115   Attribute attrVal;
116   NamedAttrList attr;
117   auto loc = parser.getCurrentLocation();
118   if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
119                             attrName, attr)) {
120     return failure();
121   }
122   if (!attrVal.isa<StringAttr>()) {
123     return parser.emitError(loc, "expected ")
124            << attrName << " attribute specified as string";
125   }
126   auto attrOptional =
127       spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
128   if (!attrOptional) {
129     return parser.emitError(loc, "invalid ")
130            << attrName << " attribute specification: " << attrVal;
131   }
132   value = attrOptional.getValue();
133   return success();
134 }
135 
136 /// Parses the next string attribute in `parser` as an enumerant of the given
137 /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
138 /// attribute with the enum class's name as attribute name.
139 template <typename EnumClass>
140 static ParseResult
parseEnumStrAttr(EnumClass & value,OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())141 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
142                  StringRef attrName = spirv::attributeName<EnumClass>()) {
143   if (parseEnumStrAttr(value, parser)) {
144     return failure();
145   }
146   state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
147                                    llvm::bit_cast<int32_t>(value)));
148   return success();
149 }
150 
151 /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
152 /// and inserts the enumerant into `state` as an 32-bit integer attribute with
153 /// the enum class's name as attribute name.
154 template <typename EnumClass>
155 static ParseResult
parseEnumKeywordAttr(EnumClass & value,OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())156 parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
157                      OperationState &state,
158                      StringRef attrName = spirv::attributeName<EnumClass>()) {
159   if (parseEnumKeywordAttr(value, parser)) {
160     return failure();
161   }
162   state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
163                                    llvm::bit_cast<int32_t>(value)));
164   return success();
165 }
166 
167 /// Parses Function, Selection and Loop control attributes. If no control is
168 /// specified, "None" is used as a default.
169 template <typename EnumClass>
170 static ParseResult
parseControlAttribute(OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())171 parseControlAttribute(OpAsmParser &parser, OperationState &state,
172                       StringRef attrName = spirv::attributeName<EnumClass>()) {
173   if (succeeded(parser.parseOptionalKeyword(kControl))) {
174     EnumClass control;
175     if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
176         parser.parseRParen())
177       return failure();
178     return success();
179   }
180   // Set control to "None" otherwise.
181   Builder builder = parser.getBuilder();
182   state.addAttribute(attrName, builder.getI32IntegerAttr(0));
183   return success();
184 }
185 
186 /// Parses optional memory access attributes attached to a memory access
187 /// operand/pointer. Specifically, parses the following syntax:
188 ///     (`[` memory-access `]`)?
189 /// where:
190 ///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
191 ///         integer-literal | `"NonTemporal"`
parseMemoryAccessAttributes(OpAsmParser & parser,OperationState & state)192 static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
193                                                OperationState &state) {
194   // Parse an optional list of attributes staring with '['
195   if (parser.parseOptionalLSquare()) {
196     // Nothing to do
197     return success();
198   }
199 
200   spirv::MemoryAccess memoryAccessAttr;
201   if (parseEnumStrAttr(memoryAccessAttr, parser, state,
202                        kMemoryAccessAttrName)) {
203     return failure();
204   }
205 
206   if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
207     // Parse integer attribute for alignment.
208     Attribute alignmentAttr;
209     Type i32Type = parser.getBuilder().getIntegerType(32);
210     if (parser.parseComma() ||
211         parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
212                               state.attributes)) {
213       return failure();
214     }
215   }
216   return parser.parseRSquare();
217 }
218 
219 // TODO Make sure to merge this and the previous function into one template
220 // parameterized by memory access attribute name and alignment. Doing so now
221 // results in VS2017 in producing an internal error (at the call site) that's
222 // not detailed enough to understand what is happening.
parseSourceMemoryAccessAttributes(OpAsmParser & parser,OperationState & state)223 static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
224                                                      OperationState &state) {
225   // Parse an optional list of attributes staring with '['
226   if (parser.parseOptionalLSquare()) {
227     // Nothing to do
228     return success();
229   }
230 
231   spirv::MemoryAccess memoryAccessAttr;
232   if (parseEnumStrAttr(memoryAccessAttr, parser, state,
233                        kSourceMemoryAccessAttrName)) {
234     return failure();
235   }
236 
237   if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
238     // Parse integer attribute for alignment.
239     Attribute alignmentAttr;
240     Type i32Type = parser.getBuilder().getIntegerType(32);
241     if (parser.parseComma() ||
242         parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
243                               state.attributes)) {
244       return failure();
245     }
246   }
247   return parser.parseRSquare();
248 }
249 
250 template <typename MemoryOpTy>
printMemoryAccessAttribute(MemoryOpTy memoryOp,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs,Optional<spirv::MemoryAccess> memoryAccessAtrrValue=None,Optional<uint32_t> alignmentAttrValue=None)251 static void printMemoryAccessAttribute(
252     MemoryOpTy memoryOp, OpAsmPrinter &printer,
253     SmallVectorImpl<StringRef> &elidedAttrs,
254     Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
255     Optional<uint32_t> alignmentAttrValue = None) {
256   // Print optional memory access attribute.
257   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
258                                               : memoryOp.memory_access())) {
259     elidedAttrs.push_back(kMemoryAccessAttrName);
260 
261     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
262 
263     if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
264       // Print integer alignment attribute.
265       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
266                                                : memoryOp.alignment())) {
267         elidedAttrs.push_back(kAlignmentAttrName);
268         printer << ", " << alignment;
269       }
270     }
271     printer << "]";
272   }
273   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
274 }
275 
276 // TODO Make sure to merge this and the previous function into one template
277 // parameterized by memory access attribute name and alignment. Doing so now
278 // results in VS2017 in producing an internal error (at the call site) that's
279 // not detailed enough to understand what is happening.
280 template <typename MemoryOpTy>
printSourceMemoryAccessAttribute(MemoryOpTy memoryOp,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs,Optional<spirv::MemoryAccess> memoryAccessAtrrValue=None,Optional<uint32_t> alignmentAttrValue=None)281 static void printSourceMemoryAccessAttribute(
282     MemoryOpTy memoryOp, OpAsmPrinter &printer,
283     SmallVectorImpl<StringRef> &elidedAttrs,
284     Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
285     Optional<uint32_t> alignmentAttrValue = None) {
286 
287   printer << ", ";
288 
289   // Print optional memory access attribute.
290   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
291                                               : memoryOp.memory_access())) {
292     elidedAttrs.push_back(kSourceMemoryAccessAttrName);
293 
294     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
295 
296     if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
297       // Print integer alignment attribute.
298       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
299                                                : memoryOp.alignment())) {
300         elidedAttrs.push_back(kSourceAlignmentAttrName);
301         printer << ", " << alignment;
302       }
303     }
304     printer << "]";
305   }
306   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
307 }
308 
verifyCastOp(Operation * op,bool requireSameBitWidth=true,bool skipBitWidthCheck=false)309 static LogicalResult verifyCastOp(Operation *op,
310                                   bool requireSameBitWidth = true,
311                                   bool skipBitWidthCheck = false) {
312   // Some CastOps have no limit on bit widths for result and operand type.
313   if (skipBitWidthCheck)
314     return success();
315 
316   Type operandType = op->getOperand(0).getType();
317   Type resultType = op->getResult(0).getType();
318 
319   // ODS checks that result type and operand type have the same shape.
320   if (auto vectorType = operandType.dyn_cast<VectorType>()) {
321     operandType = vectorType.getElementType();
322     resultType = resultType.cast<VectorType>().getElementType();
323   }
324 
325   if (auto coopMatrixType =
326           operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
327     operandType = coopMatrixType.getElementType();
328     resultType =
329         resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
330   }
331 
332   auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
333   auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
334   auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
335 
336   if (requireSameBitWidth) {
337     if (!isSameBitWidth) {
338       return op->emitOpError(
339                  "expected the same bit widths for operand type and result "
340                  "type, but provided ")
341              << operandType << " and " << resultType;
342     }
343     return success();
344   }
345 
346   if (isSameBitWidth) {
347     return op->emitOpError(
348                "expected the different bit widths for operand type and result "
349                "type, but provided ")
350            << operandType << " and " << resultType;
351   }
352   return success();
353 }
354 
355 template <typename MemoryOpTy>
verifyMemoryAccessAttribute(MemoryOpTy memoryOp)356 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
357   // ODS checks for attributes values. Just need to verify that if the
358   // memory-access attribute is Aligned, then the alignment attribute must be
359   // present.
360   auto *op = memoryOp.getOperation();
361   auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
362   if (!memAccessAttr) {
363     // Alignment attribute shouldn't be present if memory access attribute is
364     // not present.
365     if (op->getAttr(kAlignmentAttrName)) {
366       return memoryOp.emitOpError(
367           "invalid alignment specification without aligned memory access "
368           "specification");
369     }
370     return success();
371   }
372 
373   auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
374   auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
375 
376   if (!memAccess) {
377     return memoryOp.emitOpError("invalid memory access specifier: ")
378            << memAccessVal;
379   }
380 
381   if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
382     if (!op->getAttr(kAlignmentAttrName)) {
383       return memoryOp.emitOpError("missing alignment value");
384     }
385   } else {
386     if (op->getAttr(kAlignmentAttrName)) {
387       return memoryOp.emitOpError(
388           "invalid alignment specification with non-aligned memory access "
389           "specification");
390     }
391   }
392   return success();
393 }
394 
395 // TODO Make sure to merge this and the previous function into one template
396 // parameterized by memory access attribute name and alignment. Doing so now
397 // results in VS2017 in producing an internal error (at the call site) that's
398 // not detailed enough to understand what is happening.
399 template <typename MemoryOpTy>
verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)400 static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
401   // ODS checks for attributes values. Just need to verify that if the
402   // memory-access attribute is Aligned, then the alignment attribute must be
403   // present.
404   auto *op = memoryOp.getOperation();
405   auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
406   if (!memAccessAttr) {
407     // Alignment attribute shouldn't be present if memory access attribute is
408     // not present.
409     if (op->getAttr(kSourceAlignmentAttrName)) {
410       return memoryOp.emitOpError(
411           "invalid alignment specification without aligned memory access "
412           "specification");
413     }
414     return success();
415   }
416 
417   auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
418   auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
419 
420   if (!memAccess) {
421     return memoryOp.emitOpError("invalid memory access specifier: ")
422            << memAccessVal;
423   }
424 
425   if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
426     if (!op->getAttr(kSourceAlignmentAttrName)) {
427       return memoryOp.emitOpError("missing alignment value");
428     }
429   } else {
430     if (op->getAttr(kSourceAlignmentAttrName)) {
431       return memoryOp.emitOpError(
432           "invalid alignment specification with non-aligned memory access "
433           "specification");
434     }
435   }
436   return success();
437 }
438 
439 template <typename BarrierOp>
verifyMemorySemantics(BarrierOp op)440 static LogicalResult verifyMemorySemantics(BarrierOp op) {
441   // According to the SPIR-V specification:
442   // "Despite being a mask and allowing multiple bits to be combined, it is
443   // invalid for more than one of these four bits to be set: Acquire, Release,
444   // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
445   // Release semantics is done by setting the AcquireRelease bit, not by setting
446   // two bits."
447   auto memorySemantics = op.memory_semantics();
448   auto atMostOneInSet = spirv::MemorySemantics::Acquire |
449                         spirv::MemorySemantics::Release |
450                         spirv::MemorySemantics::AcquireRelease |
451                         spirv::MemorySemantics::SequentiallyConsistent;
452 
453   auto bitCount = llvm::countPopulation(
454       static_cast<uint32_t>(memorySemantics & atMostOneInSet));
455   if (bitCount > 1) {
456     return op.emitError("expected at most one of these four memory constraints "
457                         "to be set: `Acquire`, `Release`,"
458                         "`AcquireRelease` or `SequentiallyConsistent`");
459   }
460   return success();
461 }
462 
463 template <typename LoadStoreOpTy>
verifyLoadStorePtrAndValTypes(LoadStoreOpTy op,Value ptr,Value val)464 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
465                                                    Value val) {
466   // ODS already checks ptr is spirv::PointerType. Just check that the pointee
467   // type of the pointer and the type of the value are the same
468   //
469   // TODO: Check that the value type satisfies restrictions of
470   // SPIR-V OpLoad/OpStore operations
471   if (val.getType() !=
472       ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
473     return op.emitOpError("mismatch in result type and pointer type");
474   }
475   return success();
476 }
477 
478 template <typename BlockReadWriteOpTy>
verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,Value ptr,Value val)479 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
480                                                         Value ptr, Value val) {
481   auto valType = val.getType();
482   if (auto valVecTy = valType.dyn_cast<VectorType>())
483     valType = valVecTy.getElementType();
484 
485   if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
486     return op.emitOpError("mismatch in result type and pointer type");
487   }
488   return success();
489 }
490 
parseVariableDecorations(OpAsmParser & parser,OperationState & state)491 static ParseResult parseVariableDecorations(OpAsmParser &parser,
492                                             OperationState &state) {
493   auto builtInName = llvm::convertToSnakeFromCamelCase(
494       stringifyDecoration(spirv::Decoration::BuiltIn));
495   if (succeeded(parser.parseOptionalKeyword("bind"))) {
496     Attribute set, binding;
497     // Parse optional descriptor binding
498     auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
499         stringifyDecoration(spirv::Decoration::DescriptorSet));
500     auto bindingName = llvm::convertToSnakeFromCamelCase(
501         stringifyDecoration(spirv::Decoration::Binding));
502     Type i32Type = parser.getBuilder().getIntegerType(32);
503     if (parser.parseLParen() ||
504         parser.parseAttribute(set, i32Type, descriptorSetName,
505                               state.attributes) ||
506         parser.parseComma() ||
507         parser.parseAttribute(binding, i32Type, bindingName,
508                               state.attributes) ||
509         parser.parseRParen()) {
510       return failure();
511     }
512   } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
513     StringAttr builtIn;
514     if (parser.parseLParen() ||
515         parser.parseAttribute(builtIn, builtInName, state.attributes) ||
516         parser.parseRParen()) {
517       return failure();
518     }
519   }
520 
521   // Parse other attributes
522   if (parser.parseOptionalAttrDict(state.attributes))
523     return failure();
524 
525   return success();
526 }
527 
printVariableDecorations(Operation * op,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs)528 static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
529                                      SmallVectorImpl<StringRef> &elidedAttrs) {
530   // Print optional descriptor binding
531   auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
532       stringifyDecoration(spirv::Decoration::DescriptorSet));
533   auto bindingName = llvm::convertToSnakeFromCamelCase(
534       stringifyDecoration(spirv::Decoration::Binding));
535   auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
536   auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
537   if (descriptorSet && binding) {
538     elidedAttrs.push_back(descriptorSetName);
539     elidedAttrs.push_back(bindingName);
540     printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
541             << ")";
542   }
543 
544   // Print BuiltIn attribute if present
545   auto builtInName = llvm::convertToSnakeFromCamelCase(
546       stringifyDecoration(spirv::Decoration::BuiltIn));
547   if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
548     printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
549     elidedAttrs.push_back(builtInName);
550   }
551 
552   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
553 }
554 
555 // Get bit width of types.
getBitWidth(Type type)556 static unsigned getBitWidth(Type type) {
557   if (type.isa<spirv::PointerType>()) {
558     // Just return 64 bits for pointer types for now.
559     // TODO: Make sure not caller relies on the actual pointer width value.
560     return 64;
561   }
562 
563   if (type.isIntOrFloat())
564     return type.getIntOrFloatBitWidth();
565 
566   if (auto vectorType = type.dyn_cast<VectorType>()) {
567     assert(vectorType.getElementType().isIntOrFloat());
568     return vectorType.getNumElements() *
569            vectorType.getElementType().getIntOrFloatBitWidth();
570   }
571   llvm_unreachable("unhandled bit width computation for type");
572 }
573 
574 /// Walks the given type hierarchy with the given indices, potentially down
575 /// to component granularity, to select an element type. Returns null type and
576 /// emits errors with the given loc on failure.
577 static Type
getElementType(Type type,ArrayRef<int32_t> indices,function_ref<InFlightDiagnostic (StringRef)> emitErrorFn)578 getElementType(Type type, ArrayRef<int32_t> indices,
579                function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
580   if (indices.empty()) {
581     emitErrorFn("expected at least one index for spv.CompositeExtract");
582     return nullptr;
583   }
584 
585   for (auto index : indices) {
586     if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
587       if (cType.hasCompileTimeKnownNumElements() &&
588           (index < 0 ||
589            static_cast<uint64_t>(index) >= cType.getNumElements())) {
590         emitErrorFn("index ") << index << " out of bounds for " << type;
591         return nullptr;
592       }
593       type = cType.getElementType(index);
594     } else {
595       emitErrorFn("cannot extract from non-composite type ")
596           << type << " with index " << index;
597       return nullptr;
598     }
599   }
600   return type;
601 }
602 
603 static Type
getElementType(Type type,Attribute indices,function_ref<InFlightDiagnostic (StringRef)> emitErrorFn)604 getElementType(Type type, Attribute indices,
605                function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
606   auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
607   if (!indicesArrayAttr) {
608     emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
609     return nullptr;
610   }
611   if (!indicesArrayAttr.size()) {
612     emitErrorFn("expected at least one index for spv.CompositeExtract");
613     return nullptr;
614   }
615 
616   SmallVector<int32_t, 2> indexVals;
617   for (auto indexAttr : indicesArrayAttr) {
618     auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
619     if (!indexIntAttr) {
620       emitErrorFn("expected an 32-bit integer for index, but found '")
621           << indexAttr << "'";
622       return nullptr;
623     }
624     indexVals.push_back(indexIntAttr.getInt());
625   }
626   return getElementType(type, indexVals, emitErrorFn);
627 }
628 
getElementType(Type type,Attribute indices,Location loc)629 static Type getElementType(Type type, Attribute indices, Location loc) {
630   auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
631     return ::mlir::emitError(loc, err);
632   };
633   return getElementType(type, indices, errorFn);
634 }
635 
getElementType(Type type,Attribute indices,OpAsmParser & parser,llvm::SMLoc loc)636 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
637                            llvm::SMLoc loc) {
638   auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
639     return parser.emitError(loc, err);
640   };
641   return getElementType(type, indices, errorFn);
642 }
643 
644 /// Returns true if the given `block` only contains one `spv.mlir.merge` op.
isMergeBlock(Block & block)645 static inline bool isMergeBlock(Block &block) {
646   return !block.empty() && std::next(block.begin()) == block.end() &&
647          isa<spirv::MergeOp>(block.front());
648 }
649 
650 //===----------------------------------------------------------------------===//
651 // Common parsers and printers
652 //===----------------------------------------------------------------------===//
653 
654 // Parses an atomic update op. If the update op does not take a value (like
655 // AtomicIIncrement) `hasValue` must be false.
parseAtomicUpdateOp(OpAsmParser & parser,OperationState & state,bool hasValue)656 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
657                                        OperationState &state, bool hasValue) {
658   spirv::Scope scope;
659   spirv::MemorySemantics memoryScope;
660   SmallVector<OpAsmParser::OperandType, 2> operandInfo;
661   OpAsmParser::OperandType ptrInfo, valueInfo;
662   Type type;
663   llvm::SMLoc loc;
664   if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
665       parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
666       parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
667       parser.getCurrentLocation(&loc) || parser.parseColonType(type))
668     return failure();
669 
670   auto ptrType = type.dyn_cast<spirv::PointerType>();
671   if (!ptrType)
672     return parser.emitError(loc, "expected pointer type");
673 
674   SmallVector<Type, 2> operandTypes;
675   operandTypes.push_back(ptrType);
676   if (hasValue)
677     operandTypes.push_back(ptrType.getPointeeType());
678   if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
679                              state.operands))
680     return failure();
681   return parser.addTypeToList(ptrType.getPointeeType(), state.types);
682 }
683 
684 // Prints an atomic update op.
printAtomicUpdateOp(Operation * op,OpAsmPrinter & printer)685 static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
686   printer << op->getName() << " \"";
687   auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
688   printer << spirv::stringifyScope(
689                  static_cast<spirv::Scope>(scopeAttr.getInt()))
690           << "\" \"";
691   auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
692   printer << spirv::stringifyMemorySemantics(
693                  static_cast<spirv::MemorySemantics>(
694                      memorySemanticsAttr.getInt()))
695           << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
696 }
697 
698 // Verifies an atomic update op.
verifyAtomicUpdateOp(Operation * op)699 static LogicalResult verifyAtomicUpdateOp(Operation *op) {
700   auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
701   auto elementType = ptrType.getPointeeType();
702   if (!elementType.isa<IntegerType>())
703     return op->emitOpError(
704                "pointer operand must point to an integer value, found ")
705            << elementType;
706 
707   if (op->getNumOperands() > 1) {
708     auto valueType = op->getOperand(1).getType();
709     if (valueType != elementType)
710       return op->emitOpError("expected value to have the same type as the "
711                              "pointer operand's pointee type ")
712              << elementType << ", but found " << valueType;
713   }
714   return success();
715 }
716 
parseGroupNonUniformArithmeticOp(OpAsmParser & parser,OperationState & state)717 static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
718                                                     OperationState &state) {
719   spirv::Scope executionScope;
720   spirv::GroupOperation groupOperation;
721   OpAsmParser::OperandType valueInfo;
722   if (parseEnumStrAttr(executionScope, parser, state,
723                        kExecutionScopeAttrName) ||
724       parseEnumStrAttr(groupOperation, parser, state,
725                        kGroupOperationAttrName) ||
726       parser.parseOperand(valueInfo))
727     return failure();
728 
729   Optional<OpAsmParser::OperandType> clusterSizeInfo;
730   if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
731     clusterSizeInfo = OpAsmParser::OperandType();
732     if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
733         parser.parseRParen())
734       return failure();
735   }
736 
737   Type resultType;
738   if (parser.parseColonType(resultType))
739     return failure();
740 
741   if (parser.resolveOperand(valueInfo, resultType, state.operands))
742     return failure();
743 
744   if (clusterSizeInfo.hasValue()) {
745     Type i32Type = parser.getBuilder().getIntegerType(32);
746     if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
747       return failure();
748   }
749 
750   return parser.addTypeToList(resultType, state.types);
751 }
752 
printGroupNonUniformArithmeticOp(Operation * groupOp,OpAsmPrinter & printer)753 static void printGroupNonUniformArithmeticOp(Operation *groupOp,
754                                              OpAsmPrinter &printer) {
755   printer << groupOp->getName() << " \""
756           << stringifyScope(static_cast<spirv::Scope>(
757                  groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
758                      .getInt()))
759           << "\" \""
760           << stringifyGroupOperation(static_cast<spirv::GroupOperation>(
761                  groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
762                      .getInt()))
763           << "\" " << groupOp->getOperand(0);
764 
765   if (groupOp->getNumOperands() > 1)
766     printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
767   printer << " : " << groupOp->getResult(0).getType();
768 }
769 
verifyGroupNonUniformArithmeticOp(Operation * groupOp)770 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
771   spirv::Scope scope = static_cast<spirv::Scope>(
772       groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
773   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
774     return groupOp->emitOpError(
775         "execution scope must be 'Workgroup' or 'Subgroup'");
776 
777   spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
778       groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
779   if (operation == spirv::GroupOperation::ClusteredReduce &&
780       groupOp->getNumOperands() == 1)
781     return groupOp->emitOpError("cluster size operand must be provided for "
782                                 "'ClusteredReduce' group operation");
783   if (groupOp->getNumOperands() > 1) {
784     Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
785     int32_t clusterSize = 0;
786 
787     // TODO: support specialization constant here.
788     if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
789       return groupOp->emitOpError(
790           "cluster size operand must come from a constant op");
791 
792     if (!llvm::isPowerOf2_32(clusterSize))
793       return groupOp->emitOpError(
794           "cluster size operand must be a power of two");
795   }
796   return success();
797 }
798 
parseUnaryOp(OpAsmParser & parser,OperationState & state)799 static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) {
800   OpAsmParser::OperandType operandInfo;
801   Type type;
802   if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
803       parser.resolveOperands(operandInfo, type, state.operands)) {
804     return failure();
805   }
806   state.addTypes(type);
807   return success();
808 }
809 
printUnaryOp(Operation * unaryOp,OpAsmPrinter & printer)810 static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
811   printer << unaryOp->getName() << ' ' << unaryOp->getOperand(0) << " : "
812           << unaryOp->getOperand(0).getType();
813 }
814 
815 /// Result of a logical op must be a scalar or vector of boolean type.
getUnaryOpResultType(Builder & builder,Type operandType)816 static Type getUnaryOpResultType(Builder &builder, Type operandType) {
817   Type resultType = builder.getIntegerType(1);
818   if (auto vecType = operandType.dyn_cast<VectorType>()) {
819     return VectorType::get(vecType.getNumElements(), resultType);
820   }
821   return resultType;
822 }
823 
parseLogicalUnaryOp(OpAsmParser & parser,OperationState & state)824 static ParseResult parseLogicalUnaryOp(OpAsmParser &parser,
825                                        OperationState &state) {
826   OpAsmParser::OperandType operandInfo;
827   Type type;
828   if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
829       parser.resolveOperand(operandInfo, type, state.operands)) {
830     return failure();
831   }
832   state.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
833   return success();
834 }
835 
parseLogicalBinaryOp(OpAsmParser & parser,OperationState & result)836 static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
837                                         OperationState &result) {
838   SmallVector<OpAsmParser::OperandType, 2> ops;
839   Type type;
840   if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) ||
841       parser.resolveOperands(ops, type, result.operands)) {
842     return failure();
843   }
844   result.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
845   return success();
846 }
847 
printLogicalOp(Operation * logicalOp,OpAsmPrinter & printer)848 static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
849   printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
850           << logicalOp->getOperand(0).getType();
851 }
852 
parseShiftOp(OpAsmParser & parser,OperationState & state)853 static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
854   SmallVector<OpAsmParser::OperandType, 2> operandInfo;
855   Type baseType;
856   Type shiftType;
857   auto loc = parser.getCurrentLocation();
858 
859   if (parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
860       parser.parseType(baseType) || parser.parseComma() ||
861       parser.parseType(shiftType) ||
862       parser.resolveOperands(operandInfo, {baseType, shiftType}, loc,
863                              state.operands)) {
864     return failure();
865   }
866   state.addTypes(baseType);
867   return success();
868 }
869 
printShiftOp(Operation * op,OpAsmPrinter & printer)870 static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
871   Value base = op->getOperand(0);
872   Value shift = op->getOperand(1);
873   printer << op->getName() << ' ' << base << ", " << shift << " : "
874           << base.getType() << ", " << shift.getType();
875 }
876 
verifyShiftOp(Operation * op)877 static LogicalResult verifyShiftOp(Operation *op) {
878   if (op->getOperand(0).getType() != op->getResult(0).getType()) {
879     return op->emitError("expected the same type for the first operand and "
880                          "result, but provided ")
881            << op->getOperand(0).getType() << " and "
882            << op->getResult(0).getType();
883   }
884   return success();
885 }
886 
buildLogicalBinaryOp(OpBuilder & builder,OperationState & state,Value lhs,Value rhs)887 static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
888                                  Value lhs, Value rhs) {
889   assert(lhs.getType() == rhs.getType());
890 
891   Type boolType = builder.getI1Type();
892   if (auto vecType = lhs.getType().dyn_cast<VectorType>())
893     boolType = VectorType::get(vecType.getShape(), boolType);
894   state.addTypes(boolType);
895 
896   state.addOperands({lhs, rhs});
897 }
898 
899 //===----------------------------------------------------------------------===//
900 // spv.AccessChainOp
901 //===----------------------------------------------------------------------===//
902 
getElementPtrType(Type type,ValueRange indices,Location baseLoc)903 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
904   auto ptrType = type.dyn_cast<spirv::PointerType>();
905   if (!ptrType) {
906     emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
907                        "to composite type, but provided ")
908         << type;
909     return nullptr;
910   }
911 
912   auto resultType = ptrType.getPointeeType();
913   auto resultStorageClass = ptrType.getStorageClass();
914   int32_t index = 0;
915 
916   for (auto indexSSA : indices) {
917     auto cType = resultType.dyn_cast<spirv::CompositeType>();
918     if (!cType) {
919       emitError(baseLoc,
920                 "'spv.AccessChain' op cannot extract from non-composite type ")
921           << resultType << " with index " << index;
922       return nullptr;
923     }
924     index = 0;
925     if (resultType.isa<spirv::StructType>()) {
926       Operation *op = indexSSA.getDefiningOp();
927       if (!op) {
928         emitError(baseLoc, "'spv.AccessChain' op index must be an "
929                            "integer spv.constant to access "
930                            "element of spv.struct");
931         return nullptr;
932       }
933 
934       // TODO: this should be relaxed to allow
935       // integer literals of other bitwidths.
936       if (failed(extractValueFromConstOp(op, index))) {
937         emitError(baseLoc,
938                   "'spv.AccessChain' index must be an integer spv.constant to "
939                   "access element of spv.struct, but provided ")
940             << op->getName();
941         return nullptr;
942       }
943       if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
944         emitError(baseLoc, "'spv.AccessChain' op index ")
945             << index << " out of bounds for " << resultType;
946         return nullptr;
947       }
948     }
949     resultType = cType.getElementType(index);
950   }
951   return spirv::PointerType::get(resultType, resultStorageClass);
952 }
953 
build(OpBuilder & builder,OperationState & state,Value basePtr,ValueRange indices)954 void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
955                                  Value basePtr, ValueRange indices) {
956   auto type = getElementPtrType(basePtr.getType(), indices, state.location);
957   assert(type && "Unable to deduce return type based on basePtr and indices");
958   build(builder, state, type, basePtr, indices);
959 }
960 
parseAccessChainOp(OpAsmParser & parser,OperationState & state)961 static ParseResult parseAccessChainOp(OpAsmParser &parser,
962                                       OperationState &state) {
963   OpAsmParser::OperandType ptrInfo;
964   SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
965   Type type;
966   auto loc = parser.getCurrentLocation();
967   SmallVector<Type, 4> indicesTypes;
968 
969   if (parser.parseOperand(ptrInfo) ||
970       parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
971       parser.parseColonType(type) ||
972       parser.resolveOperand(ptrInfo, type, state.operands)) {
973     return failure();
974   }
975 
976   // Check that the provided indices list is not empty before parsing their
977   // type list.
978   if (indicesInfo.empty()) {
979     return emitError(state.location, "'spv.AccessChain' op expected at "
980                                      "least one index ");
981   }
982 
983   if (parser.parseComma() || parser.parseTypeList(indicesTypes))
984     return failure();
985 
986   // Check that the indices types list is not empty and that it has a one-to-one
987   // mapping to the provided indices.
988   if (indicesTypes.size() != indicesInfo.size()) {
989     return emitError(state.location, "'spv.AccessChain' op indices "
990                                      "types' count must be equal to indices "
991                                      "info count");
992   }
993 
994   if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
995     return failure();
996 
997   auto resultType = getElementPtrType(
998       type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
999   if (!resultType) {
1000     return failure();
1001   }
1002 
1003   state.addTypes(resultType);
1004   return success();
1005 }
1006 
print(spirv::AccessChainOp op,OpAsmPrinter & printer)1007 static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
1008   printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
1009           << '[' << op.indices() << "] : " << op.base_ptr().getType() << ", "
1010           << op.indices().getTypes();
1011 }
1012 
verify(spirv::AccessChainOp accessChainOp)1013 static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
1014   SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
1015                                 accessChainOp.indices().end());
1016   auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
1017                                       indices, accessChainOp.getLoc());
1018   if (!resultType) {
1019     return failure();
1020   }
1021 
1022   auto providedResultType =
1023       accessChainOp.getType().dyn_cast<spirv::PointerType>();
1024   if (!providedResultType) {
1025     return accessChainOp.emitOpError(
1026                "result type must be a pointer, but provided")
1027            << providedResultType;
1028   }
1029 
1030   if (resultType != providedResultType) {
1031     return accessChainOp.emitOpError("invalid result type: expected ")
1032            << resultType << ", but provided " << providedResultType;
1033   }
1034 
1035   return success();
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // spv.mlir.addressof
1040 //===----------------------------------------------------------------------===//
1041 
build(OpBuilder & builder,OperationState & state,spirv::GlobalVariableOp var)1042 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1043                                spirv::GlobalVariableOp var) {
1044   build(builder, state, var.type(), builder.getSymbolRefAttr(var));
1045 }
1046 
verify(spirv::AddressOfOp addressOfOp)1047 static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
1048   auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1049       SymbolTable::lookupNearestSymbolFrom(addressOfOp->getParentOp(),
1050                                            addressOfOp.variable()));
1051   if (!varOp) {
1052     return addressOfOp.emitOpError("expected spv.globalVariable symbol");
1053   }
1054   if (addressOfOp.pointer().getType() != varOp.type()) {
1055     return addressOfOp.emitOpError(
1056         "result type mismatch with the referenced global variable's type");
1057   }
1058   return success();
1059 }
1060 
1061 //===----------------------------------------------------------------------===//
1062 // spv.AtomicCompareExchangeWeak
1063 //===----------------------------------------------------------------------===//
1064 
parseAtomicCompareExchangeWeakOp(OpAsmParser & parser,OperationState & state)1065 static ParseResult parseAtomicCompareExchangeWeakOp(OpAsmParser &parser,
1066                                                     OperationState &state) {
1067   spirv::Scope memoryScope;
1068   spirv::MemorySemantics equalSemantics, unequalSemantics;
1069   SmallVector<OpAsmParser::OperandType, 3> operandInfo;
1070   Type type;
1071   if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
1072       parseEnumStrAttr(equalSemantics, parser, state,
1073                        kEqualSemanticsAttrName) ||
1074       parseEnumStrAttr(unequalSemantics, parser, state,
1075                        kUnequalSemanticsAttrName) ||
1076       parser.parseOperandList(operandInfo, 3))
1077     return failure();
1078 
1079   auto loc = parser.getCurrentLocation();
1080   if (parser.parseColonType(type))
1081     return failure();
1082 
1083   auto ptrType = type.dyn_cast<spirv::PointerType>();
1084   if (!ptrType)
1085     return parser.emitError(loc, "expected pointer type");
1086 
1087   if (parser.resolveOperands(
1088           operandInfo,
1089           {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1090           parser.getNameLoc(), state.operands))
1091     return failure();
1092 
1093   return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1094 }
1095 
print(spirv::AtomicCompareExchangeWeakOp atomOp,OpAsmPrinter & printer)1096 static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
1097                   OpAsmPrinter &printer) {
1098   printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \""
1099           << stringifyScope(atomOp.memory_scope()) << "\" \""
1100           << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
1101           << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
1102           << atomOp.getOperands() << " : " << atomOp.pointer().getType();
1103 }
1104 
verify(spirv::AtomicCompareExchangeWeakOp atomOp)1105 static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
1106   // According to the spec:
1107   // "The type of Value must be the same as Result Type. The type of the value
1108   // pointed to by Pointer must be the same as Result Type. This type must also
1109   // match the type of Comparator."
1110   if (atomOp.getType() != atomOp.value().getType())
1111     return atomOp.emitOpError("value operand must have the same type as the op "
1112                               "result, but found ")
1113            << atomOp.value().getType() << " vs " << atomOp.getType();
1114 
1115   if (atomOp.getType() != atomOp.comparator().getType())
1116     return atomOp.emitOpError(
1117                "comparator operand must have the same type as the op "
1118                "result, but found ")
1119            << atomOp.comparator().getType() << " vs " << atomOp.getType();
1120 
1121   Type pointeeType =
1122       atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
1123   if (atomOp.getType() != pointeeType)
1124     return atomOp.emitOpError(
1125                "pointer operand's pointee type must have the same "
1126                "as the op result type, but found ")
1127            << pointeeType << " vs " << atomOp.getType();
1128 
1129   // TODO: Unequal cannot be set to Release or Acquire and Release.
1130   // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1131 
1132   return success();
1133 }
1134 
1135 //===----------------------------------------------------------------------===//
1136 // spv.BitcastOp
1137 //===----------------------------------------------------------------------===//
1138 
verify(spirv::BitcastOp bitcastOp)1139 static LogicalResult verify(spirv::BitcastOp bitcastOp) {
1140   // TODO: The SPIR-V spec validation rules are different for different
1141   // versions.
1142   auto operandType = bitcastOp.operand().getType();
1143   auto resultType = bitcastOp.result().getType();
1144   if (operandType == resultType) {
1145     return bitcastOp.emitError(
1146         "result type must be different from operand type");
1147   }
1148   if (operandType.isa<spirv::PointerType>() &&
1149       !resultType.isa<spirv::PointerType>()) {
1150     return bitcastOp.emitError(
1151         "unhandled bit cast conversion from pointer type to non-pointer type");
1152   }
1153   if (!operandType.isa<spirv::PointerType>() &&
1154       resultType.isa<spirv::PointerType>()) {
1155     return bitcastOp.emitError(
1156         "unhandled bit cast conversion from non-pointer type to pointer type");
1157   }
1158   auto operandBitWidth = getBitWidth(operandType);
1159   auto resultBitWidth = getBitWidth(resultType);
1160   if (operandBitWidth != resultBitWidth) {
1161     return bitcastOp.emitOpError("mismatch in result type bitwidth ")
1162            << resultBitWidth << " and operand type bitwidth "
1163            << operandBitWidth;
1164   }
1165   return success();
1166 }
1167 
1168 //===----------------------------------------------------------------------===//
1169 // spv.BranchOp
1170 //===----------------------------------------------------------------------===//
1171 
1172 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1173 spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
1174   assert(index == 0 && "invalid successor index");
1175   return targetOperandsMutable();
1176 }
1177 
1178 //===----------------------------------------------------------------------===//
1179 // spv.BranchConditionalOp
1180 //===----------------------------------------------------------------------===//
1181 
1182 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1183 spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
1184   assert(index < 2 && "invalid successor index");
1185   return index == kTrueIndex ? trueTargetOperandsMutable()
1186                              : falseTargetOperandsMutable();
1187 }
1188 
parseBranchConditionalOp(OpAsmParser & parser,OperationState & state)1189 static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
1190                                             OperationState &state) {
1191   auto &builder = parser.getBuilder();
1192   OpAsmParser::OperandType condInfo;
1193   Block *dest;
1194 
1195   // Parse the condition.
1196   Type boolTy = builder.getI1Type();
1197   if (parser.parseOperand(condInfo) ||
1198       parser.resolveOperand(condInfo, boolTy, state.operands))
1199     return failure();
1200 
1201   // Parse the optional branch weights.
1202   if (succeeded(parser.parseOptionalLSquare())) {
1203     IntegerAttr trueWeight, falseWeight;
1204     NamedAttrList weights;
1205 
1206     auto i32Type = builder.getIntegerType(32);
1207     if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1208         parser.parseComma() ||
1209         parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1210         parser.parseRSquare())
1211       return failure();
1212 
1213     state.addAttribute(kBranchWeightAttrName,
1214                        builder.getArrayAttr({trueWeight, falseWeight}));
1215   }
1216 
1217   // Parse the true branch.
1218   SmallVector<Value, 4> trueOperands;
1219   if (parser.parseComma() ||
1220       parser.parseSuccessorAndUseList(dest, trueOperands))
1221     return failure();
1222   state.addSuccessors(dest);
1223   state.addOperands(trueOperands);
1224 
1225   // Parse the false branch.
1226   SmallVector<Value, 4> falseOperands;
1227   if (parser.parseComma() ||
1228       parser.parseSuccessorAndUseList(dest, falseOperands))
1229     return failure();
1230   state.addSuccessors(dest);
1231   state.addOperands(falseOperands);
1232   state.addAttribute(
1233       spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1234       builder.getI32VectorAttr({1, static_cast<int32_t>(trueOperands.size()),
1235                                 static_cast<int32_t>(falseOperands.size())}));
1236 
1237   return success();
1238 }
1239 
print(spirv::BranchConditionalOp branchOp,OpAsmPrinter & printer)1240 static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
1241   printer << spirv::BranchConditionalOp::getOperationName() << ' '
1242           << branchOp.condition();
1243 
1244   if (auto weights = branchOp.branch_weights()) {
1245     printer << " [";
1246     llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1247       printer << a.cast<IntegerAttr>().getInt();
1248     });
1249     printer << "]";
1250   }
1251 
1252   printer << ", ";
1253   printer.printSuccessorAndUseList(branchOp.getTrueBlock(),
1254                                    branchOp.getTrueBlockArguments());
1255   printer << ", ";
1256   printer.printSuccessorAndUseList(branchOp.getFalseBlock(),
1257                                    branchOp.getFalseBlockArguments());
1258 }
1259 
verify(spirv::BranchConditionalOp branchOp)1260 static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
1261   if (auto weights = branchOp.branch_weights()) {
1262     if (weights->getValue().size() != 2) {
1263       return branchOp.emitOpError("must have exactly two branch weights");
1264     }
1265     if (llvm::all_of(*weights, [](Attribute attr) {
1266           return attr.cast<IntegerAttr>().getValue().isNullValue();
1267         }))
1268       return branchOp.emitOpError("branch weights cannot both be zero");
1269   }
1270 
1271   return success();
1272 }
1273 
1274 //===----------------------------------------------------------------------===//
1275 // spv.CompositeConstruct
1276 //===----------------------------------------------------------------------===//
1277 
parseCompositeConstructOp(OpAsmParser & parser,OperationState & state)1278 static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
1279                                              OperationState &state) {
1280   SmallVector<OpAsmParser::OperandType, 4> operands;
1281   Type type;
1282   auto loc = parser.getCurrentLocation();
1283 
1284   if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
1285     return failure();
1286   }
1287   auto cType = type.dyn_cast<spirv::CompositeType>();
1288   if (!cType) {
1289     return parser.emitError(
1290                loc, "result type must be a composite type, but provided ")
1291            << type;
1292   }
1293 
1294   if (cType.hasCompileTimeKnownNumElements() &&
1295       operands.size() != cType.getNumElements()) {
1296     return parser.emitError(loc, "has incorrect number of operands: expected ")
1297            << cType.getNumElements() << ", but provided " << operands.size();
1298   }
1299   // TODO: Add support for constructing a vector type from the vector operands.
1300   // According to the spec: "for constructing a vector, the operands may
1301   // also be vectors with the same component type as the Result Type component
1302   // type".
1303   SmallVector<Type, 4> elementTypes;
1304   elementTypes.reserve(operands.size());
1305   for (auto index : llvm::seq<uint32_t>(0, operands.size())) {
1306     elementTypes.push_back(cType.getElementType(index));
1307   }
1308   state.addTypes(type);
1309   return parser.resolveOperands(operands, elementTypes, loc, state.operands);
1310 }
1311 
print(spirv::CompositeConstructOp compositeConstructOp,OpAsmPrinter & printer)1312 static void print(spirv::CompositeConstructOp compositeConstructOp,
1313                   OpAsmPrinter &printer) {
1314   printer << spirv::CompositeConstructOp::getOperationName() << " "
1315           << compositeConstructOp.constituents() << " : "
1316           << compositeConstructOp.getResult().getType();
1317 }
1318 
verify(spirv::CompositeConstructOp compositeConstructOp)1319 static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
1320   auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>();
1321   SmallVector<Value, 4> constituents(compositeConstructOp.constituents());
1322 
1323   if (cType.isa<spirv::CooperativeMatrixNVType>()) {
1324     if (constituents.size() != 1)
1325       return compositeConstructOp.emitError(
1326                  "has incorrect number of operands: expected ")
1327              << "1, but provided " << constituents.size();
1328   } else if (constituents.size() != cType.getNumElements()) {
1329     return compositeConstructOp.emitError(
1330                "has incorrect number of operands: expected ")
1331            << cType.getNumElements() << ", but provided "
1332            << constituents.size();
1333   }
1334 
1335   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1336     if (constituents[index].getType() != cType.getElementType(index)) {
1337       return compositeConstructOp.emitError(
1338                  "operand type mismatch: expected operand type ")
1339              << cType.getElementType(index) << ", but provided "
1340              << constituents[index].getType();
1341     }
1342   }
1343 
1344   return success();
1345 }
1346 
1347 //===----------------------------------------------------------------------===//
1348 // spv.CompositeExtractOp
1349 //===----------------------------------------------------------------------===//
1350 
build(OpBuilder & builder,OperationState & state,Value composite,ArrayRef<int32_t> indices)1351 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1352                                       Value composite,
1353                                       ArrayRef<int32_t> indices) {
1354   auto indexAttr = builder.getI32ArrayAttr(indices);
1355   auto elementType =
1356       getElementType(composite.getType(), indexAttr, state.location);
1357   if (!elementType) {
1358     return;
1359   }
1360   build(builder, state, elementType, composite, indexAttr);
1361 }
1362 
parseCompositeExtractOp(OpAsmParser & parser,OperationState & state)1363 static ParseResult parseCompositeExtractOp(OpAsmParser &parser,
1364                                            OperationState &state) {
1365   OpAsmParser::OperandType compositeInfo;
1366   Attribute indicesAttr;
1367   Type compositeType;
1368   llvm::SMLoc attrLocation;
1369 
1370   if (parser.parseOperand(compositeInfo) ||
1371       parser.getCurrentLocation(&attrLocation) ||
1372       parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1373       parser.parseColonType(compositeType) ||
1374       parser.resolveOperand(compositeInfo, compositeType, state.operands)) {
1375     return failure();
1376   }
1377 
1378   Type resultType =
1379       getElementType(compositeType, indicesAttr, parser, attrLocation);
1380   if (!resultType) {
1381     return failure();
1382   }
1383   state.addTypes(resultType);
1384   return success();
1385 }
1386 
print(spirv::CompositeExtractOp compositeExtractOp,OpAsmPrinter & printer)1387 static void print(spirv::CompositeExtractOp compositeExtractOp,
1388                   OpAsmPrinter &printer) {
1389   printer << spirv::CompositeExtractOp::getOperationName() << ' '
1390           << compositeExtractOp.composite() << compositeExtractOp.indices()
1391           << " : " << compositeExtractOp.composite().getType();
1392 }
1393 
verify(spirv::CompositeExtractOp compExOp)1394 static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
1395   auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
1396   auto resultType = getElementType(compExOp.composite().getType(),
1397                                    indicesArrayAttr, compExOp.getLoc());
1398   if (!resultType)
1399     return failure();
1400 
1401   if (resultType != compExOp.getType()) {
1402     return compExOp.emitOpError("invalid result type: expected ")
1403            << resultType << " but provided " << compExOp.getType();
1404   }
1405 
1406   return success();
1407 }
1408 
1409 //===----------------------------------------------------------------------===//
1410 // spv.CompositeInsert
1411 //===----------------------------------------------------------------------===//
1412 
build(OpBuilder & builder,OperationState & state,Value object,Value composite,ArrayRef<int32_t> indices)1413 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1414                                      Value object, Value composite,
1415                                      ArrayRef<int32_t> indices) {
1416   auto indexAttr = builder.getI32ArrayAttr(indices);
1417   build(builder, state, composite.getType(), object, composite, indexAttr);
1418 }
1419 
parseCompositeInsertOp(OpAsmParser & parser,OperationState & state)1420 static ParseResult parseCompositeInsertOp(OpAsmParser &parser,
1421                                           OperationState &state) {
1422   SmallVector<OpAsmParser::OperandType, 2> operands;
1423   Type objectType, compositeType;
1424   Attribute indicesAttr;
1425   auto loc = parser.getCurrentLocation();
1426 
1427   return failure(
1428       parser.parseOperandList(operands, 2) ||
1429       parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1430       parser.parseColonType(objectType) ||
1431       parser.parseKeywordType("into", compositeType) ||
1432       parser.resolveOperands(operands, {objectType, compositeType}, loc,
1433                              state.operands) ||
1434       parser.addTypesToList(compositeType, state.types));
1435 }
1436 
verify(spirv::CompositeInsertOp compositeInsertOp)1437 static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
1438   auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>();
1439   auto objectType =
1440       getElementType(compositeInsertOp.composite().getType(), indicesArrayAttr,
1441                      compositeInsertOp.getLoc());
1442   if (!objectType)
1443     return failure();
1444 
1445   if (objectType != compositeInsertOp.object().getType()) {
1446     return compositeInsertOp.emitOpError("object operand type should be ")
1447            << objectType << ", but found "
1448            << compositeInsertOp.object().getType();
1449   }
1450 
1451   if (compositeInsertOp.composite().getType() != compositeInsertOp.getType()) {
1452     return compositeInsertOp.emitOpError("result type should be the same as "
1453                                          "the composite type, but found ")
1454            << compositeInsertOp.composite().getType() << " vs "
1455            << compositeInsertOp.getType();
1456   }
1457 
1458   return success();
1459 }
1460 
print(spirv::CompositeInsertOp compositeInsertOp,OpAsmPrinter & printer)1461 static void print(spirv::CompositeInsertOp compositeInsertOp,
1462                   OpAsmPrinter &printer) {
1463   printer << spirv::CompositeInsertOp::getOperationName() << " "
1464           << compositeInsertOp.object() << ", " << compositeInsertOp.composite()
1465           << compositeInsertOp.indices() << " : "
1466           << compositeInsertOp.object().getType() << " into "
1467           << compositeInsertOp.composite().getType();
1468 }
1469 
1470 //===----------------------------------------------------------------------===//
1471 // spv.constant
1472 //===----------------------------------------------------------------------===//
1473 
parseConstantOp(OpAsmParser & parser,OperationState & state)1474 static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
1475   Attribute value;
1476   if (parser.parseAttribute(value, kValueAttrName, state.attributes))
1477     return failure();
1478 
1479   Type type = value.getType();
1480   if (type.isa<NoneType, TensorType>()) {
1481     if (parser.parseColonType(type))
1482       return failure();
1483   }
1484 
1485   return parser.addTypeToList(type, state.types);
1486 }
1487 
print(spirv::ConstantOp constOp,OpAsmPrinter & printer)1488 static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
1489   printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
1490   if (constOp.getType().isa<spirv::ArrayType>())
1491     printer << " : " << constOp.getType();
1492 }
1493 
verify(spirv::ConstantOp constOp)1494 static LogicalResult verify(spirv::ConstantOp constOp) {
1495   auto opType = constOp.getType();
1496   auto value = constOp.value();
1497   auto valueType = value.getType();
1498 
1499   // ODS already generates checks to make sure the result type is valid. We just
1500   // need to additionally check that the value's attribute type is consistent
1501   // with the result type.
1502   if (value.isa<IntegerAttr, FloatAttr>()) {
1503     if (valueType != opType)
1504       return constOp.emitOpError("result type (")
1505              << opType << ") does not match value type (" << valueType << ")";
1506     return success();
1507   }
1508   if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1509     if (valueType == opType)
1510       return success();
1511     auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1512     auto shapedType = valueType.dyn_cast<ShapedType>();
1513     if (!arrayType) {
1514       return constOp.emitOpError(
1515           "must have spv.array result type for array value");
1516     }
1517 
1518     int numElements = arrayType.getNumElements();
1519     auto opElemType = arrayType.getElementType();
1520     while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
1521       numElements *= t.getNumElements();
1522       opElemType = t.getElementType();
1523     }
1524     if (!opElemType.isIntOrFloat())
1525       return constOp.emitOpError("only support nested array result type");
1526 
1527     auto valueElemType = shapedType.getElementType();
1528     if (valueElemType != opElemType) {
1529       return constOp.emitOpError("result element type (")
1530              << opElemType << ") does not match value element type ("
1531              << valueElemType << ")";
1532     }
1533 
1534     if (numElements != shapedType.getNumElements()) {
1535       return constOp.emitOpError("result number of elements (")
1536              << numElements << ") does not match value number of elements ("
1537              << shapedType.getNumElements() << ")";
1538     }
1539     return success();
1540   }
1541   if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
1542     auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1543     if (!arrayType)
1544       return constOp.emitOpError(
1545           "must have spv.array result type for array value");
1546     Type elemType = arrayType.getElementType();
1547     for (Attribute element : attayAttr.getValue()) {
1548       if (element.getType() != elemType)
1549         return constOp.emitOpError("has array element whose type (")
1550                << element.getType()
1551                << ") does not match the result element type (" << elemType
1552                << ')';
1553     }
1554     return success();
1555   }
1556   return constOp.emitOpError("cannot have value of type ") << valueType;
1557 }
1558 
isBuildableWith(Type type)1559 bool spirv::ConstantOp::isBuildableWith(Type type) {
1560   // Must be valid SPIR-V type first.
1561   if (!type.isa<spirv::SPIRVType>())
1562     return false;
1563 
1564   if (isa<SPIRVDialect>(type.getDialect())) {
1565     // TODO: support constant struct
1566     return type.isa<spirv::ArrayType>();
1567   }
1568 
1569   return true;
1570 }
1571 
getZero(Type type,Location loc,OpBuilder & builder)1572 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
1573                                              OpBuilder &builder) {
1574   if (auto intType = type.dyn_cast<IntegerType>()) {
1575     unsigned width = intType.getWidth();
1576     if (width == 1)
1577       return builder.create<spirv::ConstantOp>(loc, type,
1578                                                builder.getBoolAttr(false));
1579     return builder.create<spirv::ConstantOp>(
1580         loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
1581   }
1582 
1583   llvm_unreachable("unimplemented types for ConstantOp::getZero()");
1584 }
1585 
getOne(Type type,Location loc,OpBuilder & builder)1586 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
1587                                             OpBuilder &builder) {
1588   if (auto intType = type.dyn_cast<IntegerType>()) {
1589     unsigned width = intType.getWidth();
1590     if (width == 1)
1591       return builder.create<spirv::ConstantOp>(loc, type,
1592                                                builder.getBoolAttr(true));
1593     return builder.create<spirv::ConstantOp>(
1594         loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
1595   }
1596 
1597   llvm_unreachable("unimplemented types for ConstantOp::getOne()");
1598 }
1599 
1600 //===----------------------------------------------------------------------===//
1601 // spv.EntryPoint
1602 //===----------------------------------------------------------------------===//
1603 
build(OpBuilder & builder,OperationState & state,spirv::ExecutionModel executionModel,spirv::FuncOp function,ArrayRef<Attribute> interfaceVars)1604 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
1605                                 spirv::ExecutionModel executionModel,
1606                                 spirv::FuncOp function,
1607                                 ArrayRef<Attribute> interfaceVars) {
1608   build(builder, state,
1609         builder.getI32IntegerAttr(static_cast<int32_t>(executionModel)),
1610         builder.getSymbolRefAttr(function),
1611         builder.getArrayAttr(interfaceVars));
1612 }
1613 
parseEntryPointOp(OpAsmParser & parser,OperationState & state)1614 static ParseResult parseEntryPointOp(OpAsmParser &parser,
1615                                      OperationState &state) {
1616   spirv::ExecutionModel execModel;
1617   SmallVector<OpAsmParser::OperandType, 0> identifiers;
1618   SmallVector<Type, 0> idTypes;
1619   SmallVector<Attribute, 4> interfaceVars;
1620 
1621   FlatSymbolRefAttr fn;
1622   if (parseEnumStrAttr(execModel, parser, state) ||
1623       parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
1624     return failure();
1625   }
1626 
1627   if (!parser.parseOptionalComma()) {
1628     // Parse the interface variables
1629     do {
1630       // The name of the interface variable attribute isnt important
1631       auto attrName = "var_symbol";
1632       FlatSymbolRefAttr var;
1633       NamedAttrList attrs;
1634       if (parser.parseAttribute(var, Type(), attrName, attrs)) {
1635         return failure();
1636       }
1637       interfaceVars.push_back(var);
1638     } while (!parser.parseOptionalComma());
1639   }
1640   state.addAttribute(kInterfaceAttrName,
1641                      parser.getBuilder().getArrayAttr(interfaceVars));
1642   return success();
1643 }
1644 
print(spirv::EntryPointOp entryPointOp,OpAsmPrinter & printer)1645 static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) {
1646   printer << spirv::EntryPointOp::getOperationName() << " \""
1647           << stringifyExecutionModel(entryPointOp.execution_model()) << "\" ";
1648   printer.printSymbolName(entryPointOp.fn());
1649   auto interfaceVars = entryPointOp.interface().getValue();
1650   if (!interfaceVars.empty()) {
1651     printer << ", ";
1652     llvm::interleaveComma(interfaceVars, printer);
1653   }
1654 }
1655 
verify(spirv::EntryPointOp entryPointOp)1656 static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
1657   // Checks for fn and interface symbol reference are done in spirv::ModuleOp
1658   // verification.
1659   return success();
1660 }
1661 
1662 //===----------------------------------------------------------------------===//
1663 // spv.ExecutionMode
1664 //===----------------------------------------------------------------------===//
1665 
build(OpBuilder & builder,OperationState & state,spirv::FuncOp function,spirv::ExecutionMode executionMode,ArrayRef<int32_t> params)1666 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
1667                                    spirv::FuncOp function,
1668                                    spirv::ExecutionMode executionMode,
1669                                    ArrayRef<int32_t> params) {
1670   build(builder, state, builder.getSymbolRefAttr(function),
1671         builder.getI32IntegerAttr(static_cast<int32_t>(executionMode)),
1672         builder.getI32ArrayAttr(params));
1673 }
1674 
parseExecutionModeOp(OpAsmParser & parser,OperationState & state)1675 static ParseResult parseExecutionModeOp(OpAsmParser &parser,
1676                                         OperationState &state) {
1677   spirv::ExecutionMode execMode;
1678   Attribute fn;
1679   if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) ||
1680       parseEnumStrAttr(execMode, parser, state)) {
1681     return failure();
1682   }
1683 
1684   SmallVector<int32_t, 4> values;
1685   Type i32Type = parser.getBuilder().getIntegerType(32);
1686   while (!parser.parseOptionalComma()) {
1687     NamedAttrList attr;
1688     Attribute value;
1689     if (parser.parseAttribute(value, i32Type, "value", attr)) {
1690       return failure();
1691     }
1692     values.push_back(value.cast<IntegerAttr>().getInt());
1693   }
1694   state.addAttribute(kValuesAttrName,
1695                      parser.getBuilder().getI32ArrayAttr(values));
1696   return success();
1697 }
1698 
print(spirv::ExecutionModeOp execModeOp,OpAsmPrinter & printer)1699 static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
1700   printer << spirv::ExecutionModeOp::getOperationName() << " ";
1701   printer.printSymbolName(execModeOp.fn());
1702   printer << " \"" << stringifyExecutionMode(execModeOp.execution_mode())
1703           << "\"";
1704   auto values = execModeOp.values();
1705   if (!values.size())
1706     return;
1707   printer << ", ";
1708   llvm::interleaveComma(values, printer, [&](Attribute a) {
1709     printer << a.cast<IntegerAttr>().getInt();
1710   });
1711 }
1712 
1713 //===----------------------------------------------------------------------===//
1714 // spv.func
1715 //===----------------------------------------------------------------------===//
1716 
parseFuncOp(OpAsmParser & parser,OperationState & state)1717 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
1718   SmallVector<OpAsmParser::OperandType, 4> entryArgs;
1719   SmallVector<NamedAttrList, 4> argAttrs;
1720   SmallVector<NamedAttrList, 4> resultAttrs;
1721   SmallVector<Type, 4> argTypes;
1722   SmallVector<Type, 4> resultTypes;
1723   auto &builder = parser.getBuilder();
1724 
1725   // Parse the name as a symbol.
1726   StringAttr nameAttr;
1727   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1728                              state.attributes))
1729     return failure();
1730 
1731   // Parse the function signature.
1732   bool isVariadic = false;
1733   if (impl::parseFunctionSignature(parser, /*allowVariadic=*/false, entryArgs,
1734                                    argTypes, argAttrs, isVariadic, resultTypes,
1735                                    resultAttrs))
1736     return failure();
1737 
1738   auto fnType = builder.getFunctionType(argTypes, resultTypes);
1739   state.addAttribute(impl::getTypeAttrName(), TypeAttr::get(fnType));
1740 
1741   // Parse the optional function control keyword.
1742   spirv::FunctionControl fnControl;
1743   if (parseEnumStrAttr(fnControl, parser, state))
1744     return failure();
1745 
1746   // If additional attributes are present, parse them.
1747   if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
1748     return failure();
1749 
1750   // Add the attributes to the function arguments.
1751   assert(argAttrs.size() == argTypes.size());
1752   assert(resultAttrs.size() == resultTypes.size());
1753   impl::addArgAndResultAttrs(builder, state, argAttrs, resultAttrs);
1754 
1755   // Parse the optional function body.
1756   auto *body = state.addRegion();
1757   OptionalParseResult result = parser.parseOptionalRegion(
1758       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
1759   return failure(result.hasValue() && failed(*result));
1760 }
1761 
print(spirv::FuncOp fnOp,OpAsmPrinter & printer)1762 static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) {
1763   // Print function name, signature, and control.
1764   printer << spirv::FuncOp::getOperationName() << " ";
1765   printer.printSymbolName(fnOp.sym_name());
1766   auto fnType = fnOp.getType();
1767   impl::printFunctionSignature(printer, fnOp, fnType.getInputs(),
1768                                /*isVariadic=*/false, fnType.getResults());
1769   printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control())
1770           << "\"";
1771   impl::printFunctionAttributes(
1772       printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(),
1773       {spirv::attributeName<spirv::FunctionControl>()});
1774 
1775   // Print the body if this is not an external function.
1776   Region &body = fnOp.body();
1777   if (!body.empty())
1778     printer.printRegion(body, /*printEntryBlockArgs=*/false,
1779                         /*printBlockTerminators=*/true);
1780 }
1781 
verifyType()1782 LogicalResult spirv::FuncOp::verifyType() {
1783   auto type = getTypeAttr().getValue();
1784   if (!type.isa<FunctionType>())
1785     return emitOpError("requires '" + getTypeAttrName() +
1786                        "' attribute of function type");
1787   if (getType().getNumResults() > 1)
1788     return emitOpError("cannot have more than one result");
1789   return success();
1790 }
1791 
verifyBody()1792 LogicalResult spirv::FuncOp::verifyBody() {
1793   FunctionType fnType = getType();
1794 
1795   auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1796     if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1797       if (fnType.getNumResults() != 0)
1798         return retOp.emitOpError("cannot be used in functions returning value");
1799     } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1800       if (fnType.getNumResults() != 1)
1801         return retOp.emitOpError(
1802                    "returns 1 value but enclosing function requires ")
1803                << fnType.getNumResults() << " results";
1804 
1805       auto retOperandType = retOp.value().getType();
1806       auto fnResultType = fnType.getResult(0);
1807       if (retOperandType != fnResultType)
1808         return retOp.emitOpError(" return value's type (")
1809                << retOperandType << ") mismatch with function's result type ("
1810                << fnResultType << ")";
1811     }
1812     return WalkResult::advance();
1813   });
1814 
1815   // TODO: verify other bits like linkage type.
1816 
1817   return failure(walkResult.wasInterrupted());
1818 }
1819 
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,spirv::FunctionControl control,ArrayRef<NamedAttribute> attrs)1820 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1821                           StringRef name, FunctionType type,
1822                           spirv::FunctionControl control,
1823                           ArrayRef<NamedAttribute> attrs) {
1824   state.addAttribute(SymbolTable::getSymbolAttrName(),
1825                      builder.getStringAttr(name));
1826   state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
1827   state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1828                      builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
1829   state.attributes.append(attrs.begin(), attrs.end());
1830   state.addRegion();
1831 }
1832 
1833 // CallableOpInterface
getCallableRegion()1834 Region *spirv::FuncOp::getCallableRegion() {
1835   return isExternal() ? nullptr : &body();
1836 }
1837 
1838 // CallableOpInterface
getCallableResults()1839 ArrayRef<Type> spirv::FuncOp::getCallableResults() {
1840   return getType().getResults();
1841 }
1842 
1843 //===----------------------------------------------------------------------===//
1844 // spv.FunctionCall
1845 //===----------------------------------------------------------------------===//
1846 
verify(spirv::FunctionCallOp functionCallOp)1847 static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
1848   auto fnName = functionCallOp.callee();
1849 
1850   auto funcOp =
1851       dyn_cast_or_null<spirv::FuncOp>(SymbolTable::lookupNearestSymbolFrom(
1852           functionCallOp->getParentOp(), fnName));
1853   if (!funcOp) {
1854     return functionCallOp.emitOpError("callee function '")
1855            << fnName << "' not found in nearest symbol table";
1856   }
1857 
1858   auto functionType = funcOp.getType();
1859 
1860   if (functionCallOp.getNumResults() > 1) {
1861     return functionCallOp.emitOpError(
1862                "expected callee function to have 0 or 1 result, but provided ")
1863            << functionCallOp.getNumResults();
1864   }
1865 
1866   if (functionType.getNumInputs() != functionCallOp.getNumOperands()) {
1867     return functionCallOp.emitOpError(
1868                "has incorrect number of operands for callee: expected ")
1869            << functionType.getNumInputs() << ", but provided "
1870            << functionCallOp.getNumOperands();
1871   }
1872 
1873   for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
1874     if (functionCallOp.getOperand(i).getType() != functionType.getInput(i)) {
1875       return functionCallOp.emitOpError(
1876                  "operand type mismatch: expected operand type ")
1877              << functionType.getInput(i) << ", but provided "
1878              << functionCallOp.getOperand(i).getType() << " for operand number "
1879              << i;
1880     }
1881   }
1882 
1883   if (functionType.getNumResults() != functionCallOp.getNumResults()) {
1884     return functionCallOp.emitOpError(
1885                "has incorrect number of results has for callee: expected ")
1886            << functionType.getNumResults() << ", but provided "
1887            << functionCallOp.getNumResults();
1888   }
1889 
1890   if (functionCallOp.getNumResults() &&
1891       (functionCallOp.getResult(0).getType() != functionType.getResult(0))) {
1892     return functionCallOp.emitOpError("result type mismatch: expected ")
1893            << functionType.getResult(0) << ", but provided "
1894            << functionCallOp.getResult(0).getType();
1895   }
1896 
1897   return success();
1898 }
1899 
getCallableForCallee()1900 CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
1901   return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
1902 }
1903 
getArgOperands()1904 Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
1905   return arguments();
1906 }
1907 
1908 //===----------------------------------------------------------------------===//
1909 // spv.globalVariable
1910 //===----------------------------------------------------------------------===//
1911 
build(OpBuilder & builder,OperationState & state,Type type,StringRef name,unsigned descriptorSet,unsigned binding)1912 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1913                                     Type type, StringRef name,
1914                                     unsigned descriptorSet, unsigned binding) {
1915   build(builder, state, TypeAttr::get(type), builder.getStringAttr(name),
1916         nullptr);
1917   state.addAttribute(
1918       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1919       builder.getI32IntegerAttr(descriptorSet));
1920   state.addAttribute(
1921       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1922       builder.getI32IntegerAttr(binding));
1923 }
1924 
build(OpBuilder & builder,OperationState & state,Type type,StringRef name,spirv::BuiltIn builtin)1925 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1926                                     Type type, StringRef name,
1927                                     spirv::BuiltIn builtin) {
1928   build(builder, state, TypeAttr::get(type), builder.getStringAttr(name),
1929         nullptr);
1930   state.addAttribute(
1931       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1932       builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1933 }
1934 
parseGlobalVariableOp(OpAsmParser & parser,OperationState & state)1935 static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
1936                                          OperationState &state) {
1937   // Parse variable name.
1938   StringAttr nameAttr;
1939   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1940                              state.attributes)) {
1941     return failure();
1942   }
1943 
1944   // Parse optional initializer
1945   if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
1946     FlatSymbolRefAttr initSymbol;
1947     if (parser.parseLParen() ||
1948         parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
1949                               state.attributes) ||
1950         parser.parseRParen())
1951       return failure();
1952   }
1953 
1954   if (parseVariableDecorations(parser, state)) {
1955     return failure();
1956   }
1957 
1958   Type type;
1959   auto loc = parser.getCurrentLocation();
1960   if (parser.parseColonType(type)) {
1961     return failure();
1962   }
1963   if (!type.isa<spirv::PointerType>()) {
1964     return parser.emitError(loc, "expected spv.ptr type");
1965   }
1966   state.addAttribute(kTypeAttrName, TypeAttr::get(type));
1967 
1968   return success();
1969 }
1970 
print(spirv::GlobalVariableOp varOp,OpAsmPrinter & printer)1971 static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter &printer) {
1972   auto *op = varOp.getOperation();
1973   SmallVector<StringRef, 4> elidedAttrs{
1974       spirv::attributeName<spirv::StorageClass>()};
1975   printer << spirv::GlobalVariableOp::getOperationName();
1976 
1977   // Print variable name.
1978   printer << ' ';
1979   printer.printSymbolName(varOp.sym_name());
1980   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1981 
1982   // Print optional initializer
1983   if (auto initializer = varOp.initializer()) {
1984     printer << " " << kInitializerAttrName << '(';
1985     printer.printSymbolName(initializer.getValue());
1986     printer << ')';
1987     elidedAttrs.push_back(kInitializerAttrName);
1988   }
1989 
1990   elidedAttrs.push_back(kTypeAttrName);
1991   printVariableDecorations(op, printer, elidedAttrs);
1992   printer << " : " << varOp.type();
1993 }
1994 
verify(spirv::GlobalVariableOp varOp)1995 static LogicalResult verify(spirv::GlobalVariableOp varOp) {
1996   // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1997   // object. It cannot be Generic. It must be the same as the Storage Class
1998   // operand of the Result Type."
1999   // Also, Function storage class is reserved by spv.Variable.
2000   auto storageClass = varOp.storageClass();
2001   if (storageClass == spirv::StorageClass::Generic ||
2002       storageClass == spirv::StorageClass::Function) {
2003     return varOp.emitOpError("storage class cannot be '")
2004            << stringifyStorageClass(storageClass) << "'";
2005   }
2006 
2007   if (auto init =
2008           varOp->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2009     Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
2010         varOp->getParentOp(), init.getValue());
2011     // TODO: Currently only variable initialization with specialization
2012     // constants and other variables is supported. They could be normal
2013     // constants in the module scope as well.
2014     if (!initOp ||
2015         !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2016       return varOp.emitOpError("initializer must be result of a "
2017                                "spv.specConstant or spv.globalVariable op");
2018     }
2019   }
2020 
2021   return success();
2022 }
2023 
2024 //===----------------------------------------------------------------------===//
2025 // spv.GroupBroadcast
2026 //===----------------------------------------------------------------------===//
2027 
verify(spirv::GroupBroadcastOp broadcastOp)2028 static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) {
2029   spirv::Scope scope = broadcastOp.execution_scope();
2030   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2031     return broadcastOp.emitOpError(
2032         "execution scope must be 'Workgroup' or 'Subgroup'");
2033 
2034   if (auto localIdTy = broadcastOp.localid().getType().dyn_cast<VectorType>())
2035     if (!(localIdTy.getNumElements() == 2 || localIdTy.getNumElements() == 3))
2036       return broadcastOp.emitOpError("localid is a vector and can be with only "
2037                                      " 2 or 3 components, actual number is ")
2038              << localIdTy.getNumElements();
2039 
2040   return success();
2041 }
2042 
2043 //===----------------------------------------------------------------------===//
2044 // spv.GroupNonUniformBallotOp
2045 //===----------------------------------------------------------------------===//
2046 
verify(spirv::GroupNonUniformBallotOp ballotOp)2047 static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
2048   spirv::Scope scope = ballotOp.execution_scope();
2049   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2050     return ballotOp.emitOpError(
2051         "execution scope must be 'Workgroup' or 'Subgroup'");
2052 
2053   return success();
2054 }
2055 
2056 //===----------------------------------------------------------------------===//
2057 // spv.GroupNonUniformBroadcast
2058 //===----------------------------------------------------------------------===//
2059 
verify(spirv::GroupNonUniformBroadcastOp broadcastOp)2060 static LogicalResult verify(spirv::GroupNonUniformBroadcastOp broadcastOp) {
2061   spirv::Scope scope = broadcastOp.execution_scope();
2062   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2063     return broadcastOp.emitOpError(
2064         "execution scope must be 'Workgroup' or 'Subgroup'");
2065 
2066   // SPIR-V spec: "Before version 1.5, Id must come from a
2067   // constant instruction.
2068   auto targetEnv = spirv::getDefaultTargetEnv(broadcastOp.getContext());
2069   if (auto spirvModule = broadcastOp->getParentOfType<spirv::ModuleOp>())
2070     targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2071 
2072   if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2073     auto *idOp = broadcastOp.id().getDefiningOp();
2074     if (!idOp || !isa<spirv::ConstantOp,           // for normal constant
2075                       spirv::ReferenceOfOp>(idOp)) // for spec constant
2076       return broadcastOp.emitOpError("id must be the result of a constant op");
2077   }
2078 
2079   return success();
2080 }
2081 
2082 //===----------------------------------------------------------------------===//
2083 // spv.SubgroupBlockReadINTEL
2084 //===----------------------------------------------------------------------===//
2085 
parseSubgroupBlockReadINTELOp(OpAsmParser & parser,OperationState & state)2086 static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser,
2087                                                  OperationState &state) {
2088   // Parse the storage class specification
2089   spirv::StorageClass storageClass;
2090   OpAsmParser::OperandType ptrInfo;
2091   Type elementType;
2092   if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2093       parser.parseColon() || parser.parseType(elementType)) {
2094     return failure();
2095   }
2096 
2097   auto ptrType = spirv::PointerType::get(elementType, storageClass);
2098   if (auto valVecTy = elementType.dyn_cast<VectorType>())
2099     ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2100 
2101   if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2102     return failure();
2103   }
2104 
2105   state.addTypes(elementType);
2106   return success();
2107 }
2108 
print(spirv::SubgroupBlockReadINTELOp blockReadOp,OpAsmPrinter & printer)2109 static void print(spirv::SubgroupBlockReadINTELOp blockReadOp,
2110                   OpAsmPrinter &printer) {
2111   SmallVector<StringRef, 4> elidedAttrs;
2112   printer << spirv::SubgroupBlockReadINTELOp::getOperationName() << " "
2113           << blockReadOp.ptr();
2114   printer << " : " << blockReadOp.getType();
2115 }
2116 
verify(spirv::SubgroupBlockReadINTELOp blockReadOp)2117 static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) {
2118   if (failed(verifyBlockReadWritePtrAndValTypes(blockReadOp, blockReadOp.ptr(),
2119                                                 blockReadOp.value())))
2120     return failure();
2121 
2122   return success();
2123 }
2124 
2125 //===----------------------------------------------------------------------===//
2126 // spv.SubgroupBlockWriteINTEL
2127 //===----------------------------------------------------------------------===//
2128 
parseSubgroupBlockWriteINTELOp(OpAsmParser & parser,OperationState & state)2129 static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser,
2130                                                   OperationState &state) {
2131   // Parse the storage class specification
2132   spirv::StorageClass storageClass;
2133   SmallVector<OpAsmParser::OperandType, 2> operandInfo;
2134   auto loc = parser.getCurrentLocation();
2135   Type elementType;
2136   if (parseEnumStrAttr(storageClass, parser) ||
2137       parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2138       parser.parseType(elementType)) {
2139     return failure();
2140   }
2141 
2142   auto ptrType = spirv::PointerType::get(elementType, storageClass);
2143   if (auto valVecTy = elementType.dyn_cast<VectorType>())
2144     ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2145 
2146   if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2147                              state.operands)) {
2148     return failure();
2149   }
2150   return success();
2151 }
2152 
print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,OpAsmPrinter & printer)2153 static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,
2154                   OpAsmPrinter &printer) {
2155   SmallVector<StringRef, 4> elidedAttrs;
2156   printer << spirv::SubgroupBlockWriteINTELOp::getOperationName() << " "
2157           << blockWriteOp.ptr() << ", " << blockWriteOp.value();
2158   printer << " : " << blockWriteOp.value().getType();
2159 }
2160 
verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp)2161 static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) {
2162   if (failed(verifyBlockReadWritePtrAndValTypes(
2163           blockWriteOp, blockWriteOp.ptr(), blockWriteOp.value())))
2164     return failure();
2165 
2166   return success();
2167 }
2168 
2169 //===----------------------------------------------------------------------===//
2170 // spv.GroupNonUniformElectOp
2171 //===----------------------------------------------------------------------===//
2172 
build(OpBuilder & builder,OperationState & state,spirv::Scope scope)2173 void spirv::GroupNonUniformElectOp::build(OpBuilder &builder,
2174                                           OperationState &state,
2175                                           spirv::Scope scope) {
2176   build(builder, state, builder.getI1Type(), scope);
2177 }
2178 
verify(spirv::GroupNonUniformElectOp groupOp)2179 static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
2180   spirv::Scope scope = groupOp.execution_scope();
2181   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2182     return groupOp.emitOpError(
2183         "execution scope must be 'Workgroup' or 'Subgroup'");
2184 
2185   return success();
2186 }
2187 
2188 //===----------------------------------------------------------------------===//
2189 // spv.LoadOp
2190 //===----------------------------------------------------------------------===//
2191 
build(OpBuilder & builder,OperationState & state,Value basePtr,IntegerAttr memory_access,IntegerAttr alignment)2192 void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
2193                           Value basePtr, IntegerAttr memory_access,
2194                           IntegerAttr alignment) {
2195   auto ptrType = basePtr.getType().cast<spirv::PointerType>();
2196   build(builder, state, ptrType.getPointeeType(), basePtr, memory_access,
2197         alignment);
2198 }
2199 
parseLoadOp(OpAsmParser & parser,OperationState & state)2200 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &state) {
2201   // Parse the storage class specification
2202   spirv::StorageClass storageClass;
2203   OpAsmParser::OperandType ptrInfo;
2204   Type elementType;
2205   if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2206       parseMemoryAccessAttributes(parser, state) ||
2207       parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
2208       parser.parseType(elementType)) {
2209     return failure();
2210   }
2211 
2212   auto ptrType = spirv::PointerType::get(elementType, storageClass);
2213   if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2214     return failure();
2215   }
2216 
2217   state.addTypes(elementType);
2218   return success();
2219 }
2220 
print(spirv::LoadOp loadOp,OpAsmPrinter & printer)2221 static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
2222   auto *op = loadOp.getOperation();
2223   SmallVector<StringRef, 4> elidedAttrs;
2224   StringRef sc = stringifyStorageClass(
2225       loadOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
2226   printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
2227           << loadOp.ptr();
2228 
2229   printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
2230 
2231   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2232   printer << " : " << loadOp.getType();
2233 }
2234 
verify(spirv::LoadOp loadOp)2235 static LogicalResult verify(spirv::LoadOp loadOp) {
2236   // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
2237   // type with fixed size; i.e., it cannot be, nor include, any
2238   // OpTypeRuntimeArray types."
2239   if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(),
2240                                            loadOp.value()))) {
2241     return failure();
2242   }
2243   return verifyMemoryAccessAttribute(loadOp);
2244 }
2245 
2246 //===----------------------------------------------------------------------===//
2247 // spv.loop
2248 //===----------------------------------------------------------------------===//
2249 
build(OpBuilder & builder,OperationState & state)2250 void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
2251   state.addAttribute("loop_control",
2252                      builder.getI32IntegerAttr(
2253                          static_cast<uint32_t>(spirv::LoopControl::None)));
2254   state.addRegion();
2255 }
2256 
parseLoopOp(OpAsmParser & parser,OperationState & state)2257 static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) {
2258   if (parseControlAttribute<spirv::LoopControl>(parser, state))
2259     return failure();
2260   return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2261                             /*argTypes=*/{});
2262 }
2263 
print(spirv::LoopOp loopOp,OpAsmPrinter & printer)2264 static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
2265   auto *op = loopOp.getOperation();
2266 
2267   printer << spirv::LoopOp::getOperationName();
2268   auto control = loopOp.loop_control();
2269   if (control != spirv::LoopControl::None)
2270     printer << " control(" << spirv::stringifyLoopControl(control) << ")";
2271   printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
2272                       /*printBlockTerminators=*/true);
2273 }
2274 
2275 /// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
2276 /// given `dstBlock`.
hasOneBranchOpTo(Block & srcBlock,Block & dstBlock)2277 static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
2278   // Check that there is only one op in the `srcBlock`.
2279   if (!llvm::hasSingleElement(srcBlock))
2280     return false;
2281 
2282   auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
2283   return branchOp && branchOp.getSuccessor() == &dstBlock;
2284 }
2285 
verify(spirv::LoopOp loopOp)2286 static LogicalResult verify(spirv::LoopOp loopOp) {
2287   auto *op = loopOp.getOperation();
2288 
2289   // We need to verify that the blocks follow the following layout:
2290   //
2291   //                     +-------------+
2292   //                     | entry block |
2293   //                     +-------------+
2294   //                            |
2295   //                            v
2296   //                     +-------------+
2297   //                     | loop header | <-----+
2298   //                     +-------------+       |
2299   //                                           |
2300   //                           ...             |
2301   //                          \ | /            |
2302   //                            v              |
2303   //                    +---------------+      |
2304   //                    | loop continue | -----+
2305   //                    +---------------+
2306   //
2307   //                           ...
2308   //                          \ | /
2309   //                            v
2310   //                     +-------------+
2311   //                     | merge block |
2312   //                     +-------------+
2313 
2314   auto &region = op->getRegion(0);
2315   // Allow empty region as a degenerated case, which can come from
2316   // optimizations.
2317   if (region.empty())
2318     return success();
2319 
2320   // The last block is the merge block.
2321   Block &merge = region.back();
2322   if (!isMergeBlock(merge))
2323     return loopOp.emitOpError(
2324         "last block must be the merge block with only one 'spv.mlir.merge' op");
2325 
2326   if (std::next(region.begin()) == region.end())
2327     return loopOp.emitOpError(
2328         "must have an entry block branching to the loop header block");
2329   // The first block is the entry block.
2330   Block &entry = region.front();
2331 
2332   if (std::next(region.begin(), 2) == region.end())
2333     return loopOp.emitOpError(
2334         "must have a loop header block branched from the entry block");
2335   // The second block is the loop header block.
2336   Block &header = *std::next(region.begin(), 1);
2337 
2338   if (!hasOneBranchOpTo(entry, header))
2339     return loopOp.emitOpError(
2340         "entry block must only have one 'spv.Branch' op to the second block");
2341 
2342   if (std::next(region.begin(), 3) == region.end())
2343     return loopOp.emitOpError(
2344         "requires a loop continue block branching to the loop header block");
2345   // The second to last block is the loop continue block.
2346   Block &cont = *std::prev(region.end(), 2);
2347 
2348   // Make sure that we have a branch from the loop continue block to the loop
2349   // header block.
2350   if (llvm::none_of(
2351           llvm::seq<unsigned>(0, cont.getNumSuccessors()),
2352           [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
2353     return loopOp.emitOpError("second to last block must be the loop continue "
2354                               "block that branches to the loop header block");
2355 
2356   // Make sure that no other blocks (except the entry and loop continue block)
2357   // branches to the loop header block.
2358   for (auto &block : llvm::make_range(std::next(region.begin(), 2),
2359                                       std::prev(region.end(), 2))) {
2360     for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
2361       if (block.getSuccessor(i) == &header) {
2362         return loopOp.emitOpError("can only have the entry and loop continue "
2363                                   "block branching to the loop header block");
2364       }
2365     }
2366   }
2367 
2368   return success();
2369 }
2370 
getEntryBlock()2371 Block *spirv::LoopOp::getEntryBlock() {
2372   assert(!body().empty() && "op region should not be empty!");
2373   return &body().front();
2374 }
2375 
getHeaderBlock()2376 Block *spirv::LoopOp::getHeaderBlock() {
2377   assert(!body().empty() && "op region should not be empty!");
2378   // The second block is the loop header block.
2379   return &*std::next(body().begin());
2380 }
2381 
getContinueBlock()2382 Block *spirv::LoopOp::getContinueBlock() {
2383   assert(!body().empty() && "op region should not be empty!");
2384   // The second to last block is the loop continue block.
2385   return &*std::prev(body().end(), 2);
2386 }
2387 
getMergeBlock()2388 Block *spirv::LoopOp::getMergeBlock() {
2389   assert(!body().empty() && "op region should not be empty!");
2390   // The last block is the loop merge block.
2391   return &body().back();
2392 }
2393 
addEntryAndMergeBlock()2394 void spirv::LoopOp::addEntryAndMergeBlock() {
2395   assert(body().empty() && "entry and merge block already exist");
2396   body().push_back(new Block());
2397   auto *mergeBlock = new Block();
2398   body().push_back(mergeBlock);
2399   OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
2400 
2401   // Add a spv.mlir.merge op into the merge block.
2402   builder.create<spirv::MergeOp>(getLoc());
2403 }
2404 
2405 //===----------------------------------------------------------------------===//
2406 // spv.mlir.merge
2407 //===----------------------------------------------------------------------===//
2408 
verify(spirv::MergeOp mergeOp)2409 static LogicalResult verify(spirv::MergeOp mergeOp) {
2410   auto *parentOp = mergeOp->getParentOp();
2411   if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
2412     return mergeOp.emitOpError(
2413         "expected parent op to be 'spv.selection' or 'spv.loop'");
2414 
2415   Block &parentLastBlock = mergeOp->getParentRegion()->back();
2416   if (mergeOp.getOperation() != parentLastBlock.getTerminator())
2417     return mergeOp.emitOpError(
2418         "can only be used in the last block of 'spv.selection' or 'spv.loop'");
2419   return success();
2420 }
2421 
2422 //===----------------------------------------------------------------------===//
2423 // spv.module
2424 //===----------------------------------------------------------------------===//
2425 
build(OpBuilder & builder,OperationState & state,Optional<StringRef> name)2426 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
2427                             Optional<StringRef> name) {
2428   ensureTerminator(*state.addRegion(), builder, state.location);
2429   if (name) {
2430     state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
2431                             builder.getStringAttr(*name));
2432   }
2433 }
2434 
build(OpBuilder & builder,OperationState & state,spirv::AddressingModel addressingModel,spirv::MemoryModel memoryModel,Optional<StringRef> name)2435 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
2436                             spirv::AddressingModel addressingModel,
2437                             spirv::MemoryModel memoryModel,
2438                             Optional<StringRef> name) {
2439   state.addAttribute(
2440       "addressing_model",
2441       builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
2442   state.addAttribute("memory_model", builder.getI32IntegerAttr(
2443                                          static_cast<int32_t>(memoryModel)));
2444   ensureTerminator(*state.addRegion(), builder, state.location);
2445   if (name) {
2446     state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
2447                             builder.getStringAttr(*name));
2448   }
2449 }
2450 
parseModuleOp(OpAsmParser & parser,OperationState & state)2451 static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
2452   Region *body = state.addRegion();
2453 
2454   // If the name is present, parse it.
2455   StringAttr nameAttr;
2456   parser.parseOptionalSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2457                                  state.attributes);
2458 
2459   // Parse attributes
2460   spirv::AddressingModel addrModel;
2461   spirv::MemoryModel memoryModel;
2462   if (parseEnumKeywordAttr(addrModel, parser, state) ||
2463       parseEnumKeywordAttr(memoryModel, parser, state))
2464     return failure();
2465 
2466   if (succeeded(parser.parseOptionalKeyword("requires"))) {
2467     spirv::VerCapExtAttr vceTriple;
2468     if (parser.parseAttribute(vceTriple,
2469                               spirv::ModuleOp::getVCETripleAttrName(),
2470                               state.attributes))
2471       return failure();
2472   }
2473 
2474   if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
2475     return failure();
2476 
2477   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2478     return failure();
2479 
2480   spirv::ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location);
2481   return success();
2482 }
2483 
print(spirv::ModuleOp moduleOp,OpAsmPrinter & printer)2484 static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
2485   printer << spirv::ModuleOp::getOperationName();
2486 
2487   if (Optional<StringRef> name = moduleOp.getName()) {
2488     printer << ' ';
2489     printer.printSymbolName(*name);
2490   }
2491 
2492   SmallVector<StringRef, 2> elidedAttrs;
2493 
2494   printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model())
2495           << " " << spirv::stringifyMemoryModel(moduleOp.memory_model());
2496   auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
2497   auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
2498   elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
2499                       SymbolTable::getSymbolAttrName()});
2500 
2501   if (Optional<spirv::VerCapExtAttr> triple = moduleOp.vce_triple()) {
2502     printer << " requires " << *triple;
2503     elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
2504   }
2505 
2506   printer.printOptionalAttrDictWithKeyword(moduleOp.getAttrs(), elidedAttrs);
2507   printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
2508                       /*printBlockTerminators=*/false);
2509 }
2510 
verify(spirv::ModuleOp moduleOp)2511 static LogicalResult verify(spirv::ModuleOp moduleOp) {
2512   auto &op = *moduleOp.getOperation();
2513   auto *dialect = op.getDialect();
2514   DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
2515       entryPoints;
2516   SymbolTable table(moduleOp);
2517 
2518   for (auto &op : moduleOp.getBlock()) {
2519     if (op.getDialect() != dialect)
2520       return op.emitError("'spv.module' can only contain spv.* ops");
2521 
2522     // For EntryPoint op, check that the function and execution model is not
2523     // duplicated in EntryPointOps. Also verify that the interface specified
2524     // comes from globalVariables here to make this check cheaper.
2525     if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
2526       auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
2527       if (!funcOp) {
2528         return entryPointOp.emitError("function '")
2529                << entryPointOp.fn() << "' not found in 'spv.module'";
2530       }
2531       if (auto interface = entryPointOp.interface()) {
2532         for (Attribute varRef : interface) {
2533           auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
2534           if (!varSymRef) {
2535             return entryPointOp.emitError(
2536                        "expected symbol reference for interface "
2537                        "specification instead of '")
2538                    << varRef;
2539           }
2540           auto variableOp =
2541               table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
2542           if (!variableOp) {
2543             return entryPointOp.emitError("expected spv.globalVariable "
2544                                           "symbol reference instead of'")
2545                    << varSymRef << "'";
2546           }
2547         }
2548       }
2549 
2550       auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
2551           funcOp, entryPointOp.execution_model());
2552       auto entryPtIt = entryPoints.find(key);
2553       if (entryPtIt != entryPoints.end()) {
2554         return entryPointOp.emitError("duplicate of a previous EntryPointOp");
2555       }
2556       entryPoints[key] = entryPointOp;
2557     } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
2558       if (funcOp.isExternal())
2559         return op.emitError("'spv.module' cannot contain external functions");
2560 
2561       // TODO: move this check to spv.func.
2562       for (auto &block : funcOp)
2563         for (auto &op : block) {
2564           if (op.getDialect() != dialect)
2565             return op.emitError(
2566                 "functions in 'spv.module' can only contain spv.* ops");
2567         }
2568     }
2569   }
2570 
2571   return success();
2572 }
2573 
2574 //===----------------------------------------------------------------------===//
2575 // spv.mlir.referenceof
2576 //===----------------------------------------------------------------------===//
2577 
verify(spirv::ReferenceOfOp referenceOfOp)2578 static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
2579   auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
2580       referenceOfOp->getParentOp(), referenceOfOp.spec_const());
2581   Type constType;
2582 
2583   auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
2584   if (specConstOp)
2585     constType = specConstOp.default_value().getType();
2586 
2587   auto specConstCompositeOp =
2588       dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
2589   if (specConstCompositeOp)
2590     constType = specConstCompositeOp.type();
2591 
2592   if (!specConstOp && !specConstCompositeOp)
2593     return referenceOfOp.emitOpError(
2594         "expected spv.specConstant or spv.SpecConstantComposite symbol");
2595 
2596   if (referenceOfOp.reference().getType() != constType)
2597     return referenceOfOp.emitOpError("result type mismatch with the referenced "
2598                                      "specialization constant's type");
2599 
2600   return success();
2601 }
2602 
2603 //===----------------------------------------------------------------------===//
2604 // spv.Return
2605 //===----------------------------------------------------------------------===//
2606 
verify(spirv::ReturnOp returnOp)2607 static LogicalResult verify(spirv::ReturnOp returnOp) {
2608   // Verification is performed in spv.func op.
2609   return success();
2610 }
2611 
2612 //===----------------------------------------------------------------------===//
2613 // spv.ReturnValue
2614 //===----------------------------------------------------------------------===//
2615 
verify(spirv::ReturnValueOp retValOp)2616 static LogicalResult verify(spirv::ReturnValueOp retValOp) {
2617   // Verification is performed in spv.func op.
2618   return success();
2619 }
2620 
2621 //===----------------------------------------------------------------------===//
2622 // spv.Select
2623 //===----------------------------------------------------------------------===//
2624 
build(OpBuilder & builder,OperationState & state,Value cond,Value trueValue,Value falseValue)2625 void spirv::SelectOp::build(OpBuilder &builder, OperationState &state,
2626                             Value cond, Value trueValue, Value falseValue) {
2627   build(builder, state, trueValue.getType(), cond, trueValue, falseValue);
2628 }
2629 
verify(spirv::SelectOp op)2630 static LogicalResult verify(spirv::SelectOp op) {
2631   if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
2632     auto resultVectorTy = op.result().getType().dyn_cast<VectorType>();
2633     if (!resultVectorTy) {
2634       return op.emitOpError("result expected to be of vector type when "
2635                             "condition is of vector type");
2636     }
2637     if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
2638       return op.emitOpError("result should have the same number of elements as "
2639                             "the condition when condition is of vector type");
2640     }
2641   }
2642   return success();
2643 }
2644 
2645 //===----------------------------------------------------------------------===//
2646 // spv.selection
2647 //===----------------------------------------------------------------------===//
2648 
parseSelectionOp(OpAsmParser & parser,OperationState & state)2649 static ParseResult parseSelectionOp(OpAsmParser &parser,
2650                                     OperationState &state) {
2651   if (parseControlAttribute<spirv::SelectionControl>(parser, state))
2652     return failure();
2653   return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2654                             /*argTypes=*/{});
2655 }
2656 
print(spirv::SelectionOp selectionOp,OpAsmPrinter & printer)2657 static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) {
2658   auto *op = selectionOp.getOperation();
2659 
2660   printer << spirv::SelectionOp::getOperationName();
2661   auto control = selectionOp.selection_control();
2662   if (control != spirv::SelectionControl::None)
2663     printer << " control(" << spirv::stringifySelectionControl(control) << ")";
2664   printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
2665                       /*printBlockTerminators=*/true);
2666 }
2667 
verify(spirv::SelectionOp selectionOp)2668 static LogicalResult verify(spirv::SelectionOp selectionOp) {
2669   auto *op = selectionOp.getOperation();
2670 
2671   // We need to verify that the blocks follow the following layout:
2672   //
2673   //                     +--------------+
2674   //                     | header block |
2675   //                     +--------------+
2676   //                          / | \
2677   //                           ...
2678   //
2679   //
2680   //         +---------+   +---------+   +---------+
2681   //         | case #0 |   | case #1 |   | case #2 |  ...
2682   //         +---------+   +---------+   +---------+
2683   //
2684   //
2685   //                           ...
2686   //                          \ | /
2687   //                            v
2688   //                     +-------------+
2689   //                     | merge block |
2690   //                     +-------------+
2691 
2692   auto &region = op->getRegion(0);
2693   // Allow empty region as a degenerated case, which can come from
2694   // optimizations.
2695   if (region.empty())
2696     return success();
2697 
2698   // The last block is the merge block.
2699   if (!isMergeBlock(region.back()))
2700     return selectionOp.emitOpError(
2701         "last block must be the merge block with only one 'spv.mlir.merge' op");
2702 
2703   if (std::next(region.begin()) == region.end())
2704     return selectionOp.emitOpError("must have a selection header block");
2705 
2706   return success();
2707 }
2708 
getHeaderBlock()2709 Block *spirv::SelectionOp::getHeaderBlock() {
2710   assert(!body().empty() && "op region should not be empty!");
2711   // The first block is the loop header block.
2712   return &body().front();
2713 }
2714 
getMergeBlock()2715 Block *spirv::SelectionOp::getMergeBlock() {
2716   assert(!body().empty() && "op region should not be empty!");
2717   // The last block is the loop merge block.
2718   return &body().back();
2719 }
2720 
addMergeBlock()2721 void spirv::SelectionOp::addMergeBlock() {
2722   assert(body().empty() && "entry and merge block already exist");
2723   auto *mergeBlock = new Block();
2724   body().push_back(mergeBlock);
2725   OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
2726 
2727   // Add a spv.mlir.merge op into the merge block.
2728   builder.create<spirv::MergeOp>(getLoc());
2729 }
2730 
createIfThen(Location loc,Value condition,function_ref<void (OpBuilder & builder)> thenBody,OpBuilder & builder)2731 spirv::SelectionOp spirv::SelectionOp::createIfThen(
2732     Location loc, Value condition,
2733     function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
2734   auto selectionControl = builder.getI32IntegerAttr(
2735       static_cast<uint32_t>(spirv::SelectionControl::None));
2736   auto selectionOp = builder.create<spirv::SelectionOp>(loc, selectionControl);
2737 
2738   selectionOp.addMergeBlock();
2739   Block *mergeBlock = selectionOp.getMergeBlock();
2740   Block *thenBlock = nullptr;
2741 
2742   // Build the "then" block.
2743   {
2744     OpBuilder::InsertionGuard guard(builder);
2745     thenBlock = builder.createBlock(mergeBlock);
2746     thenBody(builder);
2747     builder.create<spirv::BranchOp>(loc, mergeBlock);
2748   }
2749 
2750   // Build the header block.
2751   {
2752     OpBuilder::InsertionGuard guard(builder);
2753     builder.createBlock(thenBlock);
2754     builder.create<spirv::BranchConditionalOp>(
2755         loc, condition, thenBlock,
2756         /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
2757         /*falseArguments=*/ArrayRef<Value>());
2758   }
2759 
2760   return selectionOp;
2761 }
2762 
2763 //===----------------------------------------------------------------------===//
2764 // spv.specConstant
2765 //===----------------------------------------------------------------------===//
2766 
parseSpecConstantOp(OpAsmParser & parser,OperationState & state)2767 static ParseResult parseSpecConstantOp(OpAsmParser &parser,
2768                                        OperationState &state) {
2769   StringAttr nameAttr;
2770   Attribute valueAttr;
2771 
2772   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2773                              state.attributes))
2774     return failure();
2775 
2776   // Parse optional spec_id.
2777   if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
2778     IntegerAttr specIdAttr;
2779     if (parser.parseLParen() ||
2780         parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) ||
2781         parser.parseRParen())
2782       return failure();
2783   }
2784 
2785   if (parser.parseEqual() ||
2786       parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes))
2787     return failure();
2788 
2789   return success();
2790 }
2791 
print(spirv::SpecConstantOp constOp,OpAsmPrinter & printer)2792 static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
2793   printer << spirv::SpecConstantOp::getOperationName() << ' ';
2794   printer.printSymbolName(constOp.sym_name());
2795   if (auto specID = constOp->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
2796     printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
2797   printer << " = " << constOp.default_value();
2798 }
2799 
verify(spirv::SpecConstantOp constOp)2800 static LogicalResult verify(spirv::SpecConstantOp constOp) {
2801   if (auto specID = constOp->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
2802     if (specID.getValue().isNegative())
2803       return constOp.emitOpError("SpecId cannot be negative");
2804 
2805   auto value = constOp.default_value();
2806   if (value.isa<IntegerAttr, FloatAttr>()) {
2807     // Make sure bitwidth is allowed.
2808     if (!value.getType().isa<spirv::SPIRVType>())
2809       return constOp.emitOpError("default value bitwidth disallowed");
2810     return success();
2811   }
2812   return constOp.emitOpError(
2813       "default value can only be a bool, integer, or float scalar");
2814 }
2815 
2816 //===----------------------------------------------------------------------===//
2817 // spv.StoreOp
2818 //===----------------------------------------------------------------------===//
2819 
parseStoreOp(OpAsmParser & parser,OperationState & state)2820 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &state) {
2821   // Parse the storage class specification
2822   spirv::StorageClass storageClass;
2823   SmallVector<OpAsmParser::OperandType, 2> operandInfo;
2824   auto loc = parser.getCurrentLocation();
2825   Type elementType;
2826   if (parseEnumStrAttr(storageClass, parser) ||
2827       parser.parseOperandList(operandInfo, 2) ||
2828       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
2829       parser.parseType(elementType)) {
2830     return failure();
2831   }
2832 
2833   auto ptrType = spirv::PointerType::get(elementType, storageClass);
2834   if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2835                              state.operands)) {
2836     return failure();
2837   }
2838   return success();
2839 }
2840 
print(spirv::StoreOp storeOp,OpAsmPrinter & printer)2841 static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
2842   auto *op = storeOp.getOperation();
2843   SmallVector<StringRef, 4> elidedAttrs;
2844   StringRef sc = stringifyStorageClass(
2845       storeOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
2846   printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
2847           << storeOp.ptr() << ", " << storeOp.value();
2848 
2849   printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
2850 
2851   printer << " : " << storeOp.value().getType();
2852   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2853 }
2854 
verify(spirv::StoreOp storeOp)2855 static LogicalResult verify(spirv::StoreOp storeOp) {
2856   // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
2857   // OpTypePointer whose Type operand is the same as the type of Object."
2858   if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(),
2859                                            storeOp.value()))) {
2860     return failure();
2861   }
2862   return verifyMemoryAccessAttribute(storeOp);
2863 }
2864 
2865 //===----------------------------------------------------------------------===//
2866 // spv.Unreachable
2867 //===----------------------------------------------------------------------===//
2868 
verify(spirv::UnreachableOp unreachableOp)2869 static LogicalResult verify(spirv::UnreachableOp unreachableOp) {
2870   auto *op = unreachableOp.getOperation();
2871   auto *block = op->getBlock();
2872   // Fast track: if this is in entry block, its invalid. Otherwise, if no
2873   // predecessors, it's valid.
2874   if (block->isEntryBlock())
2875     return unreachableOp.emitOpError("cannot be used in reachable block");
2876   if (block->hasNoPredecessors())
2877     return success();
2878 
2879   // TODO: further verification needs to analyze reachability from
2880   // the entry block.
2881 
2882   return success();
2883 }
2884 
2885 //===----------------------------------------------------------------------===//
2886 // spv.Variable
2887 //===----------------------------------------------------------------------===//
2888 
parseVariableOp(OpAsmParser & parser,OperationState & state)2889 static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
2890   // Parse optional initializer
2891   Optional<OpAsmParser::OperandType> initInfo;
2892   if (succeeded(parser.parseOptionalKeyword("init"))) {
2893     initInfo = OpAsmParser::OperandType();
2894     if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
2895         parser.parseRParen())
2896       return failure();
2897   }
2898 
2899   if (parseVariableDecorations(parser, state)) {
2900     return failure();
2901   }
2902 
2903   // Parse result pointer type
2904   Type type;
2905   if (parser.parseColon())
2906     return failure();
2907   auto loc = parser.getCurrentLocation();
2908   if (parser.parseType(type))
2909     return failure();
2910 
2911   auto ptrType = type.dyn_cast<spirv::PointerType>();
2912   if (!ptrType)
2913     return parser.emitError(loc, "expected spv.ptr type");
2914   state.addTypes(ptrType);
2915 
2916   // Resolve the initializer operand
2917   if (initInfo) {
2918     if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
2919                               state.operands))
2920       return failure();
2921   }
2922 
2923   auto attr = parser.getBuilder().getI32IntegerAttr(
2924       llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
2925   state.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
2926 
2927   return success();
2928 }
2929 
print(spirv::VariableOp varOp,OpAsmPrinter & printer)2930 static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
2931   SmallVector<StringRef, 4> elidedAttrs{
2932       spirv::attributeName<spirv::StorageClass>()};
2933   printer << spirv::VariableOp::getOperationName();
2934 
2935   // Print optional initializer
2936   if (varOp.getNumOperands() != 0)
2937     printer << " init(" << varOp.initializer() << ")";
2938 
2939   printVariableDecorations(varOp, printer, elidedAttrs);
2940   printer << " : " << varOp.getType();
2941 }
2942 
verify(spirv::VariableOp varOp)2943 static LogicalResult verify(spirv::VariableOp varOp) {
2944   // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2945   // object. It cannot be Generic. It must be the same as the Storage Class
2946   // operand of the Result Type."
2947   if (varOp.storage_class() != spirv::StorageClass::Function) {
2948     return varOp.emitOpError(
2949         "can only be used to model function-level variables. Use "
2950         "spv.globalVariable for module-level variables.");
2951   }
2952 
2953   auto pointerType = varOp.pointer().getType().cast<spirv::PointerType>();
2954   if (varOp.storage_class() != pointerType.getStorageClass())
2955     return varOp.emitOpError(
2956         "storage class must match result pointer's storage class");
2957 
2958   if (varOp.getNumOperands() != 0) {
2959     // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
2960     // a global (module scope) OpVariable instruction".
2961     auto *initOp = varOp.getOperand(0).getDefiningOp();
2962     if (!initOp || !isa<spirv::ConstantOp,    // for normal constant
2963                         spirv::ReferenceOfOp, // for spec constant
2964                         spirv::AddressOfOp>(initOp))
2965       return varOp.emitOpError("initializer must be the result of a "
2966                                "constant or spv.globalVariable op");
2967   }
2968 
2969   // TODO: generate these strings using ODS.
2970   auto *op = varOp.getOperation();
2971   auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
2972       stringifyDecoration(spirv::Decoration::DescriptorSet));
2973   auto bindingName = llvm::convertToSnakeFromCamelCase(
2974       stringifyDecoration(spirv::Decoration::Binding));
2975   auto builtInName = llvm::convertToSnakeFromCamelCase(
2976       stringifyDecoration(spirv::Decoration::BuiltIn));
2977 
2978   for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
2979     if (op->getAttr(attr))
2980       return varOp.emitOpError("cannot have '")
2981              << attr << "' attribute (only allowed in spv.globalVariable)";
2982   }
2983 
2984   return success();
2985 }
2986 
2987 //===----------------------------------------------------------------------===//
2988 // spv.CooperativeMatrixLoadNV
2989 //===----------------------------------------------------------------------===//
2990 
parseCooperativeMatrixLoadNVOp(OpAsmParser & parser,OperationState & state)2991 static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
2992                                                   OperationState &state) {
2993   SmallVector<OpAsmParser::OperandType, 3> operandInfo;
2994   Type strideType = parser.getBuilder().getIntegerType(32);
2995   Type columnMajorType = parser.getBuilder().getIntegerType(1);
2996   Type ptrType;
2997   Type elementType;
2998   if (parser.parseOperandList(operandInfo, 3) ||
2999       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3000       parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3001     return failure();
3002   }
3003   if (parser.resolveOperands(operandInfo,
3004                              {ptrType, strideType, columnMajorType},
3005                              parser.getNameLoc(), state.operands)) {
3006     return failure();
3007   }
3008 
3009   state.addTypes(elementType);
3010   return success();
3011 }
3012 
print(spirv::CooperativeMatrixLoadNVOp M,OpAsmPrinter & printer)3013 static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
3014   printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " "
3015           << M.pointer() << ", " << M.stride() << ", " << M.columnmajor();
3016   // Print optional memory access attribute.
3017   if (auto memAccess = M.memory_access())
3018     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3019   printer << " : " << M.pointer().getType() << " as " << M.getType();
3020 }
3021 
verifyPointerAndCoopMatrixType(Operation * op,Type pointer,Type coopMatrix)3022 static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
3023                                                     Type coopMatrix) {
3024   Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3025   if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3026     return op->emitError(
3027                "Pointer must point to a scalar or vector type but provided ")
3028            << pointeeType;
3029   spirv::StorageClass storage =
3030       pointer.cast<spirv::PointerType>().getStorageClass();
3031   if (storage != spirv::StorageClass::Workgroup &&
3032       storage != spirv::StorageClass::StorageBuffer &&
3033       storage != spirv::StorageClass::PhysicalStorageBuffer)
3034     return op->emitError(
3035                "Pointer storage class must be Workgroup, StorageBuffer or "
3036                "PhysicalStorageBufferEXT but provided ")
3037            << stringifyStorageClass(storage);
3038   return success();
3039 }
3040 
3041 //===----------------------------------------------------------------------===//
3042 // spv.CooperativeMatrixStoreNV
3043 //===----------------------------------------------------------------------===//
3044 
parseCooperativeMatrixStoreNVOp(OpAsmParser & parser,OperationState & state)3045 static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
3046                                                    OperationState &state) {
3047   SmallVector<OpAsmParser::OperandType, 4> operandInfo;
3048   Type strideType = parser.getBuilder().getIntegerType(32);
3049   Type columnMajorType = parser.getBuilder().getIntegerType(1);
3050   Type ptrType;
3051   Type elementType;
3052   if (parser.parseOperandList(operandInfo, 4) ||
3053       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3054       parser.parseType(ptrType) || parser.parseComma() ||
3055       parser.parseType(elementType)) {
3056     return failure();
3057   }
3058   if (parser.resolveOperands(
3059           operandInfo, {ptrType, elementType, strideType, columnMajorType},
3060           parser.getNameLoc(), state.operands)) {
3061     return failure();
3062   }
3063 
3064   return success();
3065 }
3066 
print(spirv::CooperativeMatrixStoreNVOp coopMatrix,OpAsmPrinter & printer)3067 static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
3068                   OpAsmPrinter &printer) {
3069   printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " "
3070           << coopMatrix.pointer() << ", " << coopMatrix.object() << ", "
3071           << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
3072   // Print optional memory access attribute.
3073   if (auto memAccess = coopMatrix.memory_access())
3074     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3075   printer << " : " << coopMatrix.pointer().getType() << ", "
3076           << coopMatrix.getOperand(1).getType();
3077 }
3078 
3079 //===----------------------------------------------------------------------===//
3080 // spv.CooperativeMatrixMulAddNV
3081 //===----------------------------------------------------------------------===//
3082 
3083 static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op)3084 verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
3085   if (op.c().getType() != op.result().getType())
3086     return op.emitOpError("result and third operand must have the same type");
3087   auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
3088   auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
3089   auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
3090   auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
3091   if (typeA.getRows() != typeR.getRows() ||
3092       typeA.getColumns() != typeB.getRows() ||
3093       typeB.getColumns() != typeR.getColumns())
3094     return op.emitOpError("matrix size must match");
3095   if (typeR.getScope() != typeA.getScope() ||
3096       typeR.getScope() != typeB.getScope() ||
3097       typeR.getScope() != typeC.getScope())
3098     return op.emitOpError("matrix scope must match");
3099   if (typeA.getElementType() != typeB.getElementType() ||
3100       typeR.getElementType() != typeC.getElementType())
3101     return op.emitOpError("matrix element type must match");
3102   return success();
3103 }
3104 
3105 //===----------------------------------------------------------------------===//
3106 // spv.MatrixTimesScalar
3107 //===----------------------------------------------------------------------===//
3108 
verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op)3109 static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
3110   // We already checked that result and matrix are both of matrix type in the
3111   // auto-generated verify method.
3112 
3113   auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
3114   auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3115 
3116   // Check that the scalar type is the same as the matrix element type.
3117   if (op.scalar().getType() != inputMatrix.getElementType())
3118     return op.emitError("input matrix components' type and scaling value must "
3119                         "have the same type");
3120 
3121   // Note that the next three checks could be done using the AllTypesMatch
3122   // trait in the Op definition file but it generates a vague error message.
3123 
3124   // Check that the input and result matrices have the same columns' count
3125   if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
3126     return op.emitError("input and result matrices must have the same "
3127                         "number of columns");
3128 
3129   // Check that the input and result matrices' have the same rows count
3130   if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
3131     return op.emitError("input and result matrices' columns must have "
3132                         "the same size");
3133 
3134   // Check that the input and result matrices' have the same component type
3135   if (inputMatrix.getElementType() != resultMatrix.getElementType())
3136     return op.emitError("input and result matrices' columns must have "
3137                         "the same component type");
3138 
3139   return success();
3140 }
3141 
3142 //===----------------------------------------------------------------------===//
3143 // spv.CopyMemory
3144 //===----------------------------------------------------------------------===//
3145 
print(spirv::CopyMemoryOp copyMemory,OpAsmPrinter & printer)3146 static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {
3147   auto *op = copyMemory.getOperation();
3148   printer << spirv::CopyMemoryOp::getOperationName() << ' ';
3149 
3150   StringRef targetStorageClass =
3151       stringifyStorageClass(copyMemory.target()
3152                                 .getType()
3153                                 .cast<spirv::PointerType>()
3154                                 .getStorageClass());
3155   printer << " \"" << targetStorageClass << "\" " << copyMemory.target()
3156           << ", ";
3157 
3158   StringRef sourceStorageClass =
3159       stringifyStorageClass(copyMemory.source()
3160                                 .getType()
3161                                 .cast<spirv::PointerType>()
3162                                 .getStorageClass());
3163   printer << " \"" << sourceStorageClass << "\" " << copyMemory.source();
3164 
3165   SmallVector<StringRef, 4> elidedAttrs;
3166   printMemoryAccessAttribute(copyMemory, printer, elidedAttrs);
3167   printSourceMemoryAccessAttribute(copyMemory, printer, elidedAttrs,
3168                                    copyMemory.source_memory_access(),
3169                                    copyMemory.source_alignment());
3170 
3171   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3172 
3173   Type pointeeType =
3174       copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
3175   printer << " : " << pointeeType;
3176 }
3177 
parseCopyMemoryOp(OpAsmParser & parser,OperationState & state)3178 static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
3179                                      OperationState &state) {
3180   spirv::StorageClass targetStorageClass;
3181   OpAsmParser::OperandType targetPtrInfo;
3182 
3183   spirv::StorageClass sourceStorageClass;
3184   OpAsmParser::OperandType sourcePtrInfo;
3185 
3186   Type elementType;
3187 
3188   if (parseEnumStrAttr(targetStorageClass, parser) ||
3189       parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
3190       parseEnumStrAttr(sourceStorageClass, parser) ||
3191       parser.parseOperand(sourcePtrInfo) ||
3192       parseMemoryAccessAttributes(parser, state)) {
3193     return failure();
3194   }
3195 
3196   if (!parser.parseOptionalComma()) {
3197     // Parse 2nd memory access attributes.
3198     if (parseSourceMemoryAccessAttributes(parser, state)) {
3199       return failure();
3200     }
3201   }
3202 
3203   if (parser.parseColon() || parser.parseType(elementType))
3204     return failure();
3205 
3206   if (parser.parseOptionalAttrDict(state.attributes))
3207     return failure();
3208 
3209   auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
3210   auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
3211 
3212   if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
3213       parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
3214     return failure();
3215   }
3216 
3217   return success();
3218 }
3219 
verifyCopyMemory(spirv::CopyMemoryOp copyMemory)3220 static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
3221   Type targetType =
3222       copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
3223 
3224   Type sourceType =
3225       copyMemory.source().getType().cast<spirv::PointerType>().getPointeeType();
3226 
3227   if (targetType != sourceType) {
3228     return copyMemory.emitOpError(
3229         "both operands must be pointers to the same type");
3230   }
3231 
3232   if (failed(verifyMemoryAccessAttribute(copyMemory))) {
3233     return failure();
3234   }
3235 
3236   // TODO - According to the spec:
3237   //
3238   // If two masks are present, the first applies to Target and cannot include
3239   // MakePointerVisible, and the second applies to Source and cannot include
3240   // MakePointerAvailable.
3241   //
3242   // Add such verification here.
3243 
3244   return verifySourceMemoryAccessAttribute(copyMemory);
3245 }
3246 
3247 //===----------------------------------------------------------------------===//
3248 // spv.Transpose
3249 //===----------------------------------------------------------------------===//
3250 
verifyTranspose(spirv::TransposeOp op)3251 static LogicalResult verifyTranspose(spirv::TransposeOp op) {
3252   auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
3253   auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3254 
3255   // Verify that the input and output matrices have correct shapes.
3256   if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
3257     return op.emitError("input matrix rows count must be equal to "
3258                         "output matrix columns count");
3259 
3260   if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
3261     return op.emitError("input matrix columns count must be equal to "
3262                         "output matrix rows count");
3263 
3264   // Verify that the input and output matrices have the same component type
3265   if (inputMatrix.getElementType() != resultMatrix.getElementType())
3266     return op.emitError("input and output matrices must have the same "
3267                         "component type");
3268 
3269   return success();
3270 }
3271 
3272 //===----------------------------------------------------------------------===//
3273 // spv.MatrixTimesMatrix
3274 //===----------------------------------------------------------------------===//
3275 
verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op)3276 static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) {
3277   auto leftMatrix = op.leftmatrix().getType().cast<spirv::MatrixType>();
3278   auto rightMatrix = op.rightmatrix().getType().cast<spirv::MatrixType>();
3279   auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3280 
3281   // left matrix columns' count and right matrix rows' count must be equal
3282   if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
3283     return op.emitError("left matrix columns' count must be equal to "
3284                         "the right matrix rows' count");
3285 
3286   // right and result matrices columns' count must be the same
3287   if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
3288     return op.emitError(
3289         "right and result matrices must have equal columns' count");
3290 
3291   // right and result matrices component type must be the same
3292   if (rightMatrix.getElementType() != resultMatrix.getElementType())
3293     return op.emitError("right and result matrices' component type must"
3294                         " be the same");
3295 
3296   // left and result matrices component type must be the same
3297   if (leftMatrix.getElementType() != resultMatrix.getElementType())
3298     return op.emitError("left and result matrices' component type"
3299                         " must be the same");
3300 
3301   // left and result matrices rows count must be the same
3302   if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
3303     return op.emitError("left and result matrices must have equal rows'"
3304                         " count");
3305 
3306   return success();
3307 }
3308 
3309 //===----------------------------------------------------------------------===//
3310 // spv.specConstantComposite
3311 //===----------------------------------------------------------------------===//
3312 
parseSpecConstantCompositeOp(OpAsmParser & parser,OperationState & state)3313 static ParseResult parseSpecConstantCompositeOp(OpAsmParser &parser,
3314                                                 OperationState &state) {
3315 
3316   StringAttr compositeName;
3317   if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
3318                              state.attributes))
3319     return failure();
3320 
3321   if (parser.parseLParen())
3322     return failure();
3323 
3324   SmallVector<Attribute, 4> constituents;
3325 
3326   do {
3327     // The name of the constituent attribute isn't important
3328     const char *attrName = "spec_const";
3329     FlatSymbolRefAttr specConstRef;
3330     NamedAttrList attrs;
3331 
3332     if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
3333       return failure();
3334 
3335     constituents.push_back(specConstRef);
3336   } while (!parser.parseOptionalComma());
3337 
3338   if (parser.parseRParen())
3339     return failure();
3340 
3341   state.addAttribute(kCompositeSpecConstituentsName,
3342                      parser.getBuilder().getArrayAttr(constituents));
3343 
3344   Type type;
3345   if (parser.parseColonType(type))
3346     return failure();
3347 
3348   state.addAttribute(kTypeAttrName, TypeAttr::get(type));
3349 
3350   return success();
3351 }
3352 
print(spirv::SpecConstantCompositeOp op,OpAsmPrinter & printer)3353 static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) {
3354   printer << spirv::SpecConstantCompositeOp::getOperationName() << " ";
3355   printer.printSymbolName(op.sym_name());
3356   printer << " (";
3357   auto constituents = op.constituents().getValue();
3358 
3359   if (!constituents.empty())
3360     llvm::interleaveComma(constituents, printer);
3361 
3362   printer << ") : " << op.type();
3363 }
3364 
verify(spirv::SpecConstantCompositeOp constOp)3365 static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
3366   auto cType = constOp.type().dyn_cast<spirv::CompositeType>();
3367   auto constituents = constOp.constituents().getValue();
3368 
3369   if (!cType)
3370     return constOp.emitError(
3371                "result type must be a composite type, but provided ")
3372            << constOp.type();
3373 
3374   if (cType.isa<spirv::CooperativeMatrixNVType>())
3375     return constOp.emitError("unsupported composite type  ") << cType;
3376   else if (constituents.size() != cType.getNumElements())
3377     return constOp.emitError("has incorrect number of operands: expected ")
3378            << cType.getNumElements() << ", but provided "
3379            << constituents.size();
3380 
3381   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
3382     auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
3383 
3384     auto constituentSpecConstOp =
3385         dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
3386             constOp->getParentOp(), constituent.getValue()));
3387 
3388     if (constituentSpecConstOp.default_value().getType() !=
3389         cType.getElementType(index))
3390       return constOp.emitError("has incorrect types of operands: expected ")
3391              << cType.getElementType(index) << ", but provided "
3392              << constituentSpecConstOp.default_value().getType();
3393   }
3394 
3395   return success();
3396 }
3397 
3398 //===----------------------------------------------------------------------===//
3399 // spv.mlir.yield
3400 //===----------------------------------------------------------------------===//
3401 
verify(spirv::YieldOp yieldOp)3402 static LogicalResult verify(spirv::YieldOp yieldOp) {
3403   Operation *parentOp = yieldOp->getParentOp();
3404 
3405   if (!parentOp || !isa<spirv::SpecConstantOperationOp>(parentOp))
3406     return yieldOp.emitOpError(
3407         "expected parent op to be 'spv.SpecConstantOperation'");
3408 
3409   Block &block = parentOp->getRegion(0).getBlocks().front();
3410   Operation &enclosedOp = block.getOperations().front();
3411 
3412   if (yieldOp.getOperand().getDefiningOp() != &enclosedOp)
3413     return yieldOp.emitOpError(
3414         "expected operand to be defined by preceeding op");
3415 
3416   return success();
3417 }
3418 
parseSpecConstantOperationOp(OpAsmParser & parser,OperationState & state)3419 static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser,
3420                                                 OperationState &state) {
3421   // TODO: For now, only generic form is supported.
3422   return failure();
3423 }
3424 
print(spirv::SpecConstantOperationOp op,OpAsmPrinter & printer)3425 static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) {
3426   // TODO
3427   printer.printGenericOp(op);
3428 }
3429 
verify(spirv::SpecConstantOperationOp constOp)3430 static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
3431   Block &block = constOp.getRegion().getBlocks().front();
3432 
3433   if (block.getOperations().size() != 2)
3434     return constOp.emitOpError("expected exactly 2 nested ops");
3435 
3436   Operation &yieldOp = block.getOperations().back();
3437 
3438   if (!isa<spirv::YieldOp>(yieldOp))
3439     return constOp.emitOpError("expected terminator to be a yield op");
3440 
3441   Operation &enclosedOp = block.getOperations().front();
3442 
3443   // TODO Add a `UsableInSpecConstantOp` trait and mark ops from the list below
3444   // with it instead.
3445   if (!isa<spirv::SConvertOp, spirv::UConvertOp, spirv::FConvertOp,
3446            spirv::SNegateOp, spirv::NotOp, spirv::IAddOp, spirv::ISubOp,
3447            spirv::IMulOp, spirv::UDivOp, spirv::SDivOp, spirv::UModOp,
3448            spirv::SRemOp, spirv::SModOp, spirv::ShiftRightLogicalOp,
3449            spirv::ShiftRightArithmeticOp, spirv::ShiftLeftLogicalOp,
3450            spirv::BitwiseOrOp, spirv::BitwiseXorOp, spirv::BitwiseAndOp,
3451            spirv::CompositeExtractOp, spirv::CompositeInsertOp,
3452            spirv::LogicalOrOp, spirv::LogicalAndOp, spirv::LogicalNotOp,
3453            spirv::LogicalEqualOp, spirv::LogicalNotEqualOp, spirv::SelectOp,
3454            spirv::IEqualOp, spirv::INotEqualOp, spirv::ULessThanOp,
3455            spirv::SLessThanOp, spirv::UGreaterThanOp, spirv::SGreaterThanOp,
3456            spirv::ULessThanEqualOp, spirv::SLessThanEqualOp,
3457            spirv::UGreaterThanEqualOp, spirv::SGreaterThanEqualOp>(enclosedOp))
3458     return constOp.emitOpError("invalid enclosed op");
3459 
3460   if (enclosedOp.getNumOperands() != constOp.getOperands().size())
3461     return constOp.emitOpError("invalid number of operands; expected ")
3462            << enclosedOp.getNumOperands() << ", actual "
3463            << constOp.getOperands().size();
3464 
3465   if (enclosedOp.getNumOperands() != constOp.getRegion().getNumArguments())
3466     return constOp.emitOpError("invalid number of region arguments; expected ")
3467            << enclosedOp.getNumOperands() << ", actual "
3468            << constOp.getRegion().getNumArguments();
3469 
3470   for (auto operand : constOp.getOperands())
3471     if (!isa<spirv::ConstantOp, spirv::SpecConstantOp,
3472              spirv::SpecConstantCompositeOp, spirv::SpecConstantOperationOp>(
3473             operand.getDefiningOp()))
3474       return constOp.emitOpError("invalid operand");
3475 
3476   return success();
3477 }
3478 
3479 namespace mlir {
3480 namespace spirv {
3481 
3482 // TableGen'erated operation interfaces for querying versions, extensions, and
3483 // capabilities.
3484 #include "mlir/Dialect/SPIRV/SPIRVAvailability.cpp.inc"
3485 } // namespace spirv
3486 } // namespace mlir
3487 
3488 // TablenGen'erated operation definitions.
3489 #define GET_OP_CLASSES
3490 #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
3491 
3492 namespace mlir {
3493 namespace spirv {
3494 // TableGen'erated operation availability interface implementations.
3495 #include "mlir/Dialect/SPIRV/SPIRVOpAvailabilityImpl.inc"
3496 
3497 } // namespace spirv
3498 } // namespace mlir
3499