1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include "mlir/Transforms/FoldUtils.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/StringSwitch.h"
20 
21 using namespace mlir;
22 using namespace mlir::test;
23 
registerTestDialect(DialectRegistry & registry)24 void mlir::test::registerTestDialect(DialectRegistry &registry) {
25   registry.insert<TestDialect>();
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // TestDialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 
34 // Test support for interacting with the AsmPrinter.
35 struct TestOpAsmInterface : public OpAsmDialectInterface {
36   using OpAsmDialectInterface::OpAsmDialectInterface;
37 
getAlias__anon9b4000320111::TestOpAsmInterface38   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
39     StringAttr strAttr = attr.dyn_cast<StringAttr>();
40     if (!strAttr)
41       return failure();
42 
43     // Check the contents of the string attribute to see what the test alias
44     // should be named.
45     Optional<StringRef> aliasName =
46         StringSwitch<Optional<StringRef>>(strAttr.getValue())
47             .Case("alias_test:dot_in_name", StringRef("test.alias"))
48             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
49             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
50             .Case("alias_test:sanitize_conflict_a",
51                   StringRef("test_alias_conflict0"))
52             .Case("alias_test:sanitize_conflict_b",
53                   StringRef("test_alias_conflict0_"))
54             .Default(llvm::None);
55     if (!aliasName)
56       return failure();
57 
58     os << *aliasName;
59     return success();
60   }
61 
getAsmResultNames__anon9b4000320111::TestOpAsmInterface62   void getAsmResultNames(Operation *op,
63                          OpAsmSetValueNameFn setNameFn) const final {
64     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
65       setNameFn(asmOp, "result");
66   }
67 
getAsmBlockArgumentNames__anon9b4000320111::TestOpAsmInterface68   void getAsmBlockArgumentNames(Block *block,
69                                 OpAsmSetValueNameFn setNameFn) const final {
70     auto op = block->getParentOp();
71     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
72     if (!arrayAttr)
73       return;
74     auto args = block->getArguments();
75     auto e = std::min(arrayAttr.size(), args.size());
76     for (unsigned i = 0; i < e; ++i) {
77       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
78         setNameFn(args[i], strAttr.getValue());
79     }
80   }
81 };
82 
83 struct TestDialectFoldInterface : public DialectFoldInterface {
84   using DialectFoldInterface::DialectFoldInterface;
85 
86   /// Registered hook to check if the given region, which is attached to an
87   /// operation that is *not* isolated from above, should be used when
88   /// materializing constants.
shouldMaterializeInto__anon9b4000320111::TestDialectFoldInterface89   bool shouldMaterializeInto(Region *region) const final {
90     // If this is a one region operation, then insert into it.
91     return isa<OneRegionOp>(region->getParentOp());
92   }
93 };
94 
95 /// This class defines the interface for handling inlining with standard
96 /// operations.
97 struct TestInlinerInterface : public DialectInlinerInterface {
98   using DialectInlinerInterface::DialectInlinerInterface;
99 
100   //===--------------------------------------------------------------------===//
101   // Analysis Hooks
102   //===--------------------------------------------------------------------===//
103 
isLegalToInline__anon9b4000320111::TestInlinerInterface104   bool isLegalToInline(Operation *call, Operation *callable,
105                        bool wouldBeCloned) const final {
106     // Don't allow inlining calls that are marked `noinline`.
107     return !call->hasAttr("noinline");
108   }
isLegalToInline__anon9b4000320111::TestInlinerInterface109   bool isLegalToInline(Region *, Region *, bool,
110                        BlockAndValueMapping &) const final {
111     // Inlining into test dialect regions is legal.
112     return true;
113   }
isLegalToInline__anon9b4000320111::TestInlinerInterface114   bool isLegalToInline(Operation *, Region *, bool,
115                        BlockAndValueMapping &) const final {
116     return true;
117   }
118 
shouldAnalyzeRecursively__anon9b4000320111::TestInlinerInterface119   bool shouldAnalyzeRecursively(Operation *op) const final {
120     // Analyze recursively if this is not a functional region operation, it
121     // froms a separate functional scope.
122     return !isa<FunctionalRegionOp>(op);
123   }
124 
125   //===--------------------------------------------------------------------===//
126   // Transformation Hooks
127   //===--------------------------------------------------------------------===//
128 
129   /// Handle the given inlined terminator by replacing it with a new operation
130   /// as necessary.
handleTerminator__anon9b4000320111::TestInlinerInterface131   void handleTerminator(Operation *op,
132                         ArrayRef<Value> valuesToRepl) const final {
133     // Only handle "test.return" here.
134     auto returnOp = dyn_cast<TestReturnOp>(op);
135     if (!returnOp)
136       return;
137 
138     // Replace the values directly with the return operands.
139     assert(returnOp.getNumOperands() == valuesToRepl.size());
140     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
141       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
142   }
143 
144   /// Attempt to materialize a conversion for a type mismatch between a call
145   /// from this dialect, and a callable region. This method should generate an
146   /// operation that takes 'input' as the only operand, and produces a single
147   /// result of 'resultType'. If a conversion can not be generated, nullptr
148   /// should be returned.
materializeCallConversion__anon9b4000320111::TestInlinerInterface149   Operation *materializeCallConversion(OpBuilder &builder, Value input,
150                                        Type resultType,
151                                        Location conversionLoc) const final {
152     // Only allow conversion for i16/i32 types.
153     if (!(resultType.isSignlessInteger(16) ||
154           resultType.isSignlessInteger(32)) ||
155         !(input.getType().isSignlessInteger(16) ||
156           input.getType().isSignlessInteger(32)))
157       return nullptr;
158     return builder.create<TestCastOp>(conversionLoc, resultType, input);
159   }
160 };
161 } // end anonymous namespace
162 
163 //===----------------------------------------------------------------------===//
164 // TestDialect
165 //===----------------------------------------------------------------------===//
166 
initialize()167 void TestDialect::initialize() {
168   addOperations<
169 #define GET_OP_LIST
170 #include "TestOps.cpp.inc"
171       >();
172   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
173                 TestInlinerInterface>();
174   addTypes<TestType, TestRecursiveType,
175 #define GET_TYPEDEF_LIST
176 #include "TestTypeDefs.cpp.inc"
177            >();
178   allowUnknownOperations();
179 }
180 
parseTestType(MLIRContext * ctxt,DialectAsmParser & parser,llvm::SetVector<Type> & stack)181 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
182                           llvm::SetVector<Type> &stack) {
183   StringRef typeTag;
184   if (failed(parser.parseKeyword(&typeTag)))
185     return Type();
186 
187   auto genType = generatedTypeParser(ctxt, parser, typeTag);
188   if (genType != Type())
189     return genType;
190 
191   if (typeTag == "test_type")
192     return TestType::get(parser.getBuilder().getContext());
193 
194   if (typeTag != "test_rec")
195     return Type();
196 
197   StringRef name;
198   if (parser.parseLess() || parser.parseKeyword(&name))
199     return Type();
200   auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
201 
202   // If this type already has been parsed above in the stack, expect just the
203   // name.
204   if (stack.contains(rec)) {
205     if (failed(parser.parseGreater()))
206       return Type();
207     return rec;
208   }
209 
210   // Otherwise, parse the body and update the type.
211   if (failed(parser.parseComma()))
212     return Type();
213   stack.insert(rec);
214   Type subtype = parseTestType(ctxt, parser, stack);
215   stack.pop_back();
216   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
217     return Type();
218 
219   return rec;
220 }
221 
parseType(DialectAsmParser & parser) const222 Type TestDialect::parseType(DialectAsmParser &parser) const {
223   llvm::SetVector<Type> stack;
224   return parseTestType(getContext(), parser, stack);
225 }
226 
printTestType(Type type,DialectAsmPrinter & printer,llvm::SetVector<Type> & stack)227 static void printTestType(Type type, DialectAsmPrinter &printer,
228                           llvm::SetVector<Type> &stack) {
229   if (succeeded(generatedTypePrinter(type, printer)))
230     return;
231   if (type.isa<TestType>()) {
232     printer << "test_type";
233     return;
234   }
235 
236   auto rec = type.cast<TestRecursiveType>();
237   printer << "test_rec<" << rec.getName();
238   if (!stack.contains(rec)) {
239     printer << ", ";
240     stack.insert(rec);
241     printTestType(rec.getBody(), printer, stack);
242     stack.pop_back();
243   }
244   printer << ">";
245 }
246 
printType(Type type,DialectAsmPrinter & printer) const247 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
248   llvm::SetVector<Type> stack;
249   printTestType(type, printer, stack);
250 }
251 
verifyOperationAttribute(Operation * op,NamedAttribute namedAttr)252 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
253                                                     NamedAttribute namedAttr) {
254   if (namedAttr.first == "test.invalid_attr")
255     return op->emitError() << "invalid to use 'test.invalid_attr'";
256   return success();
257 }
258 
verifyRegionArgAttribute(Operation * op,unsigned regionIndex,unsigned argIndex,NamedAttribute namedAttr)259 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
260                                                     unsigned regionIndex,
261                                                     unsigned argIndex,
262                                                     NamedAttribute namedAttr) {
263   if (namedAttr.first == "test.invalid_attr")
264     return op->emitError() << "invalid to use 'test.invalid_attr'";
265   return success();
266 }
267 
268 LogicalResult
verifyRegionResultAttribute(Operation * op,unsigned regionIndex,unsigned resultIndex,NamedAttribute namedAttr)269 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
270                                          unsigned resultIndex,
271                                          NamedAttribute namedAttr) {
272   if (namedAttr.first == "test.invalid_attr")
273     return op->emitError() << "invalid to use 'test.invalid_attr'";
274   return success();
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // TestBranchOp
279 //===----------------------------------------------------------------------===//
280 
281 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)282 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
283   assert(index == 0 && "invalid successor index");
284   return targetOperandsMutable();
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // TestFoldToCallOp
289 //===----------------------------------------------------------------------===//
290 
291 namespace {
292 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
293   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
294 
matchAndRewrite__anon9b4000320211::FoldToCallOpPattern295   LogicalResult matchAndRewrite(FoldToCallOp op,
296                                 PatternRewriter &rewriter) const override {
297     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
298                                         ValueRange());
299     return success();
300   }
301 };
302 } // end anonymous namespace
303 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)304 void FoldToCallOp::getCanonicalizationPatterns(
305     OwningRewritePatternList &results, MLIRContext *context) {
306   results.insert<FoldToCallOpPattern>(context);
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // Test Format* operations
311 //===----------------------------------------------------------------------===//
312 
313 //===----------------------------------------------------------------------===//
314 // Parsing
315 
parseCustomDirectiveOperands(OpAsmParser & parser,OpAsmParser::OperandType & operand,Optional<OpAsmParser::OperandType> & optOperand,SmallVectorImpl<OpAsmParser::OperandType> & varOperands)316 static ParseResult parseCustomDirectiveOperands(
317     OpAsmParser &parser, OpAsmParser::OperandType &operand,
318     Optional<OpAsmParser::OperandType> &optOperand,
319     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
320   if (parser.parseOperand(operand))
321     return failure();
322   if (succeeded(parser.parseOptionalComma())) {
323     optOperand.emplace();
324     if (parser.parseOperand(*optOperand))
325       return failure();
326   }
327   if (parser.parseArrow() || parser.parseLParen() ||
328       parser.parseOperandList(varOperands) || parser.parseRParen())
329     return failure();
330   return success();
331 }
332 static ParseResult
parseCustomDirectiveResults(OpAsmParser & parser,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)333 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
334                             Type &optOperandType,
335                             SmallVectorImpl<Type> &varOperandTypes) {
336   if (parser.parseColon())
337     return failure();
338 
339   if (parser.parseType(operandType))
340     return failure();
341   if (succeeded(parser.parseOptionalComma())) {
342     if (parser.parseType(optOperandType))
343       return failure();
344   }
345   if (parser.parseArrow() || parser.parseLParen() ||
346       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
347     return failure();
348   return success();
349 }
350 static ParseResult
parseCustomDirectiveWithTypeRefs(OpAsmParser & parser,Type operandType,Type optOperandType,const SmallVectorImpl<Type> & varOperandTypes)351 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
352                                  Type optOperandType,
353                                  const SmallVectorImpl<Type> &varOperandTypes) {
354   if (parser.parseKeyword("type_refs_capture"))
355     return failure();
356 
357   Type operandType2, optOperandType2;
358   SmallVector<Type, 1> varOperandTypes2;
359   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
360                                   varOperandTypes2))
361     return failure();
362 
363   if (operandType != operandType2 || optOperandType != optOperandType2 ||
364       varOperandTypes != varOperandTypes2)
365     return failure();
366 
367   return success();
368 }
parseCustomDirectiveOperandsAndTypes(OpAsmParser & parser,OpAsmParser::OperandType & operand,Optional<OpAsmParser::OperandType> & optOperand,SmallVectorImpl<OpAsmParser::OperandType> & varOperands,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)369 static ParseResult parseCustomDirectiveOperandsAndTypes(
370     OpAsmParser &parser, OpAsmParser::OperandType &operand,
371     Optional<OpAsmParser::OperandType> &optOperand,
372     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
373     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
374   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
375       parseCustomDirectiveResults(parser, operandType, optOperandType,
376                                   varOperandTypes))
377     return failure();
378   return success();
379 }
parseCustomDirectiveRegions(OpAsmParser & parser,Region & region,SmallVectorImpl<std::unique_ptr<Region>> & varRegions)380 static ParseResult parseCustomDirectiveRegions(
381     OpAsmParser &parser, Region &region,
382     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
383   if (parser.parseRegion(region))
384     return failure();
385   if (failed(parser.parseOptionalComma()))
386     return success();
387   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
388   if (parser.parseRegion(*varRegion))
389     return failure();
390   varRegions.emplace_back(std::move(varRegion));
391   return success();
392 }
393 static ParseResult
parseCustomDirectiveSuccessors(OpAsmParser & parser,Block * & successor,SmallVectorImpl<Block * > & varSuccessors)394 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
395                                SmallVectorImpl<Block *> &varSuccessors) {
396   if (parser.parseSuccessor(successor))
397     return failure();
398   if (failed(parser.parseOptionalComma()))
399     return success();
400   Block *varSuccessor;
401   if (parser.parseSuccessor(varSuccessor))
402     return failure();
403   varSuccessors.append(2, varSuccessor);
404   return success();
405 }
parseCustomDirectiveAttributes(OpAsmParser & parser,IntegerAttr & attr,IntegerAttr & optAttr)406 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
407                                                   IntegerAttr &attr,
408                                                   IntegerAttr &optAttr) {
409   if (parser.parseAttribute(attr))
410     return failure();
411   if (succeeded(parser.parseOptionalComma())) {
412     if (parser.parseAttribute(optAttr))
413       return failure();
414   }
415   return success();
416 }
417 
parseCustomDirectiveAttrDict(OpAsmParser & parser,NamedAttrList & attrs)418 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
419                                                 NamedAttrList &attrs) {
420   return parser.parseOptionalAttrDict(attrs);
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // Printing
425 
printCustomDirectiveOperands(OpAsmPrinter & printer,Operation *,Value operand,Value optOperand,OperandRange varOperands)426 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
427                                          Value operand, Value optOperand,
428                                          OperandRange varOperands) {
429   printer << operand;
430   if (optOperand)
431     printer << ", " << optOperand;
432   printer << " -> (" << varOperands << ")";
433 }
printCustomDirectiveResults(OpAsmPrinter & printer,Operation *,Type operandType,Type optOperandType,TypeRange varOperandTypes)434 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
435                                         Type operandType, Type optOperandType,
436                                         TypeRange varOperandTypes) {
437   printer << " : " << operandType;
438   if (optOperandType)
439     printer << ", " << optOperandType;
440   printer << " -> (" << varOperandTypes << ")";
441 }
printCustomDirectiveWithTypeRefs(OpAsmPrinter & printer,Operation * op,Type operandType,Type optOperandType,TypeRange varOperandTypes)442 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
443                                              Operation *op, Type operandType,
444                                              Type optOperandType,
445                                              TypeRange varOperandTypes) {
446   printer << " type_refs_capture ";
447   printCustomDirectiveResults(printer, op, operandType, optOperandType,
448                               varOperandTypes);
449 }
printCustomDirectiveOperandsAndTypes(OpAsmPrinter & printer,Operation * op,Value operand,Value optOperand,OperandRange varOperands,Type operandType,Type optOperandType,TypeRange varOperandTypes)450 static void printCustomDirectiveOperandsAndTypes(
451     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
452     OperandRange varOperands, Type operandType, Type optOperandType,
453     TypeRange varOperandTypes) {
454   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
455   printCustomDirectiveResults(printer, op, operandType, optOperandType,
456                               varOperandTypes);
457 }
printCustomDirectiveRegions(OpAsmPrinter & printer,Operation *,Region & region,MutableArrayRef<Region> varRegions)458 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
459                                         Region &region,
460                                         MutableArrayRef<Region> varRegions) {
461   printer.printRegion(region);
462   if (!varRegions.empty()) {
463     printer << ", ";
464     for (Region &region : varRegions)
465       printer.printRegion(region);
466   }
467 }
printCustomDirectiveSuccessors(OpAsmPrinter & printer,Operation *,Block * successor,SuccessorRange varSuccessors)468 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
469                                            Block *successor,
470                                            SuccessorRange varSuccessors) {
471   printer << successor;
472   if (!varSuccessors.empty())
473     printer << ", " << varSuccessors.front();
474 }
printCustomDirectiveAttributes(OpAsmPrinter & printer,Operation *,Attribute attribute,Attribute optAttribute)475 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
476                                            Attribute attribute,
477                                            Attribute optAttribute) {
478   printer << attribute;
479   if (optAttribute)
480     printer << ", " << optAttribute;
481 }
482 
printCustomDirectiveAttrDict(OpAsmPrinter & printer,Operation * op,MutableDictionaryAttr attrs)483 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
484                                          MutableDictionaryAttr attrs) {
485   printer.printOptionalAttrDict(attrs.getAttrs());
486 }
487 //===----------------------------------------------------------------------===//
488 // Test IsolatedRegionOp - parse passthrough region arguments.
489 //===----------------------------------------------------------------------===//
490 
parseIsolatedRegionOp(OpAsmParser & parser,OperationState & result)491 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
492                                          OperationState &result) {
493   OpAsmParser::OperandType argInfo;
494   Type argType = parser.getBuilder().getIndexType();
495 
496   // Parse the input operand.
497   if (parser.parseOperand(argInfo) ||
498       parser.resolveOperand(argInfo, argType, result.operands))
499     return failure();
500 
501   // Parse the body region, and reuse the operand info as the argument info.
502   Region *body = result.addRegion();
503   return parser.parseRegion(*body, argInfo, argType,
504                             /*enableNameShadowing=*/true);
505 }
506 
print(OpAsmPrinter & p,IsolatedRegionOp op)507 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
508   p << "test.isolated_region ";
509   p.printOperand(op.getOperand());
510   p.shadowRegionArgs(op.region(), op.getOperand());
511   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
512 }
513 
514 //===----------------------------------------------------------------------===//
515 // Test SSACFGRegionOp
516 //===----------------------------------------------------------------------===//
517 
getRegionKind(unsigned index)518 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
519   return RegionKind::SSACFG;
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // Test GraphRegionOp
524 //===----------------------------------------------------------------------===//
525 
parseGraphRegionOp(OpAsmParser & parser,OperationState & result)526 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
527                                       OperationState &result) {
528   // Parse the body region, and reuse the operand info as the argument info.
529   Region *body = result.addRegion();
530   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
531 }
532 
print(OpAsmPrinter & p,GraphRegionOp op)533 static void print(OpAsmPrinter &p, GraphRegionOp op) {
534   p << "test.graph_region ";
535   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
536 }
537 
getRegionKind(unsigned index)538 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
539   return RegionKind::Graph;
540 }
541 
542 //===----------------------------------------------------------------------===//
543 // Test AffineScopeOp
544 //===----------------------------------------------------------------------===//
545 
parseAffineScopeOp(OpAsmParser & parser,OperationState & result)546 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
547                                       OperationState &result) {
548   // Parse the body region, and reuse the operand info as the argument info.
549   Region *body = result.addRegion();
550   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
551 }
552 
print(OpAsmPrinter & p,AffineScopeOp op)553 static void print(OpAsmPrinter &p, AffineScopeOp op) {
554   p << "test.affine_scope ";
555   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // Test parser.
560 //===----------------------------------------------------------------------===//
561 
parseWrappedKeywordOp(OpAsmParser & parser,OperationState & result)562 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
563                                          OperationState &result) {
564   StringRef keyword;
565   if (parser.parseKeyword(&keyword))
566     return failure();
567   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
568   return success();
569 }
570 
print(OpAsmPrinter & p,WrappedKeywordOp op)571 static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
572   p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
573 }
574 
575 //===----------------------------------------------------------------------===//
576 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
577 
parseWrappingRegionOp(OpAsmParser & parser,OperationState & result)578 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
579                                          OperationState &result) {
580   if (parser.parseKeyword("wraps"))
581     return failure();
582 
583   // Parse the wrapped op in a region
584   Region &body = *result.addRegion();
585   body.push_back(new Block);
586   Block &block = body.back();
587   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
588   if (!wrapped_op)
589     return failure();
590 
591   // Create a return terminator in the inner region, pass as operand to the
592   // terminator the returned values from the wrapped operation.
593   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
594   OpBuilder builder(parser.getBuilder().getContext());
595   builder.setInsertionPointToEnd(&block);
596   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
597 
598   // Get the results type for the wrapping op from the terminator operands.
599   Operation &return_op = body.back().back();
600   result.types.append(return_op.operand_type_begin(),
601                       return_op.operand_type_end());
602 
603   // Use the location of the wrapped op for the "test.wrapping_region" op.
604   result.location = wrapped_op->getLoc();
605 
606   return success();
607 }
608 
print(OpAsmPrinter & p,WrappingRegionOp op)609 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
610   p << op.getOperationName() << " wraps ";
611   p.printGenericOp(&op.region().front().front());
612 }
613 
614 //===----------------------------------------------------------------------===//
615 // Test PolyForOp - parse list of region arguments.
616 //===----------------------------------------------------------------------===//
617 
parsePolyForOp(OpAsmParser & parser,OperationState & result)618 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
619   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
620   // Parse list of region arguments without a delimiter.
621   if (parser.parseRegionArgumentList(ivsInfo))
622     return failure();
623 
624   // Parse the body region.
625   Region *body = result.addRegion();
626   auto &builder = parser.getBuilder();
627   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
628   return parser.parseRegion(*body, ivsInfo, argTypes);
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // Test removing op with inner ops.
633 //===----------------------------------------------------------------------===//
634 
635 namespace {
636 struct TestRemoveOpWithInnerOps
637     : public OpRewritePattern<TestOpWithRegionPattern> {
638   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
639 
matchAndRewrite__anon9b4000320311::TestRemoveOpWithInnerOps640   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
641                                 PatternRewriter &rewriter) const override {
642     rewriter.eraseOp(op);
643     return success();
644   }
645 };
646 } // end anonymous namespace
647 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)648 void TestOpWithRegionPattern::getCanonicalizationPatterns(
649     OwningRewritePatternList &results, MLIRContext *context) {
650   results.insert<TestRemoveOpWithInnerOps>(context);
651 }
652 
fold(ArrayRef<Attribute> operands)653 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
654   return operand();
655 }
656 
fold(ArrayRef<Attribute> operands)657 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
658   return getValue();
659 }
660 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)661 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
662     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
663   for (Value input : this->operands()) {
664     results.push_back(input);
665   }
666   return success();
667 }
668 
fold(ArrayRef<Attribute> operands)669 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
670   assert(operands.size() == 1);
671   if (operands.front()) {
672     setAttr("attr", operands.front());
673     return getResult();
674   }
675   return {};
676 }
677 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)678 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
679     MLIRContext *, Optional<Location> location, ValueRange operands,
680     DictionaryAttr attributes, RegionRange regions,
681     SmallVectorImpl<Type> &inferredReturnTypes) {
682   if (operands[0].getType() != operands[1].getType()) {
683     return emitOptionalError(location, "operand type mismatch ",
684                              operands[0].getType(), " vs ",
685                              operands[1].getType());
686   }
687   inferredReturnTypes.assign({operands[0].getType()});
688   return success();
689 }
690 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)691 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
692     MLIRContext *context, Optional<Location> location, ValueRange operands,
693     DictionaryAttr attributes, RegionRange regions,
694     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
695   // Create return type consisting of the last element of the first operand.
696   auto operandType = *operands.getTypes().begin();
697   auto sval = operandType.dyn_cast<ShapedType>();
698   if (!sval) {
699     return emitOptionalError(location, "only shaped type operands allowed");
700   }
701   int64_t dim =
702       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
703   auto type = IntegerType::get(17, context);
704   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
705   return success();
706 }
707 
reifyReturnTypeShapes(OpBuilder & builder,llvm::SmallVectorImpl<Value> & shapes)708 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
709     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
710   shapes = SmallVector<Value, 1>{
711       builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
712   return success();
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // Test SideEffect interfaces
717 //===----------------------------------------------------------------------===//
718 
719 namespace {
720 /// A test resource for side effects.
721 struct TestResource : public SideEffects::Resource::Base<TestResource> {
getName__anon9b4000320411::TestResource722   StringRef getName() final { return "<Test>"; }
723 };
724 } // end anonymous namespace
725 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)726 void SideEffectOp::getEffects(
727     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
728   // Check for an effects attribute on the op instance.
729   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
730   if (!effectsAttr)
731     return;
732 
733   // If there is one, it is an array of dictionary attributes that hold
734   // information on the effects of this operation.
735   for (Attribute element : effectsAttr) {
736     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
737 
738     // Get the specific memory effect.
739     MemoryEffects::Effect *effect =
740         StringSwitch<MemoryEffects::Effect *>(
741             effectElement.get("effect").cast<StringAttr>().getValue())
742             .Case("allocate", MemoryEffects::Allocate::get())
743             .Case("free", MemoryEffects::Free::get())
744             .Case("read", MemoryEffects::Read::get())
745             .Case("write", MemoryEffects::Write::get());
746 
747     // Check for a non-default resource to use.
748     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
749     if (effectElement.get("test_resource"))
750       resource = TestResource::get();
751 
752     // Check for a result to affect.
753     if (effectElement.get("on_result"))
754       effects.emplace_back(effect, getResult(), resource);
755     else if (Attribute ref = effectElement.get("on_reference"))
756       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
757     else
758       effects.emplace_back(effect, resource);
759   }
760 }
761 
getEffects(SmallVectorImpl<TestEffects::EffectInstance> & effects)762 void SideEffectOp::getEffects(
763     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
764   auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter");
765   if (!effectsAttr)
766     return;
767 
768   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
769 }
770 
771 //===----------------------------------------------------------------------===//
772 // StringAttrPrettyNameOp
773 //===----------------------------------------------------------------------===//
774 
775 // This op has fancy handling of its SSA result name.
parseStringAttrPrettyNameOp(OpAsmParser & parser,OperationState & result)776 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
777                                                OperationState &result) {
778   // Add the result types.
779   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
780     result.addTypes(parser.getBuilder().getIntegerType(32));
781 
782   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
783     return failure();
784 
785   // If the attribute dictionary contains no 'names' attribute, infer it from
786   // the SSA name (if specified).
787   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
788     return attr.first == "names";
789   });
790 
791   // If there was no name specified, check to see if there was a useful name
792   // specified in the asm file.
793   if (hadNames || parser.getNumResults() == 0)
794     return success();
795 
796   SmallVector<StringRef, 4> names;
797   auto *context = result.getContext();
798 
799   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
800     auto resultName = parser.getResultName(i);
801     StringRef nameStr;
802     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
803       nameStr = resultName.first;
804 
805     names.push_back(nameStr);
806   }
807 
808   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
809   result.attributes.push_back({Identifier::get("names", context), namesAttr});
810   return success();
811 }
812 
print(OpAsmPrinter & p,StringAttrPrettyNameOp op)813 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
814   p << "test.string_attr_pretty_name";
815 
816   // Note that we only need to print the "name" attribute if the asmprinter
817   // result name disagrees with it.  This can happen in strange cases, e.g.
818   // when there are conflicts.
819   bool namesDisagree = op.names().size() != op.getNumResults();
820 
821   SmallString<32> resultNameStr;
822   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
823     resultNameStr.clear();
824     llvm::raw_svector_ostream tmpStream(resultNameStr);
825     p.printOperand(op.getResult(i), tmpStream);
826 
827     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
828     if (!expectedName ||
829         tmpStream.str().drop_front() != expectedName.getValue()) {
830       namesDisagree = true;
831     }
832   }
833 
834   if (namesDisagree)
835     p.printOptionalAttrDictWithKeyword(op.getAttrs());
836   else
837     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
838 }
839 
840 // We set the SSA name in the asm syntax to the contents of the name
841 // attribute.
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)842 void StringAttrPrettyNameOp::getAsmResultNames(
843     function_ref<void(Value, StringRef)> setNameFn) {
844 
845   auto value = names();
846   for (size_t i = 0, e = value.size(); i != e; ++i)
847     if (auto str = value[i].dyn_cast<StringAttr>())
848       if (!str.getValue().empty())
849         setNameFn(getResult(i), str.getValue());
850 }
851 
852 //===----------------------------------------------------------------------===//
853 // RegionIfOp
854 //===----------------------------------------------------------------------===//
855 
print(OpAsmPrinter & p,RegionIfOp op)856 static void print(OpAsmPrinter &p, RegionIfOp op) {
857   p << RegionIfOp::getOperationName() << " ";
858   p.printOperands(op.getOperands());
859   p << ": " << op.getOperandTypes();
860   p.printArrowTypeList(op.getResultTypes());
861   p << " then";
862   p.printRegion(op.thenRegion(),
863                 /*printEntryBlockArgs=*/true,
864                 /*printBlockTerminators=*/true);
865   p << " else";
866   p.printRegion(op.elseRegion(),
867                 /*printEntryBlockArgs=*/true,
868                 /*printBlockTerminators=*/true);
869   p << " join";
870   p.printRegion(op.joinRegion(),
871                 /*printEntryBlockArgs=*/true,
872                 /*printBlockTerminators=*/true);
873 }
874 
parseRegionIfOp(OpAsmParser & parser,OperationState & result)875 static ParseResult parseRegionIfOp(OpAsmParser &parser,
876                                    OperationState &result) {
877   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
878   SmallVector<Type, 2> operandTypes;
879 
880   result.regions.reserve(3);
881   Region *thenRegion = result.addRegion();
882   Region *elseRegion = result.addRegion();
883   Region *joinRegion = result.addRegion();
884 
885   // Parse operand, type and arrow type lists.
886   if (parser.parseOperandList(operandInfos) ||
887       parser.parseColonTypeList(operandTypes) ||
888       parser.parseArrowTypeList(result.types))
889     return failure();
890 
891   // Parse all attached regions.
892   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
893       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
894       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
895     return failure();
896 
897   return parser.resolveOperands(operandInfos, operandTypes,
898                                 parser.getCurrentLocation(), result.operands);
899 }
900 
getSuccessorEntryOperands(unsigned index)901 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
902   assert(index < 2 && "invalid region index");
903   return getOperands();
904 }
905 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)906 void RegionIfOp::getSuccessorRegions(
907     Optional<unsigned> index, ArrayRef<Attribute> operands,
908     SmallVectorImpl<RegionSuccessor> &regions) {
909   // We always branch to the join region.
910   if (index.hasValue()) {
911     if (index.getValue() < 2)
912       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
913     else
914       regions.push_back(RegionSuccessor(getResults()));
915     return;
916   }
917 
918   // The then and else regions are the entry regions of this op.
919   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
920   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
921 }
922 
923 #include "TestOpEnums.cpp.inc"
924 #include "TestOpInterfaces.cpp.inc"
925 #include "TestOpStructs.cpp.inc"
926 #include "TestTypeInterfaces.cpp.inc"
927 
928 #define GET_OP_CLASSES
929 #include "TestOps.cpp.inc"
930