1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/jit/union_find.h"
28 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/framework/graph_to_functiondef.h"
32 #include "tensorflow/core/framework/node_def_builder.h"
33 #include "tensorflow/core/graph/algorithm.h"
34 #include "tensorflow/core/graph/control_flow.h"
35 #include "tensorflow/core/graph/node_builder.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/hash/hash.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/util/dump_graph.h"
40
41 using xla::StatusOr;
42
43 namespace tensorflow {
44 namespace functionalize_cond {
45
operator <(const AncestorNode & other) const46 bool AncestorNode::operator<(const AncestorNode& other) const {
47 return (output_tensor.node->id() < other.output_tensor.node->id()) ||
48 (output_tensor.node->id() == other.output_tensor.node->id() &&
49 output_tensor.index < other.output_tensor.index) ||
50 (output_tensor.node->id() == other.output_tensor.node->id() &&
51 output_tensor.index == other.output_tensor.index &&
52 type < other.type);
53 }
54
operator ==(const AncestorNode & other) const55 bool AncestorNode::operator==(const AncestorNode& other) const {
56 return output_tensor.node->id() == other.output_tensor.node->id() &&
57 output_tensor.index == other.output_tensor.index && type == other.type;
58 }
59
operator ()(const AncestorNode & ancestor) const60 size_t AncestorNode::Hash::operator()(const AncestorNode& ancestor) const {
61 size_t h = std::hash<int>()(ancestor.output_tensor.node->id());
62 h = Hash64Combine(h, std::hash<int>()(ancestor.output_tensor.index));
63 return Hash64Combine(h, std::hash<int>()(static_cast<int>(ancestor.type)));
64 }
65
66 typedef std::tuple<StateMap::CondId, StateMap::AncestorId, OutputTensor>
67 ClusterTuple;
68
69 struct ClusterTupleLessThan {
operator ()tensorflow::functionalize_cond::ClusterTupleLessThan70 bool operator()(const ClusterTuple& a, const ClusterTuple& b) const {
71 if (std::tie(std::get<0>(a), std::get<1>(a)) <
72 std::tie(std::get<0>(b), std::get<1>(b))) {
73 return true;
74 } else if (std::tie(std::get<0>(a), std::get<1>(a)) ==
75 std::tie(std::get<0>(b), std::get<1>(b))) {
76 return StateMap::OutputTensorLess()(std::get<2>(a), std::get<2>(b));
77 } else {
78 return false;
79 }
80 }
81 };
82
83 // TODO(jpienaar): Move to OutputTensor.
DebugString(const OutputTensor & tensor)84 string DebugString(const OutputTensor& tensor) {
85 return absl::StrCat(tensor.node->name(), ":", tensor.index);
86 }
87
Branch_Name(BranchType b)88 string Branch_Name(BranchType b) {
89 switch (b) {
90 case BranchType::kElseBranch:
91 return "else";
92 case BranchType::kThenBranch:
93 return "then";
94 case BranchType::kBoth:
95 return "both";
96 case BranchType::kNeither:
97 return "neither";
98 }
99 }
100
DebugString(StateMap::CondId cond_state)101 string DebugString(StateMap::CondId cond_state) {
102 if (cond_state == nullptr || cond_state->empty()) return "{}";
103 using value_type = StateMap::CondState::value_type;
104 return absl::StrCat(
105 "{",
106 absl::StrJoin(*cond_state, ", ",
107 [](string* output, const value_type& pred_branch) {
108 const OutputTensor& pred = pred_branch.first;
109 const BranchType& branch = pred_branch.second;
110 if (branch == BranchType::kNeither)
111 absl::StrAppend(output, "d");
112 else
113 absl::StrAppend(output, "s(", DebugString(pred), ",",
114 Branch_Name(branch), ")");
115 }),
116 "}");
117 }
118
119 // Returns the predicate of a switch.
GetSwitchPredicate(const Node & switch_node,OutputTensor * pred)120 Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) {
121 const Edge* pred_edge;
122 TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge));
123 // The predicate can be preceded by a identity node. Look through
124 // identity nodes to predicate.
125 while (pred_edge->src()->IsIdentity()) {
126 TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge));
127 }
128 *pred = OutputTensor(pred_edge->src(), pred_edge->src_output());
129 return Status::OK();
130 }
131
GetSwitchValue(const Node & switch_node,OutputTensor * val)132 Status GetSwitchValue(const Node& switch_node, OutputTensor* val) {
133 const Edge* val_edge;
134 TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge));
135 *val = OutputTensor(val_edge->src(), val_edge->src_output());
136 return Status::OK();
137 }
138
operator ()(const OutputTensor & lhs,const OutputTensor & rhs) const139 bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs,
140 const OutputTensor& rhs) const {
141 return (lhs.node->id() < rhs.node->id()) ||
142 (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index);
143 }
144
145 struct CondStateLess {
operator ()tensorflow::functionalize_cond::CondStateLess146 bool operator()(const StateMap::CondState::value_type& lhs,
147 const StateMap::CondState::value_type& rhs) const {
148 if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first))
149 return true;
150 if (lhs.first.node->id() == rhs.first.node->id() &&
151 lhs.first.index == rhs.first.index)
152 return lhs.second < rhs.second;
153 return false;
154 }
155 };
156
StateMap(Graph * graph)157 StateMap::StateMap(Graph* graph) {
158 node_to_condid_map_.resize(graph->num_node_ids());
159 node_to_ancestorid_map_.resize(graph->num_node_ids());
160 // Initialize the dead state (empty state is designated with a nullptr).
161 dead_id_ = GetCondId(
162 {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)});
163 }
164
IsDead(StateMap::CondId id) const165 bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; }
166
IsEmpty(StateMap::CondId id) const167 bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; }
168
operator ()(const StateMap::CondState & map) const169 size_t StateMap::Hash::operator()(const StateMap::CondState& map) const {
170 if (map.empty()) return 0;
171 // Compute hash of the front element.
172 auto it = map.begin();
173 size_t h = Hash64Combine(OutputTensor::Hash()(it->first),
174 hash<BranchType>()(it->second));
175 for (++it; it != map.end(); ++it) {
176 // Combine the has with the different elements in the map.
177 h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first),
178 hash<BranchType>()(it->second)));
179 }
180 return h;
181 }
182
operator ()(const StateMap::AncestorState & map) const183 size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const {
184 if (map.empty()) return 0;
185 // Compute hash of the front element.
186 auto it = map.begin();
187 size_t h = AncestorNode::Hash()(*it);
188 for (++it; it != map.end(); ++it) {
189 // Combine the has with the different elements in the map.
190 h = Hash64Combine(h, AncestorNode::Hash()(*it));
191 }
192 return h;
193 }
194
195 // CondArgNode represents a input to the conditional and its corresponding
196 // switch nodes.
197 struct CondArgNode {
CondArgNodetensorflow::functionalize_cond::CondArgNode198 explicit CondArgNode(Node* src, int src_output)
199 : src(src), src_output(src_output) {}
200
ToStringtensorflow::functionalize_cond::CondArgNode201 string ToString() const {
202 return absl::StrCat("src=", src->name(), ":", src_output,
203 " switches=", NodesToString(switches));
204 }
205
206 Node* src;
207 int src_output;
208 std::array<Node*, 2> branch_copy;
209 std::vector<Node*> switches;
210 };
211 using CondArgNodes = std::vector<CondArgNode>;
212
DebugString(const CondArgNodes & nodes)213 string DebugString(const CondArgNodes& nodes) {
214 return absl::StrCat(
215 "[",
216 absl::StrJoin(nodes, ", ",
217 [](string* output, const CondArgNode& node) {
218 absl::StrAppend(output, node.ToString());
219 }),
220 "]");
221 }
222
LookupCondId(const Node * node) const223 StateMap::CondId StateMap::LookupCondId(const Node* node) const {
224 if (node->id() < node_to_condid_map_.size())
225 return node_to_condid_map_[node->id()];
226 return added_node_condid_mapping_.at(node->id());
227 }
228
GetCondId(const StateMap::CondState & state)229 StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) {
230 if (state.empty()) return nullptr;
231 return &*condstate_set_.insert(state).first;
232 }
233
ResetCondId(const Node * node,StateMap::CondId id)234 void StateMap::ResetCondId(const Node* node, StateMap::CondId id) {
235 if (node->id() < node_to_condid_map_.size())
236 node_to_condid_map_[node->id()] = id;
237 else
238 added_node_condid_mapping_[node->id()] = id;
239 }
240
LookupAncestorId(const Node * node) const241 StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const {
242 if (node->id() < node_to_ancestorid_map_.size())
243 return node_to_ancestorid_map_[node->id()];
244 return added_node_ancestorid_mapping_.at(node->id());
245 }
246
GetAncestorId(const StateMap::AncestorState & state)247 StateMap::AncestorId StateMap::GetAncestorId(
248 const StateMap::AncestorState& state) {
249 if (state.empty()) return nullptr;
250 return &*ancestorstate_set_.insert(state).first;
251 }
252
ResetAncestorId(const Node * node,StateMap::AncestorId id)253 void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
254 if (node->id() < node_to_ancestorid_map_.size())
255 node_to_ancestorid_map_[node->id()] = id;
256 else
257 added_node_ancestorid_mapping_[node->id()] = id;
258 }
259
MarkDead(const Node * node)260 void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
261
CondStateToString(const Node * node) const262 string StateMap::CondStateToString(const Node* node) const {
263 return CondStateToString(LookupCondId(node));
264 }
265
CondStateToString(StateMap::CondId id) const266 string StateMap::CondStateToString(StateMap::CondId id) const {
267 return DebugString(id);
268 }
269
AncestorStateToString(const Node * node) const270 string StateMap::AncestorStateToString(const Node* node) const {
271 if (auto id = LookupAncestorId(node)) {
272 return absl::StrCat(
273 "{",
274 absl::StrJoin(*id, ",",
275 [](string* output, const AncestorNode& ancestor) {
276 absl::StrAppend(output,
277 ancestor.output_tensor.node->name(),
278 ":", ancestor.output_tensor.index);
279 }),
280 "}");
281 }
282 return "{}";
283 }
284
FunctionalizeCond(Graph * graph,FunctionLibraryDefinition * library)285 FunctionalizeCond::FunctionalizeCond(Graph* graph,
286 FunctionLibraryDefinition* library)
287 : state_map_(graph), library_(library), graph_(graph) {}
288
289 // Class representing the merge/switch nodes that will become a conditional.
290 class Conditional {
291 public:
292 Conditional(OutputTensor predicate, FunctionalizeCond* parent,
293 StateMap* cond_state_map);
294
295 // Adds merge node that is part of this conditional.
296 Status AddMerge(Node* m);
297
298 // Constructs an If node from the merge nodes.
299 Status BuildAndReplace(
300 Graph* graph, FunctionLibraryDefinition* library,
301 std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
302
303 private:
304 // Extracts the then/else bodies: creates new graphs with the nodes
305 // corresponding to the nodes in the then/else branches as of this conditional
306 // as function bodies.
307 Status ExtractBodies(Graph* graph);
308
309 // Builds the arguments that are the input to the If.
310 Status BuildArgumentNodes();
311
312 // Builds the If node for the extracted bodies with the given predicate.
313 Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library);
314
315 // Adds input edges to If node.
316 Status AddInputEdges(
317 Graph* graph,
318 const std::unordered_map<Node*, OutputTensor>& merge_to_replacement);
319
320 // Adds output edges from If node.
321 // Record new output tensor for all Merge nodes in 'merge_to_replacement'.
322 Status AddOutputEdges(
323 Graph* graph,
324 std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
325
326 // Adds switch node that is part of this conditional.
327 Status AddSwitch(Node* s);
328
329 // Adds a switch node along the edge and rewire the edge to go via the switch.
330 Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
331 Graph* graph);
332
333 // Internal name of conditional. The name is based on the first merge node
334 // added.
335 string name() const;
336
337 // The FunctionalizeCond instance that created this.
338 FunctionalizeCond* parent_;
339
340 // Mapping between nodes and their cond state.
341 StateMap* state_map_;
342
343 // The predicate of the conditional.
344 OutputTensor predicate_;
345
346 // The predicate of the switches of the conditional. This may be different
347 // than predicate (which is initialized from the original graph) as the
348 // predicate could be the output of a newly created If node.
349 OutputTensor switch_predicate_;
350
351 // Switch nodes in graph that are part of this conditional.
352 std::set<Node*, NodeCmpByNameResourcesLast> switches_;
353
354 // Merge nodes in graph that are part of this conditional.
355 std::set<Node*, NodeCmpByNameResourcesLast> merges_;
356
357 // Vector of control inputs from outside the conditional to a node inside.
358 std::vector<Node*> external_control_inputs_;
359 std::vector<Node*> external_control_outputs_;
360
361 // Graphs corresponding to the then and else branch.
362 std::array<std::unique_ptr<Graph>, 2> bodies_;
363
364 // Maps from graph_ to the branch body's graph.
365 std::array<std::vector<Node*>, 2> node_maps_;
366
367 // The argument nodes created for the switches.
368 CondArgNodes cond_arg_nodes_;
369
370 // The constructed If node.
371 Node* if_node_ = nullptr;
372
373 // Whether the merge nodes of this conditional have been replaced.
374 bool replaced_ = false;
375 };
376
Conditional(OutputTensor predicate,FunctionalizeCond * parent,StateMap * cond_state_map)377 Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent,
378 StateMap* cond_state_map)
379 : parent_(parent), state_map_(cond_state_map), predicate_(predicate) {}
380
AddMerge(Node * m)381 Status Conditional::AddMerge(Node* m) {
382 merges_.insert(m);
383 return Status::OK();
384 }
385
AddSwitch(Node * s)386 Status Conditional::AddSwitch(Node* s) {
387 VLOG(5) << "Adding switch " << s->DebugString();
388 OutputTensor predicate;
389 TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate));
390 if (switch_predicate_.node == nullptr) switch_predicate_ = predicate;
391 if (!(switch_predicate_ == predicate)) {
392 return errors::InvalidArgument(
393 "Merge nodes ", NodesToString(merges_),
394 " directly dominated by switch nodes with different predicates (",
395 DebugString(switch_predicate_), " vs ", DebugString(predicate), ").");
396 }
397 switches_.insert(s);
398 parent_->AddSwitchId(s->id());
399 return Status::OK();
400 }
401
BuildArgumentNodes()402 Status Conditional::BuildArgumentNodes() {
403 VLOG(1) << "Build function arguments";
404 struct Hash {
405 size_t operator()(const std::pair<Node*, int>& item) const {
406 return Hash64Combine(hash<Node*>()(item.first),
407 std::hash<int>()(item.second));
408 }
409 };
410
411 std::unordered_map<std::pair<Node*, int>, int, Hash> input_index;
412 for (Node* switch_node : switches_) {
413 const Edge* e;
414 TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e));
415 std::pair<Node*, int> key = std::make_pair(e->src(), e->src_output());
416 if (input_index.find(key) == input_index.end()) {
417 input_index[key] = cond_arg_nodes_.size();
418 cond_arg_nodes_.emplace_back(key.first, key.second);
419 }
420 cond_arg_nodes_.at(input_index.at(key)).switches.push_back(switch_node);
421 }
422 VLOG(5) << "CondArg nodes created: " << DebugString(cond_arg_nodes_);
423
424 int arg_count = 0;
425 for (CondArgNode& cond_arg_node : cond_arg_nodes_) {
426 DataType dtype = cond_arg_node.src->output_type(cond_arg_node.src_output);
427 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
428 int branch_index = static_cast<int>(branch);
429 TF_RETURN_IF_ERROR(
430 NodeBuilder(absl::StrCat("_Arg", arg_count),
431 FunctionLibraryDefinition::kArgOp)
432 .Attr("T", dtype)
433 .Attr("index", arg_count)
434 .Finalize(bodies_[branch_index].get(),
435 &cond_arg_node.branch_copy[branch_index]));
436 }
437 for (Node* node : cond_arg_node.switches) {
438 for (const Edge* e : node->out_edges()) {
439 if (e->IsControlEdge()) continue;
440 int branch_index = e->src_output();
441 Node* src_copy = cond_arg_node.branch_copy[branch_index];
442 Node* dst_copy = node_maps_[branch_index][e->dst()->id()];
443
444 // The graph may contain dead switch nodes,
445 if (dst_copy == nullptr) continue;
446
447 TF_RET_CHECK(dst_copy != nullptr)
448 << "Unable to find copied node for " << e->dst()->DebugString()
449 << " on branch " << Branch_Name(BranchType(branch_index));
450 // If the input goes directly to a merge then the merge has
451 // been replaced by a retval so the dst input is 0 instead of
452 // dst_input.
453 int dst_input = IsMerge(e->dst()) ? 0 : e->dst_input();
454 bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input);
455 }
456 }
457 ++arg_count;
458 }
459
460 // Verify that all retvals have an input.
461 // TODO(jpienaar): One could add a ZerosLike in the branch that doesn't have
462 // input.
463 for (Node* m : merges_) {
464 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
465 bool has_input = false;
466 for (auto e : node_maps_[static_cast<int>(branch)][m->id()]->in_edges()) {
467 if (!e->IsControlEdge()) {
468 has_input = true;
469 break;
470 }
471 }
472 if (!has_input) {
473 return errors::Internal(
474 "Failed to functionalize control flow with merge ",
475 FormatNodeForError(*m), " that doesn't have input on ",
476 Branch_Name(branch), " branch.");
477 }
478 }
479 }
480
481 return Status::OK();
482 }
483
AddSwitchNodeAlongEdge(const Edge * edge,BranchType branch,Graph * graph)484 Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
485 Graph* graph) {
486 // Previously we had edge:
487 // src:src_output ---- edge ----> dst:dst_input
488 // post this we have (in graph)
489 // src:src_output --> switch<pred> --- new_edge --> dst:dst_input
490
491 // TODO(jpienaar): One could keep a map caching the extra switch nodes added
492 // to avoid adding another switch to feed a value for which a switch was
493 // already added.
494 Node* switch_node;
495 Node* src = edge->src();
496 int src_output = edge->src_output();
497 TF_RETURN_IF_ERROR(
498 NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")),
499 "Switch")
500 .Input(src, src_output)
501 .Input(const_cast<Node*>(predicate_.node), predicate_.index)
502 .Finalize(graph, &switch_node));
503 state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src));
504 state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src));
505
506 Node* dst = edge->dst();
507 int dst_input = edge->dst_input();
508 graph->RemoveEdge(edge);
509 graph->AddEdge(switch_node, static_cast<int>(branch), dst, dst_input);
510 return AddSwitch(switch_node);
511 }
512
ExtractBodies(Graph * graph)513 Status Conditional::ExtractBodies(Graph* graph) {
514 VLOG(2) << "Extracting bodies for " << name();
515 for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) {
516 bodies_[static_cast<int>(b)] =
517 absl::make_unique<Graph>(graph->op_registry());
518 }
519
520 auto find_branch = [&](const Edge* e) {
521 const auto& id = state_map_->LookupCondId(e->src());
522 return IsSwitch(e->src()) ? BranchType(e->src_output())
523 : state_map_->FindBranchOf(id, predicate_);
524 };
525
526 std::array<std::vector<Node*>, 2> stacks;
527 VLOG(5) << "Merges: " << NodesToString(merges_);
528 for (Node* m : merges_) {
529 VLOG(5) << "For merge: " << m->DebugString() << " "
530 << state_map_->CondStateToString(m);
531 for (auto e : m->in_edges()) {
532 if (e->IsControlEdge()) continue;
533 BranchType branch = find_branch(e);
534 TF_RET_CHECK(branch == BranchType::kThenBranch ||
535 branch == BranchType::kElseBranch)
536 << "Error: " << e->src()->name()
537 << " is not on either then or else branch (" << Branch_Name(branch)
538 << ") for predicate " << DebugString(predicate_) << " ["
539 << DebugString(state_map_->LookupCondId(e->src())) << "].";
540 Node* src = e->src();
541 if (IsSwitch(src)) {
542 // Switch node outputs and dependencies are handled separately.
543 TF_RETURN_IF_ERROR(AddSwitch(src));
544 } else {
545 stacks[static_cast<int>(branch)].push_back(src);
546 }
547 }
548 }
549
550 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
551 int branch_index = static_cast<int>(branch);
552 auto output = bodies_[branch_index].get();
553 auto& stack = stacks[branch_index];
554 VLOG(5) << "In branch: " << Branch_Name(branch) << " "
555 << NodesToString(stack);
556 std::vector<bool> visited(graph->num_node_ids(), false);
557 node_maps_[branch_index].resize(graph->num_node_ids(), nullptr);
558 auto& node_map = node_maps_[branch_index];
559
560 while (!stack.empty()) {
561 Node* n = stack.back();
562 stack.pop_back();
563
564 if (visited.at(n->id())) continue;
565 visited[n->id()] = true;
566
567 // Verify output edges and record control edges exitting scope.
568 for (const Edge* e : n->out_edges()) {
569 Node* dst = e->dst();
570 if (IsMerge(dst)) continue;
571 Node* src = e->src();
572
573 auto dst_id = state_map_->LookupCondId(dst);
574 auto src_id = state_map_->LookupCondId(src);
575 if (dst_id != src_id) {
576 if (e->IsControlEdge()) {
577 external_control_outputs_.push_back(e->src());
578 } else {
579 // Constants are treated specially to workaround the case of
580 // non-dominated constant nodes.
581 if (!IsConstant(src)) {
582 // TODO(b/78882471): A node that feeds into two different
583 // CondState is not necessarily an error so log a warning for now
584 // but revisit to improve the testing to enable making this an
585 // error.
586 LOG(WARNING) << errors::InvalidArgument(
587 "Graph contains node ", FormatNodeForError(*src),
588 " that feeds into node ", FormatNodeForError(*dst),
589 " but these nodes are in different control contexts (",
590 DebugString(src_id), " vs ", DebugString(dst_id),
591 " (detected during out edge testing)");
592 }
593 }
594 }
595 }
596
597 // Copying incomming edges to dst node. Iterate over a copy of the edges
598 // as they could be mutated during iteration.
599 std::vector<const Edge*> in_edges(n->in_edges().begin(),
600 n->in_edges().end());
601 for (const Edge* e : in_edges) {
602 Node* src = e->src();
603 // Skip src/dst node.
604 if (!src->IsOp()) continue;
605
606 Node* dst = e->dst();
607 if (IsSwitch(src)) {
608 // Switch node outputs and dependencies are handled separately.
609 TF_RETURN_IF_ERROR(AddSwitch(src));
610 continue;
611 }
612
613 // Verify input is from the same context.
614 auto src_id = state_map_->LookupCondId(src);
615 auto dst_id = state_map_->LookupCondId(dst);
616 if (IsMerge(dst) || src_id == dst_id) {
617 // TODO(jpienaar): The merge case can be more strict.
618 if (node_map.at(src->id()) == nullptr) {
619 node_map.at(src->id()) = output->CopyNode(src);
620 stack.push_back(src);
621 }
622 } else if (e->IsControlEdge()) {
623 // Here we have a control flow edge between src and dst that are not
624 // in the same context. This is an external control dependency except
625 // for one case: where the only difference between CondId of e->src()
626 // and CondId of e->dst() is that e->src() has {PRED, kNeither} and
627 // e->dst() has {PRED, kThenBranch/kElseBranch}. This happens in
628 // gradients code for tf.cond(), where e->src() is a control pivot
629 // node for a branch and e->dst() is a data node in that branch.
630 bool is_external_control_input = true;
631 if (!state_map_->IsEmpty(src_id) && !state_map_->IsEmpty(dst_id)) {
632 std::vector<StateMap::CondState::value_type> diff;
633 std::set_symmetric_difference(
634 src_id->begin(), src_id->end(), dst_id->begin(), dst_id->end(),
635 std::back_inserter(diff), CondStateLess());
636 if (diff.size() == 2 && diff[0].first == diff[1].first &&
637 (diff[0].second == BranchType::kNeither ||
638 diff[1].second == BranchType::kNeither)) {
639 auto src_branch = src_id->find(diff[0].first);
640 if (src_branch != src_id->end() &&
641 src_branch->second == BranchType::kNeither) {
642 is_external_control_input = false;
643 }
644 }
645 }
646 if (is_external_control_input) {
647 external_control_inputs_.push_back(src);
648 }
649 } else {
650 // This shouldn't happen, this means we have an external data input
651 // not entering via a switch node. Work around this by for
652 // * constant nodes copy them;
653 // * non-constant nodes, insert a switch along the edge;
654 if (IsConstant(src)) {
655 node_map.at(src->id()) = output->CopyNode(src);
656 } else {
657 StateMap::CondState state = *dst_id;
658 state.erase(predicate_);
659 if (state_map_->GetCondId(state) == src_id) {
660 TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph));
661 continue;
662 } else {
663 return errors::InvalidArgument(
664 "Graph contains node ", FormatNodeForError(*src),
665 " that feeds into node ", FormatNodeForError(*dst),
666 " but these nodes are in different control contexts (",
667 DebugString(src_id), " vs ", DebugString(dst_id),
668 " (detected during in edge testing)");
669 }
670 }
671 }
672
673 Node* src_copy = node_map.at(e->src()->id());
674 int src_output = e->src_output();
675 if (node_map.at(dst->id()) == nullptr) {
676 node_map.at(dst->id()) = output->CopyNode(dst);
677 }
678 Node* dst_copy = node_map.at(e->dst()->id());
679 if (e->IsControlEdge()) {
680 // Skip control inputs from external context.
681 if (src_copy != nullptr) output->AddControlEdge(src_copy, dst_copy);
682 } else {
683 output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
684 }
685 }
686 }
687 }
688
689 // Build return values from the merge nodes.
690 int index = 0;
691 for (Node* m : merges_) {
692 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
693 int branch_index = static_cast<int>(branch);
694 auto& node_map = node_maps_[branch_index];
695 auto output = bodies_[branch_index].get();
696 TF_ASSIGN_OR_RETURN(node_map[m->id()],
697 BuildRetvalNode(output, m->output_type(0), index));
698 }
699 ++index;
700
701 // Connect the input to the merge_ with the retval, except if it is a
702 // Swich node, which is handled separately.
703 for (auto e : m->in_edges()) {
704 if (e->IsControlEdge()) continue;
705 int branch_index = static_cast<int>(find_branch(e));
706 auto& node_map = node_maps_[branch_index];
707 auto output = bodies_[branch_index].get();
708 Node* in = e->src();
709 if (!IsSwitch(in)) {
710 if (node_map.at(in->id()) == nullptr) {
711 node_map[in->id()] = output->CopyNode(in);
712 }
713 output->AddEdge(node_map[in->id()], e->src_output(),
714 node_map.at(m->id()), 0);
715 }
716 }
717 }
718 return Status::OK();
719 }
720
BuildIfNode(Graph * graph,FunctionLibraryDefinition * library)721 Status Conditional::BuildIfNode(Graph* graph,
722 FunctionLibraryDefinition* library) {
723 VLOG(2) << "Build cond function for " << name();
724 NodeDebugInfo debug_info((*merges_.begin())->def());
725 NodeDefBuilder builder(name(), "If", library, &debug_info);
726 const string branch_name[] = {"else_branch", "then_branch"};
727 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
728 int branch_index = static_cast<int>(branch);
729 static std::atomic<int64> sequence_num(0LL);
730 int64 id = ++sequence_num;
731
732 NameAttrList body_name;
733 body_name.set_name(
734 absl::StrCat("_functionalize_if_", branch_name[branch_index], "_", id));
735
736 VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index]
737 << "): "
738 << DumpGraphToFile(
739 "functionalize_cond_body_" + branch_name[branch_index],
740 *bodies_[branch_index], nullptr);
741
742 FunctionDef body_fdef;
743 TF_RETURN_IF_ERROR(GraphToFunctionDef(*bodies_[branch_index],
744 body_name.name(), &body_fdef));
745 TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
746 builder.Attr(branch_name[branch_index], body_name);
747 }
748
749 VLOG(3) << "Build input type";
750 std::vector<NodeDefBuilder::NodeOut> inputs;
751 DataTypeVector in_arg_types;
752 for (auto& kv : cond_arg_nodes_) {
753 bool inserted = false;
754 for (const Node* arg : kv.switches) {
755 const Edge* in_edge;
756 TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
757 if (in_edge->IsControlEdge()) {
758 builder.ControlInput(in_edge->src()->name());
759 } else {
760 if (!inserted) {
761 DataType dtype = arg->input_type(0);
762 inputs.emplace_back(NodeDefBuilder::NodeOut(
763 in_edge->src()->name(), in_edge->src_output(), dtype));
764 in_arg_types.push_back(dtype);
765 inserted = true;
766 }
767 }
768 }
769 }
770 builder.Attr("Tin", in_arg_types);
771
772 DataTypeVector out_type;
773 for (const Node* merge : merges_) {
774 DataType dtype = merge->output_type(0);
775 out_type.push_back(dtype);
776 }
777 builder.Attr("Tout", out_type);
778 VLOG(3) << "Build output type: " << DataTypeVectorString(out_type);
779
780 builder.Attr("Tcond", DT_BOOL);
781 string outside_compilation;
782 if (GetNodeAttr(predicate_.node->def(), kXlaOutsideCompilationAttrName,
783 &outside_compilation)
784 .ok()) {
785 builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
786 }
787 builder.Device(predicate_.node->assigned_device_name());
788 // Conditional should be the first input ...
789 builder.Input(
790 NodeDefBuilder::NodeOut(predicate_.node->name(), predicate_.index,
791 predicate_.node->output_type(predicate_.index)));
792 // ... followed by the other inputs.
793 builder.Input(inputs);
794
795 VLOG(3) << "Build If node";
796 NodeDef if_def;
797 TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
798 TF_ASSIGN_OR_RETURN(if_node_,
799 parent_->AddIfNode(if_def, *merges_.begin(), predicate_));
800
801 return Status::OK();
802 }
803
AddInputEdges(Graph * graph,const std::unordered_map<Node *,OutputTensor> & merge_to_replacement)804 Status Conditional::AddInputEdges(
805 Graph* graph,
806 const std::unordered_map<Node*, OutputTensor>& merge_to_replacement) {
807 VLOG(2) << "AddInputEdges for " << if_node_->name();
808 int index = 0;
809 // Add predicate input.
810 if (predicate_.node->IsMerge()) {
811 // If the predicate is a Merge node, we should not use Merge output as
812 // predicate. Instead, we should use the corresponding If output in
813 // 'merge_to_replacement'. Otherwise, this Conditional's If node is still
814 // connected to the predicate Merge node; and when we call
815 // DeleteReachableAndDeadNodes(), the predicate Merge node and this
816 // Conditional's If node will be removed.
817 auto iter = merge_to_replacement.find(predicate_.node);
818 if (iter == merge_to_replacement.end()) {
819 return errors::Internal("Cannot find replacement for Merge node ",
820 predicate_.node->name());
821 }
822 graph->AddEdge(iter->second.node, iter->second.index, if_node_, index++);
823 } else {
824 graph->AddEdge(const_cast<Node*>(predicate_.node), predicate_.index,
825 if_node_, index++);
826 }
827 // Add function body inputs.
828 for (auto& arg : cond_arg_nodes_) {
829 if (arg.src_output == Graph::kControlSlot) {
830 graph->AddControlEdge(arg.src, if_node_);
831 } else {
832 graph->AddEdge(arg.src, arg.src_output, if_node_, index++);
833 }
834 }
835 for (Node* n : external_control_inputs_) {
836 graph->AddControlEdge(n, if_node_);
837 }
838 return Status::OK();
839 }
840
AddOutputEdges(Graph * graph,std::unordered_map<Node *,OutputTensor> * merge_to_replacement)841 Status Conditional::AddOutputEdges(
842 Graph* graph,
843 std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
844 VLOG(2) << "AddOutputEdges for " << if_node_->name();
845 int i = 0;
846 for (Node* node : merges_) {
847 TF_RETURN_IF_ERROR(parent_->AddIdentityNode(node, if_node_, i));
848 std::vector<const Edge*> edges(node->out_edges().begin(),
849 node->out_edges().end());
850 for (const Edge* edge : edges) {
851 Node* dst = edge->dst();
852 int dst_input = edge->dst_input();
853 if (edge->src_output() > 0) {
854 return errors::Unimplemented("Output of index (", edge->src_output(),
855 ") of merge node ",
856 FormatNodeForError(*node));
857 }
858
859 bool control_edge = edge->IsControlEdge();
860 graph->RemoveEdge(edge);
861 if (control_edge) {
862 graph->AddControlEdge(if_node_, dst);
863 } else {
864 graph->AddEdge(if_node_, i, dst, dst_input);
865 }
866 }
867
868 // Record corresponding output tensor in 'merge_to_replacement'.
869 (*merge_to_replacement)[node] = OutputTensor{if_node_, i};
870
871 ++i;
872 }
873 for (Node* n : external_control_outputs_) {
874 graph->AddControlEdge(if_node_, n);
875 }
876
877 return Status::OK();
878 }
879
BuildAndReplace(Graph * graph,FunctionLibraryDefinition * library,std::unordered_map<Node *,OutputTensor> * merge_to_replacement)880 Status Conditional::BuildAndReplace(
881 Graph* graph, FunctionLibraryDefinition* library,
882 std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
883 VLOG(1) << "Build If and replace merge nodes "
884 << NodesToString(this->merges_);
885 if (replaced_) return Status::OK();
886
887 TF_RETURN_IF_ERROR(ExtractBodies(graph));
888 TF_RETURN_IF_ERROR(BuildArgumentNodes());
889
890 if (VLOG_IS_ON(3)) {
891 LOG(INFO) << "Extracted bodies:";
892 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
893 int branch_index = static_cast<int>(branch);
894 auto output = bodies_[branch_index].get();
895 LOG(INFO) << Branch_Name(branch) << ": "
896 << DebugString(output->ToGraphDefDebug());
897 }
898 }
899
900 TF_RETURN_IF_ERROR(BuildIfNode(graph, library));
901 TF_RETURN_IF_ERROR(AddInputEdges(graph, *merge_to_replacement));
902 TF_RETURN_IF_ERROR(AddOutputEdges(graph, merge_to_replacement));
903 TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
904
905 // Check that the if_node doesn't feed into itself.
906 TF_RETURN_WITH_CONTEXT_IF_ERROR(
907 CheckNodeNotInCycle(if_node_, graph->num_node_ids()),
908 "Converting to If failed.");
909
910 replaced_ = true;
911 return Status::OK();
912 }
913
name() const914 string Conditional::name() const {
915 CHECK(!merges_.empty());
916 return absl::StrCat((*merges_.begin())->name(), "_if");
917 }
918
AddIdentityNode(const Node * replacee,Node * if_node,int port)919 Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
920 int port) {
921 Node* id;
922 TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity")
923 .Input(if_node, port)
924 .Finalize(graph_, &id));
925 state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
926 state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
927 return Status::OK();
928 }
929
AddIfNode(const NodeDef & def,const Node * replacee,const OutputTensor & predicate)930 StatusOr<Node*> FunctionalizeCond::AddIfNode(const NodeDef& def,
931 const Node* replacee,
932 const OutputTensor& predicate) {
933 Status status;
934 Node* ret = graph_->AddNode(def, &status);
935 TF_RETURN_IF_ERROR(status);
936 VLOG(1) << "Adding If for " << replacee->name();
937 StateMap::CondId id = state_map_.LookupCondId(replacee);
938 if (id) {
939 StateMap::CondState state = *id;
940 state.erase(predicate);
941 state_map_.ResetCondId(ret, state_map_.GetCondId(state));
942 } else {
943 state_map_.ResetCondId(ret, nullptr);
944 }
945
946 state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee));
947
948 return ret;
949 }
950
PropagateUpdatedState(const Node * replacee)951 Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) {
952 VLOG(2) << "Propagating update state for " << replacee->name() << " "
953 << state_map_.CondStateToString(replacee);
954 // Redo topological sort as the order could have changed.
955 // TODO(jpienaar): The original topological order could also be updated
956 // dynamically if needed.
957 std::vector<Node*> rev_topo_order;
958 GetPostOrder(*graph_, &rev_topo_order);
959
960 // All the outputs of the new node could potentially be updated.
961 std::unordered_set<Node*> changed;
962 for (auto n : replacee->out_nodes())
963 if (n->IsOp()) changed.insert(n);
964
965 // Iterate through the changed/possible changed nodes in topological order.
966 for (auto it = rev_topo_order.rbegin();
967 it != rev_topo_order.rend() && !changed.empty(); ++it) {
968 if (changed.find(*it) != changed.end()) {
969 // Update the node state.
970 Node* n = *it;
971 StateMap::CondId old_state = state_map_.LookupCondId(n);
972 state_map_.ResetCondId(n, nullptr);
973 TF_RETURN_IF_ERROR(DetermineCondState(n));
974 if (state_map_.LookupCondId(n) != old_state) {
975 for (auto out : n->out_nodes())
976 if (out->IsOp()) changed.insert(out);
977 }
978 changed.erase(n);
979 }
980 }
981 return Status::OK();
982 }
983
984 // Returns the most restrictive branch of two branches or neither. This is the
985 // meet operator of the BranchType lattice.
MeetBranch(const BranchType & lhs,const BranchType & rhs)986 BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) {
987 if (lhs == rhs) return lhs;
988 if (lhs == BranchType::kNeither) return rhs;
989 if (rhs == BranchType::kNeither) return lhs;
990 if (lhs == BranchType::kBoth) return rhs;
991 if (rhs == BranchType::kBoth) return lhs;
992 return BranchType::kNeither;
993 }
994
FindBranchOf(CondId id,OutputTensor predicate) const995 BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
996 if (IsEmpty(id)) return BranchType::kNeither;
997 const CondState& nodes = *id;
998 auto it = nodes.find(predicate);
999 if (it == nodes.end()) return BranchType::kNeither;
1000 return it->second;
1001 }
1002
JoinCondStatesNonMerge(StateMap::CondId src,StateMap::CondId dst)1003 StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
1004 StateMap::CondId src, StateMap::CondId dst) {
1005 VLOG(5) << "Joining src=" << DebugString(src) << " [" << src
1006 << "] and dst=" << DebugString(dst) << " [" << dst << "]";
1007
1008 if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src;
1009 if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst;
1010
1011 // Nothing to do if the CondState is the same.
1012 if (src == dst) return src;
1013
1014 StateMap::CondState both = *src;
1015 for (const auto& kv : *dst) {
1016 auto it = both.find(kv.first);
1017 if (it == both.end()) {
1018 both.insert(kv);
1019 } else {
1020 if (it->second != kv.second) {
1021 if (it->second == BranchType::kNeither) {
1022 // BranchType for 'src' is kNeither. Use the BranchType in 'dst'.
1023 it->second = kv.second;
1024 } else if (kv.second == BranchType::kNeither) {
1025 // BranchType for 'dst' is kNeither. Use the BranchType in 'src'.
1026 // No need to change it->second.
1027 } else {
1028 return errors::InvalidArgument(
1029 "Graph contains node with inputs predicated on incompatible "
1030 "predicates: ",
1031 DebugString(src), " and ", DebugString(dst));
1032 }
1033 }
1034 }
1035 }
1036 return state_map_.GetCondId(both);
1037 }
1038
JoinCondStatesMerge(Node * merge,StateMap::CondId src,StateMap::CondId dst)1039 StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
1040 Node* merge, StateMap::CondId src, StateMap::CondId dst) {
1041 // Determine the flow state when joining two states for a merge
1042 // node. Combining the two states for a merge node is effectively performing a
1043 // disjunction of the states along the different input edges. For a merge that
1044 // can be transformed into a If the two inputs paths have to have a predicate
1045 // on which they differ (e.g., along one edge predicate `p` has to hold while
1046 // on another it should not). This function first determines this predicate
1047 // and then the resultant state is the common path between the two inputs
1048 // followed by s(p, both).
1049 VLOG(4) << "Joining (for merge) " << DebugString(src) << " and "
1050 << DebugString(dst);
1051 if (state_map_.IsEmpty(dst)) return src;
1052 if (state_map_.IsEmpty(src)) {
1053 return errors::Internal("Merge node ", merge->name(),
1054 " has input that's not in any CondContext.");
1055 }
1056
1057 if (state_map_.IsDead(src)) return src;
1058 if (state_map_.IsDead(dst)) return dst;
1059
1060 std::vector<StateMap::CondState::value_type> diff;
1061 StateMap::CondState merged;
1062 std::set_symmetric_difference(src->begin(), src->end(), dst->begin(),
1063 dst->end(), std::back_inserter(diff),
1064 CondStateLess());
1065 std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(),
1066 std::inserter(merged, merged.begin()), CondStateLess());
1067
1068 // Update mapping from merge node to predicate.
1069 if (diff.size() == 2) {
1070 auto pred = diff[0].first;
1071 bool different_branches = (diff[0].second != diff[1].second) &&
1072 (diff[0].second == BranchType::kThenBranch ||
1073 diff[0].second == BranchType::kElseBranch) &&
1074 (diff[1].second == BranchType::kThenBranch ||
1075 diff[1].second == BranchType::kElseBranch);
1076 if (!(pred == diff[1].first) || !different_branches)
1077 return errors::InvalidArgument(
1078 "Unable to determine predicate for merge node");
1079 merge_to_predicate_[merge] = pred;
1080 } else {
1081 return errors::InvalidArgument(
1082 "Merge of two inputs that differ on more than one predicate ",
1083 DebugString(src), " and ", DebugString(dst));
1084 }
1085
1086 return state_map_.GetCondId(merged);
1087 }
1088
StateAlongEdge(const Edge * e)1089 StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
1090 Node* src = e->src();
1091 StateMap::CondId id = state_map_.LookupCondId(e->src());
1092
1093 // Dead nodes only propagate dead state.
1094 if (state_map_.IsDead(id)) return id;
1095
1096 if (IsSwitch(src)) {
1097 StateMap::CondState state;
1098 if (id != nullptr) state = *id;
1099 OutputTensor predicate;
1100 TF_CHECK_OK(GetSwitchPredicate(*src, &predicate));
1101 if (e->IsControlEdge()) {
1102 // In gradients of tf.cond(), in each branch, we have a NoOp node as
1103 // control pivot. These NoOp nodes have control dependency from Switch
1104 // node. If we don't record this into CondState, branches might have
1105 // incorrect CondState (e.g. if the branch only has a Const data node).
1106 // We set it to kNeither because there is no way to tell whether it's
1107 // for true branch or false branch. This node's desendents might have
1108 // other incoming edges with defined BranchType, and we correctly handle
1109 // merging kNeither with other defined BranchType in StateAlongEdge().
1110 state[predicate] = BranchType::kNeither;
1111 } else {
1112 state[predicate] = BranchType(e->src_output());
1113 }
1114 return state_map_.GetCondId(state);
1115 }
1116 return id;
1117 }
1118
DetermineCondStateMerge(Node * dst)1119 Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
1120 // Only Merge nodes with two inputs are supported, but if this is a redundant
1121 // merge, then the dead edge may already have been removed (if due to a
1122 // switch) and so the input count would be incorrect.
1123 if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK();
1124
1125 int data_inputs = 0;
1126 for (auto e : dst->in_edges()) {
1127 Node* src = e->src();
1128 VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " "
1129 << state_map_.CondStateToString(src);
1130 if (!src->IsOp()) continue;
1131 if (!e->IsControlEdge()) ++data_inputs;
1132
1133 StateMap::CondId prop = StateAlongEdge(e);
1134 auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst));
1135 TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1136 FormatNodeForError(*dst));
1137 state_map_.ResetCondId(dst, id_or.ValueOrDie());
1138 }
1139
1140 // Incomplete Merge nodes are not supported.
1141 if (data_inputs != 2) {
1142 return errors::Unimplemented(
1143 dst->name(), " only has ", data_inputs,
1144 " inputs, while only merge nodes with two inputs supported.");
1145 }
1146 return Status::OK();
1147 }
1148
DetermineCondStateNonMerge(Node * dst)1149 Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) {
1150 // Handle non-merge join.
1151 for (auto e : dst->in_edges()) {
1152 VLOG(4) << "Processing forward flow for: " << e->DebugString() << " "
1153 << state_map_.CondStateToString(dst);
1154 Node* src = e->src();
1155 if (!src->IsOp()) continue;
1156
1157 // Joining the state between the current and propagated state.
1158 StateMap::CondId prop = StateAlongEdge(e);
1159 auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst));
1160 TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1161 FormatNodeForError(*dst));
1162 state_map_.ResetCondId(dst, id_or.ValueOrDie());
1163 }
1164 return Status::OK();
1165 }
1166
RemoveRedundantMerge(Node * node)1167 Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
1168 // Handle redundant merge nodes. A merge node is considered redundant if
1169 // one input edge is dead while the other has a value.
1170 if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK();
1171
1172 const Edge* non_dead_edge = nullptr;
1173 for (auto e : node->in_edges()) {
1174 if (e->IsControlEdge()) continue;
1175 Node* src = e->src();
1176
1177 // Handle merge with dead state.
1178 const auto& src_id = state_map_.LookupCondId(src);
1179 if (!state_map_.IsDead(src_id)) {
1180 non_dead_edge = e;
1181 break;
1182 }
1183 }
1184
1185 if (non_dead_edge == nullptr) {
1186 return errors::InvalidArgument("Merge node ", FormatNodeForError(*node),
1187 " has no non-dead inputs.");
1188 }
1189 state_map_.MarkDead(node);
1190 VLOG(5) << "removing redundant merge: " << node->name();
1191 while (!node->out_edges().empty()) {
1192 const Edge* oe = *node->out_edges().begin();
1193 Node* dst_node = oe->dst();
1194 int dst_port = oe->dst_input();
1195 graph_->RemoveEdge(oe);
1196 graph_->AddEdge(non_dead_edge->src(),
1197 dst_port == Graph::kControlSlot
1198 ? Graph::kControlSlot
1199 : non_dead_edge->src_output(),
1200 dst_node, dst_port);
1201 }
1202 return Status::OK();
1203 }
1204
RemoveRedundantSwitch(Node * node)1205 Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
1206 // Handle redundant switch nodes. A switch node is considered redundant if
1207 // the predicate of the switch already holds on the current branch. E.g., if
1208 // p is the predicate of the switch but p is already known to hold on this
1209 // branch, then the switch can be removed and the dead state propagated
1210 // along one. The checking of predicate is based on the exact predicate
1211 // (rather than boolean equivalence) and aimed at redundant switches as
1212 // currently generated by gradient code.
1213 StateMap::CondId dst_id = state_map_.LookupCondId(node);
1214 if (state_map_.IsDead(dst_id)) return Status::OK();
1215
1216 BranchType b;
1217 OutputTensor pred;
1218 TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred));
1219
1220 // Determine if we are already on a branch where the switch predicate is
1221 // true/false. Consider both the data and predicate to determine if the
1222 // node is redundant (skipping over identity node).
1223 b = state_map_.FindBranchOf(dst_id, pred);
1224 if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) {
1225 OutputTensor val;
1226 const Edge* e;
1227 TF_RETURN_IF_ERROR(node->input_edge(0, &e));
1228 val = OutputTensor(e->src(), e->src_output());
1229 while (IsIdentity(val.node)) {
1230 TF_RETURN_IF_ERROR(val.node->input_edge(0, &e));
1231 val = OutputTensor(e->src(), e->src_output());
1232 }
1233 b = state_map_.FindBranchOf(dst_id, val);
1234 if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
1235 return Status::OK();
1236 }
1237
1238 VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " "
1239 << DebugString(dst_id);
1240 const Edge* value_edge;
1241 TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge));
1242 Node* val_node = value_edge->src();
1243 int val_port = value_edge->src_output();
1244 while (!node->out_edges().empty()) {
1245 auto e = *node->out_edges().begin();
1246 Node* dst_node = e->dst();
1247 int dst_input = e->dst_input();
1248 int switch_branch = e->src_output();
1249 graph_->RemoveEdge(e);
1250 if (switch_branch == Graph::kControlSlot) {
1251 if (IsMerge(dst_node)) {
1252 auto id_or = JoinCondStatesMerge(dst_node, dst_id,
1253 state_map_.LookupCondId(dst_node));
1254 TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1255 FormatNodeForError(*dst_node));
1256 state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
1257 } else {
1258 auto id_or =
1259 JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node));
1260 TF_RETURN_IF_ERROR(id_or.status());
1261 state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
1262 }
1263 } else if (BranchType(switch_branch) != b) {
1264 state_map_.MarkDead(dst_node);
1265 continue;
1266 }
1267 graph_->AddEdge(
1268 val_node,
1269 switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port,
1270 dst_node, dst_input);
1271 }
1272 return Status::OK();
1273 }
1274
DetermineStates(std::vector<Node * > rev_topo_order)1275 Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
1276 // The state that is propagated along the given edge.
1277 for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) {
1278 Node* dst = *it;
1279 TF_RETURN_IF_ERROR(DetermineCondState(dst));
1280 TF_RETURN_IF_ERROR(DetermineAncestorState(dst));
1281 if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst));
1282 if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst));
1283
1284 VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
1285 << " @ " << state_map_.AncestorStateToString(dst);
1286 if (VLOG_IS_ON(10)) DumpGraphWithCondState("it");
1287 }
1288 return Status::OK();
1289 }
1290
DetermineAncestorState(Node * dst)1291 Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
1292 StateMap::AncestorId id = nullptr;
1293 StateMap::AncestorState state;
1294
1295 auto insert = [&](StateMap::AncestorId id, Node* src) {
1296 auto other_id = state_map_.LookupAncestorId(src);
1297 if (other_id != id && other_id != nullptr) {
1298 state.insert(other_id->begin(), other_id->end());
1299 }
1300 if (IsMerge(src)) {
1301 state.insert({{src, 0}, AncestorNode::AncestorNodeType::kMerge});
1302 } else if (IsSwitch(src)) {
1303 OutputTensor pred;
1304 // For dead switch nodes, GetSwitchPredicate() will fail, and we use
1305 // the switch node directly as ancestor.
1306 if (GetSwitchPredicate(*src, &pred).ok()) {
1307 state.insert({pred, AncestorNode::AncestorNodeType::kPred});
1308 } else {
1309 state.insert({{src, 0}, AncestorNode::AncestorNodeType::kSwitch});
1310 }
1311 }
1312 return state_map_.GetAncestorId(state);
1313 };
1314
1315 // Compute the union of all the switch/merge nodes that affects the input of
1316 // dst.
1317 for (auto e : dst->in_edges()) {
1318 Node* src = e->src();
1319 id = insert(id, src);
1320 }
1321 state_map_.ResetAncestorId(dst, id);
1322 return Status::OK();
1323 }
1324
DeleteReachableAndDeadNodes(const std::vector<Node * > & merge_order)1325 void FunctionalizeCond::DeleteReachableAndDeadNodes(
1326 const std::vector<Node*>& merge_order) {
1327 // Delete all nodes that have been extracted or are reachable from
1328 // deleted/dead nodes. The input and outgoing edges should have already been
1329 // removed.
1330 std::deque<int> delete_nodes;
1331 std::vector<bool> deleted(graph_->num_node_ids(), false);
1332 // Don't try to delete source or sink nodes.
1333 deleted[graph_->kSourceId] = true;
1334 deleted[graph_->kSinkId] = true;
1335
1336 // All remaining Switch nodes are not reachable from a Merge node and
1337 // removed. This is to account for dead Switch nodes.
1338 for (int s_id : switch_ids_) {
1339 Node* s = graph_->FindNodeId(s_id);
1340 if (s == nullptr) continue;
1341 for (const Edge* e : s->out_edges()) {
1342 // Control outputs of switch nodes (which are unconditionally executed if
1343 // the switch is) are not removed as they need not be part of a
1344 // conditional.
1345 if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
1346 }
1347 deleted[s_id] = true;
1348 graph_->RemoveNode(s);
1349 }
1350
1351 // All merge nodes should have been transformed at this point and we remove
1352 // them from the graph here.
1353 for (Node* m : merge_order) {
1354 for (const Edge* e : m->out_edges()) {
1355 // Similar to control outputs of switch nodes don't remove control
1356 // outputs of merge nodes.
1357 // TODO(jpienaar): Check cases where output edges still exist here vs
1358 // being removed in AddOutputEdges.
1359 if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
1360 }
1361 deleted[m->id()] = true;
1362 graph_->RemoveNode(m);
1363 }
1364
1365 // Enqueue all the dead nodes.
1366 for (Node* n : graph_->nodes()) {
1367 if (state_map_.IsDead(state_map_.LookupCondId(n))) {
1368 delete_nodes.push_back(n->id());
1369 }
1370 }
1371
1372 while (!delete_nodes.empty()) {
1373 int d_id = delete_nodes.front();
1374 delete_nodes.pop_front();
1375 if (deleted[d_id]) continue;
1376 Node* d = graph_->FindNodeId(d_id);
1377 // Switch and Merge nodes could have been deleted already.
1378 if (d == nullptr) continue;
1379 for (const Edge* e : d->out_edges()) {
1380 delete_nodes.push_back(e->dst()->id());
1381 }
1382 deleted[d_id] = true;
1383 graph_->RemoveNode(d);
1384 }
1385 }
1386
SortMergeNodes(std::vector<Node * > * merge_order)1387 void FunctionalizeCond::SortMergeNodes(std::vector<Node*>* merge_order) {
1388 // Sort merge nodes by nesting depth.
1389 using sort_pair = std::pair<int, Node*>;
1390 std::vector<sort_pair> inner_to_outer_merge_order;
1391 inner_to_outer_merge_order.reserve(merge_order->size());
1392 for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) {
1393 Node* merge = *it;
1394 StateMap::CondId id = state_map_.LookupCondId(merge);
1395 int depth = id != nullptr ? id->size() : 0;
1396 inner_to_outer_merge_order.emplace_back(depth, merge);
1397 }
1398 std::stable_sort(
1399 inner_to_outer_merge_order.begin(), inner_to_outer_merge_order.end(),
1400 [](sort_pair lhs, sort_pair rhs) { return lhs.first > rhs.first; });
1401 merge_order->clear();
1402 for (sort_pair t : inner_to_outer_merge_order) {
1403 merge_order->push_back(t.second);
1404 }
1405 }
1406
FunctionalizeInternal()1407 Status FunctionalizeCond::FunctionalizeInternal() {
1408 // The general approach for converting a tf.cond (as lowered via switch/merge
1409 // nodes) to a functional if is as follows:
1410 // 1. Determine the topological order and collect all the switch and merge
1411 // nodes in the graph;
1412 // 2. Compute the predicates and dominance structure for all the nodes in the
1413 // graph - this includes which predicate must be true for a op to execute
1414 // (predicate values are considered directly rather than attempting to
1415 // determine deeper equivalence). We shall refer to this structure as the
1416 // CondState;
1417 // 3. Sort the merge nodes by nesting depth;
1418 // 4. Extract merge nodes together that have the same CondState and
1419 // AncestorState from the innermost to the outermost into IfOps;
1420 // Note: In the above only nodes that feed into a merge node will be
1421 // considered for functionalization.
1422
1423 // Perform a DFS over the graph and
1424 // * Determine the reverse topological order of the nodes (there should be no
1425 // cycles at this point so the post-order numbering corresponds to the
1426 // reverse topological sorting);
1427 // * Record reverse topological for merge and switch nodes;
1428 std::vector<Node*> rev_topo_order;
1429 std::vector<Node*> merge_order;
1430 DFS(*graph_, nullptr, [&](Node* n) {
1431 if (IsSwitch(n)) {
1432 AddSwitchId(n->id());
1433 }
1434 if (IsMerge(n)) {
1435 merge_order.push_back(n);
1436 }
1437 if (n->IsOp()) {
1438 rev_topo_order.push_back(n);
1439 }
1440 });
1441
1442 // No merges to functionalize.
1443 if (merge_order.empty()) {
1444 // No merges mean no switch values consumed (as only considering values
1445 // fetchable as output of merge);
1446 DeleteReachableAndDeadNodes(merge_order);
1447 return Status::OK();
1448 }
1449
1450 TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
1451 if (VLOG_IS_ON(4)) DumpGraphWithCondState("id");
1452
1453 // Sort the merge nodes from innermost outwards.
1454 SortMergeNodes(&merge_order);
1455
1456 // Cluster merge nodes by (CondId, AncestorId, predicate) in order of
1457 // nesting. (CondId, AncestorId) is not enough, e.g.
1458 // pred1 = array_ops.placeholder(dtypes.bool, name='pred1')
1459 // pred2 = array_ops.placeholder(dtypes.bool, name='pred2')
1460 // cond1 = control_flow_ops.cond(pred1, ...)
1461 // cond2 = control_flow_ops.cond(pred2, ...)
1462 // cond3 = control_flow_ops.cond(pred1, use cond1 and cond2)
1463 // cond4 = control_flow_ops.cond(pred2, use cond1 and cond2)
1464 // cond3 and cond4 have the same (CondId, AncestorId), but they should not
1465 // be merged into one "If" node (because they have different predicates).
1466 std::deque<std::vector<Node*>> merge_clusters;
1467 std::map<ClusterTuple, int, ClusterTupleLessThan> merge_cluster_index;
1468 for (Node* merge : merge_order) {
1469 auto cond_id = state_map_.LookupCondId(merge);
1470 if (state_map_.IsDead(cond_id)) continue;
1471
1472 auto predicate = merge_to_predicate_.find(merge);
1473 if (predicate == merge_to_predicate_.end()) {
1474 return errors::Internal("Cannot find predicate for Merge node ",
1475 merge->name());
1476 }
1477
1478 ClusterTuple key = std::make_tuple(
1479 cond_id, state_map_.LookupAncestorId(merge), predicate->second);
1480 auto idx = merge_cluster_index.find(key);
1481 if (idx == merge_cluster_index.end()) {
1482 merge_cluster_index[key] = merge_clusters.size();
1483 merge_clusters.push_back({merge});
1484 } else {
1485 merge_clusters[idx->second].emplace_back(merge);
1486 }
1487 }
1488
1489 // Extract the conditionals from inner most to outer most. Extracting from
1490 // innermost to outermost enables the extraction pass to stop once it
1491 // encounters a Switch node instead of having to keep track of Switch/Merge
1492 // nodes seen.
1493 for (const auto& cluster : merge_clusters) {
1494 // Construct a Conditional with the predicate of the merge.
1495 Conditional cond(merge_to_predicate_.at(cluster.front()), this,
1496 &state_map_);
1497 for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge));
1498 TF_RETURN_IF_ERROR(
1499 cond.BuildAndReplace(graph_, library_, &merge_to_replacement_));
1500
1501 if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
1502 }
1503
1504 DeleteReachableAndDeadNodes(merge_order);
1505
1506 return Status::OK();
1507 }
1508
DumpGraphWithCondState(const string & name)1509 void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
1510 const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup";
1511
1512 for (Node* n : graph_->nodes()) {
1513 n->ClearAttr(kCondGroupDebugAttr);
1514 n->AddAttr(kCondGroupDebugAttr,
1515 absl::StrCat(state_map_.CondStateToString(n), "_",
1516 state_map_.AncestorStateToString(n)));
1517 }
1518 LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
1519 << DumpGraphToFile(absl::StrCat("functionalize_cond_", name),
1520 *graph_, library_);
1521 }
1522
AddSwitchId(int switch_id)1523 void FunctionalizeCond::AddSwitchId(int switch_id) {
1524 switch_ids_.push_back(switch_id);
1525 }
1526
Functionalize(Graph * graph,FunctionLibraryDefinition * library)1527 Status FunctionalizeCond::Functionalize(Graph* graph,
1528 FunctionLibraryDefinition* library) {
1529 VLOG(1) << "FunctionalizeCond::Functionalize";
1530 FunctionalizeCond fc(graph, library);
1531 return fc.FunctionalizeInternal();
1532 }
1533
1534 } // namespace functionalize_cond
1535
FunctionalizeCond(Graph * graph,FunctionLibraryDefinition * library)1536 Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) {
1537 // FunctionalizeControlFlow is invoked for every function, so the loops's
1538 // bodies and conditionals that were extracted into functions will be handled
1539 // in successive invocations.
1540 return functionalize_cond::FunctionalizeCond::Functionalize(graph, library);
1541 }
1542
1543 } // namespace tensorflow
1544