1 //===- pass.c - Simple test of C APIs -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9
10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
11 */
12
13 #include "mlir-c/Pass.h"
14 #include "mlir-c/IR.h"
15 #include "mlir-c/Registration.h"
16 #include "mlir-c/Transforms.h"
17
18 #include <assert.h>
19 #include <math.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23
testRunPassOnModule()24 void testRunPassOnModule() {
25 MlirContext ctx = mlirContextCreate();
26 mlirRegisterAllDialects(ctx);
27
28 MlirModule module = mlirModuleCreateParse(
29 ctx,
30 // clang-format off
31 mlirStringRefCreateFromCString(
32 "func @foo(%arg0 : i32) -> i32 { \n"
33 " %res = addi %arg0, %arg0 : i32 \n"
34 " return %res : i32 \n"
35 "}"));
36 // clang-format on
37 if (mlirModuleIsNull(module)) {
38 fprintf(stderr, "Unexpected failure parsing module.\n");
39 exit(EXIT_FAILURE);
40 }
41
42 // Run the print-op-stats pass on the top-level module:
43 // CHECK-LABEL: Operations encountered:
44 // CHECK: func , 1
45 // CHECK: module_terminator , 1
46 // CHECK: std.addi , 1
47 // CHECK: std.return , 1
48 {
49 MlirPassManager pm = mlirPassManagerCreate(ctx);
50 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
51 mlirPassManagerAddOwnedPass(pm, printOpStatPass);
52 MlirLogicalResult success = mlirPassManagerRun(pm, module);
53 if (mlirLogicalResultIsFailure(success)) {
54 fprintf(stderr, "Unexpected failure running pass manager.\n");
55 exit(EXIT_FAILURE);
56 }
57 mlirPassManagerDestroy(pm);
58 }
59 mlirModuleDestroy(module);
60 mlirContextDestroy(ctx);
61 }
62
testRunPassOnNestedModule()63 void testRunPassOnNestedModule() {
64 MlirContext ctx = mlirContextCreate();
65 mlirRegisterAllDialects(ctx);
66
67 MlirModule module = mlirModuleCreateParse(
68 ctx,
69 // clang-format off
70 mlirStringRefCreateFromCString(
71 "func @foo(%arg0 : i32) -> i32 { \n"
72 " %res = addi %arg0, %arg0 : i32 \n"
73 " return %res : i32 \n"
74 "} \n"
75 "module { \n"
76 " func @bar(%arg0 : f32) -> f32 { \n"
77 " %res = addf %arg0, %arg0 : f32 \n"
78 " return %res : f32 \n"
79 " } \n"
80 "}"));
81 // clang-format on
82 if (mlirModuleIsNull(module))
83 exit(1);
84
85 // Run the print-op-stats pass on functions under the top-level module:
86 // CHECK-LABEL: Operations encountered:
87 // CHECK-NOT: module_terminator
88 // CHECK: func , 1
89 // CHECK: std.addi , 1
90 // CHECK: std.return , 1
91 {
92 MlirPassManager pm = mlirPassManagerCreate(ctx);
93 MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(
94 pm, mlirStringRefCreateFromCString("func"));
95 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
96 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
97 MlirLogicalResult success = mlirPassManagerRun(pm, module);
98 if (mlirLogicalResultIsFailure(success))
99 exit(2);
100 mlirPassManagerDestroy(pm);
101 }
102 // Run the print-op-stats pass on functions under the nested module:
103 // CHECK-LABEL: Operations encountered:
104 // CHECK-NOT: module_terminator
105 // CHECK: func , 1
106 // CHECK: std.addf , 1
107 // CHECK: std.return , 1
108 {
109 MlirPassManager pm = mlirPassManagerCreate(ctx);
110 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
111 pm, mlirStringRefCreateFromCString("module"));
112 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
113 nestedModulePm, mlirStringRefCreateFromCString("func"));
114 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
115 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
116 MlirLogicalResult success = mlirPassManagerRun(pm, module);
117 if (mlirLogicalResultIsFailure(success))
118 exit(2);
119 mlirPassManagerDestroy(pm);
120 }
121
122 mlirModuleDestroy(module);
123 mlirContextDestroy(ctx);
124 }
125
printToStderr(MlirStringRef str,void * userData)126 static void printToStderr(MlirStringRef str, void *userData) {
127 (void)userData;
128 fwrite(str.data, 1, str.length, stderr);
129 }
130
testPrintPassPipeline()131 void testPrintPassPipeline() {
132 MlirContext ctx = mlirContextCreate();
133 MlirPassManager pm = mlirPassManagerCreate(ctx);
134 // Populate the pass-manager
135 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
136 pm, mlirStringRefCreateFromCString("module"));
137 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
138 nestedModulePm, mlirStringRefCreateFromCString("func"));
139 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
140 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
141
142 // Print the top level pass manager
143 // CHECK: Top-level: module(func(print-op-stats))
144 fprintf(stderr, "Top-level: ");
145 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
146 NULL);
147 fprintf(stderr, "\n");
148
149 // Print the pipeline nested one level down
150 // CHECK: Nested Module: func(print-op-stats)
151 fprintf(stderr, "Nested Module: ");
152 mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
153 fprintf(stderr, "\n");
154
155 // Print the pipeline nested two levels down
156 // CHECK: Nested Module>Func: print-op-stats
157 fprintf(stderr, "Nested Module>Func: ");
158 mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
159 fprintf(stderr, "\n");
160
161 mlirPassManagerDestroy(pm);
162 mlirContextDestroy(ctx);
163 }
164
testParsePassPipeline()165 void testParsePassPipeline() {
166 MlirContext ctx = mlirContextCreate();
167 MlirPassManager pm = mlirPassManagerCreate(ctx);
168 // Try parse a pipeline.
169 MlirLogicalResult status = mlirParsePassPipeline(
170 mlirPassManagerGetAsOpPassManager(pm),
171 mlirStringRefCreateFromCString(
172 "module(func(print-op-stats), func(print-op-stats))"));
173 // Expect a failure, we haven't registered the print-op-stats pass yet.
174 if (mlirLogicalResultIsSuccess(status)) {
175 fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n");
176 exit(EXIT_FAILURE);
177 }
178 // Try again after registrating the pass.
179 mlirRegisterTransformsPrintOpStats();
180 status = mlirParsePassPipeline(
181 mlirPassManagerGetAsOpPassManager(pm),
182 mlirStringRefCreateFromCString(
183 "module(func(print-op-stats), func(print-op-stats))"));
184 // Expect a failure, we haven't registered the print-op-stats pass yet.
185 if (mlirLogicalResultIsFailure(status)) {
186 fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n");
187 exit(EXIT_FAILURE);
188 }
189
190 // CHECK: Round-trip: module(func(print-op-stats), func(print-op-stats))
191 fprintf(stderr, "Round-trip: ");
192 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
193 NULL);
194 fprintf(stderr, "\n");
195 }
196
main()197 int main() {
198 testRunPassOnModule();
199 testRunPassOnNestedModule();
200 testPrintPassPipeline();
201 testParsePassPipeline();
202 return 0;
203 }
204