1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_
17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_
18 
19 #include <unordered_map>
20 #include <unordered_set>
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/hash/hash.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op_def.pb.h"
28 #include "tensorflow/core/graph/tensor_id.h"
29 #include "tensorflow/core/grappler/utils.h"
30 #include "tensorflow/core/lib/gtl/map_util.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace tensorflow {
34 namespace grappler {
35 
36 // Map a node/op's input/output port_id to arg_id.
37 //
38 // The port_id refers to the n-th tensor of the node, while the arg_id refers to
39 // the n-th arg of the op. These two can be different if an op's arg is a list
40 // of tensors.
41 //
42 // We return -1 for any invalid port_id (i.e., no corresponding arg_id).
43 int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
44 int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
45 
46 namespace internal {
47 
48 // GraphViewInternal is a helper class to simplify graph traversal. It creates
49 // an immutable view of the nodes and edges represented by a GraphDef protocol
50 // buffer.
51 //
52 // There are two public classes implementing GraphViewInternal:
53 //
54 // - GraphView: constructed from the `const GraphDef` and doesn't allow
55 //   to mutate underlying graph via input/output ports lookup functions (ports
56 //   have const pointers to nodes).
57 //
58 // - MutableGraphView: constructed from the 'GraphDef` and allows to mutate
59 //   the graph via input/output ports lookup functions (ports have non-const
60 //   pointers to nodes), and also have couple additional functions to
61 //   add/remove/replace nodes in the graph.
62 //
63 // --------------------------- !!! WARNING !!! ---------------------------------
64 //     Removing nodes from the graph outside of MutableGraphView will
65 //     lead to segfaults! Guaranteed by absl::string_view!
66 // -----------------------------------------------------------------------------
67 //
68 template <typename GraphDefT, typename NodeDefT>
69 class GraphViewInternal {
70  public:
71   struct Port {
PortPort72     Port() : node(nullptr), port_id(0) {}
PortPort73     Port(NodeDefT* n, int port) : node(n), port_id(port) {}
74 
75     bool operator==(const Port& other) const {
76       return node == other.node && port_id == other.port_id;
77     }
78 
79     template <typename H>
AbslHashValuePort80     friend H AbslHashValue(H h, const Port& p) {
81       return H::combine(std::move(h), p.node, p.port_id);
82     }
83 
84     NodeDefT* node;
85     int port_id;
86   };
87 
88   struct InputPort : public Port {
89     using Port::Port;
90   };
91 
92   struct OutputPort : public Port {
93     using Port::Port;
94   };
95 
96   struct Edge {
EdgeEdge97     Edge(OutputPort s, InputPort d) : src(s), dst(d) {}
98 
99     bool operator==(const Edge& other) const {
100       return src == other.src && dst == other.dst;
101     }
102 
103     template <typename H>
AbslHashValueEdge104     friend H AbslHashValue(H h, const Edge& e) {
105       return H::combine(std::move(h), e.src, e.dst);
106     }
107 
108     OutputPort src;
109     InputPort dst;
110   };
111 
graph()112   GraphDefT* graph() const { return graph_; }
113 
114   // Finds a node by name or return `nullptr` if it's not in the graph view.
GetNode(absl::string_view node_name)115   NodeDefT* GetNode(absl::string_view node_name) const {
116     return gtl::FindWithDefault(nodes_, node_name, nullptr);
117   }
118 
119   // Checks if a node by name is in the graph view.
HasNode(absl::string_view node_name)120   bool HasNode(absl::string_view node_name) const {
121     return GetNode(node_name) != nullptr;
122   }
123 
124   // Gets the specified input port. Note that the special '-1' port_id can be
125   // used to access the controlling nodes (i.e. the nodes connected to node_name
126   // through an incoming control dependency).
GetInputPort(absl::string_view node_name,int port_id)127   InputPort GetInputPort(absl::string_view node_name, int port_id) const {
128     return InputPort(GetNode(node_name), port_id);
129   }
130 
131   // Gets the specified output port. Note that the special '-1' port_id can be
132   // used to access the controlled nodes (i.e. the nodes connected to node_name
133   // through an outgoing control dependency).
GetOutputPort(absl::string_view node_name,int port_id)134   OutputPort GetOutputPort(absl::string_view node_name, int port_id) const {
135     return OutputPort(GetNode(node_name), port_id);
136   }
137 
138   // Gets the input port(s) in the immediate fanout of an output port.
GetFanout(const OutputPort & port)139   const absl::flat_hash_set<InputPort>& GetFanout(
140       const OutputPort& port) const {
141     return gtl::FindWithDefault(fanouts_, port, fanout_not_found_value_);
142   }
143 
144   // Gets the output port(s) in the immediate fanin of an input port.
GetFanin(const InputPort & port)145   absl::flat_hash_set<OutputPort> GetFanin(const InputPort& port) const {
146     if (port.port_id >= 0) {
147       OutputPort regular_fanin = GetRegularFanin(port);
148       if (regular_fanin.node == nullptr) {
149         return {};
150       }
151       return {regular_fanin};
152     }
153 
154     // Collect fanin for the control input.
155     absl::flat_hash_set<OutputPort> result;
156     const int first_control_port =
157         gtl::FindWithDefault(max_regular_input_port_, port.node, -1) + 1;
158     for (int i = first_control_port; i < port.node->input_size(); ++i) {
159       TensorId tensor_id = ParseTensorName(port.node->input(i));
160 
161       auto it = nodes_.find(tensor_id.node());
162       if (it != nodes_.end()) result.emplace(it->second, tensor_id.index());
163     }
164     return result;
165   }
166 
167   // Special case: regular (i.e. non-control) input ports can only have one
168   // fanin. If port.port_id is out of range or is a control dependency, then an
169   // empty OutputPort is returned.
GetRegularFanin(const InputPort & port)170   const OutputPort GetRegularFanin(const InputPort& port) const {
171     if (port.port_id < 0 ||
172         port.port_id >
173             gtl::FindWithDefault(max_regular_input_port_, port.node, -1)) {
174       return OutputPort();
175     }
176 
177     TensorId tensor_id = ParseTensorName(port.node->input(port.port_id));
178     return GetOutputPort(tensor_id.node(), tensor_id.index());
179   }
180 
181   // Checks if a tensor id is a fanin of the node.
HasFanin(const NodeDefT & node,const TensorId & fanin)182   bool HasFanin(const NodeDefT& node, const TensorId& fanin) const {
183     int end = node.input_size();
184     if (end == 0 || fanin.index() < -1) {
185       return false;
186     }
187 
188     const int num_regular_fanins =
189         gtl::FindWithDefault(max_regular_input_port_, &node, -1) + 1;
190     int start = 0;
191     if (fanin.index() > -1) {
192       end = num_regular_fanins;
193     } else {
194       start = num_regular_fanins;
195     }
196     for (int i = start; i < end; ++i) {
197       if (ParseTensorName(node.input(i)) == fanin) {
198         return true;
199       }
200     }
201     return false;
202   }
203 
204   // Gets all the input ports in the immediate fanout of a node. Include the
205   // controlled nodes iff include_controlled_nodes is true.
GetFanouts(const NodeDefT & node,bool include_controlled_nodes)206   absl::flat_hash_set<InputPort> GetFanouts(
207       const NodeDefT& node, bool include_controlled_nodes) const {
208     absl::flat_hash_set<InputPort> result;
209 
210     OutputPort port;
211     port.node = const_cast<NodeDefT*>(&node);
212     const int first_port_id = include_controlled_nodes ? -1 : 0;
213     const int last_port_id =
214         gtl::FindWithDefault(max_regular_output_port_, &node, -1);
215 
216     for (int i = first_port_id; i <= last_port_id; ++i) {
217       port.port_id = i;
218       auto it = fanouts_.find(port);
219       if (it != fanouts_.end()) {
220         result.insert(it->second.begin(), it->second.end());
221       }
222     }
223     return result;
224   }
225 
226   // Gets all the output ports in the immediate fanin of a node. Include the
227   // controlling nodes iff include_controlling_nodes is true.
GetFanins(const NodeDefT & node,bool include_controlling_nodes)228   absl::flat_hash_set<OutputPort> GetFanins(
229       const NodeDefT& node, bool include_controlling_nodes) const {
230     absl::flat_hash_set<OutputPort> result;
231     const int max_input_port =
232         include_controlling_nodes
233             ? node.input_size() - 1
234             : gtl::FindWithDefault(max_regular_input_port_, &node, -1);
235     for (int i = 0; i <= max_input_port; ++i) {
236       TensorId tensor_id = ParseTensorName(node.input(i));
237 
238       auto it = nodes_.find(tensor_id.node());
239       if (it != nodes_.end()) result.emplace(it->second, tensor_id.index());
240     }
241     return result;
242   }
243 
244   // Gets the number of ports in the immediate fanin of a node. Count the
245   // controlling nodes iff include_controlling_nodes is true.
NumFanins(const NodeDefT & node,bool include_controlling_nodes)246   int NumFanins(const NodeDefT& node, bool include_controlling_nodes) const {
247     if (include_controlling_nodes) {
248       return node.input_size();
249     }
250     return gtl::FindWithDefault(max_regular_input_port_, &node, -1) + 1;
251   }
252 
253   // Gets the number of ports in the immediate fanout of a node. Count the
254   // controlled nodes iff include_controlled_nodes is true.
NumFanouts(const NodeDefT & node,bool include_controlled_nodes)255   int NumFanouts(const NodeDefT& node, bool include_controlled_nodes) const {
256     int count = 0;
257 
258     OutputPort port;
259     port.node = const_cast<NodeDefT*>(&node);
260     const int first_port_id = include_controlled_nodes ? -1 : 0;
261     const int last_port_id =
262         gtl::FindWithDefault(max_regular_output_port_, &node, -1);
263 
264     for (int i = first_port_id; i <= last_port_id; ++i) {
265       port.port_id = i;
266       auto it = fanouts_.find(port);
267       if (it != fanouts_.end()) count += it->second.size();
268     }
269 
270     return count;
271   }
272 
273   // Gets all the edges in the immediate fanout of a node. Include the
274   // controlled edges iff include_controlled_edges is true.
GetFanoutEdges(const NodeDefT & node,bool include_controlled_edges)275   absl::flat_hash_set<Edge> GetFanoutEdges(
276       const NodeDefT& node, bool include_controlled_edges) const {
277     absl::flat_hash_set<Edge> result;
278 
279     OutputPort port;
280     port.node = const_cast<NodeDefT*>(&node);
281     const int first_port_id = include_controlled_edges ? -1 : 0;
282     const int last_port_id =
283         gtl::FindWithDefault(max_regular_output_port_, &node, -1);
284 
285     for (int i = first_port_id; i <= last_port_id; ++i) {
286       port.port_id = i;
287       auto it = fanouts_.find(port);
288       if (it != fanouts_.end()) {
289         for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
290           result.emplace(/*src=*/port, /*dst=*/*itr);
291         }
292       }
293     }
294     return result;
295   }
296 
297   // Gets all the edges in the immediate fanin of a node. Include the
298   // controlling edges iff include_controlling_edges is true.
GetFaninEdges(const NodeDefT & node,bool include_controlling_edges)299   absl::flat_hash_set<Edge> GetFaninEdges(
300       const NodeDefT& node, bool include_controlling_edges) const {
301     absl::flat_hash_set<Edge> result;
302     const int max_input_port =
303         include_controlling_edges
304             ? node.input_size() - 1
305             : gtl::FindWithDefault(max_regular_input_port_, &node, -1);
306     for (int i = 0; i <= max_input_port; ++i) {
307       TensorId tensor_id = ParseTensorName(node.input(i));
308 
309       auto it = nodes_.find(tensor_id.node());
310       if (it != nodes_.end()) {
311         result.emplace(/*src=*/OutputPort(it->second, tensor_id.index()),
312                        /*dst=*/InputPort(const_cast<NodeDefT*>(&node), i));
313       }
314     }
315     return result;
316   }
317 
318  protected:
GraphViewInternal(GraphDefT * graph)319   explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {}
320 
AddUniqueNode(NodeDefT * node)321   Status AddUniqueNode(NodeDefT* node) {
322     auto inserted = nodes_.emplace(node->name(), node);
323     return inserted.second
324                ? Status::OK()
325                : errors::InvalidArgument("Non unique node name detected: ",
326                                          node->name());
327   }
328 
329   // TODO(ezhulenev): Remove this function.
AddUniqueNodeOrDie(NodeDefT * node)330   void AddUniqueNodeOrDie(NodeDefT* node) {
331     Status st = AddUniqueNode(node);
332     CHECK(st.ok()) << st.error_message();
333   }
334 
335   // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins
336   // exist, and all regular fanins come before controlling fanins.
AddFanouts(NodeDefT * node)337   void AddFanouts(NodeDefT* node) {
338     int max_input_port = -1;
339     for (int i = 0; i < node->input_size(); ++i) {
340       TensorId tensor_id = ParseTensorName(node->input(i));
341       OutputPort output(nodes_[tensor_id.node()], tensor_id.index());
342 
343       if (output.port_id < 0) {
344         fanouts_[output].emplace(node, -1);
345       } else {
346         max_input_port = i;
347         max_regular_output_port_[output.node] =
348             std::max(max_regular_output_port_[output.node], output.port_id);
349         fanouts_[output].emplace(node, i);
350       }
351     }
352     if (max_input_port > -1) {
353       max_regular_input_port_[node] = max_input_port;
354     }
355   }
356 
357   // Access to the mutable internal state for MutableGraphView.
nodes()358   absl::flat_hash_map<absl::string_view, NodeDefT*>& nodes() { return nodes_; }
359 
fanouts()360   absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>>& fanouts() {
361     return fanouts_;
362   }
363 
max_regular_input_port()364   absl::flat_hash_map<const NodeDefT*, int>& max_regular_input_port() {
365     return max_regular_input_port_;
366   }
367 
max_regular_output_port()368   absl::flat_hash_map<const NodeDefT*, int>& max_regular_output_port() {
369     return max_regular_output_port_;
370   }
371 
372  private:
373   GraphDefT* graph_;  // must outlive the graph view
374 
375   // A mapping from the node name to the node itself.
376   absl::flat_hash_map<absl::string_view, NodeDefT*> nodes_;
377 
378   // A mapping from the output port to all inputs that read from it.
379   absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>> fanouts_;
380 
381   // Keep a maximum index of input tensors of the node.
382   absl::flat_hash_map<const NodeDefT*, int> max_regular_input_port_;
383 
384   // Keep a maximum index of tensor fetched from the node. It doesn't guarantee
385   // that all tensors in the [0, max_regular_output_port] range are actually
386   // fetched by other nodes.
387   absl::flat_hash_map<const NodeDefT*, int> max_regular_output_port_;
388 
389   // If the node has no fanouts at given output port (output tensor consumers)
390   // we return a reference to this set from `GetFanout` (we can't construct new
391   // empty set every time, because we need a non-dangling reference).
392   absl::flat_hash_set<InputPort> fanout_not_found_value_;
393 };
394 
395 }  // namespace internal
396 
397 // Immutable GraphView that keeps the constness of the GraphDef. If you need to
398 // mutate the graph or the nodes via the graph view lookup functions, see
399 // MutableGraphView.
400 class GraphView
401     : public internal::GraphViewInternal<const GraphDef, const NodeDef> {
402  public:
GraphView(const GraphDef * graph)403   explicit GraphView(const GraphDef* graph) : GraphViewInternal(graph) {
404     for (const NodeDef& node : graph->node()) AddUniqueNodeOrDie(&node);
405     for (const NodeDef& node : graph->node()) AddFanouts(&node);
406   }
407 };
408 
409 // Returns true if node has one (or zero) fanout nodes at given output port.
410 bool HasSingleFanoutNode(const GraphView& graph_view, const NodeDef* node,
411                          int port = 0);
412 
413 // Returns true if node has at least one fanout node at given output port.
414 bool HasFanouts(const GraphView& graph_view, const NodeDef* node, int port = 0);
415 // Returns true if the node has at least one input control dependency.
416 bool HasControlFanin(const GraphView& graph_view, const NodeDef* node);
417 // Returns true if the node has at least one output control dependency.
418 bool HasControlFanout(const GraphView& graph_view, const NodeDef* node);
419 // Returns true if the node has at least one input or output control dependency.
420 bool HasControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node);
421 
422 }  // end namespace grappler
423 }  // end namespace tensorflow
424 
425 #endif  // TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_
426