1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ 18 19 #include <deque> 20 #include "tensorflow/compiler/xla/status_macros.h" 21 #include "tensorflow/core/framework/function.h" 22 #include "tensorflow/core/graph/graph.h" 23 24 namespace tensorflow { 25 26 // Functionalize all the switch-merge nodes of a loop-free graph into If 27 // nodes. That is, attempt to transform every remaining switch and merge nodes 28 // in the graph into If nodes. 29 // Precondition: All while loops have been removed from graph. 30 Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); 31 32 // Internal functions/classes exposed for testing purposes. 33 namespace functionalize_cond { 34 35 // All nodes are assumed to be either in no branch, then branch, else branch, 36 // or both branches (such as merge nodes). 37 // The code below relies on Else and Then being 0 and 1 (corresponding to the 38 // switch outputs). Both and Neither are arbitrary. 39 enum class BranchType { 40 kElseBranch = 0, 41 kThenBranch = 1, 42 kBoth = 2, 43 kNeither = 3, 44 }; 45 46 // When we keep track of which switch/merge node's feed into a node, we record 47 // 1) predicate for non-dead switch node, 48 // 2) the switch node itself for dead switch node, 49 // 3) the merge node itself for merge node. 50 // Case 1) is an optimization. With this optimization, if there are nodes from 51 // different switch nodes but those switch nodes have the same predicate, the 52 // nodes will still have same AncestorState, and they will be clustered into a 53 // single "If". 54 struct AncestorNode { 55 enum class AncestorNodeType { 56 kPred = 0, 57 kSwitch = 1, 58 kMerge = 2, 59 }; 60 61 OutputTensor output_tensor; 62 AncestorNodeType type; 63 64 // Compare two AncestorNodes by (node id, index, type). 65 bool operator<(const AncestorNode& other) const; 66 bool operator==(const AncestorNode& other) const; 67 68 struct Hash { 69 size_t operator()(const AncestorNode&) const; 70 }; 71 }; 72 73 // StateMap is responsible for mapping from each graph Node to 74 // * a CondState, where each CondState is a map from predicate to branch (i,e., 75 // what predicates have to hold or not hold). 76 // * a AncestorState, where each AncestorState is a set of switch/merge nodes 77 // that are an ancestor of the node in the graph; 78 // For efficiency, this class interns the CondState (AncestorState), so that 79 // CondState (AncestorState) equality comparisons are simply pointer 80 // comparisons. 81 class StateMap { 82 public: 83 explicit StateMap(Graph* graph); 84 85 // Compare two OutputTensors by (node id, index). 86 struct OutputTensorLess { 87 bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const; 88 }; 89 90 // A node in the graph is executed when multiple conditions hold. Keep track 91 // of the predicates that must hold for a node to execute. 92 using CondState = std::map<OutputTensor, BranchType, OutputTensorLess>; 93 94 // Every unique ID is mapped to a CondState. 95 using CondId = const CondState*; 96 97 // Keep track of which switch/merge node's feed into a node's values. 98 using AncestorState = std::set<AncestorNode>; 99 100 // Every unique ID is mapped to a AncestorState. 101 using AncestorId = const AncestorState*; 102 103 // Returns the CondId for a given node. 104 CondId LookupCondId(const Node* node) const; 105 106 // Returns the unique CondId for CondState. 107 CondId GetCondId(const CondState& state); 108 109 // Resets the CondId for a given node. 110 void ResetCondId(const Node* node, CondId id); 111 112 // Returns the AncestorId for a given node. 113 AncestorId LookupAncestorId(const Node* node) const; 114 115 // Returns the unique AncestorId for CondState. 116 AncestorId GetAncestorId(const AncestorState& state); 117 118 // Resets the AncestorId for a given node. 119 void ResetAncestorId(const Node* node, AncestorId id); 120 121 // Marks `node` as dead. 122 void MarkDead(const Node* node); 123 124 // Determine branch execution of CondState. 125 BranchType FindBranchOf(CondId id, OutputTensor predicate) const; 126 127 // Returns textual representation of node's CondState. 128 string CondStateToString(const Node* node) const; 129 string CondStateToString(CondId id) const; 130 131 // Returns textual representation of node's AncestorState. 132 string AncestorStateToString(const Node* node) const; 133 134 // Returns whether the cond state is the dead state. 135 bool IsDead(CondId id) const; 136 137 // Returns whether the cond state is the empty state. 138 bool IsEmpty(CondId id) const; 139 140 private: 141 // Hash for CondState and AncestorState. 142 struct Hash { 143 size_t operator()(const CondState& map) const; 144 size_t operator()(const AncestorState& map) const; 145 }; 146 147 // Set to keep track of unique CondStates. 148 // Pointers to the entries in the unordered set are used as identifiers: 149 // unordered_set guarantees that the pointers remain the same. 150 std::unordered_set<CondState, Hash> condstate_set_; 151 152 // Mapping from Node id to CondId. 153 std::vector<CondId> node_to_condid_map_; 154 155 // Track the CondId for newly inserted nodes. We use a vector to quickly map 156 // from Node id in the original graph to the CondId, but there will be nodes 157 // added to the original graph (such as If nodes) whose CondState needs to be 158 // tracked too. 159 std::unordered_map<int, CondId> added_node_condid_mapping_; 160 161 // AncestorId variants of the CondId members. 162 std::unordered_set<AncestorState, Hash> ancestorstate_set_; 163 std::vector<AncestorId> node_to_ancestorid_map_; 164 std::unordered_map<int, AncestorId> added_node_ancestorid_mapping_; 165 166 // Identifier of the dead flow state. The empty flow state is represented with 167 // a nullptr. 168 CondId dead_id_; 169 }; 170 171 // FunctionalizeCond groups all the state used by functionalizing conditionals 172 // of the given graph together. 173 class FunctionalizeCond { 174 public: 175 // Functionalize all the switch-merge nodes of a loop-free graph into If 176 // nodes. That is, attempt to transform every remaining switch and merge nodes 177 // in the graph into If nodes. 178 // Precondition: All while loops have been removed from graph. 179 static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); 180 181 // Build identity node with the same name as the merge that will be replaced 182 // in case the output is fetched/colocated. 183 Status AddIdentityNode(const Node* replacee, Node* if_node, int port); 184 185 // Add a If node to the graph defined by def that will, amongst other, replace 186 // replacee in the graph. 187 xla::StatusOr<Node*> AddIfNode(const NodeDef& def, const Node* replacee, 188 const OutputTensor& predicate); 189 190 // Propagates the state of a newly inserted node. 191 Status PropagateUpdatedState(const Node* replacee); 192 193 // Dump graph with the CondState annotated. 194 void DumpGraphWithCondState(const string& name); 195 196 // Adds `switch_id` to the list of Switch node ids. 197 void AddSwitchId(int switch_id); 198 199 private: 200 FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); 201 202 // Performs the actual cond functionalization. Iterate over groups of merge 203 // nodes (linked by common predicates & ancestor IDs), from innermost to 204 // outermost, and extract into If nodes. 205 Status FunctionalizeInternal(); 206 207 // Returns the forward flow state propagated along edge `e`. 208 // This may modify state_map_. 209 StateMap::CondId StateAlongEdge(const Edge* e); 210 211 // Determines the CondState and AncestorState of all the nodes in the given 212 // vector where the input is expected in reverse topological order. 213 // This populates the state_map_. 214 Status DetermineStates(std::vector<Node*> rev_topo_order); 215 216 // Determine the CondState for a given node using the incomming edges 217 // to the node. Note: it is expected that this node's CondState is only 218 // determined once its input's CondState is. DetermineCondState(Node * dst)219 Status DetermineCondState(Node* dst) { 220 if (IsMerge(dst)) return DetermineCondStateMerge(dst); 221 return DetermineCondStateNonMerge(dst); 222 } 223 224 // Helper functions for DetermineCondState. 225 Status DetermineCondStateNonMerge(Node* dst); 226 Status DetermineCondStateMerge(Node* dst); 227 228 // Determines the dst node's CondState by joining the src and dst's CondState 229 // where either the dst node is a merge or not. 230 // These may modify state_map_. 231 xla::StatusOr<StateMap::CondId> JoinCondStatesMerge(Node* merge, 232 StateMap::CondId src, 233 StateMap::CondId dst); 234 xla::StatusOr<StateMap::CondId> JoinCondStatesNonMerge(StateMap::CondId src, 235 StateMap::CondId dst); 236 237 // Determines which switch/merge nodes are ancestors of this node. 238 Status DetermineAncestorState(Node* dst); 239 240 // Checks if a merge node is redundant and if so removes it from the graph. 241 Status RemoveRedundantMerge(Node* node); 242 243 // Checks if a switch node is redundant and if so removes it from the graph. 244 Status RemoveRedundantSwitch(Node* node); 245 246 // Sorts merge nodes (in reverse topological order) in order of increasing 247 // nesting depth. 248 void SortMergeNodes(std::vector<Node*>* merge_order); 249 250 // Deletes all nodes in/consumers reachable from switch/merge nodes that were 251 // extracted. 252 void DeleteReachableAndDeadNodes(const std::vector<Node*>& merge_order); 253 254 // Member used to unique the CondState to a unique CondId (AncestorState to a 255 // unique AncestorId) and keep track of CondState/CondId 256 // (AncestorState/AncestorId) per Node. 257 StateMap state_map_; 258 259 // Mapping from merge nodes to predicate. 260 std::unordered_map<Node*, OutputTensor> merge_to_predicate_; 261 262 // Mapping from merge nodes to corresponding If node outputs. 263 std::unordered_map<Node*, OutputTensor> merge_to_replacement_; 264 265 FunctionLibraryDefinition* library_; 266 Graph* graph_; 267 268 friend class FunctionalizeCondTest; 269 270 std::vector<int> switch_ids_; 271 }; 272 273 } // namespace functionalize_cond 274 275 } // namespace tensorflow 276 277 #endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ 278