1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
17
18 #include <iterator>
19 #include <memory>
20 #include <string>
21 #include <tuple>
22 #include <type_traits>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/container/inlined_vector.h"
31 #include "absl/strings/escaping.h"
32 #include "absl/strings/numbers.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/string_view.h"
36 #include "absl/strings/strip.h"
37 #include "llvm/ADT/ArrayRef.h"
38 #include "llvm/ADT/DenseMap.h"
39 #include "llvm/ADT/DenseSet.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SetVector.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringRef.h"
44 #include "llvm/ADT/StringSet.h"
45 #include "llvm/ADT/Twine.h"
46 #include "llvm/Support/FormatVariadic.h"
47 #include "llvm/Support/SourceMgr.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
50 #include "mlir/IR/Attributes.h" // from @llvm-project
51 #include "mlir/IR/Builders.h" // from @llvm-project
52 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
53 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
54 #include "mlir/IR/Diagnostics.h" // from @llvm-project
55 #include "mlir/IR/Identifier.h" // from @llvm-project
56 #include "mlir/IR/Location.h" // from @llvm-project
57 #include "mlir/IR/MLIRContext.h" // from @llvm-project
58 #include "mlir/IR/OpDefinition.h" // from @llvm-project
59 #include "mlir/IR/Types.h" // from @llvm-project
60 #include "mlir/IR/Verifier.h" // from @llvm-project
61 #include "mlir/Pass/PassManager.h" // from @llvm-project
62 #include "tensorflow/cc/saved_model/constants.h"
63 #include "tensorflow/cc/saved_model/loader_util.h"
64 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
65 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
66 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
71 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
72 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
73 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
74 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
75 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
76 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
77 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
78 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
79 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
80 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
81 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
82 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
83 #include "tensorflow/compiler/xla/status_macros.h"
84 #include "tensorflow/core/common_runtime/function.h"
85 #include "tensorflow/core/common_runtime/graph_constructor.h"
86 #include "tensorflow/core/common_runtime/shape_refiner.h"
87 #include "tensorflow/core/framework/attr_value.pb.h"
88 #include "tensorflow/core/framework/function.pb.h"
89 #include "tensorflow/core/framework/graph.pb.h"
90 #include "tensorflow/core/framework/node_def.pb.h"
91 #include "tensorflow/core/framework/node_def_util.h"
92 #include "tensorflow/core/framework/op.h"
93 #include "tensorflow/core/framework/resource_var.h"
94 #include "tensorflow/core/framework/shape_inference.h"
95 #include "tensorflow/core/framework/tensor.pb.h"
96 #include "tensorflow/core/framework/types.h"
97 #include "tensorflow/core/framework/types.pb.h"
98 #include "tensorflow/core/framework/versions.pb.h"
99 #include "tensorflow/core/graph/algorithm.h"
100 #include "tensorflow/core/graph/graph.h"
101 #include "tensorflow/core/graph/node_builder.h"
102 #include "tensorflow/core/graph/tensor_id.h"
103 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
104 #include "tensorflow/core/lib/core/errors.h"
105 #include "tensorflow/core/lib/strings/str_util.h"
106 #include "tensorflow/core/platform/errors.h"
107 #include "tensorflow/core/platform/path.h"
108 #include "tensorflow/core/platform/protobuf.h"
109 #include "tensorflow/core/platform/types.h"
110 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
111 #include "tensorflow/core/protobuf/meta_graph.pb.h"
112 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
113 #include "tensorflow/core/protobuf/saver.pb.h"
114 #include "tensorflow/core/protobuf/struct.pb.h"
115 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
116 #include "tensorflow/stream_executor/lib/statusor.h"
117
StringRefToView(llvm::StringRef ref)118 static inline absl::string_view StringRefToView(llvm::StringRef ref) {
119 return {ref.data(), ref.size()};
120 }
121
122 namespace tensorflow {
123 using mlir::NamedAttrList;
124 using mlir::TensorType;
125 using mlir::tf_saved_model::AssetOp;
126 using mlir::tf_saved_model::GlobalTensorOp;
127 using mlir::tf_saved_model::SessionInitializerOp;
128 using stream_executor::port::StatusOr;
129
130 namespace {
131
IsOutputShapesAttribute(const AttrValue & attr_value,llvm::StringRef attr_name)132 bool IsOutputShapesAttribute(const AttrValue& attr_value,
133 llvm::StringRef attr_name) {
134 return attr_name.compare("_output_shapes") == 0 &&
135 attr_value.value_case() == AttrValue::kList;
136 }
137
IsResourceOutputShapesAttribute(const AttrValue & attr_value,llvm::StringRef attr_name)138 bool IsResourceOutputShapesAttribute(const AttrValue& attr_value,
139 llvm::StringRef attr_name) {
140 if (attr_name == "_handle_dtypes" || attr_name == "_handle_shapes")
141 return attr_value.value_case() == AttrValue::kList;
142 return false;
143 }
144
LoadImporterDialects(mlir::MLIRContext & context)145 void LoadImporterDialects(mlir::MLIRContext& context) {
146 // Load dialects involved in the conversion
147 mlir::DialectRegistry registry;
148 mlir::RegisterAllTensorFlowDialects(registry);
149 context.appendDialectRegistry(registry);
150 context.loadAllAvailableDialects();
151 }
152
153 // This class is used to generate new MLIR function name strings that are both
154 // unique in the TF function library `flib_` and unique among the name strings
155 // generated by the class object during its lifetime.
156 //
157 // In theory, this class is not necessary because we should simply take
158 // the TF function name and use it as MLIR function name. However, for some
159 // unknown reasons (callout for investigation in b/142268695), keeping the
160 // function names unchanged in an MLIR roundtrip causes test failures.
161 // TODO(b/142268695) Re-evaluate whether we need this class v.s. directly using
162 // and TF function name as MLIR function name after b/142268695 is root caused.
163 class NameUniquifier : public OpOrArgNameMapper {
164 public:
NameUniquifier(const FunctionLibraryDefinition & flib)165 explicit NameUniquifier(const FunctionLibraryDefinition& flib)
166 : flib_(flib) {}
167
168 private:
IsUnique(llvm::StringRef name)169 bool IsUnique(llvm::StringRef name) override {
170 return !flib_.Contains(std::string(name));
171 }
172
GetName(OpOrVal op_or_val)173 std::string GetName(OpOrVal op_or_val) override {
174 DCHECK(false) << "Unimplemented";
175 return "";
176 }
177
178 const FunctionLibraryDefinition& flib_;
179 };
180
181 // Stateful helper class to import a TensorFlow model into an MLIR Module.
182 //
183 // This is the base class that contains common utilities shared between the
184 // GraphDef importer and SavedModel importer.
185 //
186 // A subclass is expected to call `PrepareConvert` first to perform necessary
187 // preparation over the graph and also certain internal bookkeeping data.
188 // Afterwards the other protected methods can be called.
189 class ImporterBase {
190 protected:
ImporterBase(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier,llvm::StringRef function_name_for_debug_info="")191 explicit ImporterBase(
192 const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
193 const GraphImportConfig& specs, mlir::ModuleOp module,
194 std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
195 NameUniquifier* function_name_uniquifier,
196 llvm::StringRef function_name_for_debug_info = "")
197 : builder_(module.getContext()),
198 module_(module),
199 context_(module.getContext()),
200 tf_name_to_mlir_name_(tf_name_to_mlir_name),
201 graph_flib_(flib),
202 specs_(specs),
203 debug_info_(debug_info),
204 function_name_for_debug_info_(function_name_for_debug_info),
205 function_name_uniquifier_(function_name_uniquifier),
206 error_handler_(module.getContext()) {}
207
208 // Returns the inferred function signature of the given function body. Input
209 // types are unranked tensor of the respective datatype in the function and
210 // result types are inferred by the shape_refiner_. Result types need not be
211 // unranked tensors and could be ranked tensors in cases where result type
212 // depends on an op with static output shape like tf.Const.
213 StatusOr<mlir::FunctionType> InferLibFunctionType(const FunctionBody& fbody);
214
215 // Extracts arg and ret nodes from FunctionBody.
216 void GetArgsAndRetsFromFunctionBody(
217 const FunctionBody& fbody,
218 absl::InlinedVector<OutputTensor, 4>* arg_nodes,
219 absl::InlinedVector<OutputTensor, 4>* ret_nodes,
220 absl::InlinedVector<Node*, 4>* control_ret_nodes);
221
222 // Prepares converting the graph to an MLIR module. This step removes the
223 // backedges of the graph, orders the nodes and infers the shapes.
224 Status PrepareConvert(const Graph& graph);
225
226 // Converts the prepared graph to a Function and adds it to the module. A set
227 // of nodes from the graph are given to converted to the arguments and returns
228 // of the function.
229 Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type,
230 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
231 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
232 const absl::InlinedVector<Node*, 4>& control_ret_nodes,
233 llvm::ArrayRef<mlir::NamedAttribute> attrs);
234
235 // Finds out the function definition for the given function name from the
236 // graph and converts it to a function of the module. This method is called
237 // on demand because the graph flib_def does not provide an iterator
238 // interface.
239 Status ConvertLibFunction(llvm::StringRef func_name);
240
241 // Returns the list of nodes in the graph. Nodes are presented in the reverse
242 // order of a post-order depth-first visit starting from the graph's source
243 // nodes.
GetOrderedNodes() const244 llvm::ArrayRef<Node*> GetOrderedNodes() const { return ordered_nodes_; }
245
246 // Returns the inferred input type at index `idx` of the `node` in the
247 // context.
248 StatusOr<mlir::Type> InferInputType(const Node& node, int idx,
249 mlir::Builder builder);
250
251 // Returns the inferred output type at index `idx` of the `node` in the
252 // context.
253 StatusOr<mlir::Type> InferOutputType(const Node& node, int idx,
254 mlir::Builder builder);
255
256 private:
257 // Most types with subtypes have only one subtype.
258 using ElementSubtypes = llvm::SmallVector<TensorType, 1>;
259
260 // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all
261 // data type and shape information is maintained by the shape_refiner_.
262 // TODO(jpienaar): Remove once shape inference on import is removed.
263 Status AddNodesToShapeRefiner(
264 std::unordered_map<string, Node*>* node_name_map);
265
266 // Prune nodes that do not feed into fetch nodes.
267 Status PruneUnreachableNodes(
268 std::unordered_map<string, Node*>* node_name_map);
269
270 // Converts feeds to Placeholder nodes.
271 Status ConvertFeedsToPlaceholders(
272 std::unordered_map<string, Node*>* node_name_map);
273
274 // Converts the inferred shape referred to by 'handle' in 'context', with
275 // given element type, and returns an MLIR tensor type.
276 StatusOr<TensorType> ConvertDataTypeAndShape(
277 DataType dtype, const shape_inference::ShapeHandle& handle,
278 const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
279 shape_inference::InferenceContext* context, mlir::Builder builder);
280
281 // Converts the inferred shape referred to by 'handle' in 'context', with
282 // given element type, and returns an MLIR tensor type.
283 StatusOr<TensorType> ConvertElementTypeAndShape(
284 mlir::Type element_type, const shape_inference::ShapeHandle& handle,
285 shape_inference::InferenceContext* context, mlir::Builder builder);
286
287 // Converts the inferred subtypes for an element type to corresponding MLIR
288 // types in 'context'.
289 StatusOr<ElementSubtypes> ConvertSubtypes(
290 const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
291 shape_inference::InferenceContext* context, mlir::Builder builder);
292
293 // Converts the tensor proto into an MLIR elements attribute.
ConvertTensorProto(const TensorProto & value)294 StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& value) {
295 return ::tensorflow::ConvertTensorProto(value, &builder_);
296 }
297
298 // Converts func name in graphdef to mlir::SymbolRefAttribute.
299 StatusOr<mlir::FlatSymbolRefAttr> ConvertFunctionCallName(
300 const std::string& func_name);
301
302 // Converts the given non-function-call AttrValue to an MLIR Attribute.
303 StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value);
304
305 // Converts the given function-call AttrValue to MLIR Attributes and pushes
306 // them to the given attributes list. For example, if there is a kFunc
307 // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
308 // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
309 // {base_name.k2 : rfc}}.
310 Status ConvertFunctionCallAttribute(const std::string& base_name,
311 const AttrValue& value,
312 NamedAttrList* attributes);
313
314 // Helper to create either a tf_executor operation or a TF operation wrapped
315 // in an island.
316 mlir::Operation* CreateOperation(
317 const Node& node, llvm::StringRef node_type_name,
318 const mlir::OperationState& result,
319 const llvm::SmallVectorImpl<mlir::Value>& control_operands);
320
321 // Converts one NodeDef from the input GraphDef into an Operation and
322 // inserts it into the MLIR module using builder_.
323 Status ConvertNode(const Node& node);
324
325 // If the input graph represents a while-loop, the edges pointing from a
326 // "NextIteration" node to a "Merge" node add cyclic dependencies and make the
327 // topological sorting impossible. We need to remove these edges from the
328 // input graph to infer shapes and construct a Function. For each
329 // "NextIteration" node, there are two operations, "NextIteration.source"
330 // and "NextIteration.sink" are added to the MLIR module.
331 using BackEdge = BackEdgeHelper::BackEdge;
332
333 // Removes backedges from the input graph. The removed edges are added back to
334 // to OpBuilder after the remaining graph is converted to the Function.
335 Status RemoveBackedges(const Graph& graph);
336
337 // Restores backedges removed during shape inference to the final Function.
338 Status AddBackedges();
339
340 // Restores a single backedge in the Function by adding a replicated
341 // operation before the dst operation.
342 Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
343 int dst_input);
344
345 // Adds the input arguments and return operation to the function. The
346 // arguments are added as basic block argument. Also the argument types and
347 // the id of the nodes from the input graph needs to be specified.
348 Status ConvertFunctionArgAndRets(
349 mlir::FuncOp func, mlir::tf_executor::GraphOp graph_op,
350 llvm::ArrayRef<mlir::Type> arg_types,
351 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
352 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
353 const absl::InlinedVector<Node*, 4>& control_ret_nodes);
354
355 // Gets the location information of the given node. It uses the
356 // "original_node_name" in the NodeDef to get the corresponding file location
357 // (FileLineColLoc) from the input DebugInfo and returns an CallSiteLoc. If
358 // there are multiple "original_node_names", a FusedLoc is returned. If the
359 // node name couldn't be found in the input DebugInfo, a NameLoc is used as
360 // the location.
361 mlir::Location GetLocation(const Node& node);
362
363 // Appends the location string for the node to the error message and returns
364 // the combined error status.
365 Status EmitErrorWithLocationStr(const Node& node, const Status& error_status);
366
367 // Inserts a placeholder node in the graph to replace a feed output tensor,
368 // and returns the new placeholder node and a boolean indicating if the
369 // original input node was removed from the graph. Uses of the feed output
370 // tensor are replaced with this placeholder node. If the feed output tensor
371 // is of a single output node, the control dependencies are forwarded to the
372 // the placeholder node, and the original node will be removed.
373 // Note: This modifies the graph, and so any list of ordered nodes needs to be
374 // reconstructed.
375 StatusOr<std::pair<Node*, bool>> CreatePlaceholderNodeForFeed(
376 const TensorShapeProto& shape, DataType dtype, Node* node, int index,
377 const std::unordered_map<string, Node*>& node_name_map);
378
379 // Gets the input and output nodes corresponding to the specified input and
380 // output nodes in specs_. If there are no input or output nodes specified,
381 // nodes will be empty.
382 Status GetInputOutputNodes(
383 const std::unordered_map<string, Node*>& node_name_map,
384 std::unordered_set<const Node*>* nodes);
385
GetUnmodelledOpTypes()386 llvm::StringSet<>& GetUnmodelledOpTypes() {
387 // All the TF ops encountered that aren't modelled in dialect.
388 static auto* unmodelled_op_types = new llvm::StringSet<>();
389 return *unmodelled_op_types;
390 }
391
392 // The input graph with backedges removed. The removed backedges are stored
393 // in the back_edge_helper.
394 BackEdgeHelper back_edge_helper_;
395 // A map between node and output index, for each backedge.
396 absl::flat_hash_map<const Node*, int> back_edge_node_output_;
397 absl::flat_hash_map<const Node*, BackEdge> back_edge_dst_inputs_;
398 // A map between sink and source operation of NextIteration
399 absl::flat_hash_map<mlir::Operation*, mlir::Operation*>
400 next_iteration_sink_source_;
401
402 // All nodes and version information about the (copied) imported graph.
403 std::unique_ptr<Graph> graph_;
404 std::vector<Node*> ordered_nodes_;
405
406 // Maps from a Node ID to a MLIR value.
407 using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>;
408
409 mlir::OpBuilder builder_;
410 mlir::ModuleOp module_;
411 mlir::MLIRContext* context_;
412 std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
413 const FunctionLibraryDefinition& graph_flib_;
414 const GraphImportConfig& specs_;
415 const GraphDebugInfo& debug_info_;
416 llvm::StringRef function_name_for_debug_info_;
417 NodeValueMap node_values_;
418 // TODO(jpienaar): Remove once shape inference on import is removed.
419 // The shape_refinner_ will be nullptr if shape inference on import is
420 // not enabled.
421 std::unique_ptr<ShapeRefiner> shape_refiner_ = nullptr;
422 NameUniquifier* function_name_uniquifier_;
423 mlir::StatusScopedDiagnosticHandler error_handler_;
424
425 protected:
426 // Maps feed as TensorId to new Placeholder node name.
427 absl::flat_hash_map<TensorId, absl::string_view> remapped_feeds_;
428 };
429
430 // Returns true if the node with given name has a non primary output that is
431 // used by some other node as an input. Returns false if no outputs are in use
432 // or only the first output is in use.
HasNonPrimaryOutputInUse(const GraphDef & graph_def,const std::string & node)433 bool HasNonPrimaryOutputInUse(const GraphDef& graph_def,
434 const std::string& node) {
435 for (const auto& node_def : graph_def.node()) {
436 for (const auto& input : node_def.input()) {
437 if (absl::StartsWith(input, node + ":") && input != node + ":0") {
438 return true;
439 }
440 }
441 }
442 return false;
443 }
444
445 // Updates the given LegacyFedInput node with Placeholder node if it is one of
446 // the inputs. Returns an error if non primary output of the LegacyFedInput node
447 // is in use and therefore can not be replaced by the Placeholder node that only
448 // has a single output.
UpdateLegacyFedInputNode(const GraphDef & graph_def,const GraphImportConfig::InputArrays & inputs,NodeDef * node)449 Status UpdateLegacyFedInputNode(const GraphDef& graph_def,
450 const GraphImportConfig::InputArrays& inputs,
451 NodeDef* node) {
452 const std::string& node_name = node->name();
453 auto it = inputs.find(node_name);
454
455 // Node is not an input.
456 if (it == inputs.end()) return Status::OK();
457
458 if (HasNonPrimaryOutputInUse(graph_def, node_name)) {
459 return errors::InvalidArgument(
460 "LegacyFedInput node ", node->name(),
461 " has non primary output in use and can not be replaced with "
462 "Placeholder node");
463 }
464
465 DataType dtype = it->second.imported_dtype;
466 // Uses the existing output type if it isn't specified by the user.
467 if (dtype == DT_INVALID) {
468 dtype = node->attr().at("output_types").list().type(0);
469 }
470 // Update op name, drop inputs and set attributes required by the Placeholder
471 // op.
472 *node->mutable_op() = "Placeholder";
473 node->clear_attr();
474 node->clear_input();
475 AddNodeAttr("dtype", dtype, node);
476 AddNodeAttr("shape", it->second.shape, node);
477 return Status::OK();
478 }
479
480 // Preprocesses GraphDef before it can be converted to Graph by,
481 // - Adding the default attributes to each node def if they are missing from
482 // the GraphDef.
483 // - Replacing LegacyFedInput nodes with Placeholder nodes if
484 // convert_legacy_fed_inputs option is enabled.
PreprocessGraphDef(const GraphImportConfig * specs,GraphDef * graph_def)485 Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) {
486 for (auto& node_def : *graph_def->mutable_node()) {
487 // TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One
488 // solution could be have a tool to let users upgrade old serialized graphs.
489 if (specs && specs->convert_legacy_fed_inputs &&
490 node_def.op() == "LegacyFedInput") {
491 TF_RETURN_IF_ERROR(
492 UpdateLegacyFedInputNode(*graph_def, specs->inputs, &node_def));
493 }
494
495 const tensorflow::OpRegistrationData* op_reg_data =
496 tensorflow::OpRegistry::Global()->LookUp(node_def.op());
497 if (!op_reg_data) {
498 // This is likely a function call node, so we should continue.
499 continue;
500 }
501 ::tensorflow::AddDefaultsToNodeDef(op_reg_data->op_def, &node_def);
502 }
503 return Status::OK();
504 }
505
506 // Mapping from node name to feed (index and ArrayInfo). Node name must outlive
507 // this map.
508 using FeedsByNode = absl::flat_hash_map<
509 absl::string_view,
510 absl::flat_hash_map<int, const std::pair<std::string, ArrayInfo>*>>;
511
512 // Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output
513 // tensor name to index and ArrayInfo. Keys and values are backed by
514 // `GraphImportConfig::InputArrays`.
GetFeedsByNode(const GraphImportConfig::InputArrays & inputs)515 StatusOr<FeedsByNode> GetFeedsByNode(
516 const GraphImportConfig::InputArrays& inputs) {
517 FeedsByNode feeds_by_node;
518 feeds_by_node.reserve(inputs.size());
519
520 for (const auto& input : inputs) {
521 TensorId tensor = ParseTensorName(input.first);
522 if (tensor.index() < 0)
523 return errors::FailedPrecondition(
524 "Feed output tensor must be a data output '", tensor.ToString(), "'");
525
526 auto& node = feeds_by_node[tensor.node()];
527 if (!node.insert({tensor.index(), &input}).second)
528 return errors::FailedPrecondition(
529 "Multiple feeds for the same output tensor '", tensor.ToString(),
530 "'");
531 }
532
533 return feeds_by_node;
534 }
535
536 // Creates a unique name for a node that will be replacing a feed output tensor.
GetUniqueNodeName(absl::string_view node_name,int index,const std::unordered_map<string,Node * > & node_name_map)537 std::string GetUniqueNodeName(
538 absl::string_view node_name, int index,
539 const std::unordered_map<string, Node*>& node_name_map) {
540 std::string new_node_name_base = absl::StrCat(node_name, "_", index);
541 int count = 0;
542 std::string new_node_name = new_node_name_base;
543 while (node_name_map.find(new_node_name) != node_name_map.end()) {
544 new_node_name = absl::StrCat(new_node_name_base, "_", count++);
545 }
546 return new_node_name;
547 }
548
RemoveBackedges(const Graph & graph)549 Status ImporterBase::RemoveBackedges(const Graph& graph) {
550 // TODO(fengliuai): Converting to GraphDef and back is the easiest way to
551 // clone a graph.
552 // TODO(fengliuai): clone the graph without going to graph_def first.
553 GraphDef graph_def;
554 graph.ToGraphDef(&graph_def);
555 graph_ = absl::make_unique<Graph>(graph.flib_def());
556 GraphConstructorOptions opts;
557 opts.allow_internal_ops = true;
558 opts.add_default_attributes = false;
559 TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph(
560 opts, std::move(graph_def), graph_.get()));
561
562 // Remove all the backedges. So the nodes can be added to the shape refiner.
563 TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get()));
564 VLOG(1) << "Found " << (back_edge_helper_.RemovedEdges().size())
565 << " backedges.";
566
567 // Creates a map for quickly identifying whether a node output is a backedge.
568 for (const auto& edge : back_edge_helper_.RemovedEdges()) {
569 if (back_edge_node_output_.find(edge.src) != back_edge_node_output_.end() &&
570 back_edge_node_output_[edge.src] != edge.src_output) {
571 return errors::FailedPrecondition(
572 "More than one of the src node outputs are backedges!");
573 }
574 back_edge_node_output_[edge.src] = edge.src_output;
575 // We expect a merge to receive a single backedge (multiple NextIteration
576 // nodes feeding into the same merge is unexpected here).
577 DCHECK(!back_edge_dst_inputs_.contains(edge.dst));
578 back_edge_dst_inputs_[edge.dst] = edge;
579 }
580
581 // Obtains a RPO ordering, using node names as a tiebreak for stable sorting.
582 GetReversePostOrder(
583 *graph_, &ordered_nodes_,
584 [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
585 return Status::OK();
586 }
587
CopyStackTraces(const Graph & from,Graph * to)588 Status CopyStackTraces(const Graph& from, Graph* to) {
589 // Copy over the stack traces.
590 // TODO(jpienaar): This really shouldn't be needed, copying the Graph above
591 // and then needing these traversals is unfortunate.
592 std::unordered_map<string, Node*> node_map = from.BuildNodeNameIndex();
593 for (Node* node : to->nodes()) {
594 if (const Node* old_node = node_map[node->name()]) {
595 if (const std::shared_ptr<AbstractStackTrace>& stack =
596 old_node->GetStackTrace()) {
597 DVLOG(2) << "Stack for " << node->name() << " "
598 << old_node->GetStackTrace()->ToString(
599 AbstractStackTrace::TracePrintingOptions());
600 node->SetStackTrace(stack);
601 } else {
602 DVLOG(1) << "No stack for " << node->name() << " (" << node
603 << ") in Graph " << &from;
604 }
605 } else {
606 DVLOG(1) << "No stack for " << node->name() << " (" << node
607 << ") in Graph " << &from;
608 }
609 }
610
611 return Status::OK();
612 }
613
CreatePlaceholderNodeForFeed(const TensorShapeProto & shape,DataType dtype,Node * node,int index,const std::unordered_map<string,Node * > & node_name_map)614 StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed(
615 const TensorShapeProto& shape, DataType dtype, Node* node, int index,
616 const std::unordered_map<string, Node*>& node_name_map) {
617 DCHECK_LT(index, node->num_outputs());
618 const bool update_inplace = node->num_outputs() == 1 && index == 0;
619 std::string new_node_name =
620 update_inplace ? node->name()
621 : GetUniqueNodeName(node->name(), index, node_name_map);
622
623 Node* placeholder_node;
624 NodeBuilder builder(new_node_name, "Placeholder");
625 builder.Attr("shape", shape);
626 builder.Attr("dtype", dtype);
627 TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node));
628
629 // Update edges from original feed with Placeholder node.
630 std::vector<const Edge*> data_edges;
631 std::vector<const Edge*> control_edges;
632 for (const tensorflow::Edge* edge : node->out_edges()) {
633 if (edge->src_output() == index) {
634 data_edges.push_back(edge);
635 } else if (update_inplace && edge->IsControlEdge()) {
636 control_edges.push_back(edge);
637 }
638 }
639
640 for (const auto* edge : data_edges) {
641 TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(),
642 edge->dst_input()));
643 }
644
645 // TODO(lyandy): Preserve control dependencies properly by not forwarding
646 // control dependencies to data outputs and not removing single output nodes.
647 // When a data output is replaced as a feed, unless there is another non feed
648 // data output or an explicit control output used by the same node, transitive
649 // control dependencies are not to be executed. For single output nodes,
650 // Placeholders can be converted to a NoOp if there are no uses, and
651 // PlaceholderWithDefault can be converted to an Identity.
652 for (const auto* edge : control_edges) {
653 graph_->AddControlEdge(placeholder_node, edge->dst());
654 graph_->RemoveControlEdge(edge);
655 }
656
657 if (update_inplace) {
658 graph_->RemoveNode(node);
659 }
660
661 return std::pair<Node*, bool>(placeholder_node, update_inplace);
662 }
663
GetInputOutputNodes(const std::unordered_map<string,Node * > & node_name_map,std::unordered_set<const Node * > * nodes)664 Status ImporterBase::GetInputOutputNodes(
665 const std::unordered_map<string, Node*>& node_name_map,
666 std::unordered_set<const Node*>* nodes) {
667 auto add_node = [&](absl::string_view name) {
668 auto it = node_name_map.find(std::string(name));
669 if (it == node_name_map.end()) {
670 return errors::FailedPrecondition(
671 absl::StrCat("Graph does not contain node: ", name));
672 }
673 nodes->insert(it->second);
674 return Status::OK();
675 };
676
677 // Remap feeds and fetches to newly created Placeholder nodes.
678 for (const auto& input : specs_.inputs) {
679 TensorId tensor = ParseTensorName(input.first);
680 auto remapped_it = remapped_feeds_.find(tensor);
681 if (remapped_it != remapped_feeds_.end()) {
682 TF_RETURN_IF_ERROR(add_node(remapped_it->second));
683 } else {
684 TF_RETURN_IF_ERROR(add_node(tensor.node()));
685 }
686 }
687
688 for (const auto& output : specs_.outputs) {
689 TensorId tensor = ParseTensorName(output);
690 auto remapped_it = remapped_feeds_.find(tensor);
691 if (remapped_it != remapped_feeds_.end()) {
692 TF_RETURN_IF_ERROR(add_node(remapped_it->second));
693 } else {
694 TF_RETURN_IF_ERROR(add_node(tensor.node()));
695 }
696 }
697
698 for (const auto& control_output : specs_.control_outputs)
699 TF_RETURN_IF_ERROR(add_node(control_output));
700
701 return Status::OK();
702 }
703
704 // TODO(jpienaar): Remove this post shape inference on import flag is removed.
AddNodesToShapeRefiner(std::unordered_map<string,Node * > * node_name_map)705 Status ImporterBase::AddNodesToShapeRefiner(
706 std::unordered_map<string, Node*>* node_name_map) {
707 shape_refiner_ = absl::make_unique<ShapeRefiner>(graph_->versions(),
708 graph_->op_registry());
709 // Some operations (for example "TPUExecute") don't have shape inference
710 // function defined, so we should set this to false for adding nodes with
711 // these types of operations.
712 shape_refiner_->set_require_shape_inference_fns(false);
713 shape_refiner_->set_function_library_for_shape_inference(&graph_flib_);
714
715 TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
716
717 // First add all nodes to the refiner.
718 for (Node* node : ordered_nodes_) {
719 // We need to use a TensorFlow node to teach the shape refiner that user
720 // specifies certain data type and shape for the inputs in the `specs_`.
721 // This node shouldn't have any inputs, only have one output and its
722 // output type/shape is only determined by its "named" attributes. (The
723 // attributes should have fixed names so we can use the info from `specs_`
724 // to set the value of them.) `Placeholder` satisfies these constraints.
725 //
726 // Therefore, if the input node isn't a `Placeholder`, we create one and use
727 // it to replace the original input node, so the shape refiner can
728 // successfully propagate the user's input type and shape to the rest of the
729 // graph.
730 bool node_added_to_shape_refiner = false;
731 auto it = feeds_by_node.find(node->name());
732 if (it != feeds_by_node.end()) {
733 auto op_name = node->op_def().name();
734 if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
735 op_name != FunctionLibraryDefinition::kArgOp) {
736 for (const auto& output_tensor : it->second) {
737 const int index = output_tensor.first;
738 const ArrayInfo& array_info = output_tensor.second->second;
739
740 DataType dtype = array_info.imported_dtype;
741 // Uses the existing output type if it isn't specified by the user.
742 if (dtype == DT_INVALID) {
743 dtype = node->output_type(index);
744 }
745
746 TF_ASSIGN_OR_RETURN(
747 auto placeholder_node_and_removed,
748 CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
749 *node_name_map));
750
751 Node* placeholder_node = placeholder_node_and_removed.first;
752 if (placeholder_node_and_removed.second) {
753 // Original node has been removed from the graph.
754 node = placeholder_node;
755 node_added_to_shape_refiner = true;
756 }
757 remapped_feeds_[{it->first, index}] = placeholder_node->name();
758 (*node_name_map)[placeholder_node->name()] = placeholder_node;
759 // Add the new placeholder node to the shape refiner.
760 Status status = shape_refiner_->AddNode(placeholder_node);
761 if (!status.ok()) {
762 return EmitErrorWithLocationStr(*placeholder_node, status);
763 }
764 }
765 } else {
766 auto index_it = it->second.find(0);
767 if (index_it == it->second.end()) {
768 return errors::FailedPrecondition(
769 "Missing feed output tensor at index 0 for node '", node->name(),
770 "'");
771 }
772 node->AddAttr("shape", index_it->second->second.shape);
773 DataType dtype = index_it->second->second.imported_dtype;
774 // Uses the existing output type if it isn't specified by the user.
775 if (dtype == DT_INVALID) {
776 dtype = node->output_type(0);
777 }
778 node->AddAttr("dtype", dtype);
779 }
780 }
781 if (!node_added_to_shape_refiner) {
782 // Add the node to the shape refiner if the node hasn't been removed.
783 Status status = shape_refiner_->AddNode(node);
784 if (!status.ok()) {
785 return EmitErrorWithLocationStr(*node, status);
786 }
787 }
788
789 auto set_shape_from_list_attr = [&](const AttrValue* attr) {
790 auto& list = attr->list();
791 for (auto shape : llvm::enumerate(list.shape())) {
792 auto* node_context = shape_refiner_->GetContext(node);
793 shape_inference::ShapeHandle handle;
794 Status status =
795 node_context->MakeShapeFromShapeProto(shape.value(), &handle);
796 if (!status.ok()) {
797 return EmitErrorWithLocationStr(*node, status);
798 }
799 node_context->set_output(shape.index(), handle);
800 }
801 return Status::OK();
802 };
803
804 // We currently have no other way to get shapes from ReadVariableOp's.
805 // Some graphs seem to have _output_shapes attributes on them, so use that
806 // if possible.
807 // TODO(silvasean): Ideally, we would do this in a separate shape inference
808 // pass to avoid adding complexity to the importer. But right now, we don't
809 // have an MLIR-native shape inference pass, so we need to do this while we
810 // still have the Graph around, i.e. here, in the importer.
811 if (node->op_def().name() == "ReadVariableOp") {
812 // TODO(silvasean): In some graphs, this seems to be annotated on every
813 // node. Why and by whom?
814 // TODO(b/140588338): We should ideally incorporate that information for
815 // all nodes, but right now, this can result in e.g. an Identity node with
816 // signature such as
817 // `(tensor<?x?xf32>) -> tensor<?x9216xf32>` which fails the verifier
818 // (which checks for exact type equality; _output_shapes results in
819 // us shoehorning in the more-precise type on the output).
820 if (const AttrValue* attr = node->attrs().Find("_output_shapes"))
821 TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
822 }
823
824 // If it is the argument node, the shape handle is set explicitly, so it
825 // can be propagated to the body nodes of the function.
826 if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) {
827 auto* node_context = shape_refiner_->GetContext(node);
828 DCHECK(node_context != nullptr);
829 if (const AttrValue* attr = node->attrs().Find("shape")) {
830 shape_inference::ShapeHandle handle;
831 Status status =
832 node_context->MakeShapeFromShapeProto(attr->shape(), &handle);
833 if (!status.ok()) {
834 return EmitErrorWithLocationStr(*node, status);
835 }
836 node_context->set_output(0, handle);
837 } else if (const AttrValue* attr = node->attrs().Find("_output_shapes")) {
838 TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
839 } else {
840 node_context->set_output(0, node_context->UnknownShape());
841 }
842 }
843 }
844
845 // Since we might have inserted and removed nodes from the graph, fix
846 // source/sink edges and reconstruct the RPO ordering of nodes
847 FixupSourceAndSinkEdges(graph_.get());
848
849 // Prune nodes in the graph that are not reachable from the output.
850 if (specs_.prune_unused_nodes) {
851 std::unordered_set<const Node*> prune_start;
852 TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start));
853 if (!prune_start.empty()) {
854 if (PruneForReverseReachability(graph_.get(), prune_start)) {
855 VLOG(1) << "Pruned unused nodes in graphdef";
856 } else {
857 VLOG(1) << "No unused nodes in graphdef to prune";
858 }
859 } else {
860 VLOG(1) << "No output nodes specified, skipping pruning";
861 }
862 } else {
863 VLOG(1) << "Pruning unused nodes in graphdef is disabled";
864 }
865
866 // Re-initialize ordered_nodes_ since we might have modified the graph.
867 GetReversePostOrder(
868 *graph_, &ordered_nodes_,
869 [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
870
871 VLOG(1) << "Inferring graph shapes to fixpoint";
872
873 // The "changed" information from UpdateNode can give false positives, so we
874 // create a dedicated method to verify the shapes are not changed before and
875 // after the shape refine.
876 auto same_inferred_shape = [](shape_inference::InferenceContext* c,
877 shape_inference::ShapeHandle s0,
878 shape_inference::ShapeHandle s1) -> bool {
879 if (s0.SameHandle(s1) || (!c->RankKnown(s0) && !c->RankKnown(s1))) {
880 return true;
881 }
882 if (c->Rank(s0) != c->Rank(s1)) {
883 return false;
884 }
885 for (int i = 0; i < c->Rank(s0); ++i) {
886 if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
887 int64 val0 = c->Value(c->Dim(s0, i));
888 int64 val1 = c->Value(c->Dim(s1, i));
889 // Negative value is treated as unknown so all negative values indicate
890 // the same dimension.
891 if (val0 >= 0 && val1 >= 0 && val0 != val1) return false;
892 }
893 }
894 return true;
895 };
896
897 bool changed = true;
898 int i = 0;
899 const int kMaxIterationCount = 2;
900 while (changed && i != kMaxIterationCount) {
901 changed = false;
902 for (const Node* node : ordered_nodes_) {
903 auto* shape_context = shape_refiner_->GetContext(node);
904 DCHECK(shape_context != nullptr);
905 absl::InlinedVector<shape_inference::ShapeHandle, 4> existing;
906 existing.reserve(shape_context->num_outputs());
907 for (int o = 0; o < shape_context->num_outputs(); ++o) {
908 existing.push_back(shape_context->output(o));
909 }
910 bool inferred = false;
911 shape_inference::ShapeHandle handle;
912 Status status =
913 shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred);
914 if (!status.ok()) {
915 return EmitErrorWithLocationStr(*node, status);
916 }
917 for (int o = 0; o < shape_context->num_outputs(); ++o) {
918 if (!same_inferred_shape(shape_context, shape_context->output(o),
919 existing[o])) {
920 changed = true;
921 break;
922 }
923 }
924 }
925 ++i;
926 }
927 if (i >= kMaxIterationCount) {
928 LOG(WARNING) << "Graph shapes did not converge to a fixpoint within "
929 << kMaxIterationCount
930 << " iterations. Graph shapes may be conservative.";
931 }
932 VLOG(1) << "Graph shapes were inferred with " << (i - 1)
933 << " extra rounds of analysis to reach a fixpoint.";
934 return Status::OK();
935 }
936
InferInputType(const Node & node,int idx,mlir::Builder builder)937 StatusOr<mlir::Type> ImporterBase::InferInputType(const Node& node, int idx,
938 mlir::Builder builder) {
939 if (specs_.enable_shape_inference) {
940 // TODO(jpienaar): Remove this if shape inference on import flag is removed.
941 ExtendedInferenceContext* shape_context =
942 shape_refiner_->GetExtendedContext(&node);
943 DataType dtype = shape_context->input_type(idx);
944 auto* context = shape_context->get_context();
945 return ConvertDataTypeAndShape(dtype, context->input(idx),
946 context->input_handle_shapes_and_types(idx),
947 context, builder);
948 }
949 DataType dtype = node.properties()->input_types[idx];
950 mlir::Type element_type;
951 TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type));
952 return mlir::UnrankedTensorType::get(element_type);
953 }
954
InferOutputType(const Node & node,int idx,mlir::Builder builder)955 StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
956 mlir::Builder builder) {
957 DataType dtype = node.properties()->output_types[idx];
958
959 // Returns output type given inference context.
960 auto shape_ic = [&](shape_inference::InferenceContext* c) {
961 return ConvertDataTypeAndShape(dtype, c->output(idx),
962 c->output_handle_shapes_and_types(idx), c,
963 builder);
964 };
965
966 if (specs_.enable_shape_inference) {
967 // TODO(jpienaar): Remove this if shape inference on import flag is removed.
968 ExtendedInferenceContext* shape_context =
969 shape_refiner_->GetExtendedContext(&node);
970 return shape_ic(shape_context->get_context());
971 }
972
973 // Treat TensorList init ops specially here as the op requires knowing its
974 // element dtype.
975 // TODO(jpienaar): Reconsider post refactoring shape functions.
976 if (node.type_string() == "TensorListReserve" ||
977 node.type_string() == "EmptyTensorList") {
978 mlir::Type etype;
979 if (auto element_dtype = node.attrs().Find("element_dtype")) {
980 TF_RETURN_IF_ERROR(
981 ConvertDataType(element_dtype->type(), builder, &etype));
982 }
983 return mlir::RankedTensorType::get(
984 {}, mlir::TF::VariantType::get({mlir::UnrankedTensorType::get(etype)},
985 etype.getContext()));
986 }
987
988 if (node.IsWhileNode()) {
989 auto* output_shapes = node.attrs().Find("output_shapes");
990 auto* element_types = node.attrs().Find("T");
991 if (output_shapes && !output_shapes->list().shape().empty()) {
992 const auto& output_shape = output_shapes->list().shape(idx);
993 const auto& element_type = element_types->list().type(idx);
994 return ConvertToMlirTensorType(output_shape, element_type, &builder);
995 }
996 }
997
998 auto type_from_array_attr = [&node, &idx, &builder](
999 absl::string_view output_shape_attr,
1000 absl::string_view element_type_attr) {
1001 auto* output_shapes = node.attrs().Find(output_shape_attr);
1002 auto* element_types = node.attrs().Find(element_type_attr);
1003 const auto& output_shape = output_shapes->list().shape(idx);
1004 const auto& element_type = element_types->list().type(idx);
1005 return ConvertToMlirTensorType(output_shape, element_type, &builder);
1006 };
1007
1008 if (node.type_string() == "IteratorGetNext" ||
1009 node.type_string() == "IteratorGetNextSync" ||
1010 node.type_string() == "MultiDeviceIteratorGetNextFromShard")
1011 return type_from_array_attr("output_shapes", "output_types");
1012
1013 if (node.type_string() == "InfeedDequeueTuple")
1014 return type_from_array_attr("shapes", "dtypes");
1015
1016 if (node.type_string() == "InfeedDequeue") {
1017 assert(idx == 0);
1018 const auto& output_shape = node.attrs().Find("shape")->shape();
1019 const auto& element_type = node.attrs().Find("dtype")->type();
1020 return ConvertToMlirTensorType(output_shape, element_type, &builder);
1021 }
1022
1023 // Returns a simple, more conservative unranked tensor type.
1024 auto default_type = [&]() -> StatusOr<mlir::Type> {
1025 mlir::Type element_type;
1026 TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type));
1027 return mlir::UnrankedTensorType::get(element_type);
1028 };
1029
1030 // Below we only try and do some shape inference for "source" ops which have
1031 // no inputs.
1032 if (node.num_inputs() > 0) return default_type();
1033
1034 // Do some simply inference here to get the function arguments correct for
1035 // this common case.
1036 // TODO(jpienaar): Reconsider post refactoring shape functions.
1037 if (node.IsArg()) {
1038 if (dtype == DT_RESOURCE) {
1039 const AttrValue* dtype_attr = node.attrs().Find("_handle_dtypes");
1040 const AttrValue* shape_attr = node.attrs().Find("_handle_shapes");
1041 if (dtype_attr && shape_attr) {
1042 if (dtype_attr->list().type().empty()) {
1043 return errors::InvalidArgument(
1044 "Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
1045 shape_attr->DebugString());
1046 }
1047 if (shape_attr->list().shape().empty()) {
1048 return errors::InvalidArgument(
1049 "Invalid \"_handle_shapes\" attribute value for _Arg node: ",
1050 shape_attr->DebugString());
1051 }
1052 DataType dtype = dtype_attr->list().type(0);
1053 const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
1054 TF_ASSIGN_OR_RETURN(
1055 auto etype, ConvertToMlirTensorType(shape_proto, dtype, &builder));
1056 return mlir::UnrankedTensorType::get(mlir::TF::ResourceType::get(
1057 {etype.cast<TensorType>()}, builder.getContext()));
1058 } else {
1059 return mlir::UnrankedTensorType::get(
1060 mlir::TF::ResourceType::get(builder.getContext()));
1061 }
1062 } else if (auto shape = node.attrs().Find("_output_shapes")) {
1063 if (shape->has_list() && shape->list().shape_size() == 1) {
1064 return ConvertToMlirTensorType(shape->list().shape().at(0), dtype,
1065 &builder);
1066 }
1067 }
1068 }
1069
1070 const tensorflow::OpRegistrationData* op_reg_data;
1071 TF_RETURN_IF_ERROR(
1072 graph_->op_registry()->LookUp(node.type_string(), &op_reg_data));
1073 if (!op_reg_data) {
1074 DVLOG(1) << "Skipping inference for unregistered op " << node.type_string();
1075 return default_type();
1076 }
1077 if (op_reg_data->shape_inference_fn == nullptr) {
1078 DVLOG(1) << "Skipping inference for op without shape function "
1079 << node.type_string();
1080 return default_type();
1081 }
1082 shape_inference::InferenceContext c(graph_->versions().producer(),
1083 node.attrs(), op_reg_data->op_def,
1084 std::vector<PartialTensorShape>{}, {},
1085 /*input_tensors_as_shapes=*/{}, {});
1086 TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
1087 return shape_ic(&c);
1088 }
1089
ConvertDataTypeAndShape(DataType dtype,const shape_inference::ShapeHandle & handle,const std::vector<shape_inference::ShapeAndType> * handle_subtypes,shape_inference::InferenceContext * context,mlir::Builder builder)1090 StatusOr<TensorType> ImporterBase::ConvertDataTypeAndShape(
1091 DataType dtype, const shape_inference::ShapeHandle& handle,
1092 const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
1093 shape_inference::InferenceContext* context, mlir::Builder builder) {
1094 TF_ASSIGN_OR_RETURN(auto subtypes,
1095 ConvertSubtypes(handle_subtypes, context, builder));
1096
1097 mlir::Type element_type;
1098 if (dtype == DT_VARIANT)
1099 element_type = mlir::TF::VariantType::get(subtypes, context_);
1100 else if (dtype == DT_RESOURCE)
1101 element_type = mlir::TF::ResourceType::get(subtypes, context_);
1102 else
1103 TF_RETURN_IF_ERROR(
1104 ::tensorflow::ConvertDataType(dtype, builder, &element_type));
1105
1106 return ConvertElementTypeAndShape(element_type, handle, context, builder);
1107 }
1108
ConvertElementTypeAndShape(mlir::Type element_type,const shape_inference::ShapeHandle & handle,shape_inference::InferenceContext * context,mlir::Builder builder)1109 StatusOr<TensorType> ImporterBase::ConvertElementTypeAndShape(
1110 mlir::Type element_type, const shape_inference::ShapeHandle& handle,
1111 shape_inference::InferenceContext* context, mlir::Builder builder) {
1112 if (!context->RankKnown(handle)) {
1113 return mlir::UnrankedTensorType::get(element_type);
1114 }
1115
1116 // Sentinel for an unknown dimension size. getTensorType interprets any
1117 // negative value as an unknown dimension.
1118 // TODO(jmolloy): Ideally this shouldn't be a local sentinel.
1119 const int64_t kUnknownDim = -1;
1120
1121 absl::InlinedVector<int64_t, 4> dimensions;
1122 int32 rank = context->Rank(handle);
1123 dimensions.reserve(rank);
1124 for (int i = 0; i < rank; ++i) {
1125 auto dim_handle = context->Dim(handle, i);
1126 if (!context->ValueKnown(dim_handle))
1127 dimensions.push_back(kUnknownDim);
1128 else
1129 dimensions.push_back(context->Value(dim_handle));
1130 }
1131
1132 return mlir::RankedTensorType::get(
1133 llvm::makeArrayRef(dimensions.begin(), dimensions.end()), element_type);
1134 }
1135
ConvertSubtypes(const std::vector<shape_inference::ShapeAndType> * handle_subtypes,shape_inference::InferenceContext * context,mlir::Builder builder)1136 StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes(
1137 const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
1138 shape_inference::InferenceContext* context, mlir::Builder builder) {
1139 ElementSubtypes subtypes;
1140 if (!handle_subtypes) return subtypes;
1141
1142 subtypes.reserve(handle_subtypes->size());
1143 for (const auto& subtype : *handle_subtypes) {
1144 mlir::Type element_type;
1145 TF_RETURN_IF_ERROR(
1146 ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type));
1147 TF_ASSIGN_OR_RETURN(TensorType type,
1148 ConvertElementTypeAndShape(element_type, subtype.shape,
1149 context, builder));
1150 subtypes.push_back(type);
1151 }
1152 return subtypes;
1153 }
1154
ConvertFunctionCallAttribute(const std::string & base_name,const AttrValue & value,NamedAttrList * attributes)1155 Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name,
1156 const AttrValue& value,
1157 NamedAttrList* attributes) {
1158 TF_ASSIGN_OR_RETURN(auto func_attr,
1159 ConvertFunctionCallName(value.func().name()));
1160 attributes->push_back(builder_.getNamedAttr(base_name, func_attr));
1161
1162 for (const auto& it : value.func().attr()) {
1163 auto name = absl::StrCat(base_name, ".", it.first);
1164 TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second));
1165 attributes->push_back(builder_.getNamedAttr(name, value));
1166 }
1167 return Status::OK();
1168 }
1169
ConvertFunctionCallName(const std::string & func_name)1170 StatusOr<mlir::FlatSymbolRefAttr> ImporterBase::ConvertFunctionCallName(
1171 const std::string& func_name) {
1172 TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
1173 auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
1174 auto func = module_.lookupSymbol<mlir::FuncOp>(mlir_func_name);
1175 return builder_.getSymbolRefAttr(func);
1176 }
1177
ConvertAttributeValue(const AttrValue & value)1178 StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
1179 const AttrValue& value) {
1180 switch (value.value_case()) {
1181 case AttrValue::kFunc: {
1182 // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation.
1183 // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue
1184 // will not use this representation.
1185 NamedAttrList attrs;
1186 for (const auto& func_attr : value.func().attr()) {
1187 TF_ASSIGN_OR_RETURN(
1188 auto attr, ImporterBase::ConvertAttributeValue(func_attr.second));
1189 attrs.push_back(builder_.getNamedAttr(func_attr.first, attr));
1190 }
1191 auto func_attrs = builder_.getDictionaryAttr(attrs);
1192 return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs);
1193 }
1194 case AttrValue::kList: {
1195 if (!value.list().func().empty()) {
1196 absl::InlinedVector<mlir::Attribute, 8> attrs;
1197 for (const auto& item : value.list().func()) {
1198 TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name()));
1199 if (item.attr_size() != 0)
1200 return errors::Unimplemented(
1201 "func attributes with non-zero attr.size()");
1202 attrs.push_back(attr);
1203 }
1204 return builder_.getArrayAttr(
1205 llvm::makeArrayRef(attrs.begin(), attrs.end()));
1206 }
1207 return ConvertNonFuncAttributeValue(value, &builder_);
1208 }
1209 default:
1210 return ConvertNonFuncAttributeValue(value, &builder_);
1211 }
1212 }
1213
GetArgsAndRetsFromFunctionBody(const FunctionBody & fbody,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes,absl::InlinedVector<Node *,4> * control_ret_nodes)1214 void ImporterBase::GetArgsAndRetsFromFunctionBody(
1215 const FunctionBody& fbody, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
1216 absl::InlinedVector<OutputTensor, 4>* ret_nodes,
1217 absl::InlinedVector<Node*, 4>* control_ret_nodes) {
1218 arg_nodes->reserve(fbody.arg_nodes.size());
1219 ret_nodes->reserve(fbody.ret_nodes.size());
1220 for (auto arg : fbody.arg_nodes) {
1221 arg_nodes->emplace_back(arg, 0);
1222 }
1223 for (auto ret : fbody.ret_nodes) {
1224 ret_nodes->emplace_back(ret, 0);
1225 }
1226 *control_ret_nodes = fbody.control_ret_nodes;
1227 }
1228
ConvertLibFunction(llvm::StringRef func_name)1229 Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
1230 // If the library function has been converted already, nothing needs to be
1231 // done.
1232 if (tf_name_to_mlir_name_->find(std::string(func_name)) !=
1233 tf_name_to_mlir_name_->end())
1234 return Status::OK();
1235
1236 std::string mlir_func_name(
1237 function_name_uniquifier_->GetUniqueName(func_name));
1238 (*tf_name_to_mlir_name_)[std::string(func_name)] = mlir_func_name;
1239
1240 const auto& func_lib = graph_flib_;
1241 const auto* func_def = func_lib.Find(std::string(func_name));
1242 if (func_def == nullptr) {
1243 return errors::FailedPrecondition(
1244 absl::StrCat("Failed to find function '", StringRefToView(func_name),
1245 "'. The imported TensorFlow GraphDef is ill-formed."));
1246 }
1247
1248 // Converts the function definition to a graph.
1249 std::unique_ptr<FunctionBody> fbody;
1250 TF_RETURN_IF_ERROR(
1251 FunctionDefToBodyHelper(*func_def, AttrSlice(), &func_lib, &fbody));
1252
1253 // Converts the argument and return types to MLIR types.
1254 absl::InlinedVector<mlir::NamedAttribute, 8> attributes;
1255 attributes.reserve(func_def->attr_size());
1256 for (const auto& name_and_value : func_def->attr()) {
1257 // This is a function definition attribute, so it shouldn't contain
1258 // kFunc attribute and it is treated as normal one.
1259 TF_ASSIGN_OR_RETURN(auto attr,
1260 ConvertAttributeValue(name_and_value.second));
1261 std::string attr_name =
1262 mangling_util::MangleAttributeName(name_and_value.first);
1263 attributes.push_back(builder_.getNamedAttr(attr_name, attr));
1264 }
1265
1266 // Checks opdef stateful attribute and import that as Function Attribute
1267 if (func_def->signature().is_stateful()) {
1268 auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
1269 attributes.push_back(
1270 builder_.getNamedAttr(stateful_str, builder_.getUnitAttr()));
1271 }
1272
1273 // Checks for an associated custom gradient function. Adds it to the attribute
1274 // list of this function.
1275 auto grad_func_name = func_lib.FindGradient(std::string(func_name));
1276 if (!grad_func_name.empty()) {
1277 TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
1278 auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
1279 auto grad_func = module_.lookupSymbol<mlir::FuncOp>(mlir_grad_func_name);
1280 auto gradient_attr = builder_.getSymbolRefAttr(grad_func);
1281 auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
1282 attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr));
1283 }
1284
1285 // Converts the graph to an MLIR function and adds it to the module.
1286 // We populate the NodeSpec so that all the _Arg ops get their shape
1287 // added correctly.
1288 GraphImportConfig specs;
1289 specs.enable_shape_inference = specs_.enable_shape_inference;
1290 for (const auto& name_and_value : func_def->attr()) {
1291 if (name_and_value.first == "_input_shapes") {
1292 auto& list = name_and_value.second.list();
1293 auto& signature = func_def->signature();
1294 if (list.shape_size() != signature.input_arg_size()) {
1295 return errors::FailedPrecondition(
1296 "Number of input arguments must be equal to the length of "
1297 "_input_shapes attribute in function '",
1298 StringRefToView(func_name), "'.");
1299 }
1300 for (int i = 0; i < list.shape_size(); i++) {
1301 auto& input_arg = signature.input_arg(i);
1302 auto& array_info = specs.inputs[input_arg.name()];
1303 array_info.imported_dtype = input_arg.type();
1304 array_info.shape = list.shape(i);
1305 }
1306 }
1307 }
1308
1309 ImporterBase child_importer(graph_flib_, debug_info_, specs, module_,
1310 tf_name_to_mlir_name_, function_name_uniquifier_,
1311 func_name);
1312 TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph));
1313
1314 TF_ASSIGN_OR_RETURN(auto func_type,
1315 child_importer.InferLibFunctionType(*fbody));
1316
1317 absl::InlinedVector<OutputTensor, 4> arg_nodes;
1318 absl::InlinedVector<OutputTensor, 4> ret_nodes;
1319 absl::InlinedVector<Node*, 4> control_ret_nodes;
1320 GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes,
1321 &control_ret_nodes);
1322
1323 TF_RETURN_IF_ERROR(child_importer.Convert(
1324 mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes,
1325 llvm::makeArrayRef(attributes.begin(), attributes.end())));
1326 return Status::OK();
1327 }
1328
PruneUnreachableNodes(std::unordered_map<string,Node * > * node_name_map)1329 Status ImporterBase::PruneUnreachableNodes(
1330 std::unordered_map<string, Node*>* node_name_map) {
1331 std::unordered_set<const Node*> prune_start;
1332 TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start));
1333
1334 if (!prune_start.empty()) {
1335 if (PruneForReverseReachability(graph_.get(), prune_start)) {
1336 VLOG(1) << "Pruned unused nodes in graphdef";
1337 } else {
1338 VLOG(1) << "No unused nodes in graphdef to prune";
1339 }
1340 } else {
1341 VLOG(1) << "No output nodes specified, skipping pruning";
1342 }
1343 return Status::OK();
1344 }
1345
ConvertFeedsToPlaceholders(std::unordered_map<string,Node * > * node_name_map)1346 Status ImporterBase::ConvertFeedsToPlaceholders(
1347 std::unordered_map<string, Node*>* node_name_map) {
1348 // Feeds (edges) are converted into single-output placeholder nodes to
1349 // simplify the conversion process.
1350 TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
1351 for (const auto& it : feeds_by_node) {
1352 TensorId tensor = ParseTensorName(it.first);
1353 auto jt = node_name_map->find(std::string(tensor.node()));
1354 if (jt == node_name_map->end()) {
1355 return errors::FailedPrecondition(
1356 absl::StrCat("Graph does not contain node: ", tensor.node()));
1357 }
1358
1359 Node* node = jt->second;
1360 auto op_name = node->op_def().name();
1361 if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
1362 op_name != FunctionLibraryDefinition::kArgOp) {
1363 for (const auto& output_tensor : it.second) {
1364 const int index = output_tensor.first;
1365 const ArrayInfo& array_info = output_tensor.second->second;
1366
1367 DataType dtype = array_info.imported_dtype;
1368 // Uses the existing output type if it isn't specified by the user.
1369 if (dtype == DT_INVALID) {
1370 dtype = node->output_type(index);
1371 }
1372
1373 TF_ASSIGN_OR_RETURN(
1374 auto placeholder_node_and_removed,
1375 CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
1376 *node_name_map));
1377
1378 Node* placeholder_node = placeholder_node_and_removed.first;
1379 if (placeholder_node->in_edges().empty()) {
1380 graph_->AddControlEdge(graph_->source_node(), placeholder_node,
1381 true /* skip test for duplicates */);
1382 }
1383 if (placeholder_node->out_edges().empty()) {
1384 graph_->AddControlEdge(placeholder_node, graph_->sink_node(),
1385 true /* skip test for duplicates */);
1386 }
1387 remapped_feeds_[{it.first, index}] = placeholder_node->name();
1388 (*node_name_map)[placeholder_node->name()] = placeholder_node;
1389 }
1390 }
1391 }
1392 return Status::OK();
1393 }
1394
PrepareConvert(const Graph & graph)1395 Status ImporterBase::PrepareConvert(const Graph& graph) {
1396 TF_RETURN_IF_ERROR(RemoveBackedges(graph));
1397 TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get()));
1398
1399 auto node_name_map = graph_->BuildNodeNameIndex();
1400
1401 if (specs_.enable_shape_inference) {
1402 // TODO(jpienaar): Remove once infer shapes on import flag is removed.
1403 TF_RETURN_IF_ERROR(AddNodesToShapeRefiner(&node_name_map));
1404 } else {
1405 TF_RETURN_IF_ERROR(ConvertFeedsToPlaceholders(&node_name_map));
1406 }
1407
1408 // Prune nodes in the graph that are not reachable from the output.
1409 if (specs_.prune_unused_nodes) {
1410 TF_RETURN_IF_ERROR(PruneUnreachableNodes(&node_name_map));
1411 }
1412
1413 if (!specs_.enable_shape_inference) {
1414 // Re-initialize ordered_nodes_ since we might have modified the graph.
1415 GetReversePostOrder(
1416 *graph_, &ordered_nodes_,
1417 [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
1418 }
1419
1420 return Status::OK();
1421 }
1422
Convert(llvm::StringRef func_name,mlir::FunctionType func_type,const absl::InlinedVector<OutputTensor,4> & arg_nodes,const absl::InlinedVector<OutputTensor,4> & ret_nodes,const absl::InlinedVector<Node *,4> & control_ret_nodes,llvm::ArrayRef<mlir::NamedAttribute> attrs)1423 Status ImporterBase::Convert(
1424 llvm::StringRef func_name, mlir::FunctionType func_type,
1425 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
1426 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
1427 const absl::InlinedVector<Node*, 4>& control_ret_nodes,
1428 llvm::ArrayRef<mlir::NamedAttribute> attrs) {
1429 // TODO(b/122040776): Uses debug info for FunctionDef.
1430 auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
1431 func_name, func_type, attrs);
1432
1433 module_.push_back(function);
1434 // Seeds the builder with an initial block.
1435 function.addEntryBlock();
1436 builder_ = mlir::OpBuilder(function.getBody());
1437
1438 // Create the graph operation in which we will convert the individual nodes.
1439 auto graph = builder_.create<mlir::tf_executor::GraphOp>(
1440 function.getLoc(), func_type.getResults());
1441 builder_.createBlock(&graph.body());
1442
1443 for (const Node* node : ordered_nodes_) {
1444 TF_RETURN_IF_ERROR(ConvertNode(*node));
1445 }
1446
1447 // Adds the backedges back to the function by creating the source and sink
1448 // pairs.
1449 TF_RETURN_IF_ERROR(AddBackedges());
1450
1451 TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph,
1452 func_type.getInputs(), arg_nodes,
1453 ret_nodes, control_ret_nodes));
1454
1455 // TODO(jpienaar): Update post removing shape_refinier_.
1456 if (!specs_.enable_shape_inference) {
1457 // Refine graph's type given more precise fetch.
1458 auto fetch = graph.GetFetch();
1459 bool all_equal = true;
1460 for (auto it :
1461 llvm::zip_first(graph.getResults(), fetch.getOperandTypes())) {
1462 auto rt = std::get<1>(it);
1463 if (rt == std::get<0>(it).getType()) continue;
1464 std::get<0>(it).setType(rt);
1465 all_equal = false;
1466 }
1467 if (!all_equal) {
1468 function.setType(mlir::FunctionType::get(function.getContext(),
1469 func_type.getInputs(),
1470 graph.getResultTypes()));
1471 }
1472 }
1473
1474 return Status::OK();
1475 }
1476
ConvertFunctionArgAndRets(mlir::FuncOp func,mlir::tf_executor::GraphOp graph_op,llvm::ArrayRef<mlir::Type> arg_types,const absl::InlinedVector<OutputTensor,4> & arg_nodes,const absl::InlinedVector<OutputTensor,4> & ret_nodes,const absl::InlinedVector<Node *,4> & control_ret_nodes)1477 Status ImporterBase::ConvertFunctionArgAndRets(
1478 mlir::FuncOp func, mlir::tf_executor::GraphOp graph_op,
1479 llvm::ArrayRef<mlir::Type> arg_types,
1480 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
1481 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
1482 const absl::InlinedVector<Node*, 4>& control_ret_nodes) {
1483 // Store the arg/return attributes as a list rather than uniqueuing during
1484 // construction.
1485 llvm::SmallVector<mlir::NamedAttrList, 4> arg_attrs;
1486 arg_attrs.resize(func.getNumArguments());
1487 llvm::SmallVector<mlir::NamedAttrList, 4> ret_attrs;
1488 ret_attrs.resize(func.getNumResults());
1489
1490 auto set_attributes_on_func = [&](Node* node, int64_t index, bool is_arg) {
1491 for (const auto& node_attr : node->attrs()) {
1492 const auto& key = node_attr.first;
1493 // Only import optional attributes (e.g., those starting with an
1494 // underscore).
1495 if (key.empty() || key[0] != '_') continue;
1496 // Ignore shape inference attributes as shape information is already
1497 // populated in the result type.
1498 if (IsOutputShapesAttribute(node_attr.second, key) ||
1499 IsResourceOutputShapesAttribute(node_attr.second, key))
1500 continue;
1501 TF_ASSIGN_OR_RETURN(auto converted_attr,
1502 ConvertAttributeValue(node_attr.second));
1503 std::string dialect_attribute = "tf." + key;
1504 if (is_arg) {
1505 arg_attrs[index].set(dialect_attribute, converted_attr);
1506 } else {
1507 func.setResultAttr(index, dialect_attribute, converted_attr);
1508 ret_attrs[index].set(dialect_attribute, converted_attr);
1509 }
1510 }
1511 return Status::OK();
1512 };
1513
1514 auto* bb = &func.front();
1515 llvm::SmallDenseMap<std::pair<Node*, int>, mlir::Value, 4>
1516 arg_nodes_to_values;
1517 for (int i = 0, e = arg_types.size(); i < e; ++i) {
1518 auto& arg_node = arg_nodes[i];
1519 // The lookup can't fail here: otherwise some nodes in the function haven't
1520 // be converted to mlir operations and don't have a mapping.
1521 mlir::Operation* island = node_values_.find(arg_node.node->id())->second;
1522
1523 auto bb_arg = bb->getArgument(i);
1524 mlir::Value arg_def = bb_arg;
1525
1526 if (island->getNumResults() != 2)
1527 return errors::InvalidArgument(
1528 "Only feed output tensors of single output nodes are supported");
1529
1530 // Collect mapping of OutputTensor to associated block arg.
1531 arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def);
1532 island->getResult(0).replaceAllUsesWith(arg_def);
1533 // Erase control outputs from feed.
1534 auto control_uses = island->getResult(1).getUses();
1535 for (auto& control_use : llvm::make_early_inc_range(control_uses))
1536 control_use.getOwner()->eraseOperand(control_use.getOperandNumber());
1537
1538 if (!arg_node.node->requested_device().empty())
1539 arg_attrs[i].set("tf.device", builder_.getStringAttr(
1540 arg_node.node->requested_device()));
1541
1542 if (arg_node.node->IsArg()) {
1543 TF_RETURN_IF_ERROR(
1544 set_attributes_on_func(arg_node.node, i, /*is_arg=*/true));
1545 }
1546
1547 island->dropAllReferences();
1548 island->erase();
1549 }
1550
1551 llvm::SmallVector<mlir::Value, 8> inst_to_return;
1552 for (auto ret_and_idx : llvm::enumerate(ret_nodes)) {
1553 const auto& ret = ret_and_idx.value();
1554 auto* inst = node_values_[ret.node->id()];
1555 if (ret.node->IsRetval()) {
1556 if (!ret.node->requested_device().empty())
1557 ret_attrs[ret_and_idx.index()].set(
1558 "tf.device", builder_.getStringAttr(ret.node->requested_device()));
1559 TF_RETURN_IF_ERROR(set_attributes_on_func(ret.node, ret_and_idx.index(),
1560 /*is_arg=*/false));
1561 // Lookup the instruction inside the island
1562 auto island_op = llvm::cast<mlir::tf_executor::IslandOp>(inst);
1563 mlir::Operation* inner_op = &island_op.GetBody().front();
1564 // Remove kRetOp or kDeviceRetOp operation and return its operand.
1565 // kRetOp and kDeviceRetOp should have just one operand unless they have
1566 // control dependencies.
1567 if (inner_op->getNumOperands() != 1)
1568 return errors::Unimplemented("Return node with multiple inputs.");
1569 inst_to_return.push_back(inner_op->getOperand(0));
1570 inst->dropAllReferences();
1571 inst->erase();
1572 } else {
1573 // Lookup and use block arg if fetch is a feed.
1574 auto it = arg_nodes_to_values.find({ret.node, ret.index});
1575 if (it != arg_nodes_to_values.end())
1576 inst_to_return.push_back(it->second);
1577 else
1578 inst_to_return.push_back(inst->getResult(ret.index));
1579 }
1580 }
1581
1582 for (Node* control_ret : control_ret_nodes) {
1583 auto* inst = node_values_[control_ret->id()];
1584 inst_to_return.push_back(*std::prev(inst->result_end()));
1585 }
1586
1587 // Terminate the function by adding a Fetch operation to terminate the graph
1588 // and a return operation to return the Graph results.
1589 builder_.setInsertionPointToEnd(&graph_op.body().front());
1590 builder_.create<mlir::tf_executor::FetchOp>(graph_op.getLoc(),
1591 inst_to_return);
1592 builder_.setInsertionPointToEnd(bb);
1593 builder_.create<mlir::ReturnOp>(mlir::UnknownLoc::get(context_),
1594 graph_op.getResults());
1595
1596 func.setAllArgAttrs(
1597 llvm::to_vector<4>(llvm::map_range(arg_attrs, [&](NamedAttrList& list) {
1598 return list.getDictionary(context_);
1599 })));
1600 func.setAllResultAttrs(
1601 llvm::to_vector<4>(llvm::map_range(ret_attrs, [&](NamedAttrList& list) {
1602 return list.getDictionary(context_);
1603 })));
1604
1605 return Status::OK();
1606 }
1607
GetLocation(const Node & node)1608 mlir::Location ImporterBase::GetLocation(const Node& node) {
1609 DVLOG(1) << "Getting location for " << node.name() << " " << &node;
1610 // TODO(b/142400497): What is the semantic contract for locations?
1611 const auto& debug_info = debug_info_.traces();
1612
1613 // Create a location for node `name` in function `function_name`.
1614 auto create_location = [&](llvm::StringRef name,
1615 llvm::StringRef function_name) -> mlir::Location {
1616 // Use the catenation of function and node names as the lookup key into the
1617 // debug info. This matches the way that the key is formed on the python
1618 // side.
1619 //
1620 // We also use this as the name for the NameLoc for ops in function, since
1621 // otherwise our names could collide across functions.
1622 // For ops in the main graph, we omit the "@function_name" (which, would be
1623 // just "@" since function_name would be empty) because some code seems to
1624 // depend on the name being this way for correctness.
1625 std::string debug_info_key = (name + "@" + function_name).str();
1626 std::string name_for_name_loc =
1627 function_name.empty() ? name.str() : (name + "@" + function_name).str();
1628 auto name_loc_id = mlir::Identifier::get(name_for_name_loc, context_);
1629
1630 llvm::SmallVector<mlir::Location, 4> locations;
1631 // Prefer stack traces if available, fallback to debug info if not, and then
1632 // finally to just name.
1633 if (auto stack_trace = node.GetStackTrace()) {
1634 DVLOG(1) << "Stack available for " << node.name();
1635 absl::Span<const StackFrame> frames = stack_trace->ToFrames();
1636 locations.reserve(frames.size());
1637 for (const StackFrame& frame : llvm::reverse(frames)) {
1638 auto file_name = mlir::Identifier::get(frame.file_name, context_);
1639 // Use col 1 as there is no column info in StackTrace.
1640 auto file_line_loc = mlir::FileLineColLoc::get(
1641 file_name, frame.line_number, 1, context_);
1642 locations.push_back(file_line_loc);
1643 }
1644 } else {
1645 DVLOG(1) << "No stack trace for " << node.name();
1646 const auto location_it = debug_info.find(debug_info_key);
1647 if (location_it != debug_info.end()) {
1648 DVLOG(1) << "Available serialized debug info for " << node.name();
1649 // Convert the stack trace to a chain of mlir::CallSiteLocs.
1650 const auto& trace = location_it->second;
1651 locations.reserve(trace.file_line_cols_size());
1652 for (const auto& location : trace.file_line_cols()) {
1653 const auto& file = debug_info_.files(location.file_index());
1654 auto file_name = mlir::Identifier::get(file, context_);
1655 auto file_line_loc = mlir::FileLineColLoc::get(
1656 file_name, location.line(), location.col(), context_);
1657 locations.push_back(file_line_loc);
1658 }
1659 }
1660 }
1661
1662 // If there are no locations in the stack trace, fall back to just a
1663 // NameLoc with no child.
1664 if (locations.empty()) return mlir::NameLoc::get(name_loc_id, context_);
1665
1666 // Use the front FileLineColLoc to generate a NameLoc.
1667 mlir::Location node_name_loc =
1668 mlir::NameLoc::get(name_loc_id, locations.front());
1669
1670 // If there are more locations then generate a stack trace, otherwise just
1671 // return the name loc.
1672 auto callsite_locs = llvm::makeArrayRef(locations).drop_front();
1673 return callsite_locs.empty()
1674 ? node_name_loc
1675 : mlir::CallSiteLoc::get(node_name_loc, callsite_locs);
1676 };
1677
1678 // For NextIteration nodes, location is used to pair source and sink nodes.
1679 // Hence, we use node name as location to keep it unique.
1680 // TODO(prakalps): In future the plan is to use tokens to pair source/sink
1681 // nodes. Then NextIteration nodes would not need to be handled separately.
1682 if (node.type_string() == "NextIteration")
1683 return create_location(node.name(), function_name_for_debug_info_);
1684
1685 if (node.GetStackTrace())
1686 return create_location(node.name(), function_name_for_debug_info_);
1687
1688 const auto& node_def = node.def();
1689 auto original_nodes =
1690 node_def.experimental_debug_info().original_node_names();
1691 auto original_funcs =
1692 node_def.experimental_debug_info().original_func_names();
1693
1694 if (original_nodes.empty()) {
1695 return create_location(node.name(), function_name_for_debug_info_);
1696 } else {
1697 // If the original nodes are defined, then we use them to get a list of
1698 // call sites, and then fuse them to a single fused location, with the name
1699 // of the node_def.
1700 llvm::SmallVector<mlir::Location, 4> node_locations;
1701 node_locations.reserve(original_nodes.size() + 1);
1702
1703 // store the names in the experimental_debug_info
1704 for (int i = 0, e = original_nodes.size(); i != e; ++i) {
1705 auto node_name = original_nodes[i];
1706 auto func_name = (i < original_funcs.size()) ? original_funcs[i] : "";
1707 node_locations.push_back(create_location(node_name, func_name));
1708 }
1709 // store the name of the node_def
1710 node_locations.push_back(
1711 create_location(node.name(), function_name_for_debug_info_));
1712 return mlir::FusedLoc::get(node_locations, context_);
1713 }
1714 }
1715
EmitErrorWithLocationStr(const Node & node,const Status & error_status)1716 Status ImporterBase::EmitErrorWithLocationStr(const Node& node,
1717 const Status& error_status) {
1718 const mlir::Location location = GetLocation(node);
1719 mlir::emitError(location);
1720 return error_handler_.Combine(error_status);
1721 }
1722
CreateOperation(const Node & node,llvm::StringRef node_type_name,const mlir::OperationState & result,const llvm::SmallVectorImpl<mlir::Value> & control_operands)1723 mlir::Operation* ImporterBase::CreateOperation(
1724 const Node& node, llvm::StringRef node_type_name,
1725 const mlir::OperationState& result,
1726 const llvm::SmallVectorImpl<mlir::Value>& control_operands) {
1727 // For the tf.executor specific operations (not wrapped in an island), we
1728 // have an extra returned value for the control result, and we concatenate
1729 // control and non-control operands.
1730 mlir::SmallVector<mlir::Type, 4> types(result.types);
1731 types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext()));
1732 mlir::SmallVector<mlir::Value, 4> operands(result.operands);
1733 operands.append(control_operands.begin(), control_operands.end());
1734
1735 auto loc = result.location;
1736 // Dispatch based on the name and create the appropriate operation.
1737 if (node.IsSwitch()) {
1738 // Switch and _SwitchN both are in switch class, differentiate based on
1739 // op name.
1740 if (node.op_def().name() == "_SwitchN") {
1741 return builder_.create<mlir::tf_executor::SwitchNOp>(loc, types, operands,
1742 result.attributes);
1743 }
1744 return builder_.create<mlir::tf_executor::SwitchOp>(loc, types, operands,
1745 result.attributes);
1746 }
1747 if (node.IsMerge()) {
1748 return builder_.create<mlir::tf_executor::MergeOp>(loc, types, operands,
1749 result.attributes);
1750 }
1751 if (node.IsNextIteration()) {
1752 // NextIteration is a bit special, we create a pair of operations that are
1753 // linked together through a token returned by the source.
1754 // We make use of a separate builder to insert the source at the top of
1755 // the block.
1756 mlir::OpBuilder builder_at_begin(builder_.getBlock(),
1757 builder_.getBlock()->begin());
1758 auto source_op =
1759 builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>(
1760 loc, operands[0].getType(), result.attributes);
1761 return builder_.create<mlir::tf_executor::NextIterationSinkOp>(
1762 loc, source_op.token(), operands, result.attributes);
1763 }
1764 if (node.IsLoopCond()) {
1765 return builder_.create<mlir::tf_executor::LoopCondOp>(loc, types, operands,
1766 result.attributes);
1767 }
1768 if (node.IsEnter()) {
1769 return builder_.create<mlir::tf_executor::EnterOp>(loc, types, operands,
1770 result.attributes);
1771 }
1772 if (node.IsExit()) {
1773 return builder_.create<mlir::tf_executor::ExitOp>(loc, types, operands,
1774 result.attributes);
1775 }
1776 if (node.IsControlTrigger()) {
1777 return builder_.create<mlir::tf_executor::ControlTriggerOp>(
1778 loc, operands, result.attributes);
1779 }
1780 // Regular TensorFlow operation are wrapped in a tf_executor.island.
1781 auto island = builder_.create<mlir::tf_executor::IslandOp>(
1782 result.location, types, control_operands,
1783 mlir::ArrayRef<mlir::NamedAttribute>{});
1784 island.body().push_back(new mlir::Block);
1785 mlir::OpBuilder island_builder =
1786 mlir::OpBuilder::atBlockEnd(&island.GetBody());
1787
1788 // Create the operation inside the island now.
1789 mlir::Operation* inner_op = island_builder.createOperation(result);
1790
1791 // Sets operand_segment_sizes or result_segment_sizes attribute to the op.
1792 const auto set_segment_sizes_attr =
1793 [&](const NameRangeMap& arg_ranges,
1794 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
1795 llvm::StringRef attr_name) {
1796 std::vector<mlir::Attribute> values;
1797 values.reserve(args.size());
1798 for (const auto& arg : args) {
1799 auto range = arg_ranges.at(arg.name());
1800 values.push_back(
1801 island_builder.getI32IntegerAttr(range.second - range.first));
1802 }
1803 auto attr_type =
1804 mlir::VectorType::get(args.size(), builder_.getIntegerType(32));
1805 auto attr_value = mlir::DenseElementsAttr::get(attr_type, values);
1806 inner_op->setAttr(attr_name, attr_value);
1807 };
1808
1809 if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>() ||
1810 inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
1811 // The op has multiple variadic operands or results.
1812 // Calculate operand and result segment sizes using the OpDef.
1813 NameRangeMap input_ranges, output_ranges;
1814 // This will fail only if the OpDef is syntactically invalid.
1815 // TODO(jpienaar): Convert this CHECK into a properly propagated error.
1816 TF_CHECK_OK(
1817 NameRangesForNode(node, node.op_def(), &input_ranges, &output_ranges));
1818 if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
1819 // Add derived "operand_segment_sizes" attr to the created operation.
1820 // TODO(b/146937733): Don't use <void> here.
1821 set_segment_sizes_attr(input_ranges, node.op_def().input_arg(),
1822 mlir::OpTrait::AttrSizedOperandSegments<
1823 void>::getOperandSegmentSizeAttr());
1824 }
1825
1826 if (inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
1827 // Add derived "result_segment_sizes" attr to the created operation.
1828 // TODO(b/146937733): Don't use <void> here.
1829 set_segment_sizes_attr(output_ranges, node.op_def().output_arg(),
1830 mlir::OpTrait::AttrSizedResultSegments<
1831 void>::getResultSegmentSizeAttr());
1832 }
1833 }
1834
1835 mlir::OperationName name = inner_op->getName();
1836 if (!name.getAbstractOperation() &&
1837 // Skip unmodelled ops that are handled differently.
1838 (node_type_name != "_Arg" && node_type_name != "_Retval")) {
1839 if (GetUnmodelledOpTypes().insert(name.getStringRef()).second) {
1840 LOG(INFO) << "Unmodelled op type `" << node.type_string() << "`"
1841 << (node.op_def().is_stateful()
1842 ? " is stateful but effects not modelled"
1843 : " is not stateful but will be treated as such "
1844 "conservatively");
1845 }
1846 }
1847
1848 // Add the terminator for the island
1849 island_builder.create<mlir::tf_executor::YieldOp>(result.location,
1850 inner_op->getResults());
1851 return island.getOperation();
1852 }
1853
ConvertNode(const Node & node)1854 Status ImporterBase::ConvertNode(const Node& node) {
1855 if (!node.IsOp()) {
1856 // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by
1857 // Graph and don't exist in GraphDef.
1858 return Status::OK();
1859 }
1860
1861 // If it is a custom OP, its definition should be found in the library. We
1862 // create the MLIR function and insert it to the module if it doesn't exist.
1863 std::string node_type_name = node.type_string();
1864 const auto* func_def = graph_flib_.Find(node_type_name);
1865 bool convert_to_legacy_call = false;
1866 if (func_def) {
1867 TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name));
1868 node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
1869 convert_to_legacy_call = true;
1870 }
1871
1872 auto get_full_op_name = [&](const std::string& op_name) {
1873 const char* kTfPrefix = "tf.";
1874 return kTfPrefix + op_name;
1875 };
1876
1877 std::string op_name = get_full_op_name(node_type_name);
1878 if (back_edge_node_output_.contains(&node)) {
1879 op_name = op_name + ".sink";
1880 }
1881
1882 mlir::OperationState result(GetLocation(node), op_name);
1883 for (int i = 0; i < node.num_outputs(); ++i) {
1884 // The backedge has been removed, so we shouldn't count the corresponding
1885 // output from the src node when converting to an operation.
1886 if (back_edge_node_output_.contains(&node) &&
1887 back_edge_node_output_[&node] == i) {
1888 continue;
1889 }
1890 TF_ASSIGN_OR_RETURN(auto type, InferOutputType(node, i, builder_));
1891 result.types.push_back(type);
1892 }
1893
1894 // Surprisingly input edges can be nondeterministically ordered. This
1895 // particularly seems to be the case for the control edges between _SOURCE
1896 // and _SINK that the Graph constructor inserts. Copy the input edges and
1897 // sort the edges, but only the control edges, not data edges!
1898 // TODO(jmolloy): We should probably just ignore _SOURCE and _SINK nodes.
1899 // They'll break roundtripping anyway unless we strip them when converting
1900 // back to graphdef.
1901 absl::InlinedVector<const Edge*, 8> in_edges(node.in_edges().size());
1902 absl::c_copy(node.in_edges(), in_edges.begin());
1903 absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) {
1904 if (e1->IsControlEdge() && !e2->IsControlEdge()) return false;
1905 if (!e1->IsControlEdge() && e2->IsControlEdge()) return true;
1906 if (e1->IsControlEdge() && e2->IsControlEdge())
1907 return e1->src()->id() < e2->src()->id();
1908 return e1->dst_input() < e2->dst_input();
1909 });
1910
1911 result.operands.reserve(in_edges.size());
1912
1913 // Collect the control operands separately, they will be held by the island.
1914 mlir::SmallVector<mlir::Value, 8> control_operands;
1915
1916 for (const auto* input_edge : in_edges) {
1917 const Node& input_node = *input_edge->src();
1918 if (input_node.IsSource()) {
1919 if (in_edges.size() != 1) {
1920 return errors::FailedPrecondition(
1921 "The node has other inputs besides the _Source node");
1922 }
1923 // We don't import the _SOURCE node.
1924 continue;
1925 }
1926 if (input_node.IsArg() && input_edge->IsControlEdge()) {
1927 // Currently we have not reached consensus as to what TF function
1928 // semantics are (b/133509504). Here we assume that all arguments to a
1929 // function should be available before we start execution of any internal
1930 // node. This makes the control dependencies between function arguments
1931 // and internal nodes redundant, and so we do not import them. The TF
1932 // inliner however assumes no such dependency between function args and
1933 // internal nodes exists, unless explicitly stated. Since we drop control
1934 // dependencies here, it leads to loss of information. If the function is
1935 // inlined later, the inliner would not know of these explicit control
1936 // dependencies present in the original graph.
1937 continue;
1938 }
1939 if (node_values_.find(input_node.id()) == node_values_.end())
1940 return errors::FailedPrecondition(
1941 "Graph not traversed in reverse post order; use seen before def!");
1942 mlir::Operation* inst = node_values_[input_node.id()];
1943 if (input_edge->IsControlEdge())
1944 control_operands.push_back(inst->getResult(inst->getNumResults() - 1));
1945 else
1946 result.operands.push_back(inst->getResult(input_edge->src_output()));
1947 }
1948
1949 using FuncPairType = std::pair<const std::string*, const AttrValue*>;
1950 std::vector<FuncPairType> funcs;
1951 result.attributes.reserve(node.attrs().size() + 2);
1952 auto abstract_op = result.name.getAbstractOperation();
1953 auto derived_op =
1954 abstract_op
1955 ? abstract_op->getInterface<mlir::DerivedAttributeOpInterface>()
1956 : nullptr;
1957 for (const auto& name_and_value : node.attrs()) {
1958 const auto& attr_name = name_and_value.first;
1959 // Skip adding derived attributes to the generated op.
1960 if (derived_op && derived_op->isDerivedAttribute(attr_name)) continue;
1961 const AttrValue& attr_value = name_and_value.second;
1962
1963 // Remove _output_shapes attribute that will be added by the exporter.
1964 if (IsOutputShapesAttribute(attr_value, attr_name)) continue;
1965
1966 if (attr_value.value_case() == AttrValue::kFunc) {
1967 // Attribute iteration order is not defined for protocol buffer Map.
1968 // Process function attributes separately in the lexicographical order to
1969 // have deterministic order of functions in the constructed IR.
1970 funcs.emplace_back(&attr_name, &attr_value);
1971 } else {
1972 TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value));
1973 result.attributes.push_back(builder_.getNamedAttr(attr_name, attr));
1974 }
1975 }
1976
1977 auto comparator = [](const FuncPairType& a, const FuncPairType& b) {
1978 return *a.first < *b.first;
1979 };
1980 std::sort(funcs.begin(), funcs.end(), comparator);
1981 for (const auto& func : funcs) {
1982 TF_RETURN_IF_ERROR(ConvertFunctionCallAttribute(*func.first, *func.second,
1983 &result.attributes));
1984 }
1985
1986 const auto& node_def = node.def();
1987 result.attributes.push_back(builder_.getNamedAttr(
1988 "device", builder_.getStringAttr(std::string(node_def.device()))));
1989
1990 // Map user function calls to LegacyCall ops and add the user function name
1991 // as an attribute.
1992 if (convert_to_legacy_call) {
1993 result.name = mlir::OperationName(get_full_op_name("LegacyCall"), context_);
1994 mlir::SymbolRefAttr val = builder_.getSymbolRefAttr(node_type_name);
1995 result.addAttribute("f", val);
1996
1997 if (!result.attributes.get("_disable_call_shape_inference")) {
1998 result.addAttribute("_disable_call_shape_inference",
1999 builder_.getBoolAttr(false));
2000 }
2001 }
2002
2003 auto composite_control_flow_op = [&](const std::string& name) {
2004 result.name = mlir::OperationName(get_full_op_name(name), context_);
2005 bool stateless = absl::StartsWith(node_type_name, "Stateless");
2006 mlir::BoolAttr val = builder_.getBoolAttr(stateless);
2007 result.attributes.push_back(builder_.getNamedAttr("is_stateless", val));
2008 };
2009
2010 // Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common
2011 // Case/If/While op in MLIR and add the differentiating attribute.
2012 if (node.IsCaseNode()) composite_control_flow_op("Case");
2013 if (node.IsIfNode()) composite_control_flow_op("If");
2014 if (node.IsWhileNode()) {
2015 composite_control_flow_op("While");
2016 auto* output_shapes = node.attrs().Find("output_shapes");
2017 if (output_shapes && !output_shapes->list().shape().empty())
2018 result.attributes.push_back(
2019 builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr()));
2020 }
2021
2022 // Register the mapping between the TF node and the newly created operation.
2023 node_values_[node.id()] =
2024 CreateOperation(node, node_type_name, result, control_operands);
2025 return Status::OK();
2026 }
2027
2028 // Add the backedges to the CFG. Given a backedge, we replace the original
2029 // source and destination operations by two new operations. Most of the
2030 // fields of the replacements are copied from the original operations.
2031 // However,
2032 // - for the src operation, one output is inserted to the front of the output
2033 // list. The type of the output is set to the type of the non-control result
2034 // of the dst operation, and
2035 // - for the dst operation, one operand is inserted to the front of the
2036 // operand list. This operand is using the first result of the src
2037 // operation.
2038 // TODO(fengliuai): Preserve the order of the results and operands if
2039 // necessary.
AddBackedges()2040 Status ImporterBase::AddBackedges() {
2041 for (auto it : back_edge_dst_inputs_) {
2042 BackEdge& edge = it.second;
2043 if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) {
2044 return errors::FailedPrecondition(
2045 "Invalid backedge; should be from NextIteration to Merge!");
2046 }
2047 auto* sink = node_values_[edge.src->id()];
2048 auto* dst = node_values_[edge.dst->id()];
2049 TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input));
2050 }
2051 return Status::OK();
2052 }
2053
AddBackedge(mlir::Operation * sink,mlir::Operation * dst,int dst_input)2054 Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
2055 int dst_input) {
2056 // Get the NextIteration.Source operation from the token operand of the sink.
2057 mlir::Operation* source = sink->getOperand(0).getDefiningOp();
2058
2059 // Adds the "source" to the operands of the dst by creating a new dst
2060 // operation.
2061 mlir::OperationState state(dst->getLoc(), dst->getName());
2062 auto num_operands = dst->getNumOperands();
2063 state.operands.reserve(num_operands + 1);
2064 for (int input = 0, e = num_operands + 1; input != e; ++input) {
2065 if (input < dst_input) {
2066 state.operands.push_back(dst->getOperand(input));
2067 } else if (input == dst_input) {
2068 state.operands.push_back(source->getResult(0));
2069 } else {
2070 state.operands.push_back(dst->getOperand(input - 1));
2071 }
2072 }
2073 state.attributes.assign(dst->getAttrs().begin(), dst->getAttrs().end());
2074 state.types.assign(dst->getResultTypes().begin(),
2075 dst->getResultTypes().end());
2076 builder_.setInsertionPoint(dst);
2077 auto* new_dst = builder_.createOperation(state);
2078
2079 // Replaces the output uses of the old operation by the corresponding
2080 // result of the new operation, and deletes the old operation.
2081 for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) {
2082 auto new_output = new_dst->getResult(i);
2083 dst->getResult(i).replaceAllUsesWith(new_output);
2084 }
2085 dst->dropAllReferences();
2086 dst->erase();
2087 return Status::OK();
2088 }
2089
InferLibFunctionType(const FunctionBody & fbody)2090 StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
2091 const FunctionBody& fbody) {
2092 mlir::Builder builder(context_);
2093
2094 // The FunctionBody contains a graph with a single-output _Arg node for each
2095 // function argument and a single-input _Retval node for each function return
2096 // value.
2097 //
2098 // We already populated the ShapeRefiner with all the information about the
2099 // shapes of these graph edges, so we just query it to build the corresponding
2100 // MLIR function type signature.
2101
2102 llvm::SmallVector<mlir::Type, 4> arg_types;
2103 if (specs_.inputs.empty()) {
2104 arg_types.reserve(fbody.arg_types.size());
2105 for (auto arg : fbody.arg_nodes) {
2106 // Find node in the graph using the node id instead of using `arg`
2107 // directly because the graph has been cloned.
2108 auto* node = graph_->FindNodeId(arg->id());
2109 TF_ASSIGN_OR_RETURN(auto type,
2110 InferOutputType(*node, /*idx=*/0, builder));
2111 arg_types.push_back(type);
2112 }
2113 } else {
2114 arg_types.reserve(fbody.arg_types.size());
2115 for (const auto& it : llvm::enumerate(specs_.inputs)) {
2116 mlir::Type element_type;
2117 const auto& node_info = it.value().second;
2118 DataType dtype = node_info.imported_dtype;
2119 // Uses the existing output type of the arg node if the data type of the
2120 // the node isn't specified through the import configuration.
2121 if (dtype == DT_INVALID) {
2122 auto arg = fbody.arg_nodes[it.index()];
2123 auto* node = graph_->FindNodeId(arg->id());
2124 dtype = node->output_type(0);
2125 if (dtype == DT_INVALID) {
2126 return errors::InvalidArgument("Input ", it.index(),
2127 "has invalid data type");
2128 }
2129 }
2130 TF_RETURN_IF_ERROR(
2131 ::tensorflow::ConvertDataType(dtype, builder, &element_type));
2132 if (node_info.shape.unknown_rank()) {
2133 arg_types.push_back(mlir::UnrankedTensorType::get(element_type));
2134 } else {
2135 llvm::SmallVector<int64_t, 4> shape;
2136 TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
2137 arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
2138 }
2139 }
2140 }
2141
2142 llvm::SmallVector<mlir::Type, 4> ret_types;
2143 ret_types.reserve(fbody.ret_types.size());
2144 for (auto ret : fbody.ret_nodes) {
2145 // Find node in the graph using the node id instead of using `ret` directly
2146 // because the graph has been cloned.
2147 auto* node = graph_->FindNodeId(ret->id());
2148 TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder));
2149 ret_types.push_back(type);
2150 }
2151
2152 return builder.getFunctionType(arg_types, ret_types);
2153 }
2154
2155 // Stateful helper class to import a TensorFlow model expressed in GraphDef into
2156 // an MLIR Module.
2157 //
2158 // The nodes defined in the graph are converted to a function called
2159 // 'func_name'. All library function definitions are converted to MLIR functions
2160 // in the module.
2161 class GraphDefImporter : public ImporterBase {
2162 public:
2163 // Main entry point: converts the given graph to an MLIR Module.
2164 static StatusOr<mlir::OwningModuleRef> Convert(
2165 mlir::MLIRContext* context, const Graph& graph,
2166 const GraphDebugInfo& debug_info,
2167 const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
2168 llvm::StringRef func_name);
2169
2170 private:
GraphDefImporter(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier)2171 explicit GraphDefImporter(
2172 const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
2173 const GraphImportConfig& specs, mlir::ModuleOp module,
2174 std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
2175 NameUniquifier* function_name_uniquifier)
2176 : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
2177 function_name_uniquifier) {}
2178
2179 // Returns the function signature of the main function of converted MLIR
2180 // module, the input nodes and output nodes. The type and shape information
2181 // for the function arguments are read from `specs`, but the type and shape
2182 // information for the function returns are inferred by the shape refiner in
2183 // ImporterBase.
2184 StatusOr<mlir::FunctionType> InferMainFunctionType(
2185 const GraphImportConfig& specs, mlir::MLIRContext* context,
2186 absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2187 absl::InlinedVector<OutputTensor, 4>* ret_nodes);
2188
2189 // Returns the function signature of the main function, alongside input and
2190 // output nodes, for function graphs. Arguments and return values are
2191 // determined by node op type. Type and shape information of the function are
2192 // inferred by the shape refiner in ImporterBase.
2193 StatusOr<mlir::FunctionType> GetArgsRetsAndTypesFromFunctionGraph(
2194 mlir::MLIRContext* context,
2195 absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2196 absl::InlinedVector<OutputTensor, 4>* ret_nodes);
2197
2198 // Finds the graph's target nodes/function's control ret nodes based on
2199 // supplied node names in `control_outputs`. If `control_outputs` are not
2200 // unique or a control ret node is missing, an error will be returned.
2201 Status GetControlRetsFromGraph(
2202 llvm::ArrayRef<std::string> control_outputs,
2203 absl::InlinedVector<Node*, 4>* control_ret_nodes);
2204 };
2205
Convert(mlir::MLIRContext * context,const Graph & graph,const GraphDebugInfo & debug_info,const FunctionLibraryDefinition & flib_def,const GraphImportConfig & specs,llvm::StringRef func_name)2206 StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
2207 mlir::MLIRContext* context, const Graph& graph,
2208 const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
2209 const GraphImportConfig& specs, llvm::StringRef func_name) {
2210 LoadImporterDialects(*context);
2211 mlir::OwningModuleRef module =
2212 mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
2213 std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
2214 NameUniquifier function_name_uniquifier(flib_def);
2215
2216 GraphDefImporter importer(flib_def, debug_info, specs, module.get(),
2217 &tf_name_to_mlir_name, &function_name_uniquifier);
2218
2219 TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
2220
2221 mlir::FunctionType func_type;
2222 absl::InlinedVector<OutputTensor, 4> arg_nodes;
2223 absl::InlinedVector<OutputTensor, 4> ret_nodes;
2224 absl::InlinedVector<Node*, 4> control_ret_nodes;
2225 llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
2226 if (specs.graph_as_function) {
2227 if (specs.prune_unused_nodes || !specs.inputs.empty() ||
2228 !specs.outputs.empty())
2229 return errors::InvalidArgument(
2230 "Pruning of graph is currently unsupported when the main graph is "
2231 "converted to a function.");
2232
2233 TF_ASSIGN_OR_RETURN(func_type,
2234 importer.GetArgsRetsAndTypesFromFunctionGraph(
2235 context, &arg_nodes, &ret_nodes));
2236
2237 TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
2238 &control_ret_nodes));
2239
2240 mlir::Builder b(context);
2241 std::string s;
2242 llvm::raw_string_ostream ss(s);
2243 auto node_name = [&](const OutputTensor& tensor) {
2244 ss << tensor.node->name();
2245 };
2246 llvm::interleave(arg_nodes, ss, node_name, ",");
2247 auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
2248 s.clear();
2249 llvm::interleave(ret_nodes, ss, node_name, ",");
2250 auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
2251 s.clear();
2252 llvm::interleave(specs.control_outputs, ss, ",");
2253 auto control_outputs =
2254 b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
2255
2256 // Under `graph_as_function` mode, `tf.entry_function` is always set as it
2257 // is assumed feed, fetch, and target nodes are set correctly.
2258 attrs.push_back(b.getNamedAttr(
2259 "tf.entry_function",
2260 b.getDictionaryAttr({inputs, outputs, control_outputs})));
2261 } else {
2262 // Collects the argument and return nodes by looking up the node names
2263 // specified by the user.
2264 TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType(
2265 specs, context, &arg_nodes, &ret_nodes));
2266
2267 TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
2268 &control_ret_nodes));
2269
2270 // TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and
2271 // decoding in a centralized place.
2272 // Record the input and output mapping.
2273 if (!specs.inputs.empty() || !specs.outputs.empty() ||
2274 !specs.control_outputs.empty()) {
2275 mlir::Builder b(context);
2276 std::string s;
2277 llvm::raw_string_ostream ss(s);
2278 llvm::interleave(
2279 specs.inputs, ss,
2280 [&](const std::pair<std::string, ArrayInfo>& v) { ss << v.first; },
2281 ",");
2282 auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
2283 s.clear();
2284 llvm::interleave(specs.outputs, ss, ",");
2285 auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
2286 s.clear();
2287 llvm::interleave(specs.control_outputs, ss, ",");
2288 auto control_outputs =
2289 b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
2290
2291 attrs.push_back(b.getNamedAttr(
2292 "tf.entry_function",
2293 b.getDictionaryAttr({inputs, outputs, control_outputs})));
2294 }
2295 }
2296
2297 // Record version info.
2298 PopulateTfVersions(module.get(), graph.versions());
2299
2300 TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
2301 func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs));
2302
2303 // Mark main function public, others private.
2304 for (auto function : module.get().getOps<mlir::FuncOp>()) {
2305 auto visibility = function.getName() == func_name
2306 ? mlir::FuncOp::Visibility::Public
2307 : mlir::FuncOp::Visibility::Private;
2308 function.setVisibility(visibility);
2309 }
2310 return module;
2311 }
2312
InferMainFunctionType(const GraphImportConfig & specs,mlir::MLIRContext * context,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes)2313 StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
2314 const GraphImportConfig& specs, mlir::MLIRContext* context,
2315 absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2316 absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
2317 // Find all the input nodes and output nodes.
2318 // Feeds have been remapped to single output nodes (Placeholder), so an exact
2319 // name match is sufficient.
2320 absl::flat_hash_map<absl::string_view, int> inputs;
2321 for (auto input_and_idx : llvm::enumerate(specs.inputs)) {
2322 TensorId tensor = ParseTensorName(input_and_idx.value().first);
2323 auto remapped_it = remapped_feeds_.find(tensor);
2324 if (remapped_it != remapped_feeds_.end()) {
2325 inputs.insert({remapped_it->second, input_and_idx.index()});
2326 } else {
2327 inputs.insert({tensor.node(), input_and_idx.index()});
2328 }
2329 }
2330
2331 absl::flat_hash_set<absl::string_view> output_node_names;
2332 std::vector<TensorId> outputs;
2333 output_node_names.reserve(specs.outputs.size());
2334 for (const auto& output : specs.outputs) {
2335 TensorId tensor = ParseTensorName(output);
2336 auto remapped_it = remapped_feeds_.find(tensor);
2337 if (remapped_it != remapped_feeds_.end()) {
2338 output_node_names.insert(remapped_it->second);
2339 outputs.push_back({remapped_it->second, 0});
2340 } else {
2341 output_node_names.insert(tensor.node());
2342 outputs.push_back(tensor);
2343 }
2344 }
2345
2346 if (!inputs.empty() || !outputs.empty()) {
2347 arg_nodes->resize(inputs.size());
2348 ret_nodes->resize(outputs.size());
2349
2350 for (Node* n : GetOrderedNodes()) {
2351 // Handle inputs/arguments.
2352 auto input_it = inputs.find(n->name());
2353 if (input_it != inputs.end()) {
2354 (*arg_nodes)[input_it->second] = {n, 0};
2355 }
2356
2357 // Handle outputs/returns.
2358 if (output_node_names.contains(n->name())) {
2359 for (int i = 0, e = outputs.size(); i != e; ++i) {
2360 TensorId tensor = outputs[i];
2361 if (n->name() != tensor.node()) continue;
2362 (*ret_nodes)[i] = {n, tensor.index()};
2363 }
2364 }
2365 }
2366 }
2367
2368 // Starts to construct the function type.
2369 mlir::Builder builder(context);
2370 llvm::SmallVector<mlir::Type, 4> arg_types;
2371 arg_types.reserve(specs.inputs.size());
2372 int i = 0;
2373 for (const auto& it : specs.inputs) {
2374 Node* arg_node = arg_nodes->at(i).node;
2375 if (arg_node == nullptr) {
2376 return errors::InvalidArgument("Input ", it.first,
2377 " was not found in graph");
2378 }
2379 mlir::Type element_type;
2380 const auto& node_info = it.second;
2381 DataType imported_dtype = node_info.imported_dtype;
2382 // Uses the existing output type of the arg node if the data type of the
2383 // the node isn't specified through the import configuration.
2384 if (imported_dtype == DT_INVALID) {
2385 imported_dtype = arg_node->output_type(0);
2386 if (imported_dtype == DT_INVALID) {
2387 return errors::InvalidArgument("Input ", i, "has invalid data type");
2388 }
2389 }
2390 TF_RETURN_IF_ERROR(
2391 ::tensorflow::ConvertDataType(imported_dtype, builder, &element_type));
2392 if (node_info.shape.unknown_rank()) {
2393 arg_types.push_back(mlir::UnrankedTensorType::get(element_type));
2394 } else {
2395 llvm::SmallVector<int64_t, 4> shape;
2396 TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
2397 arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
2398 }
2399 i++;
2400 }
2401
2402 llvm::SmallVector<mlir::Type, 4> ret_types;
2403 ret_types.reserve(specs.outputs.size());
2404 for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
2405 if (ret_nodes->at(i).node == nullptr) {
2406 return errors::InvalidArgument("Output ", specs.outputs[i],
2407 " was not found in graph");
2408 }
2409 }
2410 for (const auto& ret : *ret_nodes) {
2411 if (ret.node->num_outputs() <= ret.index) {
2412 return errors::InvalidArgument("Invalid output index ", ret.index,
2413 " specified for node: ", ret.node->name());
2414 }
2415 TF_ASSIGN_OR_RETURN(auto type,
2416 InferOutputType(*ret.node, ret.index, builder));
2417 ret_types.push_back(type);
2418 }
2419
2420 return builder.getFunctionType(arg_types, ret_types);
2421 }
2422
2423 StatusOr<mlir::FunctionType>
GetArgsRetsAndTypesFromFunctionGraph(mlir::MLIRContext * context,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes)2424 GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
2425 mlir::MLIRContext* context, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2426 absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
2427 auto add_node = [](Node* node, absl::InlinedVector<OutputTensor, 4>* nodes) {
2428 auto* attr = node->attrs().Find("index");
2429 if (!attr)
2430 return errors::InvalidArgument(node->type_string(), " node '",
2431 node->name(),
2432 "' is missing attribute 'index'");
2433
2434 auto index = attr->i();
2435 const int num_nodes = nodes->size();
2436 if (num_nodes < index + 1) nodes->resize(index + 1);
2437
2438 if ((*nodes)[index].node != nullptr)
2439 return errors::InvalidArgument(node->type_string(), " node '",
2440 node->name(), "' has attribute 'index' ",
2441 index, " that conflicts with node '",
2442 (*nodes)[index].node->name(), "'");
2443 (*nodes)[index] = {node, 0};
2444
2445 return Status::OK();
2446 };
2447
2448 // Collect arg and ret nodes from graph.
2449 for (auto* node : GetOrderedNodes())
2450 if (node->IsArg())
2451 TF_RETURN_IF_ERROR(add_node(node, arg_nodes));
2452 else if (node->IsRetval())
2453 TF_RETURN_IF_ERROR(add_node(node, ret_nodes));
2454
2455 // Collect arg and ret types and create function type.
2456 mlir::Builder builder(context);
2457 llvm::SmallVector<mlir::Type, 4> arg_types;
2458 arg_types.reserve(arg_nodes->size());
2459 for (auto arg_node_and_idx : llvm::enumerate(*arg_nodes)) {
2460 auto& arg_node = arg_node_and_idx.value();
2461 if (arg_node.node == nullptr)
2462 return errors::InvalidArgument("Graph missing _Arg at index ",
2463 arg_node_and_idx.index());
2464
2465 TF_ASSIGN_OR_RETURN(auto type,
2466 InferOutputType(*arg_node.node, /*idx=*/0, builder));
2467 arg_types.push_back(type);
2468 }
2469
2470 llvm::SmallVector<mlir::Type, 4> ret_types;
2471 ret_types.reserve(ret_nodes->size());
2472 for (auto ret_node_and_idx : llvm::enumerate(*ret_nodes)) {
2473 auto& ret_node = ret_node_and_idx.value();
2474 if (ret_node.node == nullptr)
2475 return errors::InvalidArgument("Graph missing _Retval at index ",
2476 ret_node_and_idx.index());
2477
2478 TF_ASSIGN_OR_RETURN(auto type,
2479 InferInputType(*ret_node.node, /*idx=*/0, builder));
2480 ret_types.push_back(type);
2481 }
2482
2483 return builder.getFunctionType(arg_types, ret_types);
2484 }
2485
GetControlRetsFromGraph(llvm::ArrayRef<std::string> control_outputs,absl::InlinedVector<Node *,4> * control_ret_nodes)2486 Status GraphDefImporter::GetControlRetsFromGraph(
2487 llvm::ArrayRef<std::string> control_outputs,
2488 absl::InlinedVector<Node*, 4>* control_ret_nodes) {
2489 if (control_outputs.empty()) return Status::OK();
2490
2491 llvm::SmallDenseMap<llvm::StringRef, int32_t> controls_to_idx;
2492 for (auto control_and_idx : llvm::enumerate(control_outputs))
2493 controls_to_idx.insert({control_and_idx.value(), control_and_idx.index()});
2494
2495 if (controls_to_idx.size() != control_outputs.size())
2496 return errors::InvalidArgument("Control outputs must be unique");
2497
2498 control_ret_nodes->resize(controls_to_idx.size());
2499
2500 for (auto* node : GetOrderedNodes()) {
2501 auto it = controls_to_idx.find(node->name());
2502 if (it != controls_to_idx.end()) (*control_ret_nodes)[it->second] = node;
2503 }
2504
2505 for (auto node_and_name : llvm::zip(*control_ret_nodes, control_outputs))
2506 if (std::get<0>(node_and_name) == nullptr)
2507 return errors::InvalidArgument(
2508 "Control output '", std::get<1>(node_and_name), "' is missing");
2509
2510 return Status::OK();
2511 }
2512
2513 // Stateful helper class to import a TensorFlow model expressed in SavedModel
2514 // into an MLIR Module.
2515 class SavedModelObjectGraphImporter : public ImporterBase {
2516 public:
2517 // Main entry point: converts all functions in the given meta graph to an MLIR
2518 // Module.
2519 static StatusOr<mlir::OwningModuleRef> Convert(
2520 SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
2521 mlir::MLIRContext* context, bool add_default_attributes);
2522
2523 private:
SavedModelObjectGraphImporter(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier)2524 explicit SavedModelObjectGraphImporter(
2525 const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
2526 const GraphImportConfig& specs, mlir::ModuleOp module,
2527 std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
2528 NameUniquifier* function_name_uniquifier)
2529 : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
2530 function_name_uniquifier) {}
2531 };
2532
2533 // Determines the names used to reference objects in the SavedObjectGraph.
2534 class ObjectNames {
2535 public:
2536 explicit ObjectNames(const SavedObjectGraph& object_graph,
2537 absl::Span<std::string> exported_names);
2538
2539 // Gets the names that external users of the SavedModel can use to refer to
2540 // this node.
2541 llvm::ArrayRef<llvm::StringRef> GetExportedNames(int node_id) const;
2542
2543 // Gets the name in the module symbol table for this node.
2544 // This name is only used for internal IR references.
2545 llvm::StringRef GetSymbolTableName(int node_id) const;
2546
2547 private:
2548 // In the absence of any other information, use this name as the symbol table
2549 // name for this node.
2550 std::string GetDefaultSymbolTableName(int node_id) const;
2551 // Determines if a name is exported.
2552 bool IsExported(const std::string& name);
2553 // Main object graph traversal function.
2554 void RecursivelyVisitObjectGraph(int node_id);
2555 // Gets a stable StringRef from a std::string.
2556 llvm::StringRef SaveString(const std::string& s) const;
2557
2558 // The object graph we are traversing.
2559 const SavedObjectGraph& object_graph_;
2560 // The set of names to export. Empty means "export all".
2561 std::unordered_set<std::string> names_to_export_;
2562
2563 // When we recursively follow the object graph tree structure from the root,
2564 // we track its path in the object graph by pushing and popping from here
2565 // during traversal.
2566 llvm::SmallVector<std::string, 8> path_segments_;
2567 // The set of node_id's that are on the current DFS stack.
2568 // For cyclic object graphs, this prevents infinite recursion.
2569 std::unordered_set<int> on_stack_nodes_;
2570
2571 // Key: node_id.
2572 // Value: all object names that node_id appears as.
2573 // Each object name corresponds to a unique path from the root of the object
2574 // graph.
2575 // The common intuitive case is when there is only one name for a given
2576 // object, which corresponds to the object graph being a tree.
2577 //
2578 // But, there cases where the object graph is a general graph. For
2579 // example, this happens commonly in Keras models, where `foo.bar` is
2580 // also reachable via the name `keras_api.foo.bar`.
2581 // Cycles are possible too.
2582 absl::flat_hash_map<int, std::vector<std::string>> object_names_;
2583
2584 // Key: node_id
2585 // Value: all names that this object is exported as
2586 absl::flat_hash_map<int, llvm::SmallVector<llvm::StringRef, 1>>
2587 exported_names_;
2588 // Key: node_id
2589 // Value: pretty symbol table name to use for internal references to this
2590 // object.
2591 absl::flat_hash_map<int, llvm::StringRef> pretty_symbol_table_name_;
2592
2593 // Stable strings we can take StringRef's into. Used only by the SaveString
2594 // method.
2595 mutable std::unordered_set<std::string> saved_strings_;
2596 };
2597
ObjectNames(const SavedObjectGraph & object_graph,absl::Span<std::string> exported_names)2598 ObjectNames::ObjectNames(const SavedObjectGraph& object_graph,
2599 absl::Span<std::string> exported_names)
2600 : object_graph_(object_graph),
2601 names_to_export_(exported_names.begin(), exported_names.end()) {
2602 // Visit all reachable nodes from the root of the object graph.
2603 // This builds up object_names_ to contain all names like `foo.bar` that a
2604 // particular node in the graph can be reached from.
2605 RecursivelyVisitObjectGraph(/*node_id=*/0);
2606
2607 // Populate the exported_names_ map.
2608 // TODO(silvasean): Diagnose typos in exported names?
2609 for (auto& kv : object_names_) {
2610 // Make object names map independent of our particular choice of object
2611 // graph traversal.
2612 std::sort(kv.second.begin(), kv.second.end(),
2613 [](absl::string_view a, absl::string_view b) {
2614 // The sort order here influences the "pretty name" we assign
2615 // below. We want the most debuggable name to be first.
2616 //
2617 // Debuggability heuristics:
2618 // 1. Names that end in digits are likely to be internal aliases
2619 // to the "real" names.
2620 // 2. Longer names are more likely to be internal aliases.
2621 //
2622 // Example set of object names created by Keras for the weight
2623 // matrix of a fully connected layer on a trivial FC mnist
2624 // model:
2625 // - `model.layer-1.kernel` (this is the "best" name)
2626 // - `model.keras_api.layers.1.kernel`
2627 // - `model.variables.0`
2628 // - `model.keras_api.layers.1.keras_api.trainable_variables.0`
2629 // - ... 10 more long aliases ending in digits ...
2630 return std::make_tuple(isdigit(a.back()), a.size(), a) <
2631 std::make_tuple(isdigit(b.back()), b.size(), b);
2632 });
2633 for (const std::string& name : kv.second) {
2634 if (IsExported(name)) {
2635 exported_names_[kv.first].push_back(SaveString(name));
2636 }
2637 }
2638 }
2639 // Create "pretty" symbol table names for nodes where that is applicable.
2640 // We could make all symbol table names use the default, which is basically
2641 // just the node id. But for debugging purposes, it's nicer if we can mix in
2642 // a recognizable object name if we have the information to do so.
2643 for (auto& kv : object_names_) {
2644 int node_id = kv.first;
2645 std::string internal_name =
2646 absl::StrCat(GetDefaultSymbolTableName(node_id), "__");
2647 // If the object has an exported name, we prefer that since it is probably
2648 // the most recognizable. Otherwise, we grab some non-exported name of the
2649 // object.
2650 if (exported_names_.find(node_id) != exported_names_.end()) {
2651 internal_name += exported_names_[node_id][0].str();
2652 } else {
2653 internal_name += object_names_[node_id][0];
2654 }
2655 pretty_symbol_table_name_[node_id] = SaveString(internal_name);
2656 }
2657 }
2658
GetExportedNames(int node_id) const2659 llvm::ArrayRef<llvm::StringRef> ObjectNames::GetExportedNames(
2660 int node_id) const {
2661 auto it = exported_names_.find(node_id);
2662 if (it != exported_names_.end()) {
2663 return it->second;
2664 }
2665 return {};
2666 }
2667
GetSymbolTableName(int node_id) const2668 llvm::StringRef ObjectNames::GetSymbolTableName(int node_id) const {
2669 auto it = pretty_symbol_table_name_.find(node_id);
2670 if (it != pretty_symbol_table_name_.end()) {
2671 return it->second;
2672 }
2673 return SaveString(GetDefaultSymbolTableName(node_id));
2674 }
2675
GetDefaultSymbolTableName(int node_id) const2676 std::string ObjectNames::GetDefaultSymbolTableName(int node_id) const {
2677 return absl::StrCat("__sm_node", node_id);
2678 }
2679
IsExported(const std::string & name)2680 bool ObjectNames::IsExported(const std::string& name) {
2681 if (names_to_export_.empty()) {
2682 return true;
2683 }
2684 return names_to_export_.find(name) != names_to_export_.end();
2685 }
2686
RecursivelyVisitObjectGraph(int node_id)2687 void ObjectNames::RecursivelyVisitObjectGraph(int node_id) {
2688 const SavedObject& object = object_graph_.nodes(node_id);
2689
2690 switch (object.kind_case()) {
2691 case SavedObject::kConstant:
2692 case SavedObject::kFunction:
2693 case SavedObject::kVariable: {
2694 object_names_[node_id].push_back(absl::StrJoin(path_segments_, "."));
2695 break;
2696 }
2697 default:
2698 break;
2699 }
2700
2701 for (const auto& child_ref : object.children()) {
2702 bool on_stack = !on_stack_nodes_.insert(child_ref.node_id()).second;
2703 if (on_stack) {
2704 // This is a backedge. Don't traverse it.
2705 continue;
2706 }
2707
2708 path_segments_.push_back(child_ref.local_name());
2709 RecursivelyVisitObjectGraph(child_ref.node_id());
2710 path_segments_.pop_back();
2711
2712 on_stack_nodes_.erase(child_ref.node_id());
2713 }
2714 }
2715
SaveString(const std::string & s) const2716 llvm::StringRef ObjectNames::SaveString(const std::string& s) const {
2717 return llvm::StringRef(*saved_strings_.insert(s).first);
2718 }
2719
2720 // Extracts a TensorProto for a Const op from a GraphDef, given an op_name.
2721 // Returns nullptr on not found or other mismatch.
2722 // This returns a pointer to the actual node within the graph_def so as to
2723 // avoid expensive copies.
ExtractConstTensorFromGraph(const GraphDef & graph_def,const std::string & op_name)2724 const TensorProto* ExtractConstTensorFromGraph(const GraphDef& graph_def,
2725 const std::string& op_name) {
2726 const NodeDef* match_node = nullptr;
2727 for (const auto& node : graph_def.node()) {
2728 if (node.name() == op_name) {
2729 match_node = &node;
2730 }
2731 }
2732
2733 if (!match_node) {
2734 return nullptr;
2735 }
2736
2737 auto value_it = match_node->attr().find("value");
2738 if (value_it == match_node->attr().end()) {
2739 return nullptr;
2740 }
2741
2742 if (!value_it->second.has_tensor()) {
2743 return nullptr;
2744 }
2745
2746 return &value_it->second.tensor();
2747 }
2748
2749 const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(const TrackableObjectGraph::TrackableObject & trackable_object,StringPiece name)2750 FindSerializedTensorInTrackable(
2751 const TrackableObjectGraph::TrackableObject& trackable_object,
2752 StringPiece name) {
2753 for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
2754 if (maybe_serialized_tensor.name() == name) {
2755 return &maybe_serialized_tensor;
2756 }
2757 }
2758 return nullptr;
2759 }
2760
DiagnoseMultipleConcreteFunctions(const SavedObjectGraph & object_graph,const ObjectNames & object_names)2761 Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph,
2762 const ObjectNames& object_names) {
2763 for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
2764 const SavedObject& object = object_graph.nodes(node_id);
2765 if (object_names.GetExportedNames(node_id).empty()) {
2766 continue;
2767 }
2768 if (object.kind_case() == SavedObject::kFunction) {
2769 // We only allow a single input signature to each SavedFunction.
2770 // This assumption means we have a 1:1 correspondence between
2771 // tf.function <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef
2772 // This makes defining the ABI easier (or even well-defined at all).
2773 // TODO(silvasean): How to detect a function that doesn't have an
2774 // explicitly user-provided input signature, but happens to have been
2775 // traced exactly once?
2776 if (object.function().concrete_functions_size() != 1) {
2777 llvm::SmallVector<std::string, 4> names;
2778 for (llvm::StringRef s : object_names.GetExportedNames(node_id)) {
2779 names.push_back("'" + s.str() + "'");
2780 }
2781 return errors::InvalidArgument(
2782 "Exported function with exported name(s) ",
2783 absl::StrJoin(names, ", "),
2784 " with multiple concrete functions. Add "
2785 "@tf.function(input_signature=[...]) on this function, or use a "
2786 "narrower list of exported names that excludes this function.");
2787 }
2788 }
2789 }
2790 return Status::OK();
2791 }
2792
2793 // Recursively traverses a StructuredValue, linearizing all the leaves.
2794 //
2795 // This currently only handles the subset of StructuredValue that is needed for
2796 // signatures.
2797 //
2798 // Given a StructuredValue with structure [{"x": leaf0}], the "index path"
2799 // needed to reach leaf0 is `[0, "x"]`, as it would be if you were operating on
2800 // a Python object (`obj[0]["x"] is leaf0`). Each leaf corresponds to a
2801 // linearized function argument or return on a FunctionDef, and hence to an
2802 // mlir::FuncOp argument / return.
2803 //
2804 // This must match the linearization that happens in `tf.nest.flatten`.
2805 // In particular, dict values should be linearized in sorted key order.
2806 //
2807 // The linearized index paths can be returned back to a structured
2808 // representation (e.g. to emit C structs matching a signature) with a simple
2809 // algorithm that recurses on each run of index paths with identical first
2810 // elements.
2811 class StructuredValueLinearizer {
2812 public:
2813 StructuredValueLinearizer(const StructuredValue& value,
2814 mlir::MLIRContext* context);
2815
2816 // Returns the list of index paths to each leaf of the StructuredValue,
2817 // in a linearized order matching `tf.nest.flatten`.
2818 //
2819 // If an error occurred during the linearization process, an error message
2820 // with `error_context` prepended will be included in the returned status.
2821 StatusOr<llvm::ArrayRef<mlir::ArrayAttr>> GetLeafIndexPaths(
2822 llvm::StringRef error_context) const;
2823
2824 private:
2825 // Main function that recursively traverses the StructuredValue.
2826 void RecursivelyFindLeaves(const StructuredValue& value);
2827
2828 mlir::Builder builder_;
2829 // The current index path. We push/pop this during recursive traversal of the
2830 // StructuredValue.
2831 llvm::SmallVector<mlir::Attribute, 4> current_index_path_;
2832 // The list of leaf index paths we have discovered so far.
2833 llvm::SmallVector<mlir::ArrayAttr, 4> leaf_index_paths_;
2834 // If non-empty, an error message to report.
2835 std::string error_message_;
2836 };
2837
StructuredValueLinearizer(const StructuredValue & value,mlir::MLIRContext * context)2838 StructuredValueLinearizer::StructuredValueLinearizer(
2839 const StructuredValue& value, mlir::MLIRContext* context)
2840 : builder_(context) {
2841 RecursivelyFindLeaves(value);
2842 }
2843
2844 StatusOr<llvm::ArrayRef<mlir::ArrayAttr>>
GetLeafIndexPaths(llvm::StringRef error_context) const2845 StructuredValueLinearizer::GetLeafIndexPaths(
2846 llvm::StringRef error_context) const {
2847 if (error_message_.empty()) {
2848 return llvm::makeArrayRef(leaf_index_paths_);
2849 }
2850 return errors::InvalidArgument(
2851 error_context.str(), error_message_,
2852 "This likely means that you have @tf.function "
2853 "on an exported function instead of "
2854 "@tf.function(input_signature=[...]). Consider annotating an "
2855 "input_signature or narrowing your set of "
2856 "exported names to not include this function.");
2857 }
2858
RecursivelyFindLeaves(const StructuredValue & value)2859 void StructuredValueLinearizer::RecursivelyFindLeaves(
2860 const StructuredValue& value) {
2861 switch (value.kind_case()) {
2862 case StructuredValue::kDictValue: {
2863 // Dict values must be linearized in sorted order of keys.
2864 const DictValue& dict = value.dict_value();
2865 using FieldTy = protobuf::MapPair<std::string, StructuredValue>;
2866 llvm::SmallVector<const FieldTy*, 4> fields;
2867 for (auto& field : dict.fields()) {
2868 fields.push_back(&field);
2869 }
2870 llvm::sort(fields, [](const FieldTy* a, const FieldTy* b) {
2871 return a->first < b->first;
2872 });
2873 for (auto& field : fields) {
2874 current_index_path_.push_back(builder_.getStringAttr(field->first));
2875 RecursivelyFindLeaves(field->second);
2876 current_index_path_.pop_back();
2877 }
2878 return;
2879 }
2880 case StructuredValue::kTupleValue: {
2881 const TupleValue& tuple = value.tuple_value();
2882 for (int i = 0, e = tuple.values_size(); i < e; i++) {
2883 current_index_path_.push_back(builder_.getI64IntegerAttr(i));
2884 RecursivelyFindLeaves(tuple.values(i));
2885 current_index_path_.pop_back();
2886 }
2887 return;
2888 }
2889 // We don't differentiate between tuples and lists.
2890 case StructuredValue::kListValue: {
2891 const ListValue& list = value.list_value();
2892 for (int i = 0, e = list.values_size(); i < e; i++) {
2893 current_index_path_.push_back(builder_.getI64IntegerAttr(i));
2894 RecursivelyFindLeaves(list.values(i));
2895 current_index_path_.pop_back();
2896 }
2897 return;
2898 }
2899 case StructuredValue::kTensorSpecValue: {
2900 // Base case: record the current path stack as the index path needed to
2901 // get to this leaf.
2902 leaf_index_paths_.push_back(builder_.getArrayAttr(current_index_path_));
2903 return;
2904 }
2905 case StructuredValue::kNoneValue: {
2906 // Base case: do nothing.
2907 // This arises, for example, as the top-level object of an output
2908 // signature when there are no return values.
2909 return;
2910 }
2911 default: {
2912 llvm::raw_string_ostream os(error_message_);
2913 // TODO(silvasean): Use an enumerant name string instead of a number.
2914 os << "Unhandled structured value kind " << value.kind_case()
2915 << " at index path: <value>";
2916 for (auto path_element : current_index_path_) {
2917 os << ".";
2918 if (auto integer = path_element.dyn_cast<mlir::IntegerAttr>()) {
2919 os << integer.getValue();
2920 } else {
2921 auto str = path_element.cast<mlir::StringAttr>();
2922 os << str.getValue();
2923 }
2924 }
2925 os << "\n";
2926 }
2927 }
2928 }
2929
2930 // For exported functions with bound inputs, rewrite the function
2931 // signature to match the requirements of tf_saved_model bound input args.
2932 //
2933 // The raw imported functions have `tensor<*x!tf.resource>` as the type for
2934 // mutable bound inputs and `tensor<...>` as the type for immutable
2935 // bound inputs. Here we canonicalize both of them into
2936 // `tensor<!tf.resource<tensor<...>>>`.
AdjustBoundInputArgTypes(mlir::ModuleOp module)2937 void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
2938 mlir::SymbolTable symbol_table(module);
2939 for (auto func : module.getOps<mlir::FuncOp>()) {
2940 if (!mlir::tf_saved_model::IsExported(func)) continue;
2941 mlir::OpBuilder builder(func.getBody());
2942 llvm::SmallVector<mlir::Type, 4> new_input_types;
2943 for (int i = 0, e = func.getNumArguments(); i < e; i++) {
2944 auto arg = func.getArgument(i);
2945 auto global_tensor = mlir::tf_saved_model::LookupBoundInputOfType<
2946 mlir::tf_saved_model::GlobalTensorOp>(func, i, symbol_table);
2947 if (global_tensor) {
2948 auto old_type = arg.getType();
2949 auto new_type =
2950 mlir::tf_saved_model::GetBoundInputArgTypeFor(global_tensor);
2951 arg.setType(new_type);
2952 if (global_tensor.is_mutable()) {
2953 auto arg_with_original_type = builder.create<mlir::TF::CastOp>(
2954 global_tensor.getLoc(), old_type, arg,
2955 /*Truncate=*/builder.getBoolAttr(false));
2956 arg.replaceAllUsesWith(arg_with_original_type);
2957 // The RAUW replaces the arg with itself, so we need to set it back.
2958 arg_with_original_type.setOperand(arg);
2959 } else {
2960 auto arg_with_original_type =
2961 builder.create<mlir::TF::ReadVariableOp>(global_tensor.getLoc(),
2962 old_type, arg);
2963 arg.replaceAllUsesWith(arg_with_original_type);
2964 // The RAUW replaces the arg with itself, so we need to set it back.
2965 arg_with_original_type.setOperand(arg);
2966 }
2967 }
2968 new_input_types.push_back(arg.getType());
2969 }
2970 func.setType(mlir::FunctionType::get(module.getContext(), new_input_types,
2971 func.getType().getResults()));
2972 }
2973 }
2974
2975 // Marks the visibility of functions in the saved model module.
MarkSavedModelFunctionVisibility(mlir::ModuleOp module)2976 void MarkSavedModelFunctionVisibility(mlir::ModuleOp module) {
2977 for (auto func : module.getOps<mlir::FuncOp>()) {
2978 auto visibility = mlir::tf_saved_model::IsExported(func)
2979 ? mlir::FuncOp::Visibility::Public
2980 : mlir::FuncOp::Visibility::Private;
2981 func.setVisibility(visibility);
2982 }
2983 }
2984
2985 // Reorder the ops in the module to make testing easier and less dependent
2986 // on implementation details such as the order of functions in the
2987 // FunctionDefLibrary.
2988 //
2989 // The order this ensures is:
2990 // 1. GlobalTensorOp's
2991 // 2. FuncOps's.
2992 //
2993 // Within each of 1. and 2., ops are sorted by exported name (if
2994 // available, and only the first exported name is considered), followed by
2995 // non-exported ops.
SortSavedModelModule(mlir::ModuleOp module)2996 void SortSavedModelModule(mlir::ModuleOp module) {
2997 struct NamedGlobalTensor {
2998 llvm::StringRef name;
2999 GlobalTensorOp global_tensor;
3000 };
3001 llvm::SmallVector<NamedGlobalTensor, 8> named_global_tensors;
3002 for (auto global_tensor : module.getOps<GlobalTensorOp>()) {
3003 auto exported_names = mlir::tf_saved_model::GetExportedNames(global_tensor);
3004 // We use stable_sort, so duplicate empty names are fine here.
3005 named_global_tensors.push_back(
3006 {exported_names.empty() ? "" : exported_names.front(), global_tensor});
3007 }
3008 llvm::stable_sort(named_global_tensors,
3009 [](const NamedGlobalTensor& a, const NamedGlobalTensor& b) {
3010 return std::make_tuple(a.name.empty(), a.name) <
3011 std::make_tuple(b.name.empty(), b.name);
3012 });
3013
3014 struct NamedFunc {
3015 llvm::StringRef name;
3016 mlir::FuncOp func;
3017 };
3018 llvm::SmallVector<NamedFunc, 8> named_funcs;
3019 llvm::SmallVector<mlir::FuncOp, 8> private_funcs;
3020 for (auto func : module.getOps<mlir::FuncOp>()) {
3021 auto exported_names = mlir::tf_saved_model::GetExportedNames(func);
3022 if (!exported_names.empty())
3023 named_funcs.push_back({exported_names.front(), func});
3024 else
3025 private_funcs.push_back(func);
3026 }
3027 llvm::stable_sort(named_funcs, [](const NamedFunc& a, const NamedFunc& b) {
3028 return a.name < b.name;
3029 });
3030 llvm::stable_sort(private_funcs, [](mlir::FuncOp a, mlir::FuncOp b) {
3031 return a.getName() < b.getName();
3032 });
3033
3034 struct NamedAsset {
3035 llvm::StringRef name;
3036 AssetOp asset;
3037 };
3038 llvm::SmallVector<NamedAsset, 4> assets;
3039 for (auto asset : module.getOps<AssetOp>()) {
3040 assets.push_back({asset.getName(), asset});
3041 }
3042 llvm::stable_sort(assets, [](const NamedAsset& a, const NamedAsset& b) {
3043 return a.name < b.name;
3044 });
3045
3046 // Move onto the front of the module in reverse of the final desired order.
3047 for (auto func : llvm::reverse(private_funcs)) {
3048 func.getOperation()->moveBefore(&module.getBody()->front());
3049 }
3050 for (auto named_func : llvm::reverse(named_funcs)) {
3051 named_func.func.getOperation()->moveBefore(&module.getBody()->front());
3052 }
3053 for (auto named_global_tensor : llvm::reverse(named_global_tensors)) {
3054 named_global_tensor.global_tensor.getOperation()->moveBefore(
3055 &module.getBody()->front());
3056 }
3057
3058 for (auto asset : assets) {
3059 asset.asset.getOperation()->moveBefore(&module.getBody()->front());
3060 }
3061
3062 auto initializers = module.getOps<SessionInitializerOp>();
3063 if (!initializers.empty()) {
3064 (*initializers.begin())
3065 .getOperation()
3066 ->moveBefore(&module.getBody()->front());
3067 }
3068 }
3069
CreateSavedModelIR(const ObjectNames & object_names,mlir::ModuleOp module,const SavedObjectGraph & object_graph,const std::unordered_map<std::string,std::string> & tf_name_to_mlir_name,SavedModelV2Bundle * saved_model)3070 Status CreateSavedModelIR(
3071 const ObjectNames& object_names, mlir::ModuleOp module,
3072 const SavedObjectGraph& object_graph,
3073 const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name,
3074 SavedModelV2Bundle* saved_model) {
3075 mlir::OpBuilder builder(module.getBodyRegion());
3076 mlir::SymbolTable symbol_table(module);
3077
3078 // Create a side data-structure, indexed by the object_graph node_id to
3079 // a TrackableObject that is restorable.
3080 absl::flat_hash_map<int, const TrackableObjectGraph::TrackableObject*>
3081 restored_objects;
3082 TF_RETURN_IF_ERROR(saved_model->VisitObjectsToRestore(
3083 [&](int saved_node_id,
3084 const TrackableObjectGraph::TrackableObject& trackable_object) {
3085 restored_objects.insert(
3086 std::make_pair(saved_node_id, &trackable_object));
3087 return Status::OK();
3088 }));
3089
3090 for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
3091 const SavedObject& object = object_graph.nodes(node_id);
3092 // For correctness, we cannot import functions that don't have exported
3093 // names, since they don't necessarily have a well-defined ABI (diagnosed
3094 // earlier).
3095 //
3096 // For variables/constants, pruning them is purely an optimization,
3097 // and more complicated since it requires use-def analysis of which
3098 // functions use which variables/constants, so we don't do anything
3099 // special for them here as part of our initial IR construction.
3100 if (object.kind_case() == SavedObject::kFunction) {
3101 if (object_names.GetExportedNames(node_id).empty()) {
3102 continue;
3103 }
3104 std::string error_context =
3105 "While importing SavedModel function '" +
3106 object_names.GetExportedNames(node_id)[0].str() + "': ";
3107 const SavedFunction& function = object.function();
3108 auto orig_func = symbol_table.lookup<mlir::FuncOp>(
3109 tf_name_to_mlir_name.find(function.concrete_functions(0))->second);
3110 mlir::FuncOp func = orig_func;
3111 // If there are potentially references to this func from within the
3112 // module, create a wrapper around it and decorate the wrapper with the
3113 // tf_saved_model attributes instead.
3114 if (!mlir::SymbolTable::symbolKnownUseEmpty(orig_func.getName(),
3115 &module.getBodyRegion())) {
3116 func = orig_func.cloneWithoutRegions();
3117 module.insert(module.getBody()->begin(), func);
3118 func.addEntryBlock();
3119 func.setName("__sm_exported_" + orig_func.getName().str());
3120 llvm::SmallVector<mlir::Value, 4> args_as_values;
3121 for (auto block_argument : func.getArguments()) {
3122 args_as_values.push_back(block_argument);
3123 }
3124 mlir::OpBuilder body_builder(&func.getBody());
3125 auto call = body_builder.create<mlir::TF::StatefulPartitionedCallOp>(
3126 func.getLoc(), orig_func.getType().getResults(), args_as_values,
3127 builder.getSymbolRefAttr(orig_func.getName()),
3128 /*config=*/builder.getStringAttr(""),
3129 /*config_proto=*/builder.getStringAttr(""),
3130 /*executor_type=*/builder.getStringAttr(""));
3131 body_builder.create<mlir::ReturnOp>(func.getLoc(), call.getResults());
3132 }
3133 func->setAttr(
3134 "tf_saved_model.exported_names",
3135 builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3136 const SavedConcreteFunction& concrete_function =
3137 object_graph.concrete_functions().at(function.concrete_functions(0));
3138
3139 // We do not handle the other element of this tuple, which corresponds to
3140 // Python kwonlyargs, since currently TensorFlow prohibits this in
3141 // combination with input_signature:
3142 // https://github.com/tensorflow/tensorflow/blob/8cb8627abb5ef83a6fba34f8fd0e4ee430562eb1/tensorflow/python/eager/function.py#L2027-L2030
3143 // Our SavedModel import requires input_signature on the tf.function, so
3144 // we never need to handle the kwonlyargs.
3145 auto positional_arg_structure =
3146 concrete_function.canonicalized_input_signature()
3147 .tuple_value()
3148 .values(0);
3149 StructuredValueLinearizer input_linearizer(positional_arg_structure,
3150 builder.getContext());
3151
3152 int bound_input_base =
3153 func.getNumArguments() - concrete_function.bound_inputs_size();
3154 TF_ASSIGN_OR_RETURN(auto input_index_paths,
3155 input_linearizer.GetLeafIndexPaths(
3156 error_context + "in input signature: "));
3157 const int input_index_paths_size = input_index_paths.size();
3158 if (bound_input_base != input_index_paths_size) {
3159 return errors::InvalidArgument(
3160 error_context,
3161 "Argument mismatch between concrete function input signature "
3162 "vs underlying FunctionDef for concrete function '",
3163 function.concrete_functions(0), "' (", input_index_paths.size(),
3164 " vs ", bound_input_base, ")");
3165 }
3166 for (auto index_path : llvm::enumerate(input_index_paths)) {
3167 func.setArgAttr(index_path.index(), "tf_saved_model.index_path",
3168 index_path.value());
3169 }
3170
3171 for (auto& bound_input :
3172 llvm::enumerate(concrete_function.bound_inputs())) {
3173 int arg_index = bound_input_base + bound_input.index();
3174 auto symbol_ref = builder.getSymbolRefAttr(
3175 object_names.GetSymbolTableName(bound_input.value()));
3176 func.setArgAttr(arg_index, "tf_saved_model.bound_input", symbol_ref);
3177 }
3178
3179 StructuredValueLinearizer output_linearizer(
3180 concrete_function.output_signature(), builder.getContext());
3181 TF_ASSIGN_OR_RETURN(auto output_index_paths,
3182 output_linearizer.GetLeafIndexPaths(
3183 error_context + "in output signature: "));
3184 if (func.getNumResults() != output_index_paths.size()) {
3185 return errors::InvalidArgument(
3186 error_context,
3187 "Result mismatch between concrete function output signature "
3188 "vs underlying FunctionDef for concrete function '",
3189 function.concrete_functions(0), "' (", output_index_paths.size(),
3190 " vs ", func.getNumResults(), ")");
3191 }
3192 for (auto index_path : llvm::enumerate(output_index_paths)) {
3193 func.setResultAttr(index_path.index(), "tf_saved_model.index_path",
3194 index_path.value());
3195 }
3196 } else if (object.kind_case() == SavedObject::kVariable) {
3197 const SavedVariable& variable = object.variable();
3198 // Find the trackable in the side data structure.
3199 auto variable_trackable_it = restored_objects.find(node_id);
3200 if (variable_trackable_it == restored_objects.end()) {
3201 return errors::FailedPrecondition("Could not restore saved variable: ",
3202 variable.name());
3203 }
3204 const auto* serialized_tensor_attr = FindSerializedTensorInTrackable(
3205 *variable_trackable_it->second, "VARIABLE_VALUE");
3206 if (!serialized_tensor_attr) {
3207 return errors::FailedPrecondition(
3208 "Could not find serialized tensor for saved variable: ",
3209 variable.name());
3210 }
3211 const auto& checkpoint_key = serialized_tensor_attr->checkpoint_key();
3212
3213 // Load it from the reader.
3214 Tensor value;
3215 TF_RETURN_WITH_CONTEXT_IF_ERROR(
3216 saved_model->variable_reader()->Lookup(checkpoint_key, &value),
3217 "Could not read checkpoint key from variables bundle: ",
3218 checkpoint_key);
3219 TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder));
3220 // A variable can have a partially known type, such as tensor<?x27x?xf32>,
3221 // even if the initializer is a specific static shape.
3222 TF_ASSIGN_OR_RETURN(
3223 auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(),
3224 &builder));
3225 auto op = builder.create<GlobalTensorOp>(
3226 builder.getUnknownLoc(),
3227 builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
3228 value_attr,
3229 /*type=*/mlir::TypeAttr::get(type),
3230 /*is_mutable=*/builder.getUnitAttr());
3231 op->setAttr(
3232 "tf_saved_model.exported_names",
3233 builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3234 } else if (object.kind_case() == SavedObject::kConstant) {
3235 const SavedConstant& constant = object.constant();
3236 const TensorProto* value = ExtractConstTensorFromGraph(
3237 saved_model->meta_graph_def().graph_def(), constant.operation());
3238 if (!value) {
3239 return errors::FailedPrecondition(
3240 "Unable to find const node referenced in object graph: ",
3241 constant.operation());
3242 }
3243 TF_ASSIGN_OR_RETURN(auto value_attr,
3244 ConvertTensorProto(*value, &builder));
3245 auto op = builder.create<GlobalTensorOp>(
3246 builder.getUnknownLoc(),
3247 builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
3248 value_attr,
3249 /*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()),
3250 /*is_mutable=*/nullptr);
3251 op->setAttr(
3252 "tf_saved_model.exported_names",
3253 builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3254 }
3255 }
3256 AdjustBoundInputArgTypes(module);
3257 module->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
3258 SortSavedModelModule(module);
3259 MarkSavedModelFunctionVisibility(module);
3260 return Status::OK();
3261 }
3262
Convert(SavedModelV2Bundle * saved_model,absl::Span<std::string> exported_names,mlir::MLIRContext * context,bool add_default_attributes)3263 StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
3264 SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
3265 mlir::MLIRContext* context, bool add_default_attributes) {
3266 LoadImporterDialects(*context);
3267 GraphDebugInfo dummy_debug_info;
3268 const GraphDebugInfo& debug_info =
3269 saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
3270
3271 GraphImportConfig specs;
3272 specs.prune_unused_nodes = true;
3273 mlir::OwningModuleRef module =
3274 mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
3275 std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3276
3277 const auto& graphdef = saved_model->meta_graph_def().graph_def();
3278 PopulateTfVersions(module.get(), graphdef.versions());
3279
3280 GraphConstructorOptions options;
3281 options.allow_internal_ops = true;
3282 options.add_default_attributes = add_default_attributes;
3283 Graph graph(OpRegistry::Global());
3284
3285 GraphDef preprocessed_graphdef(graphdef);
3286 if (add_default_attributes) {
3287 TF_RETURN_IF_ERROR(PreprocessGraphDef(nullptr, &preprocessed_graphdef));
3288 }
3289
3290 TF_RETURN_IF_ERROR(
3291 ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph));
3292
3293 NameUniquifier function_name_uniquifier(graph.flib_def());
3294 SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs,
3295 module.get(), &tf_name_to_mlir_name,
3296 &function_name_uniquifier);
3297
3298 TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
3299
3300 auto fn_names = graph.flib_def().ListFunctionNames();
3301 for (const auto& fn_name : fn_names) {
3302 TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name));
3303 }
3304
3305 if (!saved_model->meta_graph_def().has_object_graph_def()) {
3306 return errors::InvalidArgument(
3307 "SavedModel does not have an object graph. Please use TF2.");
3308 }
3309 auto& object_graph = saved_model->meta_graph_def().object_graph_def();
3310 ObjectNames object_names(object_graph, exported_names);
3311
3312 // Clean up a couple func's that always seem to be present when importing a
3313 // SavedModel. This is not strictly needed, as there is a separate pass that
3314 // will clean them up, but this makes staring at the raw IR of minimal
3315 // examples quite a bit nicer.
3316 for (auto func : llvm::make_early_inc_range(module->getOps<mlir::FuncOp>())) {
3317 if (func.getName().startswith("__inference__traced_save_") ||
3318 func.getName().startswith("__inference__traced_restore_") ||
3319 func.getName().startswith("__inference_signature_wrapper_")) {
3320 func.erase();
3321 }
3322 }
3323
3324 // Diagnose SavedFunction's with multiple input signatures.
3325 TF_RETURN_IF_ERROR(
3326 DiagnoseMultipleConcreteFunctions(object_graph, object_names));
3327
3328 // Construct the SavedModel IR.
3329 TF_RETURN_IF_ERROR(CreateSavedModelIR(object_names, module.get(),
3330 object_graph, tf_name_to_mlir_name,
3331 saved_model));
3332 assert(mlir::succeeded(mlir::verify(module.get())));
3333
3334 return module;
3335 }
3336
3337 class SimpleSavedModelMLIRImportInput : public SavedModelMLIRImportInput {
3338 public:
Create(const MLIRImportOptions & import_options,const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info)3339 static StatusOr<SimpleSavedModelMLIRImportInput> Create(
3340 const MLIRImportOptions& import_options,
3341 const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info) {
3342 DCHECK(meta_graph_def);
3343 GraphDef graph_def;
3344 if (import_options.enable_grappler) {
3345 // Grappler is best-effort.
3346 auto statusor = RunGrappler(*meta_graph_def);
3347 if (statusor.ok()) {
3348 graph_def = std::move(statusor).ValueOrDie();
3349 } else {
3350 // If the grappler fails, use the original graph def.
3351 LOG(WARNING) << "SimpleSavedModelMLIRImportInput: grappler failed: "
3352 << statusor.status();
3353 graph_def = meta_graph_def->graph_def();
3354 }
3355 } else {
3356 graph_def = meta_graph_def->graph_def();
3357 }
3358
3359 auto graph = std::make_unique<Graph>(OpRegistry::Global());
3360
3361 if (import_options.upgrade_legacy) {
3362 TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
3363 graph_def, graph->flib_def().default_registry()));
3364 }
3365
3366 GraphConstructorOptions graph_ctor_options;
3367 graph_ctor_options.allow_internal_ops = true;
3368 graph_ctor_options.add_default_attributes = true;
3369 TF_RETURN_IF_ERROR(
3370 ConvertGraphDefToGraph(graph_ctor_options, graph_def, graph.get()));
3371
3372 if (import_options.upgrade_legacy) {
3373 // TODO(jpienaar): Remove need to const_cast.
3374 TF_RETURN_IF_ERROR(UpgradeLegacyGraph(
3375 graph.get(),
3376 const_cast<FunctionLibraryDefinition*>(&graph->flib_def()),
3377 /*restrict_functionalization_to_tpu_nodes=*/false));
3378 }
3379
3380 return SimpleSavedModelMLIRImportInput(meta_graph_def, debug_info,
3381 std::move(graph));
3382 }
3383
SimpleSavedModelMLIRImportInput(const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info,std::unique_ptr<Graph> graph)3384 SimpleSavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def,
3385 const GraphDebugInfo& debug_info,
3386 std::unique_ptr<Graph> graph)
3387 : SavedModelMLIRImportInput(meta_graph_def, debug_info),
3388 graph_(std::move(graph)) {}
3389
GetSubGraph(absl::string_view name,const GraphImportConfig & specs)3390 StatusOr<const Graph*> GetSubGraph(absl::string_view name,
3391 const GraphImportConfig& specs) override {
3392 DCHECK(CheckGraphNameValidity(name));
3393 DCHECK(CheckGraphContainsFeedsAndFetches(specs));
3394 return graph_.get();
3395 }
3396
3397 private:
CheckGraphContainsFeedsAndFetches(const GraphImportConfig & specs) const3398 bool CheckGraphContainsFeedsAndFetches(const GraphImportConfig& specs) const {
3399 absl::flat_hash_set<std::string> feed_fetch_nodes;
3400 for (const auto& iter : specs.inputs) {
3401 TensorId tensor_id = ParseTensorName(iter.first);
3402 feed_fetch_nodes.insert(std::string(tensor_id.node()));
3403 }
3404 for (const auto& output : llvm::concat<const std::string>(
3405 specs.outputs, specs.control_outputs)) {
3406 TensorId tensor_id = ParseTensorName(output);
3407 feed_fetch_nodes.insert(std::string(tensor_id.node()));
3408 }
3409
3410 for (Node* node : graph_->op_nodes()) {
3411 feed_fetch_nodes.erase(node->name());
3412 }
3413
3414 return feed_fetch_nodes.empty();
3415 }
3416
CheckGraphNameValidity(absl::string_view name) const3417 bool CheckGraphNameValidity(absl::string_view name) const {
3418 // If it is one of the signature name, it is valid.
3419 const auto& signature_defs = meta_graph_def().signature_def();
3420 if (signature_defs.contains(std::string(name))) return true;
3421
3422 // If it is the restore graph name, it is valid.
3423 if (meta_graph_def().has_saver_def() &&
3424 meta_graph_def().saver_def().restore_op_name() == name)
3425 return true;
3426
3427 // If it is the init graph name, it is valid.
3428 std::string init_op_name;
3429 if (internal::GetInitOp("", meta_graph_def(), &init_op_name).ok()) {
3430 if (init_op_name == name) return true;
3431 }
3432
3433 return false;
3434 }
3435
3436 // `graph_` contains the entire graph in the original MetaGraphDef.
3437 std::unique_ptr<Graph> graph_;
3438 };
3439
3440 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
3441 // an MLIR Module in SavedModel dialect.
3442 //
3443 // TODO(b/179683149): Rename this class to avoid confusion with TFLite.
3444 class SavedModelSignatureDefImporterLite {
3445 public:
3446 // Main entry point: converts all functions (specified by SignatureDefs) in
3447 // the given meta graph to an MLIR Module.
3448 //
3449 // `import_restore` is introduced to control whether restore graph
3450 // is imported in eg. SavedModelSignatureDefImporter. Ideally, we don't need
3451 // this option to control this as restore graph should be always imported.
3452 // However, right now, SavedModelSignatureDefImporter cannot handle restore
3453 // graph correctly.
3454 //
3455 // TODO(chky): Remove import_restore once the restore graph is correctly
3456 // handled in SavedModelSignatureDefImporter.
Convert(SavedModelMLIRImportInput & input,absl::Span<std::string> exported_names,mlir::MLIRContext * context,bool import_restore=true)3457 static StatusOr<mlir::OwningModuleRef> Convert(
3458 SavedModelMLIRImportInput& input, absl::Span<std::string> exported_names,
3459 mlir::MLIRContext* context, bool import_restore = true) {
3460 LoadImporterDialects(*context);
3461 SavedModelSignatureDefImporterLite importer(input, exported_names, context,
3462 import_restore);
3463 TF_ASSIGN_OR_RETURN(auto module, importer.ConvertSignatures());
3464
3465 SortSavedModelModule(*module);
3466 MarkSavedModelFunctionVisibility(*module);
3467
3468 return module;
3469 }
3470
3471 private:
SavedModelSignatureDefImporterLite(SavedModelMLIRImportInput & input,absl::Span<std::string> exported_names,mlir::MLIRContext * context,bool import_restore)3472 SavedModelSignatureDefImporterLite(SavedModelMLIRImportInput& input,
3473 absl::Span<std::string> exported_names,
3474 mlir::MLIRContext* context,
3475 bool import_restore)
3476 : input_(input),
3477 exported_names_(exported_names),
3478 module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))),
3479 symbol_table_(module_.get()),
3480 import_restore_(import_restore) {}
3481
3482 // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
3483 // for each signature.
3484 StatusOr<mlir::OwningModuleRef> ConvertSignatures();
3485 Status ConvertSignature(const std::string& sig_def_key,
3486 const SignatureDef& signature_def);
3487
3488 struct AssetInfo {
3489 std::string tensor_name;
3490 mlir::tf_saved_model::AssetOp op;
3491 };
3492 StatusOr<std::vector<AssetInfo>> ConvertAssets();
3493 // Converts the initialization graph in the SavedModel to an MLIR function.
3494 Status ConvertInitializer(const std::string& target_node_name,
3495 const std::vector<AssetInfo>& assets);
3496
3497 // Converts a graph with feeds and fetches to an MLIR function.
3498 StatusOr<mlir::OwningModuleRef> ConvertGraph(
3499 const std::string& name,
3500 const std::vector<std::pair<std::string, TensorInfo>>& inputs,
3501 const std::vector<std::pair<std::string, TensorInfo>>& outputs,
3502 const std::vector<std::string> control_outputs);
3503
3504 // Moves the functions in `sub_module` to `module_` and skips the duplicate
3505 // functions.
3506 Status MoveConvertedFunctionsToModule(absl::string_view name,
3507 mlir::ModuleOp sub_module);
3508
3509 GraphImportConfig::InputArrays ParseInputArrays(
3510 llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs);
3511
3512 private:
3513 SavedModelMLIRImportInput& input_;
3514 absl::Span<std::string> exported_names_;
3515 mlir::OwningModuleRef module_;
3516 mlir::SymbolTable symbol_table_;
3517 bool import_restore_ = true;
3518 };
3519
3520 StatusOr<std::vector<SavedModelSignatureDefImporterLite::AssetInfo>>
ConvertAssets()3521 SavedModelSignatureDefImporterLite::ConvertAssets() {
3522 std::vector<AssetFileDef> asset_file_defs;
3523 TF_RETURN_IF_ERROR(
3524 internal::GetAssetFileDefs(input_.meta_graph_def(), &asset_file_defs));
3525
3526 std::vector<AssetInfo> results;
3527 results.reserve(asset_file_defs.size());
3528
3529 mlir::OpBuilder builder(module_->getBodyRegion());
3530 unsigned i = 0; // Use to generate unique sym_name(s) for duplicate assets.
3531 for (const auto& asset : asset_file_defs) {
3532 auto asset_op = builder.create<mlir::tf_saved_model::AssetOp>(
3533 module_->getLoc(),
3534 /*sym_name=*/
3535 builder.getStringAttr(
3536 absl::StrCat("__tf_saved_model_asset", i++, "_", asset.filename())),
3537 /*filename=*/
3538 builder.getStringAttr(
3539 io::JoinPath(kSavedModelAssetsDirectory, asset.filename())));
3540
3541 results.push_back({asset.tensor_info().name(), asset_op});
3542 }
3543
3544 return results;
3545 }
3546
MoveConvertedFunctionsToModule(absl::string_view name,mlir::ModuleOp sub_module)3547 Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
3548 absl::string_view name, mlir::ModuleOp sub_module) {
3549 mlir::Builder builder(sub_module.getContext());
3550 mlir::SymbolTable sub_module_symbol_table(sub_module);
3551
3552 // Prefix private functions with the unique signature name, so that it cannot
3553 // collide with private functions used in the other signatures.
3554 for (auto func : sub_module.getOps<mlir::FuncOp>()) {
3555 if (mlir::tf_saved_model::IsExported(func)) continue;
3556
3557 std::string new_sym_name = absl::StrCat(name, "/", func.sym_name().str());
3558 if (mlir::failed(sub_module_symbol_table.replaceAllSymbolUses(
3559 func, new_sym_name, sub_module)))
3560 return tensorflow::errors::InvalidArgument(absl::StrCat(
3561 "SavedModelSignatureDefImporterLite: failed to assign a unique "
3562 "name to the private function used in a signature: ",
3563 func.sym_name().str()));
3564
3565 mlir::SymbolTable::setSymbolName(func, new_sym_name);
3566 }
3567
3568 // Copy all functions used by this signature to the final MLIR module.
3569 for (auto func : sub_module.getOps<mlir::FuncOp>()) {
3570 DCHECK(symbol_table_.lookup(func.sym_name()) == nullptr);
3571 symbol_table_.insert(func.clone());
3572 }
3573
3574 return Status::OK();
3575 }
3576
ConvertInitializer(const std::string & target_node_name,const std::vector<AssetInfo> & assets)3577 Status SavedModelSignatureDefImporterLite::ConvertInitializer(
3578 const std::string& target_node_name, const std::vector<AssetInfo>& assets) {
3579 std::vector<std::pair<std::string, TensorInfo>> inputs;
3580 inputs.reserve(assets.size());
3581 for (const auto& asset : assets) {
3582 TensorInfo tensor_info;
3583 tensor_info.set_name(asset.tensor_name);
3584 tensor_info.set_dtype(DT_STRING);
3585 tensor_info.mutable_tensor_shape();
3586 inputs.push_back({asset.tensor_name, tensor_info});
3587 }
3588
3589 TF_ASSIGN_OR_RETURN(auto sub_module, ConvertGraph(target_node_name, inputs,
3590 {}, {target_node_name}));
3591
3592 mlir::SymbolTable sub_symbol_table(*sub_module);
3593
3594 auto init_func_op = sub_symbol_table.lookup<mlir::FuncOp>(target_node_name);
3595 init_func_op.removeAttr("tf.entry_function");
3596
3597 mlir::OpBuilder builder(module_->getBodyRegion());
3598
3599 // Bind asset inputs to asset ops.
3600 DCHECK_EQ(init_func_op.getNumArguments(), assets.size());
3601 for (const auto& iter : llvm::enumerate(assets)) {
3602 auto asset_op = iter.value().op;
3603 init_func_op.setArgAttr(iter.index(), "tf_saved_model.bound_input",
3604 builder.getSymbolRefAttr(asset_op.getName()));
3605 }
3606
3607 // Set the exported name of init function to an reserved name for
3608 // tf_saved_model.
3609 init_func_op->setAttr(
3610 "tf_saved_model.exported_names",
3611 builder.getStrArrayAttr({absl::StrCat(
3612 "__tf_saved_model_session_initializer_", target_node_name)}));
3613
3614 // Move the converted functions to top level MLIR module.
3615 return MoveConvertedFunctionsToModule(target_node_name, *sub_module);
3616 }
3617
3618 StatusOr<mlir::OwningModuleRef>
ConvertGraph(const std::string & name,const std::vector<std::pair<std::string,TensorInfo>> & inputs,const std::vector<std::pair<std::string,TensorInfo>> & outputs,const std::vector<std::string> control_outputs)3619 SavedModelSignatureDefImporterLite::ConvertGraph(
3620 const std::string& name,
3621 const std::vector<std::pair<std::string, TensorInfo>>& inputs,
3622 const std::vector<std::pair<std::string, TensorInfo>>& outputs,
3623 const std::vector<std::string> control_outputs) {
3624 VLOG(1) << "Importing Signature: " << name;
3625
3626 GraphImportConfig specs;
3627 specs.prune_unused_nodes = true;
3628 specs.inputs = ParseInputArrays(inputs);
3629 for (auto& output : outputs) specs.outputs.push_back(output.second.name());
3630 specs.control_outputs = control_outputs;
3631
3632 TF_ASSIGN_OR_RETURN(const auto* subgraph, input_.GetSubGraph(name, specs));
3633
3634 // Convert sub-graph to MLIR module.
3635 return GraphDefImporter::Convert(module_->getContext(), *subgraph,
3636 input_.debug_info(), subgraph->flib_def(),
3637 specs, name);
3638 }
3639
ConvertSignature(const std::string & sig_def_key,const SignatureDef & signature_def)3640 Status SavedModelSignatureDefImporterLite::ConvertSignature(
3641 const std::string& sig_def_key, const SignatureDef& signature_def) {
3642 // Create local vectors for the input and output and sort them to be
3643 // deterministic. We don't want anyone to really depend on the order, client
3644 // should lookup argument/result mapping by attribute name.
3645 // To avoid accidentally depending on the order we use an unintuitive sorting.
3646 std::vector<std::pair<std::string, TensorInfo>> inputs(
3647 signature_def.inputs().begin(), signature_def.inputs().end());
3648 llvm::sort(inputs, [](const auto& lhs, const auto& rhs) {
3649 return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
3650 });
3651 std::vector<std::pair<std::string, TensorInfo>> outputs(
3652 signature_def.outputs().begin(), signature_def.outputs().end());
3653 llvm::sort(outputs, [](const auto& lhs, const auto& rhs) {
3654 return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
3655 });
3656
3657 // Convert sub-graph to MLIR module.
3658 TF_ASSIGN_OR_RETURN(auto sub_module,
3659 ConvertGraph(sig_def_key, inputs, outputs, {}));
3660 mlir::OpBuilder builder(sub_module->getBodyRegion());
3661
3662 // Find the FuncOp which corresponds to current SignatureDef.
3663 mlir::SymbolTable sub_symbol_table(*sub_module);
3664 auto func_op = sub_symbol_table.lookup<mlir::FuncOp>(sig_def_key);
3665 TF_RET_CHECK(func_op)
3666 << "Graphdef importer should have created a function named "
3667 << sig_def_key << ".";
3668
3669 // Use unique SignatureDef key as exported name.
3670 func_op->setAttr("tf_saved_model.exported_names",
3671 builder.getStrArrayAttr({sig_def_key}));
3672
3673 // Transfer input and output parameter names to index_path attributes.
3674 for (auto input_and_idx : llvm::enumerate(inputs)) {
3675 func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path",
3676 builder.getStrArrayAttr({input_and_idx.value().first}));
3677 }
3678 for (auto output_and_idx : llvm::enumerate(outputs)) {
3679 func_op.setResultAttr(
3680 output_and_idx.index(), "tf_saved_model.index_path",
3681 builder.getStrArrayAttr({output_and_idx.value().first}));
3682 }
3683
3684 // Move the converted functions to top level MLIR module.
3685 return MoveConvertedFunctionsToModule(sig_def_key, *sub_module);
3686 }
3687
3688 GraphImportConfig::InputArrays
ParseInputArrays(llvm::ArrayRef<std::pair<std::string,TensorInfo>> inputs)3689 SavedModelSignatureDefImporterLite::ParseInputArrays(
3690 llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs) {
3691 GraphImportConfig::InputArrays results;
3692 for (const auto& iter : inputs) {
3693 const auto& tensor_info = iter.second;
3694
3695 // Only dense tensor is supported.
3696 DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName);
3697
3698 VLOG(1) << "Importing Signature Input: input_name = " << iter.first
3699 << ", tensor_info = " << tensor_info.DebugString();
3700
3701 ArrayInfo array_info;
3702 array_info.imported_dtype = tensor_info.dtype();
3703
3704 if (tensor_info.has_tensor_shape()) {
3705 array_info.shape = tensor_info.tensor_shape();
3706 } else {
3707 // If there is no tensor shape in the tensor info, conservatively set
3708 // unknown_rank to true.
3709 array_info.shape.set_unknown_rank(true);
3710 }
3711
3712 results.insert(std::pair<std::string, ArrayInfo>(tensor_info.name(),
3713 std::move(array_info)));
3714 }
3715 return results;
3716 }
3717
3718 StatusOr<mlir::OwningModuleRef>
ConvertSignatures()3719 SavedModelSignatureDefImporterLite::ConvertSignatures() {
3720 const auto& signatures = input_.meta_graph_def().signature_def();
3721 PopulateTfVersions(module_.get(),
3722 input_.meta_graph_def().graph_def().versions());
3723
3724 llvm::DenseSet<llvm::StringRef> exported_name_set;
3725 exported_name_set.insert(exported_names_.begin(), exported_names_.end());
3726
3727 for (const auto& key_and_signature_def : signatures) {
3728 const std::string& sig_def_key = key_and_signature_def.first;
3729 const SignatureDef& signature_def = key_and_signature_def.second;
3730
3731 // It is safe to skip "__saved_model_init_op" since it is an internal
3732 // signature that is not user-accessible. This signature will be handled in
3733 // ConvertInitializer().
3734 if (sig_def_key == "__saved_model_init_op") {
3735 continue;
3736 }
3737 if (!exported_name_set.empty() &&
3738 exported_name_set.count(sig_def_key) == 0) {
3739 continue;
3740 }
3741
3742 TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def));
3743 }
3744
3745 TF_ASSIGN_OR_RETURN(auto assets, ConvertAssets());
3746
3747 mlir::OpBuilder builder(module_->getBodyRegion());
3748 llvm::SmallVector<mlir::Attribute, 2> init_sym_refs;
3749
3750 if (import_restore_ && input_.meta_graph_def().has_saver_def()) {
3751 std::vector<AssetInfo> variable_and_assets;
3752
3753 // Create an AssetOp for the variable checkpoint files. The relative
3754 // filename is used here.
3755 auto variable_filename_op = builder.create<mlir::tf_saved_model::AssetOp>(
3756 module_->getLoc(),
3757 /*sym_name=*/
3758 builder.getStringAttr("__tf_saved_model_variables"),
3759 /*filename=*/
3760 builder.getStringAttr(io::JoinPath(kSavedModelVariablesDirectory,
3761 kSavedModelVariablesFilename)));
3762 variable_and_assets.push_back(
3763 {input_.meta_graph_def().saver_def().filename_tensor_name(),
3764 variable_filename_op});
3765 variable_and_assets.insert(variable_and_assets.end(), assets.begin(),
3766 assets.end());
3767
3768 const auto& restore_op_name =
3769 input_.meta_graph_def().saver_def().restore_op_name();
3770 TF_RETURN_IF_ERROR(
3771 ConvertInitializer(restore_op_name, variable_and_assets));
3772 init_sym_refs.push_back(builder.getSymbolRefAttr(restore_op_name));
3773 }
3774
3775 std::string init_op_name;
3776 TF_RETURN_IF_ERROR(
3777 internal::GetInitOp("", input_.meta_graph_def(), &init_op_name));
3778 if (!init_op_name.empty()) {
3779 TF_RETURN_IF_ERROR(ConvertInitializer(init_op_name, assets));
3780 init_sym_refs.push_back(builder.getSymbolRefAttr(init_op_name));
3781 }
3782
3783 builder.create<mlir::tf_saved_model::SessionInitializerOp>(
3784 module_->getLoc(), builder.getArrayAttr(init_sym_refs));
3785
3786 (*module_)->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
3787
3788 SortSavedModelModule(*module_);
3789 MarkSavedModelFunctionVisibility(*module_);
3790
3791 return std::move(module_);
3792 }
3793
3794 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
3795 // an MLIR Module in SavedModel dialect. In addition to importing the model, it
3796 // performs a few graph transformations, including:
3797 // 1) Convert read-only ref variables to resource variables
3798 // 2) Lift resource variables to global_tensors by using a TF session.
3799 class SavedModelSignatureDefImporter {
3800 public:
3801 // Main entry point: converts all functions (specified by SignatureDefs) in
3802 // the given meta graph to an MLIR Module.
Convert(const SavedModelBundle & bundle,absl::Span<std::string> exported_names,mlir::MLIRContext * context,tensorflow::MLIRImportOptions options)3803 static StatusOr<mlir::OwningModuleRef> Convert(
3804 const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
3805 mlir::MLIRContext* context, tensorflow::MLIRImportOptions options) {
3806 // debug_info might not be loaded with loader_lite.
3807 GraphDebugInfo debug_info;
3808 if (bundle.debug_info != nullptr) debug_info = *bundle.debug_info;
3809
3810 TF_ASSIGN_OR_RETURN(auto input,
3811 SimpleSavedModelMLIRImportInput::Create(
3812 options, &bundle.meta_graph_def, debug_info));
3813
3814 TF_ASSIGN_OR_RETURN(auto module,
3815 SavedModelSignatureDefImporterLite::Convert(
3816 input, exported_names, context,
3817 /*import_restore=*/false));
3818
3819 mlir::OpBuilder builder(module->getContext());
3820 (*module)->setAttr("tf_saved_model.under_construction",
3821 builder.getUnitAttr());
3822 TF_RETURN_IF_ERROR(LiftVariables(bundle, *module));
3823 module->removeAttr("tf_saved_model.under_construction");
3824
3825 return module;
3826 }
3827
3828 private:
3829 // Lifts the variables in `module`.
3830 static Status LiftVariables(const SavedModelBundle& bundle,
3831 mlir::ModuleOp module);
3832 };
3833
LiftVariables(const SavedModelBundle & bundle,mlir::ModuleOp module)3834 Status SavedModelSignatureDefImporter::LiftVariables(
3835 const SavedModelBundle& bundle, mlir::ModuleOp module) {
3836 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
3837
3838 mlir::PassManager pm(module.getContext());
3839 SetCrashReproducer(pm);
3840 pm.addNestedPass<mlir::FuncOp>(
3841 mlir::tf_executor::CreateTFExecutorGraphPruningPass());
3842 pm.addNestedPass<mlir::FuncOp>(
3843 mlir::CreateExecutorDialectToFunctionalConversionPass());
3844 pm.addPass(
3845 mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass());
3846 pm.addNestedPass<mlir::FuncOp>(
3847 mlir::TF::
3848 CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
3849 pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
3850 pm.addPass(
3851 mlir::tf_saved_model::CreateLiftVariablesPass(bundle.GetSession()));
3852 pm.addNestedPass<mlir::FuncOp>(
3853 mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
3854 if (mlir::failed(pm.run(module)))
3855 return diag_handler.Combine(errors::Internal("Failed to lift variables."));
3856
3857 return Status::OK();
3858 }
3859
3860 } // namespace
3861
~SavedModelMLIRImportInput()3862 SavedModelMLIRImportInput::~SavedModelMLIRImportInput() {}
3863
ConvertGraphdefToMlir(const GraphDef & graphdef,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::MLIRContext * context,bool add_default_attributes)3864 StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
3865 const GraphDef& graphdef, const GraphDebugInfo& debug_info,
3866 const GraphImportConfig& specs, mlir::MLIRContext* context,
3867 bool add_default_attributes) {
3868 GraphConstructorOptions options;
3869 options.allow_internal_ops = true;
3870 options.add_default_attributes = add_default_attributes;
3871 Graph graph(OpRegistry::Global());
3872
3873 GraphDef preprocessed_graphdef(graphdef);
3874 if (add_default_attributes) {
3875 TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef));
3876 }
3877 if (specs.upgrade_legacy) {
3878 TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
3879 preprocessed_graphdef, graph.flib_def().default_registry()));
3880 }
3881 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
3882 options, std::move(preprocessed_graphdef), &graph));
3883 return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs,
3884 context);
3885 }
3886
ConvertGraphToMlir(const Graph & graph,const GraphDebugInfo & debug_info,const FunctionLibraryDefinition & flib_def,const GraphImportConfig & specs,mlir::MLIRContext * context)3887 StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
3888 const Graph& graph, const GraphDebugInfo& debug_info,
3889 const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
3890 mlir::MLIRContext* context) {
3891 // TODO(jpienaar): Remove need to const_cast.
3892 if (specs.upgrade_legacy) {
3893 TF_RETURN_IF_ERROR(
3894 UpgradeLegacyGraph(const_cast<Graph*>(&graph),
3895 const_cast<FunctionLibraryDefinition*>(&flib_def),
3896 specs.restrict_functionalization_to_tpu_nodes));
3897 }
3898 return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs,
3899 /*func_name=*/"main");
3900 }
3901
ConvertFunctionToMlir(const FunctionBody * fbody,const FunctionLibraryDefinition & flib_def,mlir::MLIRContext * context)3902 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
3903 const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def,
3904 mlir::MLIRContext* context) {
3905 tensorflow::GraphDebugInfo dummy_debug_info;
3906 tensorflow::GraphImportConfig specs;
3907 specs.enable_shape_inference = false;
3908 specs.graph_as_function = true;
3909 for (const auto* control_ret_node : fbody->control_ret_nodes)
3910 specs.control_outputs.push_back(control_ret_node->name());
3911 return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
3912 flib_def, specs,
3913 fbody->fdef.signature().name());
3914 }
3915
ConvertSavedModelToMlir(SavedModelV2Bundle * saved_model,mlir::MLIRContext * context,absl::Span<std::string> exported_names,bool add_default_attributes)3916 StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
3917 SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
3918 absl::Span<std::string> exported_names, bool add_default_attributes) {
3919 return SavedModelObjectGraphImporter::Convert(
3920 saved_model, exported_names, context, add_default_attributes);
3921 }
3922
ConvertSavedModelV1ToMlir(const SavedModelBundle & saved_model,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)3923 StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
3924 const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
3925 mlir::MLIRContext* context, MLIRImportOptions options) {
3926 return SavedModelSignatureDefImporter::Convert(saved_model, exported_names,
3927 context, options);
3928 }
3929
ConvertSavedModelV1ToMlirLite(const MetaGraphDef & meta_graph_def,const GraphDebugInfo & debug_info,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)3930 StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite(
3931 const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
3932 absl::Span<std::string> exported_names, mlir::MLIRContext* context,
3933 MLIRImportOptions options) {
3934 TF_ASSIGN_OR_RETURN(auto input, SimpleSavedModelMLIRImportInput::Create(
3935 options, &meta_graph_def, debug_info));
3936 return ConvertSavedModelV1ToMlirLite(input, exported_names, context);
3937 }
3938
ConvertSavedModelV1ToMlirLite(SavedModelMLIRImportInput & input,absl::Span<std::string> exported_names,mlir::MLIRContext * context)3939 StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite(
3940 SavedModelMLIRImportInput& input, absl::Span<std::string> exported_names,
3941 mlir::MLIRContext* context) {
3942 return SavedModelSignatureDefImporterLite::Convert(input, exported_names,
3943 context);
3944 }
3945
MlirModuleToString(mlir::ModuleOp module,mlir::OpPrintingFlags flags)3946 std::string MlirModuleToString(mlir::ModuleOp module,
3947 mlir::OpPrintingFlags flags) {
3948 std::string txt_module;
3949 {
3950 llvm::raw_string_ostream os{txt_module};
3951 module.print(os, flags);
3952 }
3953 return txt_module;
3954 }
3955
MlirModuleToString(mlir::ModuleOp module,bool show_debug_info)3956 std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
3957 mlir::OpPrintingFlags flags;
3958 if (show_debug_info) flags.enableDebugInfo();
3959 return MlirModuleToString(module, flags);
3960 }
3961
3962 } // namespace tensorflow
3963