1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph_constructor.h"
17
18 #include <algorithm>
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23
24 #include "tensorflow/core/common_runtime/shape_refiner.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/tensor_shape.pb.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/versions.h"
33 #include "tensorflow/core/framework/versions.pb.h"
34 #include "tensorflow/core/graph/algorithm.h"
35 #include "tensorflow/core/graph/graph.h"
36 #include "tensorflow/core/graph/tensor_id.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/flatmap.h"
39 #include "tensorflow/core/lib/gtl/flatset.h"
40 #include "tensorflow/core/lib/gtl/inlined_vector.h"
41 #include "tensorflow/core/lib/strings/scanner.h"
42 #include "tensorflow/core/lib/strings/str_util.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/public/version.h"
45
46 namespace tensorflow {
47
48 namespace {
IsMerge(const NodeDef & node_def)49 inline bool IsMerge(const NodeDef& node_def) {
50 return node_def.op() == "Merge" || node_def.op() == "RefMerge";
51 }
52
IsNextIteration(const NodeDef & node_def)53 inline bool IsNextIteration(const NodeDef& node_def) {
54 return node_def.op() == "NextIteration" ||
55 node_def.op() == "RefNextIteration";
56 }
57
IsValidNodeName(StringPiece s,bool allow_internal_ops)58 bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
59 using ::tensorflow::strings::Scanner;
60 return Scanner(s)
61 .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
62 : Scanner::LETTER_DIGIT_DOT)
63 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
64 .Eos()
65 .GetResult();
66 }
67
68 class GraphConstructor {
69 public:
70 struct Options {
Optionstensorflow::__anone86fd2d90111::GraphConstructor::Options71 Options(const GraphConstructorOptions& in) // NOLINT(runtime/explicit)
72 : allow_internal_ops(in.allow_internal_ops),
73 expect_device_spec(in.expect_device_spec),
74 importing(false),
75 validate_colocation_constraints(false) {}
Optionstensorflow::__anone86fd2d90111::GraphConstructor::Options76 Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit)
77 : allow_internal_ops(false),
78 expect_device_spec(false),
79 prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
80 ? in.prefix
81 : in.prefix + "/"),
82 uniquify_names(in.uniquify_names),
83 uniquify_prefix(in.uniquify_prefix),
84 input_map(in.input_map.begin(), in.input_map.end()),
85 skip_mapped_nodes(in.skip_mapped_nodes),
86 control_dependencies(in.control_dependencies),
87 return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
88 return_nodes(in.return_nodes),
89 importing(true),
90 validate_colocation_constraints(in.validate_colocation_constraints),
91 validate_shape(in.validate_shape),
92 default_device(in.default_device) {}
93
94 bool allow_internal_ops;
95 bool expect_device_spec;
96
97 string prefix;
98 bool uniquify_names;
99 bool uniquify_prefix;
100 std::map<TensorId, TensorId> input_map;
101 bool skip_mapped_nodes;
102 std::vector<string> control_dependencies;
103 std::vector<TensorId> return_tensors;
104 std::vector<string> return_nodes;
105
106 // TODO(ashankar): This bool exists to separate out functionality required
107 // to make ImportGraphDef a close equivalent of Python's import_graph_def
108 // without affecting the behavior of ConvertGraphDefToGraph at the time
109 // ImportGraphDef was added.
110 //
111 // That said, the functionality here (shape and op validation) seems
112 // applicable to ConvertGraphDefToGraph as well, so make an attempt to
113 // remove this.
114 bool importing;
115 bool validate_colocation_constraints;
116 bool validate_shape = true;
117
118 string default_device;
119 };
120
121 typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
122
123 // versions and library may be nullptr
Construct(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)124 static Status Construct(
125 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
126 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
127 std::vector<std::pair<Node*, int>>* return_tensors,
128 std::vector<Node*>* return_nodes,
129 std::vector<SafeTensorId>* missing_unused_input_map_keys) {
130 if (versions) {
131 TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
132 TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
133 "GraphDef", "graph"));
134 }
135 GraphConstructor c(opts, node_defs, versions, library, g, refiner,
136 return_tensors, return_nodes,
137 missing_unused_input_map_keys);
138 const Status s = c.TryImport();
139 if (!s.ok()) c.Undo();
140 return s;
141 }
142
143 private:
GraphConstructor(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)144 GraphConstructor(const Options& opts, NodeDefSlice node_defs,
145 const VersionDef* versions,
146 const FunctionDefLibrary* library, Graph* g,
147 ShapeRefiner* refiner,
148 std::vector<std::pair<Node*, int>>* return_tensors,
149 std::vector<Node*>* return_nodes,
150 std::vector<SafeTensorId>* missing_unused_input_map_keys)
151 : opts_(opts),
152 node_defs_(node_defs),
153 versions_(versions),
154 library_(library),
155 g_(g),
156 original_versions_(g->versions()),
157 prefix_(opts.prefix),
158 refiner_(refiner),
159 return_tensors_(return_tensors),
160 return_nodes_(return_nodes),
161 missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
162
TryImport()163 Status TryImport() {
164 TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
165 TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
166 TF_RETURN_IF_ERROR(BuildNodeIndex());
167 TF_RETURN_IF_ERROR(InitFromEdges());
168 TF_RETURN_IF_ERROR(Convert());
169 TF_RETURN_IF_ERROR(AddBackEdges());
170 TF_RETURN_IF_ERROR(UpdateVersionDef());
171 TF_RETURN_IF_ERROR(PopulateReturnTensors());
172 TF_RETURN_IF_ERROR(PopulateReturnNodes());
173 TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
174 UpdateUniquifiedColocationNames();
175 FixupSourceAndSinkEdges(g_);
176 return Status::OK();
177 }
178
179 Status EnsureNoNameCollisions();
180 Status ValidateInputMapAndControlDependencies();
181 Status BuildNodeIndex();
182 Status InitFromEdges();
183 Status Convert();
184 Status AddBackEdges();
185 Status UpdateVersionDef();
186 Status PopulateReturnTensors();
187 Status PopulateReturnNodes();
188 Status PopulateMissingUnusedInputMapKeys();
189
190 void Undo();
191
192 Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped);
193 Status ValidateColocationConstraints(const NodeDef& node_def);
194 Status MakeNode(const NodeDef& node_def, Node** node);
195 Status MakeEdge(Node* src, int output_index, Node* dst, int input_index);
196 Status ValidateShape(Node* node);
197 Status ModifyNodeDefForImport(NodeDef* node_def);
198 // Modifies node_def's inputs according to opts_.input_map.
199 // input_already_exists is a pre-initialized vector of length
200 // node_def->input_size(). This function will mark inputs that are remapped to
201 // true.
202 void RemapNodeDefInputs(NodeDef* node_def,
203 std::vector<bool>* input_already_exists);
204 // input_already_exists is a pre-initialized vector of length
205 // node_def->input_size(). This function will add and mark control inputs as
206 // true.
207 void AddControlDependencies(NodeDef* node_def,
208 std::vector<bool>* input_already_exists);
209 void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
210 NodeDef* node_def);
211
212 // Modifies `node_def` if its name isn't unique, or if any of its inputs'
213 // names have been uniquified. This must be called in topological order on all
214 // nodes.
215 void UniquifyNames(const std::vector<bool>& input_already_exists,
216 NodeDef* node_def);
217
218 // Updates any constructed nodes' colocation group names if the name has been
219 // updated by UniquifyNames. This is called after all the nodes have been
220 // constructed so all the names have been uniquified if necessary.
221 void UpdateUniquifiedColocationNames();
222
223 // Returns true if `name` already exists in `g_` (either as a node name or
224 // prefix).
225 bool NameExistsInGraph(StringPiece name);
226
227 // Returns true if `name` already exists in the GraphDef being imported
228 // (either as a node name or prefix).
229 bool NameExistsInGraphDef(StringPiece name);
230
231 // Returns a unique version of `original_name`, or `original_name` if it's
232 // already unique in the graph.
233 string FindUniqueName(StringPiece original_name);
234
235 // Decrement pending count for users of `processed` and add the ones that now
236 // have all of their pending inputs satisfied to `ready_`.
237 void UpdatePendingCountAndReady(int processed);
238
239 // From constructor
240 const Options opts_;
241 const NodeDefSlice node_defs_;
242 const VersionDef* versions_;
243 const FunctionDefLibrary* library_;
244 Graph* g_;
245 const VersionDef original_versions_;
246
247 // A copy of opts_.prefix, possibly uniquified.
248 string prefix_;
249
250 ShapeRefiner* refiner_;
251
252 // May be null. Not owned.
253 std::vector<std::pair<Node*, int>>* return_tensors_;
254
255 // May be null. Not owned.
256 std::vector<Node*>* return_nodes_;
257
258 // May be null. Not owned.
259 std::vector<SafeTensorId>* missing_unused_input_map_keys_;
260
261 // Intermediate datastructure used to populate
262 // `missing_unused_input_map_keys_`.
263 std::set<TensorId> used_input_map_keys_;
264
265 // Mapping from node name to the index within node_defs_.
266 struct NodeInfo {
NodeInfotensorflow::__anone86fd2d90111::GraphConstructor::NodeInfo267 explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
268 // std::unordered_map<> requires that we have a default constructor.
NodeInfotensorflow::__anone86fd2d90111::GraphConstructor::NodeInfo269 NodeInfo() : NodeInfo(-1) {}
270 int gdef_index;
271 Node* node; // nullptr until the NodeDef is converted to a Node.
272 };
273 gtl::FlatMap<StringPiece, NodeInfo, StringPieceHasher> gdef_nodes_;
274
275 // Prefixes already used in the GraphDef being imported.
276 gtl::FlatSet<StringPiece, StringPieceHasher> gdef_prefixes_;
277
278 // Mapping from node name to the existing node in g_.
279 gtl::FlatMap<StringPiece, Node*, StringPieceHasher> existing_nodes_;
280
281 // Prefixes already used in the graph.
282 gtl::FlatSet<StringPiece, StringPieceHasher> existing_prefixes_;
283
284 // Imported node names that have been uniquified. The key is the original
285 // name, the value is the new unique name.
286 gtl::FlatMap<string, string> uniquified_names_;
287
288 // Index of NodeDefs in node_defs_ with all inputs already converted. We use a
289 // (sorted) set so nodes are created in the order defined in the GraphDef.
290 std::set<int> ready_;
291
292 // Mapping between index within node_defs_ and the number of inputs that
293 // still need to be converted.
294 std::vector<int> pending_count_;
295
296 // Mapping between index within node_defs_ and the index within node_defs_ of
297 // all nodes it outputs to.
298 std::vector<gtl::InlinedVector<int, 4>> outputs_;
299
300 // Used in the conversion from node_defs_ to g_ to represent the ith input
301 // of a node.
302 struct InputInfo {
InputInfotensorflow::__anone86fd2d90111::GraphConstructor::InputInfo303 explicit InputInfo(const string& node_name, Node* n, int i)
304 : name(node_name), node(n), index(i) {}
305 // Use string instead of StringPiece so we don't have to manage lifetime
306 string name;
307 Node* node;
308 int index;
309 };
310
311 // Used in the conversion from node_defs_ to g_ to represent an edge from
312 // the node named 'name' to node 'n'.
313 struct EdgeInfo {
EdgeInfotensorflow::__anone86fd2d90111::GraphConstructor::EdgeInfo314 explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
315 : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
316 // Use string instead of StringPiece so we don't have to manage lifetime
317 string src_name;
318 int src_index;
319 Node* dst_node;
320 int dst_index;
321 };
322 std::vector<EdgeInfo> back_edges_;
323 };
324
UpdatePendingCountAndReady(int processed)325 void GraphConstructor::UpdatePendingCountAndReady(int processed) {
326 // We didn't consider NextIteration->Merge edges when computing
327 // pending_counts_ so we should not have to consider it here either.
328 bool is_next_iteration = IsNextIteration(*node_defs_[processed]);
329 for (size_t i = 0; i < outputs_[processed].size(); ++i) {
330 const int output = outputs_[processed][i];
331 bool is_next_iteration_to_merge_edge =
332 is_next_iteration && IsMerge(*node_defs_[output]);
333 if (!is_next_iteration_to_merge_edge) {
334 int* current_pending_count = &pending_count_[output];
335 CHECK_GT(*current_pending_count, 0);
336 (*current_pending_count)--;
337 if (*current_pending_count == 0) {
338 ready_.insert(output);
339 }
340 }
341 }
342 }
343
344 // This could be expensive but we don't expect to call it often, if at all (only
345 // if there are multiple nodes in g_ with the same name)
NodeNameInValues(const std::map<TensorId,TensorId> & input_map,const StringPiece & node_name)346 bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
347 const StringPiece& node_name) {
348 for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) {
349 if (iter->second.first == node_name) return true;
350 }
351 return false;
352 }
353
NodeNameInValues(const std::vector<string> & control_dependencies,const StringPiece & node_name)354 bool NodeNameInValues(const std::vector<string>& control_dependencies,
355 const StringPiece& node_name) {
356 return std::find(control_dependencies.begin(), control_dependencies.end(),
357 node_name) != control_dependencies.end();
358 }
359
360 // Adds any prefixes of `node_name` (not including the full name itself) to
361 // `prefixes`.
AddPrefixes(StringPiece node_name,gtl::FlatSet<StringPiece,StringPieceHasher> * prefixes)362 void AddPrefixes(StringPiece node_name,
363 gtl::FlatSet<StringPiece, StringPieceHasher>* prefixes) {
364 size_t idx = -1;
365 while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) {
366 prefixes->insert(node_name.substr(0, idx));
367 }
368 }
369
EnsureNoNameCollisions()370 Status GraphConstructor::EnsureNoNameCollisions() {
371 existing_nodes_.reserve(g_->num_nodes());
372 // Populate existing_nodes_ and existing_prefixes_.
373 for (Node* n : g_->nodes()) {
374 bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
375 if (already_exists) {
376 if (NodeNameInValues(opts_.input_map, n->name())) {
377 return errors::InvalidArgument(
378 "cannot resolve input_map because multiple nodes exist with name '",
379 n->name(), "'");
380 }
381 if (NodeNameInValues(opts_.control_dependencies, n->name())) {
382 return errors::InvalidArgument(
383 "cannot resolve control_dependencies because multiple nodes exist "
384 "with name '",
385 n->name(), "'");
386 }
387 }
388 AddPrefixes(n->name(), &existing_prefixes_);
389 }
390 if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
391 for (const NodeDef* n : node_defs_) {
392 const string& name = n->name();
393 if (NameExistsInGraph(name)) {
394 return errors::InvalidArgument("Node name '", name,
395 "' already exists in the Graph");
396 }
397 }
398 } else if (!prefix_.empty()) {
399 StringPiece prefix_no_slash(prefix_);
400 prefix_no_slash.remove_suffix(1);
401 if (!IsValidNodeName(prefix_no_slash, false)) {
402 return errors::InvalidArgument("Imported node name prefix '", prefix_,
403 "' would lead to invalid node names");
404 }
405 if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
406 prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
407 }
408 }
409 return Status::OK();
410 }
411
ValidateInputMapAndControlDependencies()412 Status GraphConstructor::ValidateInputMapAndControlDependencies() {
413 for (const auto& mapping : opts_.input_map) {
414 TensorId src = mapping.first;
415 TensorId dst = mapping.second;
416 if (existing_nodes_.count(dst.first) == 0) {
417 return errors::InvalidArgument(
418 "node '", dst.first, "' in input_map does not exist in graph ",
419 "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")");
420 }
421 if ((src.second == Graph::kControlSlot) !=
422 (dst.second == Graph::kControlSlot)) {
423 return errors::InvalidArgument("input_map entry ", src.ToString(), "->",
424 dst.ToString(), " between ",
425 "control edge and non-control edge");
426 }
427 }
428 for (const string& node : opts_.control_dependencies) {
429 if (existing_nodes_.count(node) == 0) {
430 return errors::InvalidArgument(
431 "node '", node,
432 "' in control_dependencies does not exist in "
433 "graph");
434 }
435 }
436 return Status::OK();
437 }
438
BuildNodeIndex()439 Status GraphConstructor::BuildNodeIndex() {
440 // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_.
441 for (int n = 0; n < node_defs_.size(); ++n) {
442 const NodeDef& node_def = *node_defs_[n];
443 if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
444 return errors::InvalidArgument(
445 "Node '", node_def.name(),
446 "': Node name contains invalid characters");
447 }
448 if (!gdef_nodes_
449 .insert(std::make_pair(StringPiece(node_def.name()), NodeInfo(n)))
450 .second) {
451 return errors::InvalidArgument("Node '", node_def.name(),
452 "' is not unique");
453 }
454 // Validate the operation's type.
455 if (node_def.op().empty()) {
456 return errors::InvalidArgument("Node '", node_def.name(),
457 "' does not specify an operation");
458 }
459 if (opts_.expect_device_spec && node_def.device().empty()) {
460 return errors::InvalidArgument("Node '", node_def.name(),
461 "' is missing a device specification");
462 }
463 // Validate control edges at end
464 bool in_control_dependence = false;
465 for (int i = 0; i < node_def.input_size(); ++i) {
466 StringPiece input_name = node_def.input(i);
467 if (!input_name.empty() && str_util::StartsWith(input_name, "^")) {
468 in_control_dependence = true;
469 } else if (in_control_dependence) {
470 return errors::InvalidArgument(
471 "Node '", node_def.name(),
472 "': Control dependencies must come after regular dependencies");
473 }
474 }
475 // Update gdef_prefixes_.
476 AddPrefixes(node_def.name(), &gdef_prefixes_);
477 }
478 return Status::OK();
479 }
480
GetNextIterationNodes(const GraphConstructor::NodeDefSlice & node_defs)481 std::unordered_set<string> GetNextIterationNodes(
482 const GraphConstructor::NodeDefSlice& node_defs) {
483 std::unordered_set<string> next_iteration_nodes;
484
485 for (int n = 0; n < node_defs.size(); ++n) {
486 const NodeDef& node_def = *node_defs[n];
487 if (IsNextIteration(node_def)) {
488 next_iteration_nodes.insert(node_def.name());
489 }
490 }
491
492 return next_iteration_nodes;
493 }
494
InitFromEdges()495 Status GraphConstructor::InitFromEdges() {
496 const int num_nodes = node_defs_.size();
497 pending_count_.reserve(num_nodes);
498 outputs_.resize(num_nodes);
499 std::unordered_set<string> next_iteration_nodes_ =
500 GetNextIterationNodes(node_defs_);
501
502 // Parse the inputs for each node.
503 for (int n = 0; n < num_nodes; ++n) {
504 const NodeDef& node_def = *node_defs_[n];
505 int pending_count = node_def.input_size();
506 if (IsMerge(node_def)) {
507 // Cycles in the graph are only allowed for while loops. A while loop is
508 // identified by an edge from a NextIteration node to a Merge node. For
509 // such Merge nodes, only wait for one non-control input before
510 // considering the node ready to process in Convert().
511 int32 num_control_edges = 0;
512 bool has_loop_back_edge = false;
513 for (int i = 0; i < node_def.input_size(); ++i) {
514 StringPiece input_name(node_def.input(i));
515 if (str_util::StartsWith(input_name, "^")) {
516 num_control_edges++;
517 } else {
518 TensorId id(ParseTensorName(input_name));
519 if (next_iteration_nodes_.find(string(id.first)) !=
520 next_iteration_nodes_.end()) {
521 has_loop_back_edge = true;
522 }
523 }
524 }
525 if (has_loop_back_edge) {
526 pending_count = num_control_edges + 1;
527 }
528 }
529 for (int i = 0; i < node_def.input_size(); ++i) {
530 StringPiece input_name = node_def.input(i);
531 TensorId id(ParseTensorName(input_name));
532 if (opts_.input_map.count(id) == 0) {
533 // If an input is not mapped, then the input should appear in the graph
534 // being imported.
535 auto iter = gdef_nodes_.find(id.first);
536 if (iter == gdef_nodes_.end()) {
537 return errors::InvalidArgument("Node '", node_def.name(),
538 "': Unknown input node '",
539 node_def.input(i), "'");
540 }
541 outputs_[iter->second.gdef_index].push_back(n);
542 } else {
543 // This input is mapped to an existing edge. Therefore this input is
544 // as good as being already processed.
545 --pending_count;
546 DCHECK_GE(pending_count, 0);
547 }
548 }
549 if (pending_count == 0) {
550 ready_.insert(n);
551 }
552 pending_count_.push_back(pending_count);
553 }
554 return Status::OK();
555 }
556
ValidateColocationConstraints(const NodeDef & node_def)557 Status GraphConstructor::ValidateColocationConstraints(
558 const NodeDef& node_def) {
559 if (!opts_.validate_colocation_constraints || !opts_.importing)
560 return Status::OK();
561 const auto iter = node_def.attr().find(kColocationAttrName);
562 if (iter == node_def.attr().end()) return Status::OK();
563 for (const string& c : iter->second.list().s()) {
564 StringPiece s(c);
565 if (str_util::ConsumePrefix(&s, kColocationGroupPrefix) &&
566 gdef_nodes_.find(s) == gdef_nodes_.end()) {
567 return errors::InvalidArgument(
568 "Node '", node_def.name(),
569 "' expects to be colocated with unknown node '", s, "'");
570 }
571 }
572 return Status::OK();
573 }
574
MakeNode(const NodeDef & node_def,Node ** node)575 Status GraphConstructor::MakeNode(const NodeDef& node_def, Node** node) {
576 // Add the node to the graph.
577 Status status;
578 *node = g_->AddNode(node_def, &status);
579 if (!status.ok()) return status;
580 if (opts_.expect_device_spec) {
581 (*node)->set_assigned_device_name(node_def.device());
582 }
583 return Status::OK();
584 }
585
ValidateShape(Node * node)586 Status GraphConstructor::ValidateShape(Node* node) {
587 if (!opts_.importing || !opts_.validate_shape) return Status::OK();
588 TF_RETURN_IF_ERROR(refiner_->AddNode(node));
589 // For nodes with the _output_shapes attribute, override the shape.
590 std::vector<TensorShapeProto> shape_attrs;
591 const char* kAttrName = "_output_shapes";
592 if (!GetNodeAttr(node->attrs(), kAttrName, &shape_attrs).ok()) {
593 // No _output_shapes attribute, the AddNode call above was sufficient.
594 return Status::OK();
595 }
596 auto* ic = refiner_->GetContext(node);
597 DCHECK(ic != nullptr)
598 << "ShapeRefiner::AddNode() should have created the InferenceContext";
599 if (shape_attrs.size() < node->num_outputs()) {
600 return errors::InvalidArgument(
601 "Node '", node->name(), "' has ", node->num_outputs(),
602 " outputs but the ", kAttrName, " attribute specifies shapes for ",
603 shape_attrs.size(), " outputs");
604 }
605 // NOTE(skyewm): we don't raise an error here because some users depend on
606 // this behavior, even though it's unsafe.
607 // TODO(b/74619486): raise an error.
608 if (shape_attrs.size() > node->num_outputs()) {
609 LOG(WARNING) << "Node '" << node->name() << "' has " << node->num_outputs()
610 << " outputs but the " << kAttrName
611 << " attribute specifies shapes for " << shape_attrs.size()
612 << " outputs. Output shapes may be inaccurate.";
613 }
614 for (int i = 0; i < node->num_outputs(); ++i) {
615 const TensorShapeProto& p = shape_attrs[i];
616 shape_inference::ShapeHandle h;
617 Status s = ic->MakeShapeFromShapeProto(p, &h);
618 if (!s.ok()) {
619 return errors::InvalidArgument("Node '", node->name(), " has an invalid ",
620 kAttrName, " attribute (shape #", i,
621 " error:'", s.error_message(), "'");
622 }
623 s = refiner_->SetShape(node, i, h);
624 if (!s.ok()) {
625 // If the output shape is incompatible with what is inferred
626 // by the graph for a very specific whitelist of ops, then we
627 // ignore this output shape. This can happen if there is a
628 // bug in the shape function for some operation, and the
629 // serialized graph def has the incorrect shape set when
630 // running on a newer binary with the fixed shape function.
631 // This is an escape hatch that allows us to correct shape
632 // functions that are not critical to correct execution but
633 // would cause graphs to fail if imported after correcting.
634 //
635 const string& op = node->type_string();
636 const std::vector<string> whitelist = {
637 // To be removed after 2017/03/08.
638 "RandomShuffleQueue",
639 "PaddingFIFOQueue",
640 "FIFOQueue",
641 "PriorityQueue",
642 "QueueSize",
643 "Stack",
644 "Barrier",
645 "BarrierReadySize",
646 "BarrierIncompleteSize",
647 "HashTable",
648 "MutableHashTable",
649 "MutableHashTableOfTensors",
650 "Mutex",
651 "CuckooTable",
652 "IndexTable",
653 "WholeFileReader",
654 "TextLineReader",
655 "FixedLengthRecordReader",
656 "TFRecordReader",
657 "IdentityReader",
658 "RefSwitch",
659 "RefEnter",
660 "RefNextIteration",
661 "RefMerge",
662 "RefIdentity",
663 "LMDBReader",
664 // To be removed after 2017/04/24.
665 "ConditionalAccumulator",
666 "SparseConditionalAccumulator",
667 "Table",
668 };
669 if (std::find(whitelist.begin(), whitelist.end(), op) ==
670 whitelist.end()) {
671 return errors::InvalidArgument(
672 "Node '", node->name(), "' has an ", kAttrName,
673 " attribute inconsistent with the GraphDef for output #", i, ": ",
674 s.error_message());
675 }
676 }
677 }
678 node->ClearAttr(kAttrName);
679 return Status::OK();
680 }
681
ModifyNodeDefForImport(NodeDef * node_def)682 Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
683 const OpDef* op_def;
684 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
685 AddDefaultsToNodeDef(*op_def, node_def);
686 TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
687 if (versions_) {
688 TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions_->producer()));
689 }
690 return Status::OK();
691 }
692
RemoveInputs(const std::vector<int> & inputs_to_remove,NodeDef * node_def,std::vector<bool> * input_already_exists)693 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
694 std::vector<bool>* input_already_exists) {
695 // Remove 'inputs_to_remove' from 'node_def'
696 NodeDef copy;
697 copy.mutable_input()->Reserve(node_def->input_size() -
698 inputs_to_remove.size());
699 for (int i = 0, j = 0; i < node_def->input_size(); ++i) {
700 if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
701 ++j;
702 } else {
703 copy.add_input()->swap(*node_def->mutable_input(i));
704 }
705 }
706 node_def->mutable_input()->Swap(copy.mutable_input());
707 // Remove 'inputs_to_remove' from 'input_already_exists'
708 for (int idx : inputs_to_remove) {
709 input_already_exists->erase(input_already_exists->begin() + idx);
710 }
711 DCHECK_EQ(input_already_exists->size(), node_def->input_size());
712 }
713
RemapNodeDefInputs(NodeDef * node_def,std::vector<bool> * input_already_exists)714 void GraphConstructor::RemapNodeDefInputs(
715 NodeDef* node_def, std::vector<bool>* input_already_exists) {
716 DCHECK_EQ(input_already_exists->size(), node_def->input_size());
717 std::set<TensorId> control_inputs;
718 std::vector<int> inputs_to_remove;
719
720 for (int i = 0; i < node_def->input_size(); ++i) {
721 auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
722 if (iter == opts_.input_map.end()) continue;
723 used_input_map_keys_.insert(iter->first);
724
725 TensorId new_input = iter->second;
726 if (new_input.second == Graph::kControlSlot) {
727 // Check if we've already remapped a different input to new_input, and if
728 // so remove this input.
729 if (control_inputs.count(new_input) > 0) {
730 inputs_to_remove.push_back(i);
731 continue;
732 }
733 control_inputs.insert(new_input);
734 }
735 node_def->set_input(i, new_input.ToString());
736 (*input_already_exists)[i] = true;
737 }
738 if (!inputs_to_remove.empty()) {
739 RemoveInputs(inputs_to_remove, node_def, input_already_exists);
740 }
741 }
742
AddControlDependencies(NodeDef * node_def,std::vector<bool> * input_already_exists)743 void GraphConstructor::AddControlDependencies(
744 NodeDef* node_def, std::vector<bool>* input_already_exists) {
745 // To avoid adding redundant control dependencies to every imported node, skip
746 // nodes that will inherit the dependencies from another imported node.
747 bool inherits_deps = false;
748 for (int i = 0; i < node_def->input_size(); ++i) {
749 // Assume we won't inherit dependencies from remapped inputs that already
750 // exist in the graph. Even if we're wrong, we'll only add redundant
751 // dependencies.
752 if ((*input_already_exists)[i]) continue;
753
754 // If this input is a backedge, assume we won't inherit the dependencies.
755 // TODO(skyewm): we have many redundant ParseTensorName calls. It could be
756 // worth optimizing these.
757 TensorId id(ParseTensorName(node_def->input(i)));
758 auto iter = gdef_nodes_.find(id.first);
759 DCHECK(iter != gdef_nodes_.end()) << id.first;
760 if (iter->second.node == nullptr) {
761 // Input hasn't been created yet, indicating it's a backedge.
762 continue;
763 }
764 inherits_deps = true;
765 }
766 if (inherits_deps) return;
767
768 // node_def either has no inputs or all remapped inputs, add the control
769 // dependencies
770 for (const string& control_dep : opts_.control_dependencies) {
771 string input = TensorId(control_dep, Graph::kControlSlot).ToString();
772 bool found = false;
773 for (int i = node_def->input_size() - 1; i >= 0; --i) {
774 const string& node_input = node_def->input(i);
775 if (node_input[0] != '^') {
776 // Control inputs are at the end. Break when we reach the non-control
777 // inputs.
778 break;
779 }
780 if (node_input == input) {
781 // Control dependency already exists
782 found = true;
783 break;
784 }
785 }
786 if (found) {
787 continue;
788 }
789 node_def->add_input(input);
790 input_already_exists->push_back(true);
791 }
792 }
793
AddPrefixToNodeDef(const std::vector<bool> & input_already_exists,NodeDef * node_def)794 void GraphConstructor::AddPrefixToNodeDef(
795 const std::vector<bool>& input_already_exists, NodeDef* node_def) {
796 if (prefix_.empty()) return;
797 node_def->set_name(strings::StrCat(prefix_, node_def->name()));
798 // Update names of input nodes
799 for (int i = 0; i < node_def->input_size(); ++i) {
800 // Skip remapped inputs (which already exist in g_ and are not being
801 // imported).
802 if (input_already_exists[i]) continue;
803 StringPiece input(node_def->input(i));
804 if (str_util::ConsumePrefix(&input, "^")) {
805 node_def->set_input(i, strings::StrCat("^", prefix_, input));
806 } else {
807 node_def->set_input(i, strings::StrCat(prefix_, input));
808 }
809 }
810 // Update names of colocation groups
811 if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
812 auto* list =
813 node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
814 for (int i = 0; i < list->s_size(); ++i) {
815 StringPiece v(list->s(i));
816 if (str_util::ConsumePrefix(&v, kColocationGroupPrefix)) {
817 list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
818 }
819 }
820 }
821 }
822
UniquifyNames(const std::vector<bool> & input_already_exists,NodeDef * node_def)823 void GraphConstructor::UniquifyNames(
824 const std::vector<bool>& input_already_exists, NodeDef* node_def) {
825 if (NameExistsInGraph(node_def->name())) {
826 string old_name = node_def->name();
827 node_def->set_name(FindUniqueName(node_def->name()));
828 uniquified_names_[old_name] = node_def->name();
829 // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with
830 // `name` because we guarantee the original NodeDef names are unique,
831 // meaning we won't generate this name again.
832 }
833 for (int i = 0; i < node_def->input_size(); ++i) {
834 // Skip remapped inputs (which already exist in g_ and are not being
835 // imported).
836 if (input_already_exists[i]) continue;
837 TensorId id = ParseTensorName(node_def->input(i));
838 // We require that UniquifyNames() is called on all NodeDefs in topological
839 // order. This guarantees that node_def's inputs will already be uniquified
840 // if necessary.
841 auto iter = uniquified_names_.find(string(id.first));
842 if (iter == uniquified_names_.end()) continue;
843 id.first = iter->second;
844 node_def->set_input(i, id.ToString());
845 }
846 }
847
UpdateUniquifiedColocationNames()848 void GraphConstructor::UpdateUniquifiedColocationNames() {
849 for (const auto& pair : gdef_nodes_) {
850 Node* node = pair.second.node;
851 if (node == nullptr) continue;
852 std::vector<string> coloc_values;
853 Status status =
854 GetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values);
855 if (!status.ok()) continue;
856 bool updated = false;
857 for (int i = 0; i < coloc_values.size(); ++i) {
858 StringPiece val(coloc_values[i]);
859 if (str_util::ConsumePrefix(&val, kColocationGroupPrefix)) {
860 auto name_pair = uniquified_names_.find(string(val));
861 if (name_pair == uniquified_names_.end()) continue;
862 updated = true;
863 coloc_values[i] =
864 strings::StrCat(kColocationGroupPrefix, name_pair->second);
865 }
866 }
867 if (updated) {
868 node->AddAttr(kColocationAttrName, coloc_values);
869 }
870 }
871 }
872
NameExistsInGraph(StringPiece name)873 bool GraphConstructor::NameExistsInGraph(StringPiece name) {
874 if (existing_nodes_.find(name) != existing_nodes_.end()) return true;
875 if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true;
876 return false;
877 }
878
NameExistsInGraphDef(StringPiece name)879 bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
880 if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true;
881 if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true;
882 return false;
883 }
884
FindUniqueName(StringPiece original_name)885 string GraphConstructor::FindUniqueName(StringPiece original_name) {
886 string name(original_name);
887 int count = 0;
888 // Check that any generated names don't collide with imported NodeDefs (as
889 // well as nodes in g_).
890 while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) {
891 name = strings::StrCat(original_name, "_", ++count);
892 }
893 return name;
894 }
895
IsNodeFullyMapped(const NodeDef & node_def,bool * is_node_mapped)896 Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
897 bool* is_node_mapped) {
898 const OpDef* op_def;
899 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
900 for (int i = 0; i < op_def->output_arg_size(); ++i) {
901 if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) {
902 *is_node_mapped = false;
903 return Status::OK();
904 }
905 }
906 *is_node_mapped = true;
907 return Status::OK();
908 }
909
Convert()910 Status GraphConstructor::Convert() {
911 // Import functions before adding nodes, since imported nodes may refer to
912 // functions
913 if (library_) {
914 TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library_));
915 }
916
917 std::vector<InputInfo> inputs;
918 int processed = 0;
919
920 std::vector<bool> input_already_exists;
921
922 // Process the NodeDefs in topological order.
923 // (InitFromEdges() sets this up by filling in ready_ with nodes that have no
924 // inputs, pending_counts_ with the number of inputs for each node and
925 // outputs_ with the outputs of each node).
926 while (!ready_.empty()) {
927 int o = *ready_.begin();
928 ready_.erase(ready_.begin());
929 ++processed;
930 inputs.clear();
931 bool has_data_back_edge = false;
932
933 const NodeDef& original_node_def = *node_defs_[o];
934 NodeDef imported_node_def;
935 const NodeDef* node_def;
936
937 // input_already_exists[i] is true iff the i-th input of the node we're
938 // importing refers to a preexisting node in g_ (i.e. input[i] existed prior
939 // to importing node_defs_). Conversely, input_already_exists[i] is false
940 // iff the input refers to a node in node_defs_.
941 input_already_exists.clear();
942 input_already_exists.resize(original_node_def.input_size(), false);
943
944 if (opts_.importing) {
945 if (opts_.skip_mapped_nodes) {
946 bool is_node_mapped = false;
947 TF_RETURN_IF_ERROR(
948 IsNodeFullyMapped(original_node_def, &is_node_mapped));
949 if (is_node_mapped) {
950 // Skip this node after updating pending_count_ for outputs
951 UpdatePendingCountAndReady(o);
952 continue;
953 }
954 }
955
956 // TODO(ashankar): The line below means an additional copy of the
957 // NodeDef, which can be expensive if the NodeDef contains large tensors
958 // in it. Might make sense to change the API for ImportGraphDef to take
959 // a mutable GraphDef* and avoid the copying.
960 imported_node_def = original_node_def;
961 if (!opts_.input_map.empty()) {
962 // Note that input_already_exists can shrink here
963 RemapNodeDefInputs(&imported_node_def, &input_already_exists);
964 }
965 if (!opts_.control_dependencies.empty()) {
966 // Note that input_already_exists can grow here
967 AddControlDependencies(&imported_node_def, &input_already_exists);
968 }
969 if (!opts_.default_device.empty() && imported_node_def.device().empty()) {
970 imported_node_def.set_device(opts_.default_device);
971 }
972
973 node_def = &imported_node_def;
974 } else {
975 node_def = &original_node_def;
976 }
977
978 DCHECK_EQ(node_def->input_size(), input_already_exists.size());
979 TF_RETURN_IF_ERROR(ValidateColocationConstraints(*node_def));
980 for (int i = 0; i < node_def->input_size(); ++i) {
981 TensorId id(ParseTensorName(node_def->input(i)));
982 Node* src_node;
983 int src_index;
984
985 if (!input_already_exists[i]) {
986 // Locate input in newly-imported nodes
987 auto iter = gdef_nodes_.find(id.first);
988 DCHECK(iter != gdef_nodes_.end()) << id.first;
989 src_node = iter->second.node;
990 src_index = id.second;
991 if (src_node == nullptr) has_data_back_edge = true;
992 } else {
993 // Input refers to preexistng node in graph
994 auto iter = existing_nodes_.find(id.first);
995 DCHECK(iter != existing_nodes_.end()) << id.first;
996 src_node = iter->second;
997 src_index = id.second;
998 }
999
1000 if (src_node != nullptr && src_index >= src_node->num_outputs()) {
1001 return errors::InvalidArgument(
1002 "Node '", node_def->name(), "': Connecting to invalid output ",
1003 id.second, " of source node ", id.first, " which has ",
1004 src_node->num_outputs(), " outputs");
1005 }
1006
1007 inputs.emplace_back(string(id.first), src_node, src_index);
1008 }
1009
1010 if (has_data_back_edge && !IsMerge(*node_def)) {
1011 return errors::InvalidArgument(
1012 "Node '", node_def->name(),
1013 "' had a back edge, but only Merge nodes can have back edges.");
1014 }
1015
1016 Node* node;
1017 if (opts_.importing) {
1018 if (!prefix_.empty()) {
1019 AddPrefixToNodeDef(input_already_exists, &imported_node_def);
1020 }
1021 // Note: no need to uniquify names if the prefix already guarantees
1022 // uniqueness
1023 if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
1024 UniquifyNames(input_already_exists, &imported_node_def);
1025 }
1026 TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def));
1027 }
1028 TF_RETURN_IF_ERROR(MakeNode(*node_def, &node));
1029 // Use original_node_def so name StringPiece remains valid
1030 gdef_nodes_[original_node_def.name()].node = node;
1031
1032 // Add edges from inputs to *node to the graph.
1033 for (size_t i = 0; i < inputs.size(); ++i) {
1034 if (inputs[i].node == nullptr) {
1035 // Record this back edge, which will be added after all nodes
1036 // are created.
1037 back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i);
1038 } else if (inputs[i].index == Graph::kControlSlot) {
1039 g_->AddControlEdge(inputs[i].node, node);
1040 } else {
1041 TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i));
1042 }
1043 }
1044
1045 TF_RETURN_IF_ERROR(ValidateShape(node));
1046
1047 // Update pending_count_ for outputs.
1048 UpdatePendingCountAndReady(o);
1049 }
1050
1051 if (processed < node_defs_.size()) {
1052 LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed)
1053 << " NODES IN A CYCLE";
1054 for (int64 i = 0; i < node_defs_.size(); i++) {
1055 if (pending_count_[i] != 0) {
1056 LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i])
1057 << " WITH PENDING COUNT = " << pending_count_[i];
1058 }
1059 }
1060 return errors::InvalidArgument(node_defs_.size() - processed,
1061 " nodes in a cycle");
1062 }
1063
1064 return Status::OK();
1065 }
1066
AddBackEdges()1067 Status GraphConstructor::AddBackEdges() {
1068 // Add the back edges after all nodes are created.
1069 for (auto e : back_edges_) {
1070 Node* src_node = gdef_nodes_[e.src_name].node;
1071 if (e.src_index == Graph::kControlSlot) {
1072 g_->AddControlEdge(src_node, e.dst_node);
1073 } else {
1074 TF_RETURN_IF_ERROR(
1075 MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index));
1076 }
1077
1078 VLOG(2) << "Add back edge: " << src_node->name() << " -> "
1079 << e.dst_node->name();
1080 }
1081 return Status::OK();
1082 }
1083
UpdateVersionDef()1084 Status GraphConstructor::UpdateVersionDef() {
1085 if (versions_ == nullptr) return Status::OK();
1086
1087 if (!opts_.importing) {
1088 g_->set_versions(*versions_);
1089 return Status::OK();
1090 }
1091 VersionDef versions = g_->versions();
1092 versions.set_producer(std::min(versions.producer(), versions_->producer()));
1093 versions.set_min_consumer(
1094 std::max(versions.min_consumer(), versions_->min_consumer()));
1095 if (versions_->bad_consumers_size() > 0) {
1096 std::set<int> bad(versions.bad_consumers().begin(),
1097 versions.bad_consumers().end());
1098 bad.insert(versions_->bad_consumers().begin(),
1099 versions_->bad_consumers().end());
1100 versions.clear_bad_consumers();
1101 for (int v : bad) {
1102 versions.add_bad_consumers(v);
1103 }
1104 }
1105 g_->set_versions(versions);
1106 return Status::OK();
1107 }
1108
PopulateReturnTensors()1109 Status GraphConstructor::PopulateReturnTensors() {
1110 if (opts_.return_tensors.empty()) return Status::OK();
1111 for (const TensorId& id : opts_.return_tensors) {
1112 auto iter = opts_.input_map.find(id);
1113 if (iter == opts_.input_map.end()) {
1114 // Locate id in imported nodes
1115 auto iter = gdef_nodes_.find(id.first);
1116 if (iter == gdef_nodes_.end()) {
1117 return errors::InvalidArgument("Requested return tensor '",
1118 id.ToString(),
1119 "' not found in graph def");
1120 }
1121 int num_outputs = iter->second.node->num_outputs();
1122 if ((id.second < 0 || id.second >= num_outputs) &&
1123 id.second != Graph::kControlSlot) {
1124 return errors::InvalidArgument("Invalid return output ", id.second,
1125 " of node '", id.first, "', which has ",
1126 num_outputs, " output(s)");
1127 }
1128 return_tensors_->push_back({iter->second.node, id.second});
1129 } else {
1130 // id was remapped to existing node
1131 TensorId remapped_id = iter->second;
1132 DCHECK_GT(existing_nodes_.count(remapped_id.first), 0);
1133 Node* node = existing_nodes_[remapped_id.first];
1134 return_tensors_->push_back({node, remapped_id.second});
1135 }
1136 }
1137 return Status::OK();
1138 }
1139
PopulateReturnNodes()1140 Status GraphConstructor::PopulateReturnNodes() {
1141 if (opts_.return_nodes.empty()) return Status::OK();
1142 for (StringPiece name : opts_.return_nodes) {
1143 auto iter = gdef_nodes_.find(name);
1144 if (iter == gdef_nodes_.end()) {
1145 return errors::InvalidArgument("Requested return node '", name,
1146 "' not found in graph def");
1147 }
1148 return_nodes_->push_back(iter->second.node);
1149 }
1150 return Status::OK();
1151 }
1152
PopulateMissingUnusedInputMapKeys()1153 Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
1154 if (missing_unused_input_map_keys_ == nullptr) return Status::OK();
1155 for (const auto& input_map_pair : opts_.input_map) {
1156 TensorId key = input_map_pair.first;
1157 if (used_input_map_keys_.count(key) > 0) continue;
1158
1159 auto pair = gdef_nodes_.find(key.first);
1160 if (pair == gdef_nodes_.end()) {
1161 // key's node doesn't exist in GraphDef
1162 missing_unused_input_map_keys_->push_back(key);
1163 continue;
1164 }
1165
1166 // Check that key's index is in bounds. Get the number of outputs from the
1167 // NodeDef, rather than the imported Node, since the Node may not exist if
1168 // opts_.skip_mapped_nodes is true.
1169 const NodeDef* node_def = node_defs_[pair->second.gdef_index];
1170 const OpDef* op_def;
1171 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
1172 int num_outputs;
1173 TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs));
1174 if (key.second >= num_outputs) {
1175 // key's index out of bounds
1176 missing_unused_input_map_keys_->push_back(key);
1177 }
1178 }
1179 return Status::OK();
1180 }
1181
Undo()1182 void GraphConstructor::Undo() {
1183 for (const auto& iter : gdef_nodes_) {
1184 if (iter.second.node != nullptr) {
1185 g_->RemoveNode(iter.second.node);
1186 }
1187 }
1188 g_->set_versions(original_versions_);
1189 }
1190
MakeEdge(Node * src,int output_index,Node * dst,int input_index)1191 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
1192 int input_index) {
1193 DataType src_out = src->output_type(output_index);
1194 DataType dst_in = dst->input_type(input_index);
1195 if (!TypesCompatible(dst_in, src_out)) {
1196 return errors::InvalidArgument(
1197 "Input ", input_index, " of node ", dst->name(), " was passed ",
1198 DataTypeString(src_out), " from ", src->name(), ":", output_index,
1199 " incompatible with expected ", DataTypeString(dst_in), ".");
1200 }
1201 g_->AddEdge(src, output_index, dst, input_index);
1202 return Status::OK();
1203 }
1204
1205 } // namespace
1206
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,const GraphDef & gdef,Graph * g)1207 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1208 const GraphDef& gdef, Graph* g) {
1209 ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1210 return GraphConstructor::Construct(
1211 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
1212 /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
1213 /*missing_unused_input_map_keys=*/nullptr);
1214 }
1215
ConvertNodeDefsToGraph(const GraphConstructorOptions & opts,gtl::ArraySlice<NodeDef> nodes,Graph * g)1216 Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
1217 gtl::ArraySlice<NodeDef> nodes, Graph* g) {
1218 ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
1219 // TODO(irving): Copy will go away once NodeInfo exists
1220 std::vector<const NodeDef*> node_defs;
1221 for (const auto& n : nodes) {
1222 node_defs.push_back(&n);
1223 }
1224 return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
1225 &refiner, /*return_tensors=*/nullptr,
1226 /*return_nodes=*/nullptr,
1227 /*missing_unused_input_map_keys=*/nullptr);
1228 }
1229
ImportGraphDef(const ImportGraphDefOptions & opts,const GraphDef & gdef,Graph * g,ShapeRefiner * refiner,ImportGraphDefResults * results)1230 Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
1231 Graph* g, ShapeRefiner* refiner,
1232 ImportGraphDefResults* results) {
1233 if (!opts.return_tensors.empty()) {
1234 if (results == nullptr) {
1235 return errors::InvalidArgument(
1236 "results argument to ImportGraphDef() must be non-null if "
1237 "opts.return_tensors is non-empty");
1238 }
1239 }
1240
1241 if (!opts.return_nodes.empty()) {
1242 if (opts.skip_mapped_nodes) {
1243 return errors::InvalidArgument(
1244 "Requesting return_nodes with skip_mapped_nodes set is not currently "
1245 "supported");
1246 }
1247 if (results == nullptr) {
1248 return errors::InvalidArgument(
1249 "results argument to ImportGraphDef() must be non-null if "
1250 "opts.return_nodes is non-empty");
1251 }
1252 }
1253
1254 if (results != nullptr) {
1255 if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
1256 !results->missing_unused_input_map_keys.empty()) {
1257 return errors::InvalidArgument(
1258 "All fields in results argument to ImportGraphDef() must be empty.");
1259 }
1260 }
1261
1262 ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
1263 if (refiner == nullptr) {
1264 refiner = &default_refiner;
1265 } else {
1266 // Log a warning if we are importing a GraphDef at an older
1267 // producer version after already having added non-source/sink
1268 // nodes to the graph in the past.
1269 if (gdef.versions().producer() > 0 &&
1270 gdef.versions().producer() < refiner->graph_def_version() &&
1271 g->num_nodes() > 2) {
1272 LOG(WARNING) << "Importing a graph with a lower producer version "
1273 << gdef.versions().producer()
1274 << " into an existing graph with producer version "
1275 << refiner->graph_def_version() << ". Shape inference will "
1276 << "have run different parts of the graph with different "
1277 << "producer versions.";
1278 }
1279 }
1280
1281 // Set the graph def version of the refiner as the min of the
1282 // current value and the version from the graph we are about to
1283 // import.
1284 //
1285 // Note: to match Run() semantics, we should re-run shape inference
1286 // on the entire graph if the producer version has changed. For now
1287 // we log the warning above.
1288 refiner->set_graph_def_version(
1289 std::min(refiner->graph_def_version(), gdef.versions().producer()));
1290
1291 if (results == nullptr) {
1292 return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
1293 &gdef.library(), g, refiner, nullptr,
1294 nullptr, nullptr);
1295 } else {
1296 return GraphConstructor::Construct(
1297 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
1298 &results->return_tensors, &results->return_nodes,
1299 &results->missing_unused_input_map_keys);
1300 }
1301 }
1302
CopyGraph(const Graph & src,Graph * dest)1303 void CopyGraph(const Graph& src, Graph* dest) {
1304 for (Node* n : dest->nodes()) {
1305 CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
1306 }
1307
1308 // Copy GraphDef versions
1309 dest->set_versions(src.versions());
1310
1311 // Copy the nodes
1312 std::unordered_map<const Node*, Node*>
1313 node_map; // "Node in src" -> "Node in *dest"
1314 node_map[src.source_node()] = dest->source_node();
1315 node_map[src.sink_node()] = dest->sink_node();
1316 for (Node* n : src.op_nodes()) {
1317 node_map[n] = dest->CopyNode(n);
1318 }
1319
1320 // Copy the edges
1321 for (const Edge* e : src.edges()) {
1322 Node* src_copy = node_map[e->src()];
1323 Node* dst_copy = node_map[e->dst()];
1324 dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1325 }
1326 }
1327
1328 } // namespace tensorflow
1329