1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
17
18 #include "absl/memory/memory.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/Identifier.h" // from @llvm-project
24 #include "mlir/IR/MLIRContext.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/Parser.h" // from @llvm-project
27 #include "tensorflow/cc/saved_model/bundle_v2.h"
28 #include "tensorflow/cc/saved_model/reader.h"
29 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
30 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
31 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/graph/tensor_id.h"
37 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/protobuf.h"
40 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
41
42 namespace tensorflow {
43
GraphdefToMlirImport(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)44 static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
45 llvm::StringRef input, absl::string_view debug_info_file,
46 const std::vector<std::string>& input_arrays,
47 const std::vector<std::string>& input_dtypes,
48 const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
49 const std::vector<std::string>& output_arrays,
50 const std::vector<std::string>& control_output_arrays,
51 bool prune_unused_nodes, bool convert_legacy_fed_inputs,
52 bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
53 mlir::MLIRContext* context) {
54 GraphDef graphdef;
55 TF_RETURN_IF_ERROR(
56 tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef));
57
58 GraphDebugInfo debug_info;
59 if (!debug_info_file.empty()) {
60 TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_file, &debug_info));
61 }
62
63 GraphImportConfig specs;
64 specs.prune_unused_nodes = prune_unused_nodes;
65 specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs;
66 specs.graph_as_function = graph_as_function;
67 specs.upgrade_legacy = upgrade_legacy;
68 specs.enable_shape_inference = enable_shape_inference;
69 TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes,
70 input_shapes, &specs.inputs));
71 TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs));
72 TF_RETURN_IF_ERROR(
73 ParseOutputArrayInfo(control_output_arrays, &specs.control_outputs));
74 // TODO(b/142828368): Pruning should not be needed when TF import
75 // supports importing graphs w/ unregistered ops natively.
76 GraphDef pruned_graph_def;
77 if (specs.prune_unused_nodes) {
78 std::vector<std::string> terminal_nodes;
79 terminal_nodes.reserve(specs.outputs.size() + specs.inputs.size());
80 for (const auto& output : specs.outputs) {
81 terminal_nodes.push_back(std::string(ParseTensorName(output).node()));
82 }
83 for (const auto& control_output : specs.control_outputs) {
84 terminal_nodes.push_back(std::string(control_output));
85 }
86 for (const auto& input : specs.inputs) {
87 terminal_nodes.push_back(input.first);
88 }
89 TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
90 graphdef, &pruned_graph_def, terminal_nodes));
91 // TODO(ashwinm): Add a separate utility in grappler utils that abstracts
92 // both SetTransitiveFaninGraph and restoring the missing contents from the
93 // original graph like function def library and version.
94 pruned_graph_def.mutable_library()->Swap(graphdef.mutable_library());
95 pruned_graph_def.mutable_versions()->Swap(graphdef.mutable_versions());
96 }
97 return ConvertGraphdefToMlir(
98 specs.prune_unused_nodes ? pruned_graph_def : graphdef, debug_info, specs,
99 context);
100 }
101
GraphdefToMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)102 StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
103 llvm::StringRef input, absl::string_view debug_info_file,
104 const std::vector<std::string>& input_arrays,
105 const std::vector<std::string>& input_dtypes,
106 const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
107 const std::vector<std::string>& output_arrays,
108 const std::vector<std::string>& control_output_arrays,
109 bool prune_unused_nodes, bool convert_legacy_fed_inputs,
110 bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
111 mlir::MLIRContext* context) {
112 auto module_or = GraphdefToMlirImport(
113 input, debug_info_file, input_arrays, input_dtypes, input_shapes,
114 output_arrays, control_output_arrays, prune_unused_nodes,
115 convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
116 enable_shape_inference, context);
117 if (!module_or.status().ok()) {
118 LOG(ERROR) << "Graph import failed: " << module_or.status();
119 }
120 return module_or;
121 }
122
GraphdefToMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)123 StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
124 llvm::StringRef input, absl::string_view debug_info_file,
125 absl::string_view input_arrays, absl::string_view input_dtypes,
126 absl::string_view input_shapes, absl::string_view output_arrays,
127 absl::string_view control_output_arrays, bool prune_unused_nodes,
128 bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
129 bool enable_shape_inference, mlir::MLIRContext* context) {
130 std::vector<std::string> input_array_vector;
131 std::vector<std::string> input_dtype_vector;
132 std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
133 std::vector<std::string> output_array_vector;
134 std::vector<std::string> control_output_array_vector;
135 TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
136 TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector));
137 TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector));
138 TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector));
139 TF_RETURN_IF_ERROR(
140 ParseNodeNames(control_output_arrays, control_output_array_vector));
141 return GraphdefToMlirTranslateFunction(
142 input, debug_info_file, input_array_vector, input_dtype_vector,
143 input_shapes_vector, output_array_vector, control_output_array_vector,
144 prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function,
145 upgrade_legacy, enable_shape_inference, context);
146 }
147
SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context)148 StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphToMlirImport(
149 absl::string_view saved_model_dir,
150 const std::unordered_set<std::string>& tags,
151 absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
152 tensorflow::SavedModelV2Bundle bundle;
153 auto load_status = tensorflow::SavedModelV2Bundle::Load(
154 std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle);
155 if (!load_status.ok()) {
156 LOG(ERROR) << "Failed to load saved model '" << saved_model_dir
157 << "': " << load_status;
158 return load_status;
159 }
160
161 auto module_or = ConvertSavedModelToMlir(&bundle, context, exported_names);
162 if (!module_or.status().ok()) {
163 LOG(ERROR) << "SavedModel import failed: " << module_or.status();
164 }
165 return module_or;
166 }
167
SavedModelSignatureDefsToMlirImport(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)168 StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
169 absl::string_view saved_model_dir,
170 const std::unordered_set<std::string>& tags,
171 absl::Span<std::string> exported_names, mlir::MLIRContext* context,
172 MLIRImportOptions options) {
173 tensorflow::SavedModelBundle bundle;
174 tensorflow::SessionOptions session_options;
175 // Force saved model states to be restored to CPU.
176 (*session_options.config.mutable_device_count())["GPU"] = 0;
177 auto load_status =
178 tensorflow::LoadSavedModel(session_options, /* run_options = */ {},
179 std::string(saved_model_dir), tags, &bundle);
180 if (!load_status.ok()) {
181 LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
182 << "': " << load_status;
183 return load_status;
184 }
185
186 auto module_or =
187 ConvertSavedModelV1ToMlir(bundle, exported_names, context, options);
188 if (!module_or.status().ok()) {
189 LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
190 }
191 return module_or;
192 }
193
SavedModelSignatureDefsToMlirImportLite(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)194 StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImportLite(
195 absl::string_view saved_model_dir,
196 const std::unordered_set<std::string>& tags,
197 absl::Span<std::string> exported_names, mlir::MLIRContext* context,
198 MLIRImportOptions options) {
199 MetaGraphDef meta_graph_def;
200 auto status = ReadMetaGraphDefFromSavedModel(std::string(saved_model_dir),
201 tags, &meta_graph_def);
202 if (!status.ok()) {
203 LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
204 << "': " << status;
205 return status;
206 }
207
208 auto module_or = ConvertSavedModelV1ToMlirLite(
209 meta_graph_def, /*debug_info=*/{}, exported_names, context, options);
210 if (!module_or.status().ok()) {
211 LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
212 }
213 return module_or;
214 }
215
GraphdefToSplattedMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)216 StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
217 llvm::StringRef input, absl::string_view debug_info_file,
218 const std::vector<std::string>& input_arrays,
219 const std::vector<std::string>& input_dtypes,
220 const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
221 const std::vector<std::string>& output_arrays,
222 const std::vector<std::string>& control_output_arrays,
223 bool prune_unused_nodes, bool convert_legacy_fed_inputs,
224 bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
225 mlir::MLIRContext* context) {
226 auto module_or = GraphdefToMlirImport(
227 input, debug_info_file, input_arrays, input_dtypes, input_shapes,
228 output_arrays, control_output_arrays, prune_unused_nodes,
229 convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
230 enable_shape_inference, context);
231 if (!module_or.status().ok()) {
232 LOG(ERROR) << "Graph import failed: " << module_or.status();
233 return module_or.status();
234 }
235 auto& module = module_or.ValueOrDie();
236 std::srand(0);
237 for (auto fn : module->getOps<mlir::FuncOp>()) {
238 for (auto& bb : fn) {
239 for (auto& inst : bb) {
240 auto attr_id = mlir::Identifier::get("value", context);
241 if (auto attr = inst.getAttrOfType<mlir::ElementsAttr>(attr_id)) {
242 mlir::Attribute rand_val;
243 mlir::Type element_type = attr.getType().getElementType();
244 if (element_type.isa<mlir::IntegerType>()) {
245 rand_val = mlir::IntegerAttr::get(element_type, std::rand());
246 } else if (element_type.isF16() || element_type.isF32() ||
247 element_type.isF64()) {
248 rand_val = mlir::FloatAttr::get(element_type,
249 std::rand() * 1.0 / RAND_MAX);
250
251 } else {
252 inst.emitWarning()
253 << "Skipping splat conversion for "
254 << "an unsupported attribute type " << element_type;
255 continue;
256 }
257 auto new_attr =
258 mlir::DenseElementsAttr::get(attr.getType(), rand_val);
259 inst.setAttr(attr_id, new_attr);
260 }
261 }
262 }
263 }
264 return module_or;
265 }
266
GraphdefToSplattedMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,mlir::MLIRContext * context)267 StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
268 llvm::StringRef input, absl::string_view debug_info_file,
269 absl::string_view input_arrays, absl::string_view input_dtypes,
270 absl::string_view input_shapes, absl::string_view output_arrays,
271 absl::string_view control_output_arrays, bool prune_unused_nodes,
272 bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
273 bool enable_shape_inference, mlir::MLIRContext* context) {
274 std::vector<std::string> input_array_vector;
275 std::vector<std::string> input_dtype_vector;
276 std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
277 std::vector<std::string> output_array_vector;
278 std::vector<std::string> control_output_array_vector;
279 TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
280 TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector));
281 TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector));
282 TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector));
283 TF_RETURN_IF_ERROR(
284 ParseNodeNames(control_output_arrays, control_output_array_vector));
285 return GraphdefToSplattedMlirTranslateFunction(
286 input, debug_info_file, input_array_vector, input_dtype_vector,
287 input_shapes_vector, output_array_vector, control_output_array_vector,
288 prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function,
289 upgrade_legacy, enable_shape_inference, context);
290 }
291
292 } // namespace tensorflow
293