1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/jit/union_find.h"
28 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/framework/graph_to_functiondef.h"
32 #include "tensorflow/core/framework/node_def_builder.h"
33 #include "tensorflow/core/graph/algorithm.h"
34 #include "tensorflow/core/graph/control_flow.h"
35 #include "tensorflow/core/graph/node_builder.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/hash/hash.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/util/dump_graph.h"
40 
41 using xla::StatusOr;
42 
43 namespace tensorflow {
44 namespace functionalize_cond {
45 
operator <(const AncestorNode & other) const46 bool AncestorNode::operator<(const AncestorNode& other) const {
47   return (output_tensor.node->id() < other.output_tensor.node->id()) ||
48          (output_tensor.node->id() == other.output_tensor.node->id() &&
49           output_tensor.index < other.output_tensor.index) ||
50          (output_tensor.node->id() == other.output_tensor.node->id() &&
51           output_tensor.index == other.output_tensor.index &&
52           type < other.type);
53 }
54 
operator ==(const AncestorNode & other) const55 bool AncestorNode::operator==(const AncestorNode& other) const {
56   return output_tensor.node->id() == other.output_tensor.node->id() &&
57          output_tensor.index == other.output_tensor.index && type == other.type;
58 }
59 
operator ()(const AncestorNode & ancestor) const60 size_t AncestorNode::Hash::operator()(const AncestorNode& ancestor) const {
61   size_t h = std::hash<int>()(ancestor.output_tensor.node->id());
62   h = Hash64Combine(h, std::hash<int>()(ancestor.output_tensor.index));
63   return Hash64Combine(h, std::hash<int>()(static_cast<int>(ancestor.type)));
64 }
65 
66 typedef std::tuple<StateMap::CondId, StateMap::AncestorId, OutputTensor>
67     ClusterTuple;
68 
69 struct ClusterTupleLessThan {
operator ()tensorflow::functionalize_cond::ClusterTupleLessThan70   bool operator()(const ClusterTuple& a, const ClusterTuple& b) const {
71     if (std::tie(std::get<0>(a), std::get<1>(a)) <
72         std::tie(std::get<0>(b), std::get<1>(b))) {
73       return true;
74     } else if (std::tie(std::get<0>(a), std::get<1>(a)) ==
75                std::tie(std::get<0>(b), std::get<1>(b))) {
76       return StateMap::OutputTensorLess()(std::get<2>(a), std::get<2>(b));
77     } else {
78       return false;
79     }
80   }
81 };
82 
83 // TODO(jpienaar): Move to OutputTensor.
DebugString(const OutputTensor & tensor)84 string DebugString(const OutputTensor& tensor) {
85   return absl::StrCat(tensor.node->name(), ":", tensor.index);
86 }
87 
Branch_Name(BranchType b)88 string Branch_Name(BranchType b) {
89   switch (b) {
90     case BranchType::kElseBranch:
91       return "else";
92     case BranchType::kThenBranch:
93       return "then";
94     case BranchType::kBoth:
95       return "both";
96     case BranchType::kNeither:
97       return "neither";
98   }
99 }
100 
DebugString(StateMap::CondId cond_state)101 string DebugString(StateMap::CondId cond_state) {
102   if (cond_state == nullptr || cond_state->empty()) return "{}";
103   using value_type = StateMap::CondState::value_type;
104   return absl::StrCat(
105       "{",
106       absl::StrJoin(*cond_state, ", ",
107                     [](string* output, const value_type& pred_branch) {
108                       const OutputTensor& pred = pred_branch.first;
109                       const BranchType& branch = pred_branch.second;
110                       if (branch == BranchType::kNeither)
111                         absl::StrAppend(output, "d");
112                       else
113                         absl::StrAppend(output, "s(", DebugString(pred), ",",
114                                         Branch_Name(branch), ")");
115                     }),
116       "}");
117 }
118 
119 // Returns the predicate of a switch.
GetSwitchPredicate(const Node & switch_node,OutputTensor * pred)120 Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) {
121   const Edge* pred_edge;
122   TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge));
123   // The predicate can be preceded by a identity node. Look through
124   // identity nodes to predicate.
125   while (pred_edge->src()->IsIdentity()) {
126     TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge));
127   }
128   *pred = OutputTensor(pred_edge->src(), pred_edge->src_output());
129   return Status::OK();
130 }
131 
GetSwitchValue(const Node & switch_node,OutputTensor * val)132 Status GetSwitchValue(const Node& switch_node, OutputTensor* val) {
133   const Edge* val_edge;
134   TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge));
135   *val = OutputTensor(val_edge->src(), val_edge->src_output());
136   return Status::OK();
137 }
138 
operator ()(const OutputTensor & lhs,const OutputTensor & rhs) const139 bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs,
140                                             const OutputTensor& rhs) const {
141   return (lhs.node->id() < rhs.node->id()) ||
142          (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index);
143 }
144 
145 struct CondStateLess {
operator ()tensorflow::functionalize_cond::CondStateLess146   bool operator()(const StateMap::CondState::value_type& lhs,
147                   const StateMap::CondState::value_type& rhs) const {
148     if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first))
149       return true;
150     if (lhs.first.node->id() == rhs.first.node->id() &&
151         lhs.first.index == rhs.first.index)
152       return lhs.second < rhs.second;
153     return false;
154   }
155 };
156 
StateMap(Graph * graph)157 StateMap::StateMap(Graph* graph) {
158   node_to_condid_map_.resize(graph->num_node_ids());
159   node_to_ancestorid_map_.resize(graph->num_node_ids());
160   // Initialize the dead state (empty state is designated with a nullptr).
161   dead_id_ = GetCondId(
162       {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)});
163 }
164 
IsDead(StateMap::CondId id) const165 bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; }
166 
IsEmpty(StateMap::CondId id) const167 bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; }
168 
operator ()(const StateMap::CondState & map) const169 size_t StateMap::Hash::operator()(const StateMap::CondState& map) const {
170   if (map.empty()) return 0;
171   // Compute hash of the front element.
172   auto it = map.begin();
173   size_t h = Hash64Combine(OutputTensor::Hash()(it->first),
174                            hash<BranchType>()(it->second));
175   for (++it; it != map.end(); ++it) {
176     // Combine the has with the different elements in the map.
177     h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first),
178                                        hash<BranchType>()(it->second)));
179   }
180   return h;
181 }
182 
operator ()(const StateMap::AncestorState & map) const183 size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const {
184   if (map.empty()) return 0;
185   // Compute hash of the front element.
186   auto it = map.begin();
187   size_t h = AncestorNode::Hash()(*it);
188   for (++it; it != map.end(); ++it) {
189     // Combine the has with the different elements in the map.
190     h = Hash64Combine(h, AncestorNode::Hash()(*it));
191   }
192   return h;
193 }
194 
195 // CondArgNode represents a input to the conditional and its corresponding
196 // switch nodes.
197 struct CondArgNode {
CondArgNodetensorflow::functionalize_cond::CondArgNode198   explicit CondArgNode(Node* src, int src_output)
199       : src(src), src_output(src_output) {}
200 
ToStringtensorflow::functionalize_cond::CondArgNode201   string ToString() const {
202     return absl::StrCat("src=", src->name(), ":", src_output,
203                         " switches=", NodesToString(switches));
204   }
205 
206   Node* src;
207   int src_output;
208   std::array<Node*, 2> branch_copy;
209   std::vector<Node*> switches;
210 };
211 using CondArgNodes = std::vector<CondArgNode>;
212 
DebugString(const CondArgNodes & nodes)213 string DebugString(const CondArgNodes& nodes) {
214   return absl::StrCat(
215       "[",
216       absl::StrJoin(nodes, ", ",
217                     [](string* output, const CondArgNode& node) {
218                       absl::StrAppend(output, node.ToString());
219                     }),
220       "]");
221 }
222 
LookupCondId(const Node * node) const223 StateMap::CondId StateMap::LookupCondId(const Node* node) const {
224   if (node->id() < node_to_condid_map_.size())
225     return node_to_condid_map_[node->id()];
226   return added_node_condid_mapping_.at(node->id());
227 }
228 
GetCondId(const StateMap::CondState & state)229 StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) {
230   if (state.empty()) return nullptr;
231   return &*condstate_set_.insert(state).first;
232 }
233 
ResetCondId(const Node * node,StateMap::CondId id)234 void StateMap::ResetCondId(const Node* node, StateMap::CondId id) {
235   if (node->id() < node_to_condid_map_.size())
236     node_to_condid_map_[node->id()] = id;
237   else
238     added_node_condid_mapping_[node->id()] = id;
239 }
240 
LookupAncestorId(const Node * node) const241 StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const {
242   if (node->id() < node_to_ancestorid_map_.size())
243     return node_to_ancestorid_map_[node->id()];
244   return added_node_ancestorid_mapping_.at(node->id());
245 }
246 
GetAncestorId(const StateMap::AncestorState & state)247 StateMap::AncestorId StateMap::GetAncestorId(
248     const StateMap::AncestorState& state) {
249   if (state.empty()) return nullptr;
250   return &*ancestorstate_set_.insert(state).first;
251 }
252 
ResetAncestorId(const Node * node,StateMap::AncestorId id)253 void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
254   if (node->id() < node_to_ancestorid_map_.size())
255     node_to_ancestorid_map_[node->id()] = id;
256   else
257     added_node_ancestorid_mapping_[node->id()] = id;
258 }
259 
MarkDead(const Node * node)260 void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
261 
CondStateToString(const Node * node) const262 string StateMap::CondStateToString(const Node* node) const {
263   return CondStateToString(LookupCondId(node));
264 }
265 
CondStateToString(StateMap::CondId id) const266 string StateMap::CondStateToString(StateMap::CondId id) const {
267   return DebugString(id);
268 }
269 
AncestorStateToString(const Node * node) const270 string StateMap::AncestorStateToString(const Node* node) const {
271   if (auto id = LookupAncestorId(node)) {
272     return absl::StrCat(
273         "{",
274         absl::StrJoin(*id, ",",
275                       [](string* output, const AncestorNode& ancestor) {
276                         absl::StrAppend(output,
277                                         ancestor.output_tensor.node->name(),
278                                         ":", ancestor.output_tensor.index);
279                       }),
280         "}");
281   }
282   return "{}";
283 }
284 
FunctionalizeCond(Graph * graph,FunctionLibraryDefinition * library)285 FunctionalizeCond::FunctionalizeCond(Graph* graph,
286                                      FunctionLibraryDefinition* library)
287     : state_map_(graph), library_(library), graph_(graph) {}
288 
289 // Class representing the merge/switch nodes that will become a conditional.
290 class Conditional {
291  public:
292   Conditional(OutputTensor predicate, FunctionalizeCond* parent,
293               StateMap* cond_state_map);
294 
295   // Adds merge node that is part of this conditional.
296   Status AddMerge(Node* m);
297 
298   // Constructs an If node from the merge nodes.
299   Status BuildAndReplace(
300       Graph* graph, FunctionLibraryDefinition* library,
301       std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
302 
303  private:
304   // Extracts the then/else bodies: creates new graphs with the nodes
305   // corresponding to the nodes in the then/else branches as of this conditional
306   // as function bodies.
307   Status ExtractBodies(Graph* graph);
308 
309   // Builds the arguments that are the input to the If.
310   Status BuildArgumentNodes();
311 
312   // Builds the If node for the extracted bodies with the given predicate.
313   Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library);
314 
315   // Adds input edges to If node.
316   Status AddInputEdges(
317       Graph* graph,
318       const std::unordered_map<Node*, OutputTensor>& merge_to_replacement);
319 
320   // Adds output edges from If node.
321   // Record new output tensor for all Merge nodes in 'merge_to_replacement'.
322   Status AddOutputEdges(
323       Graph* graph,
324       std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
325 
326   // Adds switch node that is part of this conditional.
327   Status AddSwitch(Node* s);
328 
329   // Adds a switch node along the edge and rewire the edge to go via the switch.
330   Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
331                                 Graph* graph);
332 
333   // Internal name of conditional. The name is based on the first merge node
334   // added.
335   string name() const;
336 
337   // The FunctionalizeCond instance that created this.
338   FunctionalizeCond* parent_;
339 
340   // Mapping between nodes and their cond state.
341   StateMap* state_map_;
342 
343   // The predicate of the conditional.
344   OutputTensor predicate_;
345 
346   // The predicate of the switches of the conditional. This may be different
347   // than predicate (which is initialized from the original graph) as the
348   // predicate could be the output of a newly created If node.
349   OutputTensor switch_predicate_;
350 
351   // Switch nodes in graph that are part of this conditional.
352   std::set<Node*, NodeCmpByNameResourcesLast> switches_;
353 
354   // Merge nodes in graph that are part of this conditional.
355   std::set<Node*, NodeCmpByNameResourcesLast> merges_;
356 
357   // Vector of control inputs from outside the conditional to a node inside.
358   std::vector<Node*> external_control_inputs_;
359   std::vector<Node*> external_control_outputs_;
360 
361   // Graphs corresponding to the then and else branch.
362   std::array<std::unique_ptr<Graph>, 2> bodies_;
363 
364   // Maps from graph_ to the branch body's graph.
365   std::array<std::vector<Node*>, 2> node_maps_;
366 
367   // The argument nodes created for the switches.
368   CondArgNodes cond_arg_nodes_;
369 
370   // The constructed If node.
371   Node* if_node_ = nullptr;
372 
373   // Whether the merge nodes of this conditional have been replaced.
374   bool replaced_ = false;
375 };
376 
Conditional(OutputTensor predicate,FunctionalizeCond * parent,StateMap * cond_state_map)377 Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent,
378                          StateMap* cond_state_map)
379     : parent_(parent), state_map_(cond_state_map), predicate_(predicate) {}
380 
AddMerge(Node * m)381 Status Conditional::AddMerge(Node* m) {
382   merges_.insert(m);
383   return Status::OK();
384 }
385 
AddSwitch(Node * s)386 Status Conditional::AddSwitch(Node* s) {
387   VLOG(5) << "Adding switch " << s->DebugString();
388   OutputTensor predicate;
389   TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate));
390   if (switch_predicate_.node == nullptr) switch_predicate_ = predicate;
391   if (!(switch_predicate_ == predicate)) {
392     return errors::InvalidArgument(
393         "Merge nodes ", NodesToString(merges_),
394         " directly dominated by switch nodes with different predicates (",
395         DebugString(switch_predicate_), " vs ", DebugString(predicate), ").");
396   }
397   switches_.insert(s);
398   parent_->AddSwitchId(s->id());
399   return Status::OK();
400 }
401 
BuildArgumentNodes()402 Status Conditional::BuildArgumentNodes() {
403   VLOG(1) << "Build function arguments";
404   struct Hash {
405     size_t operator()(const std::pair<Node*, int>& item) const {
406       return Hash64Combine(hash<Node*>()(item.first),
407                            std::hash<int>()(item.second));
408     }
409   };
410 
411   std::unordered_map<std::pair<Node*, int>, int, Hash> input_index;
412   for (Node* switch_node : switches_) {
413     const Edge* e;
414     TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e));
415     std::pair<Node*, int> key = std::make_pair(e->src(), e->src_output());
416     if (input_index.find(key) == input_index.end()) {
417       input_index[key] = cond_arg_nodes_.size();
418       cond_arg_nodes_.emplace_back(key.first, key.second);
419     }
420     cond_arg_nodes_.at(input_index.at(key)).switches.push_back(switch_node);
421   }
422   VLOG(5) << "CondArg nodes created: " << DebugString(cond_arg_nodes_);
423 
424   int arg_count = 0;
425   for (CondArgNode& cond_arg_node : cond_arg_nodes_) {
426     DataType dtype = cond_arg_node.src->output_type(cond_arg_node.src_output);
427     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
428       int branch_index = static_cast<int>(branch);
429       TF_RETURN_IF_ERROR(
430           NodeBuilder(absl::StrCat("_Arg", arg_count),
431                       FunctionLibraryDefinition::kArgOp)
432               .Attr("T", dtype)
433               .Attr("index", arg_count)
434               .Finalize(bodies_[branch_index].get(),
435                         &cond_arg_node.branch_copy[branch_index]));
436     }
437     for (Node* node : cond_arg_node.switches) {
438       for (const Edge* e : node->out_edges()) {
439         if (e->IsControlEdge()) continue;
440         int branch_index = e->src_output();
441         Node* src_copy = cond_arg_node.branch_copy[branch_index];
442         Node* dst_copy = node_maps_[branch_index][e->dst()->id()];
443 
444         // The graph may contain dead switch nodes,
445         if (dst_copy == nullptr) continue;
446 
447         TF_RET_CHECK(dst_copy != nullptr)
448             << "Unable to find copied node for " << e->dst()->DebugString()
449             << " on branch " << Branch_Name(BranchType(branch_index));
450         // If the input goes directly to a merge then the merge has
451         // been replaced by a retval so the dst input is 0 instead of
452         // dst_input.
453         int dst_input = IsMerge(e->dst()) ? 0 : e->dst_input();
454         bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input);
455       }
456     }
457     ++arg_count;
458   }
459 
460   // Verify that all retvals have an input.
461   // TODO(jpienaar): One could add a ZerosLike in the branch that doesn't have
462   // input.
463   for (Node* m : merges_) {
464     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
465       bool has_input = false;
466       for (auto e : node_maps_[static_cast<int>(branch)][m->id()]->in_edges()) {
467         if (!e->IsControlEdge()) {
468           has_input = true;
469           break;
470         }
471       }
472       if (!has_input) {
473         return errors::Internal(
474             "Failed to functionalize control flow with merge ",
475             FormatNodeForError(*m), " that doesn't have input on ",
476             Branch_Name(branch), " branch.");
477       }
478     }
479   }
480 
481   return Status::OK();
482 }
483 
AddSwitchNodeAlongEdge(const Edge * edge,BranchType branch,Graph * graph)484 Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
485                                            Graph* graph) {
486   // Previously we had edge:
487   //   src:src_output ---- edge ----> dst:dst_input
488   // post this we have (in graph)
489   //   src:src_output --> switch<pred> --- new_edge --> dst:dst_input
490 
491   // TODO(jpienaar): One could keep a map caching the extra switch nodes added
492   // to avoid adding another switch to feed a value for which a switch was
493   // already added.
494   Node* switch_node;
495   Node* src = edge->src();
496   int src_output = edge->src_output();
497   TF_RETURN_IF_ERROR(
498       NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")),
499                   "Switch")
500           .Input(src, src_output)
501           .Input(const_cast<Node*>(predicate_.node), predicate_.index)
502           .Finalize(graph, &switch_node));
503   state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src));
504   state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src));
505 
506   Node* dst = edge->dst();
507   int dst_input = edge->dst_input();
508   graph->RemoveEdge(edge);
509   graph->AddEdge(switch_node, static_cast<int>(branch), dst, dst_input);
510   return AddSwitch(switch_node);
511 }
512 
ExtractBodies(Graph * graph)513 Status Conditional::ExtractBodies(Graph* graph) {
514   VLOG(2) << "Extracting bodies for " << name();
515   for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) {
516     bodies_[static_cast<int>(b)] =
517         absl::make_unique<Graph>(graph->op_registry());
518   }
519 
520   auto find_branch = [&](const Edge* e) {
521     const auto& id = state_map_->LookupCondId(e->src());
522     return IsSwitch(e->src()) ? BranchType(e->src_output())
523                               : state_map_->FindBranchOf(id, predicate_);
524   };
525 
526   std::array<std::vector<Node*>, 2> stacks;
527   VLOG(5) << "Merges: " << NodesToString(merges_);
528   for (Node* m : merges_) {
529     VLOG(5) << "For merge: " << m->DebugString() << " "
530             << state_map_->CondStateToString(m);
531     for (auto e : m->in_edges()) {
532       if (e->IsControlEdge()) continue;
533       BranchType branch = find_branch(e);
534       TF_RET_CHECK(branch == BranchType::kThenBranch ||
535                    branch == BranchType::kElseBranch)
536           << "Error: " << e->src()->name()
537           << " is not on either then or else branch (" << Branch_Name(branch)
538           << ") for predicate " << DebugString(predicate_) << " ["
539           << DebugString(state_map_->LookupCondId(e->src())) << "].";
540       Node* src = e->src();
541       if (IsSwitch(src)) {
542         // Switch node outputs and dependencies are handled separately.
543         TF_RETURN_IF_ERROR(AddSwitch(src));
544       } else {
545         stacks[static_cast<int>(branch)].push_back(src);
546       }
547     }
548   }
549 
550   for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
551     int branch_index = static_cast<int>(branch);
552     auto output = bodies_[branch_index].get();
553     auto& stack = stacks[branch_index];
554     VLOG(5) << "In branch: " << Branch_Name(branch) << " "
555             << NodesToString(stack);
556     std::vector<bool> visited(graph->num_node_ids(), false);
557     node_maps_[branch_index].resize(graph->num_node_ids(), nullptr);
558     auto& node_map = node_maps_[branch_index];
559 
560     while (!stack.empty()) {
561       Node* n = stack.back();
562       stack.pop_back();
563 
564       if (visited.at(n->id())) continue;
565       visited[n->id()] = true;
566 
567       // Verify output edges and record control edges exitting scope.
568       for (const Edge* e : n->out_edges()) {
569         Node* dst = e->dst();
570         if (IsMerge(dst)) continue;
571         Node* src = e->src();
572 
573         auto dst_id = state_map_->LookupCondId(dst);
574         auto src_id = state_map_->LookupCondId(src);
575         if (dst_id != src_id) {
576           if (e->IsControlEdge()) {
577             external_control_outputs_.push_back(e->src());
578           } else {
579             // Constants are treated specially to workaround the case of
580             // non-dominated constant nodes.
581             if (!IsConstant(src)) {
582               // TODO(b/78882471): A node that feeds into two different
583               // CondState is not necessarily an error so log a warning for now
584               // but revisit to improve the testing to enable making this an
585               // error.
586               LOG(WARNING) << errors::InvalidArgument(
587                   "Graph contains node ", FormatNodeForError(*src),
588                   " that feeds into node ", FormatNodeForError(*dst),
589                   " but these nodes are in different control contexts (",
590                   DebugString(src_id), " vs ", DebugString(dst_id),
591                   " (detected during out edge testing)");
592             }
593           }
594         }
595       }
596 
597       // Copying incomming edges to dst node. Iterate over a copy of the edges
598       // as they could be mutated during iteration.
599       std::vector<const Edge*> in_edges(n->in_edges().begin(),
600                                         n->in_edges().end());
601       for (const Edge* e : in_edges) {
602         Node* src = e->src();
603         // Skip src/dst node.
604         if (!src->IsOp()) continue;
605 
606         Node* dst = e->dst();
607         if (IsSwitch(src)) {
608           // Switch node outputs and dependencies are handled separately.
609           TF_RETURN_IF_ERROR(AddSwitch(src));
610           continue;
611         }
612 
613         // Verify input is from the same context.
614         auto src_id = state_map_->LookupCondId(src);
615         auto dst_id = state_map_->LookupCondId(dst);
616         if (IsMerge(dst) || src_id == dst_id) {
617           // TODO(jpienaar): The merge case can be more strict.
618           if (node_map.at(src->id()) == nullptr) {
619             node_map.at(src->id()) = output->CopyNode(src);
620             stack.push_back(src);
621           }
622         } else if (e->IsControlEdge()) {
623           // Here we have a control flow edge between src and dst that are not
624           // in the same context. This is an external control dependency except
625           // for one case: where the only difference between CondId of e->src()
626           // and CondId of e->dst() is that e->src() has {PRED, kNeither} and
627           // e->dst() has {PRED, kThenBranch/kElseBranch}. This happens in
628           // gradients code for tf.cond(), where e->src() is a control pivot
629           // node for a branch and e->dst() is a data node in that branch.
630           bool is_external_control_input = true;
631           if (!state_map_->IsEmpty(src_id) && !state_map_->IsEmpty(dst_id)) {
632             std::vector<StateMap::CondState::value_type> diff;
633             std::set_symmetric_difference(
634                 src_id->begin(), src_id->end(), dst_id->begin(), dst_id->end(),
635                 std::back_inserter(diff), CondStateLess());
636             if (diff.size() == 2 && diff[0].first == diff[1].first &&
637                 (diff[0].second == BranchType::kNeither ||
638                  diff[1].second == BranchType::kNeither)) {
639               auto src_branch = src_id->find(diff[0].first);
640               if (src_branch != src_id->end() &&
641                   src_branch->second == BranchType::kNeither) {
642                 is_external_control_input = false;
643               }
644             }
645           }
646           if (is_external_control_input) {
647             external_control_inputs_.push_back(src);
648           }
649         } else {
650           // This shouldn't happen, this means we have an external data input
651           // not entering via a switch node. Work around this by for
652           // * constant nodes copy them;
653           // * non-constant nodes, insert a switch along the edge;
654           if (IsConstant(src)) {
655             node_map.at(src->id()) = output->CopyNode(src);
656           } else {
657             StateMap::CondState state = *dst_id;
658             state.erase(predicate_);
659             if (state_map_->GetCondId(state) == src_id) {
660               TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph));
661               continue;
662             } else {
663               return errors::InvalidArgument(
664                   "Graph contains node ", FormatNodeForError(*src),
665                   " that feeds into node ", FormatNodeForError(*dst),
666                   " but these nodes are in different control contexts (",
667                   DebugString(src_id), " vs ", DebugString(dst_id),
668                   " (detected during in edge testing)");
669             }
670           }
671         }
672 
673         Node* src_copy = node_map.at(e->src()->id());
674         int src_output = e->src_output();
675         if (node_map.at(dst->id()) == nullptr) {
676           node_map.at(dst->id()) = output->CopyNode(dst);
677         }
678         Node* dst_copy = node_map.at(e->dst()->id());
679         if (e->IsControlEdge()) {
680           // Skip control inputs from external context.
681           if (src_copy != nullptr) output->AddControlEdge(src_copy, dst_copy);
682         } else {
683           output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
684         }
685       }
686     }
687   }
688 
689   // Build return values from the merge nodes.
690   int index = 0;
691   for (Node* m : merges_) {
692     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
693       int branch_index = static_cast<int>(branch);
694       auto& node_map = node_maps_[branch_index];
695       auto output = bodies_[branch_index].get();
696       TF_ASSIGN_OR_RETURN(node_map[m->id()],
697                           BuildRetvalNode(output, m->output_type(0), index));
698     }
699     ++index;
700 
701     // Connect the input to the merge_ with the retval, except if it is a
702     // Swich node, which is handled separately.
703     for (auto e : m->in_edges()) {
704       if (e->IsControlEdge()) continue;
705       int branch_index = static_cast<int>(find_branch(e));
706       auto& node_map = node_maps_[branch_index];
707       auto output = bodies_[branch_index].get();
708       Node* in = e->src();
709       if (!IsSwitch(in)) {
710         if (node_map.at(in->id()) == nullptr) {
711           node_map[in->id()] = output->CopyNode(in);
712         }
713         output->AddEdge(node_map[in->id()], e->src_output(),
714                         node_map.at(m->id()), 0);
715       }
716     }
717   }
718   return Status::OK();
719 }
720 
BuildIfNode(Graph * graph,FunctionLibraryDefinition * library)721 Status Conditional::BuildIfNode(Graph* graph,
722                                 FunctionLibraryDefinition* library) {
723   VLOG(2) << "Build cond function for " << name();
724   NodeDebugInfo debug_info((*merges_.begin())->def());
725   NodeDefBuilder builder(name(), "If", library, &debug_info);
726   const string branch_name[] = {"else_branch", "then_branch"};
727   for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
728     int branch_index = static_cast<int>(branch);
729     static std::atomic<int64> sequence_num(0LL);
730     int64 id = ++sequence_num;
731 
732     NameAttrList body_name;
733     body_name.set_name(
734         absl::StrCat("_functionalize_if_", branch_name[branch_index], "_", id));
735 
736     VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index]
737             << "): "
738             << DumpGraphToFile(
739                    "functionalize_cond_body_" + branch_name[branch_index],
740                    *bodies_[branch_index], nullptr);
741 
742     FunctionDef body_fdef;
743     TF_RETURN_IF_ERROR(GraphToFunctionDef(*bodies_[branch_index],
744                                           body_name.name(), &body_fdef));
745     TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
746     builder.Attr(branch_name[branch_index], body_name);
747   }
748 
749   VLOG(3) << "Build input type";
750   std::vector<NodeDefBuilder::NodeOut> inputs;
751   DataTypeVector in_arg_types;
752   for (auto& kv : cond_arg_nodes_) {
753     bool inserted = false;
754     for (const Node* arg : kv.switches) {
755       const Edge* in_edge;
756       TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
757       if (in_edge->IsControlEdge()) {
758         builder.ControlInput(in_edge->src()->name());
759       } else {
760         if (!inserted) {
761           DataType dtype = arg->input_type(0);
762           inputs.emplace_back(NodeDefBuilder::NodeOut(
763               in_edge->src()->name(), in_edge->src_output(), dtype));
764           in_arg_types.push_back(dtype);
765           inserted = true;
766         }
767       }
768     }
769   }
770   builder.Attr("Tin", in_arg_types);
771 
772   DataTypeVector out_type;
773   for (const Node* merge : merges_) {
774     DataType dtype = merge->output_type(0);
775     out_type.push_back(dtype);
776   }
777   builder.Attr("Tout", out_type);
778   VLOG(3) << "Build output type: " << DataTypeVectorString(out_type);
779 
780   builder.Attr("Tcond", DT_BOOL);
781   string outside_compilation;
782   if (GetNodeAttr(predicate_.node->def(), kXlaOutsideCompilationAttrName,
783                   &outside_compilation)
784           .ok()) {
785     builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
786   }
787   builder.Device(predicate_.node->assigned_device_name());
788   // Conditional should be the first input ...
789   builder.Input(
790       NodeDefBuilder::NodeOut(predicate_.node->name(), predicate_.index,
791                               predicate_.node->output_type(predicate_.index)));
792   // ... followed by the other inputs.
793   builder.Input(inputs);
794 
795   VLOG(3) << "Build If node";
796   NodeDef if_def;
797   TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
798   TF_ASSIGN_OR_RETURN(if_node_,
799                       parent_->AddIfNode(if_def, *merges_.begin(), predicate_));
800 
801   return Status::OK();
802 }
803 
AddInputEdges(Graph * graph,const std::unordered_map<Node *,OutputTensor> & merge_to_replacement)804 Status Conditional::AddInputEdges(
805     Graph* graph,
806     const std::unordered_map<Node*, OutputTensor>& merge_to_replacement) {
807   VLOG(2) << "AddInputEdges for " << if_node_->name();
808   int index = 0;
809   // Add predicate input.
810   if (predicate_.node->IsMerge()) {
811     // If the predicate is a Merge node, we should not use Merge output as
812     // predicate. Instead, we should use the corresponding If output in
813     // 'merge_to_replacement'. Otherwise, this Conditional's If node is still
814     // connected to the predicate Merge node; and when we call
815     // DeleteReachableAndDeadNodes(), the predicate Merge node and this
816     // Conditional's If node will be removed.
817     auto iter = merge_to_replacement.find(predicate_.node);
818     if (iter == merge_to_replacement.end()) {
819       return errors::Internal("Cannot find replacement for Merge node ",
820                               predicate_.node->name());
821     }
822     graph->AddEdge(iter->second.node, iter->second.index, if_node_, index++);
823   } else {
824     graph->AddEdge(const_cast<Node*>(predicate_.node), predicate_.index,
825                    if_node_, index++);
826   }
827   // Add function body inputs.
828   for (auto& arg : cond_arg_nodes_) {
829     if (arg.src_output == Graph::kControlSlot) {
830       graph->AddControlEdge(arg.src, if_node_);
831     } else {
832       graph->AddEdge(arg.src, arg.src_output, if_node_, index++);
833     }
834   }
835   for (Node* n : external_control_inputs_) {
836     graph->AddControlEdge(n, if_node_);
837   }
838   return Status::OK();
839 }
840 
AddOutputEdges(Graph * graph,std::unordered_map<Node *,OutputTensor> * merge_to_replacement)841 Status Conditional::AddOutputEdges(
842     Graph* graph,
843     std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
844   VLOG(2) << "AddOutputEdges for " << if_node_->name();
845   int i = 0;
846   for (Node* node : merges_) {
847     TF_RETURN_IF_ERROR(parent_->AddIdentityNode(node, if_node_, i));
848     std::vector<const Edge*> edges(node->out_edges().begin(),
849                                    node->out_edges().end());
850     for (const Edge* edge : edges) {
851       Node* dst = edge->dst();
852       int dst_input = edge->dst_input();
853       if (edge->src_output() > 0) {
854         return errors::Unimplemented("Output of index (", edge->src_output(),
855                                      ") of merge node ",
856                                      FormatNodeForError(*node));
857       }
858 
859       bool control_edge = edge->IsControlEdge();
860       graph->RemoveEdge(edge);
861       if (control_edge) {
862         graph->AddControlEdge(if_node_, dst);
863       } else {
864         graph->AddEdge(if_node_, i, dst, dst_input);
865       }
866     }
867 
868     // Record corresponding output tensor in 'merge_to_replacement'.
869     (*merge_to_replacement)[node] = OutputTensor{if_node_, i};
870 
871     ++i;
872   }
873   for (Node* n : external_control_outputs_) {
874     graph->AddControlEdge(if_node_, n);
875   }
876 
877   return Status::OK();
878 }
879 
BuildAndReplace(Graph * graph,FunctionLibraryDefinition * library,std::unordered_map<Node *,OutputTensor> * merge_to_replacement)880 Status Conditional::BuildAndReplace(
881     Graph* graph, FunctionLibraryDefinition* library,
882     std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
883   VLOG(1) << "Build If and replace merge nodes "
884           << NodesToString(this->merges_);
885   if (replaced_) return Status::OK();
886 
887   TF_RETURN_IF_ERROR(ExtractBodies(graph));
888   TF_RETURN_IF_ERROR(BuildArgumentNodes());
889 
890   if (VLOG_IS_ON(3)) {
891     LOG(INFO) << "Extracted bodies:";
892     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
893       int branch_index = static_cast<int>(branch);
894       auto output = bodies_[branch_index].get();
895       LOG(INFO) << Branch_Name(branch) << ": "
896                 << DebugString(output->ToGraphDefDebug());
897     }
898   }
899 
900   TF_RETURN_IF_ERROR(BuildIfNode(graph, library));
901   TF_RETURN_IF_ERROR(AddInputEdges(graph, *merge_to_replacement));
902   TF_RETURN_IF_ERROR(AddOutputEdges(graph, merge_to_replacement));
903   TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
904 
905   // Check that the if_node doesn't feed into itself.
906   TF_RETURN_WITH_CONTEXT_IF_ERROR(
907       CheckNodeNotInCycle(if_node_, graph->num_node_ids()),
908       "Converting to If failed.");
909 
910   replaced_ = true;
911   return Status::OK();
912 }
913 
name() const914 string Conditional::name() const {
915   CHECK(!merges_.empty());
916   return absl::StrCat((*merges_.begin())->name(), "_if");
917 }
918 
AddIdentityNode(const Node * replacee,Node * if_node,int port)919 Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
920                                           int port) {
921   Node* id;
922   TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity")
923                          .Input(if_node, port)
924                          .Finalize(graph_, &id));
925   state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
926   state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
927   return Status::OK();
928 }
929 
AddIfNode(const NodeDef & def,const Node * replacee,const OutputTensor & predicate)930 StatusOr<Node*> FunctionalizeCond::AddIfNode(const NodeDef& def,
931                                              const Node* replacee,
932                                              const OutputTensor& predicate) {
933   Status status;
934   Node* ret = graph_->AddNode(def, &status);
935   TF_RETURN_IF_ERROR(status);
936   VLOG(1) << "Adding If for " << replacee->name();
937   StateMap::CondId id = state_map_.LookupCondId(replacee);
938   if (id) {
939     StateMap::CondState state = *id;
940     state.erase(predicate);
941     state_map_.ResetCondId(ret, state_map_.GetCondId(state));
942   } else {
943     state_map_.ResetCondId(ret, nullptr);
944   }
945 
946   state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee));
947 
948   return ret;
949 }
950 
PropagateUpdatedState(const Node * replacee)951 Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) {
952   VLOG(2) << "Propagating update state for " << replacee->name() << " "
953           << state_map_.CondStateToString(replacee);
954   // Redo topological sort as the order could have changed.
955   // TODO(jpienaar): The original topological order could also be updated
956   // dynamically if needed.
957   std::vector<Node*> rev_topo_order;
958   GetPostOrder(*graph_, &rev_topo_order);
959 
960   // All the outputs of the new node could potentially be updated.
961   std::unordered_set<Node*> changed;
962   for (auto n : replacee->out_nodes())
963     if (n->IsOp()) changed.insert(n);
964 
965   // Iterate through the changed/possible changed nodes in topological order.
966   for (auto it = rev_topo_order.rbegin();
967        it != rev_topo_order.rend() && !changed.empty(); ++it) {
968     if (changed.find(*it) != changed.end()) {
969       // Update the node state.
970       Node* n = *it;
971       StateMap::CondId old_state = state_map_.LookupCondId(n);
972       state_map_.ResetCondId(n, nullptr);
973       TF_RETURN_IF_ERROR(DetermineCondState(n));
974       if (state_map_.LookupCondId(n) != old_state) {
975         for (auto out : n->out_nodes())
976           if (out->IsOp()) changed.insert(out);
977       }
978       changed.erase(n);
979     }
980   }
981   return Status::OK();
982 }
983 
984 // Returns the most restrictive branch of two branches or neither. This is the
985 // meet operator of the BranchType lattice.
MeetBranch(const BranchType & lhs,const BranchType & rhs)986 BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) {
987   if (lhs == rhs) return lhs;
988   if (lhs == BranchType::kNeither) return rhs;
989   if (rhs == BranchType::kNeither) return lhs;
990   if (lhs == BranchType::kBoth) return rhs;
991   if (rhs == BranchType::kBoth) return lhs;
992   return BranchType::kNeither;
993 }
994 
FindBranchOf(CondId id,OutputTensor predicate) const995 BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
996   if (IsEmpty(id)) return BranchType::kNeither;
997   const CondState& nodes = *id;
998   auto it = nodes.find(predicate);
999   if (it == nodes.end()) return BranchType::kNeither;
1000   return it->second;
1001 }
1002 
JoinCondStatesNonMerge(StateMap::CondId src,StateMap::CondId dst)1003 StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
1004     StateMap::CondId src, StateMap::CondId dst) {
1005   VLOG(5) << "Joining src=" << DebugString(src) << " [" << src
1006           << "] and dst=" << DebugString(dst) << " [" << dst << "]";
1007 
1008   if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src;
1009   if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst;
1010 
1011   // Nothing to do if the CondState is the same.
1012   if (src == dst) return src;
1013 
1014   StateMap::CondState both = *src;
1015   for (const auto& kv : *dst) {
1016     auto it = both.find(kv.first);
1017     if (it == both.end()) {
1018       both.insert(kv);
1019     } else {
1020       if (it->second != kv.second) {
1021         if (it->second == BranchType::kNeither) {
1022           // BranchType for 'src' is kNeither. Use the BranchType in 'dst'.
1023           it->second = kv.second;
1024         } else if (kv.second == BranchType::kNeither) {
1025           // BranchType for 'dst' is kNeither. Use the BranchType in 'src'.
1026           // No need to change it->second.
1027         } else {
1028           return errors::InvalidArgument(
1029               "Graph contains node with inputs predicated on incompatible "
1030               "predicates: ",
1031               DebugString(src), " and ", DebugString(dst));
1032         }
1033       }
1034     }
1035   }
1036   return state_map_.GetCondId(both);
1037 }
1038 
JoinCondStatesMerge(Node * merge,StateMap::CondId src,StateMap::CondId dst)1039 StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
1040     Node* merge, StateMap::CondId src, StateMap::CondId dst) {
1041   // Determine the flow state when joining two states for a merge
1042   // node. Combining the two states for a merge node is effectively performing a
1043   // disjunction of the states along the different input edges. For a merge that
1044   // can be transformed into a If the two inputs paths have to have a predicate
1045   // on which they differ (e.g., along one edge predicate `p` has to hold while
1046   // on another it should not). This function first determines this predicate
1047   // and then the resultant state is the common path between the two inputs
1048   // followed by s(p, both).
1049   VLOG(4) << "Joining (for merge) " << DebugString(src) << " and "
1050           << DebugString(dst);
1051   if (state_map_.IsEmpty(dst)) return src;
1052   if (state_map_.IsEmpty(src)) {
1053     return errors::Internal("Merge node ", merge->name(),
1054                             " has input that's not in any CondContext.");
1055   }
1056 
1057   if (state_map_.IsDead(src)) return src;
1058   if (state_map_.IsDead(dst)) return dst;
1059 
1060   std::vector<StateMap::CondState::value_type> diff;
1061   StateMap::CondState merged;
1062   std::set_symmetric_difference(src->begin(), src->end(), dst->begin(),
1063                                 dst->end(), std::back_inserter(diff),
1064                                 CondStateLess());
1065   std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(),
1066                         std::inserter(merged, merged.begin()), CondStateLess());
1067 
1068   // Update mapping from merge node to predicate.
1069   if (diff.size() == 2) {
1070     auto pred = diff[0].first;
1071     bool different_branches = (diff[0].second != diff[1].second) &&
1072                               (diff[0].second == BranchType::kThenBranch ||
1073                                diff[0].second == BranchType::kElseBranch) &&
1074                               (diff[1].second == BranchType::kThenBranch ||
1075                                diff[1].second == BranchType::kElseBranch);
1076     if (!(pred == diff[1].first) || !different_branches)
1077       return errors::InvalidArgument(
1078           "Unable to determine predicate for merge node");
1079     merge_to_predicate_[merge] = pred;
1080   } else {
1081     return errors::InvalidArgument(
1082         "Merge of two inputs that differ on more than one predicate ",
1083         DebugString(src), " and ", DebugString(dst));
1084   }
1085 
1086   return state_map_.GetCondId(merged);
1087 }
1088 
StateAlongEdge(const Edge * e)1089 StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
1090   Node* src = e->src();
1091   StateMap::CondId id = state_map_.LookupCondId(e->src());
1092 
1093   // Dead nodes only propagate dead state.
1094   if (state_map_.IsDead(id)) return id;
1095 
1096   if (IsSwitch(src)) {
1097     StateMap::CondState state;
1098     if (id != nullptr) state = *id;
1099     OutputTensor predicate;
1100     TF_CHECK_OK(GetSwitchPredicate(*src, &predicate));
1101     if (e->IsControlEdge()) {
1102       // In gradients of tf.cond(), in each branch, we have a NoOp node as
1103       // control pivot. These NoOp nodes have control dependency from Switch
1104       // node. If we don't record this into CondState, branches might have
1105       // incorrect CondState (e.g. if the branch only has a Const data node).
1106       // We set it to kNeither because there is no way to tell whether it's
1107       // for true branch or false branch. This node's desendents might have
1108       // other incoming edges with defined BranchType, and we correctly handle
1109       // merging kNeither with other defined BranchType in StateAlongEdge().
1110       state[predicate] = BranchType::kNeither;
1111     } else {
1112       state[predicate] = BranchType(e->src_output());
1113     }
1114     return state_map_.GetCondId(state);
1115   }
1116   return id;
1117 }
1118 
DetermineCondStateMerge(Node * dst)1119 Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
1120   // Only Merge nodes with two inputs are supported, but if this is a redundant
1121   // merge, then the dead edge may already have been removed (if due to a
1122   // switch) and so the input count would be incorrect.
1123   if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK();
1124 
1125   int data_inputs = 0;
1126   for (auto e : dst->in_edges()) {
1127     Node* src = e->src();
1128     VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " "
1129             << state_map_.CondStateToString(src);
1130     if (!src->IsOp()) continue;
1131     if (!e->IsControlEdge()) ++data_inputs;
1132 
1133     StateMap::CondId prop = StateAlongEdge(e);
1134     auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst));
1135     TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1136                                     FormatNodeForError(*dst));
1137     state_map_.ResetCondId(dst, id_or.ValueOrDie());
1138   }
1139 
1140   // Incomplete Merge nodes are not supported.
1141   if (data_inputs != 2) {
1142     return errors::Unimplemented(
1143         dst->name(), " only has ", data_inputs,
1144         " inputs, while only merge nodes with two inputs supported.");
1145   }
1146   return Status::OK();
1147 }
1148 
DetermineCondStateNonMerge(Node * dst)1149 Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) {
1150   // Handle non-merge join.
1151   for (auto e : dst->in_edges()) {
1152     VLOG(4) << "Processing forward flow for: " << e->DebugString() << " "
1153             << state_map_.CondStateToString(dst);
1154     Node* src = e->src();
1155     if (!src->IsOp()) continue;
1156 
1157     // Joining the state between the current and propagated state.
1158     StateMap::CondId prop = StateAlongEdge(e);
1159     auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst));
1160     TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1161                                     FormatNodeForError(*dst));
1162     state_map_.ResetCondId(dst, id_or.ValueOrDie());
1163   }
1164   return Status::OK();
1165 }
1166 
RemoveRedundantMerge(Node * node)1167 Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
1168   // Handle redundant merge nodes. A merge node is considered redundant if
1169   // one input edge is dead while the other has a value.
1170   if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK();
1171 
1172   const Edge* non_dead_edge = nullptr;
1173   for (auto e : node->in_edges()) {
1174     if (e->IsControlEdge()) continue;
1175     Node* src = e->src();
1176 
1177     // Handle merge with dead state.
1178     const auto& src_id = state_map_.LookupCondId(src);
1179     if (!state_map_.IsDead(src_id)) {
1180       non_dead_edge = e;
1181       break;
1182     }
1183   }
1184 
1185   if (non_dead_edge == nullptr) {
1186     return errors::InvalidArgument("Merge node ", FormatNodeForError(*node),
1187                                    " has no non-dead inputs.");
1188   }
1189   state_map_.MarkDead(node);
1190   VLOG(5) << "removing redundant merge: " << node->name();
1191   while (!node->out_edges().empty()) {
1192     const Edge* oe = *node->out_edges().begin();
1193     Node* dst_node = oe->dst();
1194     int dst_port = oe->dst_input();
1195     graph_->RemoveEdge(oe);
1196     graph_->AddEdge(non_dead_edge->src(),
1197                     dst_port == Graph::kControlSlot
1198                         ? Graph::kControlSlot
1199                         : non_dead_edge->src_output(),
1200                     dst_node, dst_port);
1201   }
1202   return Status::OK();
1203 }
1204 
RemoveRedundantSwitch(Node * node)1205 Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
1206   // Handle redundant switch nodes. A switch node is considered redundant if
1207   // the predicate of the switch already holds on the current branch. E.g., if
1208   // p is the predicate of the switch but p is already known to hold on this
1209   // branch, then the switch can be removed and the dead state propagated
1210   // along one. The checking of predicate is based on the exact predicate
1211   // (rather than boolean equivalence) and aimed at redundant switches as
1212   // currently generated by gradient code.
1213   StateMap::CondId dst_id = state_map_.LookupCondId(node);
1214   if (state_map_.IsDead(dst_id)) return Status::OK();
1215 
1216   BranchType b;
1217   OutputTensor pred;
1218   TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred));
1219 
1220   // Determine if we are already on a branch where the switch predicate is
1221   // true/false. Consider both the data and predicate to determine if the
1222   // node is redundant (skipping over identity node).
1223   b = state_map_.FindBranchOf(dst_id, pred);
1224   if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) {
1225     OutputTensor val;
1226     const Edge* e;
1227     TF_RETURN_IF_ERROR(node->input_edge(0, &e));
1228     val = OutputTensor(e->src(), e->src_output());
1229     while (IsIdentity(val.node)) {
1230       TF_RETURN_IF_ERROR(val.node->input_edge(0, &e));
1231       val = OutputTensor(e->src(), e->src_output());
1232     }
1233     b = state_map_.FindBranchOf(dst_id, val);
1234     if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
1235       return Status::OK();
1236   }
1237 
1238   VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " "
1239           << DebugString(dst_id);
1240   const Edge* value_edge;
1241   TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge));
1242   Node* val_node = value_edge->src();
1243   int val_port = value_edge->src_output();
1244   while (!node->out_edges().empty()) {
1245     auto e = *node->out_edges().begin();
1246     Node* dst_node = e->dst();
1247     int dst_input = e->dst_input();
1248     int switch_branch = e->src_output();
1249     graph_->RemoveEdge(e);
1250     if (switch_branch == Graph::kControlSlot) {
1251       if (IsMerge(dst_node)) {
1252         auto id_or = JoinCondStatesMerge(dst_node, dst_id,
1253                                          state_map_.LookupCondId(dst_node));
1254         TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1255                                         FormatNodeForError(*dst_node));
1256         state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
1257       } else {
1258         auto id_or =
1259             JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node));
1260         TF_RETURN_IF_ERROR(id_or.status());
1261         state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
1262       }
1263     } else if (BranchType(switch_branch) != b) {
1264       state_map_.MarkDead(dst_node);
1265       continue;
1266     }
1267     graph_->AddEdge(
1268         val_node,
1269         switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port,
1270         dst_node, dst_input);
1271   }
1272   return Status::OK();
1273 }
1274 
DetermineStates(std::vector<Node * > rev_topo_order)1275 Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
1276   // The state that is propagated along the given edge.
1277   for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) {
1278     Node* dst = *it;
1279     TF_RETURN_IF_ERROR(DetermineCondState(dst));
1280     TF_RETURN_IF_ERROR(DetermineAncestorState(dst));
1281     if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst));
1282     if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst));
1283 
1284     VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
1285             << " @ " << state_map_.AncestorStateToString(dst);
1286     if (VLOG_IS_ON(10)) DumpGraphWithCondState("it");
1287   }
1288   return Status::OK();
1289 }
1290 
DetermineAncestorState(Node * dst)1291 Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
1292   StateMap::AncestorId id = nullptr;
1293   StateMap::AncestorState state;
1294 
1295   auto insert = [&](StateMap::AncestorId id, Node* src) {
1296     auto other_id = state_map_.LookupAncestorId(src);
1297     if (other_id != id && other_id != nullptr) {
1298       state.insert(other_id->begin(), other_id->end());
1299     }
1300     if (IsMerge(src)) {
1301       state.insert({{src, 0}, AncestorNode::AncestorNodeType::kMerge});
1302     } else if (IsSwitch(src)) {
1303       OutputTensor pred;
1304       // For dead switch nodes, GetSwitchPredicate() will fail, and we use
1305       // the switch node directly as ancestor.
1306       if (GetSwitchPredicate(*src, &pred).ok()) {
1307         state.insert({pred, AncestorNode::AncestorNodeType::kPred});
1308       } else {
1309         state.insert({{src, 0}, AncestorNode::AncestorNodeType::kSwitch});
1310       }
1311     }
1312     return state_map_.GetAncestorId(state);
1313   };
1314 
1315   // Compute the union of all the switch/merge nodes that affects the input of
1316   // dst.
1317   for (auto e : dst->in_edges()) {
1318     Node* src = e->src();
1319     id = insert(id, src);
1320   }
1321   state_map_.ResetAncestorId(dst, id);
1322   return Status::OK();
1323 }
1324 
DeleteReachableAndDeadNodes(const std::vector<Node * > & merge_order)1325 void FunctionalizeCond::DeleteReachableAndDeadNodes(
1326     const std::vector<Node*>& merge_order) {
1327   // Delete all nodes that have been extracted or are reachable from
1328   // deleted/dead nodes. The input and outgoing edges should have already been
1329   // removed.
1330   std::deque<int> delete_nodes;
1331   std::vector<bool> deleted(graph_->num_node_ids(), false);
1332   // Don't try to delete source or sink nodes.
1333   deleted[graph_->kSourceId] = true;
1334   deleted[graph_->kSinkId] = true;
1335 
1336   // All remaining Switch nodes are not reachable from a Merge node and
1337   // removed. This is to account for dead Switch nodes.
1338   for (int s_id : switch_ids_) {
1339     Node* s = graph_->FindNodeId(s_id);
1340     if (s == nullptr) continue;
1341     for (const Edge* e : s->out_edges()) {
1342       // Control outputs of switch nodes (which are unconditionally executed if
1343       // the switch is) are not removed as they need not be part of a
1344       // conditional.
1345       if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
1346     }
1347     deleted[s_id] = true;
1348     graph_->RemoveNode(s);
1349   }
1350 
1351   // All merge nodes should have been transformed at this point and we remove
1352   // them from the graph here.
1353   for (Node* m : merge_order) {
1354     for (const Edge* e : m->out_edges()) {
1355       // Similar to control outputs of switch nodes don't remove control
1356       // outputs of merge nodes.
1357       // TODO(jpienaar): Check cases where output edges still exist here vs
1358       // being removed in AddOutputEdges.
1359       if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
1360     }
1361     deleted[m->id()] = true;
1362     graph_->RemoveNode(m);
1363   }
1364 
1365   // Enqueue all the dead nodes.
1366   for (Node* n : graph_->nodes()) {
1367     if (state_map_.IsDead(state_map_.LookupCondId(n))) {
1368       delete_nodes.push_back(n->id());
1369     }
1370   }
1371 
1372   while (!delete_nodes.empty()) {
1373     int d_id = delete_nodes.front();
1374     delete_nodes.pop_front();
1375     if (deleted[d_id]) continue;
1376     Node* d = graph_->FindNodeId(d_id);
1377     // Switch and Merge nodes could have been deleted already.
1378     if (d == nullptr) continue;
1379     for (const Edge* e : d->out_edges()) {
1380       delete_nodes.push_back(e->dst()->id());
1381     }
1382     deleted[d_id] = true;
1383     graph_->RemoveNode(d);
1384   }
1385 }
1386 
SortMergeNodes(std::vector<Node * > * merge_order)1387 void FunctionalizeCond::SortMergeNodes(std::vector<Node*>* merge_order) {
1388   // Sort merge nodes by nesting depth.
1389   using sort_pair = std::pair<int, Node*>;
1390   std::vector<sort_pair> inner_to_outer_merge_order;
1391   inner_to_outer_merge_order.reserve(merge_order->size());
1392   for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) {
1393     Node* merge = *it;
1394     StateMap::CondId id = state_map_.LookupCondId(merge);
1395     int depth = id != nullptr ? id->size() : 0;
1396     inner_to_outer_merge_order.emplace_back(depth, merge);
1397   }
1398   std::stable_sort(
1399       inner_to_outer_merge_order.begin(), inner_to_outer_merge_order.end(),
1400       [](sort_pair lhs, sort_pair rhs) { return lhs.first > rhs.first; });
1401   merge_order->clear();
1402   for (sort_pair t : inner_to_outer_merge_order) {
1403     merge_order->push_back(t.second);
1404   }
1405 }
1406 
FunctionalizeInternal()1407 Status FunctionalizeCond::FunctionalizeInternal() {
1408   // The general approach for converting a tf.cond (as lowered via switch/merge
1409   // nodes) to a functional if is as follows:
1410   // 1. Determine the topological order and collect all the switch and merge
1411   // nodes in the graph;
1412   // 2. Compute the predicates and dominance structure for all the nodes in the
1413   // graph - this includes which predicate must be true for a op to execute
1414   // (predicate values are considered directly rather than attempting to
1415   // determine deeper equivalence). We shall refer to this structure as the
1416   // CondState;
1417   // 3. Sort the merge nodes by nesting depth;
1418   // 4. Extract merge nodes together that have the same CondState and
1419   // AncestorState from the innermost to the outermost into IfOps;
1420   // Note: In the above only nodes that feed into a merge node will be
1421   // considered for functionalization.
1422 
1423   // Perform a DFS over the graph and
1424   // * Determine the reverse topological order of the nodes (there should be no
1425   //   cycles at this point so the post-order numbering corresponds to the
1426   //   reverse topological sorting);
1427   // * Record reverse topological for merge and switch nodes;
1428   std::vector<Node*> rev_topo_order;
1429   std::vector<Node*> merge_order;
1430   DFS(*graph_, nullptr, [&](Node* n) {
1431     if (IsSwitch(n)) {
1432       AddSwitchId(n->id());
1433     }
1434     if (IsMerge(n)) {
1435       merge_order.push_back(n);
1436     }
1437     if (n->IsOp()) {
1438       rev_topo_order.push_back(n);
1439     }
1440   });
1441 
1442   // No merges to functionalize.
1443   if (merge_order.empty()) {
1444     // No merges mean no switch values consumed (as only considering values
1445     // fetchable as output of merge);
1446     DeleteReachableAndDeadNodes(merge_order);
1447     return Status::OK();
1448   }
1449 
1450   TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
1451   if (VLOG_IS_ON(4)) DumpGraphWithCondState("id");
1452 
1453   // Sort the merge nodes from innermost outwards.
1454   SortMergeNodes(&merge_order);
1455 
1456   // Cluster merge nodes by (CondId, AncestorId, predicate) in order of
1457   // nesting. (CondId, AncestorId) is not enough, e.g.
1458   //   pred1 = array_ops.placeholder(dtypes.bool, name='pred1')
1459   //   pred2 = array_ops.placeholder(dtypes.bool, name='pred2')
1460   //   cond1 = control_flow_ops.cond(pred1, ...)
1461   //   cond2 = control_flow_ops.cond(pred2, ...)
1462   //   cond3 = control_flow_ops.cond(pred1, use cond1 and cond2)
1463   //   cond4 = control_flow_ops.cond(pred2, use cond1 and cond2)
1464   // cond3 and cond4 have the same (CondId, AncestorId), but they should not
1465   // be merged into one "If" node (because they have different predicates).
1466   std::deque<std::vector<Node*>> merge_clusters;
1467   std::map<ClusterTuple, int, ClusterTupleLessThan> merge_cluster_index;
1468   for (Node* merge : merge_order) {
1469     auto cond_id = state_map_.LookupCondId(merge);
1470     if (state_map_.IsDead(cond_id)) continue;
1471 
1472     auto predicate = merge_to_predicate_.find(merge);
1473     if (predicate == merge_to_predicate_.end()) {
1474       return errors::Internal("Cannot find predicate for Merge node ",
1475                               merge->name());
1476     }
1477 
1478     ClusterTuple key = std::make_tuple(
1479         cond_id, state_map_.LookupAncestorId(merge), predicate->second);
1480     auto idx = merge_cluster_index.find(key);
1481     if (idx == merge_cluster_index.end()) {
1482       merge_cluster_index[key] = merge_clusters.size();
1483       merge_clusters.push_back({merge});
1484     } else {
1485       merge_clusters[idx->second].emplace_back(merge);
1486     }
1487   }
1488 
1489   // Extract the conditionals from inner most to outer most. Extracting from
1490   // innermost to outermost enables the extraction pass to stop once it
1491   // encounters a Switch node instead of having to keep track of Switch/Merge
1492   // nodes seen.
1493   for (const auto& cluster : merge_clusters) {
1494     // Construct a Conditional with the predicate of the merge.
1495     Conditional cond(merge_to_predicate_.at(cluster.front()), this,
1496                      &state_map_);
1497     for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge));
1498     TF_RETURN_IF_ERROR(
1499         cond.BuildAndReplace(graph_, library_, &merge_to_replacement_));
1500 
1501     if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
1502   }
1503 
1504   DeleteReachableAndDeadNodes(merge_order);
1505 
1506   return Status::OK();
1507 }
1508 
DumpGraphWithCondState(const string & name)1509 void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
1510   const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup";
1511 
1512   for (Node* n : graph_->nodes()) {
1513     n->ClearAttr(kCondGroupDebugAttr);
1514     n->AddAttr(kCondGroupDebugAttr,
1515                absl::StrCat(state_map_.CondStateToString(n), "_",
1516                             state_map_.AncestorStateToString(n)));
1517   }
1518   LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
1519             << DumpGraphToFile(absl::StrCat("functionalize_cond_", name),
1520                                *graph_, library_);
1521 }
1522 
AddSwitchId(int switch_id)1523 void FunctionalizeCond::AddSwitchId(int switch_id) {
1524   switch_ids_.push_back(switch_id);
1525 }
1526 
Functionalize(Graph * graph,FunctionLibraryDefinition * library)1527 Status FunctionalizeCond::Functionalize(Graph* graph,
1528                                         FunctionLibraryDefinition* library) {
1529   VLOG(1) << "FunctionalizeCond::Functionalize";
1530   FunctionalizeCond fc(graph, library);
1531   return fc.FunctionalizeInternal();
1532 }
1533 
1534 }  // namespace functionalize_cond
1535 
FunctionalizeCond(Graph * graph,FunctionLibraryDefinition * library)1536 Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) {
1537   // FunctionalizeControlFlow is invoked for every function, so the loops's
1538   // bodies and conditionals that were extracted into functions will be handled
1539   // in successive invocations.
1540   return functionalize_cond::FunctionalizeCond::Functionalize(graph, library);
1541 }
1542 
1543 }  // namespace tensorflow
1544