1 //===- pass.c - Simple test of C APIs -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
11  */
12 
13 #include "mlir-c/Pass.h"
14 #include "mlir-c/IR.h"
15 #include "mlir-c/Registration.h"
16 #include "mlir-c/Transforms.h"
17 
18 #include <assert.h>
19 #include <math.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 
testRunPassOnModule()24 void testRunPassOnModule() {
25   MlirContext ctx = mlirContextCreate();
26   mlirRegisterAllDialects(ctx);
27 
28   MlirModule module = mlirModuleCreateParse(
29       ctx,
30       // clang-format off
31                             mlirStringRefCreateFromCString(
32 "func @foo(%arg0 : i32) -> i32 {                                            \n"
33 "  %res = addi %arg0, %arg0 : i32                                           \n"
34 "  return %res : i32                                                        \n"
35 "}"));
36   // clang-format on
37   if (mlirModuleIsNull(module)) {
38     fprintf(stderr, "Unexpected failure parsing module.\n");
39     exit(EXIT_FAILURE);
40   }
41 
42   // Run the print-op-stats pass on the top-level module:
43   // CHECK-LABEL: Operations encountered:
44   // CHECK: func              , 1
45   // CHECK: module_terminator , 1
46   // CHECK: std.addi          , 1
47   // CHECK: std.return        , 1
48   {
49     MlirPassManager pm = mlirPassManagerCreate(ctx);
50     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
51     mlirPassManagerAddOwnedPass(pm, printOpStatPass);
52     MlirLogicalResult success = mlirPassManagerRun(pm, module);
53     if (mlirLogicalResultIsFailure(success)) {
54       fprintf(stderr, "Unexpected failure running pass manager.\n");
55       exit(EXIT_FAILURE);
56     }
57     mlirPassManagerDestroy(pm);
58   }
59   mlirModuleDestroy(module);
60   mlirContextDestroy(ctx);
61 }
62 
testRunPassOnNestedModule()63 void testRunPassOnNestedModule() {
64   MlirContext ctx = mlirContextCreate();
65   mlirRegisterAllDialects(ctx);
66 
67   MlirModule module = mlirModuleCreateParse(
68       ctx,
69       // clang-format off
70                             mlirStringRefCreateFromCString(
71 "func @foo(%arg0 : i32) -> i32 {                                            \n"
72 "  %res = addi %arg0, %arg0 : i32                                           \n"
73 "  return %res : i32                                                        \n"
74 "}                                                                          \n"
75 "module {                                                                   \n"
76 "  func @bar(%arg0 : f32) -> f32 {                                          \n"
77 "    %res = addf %arg0, %arg0 : f32                                         \n"
78 "    return %res : f32                                                      \n"
79 "  }                                                                        \n"
80 "}"));
81   // clang-format on
82   if (mlirModuleIsNull(module))
83     exit(1);
84 
85   // Run the print-op-stats pass on functions under the top-level module:
86   // CHECK-LABEL: Operations encountered:
87   // CHECK-NOT: module_terminator
88   // CHECK: func              , 1
89   // CHECK: std.addi          , 1
90   // CHECK: std.return        , 1
91   {
92     MlirPassManager pm = mlirPassManagerCreate(ctx);
93     MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(
94         pm, mlirStringRefCreateFromCString("func"));
95     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
96     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
97     MlirLogicalResult success = mlirPassManagerRun(pm, module);
98     if (mlirLogicalResultIsFailure(success))
99       exit(2);
100     mlirPassManagerDestroy(pm);
101   }
102   // Run the print-op-stats pass on functions under the nested module:
103   // CHECK-LABEL: Operations encountered:
104   // CHECK-NOT: module_terminator
105   // CHECK: func              , 1
106   // CHECK: std.addf          , 1
107   // CHECK: std.return        , 1
108   {
109     MlirPassManager pm = mlirPassManagerCreate(ctx);
110     MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
111         pm, mlirStringRefCreateFromCString("module"));
112     MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
113         nestedModulePm, mlirStringRefCreateFromCString("func"));
114     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
115     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
116     MlirLogicalResult success = mlirPassManagerRun(pm, module);
117     if (mlirLogicalResultIsFailure(success))
118       exit(2);
119     mlirPassManagerDestroy(pm);
120   }
121 
122   mlirModuleDestroy(module);
123   mlirContextDestroy(ctx);
124 }
125 
printToStderr(MlirStringRef str,void * userData)126 static void printToStderr(MlirStringRef str, void *userData) {
127   (void)userData;
128   fwrite(str.data, 1, str.length, stderr);
129 }
130 
testPrintPassPipeline()131 void testPrintPassPipeline() {
132   MlirContext ctx = mlirContextCreate();
133   MlirPassManager pm = mlirPassManagerCreate(ctx);
134   // Populate the pass-manager
135   MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
136       pm, mlirStringRefCreateFromCString("module"));
137   MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
138       nestedModulePm, mlirStringRefCreateFromCString("func"));
139   MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
140   mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
141 
142   // Print the top level pass manager
143   // CHECK: Top-level: module(func(print-op-stats))
144   fprintf(stderr, "Top-level: ");
145   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
146                         NULL);
147   fprintf(stderr, "\n");
148 
149   // Print the pipeline nested one level down
150   // CHECK: Nested Module: func(print-op-stats)
151   fprintf(stderr, "Nested Module: ");
152   mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
153   fprintf(stderr, "\n");
154 
155   // Print the pipeline nested two levels down
156   // CHECK: Nested Module>Func: print-op-stats
157   fprintf(stderr, "Nested Module>Func: ");
158   mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
159   fprintf(stderr, "\n");
160 
161   mlirPassManagerDestroy(pm);
162   mlirContextDestroy(ctx);
163 }
164 
testParsePassPipeline()165 void testParsePassPipeline() {
166   MlirContext ctx = mlirContextCreate();
167   MlirPassManager pm = mlirPassManagerCreate(ctx);
168   // Try parse a pipeline.
169   MlirLogicalResult status = mlirParsePassPipeline(
170       mlirPassManagerGetAsOpPassManager(pm),
171       mlirStringRefCreateFromCString(
172           "module(func(print-op-stats), func(print-op-stats))"));
173   // Expect a failure, we haven't registered the print-op-stats pass yet.
174   if (mlirLogicalResultIsSuccess(status)) {
175     fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n");
176     exit(EXIT_FAILURE);
177   }
178   // Try again after registrating the pass.
179   mlirRegisterTransformsPrintOpStats();
180   status = mlirParsePassPipeline(
181       mlirPassManagerGetAsOpPassManager(pm),
182       mlirStringRefCreateFromCString(
183           "module(func(print-op-stats), func(print-op-stats))"));
184   // Expect a failure, we haven't registered the print-op-stats pass yet.
185   if (mlirLogicalResultIsFailure(status)) {
186     fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n");
187     exit(EXIT_FAILURE);
188   }
189 
190   // CHECK: Round-trip: module(func(print-op-stats), func(print-op-stats))
191   fprintf(stderr, "Round-trip: ");
192   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
193                         NULL);
194   fprintf(stderr, "\n");
195 }
196 
main()197 int main() {
198   testRunPassOnModule();
199   testRunPassOnNestedModule();
200   testPrintPassPipeline();
201   testParsePassPipeline();
202   return 0;
203 }
204