1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/allocator.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/gtl/array_slice.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 
32 class Device;
33 class Graph;
34 class Node;
35 class OpKernel;
36 class Tensor;
37 
38 // Represents a single data edge in a `NodeItem`.
39 struct EdgeInfo {
40   // The node ID of the destination in the containing `GraphView`.
41   int dst_id;
42   // The index of the output that produces values on this edge.
43   int output_slot : 31;
44   // true if this is the last info for output_slot in the EdgeInfo list.
45   bool is_last : 1;
46   // The index of the input that consumes values on this edge.
47   int input_slot;
48 };
49 
50 // Represents a single control edge in a `NodeItem`.
51 struct ControlEdgeInfo {
52   // The node ID of the destination in the containing `GraphView`.
53   int dst_id;
54 };
55 
56 // Compact structure representing a graph node and its associated kernel.
57 //
58 // Each NodeItem is an element of exactly one GraphView.
59 struct NodeItem {
60   // The index of this node's item in its GraphView.
61   int node_id = -1;
62 
63   // Cached attributes of this node for fast lookup.
64   bool kernel_is_async : 1;     // True iff kernel->AsAsync() != nullptr
65   bool is_merge : 1;            // True iff IsMerge(node)
66   bool is_enter : 1;            // True iff IsEnter(node)
67   bool is_constant_enter : 1;   // True iff IsEnter(node) and
68                                 // node->GetAttr("is_constant") == true.
69   bool is_exit : 1;             // True iff IsExit(node)
70   bool is_control_trigger : 1;  // True iff IsControlTrigger(node)
71   bool is_source : 1;           // True iff IsSource(node)
72   // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node)
73   bool is_enter_exit_or_next_iter : 1;
74   bool is_transfer_node : 1;      // True iff IsTransferNode(node)
75   bool is_initialization_op : 1;  // True iff IsInitializationOp(node)
76   bool is_recv_or_switch : 1;     // True iff IsRecv(node) || IsSwitch(node)
77   bool is_next_iteration : 1;     // True iff IsNextIteration(node)
78   bool is_noop : 1;  // True iff item->kernel->type_string_view() == "NoOp")
79   bool
80       is_any_consumer_merge_or_control_trigger : 1;  // True iff the destination
81                                                      // of any output edge is a
82                                                      // merge or control trigger
83                                                      // node.
84   bool is_any_input_ref_typed : 1;  // True iff any IsRefType(dt) for dt in this
85                                     // node's input types.
86 
87   // The kernel for this node.
88   OpKernel* kernel = nullptr;
89 
90   // If the kernel is a Const op, this containts points to the constant tensor.
91   const Tensor* const_tensor = nullptr;
92 
93   // Cached values of node->num_inputs() and node->num_outputs(), to
94   // avoid levels of indirection.
95   int num_inputs;
96   int num_outputs;
97 
98   // ExecutorImpl::tensors_[input_start] is the 1st positional input
99   // for this node.
100   int input_start = 0;
101 
102   // Number of output edges, excluding control edges.
103   int32 num_output_edges;
104 
105   // Number of output control edges.
106   int32 num_output_control_edges;
107 
108   // If non-null, contains an array of num_outputs bools, where the ith bool
109   // is true if and only if the ith output is consumed by another node.
110   std::unique_ptr<bool[]> outputs_required;
111 
mutable_output_edgesNodeItem112   gtl::MutableArraySlice<EdgeInfo> mutable_output_edges() {
113     return gtl::MutableArraySlice<EdgeInfo>(output_edge_base(),
114                                             num_output_edges);
115   }
116 
output_edgesNodeItem117   gtl::ArraySlice<EdgeInfo> output_edges() const {
118     return gtl::ArraySlice<EdgeInfo>(output_edge_base(), num_output_edges);
119   }
120 
output_control_edgesNodeItem121   gtl::ArraySlice<ControlEdgeInfo> output_control_edges() const {
122     return gtl::ArraySlice<const ControlEdgeInfo>(output_control_edge_base(),
123                                                   num_output_control_edges);
124   }
125 
input_typeNodeItem126   DataType input_type(int i) const {
127     DCHECK_LT(i, num_inputs);
128     return static_cast<DataType>(input_type_base()[i]);
129   }
output_typeNodeItem130   DataType output_type(int i) const {
131     DCHECK_LT(i, num_outputs);
132     return static_cast<DataType>(output_type_base()[i]);
133   }
134 
135   // Return array of per-output allocator attributes.
output_attrsNodeItem136   const AllocatorAttributes* output_attrs() const { return output_attr_base(); }
137 
138   // Return array of expected input index from which each output should
139   // be forwarded:
140   // kNeverForward (-2) for DO NOT FORWARD (must allocate).
141   // kNoReservation (-1) for no expected forwarding.
142   // 0... for forward from that input.
forward_fromNodeItem143   const int* forward_from() const { return forward_from_base(); }
144 
145   string DebugString() const;
146 
147  private:
148   friend class GraphView;
149 
NodeItemNodeItem150   NodeItem() {}
151 
152   // Variable length section starts immediately after *this
153   // (uint8 is enough for DataType).
154   //   EdgeInfo            out_edges[num_output_edges];
155   //   ControlEdgeInfo     out_control_edges[num_output_control_edges];
156   //   AllocatorAttributes output_attr[num_outputs];
157   //   int                 forward_from[num_outputs];
158   //   uint8               input_type[num_inputs];
159   //   uint8               output_type[num_outputs];
160 
161   // Return pointer to variable length section.
varNodeItem162   char* var() const {
163     return const_cast<char*>(reinterpret_cast<const char*>(this) +
164                              sizeof(NodeItem));
165   }
166 
output_edge_baseNodeItem167   EdgeInfo* output_edge_base() const {
168     return reinterpret_cast<EdgeInfo*>(var());
169   }
170 
output_control_edge_baseNodeItem171   ControlEdgeInfo* output_control_edge_base() const {
172     return reinterpret_cast<ControlEdgeInfo*>(var() + sizeof(EdgeInfo) *
173                                                           num_output_edges);
174   }
175 
output_attr_baseNodeItem176   AllocatorAttributes* output_attr_base() const {
177     return reinterpret_cast<AllocatorAttributes*>(
178         var() + sizeof(EdgeInfo) * num_output_edges +
179         sizeof(ControlEdgeInfo) * num_output_control_edges);
180   }
forward_from_baseNodeItem181   int* forward_from_base() const {
182     return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
183                                   sizeof(ControlEdgeInfo) *
184                                       num_output_control_edges +
185                                   sizeof(AllocatorAttributes) * num_outputs);
186   }
input_type_baseNodeItem187   uint8* input_type_base() const {
188     return reinterpret_cast<uint8*>(
189         var() + sizeof(EdgeInfo) * num_output_edges +
190         sizeof(ControlEdgeInfo) * num_output_control_edges +
191         sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs);
192   }
output_type_baseNodeItem193   uint8* output_type_base() const {
194     return reinterpret_cast<uint8*>(
195         var() + sizeof(EdgeInfo) * num_output_edges +
196         sizeof(ControlEdgeInfo) * num_output_control_edges +
197         sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs +
198         sizeof(uint8) * num_inputs);
199   }
200 
201   TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
202 };
203 
204 // Immutable view of a Graph organized for efficient execution.
205 //
206 // TODO(b/152651962): Add independent unit tests for this class.
207 class GraphView {
208  public:
GraphView()209   GraphView() : space_(nullptr) {}
210   ~GraphView();
211 
212   Status Initialize(const Graph* g);
213   Status SetAllocAttrs(const Graph* g, const Device* device);
214   void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
215 
216   // Returns a mutable pointer to the `NodeItem` with the given `id` if it
217   // exists in the graph, or `nullptr` if it does not.
node(int32 id)218   NodeItem* node(int32 id) const {
219     DCHECK_GE(id, 0);
220     DCHECK_LT(id, num_nodes_);
221     uint32 offset = node_offsets_[id];
222     return ((offset == kuint32max)
223                 ? nullptr
224                 : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
225   }
226 
227   // Returns the `NodeItem` with the given `id`.
228   //
229   // REQUIRES: `id` must be the ID of a valid node in the graph.
node_ref(int32 id)230   const NodeItem& node_ref(int32 id) const {
231     DCHECK_GE(id, 0);
232     DCHECK_LT(id, num_nodes_);
233     uint32 offset = node_offsets_[id];
234     DCHECK_NE(offset, kuint32max);
235     return *reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]);
236   }
237 
num_nodes()238   int32 num_nodes() const { return num_nodes_; }
239 
240  private:
241   char* InitializeNode(char* ptr, const Node* n);
242   size_t NodeItemBytes(const Node* n);
243 
244   int32 num_nodes_ = 0;
245   uint32* node_offsets_ = nullptr;  // array of size "num_nodes_"
246   // node_offsets_[id] holds the byte offset for node w/ "id" in space_
247 
248   char* space_;  // NodeItem objects are allocated here
249 
250   TF_DISALLOW_COPY_AND_ASSIGN(GraphView);
251 };
252 
253 }  // namespace tensorflow
254 
255 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
256