1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/python/mlir.h"
17
18 #include <string>
19
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/InitAllPasses.h" // from @llvm-project
23 #include "mlir/Parser.h" // from @llvm-project
24 #include "mlir/Pass/PassManager.h" // from @llvm-project
25 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
26 #include "tensorflow/c/eager/c_api.h"
27 #include "tensorflow/c/eager/tfe_context_internal.h"
28 #include "tensorflow/c/tf_status.h"
29 #include "tensorflow/c/tf_status_helper.h"
30 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
34 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
36 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
37 #include "tensorflow/core/common_runtime/eager/context.h"
38 #include "tensorflow/core/common_runtime/function_body.h"
39 #include "tensorflow/core/common_runtime/function_def_utils.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/function.pb.h"
42 #include "tensorflow/core/framework/op.h"
43
44 namespace tensorflow {
45
46 namespace {
47
48 // Runs pass pipeline `pass_pipeline` on `module` if `pass_pipeline` is not
49 // empty.
RunPassPipelineOnModule(mlir::ModuleOp module,const std::string & pass_pipeline,bool show_debug_info,TF_Status * status)50 std::string RunPassPipelineOnModule(mlir::ModuleOp module,
51 const std::string &pass_pipeline,
52 bool show_debug_info, TF_Status *status) {
53 if (!pass_pipeline.empty()) {
54 mlir::PassManager pm(module.getContext());
55 std::string error;
56 llvm::raw_string_ostream error_stream(error);
57 if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
58 TF_SetStatus(status, TF_INVALID_ARGUMENT,
59 ("Invalid pass_pipeline: " + error_stream.str()).c_str());
60 return "// error";
61 }
62
63 mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext());
64 if (failed(pm.run(module))) {
65 Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
66 return "// error";
67 }
68 }
69 return MlirModuleToString(module, show_debug_info);
70 }
71
72 } // anonymous namespace
73
ImportGraphDef(const std::string & proto,const std::string & pass_pipeline,bool show_debug_info,TF_Status * status)74 std::string ImportGraphDef(const std::string &proto,
75 const std::string &pass_pipeline,
76 bool show_debug_info, TF_Status *status) {
77 GraphDef graphdef;
78 auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
79 if (!s.ok()) {
80 Set_TF_Status_from_Status(status, s);
81 return "// error";
82 }
83 GraphDebugInfo debug_info;
84 GraphImportConfig specs;
85 mlir::MLIRContext context;
86 auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
87 if (!module.ok()) {
88 Set_TF_Status_from_Status(status, module.status());
89 return "// error";
90 }
91
92 return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info,
93 status);
94 }
95
ImportFunction(const std::string & functiondef_proto,const std::string & pass_pipeline,bool show_debug_info,TFE_Context * tfe_context,TF_Status * status)96 std::string ImportFunction(const std::string &functiondef_proto,
97 const std::string &pass_pipeline,
98 bool show_debug_info, TFE_Context *tfe_context,
99 TF_Status *status) {
100 FunctionDef functiondef;
101 auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef);
102 if (!s.ok()) {
103 Set_TF_Status_from_Status(status, s);
104 return "// error";
105 }
106
107 const std::string &function_name = functiondef.signature().name();
108 EagerContext *cpp_context = ContextFromInterface(unwrap(tfe_context));
109 FunctionLibraryDefinition &flib_def = *cpp_context->FuncLibDef();
110 const tensorflow::FunctionDef *fdef = flib_def.Find(function_name);
111 if (fdef == nullptr) {
112 s = tensorflow::errors::NotFound("Cannot find function ", function_name);
113 Set_TF_Status_from_Status(status, s);
114 return "// error";
115 }
116
117 std::unique_ptr<tensorflow::FunctionBody> fbody;
118 s = FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(), &flib_def,
119 &fbody);
120 if (!s.ok()) {
121 Set_TF_Status_from_Status(status, s);
122 return "// error";
123 }
124
125 mlir::MLIRContext context;
126 auto module = ConvertFunctionToMlir(fbody.get(), flib_def, &context);
127 if (!module.ok()) {
128 Set_TF_Status_from_Status(status, module.status());
129 return "// error";
130 }
131
132 return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info,
133 status);
134 }
135
ExperimentalConvertSavedModelToMlir(const std::string & saved_model_path,const std::string & exported_names_str,bool show_debug_info,TF_Status * status)136 std::string ExperimentalConvertSavedModelToMlir(
137 const std::string &saved_model_path, const std::string &exported_names_str,
138 bool show_debug_info, TF_Status *status) {
139 // Load the saved model into a SavedModelV2Bundle.
140
141 tensorflow::SavedModelV2Bundle bundle;
142 auto load_status =
143 tensorflow::SavedModelV2Bundle::Load(saved_model_path, &bundle);
144 if (!load_status.ok()) {
145 Set_TF_Status_from_Status(status, load_status);
146 return "// error";
147 }
148
149 // Convert the SavedModelV2Bundle to an MLIR module.
150
151 std::vector<string> exported_names =
152 absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
153 mlir::MLIRContext context;
154 auto module_or = ConvertSavedModelToMlir(
155 &bundle, &context, absl::Span<std::string>(exported_names));
156 if (!module_or.status().ok()) {
157 Set_TF_Status_from_Status(status, module_or.status());
158 return "// error";
159 }
160
161 return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
162 }
163
ExperimentalConvertSavedModelV1ToMlirLite(const std::string & saved_model_path,const std::string & tags,bool upgrade_legacy,bool show_debug_info,TF_Status * status)164 std::string ExperimentalConvertSavedModelV1ToMlirLite(
165 const std::string &saved_model_path, const std::string &tags,
166 bool upgrade_legacy, bool show_debug_info, TF_Status *status) {
167 std::unordered_set<string> tag_set =
168 absl::StrSplit(tags, ',', absl::SkipEmpty());
169
170 mlir::MLIRContext context;
171
172 tensorflow::MLIRImportOptions import_options;
173 import_options.upgrade_legacy = upgrade_legacy;
174 auto module_or = SavedModelSignatureDefsToMlirImportLite(
175 saved_model_path, tag_set, /*exported_names=*/{}, &context,
176 import_options);
177 if (!module_or.status().ok()) {
178 Set_TF_Status_from_Status(status, module_or.status());
179 return "// error";
180 }
181
182 return MlirModuleToString(*module_or.ValueOrDie(), show_debug_info);
183 }
184
ExperimentalConvertSavedModelV1ToMlir(const std::string & saved_model_path,const std::string & tags,bool lift_variables,bool upgrade_legacy,bool show_debug_info,TF_Status * status)185 std::string ExperimentalConvertSavedModelV1ToMlir(
186 const std::string &saved_model_path, const std::string &tags,
187 bool lift_variables, bool upgrade_legacy, bool show_debug_info,
188 TF_Status *status) {
189 // Load the saved model into a SavedModelBundle.
190
191 std::unordered_set<string> tag_set =
192 absl::StrSplit(tags, ',', absl::SkipEmpty());
193
194 tensorflow::SavedModelBundle bundle;
195 auto load_status =
196 tensorflow::LoadSavedModel({}, {}, saved_model_path, tag_set, &bundle);
197 if (!load_status.ok()) {
198 Set_TF_Status_from_Status(status, load_status);
199 return "// error";
200 }
201
202 // Convert the SavedModelBundle to an MLIR module.
203
204 mlir::MLIRContext context;
205 tensorflow::MLIRImportOptions import_options;
206 import_options.upgrade_legacy = upgrade_legacy;
207 auto module_or =
208 ConvertSavedModelV1ToMlir(bundle, {}, &context, import_options);
209 if (!module_or.status().ok()) {
210 Set_TF_Status_from_Status(status, module_or.status());
211 return "// error";
212 }
213
214 // Run the tf standard pipeline by default and then, run passes that lift
215 // variables if the flag is set on the module.
216 mlir::OwningModuleRef module = module_or.ConsumeValueOrDie();
217 mlir::PassManager pm(&context);
218 std::string error;
219 llvm::raw_string_ostream error_stream(error);
220
221 mlir::TF::StandardPipelineOptions tf_options;
222 mlir::TF::CreateTFStandardPipeline(pm, tf_options);
223 if (lift_variables) {
224 pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
225 pm.addPass(
226 mlir::tf_saved_model::CreateLiftVariablesPass(bundle.GetSession()));
227 }
228
229 mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
230 if (failed(pm.run(*module))) {
231 Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
232 return "// error";
233 }
234 return MlirModuleToString(*module, show_debug_info);
235 }
236
ExperimentalRunPassPipeline(const std::string & mlir_txt,const std::string & pass_pipeline,bool show_debug_info,TF_Status * status)237 std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
238 const std::string &pass_pipeline,
239 bool show_debug_info,
240 TF_Status *status) {
241 mlir::DialectRegistry registry;
242 mlir::RegisterAllTensorFlowDialects(registry);
243 mlir::MLIRContext context(registry);
244 mlir::OwningModuleRef module;
245 {
246 mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
247 module = mlir::parseSourceString(mlir_txt, &context);
248 if (!module) {
249 Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
250 return "// error";
251 }
252 }
253
254 // Run the pass_pipeline on the module.
255 mlir::PassManager pm(&context);
256 std::string error;
257 llvm::raw_string_ostream error_stream(error);
258 mlir::registerAllPasses();
259 if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
260 TF_SetStatus(status, TF_INVALID_ARGUMENT,
261 ("Invalid pass_pipeline: " + error_stream.str()).c_str());
262 return "// error";
263 }
264
265 mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
266 if (failed(pm.run(*module))) {
267 Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
268 return "// error";
269 }
270 return MlirModuleToString(*module, show_debug_info);
271 }
272
273 } // namespace tensorflow
274