1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
18 
19 #include <deque>
20 #include "tensorflow/compiler/xla/status_macros.h"
21 #include "tensorflow/core/framework/function.h"
22 #include "tensorflow/core/graph/graph.h"
23 
24 namespace tensorflow {
25 
26 // Functionalize all the switch-merge nodes of a loop-free graph into If
27 // nodes. That is, attempt to transform every remaining switch and merge nodes
28 // in the graph into If nodes.
29 // Precondition: All while loops have been removed from graph.
30 Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library);
31 
32 // Internal functions/classes exposed for testing purposes.
33 namespace functionalize_cond {
34 
35 // All nodes are assumed to be either in no branch, then branch, else branch,
36 // or both branches (such as merge nodes).
37 // The code below relies on Else and Then being 0 and 1 (corresponding to the
38 // switch outputs). Both and Neither are arbitrary.
39 enum class BranchType {
40   kElseBranch = 0,
41   kThenBranch = 1,
42   kBoth = 2,
43   kNeither = 3,
44 };
45 
46 // When we keep track of which switch/merge node's feed into a node, we record
47 // 1) predicate for non-dead switch node,
48 // 2) the switch node itself for dead switch node,
49 // 3) the merge node itself for merge node.
50 // Case 1) is an optimization. With this optimization, if there are nodes from
51 // different switch nodes but those switch nodes have the same predicate, the
52 // nodes will still have same AncestorState, and they will be clustered into a
53 // single "If".
54 struct AncestorNode {
55   enum class AncestorNodeType {
56     kPred = 0,
57     kSwitch = 1,
58     kMerge = 2,
59   };
60 
61   OutputTensor output_tensor;
62   AncestorNodeType type;
63 
64   // Compare two AncestorNodes by (node id, index, type).
65   bool operator<(const AncestorNode& other) const;
66   bool operator==(const AncestorNode& other) const;
67 
68   struct Hash {
69     size_t operator()(const AncestorNode&) const;
70   };
71 };
72 
73 // StateMap is responsible for mapping from each graph Node to
74 // * a CondState, where each CondState is a map from predicate to branch (i,e.,
75 //   what predicates have to hold or not hold).
76 // * a AncestorState, where each AncestorState is a set of switch/merge nodes
77 //   that are an ancestor of the node in the graph;
78 // For efficiency, this class interns the CondState (AncestorState), so that
79 // CondState (AncestorState) equality comparisons are simply pointer
80 // comparisons.
81 class StateMap {
82  public:
83   explicit StateMap(Graph* graph);
84 
85   // Compare two OutputTensors by (node id, index).
86   struct OutputTensorLess {
87     bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const;
88   };
89 
90   // A node in the graph is executed when multiple conditions hold. Keep track
91   // of the predicates that must hold for a node to execute.
92   using CondState = std::map<OutputTensor, BranchType, OutputTensorLess>;
93 
94   // Every unique ID is mapped to a CondState.
95   using CondId = const CondState*;
96 
97   // Keep track of which switch/merge node's feed into a node's values.
98   using AncestorState = std::set<AncestorNode>;
99 
100   // Every unique ID is mapped to a AncestorState.
101   using AncestorId = const AncestorState*;
102 
103   // Returns the CondId for a given node.
104   CondId LookupCondId(const Node* node) const;
105 
106   // Returns the unique CondId for CondState.
107   CondId GetCondId(const CondState& state);
108 
109   // Resets the CondId for a given node.
110   void ResetCondId(const Node* node, CondId id);
111 
112   // Returns the AncestorId for a given node.
113   AncestorId LookupAncestorId(const Node* node) const;
114 
115   // Returns the unique AncestorId for CondState.
116   AncestorId GetAncestorId(const AncestorState& state);
117 
118   // Resets the AncestorId for a given node.
119   void ResetAncestorId(const Node* node, AncestorId id);
120 
121   // Marks `node` as dead.
122   void MarkDead(const Node* node);
123 
124   // Determine branch execution of CondState.
125   BranchType FindBranchOf(CondId id, OutputTensor predicate) const;
126 
127   // Returns textual representation of node's CondState.
128   string CondStateToString(const Node* node) const;
129   string CondStateToString(CondId id) const;
130 
131   // Returns textual representation of node's AncestorState.
132   string AncestorStateToString(const Node* node) const;
133 
134   // Returns whether the cond state is the dead state.
135   bool IsDead(CondId id) const;
136 
137   // Returns whether the cond state is the empty state.
138   bool IsEmpty(CondId id) const;
139 
140  private:
141   // Hash for CondState and AncestorState.
142   struct Hash {
143     size_t operator()(const CondState& map) const;
144     size_t operator()(const AncestorState& map) const;
145   };
146 
147   // Set to keep track of unique CondStates.
148   // Pointers to the entries in the unordered set are used as identifiers:
149   // unordered_set guarantees that the pointers remain the same.
150   std::unordered_set<CondState, Hash> condstate_set_;
151 
152   // Mapping from Node id to CondId.
153   std::vector<CondId> node_to_condid_map_;
154 
155   // Track the CondId for newly inserted nodes. We use a vector to quickly map
156   // from Node id in the original graph to the CondId, but there will be nodes
157   // added to the original graph (such as If nodes) whose CondState needs to be
158   // tracked too.
159   std::unordered_map<int, CondId> added_node_condid_mapping_;
160 
161   // AncestorId variants of the CondId members.
162   std::unordered_set<AncestorState, Hash> ancestorstate_set_;
163   std::vector<AncestorId> node_to_ancestorid_map_;
164   std::unordered_map<int, AncestorId> added_node_ancestorid_mapping_;
165 
166   // Identifier of the dead flow state. The empty flow state is represented with
167   // a nullptr.
168   CondId dead_id_;
169 };
170 
171 // FunctionalizeCond groups all the state used by functionalizing conditionals
172 // of the given graph together.
173 class FunctionalizeCond {
174  public:
175   // Functionalize all the switch-merge nodes of a loop-free graph into If
176   // nodes. That is, attempt to transform every remaining switch and merge nodes
177   // in the graph into If nodes.
178   // Precondition: All while loops have been removed from graph.
179   static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library);
180 
181   // Build identity node with the same name as the merge that will be replaced
182   // in case the output is fetched/colocated.
183   Status AddIdentityNode(const Node* replacee, Node* if_node, int port);
184 
185   // Add a If node to the graph defined by def that will, amongst other, replace
186   // replacee in the graph.
187   xla::StatusOr<Node*> AddIfNode(const NodeDef& def, const Node* replacee,
188                                  const OutputTensor& predicate);
189 
190   // Propagates the state of a newly inserted node.
191   Status PropagateUpdatedState(const Node* replacee);
192 
193   // Dump graph with the CondState annotated.
194   void DumpGraphWithCondState(const string& name);
195 
196   // Adds `switch_id` to the list of Switch node ids.
197   void AddSwitchId(int switch_id);
198 
199  private:
200   FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library);
201 
202   // Performs the actual cond functionalization. Iterate over groups of merge
203   // nodes (linked by common predicates & ancestor IDs), from innermost to
204   // outermost, and extract into If nodes.
205   Status FunctionalizeInternal();
206 
207   // Returns the forward flow state propagated along edge `e`.
208   // This may modify state_map_.
209   StateMap::CondId StateAlongEdge(const Edge* e);
210 
211   // Determines the CondState and AncestorState of all the nodes in the given
212   // vector where the input is expected in reverse topological order.
213   // This populates the state_map_.
214   Status DetermineStates(std::vector<Node*> rev_topo_order);
215 
216   // Determine the CondState for a given node using the incomming edges
217   // to the node. Note: it is expected that this node's CondState is only
218   // determined once its input's CondState is.
DetermineCondState(Node * dst)219   Status DetermineCondState(Node* dst) {
220     if (IsMerge(dst)) return DetermineCondStateMerge(dst);
221     return DetermineCondStateNonMerge(dst);
222   }
223 
224   // Helper functions for DetermineCondState.
225   Status DetermineCondStateNonMerge(Node* dst);
226   Status DetermineCondStateMerge(Node* dst);
227 
228   // Determines the dst node's CondState by joining the src and dst's CondState
229   // where either the dst node is a merge or not.
230   // These may modify state_map_.
231   xla::StatusOr<StateMap::CondId> JoinCondStatesMerge(Node* merge,
232                                                       StateMap::CondId src,
233                                                       StateMap::CondId dst);
234   xla::StatusOr<StateMap::CondId> JoinCondStatesNonMerge(StateMap::CondId src,
235                                                          StateMap::CondId dst);
236 
237   // Determines which switch/merge nodes are ancestors of this node.
238   Status DetermineAncestorState(Node* dst);
239 
240   // Checks if a merge node is redundant and if so removes it from the graph.
241   Status RemoveRedundantMerge(Node* node);
242 
243   // Checks if a switch node is redundant and if so removes it from the graph.
244   Status RemoveRedundantSwitch(Node* node);
245 
246   // Sorts merge nodes (in reverse topological order) in order of increasing
247   // nesting depth.
248   void SortMergeNodes(std::vector<Node*>* merge_order);
249 
250   // Deletes all nodes in/consumers reachable from switch/merge nodes that were
251   // extracted.
252   void DeleteReachableAndDeadNodes(const std::vector<Node*>& merge_order);
253 
254   // Member used to unique the CondState to a unique CondId (AncestorState to a
255   // unique AncestorId) and keep track of CondState/CondId
256   // (AncestorState/AncestorId) per Node.
257   StateMap state_map_;
258 
259   // Mapping from merge nodes to predicate.
260   std::unordered_map<Node*, OutputTensor> merge_to_predicate_;
261 
262   // Mapping from merge nodes to corresponding If node outputs.
263   std::unordered_map<Node*, OutputTensor> merge_to_replacement_;
264 
265   FunctionLibraryDefinition* library_;
266   Graph* graph_;
267 
268   friend class FunctionalizeCondTest;
269 
270   std::vector<int> switch_ids_;
271 };
272 
273 }  // namespace functionalize_cond
274 
275 }  // namespace tensorflow
276 
277 #endif  // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
278