1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph_constructor.h"
17 
18 #include <algorithm>
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/shape_refiner.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/tensor_shape.pb.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/versions.h"
33 #include "tensorflow/core/framework/versions.pb.h"
34 #include "tensorflow/core/graph/algorithm.h"
35 #include "tensorflow/core/graph/graph.h"
36 #include "tensorflow/core/graph/tensor_id.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/flatmap.h"
39 #include "tensorflow/core/lib/gtl/flatset.h"
40 #include "tensorflow/core/lib/gtl/inlined_vector.h"
41 #include "tensorflow/core/lib/strings/scanner.h"
42 #include "tensorflow/core/lib/strings/str_util.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/public/version.h"
45 
46 namespace tensorflow {
47 
48 namespace {
IsMerge(const NodeDef & node_def)49 inline bool IsMerge(const NodeDef& node_def) {
50   return node_def.op() == "Merge" || node_def.op() == "RefMerge";
51 }
52 
IsNextIteration(const NodeDef & node_def)53 inline bool IsNextIteration(const NodeDef& node_def) {
54   return node_def.op() == "NextIteration" ||
55          node_def.op() == "RefNextIteration";
56 }
57 
IsValidNodeName(StringPiece s,bool allow_internal_ops)58 bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
59   using ::tensorflow::strings::Scanner;
60   return Scanner(s)
61       .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
62                               : Scanner::LETTER_DIGIT_DOT)
63       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
64       .Eos()
65       .GetResult();
66 }
67 
68 class GraphConstructor {
69  public:
70   struct Options {
Optionstensorflow::__anone86fd2d90111::GraphConstructor::Options71     Options(const GraphConstructorOptions& in)  // NOLINT(runtime/explicit)
72         : allow_internal_ops(in.allow_internal_ops),
73           expect_device_spec(in.expect_device_spec),
74           importing(false),
75           validate_colocation_constraints(false) {}
Optionstensorflow::__anone86fd2d90111::GraphConstructor::Options76     Options(const ImportGraphDefOptions& in)  // NOLINT(runtime/explicit)
77         : allow_internal_ops(false),
78           expect_device_spec(false),
79           prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
80                      ? in.prefix
81                      : in.prefix + "/"),
82           uniquify_names(in.uniquify_names),
83           uniquify_prefix(in.uniquify_prefix),
84           input_map(in.input_map.begin(), in.input_map.end()),
85           skip_mapped_nodes(in.skip_mapped_nodes),
86           control_dependencies(in.control_dependencies),
87           return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
88           return_nodes(in.return_nodes),
89           importing(true),
90           validate_colocation_constraints(in.validate_colocation_constraints),
91           validate_shape(in.validate_shape),
92           default_device(in.default_device) {}
93 
94     bool allow_internal_ops;
95     bool expect_device_spec;
96 
97     string prefix;
98     bool uniquify_names;
99     bool uniquify_prefix;
100     std::map<TensorId, TensorId> input_map;
101     bool skip_mapped_nodes;
102     std::vector<string> control_dependencies;
103     std::vector<TensorId> return_tensors;
104     std::vector<string> return_nodes;
105 
106     // TODO(ashankar): This bool exists to separate out functionality required
107     // to make ImportGraphDef a close equivalent of Python's import_graph_def
108     // without affecting the behavior of ConvertGraphDefToGraph at the time
109     // ImportGraphDef was added.
110     //
111     // That said, the functionality here (shape and op validation) seems
112     // applicable to ConvertGraphDefToGraph as well, so make an attempt to
113     // remove this.
114     bool importing;
115     bool validate_colocation_constraints;
116     bool validate_shape = true;
117 
118     string default_device;
119   };
120 
121   typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
122 
123   // versions and library may be nullptr
Construct(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)124   static Status Construct(
125       const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
126       const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
127       std::vector<std::pair<Node*, int>>* return_tensors,
128       std::vector<Node*>* return_nodes,
129       std::vector<SafeTensorId>* missing_unused_input_map_keys) {
130     if (versions) {
131       TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
132                                        TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
133                                        "GraphDef", "graph"));
134     }
135     GraphConstructor c(opts, node_defs, versions, library, g, refiner,
136                        return_tensors, return_nodes,
137                        missing_unused_input_map_keys);
138     const Status s = c.TryImport();
139     if (!s.ok()) c.Undo();
140     return s;
141   }
142 
143  private:
GraphConstructor(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)144   GraphConstructor(const Options& opts, NodeDefSlice node_defs,
145                    const VersionDef* versions,
146                    const FunctionDefLibrary* library, Graph* g,
147                    ShapeRefiner* refiner,
148                    std::vector<std::pair<Node*, int>>* return_tensors,
149                    std::vector<Node*>* return_nodes,
150                    std::vector<SafeTensorId>* missing_unused_input_map_keys)
151       : opts_(opts),
152         node_defs_(node_defs),
153         versions_(versions),
154         library_(library),
155         g_(g),
156         original_versions_(g->versions()),
157         prefix_(opts.prefix),
158         refiner_(refiner),
159         return_tensors_(return_tensors),
160         return_nodes_(return_nodes),
161         missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
162 
TryImport()163   Status TryImport() {
164     TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
165     TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
166     TF_RETURN_IF_ERROR(BuildNodeIndex());
167     TF_RETURN_IF_ERROR(InitFromEdges());
168     TF_RETURN_IF_ERROR(Convert());
169     TF_RETURN_IF_ERROR(AddBackEdges());
170     TF_RETURN_IF_ERROR(UpdateVersionDef());
171     TF_RETURN_IF_ERROR(PopulateReturnTensors());
172     TF_RETURN_IF_ERROR(PopulateReturnNodes());
173     TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
174     UpdateUniquifiedColocationNames();
175     FixupSourceAndSinkEdges(g_);
176     return Status::OK();
177   }
178 
179   Status EnsureNoNameCollisions();
180   Status ValidateInputMapAndControlDependencies();
181   Status BuildNodeIndex();
182   Status InitFromEdges();
183   Status Convert();
184   Status AddBackEdges();
185   Status UpdateVersionDef();
186   Status PopulateReturnTensors();
187   Status PopulateReturnNodes();
188   Status PopulateMissingUnusedInputMapKeys();
189 
190   void Undo();
191 
192   Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped);
193   Status ValidateColocationConstraints(const NodeDef& node_def);
194   Status MakeNode(const NodeDef& node_def, Node** node);
195   Status MakeEdge(Node* src, int output_index, Node* dst, int input_index);
196   Status ValidateShape(Node* node);
197   Status ModifyNodeDefForImport(NodeDef* node_def);
198   // Modifies node_def's inputs according to opts_.input_map.
199   // input_already_exists is a pre-initialized vector of length
200   // node_def->input_size(). This function will mark inputs that are remapped to
201   // true.
202   void RemapNodeDefInputs(NodeDef* node_def,
203                           std::vector<bool>* input_already_exists);
204   // input_already_exists is a pre-initialized vector of length
205   // node_def->input_size(). This function will add and mark control inputs as
206   // true.
207   void AddControlDependencies(NodeDef* node_def,
208                               std::vector<bool>* input_already_exists);
209   void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
210                           NodeDef* node_def);
211 
212   // Modifies `node_def` if its name isn't unique, or if any of its inputs'
213   // names have been uniquified. This must be called in topological order on all
214   // nodes.
215   void UniquifyNames(const std::vector<bool>& input_already_exists,
216                      NodeDef* node_def);
217 
218   // Updates any constructed nodes' colocation group names if the name has been
219   // updated by UniquifyNames. This is called after all the nodes have been
220   // constructed so all the names have been uniquified if necessary.
221   void UpdateUniquifiedColocationNames();
222 
223   // Returns true if `name` already exists in `g_` (either as a node name or
224   // prefix).
225   bool NameExistsInGraph(StringPiece name);
226 
227   // Returns true if `name` already exists in the GraphDef being imported
228   // (either as a node name or prefix).
229   bool NameExistsInGraphDef(StringPiece name);
230 
231   // Returns a unique version of `original_name`, or `original_name` if it's
232   // already unique in the graph.
233   string FindUniqueName(StringPiece original_name);
234 
235   // Decrement pending count for users of `processed` and add the ones that now
236   // have all of their pending inputs satisfied to `ready_`.
237   void UpdatePendingCountAndReady(int processed);
238 
239   // From constructor
240   const Options opts_;
241   const NodeDefSlice node_defs_;
242   const VersionDef* versions_;
243   const FunctionDefLibrary* library_;
244   Graph* g_;
245   const VersionDef original_versions_;
246 
247   // A copy of opts_.prefix, possibly uniquified.
248   string prefix_;
249 
250   ShapeRefiner* refiner_;
251 
252   // May be null. Not owned.
253   std::vector<std::pair<Node*, int>>* return_tensors_;
254 
255   // May be null. Not owned.
256   std::vector<Node*>* return_nodes_;
257 
258   // May be null. Not owned.
259   std::vector<SafeTensorId>* missing_unused_input_map_keys_;
260 
261   // Intermediate datastructure used to populate
262   // `missing_unused_input_map_keys_`.
263   std::set<TensorId> used_input_map_keys_;
264 
265   // Mapping from node name to the index within node_defs_.
266   struct NodeInfo {
NodeInfotensorflow::__anone86fd2d90111::GraphConstructor::NodeInfo267     explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
268     // std::unordered_map<> requires that we have a default constructor.
NodeInfotensorflow::__anone86fd2d90111::GraphConstructor::NodeInfo269     NodeInfo() : NodeInfo(-1) {}
270     int gdef_index;
271     Node* node;  // nullptr until the NodeDef is converted to a Node.
272   };
273   gtl::FlatMap<StringPiece, NodeInfo, StringPieceHasher> gdef_nodes_;
274 
275   // Prefixes already used in the GraphDef being imported.
276   gtl::FlatSet<StringPiece, StringPieceHasher> gdef_prefixes_;
277 
278   // Mapping from node name to the existing node in g_.
279   gtl::FlatMap<StringPiece, Node*, StringPieceHasher> existing_nodes_;
280 
281   // Prefixes already used in the graph.
282   gtl::FlatSet<StringPiece, StringPieceHasher> existing_prefixes_;
283 
284   // Imported node names that have been uniquified. The key is the original
285   // name, the value is the new unique name.
286   gtl::FlatMap<string, string> uniquified_names_;
287 
288   // Index of NodeDefs in node_defs_ with all inputs already converted. We use a
289   // (sorted) set so nodes are created in the order defined in the GraphDef.
290   std::set<int> ready_;
291 
292   // Mapping between index within node_defs_ and the number of inputs that
293   // still need to be converted.
294   std::vector<int> pending_count_;
295 
296   // Mapping between index within node_defs_ and the index within node_defs_ of
297   // all nodes it outputs to.
298   std::vector<gtl::InlinedVector<int, 4>> outputs_;
299 
300   // Used in the conversion from node_defs_ to g_ to represent the ith input
301   // of a node.
302   struct InputInfo {
InputInfotensorflow::__anone86fd2d90111::GraphConstructor::InputInfo303     explicit InputInfo(const string& node_name, Node* n, int i)
304         : name(node_name), node(n), index(i) {}
305     // Use string instead of StringPiece so we don't have to manage lifetime
306     string name;
307     Node* node;
308     int index;
309   };
310 
311   // Used in the conversion from node_defs_ to g_ to represent an edge from
312   // the node named 'name' to node 'n'.
313   struct EdgeInfo {
EdgeInfotensorflow::__anone86fd2d90111::GraphConstructor::EdgeInfo314     explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
315         : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
316     // Use string instead of StringPiece so we don't have to manage lifetime
317     string src_name;
318     int src_index;
319     Node* dst_node;
320     int dst_index;
321   };
322   std::vector<EdgeInfo> back_edges_;
323 };
324 
UpdatePendingCountAndReady(int processed)325 void GraphConstructor::UpdatePendingCountAndReady(int processed) {
326   // We didn't consider NextIteration->Merge edges when computing
327   // pending_counts_ so we should not have to consider it here either.
328   bool is_next_iteration = IsNextIteration(*node_defs_[processed]);
329   for (size_t i = 0; i < outputs_[processed].size(); ++i) {
330     const int output = outputs_[processed][i];
331     bool is_next_iteration_to_merge_edge =
332         is_next_iteration && IsMerge(*node_defs_[output]);
333     if (!is_next_iteration_to_merge_edge) {
334       int* current_pending_count = &pending_count_[output];
335       CHECK_GT(*current_pending_count, 0);
336       (*current_pending_count)--;
337       if (*current_pending_count == 0) {
338         ready_.insert(output);
339       }
340     }
341   }
342 }
343 
344 // This could be expensive but we don't expect to call it often, if at all (only
345 // if there are multiple nodes in g_ with the same name)
NodeNameInValues(const std::map<TensorId,TensorId> & input_map,const StringPiece & node_name)346 bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
347                       const StringPiece& node_name) {
348   for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) {
349     if (iter->second.first == node_name) return true;
350   }
351   return false;
352 }
353 
NodeNameInValues(const std::vector<string> & control_dependencies,const StringPiece & node_name)354 bool NodeNameInValues(const std::vector<string>& control_dependencies,
355                       const StringPiece& node_name) {
356   return std::find(control_dependencies.begin(), control_dependencies.end(),
357                    node_name) != control_dependencies.end();
358 }
359 
360 // Adds any prefixes of `node_name` (not including the full name itself) to
361 // `prefixes`.
AddPrefixes(StringPiece node_name,gtl::FlatSet<StringPiece,StringPieceHasher> * prefixes)362 void AddPrefixes(StringPiece node_name,
363                  gtl::FlatSet<StringPiece, StringPieceHasher>* prefixes) {
364   size_t idx = -1;
365   while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) {
366     prefixes->insert(node_name.substr(0, idx));
367   }
368 }
369 
EnsureNoNameCollisions()370 Status GraphConstructor::EnsureNoNameCollisions() {
371   existing_nodes_.reserve(g_->num_nodes());
372   // Populate existing_nodes_ and existing_prefixes_.
373   for (Node* n : g_->nodes()) {
374     bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
375     if (already_exists) {
376       if (NodeNameInValues(opts_.input_map, n->name())) {
377         return errors::InvalidArgument(
378             "cannot resolve input_map because multiple nodes exist with name '",
379             n->name(), "'");
380       }
381       if (NodeNameInValues(opts_.control_dependencies, n->name())) {
382         return errors::InvalidArgument(
383             "cannot resolve control_dependencies because multiple nodes exist "
384             "with name '",
385             n->name(), "'");
386       }
387     }
388     AddPrefixes(n->name(), &existing_prefixes_);
389   }
390   if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
391     for (const NodeDef* n : node_defs_) {
392       const string& name = n->name();
393       if (NameExistsInGraph(name)) {
394         return errors::InvalidArgument("Node name '", name,
395                                        "' already exists in the Graph");
396       }
397     }
398   } else if (!prefix_.empty()) {
399     StringPiece prefix_no_slash(prefix_);
400     prefix_no_slash.remove_suffix(1);
401     if (!IsValidNodeName(prefix_no_slash, false)) {
402       return errors::InvalidArgument("Imported node name prefix '", prefix_,
403                                      "' would lead to invalid node names");
404     }
405     if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
406       prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
407     }
408   }
409   return Status::OK();
410 }
411 
ValidateInputMapAndControlDependencies()412 Status GraphConstructor::ValidateInputMapAndControlDependencies() {
413   for (const auto& mapping : opts_.input_map) {
414     TensorId src = mapping.first;
415     TensorId dst = mapping.second;
416     if (existing_nodes_.count(dst.first) == 0) {
417       return errors::InvalidArgument(
418           "node '", dst.first, "' in input_map does not exist in graph ",
419           "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")");
420     }
421     if ((src.second == Graph::kControlSlot) !=
422         (dst.second == Graph::kControlSlot)) {
423       return errors::InvalidArgument("input_map entry ", src.ToString(), "->",
424                                      dst.ToString(), " between ",
425                                      "control edge and non-control edge");
426     }
427   }
428   for (const string& node : opts_.control_dependencies) {
429     if (existing_nodes_.count(node) == 0) {
430       return errors::InvalidArgument(
431           "node '", node,
432           "' in control_dependencies does not exist in "
433           "graph");
434     }
435   }
436   return Status::OK();
437 }
438 
BuildNodeIndex()439 Status GraphConstructor::BuildNodeIndex() {
440   // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_.
441   for (int n = 0; n < node_defs_.size(); ++n) {
442     const NodeDef& node_def = *node_defs_[n];
443     if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
444       return errors::InvalidArgument(
445           "Node '", node_def.name(),
446           "': Node name contains invalid characters");
447     }
448     if (!gdef_nodes_
449              .insert(std::make_pair(StringPiece(node_def.name()), NodeInfo(n)))
450              .second) {
451       return errors::InvalidArgument("Node '", node_def.name(),
452                                      "' is not unique");
453     }
454     // Validate the operation's type.
455     if (node_def.op().empty()) {
456       return errors::InvalidArgument("Node '", node_def.name(),
457                                      "' does not specify an operation");
458     }
459     if (opts_.expect_device_spec && node_def.device().empty()) {
460       return errors::InvalidArgument("Node '", node_def.name(),
461                                      "' is missing a device specification");
462     }
463     // Validate control edges at end
464     bool in_control_dependence = false;
465     for (int i = 0; i < node_def.input_size(); ++i) {
466       StringPiece input_name = node_def.input(i);
467       if (!input_name.empty() && str_util::StartsWith(input_name, "^")) {
468         in_control_dependence = true;
469       } else if (in_control_dependence) {
470         return errors::InvalidArgument(
471             "Node '", node_def.name(),
472             "': Control dependencies must come after regular dependencies");
473       }
474     }
475     // Update gdef_prefixes_.
476     AddPrefixes(node_def.name(), &gdef_prefixes_);
477   }
478   return Status::OK();
479 }
480 
GetNextIterationNodes(const GraphConstructor::NodeDefSlice & node_defs)481 std::unordered_set<string> GetNextIterationNodes(
482     const GraphConstructor::NodeDefSlice& node_defs) {
483   std::unordered_set<string> next_iteration_nodes;
484 
485   for (int n = 0; n < node_defs.size(); ++n) {
486     const NodeDef& node_def = *node_defs[n];
487     if (IsNextIteration(node_def)) {
488       next_iteration_nodes.insert(node_def.name());
489     }
490   }
491 
492   return next_iteration_nodes;
493 }
494 
InitFromEdges()495 Status GraphConstructor::InitFromEdges() {
496   const int num_nodes = node_defs_.size();
497   pending_count_.reserve(num_nodes);
498   outputs_.resize(num_nodes);
499   std::unordered_set<string> next_iteration_nodes_ =
500       GetNextIterationNodes(node_defs_);
501 
502   // Parse the inputs for each node.
503   for (int n = 0; n < num_nodes; ++n) {
504     const NodeDef& node_def = *node_defs_[n];
505     int pending_count = node_def.input_size();
506     if (IsMerge(node_def)) {
507       // Cycles in the graph are only allowed for while loops. A while loop is
508       // identified by an edge from a NextIteration node to a Merge node. For
509       // such Merge nodes, only wait for one non-control input before
510       // considering the node ready to process in Convert().
511       int32 num_control_edges = 0;
512       bool has_loop_back_edge = false;
513       for (int i = 0; i < node_def.input_size(); ++i) {
514         StringPiece input_name(node_def.input(i));
515         if (str_util::StartsWith(input_name, "^")) {
516           num_control_edges++;
517         } else {
518           TensorId id(ParseTensorName(input_name));
519           if (next_iteration_nodes_.find(string(id.first)) !=
520               next_iteration_nodes_.end()) {
521             has_loop_back_edge = true;
522           }
523         }
524       }
525       if (has_loop_back_edge) {
526         pending_count = num_control_edges + 1;
527       }
528     }
529     for (int i = 0; i < node_def.input_size(); ++i) {
530       StringPiece input_name = node_def.input(i);
531       TensorId id(ParseTensorName(input_name));
532       if (opts_.input_map.count(id) == 0) {
533         // If an input is not mapped, then the input should appear in the graph
534         // being imported.
535         auto iter = gdef_nodes_.find(id.first);
536         if (iter == gdef_nodes_.end()) {
537           return errors::InvalidArgument("Node '", node_def.name(),
538                                          "': Unknown input node '",
539                                          node_def.input(i), "'");
540         }
541         outputs_[iter->second.gdef_index].push_back(n);
542       } else {
543         // This input is mapped to an existing edge. Therefore this input is
544         // as good as being already processed.
545         --pending_count;
546         DCHECK_GE(pending_count, 0);
547       }
548     }
549     if (pending_count == 0) {
550       ready_.insert(n);
551     }
552     pending_count_.push_back(pending_count);
553   }
554   return Status::OK();
555 }
556 
ValidateColocationConstraints(const NodeDef & node_def)557 Status GraphConstructor::ValidateColocationConstraints(
558     const NodeDef& node_def) {
559   if (!opts_.validate_colocation_constraints || !opts_.importing)
560     return Status::OK();
561   const auto iter = node_def.attr().find(kColocationAttrName);
562   if (iter == node_def.attr().end()) return Status::OK();
563   for (const string& c : iter->second.list().s()) {
564     StringPiece s(c);
565     if (str_util::ConsumePrefix(&s, kColocationGroupPrefix) &&
566         gdef_nodes_.find(s) == gdef_nodes_.end()) {
567       return errors::InvalidArgument(
568           "Node '", node_def.name(),
569           "' expects to be colocated with unknown node '", s, "'");
570     }
571   }
572   return Status::OK();
573 }
574 
MakeNode(const NodeDef & node_def,Node ** node)575 Status GraphConstructor::MakeNode(const NodeDef& node_def, Node** node) {
576   // Add the node to the graph.
577   Status status;
578   *node = g_->AddNode(node_def, &status);
579   if (!status.ok()) return status;
580   if (opts_.expect_device_spec) {
581     (*node)->set_assigned_device_name(node_def.device());
582   }
583   return Status::OK();
584 }
585 
ValidateShape(Node * node)586 Status GraphConstructor::ValidateShape(Node* node) {
587   if (!opts_.importing || !opts_.validate_shape) return Status::OK();
588   TF_RETURN_IF_ERROR(refiner_->AddNode(node));
589   // For nodes with the _output_shapes attribute, override the shape.
590   std::vector<TensorShapeProto> shape_attrs;
591   const char* kAttrName = "_output_shapes";
592   if (!GetNodeAttr(node->attrs(), kAttrName, &shape_attrs).ok()) {
593     // No _output_shapes attribute, the AddNode call above was sufficient.
594     return Status::OK();
595   }
596   auto* ic = refiner_->GetContext(node);
597   DCHECK(ic != nullptr)
598       << "ShapeRefiner::AddNode() should have created the InferenceContext";
599   if (shape_attrs.size() < node->num_outputs()) {
600     return errors::InvalidArgument(
601         "Node '", node->name(), "' has ", node->num_outputs(),
602         " outputs but the ", kAttrName, " attribute specifies shapes for ",
603         shape_attrs.size(), " outputs");
604   }
605   // NOTE(skyewm): we don't raise an error here because some users depend on
606   // this behavior, even though it's unsafe.
607   // TODO(b/74619486): raise an error.
608   if (shape_attrs.size() > node->num_outputs()) {
609     LOG(WARNING) << "Node '" << node->name() << "' has " << node->num_outputs()
610                  << " outputs but the " << kAttrName
611                  << " attribute specifies shapes for " << shape_attrs.size()
612                  << " outputs. Output shapes may be inaccurate.";
613   }
614   for (int i = 0; i < node->num_outputs(); ++i) {
615     const TensorShapeProto& p = shape_attrs[i];
616     shape_inference::ShapeHandle h;
617     Status s = ic->MakeShapeFromShapeProto(p, &h);
618     if (!s.ok()) {
619       return errors::InvalidArgument("Node '", node->name(), " has an invalid ",
620                                      kAttrName, " attribute (shape #", i,
621                                      " error:'", s.error_message(), "'");
622     }
623     s = refiner_->SetShape(node, i, h);
624     if (!s.ok()) {
625       // If the output shape is incompatible with what is inferred
626       // by the graph for a very specific whitelist of ops, then we
627       // ignore this output shape.  This can happen if there is a
628       // bug in the shape function for some operation, and the
629       // serialized graph def has the incorrect shape set when
630       // running on a newer binary with the fixed shape function.
631       // This is an escape hatch that allows us to correct shape
632       // functions that are not critical to correct execution but
633       // would cause graphs to fail if imported after correcting.
634       //
635       const string& op = node->type_string();
636       const std::vector<string> whitelist = {
637           // To be removed after 2017/03/08.
638           "RandomShuffleQueue",
639           "PaddingFIFOQueue",
640           "FIFOQueue",
641           "PriorityQueue",
642           "QueueSize",
643           "Stack",
644           "Barrier",
645           "BarrierReadySize",
646           "BarrierIncompleteSize",
647           "HashTable",
648           "MutableHashTable",
649           "MutableHashTableOfTensors",
650           "Mutex",
651           "CuckooTable",
652           "IndexTable",
653           "WholeFileReader",
654           "TextLineReader",
655           "FixedLengthRecordReader",
656           "TFRecordReader",
657           "IdentityReader",
658           "RefSwitch",
659           "RefEnter",
660           "RefNextIteration",
661           "RefMerge",
662           "RefIdentity",
663           "LMDBReader",
664           // To be removed after 2017/04/24.
665           "ConditionalAccumulator",
666           "SparseConditionalAccumulator",
667           "Table",
668       };
669       if (std::find(whitelist.begin(), whitelist.end(), op) ==
670           whitelist.end()) {
671         return errors::InvalidArgument(
672             "Node '", node->name(), "' has an ", kAttrName,
673             " attribute inconsistent with the GraphDef for output #", i, ": ",
674             s.error_message());
675       }
676     }
677   }
678   node->ClearAttr(kAttrName);
679   return Status::OK();
680 }
681 
ModifyNodeDefForImport(NodeDef * node_def)682 Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
683   const OpDef* op_def;
684   TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
685   AddDefaultsToNodeDef(*op_def, node_def);
686   TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
687   if (versions_) {
688     TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions_->producer()));
689   }
690   return Status::OK();
691 }
692 
RemoveInputs(const std::vector<int> & inputs_to_remove,NodeDef * node_def,std::vector<bool> * input_already_exists)693 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
694                   std::vector<bool>* input_already_exists) {
695   // Remove 'inputs_to_remove' from 'node_def'
696   NodeDef copy;
697   copy.mutable_input()->Reserve(node_def->input_size() -
698                                 inputs_to_remove.size());
699   for (int i = 0, j = 0; i < node_def->input_size(); ++i) {
700     if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
701       ++j;
702     } else {
703       copy.add_input()->swap(*node_def->mutable_input(i));
704     }
705   }
706   node_def->mutable_input()->Swap(copy.mutable_input());
707   // Remove 'inputs_to_remove' from 'input_already_exists'
708   for (int idx : inputs_to_remove) {
709     input_already_exists->erase(input_already_exists->begin() + idx);
710   }
711   DCHECK_EQ(input_already_exists->size(), node_def->input_size());
712 }
713 
RemapNodeDefInputs(NodeDef * node_def,std::vector<bool> * input_already_exists)714 void GraphConstructor::RemapNodeDefInputs(
715     NodeDef* node_def, std::vector<bool>* input_already_exists) {
716   DCHECK_EQ(input_already_exists->size(), node_def->input_size());
717   std::set<TensorId> control_inputs;
718   std::vector<int> inputs_to_remove;
719 
720   for (int i = 0; i < node_def->input_size(); ++i) {
721     auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
722     if (iter == opts_.input_map.end()) continue;
723     used_input_map_keys_.insert(iter->first);
724 
725     TensorId new_input = iter->second;
726     if (new_input.second == Graph::kControlSlot) {
727       // Check if we've already remapped a different input to new_input, and if
728       // so remove this input.
729       if (control_inputs.count(new_input) > 0) {
730         inputs_to_remove.push_back(i);
731         continue;
732       }
733       control_inputs.insert(new_input);
734     }
735     node_def->set_input(i, new_input.ToString());
736     (*input_already_exists)[i] = true;
737   }
738   if (!inputs_to_remove.empty()) {
739     RemoveInputs(inputs_to_remove, node_def, input_already_exists);
740   }
741 }
742 
AddControlDependencies(NodeDef * node_def,std::vector<bool> * input_already_exists)743 void GraphConstructor::AddControlDependencies(
744     NodeDef* node_def, std::vector<bool>* input_already_exists) {
745   // To avoid adding redundant control dependencies to every imported node, skip
746   // nodes that will inherit the dependencies from another imported node.
747   bool inherits_deps = false;
748   for (int i = 0; i < node_def->input_size(); ++i) {
749     // Assume we won't inherit dependencies from remapped inputs that already
750     // exist in the graph. Even if we're wrong, we'll only add redundant
751     // dependencies.
752     if ((*input_already_exists)[i]) continue;
753 
754     // If this input is a backedge, assume we won't inherit the dependencies.
755     // TODO(skyewm): we have many redundant ParseTensorName calls. It could be
756     // worth optimizing these.
757     TensorId id(ParseTensorName(node_def->input(i)));
758     auto iter = gdef_nodes_.find(id.first);
759     DCHECK(iter != gdef_nodes_.end()) << id.first;
760     if (iter->second.node == nullptr) {
761       // Input hasn't been created yet, indicating it's a backedge.
762       continue;
763     }
764     inherits_deps = true;
765   }
766   if (inherits_deps) return;
767 
768   // node_def either has no inputs or all remapped inputs, add the control
769   // dependencies
770   for (const string& control_dep : opts_.control_dependencies) {
771     string input = TensorId(control_dep, Graph::kControlSlot).ToString();
772     bool found = false;
773     for (int i = node_def->input_size() - 1; i >= 0; --i) {
774       const string& node_input = node_def->input(i);
775       if (node_input[0] != '^') {
776         // Control inputs are at the end. Break when we reach the non-control
777         // inputs.
778         break;
779       }
780       if (node_input == input) {
781         // Control dependency already exists
782         found = true;
783         break;
784       }
785     }
786     if (found) {
787       continue;
788     }
789     node_def->add_input(input);
790     input_already_exists->push_back(true);
791   }
792 }
793 
AddPrefixToNodeDef(const std::vector<bool> & input_already_exists,NodeDef * node_def)794 void GraphConstructor::AddPrefixToNodeDef(
795     const std::vector<bool>& input_already_exists, NodeDef* node_def) {
796   if (prefix_.empty()) return;
797   node_def->set_name(strings::StrCat(prefix_, node_def->name()));
798   // Update names of input nodes
799   for (int i = 0; i < node_def->input_size(); ++i) {
800     // Skip remapped inputs (which already exist in g_ and are not being
801     // imported).
802     if (input_already_exists[i]) continue;
803     StringPiece input(node_def->input(i));
804     if (str_util::ConsumePrefix(&input, "^")) {
805       node_def->set_input(i, strings::StrCat("^", prefix_, input));
806     } else {
807       node_def->set_input(i, strings::StrCat(prefix_, input));
808     }
809   }
810   // Update names of colocation groups
811   if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
812     auto* list =
813         node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
814     for (int i = 0; i < list->s_size(); ++i) {
815       StringPiece v(list->s(i));
816       if (str_util::ConsumePrefix(&v, kColocationGroupPrefix)) {
817         list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
818       }
819     }
820   }
821 }
822 
UniquifyNames(const std::vector<bool> & input_already_exists,NodeDef * node_def)823 void GraphConstructor::UniquifyNames(
824     const std::vector<bool>& input_already_exists, NodeDef* node_def) {
825   if (NameExistsInGraph(node_def->name())) {
826     string old_name = node_def->name();
827     node_def->set_name(FindUniqueName(node_def->name()));
828     uniquified_names_[old_name] = node_def->name();
829     // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with
830     // `name` because we guarantee the original NodeDef names are unique,
831     // meaning we won't generate this name again.
832   }
833   for (int i = 0; i < node_def->input_size(); ++i) {
834     // Skip remapped inputs (which already exist in g_ and are not being
835     // imported).
836     if (input_already_exists[i]) continue;
837     TensorId id = ParseTensorName(node_def->input(i));
838     // We require that UniquifyNames() is called on all NodeDefs in topological
839     // order. This guarantees that node_def's inputs will already be uniquified
840     // if necessary.
841     auto iter = uniquified_names_.find(string(id.first));
842     if (iter == uniquified_names_.end()) continue;
843     id.first = iter->second;
844     node_def->set_input(i, id.ToString());
845   }
846 }
847 
UpdateUniquifiedColocationNames()848 void GraphConstructor::UpdateUniquifiedColocationNames() {
849   for (const auto& pair : gdef_nodes_) {
850     Node* node = pair.second.node;
851     if (node == nullptr) continue;
852     std::vector<string> coloc_values;
853     Status status =
854         GetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values);
855     if (!status.ok()) continue;
856     bool updated = false;
857     for (int i = 0; i < coloc_values.size(); ++i) {
858       StringPiece val(coloc_values[i]);
859       if (str_util::ConsumePrefix(&val, kColocationGroupPrefix)) {
860         auto name_pair = uniquified_names_.find(string(val));
861         if (name_pair == uniquified_names_.end()) continue;
862         updated = true;
863         coloc_values[i] =
864             strings::StrCat(kColocationGroupPrefix, name_pair->second);
865       }
866     }
867     if (updated) {
868       node->AddAttr(kColocationAttrName, coloc_values);
869     }
870   }
871 }
872 
NameExistsInGraph(StringPiece name)873 bool GraphConstructor::NameExistsInGraph(StringPiece name) {
874   if (existing_nodes_.find(name) != existing_nodes_.end()) return true;
875   if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true;
876   return false;
877 }
878 
NameExistsInGraphDef(StringPiece name)879 bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
880   if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true;
881   if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true;
882   return false;
883 }
884 
FindUniqueName(StringPiece original_name)885 string GraphConstructor::FindUniqueName(StringPiece original_name) {
886   string name(original_name);
887   int count = 0;
888   // Check that any generated names don't collide with imported NodeDefs (as
889   // well as nodes in g_).
890   while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) {
891     name = strings::StrCat(original_name, "_", ++count);
892   }
893   return name;
894 }
895 
IsNodeFullyMapped(const NodeDef & node_def,bool * is_node_mapped)896 Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
897                                            bool* is_node_mapped) {
898   const OpDef* op_def;
899   TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
900   for (int i = 0; i < op_def->output_arg_size(); ++i) {
901     if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) {
902       *is_node_mapped = false;
903       return Status::OK();
904     }
905   }
906   *is_node_mapped = true;
907   return Status::OK();
908 }
909 
Convert()910 Status GraphConstructor::Convert() {
911   // Import functions before adding nodes, since imported nodes may refer to
912   // functions
913   if (library_) {
914     TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library_));
915   }
916 
917   std::vector<InputInfo> inputs;
918   int processed = 0;
919 
920   std::vector<bool> input_already_exists;
921 
922   // Process the NodeDefs in topological order.
923   // (InitFromEdges() sets this up by filling in ready_ with nodes that have no
924   // inputs, pending_counts_ with the number of inputs for each node and
925   // outputs_ with the outputs of each node).
926   while (!ready_.empty()) {
927     int o = *ready_.begin();
928     ready_.erase(ready_.begin());
929     ++processed;
930     inputs.clear();
931     bool has_data_back_edge = false;
932 
933     const NodeDef& original_node_def = *node_defs_[o];
934     NodeDef imported_node_def;
935     const NodeDef* node_def;
936 
937     // input_already_exists[i] is true iff the i-th input of the node we're
938     // importing refers to a preexisting node in g_ (i.e. input[i] existed prior
939     // to importing node_defs_).  Conversely, input_already_exists[i] is false
940     // iff the input refers to a node in node_defs_.
941     input_already_exists.clear();
942     input_already_exists.resize(original_node_def.input_size(), false);
943 
944     if (opts_.importing) {
945       if (opts_.skip_mapped_nodes) {
946         bool is_node_mapped = false;
947         TF_RETURN_IF_ERROR(
948             IsNodeFullyMapped(original_node_def, &is_node_mapped));
949         if (is_node_mapped) {
950           // Skip this node after updating pending_count_ for outputs
951           UpdatePendingCountAndReady(o);
952           continue;
953         }
954       }
955 
956       // TODO(ashankar): The line below means an additional copy of the
957       // NodeDef, which can be expensive if the NodeDef contains large tensors
958       // in it. Might make sense to change the API for ImportGraphDef to take
959       // a mutable GraphDef* and avoid the copying.
960       imported_node_def = original_node_def;
961       if (!opts_.input_map.empty()) {
962         // Note that input_already_exists can shrink here
963         RemapNodeDefInputs(&imported_node_def, &input_already_exists);
964       }
965       if (!opts_.control_dependencies.empty()) {
966         // Note that input_already_exists can grow here
967         AddControlDependencies(&imported_node_def, &input_already_exists);
968       }
969       if (!opts_.default_device.empty() && imported_node_def.device().empty()) {
970         imported_node_def.set_device(opts_.default_device);
971       }
972 
973       node_def = &imported_node_def;
974     } else {
975       node_def = &original_node_def;
976     }
977 
978     DCHECK_EQ(node_def->input_size(), input_already_exists.size());
979     TF_RETURN_IF_ERROR(ValidateColocationConstraints(*node_def));
980     for (int i = 0; i < node_def->input_size(); ++i) {
981       TensorId id(ParseTensorName(node_def->input(i)));
982       Node* src_node;
983       int src_index;
984 
985       if (!input_already_exists[i]) {
986         // Locate input in newly-imported nodes
987         auto iter = gdef_nodes_.find(id.first);
988         DCHECK(iter != gdef_nodes_.end()) << id.first;
989         src_node = iter->second.node;
990         src_index = id.second;
991         if (src_node == nullptr) has_data_back_edge = true;
992       } else {
993         // Input refers to preexistng node in graph
994         auto iter = existing_nodes_.find(id.first);
995         DCHECK(iter != existing_nodes_.end()) << id.first;
996         src_node = iter->second;
997         src_index = id.second;
998       }
999 
1000       if (src_node != nullptr && src_index >= src_node->num_outputs()) {
1001         return errors::InvalidArgument(
1002             "Node '", node_def->name(), "': Connecting to invalid output ",
1003             id.second, " of source node ", id.first, " which has ",
1004             src_node->num_outputs(), " outputs");
1005       }
1006 
1007       inputs.emplace_back(string(id.first), src_node, src_index);
1008     }
1009 
1010     if (has_data_back_edge && !IsMerge(*node_def)) {
1011       return errors::InvalidArgument(
1012           "Node '", node_def->name(),
1013           "' had a back edge, but only Merge nodes can have back edges.");
1014     }
1015 
1016     Node* node;
1017     if (opts_.importing) {
1018       if (!prefix_.empty()) {
1019         AddPrefixToNodeDef(input_already_exists, &imported_node_def);
1020       }
1021       // Note: no need to uniquify names if the prefix already guarantees
1022       // uniqueness
1023       if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
1024         UniquifyNames(input_already_exists, &imported_node_def);
1025       }
1026       TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def));
1027     }
1028     TF_RETURN_IF_ERROR(MakeNode(*node_def, &node));
1029     // Use original_node_def so name StringPiece remains valid
1030     gdef_nodes_[original_node_def.name()].node = node;
1031 
1032     // Add edges from inputs to *node to the graph.
1033     for (size_t i = 0; i < inputs.size(); ++i) {
1034       if (inputs[i].node == nullptr) {
1035         // Record this back edge, which will be added after all nodes
1036         // are created.
1037         back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i);
1038       } else if (inputs[i].index == Graph::kControlSlot) {
1039         g_->AddControlEdge(inputs[i].node, node);
1040       } else {
1041         TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i));
1042       }
1043     }
1044 
1045     TF_RETURN_IF_ERROR(ValidateShape(node));
1046 
1047     // Update pending_count_ for outputs.
1048     UpdatePendingCountAndReady(o);
1049   }
1050 
1051   if (processed < node_defs_.size()) {
1052     LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed)
1053                  << " NODES IN A CYCLE";
1054     for (int64 i = 0; i < node_defs_.size(); i++) {
1055       if (pending_count_[i] != 0) {
1056         LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i])
1057                      << " WITH PENDING COUNT = " << pending_count_[i];
1058       }
1059     }
1060     return errors::InvalidArgument(node_defs_.size() - processed,
1061                                    " nodes in a cycle");
1062   }
1063 
1064   return Status::OK();
1065 }
1066 
AddBackEdges()1067 Status GraphConstructor::AddBackEdges() {
1068   // Add the back edges after all nodes are created.
1069   for (auto e : back_edges_) {
1070     Node* src_node = gdef_nodes_[e.src_name].node;
1071     if (e.src_index == Graph::kControlSlot) {
1072       g_->AddControlEdge(src_node, e.dst_node);
1073     } else {
1074       TF_RETURN_IF_ERROR(
1075           MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index));
1076     }
1077 
1078     VLOG(2) << "Add back edge: " << src_node->name() << " -> "
1079             << e.dst_node->name();
1080   }
1081   return Status::OK();
1082 }
1083 
UpdateVersionDef()1084 Status GraphConstructor::UpdateVersionDef() {
1085   if (versions_ == nullptr) return Status::OK();
1086 
1087   if (!opts_.importing) {
1088     g_->set_versions(*versions_);
1089     return Status::OK();
1090   }
1091   VersionDef versions = g_->versions();
1092   versions.set_producer(std::min(versions.producer(), versions_->producer()));
1093   versions.set_min_consumer(
1094       std::max(versions.min_consumer(), versions_->min_consumer()));
1095   if (versions_->bad_consumers_size() > 0) {
1096     std::set<int> bad(versions.bad_consumers().begin(),
1097                       versions.bad_consumers().end());
1098     bad.insert(versions_->bad_consumers().begin(),
1099                versions_->bad_consumers().end());
1100     versions.clear_bad_consumers();
1101     for (int v : bad) {
1102       versions.add_bad_consumers(v);
1103     }
1104   }
1105   g_->set_versions(versions);
1106   return Status::OK();
1107 }
1108 
PopulateReturnTensors()1109 Status GraphConstructor::PopulateReturnTensors() {
1110   if (opts_.return_tensors.empty()) return Status::OK();
1111   for (const TensorId& id : opts_.return_tensors) {
1112     auto iter = opts_.input_map.find(id);
1113     if (iter == opts_.input_map.end()) {
1114       // Locate id in imported nodes
1115       auto iter = gdef_nodes_.find(id.first);
1116       if (iter == gdef_nodes_.end()) {
1117         return errors::InvalidArgument("Requested return tensor '",
1118                                        id.ToString(),
1119                                        "' not found in graph def");
1120       }
1121       int num_outputs = iter->second.node->num_outputs();
1122       if ((id.second < 0 || id.second >= num_outputs) &&
1123           id.second != Graph::kControlSlot) {
1124         return errors::InvalidArgument("Invalid return output ", id.second,
1125                                        " of node '", id.first, "', which has ",
1126                                        num_outputs, " output(s)");
1127       }
1128       return_tensors_->push_back({iter->second.node, id.second});
1129     } else {
1130       // id was remapped to existing node
1131       TensorId remapped_id = iter->second;
1132       DCHECK_GT(existing_nodes_.count(remapped_id.first), 0);
1133       Node* node = existing_nodes_[remapped_id.first];
1134       return_tensors_->push_back({node, remapped_id.second});
1135     }
1136   }
1137   return Status::OK();
1138 }
1139 
PopulateReturnNodes()1140 Status GraphConstructor::PopulateReturnNodes() {
1141   if (opts_.return_nodes.empty()) return Status::OK();
1142   for (StringPiece name : opts_.return_nodes) {
1143     auto iter = gdef_nodes_.find(name);
1144     if (iter == gdef_nodes_.end()) {
1145       return errors::InvalidArgument("Requested return node '", name,
1146                                      "' not found in graph def");
1147     }
1148     return_nodes_->push_back(iter->second.node);
1149   }
1150   return Status::OK();
1151 }
1152 
PopulateMissingUnusedInputMapKeys()1153 Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
1154   if (missing_unused_input_map_keys_ == nullptr) return Status::OK();
1155   for (const auto& input_map_pair : opts_.input_map) {
1156     TensorId key = input_map_pair.first;
1157     if (used_input_map_keys_.count(key) > 0) continue;
1158 
1159     auto pair = gdef_nodes_.find(key.first);
1160     if (pair == gdef_nodes_.end()) {
1161       // key's node doesn't exist in GraphDef
1162       missing_unused_input_map_keys_->push_back(key);
1163       continue;
1164     }
1165 
1166     // Check that key's index is in bounds. Get the number of outputs from the
1167     // NodeDef, rather than the imported Node, since the Node may not exist if
1168     // opts_.skip_mapped_nodes is true.
1169     const NodeDef* node_def = node_defs_[pair->second.gdef_index];
1170     const OpDef* op_def;
1171     TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
1172     int num_outputs;
1173     TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs));
1174     if (key.second >= num_outputs) {
1175       // key's index out of bounds
1176       missing_unused_input_map_keys_->push_back(key);
1177     }
1178   }
1179   return Status::OK();
1180 }
1181 
Undo()1182 void GraphConstructor::Undo() {
1183   for (const auto& iter : gdef_nodes_) {
1184     if (iter.second.node != nullptr) {
1185       g_->RemoveNode(iter.second.node);
1186     }
1187   }
1188   g_->set_versions(original_versions_);
1189 }
1190 
MakeEdge(Node * src,int output_index,Node * dst,int input_index)1191 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
1192                                   int input_index) {
1193   DataType src_out = src->output_type(output_index);
1194   DataType dst_in = dst->input_type(input_index);
1195   if (!TypesCompatible(dst_in, src_out)) {
1196     return errors::InvalidArgument(
1197         "Input ", input_index, " of node ", dst->name(), " was passed ",
1198         DataTypeString(src_out), " from ", src->name(), ":", output_index,
1199         " incompatible with expected ", DataTypeString(dst_in), ".");
1200   }
1201   g_->AddEdge(src, output_index, dst, input_index);
1202   return Status::OK();
1203 }
1204 
1205 }  // namespace
1206 
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,const GraphDef & gdef,Graph * g)1207 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1208                               const GraphDef& gdef, Graph* g) {
1209   ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1210   return GraphConstructor::Construct(
1211       opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
1212       /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
1213       /*missing_unused_input_map_keys=*/nullptr);
1214 }
1215 
ConvertNodeDefsToGraph(const GraphConstructorOptions & opts,gtl::ArraySlice<NodeDef> nodes,Graph * g)1216 Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
1217                               gtl::ArraySlice<NodeDef> nodes, Graph* g) {
1218   ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
1219   // TODO(irving): Copy will go away once NodeInfo exists
1220   std::vector<const NodeDef*> node_defs;
1221   for (const auto& n : nodes) {
1222     node_defs.push_back(&n);
1223   }
1224   return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
1225                                      &refiner, /*return_tensors=*/nullptr,
1226                                      /*return_nodes=*/nullptr,
1227                                      /*missing_unused_input_map_keys=*/nullptr);
1228 }
1229 
ImportGraphDef(const ImportGraphDefOptions & opts,const GraphDef & gdef,Graph * g,ShapeRefiner * refiner,ImportGraphDefResults * results)1230 Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
1231                       Graph* g, ShapeRefiner* refiner,
1232                       ImportGraphDefResults* results) {
1233   if (!opts.return_tensors.empty()) {
1234     if (results == nullptr) {
1235       return errors::InvalidArgument(
1236           "results argument to ImportGraphDef() must be non-null if "
1237           "opts.return_tensors is non-empty");
1238     }
1239   }
1240 
1241   if (!opts.return_nodes.empty()) {
1242     if (opts.skip_mapped_nodes) {
1243       return errors::InvalidArgument(
1244           "Requesting return_nodes with skip_mapped_nodes set is not currently "
1245           "supported");
1246     }
1247     if (results == nullptr) {
1248       return errors::InvalidArgument(
1249           "results argument to ImportGraphDef() must be non-null if "
1250           "opts.return_nodes is non-empty");
1251     }
1252   }
1253 
1254   if (results != nullptr) {
1255     if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
1256         !results->missing_unused_input_map_keys.empty()) {
1257       return errors::InvalidArgument(
1258           "All fields in results argument to ImportGraphDef() must be empty.");
1259     }
1260   }
1261 
1262   ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
1263   if (refiner == nullptr) {
1264     refiner = &default_refiner;
1265   } else {
1266     // Log a warning if we are importing a GraphDef at an older
1267     // producer version after already having added non-source/sink
1268     // nodes to the graph in the past.
1269     if (gdef.versions().producer() > 0 &&
1270         gdef.versions().producer() < refiner->graph_def_version() &&
1271         g->num_nodes() > 2) {
1272       LOG(WARNING) << "Importing a graph with a lower producer version "
1273                    << gdef.versions().producer()
1274                    << " into an existing graph with producer version "
1275                    << refiner->graph_def_version() << ". Shape inference will "
1276                    << "have run different parts of the graph with different "
1277                    << "producer versions.";
1278     }
1279   }
1280 
1281   // Set the graph def version of the refiner as the min of the
1282   // current value and the version from the graph we are about to
1283   // import.
1284   //
1285   // Note: to match Run() semantics, we should re-run shape inference
1286   // on the entire graph if the producer version has changed.  For now
1287   // we log the warning above.
1288   refiner->set_graph_def_version(
1289       std::min(refiner->graph_def_version(), gdef.versions().producer()));
1290 
1291   if (results == nullptr) {
1292     return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
1293                                        &gdef.library(), g, refiner, nullptr,
1294                                        nullptr, nullptr);
1295   } else {
1296     return GraphConstructor::Construct(
1297         opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
1298         &results->return_tensors, &results->return_nodes,
1299         &results->missing_unused_input_map_keys);
1300   }
1301 }
1302 
CopyGraph(const Graph & src,Graph * dest)1303 void CopyGraph(const Graph& src, Graph* dest) {
1304   for (Node* n : dest->nodes()) {
1305     CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
1306   }
1307 
1308   // Copy GraphDef versions
1309   dest->set_versions(src.versions());
1310 
1311   // Copy the nodes
1312   std::unordered_map<const Node*, Node*>
1313       node_map;  // "Node in src" -> "Node in *dest"
1314   node_map[src.source_node()] = dest->source_node();
1315   node_map[src.sink_node()] = dest->sink_node();
1316   for (Node* n : src.op_nodes()) {
1317     node_map[n] = dest->CopyNode(n);
1318   }
1319 
1320   // Copy the edges
1321   for (const Edge* e : src.edges()) {
1322     Node* src_copy = node_map[e->src()];
1323     Node* dst_copy = node_map[e->dst()];
1324     dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1325   }
1326 }
1327 
1328 }  // namespace tensorflow
1329