1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_ 18 19 #include <map> 20 #include <memory> 21 #include <vector> 22 23 #include "tensorflow/core/framework/graph.pb.h" 24 #include "tensorflow/core/framework/node_def.pb.h" 25 #include "tensorflow/core/framework/op_def.pb.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/protobuf/meta_graph.pb.h" 28 29 namespace tensorflow { 30 namespace grappler { 31 namespace graph_analyzer { 32 33 class GenNode; 34 35 // To find nodes by name. 36 using GenNodeMap = std::unordered_map<string, std::unique_ptr<GenNode>>; 37 38 // One node in the graph, in the form convenient for traversal and generation of 39 // subgraphs. It refers to the original NodeDef protobuf for most information 40 // and adds the extra enrichment. 41 // 42 // The graph building is 2-stage: first match a GenNode with each NodeDef and 43 // collect them into a map that finds them by name, then process the map, 44 // deep-parse the underlying NodeDefs and connect the GenNodes together. 45 class GenNode { 46 public: 47 // Will keep the pointer, so the underlying object must not be deleted while 48 // GenNode is alive. 49 explicit GenNode(const NodeDef* node); 50 51 // Access wrappers. name()52 const string& name() const { return node_->name(); } opcode()53 const string& opcode() const { return node_->op(); } node_def()54 const NodeDef* node_def() const { return node_; } 55 56 // Parse the inputs of this node and update the map accordingly, creating the 57 // links (i.e. edges, connections between nodes) in itself and in the nodes 58 // it's linked to (the map itself is unchanged, only the nodes in it are 59 // updated). 60 Status ParseInputs(const GenNodeMap* map); 61 62 // Does the full 2-stage build of the graph. The map should be initially 63 // empty. The map keeps pointers to the nodes in source, so the source must 64 // not be destroyed before the map. 65 static Status BuildGraphInMap(const GraphDef& source, GenNodeMap* map); 66 67 // The enrichment that constitutes the point of this class. 68 69 // Representation of a connection on a node. 70 class Port { 71 public: 72 // A port may be inbound or outbound. 73 // Negative ids (canonically -1) mean a control port. Port(bool inbound,int32_t id)74 Port(bool inbound, int32_t id) : value_(id << 1) { 75 if (inbound) { 76 value_ |= 1; 77 } 78 } 79 Port(const Port&) = default; 80 Port& operator=(const Port&) = default; 81 IsInbound()82 bool IsInbound() const { return (value_ & 0x1); } 83 IsControl()84 bool IsControl() const { return (value_ < 0); } 85 Id()86 int32_t Id() const { 87 // Arithmetic shift preserves the sign. 88 return (value_ >> 1); 89 } 90 91 // Integer type used to represent the encoded port value. 92 using IntPort = int32_t; 93 94 // Returns the encoded form of this port, so that it can be used 95 // as various map indexes. Encoded()96 IntPort Encoded() const { return value_; } 97 Decode(IntPort encoded)98 static Port Decode(IntPort encoded) { return Port(encoded); } 99 100 bool operator==(const Port& other) const { return value_ == other.value_; } 101 bool operator<(const Port& other) const { return value_ < other.value_; } 102 103 struct Hasher { operatorHasher104 size_t operator()(const Port& port) const noexcept { 105 return hasher(port.Encoded()); 106 } 107 std::hash<int32_t> hasher; 108 }; 109 110 // Convenient for printing. I've really wanted it to be implicit but 111 // ClangTidy insists on making it explicit. 112 explicit operator string() const; 113 114 private: Port(IntPort value)115 explicit Port(IntPort value) : value_(value) {} 116 117 IntPort value_; 118 }; 119 120 struct LinkTarget { 121 GenNode* node; // Node where this link points. 122 Port port; // Port on the remote side of this link. 123 LinkTargetLinkTarget124 LinkTarget(GenNode* a_node, Port a_port) : node(a_node), port(a_port) {} 125 }; 126 // All the links that are connected to the same port of this node 127 // are collected in one vector. A link is an edge of the graph that connects 128 // 2 nodes. Each of the connected nodes has its own perspective on the link, 129 // seeing its local port, remote port and the remote node. The direction of 130 // the link is encoded in the ports, one port is always incoming and another 131 // one outgoing. 132 using LinkTargetVector = std::vector<LinkTarget>; 133 // Both inputs and outputs are stored in the same map. 134 using LinkMap = std::unordered_map<Port, LinkTargetVector, Port::Hasher>; 135 136 // Access to the link map. links()137 const LinkMap& links() const { return links_; } 138 139 // Check whether the port is an input (including the controls) with multiple 140 // connections. Such inputs get handled in a special way when building the 141 // subgraphs, in an "all or nothing" fashion. 142 bool IsMultiInput(Port port) const; 143 144 // When building the subgraphs, must include either all non-control inputs of 145 // this node into the subgraph or none of them. This happens when at least one 146 // of the inputs is a multi-input (or if the opcode is commutative, thus 147 // treating all the inputs as one multi-input). AllInputsOrNone()148 bool AllInputsOrNone() const { return all_inputs_or_none_; } 149 150 private: 151 const NodeDef* node_; 152 // Becomes valid only after ParseInputs(). 153 const OpDef* op_; 154 155 // The opcode has a complicated structure of input args, with multi-input args 156 // that are not commutative. This means that to make sense, the subgraphs that 157 // include this node must also include either all its inputs or none of them. 158 bool all_inputs_or_none_ = false; 159 160 LinkMap links_; 161 }; 162 163 } // end namespace graph_analyzer 164 } // end namespace grappler 165 } // end namespace tensorflow 166 167 #endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_ 168