1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11 
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Dialect.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/IR/Verifier.h"
21 #include "mlir/Parser.h"
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // Context API.
27 //===----------------------------------------------------------------------===//
28 
mlirContextCreate()29 MlirContext mlirContextCreate() {
30   auto *context = new MLIRContext;
31   return wrap(context);
32 }
33 
mlirContextEqual(MlirContext ctx1,MlirContext ctx2)34 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
35   return unwrap(ctx1) == unwrap(ctx2);
36 }
37 
mlirContextDestroy(MlirContext context)38 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
39 
mlirContextSetAllowUnregisteredDialects(MlirContext context,bool allow)40 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
41   unwrap(context)->allowUnregisteredDialects(allow);
42 }
43 
mlirContextGetAllowUnregisteredDialects(MlirContext context)44 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
45   return unwrap(context)->allowsUnregisteredDialects();
46 }
mlirContextGetNumRegisteredDialects(MlirContext context)47 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
48   return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
49 }
50 
51 // TODO: expose a cheaper way than constructing + sorting a vector only to take
52 // its size.
mlirContextGetNumLoadedDialects(MlirContext context)53 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
54   return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
55 }
56 
mlirContextGetOrLoadDialect(MlirContext context,MlirStringRef name)57 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
58                                         MlirStringRef name) {
59   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // Dialect API.
64 //===----------------------------------------------------------------------===//
65 
mlirDialectGetContext(MlirDialect dialect)66 MlirContext mlirDialectGetContext(MlirDialect dialect) {
67   return wrap(unwrap(dialect)->getContext());
68 }
69 
mlirDialectEqual(MlirDialect dialect1,MlirDialect dialect2)70 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
71   return unwrap(dialect1) == unwrap(dialect2);
72 }
73 
mlirDialectGetNamespace(MlirDialect dialect)74 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
75   return wrap(unwrap(dialect)->getNamespace());
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // Printing flags API.
80 //===----------------------------------------------------------------------===//
81 
mlirOpPrintingFlagsCreate()82 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
83   return wrap(new OpPrintingFlags());
84 }
85 
mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)86 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
87   delete unwrap(flags);
88 }
89 
mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,intptr_t largeElementLimit)90 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
91                                                 intptr_t largeElementLimit) {
92   unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
93 }
94 
mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,bool prettyForm)95 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
96                                         bool prettyForm) {
97   unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
98 }
99 
mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)100 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
101   unwrap(flags)->printGenericOpForm();
102 }
103 
mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)104 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
105   unwrap(flags)->useLocalScope();
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // Location API.
110 //===----------------------------------------------------------------------===//
111 
mlirLocationFileLineColGet(MlirContext context,MlirStringRef filename,unsigned line,unsigned col)112 MlirLocation mlirLocationFileLineColGet(MlirContext context,
113                                         MlirStringRef filename, unsigned line,
114                                         unsigned col) {
115   return wrap(
116       FileLineColLoc::get(unwrap(filename), line, col, unwrap(context)));
117 }
118 
mlirLocationUnknownGet(MlirContext context)119 MlirLocation mlirLocationUnknownGet(MlirContext context) {
120   return wrap(UnknownLoc::get(unwrap(context)));
121 }
122 
mlirLocationEqual(MlirLocation l1,MlirLocation l2)123 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
124   return unwrap(l1) == unwrap(l2);
125 }
126 
mlirLocationGetContext(MlirLocation location)127 MlirContext mlirLocationGetContext(MlirLocation location) {
128   return wrap(unwrap(location).getContext());
129 }
130 
mlirLocationPrint(MlirLocation location,MlirStringCallback callback,void * userData)131 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
132                        void *userData) {
133   detail::CallbackOstream stream(callback, userData);
134   unwrap(location).print(stream);
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // Module API.
139 //===----------------------------------------------------------------------===//
140 
mlirModuleCreateEmpty(MlirLocation location)141 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
142   return wrap(ModuleOp::create(unwrap(location)));
143 }
144 
mlirModuleCreateParse(MlirContext context,MlirStringRef module)145 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
146   OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context));
147   if (!owning)
148     return MlirModule{nullptr};
149   return MlirModule{owning.release().getOperation()};
150 }
151 
mlirModuleGetContext(MlirModule module)152 MlirContext mlirModuleGetContext(MlirModule module) {
153   return wrap(unwrap(module).getContext());
154 }
155 
mlirModuleGetBody(MlirModule module)156 MlirBlock mlirModuleGetBody(MlirModule module) {
157   return wrap(unwrap(module).getBody());
158 }
159 
mlirModuleDestroy(MlirModule module)160 void mlirModuleDestroy(MlirModule module) {
161   // Transfer ownership to an OwningModuleRef so that its destructor is called.
162   OwningModuleRef(unwrap(module));
163 }
164 
mlirModuleGetOperation(MlirModule module)165 MlirOperation mlirModuleGetOperation(MlirModule module) {
166   return wrap(unwrap(module).getOperation());
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // Operation state API.
171 //===----------------------------------------------------------------------===//
172 
mlirOperationStateGet(MlirStringRef name,MlirLocation loc)173 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
174   MlirOperationState state;
175   state.name = name;
176   state.location = loc;
177   state.nResults = 0;
178   state.results = nullptr;
179   state.nOperands = 0;
180   state.operands = nullptr;
181   state.nRegions = 0;
182   state.regions = nullptr;
183   state.nSuccessors = 0;
184   state.successors = nullptr;
185   state.nAttributes = 0;
186   state.attributes = nullptr;
187   return state;
188 }
189 
190 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
191   state->elemName =                                                            \
192       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
193   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
194   state->sizeName += n;
195 
mlirOperationStateAddResults(MlirOperationState * state,intptr_t n,MlirType const * results)196 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
197                                   MlirType const *results) {
198   APPEND_ELEMS(MlirType, nResults, results);
199 }
200 
mlirOperationStateAddOperands(MlirOperationState * state,intptr_t n,MlirValue const * operands)201 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
202                                    MlirValue const *operands) {
203   APPEND_ELEMS(MlirValue, nOperands, operands);
204 }
mlirOperationStateAddOwnedRegions(MlirOperationState * state,intptr_t n,MlirRegion const * regions)205 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
206                                        MlirRegion const *regions) {
207   APPEND_ELEMS(MlirRegion, nRegions, regions);
208 }
mlirOperationStateAddSuccessors(MlirOperationState * state,intptr_t n,MlirBlock const * successors)209 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
210                                      MlirBlock const *successors) {
211   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
212 }
mlirOperationStateAddAttributes(MlirOperationState * state,intptr_t n,MlirNamedAttribute const * attributes)213 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
214                                      MlirNamedAttribute const *attributes) {
215   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // Operation API.
220 //===----------------------------------------------------------------------===//
221 
mlirOperationCreate(const MlirOperationState * state)222 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
223   assert(state);
224   OperationState cppState(unwrap(state->location), unwrap(state->name));
225   SmallVector<Type, 4> resultStorage;
226   SmallVector<Value, 8> operandStorage;
227   SmallVector<Block *, 2> successorStorage;
228   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
229   cppState.addOperands(
230       unwrapList(state->nOperands, state->operands, operandStorage));
231   cppState.addSuccessors(
232       unwrapList(state->nSuccessors, state->successors, successorStorage));
233 
234   cppState.attributes.reserve(state->nAttributes);
235   for (intptr_t i = 0; i < state->nAttributes; ++i)
236     cppState.addAttribute(unwrap(state->attributes[i].name),
237                           unwrap(state->attributes[i].attribute));
238 
239   for (intptr_t i = 0; i < state->nRegions; ++i)
240     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
241 
242   MlirOperation result = wrap(Operation::create(cppState));
243   free(state->results);
244   free(state->operands);
245   free(state->successors);
246   free(state->regions);
247   free(state->attributes);
248   return result;
249 }
250 
mlirOperationDestroy(MlirOperation op)251 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
252 
mlirOperationEqual(MlirOperation op,MlirOperation other)253 bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
254   return unwrap(op) == unwrap(other);
255 }
256 
mlirOperationGetName(MlirOperation op)257 MlirIdentifier mlirOperationGetName(MlirOperation op) {
258   return wrap(unwrap(op)->getName().getIdentifier());
259 }
260 
mlirOperationGetBlock(MlirOperation op)261 MlirBlock mlirOperationGetBlock(MlirOperation op) {
262   return wrap(unwrap(op)->getBlock());
263 }
264 
mlirOperationGetParentOperation(MlirOperation op)265 MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
266   return wrap(unwrap(op)->getParentOp());
267 }
268 
mlirOperationGetNumRegions(MlirOperation op)269 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
270   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
271 }
272 
mlirOperationGetRegion(MlirOperation op,intptr_t pos)273 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
274   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
275 }
276 
mlirOperationGetNextInBlock(MlirOperation op)277 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
278   return wrap(unwrap(op)->getNextNode());
279 }
280 
mlirOperationGetNumOperands(MlirOperation op)281 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
282   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
283 }
284 
mlirOperationGetOperand(MlirOperation op,intptr_t pos)285 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
286   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
287 }
288 
mlirOperationGetNumResults(MlirOperation op)289 intptr_t mlirOperationGetNumResults(MlirOperation op) {
290   return static_cast<intptr_t>(unwrap(op)->getNumResults());
291 }
292 
mlirOperationGetResult(MlirOperation op,intptr_t pos)293 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
294   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
295 }
296 
mlirOperationGetNumSuccessors(MlirOperation op)297 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
298   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
299 }
300 
mlirOperationGetSuccessor(MlirOperation op,intptr_t pos)301 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
302   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
303 }
304 
mlirOperationGetNumAttributes(MlirOperation op)305 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
306   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
307 }
308 
mlirOperationGetAttribute(MlirOperation op,intptr_t pos)309 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
310   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
311   return MlirNamedAttribute{wrap(attr.first.strref()), wrap(attr.second)};
312 }
313 
mlirOperationGetAttributeByName(MlirOperation op,MlirStringRef name)314 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
315                                               MlirStringRef name) {
316   return wrap(unwrap(op)->getAttr(unwrap(name)));
317 }
318 
mlirOperationSetAttributeByName(MlirOperation op,MlirStringRef name,MlirAttribute attr)319 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
320                                      MlirAttribute attr) {
321   unwrap(op)->setAttr(unwrap(name), unwrap(attr));
322 }
323 
mlirOperationRemoveAttributeByName(MlirOperation op,MlirStringRef name)324 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
325   auto removeResult = unwrap(op)->removeAttr(unwrap(name));
326   return removeResult == MutableDictionaryAttr::RemoveResult::Removed;
327 }
328 
mlirOperationPrint(MlirOperation op,MlirStringCallback callback,void * userData)329 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
330                         void *userData) {
331   detail::CallbackOstream stream(callback, userData);
332   unwrap(op)->print(stream);
333 }
334 
mlirOperationPrintWithFlags(MlirOperation op,MlirOpPrintingFlags flags,MlirStringCallback callback,void * userData)335 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
336                                  MlirStringCallback callback, void *userData) {
337   detail::CallbackOstream stream(callback, userData);
338   unwrap(op)->print(stream, *unwrap(flags));
339 }
340 
mlirOperationDump(MlirOperation op)341 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
342 
mlirOperationVerify(MlirOperation op)343 bool mlirOperationVerify(MlirOperation op) {
344   return succeeded(verify(unwrap(op)));
345 }
346 
347 //===----------------------------------------------------------------------===//
348 // Region API.
349 //===----------------------------------------------------------------------===//
350 
mlirRegionCreate()351 MlirRegion mlirRegionCreate() { return wrap(new Region); }
352 
mlirRegionGetFirstBlock(MlirRegion region)353 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
354   Region *cppRegion = unwrap(region);
355   if (cppRegion->empty())
356     return wrap(static_cast<Block *>(nullptr));
357   return wrap(&cppRegion->front());
358 }
359 
mlirRegionAppendOwnedBlock(MlirRegion region,MlirBlock block)360 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
361   unwrap(region)->push_back(unwrap(block));
362 }
363 
mlirRegionInsertOwnedBlock(MlirRegion region,intptr_t pos,MlirBlock block)364 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
365                                 MlirBlock block) {
366   auto &blockList = unwrap(region)->getBlocks();
367   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
368 }
369 
mlirRegionInsertOwnedBlockAfter(MlirRegion region,MlirBlock reference,MlirBlock block)370 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
371                                      MlirBlock block) {
372   Region *cppRegion = unwrap(region);
373   if (mlirBlockIsNull(reference)) {
374     cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
375     return;
376   }
377 
378   assert(unwrap(reference)->getParent() == unwrap(region) &&
379          "expected reference block to belong to the region");
380   cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
381                                      unwrap(block));
382 }
383 
mlirRegionInsertOwnedBlockBefore(MlirRegion region,MlirBlock reference,MlirBlock block)384 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
385                                       MlirBlock block) {
386   if (mlirBlockIsNull(reference))
387     return mlirRegionAppendOwnedBlock(region, block);
388 
389   assert(unwrap(reference)->getParent() == unwrap(region) &&
390          "expected reference block to belong to the region");
391   unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
392                                      unwrap(block));
393 }
394 
mlirRegionDestroy(MlirRegion region)395 void mlirRegionDestroy(MlirRegion region) {
396   delete static_cast<Region *>(region.ptr);
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // Block API.
401 //===----------------------------------------------------------------------===//
402 
mlirBlockCreate(intptr_t nArgs,MlirType const * args)403 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args) {
404   Block *b = new Block;
405   for (intptr_t i = 0; i < nArgs; ++i)
406     b->addArgument(unwrap(args[i]));
407   return wrap(b);
408 }
409 
mlirBlockEqual(MlirBlock block,MlirBlock other)410 bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
411   return unwrap(block) == unwrap(other);
412 }
413 
mlirBlockGetNextInRegion(MlirBlock block)414 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
415   return wrap(unwrap(block)->getNextNode());
416 }
417 
mlirBlockGetFirstOperation(MlirBlock block)418 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
419   Block *cppBlock = unwrap(block);
420   if (cppBlock->empty())
421     return wrap(static_cast<Operation *>(nullptr));
422   return wrap(&cppBlock->front());
423 }
424 
mlirBlockGetTerminator(MlirBlock block)425 MlirOperation mlirBlockGetTerminator(MlirBlock block) {
426   Block *cppBlock = unwrap(block);
427   if (cppBlock->empty())
428     return wrap(static_cast<Operation *>(nullptr));
429   Operation &back = cppBlock->back();
430   if (!back.isKnownTerminator())
431     return wrap(static_cast<Operation *>(nullptr));
432   return wrap(&back);
433 }
434 
mlirBlockAppendOwnedOperation(MlirBlock block,MlirOperation operation)435 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
436   unwrap(block)->push_back(unwrap(operation));
437 }
438 
mlirBlockInsertOwnedOperation(MlirBlock block,intptr_t pos,MlirOperation operation)439 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
440                                    MlirOperation operation) {
441   auto &opList = unwrap(block)->getOperations();
442   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
443 }
444 
mlirBlockInsertOwnedOperationAfter(MlirBlock block,MlirOperation reference,MlirOperation operation)445 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
446                                         MlirOperation reference,
447                                         MlirOperation operation) {
448   Block *cppBlock = unwrap(block);
449   if (mlirOperationIsNull(reference)) {
450     cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
451     return;
452   }
453 
454   assert(unwrap(reference)->getBlock() == unwrap(block) &&
455          "expected reference operation to belong to the block");
456   cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
457                                         unwrap(operation));
458 }
459 
mlirBlockInsertOwnedOperationBefore(MlirBlock block,MlirOperation reference,MlirOperation operation)460 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
461                                          MlirOperation reference,
462                                          MlirOperation operation) {
463   if (mlirOperationIsNull(reference))
464     return mlirBlockAppendOwnedOperation(block, operation);
465 
466   assert(unwrap(reference)->getBlock() == unwrap(block) &&
467          "expected reference operation to belong to the block");
468   unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
469                                         unwrap(operation));
470 }
471 
mlirBlockDestroy(MlirBlock block)472 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
473 
mlirBlockGetNumArguments(MlirBlock block)474 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
475   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
476 }
477 
mlirBlockGetArgument(MlirBlock block,intptr_t pos)478 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
479   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
480 }
481 
mlirBlockPrint(MlirBlock block,MlirStringCallback callback,void * userData)482 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
483                     void *userData) {
484   detail::CallbackOstream stream(callback, userData);
485   unwrap(block)->print(stream);
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // Value API.
490 //===----------------------------------------------------------------------===//
491 
mlirValueEqual(MlirValue value1,MlirValue value2)492 bool mlirValueEqual(MlirValue value1, MlirValue value2) {
493   return unwrap(value1) == unwrap(value2);
494 }
495 
mlirValueIsABlockArgument(MlirValue value)496 bool mlirValueIsABlockArgument(MlirValue value) {
497   return unwrap(value).isa<BlockArgument>();
498 }
499 
mlirValueIsAOpResult(MlirValue value)500 bool mlirValueIsAOpResult(MlirValue value) {
501   return unwrap(value).isa<OpResult>();
502 }
503 
mlirBlockArgumentGetOwner(MlirValue value)504 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
505   return wrap(unwrap(value).cast<BlockArgument>().getOwner());
506 }
507 
mlirBlockArgumentGetArgNumber(MlirValue value)508 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
509   return static_cast<intptr_t>(
510       unwrap(value).cast<BlockArgument>().getArgNumber());
511 }
512 
mlirBlockArgumentSetType(MlirValue value,MlirType type)513 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
514   unwrap(value).cast<BlockArgument>().setType(unwrap(type));
515 }
516 
mlirOpResultGetOwner(MlirValue value)517 MlirOperation mlirOpResultGetOwner(MlirValue value) {
518   return wrap(unwrap(value).cast<OpResult>().getOwner());
519 }
520 
mlirOpResultGetResultNumber(MlirValue value)521 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
522   return static_cast<intptr_t>(
523       unwrap(value).cast<OpResult>().getResultNumber());
524 }
525 
mlirValueGetType(MlirValue value)526 MlirType mlirValueGetType(MlirValue value) {
527   return wrap(unwrap(value).getType());
528 }
529 
mlirValueDump(MlirValue value)530 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
531 
mlirValuePrint(MlirValue value,MlirStringCallback callback,void * userData)532 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
533                     void *userData) {
534   detail::CallbackOstream stream(callback, userData);
535   unwrap(value).print(stream);
536 }
537 
538 //===----------------------------------------------------------------------===//
539 // Type API.
540 //===----------------------------------------------------------------------===//
541 
mlirTypeParseGet(MlirContext context,MlirStringRef type)542 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
543   return wrap(mlir::parseType(unwrap(type), unwrap(context)));
544 }
545 
mlirTypeGetContext(MlirType type)546 MlirContext mlirTypeGetContext(MlirType type) {
547   return wrap(unwrap(type).getContext());
548 }
549 
mlirTypeEqual(MlirType t1,MlirType t2)550 bool mlirTypeEqual(MlirType t1, MlirType t2) {
551   return unwrap(t1) == unwrap(t2);
552 }
553 
mlirTypePrint(MlirType type,MlirStringCallback callback,void * userData)554 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
555   detail::CallbackOstream stream(callback, userData);
556   unwrap(type).print(stream);
557 }
558 
mlirTypeDump(MlirType type)559 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
560 
561 //===----------------------------------------------------------------------===//
562 // Attribute API.
563 //===----------------------------------------------------------------------===//
564 
mlirAttributeParseGet(MlirContext context,MlirStringRef attr)565 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
566   return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context)));
567 }
568 
mlirAttributeGetContext(MlirAttribute attribute)569 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
570   return wrap(unwrap(attribute).getContext());
571 }
572 
mlirAttributeGetType(MlirAttribute attribute)573 MlirType mlirAttributeGetType(MlirAttribute attribute) {
574   return wrap(unwrap(attribute).getType());
575 }
576 
mlirAttributeEqual(MlirAttribute a1,MlirAttribute a2)577 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
578   return unwrap(a1) == unwrap(a2);
579 }
580 
mlirAttributePrint(MlirAttribute attr,MlirStringCallback callback,void * userData)581 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
582                         void *userData) {
583   detail::CallbackOstream stream(callback, userData);
584   unwrap(attr).print(stream);
585 }
586 
mlirAttributeDump(MlirAttribute attr)587 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
588 
mlirNamedAttributeGet(MlirStringRef name,MlirAttribute attr)589 MlirNamedAttribute mlirNamedAttributeGet(MlirStringRef name,
590                                          MlirAttribute attr) {
591   return MlirNamedAttribute{name, attr};
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // Identifier API.
596 //===----------------------------------------------------------------------===//
597 
mlirIdentifierGet(MlirContext context,MlirStringRef str)598 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
599   return wrap(Identifier::get(unwrap(str), unwrap(context)));
600 }
601 
mlirIdentifierEqual(MlirIdentifier ident,MlirIdentifier other)602 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
603   return unwrap(ident) == unwrap(other);
604 }
605 
mlirIdentifierStr(MlirIdentifier ident)606 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
607   return wrap(unwrap(ident).strref());
608 }
609