1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph.h"
17 
18 #include <vector>
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/versions.pb.h"
24 #include "tensorflow/core/graph/while_context.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/gtl/map_util.h"
27 #include "tensorflow/core/lib/hash/hash.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/lib/strings/stringprintf.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/public/version.h"
32 
33 namespace tensorflow {
34 
35 const int Graph::kControlSlot = -1;
36 
37 struct NodeProperties {
38  public:
NodePropertiestensorflow::NodeProperties39   NodeProperties(const OpDef* op_def, const NodeDef& node_def,
40                  const DataTypeSlice inputs, const DataTypeSlice outputs)
41       : op_def(op_def),
42         node_def(node_def),
43         input_types(inputs.begin(), inputs.end()),
44         output_types(outputs.begin(), outputs.end()) {}
45 
46   const OpDef* op_def;  // not owned
47   NodeDef node_def;
48   const DataTypeVector input_types;
49   const DataTypeVector output_types;
50 };
51 
52 // Node
53 
54 #define REF_CLASS(key, value) \
55   {key, value}, { "Ref" key, value }
56 
57 const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
58     *new std::unordered_map<string, Node::NodeClass>({
59         // Keep in same order as NodeClass values
60         REF_CLASS("Switch", NC_SWITCH),
61         REF_CLASS("Merge", NC_MERGE),
62         REF_CLASS("Enter", NC_ENTER),
63         REF_CLASS("Exit", NC_EXIT),
64         REF_CLASS("NextIteration", NC_NEXT_ITERATION),
65         {"LoopCond", NC_LOOP_COND},
66         {"ControlTrigger", NC_CONTROL_TRIGGER},
67         {"_Send", NC_SEND},
68         {"_HostSend", NC_HOST_SEND},
69         {"_Recv", NC_RECV},
70         {"_HostRecv", NC_HOST_RECV},
71         {"Const", NC_CONSTANT},
72         {"HostConst", NC_CONSTANT},
73         {"Variable", NC_VARIABLE},
74         {"VariableV2", NC_VARIABLE},
75         REF_CLASS("Identity", NC_IDENTITY),
76         {"GetSessionHandle", NC_GET_SESSION_HANDLE},
77         {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
78         {"GetSessionTensor", NC_GET_SESSION_TENSOR},
79         {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
80         {"Size", NC_METADATA},
81         {"Shape", NC_METADATA},
82         {"Rank", NC_METADATA},
83         {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
84         {"CollectiveReduce", NC_COLLECTIVE},
85         {"CollectiveBcastSend", NC_COLLECTIVE},
86         {"CollectiveBcastRecv", NC_COLLECTIVE},
87         {"FakeParam", NC_FAKE_PARAM},
88         {"PartitionedCall", NC_PARTITIONED_CALL},
89         {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
90         // Not using the constants defined in FunctionLibraryDefinition for the
91         // 4 ops below because android inference library does not link
92         // tf.function related files.
93         {"_Arg", NC_ARG},
94         {"_DeviceArg", NC_ARG},
95         {"_Retval", NC_RETVAL},
96         {"_DeviceRetval", NC_RETVAL},
97     });
98 
99 #undef REF_CLASS
100 
GetNodeClassForOp(const string & ts)101 Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
102   auto it = kNodeClassTable.find(ts);
103   if (it != kNodeClassTable.end()) {
104     return it->second;
105   } else {
106     return NC_OTHER;
107   }
108 }
109 
DebugString() const110 string Node::DebugString() const {
111   string ret = strings::StrCat("{name:'", name(), "' id:", id_);
112   if (IsSource()) {
113     strings::StrAppend(&ret, " source}");
114   } else if (IsSink()) {
115     strings::StrAppend(&ret, " sink}");
116   } else {
117     strings::StrAppend(&ret, " op device:");
118     strings::StrAppend(&ret, "{", assigned_device_name(), "}");
119     strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}");
120   }
121   return ret;
122 }
123 
Node()124 Node::Node()
125     : id_(-1),
126       cost_id_(-1),
127       class_(NC_UNINITIALIZED),
128       props_(nullptr),
129       assigned_device_name_index_(0),
130       while_ctx_(nullptr) {}
131 
Initialize(int id,int cost_id,std::shared_ptr<NodeProperties> props)132 void Node::Initialize(int id, int cost_id,
133                       std::shared_ptr<NodeProperties> props) {
134   DCHECK_EQ(id_, -1);
135   DCHECK(in_edges_.empty());
136   DCHECK(out_edges_.empty());
137   id_ = id;
138   cost_id_ = cost_id;
139 
140   props_ = std::move(props);
141   // Initialize the class_ based on the type string
142   class_ = GetNodeClassForOp(props_->node_def.op());
143 }
144 
Clear()145 void Node::Clear() {
146   in_edges_.clear();
147   out_edges_.clear();
148   id_ = -1;
149   cost_id_ = -1;
150   class_ = NC_UNINITIALIZED;
151   props_.reset();
152   assigned_device_name_index_ = 0;
153 }
154 
UpdateProperties()155 void Node::UpdateProperties() {
156   DataTypeVector inputs;
157   DataTypeVector outputs;
158   Status status =
159       InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
160   if (!status.ok()) {
161     LOG(ERROR) << "Failed at updating node: " << status;
162     return;
163   }
164   props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def,
165                                             inputs, outputs);
166 }
167 
name() const168 const string& Node::name() const { return props_->node_def.name(); }
type_string() const169 const string& Node::type_string() const { return props_->node_def.op(); }
def() const170 const NodeDef& Node::def() const { return props_->node_def; }
op_def() const171 const OpDef& Node::op_def() const { return *props_->op_def; }
172 
num_inputs() const173 int32 Node::num_inputs() const { return props_->input_types.size(); }
input_type(int32 i) const174 DataType Node::input_type(int32 i) const { return props_->input_types[i]; }
input_types() const175 const DataTypeVector& Node::input_types() const { return props_->input_types; }
176 
num_outputs() const177 int32 Node::num_outputs() const { return props_->output_types.size(); }
output_type(int32 o) const178 DataType Node::output_type(int32 o) const { return props_->output_types[o]; }
output_types() const179 const DataTypeVector& Node::output_types() const {
180   return props_->output_types;
181 }
182 
attrs() const183 AttrSlice Node::attrs() const { return AttrSlice(def()); }
184 
requested_inputs() const185 const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const {
186   return def().input();
187 }
188 
requested_device() const189 const string& Node::requested_device() const { return def().device(); }
190 
out_nodes() const191 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
192   return gtl::make_range(NeighborIter(out_edges_.begin(), false),
193                          NeighborIter(out_edges_.end(), false));
194 }
195 
in_nodes() const196 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
197   return gtl::make_range(NeighborIter(in_edges_.begin(), true),
198                          NeighborIter(in_edges_.end(), true));
199 }
200 
MaybeCopyOnWrite()201 void Node::MaybeCopyOnWrite() {
202   // NodeProperties may be shared between Nodes. Make a copy if so.
203   if (!props_.unique()) {
204     props_ = std::make_shared<NodeProperties>(*props_);
205   }
206 }
207 
AddAttrHelper(const string & name)208 AttrValue* Node::AddAttrHelper(const string& name) {
209   MaybeCopyOnWrite();
210   return &((*props_->node_def.mutable_attr())[name]);
211 }
212 
ClearAttr(const string & name)213 void Node::ClearAttr(const string& name) {
214   MaybeCopyOnWrite();
215   (*props_->node_def.mutable_attr()).erase(name);
216 }
217 
set_name(string name)218 void Node::set_name(string name) {
219   MaybeCopyOnWrite();
220   props_->node_def.set_name(std::move(name));
221 }
222 
set_requested_device(const string & device)223 void Node::set_requested_device(const string& device) {
224   MaybeCopyOnWrite();
225   props_->node_def.set_device(device);
226 }
227 
set_original_node_names(const std::vector<string> & names)228 void Node::set_original_node_names(const std::vector<string>& names) {
229   MaybeCopyOnWrite();
230   props_->node_def.mutable_experimental_debug_info()
231       ->clear_original_node_names();
232   if (!names.empty()) {
233     *props_->node_def.mutable_experimental_debug_info()
234          ->mutable_original_node_names() = {names.begin(), names.end()};
235   }
236 }
237 
input_edge(int idx,const Edge ** e) const238 Status Node::input_edge(int idx, const Edge** e) const {
239   if (idx < 0 || idx >= num_inputs()) {
240     return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
241                                    name(), " only has ", num_inputs(),
242                                    " inputs.");
243   }
244 
245   // This does a linear search over the edges.  In the common case,
246   // the number of elements is small enough that this search isn't
247   // expensive.  Should it become a bottleneck, one can make an
248   // optimization where, if the number of edges is small, we use
249   // linear iteration, and if the number of edges is large, we perform
250   // an indexing step during construction that keeps an array of Edges
251   // indexed by pointer.  This would keep the size of each Node small
252   // in the common case but make this function faster when the number
253   // of edges is large.
254   for (const Edge* edge : in_edges()) {
255     if (edge->dst_input() == idx) {
256       *e = edge;
257       return Status::OK();
258     }
259   }
260 
261   return errors::NotFound("Could not find input edge ", idx, " for ", name());
262 }
263 
264 // Returns a vector of the non-control input edges to a node, indexed by ID.
input_edges(std::vector<const Edge * > * input_edges) const265 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
266   input_edges->clear();
267   input_edges->resize(num_inputs(), nullptr);
268 
269   for (const Edge* edge : in_edges()) {
270     if (edge->IsControlEdge()) continue;
271     if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
272       return errors::Internal("Invalid edge input number ", edge->dst_input());
273     }
274     if ((*input_edges)[edge->dst_input()] != nullptr) {
275       return errors::Internal("Duplicate edge input number: ",
276                               edge->dst_input());
277     }
278     (*input_edges)[edge->dst_input()] = edge;
279   }
280 
281   for (int i = 0; i < num_inputs(); ++i) {
282     if ((*input_edges)[i] == nullptr) {
283       return errors::InvalidArgument("Missing edge input number: ", i);
284     }
285   }
286   return Status::OK();
287 }
288 
input_node(int idx,Node ** n) const289 Status Node::input_node(int idx, Node** n) const {
290   const Edge* e;
291   TF_RETURN_IF_ERROR(input_edge(idx, &e));
292   if (e == nullptr) {
293     *n = nullptr;
294   } else {
295     *n = e->src();
296   }
297   return Status::OK();
298 }
299 
input_node(int idx,const Node ** const_n) const300 Status Node::input_node(int idx, const Node** const_n) const {
301   Node* n;
302   TF_RETURN_IF_ERROR(input_node(idx, &n));
303   *const_n = n;
304   return Status::OK();
305 }
306 
input_tensor(int idx,OutputTensor * t) const307 Status Node::input_tensor(int idx, OutputTensor* t) const {
308   const Edge* e;
309   TF_RETURN_IF_ERROR(input_edge(idx, &e));
310   DCHECK(e != nullptr);
311   *t = OutputTensor(e->src(), e->src_output());
312   return Status::OK();
313 }
314 
315 // NodeDebugInfo
316 
NodeDebugInfo(const Node & n)317 NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
NodeDebugInfo(const NodeDef & ndef)318 NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef) : name(ndef.name()) {
319   if (ndef.has_experimental_debug_info()) {
320     const auto& names = ndef.experimental_debug_info().original_node_names();
321     original_node_names.assign(names.begin(), names.end());
322   }
323 }
324 
325 // InputTensor
326 
operator ==(const InputTensor & other) const327 bool InputTensor::operator==(const InputTensor& other) const {
328   return node == other.node && index == other.index;
329 }
330 
operator ()(InputTensor const & s) const331 uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
332   return Hash64Combine(std::hash<const Node*>()(s.node),
333                        std::hash<int>()(s.index));
334 }
335 
336 // OutputTensor
337 
operator ==(const OutputTensor & other) const338 bool OutputTensor::operator==(const OutputTensor& other) const {
339   return node == other.node && index == other.index;
340 }
341 
operator ()(OutputTensor const & s) const342 uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
343   return Hash64Combine(std::hash<const Node*>()(s.node),
344                        std::hash<int>()(s.index));
345 }
346 
347 // Graph
348 
Graph(const OpRegistryInterface * ops)349 Graph::Graph(const OpRegistryInterface* ops)
350     : ops_(ops, FunctionDefLibrary()),
351       versions_(new VersionDef),
352       arena_(8 << 10 /* 8kB */) {
353   versions_->set_producer(TF_GRAPH_DEF_VERSION);
354   versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
355 
356   // Initialize the name interning table for assigned_device_name.
357   device_names_.push_back("");
358   DCHECK_EQ(0, InternDeviceName(""));
359 
360   // Source and sink have no endpoints, just control edges.
361   NodeDef def;
362   def.set_name("_SOURCE");
363   def.set_op("NoOp");
364   Status status;
365   Node* source = AddNode(def, &status);
366   TF_CHECK_OK(status);
367   CHECK_EQ(source->id(), kSourceId);
368 
369   def.set_name("_SINK");
370   Node* sink = AddNode(def, &status);
371   TF_CHECK_OK(status);
372   CHECK_EQ(sink->id(), kSinkId);
373 
374   AddControlEdge(source, sink);
375 }
376 
Graph(const FunctionLibraryDefinition & flib_def)377 Graph::Graph(const FunctionLibraryDefinition& flib_def)
378     : Graph(flib_def.default_registry()) {
379   // Need a new-enough consumer to support the functions we add to the graph.
380   if (flib_def.ToProto().function_size() > 0 &&
381       versions_->min_consumer() < 12) {
382     versions_->set_min_consumer(12);
383   }
384   Status s = ops_.AddLibrary(flib_def);
385   CHECK(s.ok()) << s.error_message();
386 }
387 
~Graph()388 Graph::~Graph() {
389   // Manually call the destructors for all the Nodes we constructed using
390   // placement new.
391   for (Node* node : nodes_) {
392     if (node != nullptr) {
393       node->~Node();
394     }
395   }
396   for (Node* node : free_nodes_) {
397     node->~Node();
398   }
399   // Edges have no destructor, and we arena-allocated them, so no need to
400   // destroy them.
401 }
402 
versions() const403 const VersionDef& Graph::versions() const { return *versions_; }
set_versions(const VersionDef & versions)404 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
405 
AddNode(const NodeDef & node_def,Status * status)406 Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
407   const OpDef* op_def;
408   status->Update(ops_.LookUpOpDef(node_def.op(), &op_def));
409   if (!status->ok()) return nullptr;
410 
411   DataTypeVector inputs;
412   DataTypeVector outputs;
413   status->Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
414   if (!status->ok()) {
415     *status = AttachDef(*status, node_def);
416     return nullptr;
417   }
418 
419   Node* node = AllocateNode(
420       std::make_shared<NodeProperties>(op_def, node_def, inputs, outputs),
421       nullptr);
422   return node;
423 }
424 
CopyNode(const Node * node)425 Node* Graph::CopyNode(const Node* node) {
426   DCHECK(!node->IsSource());
427   DCHECK(!node->IsSink());
428   Node* copy = AllocateNode(node->props_, node);
429   copy->set_assigned_device_name(node->assigned_device_name());
430 
431   // Since the OpDef of a function may be owned by the Graph that owns 'node',
432   // relookup the OpDef in the target graph. If it differs, then clone the
433   // node properties with the updated OpDef.
434   const OpDef* op_def;
435   TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
436   if (op_def != node->props_->op_def) {
437     copy->MaybeCopyOnWrite();
438     copy->props_->op_def = op_def;
439   }
440 
441   return copy;
442 }
443 
RemoveNode(Node * node)444 void Graph::RemoveNode(Node* node) {
445   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
446   DCHECK(!node->IsSource());
447   DCHECK(!node->IsSink());
448 
449   // Remove any edges involving this node.
450   while (!node->in_edges_.empty()) {
451     RemoveEdge(*node->in_edges_.begin());
452   }
453   while (!node->out_edges_.empty()) {
454     RemoveEdge(*node->out_edges_.begin());
455   }
456   ReleaseNode(node);
457 }
458 
AddEdge(Node * source,int x,Node * dest,int y)459 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
460   TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
461   TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
462 
463   // source/sink must only be linked via control slots, and
464   // control slots must only be linked to control slots.
465   if (source == source_node() || dest == sink_node() || x == kControlSlot ||
466       y == kControlSlot) {
467     DCHECK_EQ(x, kControlSlot) << source->DebugString();
468     DCHECK_EQ(y, kControlSlot) << dest->DebugString();
469   }
470 
471   Edge* e = nullptr;
472   if (free_edges_.empty()) {
473     e = new (arena_.Alloc(sizeof(Edge))) Edge;  // placement new
474   } else {
475     e = free_edges_.back();
476     free_edges_.pop_back();
477   }
478   e->id_ = edges_.size();
479   e->src_ = source;
480   e->dst_ = dest;
481   e->src_output_ = x;
482   e->dst_input_ = y;
483   CHECK(source->out_edges_.insert(e).second);
484   CHECK(dest->in_edges_.insert(e).second);
485   edges_.push_back(e);
486   ++num_edges_;
487   return e;
488 }
489 
RemoveEdge(const Edge * e)490 void Graph::RemoveEdge(const Edge* e) {
491   TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
492   TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
493   CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
494   CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
495   CHECK_EQ(e, edges_[e->id_]);
496   CHECK_GT(num_edges_, 0);
497 
498   edges_[e->id_] = nullptr;
499 
500   Edge* del = const_cast<Edge*>(e);
501   del->src_ = nullptr;
502   del->dst_ = nullptr;
503   del->id_ = -1;
504   del->src_output_ = kControlSlot - 1;
505   del->dst_input_ = kControlSlot - 1;
506   free_edges_.push_back(del);
507   --num_edges_;
508 }
509 
AddControlEdge(Node * source,Node * dest,bool allow_duplicates)510 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
511                                   bool allow_duplicates) {
512   if (!allow_duplicates) {
513     for (const Edge* edge : dest->in_edges()) {
514       if (edge->IsControlEdge() && edge->src() == source) {
515         // The requested edge already exists.
516         return nullptr;
517       }
518     }
519   }
520   // Modify dest's NodeDef if necessary.
521   if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
522     // Check if this input is already in dest's NodeDef.
523     const string new_input = strings::StrCat("^", source->name());
524     bool input_exists = false;
525     for (const string& input : dest->props_->node_def.input()) {
526       if (input == new_input) {
527         input_exists = true;
528         break;
529       }
530     }
531     if (!input_exists) {
532       dest->MaybeCopyOnWrite();
533       dest->props_->node_def.add_input(new_input);
534     }
535   }
536   return AddEdge(source, kControlSlot, dest, kControlSlot);
537 }
538 
RemoveControlEdge(const Edge * e)539 void Graph::RemoveControlEdge(const Edge* e) {
540   if (!e->src_->IsSource() && !e->dst_->IsSink()) {
541     e->dst_->MaybeCopyOnWrite();
542     string e_src_name = strings::StrCat("^", e->src_->name());
543     auto* inputs = e->dst_->props_->node_def.mutable_input();
544     for (auto it = inputs->begin(); it != inputs->end(); ++it) {
545       if (*it == e_src_name) {
546         inputs->erase(it);
547         break;
548       }
549     }
550   }
551   RemoveEdge(e);
552 }
553 
554 namespace {
FindEdge(const Node * dst,int index)555 const Edge* FindEdge(const Node* dst, int index) {
556   for (const Edge* e : dst->in_edges()) {
557     if (e->dst_input() == index) return e;
558   }
559   return nullptr;
560 }
561 }  // namespace
562 
UpdateEdge(Node * new_src,int new_src_index,Node * dst,int dst_index)563 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
564                          int dst_index) {
565   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
566   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
567   const Edge* e = FindEdge(dst, dst_index);
568   if (e == nullptr) {
569     return errors::InvalidArgument("Couldn't find edge to ",
570                                    FormatNodeForError(*dst));
571   }
572   RemoveEdge(e);
573   AddEdge(new_src, new_src_index, dst, dst_index);
574   dst->MaybeCopyOnWrite();
575   (*dst->props_->node_def.mutable_input())[dst_index] =
576       strings::StrCat(new_src->name(), ":", new_src_index);
577   return Status::OK();
578 }
579 
AddWhileInputHack(Node * new_src,int new_src_index,Node * dst)580 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
581   if (dst->type_string() != "While") {
582     return errors::Internal(
583         "dst argument to AddWhileEdgeHack should be a While op, got: ",
584         dst->DebugString());
585   }
586   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
587   // Find the current number of data inputs. We'll add the new edge to the next
588   // missing data input.
589   int dst_index = 0;
590   for (const Edge* edge : dst->in_edges()) {
591     if (edge->IsControlEdge()) continue;
592     ++dst_index;
593   }
594   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
595   AddEdge(new_src, new_src_index, dst, dst_index);
596   dst->MaybeCopyOnWrite();
597   dst->props_->node_def.add_input(
598       strings::StrCat(new_src->name(), ":", new_src_index));
599   return Status::OK();
600 }
601 
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)602 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
603   // Need a new-enough consumer to support the functions we add to the graph.
604   if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
605     versions_->set_min_consumer(12);
606   }
607   return ops_.AddLibrary(fdef_lib);
608 }
609 
610 namespace {
611 
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)612 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
613   if (src_slot == Graph::kControlSlot) {
614     dst->add_input(strings::StrCat("^", src_name));
615   } else if (src_slot == 0) {
616     dst->add_input(src_name.data(), src_name.size());
617   } else {
618     dst->add_input(strings::StrCat(src_name, ":", src_slot));
619   }
620 }
621 
622 }  // namespace
623 
ToGraphDef(GraphDef * graph_def) const624 void Graph::ToGraphDef(GraphDef* graph_def) const {
625   ToGraphDefSubRange(graph_def, 0);
626 }
627 
ToGraphDefDebug() const628 GraphDef Graph::ToGraphDefDebug() const {
629   GraphDef ret;
630   ToGraphDef(&ret);
631   return ret;
632 }
633 
ToGraphDefSubRange(GraphDef * graph_def,int from_node_id) const634 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
635   graph_def->Clear();
636   *graph_def->mutable_versions() = versions();
637   *graph_def->mutable_library() = ops_.ToProto();
638 
639   graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
640 
641   std::vector<const Edge*>
642       inputs;  // Construct this outside the loop for speed.
643   for (auto id = from_node_id; id < num_node_ids(); ++id) {
644     const Node* node = FindNodeId(id);
645     if (node == nullptr || !node->IsOp()) continue;
646     NodeDef* node_def = graph_def->add_node();
647     *node_def = node->def();
648 
649     // Use the node's assigned device, if any, instead of the device requested
650     // in the NodeDef.
651     if (!node->assigned_device_name().empty()) {
652       node_def->set_device(node->assigned_device_name());
653     }
654 
655     // Get the inputs for this Node.  We make sure control inputs are
656     // after data inputs, as required by GraphDef.
657     inputs.clear();
658     inputs.resize(node->num_inputs(), nullptr);
659     for (const Edge* edge : node->in_edges()) {
660       if (edge->IsControlEdge()) {
661         inputs.push_back(edge);
662       } else {
663         CHECK(inputs[edge->dst_input()] == nullptr)
664             << "Edge " << edge->src()->DebugString() << ":"
665             << edge->dst()->DebugString() << " with dst_input "
666             << edge->dst_input() << " and had pre-existing input edge "
667             << inputs[edge->dst_input()]->src()->DebugString() << ":"
668             << inputs[edge->dst_input()]->dst()->DebugString();
669 
670         inputs[edge->dst_input()] = edge;
671       }
672     }
673     // Sort the control inputs for more predictable serialization.
674     std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
675               [](const Edge* a, const Edge* b) -> bool {
676                 return a->src()->name() < b->src()->name();
677               });
678     node_def->clear_input();
679     node_def->mutable_input()->Reserve(inputs.size());
680 
681     for (size_t i = 0; i < inputs.size(); ++i) {
682       const Edge* edge = inputs[i];
683       if (edge == nullptr) {
684         if (i < node->requested_inputs().size()) {
685           node_def->add_input(node->requested_inputs()[i]);
686         } else {
687           node_def->add_input("");
688         }
689       } else {
690         const Node* src = edge->src();
691         if (!src->IsOp()) continue;
692         AddInput(node_def, src->name(), edge->src_output());
693       }
694     }
695   }
696 }
697 
NewName(StringPiece prefix)698 string Graph::NewName(StringPiece prefix) {
699   return strings::StrCat(prefix, "/_", name_counter_++);
700 }
701 
IsValidNode(const Node * node) const702 Status Graph::IsValidNode(const Node* node) const {
703   if (node == nullptr) {
704     return errors::InvalidArgument("Node is null");
705   }
706   const int id = node->id();
707   if (id < 0) {
708     return errors::InvalidArgument("node id ", id, " is less than zero");
709   }
710   if (static_cast<size_t>(id) >= nodes_.size()) {
711     return errors::InvalidArgument(
712         "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
713   }
714   if (nodes_[id] != node) {
715     return errors::InvalidArgument("Node with id ", id,
716                                    " is different from the passed in node. "
717                                    "Does it belong to a different graph?");
718   }
719   return Status::OK();
720 }
721 
IsValidOutputTensor(const Node * node,int idx) const722 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
723   TF_RETURN_IF_ERROR(IsValidNode(node));
724   if (idx >= node->num_outputs() || idx < 0) {
725     return errors::OutOfRange("Node '", node->name(), "' (type: '",
726                               node->op_def().name(),
727                               "', num of outputs: ", node->num_outputs(),
728                               ") does not have ", "output ", idx);
729   }
730   return Status::OK();
731 }
732 
IsValidInputTensor(const Node * node,int idx) const733 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
734   TF_RETURN_IF_ERROR(IsValidNode(node));
735   if (idx >= node->num_inputs() || idx < 0) {
736     return errors::OutOfRange("Node '", node->name(), "' (type: '",
737                               node->op_def().name(),
738                               "', num of inputs: ", node->num_inputs(),
739                               ") does not have ", "input ", idx);
740   }
741   return Status::OK();
742 }
743 
AllocateNode(std::shared_ptr<NodeProperties> props,const Node * cost_node)744 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
745                           const Node* cost_node) {
746   Node* node = nullptr;
747   if (free_nodes_.empty()) {
748     node = new (arena_.Alloc(sizeof(Node))) Node;  // placement new
749   } else {
750     node = free_nodes_.back();
751     free_nodes_.pop_back();
752   }
753   node->graph_ = this;
754   const int id = nodes_.size();
755   int cost_id = cost_node ? cost_node->cost_id() : id;
756   node->Initialize(id, cost_id, std::move(props));
757   nodes_.push_back(node);
758   ++num_nodes_;
759   return node;
760 }
761 
ReleaseNode(Node * node)762 void Graph::ReleaseNode(Node* node) {
763   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
764   nodes_[node->id()] = nullptr;
765   free_nodes_.push_back(node);
766   --num_nodes_;
767   node->Clear();
768 }
769 
770 // Ensures that 'device_name' is present in the device name table, and returns
771 // the index of that device name. The index is stable, and can be used in
772 // calls to Node::set_assigned_device_name_index().
InternDeviceName(const string & device_name)773 int Graph::InternDeviceName(const string& device_name) {
774   // Special case, very common.  Also, this allows us to use a single map
775   // lookup below, instead of two.  The 'if (index_cell > 0)' test below
776   // relies on this check.
777   if (device_name.empty()) {
778     return 0;
779   }
780 
781   int& index_cell = device_names_map_[device_name];
782   if (index_cell > 0) {
783     return index_cell;
784   }
785 
786   const int index = device_names_map_.size();
787   index_cell = index;
788   device_names_.push_back(device_name);
789   return index;
790 }
791 
AddWhileContext(StringPiece frame_name,std::vector<Node * > enter_nodes,std::vector<Node * > exit_nodes,OutputTensor cond_output,std::vector<OutputTensor> body_inputs,std::vector<OutputTensor> body_outputs,WhileContext ** result)792 Status Graph::AddWhileContext(StringPiece frame_name,
793                               std::vector<Node*> enter_nodes,
794                               std::vector<Node*> exit_nodes,
795                               OutputTensor cond_output,
796                               std::vector<OutputTensor> body_inputs,
797                               std::vector<OutputTensor> body_outputs,
798                               WhileContext** result) {
799   auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
800       string(frame_name),
801       WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
802                    cond_output, std::move(body_inputs),
803                    std::move(body_outputs))));
804   if (!pair.second) {
805     *result = nullptr;
806     return errors::InvalidArgument("WhileContext with frame name '", frame_name,
807                                    "' already exists");
808   }
809   *result = &pair.first->second;
810   return Status::OK();
811 }
812 
BuildNodeNameIndex() const813 std::unordered_map<string, Node*> Graph::BuildNodeNameIndex() const {
814   std::unordered_map<string, Node*> result;
815   for (Node* n : nodes()) {
816     result[n->name()] = n;
817   }
818   return result;
819 }
820 
DebugString() const821 string Edge::DebugString() const {
822   return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
823                          src_output_, dst_->name().c_str(), dst_input_);
824 }
825 
826 }  // namespace tensorflow
827