1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
18 
19 #include <map>
20 #include <memory>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op_def.pb.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/protobuf/meta_graph.pb.h"
28 
29 namespace tensorflow {
30 namespace grappler {
31 namespace graph_analyzer {
32 
33 class GenNode;
34 
35 // To find nodes by name.
36 using GenNodeMap = std::unordered_map<string, std::unique_ptr<GenNode>>;
37 
38 // One node in the graph, in the form convenient for traversal and generation of
39 // subgraphs. It refers to the original NodeDef protobuf for most information
40 // and adds the extra enrichment.
41 //
42 // The graph building is 2-stage: first match a GenNode with each NodeDef and
43 // collect them into a map that finds them by name, then process the map,
44 // deep-parse the underlying NodeDefs and connect the GenNodes together.
45 class GenNode {
46  public:
47   // Will keep the pointer, so the underlying object must not be deleted while
48   // GenNode is alive.
49   explicit GenNode(const NodeDef* node);
50 
51   // Access wrappers.
name()52   const string& name() const { return node_->name(); }
opcode()53   const string& opcode() const { return node_->op(); }
node_def()54   const NodeDef* node_def() const { return node_; }
55 
56   // Parse the inputs of this node and update the map accordingly, creating the
57   // links (i.e. edges, connections between nodes) in itself and in the nodes
58   // it's linked to (the map itself is unchanged, only the nodes in it are
59   // updated).
60   Status ParseInputs(const GenNodeMap* map);
61 
62   // Does the full 2-stage build of the graph. The map should be initially
63   // empty. The map keeps pointers to the nodes in source, so the source must
64   // not be destroyed before the map.
65   static Status BuildGraphInMap(const GraphDef& source, GenNodeMap* map);
66 
67   // The enrichment that constitutes the point of this class.
68 
69   // Representation of a connection on a node.
70   class Port {
71    public:
72     // A port may be inbound or outbound.
73     // Negative ids (canonically -1) mean a control port.
Port(bool inbound,int32_t id)74     Port(bool inbound, int32_t id) : value_(id << 1) {
75       if (inbound) {
76         value_ |= 1;
77       }
78     }
79     Port(const Port&) = default;
80     Port& operator=(const Port&) = default;
81 
IsInbound()82     bool IsInbound() const { return (value_ & 0x1); }
83 
IsControl()84     bool IsControl() const { return (value_ < 0); }
85 
Id()86     int32_t Id() const {
87       // Arithmetic shift preserves the sign.
88       return (value_ >> 1);
89     }
90 
91     // Integer type used to represent the encoded port value.
92     using IntPort = int32_t;
93 
94     // Returns the encoded form of this port, so that it can be used
95     // as various map indexes.
Encoded()96     IntPort Encoded() const { return value_; }
97 
Decode(IntPort encoded)98     static Port Decode(IntPort encoded) { return Port(encoded); }
99 
100     bool operator==(const Port& other) const { return value_ == other.value_; }
101     bool operator<(const Port& other) const { return value_ < other.value_; }
102 
103     struct Hasher {
operatorHasher104       size_t operator()(const Port& port) const noexcept {
105         return hasher(port.Encoded());
106       }
107       std::hash<int32_t> hasher;
108     };
109 
110     // Convenient for printing. I've really wanted it to be implicit but
111     // ClangTidy insists on making it explicit.
112     explicit operator string() const;
113 
114    private:
Port(IntPort value)115     explicit Port(IntPort value) : value_(value) {}
116 
117     IntPort value_;
118   };
119 
120   struct LinkTarget {
121     GenNode* node;  // Node where this link points.
122     Port port;      // Port on the remote side of this link.
123 
LinkTargetLinkTarget124     LinkTarget(GenNode* a_node, Port a_port) : node(a_node), port(a_port) {}
125   };
126   // All the links that are connected to the same port of this node
127   // are collected in one vector. A link is an edge of the graph that connects
128   // 2 nodes. Each of the connected nodes has its own perspective on the link,
129   // seeing its local port, remote port and the remote node. The direction of
130   // the link is encoded in the ports, one port is always incoming and another
131   // one outgoing.
132   using LinkTargetVector = std::vector<LinkTarget>;
133   // Both inputs and outputs are stored in the same map.
134   using LinkMap = std::unordered_map<Port, LinkTargetVector, Port::Hasher>;
135 
136   // Access to the link map.
links()137   const LinkMap& links() const { return links_; }
138 
139   // Check whether the port is an input (including the controls) with multiple
140   // connections. Such inputs get handled in a special way when building the
141   // subgraphs, in an "all or nothing" fashion.
142   bool IsMultiInput(Port port) const;
143 
144   // When building the subgraphs, must include either all non-control inputs of
145   // this node into the subgraph or none of them. This happens when at least one
146   // of the inputs is a multi-input (or if the opcode is commutative, thus
147   // treating all the inputs as one multi-input).
AllInputsOrNone()148   bool AllInputsOrNone() const { return all_inputs_or_none_; }
149 
150  private:
151   const NodeDef* node_;
152   // Becomes valid only after ParseInputs().
153   const OpDef* op_;
154 
155   // The opcode has a complicated structure of input args, with multi-input args
156   // that are not commutative. This means that to make sense, the subgraphs that
157   // include this node must also include either all its inputs or none of them.
158   bool all_inputs_or_none_ = false;
159 
160   LinkMap links_;
161 };
162 
163 }  // end namespace graph_analyzer
164 }  // end namespace grappler
165 }  // end namespace tensorflow
166 
167 #endif  // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
168