1 //===- JitRunner.h - MLIR CPU Execution Driver Library ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #ifndef MLIR_SUPPORT_JITRUNNER_H_
19 #define MLIR_SUPPORT_JITRUNNER_H_
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ExecutionEngine/Orc/Core.h"
23 
24 namespace llvm {
25 class Module;
26 class LLVMContext;
27 
28 namespace orc {
29 class MangleAndInterner;
30 } // namespace orc
31 } // namespace llvm
32 
33 namespace mlir {
34 
35 class ModuleOp;
36 struct LogicalResult;
37 
38 struct JitRunnerConfig {
39   /// MLIR transformer applied after parsing the input into MLIR IR and before
40   /// passing the MLIR module to the ExecutionEngine.
41   llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer = nullptr;
42 
43   /// A custom function that is passed to ExecutionEngine. It processes MLIR
44   /// module and creates LLVM IR module.
45   llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
46                                                    llvm::LLVMContext &)>
47       llvmModuleBuilder = nullptr;
48 
49   /// A callback to register symbols with ExecutionEngine at runtime.
50   llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
51       runtimesymbolMap = nullptr;
52 };
53 
54 // Entry point for all CPU runners. Expects the common argc/argv arguments for
55 // standard C++ main functions.
56 int JitRunnerMain(int argc, char **argv, JitRunnerConfig config = {});
57 
58 } // namespace mlir
59 
60 #endif // MLIR_SUPPORT_JITRUNNER_H_
61