1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/python/mlir.h"
17 
18 #include <string>
19 
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/InitAllPasses.h"  // from @llvm-project
23 #include "mlir/Parser.h"  // from @llvm-project
24 #include "mlir/Pass/PassManager.h"  // from @llvm-project
25 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
26 #include "tensorflow/c/eager/c_api.h"
27 #include "tensorflow/c/eager/tfe_context_internal.h"
28 #include "tensorflow/c/tf_status.h"
29 #include "tensorflow/c/tf_status_helper.h"
30 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
34 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
36 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
37 #include "tensorflow/core/common_runtime/eager/context.h"
38 #include "tensorflow/core/common_runtime/function_body.h"
39 #include "tensorflow/core/common_runtime/function_def_utils.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/function.pb.h"
42 #include "tensorflow/core/framework/op.h"
43 
44 namespace tensorflow {
45 
46 namespace {
47 
48 // Runs pass pipeline `pass_pipeline` on `module` if `pass_pipeline` is not
49 // empty.
RunPassPipelineOnModule(mlir::ModuleOp module,const std::string & pass_pipeline,bool show_debug_info,TF_Status * status)50 std::string RunPassPipelineOnModule(mlir::ModuleOp module,
51                                     const std::string &pass_pipeline,
52                                     bool show_debug_info, TF_Status *status) {
53   if (!pass_pipeline.empty()) {
54     mlir::PassManager pm(module.getContext());
55     std::string error;
56     llvm::raw_string_ostream error_stream(error);
57     if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
58       TF_SetStatus(status, TF_INVALID_ARGUMENT,
59                    ("Invalid pass_pipeline: " + error_stream.str()).c_str());
60       return "// error";
61     }
62 
63     mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext());
64     if (failed(pm.run(module))) {
65       Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
66       return "// error";
67     }
68   }
69   return MlirModuleToString(module, show_debug_info);
70 }
71 
72 }  // anonymous namespace
73 
ImportGraphDef(const std::string & proto,const std::string & pass_pipeline,bool show_debug_info,TF_Status * status)74 std::string ImportGraphDef(const std::string &proto,
75                            const std::string &pass_pipeline,
76                            bool show_debug_info, TF_Status *status) {
77   GraphDef graphdef;
78   auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
79   if (!s.ok()) {
80     Set_TF_Status_from_Status(status, s);
81     return "// error";
82   }
83   GraphDebugInfo debug_info;
84   GraphImportConfig specs;
85   mlir::MLIRContext context;
86   auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
87   if (!module.ok()) {
88     Set_TF_Status_from_Status(status, module.status());
89     return "// error";
90   }
91 
92   return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info,
93                                  status);
94 }
95 
ImportFunction(const std::string & functiondef_proto,const std::string & pass_pipeline,bool show_debug_info,TFE_Context * tfe_context,TF_Status * status)96 std::string ImportFunction(const std::string &functiondef_proto,
97                            const std::string &pass_pipeline,
98                            bool show_debug_info, TFE_Context *tfe_context,
99                            TF_Status *status) {
100   FunctionDef functiondef;
101   auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef);
102   if (!s.ok()) {
103     Set_TF_Status_from_Status(status, s);
104     return "// error";
105   }
106 
107   const std::string &function_name = functiondef.signature().name();
108   EagerContext *cpp_context = ContextFromInterface(unwrap(tfe_context));
109   FunctionLibraryDefinition &flib_def = *cpp_context->FuncLibDef();
110   const tensorflow::FunctionDef *fdef = flib_def.Find(function_name);
111   if (fdef == nullptr) {
112     s = tensorflow::errors::NotFound("Cannot find function ", function_name);
113     Set_TF_Status_from_Status(status, s);
114     return "// error";
115   }
116 
117   std::unique_ptr<tensorflow::FunctionBody> fbody;
118   s = FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(), &flib_def,
119                               &fbody);
120   if (!s.ok()) {
121     Set_TF_Status_from_Status(status, s);
122     return "// error";
123   }
124 
125   mlir::MLIRContext context;
126   auto module = ConvertFunctionToMlir(fbody.get(), flib_def, &context);
127   if (!module.ok()) {
128     Set_TF_Status_from_Status(status, module.status());
129     return "// error";
130   }
131 
132   return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info,
133                                  status);
134 }
135 
ExperimentalConvertSavedModelToMlir(const std::string & saved_model_path,const std::string & exported_names_str,bool show_debug_info,TF_Status * status)136 std::string ExperimentalConvertSavedModelToMlir(
137     const std::string &saved_model_path, const std::string &exported_names_str,
138     bool show_debug_info, TF_Status *status) {
139   // Load the saved model into a SavedModelV2Bundle.
140 
141   tensorflow::SavedModelV2Bundle bundle;
142   auto load_status =
143       tensorflow::SavedModelV2Bundle::Load(saved_model_path, &bundle);
144   if (!load_status.ok()) {
145     Set_TF_Status_from_Status(status, load_status);
146     return "// error";
147   }
148 
149   // Convert the SavedModelV2Bundle to an MLIR module.
150 
151   std::vector<string> exported_names =
152       absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
153   mlir::MLIRContext context;
154   auto module_or = ConvertSavedModelToMlir(
155       &bundle, &context, absl::Span<std::string>(exported_names));
156   if (!module_or.status().ok()) {
157     Set_TF_Status_from_Status(status, module_or.status());
158     return "// error";
159   }
160 
161   return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
162 }
163 
ExperimentalConvertSavedModelV1ToMlirLite(const std::string & saved_model_path,const std::string & tags,bool upgrade_legacy,bool show_debug_info,TF_Status * status)164 std::string ExperimentalConvertSavedModelV1ToMlirLite(
165     const std::string &saved_model_path, const std::string &tags,
166     bool upgrade_legacy, bool show_debug_info, TF_Status *status) {
167   std::unordered_set<string> tag_set =
168       absl::StrSplit(tags, ',', absl::SkipEmpty());
169 
170   mlir::MLIRContext context;
171 
172   tensorflow::MLIRImportOptions import_options;
173   import_options.upgrade_legacy = upgrade_legacy;
174   auto module_or = SavedModelSignatureDefsToMlirImportLite(
175       saved_model_path, tag_set, /*exported_names=*/{}, &context,
176       import_options);
177   if (!module_or.status().ok()) {
178     Set_TF_Status_from_Status(status, module_or.status());
179     return "// error";
180   }
181 
182   return MlirModuleToString(*module_or.ValueOrDie(), show_debug_info);
183 }
184 
ExperimentalConvertSavedModelV1ToMlir(const std::string & saved_model_path,const std::string & tags,bool lift_variables,bool upgrade_legacy,bool show_debug_info,TF_Status * status)185 std::string ExperimentalConvertSavedModelV1ToMlir(
186     const std::string &saved_model_path, const std::string &tags,
187     bool lift_variables, bool upgrade_legacy, bool show_debug_info,
188     TF_Status *status) {
189   // Load the saved model into a SavedModelBundle.
190 
191   std::unordered_set<string> tag_set =
192       absl::StrSplit(tags, ',', absl::SkipEmpty());
193 
194   tensorflow::SavedModelBundle bundle;
195   auto load_status =
196       tensorflow::LoadSavedModel({}, {}, saved_model_path, tag_set, &bundle);
197   if (!load_status.ok()) {
198     Set_TF_Status_from_Status(status, load_status);
199     return "// error";
200   }
201 
202   // Convert the SavedModelBundle to an MLIR module.
203 
204   mlir::MLIRContext context;
205   tensorflow::MLIRImportOptions import_options;
206   import_options.upgrade_legacy = upgrade_legacy;
207   auto module_or =
208       ConvertSavedModelV1ToMlir(bundle, {}, &context, import_options);
209   if (!module_or.status().ok()) {
210     Set_TF_Status_from_Status(status, module_or.status());
211     return "// error";
212   }
213 
214   // Run the tf standard pipeline by default and then, run passes that lift
215   // variables if the flag is set on the module.
216   mlir::OwningModuleRef module = module_or.ConsumeValueOrDie();
217   mlir::PassManager pm(&context);
218   std::string error;
219   llvm::raw_string_ostream error_stream(error);
220 
221   mlir::TF::StandardPipelineOptions tf_options;
222   mlir::TF::CreateTFStandardPipeline(pm, tf_options);
223   if (lift_variables) {
224     pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
225     pm.addPass(
226         mlir::tf_saved_model::CreateLiftVariablesPass(bundle.GetSession()));
227   }
228 
229   mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
230   if (failed(pm.run(*module))) {
231     Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
232     return "// error";
233   }
234   return MlirModuleToString(*module, show_debug_info);
235 }
236 
ExperimentalRunPassPipeline(const std::string & mlir_txt,const std::string & pass_pipeline,bool show_debug_info,TF_Status * status)237 std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
238                                         const std::string &pass_pipeline,
239                                         bool show_debug_info,
240                                         TF_Status *status) {
241   mlir::DialectRegistry registry;
242   mlir::RegisterAllTensorFlowDialects(registry);
243   mlir::MLIRContext context(registry);
244   mlir::OwningModuleRef module;
245   {
246     mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
247     module = mlir::parseSourceString(mlir_txt, &context);
248     if (!module) {
249       Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
250       return "// error";
251     }
252   }
253 
254   // Run the pass_pipeline on the module.
255   mlir::PassManager pm(&context);
256   std::string error;
257   llvm::raw_string_ostream error_stream(error);
258   mlir::registerAllPasses();
259   if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
260     TF_SetStatus(status, TF_INVALID_ARGUMENT,
261                  ("Invalid pass_pipeline: " + error_stream.str()).c_str());
262     return "// error";
263   }
264 
265   mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
266   if (failed(pm.run(*module))) {
267     Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
268     return "// error";
269   }
270   return MlirModuleToString(*module, show_debug_info);
271 }
272 
273 }  // namespace tensorflow
274