1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
17 
18 #include <iterator>
19 #include <memory>
20 #include <string>
21 #include <tuple>
22 #include <type_traits>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/container/inlined_vector.h"
31 #include "absl/strings/escaping.h"
32 #include "absl/strings/numbers.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/string_view.h"
36 #include "absl/strings/strip.h"
37 #include "llvm/ADT/ArrayRef.h"
38 #include "llvm/ADT/DenseMap.h"
39 #include "llvm/ADT/DenseSet.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SetVector.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringRef.h"
44 #include "llvm/ADT/StringSet.h"
45 #include "llvm/ADT/Twine.h"
46 #include "llvm/Support/FormatVariadic.h"
47 #include "llvm/Support/SourceMgr.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
50 #include "mlir/IR/Attributes.h"  // from @llvm-project
51 #include "mlir/IR/Builders.h"  // from @llvm-project
52 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
53 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
54 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
55 #include "mlir/IR/Identifier.h"  // from @llvm-project
56 #include "mlir/IR/Location.h"  // from @llvm-project
57 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
58 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
59 #include "mlir/IR/Types.h"  // from @llvm-project
60 #include "mlir/IR/Verifier.h"  // from @llvm-project
61 #include "mlir/Pass/PassManager.h"  // from @llvm-project
62 #include "tensorflow/cc/saved_model/constants.h"
63 #include "tensorflow/cc/saved_model/loader_util.h"
64 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
65 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
66 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
71 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
72 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
73 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
74 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
75 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
76 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
77 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
78 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
79 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
80 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
81 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
82 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
83 #include "tensorflow/compiler/xla/status_macros.h"
84 #include "tensorflow/core/common_runtime/function.h"
85 #include "tensorflow/core/common_runtime/graph_constructor.h"
86 #include "tensorflow/core/common_runtime/shape_refiner.h"
87 #include "tensorflow/core/framework/attr_value.pb.h"
88 #include "tensorflow/core/framework/function.pb.h"
89 #include "tensorflow/core/framework/graph.pb.h"
90 #include "tensorflow/core/framework/node_def.pb.h"
91 #include "tensorflow/core/framework/node_def_util.h"
92 #include "tensorflow/core/framework/op.h"
93 #include "tensorflow/core/framework/resource_var.h"
94 #include "tensorflow/core/framework/shape_inference.h"
95 #include "tensorflow/core/framework/tensor.pb.h"
96 #include "tensorflow/core/framework/types.h"
97 #include "tensorflow/core/framework/types.pb.h"
98 #include "tensorflow/core/framework/versions.pb.h"
99 #include "tensorflow/core/graph/algorithm.h"
100 #include "tensorflow/core/graph/graph.h"
101 #include "tensorflow/core/graph/node_builder.h"
102 #include "tensorflow/core/graph/tensor_id.h"
103 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
104 #include "tensorflow/core/lib/core/errors.h"
105 #include "tensorflow/core/lib/strings/str_util.h"
106 #include "tensorflow/core/platform/errors.h"
107 #include "tensorflow/core/platform/path.h"
108 #include "tensorflow/core/platform/protobuf.h"
109 #include "tensorflow/core/platform/types.h"
110 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
111 #include "tensorflow/core/protobuf/meta_graph.pb.h"
112 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
113 #include "tensorflow/core/protobuf/saver.pb.h"
114 #include "tensorflow/core/protobuf/struct.pb.h"
115 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
116 #include "tensorflow/stream_executor/lib/statusor.h"
117 
StringRefToView(llvm::StringRef ref)118 static inline absl::string_view StringRefToView(llvm::StringRef ref) {
119   return {ref.data(), ref.size()};
120 }
121 
122 namespace tensorflow {
123 using mlir::NamedAttrList;
124 using mlir::TensorType;
125 using mlir::tf_saved_model::AssetOp;
126 using mlir::tf_saved_model::GlobalTensorOp;
127 using mlir::tf_saved_model::SessionInitializerOp;
128 using stream_executor::port::StatusOr;
129 
130 namespace {
131 
IsOutputShapesAttribute(const AttrValue & attr_value,llvm::StringRef attr_name)132 bool IsOutputShapesAttribute(const AttrValue& attr_value,
133                              llvm::StringRef attr_name) {
134   return attr_name.compare("_output_shapes") == 0 &&
135          attr_value.value_case() == AttrValue::kList;
136 }
137 
IsResourceOutputShapesAttribute(const AttrValue & attr_value,llvm::StringRef attr_name)138 bool IsResourceOutputShapesAttribute(const AttrValue& attr_value,
139                                      llvm::StringRef attr_name) {
140   if (attr_name == "_handle_dtypes" || attr_name == "_handle_shapes")
141     return attr_value.value_case() == AttrValue::kList;
142   return false;
143 }
144 
LoadImporterDialects(mlir::MLIRContext & context)145 void LoadImporterDialects(mlir::MLIRContext& context) {
146   // Load dialects involved in the conversion
147   mlir::DialectRegistry registry;
148   mlir::RegisterAllTensorFlowDialects(registry);
149   context.appendDialectRegistry(registry);
150   context.loadAllAvailableDialects();
151 }
152 
153 // This class is used to generate new MLIR function name strings that are both
154 // unique in the TF function library `flib_` and unique among the name strings
155 // generated by the class object during its lifetime.
156 //
157 // In theory, this class is not necessary because we should simply take
158 // the TF function name and use it as MLIR function name. However, for some
159 // unknown reasons (callout for investigation in b/142268695), keeping the
160 // function names unchanged in an MLIR roundtrip causes test failures.
161 // TODO(b/142268695) Re-evaluate whether we need this class v.s. directly using
162 // and TF function name as MLIR function name after b/142268695 is root caused.
163 class NameUniquifier : public OpOrArgNameMapper {
164  public:
NameUniquifier(const FunctionLibraryDefinition & flib)165   explicit NameUniquifier(const FunctionLibraryDefinition& flib)
166       : flib_(flib) {}
167 
168  private:
IsUnique(llvm::StringRef name)169   bool IsUnique(llvm::StringRef name) override {
170     return !flib_.Contains(std::string(name));
171   }
172 
GetName(OpOrVal op_or_val)173   std::string GetName(OpOrVal op_or_val) override {
174     DCHECK(false) << "Unimplemented";
175     return "";
176   }
177 
178   const FunctionLibraryDefinition& flib_;
179 };
180 
181 // Stateful helper class to import a TensorFlow model into an MLIR Module.
182 //
183 // This is the base class that contains common utilities shared between the
184 // GraphDef importer and SavedModel importer.
185 //
186 // A subclass is expected to call `PrepareConvert` first to perform necessary
187 // preparation over the graph and also certain internal bookkeeping data.
188 // Afterwards the other protected methods can be called.
189 class ImporterBase {
190  protected:
ImporterBase(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier,llvm::StringRef function_name_for_debug_info="")191   explicit ImporterBase(
192       const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
193       const GraphImportConfig& specs, mlir::ModuleOp module,
194       std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
195       NameUniquifier* function_name_uniquifier,
196       llvm::StringRef function_name_for_debug_info = "")
197       : builder_(module.getContext()),
198         module_(module),
199         context_(module.getContext()),
200         tf_name_to_mlir_name_(tf_name_to_mlir_name),
201         graph_flib_(flib),
202         specs_(specs),
203         debug_info_(debug_info),
204         function_name_for_debug_info_(function_name_for_debug_info),
205         function_name_uniquifier_(function_name_uniquifier),
206         error_handler_(module.getContext()) {}
207 
208   // Returns the inferred function signature of the given function body. Input
209   // types are unranked tensor of the respective datatype in the function and
210   // result types are inferred by the shape_refiner_. Result types need not be
211   // unranked tensors and could be ranked tensors in cases where result type
212   // depends on an op with static output shape like tf.Const.
213   StatusOr<mlir::FunctionType> InferLibFunctionType(const FunctionBody& fbody);
214 
215   // Extracts arg and ret nodes from FunctionBody.
216   void GetArgsAndRetsFromFunctionBody(
217       const FunctionBody& fbody,
218       absl::InlinedVector<OutputTensor, 4>* arg_nodes,
219       absl::InlinedVector<OutputTensor, 4>* ret_nodes,
220       absl::InlinedVector<Node*, 4>* control_ret_nodes);
221 
222   // Prepares converting the graph to an MLIR module. This step removes the
223   // backedges of the graph, orders the nodes and infers the shapes.
224   Status PrepareConvert(const Graph& graph);
225 
226   // Converts the prepared graph to a Function and adds it to the module. A set
227   // of nodes from the graph are given to converted to the arguments and returns
228   // of the function.
229   Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type,
230                  const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
231                  const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
232                  const absl::InlinedVector<Node*, 4>& control_ret_nodes,
233                  llvm::ArrayRef<mlir::NamedAttribute> attrs);
234 
235   // Finds out the function definition for the given function name from the
236   // graph and converts it to a function of the module. This method is called
237   // on demand because the graph flib_def does not provide an iterator
238   // interface.
239   Status ConvertLibFunction(llvm::StringRef func_name);
240 
241   // Returns the list of nodes in the graph. Nodes are presented in the reverse
242   // order of a post-order depth-first visit starting from the graph's source
243   // nodes.
GetOrderedNodes() const244   llvm::ArrayRef<Node*> GetOrderedNodes() const { return ordered_nodes_; }
245 
246   // Returns the inferred input type at index `idx` of the `node` in the
247   // context.
248   StatusOr<mlir::Type> InferInputType(const Node& node, int idx,
249                                       mlir::Builder builder);
250 
251   // Returns the inferred output type at index `idx` of the `node` in the
252   // context.
253   StatusOr<mlir::Type> InferOutputType(const Node& node, int idx,
254                                        mlir::Builder builder);
255 
256  private:
257   // Most types with subtypes have only one subtype.
258   using ElementSubtypes = llvm::SmallVector<TensorType, 1>;
259 
260   // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all
261   // data type and shape information is maintained by the shape_refiner_.
262   // TODO(jpienaar): Remove once shape inference on import is removed.
263   Status AddNodesToShapeRefiner(
264       std::unordered_map<string, Node*>* node_name_map);
265 
266   // Prune nodes that do not feed into fetch nodes.
267   Status PruneUnreachableNodes(
268       std::unordered_map<string, Node*>* node_name_map);
269 
270   // Converts feeds to Placeholder nodes.
271   Status ConvertFeedsToPlaceholders(
272       std::unordered_map<string, Node*>* node_name_map);
273 
274   // Converts the inferred shape referred to by 'handle' in 'context', with
275   // given element type, and returns an MLIR tensor type.
276   StatusOr<TensorType> ConvertDataTypeAndShape(
277       DataType dtype, const shape_inference::ShapeHandle& handle,
278       const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
279       shape_inference::InferenceContext* context, mlir::Builder builder);
280 
281   // Converts the inferred shape referred to by 'handle' in 'context', with
282   // given element type, and returns an MLIR tensor type.
283   StatusOr<TensorType> ConvertElementTypeAndShape(
284       mlir::Type element_type, const shape_inference::ShapeHandle& handle,
285       shape_inference::InferenceContext* context, mlir::Builder builder);
286 
287   // Converts the inferred subtypes for an element type to corresponding MLIR
288   // types in 'context'.
289   StatusOr<ElementSubtypes> ConvertSubtypes(
290       const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
291       shape_inference::InferenceContext* context, mlir::Builder builder);
292 
293   // Converts the tensor proto into an MLIR elements attribute.
ConvertTensorProto(const TensorProto & value)294   StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& value) {
295     return ::tensorflow::ConvertTensorProto(value, &builder_);
296   }
297 
298   // Converts func name in graphdef to mlir::SymbolRefAttribute.
299   StatusOr<mlir::FlatSymbolRefAttr> ConvertFunctionCallName(
300       const std::string& func_name);
301 
302   // Converts the given non-function-call AttrValue to an MLIR Attribute.
303   StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value);
304 
305   // Converts the given function-call AttrValue to MLIR Attributes and pushes
306   // them to the given attributes list. For example, if there is a kFunc
307   // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
308   // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
309   // {base_name.k2 : rfc}}.
310   Status ConvertFunctionCallAttribute(const std::string& base_name,
311                                       const AttrValue& value,
312                                       NamedAttrList* attributes);
313 
314   // Helper to create either a tf_executor operation or a TF operation wrapped
315   // in an island.
316   mlir::Operation* CreateOperation(
317       const Node& node, llvm::StringRef node_type_name,
318       const mlir::OperationState& result,
319       const llvm::SmallVectorImpl<mlir::Value>& control_operands);
320 
321   // Converts one NodeDef from the input GraphDef into an Operation and
322   // inserts it into the MLIR module using builder_.
323   Status ConvertNode(const Node& node);
324 
325   // If the input graph represents a while-loop, the edges pointing from a
326   // "NextIteration" node to a "Merge" node add cyclic dependencies and make the
327   // topological sorting impossible. We need to remove these edges from the
328   // input graph to infer shapes and construct a Function. For each
329   // "NextIteration" node, there are two operations, "NextIteration.source"
330   // and "NextIteration.sink" are added to the MLIR module.
331   using BackEdge = BackEdgeHelper::BackEdge;
332 
333   // Removes backedges from the input graph. The removed edges are added back to
334   // to OpBuilder after the remaining graph is converted to the Function.
335   Status RemoveBackedges(const Graph& graph);
336 
337   // Restores backedges removed during shape inference to the final Function.
338   Status AddBackedges();
339 
340   // Restores a single backedge in the Function by adding a replicated
341   // operation before the dst operation.
342   Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
343                      int dst_input);
344 
345   // Adds the input arguments and return operation to the function. The
346   // arguments are added as basic block argument. Also the argument types and
347   // the id of the nodes from the input graph needs to be specified.
348   Status ConvertFunctionArgAndRets(
349       mlir::FuncOp func, mlir::tf_executor::GraphOp graph_op,
350       llvm::ArrayRef<mlir::Type> arg_types,
351       const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
352       const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
353       const absl::InlinedVector<Node*, 4>& control_ret_nodes);
354 
355   // Gets the location information of the given node. It uses the
356   // "original_node_name" in the NodeDef to get the corresponding file location
357   // (FileLineColLoc) from the input DebugInfo and returns an CallSiteLoc. If
358   // there are multiple "original_node_names", a FusedLoc is returned. If the
359   // node name couldn't be found in the input DebugInfo, a NameLoc is used as
360   // the location.
361   mlir::Location GetLocation(const Node& node);
362 
363   // Appends the location string for the node to the error message and returns
364   // the combined error status.
365   Status EmitErrorWithLocationStr(const Node& node, const Status& error_status);
366 
367   // Inserts a placeholder node in the graph to replace a feed output tensor,
368   // and returns the new placeholder node and a boolean indicating if the
369   // original input node was removed from the graph. Uses of the feed output
370   // tensor are replaced with this placeholder node. If the feed output tensor
371   // is of a single output node, the control dependencies are forwarded to the
372   // the placeholder node, and the original node will be removed.
373   // Note: This modifies the graph, and so any list of ordered nodes needs to be
374   // reconstructed.
375   StatusOr<std::pair<Node*, bool>> CreatePlaceholderNodeForFeed(
376       const TensorShapeProto& shape, DataType dtype, Node* node, int index,
377       const std::unordered_map<string, Node*>& node_name_map);
378 
379   // Gets the input and output nodes corresponding to the specified input and
380   // output nodes in specs_. If there are no input or output nodes specified,
381   // nodes will be empty.
382   Status GetInputOutputNodes(
383       const std::unordered_map<string, Node*>& node_name_map,
384       std::unordered_set<const Node*>* nodes);
385 
GetUnmodelledOpTypes()386   llvm::StringSet<>& GetUnmodelledOpTypes() {
387     // All the TF ops encountered that aren't modelled in dialect.
388     static auto* unmodelled_op_types = new llvm::StringSet<>();
389     return *unmodelled_op_types;
390   }
391 
392   // The input graph with backedges removed. The removed backedges are stored
393   // in the back_edge_helper.
394   BackEdgeHelper back_edge_helper_;
395   // A map between node and output index, for each backedge.
396   absl::flat_hash_map<const Node*, int> back_edge_node_output_;
397   absl::flat_hash_map<const Node*, BackEdge> back_edge_dst_inputs_;
398   // A map between sink and source operation of NextIteration
399   absl::flat_hash_map<mlir::Operation*, mlir::Operation*>
400       next_iteration_sink_source_;
401 
402   // All nodes and version information about the (copied) imported graph.
403   std::unique_ptr<Graph> graph_;
404   std::vector<Node*> ordered_nodes_;
405 
406   // Maps from a Node ID to a MLIR value.
407   using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>;
408 
409   mlir::OpBuilder builder_;
410   mlir::ModuleOp module_;
411   mlir::MLIRContext* context_;
412   std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
413   const FunctionLibraryDefinition& graph_flib_;
414   const GraphImportConfig& specs_;
415   const GraphDebugInfo& debug_info_;
416   llvm::StringRef function_name_for_debug_info_;
417   NodeValueMap node_values_;
418   // TODO(jpienaar): Remove once shape inference on import is removed.
419   // The shape_refinner_ will be nullptr if shape inference on import is
420   // not enabled.
421   std::unique_ptr<ShapeRefiner> shape_refiner_ = nullptr;
422   NameUniquifier* function_name_uniquifier_;
423   mlir::StatusScopedDiagnosticHandler error_handler_;
424 
425  protected:
426   // Maps feed as TensorId to new Placeholder node name.
427   absl::flat_hash_map<TensorId, absl::string_view> remapped_feeds_;
428 };
429 
430 // Returns true if the node with given name has a non primary output that is
431 // used by some other node as an input. Returns false if no outputs are in use
432 // or only the first output is in use.
HasNonPrimaryOutputInUse(const GraphDef & graph_def,const std::string & node)433 bool HasNonPrimaryOutputInUse(const GraphDef& graph_def,
434                               const std::string& node) {
435   for (const auto& node_def : graph_def.node()) {
436     for (const auto& input : node_def.input()) {
437       if (absl::StartsWith(input, node + ":") && input != node + ":0") {
438         return true;
439       }
440     }
441   }
442   return false;
443 }
444 
445 // Updates the given LegacyFedInput node with Placeholder node if it is one of
446 // the inputs. Returns an error if non primary output of the LegacyFedInput node
447 // is in use and therefore can not be replaced by the Placeholder node that only
448 // has a single output.
UpdateLegacyFedInputNode(const GraphDef & graph_def,const GraphImportConfig::InputArrays & inputs,NodeDef * node)449 Status UpdateLegacyFedInputNode(const GraphDef& graph_def,
450                                 const GraphImportConfig::InputArrays& inputs,
451                                 NodeDef* node) {
452   const std::string& node_name = node->name();
453   auto it = inputs.find(node_name);
454 
455   // Node is not an input.
456   if (it == inputs.end()) return Status::OK();
457 
458   if (HasNonPrimaryOutputInUse(graph_def, node_name)) {
459     return errors::InvalidArgument(
460         "LegacyFedInput node ", node->name(),
461         " has non primary output in use and can not be replaced with "
462         "Placeholder node");
463   }
464 
465   DataType dtype = it->second.imported_dtype;
466   // Uses the existing output type if it isn't specified by the user.
467   if (dtype == DT_INVALID) {
468     dtype = node->attr().at("output_types").list().type(0);
469   }
470   // Update op name, drop inputs and set attributes required by the Placeholder
471   // op.
472   *node->mutable_op() = "Placeholder";
473   node->clear_attr();
474   node->clear_input();
475   AddNodeAttr("dtype", dtype, node);
476   AddNodeAttr("shape", it->second.shape, node);
477   return Status::OK();
478 }
479 
480 // Preprocesses GraphDef before it can be converted to Graph by,
481 // - Adding the default attributes to each node def if they are missing from
482 //   the GraphDef.
483 // - Replacing LegacyFedInput nodes with Placeholder nodes if
484 //   convert_legacy_fed_inputs option is enabled.
PreprocessGraphDef(const GraphImportConfig * specs,GraphDef * graph_def)485 Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) {
486   for (auto& node_def : *graph_def->mutable_node()) {
487     // TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One
488     // solution could be have a tool to let users upgrade old serialized graphs.
489     if (specs && specs->convert_legacy_fed_inputs &&
490         node_def.op() == "LegacyFedInput") {
491       TF_RETURN_IF_ERROR(
492           UpdateLegacyFedInputNode(*graph_def, specs->inputs, &node_def));
493     }
494 
495     const tensorflow::OpRegistrationData* op_reg_data =
496         tensorflow::OpRegistry::Global()->LookUp(node_def.op());
497     if (!op_reg_data) {
498       // This is likely a function call node, so we should continue.
499       continue;
500     }
501     ::tensorflow::AddDefaultsToNodeDef(op_reg_data->op_def, &node_def);
502   }
503   return Status::OK();
504 }
505 
506 // Mapping from node name to feed (index and ArrayInfo). Node name must outlive
507 // this map.
508 using FeedsByNode = absl::flat_hash_map<
509     absl::string_view,
510     absl::flat_hash_map<int, const std::pair<std::string, ArrayInfo>*>>;
511 
512 // Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output
513 // tensor name to index and ArrayInfo. Keys and values are backed by
514 // `GraphImportConfig::InputArrays`.
GetFeedsByNode(const GraphImportConfig::InputArrays & inputs)515 StatusOr<FeedsByNode> GetFeedsByNode(
516     const GraphImportConfig::InputArrays& inputs) {
517   FeedsByNode feeds_by_node;
518   feeds_by_node.reserve(inputs.size());
519 
520   for (const auto& input : inputs) {
521     TensorId tensor = ParseTensorName(input.first);
522     if (tensor.index() < 0)
523       return errors::FailedPrecondition(
524           "Feed output tensor must be a data output '", tensor.ToString(), "'");
525 
526     auto& node = feeds_by_node[tensor.node()];
527     if (!node.insert({tensor.index(), &input}).second)
528       return errors::FailedPrecondition(
529           "Multiple feeds for the same output tensor '", tensor.ToString(),
530           "'");
531   }
532 
533   return feeds_by_node;
534 }
535 
536 // Creates a unique name for a node that will be replacing a feed output tensor.
GetUniqueNodeName(absl::string_view node_name,int index,const std::unordered_map<string,Node * > & node_name_map)537 std::string GetUniqueNodeName(
538     absl::string_view node_name, int index,
539     const std::unordered_map<string, Node*>& node_name_map) {
540   std::string new_node_name_base = absl::StrCat(node_name, "_", index);
541   int count = 0;
542   std::string new_node_name = new_node_name_base;
543   while (node_name_map.find(new_node_name) != node_name_map.end()) {
544     new_node_name = absl::StrCat(new_node_name_base, "_", count++);
545   }
546   return new_node_name;
547 }
548 
RemoveBackedges(const Graph & graph)549 Status ImporterBase::RemoveBackedges(const Graph& graph) {
550   // TODO(fengliuai): Converting to GraphDef and back is the easiest way to
551   // clone a graph.
552   // TODO(fengliuai): clone the graph without going to graph_def first.
553   GraphDef graph_def;
554   graph.ToGraphDef(&graph_def);
555   graph_ = absl::make_unique<Graph>(graph.flib_def());
556   GraphConstructorOptions opts;
557   opts.allow_internal_ops = true;
558   opts.add_default_attributes = false;
559   TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph(
560       opts, std::move(graph_def), graph_.get()));
561 
562   // Remove all the backedges. So the nodes can be added to the shape refiner.
563   TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get()));
564   VLOG(1) << "Found " << (back_edge_helper_.RemovedEdges().size())
565           << " backedges.";
566 
567   // Creates a map for quickly identifying whether a node output is a backedge.
568   for (const auto& edge : back_edge_helper_.RemovedEdges()) {
569     if (back_edge_node_output_.find(edge.src) != back_edge_node_output_.end() &&
570         back_edge_node_output_[edge.src] != edge.src_output) {
571       return errors::FailedPrecondition(
572           "More than one of the src node outputs are backedges!");
573     }
574     back_edge_node_output_[edge.src] = edge.src_output;
575     // We expect a merge to receive a single backedge (multiple NextIteration
576     // nodes feeding into the same merge is unexpected here).
577     DCHECK(!back_edge_dst_inputs_.contains(edge.dst));
578     back_edge_dst_inputs_[edge.dst] = edge;
579   }
580 
581   // Obtains a RPO ordering, using node names as a tiebreak for stable sorting.
582   GetReversePostOrder(
583       *graph_, &ordered_nodes_,
584       [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
585   return Status::OK();
586 }
587 
CopyStackTraces(const Graph & from,Graph * to)588 Status CopyStackTraces(const Graph& from, Graph* to) {
589   // Copy over the stack traces.
590   // TODO(jpienaar): This really shouldn't be needed, copying the Graph above
591   // and then needing these traversals is unfortunate.
592   std::unordered_map<string, Node*> node_map = from.BuildNodeNameIndex();
593   for (Node* node : to->nodes()) {
594     if (const Node* old_node = node_map[node->name()]) {
595       if (const std::shared_ptr<AbstractStackTrace>& stack =
596               old_node->GetStackTrace()) {
597         DVLOG(2) << "Stack for " << node->name() << " "
598                  << old_node->GetStackTrace()->ToString(
599                         AbstractStackTrace::TracePrintingOptions());
600         node->SetStackTrace(stack);
601       } else {
602         DVLOG(1) << "No stack for " << node->name() << " (" << node
603                  << ") in Graph " << &from;
604       }
605     } else {
606       DVLOG(1) << "No stack for " << node->name() << " (" << node
607                << ") in Graph " << &from;
608     }
609   }
610 
611   return Status::OK();
612 }
613 
CreatePlaceholderNodeForFeed(const TensorShapeProto & shape,DataType dtype,Node * node,int index,const std::unordered_map<string,Node * > & node_name_map)614 StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed(
615     const TensorShapeProto& shape, DataType dtype, Node* node, int index,
616     const std::unordered_map<string, Node*>& node_name_map) {
617   DCHECK_LT(index, node->num_outputs());
618   const bool update_inplace = node->num_outputs() == 1 && index == 0;
619   std::string new_node_name =
620       update_inplace ? node->name()
621                      : GetUniqueNodeName(node->name(), index, node_name_map);
622 
623   Node* placeholder_node;
624   NodeBuilder builder(new_node_name, "Placeholder");
625   builder.Attr("shape", shape);
626   builder.Attr("dtype", dtype);
627   TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node));
628 
629   // Update edges from original feed with Placeholder node.
630   std::vector<const Edge*> data_edges;
631   std::vector<const Edge*> control_edges;
632   for (const tensorflow::Edge* edge : node->out_edges()) {
633     if (edge->src_output() == index) {
634       data_edges.push_back(edge);
635     } else if (update_inplace && edge->IsControlEdge()) {
636       control_edges.push_back(edge);
637     }
638   }
639 
640   for (const auto* edge : data_edges) {
641     TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(),
642                                           edge->dst_input()));
643   }
644 
645   // TODO(lyandy): Preserve control dependencies properly by not forwarding
646   // control dependencies to data outputs and not removing single output nodes.
647   // When a data output is replaced as a feed, unless there is another non feed
648   // data output or an explicit control output used by the same node, transitive
649   // control dependencies are not to be executed. For single output nodes,
650   // Placeholders can be converted to a NoOp if there are no uses, and
651   // PlaceholderWithDefault can be converted to an Identity.
652   for (const auto* edge : control_edges) {
653     graph_->AddControlEdge(placeholder_node, edge->dst());
654     graph_->RemoveControlEdge(edge);
655   }
656 
657   if (update_inplace) {
658     graph_->RemoveNode(node);
659   }
660 
661   return std::pair<Node*, bool>(placeholder_node, update_inplace);
662 }
663 
GetInputOutputNodes(const std::unordered_map<string,Node * > & node_name_map,std::unordered_set<const Node * > * nodes)664 Status ImporterBase::GetInputOutputNodes(
665     const std::unordered_map<string, Node*>& node_name_map,
666     std::unordered_set<const Node*>* nodes) {
667   auto add_node = [&](absl::string_view name) {
668     auto it = node_name_map.find(std::string(name));
669     if (it == node_name_map.end()) {
670       return errors::FailedPrecondition(
671           absl::StrCat("Graph does not contain node: ", name));
672     }
673     nodes->insert(it->second);
674     return Status::OK();
675   };
676 
677   // Remap feeds and fetches to newly created Placeholder nodes.
678   for (const auto& input : specs_.inputs) {
679     TensorId tensor = ParseTensorName(input.first);
680     auto remapped_it = remapped_feeds_.find(tensor);
681     if (remapped_it != remapped_feeds_.end()) {
682       TF_RETURN_IF_ERROR(add_node(remapped_it->second));
683     } else {
684       TF_RETURN_IF_ERROR(add_node(tensor.node()));
685     }
686   }
687 
688   for (const auto& output : specs_.outputs) {
689     TensorId tensor = ParseTensorName(output);
690     auto remapped_it = remapped_feeds_.find(tensor);
691     if (remapped_it != remapped_feeds_.end()) {
692       TF_RETURN_IF_ERROR(add_node(remapped_it->second));
693     } else {
694       TF_RETURN_IF_ERROR(add_node(tensor.node()));
695     }
696   }
697 
698   for (const auto& control_output : specs_.control_outputs)
699     TF_RETURN_IF_ERROR(add_node(control_output));
700 
701   return Status::OK();
702 }
703 
704 // TODO(jpienaar): Remove this post shape inference on import flag is removed.
AddNodesToShapeRefiner(std::unordered_map<string,Node * > * node_name_map)705 Status ImporterBase::AddNodesToShapeRefiner(
706     std::unordered_map<string, Node*>* node_name_map) {
707   shape_refiner_ = absl::make_unique<ShapeRefiner>(graph_->versions(),
708                                                    graph_->op_registry());
709   // Some operations (for example "TPUExecute") don't have shape inference
710   // function defined, so we should set this to false for adding nodes with
711   // these types of operations.
712   shape_refiner_->set_require_shape_inference_fns(false);
713   shape_refiner_->set_function_library_for_shape_inference(&graph_flib_);
714 
715   TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
716 
717   // First add all nodes to the refiner.
718   for (Node* node : ordered_nodes_) {
719     // We need to use a TensorFlow node to teach the shape refiner that user
720     // specifies certain data type and shape for the inputs in the `specs_`.
721     // This node shouldn't have any inputs, only have one output and its
722     // output type/shape is only determined by its "named" attributes. (The
723     // attributes should have fixed names so we can use the info from `specs_`
724     // to set the value of them.) `Placeholder` satisfies these constraints.
725     //
726     // Therefore, if the input node isn't a `Placeholder`, we create one and use
727     // it to replace the original input node, so the shape refiner can
728     // successfully propagate the user's input type and shape to the rest of the
729     // graph.
730     bool node_added_to_shape_refiner = false;
731     auto it = feeds_by_node.find(node->name());
732     if (it != feeds_by_node.end()) {
733       auto op_name = node->op_def().name();
734       if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
735           op_name != FunctionLibraryDefinition::kArgOp) {
736         for (const auto& output_tensor : it->second) {
737           const int index = output_tensor.first;
738           const ArrayInfo& array_info = output_tensor.second->second;
739 
740           DataType dtype = array_info.imported_dtype;
741           // Uses the existing output type if it isn't specified by the user.
742           if (dtype == DT_INVALID) {
743             dtype = node->output_type(index);
744           }
745 
746           TF_ASSIGN_OR_RETURN(
747               auto placeholder_node_and_removed,
748               CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
749                                            *node_name_map));
750 
751           Node* placeholder_node = placeholder_node_and_removed.first;
752           if (placeholder_node_and_removed.second) {
753             // Original node has been removed from the graph.
754             node = placeholder_node;
755             node_added_to_shape_refiner = true;
756           }
757           remapped_feeds_[{it->first, index}] = placeholder_node->name();
758           (*node_name_map)[placeholder_node->name()] = placeholder_node;
759           // Add the new placeholder node to the shape refiner.
760           Status status = shape_refiner_->AddNode(placeholder_node);
761           if (!status.ok()) {
762             return EmitErrorWithLocationStr(*placeholder_node, status);
763           }
764         }
765       } else {
766         auto index_it = it->second.find(0);
767         if (index_it == it->second.end()) {
768           return errors::FailedPrecondition(
769               "Missing feed output tensor at index 0 for node '", node->name(),
770               "'");
771         }
772         node->AddAttr("shape", index_it->second->second.shape);
773         DataType dtype = index_it->second->second.imported_dtype;
774         // Uses the existing output type if it isn't specified by the user.
775         if (dtype == DT_INVALID) {
776           dtype = node->output_type(0);
777         }
778         node->AddAttr("dtype", dtype);
779       }
780     }
781     if (!node_added_to_shape_refiner) {
782       // Add the node to the shape refiner if the node hasn't been removed.
783       Status status = shape_refiner_->AddNode(node);
784       if (!status.ok()) {
785         return EmitErrorWithLocationStr(*node, status);
786       }
787     }
788 
789     auto set_shape_from_list_attr = [&](const AttrValue* attr) {
790       auto& list = attr->list();
791       for (auto shape : llvm::enumerate(list.shape())) {
792         auto* node_context = shape_refiner_->GetContext(node);
793         shape_inference::ShapeHandle handle;
794         Status status =
795             node_context->MakeShapeFromShapeProto(shape.value(), &handle);
796         if (!status.ok()) {
797           return EmitErrorWithLocationStr(*node, status);
798         }
799         node_context->set_output(shape.index(), handle);
800       }
801       return Status::OK();
802     };
803 
804     // We currently have no other way to get shapes from ReadVariableOp's.
805     // Some graphs seem to have _output_shapes attributes on them, so use that
806     // if possible.
807     // TODO(silvasean): Ideally, we would do this in a separate shape inference
808     // pass to avoid adding complexity to the importer. But right now, we don't
809     // have an MLIR-native shape inference pass, so we need to do this while we
810     // still have the Graph around, i.e. here, in the importer.
811     if (node->op_def().name() == "ReadVariableOp") {
812       // TODO(silvasean): In some graphs, this seems to be annotated on every
813       // node. Why and by whom?
814       // TODO(b/140588338): We should ideally incorporate that information for
815       // all nodes, but right now, this can result in e.g. an Identity node with
816       // signature such as
817       // `(tensor<?x?xf32>) -> tensor<?x9216xf32>` which fails the verifier
818       // (which checks for exact type equality; _output_shapes results in
819       // us shoehorning in the more-precise type on the output).
820       if (const AttrValue* attr = node->attrs().Find("_output_shapes"))
821         TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
822     }
823 
824     // If it is the argument node, the shape handle is set explicitly, so it
825     // can be propagated to the body nodes of the function.
826     if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) {
827       auto* node_context = shape_refiner_->GetContext(node);
828       DCHECK(node_context != nullptr);
829       if (const AttrValue* attr = node->attrs().Find("shape")) {
830         shape_inference::ShapeHandle handle;
831         Status status =
832             node_context->MakeShapeFromShapeProto(attr->shape(), &handle);
833         if (!status.ok()) {
834           return EmitErrorWithLocationStr(*node, status);
835         }
836         node_context->set_output(0, handle);
837       } else if (const AttrValue* attr = node->attrs().Find("_output_shapes")) {
838         TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
839       } else {
840         node_context->set_output(0, node_context->UnknownShape());
841       }
842     }
843   }
844 
845   // Since we might have inserted and removed nodes from the graph, fix
846   // source/sink edges and reconstruct the RPO ordering of nodes
847   FixupSourceAndSinkEdges(graph_.get());
848 
849   // Prune nodes in the graph that are not reachable from the output.
850   if (specs_.prune_unused_nodes) {
851     std::unordered_set<const Node*> prune_start;
852     TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start));
853     if (!prune_start.empty()) {
854       if (PruneForReverseReachability(graph_.get(), prune_start)) {
855         VLOG(1) << "Pruned unused nodes in graphdef";
856       } else {
857         VLOG(1) << "No unused nodes in graphdef to prune";
858       }
859     } else {
860       VLOG(1) << "No output nodes specified, skipping pruning";
861     }
862   } else {
863     VLOG(1) << "Pruning unused nodes in graphdef is disabled";
864   }
865 
866   // Re-initialize ordered_nodes_ since we might have modified the graph.
867   GetReversePostOrder(
868       *graph_, &ordered_nodes_,
869       [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
870 
871   VLOG(1) << "Inferring graph shapes to fixpoint";
872 
873   // The "changed" information from UpdateNode can give false positives, so we
874   // create a dedicated method to verify the shapes are not changed before and
875   // after the shape refine.
876   auto same_inferred_shape = [](shape_inference::InferenceContext* c,
877                                 shape_inference::ShapeHandle s0,
878                                 shape_inference::ShapeHandle s1) -> bool {
879     if (s0.SameHandle(s1) || (!c->RankKnown(s0) && !c->RankKnown(s1))) {
880       return true;
881     }
882     if (c->Rank(s0) != c->Rank(s1)) {
883       return false;
884     }
885     for (int i = 0; i < c->Rank(s0); ++i) {
886       if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
887         int64 val0 = c->Value(c->Dim(s0, i));
888         int64 val1 = c->Value(c->Dim(s1, i));
889         // Negative value is treated as unknown so all negative values indicate
890         // the same dimension.
891         if (val0 >= 0 && val1 >= 0 && val0 != val1) return false;
892       }
893     }
894     return true;
895   };
896 
897   bool changed = true;
898   int i = 0;
899   const int kMaxIterationCount = 2;
900   while (changed && i != kMaxIterationCount) {
901     changed = false;
902     for (const Node* node : ordered_nodes_) {
903       auto* shape_context = shape_refiner_->GetContext(node);
904       DCHECK(shape_context != nullptr);
905       absl::InlinedVector<shape_inference::ShapeHandle, 4> existing;
906       existing.reserve(shape_context->num_outputs());
907       for (int o = 0; o < shape_context->num_outputs(); ++o) {
908         existing.push_back(shape_context->output(o));
909       }
910       bool inferred = false;
911       shape_inference::ShapeHandle handle;
912       Status status =
913           shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred);
914       if (!status.ok()) {
915         return EmitErrorWithLocationStr(*node, status);
916       }
917       for (int o = 0; o < shape_context->num_outputs(); ++o) {
918         if (!same_inferred_shape(shape_context, shape_context->output(o),
919                                  existing[o])) {
920           changed = true;
921           break;
922         }
923       }
924     }
925     ++i;
926   }
927   if (i >= kMaxIterationCount) {
928     LOG(WARNING) << "Graph shapes did not converge to a fixpoint within "
929                  << kMaxIterationCount
930                  << " iterations. Graph shapes may be conservative.";
931   }
932   VLOG(1) << "Graph shapes were inferred with " << (i - 1)
933           << " extra rounds of analysis to reach a fixpoint.";
934   return Status::OK();
935 }
936 
InferInputType(const Node & node,int idx,mlir::Builder builder)937 StatusOr<mlir::Type> ImporterBase::InferInputType(const Node& node, int idx,
938                                                   mlir::Builder builder) {
939   if (specs_.enable_shape_inference) {
940     // TODO(jpienaar): Remove this if shape inference on import flag is removed.
941     ExtendedInferenceContext* shape_context =
942         shape_refiner_->GetExtendedContext(&node);
943     DataType dtype = shape_context->input_type(idx);
944     auto* context = shape_context->get_context();
945     return ConvertDataTypeAndShape(dtype, context->input(idx),
946                                    context->input_handle_shapes_and_types(idx),
947                                    context, builder);
948   }
949   DataType dtype = node.properties()->input_types[idx];
950   mlir::Type element_type;
951   TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type));
952   return mlir::UnrankedTensorType::get(element_type);
953 }
954 
InferOutputType(const Node & node,int idx,mlir::Builder builder)955 StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
956                                                    mlir::Builder builder) {
957   DataType dtype = node.properties()->output_types[idx];
958 
959   // Returns output type given inference context.
960   auto shape_ic = [&](shape_inference::InferenceContext* c) {
961     return ConvertDataTypeAndShape(dtype, c->output(idx),
962                                    c->output_handle_shapes_and_types(idx), c,
963                                    builder);
964   };
965 
966   if (specs_.enable_shape_inference) {
967     // TODO(jpienaar): Remove this if shape inference on import flag is removed.
968     ExtendedInferenceContext* shape_context =
969         shape_refiner_->GetExtendedContext(&node);
970     return shape_ic(shape_context->get_context());
971   }
972 
973   // Treat TensorList init ops specially here as the op requires knowing its
974   // element dtype.
975   // TODO(jpienaar): Reconsider post refactoring shape functions.
976   if (node.type_string() == "TensorListReserve" ||
977       node.type_string() == "EmptyTensorList") {
978     mlir::Type etype;
979     if (auto element_dtype = node.attrs().Find("element_dtype")) {
980       TF_RETURN_IF_ERROR(
981           ConvertDataType(element_dtype->type(), builder, &etype));
982     }
983     return mlir::RankedTensorType::get(
984         {}, mlir::TF::VariantType::get({mlir::UnrankedTensorType::get(etype)},
985                                        etype.getContext()));
986   }
987 
988   if (node.IsWhileNode()) {
989     auto* output_shapes = node.attrs().Find("output_shapes");
990     auto* element_types = node.attrs().Find("T");
991     if (output_shapes && !output_shapes->list().shape().empty()) {
992       const auto& output_shape = output_shapes->list().shape(idx);
993       const auto& element_type = element_types->list().type(idx);
994       return ConvertToMlirTensorType(output_shape, element_type, &builder);
995     }
996   }
997 
998   auto type_from_array_attr = [&node, &idx, &builder](
999                                   absl::string_view output_shape_attr,
1000                                   absl::string_view element_type_attr) {
1001     auto* output_shapes = node.attrs().Find(output_shape_attr);
1002     auto* element_types = node.attrs().Find(element_type_attr);
1003     const auto& output_shape = output_shapes->list().shape(idx);
1004     const auto& element_type = element_types->list().type(idx);
1005     return ConvertToMlirTensorType(output_shape, element_type, &builder);
1006   };
1007 
1008   if (node.type_string() == "IteratorGetNext" ||
1009       node.type_string() == "IteratorGetNextSync" ||
1010       node.type_string() == "MultiDeviceIteratorGetNextFromShard")
1011     return type_from_array_attr("output_shapes", "output_types");
1012 
1013   if (node.type_string() == "InfeedDequeueTuple")
1014     return type_from_array_attr("shapes", "dtypes");
1015 
1016   if (node.type_string() == "InfeedDequeue") {
1017     assert(idx == 0);
1018     const auto& output_shape = node.attrs().Find("shape")->shape();
1019     const auto& element_type = node.attrs().Find("dtype")->type();
1020     return ConvertToMlirTensorType(output_shape, element_type, &builder);
1021   }
1022 
1023   // Returns a simple, more conservative unranked tensor type.
1024   auto default_type = [&]() -> StatusOr<mlir::Type> {
1025     mlir::Type element_type;
1026     TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type));
1027     return mlir::UnrankedTensorType::get(element_type);
1028   };
1029 
1030   // Below we only try and do some shape inference for "source" ops which have
1031   // no inputs.
1032   if (node.num_inputs() > 0) return default_type();
1033 
1034   // Do some simply inference here to get the function arguments correct for
1035   // this common case.
1036   // TODO(jpienaar): Reconsider post refactoring shape functions.
1037   if (node.IsArg()) {
1038     if (dtype == DT_RESOURCE) {
1039       const AttrValue* dtype_attr = node.attrs().Find("_handle_dtypes");
1040       const AttrValue* shape_attr = node.attrs().Find("_handle_shapes");
1041       if (dtype_attr && shape_attr) {
1042         if (dtype_attr->list().type().empty()) {
1043           return errors::InvalidArgument(
1044               "Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
1045               shape_attr->DebugString());
1046         }
1047         if (shape_attr->list().shape().empty()) {
1048           return errors::InvalidArgument(
1049               "Invalid \"_handle_shapes\" attribute value for _Arg node: ",
1050               shape_attr->DebugString());
1051         }
1052         DataType dtype = dtype_attr->list().type(0);
1053         const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
1054         TF_ASSIGN_OR_RETURN(
1055             auto etype, ConvertToMlirTensorType(shape_proto, dtype, &builder));
1056         return mlir::UnrankedTensorType::get(mlir::TF::ResourceType::get(
1057             {etype.cast<TensorType>()}, builder.getContext()));
1058       } else {
1059         return mlir::UnrankedTensorType::get(
1060             mlir::TF::ResourceType::get(builder.getContext()));
1061       }
1062     } else if (auto shape = node.attrs().Find("_output_shapes")) {
1063       if (shape->has_list() && shape->list().shape_size() == 1) {
1064         return ConvertToMlirTensorType(shape->list().shape().at(0), dtype,
1065                                        &builder);
1066       }
1067     }
1068   }
1069 
1070   const tensorflow::OpRegistrationData* op_reg_data;
1071   TF_RETURN_IF_ERROR(
1072       graph_->op_registry()->LookUp(node.type_string(), &op_reg_data));
1073   if (!op_reg_data) {
1074     DVLOG(1) << "Skipping inference for unregistered op " << node.type_string();
1075     return default_type();
1076   }
1077   if (op_reg_data->shape_inference_fn == nullptr) {
1078     DVLOG(1) << "Skipping inference for op without shape function "
1079              << node.type_string();
1080     return default_type();
1081   }
1082   shape_inference::InferenceContext c(graph_->versions().producer(),
1083                                       node.attrs(), op_reg_data->op_def,
1084                                       std::vector<PartialTensorShape>{}, {},
1085                                       /*input_tensors_as_shapes=*/{}, {});
1086   TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
1087   return shape_ic(&c);
1088 }
1089 
ConvertDataTypeAndShape(DataType dtype,const shape_inference::ShapeHandle & handle,const std::vector<shape_inference::ShapeAndType> * handle_subtypes,shape_inference::InferenceContext * context,mlir::Builder builder)1090 StatusOr<TensorType> ImporterBase::ConvertDataTypeAndShape(
1091     DataType dtype, const shape_inference::ShapeHandle& handle,
1092     const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
1093     shape_inference::InferenceContext* context, mlir::Builder builder) {
1094   TF_ASSIGN_OR_RETURN(auto subtypes,
1095                       ConvertSubtypes(handle_subtypes, context, builder));
1096 
1097   mlir::Type element_type;
1098   if (dtype == DT_VARIANT)
1099     element_type = mlir::TF::VariantType::get(subtypes, context_);
1100   else if (dtype == DT_RESOURCE)
1101     element_type = mlir::TF::ResourceType::get(subtypes, context_);
1102   else
1103     TF_RETURN_IF_ERROR(
1104         ::tensorflow::ConvertDataType(dtype, builder, &element_type));
1105 
1106   return ConvertElementTypeAndShape(element_type, handle, context, builder);
1107 }
1108 
ConvertElementTypeAndShape(mlir::Type element_type,const shape_inference::ShapeHandle & handle,shape_inference::InferenceContext * context,mlir::Builder builder)1109 StatusOr<TensorType> ImporterBase::ConvertElementTypeAndShape(
1110     mlir::Type element_type, const shape_inference::ShapeHandle& handle,
1111     shape_inference::InferenceContext* context, mlir::Builder builder) {
1112   if (!context->RankKnown(handle)) {
1113     return mlir::UnrankedTensorType::get(element_type);
1114   }
1115 
1116   // Sentinel for an unknown dimension size. getTensorType interprets any
1117   // negative value as an unknown dimension.
1118   // TODO(jmolloy): Ideally this shouldn't be a local sentinel.
1119   const int64_t kUnknownDim = -1;
1120 
1121   absl::InlinedVector<int64_t, 4> dimensions;
1122   int32 rank = context->Rank(handle);
1123   dimensions.reserve(rank);
1124   for (int i = 0; i < rank; ++i) {
1125     auto dim_handle = context->Dim(handle, i);
1126     if (!context->ValueKnown(dim_handle))
1127       dimensions.push_back(kUnknownDim);
1128     else
1129       dimensions.push_back(context->Value(dim_handle));
1130   }
1131 
1132   return mlir::RankedTensorType::get(
1133       llvm::makeArrayRef(dimensions.begin(), dimensions.end()), element_type);
1134 }
1135 
ConvertSubtypes(const std::vector<shape_inference::ShapeAndType> * handle_subtypes,shape_inference::InferenceContext * context,mlir::Builder builder)1136 StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes(
1137     const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
1138     shape_inference::InferenceContext* context, mlir::Builder builder) {
1139   ElementSubtypes subtypes;
1140   if (!handle_subtypes) return subtypes;
1141 
1142   subtypes.reserve(handle_subtypes->size());
1143   for (const auto& subtype : *handle_subtypes) {
1144     mlir::Type element_type;
1145     TF_RETURN_IF_ERROR(
1146         ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type));
1147     TF_ASSIGN_OR_RETURN(TensorType type,
1148                         ConvertElementTypeAndShape(element_type, subtype.shape,
1149                                                    context, builder));
1150     subtypes.push_back(type);
1151   }
1152   return subtypes;
1153 }
1154 
ConvertFunctionCallAttribute(const std::string & base_name,const AttrValue & value,NamedAttrList * attributes)1155 Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name,
1156                                                   const AttrValue& value,
1157                                                   NamedAttrList* attributes) {
1158   TF_ASSIGN_OR_RETURN(auto func_attr,
1159                       ConvertFunctionCallName(value.func().name()));
1160   attributes->push_back(builder_.getNamedAttr(base_name, func_attr));
1161 
1162   for (const auto& it : value.func().attr()) {
1163     auto name = absl::StrCat(base_name, ".", it.first);
1164     TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second));
1165     attributes->push_back(builder_.getNamedAttr(name, value));
1166   }
1167   return Status::OK();
1168 }
1169 
ConvertFunctionCallName(const std::string & func_name)1170 StatusOr<mlir::FlatSymbolRefAttr> ImporterBase::ConvertFunctionCallName(
1171     const std::string& func_name) {
1172   TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
1173   auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
1174   auto func = module_.lookupSymbol<mlir::FuncOp>(mlir_func_name);
1175   return builder_.getSymbolRefAttr(func);
1176 }
1177 
ConvertAttributeValue(const AttrValue & value)1178 StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
1179     const AttrValue& value) {
1180   switch (value.value_case()) {
1181     case AttrValue::kFunc: {
1182       // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation.
1183       // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue
1184       // will not use this representation.
1185       NamedAttrList attrs;
1186       for (const auto& func_attr : value.func().attr()) {
1187         TF_ASSIGN_OR_RETURN(
1188             auto attr, ImporterBase::ConvertAttributeValue(func_attr.second));
1189         attrs.push_back(builder_.getNamedAttr(func_attr.first, attr));
1190       }
1191       auto func_attrs = builder_.getDictionaryAttr(attrs);
1192       return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs);
1193     }
1194     case AttrValue::kList: {
1195       if (!value.list().func().empty()) {
1196         absl::InlinedVector<mlir::Attribute, 8> attrs;
1197         for (const auto& item : value.list().func()) {
1198           TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name()));
1199           if (item.attr_size() != 0)
1200             return errors::Unimplemented(
1201                 "func attributes with non-zero attr.size()");
1202           attrs.push_back(attr);
1203         }
1204         return builder_.getArrayAttr(
1205             llvm::makeArrayRef(attrs.begin(), attrs.end()));
1206       }
1207       return ConvertNonFuncAttributeValue(value, &builder_);
1208     }
1209     default:
1210       return ConvertNonFuncAttributeValue(value, &builder_);
1211   }
1212 }
1213 
GetArgsAndRetsFromFunctionBody(const FunctionBody & fbody,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes,absl::InlinedVector<Node *,4> * control_ret_nodes)1214 void ImporterBase::GetArgsAndRetsFromFunctionBody(
1215     const FunctionBody& fbody, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
1216     absl::InlinedVector<OutputTensor, 4>* ret_nodes,
1217     absl::InlinedVector<Node*, 4>* control_ret_nodes) {
1218   arg_nodes->reserve(fbody.arg_nodes.size());
1219   ret_nodes->reserve(fbody.ret_nodes.size());
1220   for (auto arg : fbody.arg_nodes) {
1221     arg_nodes->emplace_back(arg, 0);
1222   }
1223   for (auto ret : fbody.ret_nodes) {
1224     ret_nodes->emplace_back(ret, 0);
1225   }
1226   *control_ret_nodes = fbody.control_ret_nodes;
1227 }
1228 
ConvertLibFunction(llvm::StringRef func_name)1229 Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
1230   // If the library function has been converted already, nothing needs to be
1231   // done.
1232   if (tf_name_to_mlir_name_->find(std::string(func_name)) !=
1233       tf_name_to_mlir_name_->end())
1234     return Status::OK();
1235 
1236   std::string mlir_func_name(
1237       function_name_uniquifier_->GetUniqueName(func_name));
1238   (*tf_name_to_mlir_name_)[std::string(func_name)] = mlir_func_name;
1239 
1240   const auto& func_lib = graph_flib_;
1241   const auto* func_def = func_lib.Find(std::string(func_name));
1242   if (func_def == nullptr) {
1243     return errors::FailedPrecondition(
1244         absl::StrCat("Failed to find function '", StringRefToView(func_name),
1245                      "'. The imported TensorFlow GraphDef is ill-formed."));
1246   }
1247 
1248   // Converts the function definition to a graph.
1249   std::unique_ptr<FunctionBody> fbody;
1250   TF_RETURN_IF_ERROR(
1251       FunctionDefToBodyHelper(*func_def, AttrSlice(), &func_lib, &fbody));
1252 
1253   // Converts the argument and return types to MLIR types.
1254   absl::InlinedVector<mlir::NamedAttribute, 8> attributes;
1255   attributes.reserve(func_def->attr_size());
1256   for (const auto& name_and_value : func_def->attr()) {
1257     // This is a function definition attribute, so it shouldn't contain
1258     // kFunc attribute and it is treated as normal one.
1259     TF_ASSIGN_OR_RETURN(auto attr,
1260                         ConvertAttributeValue(name_and_value.second));
1261     std::string attr_name =
1262         mangling_util::MangleAttributeName(name_and_value.first);
1263     attributes.push_back(builder_.getNamedAttr(attr_name, attr));
1264   }
1265 
1266   // Checks opdef stateful attribute and import that as Function Attribute
1267   if (func_def->signature().is_stateful()) {
1268     auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
1269     attributes.push_back(
1270         builder_.getNamedAttr(stateful_str, builder_.getUnitAttr()));
1271   }
1272 
1273   // Checks for an associated custom gradient function. Adds it to the attribute
1274   // list of this function.
1275   auto grad_func_name = func_lib.FindGradient(std::string(func_name));
1276   if (!grad_func_name.empty()) {
1277     TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
1278     auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
1279     auto grad_func = module_.lookupSymbol<mlir::FuncOp>(mlir_grad_func_name);
1280     auto gradient_attr = builder_.getSymbolRefAttr(grad_func);
1281     auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
1282     attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr));
1283   }
1284 
1285   // Converts the graph to an MLIR function and adds it to the module.
1286   // We populate the NodeSpec so that all the _Arg ops get their shape
1287   // added correctly.
1288   GraphImportConfig specs;
1289   specs.enable_shape_inference = specs_.enable_shape_inference;
1290   for (const auto& name_and_value : func_def->attr()) {
1291     if (name_and_value.first == "_input_shapes") {
1292       auto& list = name_and_value.second.list();
1293       auto& signature = func_def->signature();
1294       if (list.shape_size() != signature.input_arg_size()) {
1295         return errors::FailedPrecondition(
1296             "Number of input arguments must be equal to the length of "
1297             "_input_shapes attribute in function '",
1298             StringRefToView(func_name), "'.");
1299       }
1300       for (int i = 0; i < list.shape_size(); i++) {
1301         auto& input_arg = signature.input_arg(i);
1302         auto& array_info = specs.inputs[input_arg.name()];
1303         array_info.imported_dtype = input_arg.type();
1304         array_info.shape = list.shape(i);
1305       }
1306     }
1307   }
1308 
1309   ImporterBase child_importer(graph_flib_, debug_info_, specs, module_,
1310                               tf_name_to_mlir_name_, function_name_uniquifier_,
1311                               func_name);
1312   TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph));
1313 
1314   TF_ASSIGN_OR_RETURN(auto func_type,
1315                       child_importer.InferLibFunctionType(*fbody));
1316 
1317   absl::InlinedVector<OutputTensor, 4> arg_nodes;
1318   absl::InlinedVector<OutputTensor, 4> ret_nodes;
1319   absl::InlinedVector<Node*, 4> control_ret_nodes;
1320   GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes,
1321                                  &control_ret_nodes);
1322 
1323   TF_RETURN_IF_ERROR(child_importer.Convert(
1324       mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes,
1325       llvm::makeArrayRef(attributes.begin(), attributes.end())));
1326   return Status::OK();
1327 }
1328 
PruneUnreachableNodes(std::unordered_map<string,Node * > * node_name_map)1329 Status ImporterBase::PruneUnreachableNodes(
1330     std::unordered_map<string, Node*>* node_name_map) {
1331   std::unordered_set<const Node*> prune_start;
1332   TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start));
1333 
1334   if (!prune_start.empty()) {
1335     if (PruneForReverseReachability(graph_.get(), prune_start)) {
1336       VLOG(1) << "Pruned unused nodes in graphdef";
1337     } else {
1338       VLOG(1) << "No unused nodes in graphdef to prune";
1339     }
1340   } else {
1341     VLOG(1) << "No output nodes specified, skipping pruning";
1342   }
1343   return Status::OK();
1344 }
1345 
ConvertFeedsToPlaceholders(std::unordered_map<string,Node * > * node_name_map)1346 Status ImporterBase::ConvertFeedsToPlaceholders(
1347     std::unordered_map<string, Node*>* node_name_map) {
1348   // Feeds (edges) are converted into single-output placeholder nodes to
1349   // simplify the conversion process.
1350   TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
1351   for (const auto& it : feeds_by_node) {
1352     TensorId tensor = ParseTensorName(it.first);
1353     auto jt = node_name_map->find(std::string(tensor.node()));
1354     if (jt == node_name_map->end()) {
1355       return errors::FailedPrecondition(
1356           absl::StrCat("Graph does not contain node: ", tensor.node()));
1357     }
1358 
1359     Node* node = jt->second;
1360     auto op_name = node->op_def().name();
1361     if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
1362         op_name != FunctionLibraryDefinition::kArgOp) {
1363       for (const auto& output_tensor : it.second) {
1364         const int index = output_tensor.first;
1365         const ArrayInfo& array_info = output_tensor.second->second;
1366 
1367         DataType dtype = array_info.imported_dtype;
1368         // Uses the existing output type if it isn't specified by the user.
1369         if (dtype == DT_INVALID) {
1370           dtype = node->output_type(index);
1371         }
1372 
1373         TF_ASSIGN_OR_RETURN(
1374             auto placeholder_node_and_removed,
1375             CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
1376                                          *node_name_map));
1377 
1378         Node* placeholder_node = placeholder_node_and_removed.first;
1379         if (placeholder_node->in_edges().empty()) {
1380           graph_->AddControlEdge(graph_->source_node(), placeholder_node,
1381                                  true /* skip test for duplicates */);
1382         }
1383         if (placeholder_node->out_edges().empty()) {
1384           graph_->AddControlEdge(placeholder_node, graph_->sink_node(),
1385                                  true /* skip test for duplicates */);
1386         }
1387         remapped_feeds_[{it.first, index}] = placeholder_node->name();
1388         (*node_name_map)[placeholder_node->name()] = placeholder_node;
1389       }
1390     }
1391   }
1392   return Status::OK();
1393 }
1394 
PrepareConvert(const Graph & graph)1395 Status ImporterBase::PrepareConvert(const Graph& graph) {
1396   TF_RETURN_IF_ERROR(RemoveBackedges(graph));
1397   TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get()));
1398 
1399   auto node_name_map = graph_->BuildNodeNameIndex();
1400 
1401   if (specs_.enable_shape_inference) {
1402     // TODO(jpienaar): Remove once infer shapes on import flag is removed.
1403     TF_RETURN_IF_ERROR(AddNodesToShapeRefiner(&node_name_map));
1404   } else {
1405     TF_RETURN_IF_ERROR(ConvertFeedsToPlaceholders(&node_name_map));
1406   }
1407 
1408   // Prune nodes in the graph that are not reachable from the output.
1409   if (specs_.prune_unused_nodes) {
1410     TF_RETURN_IF_ERROR(PruneUnreachableNodes(&node_name_map));
1411   }
1412 
1413   if (!specs_.enable_shape_inference) {
1414     // Re-initialize ordered_nodes_ since we might have modified the graph.
1415     GetReversePostOrder(
1416         *graph_, &ordered_nodes_,
1417         [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
1418   }
1419 
1420   return Status::OK();
1421 }
1422 
Convert(llvm::StringRef func_name,mlir::FunctionType func_type,const absl::InlinedVector<OutputTensor,4> & arg_nodes,const absl::InlinedVector<OutputTensor,4> & ret_nodes,const absl::InlinedVector<Node *,4> & control_ret_nodes,llvm::ArrayRef<mlir::NamedAttribute> attrs)1423 Status ImporterBase::Convert(
1424     llvm::StringRef func_name, mlir::FunctionType func_type,
1425     const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
1426     const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
1427     const absl::InlinedVector<Node*, 4>& control_ret_nodes,
1428     llvm::ArrayRef<mlir::NamedAttribute> attrs) {
1429   // TODO(b/122040776): Uses debug info for FunctionDef.
1430   auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
1431                                        func_name, func_type, attrs);
1432 
1433   module_.push_back(function);
1434   // Seeds the builder with an initial block.
1435   function.addEntryBlock();
1436   builder_ = mlir::OpBuilder(function.getBody());
1437 
1438   // Create the graph operation in which we will convert the individual nodes.
1439   auto graph = builder_.create<mlir::tf_executor::GraphOp>(
1440       function.getLoc(), func_type.getResults());
1441   builder_.createBlock(&graph.body());
1442 
1443   for (const Node* node : ordered_nodes_) {
1444     TF_RETURN_IF_ERROR(ConvertNode(*node));
1445   }
1446 
1447   // Adds the backedges back to the function by creating the source and sink
1448   // pairs.
1449   TF_RETURN_IF_ERROR(AddBackedges());
1450 
1451   TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph,
1452                                                func_type.getInputs(), arg_nodes,
1453                                                ret_nodes, control_ret_nodes));
1454 
1455   // TODO(jpienaar): Update post removing shape_refinier_.
1456   if (!specs_.enable_shape_inference) {
1457     // Refine graph's type given more precise fetch.
1458     auto fetch = graph.GetFetch();
1459     bool all_equal = true;
1460     for (auto it :
1461          llvm::zip_first(graph.getResults(), fetch.getOperandTypes())) {
1462       auto rt = std::get<1>(it);
1463       if (rt == std::get<0>(it).getType()) continue;
1464       std::get<0>(it).setType(rt);
1465       all_equal = false;
1466     }
1467     if (!all_equal) {
1468       function.setType(mlir::FunctionType::get(function.getContext(),
1469                                                func_type.getInputs(),
1470                                                graph.getResultTypes()));
1471     }
1472   }
1473 
1474   return Status::OK();
1475 }
1476 
ConvertFunctionArgAndRets(mlir::FuncOp func,mlir::tf_executor::GraphOp graph_op,llvm::ArrayRef<mlir::Type> arg_types,const absl::InlinedVector<OutputTensor,4> & arg_nodes,const absl::InlinedVector<OutputTensor,4> & ret_nodes,const absl::InlinedVector<Node *,4> & control_ret_nodes)1477 Status ImporterBase::ConvertFunctionArgAndRets(
1478     mlir::FuncOp func, mlir::tf_executor::GraphOp graph_op,
1479     llvm::ArrayRef<mlir::Type> arg_types,
1480     const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
1481     const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
1482     const absl::InlinedVector<Node*, 4>& control_ret_nodes) {
1483   // Store the arg/return attributes as a list rather than uniqueuing during
1484   // construction.
1485   llvm::SmallVector<mlir::NamedAttrList, 4> arg_attrs;
1486   arg_attrs.resize(func.getNumArguments());
1487   llvm::SmallVector<mlir::NamedAttrList, 4> ret_attrs;
1488   ret_attrs.resize(func.getNumResults());
1489 
1490   auto set_attributes_on_func = [&](Node* node, int64_t index, bool is_arg) {
1491     for (const auto& node_attr : node->attrs()) {
1492       const auto& key = node_attr.first;
1493       // Only import optional attributes (e.g., those starting with an
1494       // underscore).
1495       if (key.empty() || key[0] != '_') continue;
1496       // Ignore shape inference attributes as shape information is already
1497       // populated in the result type.
1498       if (IsOutputShapesAttribute(node_attr.second, key) ||
1499           IsResourceOutputShapesAttribute(node_attr.second, key))
1500         continue;
1501       TF_ASSIGN_OR_RETURN(auto converted_attr,
1502                           ConvertAttributeValue(node_attr.second));
1503       std::string dialect_attribute = "tf." + key;
1504       if (is_arg) {
1505         arg_attrs[index].set(dialect_attribute, converted_attr);
1506       } else {
1507         func.setResultAttr(index, dialect_attribute, converted_attr);
1508         ret_attrs[index].set(dialect_attribute, converted_attr);
1509       }
1510     }
1511     return Status::OK();
1512   };
1513 
1514   auto* bb = &func.front();
1515   llvm::SmallDenseMap<std::pair<Node*, int>, mlir::Value, 4>
1516       arg_nodes_to_values;
1517   for (int i = 0, e = arg_types.size(); i < e; ++i) {
1518     auto& arg_node = arg_nodes[i];
1519     // The lookup can't fail here: otherwise some nodes in the function haven't
1520     // be converted to mlir operations and don't have a mapping.
1521     mlir::Operation* island = node_values_.find(arg_node.node->id())->second;
1522 
1523     auto bb_arg = bb->getArgument(i);
1524     mlir::Value arg_def = bb_arg;
1525 
1526     if (island->getNumResults() != 2)
1527       return errors::InvalidArgument(
1528           "Only feed output tensors of single output nodes are supported");
1529 
1530     // Collect mapping of OutputTensor to associated block arg.
1531     arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def);
1532     island->getResult(0).replaceAllUsesWith(arg_def);
1533     // Erase control outputs from feed.
1534     auto control_uses = island->getResult(1).getUses();
1535     for (auto& control_use : llvm::make_early_inc_range(control_uses))
1536       control_use.getOwner()->eraseOperand(control_use.getOperandNumber());
1537 
1538     if (!arg_node.node->requested_device().empty())
1539       arg_attrs[i].set("tf.device", builder_.getStringAttr(
1540                                         arg_node.node->requested_device()));
1541 
1542     if (arg_node.node->IsArg()) {
1543       TF_RETURN_IF_ERROR(
1544           set_attributes_on_func(arg_node.node, i, /*is_arg=*/true));
1545     }
1546 
1547     island->dropAllReferences();
1548     island->erase();
1549   }
1550 
1551   llvm::SmallVector<mlir::Value, 8> inst_to_return;
1552   for (auto ret_and_idx : llvm::enumerate(ret_nodes)) {
1553     const auto& ret = ret_and_idx.value();
1554     auto* inst = node_values_[ret.node->id()];
1555     if (ret.node->IsRetval()) {
1556       if (!ret.node->requested_device().empty())
1557         ret_attrs[ret_and_idx.index()].set(
1558             "tf.device", builder_.getStringAttr(ret.node->requested_device()));
1559       TF_RETURN_IF_ERROR(set_attributes_on_func(ret.node, ret_and_idx.index(),
1560                                                 /*is_arg=*/false));
1561       // Lookup the instruction inside the island
1562       auto island_op = llvm::cast<mlir::tf_executor::IslandOp>(inst);
1563       mlir::Operation* inner_op = &island_op.GetBody().front();
1564       // Remove kRetOp or kDeviceRetOp operation and return its operand.
1565       // kRetOp and kDeviceRetOp should have just one operand unless they have
1566       // control dependencies.
1567       if (inner_op->getNumOperands() != 1)
1568         return errors::Unimplemented("Return node with multiple inputs.");
1569       inst_to_return.push_back(inner_op->getOperand(0));
1570       inst->dropAllReferences();
1571       inst->erase();
1572     } else {
1573       // Lookup and use block arg if fetch is a feed.
1574       auto it = arg_nodes_to_values.find({ret.node, ret.index});
1575       if (it != arg_nodes_to_values.end())
1576         inst_to_return.push_back(it->second);
1577       else
1578         inst_to_return.push_back(inst->getResult(ret.index));
1579     }
1580   }
1581 
1582   for (Node* control_ret : control_ret_nodes) {
1583     auto* inst = node_values_[control_ret->id()];
1584     inst_to_return.push_back(*std::prev(inst->result_end()));
1585   }
1586 
1587   // Terminate the function by adding a Fetch operation to terminate the graph
1588   // and a return operation to return the Graph results.
1589   builder_.setInsertionPointToEnd(&graph_op.body().front());
1590   builder_.create<mlir::tf_executor::FetchOp>(graph_op.getLoc(),
1591                                               inst_to_return);
1592   builder_.setInsertionPointToEnd(bb);
1593   builder_.create<mlir::ReturnOp>(mlir::UnknownLoc::get(context_),
1594                                   graph_op.getResults());
1595 
1596   func.setAllArgAttrs(
1597       llvm::to_vector<4>(llvm::map_range(arg_attrs, [&](NamedAttrList& list) {
1598         return list.getDictionary(context_);
1599       })));
1600   func.setAllResultAttrs(
1601       llvm::to_vector<4>(llvm::map_range(ret_attrs, [&](NamedAttrList& list) {
1602         return list.getDictionary(context_);
1603       })));
1604 
1605   return Status::OK();
1606 }
1607 
GetLocation(const Node & node)1608 mlir::Location ImporterBase::GetLocation(const Node& node) {
1609   DVLOG(1) << "Getting location for " << node.name() << " " << &node;
1610   // TODO(b/142400497): What is the semantic contract for locations?
1611   const auto& debug_info = debug_info_.traces();
1612 
1613   // Create a location for node `name` in function `function_name`.
1614   auto create_location = [&](llvm::StringRef name,
1615                              llvm::StringRef function_name) -> mlir::Location {
1616     // Use the catenation of function and node names as the lookup key into the
1617     // debug info. This matches the way that the key is formed on the python
1618     // side.
1619     //
1620     // We also use this as the name for the NameLoc for ops in function, since
1621     // otherwise our names could collide across functions.
1622     // For ops in the main graph, we omit the "@function_name" (which, would be
1623     // just "@" since function_name would be empty) because some code seems to
1624     // depend on the name being this way for correctness.
1625     std::string debug_info_key = (name + "@" + function_name).str();
1626     std::string name_for_name_loc =
1627         function_name.empty() ? name.str() : (name + "@" + function_name).str();
1628     auto name_loc_id = mlir::Identifier::get(name_for_name_loc, context_);
1629 
1630     llvm::SmallVector<mlir::Location, 4> locations;
1631     // Prefer stack traces if available, fallback to debug info if not, and then
1632     // finally to just name.
1633     if (auto stack_trace = node.GetStackTrace()) {
1634       DVLOG(1) << "Stack available for " << node.name();
1635       absl::Span<const StackFrame> frames = stack_trace->ToFrames();
1636       locations.reserve(frames.size());
1637       for (const StackFrame& frame : llvm::reverse(frames)) {
1638         auto file_name = mlir::Identifier::get(frame.file_name, context_);
1639         // Use col 1 as there is no column info in StackTrace.
1640         auto file_line_loc = mlir::FileLineColLoc::get(
1641             file_name, frame.line_number, 1, context_);
1642         locations.push_back(file_line_loc);
1643       }
1644     } else {
1645       DVLOG(1) << "No stack trace for " << node.name();
1646       const auto location_it = debug_info.find(debug_info_key);
1647       if (location_it != debug_info.end()) {
1648         DVLOG(1) << "Available serialized debug info for " << node.name();
1649         // Convert the stack trace to a chain of mlir::CallSiteLocs.
1650         const auto& trace = location_it->second;
1651         locations.reserve(trace.file_line_cols_size());
1652         for (const auto& location : trace.file_line_cols()) {
1653           const auto& file = debug_info_.files(location.file_index());
1654           auto file_name = mlir::Identifier::get(file, context_);
1655           auto file_line_loc = mlir::FileLineColLoc::get(
1656               file_name, location.line(), location.col(), context_);
1657           locations.push_back(file_line_loc);
1658         }
1659       }
1660     }
1661 
1662     // If there are no locations in the stack trace, fall back to just a
1663     // NameLoc with no child.
1664     if (locations.empty()) return mlir::NameLoc::get(name_loc_id, context_);
1665 
1666     // Use the front FileLineColLoc to generate a NameLoc.
1667     mlir::Location node_name_loc =
1668         mlir::NameLoc::get(name_loc_id, locations.front());
1669 
1670     // If there are more locations then generate a stack trace, otherwise just
1671     // return the name loc.
1672     auto callsite_locs = llvm::makeArrayRef(locations).drop_front();
1673     return callsite_locs.empty()
1674                ? node_name_loc
1675                : mlir::CallSiteLoc::get(node_name_loc, callsite_locs);
1676   };
1677 
1678   // For NextIteration nodes, location is used to pair source and sink nodes.
1679   // Hence, we use node name as location to keep it unique.
1680   // TODO(prakalps): In future the plan is to use tokens to pair source/sink
1681   // nodes. Then NextIteration nodes would not need to be handled separately.
1682   if (node.type_string() == "NextIteration")
1683     return create_location(node.name(), function_name_for_debug_info_);
1684 
1685   if (node.GetStackTrace())
1686     return create_location(node.name(), function_name_for_debug_info_);
1687 
1688   const auto& node_def = node.def();
1689   auto original_nodes =
1690       node_def.experimental_debug_info().original_node_names();
1691   auto original_funcs =
1692       node_def.experimental_debug_info().original_func_names();
1693 
1694   if (original_nodes.empty()) {
1695     return create_location(node.name(), function_name_for_debug_info_);
1696   } else {
1697     // If the original nodes are defined, then we use them to get a list of
1698     // call sites, and then fuse them to a single fused location, with the name
1699     // of the node_def.
1700     llvm::SmallVector<mlir::Location, 4> node_locations;
1701     node_locations.reserve(original_nodes.size() + 1);
1702 
1703     // store the names in the experimental_debug_info
1704     for (int i = 0, e = original_nodes.size(); i != e; ++i) {
1705       auto node_name = original_nodes[i];
1706       auto func_name = (i < original_funcs.size()) ? original_funcs[i] : "";
1707       node_locations.push_back(create_location(node_name, func_name));
1708     }
1709     // store the name of the node_def
1710     node_locations.push_back(
1711         create_location(node.name(), function_name_for_debug_info_));
1712     return mlir::FusedLoc::get(node_locations, context_);
1713   }
1714 }
1715 
EmitErrorWithLocationStr(const Node & node,const Status & error_status)1716 Status ImporterBase::EmitErrorWithLocationStr(const Node& node,
1717                                               const Status& error_status) {
1718   const mlir::Location location = GetLocation(node);
1719   mlir::emitError(location);
1720   return error_handler_.Combine(error_status);
1721 }
1722 
CreateOperation(const Node & node,llvm::StringRef node_type_name,const mlir::OperationState & result,const llvm::SmallVectorImpl<mlir::Value> & control_operands)1723 mlir::Operation* ImporterBase::CreateOperation(
1724     const Node& node, llvm::StringRef node_type_name,
1725     const mlir::OperationState& result,
1726     const llvm::SmallVectorImpl<mlir::Value>& control_operands) {
1727   // For the tf.executor specific operations (not wrapped in an island), we
1728   // have an extra returned value for the control result, and we concatenate
1729   // control and non-control operands.
1730   mlir::SmallVector<mlir::Type, 4> types(result.types);
1731   types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext()));
1732   mlir::SmallVector<mlir::Value, 4> operands(result.operands);
1733   operands.append(control_operands.begin(), control_operands.end());
1734 
1735   auto loc = result.location;
1736   // Dispatch based on the name and create the appropriate operation.
1737   if (node.IsSwitch()) {
1738     // Switch and _SwitchN both are in switch class, differentiate based on
1739     // op name.
1740     if (node.op_def().name() == "_SwitchN") {
1741       return builder_.create<mlir::tf_executor::SwitchNOp>(loc, types, operands,
1742                                                            result.attributes);
1743     }
1744     return builder_.create<mlir::tf_executor::SwitchOp>(loc, types, operands,
1745                                                         result.attributes);
1746   }
1747   if (node.IsMerge()) {
1748     return builder_.create<mlir::tf_executor::MergeOp>(loc, types, operands,
1749                                                        result.attributes);
1750   }
1751   if (node.IsNextIteration()) {
1752     // NextIteration is a bit special, we create a pair of operations that are
1753     // linked together through a token returned by the source.
1754     // We make use of a separate builder to insert the source at the top of
1755     // the block.
1756     mlir::OpBuilder builder_at_begin(builder_.getBlock(),
1757                                      builder_.getBlock()->begin());
1758     auto source_op =
1759         builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>(
1760             loc, operands[0].getType(), result.attributes);
1761     return builder_.create<mlir::tf_executor::NextIterationSinkOp>(
1762         loc, source_op.token(), operands, result.attributes);
1763   }
1764   if (node.IsLoopCond()) {
1765     return builder_.create<mlir::tf_executor::LoopCondOp>(loc, types, operands,
1766                                                           result.attributes);
1767   }
1768   if (node.IsEnter()) {
1769     return builder_.create<mlir::tf_executor::EnterOp>(loc, types, operands,
1770                                                        result.attributes);
1771   }
1772   if (node.IsExit()) {
1773     return builder_.create<mlir::tf_executor::ExitOp>(loc, types, operands,
1774                                                       result.attributes);
1775   }
1776   if (node.IsControlTrigger()) {
1777     return builder_.create<mlir::tf_executor::ControlTriggerOp>(
1778         loc, operands, result.attributes);
1779   }
1780   // Regular TensorFlow operation are wrapped in a tf_executor.island.
1781   auto island = builder_.create<mlir::tf_executor::IslandOp>(
1782       result.location, types, control_operands,
1783       mlir::ArrayRef<mlir::NamedAttribute>{});
1784   island.body().push_back(new mlir::Block);
1785   mlir::OpBuilder island_builder =
1786       mlir::OpBuilder::atBlockEnd(&island.GetBody());
1787 
1788   // Create the operation inside the island now.
1789   mlir::Operation* inner_op = island_builder.createOperation(result);
1790 
1791   // Sets operand_segment_sizes or result_segment_sizes attribute to the op.
1792   const auto set_segment_sizes_attr =
1793       [&](const NameRangeMap& arg_ranges,
1794           const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
1795           llvm::StringRef attr_name) {
1796         std::vector<mlir::Attribute> values;
1797         values.reserve(args.size());
1798         for (const auto& arg : args) {
1799           auto range = arg_ranges.at(arg.name());
1800           values.push_back(
1801               island_builder.getI32IntegerAttr(range.second - range.first));
1802         }
1803         auto attr_type =
1804             mlir::VectorType::get(args.size(), builder_.getIntegerType(32));
1805         auto attr_value = mlir::DenseElementsAttr::get(attr_type, values);
1806         inner_op->setAttr(attr_name, attr_value);
1807       };
1808 
1809   if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>() ||
1810       inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
1811     // The op has multiple variadic operands or results.
1812     // Calculate operand and result segment sizes using the OpDef.
1813     NameRangeMap input_ranges, output_ranges;
1814     // This will fail only if the OpDef is syntactically invalid.
1815     // TODO(jpienaar): Convert this CHECK into a properly propagated error.
1816     TF_CHECK_OK(
1817         NameRangesForNode(node, node.op_def(), &input_ranges, &output_ranges));
1818     if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
1819       // Add derived "operand_segment_sizes" attr to the created operation.
1820       // TODO(b/146937733): Don't use <void> here.
1821       set_segment_sizes_attr(input_ranges, node.op_def().input_arg(),
1822                              mlir::OpTrait::AttrSizedOperandSegments<
1823                                  void>::getOperandSegmentSizeAttr());
1824     }
1825 
1826     if (inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
1827       // Add derived "result_segment_sizes" attr to the created operation.
1828       // TODO(b/146937733): Don't use <void> here.
1829       set_segment_sizes_attr(output_ranges, node.op_def().output_arg(),
1830                              mlir::OpTrait::AttrSizedResultSegments<
1831                                  void>::getResultSegmentSizeAttr());
1832     }
1833   }
1834 
1835   mlir::OperationName name = inner_op->getName();
1836   if (!name.getAbstractOperation() &&
1837       // Skip unmodelled ops that are handled differently.
1838       (node_type_name != "_Arg" && node_type_name != "_Retval")) {
1839     if (GetUnmodelledOpTypes().insert(name.getStringRef()).second) {
1840       LOG(INFO) << "Unmodelled op type `" << node.type_string() << "`"
1841                 << (node.op_def().is_stateful()
1842                         ? " is stateful but effects not modelled"
1843                         : " is not stateful but will be treated as such "
1844                           "conservatively");
1845     }
1846   }
1847 
1848   // Add the terminator for the island
1849   island_builder.create<mlir::tf_executor::YieldOp>(result.location,
1850                                                     inner_op->getResults());
1851   return island.getOperation();
1852 }
1853 
ConvertNode(const Node & node)1854 Status ImporterBase::ConvertNode(const Node& node) {
1855   if (!node.IsOp()) {
1856     // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by
1857     // Graph and don't exist in GraphDef.
1858     return Status::OK();
1859   }
1860 
1861   // If it is a custom OP, its definition should be found in the library. We
1862   // create the MLIR function and insert it to the module if it doesn't exist.
1863   std::string node_type_name = node.type_string();
1864   const auto* func_def = graph_flib_.Find(node_type_name);
1865   bool convert_to_legacy_call = false;
1866   if (func_def) {
1867     TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name));
1868     node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
1869     convert_to_legacy_call = true;
1870   }
1871 
1872   auto get_full_op_name = [&](const std::string& op_name) {
1873     const char* kTfPrefix = "tf.";
1874     return kTfPrefix + op_name;
1875   };
1876 
1877   std::string op_name = get_full_op_name(node_type_name);
1878   if (back_edge_node_output_.contains(&node)) {
1879     op_name = op_name + ".sink";
1880   }
1881 
1882   mlir::OperationState result(GetLocation(node), op_name);
1883   for (int i = 0; i < node.num_outputs(); ++i) {
1884     // The backedge has been removed, so we shouldn't count the corresponding
1885     // output from the src node when converting to an operation.
1886     if (back_edge_node_output_.contains(&node) &&
1887         back_edge_node_output_[&node] == i) {
1888       continue;
1889     }
1890     TF_ASSIGN_OR_RETURN(auto type, InferOutputType(node, i, builder_));
1891     result.types.push_back(type);
1892   }
1893 
1894   // Surprisingly input edges can be nondeterministically ordered. This
1895   // particularly seems to be the case for the control edges between _SOURCE
1896   // and _SINK that the Graph constructor inserts. Copy the input edges and
1897   // sort the edges, but only the control edges, not data edges!
1898   // TODO(jmolloy): We should probably just ignore _SOURCE and _SINK nodes.
1899   // They'll break roundtripping anyway unless we strip them when converting
1900   // back to graphdef.
1901   absl::InlinedVector<const Edge*, 8> in_edges(node.in_edges().size());
1902   absl::c_copy(node.in_edges(), in_edges.begin());
1903   absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) {
1904     if (e1->IsControlEdge() && !e2->IsControlEdge()) return false;
1905     if (!e1->IsControlEdge() && e2->IsControlEdge()) return true;
1906     if (e1->IsControlEdge() && e2->IsControlEdge())
1907       return e1->src()->id() < e2->src()->id();
1908     return e1->dst_input() < e2->dst_input();
1909   });
1910 
1911   result.operands.reserve(in_edges.size());
1912 
1913   // Collect the control operands separately, they will be held by the island.
1914   mlir::SmallVector<mlir::Value, 8> control_operands;
1915 
1916   for (const auto* input_edge : in_edges) {
1917     const Node& input_node = *input_edge->src();
1918     if (input_node.IsSource()) {
1919       if (in_edges.size() != 1) {
1920         return errors::FailedPrecondition(
1921             "The node has other inputs besides the _Source node");
1922       }
1923       // We don't import the _SOURCE node.
1924       continue;
1925     }
1926     if (input_node.IsArg() && input_edge->IsControlEdge()) {
1927       // Currently we have not reached consensus as to what TF function
1928       // semantics are (b/133509504). Here we assume that all arguments to a
1929       // function should be available before we start execution of any internal
1930       // node. This makes the control dependencies between function arguments
1931       // and internal nodes redundant, and so we do not import them. The TF
1932       // inliner however assumes no such dependency between function args and
1933       // internal nodes exists, unless explicitly stated. Since we drop control
1934       // dependencies here, it leads to loss of information. If the function is
1935       // inlined later, the inliner would not know of these explicit control
1936       // dependencies present in the original graph.
1937       continue;
1938     }
1939     if (node_values_.find(input_node.id()) == node_values_.end())
1940       return errors::FailedPrecondition(
1941           "Graph not traversed in reverse post order; use seen before def!");
1942     mlir::Operation* inst = node_values_[input_node.id()];
1943     if (input_edge->IsControlEdge())
1944       control_operands.push_back(inst->getResult(inst->getNumResults() - 1));
1945     else
1946       result.operands.push_back(inst->getResult(input_edge->src_output()));
1947   }
1948 
1949   using FuncPairType = std::pair<const std::string*, const AttrValue*>;
1950   std::vector<FuncPairType> funcs;
1951   result.attributes.reserve(node.attrs().size() + 2);
1952   auto abstract_op = result.name.getAbstractOperation();
1953   auto derived_op =
1954       abstract_op
1955           ? abstract_op->getInterface<mlir::DerivedAttributeOpInterface>()
1956           : nullptr;
1957   for (const auto& name_and_value : node.attrs()) {
1958     const auto& attr_name = name_and_value.first;
1959     // Skip adding derived attributes to the generated op.
1960     if (derived_op && derived_op->isDerivedAttribute(attr_name)) continue;
1961     const AttrValue& attr_value = name_and_value.second;
1962 
1963     // Remove _output_shapes attribute that will be added by the exporter.
1964     if (IsOutputShapesAttribute(attr_value, attr_name)) continue;
1965 
1966     if (attr_value.value_case() == AttrValue::kFunc) {
1967       // Attribute iteration order is not defined for protocol buffer Map.
1968       // Process function attributes separately in the lexicographical order to
1969       // have deterministic order of functions in the constructed IR.
1970       funcs.emplace_back(&attr_name, &attr_value);
1971     } else {
1972       TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value));
1973       result.attributes.push_back(builder_.getNamedAttr(attr_name, attr));
1974     }
1975   }
1976 
1977   auto comparator = [](const FuncPairType& a, const FuncPairType& b) {
1978     return *a.first < *b.first;
1979   };
1980   std::sort(funcs.begin(), funcs.end(), comparator);
1981   for (const auto& func : funcs) {
1982     TF_RETURN_IF_ERROR(ConvertFunctionCallAttribute(*func.first, *func.second,
1983                                                     &result.attributes));
1984   }
1985 
1986   const auto& node_def = node.def();
1987   result.attributes.push_back(builder_.getNamedAttr(
1988       "device", builder_.getStringAttr(std::string(node_def.device()))));
1989 
1990   // Map user function calls to LegacyCall ops and add the user function name
1991   // as an attribute.
1992   if (convert_to_legacy_call) {
1993     result.name = mlir::OperationName(get_full_op_name("LegacyCall"), context_);
1994     mlir::SymbolRefAttr val = builder_.getSymbolRefAttr(node_type_name);
1995     result.addAttribute("f", val);
1996 
1997     if (!result.attributes.get("_disable_call_shape_inference")) {
1998       result.addAttribute("_disable_call_shape_inference",
1999                           builder_.getBoolAttr(false));
2000     }
2001   }
2002 
2003   auto composite_control_flow_op = [&](const std::string& name) {
2004     result.name = mlir::OperationName(get_full_op_name(name), context_);
2005     bool stateless = absl::StartsWith(node_type_name, "Stateless");
2006     mlir::BoolAttr val = builder_.getBoolAttr(stateless);
2007     result.attributes.push_back(builder_.getNamedAttr("is_stateless", val));
2008   };
2009 
2010   // Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common
2011   // Case/If/While op in MLIR and add the differentiating attribute.
2012   if (node.IsCaseNode()) composite_control_flow_op("Case");
2013   if (node.IsIfNode()) composite_control_flow_op("If");
2014   if (node.IsWhileNode()) {
2015     composite_control_flow_op("While");
2016     auto* output_shapes = node.attrs().Find("output_shapes");
2017     if (output_shapes && !output_shapes->list().shape().empty())
2018       result.attributes.push_back(
2019           builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr()));
2020   }
2021 
2022   // Register the mapping between the TF node and the newly created operation.
2023   node_values_[node.id()] =
2024       CreateOperation(node, node_type_name, result, control_operands);
2025   return Status::OK();
2026 }
2027 
2028 // Add the backedges to the CFG. Given a backedge, we replace the original
2029 // source and destination operations by two new operations. Most of the
2030 // fields of the replacements are copied from the original operations.
2031 // However,
2032 // - for the src operation, one output is inserted to the front of the output
2033 //   list. The type of the output is set to the type of the non-control result
2034 //   of the dst operation, and
2035 // - for the dst operation, one operand is inserted to the front of the
2036 //   operand list. This operand is using the first result of the src
2037 //   operation.
2038 // TODO(fengliuai): Preserve the order of the results and operands if
2039 // necessary.
AddBackedges()2040 Status ImporterBase::AddBackedges() {
2041   for (auto it : back_edge_dst_inputs_) {
2042     BackEdge& edge = it.second;
2043     if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) {
2044       return errors::FailedPrecondition(
2045           "Invalid backedge; should be from NextIteration to Merge!");
2046     }
2047     auto* sink = node_values_[edge.src->id()];
2048     auto* dst = node_values_[edge.dst->id()];
2049     TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input));
2050   }
2051   return Status::OK();
2052 }
2053 
AddBackedge(mlir::Operation * sink,mlir::Operation * dst,int dst_input)2054 Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
2055                                  int dst_input) {
2056   // Get the NextIteration.Source operation from the token operand of the sink.
2057   mlir::Operation* source = sink->getOperand(0).getDefiningOp();
2058 
2059   // Adds the "source" to the operands of the dst by creating a new dst
2060   // operation.
2061   mlir::OperationState state(dst->getLoc(), dst->getName());
2062   auto num_operands = dst->getNumOperands();
2063   state.operands.reserve(num_operands + 1);
2064   for (int input = 0, e = num_operands + 1; input != e; ++input) {
2065     if (input < dst_input) {
2066       state.operands.push_back(dst->getOperand(input));
2067     } else if (input == dst_input) {
2068       state.operands.push_back(source->getResult(0));
2069     } else {
2070       state.operands.push_back(dst->getOperand(input - 1));
2071     }
2072   }
2073   state.attributes.assign(dst->getAttrs().begin(), dst->getAttrs().end());
2074   state.types.assign(dst->getResultTypes().begin(),
2075                      dst->getResultTypes().end());
2076   builder_.setInsertionPoint(dst);
2077   auto* new_dst = builder_.createOperation(state);
2078 
2079   // Replaces the output uses of the old operation by the corresponding
2080   // result of the new operation, and deletes the old operation.
2081   for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) {
2082     auto new_output = new_dst->getResult(i);
2083     dst->getResult(i).replaceAllUsesWith(new_output);
2084   }
2085   dst->dropAllReferences();
2086   dst->erase();
2087   return Status::OK();
2088 }
2089 
InferLibFunctionType(const FunctionBody & fbody)2090 StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
2091     const FunctionBody& fbody) {
2092   mlir::Builder builder(context_);
2093 
2094   // The FunctionBody contains a graph with a single-output _Arg node for each
2095   // function argument and a single-input _Retval node for each function return
2096   // value.
2097   //
2098   // We already populated the ShapeRefiner with all the information about the
2099   // shapes of these graph edges, so we just query it to build the corresponding
2100   // MLIR function type signature.
2101 
2102   llvm::SmallVector<mlir::Type, 4> arg_types;
2103   if (specs_.inputs.empty()) {
2104     arg_types.reserve(fbody.arg_types.size());
2105     for (auto arg : fbody.arg_nodes) {
2106       // Find node in the graph using the node id instead of using `arg`
2107       // directly because the graph has been cloned.
2108       auto* node = graph_->FindNodeId(arg->id());
2109       TF_ASSIGN_OR_RETURN(auto type,
2110                           InferOutputType(*node, /*idx=*/0, builder));
2111       arg_types.push_back(type);
2112     }
2113   } else {
2114     arg_types.reserve(fbody.arg_types.size());
2115     for (const auto& it : llvm::enumerate(specs_.inputs)) {
2116       mlir::Type element_type;
2117       const auto& node_info = it.value().second;
2118       DataType dtype = node_info.imported_dtype;
2119       // Uses the existing output type of the arg node if the data type of the
2120       // the node isn't specified through the import configuration.
2121       if (dtype == DT_INVALID) {
2122         auto arg = fbody.arg_nodes[it.index()];
2123         auto* node = graph_->FindNodeId(arg->id());
2124         dtype = node->output_type(0);
2125         if (dtype == DT_INVALID) {
2126           return errors::InvalidArgument("Input ", it.index(),
2127                                          "has invalid data type");
2128         }
2129       }
2130       TF_RETURN_IF_ERROR(
2131           ::tensorflow::ConvertDataType(dtype, builder, &element_type));
2132       if (node_info.shape.unknown_rank()) {
2133         arg_types.push_back(mlir::UnrankedTensorType::get(element_type));
2134       } else {
2135         llvm::SmallVector<int64_t, 4> shape;
2136         TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
2137         arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
2138       }
2139     }
2140   }
2141 
2142   llvm::SmallVector<mlir::Type, 4> ret_types;
2143   ret_types.reserve(fbody.ret_types.size());
2144   for (auto ret : fbody.ret_nodes) {
2145     // Find node in the graph using the node id instead of using `ret` directly
2146     // because the graph has been cloned.
2147     auto* node = graph_->FindNodeId(ret->id());
2148     TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder));
2149     ret_types.push_back(type);
2150   }
2151 
2152   return builder.getFunctionType(arg_types, ret_types);
2153 }
2154 
2155 // Stateful helper class to import a TensorFlow model expressed in GraphDef into
2156 // an MLIR Module.
2157 //
2158 // The nodes defined in the graph are converted to a function called
2159 // 'func_name'. All library function definitions are converted to MLIR functions
2160 // in the module.
2161 class GraphDefImporter : public ImporterBase {
2162  public:
2163   // Main entry point: converts the given graph to an MLIR Module.
2164   static StatusOr<mlir::OwningModuleRef> Convert(
2165       mlir::MLIRContext* context, const Graph& graph,
2166       const GraphDebugInfo& debug_info,
2167       const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
2168       llvm::StringRef func_name);
2169 
2170  private:
GraphDefImporter(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier)2171   explicit GraphDefImporter(
2172       const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
2173       const GraphImportConfig& specs, mlir::ModuleOp module,
2174       std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
2175       NameUniquifier* function_name_uniquifier)
2176       : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
2177                      function_name_uniquifier) {}
2178 
2179   // Returns the function signature of the main function of converted MLIR
2180   // module, the input nodes and output nodes. The type and shape information
2181   // for the function arguments are read from `specs`, but the type and shape
2182   // information for the function returns are inferred by the shape refiner in
2183   // ImporterBase.
2184   StatusOr<mlir::FunctionType> InferMainFunctionType(
2185       const GraphImportConfig& specs, mlir::MLIRContext* context,
2186       absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2187       absl::InlinedVector<OutputTensor, 4>* ret_nodes);
2188 
2189   // Returns the function signature of the main function, alongside input and
2190   // output nodes, for function graphs. Arguments and return values are
2191   // determined by node op type. Type and shape information of the function are
2192   // inferred by the shape refiner in ImporterBase.
2193   StatusOr<mlir::FunctionType> GetArgsRetsAndTypesFromFunctionGraph(
2194       mlir::MLIRContext* context,
2195       absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2196       absl::InlinedVector<OutputTensor, 4>* ret_nodes);
2197 
2198   // Finds the graph's target nodes/function's control ret nodes based on
2199   // supplied node names in `control_outputs`. If `control_outputs` are not
2200   // unique or a control ret node is missing, an error will be returned.
2201   Status GetControlRetsFromGraph(
2202       llvm::ArrayRef<std::string> control_outputs,
2203       absl::InlinedVector<Node*, 4>* control_ret_nodes);
2204 };
2205 
Convert(mlir::MLIRContext * context,const Graph & graph,const GraphDebugInfo & debug_info,const FunctionLibraryDefinition & flib_def,const GraphImportConfig & specs,llvm::StringRef func_name)2206 StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
2207     mlir::MLIRContext* context, const Graph& graph,
2208     const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
2209     const GraphImportConfig& specs, llvm::StringRef func_name) {
2210   LoadImporterDialects(*context);
2211   mlir::OwningModuleRef module =
2212       mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
2213   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
2214   NameUniquifier function_name_uniquifier(flib_def);
2215 
2216   GraphDefImporter importer(flib_def, debug_info, specs, module.get(),
2217                             &tf_name_to_mlir_name, &function_name_uniquifier);
2218 
2219   TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
2220 
2221   mlir::FunctionType func_type;
2222   absl::InlinedVector<OutputTensor, 4> arg_nodes;
2223   absl::InlinedVector<OutputTensor, 4> ret_nodes;
2224   absl::InlinedVector<Node*, 4> control_ret_nodes;
2225   llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
2226   if (specs.graph_as_function) {
2227     if (specs.prune_unused_nodes || !specs.inputs.empty() ||
2228         !specs.outputs.empty())
2229       return errors::InvalidArgument(
2230           "Pruning of graph is currently unsupported when the main graph is "
2231           "converted to a function.");
2232 
2233     TF_ASSIGN_OR_RETURN(func_type,
2234                         importer.GetArgsRetsAndTypesFromFunctionGraph(
2235                             context, &arg_nodes, &ret_nodes));
2236 
2237     TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
2238                                                         &control_ret_nodes));
2239 
2240     mlir::Builder b(context);
2241     std::string s;
2242     llvm::raw_string_ostream ss(s);
2243     auto node_name = [&](const OutputTensor& tensor) {
2244       ss << tensor.node->name();
2245     };
2246     llvm::interleave(arg_nodes, ss, node_name, ",");
2247     auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
2248     s.clear();
2249     llvm::interleave(ret_nodes, ss, node_name, ",");
2250     auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
2251     s.clear();
2252     llvm::interleave(specs.control_outputs, ss, ",");
2253     auto control_outputs =
2254         b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
2255 
2256     // Under `graph_as_function` mode, `tf.entry_function` is always set as it
2257     // is assumed feed, fetch, and target nodes are set correctly.
2258     attrs.push_back(b.getNamedAttr(
2259         "tf.entry_function",
2260         b.getDictionaryAttr({inputs, outputs, control_outputs})));
2261   } else {
2262     // Collects the argument and return nodes by looking up the node names
2263     // specified by the user.
2264     TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType(
2265                                        specs, context, &arg_nodes, &ret_nodes));
2266 
2267     TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
2268                                                         &control_ret_nodes));
2269 
2270     // TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and
2271     // decoding in a centralized place.
2272     // Record the input and output mapping.
2273     if (!specs.inputs.empty() || !specs.outputs.empty() ||
2274         !specs.control_outputs.empty()) {
2275       mlir::Builder b(context);
2276       std::string s;
2277       llvm::raw_string_ostream ss(s);
2278       llvm::interleave(
2279           specs.inputs, ss,
2280           [&](const std::pair<std::string, ArrayInfo>& v) { ss << v.first; },
2281           ",");
2282       auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
2283       s.clear();
2284       llvm::interleave(specs.outputs, ss, ",");
2285       auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
2286       s.clear();
2287       llvm::interleave(specs.control_outputs, ss, ",");
2288       auto control_outputs =
2289           b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
2290 
2291       attrs.push_back(b.getNamedAttr(
2292           "tf.entry_function",
2293           b.getDictionaryAttr({inputs, outputs, control_outputs})));
2294     }
2295   }
2296 
2297   // Record version info.
2298   PopulateTfVersions(module.get(), graph.versions());
2299 
2300   TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
2301       func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs));
2302 
2303   // Mark main function public, others private.
2304   for (auto function : module.get().getOps<mlir::FuncOp>()) {
2305     auto visibility = function.getName() == func_name
2306                           ? mlir::FuncOp::Visibility::Public
2307                           : mlir::FuncOp::Visibility::Private;
2308     function.setVisibility(visibility);
2309   }
2310   return module;
2311 }
2312 
InferMainFunctionType(const GraphImportConfig & specs,mlir::MLIRContext * context,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes)2313 StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
2314     const GraphImportConfig& specs, mlir::MLIRContext* context,
2315     absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2316     absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
2317   // Find all the input nodes and output nodes.
2318   // Feeds have been remapped to single output nodes (Placeholder), so an exact
2319   // name match is sufficient.
2320   absl::flat_hash_map<absl::string_view, int> inputs;
2321   for (auto input_and_idx : llvm::enumerate(specs.inputs)) {
2322     TensorId tensor = ParseTensorName(input_and_idx.value().first);
2323     auto remapped_it = remapped_feeds_.find(tensor);
2324     if (remapped_it != remapped_feeds_.end()) {
2325       inputs.insert({remapped_it->second, input_and_idx.index()});
2326     } else {
2327       inputs.insert({tensor.node(), input_and_idx.index()});
2328     }
2329   }
2330 
2331   absl::flat_hash_set<absl::string_view> output_node_names;
2332   std::vector<TensorId> outputs;
2333   output_node_names.reserve(specs.outputs.size());
2334   for (const auto& output : specs.outputs) {
2335     TensorId tensor = ParseTensorName(output);
2336     auto remapped_it = remapped_feeds_.find(tensor);
2337     if (remapped_it != remapped_feeds_.end()) {
2338       output_node_names.insert(remapped_it->second);
2339       outputs.push_back({remapped_it->second, 0});
2340     } else {
2341       output_node_names.insert(tensor.node());
2342       outputs.push_back(tensor);
2343     }
2344   }
2345 
2346   if (!inputs.empty() || !outputs.empty()) {
2347     arg_nodes->resize(inputs.size());
2348     ret_nodes->resize(outputs.size());
2349 
2350     for (Node* n : GetOrderedNodes()) {
2351       // Handle inputs/arguments.
2352       auto input_it = inputs.find(n->name());
2353       if (input_it != inputs.end()) {
2354         (*arg_nodes)[input_it->second] = {n, 0};
2355       }
2356 
2357       // Handle outputs/returns.
2358       if (output_node_names.contains(n->name())) {
2359         for (int i = 0, e = outputs.size(); i != e; ++i) {
2360           TensorId tensor = outputs[i];
2361           if (n->name() != tensor.node()) continue;
2362           (*ret_nodes)[i] = {n, tensor.index()};
2363         }
2364       }
2365     }
2366   }
2367 
2368   // Starts to construct the function type.
2369   mlir::Builder builder(context);
2370   llvm::SmallVector<mlir::Type, 4> arg_types;
2371   arg_types.reserve(specs.inputs.size());
2372   int i = 0;
2373   for (const auto& it : specs.inputs) {
2374     Node* arg_node = arg_nodes->at(i).node;
2375     if (arg_node == nullptr) {
2376       return errors::InvalidArgument("Input ", it.first,
2377                                      " was not found in graph");
2378     }
2379     mlir::Type element_type;
2380     const auto& node_info = it.second;
2381     DataType imported_dtype = node_info.imported_dtype;
2382     // Uses the existing output type of the arg node if the data type of the
2383     // the node isn't specified through the import configuration.
2384     if (imported_dtype == DT_INVALID) {
2385       imported_dtype = arg_node->output_type(0);
2386       if (imported_dtype == DT_INVALID) {
2387         return errors::InvalidArgument("Input ", i, "has invalid data type");
2388       }
2389     }
2390     TF_RETURN_IF_ERROR(
2391         ::tensorflow::ConvertDataType(imported_dtype, builder, &element_type));
2392     if (node_info.shape.unknown_rank()) {
2393       arg_types.push_back(mlir::UnrankedTensorType::get(element_type));
2394     } else {
2395       llvm::SmallVector<int64_t, 4> shape;
2396       TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
2397       arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
2398     }
2399     i++;
2400   }
2401 
2402   llvm::SmallVector<mlir::Type, 4> ret_types;
2403   ret_types.reserve(specs.outputs.size());
2404   for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
2405     if (ret_nodes->at(i).node == nullptr) {
2406       return errors::InvalidArgument("Output ", specs.outputs[i],
2407                                      " was not found in graph");
2408     }
2409   }
2410   for (const auto& ret : *ret_nodes) {
2411     if (ret.node->num_outputs() <= ret.index) {
2412       return errors::InvalidArgument("Invalid output index ", ret.index,
2413                                      " specified for node: ", ret.node->name());
2414     }
2415     TF_ASSIGN_OR_RETURN(auto type,
2416                         InferOutputType(*ret.node, ret.index, builder));
2417     ret_types.push_back(type);
2418   }
2419 
2420   return builder.getFunctionType(arg_types, ret_types);
2421 }
2422 
2423 StatusOr<mlir::FunctionType>
GetArgsRetsAndTypesFromFunctionGraph(mlir::MLIRContext * context,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes)2424 GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
2425     mlir::MLIRContext* context, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2426     absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
2427   auto add_node = [](Node* node, absl::InlinedVector<OutputTensor, 4>* nodes) {
2428     auto* attr = node->attrs().Find("index");
2429     if (!attr)
2430       return errors::InvalidArgument(node->type_string(), " node '",
2431                                      node->name(),
2432                                      "' is missing attribute 'index'");
2433 
2434     auto index = attr->i();
2435     const int num_nodes = nodes->size();
2436     if (num_nodes < index + 1) nodes->resize(index + 1);
2437 
2438     if ((*nodes)[index].node != nullptr)
2439       return errors::InvalidArgument(node->type_string(), " node '",
2440                                      node->name(), "' has attribute 'index' ",
2441                                      index, " that conflicts with node '",
2442                                      (*nodes)[index].node->name(), "'");
2443     (*nodes)[index] = {node, 0};
2444 
2445     return Status::OK();
2446   };
2447 
2448   // Collect arg and ret nodes from graph.
2449   for (auto* node : GetOrderedNodes())
2450     if (node->IsArg())
2451       TF_RETURN_IF_ERROR(add_node(node, arg_nodes));
2452     else if (node->IsRetval())
2453       TF_RETURN_IF_ERROR(add_node(node, ret_nodes));
2454 
2455   // Collect arg and ret types and create function type.
2456   mlir::Builder builder(context);
2457   llvm::SmallVector<mlir::Type, 4> arg_types;
2458   arg_types.reserve(arg_nodes->size());
2459   for (auto arg_node_and_idx : llvm::enumerate(*arg_nodes)) {
2460     auto& arg_node = arg_node_and_idx.value();
2461     if (arg_node.node == nullptr)
2462       return errors::InvalidArgument("Graph missing _Arg at index ",
2463                                      arg_node_and_idx.index());
2464 
2465     TF_ASSIGN_OR_RETURN(auto type,
2466                         InferOutputType(*arg_node.node, /*idx=*/0, builder));
2467     arg_types.push_back(type);
2468   }
2469 
2470   llvm::SmallVector<mlir::Type, 4> ret_types;
2471   ret_types.reserve(ret_nodes->size());
2472   for (auto ret_node_and_idx : llvm::enumerate(*ret_nodes)) {
2473     auto& ret_node = ret_node_and_idx.value();
2474     if (ret_node.node == nullptr)
2475       return errors::InvalidArgument("Graph missing _Retval at index ",
2476                                      ret_node_and_idx.index());
2477 
2478     TF_ASSIGN_OR_RETURN(auto type,
2479                         InferInputType(*ret_node.node, /*idx=*/0, builder));
2480     ret_types.push_back(type);
2481   }
2482 
2483   return builder.getFunctionType(arg_types, ret_types);
2484 }
2485 
GetControlRetsFromGraph(llvm::ArrayRef<std::string> control_outputs,absl::InlinedVector<Node *,4> * control_ret_nodes)2486 Status GraphDefImporter::GetControlRetsFromGraph(
2487     llvm::ArrayRef<std::string> control_outputs,
2488     absl::InlinedVector<Node*, 4>* control_ret_nodes) {
2489   if (control_outputs.empty()) return Status::OK();
2490 
2491   llvm::SmallDenseMap<llvm::StringRef, int32_t> controls_to_idx;
2492   for (auto control_and_idx : llvm::enumerate(control_outputs))
2493     controls_to_idx.insert({control_and_idx.value(), control_and_idx.index()});
2494 
2495   if (controls_to_idx.size() != control_outputs.size())
2496     return errors::InvalidArgument("Control outputs must be unique");
2497 
2498   control_ret_nodes->resize(controls_to_idx.size());
2499 
2500   for (auto* node : GetOrderedNodes()) {
2501     auto it = controls_to_idx.find(node->name());
2502     if (it != controls_to_idx.end()) (*control_ret_nodes)[it->second] = node;
2503   }
2504 
2505   for (auto node_and_name : llvm::zip(*control_ret_nodes, control_outputs))
2506     if (std::get<0>(node_and_name) == nullptr)
2507       return errors::InvalidArgument(
2508           "Control output '", std::get<1>(node_and_name), "' is missing");
2509 
2510   return Status::OK();
2511 }
2512 
2513 // Stateful helper class to import a TensorFlow model expressed in SavedModel
2514 // into an MLIR Module.
2515 class SavedModelObjectGraphImporter : public ImporterBase {
2516  public:
2517   // Main entry point: converts all functions in the given meta graph to an MLIR
2518   // Module.
2519   static StatusOr<mlir::OwningModuleRef> Convert(
2520       SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
2521       mlir::MLIRContext* context, bool add_default_attributes);
2522 
2523  private:
SavedModelObjectGraphImporter(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier)2524   explicit SavedModelObjectGraphImporter(
2525       const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
2526       const GraphImportConfig& specs, mlir::ModuleOp module,
2527       std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
2528       NameUniquifier* function_name_uniquifier)
2529       : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
2530                      function_name_uniquifier) {}
2531 };
2532 
2533 // Determines the names used to reference objects in the SavedObjectGraph.
2534 class ObjectNames {
2535  public:
2536   explicit ObjectNames(const SavedObjectGraph& object_graph,
2537                        absl::Span<std::string> exported_names);
2538 
2539   // Gets the names that external users of the SavedModel can use to refer to
2540   // this node.
2541   llvm::ArrayRef<llvm::StringRef> GetExportedNames(int node_id) const;
2542 
2543   // Gets the name in the module symbol table for this node.
2544   // This name is only used for internal IR references.
2545   llvm::StringRef GetSymbolTableName(int node_id) const;
2546 
2547  private:
2548   // In the absence of any other information, use this name as the symbol table
2549   // name for this node.
2550   std::string GetDefaultSymbolTableName(int node_id) const;
2551   // Determines if a name is exported.
2552   bool IsExported(const std::string& name);
2553   // Main object graph traversal function.
2554   void RecursivelyVisitObjectGraph(int node_id);
2555   // Gets a stable StringRef from a std::string.
2556   llvm::StringRef SaveString(const std::string& s) const;
2557 
2558   // The object graph we are traversing.
2559   const SavedObjectGraph& object_graph_;
2560   // The set of names to export. Empty means "export all".
2561   std::unordered_set<std::string> names_to_export_;
2562 
2563   // When we recursively follow the object graph tree structure from the root,
2564   // we track its path in the object graph by pushing and popping from here
2565   // during traversal.
2566   llvm::SmallVector<std::string, 8> path_segments_;
2567   // The set of node_id's that are on the current DFS stack.
2568   // For cyclic object graphs, this prevents infinite recursion.
2569   std::unordered_set<int> on_stack_nodes_;
2570 
2571   // Key: node_id.
2572   // Value: all object names that node_id appears as.
2573   // Each object name corresponds to a unique path from the root of the object
2574   // graph.
2575   // The common intuitive case is when there is only one name for a given
2576   // object, which corresponds to the object graph being a tree.
2577   //
2578   // But, there cases where the object graph is a general graph. For
2579   // example, this happens commonly in Keras models, where `foo.bar` is
2580   // also reachable via the name `keras_api.foo.bar`.
2581   // Cycles are possible too.
2582   absl::flat_hash_map<int, std::vector<std::string>> object_names_;
2583 
2584   // Key: node_id
2585   // Value: all names that this object is exported as
2586   absl::flat_hash_map<int, llvm::SmallVector<llvm::StringRef, 1>>
2587       exported_names_;
2588   // Key: node_id
2589   // Value: pretty symbol table name to use for internal references to this
2590   // object.
2591   absl::flat_hash_map<int, llvm::StringRef> pretty_symbol_table_name_;
2592 
2593   // Stable strings we can take StringRef's into. Used only by the SaveString
2594   // method.
2595   mutable std::unordered_set<std::string> saved_strings_;
2596 };
2597 
ObjectNames(const SavedObjectGraph & object_graph,absl::Span<std::string> exported_names)2598 ObjectNames::ObjectNames(const SavedObjectGraph& object_graph,
2599                          absl::Span<std::string> exported_names)
2600     : object_graph_(object_graph),
2601       names_to_export_(exported_names.begin(), exported_names.end()) {
2602   // Visit all reachable nodes from the root of the object graph.
2603   // This builds up object_names_ to contain all names like `foo.bar` that a
2604   // particular node in the graph can be reached from.
2605   RecursivelyVisitObjectGraph(/*node_id=*/0);
2606 
2607   // Populate the exported_names_ map.
2608   // TODO(silvasean): Diagnose typos in exported names?
2609   for (auto& kv : object_names_) {
2610     // Make object names map independent of our particular choice of object
2611     // graph traversal.
2612     std::sort(kv.second.begin(), kv.second.end(),
2613               [](absl::string_view a, absl::string_view b) {
2614                 // The sort order here influences the "pretty name" we assign
2615                 // below. We want the most debuggable name to be first.
2616                 //
2617                 // Debuggability heuristics:
2618                 // 1. Names that end in digits are likely to be internal aliases
2619                 // to the "real" names.
2620                 // 2. Longer names are more likely to be internal aliases.
2621                 //
2622                 // Example set of object names created by Keras for the weight
2623                 // matrix of a fully connected layer on a trivial FC mnist
2624                 // model:
2625                 // - `model.layer-1.kernel` (this is the "best" name)
2626                 // - `model.keras_api.layers.1.kernel`
2627                 // - `model.variables.0`
2628                 // - `model.keras_api.layers.1.keras_api.trainable_variables.0`
2629                 // - ... 10 more long aliases ending in digits ...
2630                 return std::make_tuple(isdigit(a.back()), a.size(), a) <
2631                        std::make_tuple(isdigit(b.back()), b.size(), b);
2632               });
2633     for (const std::string& name : kv.second) {
2634       if (IsExported(name)) {
2635         exported_names_[kv.first].push_back(SaveString(name));
2636       }
2637     }
2638   }
2639   // Create "pretty" symbol table names for nodes where that is applicable.
2640   // We could make all symbol table names use the default, which is basically
2641   // just the node id. But for debugging purposes, it's nicer if we can mix in
2642   // a recognizable object name if we have the information to do so.
2643   for (auto& kv : object_names_) {
2644     int node_id = kv.first;
2645     std::string internal_name =
2646         absl::StrCat(GetDefaultSymbolTableName(node_id), "__");
2647     // If the object has an exported name, we prefer that since it is probably
2648     // the most recognizable. Otherwise, we grab some non-exported name of the
2649     // object.
2650     if (exported_names_.find(node_id) != exported_names_.end()) {
2651       internal_name += exported_names_[node_id][0].str();
2652     } else {
2653       internal_name += object_names_[node_id][0];
2654     }
2655     pretty_symbol_table_name_[node_id] = SaveString(internal_name);
2656   }
2657 }
2658 
GetExportedNames(int node_id) const2659 llvm::ArrayRef<llvm::StringRef> ObjectNames::GetExportedNames(
2660     int node_id) const {
2661   auto it = exported_names_.find(node_id);
2662   if (it != exported_names_.end()) {
2663     return it->second;
2664   }
2665   return {};
2666 }
2667 
GetSymbolTableName(int node_id) const2668 llvm::StringRef ObjectNames::GetSymbolTableName(int node_id) const {
2669   auto it = pretty_symbol_table_name_.find(node_id);
2670   if (it != pretty_symbol_table_name_.end()) {
2671     return it->second;
2672   }
2673   return SaveString(GetDefaultSymbolTableName(node_id));
2674 }
2675 
GetDefaultSymbolTableName(int node_id) const2676 std::string ObjectNames::GetDefaultSymbolTableName(int node_id) const {
2677   return absl::StrCat("__sm_node", node_id);
2678 }
2679 
IsExported(const std::string & name)2680 bool ObjectNames::IsExported(const std::string& name) {
2681   if (names_to_export_.empty()) {
2682     return true;
2683   }
2684   return names_to_export_.find(name) != names_to_export_.end();
2685 }
2686 
RecursivelyVisitObjectGraph(int node_id)2687 void ObjectNames::RecursivelyVisitObjectGraph(int node_id) {
2688   const SavedObject& object = object_graph_.nodes(node_id);
2689 
2690   switch (object.kind_case()) {
2691     case SavedObject::kConstant:
2692     case SavedObject::kFunction:
2693     case SavedObject::kVariable: {
2694       object_names_[node_id].push_back(absl::StrJoin(path_segments_, "."));
2695       break;
2696     }
2697     default:
2698       break;
2699   }
2700 
2701   for (const auto& child_ref : object.children()) {
2702     bool on_stack = !on_stack_nodes_.insert(child_ref.node_id()).second;
2703     if (on_stack) {
2704       // This is a backedge. Don't traverse it.
2705       continue;
2706     }
2707 
2708     path_segments_.push_back(child_ref.local_name());
2709     RecursivelyVisitObjectGraph(child_ref.node_id());
2710     path_segments_.pop_back();
2711 
2712     on_stack_nodes_.erase(child_ref.node_id());
2713   }
2714 }
2715 
SaveString(const std::string & s) const2716 llvm::StringRef ObjectNames::SaveString(const std::string& s) const {
2717   return llvm::StringRef(*saved_strings_.insert(s).first);
2718 }
2719 
2720 // Extracts a TensorProto for a Const op from a GraphDef, given an op_name.
2721 // Returns nullptr on not found or other mismatch.
2722 // This returns a pointer to the actual node within the graph_def so as to
2723 // avoid expensive copies.
ExtractConstTensorFromGraph(const GraphDef & graph_def,const std::string & op_name)2724 const TensorProto* ExtractConstTensorFromGraph(const GraphDef& graph_def,
2725                                                const std::string& op_name) {
2726   const NodeDef* match_node = nullptr;
2727   for (const auto& node : graph_def.node()) {
2728     if (node.name() == op_name) {
2729       match_node = &node;
2730     }
2731   }
2732 
2733   if (!match_node) {
2734     return nullptr;
2735   }
2736 
2737   auto value_it = match_node->attr().find("value");
2738   if (value_it == match_node->attr().end()) {
2739     return nullptr;
2740   }
2741 
2742   if (!value_it->second.has_tensor()) {
2743     return nullptr;
2744   }
2745 
2746   return &value_it->second.tensor();
2747 }
2748 
2749 const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(const TrackableObjectGraph::TrackableObject & trackable_object,StringPiece name)2750 FindSerializedTensorInTrackable(
2751     const TrackableObjectGraph::TrackableObject& trackable_object,
2752     StringPiece name) {
2753   for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
2754     if (maybe_serialized_tensor.name() == name) {
2755       return &maybe_serialized_tensor;
2756     }
2757   }
2758   return nullptr;
2759 }
2760 
DiagnoseMultipleConcreteFunctions(const SavedObjectGraph & object_graph,const ObjectNames & object_names)2761 Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph,
2762                                          const ObjectNames& object_names) {
2763   for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
2764     const SavedObject& object = object_graph.nodes(node_id);
2765     if (object_names.GetExportedNames(node_id).empty()) {
2766       continue;
2767     }
2768     if (object.kind_case() == SavedObject::kFunction) {
2769       // We only allow a single input signature to each SavedFunction.
2770       // This assumption means we have a 1:1 correspondence between
2771       // tf.function <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef
2772       // This makes defining the ABI easier (or even well-defined at all).
2773       // TODO(silvasean): How to detect a function that doesn't have an
2774       // explicitly user-provided input signature, but happens to have been
2775       // traced exactly once?
2776       if (object.function().concrete_functions_size() != 1) {
2777         llvm::SmallVector<std::string, 4> names;
2778         for (llvm::StringRef s : object_names.GetExportedNames(node_id)) {
2779           names.push_back("'" + s.str() + "'");
2780         }
2781         return errors::InvalidArgument(
2782             "Exported function with exported name(s) ",
2783             absl::StrJoin(names, ", "),
2784             " with multiple concrete functions. Add "
2785             "@tf.function(input_signature=[...]) on this function, or use a "
2786             "narrower list of exported names that excludes this function.");
2787       }
2788     }
2789   }
2790   return Status::OK();
2791 }
2792 
2793 // Recursively traverses a StructuredValue, linearizing all the leaves.
2794 //
2795 // This currently only handles the subset of StructuredValue that is needed for
2796 // signatures.
2797 //
2798 // Given a StructuredValue with structure [{"x": leaf0}], the "index path"
2799 // needed to reach leaf0 is `[0, "x"]`, as it would be if you were operating on
2800 // a Python object (`obj[0]["x"] is leaf0`). Each leaf corresponds to a
2801 // linearized function argument or return on a FunctionDef, and hence to an
2802 // mlir::FuncOp argument / return.
2803 //
2804 // This must match the linearization that happens in `tf.nest.flatten`.
2805 // In particular, dict values should be linearized in sorted key order.
2806 //
2807 // The linearized index paths can be returned back to a structured
2808 // representation (e.g. to emit C structs matching a signature) with a simple
2809 // algorithm that recurses on each run of index paths with identical first
2810 // elements.
2811 class StructuredValueLinearizer {
2812  public:
2813   StructuredValueLinearizer(const StructuredValue& value,
2814                             mlir::MLIRContext* context);
2815 
2816   // Returns the list of index paths to each leaf of the StructuredValue,
2817   // in a linearized order matching `tf.nest.flatten`.
2818   //
2819   // If an error occurred during the linearization process, an error message
2820   // with `error_context` prepended will be included in the returned status.
2821   StatusOr<llvm::ArrayRef<mlir::ArrayAttr>> GetLeafIndexPaths(
2822       llvm::StringRef error_context) const;
2823 
2824  private:
2825   // Main function that recursively traverses the StructuredValue.
2826   void RecursivelyFindLeaves(const StructuredValue& value);
2827 
2828   mlir::Builder builder_;
2829   // The current index path. We push/pop this during recursive traversal of the
2830   // StructuredValue.
2831   llvm::SmallVector<mlir::Attribute, 4> current_index_path_;
2832   // The list of leaf index paths we have discovered so far.
2833   llvm::SmallVector<mlir::ArrayAttr, 4> leaf_index_paths_;
2834   // If non-empty, an error message to report.
2835   std::string error_message_;
2836 };
2837 
StructuredValueLinearizer(const StructuredValue & value,mlir::MLIRContext * context)2838 StructuredValueLinearizer::StructuredValueLinearizer(
2839     const StructuredValue& value, mlir::MLIRContext* context)
2840     : builder_(context) {
2841   RecursivelyFindLeaves(value);
2842 }
2843 
2844 StatusOr<llvm::ArrayRef<mlir::ArrayAttr>>
GetLeafIndexPaths(llvm::StringRef error_context) const2845 StructuredValueLinearizer::GetLeafIndexPaths(
2846     llvm::StringRef error_context) const {
2847   if (error_message_.empty()) {
2848     return llvm::makeArrayRef(leaf_index_paths_);
2849   }
2850   return errors::InvalidArgument(
2851       error_context.str(), error_message_,
2852       "This likely means that you have @tf.function "
2853       "on an exported function instead of "
2854       "@tf.function(input_signature=[...]). Consider annotating an "
2855       "input_signature or narrowing your set of "
2856       "exported names to not include this function.");
2857 }
2858 
RecursivelyFindLeaves(const StructuredValue & value)2859 void StructuredValueLinearizer::RecursivelyFindLeaves(
2860     const StructuredValue& value) {
2861   switch (value.kind_case()) {
2862     case StructuredValue::kDictValue: {
2863       // Dict values must be linearized in sorted order of keys.
2864       const DictValue& dict = value.dict_value();
2865       using FieldTy = protobuf::MapPair<std::string, StructuredValue>;
2866       llvm::SmallVector<const FieldTy*, 4> fields;
2867       for (auto& field : dict.fields()) {
2868         fields.push_back(&field);
2869       }
2870       llvm::sort(fields, [](const FieldTy* a, const FieldTy* b) {
2871         return a->first < b->first;
2872       });
2873       for (auto& field : fields) {
2874         current_index_path_.push_back(builder_.getStringAttr(field->first));
2875         RecursivelyFindLeaves(field->second);
2876         current_index_path_.pop_back();
2877       }
2878       return;
2879     }
2880     case StructuredValue::kTupleValue: {
2881       const TupleValue& tuple = value.tuple_value();
2882       for (int i = 0, e = tuple.values_size(); i < e; i++) {
2883         current_index_path_.push_back(builder_.getI64IntegerAttr(i));
2884         RecursivelyFindLeaves(tuple.values(i));
2885         current_index_path_.pop_back();
2886       }
2887       return;
2888     }
2889     // We don't differentiate between tuples and lists.
2890     case StructuredValue::kListValue: {
2891       const ListValue& list = value.list_value();
2892       for (int i = 0, e = list.values_size(); i < e; i++) {
2893         current_index_path_.push_back(builder_.getI64IntegerAttr(i));
2894         RecursivelyFindLeaves(list.values(i));
2895         current_index_path_.pop_back();
2896       }
2897       return;
2898     }
2899     case StructuredValue::kTensorSpecValue: {
2900       // Base case: record the current path stack as the index path needed to
2901       // get to this leaf.
2902       leaf_index_paths_.push_back(builder_.getArrayAttr(current_index_path_));
2903       return;
2904     }
2905     case StructuredValue::kNoneValue: {
2906       // Base case: do nothing.
2907       // This arises, for example, as the top-level object of an output
2908       // signature when there are no return values.
2909       return;
2910     }
2911     default: {
2912       llvm::raw_string_ostream os(error_message_);
2913       // TODO(silvasean): Use an enumerant name string instead of a number.
2914       os << "Unhandled structured value kind " << value.kind_case()
2915          << " at index path: <value>";
2916       for (auto path_element : current_index_path_) {
2917         os << ".";
2918         if (auto integer = path_element.dyn_cast<mlir::IntegerAttr>()) {
2919           os << integer.getValue();
2920         } else {
2921           auto str = path_element.cast<mlir::StringAttr>();
2922           os << str.getValue();
2923         }
2924       }
2925       os << "\n";
2926     }
2927   }
2928 }
2929 
2930 // For exported functions with bound inputs, rewrite the function
2931 // signature to match the requirements of tf_saved_model bound input args.
2932 //
2933 // The raw imported functions have `tensor<*x!tf.resource>` as the type for
2934 // mutable bound inputs and `tensor<...>` as the type for immutable
2935 // bound inputs. Here we canonicalize both of them into
2936 // `tensor<!tf.resource<tensor<...>>>`.
AdjustBoundInputArgTypes(mlir::ModuleOp module)2937 void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
2938   mlir::SymbolTable symbol_table(module);
2939   for (auto func : module.getOps<mlir::FuncOp>()) {
2940     if (!mlir::tf_saved_model::IsExported(func)) continue;
2941     mlir::OpBuilder builder(func.getBody());
2942     llvm::SmallVector<mlir::Type, 4> new_input_types;
2943     for (int i = 0, e = func.getNumArguments(); i < e; i++) {
2944       auto arg = func.getArgument(i);
2945       auto global_tensor = mlir::tf_saved_model::LookupBoundInputOfType<
2946           mlir::tf_saved_model::GlobalTensorOp>(func, i, symbol_table);
2947       if (global_tensor) {
2948         auto old_type = arg.getType();
2949         auto new_type =
2950             mlir::tf_saved_model::GetBoundInputArgTypeFor(global_tensor);
2951         arg.setType(new_type);
2952         if (global_tensor.is_mutable()) {
2953           auto arg_with_original_type = builder.create<mlir::TF::CastOp>(
2954               global_tensor.getLoc(), old_type, arg,
2955               /*Truncate=*/builder.getBoolAttr(false));
2956           arg.replaceAllUsesWith(arg_with_original_type);
2957           // The RAUW replaces the arg with itself, so we need to set it back.
2958           arg_with_original_type.setOperand(arg);
2959         } else {
2960           auto arg_with_original_type =
2961               builder.create<mlir::TF::ReadVariableOp>(global_tensor.getLoc(),
2962                                                        old_type, arg);
2963           arg.replaceAllUsesWith(arg_with_original_type);
2964           // The RAUW replaces the arg with itself, so we need to set it back.
2965           arg_with_original_type.setOperand(arg);
2966         }
2967       }
2968       new_input_types.push_back(arg.getType());
2969     }
2970     func.setType(mlir::FunctionType::get(module.getContext(), new_input_types,
2971                                          func.getType().getResults()));
2972   }
2973 }
2974 
2975 // Marks the visibility of functions in the saved model module.
MarkSavedModelFunctionVisibility(mlir::ModuleOp module)2976 void MarkSavedModelFunctionVisibility(mlir::ModuleOp module) {
2977   for (auto func : module.getOps<mlir::FuncOp>()) {
2978     auto visibility = mlir::tf_saved_model::IsExported(func)
2979                           ? mlir::FuncOp::Visibility::Public
2980                           : mlir::FuncOp::Visibility::Private;
2981     func.setVisibility(visibility);
2982   }
2983 }
2984 
2985 // Reorder the ops in the module to make testing easier and less dependent
2986 // on implementation details such as the order of functions in the
2987 // FunctionDefLibrary.
2988 //
2989 // The order this ensures is:
2990 // 1. GlobalTensorOp's
2991 // 2. FuncOps's.
2992 //
2993 // Within each of 1. and 2., ops are sorted by exported name (if
2994 // available, and only the first exported name is considered), followed by
2995 // non-exported ops.
SortSavedModelModule(mlir::ModuleOp module)2996 void SortSavedModelModule(mlir::ModuleOp module) {
2997   struct NamedGlobalTensor {
2998     llvm::StringRef name;
2999     GlobalTensorOp global_tensor;
3000   };
3001   llvm::SmallVector<NamedGlobalTensor, 8> named_global_tensors;
3002   for (auto global_tensor : module.getOps<GlobalTensorOp>()) {
3003     auto exported_names = mlir::tf_saved_model::GetExportedNames(global_tensor);
3004     // We use stable_sort, so duplicate empty names are fine here.
3005     named_global_tensors.push_back(
3006         {exported_names.empty() ? "" : exported_names.front(), global_tensor});
3007   }
3008   llvm::stable_sort(named_global_tensors,
3009                     [](const NamedGlobalTensor& a, const NamedGlobalTensor& b) {
3010                       return std::make_tuple(a.name.empty(), a.name) <
3011                              std::make_tuple(b.name.empty(), b.name);
3012                     });
3013 
3014   struct NamedFunc {
3015     llvm::StringRef name;
3016     mlir::FuncOp func;
3017   };
3018   llvm::SmallVector<NamedFunc, 8> named_funcs;
3019   llvm::SmallVector<mlir::FuncOp, 8> private_funcs;
3020   for (auto func : module.getOps<mlir::FuncOp>()) {
3021     auto exported_names = mlir::tf_saved_model::GetExportedNames(func);
3022     if (!exported_names.empty())
3023       named_funcs.push_back({exported_names.front(), func});
3024     else
3025       private_funcs.push_back(func);
3026   }
3027   llvm::stable_sort(named_funcs, [](const NamedFunc& a, const NamedFunc& b) {
3028     return a.name < b.name;
3029   });
3030   llvm::stable_sort(private_funcs, [](mlir::FuncOp a, mlir::FuncOp b) {
3031     return a.getName() < b.getName();
3032   });
3033 
3034   struct NamedAsset {
3035     llvm::StringRef name;
3036     AssetOp asset;
3037   };
3038   llvm::SmallVector<NamedAsset, 4> assets;
3039   for (auto asset : module.getOps<AssetOp>()) {
3040     assets.push_back({asset.getName(), asset});
3041   }
3042   llvm::stable_sort(assets, [](const NamedAsset& a, const NamedAsset& b) {
3043     return a.name < b.name;
3044   });
3045 
3046   // Move onto the front of the module in reverse of the final desired order.
3047   for (auto func : llvm::reverse(private_funcs)) {
3048     func.getOperation()->moveBefore(&module.getBody()->front());
3049   }
3050   for (auto named_func : llvm::reverse(named_funcs)) {
3051     named_func.func.getOperation()->moveBefore(&module.getBody()->front());
3052   }
3053   for (auto named_global_tensor : llvm::reverse(named_global_tensors)) {
3054     named_global_tensor.global_tensor.getOperation()->moveBefore(
3055         &module.getBody()->front());
3056   }
3057 
3058   for (auto asset : assets) {
3059     asset.asset.getOperation()->moveBefore(&module.getBody()->front());
3060   }
3061 
3062   auto initializers = module.getOps<SessionInitializerOp>();
3063   if (!initializers.empty()) {
3064     (*initializers.begin())
3065         .getOperation()
3066         ->moveBefore(&module.getBody()->front());
3067   }
3068 }
3069 
CreateSavedModelIR(const ObjectNames & object_names,mlir::ModuleOp module,const SavedObjectGraph & object_graph,const std::unordered_map<std::string,std::string> & tf_name_to_mlir_name,SavedModelV2Bundle * saved_model)3070 Status CreateSavedModelIR(
3071     const ObjectNames& object_names, mlir::ModuleOp module,
3072     const SavedObjectGraph& object_graph,
3073     const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name,
3074     SavedModelV2Bundle* saved_model) {
3075   mlir::OpBuilder builder(module.getBodyRegion());
3076   mlir::SymbolTable symbol_table(module);
3077 
3078   // Create a side data-structure, indexed by the object_graph node_id to
3079   // a TrackableObject that is restorable.
3080   absl::flat_hash_map<int, const TrackableObjectGraph::TrackableObject*>
3081       restored_objects;
3082   TF_RETURN_IF_ERROR(saved_model->VisitObjectsToRestore(
3083       [&](int saved_node_id,
3084           const TrackableObjectGraph::TrackableObject& trackable_object) {
3085         restored_objects.insert(
3086             std::make_pair(saved_node_id, &trackable_object));
3087         return Status::OK();
3088       }));
3089 
3090   for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
3091     const SavedObject& object = object_graph.nodes(node_id);
3092     // For correctness, we cannot import functions that don't have exported
3093     // names, since they don't necessarily have a well-defined ABI (diagnosed
3094     // earlier).
3095     //
3096     // For variables/constants, pruning them is purely an optimization,
3097     // and more complicated since it requires use-def analysis of which
3098     // functions use which variables/constants, so we don't do anything
3099     // special for them here as part of our initial IR construction.
3100     if (object.kind_case() == SavedObject::kFunction) {
3101       if (object_names.GetExportedNames(node_id).empty()) {
3102         continue;
3103       }
3104       std::string error_context =
3105           "While importing SavedModel function '" +
3106           object_names.GetExportedNames(node_id)[0].str() + "': ";
3107       const SavedFunction& function = object.function();
3108       auto orig_func = symbol_table.lookup<mlir::FuncOp>(
3109           tf_name_to_mlir_name.find(function.concrete_functions(0))->second);
3110       mlir::FuncOp func = orig_func;
3111       // If there are potentially references to this func from within the
3112       // module, create a wrapper around it and decorate the wrapper with the
3113       // tf_saved_model attributes instead.
3114       if (!mlir::SymbolTable::symbolKnownUseEmpty(orig_func.getName(),
3115                                                   &module.getBodyRegion())) {
3116         func = orig_func.cloneWithoutRegions();
3117         module.insert(module.getBody()->begin(), func);
3118         func.addEntryBlock();
3119         func.setName("__sm_exported_" + orig_func.getName().str());
3120         llvm::SmallVector<mlir::Value, 4> args_as_values;
3121         for (auto block_argument : func.getArguments()) {
3122           args_as_values.push_back(block_argument);
3123         }
3124         mlir::OpBuilder body_builder(&func.getBody());
3125         auto call = body_builder.create<mlir::TF::StatefulPartitionedCallOp>(
3126             func.getLoc(), orig_func.getType().getResults(), args_as_values,
3127             builder.getSymbolRefAttr(orig_func.getName()),
3128             /*config=*/builder.getStringAttr(""),
3129             /*config_proto=*/builder.getStringAttr(""),
3130             /*executor_type=*/builder.getStringAttr(""));
3131         body_builder.create<mlir::ReturnOp>(func.getLoc(), call.getResults());
3132       }
3133       func->setAttr(
3134           "tf_saved_model.exported_names",
3135           builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3136       const SavedConcreteFunction& concrete_function =
3137           object_graph.concrete_functions().at(function.concrete_functions(0));
3138 
3139       // We do not handle the other element of this tuple, which corresponds to
3140       // Python kwonlyargs, since currently TensorFlow prohibits this in
3141       // combination with input_signature:
3142       // https://github.com/tensorflow/tensorflow/blob/8cb8627abb5ef83a6fba34f8fd0e4ee430562eb1/tensorflow/python/eager/function.py#L2027-L2030
3143       // Our SavedModel import requires input_signature on the tf.function, so
3144       // we never need to handle the kwonlyargs.
3145       auto positional_arg_structure =
3146           concrete_function.canonicalized_input_signature()
3147               .tuple_value()
3148               .values(0);
3149       StructuredValueLinearizer input_linearizer(positional_arg_structure,
3150                                                  builder.getContext());
3151 
3152       int bound_input_base =
3153           func.getNumArguments() - concrete_function.bound_inputs_size();
3154       TF_ASSIGN_OR_RETURN(auto input_index_paths,
3155                           input_linearizer.GetLeafIndexPaths(
3156                               error_context + "in input signature: "));
3157       const int input_index_paths_size = input_index_paths.size();
3158       if (bound_input_base != input_index_paths_size) {
3159         return errors::InvalidArgument(
3160             error_context,
3161             "Argument mismatch between concrete function input signature "
3162             "vs underlying FunctionDef for concrete function '",
3163             function.concrete_functions(0), "' (", input_index_paths.size(),
3164             " vs ", bound_input_base, ")");
3165       }
3166       for (auto index_path : llvm::enumerate(input_index_paths)) {
3167         func.setArgAttr(index_path.index(), "tf_saved_model.index_path",
3168                         index_path.value());
3169       }
3170 
3171       for (auto& bound_input :
3172            llvm::enumerate(concrete_function.bound_inputs())) {
3173         int arg_index = bound_input_base + bound_input.index();
3174         auto symbol_ref = builder.getSymbolRefAttr(
3175             object_names.GetSymbolTableName(bound_input.value()));
3176         func.setArgAttr(arg_index, "tf_saved_model.bound_input", symbol_ref);
3177       }
3178 
3179       StructuredValueLinearizer output_linearizer(
3180           concrete_function.output_signature(), builder.getContext());
3181       TF_ASSIGN_OR_RETURN(auto output_index_paths,
3182                           output_linearizer.GetLeafIndexPaths(
3183                               error_context + "in output signature: "));
3184       if (func.getNumResults() != output_index_paths.size()) {
3185         return errors::InvalidArgument(
3186             error_context,
3187             "Result mismatch between concrete function output signature "
3188             "vs underlying FunctionDef for concrete function '",
3189             function.concrete_functions(0), "' (", output_index_paths.size(),
3190             " vs ", func.getNumResults(), ")");
3191       }
3192       for (auto index_path : llvm::enumerate(output_index_paths)) {
3193         func.setResultAttr(index_path.index(), "tf_saved_model.index_path",
3194                            index_path.value());
3195       }
3196     } else if (object.kind_case() == SavedObject::kVariable) {
3197       const SavedVariable& variable = object.variable();
3198       // Find the trackable in the side data structure.
3199       auto variable_trackable_it = restored_objects.find(node_id);
3200       if (variable_trackable_it == restored_objects.end()) {
3201         return errors::FailedPrecondition("Could not restore saved variable: ",
3202                                           variable.name());
3203       }
3204       const auto* serialized_tensor_attr = FindSerializedTensorInTrackable(
3205           *variable_trackable_it->second, "VARIABLE_VALUE");
3206       if (!serialized_tensor_attr) {
3207         return errors::FailedPrecondition(
3208             "Could not find serialized tensor for saved variable: ",
3209             variable.name());
3210       }
3211       const auto& checkpoint_key = serialized_tensor_attr->checkpoint_key();
3212 
3213       // Load it from the reader.
3214       Tensor value;
3215       TF_RETURN_WITH_CONTEXT_IF_ERROR(
3216           saved_model->variable_reader()->Lookup(checkpoint_key, &value),
3217           "Could not read checkpoint key from variables bundle: ",
3218           checkpoint_key);
3219       TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder));
3220       // A variable can have a partially known type, such as tensor<?x27x?xf32>,
3221       // even if the initializer is a specific static shape.
3222       TF_ASSIGN_OR_RETURN(
3223           auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(),
3224                                              &builder));
3225       auto op = builder.create<GlobalTensorOp>(
3226           builder.getUnknownLoc(),
3227           builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
3228           value_attr,
3229           /*type=*/mlir::TypeAttr::get(type),
3230           /*is_mutable=*/builder.getUnitAttr());
3231       op->setAttr(
3232           "tf_saved_model.exported_names",
3233           builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3234     } else if (object.kind_case() == SavedObject::kConstant) {
3235       const SavedConstant& constant = object.constant();
3236       const TensorProto* value = ExtractConstTensorFromGraph(
3237           saved_model->meta_graph_def().graph_def(), constant.operation());
3238       if (!value) {
3239         return errors::FailedPrecondition(
3240             "Unable to find const node referenced in object graph: ",
3241             constant.operation());
3242       }
3243       TF_ASSIGN_OR_RETURN(auto value_attr,
3244                           ConvertTensorProto(*value, &builder));
3245       auto op = builder.create<GlobalTensorOp>(
3246           builder.getUnknownLoc(),
3247           builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
3248           value_attr,
3249           /*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()),
3250           /*is_mutable=*/nullptr);
3251       op->setAttr(
3252           "tf_saved_model.exported_names",
3253           builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3254     }
3255   }
3256   AdjustBoundInputArgTypes(module);
3257   module->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
3258   SortSavedModelModule(module);
3259   MarkSavedModelFunctionVisibility(module);
3260   return Status::OK();
3261 }
3262 
Convert(SavedModelV2Bundle * saved_model,absl::Span<std::string> exported_names,mlir::MLIRContext * context,bool add_default_attributes)3263 StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
3264     SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
3265     mlir::MLIRContext* context, bool add_default_attributes) {
3266   LoadImporterDialects(*context);
3267   GraphDebugInfo dummy_debug_info;
3268   const GraphDebugInfo& debug_info =
3269       saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
3270 
3271   GraphImportConfig specs;
3272   specs.prune_unused_nodes = true;
3273   mlir::OwningModuleRef module =
3274       mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
3275   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3276 
3277   const auto& graphdef = saved_model->meta_graph_def().graph_def();
3278   PopulateTfVersions(module.get(), graphdef.versions());
3279 
3280   GraphConstructorOptions options;
3281   options.allow_internal_ops = true;
3282   options.add_default_attributes = add_default_attributes;
3283   Graph graph(OpRegistry::Global());
3284 
3285   GraphDef preprocessed_graphdef(graphdef);
3286   if (add_default_attributes) {
3287     TF_RETURN_IF_ERROR(PreprocessGraphDef(nullptr, &preprocessed_graphdef));
3288   }
3289 
3290   TF_RETURN_IF_ERROR(
3291       ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph));
3292 
3293   NameUniquifier function_name_uniquifier(graph.flib_def());
3294   SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs,
3295                                          module.get(), &tf_name_to_mlir_name,
3296                                          &function_name_uniquifier);
3297 
3298   TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
3299 
3300   auto fn_names = graph.flib_def().ListFunctionNames();
3301   for (const auto& fn_name : fn_names) {
3302     TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name));
3303   }
3304 
3305   if (!saved_model->meta_graph_def().has_object_graph_def()) {
3306     return errors::InvalidArgument(
3307         "SavedModel does not have an object graph. Please use TF2.");
3308   }
3309   auto& object_graph = saved_model->meta_graph_def().object_graph_def();
3310   ObjectNames object_names(object_graph, exported_names);
3311 
3312   // Clean up a couple func's that always seem to be present when importing a
3313   // SavedModel. This is not strictly needed, as there is a separate pass that
3314   // will clean them up, but this makes staring at the raw IR of minimal
3315   // examples quite a bit nicer.
3316   for (auto func : llvm::make_early_inc_range(module->getOps<mlir::FuncOp>())) {
3317     if (func.getName().startswith("__inference__traced_save_") ||
3318         func.getName().startswith("__inference__traced_restore_") ||
3319         func.getName().startswith("__inference_signature_wrapper_")) {
3320       func.erase();
3321     }
3322   }
3323 
3324   // Diagnose SavedFunction's with multiple input signatures.
3325   TF_RETURN_IF_ERROR(
3326       DiagnoseMultipleConcreteFunctions(object_graph, object_names));
3327 
3328   // Construct the SavedModel IR.
3329   TF_RETURN_IF_ERROR(CreateSavedModelIR(object_names, module.get(),
3330                                         object_graph, tf_name_to_mlir_name,
3331                                         saved_model));
3332   assert(mlir::succeeded(mlir::verify(module.get())));
3333 
3334   return module;
3335 }
3336 
3337 class SimpleSavedModelMLIRImportInput : public SavedModelMLIRImportInput {
3338  public:
Create(const MLIRImportOptions & import_options,const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info)3339   static StatusOr<SimpleSavedModelMLIRImportInput> Create(
3340       const MLIRImportOptions& import_options,
3341       const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info) {
3342     DCHECK(meta_graph_def);
3343     GraphDef graph_def;
3344     if (import_options.enable_grappler) {
3345       // Grappler is best-effort.
3346       auto statusor = RunGrappler(*meta_graph_def);
3347       if (statusor.ok()) {
3348         graph_def = std::move(statusor).ValueOrDie();
3349       } else {
3350         // If the grappler fails, use the original graph def.
3351         LOG(WARNING) << "SimpleSavedModelMLIRImportInput: grappler failed: "
3352                      << statusor.status();
3353         graph_def = meta_graph_def->graph_def();
3354       }
3355     } else {
3356       graph_def = meta_graph_def->graph_def();
3357     }
3358 
3359     auto graph = std::make_unique<Graph>(OpRegistry::Global());
3360 
3361     if (import_options.upgrade_legacy) {
3362       TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
3363           graph_def, graph->flib_def().default_registry()));
3364     }
3365 
3366     GraphConstructorOptions graph_ctor_options;
3367     graph_ctor_options.allow_internal_ops = true;
3368     graph_ctor_options.add_default_attributes = true;
3369     TF_RETURN_IF_ERROR(
3370         ConvertGraphDefToGraph(graph_ctor_options, graph_def, graph.get()));
3371 
3372     if (import_options.upgrade_legacy) {
3373       // TODO(jpienaar): Remove need to const_cast.
3374       TF_RETURN_IF_ERROR(UpgradeLegacyGraph(
3375           graph.get(),
3376           const_cast<FunctionLibraryDefinition*>(&graph->flib_def()),
3377           /*restrict_functionalization_to_tpu_nodes=*/false));
3378     }
3379 
3380     return SimpleSavedModelMLIRImportInput(meta_graph_def, debug_info,
3381                                            std::move(graph));
3382   }
3383 
SimpleSavedModelMLIRImportInput(const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info,std::unique_ptr<Graph> graph)3384   SimpleSavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def,
3385                                   const GraphDebugInfo& debug_info,
3386                                   std::unique_ptr<Graph> graph)
3387       : SavedModelMLIRImportInput(meta_graph_def, debug_info),
3388         graph_(std::move(graph)) {}
3389 
GetSubGraph(absl::string_view name,const GraphImportConfig & specs)3390   StatusOr<const Graph*> GetSubGraph(absl::string_view name,
3391                                      const GraphImportConfig& specs) override {
3392     DCHECK(CheckGraphNameValidity(name));
3393     DCHECK(CheckGraphContainsFeedsAndFetches(specs));
3394     return graph_.get();
3395   }
3396 
3397  private:
CheckGraphContainsFeedsAndFetches(const GraphImportConfig & specs) const3398   bool CheckGraphContainsFeedsAndFetches(const GraphImportConfig& specs) const {
3399     absl::flat_hash_set<std::string> feed_fetch_nodes;
3400     for (const auto& iter : specs.inputs) {
3401       TensorId tensor_id = ParseTensorName(iter.first);
3402       feed_fetch_nodes.insert(std::string(tensor_id.node()));
3403     }
3404     for (const auto& output : llvm::concat<const std::string>(
3405              specs.outputs, specs.control_outputs)) {
3406       TensorId tensor_id = ParseTensorName(output);
3407       feed_fetch_nodes.insert(std::string(tensor_id.node()));
3408     }
3409 
3410     for (Node* node : graph_->op_nodes()) {
3411       feed_fetch_nodes.erase(node->name());
3412     }
3413 
3414     return feed_fetch_nodes.empty();
3415   }
3416 
CheckGraphNameValidity(absl::string_view name) const3417   bool CheckGraphNameValidity(absl::string_view name) const {
3418     // If it is one of the signature name, it is valid.
3419     const auto& signature_defs = meta_graph_def().signature_def();
3420     if (signature_defs.contains(std::string(name))) return true;
3421 
3422     // If it is the restore graph name, it is valid.
3423     if (meta_graph_def().has_saver_def() &&
3424         meta_graph_def().saver_def().restore_op_name() == name)
3425       return true;
3426 
3427     // If it is the init graph name, it is valid.
3428     std::string init_op_name;
3429     if (internal::GetInitOp("", meta_graph_def(), &init_op_name).ok()) {
3430       if (init_op_name == name) return true;
3431     }
3432 
3433     return false;
3434   }
3435 
3436   // `graph_` contains the entire graph in the original MetaGraphDef.
3437   std::unique_ptr<Graph> graph_;
3438 };
3439 
3440 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
3441 // an MLIR Module in SavedModel dialect.
3442 //
3443 // TODO(b/179683149): Rename this class to avoid confusion with TFLite.
3444 class SavedModelSignatureDefImporterLite {
3445  public:
3446   // Main entry point: converts all functions (specified by SignatureDefs) in
3447   // the given meta graph to an MLIR Module.
3448   //
3449   // `import_restore` is introduced to control whether restore graph
3450   // is imported in eg. SavedModelSignatureDefImporter. Ideally, we don't need
3451   // this option to control this as restore graph should be always imported.
3452   // However, right now, SavedModelSignatureDefImporter cannot handle restore
3453   // graph correctly.
3454   //
3455   // TODO(chky): Remove import_restore once the restore graph is correctly
3456   // handled in SavedModelSignatureDefImporter.
Convert(SavedModelMLIRImportInput & input,absl::Span<std::string> exported_names,mlir::MLIRContext * context,bool import_restore=true)3457   static StatusOr<mlir::OwningModuleRef> Convert(
3458       SavedModelMLIRImportInput& input, absl::Span<std::string> exported_names,
3459       mlir::MLIRContext* context, bool import_restore = true) {
3460     LoadImporterDialects(*context);
3461     SavedModelSignatureDefImporterLite importer(input, exported_names, context,
3462                                                 import_restore);
3463     TF_ASSIGN_OR_RETURN(auto module, importer.ConvertSignatures());
3464 
3465     SortSavedModelModule(*module);
3466     MarkSavedModelFunctionVisibility(*module);
3467 
3468     return module;
3469   }
3470 
3471  private:
SavedModelSignatureDefImporterLite(SavedModelMLIRImportInput & input,absl::Span<std::string> exported_names,mlir::MLIRContext * context,bool import_restore)3472   SavedModelSignatureDefImporterLite(SavedModelMLIRImportInput& input,
3473                                      absl::Span<std::string> exported_names,
3474                                      mlir::MLIRContext* context,
3475                                      bool import_restore)
3476       : input_(input),
3477         exported_names_(exported_names),
3478         module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))),
3479         symbol_table_(module_.get()),
3480         import_restore_(import_restore) {}
3481 
3482   // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
3483   // for each signature.
3484   StatusOr<mlir::OwningModuleRef> ConvertSignatures();
3485   Status ConvertSignature(const std::string& sig_def_key,
3486                           const SignatureDef& signature_def);
3487 
3488   struct AssetInfo {
3489     std::string tensor_name;
3490     mlir::tf_saved_model::AssetOp op;
3491   };
3492   StatusOr<std::vector<AssetInfo>> ConvertAssets();
3493   // Converts the initialization graph in the SavedModel to an MLIR function.
3494   Status ConvertInitializer(const std::string& target_node_name,
3495                             const std::vector<AssetInfo>& assets);
3496 
3497   // Converts a graph with feeds and fetches to an MLIR function.
3498   StatusOr<mlir::OwningModuleRef> ConvertGraph(
3499       const std::string& name,
3500       const std::vector<std::pair<std::string, TensorInfo>>& inputs,
3501       const std::vector<std::pair<std::string, TensorInfo>>& outputs,
3502       const std::vector<std::string> control_outputs);
3503 
3504   // Moves the functions in `sub_module` to `module_` and skips the duplicate
3505   // functions.
3506   Status MoveConvertedFunctionsToModule(absl::string_view name,
3507                                         mlir::ModuleOp sub_module);
3508 
3509   GraphImportConfig::InputArrays ParseInputArrays(
3510       llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs);
3511 
3512  private:
3513   SavedModelMLIRImportInput& input_;
3514   absl::Span<std::string> exported_names_;
3515   mlir::OwningModuleRef module_;
3516   mlir::SymbolTable symbol_table_;
3517   bool import_restore_ = true;
3518 };
3519 
3520 StatusOr<std::vector<SavedModelSignatureDefImporterLite::AssetInfo>>
ConvertAssets()3521 SavedModelSignatureDefImporterLite::ConvertAssets() {
3522   std::vector<AssetFileDef> asset_file_defs;
3523   TF_RETURN_IF_ERROR(
3524       internal::GetAssetFileDefs(input_.meta_graph_def(), &asset_file_defs));
3525 
3526   std::vector<AssetInfo> results;
3527   results.reserve(asset_file_defs.size());
3528 
3529   mlir::OpBuilder builder(module_->getBodyRegion());
3530   unsigned i = 0;  // Use to generate unique sym_name(s) for duplicate assets.
3531   for (const auto& asset : asset_file_defs) {
3532     auto asset_op = builder.create<mlir::tf_saved_model::AssetOp>(
3533         module_->getLoc(),
3534         /*sym_name=*/
3535         builder.getStringAttr(
3536             absl::StrCat("__tf_saved_model_asset", i++, "_", asset.filename())),
3537         /*filename=*/
3538         builder.getStringAttr(
3539             io::JoinPath(kSavedModelAssetsDirectory, asset.filename())));
3540 
3541     results.push_back({asset.tensor_info().name(), asset_op});
3542   }
3543 
3544   return results;
3545 }
3546 
MoveConvertedFunctionsToModule(absl::string_view name,mlir::ModuleOp sub_module)3547 Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
3548     absl::string_view name, mlir::ModuleOp sub_module) {
3549   mlir::Builder builder(sub_module.getContext());
3550   mlir::SymbolTable sub_module_symbol_table(sub_module);
3551 
3552   // Prefix private functions with the unique signature name, so that it cannot
3553   // collide with private functions used in the other signatures.
3554   for (auto func : sub_module.getOps<mlir::FuncOp>()) {
3555     if (mlir::tf_saved_model::IsExported(func)) continue;
3556 
3557     std::string new_sym_name = absl::StrCat(name, "/", func.sym_name().str());
3558     if (mlir::failed(sub_module_symbol_table.replaceAllSymbolUses(
3559             func, new_sym_name, sub_module)))
3560       return tensorflow::errors::InvalidArgument(absl::StrCat(
3561           "SavedModelSignatureDefImporterLite: failed to assign a unique "
3562           "name to the private function used in a signature: ",
3563           func.sym_name().str()));
3564 
3565     mlir::SymbolTable::setSymbolName(func, new_sym_name);
3566   }
3567 
3568   // Copy all functions used by this signature to the final MLIR module.
3569   for (auto func : sub_module.getOps<mlir::FuncOp>()) {
3570     DCHECK(symbol_table_.lookup(func.sym_name()) == nullptr);
3571     symbol_table_.insert(func.clone());
3572   }
3573 
3574   return Status::OK();
3575 }
3576 
ConvertInitializer(const std::string & target_node_name,const std::vector<AssetInfo> & assets)3577 Status SavedModelSignatureDefImporterLite::ConvertInitializer(
3578     const std::string& target_node_name, const std::vector<AssetInfo>& assets) {
3579   std::vector<std::pair<std::string, TensorInfo>> inputs;
3580   inputs.reserve(assets.size());
3581   for (const auto& asset : assets) {
3582     TensorInfo tensor_info;
3583     tensor_info.set_name(asset.tensor_name);
3584     tensor_info.set_dtype(DT_STRING);
3585     tensor_info.mutable_tensor_shape();
3586     inputs.push_back({asset.tensor_name, tensor_info});
3587   }
3588 
3589   TF_ASSIGN_OR_RETURN(auto sub_module, ConvertGraph(target_node_name, inputs,
3590                                                     {}, {target_node_name}));
3591 
3592   mlir::SymbolTable sub_symbol_table(*sub_module);
3593 
3594   auto init_func_op = sub_symbol_table.lookup<mlir::FuncOp>(target_node_name);
3595   init_func_op.removeAttr("tf.entry_function");
3596 
3597   mlir::OpBuilder builder(module_->getBodyRegion());
3598 
3599   // Bind asset inputs to asset ops.
3600   DCHECK_EQ(init_func_op.getNumArguments(), assets.size());
3601   for (const auto& iter : llvm::enumerate(assets)) {
3602     auto asset_op = iter.value().op;
3603     init_func_op.setArgAttr(iter.index(), "tf_saved_model.bound_input",
3604                             builder.getSymbolRefAttr(asset_op.getName()));
3605   }
3606 
3607   // Set the exported name of init function to an reserved name for
3608   // tf_saved_model.
3609   init_func_op->setAttr(
3610       "tf_saved_model.exported_names",
3611       builder.getStrArrayAttr({absl::StrCat(
3612           "__tf_saved_model_session_initializer_", target_node_name)}));
3613 
3614   // Move the converted functions to top level MLIR module.
3615   return MoveConvertedFunctionsToModule(target_node_name, *sub_module);
3616 }
3617 
3618 StatusOr<mlir::OwningModuleRef>
ConvertGraph(const std::string & name,const std::vector<std::pair<std::string,TensorInfo>> & inputs,const std::vector<std::pair<std::string,TensorInfo>> & outputs,const std::vector<std::string> control_outputs)3619 SavedModelSignatureDefImporterLite::ConvertGraph(
3620     const std::string& name,
3621     const std::vector<std::pair<std::string, TensorInfo>>& inputs,
3622     const std::vector<std::pair<std::string, TensorInfo>>& outputs,
3623     const std::vector<std::string> control_outputs) {
3624   VLOG(1) << "Importing Signature: " << name;
3625 
3626   GraphImportConfig specs;
3627   specs.prune_unused_nodes = true;
3628   specs.inputs = ParseInputArrays(inputs);
3629   for (auto& output : outputs) specs.outputs.push_back(output.second.name());
3630   specs.control_outputs = control_outputs;
3631 
3632   TF_ASSIGN_OR_RETURN(const auto* subgraph, input_.GetSubGraph(name, specs));
3633 
3634   // Convert sub-graph to MLIR module.
3635   return GraphDefImporter::Convert(module_->getContext(), *subgraph,
3636                                    input_.debug_info(), subgraph->flib_def(),
3637                                    specs, name);
3638 }
3639 
ConvertSignature(const std::string & sig_def_key,const SignatureDef & signature_def)3640 Status SavedModelSignatureDefImporterLite::ConvertSignature(
3641     const std::string& sig_def_key, const SignatureDef& signature_def) {
3642   // Create local vectors for the input and output and sort them to be
3643   // deterministic. We don't want anyone to really depend on the order, client
3644   // should lookup argument/result mapping by attribute name.
3645   // To avoid accidentally depending on the order we use an unintuitive sorting.
3646   std::vector<std::pair<std::string, TensorInfo>> inputs(
3647       signature_def.inputs().begin(), signature_def.inputs().end());
3648   llvm::sort(inputs, [](const auto& lhs, const auto& rhs) {
3649     return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
3650   });
3651   std::vector<std::pair<std::string, TensorInfo>> outputs(
3652       signature_def.outputs().begin(), signature_def.outputs().end());
3653   llvm::sort(outputs, [](const auto& lhs, const auto& rhs) {
3654     return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
3655   });
3656 
3657   // Convert sub-graph to MLIR module.
3658   TF_ASSIGN_OR_RETURN(auto sub_module,
3659                       ConvertGraph(sig_def_key, inputs, outputs, {}));
3660   mlir::OpBuilder builder(sub_module->getBodyRegion());
3661 
3662   // Find the FuncOp which corresponds to current SignatureDef.
3663   mlir::SymbolTable sub_symbol_table(*sub_module);
3664   auto func_op = sub_symbol_table.lookup<mlir::FuncOp>(sig_def_key);
3665   TF_RET_CHECK(func_op)
3666       << "Graphdef importer should have created a function named "
3667       << sig_def_key << ".";
3668 
3669   // Use unique SignatureDef key as exported name.
3670   func_op->setAttr("tf_saved_model.exported_names",
3671                    builder.getStrArrayAttr({sig_def_key}));
3672 
3673   // Transfer input and output parameter names to index_path attributes.
3674   for (auto input_and_idx : llvm::enumerate(inputs)) {
3675     func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path",
3676                        builder.getStrArrayAttr({input_and_idx.value().first}));
3677   }
3678   for (auto output_and_idx : llvm::enumerate(outputs)) {
3679     func_op.setResultAttr(
3680         output_and_idx.index(), "tf_saved_model.index_path",
3681         builder.getStrArrayAttr({output_and_idx.value().first}));
3682   }
3683 
3684   // Move the converted functions to top level MLIR module.
3685   return MoveConvertedFunctionsToModule(sig_def_key, *sub_module);
3686 }
3687 
3688 GraphImportConfig::InputArrays
ParseInputArrays(llvm::ArrayRef<std::pair<std::string,TensorInfo>> inputs)3689 SavedModelSignatureDefImporterLite::ParseInputArrays(
3690     llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs) {
3691   GraphImportConfig::InputArrays results;
3692   for (const auto& iter : inputs) {
3693     const auto& tensor_info = iter.second;
3694 
3695     // Only dense tensor is supported.
3696     DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName);
3697 
3698     VLOG(1) << "Importing Signature Input: input_name = " << iter.first
3699             << ", tensor_info = " << tensor_info.DebugString();
3700 
3701     ArrayInfo array_info;
3702     array_info.imported_dtype = tensor_info.dtype();
3703 
3704     if (tensor_info.has_tensor_shape()) {
3705       array_info.shape = tensor_info.tensor_shape();
3706     } else {
3707       // If there is no tensor shape in the tensor info, conservatively set
3708       // unknown_rank to true.
3709       array_info.shape.set_unknown_rank(true);
3710     }
3711 
3712     results.insert(std::pair<std::string, ArrayInfo>(tensor_info.name(),
3713                                                      std::move(array_info)));
3714   }
3715   return results;
3716 }
3717 
3718 StatusOr<mlir::OwningModuleRef>
ConvertSignatures()3719 SavedModelSignatureDefImporterLite::ConvertSignatures() {
3720   const auto& signatures = input_.meta_graph_def().signature_def();
3721   PopulateTfVersions(module_.get(),
3722                      input_.meta_graph_def().graph_def().versions());
3723 
3724   llvm::DenseSet<llvm::StringRef> exported_name_set;
3725   exported_name_set.insert(exported_names_.begin(), exported_names_.end());
3726 
3727   for (const auto& key_and_signature_def : signatures) {
3728     const std::string& sig_def_key = key_and_signature_def.first;
3729     const SignatureDef& signature_def = key_and_signature_def.second;
3730 
3731     // It is safe to skip "__saved_model_init_op" since it is an internal
3732     // signature that is not user-accessible. This signature will be handled in
3733     // ConvertInitializer().
3734     if (sig_def_key == "__saved_model_init_op") {
3735       continue;
3736     }
3737     if (!exported_name_set.empty() &&
3738         exported_name_set.count(sig_def_key) == 0) {
3739       continue;
3740     }
3741 
3742     TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def));
3743   }
3744 
3745   TF_ASSIGN_OR_RETURN(auto assets, ConvertAssets());
3746 
3747   mlir::OpBuilder builder(module_->getBodyRegion());
3748   llvm::SmallVector<mlir::Attribute, 2> init_sym_refs;
3749 
3750   if (import_restore_ && input_.meta_graph_def().has_saver_def()) {
3751     std::vector<AssetInfo> variable_and_assets;
3752 
3753     // Create an AssetOp for the variable checkpoint files. The relative
3754     // filename is used here.
3755     auto variable_filename_op = builder.create<mlir::tf_saved_model::AssetOp>(
3756         module_->getLoc(),
3757         /*sym_name=*/
3758         builder.getStringAttr("__tf_saved_model_variables"),
3759         /*filename=*/
3760         builder.getStringAttr(io::JoinPath(kSavedModelVariablesDirectory,
3761                                            kSavedModelVariablesFilename)));
3762     variable_and_assets.push_back(
3763         {input_.meta_graph_def().saver_def().filename_tensor_name(),
3764          variable_filename_op});
3765     variable_and_assets.insert(variable_and_assets.end(), assets.begin(),
3766                                assets.end());
3767 
3768     const auto& restore_op_name =
3769         input_.meta_graph_def().saver_def().restore_op_name();
3770     TF_RETURN_IF_ERROR(
3771         ConvertInitializer(restore_op_name, variable_and_assets));
3772     init_sym_refs.push_back(builder.getSymbolRefAttr(restore_op_name));
3773   }
3774 
3775   std::string init_op_name;
3776   TF_RETURN_IF_ERROR(
3777       internal::GetInitOp("", input_.meta_graph_def(), &init_op_name));
3778   if (!init_op_name.empty()) {
3779     TF_RETURN_IF_ERROR(ConvertInitializer(init_op_name, assets));
3780     init_sym_refs.push_back(builder.getSymbolRefAttr(init_op_name));
3781   }
3782 
3783   builder.create<mlir::tf_saved_model::SessionInitializerOp>(
3784       module_->getLoc(), builder.getArrayAttr(init_sym_refs));
3785 
3786   (*module_)->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
3787 
3788   SortSavedModelModule(*module_);
3789   MarkSavedModelFunctionVisibility(*module_);
3790 
3791   return std::move(module_);
3792 }
3793 
3794 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
3795 // an MLIR Module in SavedModel dialect. In addition to importing the model, it
3796 // performs a few graph transformations, including:
3797 //  1) Convert read-only ref variables to resource variables
3798 //  2) Lift resource variables to global_tensors by using a TF session.
3799 class SavedModelSignatureDefImporter {
3800  public:
3801   // Main entry point: converts all functions (specified by SignatureDefs) in
3802   // the given meta graph to an MLIR Module.
Convert(const SavedModelBundle & bundle,absl::Span<std::string> exported_names,mlir::MLIRContext * context,tensorflow::MLIRImportOptions options)3803   static StatusOr<mlir::OwningModuleRef> Convert(
3804       const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
3805       mlir::MLIRContext* context, tensorflow::MLIRImportOptions options) {
3806     // debug_info might not be loaded with loader_lite.
3807     GraphDebugInfo debug_info;
3808     if (bundle.debug_info != nullptr) debug_info = *bundle.debug_info;
3809 
3810     TF_ASSIGN_OR_RETURN(auto input,
3811                         SimpleSavedModelMLIRImportInput::Create(
3812                             options, &bundle.meta_graph_def, debug_info));
3813 
3814     TF_ASSIGN_OR_RETURN(auto module,
3815                         SavedModelSignatureDefImporterLite::Convert(
3816                             input, exported_names, context,
3817                             /*import_restore=*/false));
3818 
3819     mlir::OpBuilder builder(module->getContext());
3820     (*module)->setAttr("tf_saved_model.under_construction",
3821                        builder.getUnitAttr());
3822     TF_RETURN_IF_ERROR(LiftVariables(bundle, *module));
3823     module->removeAttr("tf_saved_model.under_construction");
3824 
3825     return module;
3826   }
3827 
3828  private:
3829   // Lifts the variables in `module`.
3830   static Status LiftVariables(const SavedModelBundle& bundle,
3831                               mlir::ModuleOp module);
3832 };
3833 
LiftVariables(const SavedModelBundle & bundle,mlir::ModuleOp module)3834 Status SavedModelSignatureDefImporter::LiftVariables(
3835     const SavedModelBundle& bundle, mlir::ModuleOp module) {
3836   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
3837 
3838   mlir::PassManager pm(module.getContext());
3839   SetCrashReproducer(pm);
3840   pm.addNestedPass<mlir::FuncOp>(
3841       mlir::tf_executor::CreateTFExecutorGraphPruningPass());
3842   pm.addNestedPass<mlir::FuncOp>(
3843       mlir::CreateExecutorDialectToFunctionalConversionPass());
3844   pm.addPass(
3845       mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass());
3846   pm.addNestedPass<mlir::FuncOp>(
3847       mlir::TF::
3848           CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
3849   pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
3850   pm.addPass(
3851       mlir::tf_saved_model::CreateLiftVariablesPass(bundle.GetSession()));
3852   pm.addNestedPass<mlir::FuncOp>(
3853       mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
3854   if (mlir::failed(pm.run(module)))
3855     return diag_handler.Combine(errors::Internal("Failed to lift variables."));
3856 
3857   return Status::OK();
3858 }
3859 
3860 }  // namespace
3861 
~SavedModelMLIRImportInput()3862 SavedModelMLIRImportInput::~SavedModelMLIRImportInput() {}
3863 
ConvertGraphdefToMlir(const GraphDef & graphdef,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::MLIRContext * context,bool add_default_attributes)3864 StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
3865     const GraphDef& graphdef, const GraphDebugInfo& debug_info,
3866     const GraphImportConfig& specs, mlir::MLIRContext* context,
3867     bool add_default_attributes) {
3868   GraphConstructorOptions options;
3869   options.allow_internal_ops = true;
3870   options.add_default_attributes = add_default_attributes;
3871   Graph graph(OpRegistry::Global());
3872 
3873   GraphDef preprocessed_graphdef(graphdef);
3874   if (add_default_attributes) {
3875     TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef));
3876   }
3877   if (specs.upgrade_legacy) {
3878     TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
3879         preprocessed_graphdef, graph.flib_def().default_registry()));
3880   }
3881   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
3882       options, std::move(preprocessed_graphdef), &graph));
3883   return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs,
3884                             context);
3885 }
3886 
ConvertGraphToMlir(const Graph & graph,const GraphDebugInfo & debug_info,const FunctionLibraryDefinition & flib_def,const GraphImportConfig & specs,mlir::MLIRContext * context)3887 StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
3888     const Graph& graph, const GraphDebugInfo& debug_info,
3889     const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
3890     mlir::MLIRContext* context) {
3891   // TODO(jpienaar): Remove need to const_cast.
3892   if (specs.upgrade_legacy) {
3893     TF_RETURN_IF_ERROR(
3894         UpgradeLegacyGraph(const_cast<Graph*>(&graph),
3895                            const_cast<FunctionLibraryDefinition*>(&flib_def),
3896                            specs.restrict_functionalization_to_tpu_nodes));
3897   }
3898   return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs,
3899                                    /*func_name=*/"main");
3900 }
3901 
ConvertFunctionToMlir(const FunctionBody * fbody,const FunctionLibraryDefinition & flib_def,mlir::MLIRContext * context)3902 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
3903     const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def,
3904     mlir::MLIRContext* context) {
3905   tensorflow::GraphDebugInfo dummy_debug_info;
3906   tensorflow::GraphImportConfig specs;
3907   specs.enable_shape_inference = false;
3908   specs.graph_as_function = true;
3909   for (const auto* control_ret_node : fbody->control_ret_nodes)
3910     specs.control_outputs.push_back(control_ret_node->name());
3911   return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
3912                                    flib_def, specs,
3913                                    fbody->fdef.signature().name());
3914 }
3915 
ConvertSavedModelToMlir(SavedModelV2Bundle * saved_model,mlir::MLIRContext * context,absl::Span<std::string> exported_names,bool add_default_attributes)3916 StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
3917     SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
3918     absl::Span<std::string> exported_names, bool add_default_attributes) {
3919   return SavedModelObjectGraphImporter::Convert(
3920       saved_model, exported_names, context, add_default_attributes);
3921 }
3922 
ConvertSavedModelV1ToMlir(const SavedModelBundle & saved_model,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)3923 StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
3924     const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
3925     mlir::MLIRContext* context, MLIRImportOptions options) {
3926   return SavedModelSignatureDefImporter::Convert(saved_model, exported_names,
3927                                                  context, options);
3928 }
3929 
ConvertSavedModelV1ToMlirLite(const MetaGraphDef & meta_graph_def,const GraphDebugInfo & debug_info,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)3930 StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite(
3931     const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
3932     absl::Span<std::string> exported_names, mlir::MLIRContext* context,
3933     MLIRImportOptions options) {
3934   TF_ASSIGN_OR_RETURN(auto input, SimpleSavedModelMLIRImportInput::Create(
3935                                       options, &meta_graph_def, debug_info));
3936   return ConvertSavedModelV1ToMlirLite(input, exported_names, context);
3937 }
3938 
ConvertSavedModelV1ToMlirLite(SavedModelMLIRImportInput & input,absl::Span<std::string> exported_names,mlir::MLIRContext * context)3939 StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite(
3940     SavedModelMLIRImportInput& input, absl::Span<std::string> exported_names,
3941     mlir::MLIRContext* context) {
3942   return SavedModelSignatureDefImporterLite::Convert(input, exported_names,
3943                                                      context);
3944 }
3945 
MlirModuleToString(mlir::ModuleOp module,mlir::OpPrintingFlags flags)3946 std::string MlirModuleToString(mlir::ModuleOp module,
3947                                mlir::OpPrintingFlags flags) {
3948   std::string txt_module;
3949   {
3950     llvm::raw_string_ostream os{txt_module};
3951     module.print(os, flags);
3952   }
3953   return txt_module;
3954 }
3955 
MlirModuleToString(mlir::ModuleOp module,bool show_debug_info)3956 std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
3957   mlir::OpPrintingFlags flags;
3958   if (show_debug_info) flags.enableDebugInfo();
3959   return MlirModuleToString(module, flags);
3960 }
3961 
3962 }  // namespace tensorflow
3963