1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
17 
18 #include "absl/memory/memory.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/Identifier.h"  // from @llvm-project
24 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/Parser.h"  // from @llvm-project
27 #include "tensorflow/cc/saved_model/bundle_v2.h"
28 #include "tensorflow/cc/saved_model/reader.h"
29 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
30 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
31 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/graph/tensor_id.h"
37 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/protobuf.h"
40 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
41 
42 namespace tensorflow {
43 
GraphdefToMlirImport(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)44 static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
45     llvm::StringRef input, absl::string_view debug_info_file,
46     const std::vector<std::string>& input_arrays,
47     const std::vector<std::string>& input_dtypes,
48     const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
49     const std::vector<std::string>& output_arrays,
50     const std::vector<std::string>& control_output_arrays,
51     bool prune_unused_nodes, bool convert_legacy_fed_inputs,
52     bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
53     mlir::MLIRContext* context) {
54   GraphDef graphdef;
55   TF_RETURN_IF_ERROR(
56       tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef));
57 
58   GraphDebugInfo debug_info;
59   if (!debug_info_file.empty()) {
60     TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_file, &debug_info));
61   }
62 
63   GraphImportConfig specs;
64   specs.prune_unused_nodes = prune_unused_nodes;
65   specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs;
66   specs.graph_as_function = graph_as_function;
67   specs.upgrade_legacy = upgrade_legacy;
68   specs.enable_shape_inference = enable_shape_inference;
69   TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes,
70                                          input_shapes, &specs.inputs));
71   TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs));
72   TF_RETURN_IF_ERROR(
73       ParseOutputArrayInfo(control_output_arrays, &specs.control_outputs));
74   // TODO(b/142828368): Pruning should not be needed when TF import
75   // supports importing graphs w/ unregistered ops natively.
76   GraphDef pruned_graph_def;
77   if (specs.prune_unused_nodes) {
78     std::vector<std::string> terminal_nodes;
79     terminal_nodes.reserve(specs.outputs.size() + specs.inputs.size());
80     for (const auto& output : specs.outputs) {
81       terminal_nodes.push_back(std::string(ParseTensorName(output).node()));
82     }
83     for (const auto& control_output : specs.control_outputs) {
84       terminal_nodes.push_back(std::string(control_output));
85     }
86     for (const auto& input : specs.inputs) {
87       terminal_nodes.push_back(input.first);
88     }
89     TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
90         graphdef, &pruned_graph_def, terminal_nodes));
91     // TODO(ashwinm): Add a separate utility in grappler utils that abstracts
92     // both SetTransitiveFaninGraph and restoring the missing contents from the
93     // original graph like function def library and version.
94     pruned_graph_def.mutable_library()->Swap(graphdef.mutable_library());
95     pruned_graph_def.mutable_versions()->Swap(graphdef.mutable_versions());
96   }
97   return ConvertGraphdefToMlir(
98       specs.prune_unused_nodes ? pruned_graph_def : graphdef, debug_info, specs,
99       context);
100 }
101 
GraphdefToMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)102 StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
103     llvm::StringRef input, absl::string_view debug_info_file,
104     const std::vector<std::string>& input_arrays,
105     const std::vector<std::string>& input_dtypes,
106     const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
107     const std::vector<std::string>& output_arrays,
108     const std::vector<std::string>& control_output_arrays,
109     bool prune_unused_nodes, bool convert_legacy_fed_inputs,
110     bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
111     mlir::MLIRContext* context) {
112   auto module_or = GraphdefToMlirImport(
113       input, debug_info_file, input_arrays, input_dtypes, input_shapes,
114       output_arrays, control_output_arrays, prune_unused_nodes,
115       convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
116       enable_shape_inference, context);
117   if (!module_or.status().ok()) {
118     LOG(ERROR) << "Graph import failed: " << module_or.status();
119   }
120   return module_or;
121 }
122 
GraphdefToMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)123 StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
124     llvm::StringRef input, absl::string_view debug_info_file,
125     absl::string_view input_arrays, absl::string_view input_dtypes,
126     absl::string_view input_shapes, absl::string_view output_arrays,
127     absl::string_view control_output_arrays, bool prune_unused_nodes,
128     bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
129     bool enable_shape_inference, mlir::MLIRContext* context) {
130   std::vector<std::string> input_array_vector;
131   std::vector<std::string> input_dtype_vector;
132   std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
133   std::vector<std::string> output_array_vector;
134   std::vector<std::string> control_output_array_vector;
135   TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
136   TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector));
137   TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector));
138   TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector));
139   TF_RETURN_IF_ERROR(
140       ParseNodeNames(control_output_arrays, control_output_array_vector));
141   return GraphdefToMlirTranslateFunction(
142       input, debug_info_file, input_array_vector, input_dtype_vector,
143       input_shapes_vector, output_array_vector, control_output_array_vector,
144       prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function,
145       upgrade_legacy, enable_shape_inference, context);
146 }
147 
SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context)148 StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphToMlirImport(
149     absl::string_view saved_model_dir,
150     const std::unordered_set<std::string>& tags,
151     absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
152   tensorflow::SavedModelV2Bundle bundle;
153   auto load_status = tensorflow::SavedModelV2Bundle::Load(
154       std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle);
155   if (!load_status.ok()) {
156     LOG(ERROR) << "Failed to load saved model '" << saved_model_dir
157                << "': " << load_status;
158     return load_status;
159   }
160 
161   auto module_or = ConvertSavedModelToMlir(&bundle, context, exported_names);
162   if (!module_or.status().ok()) {
163     LOG(ERROR) << "SavedModel import failed: " << module_or.status();
164   }
165   return module_or;
166 }
167 
SavedModelSignatureDefsToMlirImport(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)168 StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
169     absl::string_view saved_model_dir,
170     const std::unordered_set<std::string>& tags,
171     absl::Span<std::string> exported_names, mlir::MLIRContext* context,
172     MLIRImportOptions options) {
173   tensorflow::SavedModelBundle bundle;
174   tensorflow::SessionOptions session_options;
175   // Force saved model states to be restored to CPU.
176   (*session_options.config.mutable_device_count())["GPU"] = 0;
177   auto load_status =
178       tensorflow::LoadSavedModel(session_options, /* run_options = */ {},
179                                  std::string(saved_model_dir), tags, &bundle);
180   if (!load_status.ok()) {
181     LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
182                << "': " << load_status;
183     return load_status;
184   }
185 
186   auto module_or =
187       ConvertSavedModelV1ToMlir(bundle, exported_names, context, options);
188   if (!module_or.status().ok()) {
189     LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
190   }
191   return module_or;
192 }
193 
SavedModelSignatureDefsToMlirImportLite(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)194 StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImportLite(
195     absl::string_view saved_model_dir,
196     const std::unordered_set<std::string>& tags,
197     absl::Span<std::string> exported_names, mlir::MLIRContext* context,
198     MLIRImportOptions options) {
199   MetaGraphDef meta_graph_def;
200   auto status = ReadMetaGraphDefFromSavedModel(std::string(saved_model_dir),
201                                                tags, &meta_graph_def);
202   if (!status.ok()) {
203     LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
204                << "': " << status;
205     return status;
206   }
207 
208   auto module_or = ConvertSavedModelV1ToMlirLite(
209       meta_graph_def, /*debug_info=*/{}, exported_names, context, options);
210   if (!module_or.status().ok()) {
211     LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
212   }
213   return module_or;
214 }
215 
GraphdefToSplattedMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)216 StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
217     llvm::StringRef input, absl::string_view debug_info_file,
218     const std::vector<std::string>& input_arrays,
219     const std::vector<std::string>& input_dtypes,
220     const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
221     const std::vector<std::string>& output_arrays,
222     const std::vector<std::string>& control_output_arrays,
223     bool prune_unused_nodes, bool convert_legacy_fed_inputs,
224     bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
225     mlir::MLIRContext* context) {
226   auto module_or = GraphdefToMlirImport(
227       input, debug_info_file, input_arrays, input_dtypes, input_shapes,
228       output_arrays, control_output_arrays, prune_unused_nodes,
229       convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
230       enable_shape_inference, context);
231   if (!module_or.status().ok()) {
232     LOG(ERROR) << "Graph import failed: " << module_or.status();
233     return module_or.status();
234   }
235   auto& module = module_or.ValueOrDie();
236   std::srand(0);
237   for (auto fn : module->getOps<mlir::FuncOp>()) {
238     for (auto& bb : fn) {
239       for (auto& inst : bb) {
240         auto attr_id = mlir::Identifier::get("value", context);
241         if (auto attr = inst.getAttrOfType<mlir::ElementsAttr>(attr_id)) {
242           mlir::Attribute rand_val;
243           mlir::Type element_type = attr.getType().getElementType();
244           if (element_type.isa<mlir::IntegerType>()) {
245             rand_val = mlir::IntegerAttr::get(element_type, std::rand());
246           } else if (element_type.isF16() || element_type.isF32() ||
247                      element_type.isF64()) {
248             rand_val = mlir::FloatAttr::get(element_type,
249                                             std::rand() * 1.0 / RAND_MAX);
250 
251           } else {
252             inst.emitWarning()
253                 << "Skipping splat conversion for "
254                 << "an unsupported attribute type " << element_type;
255             continue;
256           }
257           auto new_attr =
258               mlir::DenseElementsAttr::get(attr.getType(), rand_val);
259           inst.setAttr(attr_id, new_attr);
260         }
261       }
262     }
263   }
264   return module_or;
265 }
266 
GraphdefToSplattedMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)267 StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
268     llvm::StringRef input, absl::string_view debug_info_file,
269     absl::string_view input_arrays, absl::string_view input_dtypes,
270     absl::string_view input_shapes, absl::string_view output_arrays,
271     absl::string_view control_output_arrays, bool prune_unused_nodes,
272     bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
273     bool enable_shape_inference, mlir::MLIRContext* context) {
274   std::vector<std::string> input_array_vector;
275   std::vector<std::string> input_dtype_vector;
276   std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
277   std::vector<std::string> output_array_vector;
278   std::vector<std::string> control_output_array_vector;
279   TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
280   TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector));
281   TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector));
282   TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector));
283   TF_RETURN_IF_ERROR(
284       ParseNodeNames(control_output_arrays, control_output_array_vector));
285   return GraphdefToSplattedMlirTranslateFunction(
286       input, debug_info_file, input_array_vector, input_dtype_vector,
287       input_shapes_vector, output_array_vector, control_output_array_vector,
288       prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function,
289       upgrade_legacy, enable_shape_inference, context);
290 }
291 
292 }  // namespace tensorflow
293