1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph.h"
17 
18 #include <vector>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/node_properties.h"
24 #include "tensorflow/core/framework/op_def_builder.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/versions.pb.h"
27 #include "tensorflow/core/graph/graph_node_util.h"
28 #include "tensorflow/core/graph/while_context.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/gtl/map_util.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/lib/strings/stringprintf.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/public/version.h"
36 
37 namespace tensorflow {
38 
39 const int Graph::kControlSlot = -1;
40 
41 // Node
GetNodeClassForOp(const string & ts)42 Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
43   static const absl::flat_hash_map<string, Node::NodeClass>* kNodeClassTable =
44 #define REF_CLASS(key, value) \
45   {key, value}, { "Ref" key, value }
46       new absl::flat_hash_map<string, Node::NodeClass>({
47           // Keep in same order as NodeClass values
48           REF_CLASS("Switch", NC_SWITCH),
49           REF_CLASS("_SwitchN", NC_SWITCH),
50           REF_CLASS("Merge", NC_MERGE),
51           REF_CLASS("Enter", NC_ENTER),
52           REF_CLASS("Exit", NC_EXIT),
53           REF_CLASS("NextIteration", NC_NEXT_ITERATION),
54           {"LoopCond", NC_LOOP_COND},
55           {"ControlTrigger", NC_CONTROL_TRIGGER},
56           {"_Send", NC_SEND},
57           {"_HostSend", NC_HOST_SEND},
58           {"_Recv", NC_RECV},
59           {"_HostRecv", NC_HOST_RECV},
60           {"Const", NC_CONSTANT},
61           {"HostConst", NC_CONSTANT},
62           {"Variable", NC_VARIABLE},
63           {"VariableV2", NC_VARIABLE},
64           REF_CLASS("Identity", NC_IDENTITY),
65           {"GetSessionHandle", NC_GET_SESSION_HANDLE},
66           {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
67           {"GetSessionTensor", NC_GET_SESSION_TENSOR},
68           {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
69           {"Size", NC_METADATA},
70           {"Shape", NC_METADATA},
71           {"Rank", NC_METADATA},
72           {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
73           {"CollectiveReduce", NC_COLLECTIVE},
74           {"CollectiveBcastSend", NC_COLLECTIVE},
75           {"CollectiveBcastRecv", NC_COLLECTIVE},
76           {"CollectiveGather", NC_COLLECTIVE},
77           {"FakeParam", NC_FAKE_PARAM},
78           {"PartitionedCall", NC_PARTITIONED_CALL},
79           {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
80           {"SymbolicGradient", NC_SYMBOLIC_GRADIENT},
81           {"If", NC_IF},
82           {"StatelessIf", NC_IF},
83           {"While", NC_WHILE},
84           {"StatelessWhile", NC_WHILE},
85           {"Case", NC_CASE},
86           {"StatelessCase", NC_CASE},
87           // Not using the constants defined in FunctionLibraryDefinition
88           // for the
89           // 4 ops below because android inference library does not link
90           // tf.function related files.
91           {"_Arg", NC_ARG},
92           {"_DeviceArg", NC_ARG},
93           {"_Retval", NC_RETVAL},
94           {"_DeviceRetval", NC_RETVAL},
95           {"_XlaMerge", NC_MERGE},
96       });
97 #undef REF_CLASS
98 
99   auto it = kNodeClassTable->find(ts);
100   if (it != kNodeClassTable->end()) {
101     return it->second;
102   } else {
103     return NC_OTHER;
104   }
105 }
106 
DebugString() const107 string Node::DebugString() const {
108   string ret = strings::StrCat("{name:'", name(), "' id:", id_);
109   if (IsSource()) {
110     strings::StrAppend(&ret, " source}");
111   } else if (IsSink()) {
112     strings::StrAppend(&ret, " sink}");
113   } else {
114     strings::StrAppend(&ret, " op device:");
115     strings::StrAppend(&ret, "{", assigned_device_name(), "}");
116     strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}");
117   }
118   return ret;
119 }
120 
Node()121 Node::Node()
122     : id_(-1),
123       cost_id_(-1),
124       class_(NC_UNINITIALIZED),
125       props_(nullptr),
126       assigned_device_name_index_(0),
127       while_ctx_(nullptr) {}
128 
Initialize(int id,int cost_id,std::shared_ptr<NodeProperties> props,Node::NodeClass node_class)129 void Node::Initialize(int id, int cost_id,
130                       std::shared_ptr<NodeProperties> props,
131                       Node::NodeClass node_class) {
132   DCHECK_EQ(id_, -1);
133   DCHECK(in_edges_.empty());
134   DCHECK(out_edges_.empty());
135   id_ = id;
136   cost_id_ = cost_id;
137 
138   props_ = std::move(props);
139   class_ = node_class;
140 }
141 
Clear()142 void Node::Clear() {
143   in_edges_.clear();
144   out_edges_.clear();
145   id_ = -1;
146   cost_id_ = -1;
147   class_ = NC_UNINITIALIZED;
148   props_.reset();
149   assigned_device_name_index_ = 0;
150 }
151 
UpdateProperties()152 void Node::UpdateProperties() {
153   DataTypeVector inputs;
154   DataTypeVector outputs;
155   Status status =
156       InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
157   if (!status.ok()) {
158     LOG(ERROR) << "Failed at updating node: " << status;
159     return;
160   }
161   if (props_->input_types != inputs || props_->output_types != outputs) {
162     if (TF_PREDICT_TRUE(props_.use_count() == 1)) {
163       props_->input_types = inputs;
164       props_->input_types_slice = props_->input_types;
165       props_->output_types = outputs;
166       props_->output_types_slice = props_->output_types;
167     } else {
168       props_ = std::make_shared<NodeProperties>(
169           props_->op_def, std::move(props_->node_def), inputs, outputs);
170     }
171   }
172 }
173 
name() const174 const string& Node::name() const { return props_->node_def.name(); }
type_string() const175 const string& Node::type_string() const { return props_->node_def.op(); }
def() const176 const NodeDef& Node::def() const { return props_->node_def; }
op_def() const177 const OpDef& Node::op_def() const { return *props_->op_def; }
178 
num_inputs() const179 int32 Node::num_inputs() const { return props_->input_types.size(); }
input_type(int32 i) const180 DataType Node::input_type(int32 i) const { return props_->input_types[i]; }
input_types() const181 const DataTypeVector& Node::input_types() const { return props_->input_types; }
182 
num_outputs() const183 int32 Node::num_outputs() const { return props_->output_types.size(); }
output_type(int32 o) const184 DataType Node::output_type(int32 o) const { return props_->output_types[o]; }
output_types() const185 const DataTypeVector& Node::output_types() const {
186   return props_->output_types;
187 }
188 
attrs() const189 AttrSlice Node::attrs() const { return AttrSlice(def()); }
190 
requested_inputs() const191 const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const {
192   return def().input();
193 }
194 
requested_device() const195 const string& Node::requested_device() const { return def().device(); }
196 
out_nodes() const197 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
198   return gtl::make_range(NeighborIter(out_edges_.begin(), false),
199                          NeighborIter(out_edges_.end(), false));
200 }
201 
in_nodes() const202 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
203   return gtl::make_range(NeighborIter(in_edges_.begin(), true),
204                          NeighborIter(in_edges_.end(), true));
205 }
206 
MaybeCopyOnWrite()207 void Node::MaybeCopyOnWrite() {
208   // NodeProperties may be shared between Nodes. Make a copy if so.
209   if (!props_.unique()) {
210     props_ = std::make_shared<NodeProperties>(*props_);
211   }
212 }
213 
AddAttrHelper(const string & name)214 AttrValue* Node::AddAttrHelper(const string& name) {
215   MaybeCopyOnWrite();
216   return &((*props_->node_def.mutable_attr())[name]);
217 }
218 
ClearAttr(const string & name)219 void Node::ClearAttr(const string& name) {
220   MaybeCopyOnWrite();
221   (*props_->node_def.mutable_attr()).erase(name);
222 }
223 
set_name(string name)224 void Node::set_name(string name) {
225   MaybeCopyOnWrite();
226   props_->node_def.set_name(std::move(name));
227 }
228 
set_requested_device(const string & device)229 void Node::set_requested_device(const string& device) {
230   MaybeCopyOnWrite();
231   props_->node_def.set_device(device);
232 }
233 
set_original_node_names(const std::vector<string> & names)234 void Node::set_original_node_names(const std::vector<string>& names) {
235   MaybeCopyOnWrite();
236   props_->node_def.mutable_experimental_debug_info()
237       ->clear_original_node_names();
238   if (!names.empty()) {
239     *props_->node_def.mutable_experimental_debug_info()
240          ->mutable_original_node_names() = {names.begin(), names.end()};
241   }
242 }
243 
input_edge(int idx,const Edge ** e) const244 Status Node::input_edge(int idx, const Edge** e) const {
245   if (idx < 0 || idx >= num_inputs()) {
246     return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
247                                    name(), " only has ", num_inputs(),
248                                    " inputs.");
249   }
250 
251   // This does a linear search over the edges.  In the common case,
252   // the number of elements is small enough that this search isn't
253   // expensive.  Should it become a bottleneck, one can make an
254   // optimization where, if the number of edges is small, we use
255   // linear iteration, and if the number of edges is large, we perform
256   // an indexing step during construction that keeps an array of Edges
257   // indexed by pointer.  This would keep the size of each Node small
258   // in the common case but make this function faster when the number
259   // of edges is large.
260   for (const Edge* edge : in_edges()) {
261     if (edge->dst_input() == idx) {
262       *e = edge;
263       return Status::OK();
264     }
265   }
266 
267   return errors::NotFound("Could not find input edge ", idx, " for ", name());
268 }
269 
270 // Returns a vector of the non-control input edges to a node, indexed by ID.
input_edges(std::vector<const Edge * > * input_edges) const271 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
272   input_edges->clear();
273   input_edges->resize(num_inputs(), nullptr);
274 
275   for (const Edge* edge : in_edges()) {
276     if (edge->IsControlEdge()) continue;
277     if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
278       return errors::Internal("Invalid edge input number ", edge->dst_input());
279     }
280     if ((*input_edges)[edge->dst_input()] != nullptr) {
281       return errors::Internal("Duplicate edge input number: ",
282                               edge->dst_input());
283     }
284     (*input_edges)[edge->dst_input()] = edge;
285   }
286 
287   for (int i = 0; i < num_inputs(); ++i) {
288     if ((*input_edges)[i] == nullptr) {
289       return errors::InvalidArgument("Missing edge input number: ", i);
290     }
291   }
292   return Status::OK();
293 }
294 
input_node(int idx,Node ** n) const295 Status Node::input_node(int idx, Node** n) const {
296   const Edge* e;
297   TF_RETURN_IF_ERROR(input_edge(idx, &e));
298   if (e == nullptr) {
299     *n = nullptr;
300   } else {
301     *n = e->src();
302   }
303   return Status::OK();
304 }
305 
input_node(int idx,const Node ** const_n) const306 Status Node::input_node(int idx, const Node** const_n) const {
307   Node* n;
308   TF_RETURN_IF_ERROR(input_node(idx, &n));
309   *const_n = n;
310   return Status::OK();
311 }
312 
input_tensor(int idx,OutputTensor * t) const313 Status Node::input_tensor(int idx, OutputTensor* t) const {
314   const Edge* e;
315   TF_RETURN_IF_ERROR(input_edge(idx, &e));
316   DCHECK(e != nullptr);
317   *t = OutputTensor(e->src(), e->src_output());
318   return Status::OK();
319 }
320 
321 // NodeDebugInfo
322 
NodeDebugInfo(const Node & n)323 NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
NodeDebugInfo(const NodeDef & ndef)324 NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef)
325     : NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(),
326                     ndef.experimental_debug_info()) {}
NodeDebugInfo(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)327 NodeDebugInfo::NodeDebugInfo(
328     StringPiece node_name, bool has_experimental_debug_info,
329     const NodeDef_ExperimentalDebugInfo& experimental_debug_info)
330     : name(node_name) {
331   if (has_experimental_debug_info) {
332     const auto& names = experimental_debug_info.original_node_names();
333     original_node_names.assign(names.begin(), names.end());
334   }
335 }
336 
337 // InputTensor
338 
operator ==(const InputTensor & other) const339 bool InputTensor::operator==(const InputTensor& other) const {
340   return node == other.node && index == other.index;
341 }
342 
operator ()(InputTensor const & s) const343 uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
344   return Hash64Combine(std::hash<const Node*>()(s.node),
345                        std::hash<int>()(s.index));
346 }
347 
348 // OutputTensor
349 
operator ==(const OutputTensor & other) const350 bool OutputTensor::operator==(const OutputTensor& other) const {
351   return node == other.node && index == other.index;
352 }
353 
operator ()(OutputTensor const & s) const354 uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
355   return Hash64Combine(std::hash<const Node*>()(s.node),
356                        std::hash<int>()(s.index));
357 }
358 
359 // Graph
360 
Graph(const OpRegistryInterface * ops)361 Graph::Graph(const OpRegistryInterface* ops)
362     : ops_(ops, FunctionDefLibrary()),
363       versions_(new VersionDef),
364       arena_(8 << 10 /* 8kB */) {
365   versions_->set_producer(TF_GRAPH_DEF_VERSION);
366   versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
367 
368   // Initialize the name interning table for assigned_device_name.
369   device_names_.push_back("");
370   DCHECK_EQ(0, InternDeviceName(""));
371 
372   // Source and sink have no endpoints, just control edges.
373   NodeDef def;
374   def.set_name("_SOURCE");
375   def.set_op("NoOp");
376   Status status;
377   Node* source = AddNode(def, &status);
378   TF_CHECK_OK(status);
379   CHECK_EQ(source->id(), kSourceId);
380 
381   def.set_name("_SINK");
382   Node* sink = AddNode(def, &status);
383   TF_CHECK_OK(status);
384   CHECK_EQ(sink->id(), kSinkId);
385 
386   AddControlEdge(source, sink);
387 }
388 
Graph(const FunctionLibraryDefinition & flib_def)389 Graph::Graph(const FunctionLibraryDefinition& flib_def)
390     : Graph(flib_def.default_registry()) {
391   // Need a new-enough consumer to support the functions we add to the graph.
392   if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) {
393     versions_->set_min_consumer(12);
394   }
395   Status s = ops_.AddLibrary(flib_def);
396   CHECK(s.ok()) << s.error_message();
397 }
398 
~Graph()399 Graph::~Graph() {
400   // Manually call the destructors for all the Nodes we constructed using
401   // placement new.
402   for (Node* node : nodes_) {
403     if (node != nullptr) {
404       node->~Node();
405     }
406   }
407   for (Node* node : free_nodes_) {
408     node->~Node();
409   }
410   // Edges have no destructor, and we arena-allocated them, so no need to
411   // destroy them.
412 }
413 
versions() const414 const VersionDef& Graph::versions() const { return *versions_; }
set_versions(const VersionDef & versions)415 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
416 
Copy(const Graph & src)417 void Graph::Copy(const Graph& src) {
418   SetConstructionContext(src.GetConstructionContextInternal());
419   for (Node* n : nodes()) {
420     CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
421   }
422 
423   // Copy GraphDef versions
424   set_versions(src.versions());
425 
426   // Copy the nodes.
427   // "Node in src" -> "Node in *dest"
428   gtl::FlatMap<const Node*, Node*> node_map;
429   node_map.reserve(src.num_nodes());
430   node_map[src.source_node()] = source_node();
431   node_map[src.sink_node()] = sink_node();
432   for (Node* n : src.op_nodes()) {
433     auto copy = CopyNode(n);
434     copy->in_edges_.reserve(n->in_edges().size());
435     copy->out_edges_.reserve(n->out_edges().size());
436     node_map[n] = copy;
437   }
438 
439   // Copy the edges
440   edges_.reserve(src.num_edges());
441   for (const Edge* e : src.edges()) {
442     Node* src_copy = node_map[e->src()];
443     Node* dst_copy = node_map[e->dst()];
444     AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
445   }
446 }
447 
AddNode(NodeDef node_def,Status * status)448 Node* Graph::AddNode(NodeDef node_def, Status* status) {
449   const OpRegistrationData* op_reg_data;
450   status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
451   if (!status->ok()) return nullptr;
452 
453   DataTypeVector inputs;
454   DataTypeVector outputs;
455   status->Update(
456       InOutTypesForNode(node_def, op_reg_data->op_def, &inputs, &outputs));
457   if (!status->ok()) {
458     *status = AttachDef(*status, node_def);
459     return nullptr;
460   }
461 
462   Node::NodeClass node_class = op_reg_data->is_function_op
463                                    ? Node::NC_FUNCTION_OP
464                                    : Node::GetNodeClassForOp(node_def.op());
465 
466   Node* node = AllocateNode(
467       std::make_shared<NodeProperties>(&op_reg_data->op_def,
468                                        std::move(node_def), inputs, outputs),
469       nullptr, node_class);
470   return node;
471 }
472 
CopyNode(const Node * node)473 Node* Graph::CopyNode(const Node* node) {
474   DCHECK(!node->IsSource());
475   DCHECK(!node->IsSink());
476   Node* copy = AllocateNode(node->props_, node, node->class_);
477   copy->set_assigned_device_name(node->assigned_device_name());
478 
479   // Since the OpDef of a function may be owned by the Graph that owns 'node',
480   // relookup the OpDef in the target graph. If it differs, then clone the
481   // node properties with the updated OpDef.
482   const OpDef* op_def;
483   TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
484   if (op_def != node->props_->op_def) {
485     copy->MaybeCopyOnWrite();
486     copy->props_->op_def = op_def;
487   }
488   copy->SetStackTrace(node->GetStackTrace());
489 
490   return copy;
491 }
492 
RemoveNode(Node * node)493 void Graph::RemoveNode(Node* node) {
494   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
495   DCHECK(!node->IsSource());
496   DCHECK(!node->IsSink());
497 
498   // Remove any edges involving this node.
499   for (const Edge* e : node->in_edges_) {
500     CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
501     edges_[e->id_] = nullptr;
502     RecycleEdge(e);
503     --num_edges_;
504   }
505   node->in_edges_.clear();
506   for (const Edge* e : node->out_edges_) {
507     CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
508     edges_[e->id_] = nullptr;
509     RecycleEdge(e);
510     --num_edges_;
511   }
512   node->out_edges_.clear();
513   ReleaseNode(node);
514 }
515 
AddEdge(Node * source,int x,Node * dest,int y)516 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
517   TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
518   TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
519 
520   // source/sink must only be linked via control slots, and
521   // control slots must only be linked to control slots.
522   if (source == source_node() || dest == sink_node() || x == kControlSlot ||
523       y == kControlSlot) {
524     DCHECK_EQ(x, kControlSlot) << source->DebugString();
525     DCHECK_EQ(y, kControlSlot) << dest->DebugString();
526   }
527 
528   Edge* e = nullptr;
529   if (free_edges_.empty()) {
530     e = new (arena_.Alloc(sizeof(Edge))) Edge;  // placement new
531   } else {
532     e = free_edges_.back();
533     free_edges_.pop_back();
534   }
535   e->id_ = edges_.size();
536   e->src_ = source;
537   e->dst_ = dest;
538   e->src_output_ = x;
539   e->dst_input_ = y;
540   CHECK(source->out_edges_.insert(e).second);
541   CHECK(dest->in_edges_.insert(e).second);
542   edges_.push_back(e);
543   ++num_edges_;
544   return e;
545 }
546 
RemoveEdge(const Edge * e)547 void Graph::RemoveEdge(const Edge* e) {
548   TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
549   TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
550   CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
551   CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
552   CHECK_EQ(e, edges_[e->id_]);
553   CHECK_GT(num_edges_, 0);
554 
555   edges_[e->id_] = nullptr;
556   RecycleEdge(e);
557   --num_edges_;
558 }
559 
RecycleEdge(const Edge * e)560 void Graph::RecycleEdge(const Edge* e) {
561   free_edges_.push_back(const_cast<Edge*>(e));
562 }
563 
AddControlEdge(Node * source,Node * dest,bool allow_duplicates)564 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
565                                   bool allow_duplicates) {
566   if (!allow_duplicates) {
567     for (const Edge* edge : dest->in_edges()) {
568       if (edge->IsControlEdge() && edge->src() == source) {
569         // The requested edge already exists.
570         return nullptr;
571       }
572     }
573   }
574   // Modify dest's NodeDef if necessary.
575   if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
576     // Check if this input is already in dest's NodeDef.
577     const string new_input = strings::StrCat("^", source->name());
578     bool input_exists = false;
579     for (const string& input : dest->props_->node_def.input()) {
580       if (input == new_input) {
581         input_exists = true;
582         break;
583       }
584     }
585     if (!input_exists) {
586       dest->MaybeCopyOnWrite();
587       dest->props_->node_def.add_input(new_input);
588     }
589   }
590   return AddEdge(source, kControlSlot, dest, kControlSlot);
591 }
592 
RemoveControlEdge(const Edge * e)593 void Graph::RemoveControlEdge(const Edge* e) {
594   if (!e->src_->IsSource() && !e->dst_->IsSink()) {
595     e->dst_->MaybeCopyOnWrite();
596     string e_src_name = strings::StrCat("^", e->src_->name());
597     auto* inputs = e->dst_->props_->node_def.mutable_input();
598     for (auto it = inputs->begin(); it != inputs->end(); ++it) {
599       if (*it == e_src_name) {
600         inputs->erase(it);
601         break;
602       }
603     }
604   }
605   RemoveEdge(e);
606 }
607 
608 namespace {
FindEdge(const Node * dst,int index)609 const Edge* FindEdge(const Node* dst, int index) {
610   for (const Edge* e : dst->in_edges()) {
611     if (e->dst_input() == index) return e;
612   }
613   return nullptr;
614 }
615 }  // namespace
616 
UpdateEdge(Node * new_src,int new_src_index,Node * dst,int dst_index)617 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
618                          int dst_index) {
619   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
620   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
621   const Edge* e = FindEdge(dst, dst_index);
622   if (e == nullptr) {
623     return errors::InvalidArgument("Couldn't find edge to ",
624                                    FormatNodeForError(*dst));
625   }
626   RemoveEdge(e);
627   AddEdge(new_src, new_src_index, dst, dst_index);
628   dst->MaybeCopyOnWrite();
629   (*dst->props_->node_def.mutable_input())[dst_index] =
630       strings::StrCat(new_src->name(), ":", new_src_index);
631   return Status::OK();
632 }
633 
AddWhileInputHack(Node * new_src,int new_src_index,Node * dst)634 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
635   if (!dst->IsWhileNode()) {
636     return errors::Internal(
637         "dst argument to AddWhileEdgeHack should be a While op, got: ",
638         dst->DebugString());
639   }
640   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
641   // Find the current number of data inputs. We'll add the new edge to the next
642   // missing data input.
643   int dst_index = 0;
644   for (const Edge* edge : dst->in_edges()) {
645     if (edge->IsControlEdge()) continue;
646     ++dst_index;
647   }
648   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
649   AddEdge(new_src, new_src_index, dst, dst_index);
650   dst->MaybeCopyOnWrite();
651   dst->props_->node_def.add_input(
652       strings::StrCat(new_src->name(), ":", new_src_index));
653   return Status::OK();
654 }
655 
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)656 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
657   // Need a new-enough consumer to support the functions we add to the graph.
658   if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
659     versions_->set_min_consumer(12);
660   }
661   return ops_.AddLibrary(fdef_lib);
662 }
663 
664 namespace {
665 
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)666 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
667   if (src_slot == Graph::kControlSlot) {
668     dst->add_input(strings::StrCat("^", src_name));
669   } else if (src_slot == 0) {
670     dst->add_input(src_name.data(), src_name.size());
671   } else {
672     dst->add_input(strings::StrCat(src_name, ":", src_slot));
673   }
674 }
675 
676 }  // namespace
677 
ToGraphDef(GraphDef * graph_def) const678 void Graph::ToGraphDef(GraphDef* graph_def) const {
679   ToGraphDefSubRange(graph_def, 0);
680 }
681 
ToGraphDefDebug() const682 GraphDef Graph::ToGraphDefDebug() const {
683   GraphDef ret;
684   ToGraphDef(&ret);
685   return ret;
686 }
687 
ToGraphDefSubRange(GraphDef * graph_def,int from_node_id) const688 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
689   graph_def->Clear();
690   *graph_def->mutable_versions() = versions();
691   *graph_def->mutable_library() = ops_.ToProto();
692 
693   graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
694 
695   std::vector<const Edge*>
696       inputs;  // Construct this outside the loop for speed.
697   for (auto id = from_node_id; id < num_node_ids(); ++id) {
698     const Node* node = FindNodeId(id);
699     if (node == nullptr || !node->IsOp()) continue;
700     NodeDef* node_def = graph_def->add_node();
701     *node_def = node->def();
702 
703     // Use the node's assigned device, if any, instead of the device requested
704     // in the NodeDef.
705     if (!node->assigned_device_name().empty()) {
706       node_def->set_device(node->assigned_device_name());
707     }
708 
709     // Get the inputs for this Node.  We make sure control inputs are
710     // after data inputs, as required by GraphDef.
711     inputs.clear();
712     inputs.resize(node->num_inputs(), nullptr);
713     for (const Edge* edge : node->in_edges()) {
714       if (edge->IsControlEdge()) {
715         inputs.push_back(edge);
716       } else {
717         DCHECK(edge->dst_input() < inputs.size())
718             << "Edge " << edge->DebugString()
719             << " is overflowing the expected number of inputs ("
720             << node->num_inputs() << ") for node " << node->DebugString();
721         CHECK(inputs[edge->dst_input()] == nullptr)
722             << "Edge " << edge->src()->name() << "->" << edge->dst()->name()
723             << " conflicts with pre-existing input edge "
724             << inputs[edge->dst_input()]->src()->name() << "->"
725             << inputs[edge->dst_input()]->dst()->name();
726 
727         inputs[edge->dst_input()] = edge;
728       }
729     }
730     // Sort the control inputs for more predictable serialization.
731     std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
732               [](const Edge* a, const Edge* b) -> bool {
733                 return a->src()->name() < b->src()->name();
734               });
735     node_def->clear_input();
736     node_def->mutable_input()->Reserve(inputs.size());
737 
738     for (size_t i = 0; i < inputs.size(); ++i) {
739       const Edge* edge = inputs[i];
740       if (edge == nullptr) {
741         if (i < node->requested_inputs().size()) {
742           node_def->add_input(node->requested_inputs()[i]);
743         } else {
744           node_def->add_input("");
745         }
746       } else {
747         const Node* src = edge->src();
748         if (!src->IsOp()) continue;
749         AddInput(node_def, src->name(), edge->src_output());
750       }
751     }
752   }
753 }
754 
NewName(StringPiece prefix)755 string Graph::NewName(StringPiece prefix) {
756   return strings::StrCat(prefix, "/_", name_counter_++);
757 }
758 
IsValidNode(const Node * node) const759 Status Graph::IsValidNode(const Node* node) const {
760   if (node == nullptr) {
761     return errors::InvalidArgument("Node is null");
762   }
763   const int id = node->id();
764   if (id < 0) {
765     return errors::InvalidArgument("node id ", id, " is less than zero");
766   }
767   if (static_cast<size_t>(id) >= nodes_.size()) {
768     return errors::InvalidArgument(
769         "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
770   }
771   if (nodes_[id] != node) {
772     return errors::InvalidArgument("Node with id ", id,
773                                    " is different from the passed in node. "
774                                    "Does it belong to a different graph?");
775   }
776   return Status::OK();
777 }
778 
IsValidOutputTensor(const Node * node,int idx) const779 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
780   TF_RETURN_IF_ERROR(IsValidNode(node));
781   if (idx >= node->num_outputs() || idx < 0) {
782     return errors::OutOfRange("Node '", node->name(), "' (type: '",
783                               node->op_def().name(),
784                               "', num of outputs: ", node->num_outputs(),
785                               ") does not have ", "output ", idx);
786   }
787   return Status::OK();
788 }
789 
IsValidInputTensor(const Node * node,int idx) const790 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
791   TF_RETURN_IF_ERROR(IsValidNode(node));
792   if (idx >= node->num_inputs() || idx < 0) {
793     return errors::OutOfRange("Node '", node->name(), "' (type: '",
794                               node->op_def().name(),
795                               "', num of inputs: ", node->num_inputs(),
796                               ") does not have ", "input ", idx);
797   }
798   return Status::OK();
799 }
800 
AllocateNode(std::shared_ptr<NodeProperties> props,const Node * cost_node,Node::NodeClass node_class)801 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
802                           const Node* cost_node, Node::NodeClass node_class) {
803   Node* node = nullptr;
804   if (free_nodes_.empty()) {
805     node = new (arena_.Alloc(sizeof(Node))) Node;  // placement new
806   } else {
807     node = free_nodes_.back();
808     free_nodes_.pop_back();
809   }
810   node->graph_ = this;
811   const int id = nodes_.size();
812   int cost_id = cost_node ? cost_node->cost_id() : id;
813   node->Initialize(id, cost_id, std::move(props), node_class);
814   nodes_.push_back(node);
815   ++num_nodes_;
816   return node;
817 }
818 
ReleaseNode(Node * node)819 void Graph::ReleaseNode(Node* node) {
820   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
821   nodes_[node->id()] = nullptr;
822   free_nodes_.push_back(node);
823   --num_nodes_;
824   node->Clear();
825 }
826 
827 // Ensures that 'device_name' is present in the device name table, and returns
828 // the index of that device name. The index is stable, and can be used in
829 // calls to Node::set_assigned_device_name_index().
InternDeviceName(const string & device_name)830 int Graph::InternDeviceName(const string& device_name) {
831   // Special case, very common.  Also, this allows us to use a single map
832   // lookup below, instead of two.  The 'if (index_cell > 0)' test below
833   // relies on this check.
834   if (device_name.empty()) {
835     return 0;
836   }
837 
838   int& index_cell = device_names_map_[device_name];
839   if (index_cell > 0) {
840     return index_cell;
841   }
842 
843   const int index = device_names_map_.size();
844   index_cell = index;
845   device_names_.push_back(device_name);
846   return index;
847 }
848 
AddWhileContext(StringPiece frame_name,std::vector<Node * > enter_nodes,std::vector<Node * > exit_nodes,OutputTensor cond_output,std::vector<OutputTensor> body_inputs,std::vector<OutputTensor> body_outputs,WhileContext ** result)849 Status Graph::AddWhileContext(StringPiece frame_name,
850                               std::vector<Node*> enter_nodes,
851                               std::vector<Node*> exit_nodes,
852                               OutputTensor cond_output,
853                               std::vector<OutputTensor> body_inputs,
854                               std::vector<OutputTensor> body_outputs,
855                               WhileContext** result) {
856   auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
857       string(frame_name),
858       WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
859                    cond_output, std::move(body_inputs),
860                    std::move(body_outputs))));
861   if (!pair.second) {
862     *result = nullptr;
863     return errors::InvalidArgument("WhileContext with frame name '", frame_name,
864                                    "' already exists");
865   }
866   *result = &pair.first->second;
867   return Status::OK();
868 }
869 
BuildNodeNameIndex() const870 std::unordered_map<string, Node*> Graph::BuildNodeNameIndex() const {
871   std::unordered_map<string, Node*> result;
872   for (Node* n : nodes()) {
873     result[n->name()] = n;
874   }
875   return result;
876 }
877 
DebugString() const878 string Edge::DebugString() const {
879   return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
880                          src_output_, dst_->name().c_str(), dst_input_);
881 }
882 
883 }  // namespace tensorflow
884