1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A Graph describes a set of computations that are to be
17 // performed, as well as the dependencies between those
18 // computations. The basic model is a DAG (directed acyclic graph) with
19 // * internal nodes representing computational operations to be performed;
20 // * edges represent dependencies, indicating the target may only be
21 //   executed once the source has completed; and
22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
23 //   should be the only node that doesn't depend on anything, and the sink
24 //   should be the only node that nothing depends on.
25 //
26 // Note: Node ids are intended to be relatively dense in the
27 // 0..max_id range, but there may be gaps since ids won't be reused.
28 //
29 // Note: Some dependencies between operations are due to one operation
30 // consuming the output of another. In fact operations can produce
31 // multiple outputs and consume multiple inputs, and some
32 // optimizations will care about which specific outputs are connected
33 // to which specific inputs.  We therefore represent data dependency
34 // between output O of layer A and input I of layer B using
35 // "input index" and "output index" labels per edge.
36 
37 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
38 #define TENSORFLOW_CORE_GRAPH_GRAPH_H_
39 
40 #include <functional>
41 #include <string>
42 #include <vector>
43 #include "tensorflow/core/framework/function.h"
44 #include "tensorflow/core/framework/op.h"
45 #include "tensorflow/core/framework/types.h"
46 #include "tensorflow/core/graph/edgeset.h"
47 #include "tensorflow/core/lib/core/arena.h"
48 #include "tensorflow/core/lib/core/refcount.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/iterator_range.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/types.h"
54 
55 namespace tensorflow {
56 
57 class Edge;
58 class EdgeSetTest;
59 class Graph;
60 class GraphDef;
61 class Node;
62 struct OutputTensor;
63 class VersionDef;
64 class WhileContext;
65 
66 class NeighborIter;     // Declared below
67 class NodeIter;         // Declared below
68 struct NodeProperties;  // Defined in .cc
69 
70 class Node {
71  public:
72   string DebugString() const;
id()73   int id() const { return id_; }
cost_id()74   int cost_id() const { return cost_id_; }
75   const string& name() const;
76   void set_name(string name);
77   const string& type_string() const;
78 
79   // def() provides the NodeDef the user supplied, but the specifics
80   // of this Node may have changed due to placement, optimization, etc.
81   // In particular:
82   // * def().name() will match name();
83   // * def().op() will match type_string() and op_def().name();
84   // * def().input() is not reliable, use "in_edges()" below instead;
85   // * def().device() is the "user's requested device" and may not match
86   //   the actual assigned device, see assigned_device_name() below;
87   // * def().attr() is authoritative.
88   // TODO(irving): Replace with NodeInfo.
89   const NodeDef& def() const;
90   const OpDef& op_def() const;
91 
92   // input and output types
93   int32 num_inputs() const;
94   DataType input_type(int32 i) const;
95   const DataTypeVector& input_types() const;
96 
97   int32 num_outputs() const;
98   DataType output_type(int32 o) const;
99   const DataTypeVector& output_types() const;
100 
101   // The device requested by the user.  For the actual assigned device,
102   // use assigned_device_name() below.
103   const string& requested_device() const;
104 
105   // This changes the user requested device but not necessarily the device that
106   // on which the operation will run.
107   void set_requested_device(const string& device);
108 
109   // This gives the device the runtime has assigned this node to.  If
110   // you want the device the user requested, use def().device() instead.
111   // TODO(josh11b): Validate that the assigned_device, if not empty:
112   // fully specifies a device, and satisfies def().device().
113   // TODO(josh11b): Move assigned_device_name outside of Node into a
114   // NodeId->DeviceName map.
115   const string& assigned_device_name() const;
116   void set_assigned_device_name(const string& device_name);
has_assigned_device_name()117   bool has_assigned_device_name() const {
118     return assigned_device_name_index_ > 0;
119   }
assigned_device_name_index()120   int assigned_device_name_index() const { return assigned_device_name_index_; }
121   void set_assigned_device_name_index(int index);
122 
123   // Sets 'original_node_names' field of this node's DebugInfo proto to
124   // 'names'.
125   void set_original_node_names(const std::vector<string>& names);
126 
127   // Read only access to attributes
128   AttrSlice attrs() const;
129 
130   // Inputs requested by the NodeDef.  For the actual inputs, use in_edges.
131   const protobuf::RepeatedPtrField<string>& requested_inputs() const;
132 
133   // Get the neighboring nodes via edges either in or out of this node.  This
134   // includes control edges.
135   gtl::iterator_range<NeighborIter> in_nodes() const;
136   gtl::iterator_range<NeighborIter> out_nodes() const;
in_edges()137   const EdgeSet& in_edges() const { return in_edges_; }
out_edges()138   const EdgeSet& out_edges() const { return out_edges_; }
139 
140   // Node type helpers.
IsSource()141   bool IsSource() const { return id() == 0; }
IsSink()142   bool IsSink() const { return id() == 1; }
143   // Anything other than the special Source & Sink nodes.
IsOp()144   bool IsOp() const { return id() > 1; }
145 
146   // Node class helpers
IsSwitch()147   bool IsSwitch() const { return class_ == NC_SWITCH; }
IsMerge()148   bool IsMerge() const { return class_ == NC_MERGE; }
IsEnter()149   bool IsEnter() const { return class_ == NC_ENTER; }
IsExit()150   bool IsExit() const { return class_ == NC_EXIT; }
IsNextIteration()151   bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
IsLoopCond()152   bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
IsControlTrigger()153   bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
IsSend()154   bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
IsRecv()155   bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
IsConstant()156   bool IsConstant() const { return class_ == NC_CONSTANT; }
IsVariable()157   bool IsVariable() const { return class_ == NC_VARIABLE; }
IsIdentity()158   bool IsIdentity() const { return class_ == NC_IDENTITY; }
IsGetSessionHandle()159   bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
IsGetSessionTensor()160   bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
IsDeleteSessionTensor()161   bool IsDeleteSessionTensor() const {
162     return class_ == NC_DELETE_SESSION_TENSOR;
163   }
IsControlFlow()164   bool IsControlFlow() const {
165     return (class_ != NC_OTHER) &&  // Fast path
166            (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
167             IsNextIteration());
168   }
IsHostSend()169   bool IsHostSend() const { return class_ == NC_HOST_SEND; }
IsHostRecv()170   bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
IsScopedAllocator()171   bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
IsCollective()172   bool IsCollective() const { return class_ == NC_COLLECTIVE; }
173 
IsMetadata()174   bool IsMetadata() const { return class_ == NC_METADATA; }
IsFakeParam()175   bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
IsPartitionedCall()176   bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
177   // Is this node a function input
IsArg()178   bool IsArg() const { return class_ == NC_ARG; }
179   // Is this node a function output
IsRetval()180   bool IsRetval() const { return class_ == NC_RETVAL; }
181 
182   template <typename T>
AddAttr(const string & name,const T & val)183   void AddAttr(const string& name, const T& val) {
184     SetAttrValue(val, AddAttrHelper(name));
185     UpdateProperties();
186   }
187 
188   void ClearAttr(const string& name);
189 
190   // Returns into '*e' the edge connecting to the 'idx' input of this Node.
191   Status input_edge(int idx, const Edge** e) const;
192 
193   // Returns into '*edges' the input data edges of this Node, indexed by input
194   // number. Does not return control edges.
195   Status input_edges(std::vector<const Edge*>* edges) const;
196 
197   // Returns into '*n' the node that has an output connected to the
198   // 'idx' input of this Node.
199   Status input_node(int idx, const Node** n) const;
200   Status input_node(int idx, Node** n) const;
201 
202   // Returns into '*t' the idx-th input tensor of this node, represented as the
203   // output tensor of input_node(idx).
204   Status input_tensor(int idx, OutputTensor* t) const;
205 
while_ctx()206   WhileContext* while_ctx() const { return while_ctx_; }
set_while_ctx(WhileContext * while_ctx)207   void set_while_ctx(WhileContext* while_ctx) {
208     DCHECK(IsExit());
209     DCHECK(while_ctx_ == nullptr);
210     while_ctx_ = while_ctx;
211   }
212 
213  private:
214   friend class Graph;
215   Node();
216 
properties()217   NodeProperties* properties() const { return props_.get(); }
218 
219   void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props);
220 
221   // Releases memory from props_, in addition to restoring *this to its
222   // uninitialized state.
223   void Clear();
224 
225   // Make a copy of the Node's props_ if props_ is shared with
226   // other nodes. This must be called before mutating properties,
227   // e.g. in AddAttr.
228   void MaybeCopyOnWrite();
229 
230   // Called after an attr has changed. Decides whether we need to update some
231   // property of the node (stored in props_).
232   void UpdateProperties();
233 
234   AttrValue* AddAttrHelper(const string& name);
235 
236   // A set of mutually exclusive classes for different kinds of nodes,
237   // class_ is initialized in the Node::Initialize routine based on the
238   // node's type_string().
239   enum NodeClass {
240     NC_UNINITIALIZED,
241     NC_SWITCH,
242     NC_MERGE,
243     NC_ENTER,
244     NC_EXIT,
245     NC_NEXT_ITERATION,
246     NC_LOOP_COND,
247     NC_CONTROL_TRIGGER,
248     NC_SEND,
249     NC_HOST_SEND,
250     NC_RECV,
251     NC_HOST_RECV,
252     NC_CONSTANT,
253     NC_VARIABLE,
254     NC_IDENTITY,
255     NC_GET_SESSION_HANDLE,
256     NC_GET_SESSION_TENSOR,
257     NC_DELETE_SESSION_TENSOR,
258     NC_METADATA,
259     NC_SCOPED_ALLOCATOR,
260     NC_COLLECTIVE,
261     NC_FAKE_PARAM,
262     NC_PARTITIONED_CALL,
263     NC_ARG,
264     NC_RETVAL,
265     NC_OTHER  // Not a special kind of node
266   };
267 
268   static const std::unordered_map<string, NodeClass>& kNodeClassTable;
269 
270   static NodeClass GetNodeClassForOp(const string& ts);
271 
272   int id_;       // -1 until Initialize() is called
273   int cost_id_;  // -1 if there is no corresponding cost accounting node
274   NodeClass class_;
275 
276   EdgeSet in_edges_;
277   EdgeSet out_edges_;
278 
279   // NOTE(skyewm): inheriting from core::RefCounted may have a slight
280   // performance benefit over using shared_ptr, at the cost of manual ref
281   // counting
282   std::shared_ptr<NodeProperties> props_;
283 
284   // Index within Graph::device_names_ of the name of device assigned
285   // to perform this computation.
286   int assigned_device_name_index_;
287 
288   // A back-pointer to the Graph that owns this node.  Currently, this exists
289   // solely to allow Node::[set_]assigned_device_name() to work. However, if all
290   // callers of Node::[set_]assigned_device_name() are modified to use the
291   // equivalent methods defined directly on Graph, then we can remove this
292   // field and reclaim that memory.
293   Graph* graph_;
294 
295   // Set if this is an exit node of a while loop with an associated
296   // WhileContext. Otherwise null. (This is only set for exit nodes because
297   // they're the first nodes of a loop encountered while creating the gradient
298   // graph. Exit nodes that are part of while loop gradient graphs will not have
299   // this set.)
300   WhileContext* while_ctx_;
301 
302   TF_DISALLOW_COPY_AND_ASSIGN(Node);
303 };
304 
305 // Stores debug information associated with the Node.
306 struct NodeDebugInfo {
307   const string name;
308   std::vector<string> original_node_names;
309 
310   NodeDebugInfo(const Node& n);
311   NodeDebugInfo(const NodeDef& ndef);
312 };
313 
314 // Represents an input of a node, i.e., the `index`-th input to `node`.
315 struct InputTensor {
316   Node* node;
317   int index;
318 
InputTensorInputTensor319   InputTensor(Node* n, int i) : node(n), index(i) {}
InputTensorInputTensor320   InputTensor() : node(nullptr), index(0) {}
321 
322   // Returns true if this InputTensor is identical to 'other'. Nodes are
323   // compared using pointer equality.
324   bool operator==(const InputTensor& other) const;
325 
326   // A hash function for InputTensors. Nodes are hashed based on their pointer
327   // value.
328   struct Hash {
329     uint64 operator()(InputTensor const& s) const;
330   };
331 };
332 
333 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
334 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
335 // is consumed by multiple destination nodes.
336 struct OutputTensor {
337   Node* node;
338   int index;
339 
OutputTensorOutputTensor340   OutputTensor(Node* n, int i) : node(n), index(i) {}
OutputTensorOutputTensor341   OutputTensor() : node(nullptr), index(0) {}
342 
343   // Returns true if this OutputTensor is identical to 'other'. Nodes are
344   // compared using pointer equality.
345   bool operator==(const OutputTensor& other) const;
346 
347   // A hash function for OutputTensors. Nodes are hashed based on their pointer
348   // value.
349   struct Hash {
350     uint64 operator()(OutputTensor const& s) const;
351   };
352 };
353 
354 class Edge {
355  public:
src()356   Node* src() const { return src_; }
dst()357   Node* dst() const { return dst_; }
id()358   int id() const { return id_; }
359 
360   // Return the index of the source output that produces the data
361   // carried by this edge.  The special value kControlSlot is used
362   // for control dependencies.
src_output()363   int src_output() const { return src_output_; }
364 
365   // Return the index of the destination input that consumes the data
366   // carried by this edge.  The special value kControlSlot is used
367   // for control dependencies.
dst_input()368   int dst_input() const { return dst_input_; }
369 
370   // Return true iff this is an edge that indicates a control-flow
371   // (as opposed to a data-flow) dependency.
372   bool IsControlEdge() const;
373 
374   string DebugString() const;
375 
376  private:
Edge()377   Edge() {}
378 
379   friend class EdgeSetTest;
380   friend class Graph;
381   Node* src_;
382   Node* dst_;
383   int id_;
384   int src_output_;
385   int dst_input_;
386 };
387 
388 // Allows for iteration of the edges of a Graph, by iterating the underlying
389 // Graph.edges_ vector while skipping over null entries.
390 class GraphEdgesIterable {
391  private:
392   const std::vector<Edge*>& edges_;
393 
394  public:
GraphEdgesIterable(const std::vector<Edge * > & edges)395   explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
396       : edges_(edges) {}
397 
398   typedef Edge* value_type;
399 
400   class const_iterator {
401    private:
402     // The underlying iterator.
403     std::vector<value_type>::const_iterator iter_;
404 
405     // The end of the underlying iterator.
406     std::vector<value_type>::const_iterator end_;
407 
408     // Advances iter_ until it reaches a non-null item, or reaches the end.
apply_filter()409     void apply_filter() {
410       while (iter_ != end_ && *iter_ == nullptr) {
411         ++iter_;
412       }
413     }
414 
415    public:
const_iterator(std::vector<value_type>::const_iterator iter,std::vector<value_type>::const_iterator end)416     const_iterator(std::vector<value_type>::const_iterator iter,
417                    std::vector<value_type>::const_iterator end)
418         : iter_(iter), end_(end) {
419       apply_filter();
420     }
421 
422     bool operator==(const const_iterator& other) const {
423       return iter_ == other.iter_;
424     }
425 
426     bool operator!=(const const_iterator& other) const {
427       return iter_ != other.iter_;
428     }
429 
430     // This is the prefix increment operator (++x), which is the operator
431     // used by C++ range iteration (for (x : y) ...).  We intentionally do not
432     // provide a postfix increment operator.
433     const_iterator& operator++() {
434       ++iter_;
435       apply_filter();
436       return *this;
437     }
438 
439     value_type operator*() { return *iter_; }
440   };
441 
begin()442   const_iterator begin() {
443     return const_iterator(edges_.begin(), edges_.end());
444   }
end()445   const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
446 };
447 
448 // Thread compatible but not thread safe.
449 class Graph {
450  public:
451   // Constructs a graph with a single SOURCE (always id kSourceId) and a
452   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
453   //
454   // The graph can hold ops found in the registry. `ops`s lifetime must be at
455   // least that of the constructed graph's.
456   explicit Graph(const OpRegistryInterface* ops);
457 
458   // Constructs a graph with a single SOURCE (always id kSourceId) and a
459   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
460   //
461   // The graph can hold ops found in `flib_def`. Unlike the constructor taking
462   // an OpRegistryInterface, this constructor copies the function definitions in
463   // `flib_def` so its lifetime may be shorter than that of the graph's. The
464   // OpRegistryInterface backing `flib_def` must still have the lifetime of the
465   // graph though.
466   explicit Graph(const FunctionLibraryDefinition& flib_def);
467 
468   ~Graph();
469 
470   static const int kControlSlot;
471 
472   // The GraphDef version range of this graph (see graph.proto).
473   const VersionDef& versions() const;
474   void set_versions(const VersionDef& versions);
475 
476   // Adds a new node to this graph, and returns it. Infers the Op and
477   // input/output types for the node. *this owns the returned instance.
478   // Returns nullptr and sets *status on error.
479   Node* AddNode(const NodeDef& node_def, Status* status);
480 
481   // Copies *node, which may belong to another graph, to a new node,
482   // which is returned.  Does not copy any edges.  *this owns the
483   // returned instance.
484   Node* CopyNode(const Node* node);
485 
486   // Removes a node from this graph, including all edges from or to it.
487   // *node should not be accessed after calling this function.
488   // REQUIRES: node->IsOp()
489   void RemoveNode(Node* node);
490 
491   // Adds an edge that connects the xth output of `source` to the yth input of
492   // `dest` and returns it. Does not update dest's NodeDef.
493   const Edge* AddEdge(Node* source, int x, Node* dest, int y);
494 
495   // Adds a control edge (no data flows along this edge) that connects `source`
496   // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
497   // adds the control input.
498   //
499   // If such a control edge already exists and `allow_duplicates` is false, no
500   // edge is added and the function returns nullptr. Otherwise the edge is
501   // unconditionally created and returned. The NodeDef is not updated if
502   // `allow_duplicates` is true.
503   // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
504   // graph_partition.cc. Figure out if we can do away with it.
505   const Edge* AddControlEdge(Node* source, Node* dest,
506                              bool allow_duplicates = false);
507 
508   // Removes edge from the graph. Does not update the destination node's
509   // NodeDef.
510   // REQUIRES: The edge must exist.
511   void RemoveEdge(const Edge* edge);
512 
513   // Removes control edge `edge` from the graph. Note that this also updates
514   // the corresponding NodeDef to reflect the change.
515   // REQUIRES: The control edge must exist.
516   void RemoveControlEdge(const Edge* e);
517 
518   // Updates the input to a node.  The existing edge to `dst` is removed and an
519   // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
520   // is also updated.
521   Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
522 
523   // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
524   // "While" op during gradient construction, see AddInputWhileHack in
525   // python_api.h for more details.
526   Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
527 
528   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
529   // registry. Ignores duplicate functions, and returns a bad status if an
530   // imported function differs from an existing function or op with the same
531   // name.
532   Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
533 
534   // The number of live nodes in the graph.
535   //
536   // Because nodes can be removed from the graph, num_nodes() is often
537   // smaller than num_node_ids(). If one needs to create an array of
538   // nodes indexed by node ids, num_node_ids() should be used as the
539   // array's size.
num_nodes()540   int num_nodes() const { return num_nodes_; }
541 
542   // The number of live nodes in the graph, excluding the Source and Sink nodes.
num_op_nodes()543   int num_op_nodes() const {
544     DCHECK_GE(num_nodes_, 2);
545     return num_nodes_ - 2;
546   }
547 
548   // The number of live edges in the graph.
549   //
550   // Because edges can be removed from the graph, num_edges() is often
551   // smaller than num_edge_ids(). If one needs to create an array of
552   // edges indexed by edge ids, num_edge_ids() should be used as the
553   // array's size.
num_edges()554   int num_edges() const { return num_edges_; }
555 
556   // Serialize the nodes starting at `from_node_id` to a GraphDef.
557   void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
558 
559   // Serialize to a GraphDef.
560   void ToGraphDef(GraphDef* graph_def) const;
561 
562   // This version can be called from debugger to inspect the graph content.
563   // Use the previous version outside debug context for efficiency reasons.
564   //
565   // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
566   // not defined in some TensorFlow builds.
567   GraphDef ToGraphDefDebug() const;
568 
569   // Generate new node name with the specified prefix that is unique
570   // across this graph.
571   string NewName(StringPiece prefix);
572 
573   // Access to the list of all nodes.  Example usage:
574   //   for (Node* node : graph.nodes()) { ... }
575   gtl::iterator_range<NodeIter> nodes() const;
576 
577   // Access to the list of all nodes, excluding the Source and Sink nodes.
578   gtl::iterator_range<NodeIter> op_nodes() const;
579 
580   // Returns one more than the maximum id assigned to any node.
num_node_ids()581   int num_node_ids() const { return nodes_.size(); }
582 
583   // Returns the node associated with an id, or nullptr if no node
584   // with that id (the node with that id was removed and the id has
585   // not yet been re-used). *this owns the returned instance.
586   // REQUIRES: 0 <= id < num_node_ids().
FindNodeId(int id)587   Node* FindNodeId(int id) const { return nodes_[id]; }
588 
589   // Returns one more than the maximum id assigned to any edge.
num_edge_ids()590   int num_edge_ids() const { return edges_.size(); }
591 
592   // Returns the Edge associated with an id, or nullptr if no edge
593   // with that id (the node with that id was removed and the id has
594   // not yet been re-used). *this owns the returned instance.
595   // REQUIRES: 0 <= id < num_node_ids().
FindEdgeId(int id)596   const Edge* FindEdgeId(int id) const { return edges_[id]; }
597 
598   // Access to the set of all edges.  Example usage:
599   //   for (const Edge* e : graph.edges()) { ... }
edges()600   GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
601 
602   // The pre-defined nodes.
603   enum { kSourceId = 0, kSinkId = 1 };
source_node()604   Node* source_node() const { return FindNodeId(kSourceId); }
sink_node()605   Node* sink_node() const { return FindNodeId(kSinkId); }
606 
op_registry()607   const OpRegistryInterface* op_registry() const { return &ops_; }
flib_def()608   const FunctionLibraryDefinition& flib_def() const { return ops_; }
609 
CheckDeviceNameIndex(int index)610   void CheckDeviceNameIndex(int index) {
611     DCHECK_GE(index, 0);
612     DCHECK_LT(index, static_cast<int>(device_names_.size()));
613   }
614 
615   int InternDeviceName(const string& device_name);
616 
get_assigned_device_name(const Node & node)617   const string& get_assigned_device_name(const Node& node) const {
618     return device_names_[node.assigned_device_name_index()];
619   }
620 
set_assigned_device_name_index(Node * node,int device_name_index)621   void set_assigned_device_name_index(Node* node, int device_name_index) {
622     CheckDeviceNameIndex(device_name_index);
623     node->assigned_device_name_index_ = device_name_index;
624   }
625 
set_assigned_device_name(Node * node,const string & device_name)626   void set_assigned_device_name(Node* node, const string& device_name) {
627     node->assigned_device_name_index_ = InternDeviceName(device_name);
628   }
629 
630   // Returns OK if `node` is non-null and belongs to this graph
631   Status IsValidNode(const Node* node) const;
632 
633   // Returns OK if IsValidNode(`node`) and `idx` is a valid output.  Does not
634   // accept control outputs.
635   Status IsValidOutputTensor(const Node* node, int idx) const;
636 
637   // Returns OK if IsValidNode(`node`) and `idx` a valid input.  Does not accept
638   // control inputs.
639   Status IsValidInputTensor(const Node* node, int idx) const;
640 
641   // Create and return a new WhileContext owned by this graph. This is called
642   // when a new while loop is created. `frame_name` must be unique among
643   // WhileContexts in this graph.
644   Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
645                          std::vector<Node*> exit_nodes,
646                          OutputTensor cond_output,
647                          std::vector<OutputTensor> body_inputs,
648                          std::vector<OutputTensor> body_outputs,
649                          WhileContext** result);
650 
651   // Builds a node name to node pointer index for all nodes in the graph.
652   std::unordered_map<string, Node*> BuildNodeNameIndex() const;
653 
654   // TODO(josh11b): uint64 hash() const;
655 
656  private:
657   // If cost_node is non-null, then cost accounting (in CostModel)
658   // will be associated with that node rather than the new one being
659   // created.
660   //
661   // Ownership of the returned Node is not transferred to caller.
662   Node* AllocateNode(std::shared_ptr<NodeProperties> props,
663                      const Node* cost_node);
664   void ReleaseNode(Node* node);
665 
666   // Registry of all known ops, including functions.
667   FunctionLibraryDefinition ops_;
668 
669   // GraphDef versions
670   const std::unique_ptr<VersionDef> versions_;
671 
672   // Allocator which will give us good locality.
673   core::Arena arena_;
674 
675   // Map from node ids to allocated nodes.  nodes_[id] may be nullptr if
676   // the node with that id was removed from the graph.
677   std::vector<Node*> nodes_;
678 
679   // Number of nodes alive.
680   int64 num_nodes_ = 0;
681 
682   // Map from edge ids to allocated edges.  edges_[id] may be nullptr if
683   // the edge with that id was removed from the graph.
684   std::vector<Edge*> edges_;
685 
686   // The number of entries in edges_ that are not nullptr.
687   int num_edges_ = 0;
688 
689   // Allocated but free nodes and edges.
690   std::vector<Node*> free_nodes_;
691   std::vector<Edge*> free_edges_;
692 
693   // For generating unique names.
694   int name_counter_ = 0;
695 
696   // In most graphs, the number of unique values used for the
697   // Node::assigned_device_name() property is quite small.  If the graph is
698   // large, then this duplication of values can consume a significant amount of
699   // memory.  Instead, we represent the same information using an interning
700   // table, which consists of a vector of unique strings (device_names_), as
701   // well a map (device_names_map_) from unique strings to indices within the
702   // unique string table.
703   //
704   // The InternDeviceName() method handles adding a new entry into the table,
705   // or locating the index of an existing entry.
706   //
707   // The fact that Node::assigned_device_name() is implemented using an
708   // interning table is intentionally public.  This allows algorithms that
709   // frequently access this field to do so efficiently, especially for the case
710   // where the assigned_device_name of one Node is copied directly from that
711   // of another Node.
712 
713   // A table of the unique assigned device names.  Indices do NOT correspond
714   // to node IDs.  Index 0 is always the empty string.
715   std::vector<string> device_names_;
716 
717   // Maps unique device names to indices within device_names_[i].
718   std::unordered_map<string, int> device_names_map_;
719 
720   // All the while contexts owned by this graph, keyed by frame name,
721   // corresponding to all the while loops contained in this graph (including
722   // nested loops). The stored contexts are usually accessed via
723   // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
724   std::map<string, WhileContext> while_ctxs_;
725 
726   TF_DISALLOW_COPY_AND_ASSIGN(Graph);
727 };
728 
729 // TODO(josh11b): We may want to support keeping an index on various
730 // node/edge attributes in a graph, particularly node names.
731 
732 // Helper routines
733 
IsSource(const Node * node)734 inline bool IsSource(const Node* node) { return node->IsSource(); }
IsSink(const Node * node)735 inline bool IsSink(const Node* node) { return node->IsSink(); }
IsSwitch(const Node * node)736 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
IsMerge(const Node * node)737 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
IsEnter(const Node * node)738 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
IsExit(const Node * node)739 inline bool IsExit(const Node* node) { return node->IsExit(); }
IsNextIteration(const Node * n)740 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
IsLoopCond(const Node * node)741 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
IsControlTrigger(const Node * n)742 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
IsSend(const Node * node)743 inline bool IsSend(const Node* node) { return node->IsSend(); }
IsRecv(const Node * node)744 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
IsHostSend(const Node * node)745 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
IsHostRecv(const Node * node)746 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
747 
748 // True for Nodes that mediate the transfer of values between processes.
IsTransferNode(const Node * n)749 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
750 
IsConstant(const Node * node)751 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
IsVariable(const Node * node)752 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
IsIdentity(const Node * node)753 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
754 
755 // Returns true iff 'n' is a control flow node.
IsControlFlow(const Node * n)756 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
757 
758 // Returns true if the node only depends on its input's metadata
759 // (shape).  Specifically, returns true for "Size", "Shape" and "Rank" ops.
IsMetadata(const Node * n)760 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
761 
IsScopedAllocator(const Node * n)762 inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
763 
IsHostMemoryPreserving(const Node * node)764 inline bool IsHostMemoryPreserving(const Node* node) {
765   return IsIdentity(node) || IsControlFlow(node);
766 }
767 
768 // Iterator for stepping through the nodes of a graph.
769 class NodeIter {
770  public:
771   NodeIter(const Graph* graph, int id);
772   bool operator==(const NodeIter& rhs);
773   bool operator!=(const NodeIter& rhs);
774   void operator++();
775   Node* operator*();
776   Node* operator->();
777 
778  private:
779   // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
780   const Graph* graph_;
781   int id_;
782 };
783 
784 // Iterator for stepping through the neighbors of a node.
785 class NeighborIter {
786  public:
787   NeighborIter(EdgeSet::const_iterator iter, bool incoming);
788   bool operator==(const NeighborIter& rhs);
789   bool operator!=(const NeighborIter& rhs);
790   void operator++();
791   Node* operator*();
792   Node* operator->();
793 
794  private:
795   EdgeSet::const_iterator iter_;
796   bool incoming_;
797 };
798 
799 // IMPLEMENTATION DETAILS, PLEASE IGNORE
800 
NodeIter(const Graph * graph,int id)801 inline NodeIter::NodeIter(const Graph* graph, int id)
802     : graph_(graph), id_(id) {}
803 
804 inline bool NodeIter::operator==(const NodeIter& rhs) {
805   DCHECK(graph_ == rhs.graph_);
806   return id_ == rhs.id_;
807 }
808 
809 inline bool NodeIter::operator!=(const NodeIter& rhs) {
810   return !(*this == rhs);
811 }
812 
813 inline void NodeIter::operator++() {
814   while (1) {
815     DCHECK_LE(id_, graph_->num_node_ids());
816     ++id_;
817     if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
818       return;
819     }
820   }
821 }
822 
823 inline Node* NodeIter::operator*() { return graph_->FindNodeId(id_); }
824 
825 inline Node* NodeIter::operator->() { return graph_->FindNodeId(id_); }
826 
NeighborIter(EdgeSet::const_iterator iter,bool incoming)827 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
828     : iter_(iter), incoming_(incoming) {}
829 
830 inline bool NeighborIter::operator==(const NeighborIter& rhs) {
831   return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
832 }
833 
834 inline bool NeighborIter::operator!=(const NeighborIter& rhs) {
835   return !(*this == rhs);
836 }
837 
838 inline void NeighborIter::operator++() { ++iter_; }
839 
840 inline Node* NeighborIter::operator*() {
841   const Edge* e = *iter_;
842   return incoming_ ? e->src() : e->dst();
843 }
844 
845 inline Node* NeighborIter::operator->() {
846   const Edge* e = *iter_;
847   return incoming_ ? e->src() : e->dst();
848 }
849 
IsControlEdge()850 inline bool Edge::IsControlEdge() const {
851   // Note that if either src_output_ or dst_input_ is kControlSlot,
852   // so is the other one (AddEdge checks this).
853   return src_output_ == Graph::kControlSlot;
854 }
855 
nodes()856 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
857   // Note that NodeId 0 is always valid since we don't let the source
858   // node be removed from the graph.
859   return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
860 }
861 
op_nodes()862 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
863   // Note that NodeId 0 is always valid since we don't let the source
864   // node be removed from the graph.
865   //
866   // The current implementation of Graph maintains the invariant that the
867   // first two nodes are the source and sink nodes, and all other nodes are op
868   // nodes. This method (op_nodes()) relies on this invariant.
869   NodeIter begin(this, 0);
870   NodeIter end(this, num_node_ids());
871   if (begin != end) {
872     ++begin;
873   }
874   if (begin != end) {
875     ++begin;
876   }
877   return gtl::make_range(begin, end);
878 }
879 
set_assigned_device_name_index(int index)880 inline void Node::set_assigned_device_name_index(int index) {
881   graph_->CheckDeviceNameIndex(index);
882   assigned_device_name_index_ = index;
883 }
884 
set_assigned_device_name(const string & device_name)885 inline void Node::set_assigned_device_name(const string& device_name) {
886   graph_->set_assigned_device_name(this, device_name);
887 }
888 
assigned_device_name()889 inline const string& Node::assigned_device_name() const {
890   return graph_->get_assigned_device_name(*this);
891 }
892 
893 }  // namespace tensorflow
894 
895 #endif  // TENSORFLOW_CORE_GRAPH_GRAPH_H_
896