1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
18 
19 #include <map>
20 #include <memory>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op_def.pb.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/protobuf/meta_graph.pb.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 namespace graph_analyzer {
33 
34 class GenNode;
35 
36 // To find nodes by name.
37 using GenNodeMap = std::unordered_map<string, std::unique_ptr<GenNode>>;
38 
39 // One node in the graph, in the form convenient for traversal and generation of
40 // subgraphs. It refers to the original NodeDef protobuf for most information
41 // and adds the extra enrichment.
42 //
43 // The graph building is 2-stage: first match a GenNode with each NodeDef and
44 // collect them into a map that finds them by name, then process the map,
45 // deep-parse the underlying NodeDefs and connect the GenNodes together.
46 class GenNode {
47  public:
48   // Will keep the pointer, so the underlying object must not be deleted while
49   // GenNode is alive.
50   explicit GenNode(const NodeDef* node);
51 
52   // Access wrappers.
name()53   const string& name() const { return node_->name(); }
opcode()54   const string& opcode() const { return node_->op(); }
node_def()55   const NodeDef* node_def() const { return node_; }
56 
57   // Parse the inputs of this node and update the map accordingly, creating the
58   // links (i.e. edges, connections between nodes) in itself and in the nodes
59   // it's linked to (the map itself is unchanged, only the nodes in it are
60   // updated).
61   Status ParseInputs(const GenNodeMap* map);
62 
63   // Does the full 2-stage build of the graph. The map should be initially
64   // empty. The map keeps pointers to the nodes in source, so the source must
65   // not be destroyed before the map.
66   static Status BuildGraphInMap(const GraphDef& source, GenNodeMap* map);
67 
68   // The enrichment that constitutes the point of this class.
69 
70   // Representation of a connection on a node.
71   class Port {
72    public:
73     // A port may be inbound or outbound.
74     // Negative ids (canonically -1) mean a control port.
Port(bool inbound,int32_t id)75     Port(bool inbound, int32_t id) : value_(id << 1) {
76       if (inbound) {
77         value_ |= 1;
78       }
79     }
80     Port(const Port&) = default;
81     Port& operator=(const Port&) = default;
82 
IsInbound()83     bool IsInbound() const { return (value_ & 0x1); }
84 
IsControl()85     bool IsControl() const { return (value_ < 0); }
86 
Id()87     int32_t Id() const {
88       // Arithmetic shift preserves the sign.
89       return (value_ >> 1);
90     }
91 
92     // Integer type used to represent the encoded port value.
93     using IntPort = int32_t;
94 
95     // Returns the encoded form of this port, so that it can be used
96     // as various map indexes.
Encoded()97     IntPort Encoded() const { return value_; }
98 
Decode(IntPort encoded)99     static Port Decode(IntPort encoded) { return Port(encoded); }
100 
101     bool operator==(const Port& other) const { return value_ == other.value_; }
102     bool operator<(const Port& other) const { return value_ < other.value_; }
103 
104     struct Hasher {
operatorHasher105       size_t operator()(const Port& port) const noexcept {
106         return hasher(port.Encoded());
107       }
108       std::hash<int32_t> hasher;
109     };
110 
111     // Convenient for printing. I've really wanted it to be implicit but
112     // ClangTidy insists on making it explicit.
113     explicit operator string() const;
114 
115    private:
Port(IntPort value)116     explicit Port(IntPort value) : value_(value) {}
117 
118     IntPort value_;
119   };
120 
121   struct LinkTarget {
122     GenNode* node;  // Node where this link points.
123     Port port;      // Port on the remote side of this link.
124 
LinkTargetLinkTarget125     LinkTarget(GenNode* a_node, Port a_port) : node(a_node), port(a_port) {}
126   };
127   // All the links that are connected to the same port of this node
128   // are collected in one vector. A link is an edge of the graph that connects
129   // 2 nodes. Each of the connected nodes has its own perspective on the link,
130   // seeing its local port, remote port and the remote node. The direction of
131   // the link is encoded in the ports, one port is always incoming and another
132   // one outgoing.
133   using LinkTargetVector = std::vector<LinkTarget>;
134   // Both inputs and outputs are stored in the same map.
135   using LinkMap = std::unordered_map<Port, LinkTargetVector, Port::Hasher>;
136 
137   // Access to the link map.
links()138   const LinkMap& links() const { return links_; }
139 
140   // Check whether the port is an input (including the controls) with multiple
141   // connections. Such inputs get handled in a special way when building the
142   // subgraphs, in an "all or nothing" fashion.
143   bool IsMultiInput(Port port) const;
144 
145   // When building the subgraphs, must include either all non-control inputs of
146   // this node into the subgraph or none of them. This happens when at least one
147   // of the inputs is a multi-input (or if the opcode is commutative, thus
148   // treating all the inputs as one multi-input).
AllInputsOrNone()149   bool AllInputsOrNone() const { return all_inputs_or_none_; }
150 
151  private:
152   const NodeDef* node_;
153   // Becomes valid only after ParseInputs().
154   const OpDef* op_;
155 
156   // The opcode has a complicated structure of input args, with multi-input args
157   // that are not commutative. This means that to make sense, the subgraphs that
158   // include this node must also include either all its inputs or none of them.
159   bool all_inputs_or_none_ = false;
160 
161   LinkMap links_;
162 };
163 
164 }  // end namespace graph_analyzer
165 }  // end namespace grappler
166 }  // end namespace tensorflow
167 
168 #endif  // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
169